From ee0d241c75adbe6cd63bba636d4b0f0229bde8be Mon Sep 17 00:00:00 2001 From: pixelpaws Date: Wed, 15 Oct 2025 15:37:35 +0800 Subject: [PATCH] refactor(routes): limit update endpoints to essentials --- py/routes/base_model_routes.py | 27 +- py/routes/checkpoint_routes.py | 4 +- py/routes/embedding_routes.py | 4 +- py/routes/handlers/model_handlers.py | 155 +++++++ py/routes/lora_routes.py | 4 +- py/routes/model_route_registrar.py | 3 + py/services/model_update_service.py | 404 +++++++++++++++++++ py/services/persistent_model_cache.py | 5 + py/services/service_registry.py | 30 +- tests/routes/test_base_model_routes_smoke.py | 36 ++ tests/services/test_model_update_service.py | 85 ++++ 11 files changed, 749 insertions(+), 8 deletions(-) create mode 100644 py/services/model_update_service.py create mode 100644 tests/services/test_model_update_service.py diff --git a/py/routes/base_model_routes.py b/py/routes/base_model_routes.py index cb4b5d02..5eb6d1cc 100644 --- a/py/routes/base_model_routes.py +++ b/py/routes/base_model_routes.py @@ -2,7 +2,7 @@ from __future__ import annotations import logging from abc import ABC, abstractmethod -from typing import Callable, Dict, Mapping +from typing import TYPE_CHECKING, Callable, Dict, Mapping import jinja2 from aiohttp import web @@ -42,8 +42,12 @@ from .handlers.model_handlers import ( ModelMoveHandler, ModelPageView, ModelQueryHandler, + ModelUpdateHandler, ) +if TYPE_CHECKING: + from ..services.model_update_service import ModelUpdateService + logger = logging.getLogger(__name__) @@ -99,10 +103,18 @@ class BaseModelRoutes(ABC): ws_manager=self._ws_manager, download_manager_factory=ServiceRegistry.get_download_manager, ) + self._model_update_service: ModelUpdateService | None = None if service is not None: self.attach_service(service) + def set_model_update_service(self, service: "ModelUpdateService") -> None: + """Attach the model update tracking service.""" + + self._model_update_service = service + self._handler_set = None + self._handler_mapping = None + def attach_service(self, service) -> None: """Attach a model service and rebuild handler dependencies.""" self.service = service @@ -127,6 +139,7 @@ class BaseModelRoutes(ABC): def _create_handler_set(self) -> ModelHandlerSet: service = self._ensure_service() + update_service = self._ensure_model_update_service() page_view = ModelPageView( template_env=self.template_env, template_name=self.template_name or "", @@ -186,6 +199,12 @@ class BaseModelRoutes(ABC): ws_manager=self._ws_manager, logger=logger, ) + updates = ModelUpdateHandler( + service=service, + update_service=update_service, + metadata_provider_selector=get_metadata_provider, + logger=logger, + ) return ModelHandlerSet( page_view=page_view, listing=listing, @@ -195,6 +214,7 @@ class BaseModelRoutes(ABC): civitai=civitai, move=move, auto_organize=auto_organize, + updates=updates, ) @property @@ -273,3 +293,8 @@ class BaseModelRoutes(ABC): return proxy + def _ensure_model_update_service(self) -> "ModelUpdateService": + if self._model_update_service is None: + raise RuntimeError("Model update service has not been attached") + return self._model_update_service + diff --git a/py/routes/checkpoint_routes.py b/py/routes/checkpoint_routes.py index ad4c538a..5a17d79a 100644 --- a/py/routes/checkpoint_routes.py +++ b/py/routes/checkpoint_routes.py @@ -21,7 +21,9 @@ class CheckpointRoutes(BaseModelRoutes): """Initialize services from ServiceRegistry""" checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner() self.service = CheckpointService(checkpoint_scanner) - + update_service = await ServiceRegistry.get_model_update_service() + self.set_model_update_service(update_service) + # Attach service dependencies self.attach_service(self.service) diff --git a/py/routes/embedding_routes.py b/py/routes/embedding_routes.py index d7d361ce..80b15525 100644 --- a/py/routes/embedding_routes.py +++ b/py/routes/embedding_routes.py @@ -20,7 +20,9 @@ class EmbeddingRoutes(BaseModelRoutes): """Initialize services from ServiceRegistry""" embedding_scanner = await ServiceRegistry.get_embedding_scanner() self.service = EmbeddingService(embedding_scanner) - + update_service = await ServiceRegistry.get_model_update_service() + self.set_model_update_service(update_service) + # Attach service dependencies self.attach_service(self.service) diff --git a/py/routes/handlers/model_handlers.py b/py/routes/handlers/model_handlers.py index 16089b66..4df9941a 100644 --- a/py/routes/handlers/model_handlers.py +++ b/py/routes/handlers/model_handlers.py @@ -29,6 +29,7 @@ from ...services.use_cases import ( ) from ...services.websocket_manager import WebSocketManager from ...services.websocket_progress_callback import WebSocketProgressCallback +from ...services.errors import RateLimitError from ...utils.file_utils import calculate_sha256 from ...utils.metadata_manager import MetadataManager @@ -1017,6 +1018,156 @@ class ModelAutoOrganizeHandler: return web.json_response({"success": False, "error": str(exc)}, status=500) +class ModelUpdateHandler: + """Handle update tracking requests.""" + + def __init__( + self, + *, + service, + update_service, + metadata_provider_selector, + logger: logging.Logger, + ) -> None: + self._service = service + self._update_service = update_service + self._metadata_provider_selector = metadata_provider_selector + self._logger = logger + + async def refresh_model_updates(self, request: web.Request) -> web.Response: + payload = await self._read_json(request) + force_refresh = self._parse_bool(request.query.get("force")) or self._parse_bool( + payload.get("force") + ) + provider = await self._get_civitai_provider() + if provider is None: + return web.json_response( + {"success": False, "error": "Civitai provider not available"}, status=503 + ) + + try: + records = await self._update_service.refresh_for_model_type( + self._service.model_type, + self._service.scanner, + provider, + force_refresh=force_refresh, + ) + except RateLimitError as exc: + return web.json_response( + {"success": False, "error": str(exc) or "Rate limited"}, status=429 + ) + except Exception as exc: # pragma: no cover - defensive logging + self._logger.error("Failed to refresh model updates: %s", exc, exc_info=True) + return web.json_response({"success": False, "error": str(exc)}, status=500) + + return web.json_response( + { + "success": True, + "records": [self._serialize_record(record) for record in records.values()], + } + ) + + async def set_model_update_ignore(self, request: web.Request) -> web.Response: + payload = await self._read_json(request) + model_id = self._normalize_model_id(payload.get("modelId")) + if model_id is None: + return web.json_response({"success": False, "error": "modelId is required"}, status=400) + + should_ignore = self._parse_bool(payload.get("shouldIgnore")) + record = await self._update_service.set_should_ignore( + self._service.model_type, model_id, should_ignore + ) + return web.json_response({"success": True, "record": self._serialize_record(record)}) + + async def get_model_update_status(self, request: web.Request) -> web.Response: + model_id = self._normalize_model_id(request.match_info.get("model_id")) + if model_id is None: + return web.json_response({"success": False, "error": "model_id must be an integer"}, status=400) + + refresh = self._parse_bool(request.query.get("refresh")) + force = self._parse_bool(request.query.get("force")) + + try: + record = await self._get_or_refresh_record(model_id, refresh=refresh, force=force) + except RateLimitError as exc: + return web.json_response( + {"success": False, "error": str(exc) or "Rate limited"}, status=429 + ) + + if record is None: + return web.json_response( + {"success": False, "error": "Model not tracked"}, status=404 + ) + + return web.json_response({"success": True, "record": self._serialize_record(record)}) + + async def _get_or_refresh_record( + self, model_id: int, *, refresh: bool, force: bool + ) -> Optional[object]: + record = await self._update_service.get_record(self._service.model_type, model_id) + if record and not refresh and not force: + return record + + provider = await self._get_civitai_provider() + if provider is None: + return record + + return await self._update_service.refresh_single_model( + self._service.model_type, + model_id, + self._service.scanner, + provider, + force_refresh=force or refresh, + ) + + async def _get_civitai_provider(self): + try: + return await self._metadata_provider_selector("civitai_api") + except Exception as exc: # pragma: no cover - defensive log + self._logger.error("Failed to acquire civitai provider: %s", exc, exc_info=True) + return None + + async def _read_json(self, request: web.Request) -> Dict: + if not request.can_read_body: + return {} + try: + return await request.json() + except Exception: + return {} + + @staticmethod + def _parse_bool(value) -> bool: + if isinstance(value, bool): + return value + if isinstance(value, str): + return value.lower() in {"1", "true", "yes"} + if isinstance(value, (int, float)): + return bool(value) + return False + + @staticmethod + def _normalize_model_id(value) -> Optional[int]: + try: + if value is None: + return None + return int(value) + except (TypeError, ValueError): + return None + + @staticmethod + def _serialize_record(record) -> Dict: + return { + "modelType": record.model_type, + "modelId": record.model_id, + "largestVersionId": record.largest_version_id, + "versionIds": record.version_ids, + "inLibraryVersionIds": record.in_library_version_ids, + "lastCheckedAt": record.last_checked_at, + "shouldIgnore": record.should_ignore, + "hasUpdate": record.has_update(), + } + + @dataclass class ModelHandlerSet: """Aggregate concrete handlers into a flat mapping.""" @@ -1029,6 +1180,7 @@ class ModelHandlerSet: civitai: ModelCivitaiHandler move: ModelMoveHandler auto_organize: ModelAutoOrganizeHandler + updates: ModelUpdateHandler def to_route_mapping(self) -> Dict[str, Callable[[web.Request], Awaitable[web.Response]]]: return { @@ -1073,5 +1225,8 @@ class ModelHandlerSet: "get_model_metadata": self.query.get_model_metadata, "get_model_description": self.query.get_model_description, "get_relative_paths": self.query.get_relative_paths, + "refresh_model_updates": self.updates.refresh_model_updates, + "set_model_update_ignore": self.updates.set_model_update_ignore, + "get_model_update_status": self.updates.get_model_update_status, } diff --git a/py/routes/lora_routes.py b/py/routes/lora_routes.py index 15b6a0b7..c966cd5f 100644 --- a/py/routes/lora_routes.py +++ b/py/routes/lora_routes.py @@ -24,7 +24,9 @@ class LoraRoutes(BaseModelRoutes): """Initialize services from ServiceRegistry""" lora_scanner = await ServiceRegistry.get_lora_scanner() self.service = LoraService(lora_scanner) - + update_service = await ServiceRegistry.get_model_update_service() + self.set_model_update_service(update_service) + # Attach service dependencies self.attach_service(self.service) diff --git a/py/routes/model_route_registrar.py b/py/routes/model_route_registrar.py index 105e5f09..ff57672a 100644 --- a/py/routes/model_route_registrar.py +++ b/py/routes/model_route_registrar.py @@ -55,6 +55,9 @@ COMMON_ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = ( RouteDefinition("GET", "/api/lm/{prefix}/civitai/versions/{model_id}", "get_civitai_versions"), RouteDefinition("GET", "/api/lm/{prefix}/civitai/model/version/{modelVersionId}", "get_civitai_model_by_version"), RouteDefinition("GET", "/api/lm/{prefix}/civitai/model/hash/{hash}", "get_civitai_model_by_hash"), + RouteDefinition("POST", "/api/lm/{prefix}/updates/refresh", "refresh_model_updates"), + RouteDefinition("POST", "/api/lm/{prefix}/updates/ignore", "set_model_update_ignore"), + RouteDefinition("GET", "/api/lm/{prefix}/updates/status/{model_id}", "get_model_update_status"), RouteDefinition("POST", "/api/lm/download-model", "download_model"), RouteDefinition("GET", "/api/lm/download-model-get", "download_model_get"), RouteDefinition("GET", "/api/lm/cancel-download-get", "cancel_download_get"), diff --git a/py/services/model_update_service.py b/py/services/model_update_service.py new file mode 100644 index 00000000..11251eed --- /dev/null +++ b/py/services/model_update_service.py @@ -0,0 +1,404 @@ +"""Service for tracking remote model version updates.""" +from __future__ import annotations + +import asyncio +import json +import logging +import os +import sqlite3 +import time +from dataclasses import dataclass +from typing import Dict, Iterable, List, Mapping, Optional, Sequence + +from .errors import RateLimitError + +logger = logging.getLogger(__name__) + + +@dataclass +class ModelUpdateRecord: + """Representation of a persisted update record.""" + + model_type: str + model_id: int + largest_version_id: Optional[int] + version_ids: List[int] + in_library_version_ids: List[int] + last_checked_at: Optional[float] + should_ignore: bool + + def has_update(self) -> bool: + """Return True when remote versions exceed the local library.""" + + if self.should_ignore or not self.version_ids: + return False + local_versions = set(self.in_library_version_ids) + return any(version_id not in local_versions for version_id in self.version_ids) + + +class ModelUpdateService: + """Persist and query remote model version metadata.""" + + _SCHEMA = """ + CREATE TABLE IF NOT EXISTS model_update_status ( + model_type TEXT NOT NULL, + model_id INTEGER NOT NULL, + largest_version_id INTEGER, + version_ids TEXT, + in_library_version_ids TEXT, + last_checked_at REAL, + should_ignore INTEGER DEFAULT 0, + PRIMARY KEY (model_type, model_id) + ) + """ + + def __init__(self, db_path: str, *, ttl_seconds: int = 24 * 60 * 60) -> None: + self._db_path = db_path + self._ttl_seconds = ttl_seconds + self._lock = asyncio.Lock() + self._schema_initialized = False + self._ensure_directory() + self._initialize_schema() + + def _ensure_directory(self) -> None: + directory = os.path.dirname(self._db_path) + if directory: + os.makedirs(directory, exist_ok=True) + + def _connect(self) -> sqlite3.Connection: + conn = sqlite3.connect(self._db_path, check_same_thread=False) + conn.row_factory = sqlite3.Row + return conn + + def _initialize_schema(self) -> None: + if self._schema_initialized: + return + try: + with self._connect() as conn: + conn.execute("PRAGMA journal_mode=WAL") + conn.execute("PRAGMA foreign_keys = ON") + conn.executescript(self._SCHEMA) + self._schema_initialized = True + except Exception as exc: # pragma: no cover - defensive guard + logger.error("Failed to initialize update schema: %s", exc, exc_info=True) + raise + + async def refresh_for_model_type( + self, + model_type: str, + scanner, + metadata_provider, + *, + force_refresh: bool = False, + ) -> Dict[int, ModelUpdateRecord]: + """Refresh update information for every model present in the cache.""" + + local_versions = await self._collect_local_versions(scanner) + results: Dict[int, ModelUpdateRecord] = {} + for model_id, version_ids in local_versions.items(): + record = await self._refresh_single_model( + model_type, + model_id, + version_ids, + metadata_provider, + force_refresh=force_refresh, + ) + if record: + results[model_id] = record + return results + + async def refresh_single_model( + self, + model_type: str, + model_id: int, + scanner, + metadata_provider, + *, + force_refresh: bool = False, + ) -> Optional[ModelUpdateRecord]: + """Refresh update information for a specific model id.""" + + local_versions = await self._collect_local_versions(scanner) + version_ids = local_versions.get(model_id, []) + return await self._refresh_single_model( + model_type, + model_id, + version_ids, + metadata_provider, + force_refresh=force_refresh, + ) + + async def update_in_library_versions( + self, + model_type: str, + model_id: int, + version_ids: Sequence[int], + ) -> ModelUpdateRecord: + """Persist a new set of in-library version identifiers.""" + + normalized_versions = self._normalize_sequence(version_ids) + async with self._lock: + existing = self._get_record(model_type, model_id) + record = ModelUpdateRecord( + model_type=model_type, + model_id=model_id, + largest_version_id=existing.largest_version_id if existing else None, + version_ids=list(existing.version_ids) if existing else [], + in_library_version_ids=normalized_versions, + last_checked_at=existing.last_checked_at if existing else None, + should_ignore=existing.should_ignore if existing else False, + ) + self._upsert_record(record) + return record + + async def set_should_ignore( + self, model_type: str, model_id: int, should_ignore: bool + ) -> ModelUpdateRecord: + """Toggle the ignore flag for a model.""" + + async with self._lock: + existing = self._get_record(model_type, model_id) + if existing: + record = ModelUpdateRecord( + model_type=model_type, + model_id=model_id, + largest_version_id=existing.largest_version_id, + version_ids=list(existing.version_ids), + in_library_version_ids=list(existing.in_library_version_ids), + last_checked_at=existing.last_checked_at, + should_ignore=should_ignore, + ) + else: + record = ModelUpdateRecord( + model_type=model_type, + model_id=model_id, + largest_version_id=None, + version_ids=[], + in_library_version_ids=[], + last_checked_at=None, + should_ignore=should_ignore, + ) + self._upsert_record(record) + return record + + async def get_record(self, model_type: str, model_id: int) -> Optional[ModelUpdateRecord]: + """Return a cached record without triggering remote fetches.""" + + async with self._lock: + return self._get_record(model_type, model_id) + + async def has_update(self, model_type: str, model_id: int) -> bool: + """Determine if a model has updates pending.""" + + record = await self.get_record(model_type, model_id) + return record.has_update() if record else False + + async def _refresh_single_model( + self, + model_type: str, + model_id: int, + local_versions: Sequence[int], + metadata_provider, + *, + force_refresh: bool = False, + ) -> Optional[ModelUpdateRecord]: + normalized_local = self._normalize_sequence(local_versions) + now = time.time() + async with self._lock: + existing = self._get_record(model_type, model_id) + if existing and existing.should_ignore and not force_refresh: + record = ModelUpdateRecord( + model_type=model_type, + model_id=model_id, + largest_version_id=existing.largest_version_id, + version_ids=list(existing.version_ids), + in_library_version_ids=normalized_local, + last_checked_at=existing.last_checked_at, + should_ignore=True, + ) + self._upsert_record(record) + return record + + should_fetch = force_refresh or not existing or self._is_stale(existing, now) + # release lock during network request + fetched_versions: List[int] | None = None + if metadata_provider and should_fetch: + try: + response = await metadata_provider.get_model_versions(model_id) + except RateLimitError: + raise + except Exception as exc: # pragma: no cover - defensive log + logger.error( + "Failed to fetch versions for model %s (%s): %s", + model_id, + model_type, + exc, + exc_info=True, + ) + else: + fetched_versions = self._extract_version_ids(response) + + async with self._lock: + existing = self._get_record(model_type, model_id) + if existing and existing.should_ignore and not force_refresh: + # Ignore state could have flipped while awaiting provider + record = ModelUpdateRecord( + model_type=model_type, + model_id=model_id, + largest_version_id=existing.largest_version_id, + version_ids=list(existing.version_ids), + in_library_version_ids=normalized_local, + last_checked_at=existing.last_checked_at, + should_ignore=True, + ) + self._upsert_record(record) + return record + + version_ids = ( + fetched_versions + if fetched_versions is not None + else (list(existing.version_ids) if existing else []) + ) + largest = max(version_ids) if version_ids else None + last_checked = now if fetched_versions is not None else ( + existing.last_checked_at if existing else None + ) + record = ModelUpdateRecord( + model_type=model_type, + model_id=model_id, + largest_version_id=largest, + version_ids=version_ids, + in_library_version_ids=normalized_local, + last_checked_at=last_checked, + should_ignore=existing.should_ignore if existing else False, + ) + self._upsert_record(record) + return record + + async def _collect_local_versions(self, scanner) -> Dict[int, List[int]]: + cache = await scanner.get_cached_data() + mapping: Dict[int, set[int]] = {} + if not cache or not getattr(cache, "raw_data", None): + return {} + + for item in cache.raw_data: + civitai = item.get("civitai") if isinstance(item, dict) else None + if not isinstance(civitai, dict): + continue + model_id = self._normalize_int(civitai.get("modelId")) + version_id = self._normalize_int(civitai.get("id")) + if model_id is None or version_id is None: + continue + mapping.setdefault(model_id, set()).add(version_id) + + return {model_id: sorted(ids) for model_id, ids in mapping.items()} + + def _is_stale(self, record: ModelUpdateRecord, now: float) -> bool: + if record.last_checked_at is None: + return True + return (now - record.last_checked_at) >= self._ttl_seconds + + @staticmethod + def _normalize_int(value) -> Optional[int]: + try: + if value is None: + return None + return int(value) + except (TypeError, ValueError): + return None + + def _normalize_sequence(self, values: Sequence[int]) -> List[int]: + normalized = [ + item + for item in (self._normalize_int(value) for value in values) + if item is not None + ] + return sorted(dict.fromkeys(normalized)) + + def _extract_version_ids(self, response) -> List[int]: + if not isinstance(response, Mapping): + return [] + versions = response.get("modelVersions") + if not isinstance(versions, Iterable): + return [] + normalized = [] + for entry in versions: + if isinstance(entry, Mapping): + normalized_id = self._normalize_int(entry.get("id")) + else: + normalized_id = self._normalize_int(entry) + if normalized_id is not None: + normalized.append(normalized_id) + return sorted(dict.fromkeys(normalized)) + + def _get_record(self, model_type: str, model_id: int) -> Optional[ModelUpdateRecord]: + with self._connect() as conn: + row = conn.execute( + """ + SELECT model_type, model_id, largest_version_id, version_ids, + in_library_version_ids, last_checked_at, should_ignore + FROM model_update_status + WHERE model_type = ? AND model_id = ? + """, + (model_type, model_id), + ).fetchone() + if not row: + return None + return ModelUpdateRecord( + model_type=row["model_type"], + model_id=int(row["model_id"]), + largest_version_id=self._normalize_int(row["largest_version_id"]), + version_ids=self._deserialize_json_array(row["version_ids"]), + in_library_version_ids=self._deserialize_json_array( + row["in_library_version_ids"] + ), + last_checked_at=row["last_checked_at"], + should_ignore=bool(row["should_ignore"]), + ) + + def _upsert_record(self, record: ModelUpdateRecord) -> None: + payload = ( + record.model_type, + record.model_id, + record.largest_version_id, + json.dumps(record.version_ids), + json.dumps(record.in_library_version_ids), + record.last_checked_at, + 1 if record.should_ignore else 0, + ) + with self._connect() as conn: + conn.execute( + """ + INSERT INTO model_update_status ( + model_type, model_id, largest_version_id, version_ids, + in_library_version_ids, last_checked_at, should_ignore + ) VALUES (?, ?, ?, ?, ?, ?, ?) + ON CONFLICT(model_type, model_id) DO UPDATE SET + largest_version_id = excluded.largest_version_id, + version_ids = excluded.version_ids, + in_library_version_ids = excluded.in_library_version_ids, + last_checked_at = excluded.last_checked_at, + should_ignore = excluded.should_ignore + """, + payload, + ) + conn.commit() + + @staticmethod + def _deserialize_json_array(value) -> List[int]: + if not value: + return [] + try: + data = json.loads(value) + except (TypeError, json.JSONDecodeError): + return [] + if isinstance(data, list): + normalized = [] + for entry in data: + try: + normalized.append(int(entry)) + except (TypeError, ValueError): + continue + return sorted(dict.fromkeys(normalized)) + return [] + diff --git a/py/services/persistent_model_cache.py b/py/services/persistent_model_cache.py index 7dfb21ac..60ff2d64 100644 --- a/py/services/persistent_model_cache.py +++ b/py/services/persistent_model_cache.py @@ -81,6 +81,11 @@ class PersistentModelCache: def is_enabled(self) -> bool: return os.environ.get("LORA_MANAGER_DISABLE_PERSISTENT_CACHE", "0") != "1" + def get_database_path(self) -> str: + """Expose the resolved SQLite database path.""" + + return self._db_path + def load_cache(self, model_type: str) -> Optional[PersistedCacheData]: if not self.is_enabled(): return None diff --git a/py/services/service_registry.py b/py/services/service_registry.py index d3d65e65..4e3bea57 100644 --- a/py/services/service_registry.py +++ b/py/services/service_registry.py @@ -128,23 +128,45 @@ class ServiceRegistry: async def get_civitai_client(cls): """Get or create CivitAI client instance""" service_name = "civitai_client" - + if service_name in cls._services: return cls._services[service_name] - + async with cls._get_lock(service_name): # Double-check after acquiring lock if service_name in cls._services: return cls._services[service_name] - + # Import here to avoid circular imports from .civitai_client import CivitaiClient - + client = await CivitaiClient.get_instance() cls._services[service_name] = client logger.debug(f"Created and registered {service_name}") return client + @classmethod + async def get_model_update_service(cls): + """Get or create the model update tracking service.""" + + service_name = "model_update_service" + + if service_name in cls._services: + return cls._services[service_name] + + async with cls._get_lock(service_name): + if service_name in cls._services: + return cls._services[service_name] + + from .model_update_service import ModelUpdateService + from .persistent_model_cache import get_persistent_cache + + cache = get_persistent_cache() + service = ModelUpdateService(cache.get_database_path()) + cls._services[service_name] = service + logger.debug(f"Created and registered {service_name}") + return service + @classmethod async def get_civarchive_client(cls): """Get or create CivArchive client instance""" diff --git a/tests/routes/test_base_model_routes_smoke.py b/tests/routes/test_base_model_routes_smoke.py index 88bff245..90438b17 100644 --- a/tests/routes/test_base_model_routes_smoke.py +++ b/tests/routes/test_base_model_routes_smoke.py @@ -5,6 +5,7 @@ import sys from pathlib import Path import types +from dataclasses import dataclass, field from typing import Optional folder_paths_stub = types.SimpleNamespace(get_folder_paths=lambda *_: []) @@ -32,6 +33,41 @@ class DummyRoutes(BaseModelRoutes): def setup_specific_routes(self, registrar, prefix: str) -> None: # pragma: no cover - no extra routes in smoke tests return None + def __init__(self, service=None): + super().__init__(service) + self.set_model_update_service(NullModelUpdateService()) + + +@dataclass +class NullUpdateRecord: + model_type: str + model_id: int + largest_version_id: int | None = None + version_ids: list[int] = field(default_factory=list) + in_library_version_ids: list[int] = field(default_factory=list) + last_checked_at: float | None = None + should_ignore: bool = False + + def has_update(self) -> bool: + return False + + +class NullModelUpdateService: + async def refresh_for_model_type(self, *args, **kwargs): + return {} + + async def refresh_single_model(self, *args, **kwargs): + return None + + async def update_in_library_versions(self, model_type, model_id, version_ids): + return NullUpdateRecord(model_type=model_type, model_id=model_id, in_library_version_ids=list(version_ids)) + + async def set_should_ignore(self, model_type, model_id, should_ignore): + return NullUpdateRecord(model_type=model_type, model_id=model_id, should_ignore=should_ignore) + + async def get_record(self, *args, **kwargs): + return None + async def create_test_client(service) -> TestClient: routes = DummyRoutes(service) diff --git a/tests/services/test_model_update_service.py b/tests/services/test_model_update_service.py new file mode 100644 index 00000000..3cabcd38 --- /dev/null +++ b/tests/services/test_model_update_service.py @@ -0,0 +1,85 @@ +import asyncio +from types import SimpleNamespace + +import pytest + +from py.services.model_update_service import ModelUpdateService + + +class DummyScanner: + def __init__(self, raw_data): + self._cache = SimpleNamespace(raw_data=raw_data) + + async def get_cached_data(self, *args, **kwargs): + return self._cache + + +class DummyProvider: + def __init__(self, response): + self.response = response + self.calls: int = 0 + + async def get_model_versions(self, model_id): + self.calls += 1 + return self.response + + +@pytest.mark.asyncio +async def test_refresh_persists_versions_and_uses_cache(tmp_path): + db_path = tmp_path / "updates.sqlite" + service = ModelUpdateService(str(db_path), ttl_seconds=3600) + raw_data = [ + {"civitai": {"modelId": 1, "id": 11}}, + {"civitai": {"modelId": 1, "id": 15}}, + ] + scanner = DummyScanner(raw_data) + provider = DummyProvider({"modelVersions": [{"id": 11}, {"id": 15}]}) + + await service.refresh_for_model_type("lora", scanner, provider) + record = await service.get_record("lora", 1) + + assert provider.calls == 1 + assert record is not None + assert record.version_ids == [11, 15] + assert record.in_library_version_ids == [11, 15] + assert record.has_update() is False + + await service.refresh_for_model_type("lora", scanner, provider) + assert provider.calls == 1, "provider should not be called again within TTL" + + +@pytest.mark.asyncio +async def test_refresh_respects_ignore_flag(tmp_path): + db_path = tmp_path / "updates.sqlite" + service = ModelUpdateService(str(db_path), ttl_seconds=3600) + raw_data = [{"civitai": {"modelId": 2, "id": 21}}] + scanner = DummyScanner(raw_data) + provider = DummyProvider({"modelVersions": [{"id": 21}, {"id": 22}]}) + + await service.refresh_for_model_type("lora", scanner, provider) + await service.set_should_ignore("lora", 2, True) + + provider.calls = 0 + await service.refresh_for_model_type("lora", scanner, provider) + assert provider.calls == 0 + + +@pytest.mark.asyncio +async def test_update_in_library_versions_changes_update_state(tmp_path): + db_path = tmp_path / "updates.sqlite" + service = ModelUpdateService(str(db_path), ttl_seconds=1) + raw_data = [{"civitai": {"modelId": 3, "id": 31}}] + scanner = DummyScanner(raw_data) + provider = DummyProvider({"modelVersions": [{"id": 31}, {"id": 35}]}) + + await service.refresh_for_model_type("lora", scanner, provider) + await service.update_in_library_versions("lora", 3, [31]) + record = await service.get_record("lora", 3) + + assert record is not None + assert record.has_update() is True + + await service.update_in_library_versions("lora", 3, [31, 35]) + record = await service.get_record("lora", 3) + + assert record.has_update() is False