diff --git a/py/routes/handlers/model_handlers.py b/py/routes/handlers/model_handlers.py index b1254e0b..76c4d040 100644 --- a/py/routes/handlers/model_handlers.py +++ b/py/routes/handlers/model_handlers.py @@ -1085,6 +1085,28 @@ class ModelUpdateHandler: ) return web.json_response({"success": True, "record": self._serialize_record(record)}) + async def set_version_update_ignore(self, request: web.Request) -> web.Response: + payload = await self._read_json(request) + model_id = self._normalize_model_id(payload.get("modelId")) + version_id = self._normalize_model_id(payload.get("versionId")) + if model_id is None or version_id is None: + return web.json_response( + {"success": False, "error": "modelId and versionId are required"}, + status=400, + ) + + should_ignore = self._parse_bool(payload.get("shouldIgnore")) + record = await self._update_service.set_version_should_ignore( + self._service.model_type, + model_id, + version_id, + should_ignore, + ) + overrides = await self._build_preview_overrides(record) + return web.json_response( + {"success": True, "record": self._serialize_record(record, preview_overrides=overrides)} + ) + async def get_model_update_status(self, request: web.Request) -> web.Response: model_id = self._normalize_model_id(request.match_info.get("model_id")) if model_id is None: @@ -1107,6 +1129,33 @@ class ModelUpdateHandler: return web.json_response({"success": True, "record": self._serialize_record(record)}) + async def get_model_versions(self, request: web.Request) -> web.Response: + model_id = self._normalize_model_id(request.match_info.get("model_id")) + if model_id is None: + return web.json_response( + {"success": False, "error": "model_id must be an integer"}, status=400 + ) + + refresh = self._parse_bool(request.query.get("refresh")) + force = self._parse_bool(request.query.get("force")) + + try: + record = await self._get_or_refresh_record(model_id, refresh=refresh, force=force) + except RateLimitError as exc: + return web.json_response( + {"success": False, "error": str(exc) or "Rate limited"}, status=429 + ) + + if record is None: + return web.json_response( + {"success": False, "error": "Model not tracked"}, status=404 + ) + + overrides = await self._build_preview_overrides(record) + return web.json_response( + {"success": True, "record": self._serialize_record(record, preview_overrides=overrides)} + ) + async def _get_or_refresh_record( self, model_id: int, *, refresh: bool, force: bool ) -> Optional[object]: @@ -1160,8 +1209,13 @@ class ModelUpdateHandler: except (TypeError, ValueError): return None - @staticmethod - def _serialize_record(record) -> Dict: + def _serialize_record( + self, + record, + *, + preview_overrides: Optional[Dict[int, Optional[str]]] = None, + ) -> Dict: + overrides = preview_overrides or {} return { "modelType": record.model_type, "modelId": record.model_id, @@ -1169,10 +1223,50 @@ class ModelUpdateHandler: "versionIds": record.version_ids, "inLibraryVersionIds": record.in_library_version_ids, "lastCheckedAt": record.last_checked_at, - "shouldIgnore": record.should_ignore, + "shouldIgnore": record.should_ignore_model, "hasUpdate": record.has_update(), + "versions": [ + self._serialize_version(version, overrides.get(version.version_id)) + for version in record.versions + ], } + @staticmethod + def _serialize_version(version, override_preview: Optional[str]) -> Dict: + preview_url = override_preview if override_preview is not None else version.preview_url + return { + "versionId": version.version_id, + "name": version.name, + "baseModel": version.base_model, + "releasedAt": version.released_at, + "sizeBytes": version.size_bytes, + "previewUrl": preview_url, + "isInLibrary": version.is_in_library, + "shouldIgnore": version.should_ignore, + } + + async def _build_preview_overrides(self, record) -> Dict[int, Optional[str]]: + overrides: Dict[int, Optional[str]] = {} + try: + cache = await self._service.scanner.get_cached_data() + except Exception as exc: # pragma: no cover - defensive logging + self._logger.debug("Failed to load cache while building preview overrides: %s", exc) + return overrides + + version_index = getattr(cache, "version_index", None) + if not version_index: + return overrides + + for version in record.versions: + if not version.is_in_library: + continue + cache_entry = version_index.get(version.version_id) + if isinstance(cache_entry, Mapping): + preview = cache_entry.get("preview_url") + if isinstance(preview, str) and preview: + overrides[version.version_id] = config.get_preview_static_url(preview) + return overrides + @dataclass class ModelHandlerSet: @@ -1233,6 +1327,8 @@ class ModelHandlerSet: "get_relative_paths": self.query.get_relative_paths, "refresh_model_updates": self.updates.refresh_model_updates, "set_model_update_ignore": self.updates.set_model_update_ignore, + "set_version_update_ignore": self.updates.set_version_update_ignore, "get_model_update_status": self.updates.get_model_update_status, + "get_model_versions": self.updates.get_model_versions, } diff --git a/py/routes/model_route_registrar.py b/py/routes/model_route_registrar.py index ff57672a..12b36850 100644 --- a/py/routes/model_route_registrar.py +++ b/py/routes/model_route_registrar.py @@ -57,7 +57,9 @@ COMMON_ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = ( RouteDefinition("GET", "/api/lm/{prefix}/civitai/model/hash/{hash}", "get_civitai_model_by_hash"), RouteDefinition("POST", "/api/lm/{prefix}/updates/refresh", "refresh_model_updates"), RouteDefinition("POST", "/api/lm/{prefix}/updates/ignore", "set_model_update_ignore"), + RouteDefinition("POST", "/api/lm/{prefix}/updates/ignore-version", "set_version_update_ignore"), RouteDefinition("GET", "/api/lm/{prefix}/updates/status/{model_id}", "get_model_update_status"), + RouteDefinition("GET", "/api/lm/{prefix}/updates/versions/{model_id}", "get_model_versions"), RouteDefinition("POST", "/api/lm/download-model", "download_model"), RouteDefinition("GET", "/api/lm/download-model-get", "download_model_get"), RouteDefinition("GET", "/api/lm/cancel-download-get", "cancel_download_get"), diff --git a/py/services/model_update_service.py b/py/services/model_update_service.py index 4f89e6e9..6392ff32 100644 --- a/py/services/model_update_service.py +++ b/py/services/model_update_service.py @@ -2,61 +2,112 @@ from __future__ import annotations import asyncio -import json import logging import os import sqlite3 import time -from dataclasses import dataclass +from dataclasses import dataclass, replace from typing import Dict, Iterable, List, Mapping, Optional, Sequence from .errors import RateLimitError +from .settings_manager import get_settings_manager +from ..utils.civitai_utils import rewrite_preview_url +from ..utils.preview_selection import select_preview_media logger = logging.getLogger(__name__) +@dataclass +class ModelVersionRecord: + """Persisted metadata for a single model version.""" + + version_id: int + name: Optional[str] + base_model: Optional[str] + released_at: Optional[str] + size_bytes: Optional[int] + preview_url: Optional[str] + is_in_library: bool + should_ignore: bool + sort_index: int = 0 + + @dataclass class ModelUpdateRecord: """Representation of a persisted update record.""" model_type: str model_id: int - largest_version_id: Optional[int] - version_ids: List[int] - in_library_version_ids: List[int] + versions: List[ModelVersionRecord] last_checked_at: Optional[float] - should_ignore: bool + should_ignore_model: bool + + @property + def largest_version_id(self) -> Optional[int]: + """Return the highest known version identifier for the model.""" + + if not self.versions: + return None + return max(version.version_id for version in self.versions) + + @property + def version_ids(self) -> List[int]: + """Return all known version identifiers.""" + + return [version.version_id for version in self.versions] + + @property + def in_library_version_ids(self) -> List[int]: + """Return the subset of version identifiers present in the local library.""" + + return [version.version_id for version in self.versions if version.is_in_library] def has_update(self) -> bool: - """Return True when remote versions exceed the local library.""" + """Return True when a non-ignored remote version is missing locally.""" - if self.should_ignore or not self.version_ids: + if self.should_ignore_model: return False - local_versions = set(self.in_library_version_ids) - return any(version_id not in local_versions for version_id in self.version_ids) + return any( + not version.is_in_library and not version.should_ignore for version in self.versions + ) class ModelUpdateService: """Persist and query remote model version metadata.""" _SCHEMA = """ - CREATE TABLE IF NOT EXISTS model_update_status ( + PRAGMA foreign_keys = ON; + DROP TABLE IF EXISTS model_update_versions; + DROP TABLE IF EXISTS model_update_status; + CREATE TABLE model_update_status ( + model_id INTEGER PRIMARY KEY, model_type TEXT NOT NULL, - model_id INTEGER NOT NULL, - largest_version_id INTEGER, - version_ids TEXT, - in_library_version_ids TEXT, last_checked_at REAL, - should_ignore INTEGER DEFAULT 0, - PRIMARY KEY (model_type, model_id) - ) + should_ignore_model INTEGER NOT NULL DEFAULT 0 + ); + CREATE TABLE model_update_versions ( + version_id INTEGER PRIMARY KEY, + model_id INTEGER NOT NULL, + sort_index INTEGER NOT NULL DEFAULT 0, + name TEXT, + base_model TEXT, + released_at TEXT, + size_bytes INTEGER, + preview_url TEXT, + is_in_library INTEGER NOT NULL DEFAULT 0, + should_ignore INTEGER NOT NULL DEFAULT 0, + FOREIGN KEY(model_id) REFERENCES model_update_status(model_id) ON DELETE CASCADE + ); + CREATE INDEX IF NOT EXISTS idx_model_update_versions_model_id + ON model_update_versions(model_id); """ - def __init__(self, db_path: str, *, ttl_seconds: int = 24 * 60 * 60) -> None: + def __init__(self, db_path: str, *, ttl_seconds: int = 24 * 60 * 60, settings_manager=None) -> None: self._db_path = db_path self._ttl_seconds = ttl_seconds self._lock = asyncio.Lock() self._schema_initialized = False + self._settings = settings_manager or get_settings_manager() self._ensure_directory() self._initialize_schema() @@ -103,7 +154,7 @@ class ModelUpdateService: async with self._lock: for model_id in local_versions.keys(): existing = self._get_record(model_type, model_id) - if existing and existing.should_ignore and not force_refresh: + if existing and existing.should_ignore_model and not force_refresh: continue if force_refresh or not existing or self._is_stale(existing, now): fetch_targets.append(model_id) @@ -162,14 +213,11 @@ class ModelUpdateService: normalized_versions = self._normalize_sequence(version_ids) async with self._lock: existing = self._get_record(model_type, model_id) - record = ModelUpdateRecord( + record = self._merge_with_local_versions( + existing, + normalized_versions, model_type=model_type, model_id=model_id, - largest_version_id=existing.largest_version_id if existing else None, - version_ids=list(existing.version_ids) if existing else [], - in_library_version_ids=normalized_versions, - last_checked_at=existing.last_checked_at if existing else None, - should_ignore=existing.should_ignore if existing else False, ) self._upsert_record(record) return record @@ -183,27 +231,70 @@ class ModelUpdateService: existing = self._get_record(model_type, model_id) if existing: record = ModelUpdateRecord( - model_type=model_type, - model_id=model_id, - largest_version_id=existing.largest_version_id, - version_ids=list(existing.version_ids), - in_library_version_ids=list(existing.in_library_version_ids), + model_type=existing.model_type, + model_id=existing.model_id, + versions=list(existing.versions), last_checked_at=existing.last_checked_at, - should_ignore=should_ignore, + should_ignore_model=should_ignore, ) else: record = ModelUpdateRecord( model_type=model_type, model_id=model_id, - largest_version_id=None, - version_ids=[], - in_library_version_ids=[], + versions=[], last_checked_at=None, - should_ignore=should_ignore, + should_ignore_model=should_ignore, ) self._upsert_record(record) return record + async def set_version_should_ignore( + self, + model_type: str, + model_id: int, + version_id: int, + should_ignore: bool, + ) -> ModelUpdateRecord: + """Toggle the ignore flag for an individual version.""" + + async with self._lock: + existing = self._get_record(model_type, model_id) + versions: List[ModelVersionRecord] = [] + found = False + if existing: + for record_version in existing.versions: + if record_version.version_id == version_id: + versions.append( + replace(record_version, should_ignore=should_ignore) + ) + found = True + else: + versions.append(record_version) + if not found: + versions.append( + ModelVersionRecord( + version_id=version_id, + name=None, + base_model=None, + released_at=None, + size_bytes=None, + preview_url=None, + is_in_library=False, + should_ignore=should_ignore, + sort_index=len(versions), + ) + ) + + record = ModelUpdateRecord( + model_type=existing.model_type if existing else model_type, + model_id=existing.model_id if existing else model_id, + versions=self._sorted_versions(versions), + last_checked_at=existing.last_checked_at if existing else None, + should_ignore_model=existing.should_ignore_model if existing else False, + ) + self._upsert_record(record) + return record + async def get_record(self, model_type: str, model_id: int) -> Optional[ModelUpdateRecord]: """Return a cached record without triggering remote fetches.""" @@ -230,22 +321,17 @@ class ModelUpdateService: now = time.time() async with self._lock: existing = self._get_record(model_type, model_id) - if existing and existing.should_ignore and not force_refresh: - record = ModelUpdateRecord( - model_type=model_type, - model_id=model_id, - largest_version_id=existing.largest_version_id, - version_ids=list(existing.version_ids), - in_library_version_ids=normalized_local, - last_checked_at=existing.last_checked_at, - should_ignore=True, + if existing and existing.should_ignore_model and not force_refresh: + record = self._merge_with_local_versions( + existing, + normalized_local, ) self._upsert_record(record) return record should_fetch = force_refresh or not existing or self._is_stale(existing, now) # release lock during network request - fetched_versions: List[int] | None = None + fetched_versions: List[ModelVersionRecord] | None = None refresh_succeeded = False response: Optional[Mapping] = None if metadata_provider and should_fetch: @@ -264,45 +350,38 @@ class ModelUpdateService: exc_info=True, ) if response is not None: - extracted = self._extract_version_ids(response) + extracted = self._extract_versions(response) if extracted is not None: fetched_versions = extracted refresh_succeeded = True async with self._lock: existing = self._get_record(model_type, model_id) - if existing and existing.should_ignore and not force_refresh: - # Ignore state could have flipped while awaiting provider - record = ModelUpdateRecord( - model_type=model_type, - model_id=model_id, - largest_version_id=existing.largest_version_id, - version_ids=list(existing.version_ids), - in_library_version_ids=normalized_local, - last_checked_at=existing.last_checked_at, - should_ignore=True, + if existing and existing.should_ignore_model and not force_refresh: + record = self._merge_with_local_versions( + existing, + normalized_local, ) self._upsert_record(record) return record - version_ids = ( - fetched_versions - if refresh_succeeded - else (list(existing.version_ids) if existing else []) - ) - largest = max(version_ids) if version_ids else None - last_checked = now if refresh_succeeded else ( - existing.last_checked_at if existing else None - ) - record = ModelUpdateRecord( - model_type=model_type, - model_id=model_id, - largest_version_id=largest, - version_ids=version_ids, - in_library_version_ids=normalized_local, - last_checked_at=last_checked, - should_ignore=existing.should_ignore if existing else False, - ) + if refresh_succeeded and isinstance(fetched_versions, list): + record = self._build_record_from_remote( + model_type, + model_id, + normalized_local, + fetched_versions, + existing, + now, + ) + else: + record = self._merge_with_local_versions( + existing, + normalized_local, + model_type=model_type, + model_id=model_id, + last_checked_at=existing.last_checked_at if existing else None, + ) self._upsert_record(record) return record @@ -358,6 +437,132 @@ class ModelUpdateService: return {model_id: sorted(ids) for model_id, ids in mapping.items()} + def _merge_with_local_versions( + self, + existing: Optional[ModelUpdateRecord], + normalized_local: Sequence[int], + *, + model_type: Optional[str] = None, + model_id: Optional[int] = None, + last_checked_at: Optional[float] = None, + ) -> ModelUpdateRecord: + local_set = set(normalized_local) + versions: List[ModelVersionRecord] = [] + ignore_map: Dict[int, bool] = {} + if existing: + model_type = existing.model_type + model_id = existing.model_id + last_checked_at = existing.last_checked_at if last_checked_at is None else last_checked_at + ignore_map = {version.version_id: version.should_ignore for version in existing.versions} + for version in existing.versions: + versions.append( + replace( + version, + is_in_library=version.version_id in local_set, + ) + ) + elif model_type is None or model_id is None: + raise ValueError("model_type and model_id are required when creating a new record") + + seen_ids = {version.version_id for version in versions} + for missing_id in sorted(local_set - seen_ids): + versions.append( + ModelVersionRecord( + version_id=missing_id, + name=None, + base_model=None, + released_at=None, + size_bytes=None, + preview_url=None, + is_in_library=True, + should_ignore=ignore_map.get(missing_id, False), + sort_index=len(versions), + ) + ) + + return ModelUpdateRecord( + model_type=model_type, + model_id=model_id, + versions=self._sorted_versions(versions), + last_checked_at=last_checked_at, + should_ignore_model=existing.should_ignore_model if existing else False, + ) + + def _build_record_from_remote( + self, + model_type: str, + model_id: int, + local_versions: Sequence[int], + remote_versions: Sequence[ModelVersionRecord], + existing: Optional[ModelUpdateRecord], + timestamp: float, + ) -> ModelUpdateRecord: + local_set = set(local_versions) + ignore_map = {version.version_id: version.should_ignore for version in existing.versions} if existing else {} + preview_map = {version.version_id: version.preview_url for version in existing.versions} if existing else {} + sort_map = {version.version_id: version.sort_index for version in existing.versions} if existing else {} + existing_map = {version.version_id: version for version in existing.versions} if existing else {} + + versions: List[ModelVersionRecord] = [] + seen_ids: set[int] = set() + for index, remote_version in enumerate(remote_versions): + version_id = remote_version.version_id + seen_ids.add(version_id) + versions.append( + ModelVersionRecord( + version_id=version_id, + name=remote_version.name, + base_model=remote_version.base_model, + released_at=remote_version.released_at, + size_bytes=remote_version.size_bytes, + preview_url=remote_version.preview_url or preview_map.get(version_id), + is_in_library=version_id in local_set, + should_ignore=ignore_map.get(version_id, remote_version.should_ignore), + sort_index=sort_map.get(version_id, index), + ) + ) + + missing_local = local_set - seen_ids + if missing_local: + for version_id in sorted(missing_local): + existing_version = existing_map.get(version_id) + if existing_version: + versions.append( + replace( + existing_version, + is_in_library=True, + ) + ) + else: + versions.append( + ModelVersionRecord( + version_id=version_id, + name=None, + base_model=None, + released_at=None, + size_bytes=None, + preview_url=None, + is_in_library=True, + should_ignore=ignore_map.get(version_id, False), + sort_index=len(versions), + ) + ) + + return ModelUpdateRecord( + model_type=model_type, + model_id=model_id, + versions=self._sorted_versions(versions), + last_checked_at=timestamp, + should_ignore_model=existing.should_ignore_model if existing else False, + ) + + def _sorted_versions(self, versions: Sequence[ModelVersionRecord]) -> List[ModelVersionRecord]: + ordered = sorted(versions, key=lambda version: (version.sort_index, version.version_id)) + normalized: List[ModelVersionRecord] = [] + for index, version in enumerate(ordered): + normalized.append(replace(version, sort_index=index)) + return normalized + def _is_stale(self, record: ModelUpdateRecord, now: float) -> bool: if record.last_checked_at is None: return True @@ -380,7 +585,19 @@ class ModelUpdateService: ] return sorted(dict.fromkeys(normalized)) - def _extract_version_ids(self, response) -> Optional[List[int]]: + @staticmethod + def _normalize_string(value) -> Optional[str]: + if value is None: + return None + if isinstance(value, str): + stripped = value.strip() + return stripped or None + try: + return str(value) + except Exception: # pragma: no cover - defensive conversion + return None + + def _extract_versions(self, response) -> Optional[List[ModelVersionRecord]]: if not isinstance(response, Mapping): return None versions = response.get("modelVersions") @@ -388,84 +605,175 @@ class ModelUpdateService: return [] if not isinstance(versions, Iterable): return None - normalized = [] - for entry in versions: - if isinstance(entry, Mapping): - normalized_id = self._normalize_int(entry.get("id")) - else: - normalized_id = self._normalize_int(entry) - if normalized_id is not None: - normalized.append(normalized_id) - return sorted(dict.fromkeys(normalized)) + extracted: List[ModelVersionRecord] = [] + for index, entry in enumerate(versions): + if not isinstance(entry, Mapping): + continue + version_id = self._normalize_int(entry.get("id")) + if version_id is None: + continue + name = self._normalize_string(entry.get("name")) + base_model = self._normalize_string(entry.get("baseModel")) + released_at = self._normalize_string(entry.get("publishedAt") or entry.get("createdAt")) + size_bytes = self._extract_size_bytes(entry.get("files")) + preview_url = self._extract_preview_url(entry.get("images")) + extracted.append( + ModelVersionRecord( + version_id=version_id, + name=name, + base_model=base_model, + released_at=released_at, + size_bytes=size_bytes, + preview_url=preview_url, + is_in_library=False, + should_ignore=False, + sort_index=index, + ) + ) + return extracted + + def _extract_size_bytes(self, files) -> Optional[int]: + if not isinstance(files, Iterable): + return None + for entry in files: + if not isinstance(entry, Mapping): + continue + size_kb = entry.get("sizeKB") + if size_kb is None: + continue + try: + size_float = float(size_kb) + except (TypeError, ValueError): + continue + return int(size_float * 1024) + return None + + def _extract_preview_url(self, images) -> Optional[str]: + if not isinstance(images, Iterable): + return None + + candidates = [entry for entry in images if isinstance(entry, Mapping)] + if not candidates: + return None + + blur_mature_content = True + settings = getattr(self, "_settings", None) + if settings is not None and hasattr(settings, "get"): + try: + blur_mature_content = bool(settings.get("blur_mature_content", True)) + except Exception: # pragma: no cover - defensive guard + blur_mature_content = True + + selected, _ = select_preview_media(candidates, blur_mature_content=blur_mature_content) + if not selected: + return None + + url = selected.get("url") + if not isinstance(url, str) or not url: + return None + + media_type = selected.get("type") + if not isinstance(media_type, str): + media_type = None + + rewritten, _ = rewrite_preview_url(url, media_type) + return rewritten or url def _get_record(self, model_type: str, model_id: int) -> Optional[ModelUpdateRecord]: with self._connect() as conn: - row = conn.execute( + status_row = conn.execute( """ - SELECT model_type, model_id, largest_version_id, version_ids, - in_library_version_ids, last_checked_at, should_ignore + SELECT model_id, model_type, last_checked_at, should_ignore_model FROM model_update_status - WHERE model_type = ? AND model_id = ? + WHERE model_id = ? """, - (model_type, model_id), + (model_id,), ).fetchone() - if not row: - return None + if not status_row: + return None + stored_type = status_row["model_type"] + if stored_type and stored_type != model_type: + logger.debug( + "Model id %s requested as %s but stored as %s", model_id, model_type, stored_type + ) + version_rows = conn.execute( + """ + SELECT version_id, sort_index, name, base_model, released_at, size_bytes, + preview_url, is_in_library, should_ignore + FROM model_update_versions + WHERE model_id = ? + ORDER BY sort_index ASC, version_id ASC + """, + (model_id,), + ).fetchall() + + versions: List[ModelVersionRecord] = [] + for row in version_rows: + versions.append( + ModelVersionRecord( + version_id=int(row["version_id"]), + name=row["name"], + base_model=row["base_model"], + released_at=row["released_at"], + size_bytes=self._normalize_int(row["size_bytes"]), + preview_url=row["preview_url"], + is_in_library=bool(row["is_in_library"]), + should_ignore=bool(row["should_ignore"]), + sort_index=self._normalize_int(row["sort_index"]) or 0, + ) + ) + return ModelUpdateRecord( - model_type=row["model_type"], - model_id=int(row["model_id"]), - largest_version_id=self._normalize_int(row["largest_version_id"]), - version_ids=self._deserialize_json_array(row["version_ids"]), - in_library_version_ids=self._deserialize_json_array( - row["in_library_version_ids"] - ), - last_checked_at=row["last_checked_at"], - should_ignore=bool(row["should_ignore"]), + model_type=stored_type or model_type, + model_id=int(status_row["model_id"]), + versions=self._sorted_versions(versions), + last_checked_at=status_row["last_checked_at"], + should_ignore_model=bool(status_row["should_ignore_model"]), ) def _upsert_record(self, record: ModelUpdateRecord) -> None: payload = ( - record.model_type, record.model_id, - record.largest_version_id, - json.dumps(record.version_ids), - json.dumps(record.in_library_version_ids), + record.model_type, record.last_checked_at, - 1 if record.should_ignore else 0, + 1 if record.should_ignore_model else 0, ) with self._connect() as conn: conn.execute( """ INSERT INTO model_update_status ( - model_type, model_id, largest_version_id, version_ids, - in_library_version_ids, last_checked_at, should_ignore - ) VALUES (?, ?, ?, ?, ?, ?, ?) - ON CONFLICT(model_type, model_id) DO UPDATE SET - largest_version_id = excluded.largest_version_id, - version_ids = excluded.version_ids, - in_library_version_ids = excluded.in_library_version_ids, + model_id, model_type, last_checked_at, should_ignore_model + ) VALUES (?, ?, ?, ?) + ON CONFLICT(model_id) DO UPDATE SET + model_type = excluded.model_type, last_checked_at = excluded.last_checked_at, - should_ignore = excluded.should_ignore + should_ignore_model = excluded.should_ignore_model """, payload, ) + conn.execute( + "DELETE FROM model_update_versions WHERE model_id = ?", + (record.model_id,), + ) + for version in record.versions: + conn.execute( + """ + INSERT INTO model_update_versions ( + version_id, model_id, sort_index, name, base_model, released_at, + size_bytes, preview_url, is_in_library, should_ignore + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + version.version_id, + record.model_id, + version.sort_index, + version.name, + version.base_model, + version.released_at, + version.size_bytes, + version.preview_url, + 1 if version.is_in_library else 0, + 1 if version.should_ignore else 0, + ), + ) conn.commit() - - @staticmethod - def _deserialize_json_array(value) -> List[int]: - if not value: - return [] - try: - data = json.loads(value) - except (TypeError, json.JSONDecodeError): - return [] - if isinstance(data, list): - normalized = [] - for entry in data: - try: - normalized.append(int(entry)) - except (TypeError, ValueError): - continue - return sorted(dict.fromkeys(normalized)) - return [] - diff --git a/tests/routes/test_base_model_routes_smoke.py b/tests/routes/test_base_model_routes_smoke.py index 90438b17..bd4b8550 100644 --- a/tests/routes/test_base_model_routes_smoke.py +++ b/tests/routes/test_base_model_routes_smoke.py @@ -21,6 +21,7 @@ from py.services import model_file_service from py.services.downloader import DownloadProgress from py.services.metadata_sync_service import MetadataSyncService from py.services.model_file_service import AutoOrganizeResult +from py.services.model_update_service import ModelVersionRecord from py.services.service_registry import ServiceRegistry from py.services.websocket_manager import ws_manager from py.utils.exif_utils import ExifUtils @@ -42,11 +43,23 @@ class DummyRoutes(BaseModelRoutes): class NullUpdateRecord: model_type: str model_id: int - largest_version_id: int | None = None - version_ids: list[int] = field(default_factory=list) - in_library_version_ids: list[int] = field(default_factory=list) + versions: list[ModelVersionRecord] = field(default_factory=list) last_checked_at: float | None = None - should_ignore: bool = False + should_ignore_model: bool = False + + @property + def largest_version_id(self) -> int | None: + if not self.versions: + return None + return max(version.version_id for version in self.versions) + + @property + def version_ids(self) -> list[int]: + return [version.version_id for version in self.versions] + + @property + def in_library_version_ids(self) -> list[int]: + return [version.version_id for version in self.versions if version.is_in_library] def has_update(self) -> bool: return False @@ -60,10 +73,30 @@ class NullModelUpdateService: return None async def update_in_library_versions(self, model_type, model_id, version_ids): - return NullUpdateRecord(model_type=model_type, model_id=model_id, in_library_version_ids=list(version_ids)) + versions = [ + ModelVersionRecord( + version_id=version_id, + name=None, + base_model=None, + released_at=None, + size_bytes=None, + preview_url=None, + is_in_library=True, + should_ignore=False, + ) + for version_id in version_ids + ] + return NullUpdateRecord(model_type=model_type, model_id=model_id, versions=versions) async def set_should_ignore(self, model_type, model_id, should_ignore): - return NullUpdateRecord(model_type=model_type, model_id=model_id, should_ignore=should_ignore) + return NullUpdateRecord( + model_type=model_type, + model_id=model_id, + should_ignore_model=should_ignore, + ) + + async def set_version_should_ignore(self, model_type, model_id, version_id, should_ignore): + return await self.set_should_ignore(model_type, model_id, should_ignore) async def get_record(self, *args, **kwargs): return None diff --git a/tests/routes/test_model_update_handler.py b/tests/routes/test_model_update_handler.py new file mode 100644 index 00000000..864c1b78 --- /dev/null +++ b/tests/routes/test_model_update_handler.py @@ -0,0 +1,57 @@ +import logging +from types import SimpleNamespace + +import pytest + +from py.config import config +from py.routes.handlers.model_handlers import ModelUpdateHandler +from py.services.model_update_service import ModelUpdateRecord, ModelVersionRecord + + +class DummyScanner: + def __init__(self, cache): + self._cache = cache + + async def get_cached_data(self): + return self._cache + + +class DummyService: + def __init__(self, cache): + self.model_type = "lora" + self.scanner = DummyScanner(cache) + + +@pytest.mark.asyncio +async def test_build_preview_overrides_uses_static_urls(): + cache = SimpleNamespace(version_index={123: {"preview_url": "/tmp/previews/example.png"}}) + service = DummyService(cache) + handler = ModelUpdateHandler( + service=service, + update_service=SimpleNamespace(), + metadata_provider_selector=lambda *_: None, + logger=logging.getLogger(__name__), + ) + + record = ModelUpdateRecord( + model_type="lora", + model_id=42, + versions=[ + ModelVersionRecord( + version_id=123, + name=None, + base_model=None, + released_at=None, + size_bytes=None, + preview_url=None, + is_in_library=True, + should_ignore=False, + ) + ], + last_checked_at=None, + should_ignore_model=False, + ) + + overrides = await handler._build_preview_overrides(record) + expected = config.get_preview_static_url("/tmp/previews/example.png") + assert overrides == {123: expected} diff --git a/tests/services/test_model_update_service.py b/tests/services/test_model_update_service.py index b9e59e65..09e869c8 100644 --- a/tests/services/test_model_update_service.py +++ b/tests/services/test_model_update_service.py @@ -1,4 +1,3 @@ -import asyncio from types import SimpleNamespace import pytest @@ -8,7 +7,7 @@ from py.services.model_update_service import ModelUpdateService class DummyScanner: def __init__(self, raw_data): - self._cache = SimpleNamespace(raw_data=raw_data) + self._cache = SimpleNamespace(raw_data=raw_data, version_index={}) async def get_cached_data(self, *args, **kwargs): return self._cache @@ -41,7 +40,28 @@ async def test_refresh_persists_versions_and_uses_cache(tmp_path): {"civitai": {"modelId": 1, "id": 15}}, ] scanner = DummyScanner(raw_data) - provider = DummyProvider({"modelVersions": [{"id": 11}, {"id": 15}]}) + provider = DummyProvider( + { + "modelVersions": [ + { + "id": 11, + "name": "v1", + "baseModel": "SD15", + "publishedAt": "2024-01-01T00:00:00Z", + "files": [{"sizeKB": 1024}], + "images": [{"url": "https://example.com/1.png"}], + }, + { + "id": 15, + "name": "v1.5", + "baseModel": "SD15", + "publishedAt": "2024-02-01T00:00:00Z", + "files": [{"sizeKB": 512}], + "images": [{"url": "https://example.com/2.png"}], + }, + ] + } + ) await service.refresh_for_model_type("lora", scanner, provider) record = await service.get_record("lora", 1) @@ -51,6 +71,8 @@ async def test_refresh_persists_versions_and_uses_cache(tmp_path): assert record is not None assert record.version_ids == [11, 15] assert record.in_library_version_ids == [11, 15] + assert [version.name for version in record.versions] == ["v1", "v1.5"] + assert record.should_ignore_model is False assert record.has_update() is False await service.refresh_for_model_type("lora", scanner, provider) @@ -64,7 +86,14 @@ async def test_refresh_respects_ignore_flag(tmp_path): service = ModelUpdateService(str(db_path), ttl_seconds=3600) raw_data = [{"civitai": {"modelId": 2, "id": 21}}] scanner = DummyScanner(raw_data) - provider = DummyProvider({"modelVersions": [{"id": 21}, {"id": 22}]}) + provider = DummyProvider( + { + "modelVersions": [ + {"id": 21, "files": [], "images": []}, + {"id": 22, "files": [], "images": []}, + ] + } + ) await service.refresh_for_model_type("lora", scanner, provider) await service.set_should_ignore("lora", 2, True) @@ -74,6 +103,9 @@ async def test_refresh_respects_ignore_flag(tmp_path): await service.refresh_for_model_type("lora", scanner, provider) assert provider.calls == 0 assert provider.bulk_calls == [] + record = await service.get_record("lora", 2) + assert record is not None + assert record.should_ignore_model is True @pytest.mark.asyncio @@ -82,7 +114,10 @@ async def test_refresh_falls_back_when_bulk_not_supported(tmp_path): service = ModelUpdateService(str(db_path), ttl_seconds=3600) raw_data = [{"civitai": {"modelId": 4, "id": 41}}] scanner = DummyScanner(raw_data) - provider = DummyProvider({"modelVersions": [{"id": 41}]}, support_bulk=False) + provider = DummyProvider( + {"modelVersions": [{"id": 41, "files": [], "images": []}]}, + support_bulk=False, + ) await service.refresh_for_model_type("lora", scanner, provider) record = await service.get_record("lora", 4) @@ -117,7 +152,14 @@ async def test_update_in_library_versions_changes_update_state(tmp_path): service = ModelUpdateService(str(db_path), ttl_seconds=1) raw_data = [{"civitai": {"modelId": 3, "id": 31}}] scanner = DummyScanner(raw_data) - provider = DummyProvider({"modelVersions": [{"id": 31}, {"id": 35}]}) + provider = DummyProvider( + { + "modelVersions": [ + {"id": 31, "files": [], "images": []}, + {"id": 35, "files": [], "images": []}, + ] + } + ) await service.refresh_for_model_type("lora", scanner, provider) await service.update_in_library_versions("lora", 3, [31]) @@ -130,3 +172,70 @@ async def test_update_in_library_versions_changes_update_state(tmp_path): record = await service.get_record("lora", 3) assert record.has_update() is False + + +@pytest.mark.asyncio +async def test_version_ignore_blocks_update_flag(tmp_path): + db_path = tmp_path / "updates.sqlite" + service = ModelUpdateService(str(db_path), ttl_seconds=1) + raw_data = [{"civitai": {"modelId": 5, "id": 51}}] + scanner = DummyScanner(raw_data) + provider = DummyProvider( + { + "modelVersions": [ + {"id": 51, "files": [], "images": []}, + {"id": 55, "files": [], "images": []}, + ] + } + ) + + await service.refresh_for_model_type("lora", scanner, provider) + record = await service.get_record("lora", 5) + assert record is not None + assert record.has_update() is True + + await service.set_version_should_ignore("lora", 5, 55, True) + record = await service.get_record("lora", 5) + assert record is not None + assert record.has_update() is False + + +@pytest.mark.asyncio +async def test_refresh_rewrites_remote_preview_urls(tmp_path): + db_path = tmp_path / "updates.sqlite" + service = ModelUpdateService(str(db_path), ttl_seconds=1) + raw_data = [{"civitai": {"modelId": 7, "id": 71}}] + scanner = DummyScanner(raw_data) + provider = DummyProvider( + { + "modelVersions": [ + { + "id": 71, + "files": [], + "images": [ + { + "url": "https://image.civitai.com/high/original=true/sample.png", + "nsfwLevel": 6, + "type": "image", + }, + { + "url": "https://image.civitai.com/safe/original=true/preview.png", + "nsfwLevel": 1, + "type": "image", + }, + ], + } + ] + } + ) + + await service.refresh_for_model_type("lora", scanner, provider) + record = await service.get_record("lora", 7) + + assert record is not None + assert record.versions + preview_url = record.versions[0].preview_url + assert ( + preview_url + == "https://image.civitai.com/safe/width=450,optimized=true/preview.png" + )