mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
Merge pull request #572 from willmiao/codex/design-ui-for-model-update-notifications
refactor: tighten civitai update endpoints
This commit is contained in:
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, Dict, Mapping
|
||||
from typing import TYPE_CHECKING, Callable, Dict, Mapping
|
||||
|
||||
import jinja2
|
||||
from aiohttp import web
|
||||
@@ -42,8 +42,12 @@ from .handlers.model_handlers import (
|
||||
ModelMoveHandler,
|
||||
ModelPageView,
|
||||
ModelQueryHandler,
|
||||
ModelUpdateHandler,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..services.model_update_service import ModelUpdateService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -99,10 +103,18 @@ class BaseModelRoutes(ABC):
|
||||
ws_manager=self._ws_manager,
|
||||
download_manager_factory=ServiceRegistry.get_download_manager,
|
||||
)
|
||||
self._model_update_service: ModelUpdateService | None = None
|
||||
|
||||
if service is not None:
|
||||
self.attach_service(service)
|
||||
|
||||
def set_model_update_service(self, service: "ModelUpdateService") -> None:
|
||||
"""Attach the model update tracking service."""
|
||||
|
||||
self._model_update_service = service
|
||||
self._handler_set = None
|
||||
self._handler_mapping = None
|
||||
|
||||
def attach_service(self, service) -> None:
|
||||
"""Attach a model service and rebuild handler dependencies."""
|
||||
self.service = service
|
||||
@@ -127,6 +139,7 @@ class BaseModelRoutes(ABC):
|
||||
|
||||
def _create_handler_set(self) -> ModelHandlerSet:
|
||||
service = self._ensure_service()
|
||||
update_service = self._ensure_model_update_service()
|
||||
page_view = ModelPageView(
|
||||
template_env=self.template_env,
|
||||
template_name=self.template_name or "",
|
||||
@@ -186,6 +199,12 @@ class BaseModelRoutes(ABC):
|
||||
ws_manager=self._ws_manager,
|
||||
logger=logger,
|
||||
)
|
||||
updates = ModelUpdateHandler(
|
||||
service=service,
|
||||
update_service=update_service,
|
||||
metadata_provider_selector=get_metadata_provider,
|
||||
logger=logger,
|
||||
)
|
||||
return ModelHandlerSet(
|
||||
page_view=page_view,
|
||||
listing=listing,
|
||||
@@ -195,6 +214,7 @@ class BaseModelRoutes(ABC):
|
||||
civitai=civitai,
|
||||
move=move,
|
||||
auto_organize=auto_organize,
|
||||
updates=updates,
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -273,3 +293,8 @@ class BaseModelRoutes(ABC):
|
||||
|
||||
return proxy
|
||||
|
||||
def _ensure_model_update_service(self) -> "ModelUpdateService":
|
||||
if self._model_update_service is None:
|
||||
raise RuntimeError("Model update service has not been attached")
|
||||
return self._model_update_service
|
||||
|
||||
|
||||
@@ -21,7 +21,9 @@ class CheckpointRoutes(BaseModelRoutes):
|
||||
"""Initialize services from ServiceRegistry"""
|
||||
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
self.service = CheckpointService(checkpoint_scanner)
|
||||
|
||||
update_service = await ServiceRegistry.get_model_update_service()
|
||||
self.set_model_update_service(update_service)
|
||||
|
||||
# Attach service dependencies
|
||||
self.attach_service(self.service)
|
||||
|
||||
|
||||
@@ -20,7 +20,9 @@ class EmbeddingRoutes(BaseModelRoutes):
|
||||
"""Initialize services from ServiceRegistry"""
|
||||
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
|
||||
self.service = EmbeddingService(embedding_scanner)
|
||||
|
||||
update_service = await ServiceRegistry.get_model_update_service()
|
||||
self.set_model_update_service(update_service)
|
||||
|
||||
# Attach service dependencies
|
||||
self.attach_service(self.service)
|
||||
|
||||
|
||||
@@ -29,6 +29,7 @@ from ...services.use_cases import (
|
||||
)
|
||||
from ...services.websocket_manager import WebSocketManager
|
||||
from ...services.websocket_progress_callback import WebSocketProgressCallback
|
||||
from ...services.errors import RateLimitError
|
||||
from ...utils.file_utils import calculate_sha256
|
||||
from ...utils.metadata_manager import MetadataManager
|
||||
|
||||
@@ -1017,6 +1018,156 @@ class ModelAutoOrganizeHandler:
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
|
||||
class ModelUpdateHandler:
|
||||
"""Handle update tracking requests."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
service,
|
||||
update_service,
|
||||
metadata_provider_selector,
|
||||
logger: logging.Logger,
|
||||
) -> None:
|
||||
self._service = service
|
||||
self._update_service = update_service
|
||||
self._metadata_provider_selector = metadata_provider_selector
|
||||
self._logger = logger
|
||||
|
||||
async def refresh_model_updates(self, request: web.Request) -> web.Response:
|
||||
payload = await self._read_json(request)
|
||||
force_refresh = self._parse_bool(request.query.get("force")) or self._parse_bool(
|
||||
payload.get("force")
|
||||
)
|
||||
provider = await self._get_civitai_provider()
|
||||
if provider is None:
|
||||
return web.json_response(
|
||||
{"success": False, "error": "Civitai provider not available"}, status=503
|
||||
)
|
||||
|
||||
try:
|
||||
records = await self._update_service.refresh_for_model_type(
|
||||
self._service.model_type,
|
||||
self._service.scanner,
|
||||
provider,
|
||||
force_refresh=force_refresh,
|
||||
)
|
||||
except RateLimitError as exc:
|
||||
return web.json_response(
|
||||
{"success": False, "error": str(exc) or "Rate limited"}, status=429
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
self._logger.error("Failed to refresh model updates: %s", exc, exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
return web.json_response(
|
||||
{
|
||||
"success": True,
|
||||
"records": [self._serialize_record(record) for record in records.values()],
|
||||
}
|
||||
)
|
||||
|
||||
async def set_model_update_ignore(self, request: web.Request) -> web.Response:
|
||||
payload = await self._read_json(request)
|
||||
model_id = self._normalize_model_id(payload.get("modelId"))
|
||||
if model_id is None:
|
||||
return web.json_response({"success": False, "error": "modelId is required"}, status=400)
|
||||
|
||||
should_ignore = self._parse_bool(payload.get("shouldIgnore"))
|
||||
record = await self._update_service.set_should_ignore(
|
||||
self._service.model_type, model_id, should_ignore
|
||||
)
|
||||
return web.json_response({"success": True, "record": self._serialize_record(record)})
|
||||
|
||||
async def get_model_update_status(self, request: web.Request) -> web.Response:
|
||||
model_id = self._normalize_model_id(request.match_info.get("model_id"))
|
||||
if model_id is None:
|
||||
return web.json_response({"success": False, "error": "model_id must be an integer"}, status=400)
|
||||
|
||||
refresh = self._parse_bool(request.query.get("refresh"))
|
||||
force = self._parse_bool(request.query.get("force"))
|
||||
|
||||
try:
|
||||
record = await self._get_or_refresh_record(model_id, refresh=refresh, force=force)
|
||||
except RateLimitError as exc:
|
||||
return web.json_response(
|
||||
{"success": False, "error": str(exc) or "Rate limited"}, status=429
|
||||
)
|
||||
|
||||
if record is None:
|
||||
return web.json_response(
|
||||
{"success": False, "error": "Model not tracked"}, status=404
|
||||
)
|
||||
|
||||
return web.json_response({"success": True, "record": self._serialize_record(record)})
|
||||
|
||||
async def _get_or_refresh_record(
|
||||
self, model_id: int, *, refresh: bool, force: bool
|
||||
) -> Optional[object]:
|
||||
record = await self._update_service.get_record(self._service.model_type, model_id)
|
||||
if record and not refresh and not force:
|
||||
return record
|
||||
|
||||
provider = await self._get_civitai_provider()
|
||||
if provider is None:
|
||||
return record
|
||||
|
||||
return await self._update_service.refresh_single_model(
|
||||
self._service.model_type,
|
||||
model_id,
|
||||
self._service.scanner,
|
||||
provider,
|
||||
force_refresh=force or refresh,
|
||||
)
|
||||
|
||||
async def _get_civitai_provider(self):
|
||||
try:
|
||||
return await self._metadata_provider_selector("civitai_api")
|
||||
except Exception as exc: # pragma: no cover - defensive log
|
||||
self._logger.error("Failed to acquire civitai provider: %s", exc, exc_info=True)
|
||||
return None
|
||||
|
||||
async def _read_json(self, request: web.Request) -> Dict:
|
||||
if not request.can_read_body:
|
||||
return {}
|
||||
try:
|
||||
return await request.json()
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def _parse_bool(value) -> bool:
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
return value.lower() in {"1", "true", "yes"}
|
||||
if isinstance(value, (int, float)):
|
||||
return bool(value)
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _normalize_model_id(value) -> Optional[int]:
|
||||
try:
|
||||
if value is None:
|
||||
return None
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _serialize_record(record) -> Dict:
|
||||
return {
|
||||
"modelType": record.model_type,
|
||||
"modelId": record.model_id,
|
||||
"largestVersionId": record.largest_version_id,
|
||||
"versionIds": record.version_ids,
|
||||
"inLibraryVersionIds": record.in_library_version_ids,
|
||||
"lastCheckedAt": record.last_checked_at,
|
||||
"shouldIgnore": record.should_ignore,
|
||||
"hasUpdate": record.has_update(),
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelHandlerSet:
|
||||
"""Aggregate concrete handlers into a flat mapping."""
|
||||
@@ -1029,6 +1180,7 @@ class ModelHandlerSet:
|
||||
civitai: ModelCivitaiHandler
|
||||
move: ModelMoveHandler
|
||||
auto_organize: ModelAutoOrganizeHandler
|
||||
updates: ModelUpdateHandler
|
||||
|
||||
def to_route_mapping(self) -> Dict[str, Callable[[web.Request], Awaitable[web.Response]]]:
|
||||
return {
|
||||
@@ -1073,5 +1225,8 @@ class ModelHandlerSet:
|
||||
"get_model_metadata": self.query.get_model_metadata,
|
||||
"get_model_description": self.query.get_model_description,
|
||||
"get_relative_paths": self.query.get_relative_paths,
|
||||
"refresh_model_updates": self.updates.refresh_model_updates,
|
||||
"set_model_update_ignore": self.updates.set_model_update_ignore,
|
||||
"get_model_update_status": self.updates.get_model_update_status,
|
||||
}
|
||||
|
||||
|
||||
@@ -24,7 +24,9 @@ class LoraRoutes(BaseModelRoutes):
|
||||
"""Initialize services from ServiceRegistry"""
|
||||
lora_scanner = await ServiceRegistry.get_lora_scanner()
|
||||
self.service = LoraService(lora_scanner)
|
||||
|
||||
update_service = await ServiceRegistry.get_model_update_service()
|
||||
self.set_model_update_service(update_service)
|
||||
|
||||
# Attach service dependencies
|
||||
self.attach_service(self.service)
|
||||
|
||||
|
||||
@@ -55,6 +55,9 @@ COMMON_ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
|
||||
RouteDefinition("GET", "/api/lm/{prefix}/civitai/versions/{model_id}", "get_civitai_versions"),
|
||||
RouteDefinition("GET", "/api/lm/{prefix}/civitai/model/version/{modelVersionId}", "get_civitai_model_by_version"),
|
||||
RouteDefinition("GET", "/api/lm/{prefix}/civitai/model/hash/{hash}", "get_civitai_model_by_hash"),
|
||||
RouteDefinition("POST", "/api/lm/{prefix}/updates/refresh", "refresh_model_updates"),
|
||||
RouteDefinition("POST", "/api/lm/{prefix}/updates/ignore", "set_model_update_ignore"),
|
||||
RouteDefinition("GET", "/api/lm/{prefix}/updates/status/{model_id}", "get_model_update_status"),
|
||||
RouteDefinition("POST", "/api/lm/download-model", "download_model"),
|
||||
RouteDefinition("GET", "/api/lm/download-model-get", "download_model_get"),
|
||||
RouteDefinition("GET", "/api/lm/cancel-download-get", "cancel_download_get"),
|
||||
|
||||
411
py/services/model_update_service.py
Normal file
411
py/services/model_update_service.py
Normal file
@@ -0,0 +1,411 @@
|
||||
"""Service for tracking remote model version updates."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sqlite3
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Iterable, List, Mapping, Optional, Sequence
|
||||
|
||||
from .errors import RateLimitError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelUpdateRecord:
|
||||
"""Representation of a persisted update record."""
|
||||
|
||||
model_type: str
|
||||
model_id: int
|
||||
largest_version_id: Optional[int]
|
||||
version_ids: List[int]
|
||||
in_library_version_ids: List[int]
|
||||
last_checked_at: Optional[float]
|
||||
should_ignore: bool
|
||||
|
||||
def has_update(self) -> bool:
|
||||
"""Return True when remote versions exceed the local library."""
|
||||
|
||||
if self.should_ignore or not self.version_ids:
|
||||
return False
|
||||
local_versions = set(self.in_library_version_ids)
|
||||
return any(version_id not in local_versions for version_id in self.version_ids)
|
||||
|
||||
|
||||
class ModelUpdateService:
|
||||
"""Persist and query remote model version metadata."""
|
||||
|
||||
_SCHEMA = """
|
||||
CREATE TABLE IF NOT EXISTS model_update_status (
|
||||
model_type TEXT NOT NULL,
|
||||
model_id INTEGER NOT NULL,
|
||||
largest_version_id INTEGER,
|
||||
version_ids TEXT,
|
||||
in_library_version_ids TEXT,
|
||||
last_checked_at REAL,
|
||||
should_ignore INTEGER DEFAULT 0,
|
||||
PRIMARY KEY (model_type, model_id)
|
||||
)
|
||||
"""
|
||||
|
||||
def __init__(self, db_path: str, *, ttl_seconds: int = 24 * 60 * 60) -> None:
|
||||
self._db_path = db_path
|
||||
self._ttl_seconds = ttl_seconds
|
||||
self._lock = asyncio.Lock()
|
||||
self._schema_initialized = False
|
||||
self._ensure_directory()
|
||||
self._initialize_schema()
|
||||
|
||||
def _ensure_directory(self) -> None:
|
||||
directory = os.path.dirname(self._db_path)
|
||||
if directory:
|
||||
os.makedirs(directory, exist_ok=True)
|
||||
|
||||
def _connect(self) -> sqlite3.Connection:
|
||||
conn = sqlite3.connect(self._db_path, check_same_thread=False)
|
||||
conn.row_factory = sqlite3.Row
|
||||
return conn
|
||||
|
||||
def _initialize_schema(self) -> None:
|
||||
if self._schema_initialized:
|
||||
return
|
||||
try:
|
||||
with self._connect() as conn:
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
conn.execute("PRAGMA foreign_keys = ON")
|
||||
conn.executescript(self._SCHEMA)
|
||||
self._schema_initialized = True
|
||||
except Exception as exc: # pragma: no cover - defensive guard
|
||||
logger.error("Failed to initialize update schema: %s", exc, exc_info=True)
|
||||
raise
|
||||
|
||||
async def refresh_for_model_type(
|
||||
self,
|
||||
model_type: str,
|
||||
scanner,
|
||||
metadata_provider,
|
||||
*,
|
||||
force_refresh: bool = False,
|
||||
) -> Dict[int, ModelUpdateRecord]:
|
||||
"""Refresh update information for every model present in the cache."""
|
||||
|
||||
local_versions = await self._collect_local_versions(scanner)
|
||||
results: Dict[int, ModelUpdateRecord] = {}
|
||||
for model_id, version_ids in local_versions.items():
|
||||
record = await self._refresh_single_model(
|
||||
model_type,
|
||||
model_id,
|
||||
version_ids,
|
||||
metadata_provider,
|
||||
force_refresh=force_refresh,
|
||||
)
|
||||
if record:
|
||||
results[model_id] = record
|
||||
return results
|
||||
|
||||
async def refresh_single_model(
|
||||
self,
|
||||
model_type: str,
|
||||
model_id: int,
|
||||
scanner,
|
||||
metadata_provider,
|
||||
*,
|
||||
force_refresh: bool = False,
|
||||
) -> Optional[ModelUpdateRecord]:
|
||||
"""Refresh update information for a specific model id."""
|
||||
|
||||
local_versions = await self._collect_local_versions(scanner)
|
||||
version_ids = local_versions.get(model_id, [])
|
||||
return await self._refresh_single_model(
|
||||
model_type,
|
||||
model_id,
|
||||
version_ids,
|
||||
metadata_provider,
|
||||
force_refresh=force_refresh,
|
||||
)
|
||||
|
||||
async def update_in_library_versions(
|
||||
self,
|
||||
model_type: str,
|
||||
model_id: int,
|
||||
version_ids: Sequence[int],
|
||||
) -> ModelUpdateRecord:
|
||||
"""Persist a new set of in-library version identifiers."""
|
||||
|
||||
normalized_versions = self._normalize_sequence(version_ids)
|
||||
async with self._lock:
|
||||
existing = self._get_record(model_type, model_id)
|
||||
record = ModelUpdateRecord(
|
||||
model_type=model_type,
|
||||
model_id=model_id,
|
||||
largest_version_id=existing.largest_version_id if existing else None,
|
||||
version_ids=list(existing.version_ids) if existing else [],
|
||||
in_library_version_ids=normalized_versions,
|
||||
last_checked_at=existing.last_checked_at if existing else None,
|
||||
should_ignore=existing.should_ignore if existing else False,
|
||||
)
|
||||
self._upsert_record(record)
|
||||
return record
|
||||
|
||||
async def set_should_ignore(
|
||||
self, model_type: str, model_id: int, should_ignore: bool
|
||||
) -> ModelUpdateRecord:
|
||||
"""Toggle the ignore flag for a model."""
|
||||
|
||||
async with self._lock:
|
||||
existing = self._get_record(model_type, model_id)
|
||||
if existing:
|
||||
record = ModelUpdateRecord(
|
||||
model_type=model_type,
|
||||
model_id=model_id,
|
||||
largest_version_id=existing.largest_version_id,
|
||||
version_ids=list(existing.version_ids),
|
||||
in_library_version_ids=list(existing.in_library_version_ids),
|
||||
last_checked_at=existing.last_checked_at,
|
||||
should_ignore=should_ignore,
|
||||
)
|
||||
else:
|
||||
record = ModelUpdateRecord(
|
||||
model_type=model_type,
|
||||
model_id=model_id,
|
||||
largest_version_id=None,
|
||||
version_ids=[],
|
||||
in_library_version_ids=[],
|
||||
last_checked_at=None,
|
||||
should_ignore=should_ignore,
|
||||
)
|
||||
self._upsert_record(record)
|
||||
return record
|
||||
|
||||
async def get_record(self, model_type: str, model_id: int) -> Optional[ModelUpdateRecord]:
|
||||
"""Return a cached record without triggering remote fetches."""
|
||||
|
||||
async with self._lock:
|
||||
return self._get_record(model_type, model_id)
|
||||
|
||||
async def has_update(self, model_type: str, model_id: int) -> bool:
|
||||
"""Determine if a model has updates pending."""
|
||||
|
||||
record = await self.get_record(model_type, model_id)
|
||||
return record.has_update() if record else False
|
||||
|
||||
async def _refresh_single_model(
|
||||
self,
|
||||
model_type: str,
|
||||
model_id: int,
|
||||
local_versions: Sequence[int],
|
||||
metadata_provider,
|
||||
*,
|
||||
force_refresh: bool = False,
|
||||
) -> Optional[ModelUpdateRecord]:
|
||||
normalized_local = self._normalize_sequence(local_versions)
|
||||
now = time.time()
|
||||
async with self._lock:
|
||||
existing = self._get_record(model_type, model_id)
|
||||
if existing and existing.should_ignore and not force_refresh:
|
||||
record = ModelUpdateRecord(
|
||||
model_type=model_type,
|
||||
model_id=model_id,
|
||||
largest_version_id=existing.largest_version_id,
|
||||
version_ids=list(existing.version_ids),
|
||||
in_library_version_ids=normalized_local,
|
||||
last_checked_at=existing.last_checked_at,
|
||||
should_ignore=True,
|
||||
)
|
||||
self._upsert_record(record)
|
||||
return record
|
||||
|
||||
should_fetch = force_refresh or not existing or self._is_stale(existing, now)
|
||||
# release lock during network request
|
||||
fetched_versions: List[int] | None = None
|
||||
refresh_succeeded = False
|
||||
if metadata_provider and should_fetch:
|
||||
try:
|
||||
response = await metadata_provider.get_model_versions(model_id)
|
||||
except RateLimitError:
|
||||
raise
|
||||
except Exception as exc: # pragma: no cover - defensive log
|
||||
logger.error(
|
||||
"Failed to fetch versions for model %s (%s): %s",
|
||||
model_id,
|
||||
model_type,
|
||||
exc,
|
||||
exc_info=True,
|
||||
)
|
||||
else:
|
||||
if response is not None:
|
||||
extracted = self._extract_version_ids(response)
|
||||
if extracted is not None:
|
||||
fetched_versions = extracted
|
||||
refresh_succeeded = True
|
||||
|
||||
async with self._lock:
|
||||
existing = self._get_record(model_type, model_id)
|
||||
if existing and existing.should_ignore and not force_refresh:
|
||||
# Ignore state could have flipped while awaiting provider
|
||||
record = ModelUpdateRecord(
|
||||
model_type=model_type,
|
||||
model_id=model_id,
|
||||
largest_version_id=existing.largest_version_id,
|
||||
version_ids=list(existing.version_ids),
|
||||
in_library_version_ids=normalized_local,
|
||||
last_checked_at=existing.last_checked_at,
|
||||
should_ignore=True,
|
||||
)
|
||||
self._upsert_record(record)
|
||||
return record
|
||||
|
||||
version_ids = (
|
||||
fetched_versions
|
||||
if refresh_succeeded
|
||||
else (list(existing.version_ids) if existing else [])
|
||||
)
|
||||
largest = max(version_ids) if version_ids else None
|
||||
last_checked = now if refresh_succeeded else (
|
||||
existing.last_checked_at if existing else None
|
||||
)
|
||||
record = ModelUpdateRecord(
|
||||
model_type=model_type,
|
||||
model_id=model_id,
|
||||
largest_version_id=largest,
|
||||
version_ids=version_ids,
|
||||
in_library_version_ids=normalized_local,
|
||||
last_checked_at=last_checked,
|
||||
should_ignore=existing.should_ignore if existing else False,
|
||||
)
|
||||
self._upsert_record(record)
|
||||
return record
|
||||
|
||||
async def _collect_local_versions(self, scanner) -> Dict[int, List[int]]:
|
||||
cache = await scanner.get_cached_data()
|
||||
mapping: Dict[int, set[int]] = {}
|
||||
if not cache or not getattr(cache, "raw_data", None):
|
||||
return {}
|
||||
|
||||
for item in cache.raw_data:
|
||||
civitai = item.get("civitai") if isinstance(item, dict) else None
|
||||
if not isinstance(civitai, dict):
|
||||
continue
|
||||
model_id = self._normalize_int(civitai.get("modelId"))
|
||||
version_id = self._normalize_int(civitai.get("id"))
|
||||
if model_id is None or version_id is None:
|
||||
continue
|
||||
mapping.setdefault(model_id, set()).add(version_id)
|
||||
|
||||
return {model_id: sorted(ids) for model_id, ids in mapping.items()}
|
||||
|
||||
def _is_stale(self, record: ModelUpdateRecord, now: float) -> bool:
|
||||
if record.last_checked_at is None:
|
||||
return True
|
||||
return (now - record.last_checked_at) >= self._ttl_seconds
|
||||
|
||||
@staticmethod
|
||||
def _normalize_int(value) -> Optional[int]:
|
||||
try:
|
||||
if value is None:
|
||||
return None
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
def _normalize_sequence(self, values: Sequence[int]) -> List[int]:
|
||||
normalized = [
|
||||
item
|
||||
for item in (self._normalize_int(value) for value in values)
|
||||
if item is not None
|
||||
]
|
||||
return sorted(dict.fromkeys(normalized))
|
||||
|
||||
def _extract_version_ids(self, response) -> Optional[List[int]]:
|
||||
if not isinstance(response, Mapping):
|
||||
return None
|
||||
versions = response.get("modelVersions")
|
||||
if versions is None:
|
||||
return []
|
||||
if not isinstance(versions, Iterable):
|
||||
return None
|
||||
normalized = []
|
||||
for entry in versions:
|
||||
if isinstance(entry, Mapping):
|
||||
normalized_id = self._normalize_int(entry.get("id"))
|
||||
else:
|
||||
normalized_id = self._normalize_int(entry)
|
||||
if normalized_id is not None:
|
||||
normalized.append(normalized_id)
|
||||
return sorted(dict.fromkeys(normalized))
|
||||
|
||||
def _get_record(self, model_type: str, model_id: int) -> Optional[ModelUpdateRecord]:
|
||||
with self._connect() as conn:
|
||||
row = conn.execute(
|
||||
"""
|
||||
SELECT model_type, model_id, largest_version_id, version_ids,
|
||||
in_library_version_ids, last_checked_at, should_ignore
|
||||
FROM model_update_status
|
||||
WHERE model_type = ? AND model_id = ?
|
||||
""",
|
||||
(model_type, model_id),
|
||||
).fetchone()
|
||||
if not row:
|
||||
return None
|
||||
return ModelUpdateRecord(
|
||||
model_type=row["model_type"],
|
||||
model_id=int(row["model_id"]),
|
||||
largest_version_id=self._normalize_int(row["largest_version_id"]),
|
||||
version_ids=self._deserialize_json_array(row["version_ids"]),
|
||||
in_library_version_ids=self._deserialize_json_array(
|
||||
row["in_library_version_ids"]
|
||||
),
|
||||
last_checked_at=row["last_checked_at"],
|
||||
should_ignore=bool(row["should_ignore"]),
|
||||
)
|
||||
|
||||
def _upsert_record(self, record: ModelUpdateRecord) -> None:
|
||||
payload = (
|
||||
record.model_type,
|
||||
record.model_id,
|
||||
record.largest_version_id,
|
||||
json.dumps(record.version_ids),
|
||||
json.dumps(record.in_library_version_ids),
|
||||
record.last_checked_at,
|
||||
1 if record.should_ignore else 0,
|
||||
)
|
||||
with self._connect() as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO model_update_status (
|
||||
model_type, model_id, largest_version_id, version_ids,
|
||||
in_library_version_ids, last_checked_at, should_ignore
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(model_type, model_id) DO UPDATE SET
|
||||
largest_version_id = excluded.largest_version_id,
|
||||
version_ids = excluded.version_ids,
|
||||
in_library_version_ids = excluded.in_library_version_ids,
|
||||
last_checked_at = excluded.last_checked_at,
|
||||
should_ignore = excluded.should_ignore
|
||||
""",
|
||||
payload,
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
@staticmethod
|
||||
def _deserialize_json_array(value) -> List[int]:
|
||||
if not value:
|
||||
return []
|
||||
try:
|
||||
data = json.loads(value)
|
||||
except (TypeError, json.JSONDecodeError):
|
||||
return []
|
||||
if isinstance(data, list):
|
||||
normalized = []
|
||||
for entry in data:
|
||||
try:
|
||||
normalized.append(int(entry))
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
return sorted(dict.fromkeys(normalized))
|
||||
return []
|
||||
|
||||
@@ -81,6 +81,11 @@ class PersistentModelCache:
|
||||
def is_enabled(self) -> bool:
|
||||
return os.environ.get("LORA_MANAGER_DISABLE_PERSISTENT_CACHE", "0") != "1"
|
||||
|
||||
def get_database_path(self) -> str:
|
||||
"""Expose the resolved SQLite database path."""
|
||||
|
||||
return self._db_path
|
||||
|
||||
def load_cache(self, model_type: str) -> Optional[PersistedCacheData]:
|
||||
if not self.is_enabled():
|
||||
return None
|
||||
|
||||
@@ -128,23 +128,45 @@ class ServiceRegistry:
|
||||
async def get_civitai_client(cls):
|
||||
"""Get or create CivitAI client instance"""
|
||||
service_name = "civitai_client"
|
||||
|
||||
|
||||
if service_name in cls._services:
|
||||
return cls._services[service_name]
|
||||
|
||||
|
||||
async with cls._get_lock(service_name):
|
||||
# Double-check after acquiring lock
|
||||
if service_name in cls._services:
|
||||
return cls._services[service_name]
|
||||
|
||||
|
||||
# Import here to avoid circular imports
|
||||
from .civitai_client import CivitaiClient
|
||||
|
||||
|
||||
client = await CivitaiClient.get_instance()
|
||||
cls._services[service_name] = client
|
||||
logger.debug(f"Created and registered {service_name}")
|
||||
return client
|
||||
|
||||
@classmethod
|
||||
async def get_model_update_service(cls):
|
||||
"""Get or create the model update tracking service."""
|
||||
|
||||
service_name = "model_update_service"
|
||||
|
||||
if service_name in cls._services:
|
||||
return cls._services[service_name]
|
||||
|
||||
async with cls._get_lock(service_name):
|
||||
if service_name in cls._services:
|
||||
return cls._services[service_name]
|
||||
|
||||
from .model_update_service import ModelUpdateService
|
||||
from .persistent_model_cache import get_persistent_cache
|
||||
|
||||
cache = get_persistent_cache()
|
||||
service = ModelUpdateService(cache.get_database_path())
|
||||
cls._services[service_name] = service
|
||||
logger.debug(f"Created and registered {service_name}")
|
||||
return service
|
||||
|
||||
@classmethod
|
||||
async def get_civarchive_client(cls):
|
||||
"""Get or create CivArchive client instance"""
|
||||
|
||||
@@ -5,6 +5,7 @@ import sys
|
||||
from pathlib import Path
|
||||
|
||||
import types
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
folder_paths_stub = types.SimpleNamespace(get_folder_paths=lambda *_: [])
|
||||
@@ -32,6 +33,41 @@ class DummyRoutes(BaseModelRoutes):
|
||||
def setup_specific_routes(self, registrar, prefix: str) -> None: # pragma: no cover - no extra routes in smoke tests
|
||||
return None
|
||||
|
||||
def __init__(self, service=None):
|
||||
super().__init__(service)
|
||||
self.set_model_update_service(NullModelUpdateService())
|
||||
|
||||
|
||||
@dataclass
|
||||
class NullUpdateRecord:
|
||||
model_type: str
|
||||
model_id: int
|
||||
largest_version_id: int | None = None
|
||||
version_ids: list[int] = field(default_factory=list)
|
||||
in_library_version_ids: list[int] = field(default_factory=list)
|
||||
last_checked_at: float | None = None
|
||||
should_ignore: bool = False
|
||||
|
||||
def has_update(self) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
class NullModelUpdateService:
|
||||
async def refresh_for_model_type(self, *args, **kwargs):
|
||||
return {}
|
||||
|
||||
async def refresh_single_model(self, *args, **kwargs):
|
||||
return None
|
||||
|
||||
async def update_in_library_versions(self, model_type, model_id, version_ids):
|
||||
return NullUpdateRecord(model_type=model_type, model_id=model_id, in_library_version_ids=list(version_ids))
|
||||
|
||||
async def set_should_ignore(self, model_type, model_id, should_ignore):
|
||||
return NullUpdateRecord(model_type=model_type, model_id=model_id, should_ignore=should_ignore)
|
||||
|
||||
async def get_record(self, *args, **kwargs):
|
||||
return None
|
||||
|
||||
|
||||
async def create_test_client(service) -> TestClient:
|
||||
routes = DummyRoutes(service)
|
||||
|
||||
85
tests/services/test_model_update_service.py
Normal file
85
tests/services/test_model_update_service.py
Normal file
@@ -0,0 +1,85 @@
|
||||
import asyncio
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from py.services.model_update_service import ModelUpdateService
|
||||
|
||||
|
||||
class DummyScanner:
|
||||
def __init__(self, raw_data):
|
||||
self._cache = SimpleNamespace(raw_data=raw_data)
|
||||
|
||||
async def get_cached_data(self, *args, **kwargs):
|
||||
return self._cache
|
||||
|
||||
|
||||
class DummyProvider:
|
||||
def __init__(self, response):
|
||||
self.response = response
|
||||
self.calls: int = 0
|
||||
|
||||
async def get_model_versions(self, model_id):
|
||||
self.calls += 1
|
||||
return self.response
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_persists_versions_and_uses_cache(tmp_path):
|
||||
db_path = tmp_path / "updates.sqlite"
|
||||
service = ModelUpdateService(str(db_path), ttl_seconds=3600)
|
||||
raw_data = [
|
||||
{"civitai": {"modelId": 1, "id": 11}},
|
||||
{"civitai": {"modelId": 1, "id": 15}},
|
||||
]
|
||||
scanner = DummyScanner(raw_data)
|
||||
provider = DummyProvider({"modelVersions": [{"id": 11}, {"id": 15}]})
|
||||
|
||||
await service.refresh_for_model_type("lora", scanner, provider)
|
||||
record = await service.get_record("lora", 1)
|
||||
|
||||
assert provider.calls == 1
|
||||
assert record is not None
|
||||
assert record.version_ids == [11, 15]
|
||||
assert record.in_library_version_ids == [11, 15]
|
||||
assert record.has_update() is False
|
||||
|
||||
await service.refresh_for_model_type("lora", scanner, provider)
|
||||
assert provider.calls == 1, "provider should not be called again within TTL"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_respects_ignore_flag(tmp_path):
|
||||
db_path = tmp_path / "updates.sqlite"
|
||||
service = ModelUpdateService(str(db_path), ttl_seconds=3600)
|
||||
raw_data = [{"civitai": {"modelId": 2, "id": 21}}]
|
||||
scanner = DummyScanner(raw_data)
|
||||
provider = DummyProvider({"modelVersions": [{"id": 21}, {"id": 22}]})
|
||||
|
||||
await service.refresh_for_model_type("lora", scanner, provider)
|
||||
await service.set_should_ignore("lora", 2, True)
|
||||
|
||||
provider.calls = 0
|
||||
await service.refresh_for_model_type("lora", scanner, provider)
|
||||
assert provider.calls == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_in_library_versions_changes_update_state(tmp_path):
|
||||
db_path = tmp_path / "updates.sqlite"
|
||||
service = ModelUpdateService(str(db_path), ttl_seconds=1)
|
||||
raw_data = [{"civitai": {"modelId": 3, "id": 31}}]
|
||||
scanner = DummyScanner(raw_data)
|
||||
provider = DummyProvider({"modelVersions": [{"id": 31}, {"id": 35}]})
|
||||
|
||||
await service.refresh_for_model_type("lora", scanner, provider)
|
||||
await service.update_in_library_versions("lora", 3, [31])
|
||||
record = await service.get_record("lora", 3)
|
||||
|
||||
assert record is not None
|
||||
assert record.has_update() is True
|
||||
|
||||
await service.update_in_library_versions("lora", 3, [31, 35])
|
||||
record = await service.get_record("lora", 3)
|
||||
|
||||
assert record.has_update() is False
|
||||
Reference in New Issue
Block a user