diff --git a/py/routes/handlers/misc_handlers.py b/py/routes/handlers/misc_handlers.py index 3f9d0ffa..554f5627 100644 --- a/py/routes/handlers/misc_handlers.py +++ b/py/routes/handlers/misc_handlers.py @@ -751,6 +751,7 @@ class ServiceRegistryAdapter: get_lora_scanner: Callable[[], Awaitable] get_checkpoint_scanner: Callable[[], Awaitable] get_embedding_scanner: Callable[[], Awaitable] + get_downloaded_version_history_service: Callable[[], Awaitable] class ModelLibraryHandler: @@ -764,6 +765,41 @@ class ModelLibraryHandler: self._service_registry = service_registry self._metadata_provider_factory = metadata_provider_factory + @staticmethod + def _normalize_model_type(model_type: str | None) -> str | None: + if not isinstance(model_type, str): + return None + normalized = model_type.strip().lower() + if normalized in {"lora", "locon", "dora"}: + return "lora" + if normalized == "checkpoint": + return "checkpoint" + if normalized in {"embedding", "textualinversion"}: + return "embedding" + return None + + async def _get_scanner_for_type(self, model_type: str | None): + normalized_type = self._normalize_model_type(model_type) + if normalized_type == "lora": + return normalized_type, await self._service_registry.get_lora_scanner() + if normalized_type == "checkpoint": + return normalized_type, await self._service_registry.get_checkpoint_scanner() + if normalized_type == "embedding": + return normalized_type, await self._service_registry.get_embedding_scanner() + return None, None + + async def _get_download_history_service(self): + return await self._service_registry.get_downloaded_version_history_service() + + @staticmethod + def _with_downloaded_flag(versions: list[dict]) -> list[dict]: + enriched: list[dict] = [] + for version in versions: + entry = dict(version) + entry.setdefault("hasBeenDownloaded", True) + enriched.append(entry) + return enriched + async def check_model_exists(self, request: web.Request) -> web.Response: try: model_id_str = request.query.get("modelId") @@ -819,11 +855,30 @@ class ModelLibraryHandler: exists = True model_type = "embedding" + history_service = await self._get_download_history_service() + has_been_downloaded = False + history_type = model_type + if history_type: + has_been_downloaded = await history_service.has_been_downloaded( + history_type, + model_version_id, + ) + else: + for candidate_type in ("lora", "checkpoint", "embedding"): + if await history_service.has_been_downloaded( + candidate_type, + model_version_id, + ): + has_been_downloaded = True + history_type = candidate_type + break + return web.json_response( { "success": True, "exists": exists, - "modelType": model_type if exists else None, + "modelType": model_type if exists else history_type, + "hasBeenDownloaded": has_been_downloaded, } ) @@ -843,13 +898,13 @@ class ModelLibraryHandler: versions = [] if lora_versions: model_type = "lora" - versions = lora_versions + versions = self._with_downloaded_flag(lora_versions) elif checkpoint_versions: model_type = "checkpoint" - versions = checkpoint_versions + versions = self._with_downloaded_flag(checkpoint_versions) elif embedding_versions: model_type = "embedding" - versions = embedding_versions + versions = self._with_downloaded_flag(embedding_versions) return web.json_response( {"success": True, "modelType": model_type, "versions": versions} @@ -858,6 +913,108 @@ class ModelLibraryHandler: logger.error("Failed to check model existence: %s", exc, exc_info=True) return web.json_response({"success": False, "error": str(exc)}, status=500) + async def get_model_version_download_status( + self, request: web.Request + ) -> web.Response: + try: + model_type, _ = await self._get_scanner_for_type(request.query.get("modelType")) + if not model_type: + return web.json_response( + {"success": False, "error": "Parameter modelType is required"}, + status=400, + ) + + model_version_id_str = request.query.get("modelVersionId") + if not model_version_id_str: + return web.json_response( + {"success": False, "error": "Missing required parameter: modelVersionId"}, + status=400, + ) + try: + model_version_id = int(model_version_id_str) + except ValueError: + return web.json_response( + {"success": False, "error": "Parameter modelVersionId must be an integer"}, + status=400, + ) + + history_service = await self._get_download_history_service() + return web.json_response( + { + "success": True, + "modelType": model_type, + "modelVersionId": model_version_id, + "hasBeenDownloaded": await history_service.has_been_downloaded( + model_type, + model_version_id, + ), + } + ) + except Exception as exc: # pragma: no cover - defensive logging + logger.error( + "Failed to get model version download status: %s", + exc, + exc_info=True, + ) + return web.json_response({"success": False, "error": str(exc)}, status=500) + + async def set_model_version_download_status( + self, request: web.Request + ) -> web.Response: + try: + data = await request.json() + model_type, _ = await self._get_scanner_for_type(data.get("modelType")) + if not model_type: + return web.json_response( + {"success": False, "error": "Parameter modelType is required"}, + status=400, + ) + + try: + model_version_id = int(data.get("modelVersionId")) + except (TypeError, ValueError): + return web.json_response( + {"success": False, "error": "Parameter modelVersionId must be an integer"}, + status=400, + ) + + downloaded = data.get("downloaded") + if not isinstance(downloaded, bool): + return web.json_response( + {"success": False, "error": "Parameter downloaded must be a boolean"}, + status=400, + ) + + history_service = await self._get_download_history_service() + if downloaded: + model_id = data.get("modelId") + file_path = data.get("filePath") + await history_service.mark_downloaded( + model_type, + model_version_id, + model_id=model_id, + source="manual", + file_path=file_path if isinstance(file_path, str) else None, + ) + else: + await history_service.mark_not_downloaded(model_type, model_version_id) + + return web.json_response( + { + "success": True, + "modelType": model_type, + "modelVersionId": model_version_id, + "hasBeenDownloaded": downloaded, + } + ) + except Exception as exc: # pragma: no cover - defensive logging + logger.error( + "Failed to set model version download status: %s", + exc, + exc_info=True, + ) + return web.json_response({"success": False, "error": str(exc)}, status=500) + async def get_model_versions_status(self, request: web.Request) -> web.Response: try: model_id_str = request.query.get("modelId") @@ -896,18 +1053,8 @@ class ModelLibraryHandler: model_name = response.get("name", "") model_type = response.get("type", "").lower() - scanner = None - normalized_type = None - if model_type in {"lora", "locon", "dora"}: - scanner = await self._service_registry.get_lora_scanner() - normalized_type = "lora" - elif model_type == "checkpoint": - scanner = await self._service_registry.get_checkpoint_scanner() - normalized_type = "checkpoint" - elif model_type == "textualinversion": - scanner = await self._service_registry.get_embedding_scanner() - normalized_type = "embedding" - else: + normalized_type, scanner = await self._get_scanner_for_type(model_type) + if not normalized_type: return web.json_response( { "success": False, @@ -925,8 +1072,14 @@ class ModelLibraryHandler: status=503, ) + history_service = await self._get_download_history_service() local_versions = await scanner.get_model_versions_by_id(model_id) local_version_ids = {version["versionId"] for version in local_versions} + downloaded_version_ids = await history_service.get_downloaded_version_ids( + normalized_type, + model_id, + ) + downloaded_version_id_set = set(downloaded_version_ids) enriched_versions = [] for version in versions: @@ -939,6 +1092,7 @@ class ModelLibraryHandler: if version.get("images") else None, "inLibrary": version_id in local_version_ids, + "hasBeenDownloaded": version_id in downloaded_version_id_set, } ) @@ -1007,6 +1161,33 @@ class ModelLibraryHandler: } versions: list[dict] = [] + history_service = await self._get_download_history_service() + model_ids: list[int] = [] + for model in models: + try: + model_ids.append(int(model.get("id"))) + except (TypeError, ValueError): + continue + + lora_downloaded = await history_service.get_downloaded_version_ids_bulk( + "lora", + model_ids, + ) + checkpoint_downloaded = await history_service.get_downloaded_version_ids_bulk( + "checkpoint", + model_ids, + ) + embedding_downloaded = await history_service.get_downloaded_version_ids_bulk( + "embedding", + model_ids, + ) + downloaded_version_map: Dict[str, Dict[int, set[int]]] = { + "lora": lora_downloaded, + "locon": lora_downloaded, + "dora": lora_downloaded, + "checkpoint": checkpoint_downloaded, + "textualinversion": embedding_downloaded, + } for model in models: if not isinstance(model, dict): continue @@ -1061,6 +1242,8 @@ class ModelLibraryHandler: in_library = await scanner.check_model_version_exists( version_id_int ) + downloaded_versions = downloaded_version_map.get(model_type, {}) + downloaded_version_ids = downloaded_versions.get(model_id_int, set()) versions.append( { @@ -1073,6 +1256,7 @@ class ModelLibraryHandler: "baseModel": version.get("baseModel"), "thumbnailUrl": thumbnail_url, "inLibrary": in_library, + "hasBeenDownloaded": version_id_int in downloaded_version_ids, } ) @@ -1655,6 +1839,8 @@ class MiscHandlerSet: "update_node_widget": self.node_registry.update_node_widget, "get_registry": self.node_registry.get_registry, "check_model_exists": self.model_library.check_model_exists, + "get_model_version_download_status": self.model_library.get_model_version_download_status, + "set_model_version_download_status": self.model_library.set_model_version_download_status, "get_civitai_user_models": self.model_library.get_civitai_user_models, "download_metadata_archive": self.metadata_archive.download_metadata_archive, "remove_metadata_archive": self.metadata_archive.remove_metadata_archive, @@ -1679,4 +1865,5 @@ def build_service_registry_adapter() -> ServiceRegistryAdapter: get_lora_scanner=ServiceRegistry.get_lora_scanner, get_checkpoint_scanner=ServiceRegistry.get_checkpoint_scanner, get_embedding_scanner=ServiceRegistry.get_embedding_scanner, + get_downloaded_version_history_service=ServiceRegistry.get_downloaded_version_history_service, ) diff --git a/py/routes/misc_route_registrar.py b/py/routes/misc_route_registrar.py index e77ed579..9f7a35c9 100644 --- a/py/routes/misc_route_registrar.py +++ b/py/routes/misc_route_registrar.py @@ -37,6 +37,16 @@ MISC_ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = ( RouteDefinition("POST", "/api/lm/update-node-widget", "update_node_widget"), RouteDefinition("GET", "/api/lm/get-registry", "get_registry"), RouteDefinition("GET", "/api/lm/check-model-exists", "check_model_exists"), + RouteDefinition( + "GET", + "/api/lm/model-version-download-status", + "get_model_version_download_status", + ), + RouteDefinition( + "POST", + "/api/lm/model-version-download-status", + "set_model_version_download_status", + ), RouteDefinition("GET", "/api/lm/civitai/user-models", "get_civitai_user_models"), RouteDefinition( "POST", "/api/lm/download-metadata-archive", "download_metadata_archive" diff --git a/py/services/download_manager.py b/py/services/download_manager.py index 7cdfdb02..4f32faee 100644 --- a/py/services/download_manager.py +++ b/py/services/download_manager.py @@ -640,6 +640,13 @@ class DownloadManager: or version_info.get("modelId") or (version_info.get("model") or {}).get("id") ) + await self._record_downloaded_version_history( + model_type, + resolved_model_id, + version_info, + model_version_id, + save_path, + ) await self._sync_downloaded_version( model_type, resolved_model_id, @@ -669,6 +676,55 @@ class DownloadManager: } return {"success": False, "error": str(e)} + async def _record_downloaded_version_history( + self, + model_type: str, + model_id_value, + version_info: Dict, + fallback_version_id=None, + file_path: str | None = None, + ) -> None: + try: + history_service = await ServiceRegistry.get_downloaded_version_history_service() + except Exception as exc: + logger.debug( + "Skipping download history sync; failed to acquire history service: %s", + exc, + ) + return + + if history_service is None: + return + + resolved_model_id = model_id_value + if resolved_model_id is None: + resolved_model_id = version_info.get("modelId") + if resolved_model_id is None: + model_info = version_info.get("model") + if isinstance(model_info, dict): + resolved_model_id = model_info.get("id") + + version_id = version_info.get("id") + if version_id is None: + version_id = fallback_version_id + + try: + await history_service.mark_downloaded( + model_type, + int(version_id), + model_id=int(resolved_model_id) if resolved_model_id is not None else None, + source="download", + file_path=file_path, + ) + except (TypeError, ValueError): + logger.debug( + "Skipping download history sync; invalid identifiers model=%s version=%s", + resolved_model_id, + version_id, + ) + except Exception as exc: + logger.debug("Failed to sync download history for %s: %s", model_type, exc) + async def _sync_downloaded_version( self, model_type: str, diff --git a/py/services/downloaded_version_history_service.py b/py/services/downloaded_version_history_service.py new file mode 100644 index 00000000..dc1d3cff --- /dev/null +++ b/py/services/downloaded_version_history_service.py @@ -0,0 +1,313 @@ +from __future__ import annotations + +import asyncio +import logging +import os +import sqlite3 +import time +from typing import Iterable, Mapping, Optional, Sequence + +from ..utils.cache_paths import get_cache_base_dir +from .settings_manager import get_settings_manager + +logger = logging.getLogger(__name__) + + +def _normalize_model_type(model_type: str | None) -> Optional[str]: + if not isinstance(model_type, str): + return None + normalized = model_type.strip().lower() + if normalized in {"lora", "locon", "dora"}: + return "lora" + if normalized == "checkpoint": + return "checkpoint" + if normalized in {"embedding", "textualinversion"}: + return "embedding" + return None + + +def _normalize_int(value) -> Optional[int]: + try: + if value is None: + return None + return int(value) + except (TypeError, ValueError): + return None + + +def _resolve_database_path() -> str: + base_dir = get_cache_base_dir(create=True) + history_dir = os.path.join(base_dir, "download_history") + os.makedirs(history_dir, exist_ok=True) + return os.path.join(history_dir, "downloaded_versions.sqlite") + + +class DownloadedVersionHistoryService: + _SCHEMA = """ + CREATE TABLE IF NOT EXISTS downloaded_model_versions ( + model_type TEXT NOT NULL, + version_id INTEGER NOT NULL, + model_id INTEGER, + first_seen_at REAL NOT NULL, + last_seen_at REAL NOT NULL, + source TEXT NOT NULL, + last_file_path TEXT, + last_library_name TEXT, + is_deleted_override INTEGER NOT NULL DEFAULT 0, + PRIMARY KEY (model_type, version_id) + ); + CREATE INDEX IF NOT EXISTS idx_downloaded_model_versions_model + ON downloaded_model_versions(model_type, model_id); + """ + + def __init__(self, db_path: str | None = None, *, settings_manager=None) -> None: + self._db_path = db_path or _resolve_database_path() + self._settings = settings_manager or get_settings_manager() + self._lock = asyncio.Lock() + self._schema_initialized = False + self._ensure_directory() + self._initialize_schema() + + def _ensure_directory(self) -> None: + directory = os.path.dirname(self._db_path) + if directory: + os.makedirs(directory, exist_ok=True) + + def _connect(self) -> sqlite3.Connection: + conn = sqlite3.connect(self._db_path, check_same_thread=False) + conn.row_factory = sqlite3.Row + return conn + + def _initialize_schema(self) -> None: + if self._schema_initialized: + return + with self._connect() as conn: + conn.executescript(self._SCHEMA) + conn.commit() + self._schema_initialized = True + + def get_database_path(self) -> str: + return self._db_path + + def _get_active_library_name(self) -> str | None: + try: + value = self._settings.get_active_library_name() + except Exception: + return None + return value or None + + async def mark_downloaded( + self, + model_type: str, + version_id: int, + *, + model_id: int | None = None, + source: str = "manual", + file_path: str | None = None, + library_name: str | None = None, + ) -> None: + normalized_type = _normalize_model_type(model_type) + normalized_version_id = _normalize_int(version_id) + normalized_model_id = _normalize_int(model_id) + if normalized_type is None or normalized_version_id is None: + return + + active_library_name = library_name or self._get_active_library_name() + timestamp = time.time() + + async with self._lock: + with self._connect() as conn: + conn.execute( + """ + INSERT INTO downloaded_model_versions ( + model_type, version_id, model_id, first_seen_at, last_seen_at, + source, last_file_path, last_library_name, is_deleted_override + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, 0) + ON CONFLICT(model_type, version_id) DO UPDATE SET + model_id = COALESCE(excluded.model_id, downloaded_model_versions.model_id), + last_seen_at = excluded.last_seen_at, + source = excluded.source, + last_file_path = COALESCE(excluded.last_file_path, downloaded_model_versions.last_file_path), + last_library_name = COALESCE(excluded.last_library_name, downloaded_model_versions.last_library_name), + is_deleted_override = 0 + """, + ( + normalized_type, + normalized_version_id, + normalized_model_id, + timestamp, + timestamp, + source, + file_path, + active_library_name, + ), + ) + conn.commit() + + async def mark_downloaded_bulk( + self, + model_type: str, + records: Sequence[Mapping[str, object]], + *, + source: str = "scan", + library_name: str | None = None, + ) -> None: + normalized_type = _normalize_model_type(model_type) + if normalized_type is None or not records: + return + + timestamp = time.time() + active_library_name = library_name or self._get_active_library_name() + payload: list[tuple[object, ...]] = [] + for record in records: + version_id = _normalize_int(record.get("version_id")) + if version_id is None: + continue + payload.append( + ( + normalized_type, + version_id, + _normalize_int(record.get("model_id")), + timestamp, + timestamp, + source, + record.get("file_path"), + active_library_name, + ) + ) + + if not payload: + return + + async with self._lock: + with self._connect() as conn: + conn.executemany( + """ + INSERT INTO downloaded_model_versions ( + model_type, version_id, model_id, first_seen_at, last_seen_at, + source, last_file_path, last_library_name, is_deleted_override + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, 0) + ON CONFLICT(model_type, version_id) DO UPDATE SET + model_id = COALESCE(excluded.model_id, downloaded_model_versions.model_id), + last_seen_at = excluded.last_seen_at, + source = excluded.source, + last_file_path = COALESCE(excluded.last_file_path, downloaded_model_versions.last_file_path), + last_library_name = COALESCE(excluded.last_library_name, downloaded_model_versions.last_library_name), + is_deleted_override = 0 + """, + payload, + ) + conn.commit() + + async def mark_not_downloaded(self, model_type: str, version_id: int) -> None: + normalized_type = _normalize_model_type(model_type) + normalized_version_id = _normalize_int(version_id) + if normalized_type is None or normalized_version_id is None: + return + + timestamp = time.time() + + async with self._lock: + with self._connect() as conn: + conn.execute( + """ + INSERT INTO downloaded_model_versions ( + model_type, version_id, model_id, first_seen_at, last_seen_at, + source, last_file_path, last_library_name, is_deleted_override + ) VALUES (?, ?, NULL, ?, ?, 'manual', NULL, ?, 1) + ON CONFLICT(model_type, version_id) DO UPDATE SET + last_seen_at = excluded.last_seen_at, + source = excluded.source, + last_library_name = COALESCE(excluded.last_library_name, downloaded_model_versions.last_library_name), + is_deleted_override = 1 + """, + ( + normalized_type, + normalized_version_id, + timestamp, + timestamp, + self._get_active_library_name(), + ), + ) + conn.commit() + + async def has_been_downloaded(self, model_type: str, version_id: int) -> bool: + normalized_type = _normalize_model_type(model_type) + normalized_version_id = _normalize_int(version_id) + if normalized_type is None or normalized_version_id is None: + return False + + async with self._lock: + with self._connect() as conn: + row = conn.execute( + """ + SELECT is_deleted_override + FROM downloaded_model_versions + WHERE model_type = ? AND version_id = ? + """, + (normalized_type, normalized_version_id), + ).fetchone() + return bool(row) and not bool(row["is_deleted_override"]) + + async def get_downloaded_version_ids( + self, model_type: str, model_id: int + ) -> list[int]: + normalized_type = _normalize_model_type(model_type) + normalized_model_id = _normalize_int(model_id) + if normalized_type is None or normalized_model_id is None: + return [] + + async with self._lock: + with self._connect() as conn: + rows = conn.execute( + """ + SELECT version_id + FROM downloaded_model_versions + WHERE model_type = ? AND model_id = ? AND is_deleted_override = 0 + ORDER BY version_id ASC + """, + (normalized_type, normalized_model_id), + ).fetchall() + return [int(row["version_id"]) for row in rows] + + async def get_downloaded_version_ids_bulk( + self, model_type: str, model_ids: Iterable[int] + ) -> dict[int, set[int]]: + normalized_type = _normalize_model_type(model_type) + if normalized_type is None: + return {} + + normalized_model_ids = sorted( + { + value + for value in (_normalize_int(model_id) for model_id in model_ids) + if value is not None + } + ) + if not normalized_model_ids: + return {} + + placeholders = ", ".join(["?"] * len(normalized_model_ids)) + params: list[object] = [normalized_type, *normalized_model_ids] + + async with self._lock: + with self._connect() as conn: + rows = conn.execute( + f""" + SELECT model_id, version_id + FROM downloaded_model_versions + WHERE model_type = ? + AND model_id IN ({placeholders}) + AND is_deleted_override = 0 + """, + params, + ).fetchall() + + result: dict[int, set[int]] = {} + for row in rows: + model_id = _normalize_int(row["model_id"]) + version_id = _normalize_int(row["version_id"]) + if model_id is None or version_id is None: + continue + result.setdefault(model_id, set()).add(version_id) + return result diff --git a/py/services/model_scanner.py b/py/services/model_scanner.py index f208b790..a364214c 100644 --- a/py/services/model_scanner.py +++ b/py/services/model_scanner.py @@ -411,6 +411,7 @@ class ModelScanner: if scan_result: await self._apply_scan_result(scan_result) await self._save_persistent_cache(scan_result) + await self._sync_download_history(scan_result.raw_data, source='scan') # Send final progress update await ws_manager.broadcast_init_progress({ @@ -516,6 +517,7 @@ class ModelScanner: ) await self._apply_scan_result(scan_result) + await self._sync_download_history(adjusted_raw_data, source='scan') await ws_manager.broadcast_init_progress({ 'stage': 'loading_cache', @@ -576,6 +578,7 @@ class ModelScanner: excluded_models=list(self._excluded_models) ) await self._save_persistent_cache(snapshot) + await self._sync_download_history(snapshot.raw_data, source='scan') def _count_model_files(self) -> int: """Count all model files with supported extensions in all roots @@ -704,6 +707,7 @@ class ModelScanner: scan_result = await self._gather_model_data() await self._apply_scan_result(scan_result) await self._save_persistent_cache(scan_result) + await self._sync_download_history(scan_result.raw_data, source='scan') logger.info( f"{self.model_type.capitalize()} Scanner: Cache initialization completed in {time.time() - start_time:.2f} seconds, " @@ -1101,6 +1105,49 @@ class ModelScanner: await self._cache.resort() + async def _sync_download_history( + self, + raw_data: List[Mapping[str, Any]], + *, + source: str, + ) -> None: + records: List[Dict[str, Any]] = [] + for item in raw_data or []: + if not isinstance(item, Mapping): + continue + civitai = item.get('civitai') + if not isinstance(civitai, Mapping): + continue + + version_id = civitai.get('id') + if version_id in (None, ''): + continue + + records.append( + { + 'version_id': version_id, + 'model_id': civitai.get('modelId'), + 'file_path': item.get('file_path'), + } + ) + + if not records: + return + + try: + history_service = await ServiceRegistry.get_downloaded_version_history_service() + await history_service.mark_downloaded_bulk( + self.model_type, + records, + source=source, + ) + except Exception as exc: + logger.debug( + "%s Scanner: Failed to sync download history: %s", + self.model_type.capitalize(), + exc, + ) + async def _gather_model_data( self, *, diff --git a/py/services/service_registry.py b/py/services/service_registry.py index 4e3bea57..8d9319de 100644 --- a/py/services/service_registry.py +++ b/py/services/service_registry.py @@ -167,6 +167,28 @@ class ServiceRegistry: logger.debug(f"Created and registered {service_name}") return service + @classmethod + async def get_downloaded_version_history_service(cls): + """Get or create the downloaded-version history service.""" + + service_name = "downloaded_version_history_service" + + if service_name in cls._services: + return cls._services[service_name] + + async with cls._get_lock(service_name): + if service_name in cls._services: + return cls._services[service_name] + + from .downloaded_version_history_service import ( + DownloadedVersionHistoryService, + ) + + service = DownloadedVersionHistoryService() + cls._services[service_name] = service + logger.debug(f"Created and registered {service_name}") + return service + @classmethod async def get_civarchive_client(cls): """Get or create CivArchive client instance""" @@ -255,4 +277,4 @@ class ServiceRegistry: """Clear all registered services - mainly for testing""" cls._services.clear() cls._locks.clear() - logger.info("Cleared all registered services") \ No newline at end of file + logger.info("Cleared all registered services") diff --git a/tests/routes/test_api_snapshots.py b/tests/routes/test_api_snapshots.py index 90239b85..4452524b 100644 --- a/tests/routes/test_api_snapshots.py +++ b/tests/routes/test_api_snapshots.py @@ -66,6 +66,27 @@ class FakePromptServer: instance = Instance() +class FakeDownloadHistoryService: + async def has_been_downloaded(self, _model_type, _version_id): + return False + + async def get_downloaded_version_ids(self, _model_type, _model_id): + return [] + + async def get_downloaded_version_ids_bulk(self, _model_type, _model_ids): + return {} + + async def mark_downloaded(self, *_args, **_kwargs): + return None + + async def mark_not_downloaded(self, *_args, **_kwargs): + return None + + +async def fake_download_history_service_factory(): + return FakeDownloadHistoryService() + + class TestSettingsHandlerSnapshots: """Snapshot tests for SettingsHandler responses.""" @@ -223,6 +244,7 @@ class TestModelLibraryHandlerSnapshots: get_lora_scanner=scanner_factory, get_checkpoint_scanner=scanner_factory, get_embedding_scanner=scanner_factory, + get_downloaded_version_history_service=fake_download_history_service_factory, ), metadata_provider_factory=lambda: None, ) diff --git a/tests/routes/test_misc_routes.py b/tests/routes/test_misc_routes.py index 1cf4dad1..67828a33 100644 --- a/tests/routes/test_misc_routes.py +++ b/tests/routes/test_misc_routes.py @@ -438,6 +438,46 @@ async def fake_metadata_archive_manager_factory(): return FakeMetadataArchiveManager() +class FakeDownloadHistoryService: + def __init__(self, downloaded_by_type=None): + self.downloaded_by_type = downloaded_by_type or {} + self.marked_downloaded: list[tuple] = [] + self.marked_not_downloaded: list[tuple] = [] + + async def has_been_downloaded(self, model_type, version_id): + return version_id in self.downloaded_by_type.get(model_type, set()) + + async def get_downloaded_version_ids(self, model_type, model_id): + entries = self.downloaded_by_type.get(model_type, {}) + if isinstance(entries, dict): + return sorted(entries.get(model_id, set())) + return [] + + async def get_downloaded_version_ids_bulk(self, model_type, model_ids): + entries = self.downloaded_by_type.get(model_type, {}) + if not isinstance(entries, dict): + return {} + return { + model_id: set(entries.get(model_id, set())) + for model_id in model_ids + if model_id in entries + } + + async def mark_downloaded( + self, model_type, version_id, *, model_id=None, source="manual", file_path=None + ): + self.marked_downloaded.append( + (model_type, version_id, model_id, source, file_path) + ) + + async def mark_not_downloaded(self, model_type, version_id): + self.marked_not_downloaded.append((model_type, version_id)) + + +async def fake_download_history_service_factory(): + return FakeDownloadHistoryService() + + class RecordingRegistrar: def __init__(self, _app): self.registered_mapping = None @@ -452,6 +492,7 @@ async def test_misc_routes_bind_produces_expected_handlers(): get_lora_scanner=fake_scanner_factory, get_checkpoint_scanner=fake_scanner_factory, get_embedding_scanner=fake_scanner_factory, + get_downloaded_version_history_service=fake_download_history_service_factory, ) recorded_registrars = [] @@ -578,6 +619,7 @@ async def test_get_civitai_user_models_marks_library_versions(): get_lora_scanner=lora_factory, get_checkpoint_scanner=checkpoint_factory, get_embedding_scanner=embedding_factory, + get_downloaded_version_history_service=lambda: fake_download_history_service_factory(), ), metadata_provider_factory=provider_factory, ) @@ -600,6 +642,7 @@ async def test_get_civitai_user_models_marks_library_versions(): "baseModel": "Flux.1", "thumbnailUrl": "http://example.com/a1.jpg", "inLibrary": False, + "hasBeenDownloaded": False, }, { "modelId": 1, @@ -611,6 +654,7 @@ async def test_get_civitai_user_models_marks_library_versions(): "baseModel": "Flux.1", "thumbnailUrl": "http://example.com/a2.jpg", "inLibrary": True, + "hasBeenDownloaded": False, }, { "modelId": 2, @@ -622,6 +666,7 @@ async def test_get_civitai_user_models_marks_library_versions(): "baseModel": None, "thumbnailUrl": "http://example.com/e1.jpg", "inLibrary": False, + "hasBeenDownloaded": False, }, { "modelId": 2, @@ -633,6 +678,7 @@ async def test_get_civitai_user_models_marks_library_versions(): "baseModel": None, "thumbnailUrl": None, "inLibrary": True, + "hasBeenDownloaded": False, }, { "modelId": 3, @@ -644,6 +690,7 @@ async def test_get_civitai_user_models_marks_library_versions(): "baseModel": "SDXL", "thumbnailUrl": None, "inLibrary": False, + "hasBeenDownloaded": False, }, ] @@ -692,6 +739,7 @@ async def test_get_civitai_user_models_rewrites_civitai_previews(): get_lora_scanner=fake_scanner_factory, get_checkpoint_scanner=fake_scanner_factory, get_embedding_scanner=fake_scanner_factory, + get_downloaded_version_history_service=fake_download_history_service_factory, ), metadata_provider_factory=provider_factory, ) @@ -727,6 +775,7 @@ async def test_get_civitai_user_models_requires_username(): get_lora_scanner=fake_scanner_factory, get_checkpoint_scanner=fake_scanner_factory, get_embedding_scanner=fake_scanner_factory, + get_downloaded_version_history_service=fake_download_history_service_factory, ), metadata_provider_factory=provider_factory, ) @@ -760,6 +809,7 @@ def test_ensure_handler_mapping_caches_result(): get_lora_scanner=fake_scanner_factory, get_checkpoint_scanner=fake_scanner_factory, get_embedding_scanner=fake_scanner_factory, + get_downloaded_version_history_service=fake_download_history_service_factory, ), metadata_provider_factory=fake_metadata_provider_factory, metadata_archive_manager_factory=fake_metadata_archive_manager_factory, @@ -802,6 +852,7 @@ async def test_check_model_exists_returns_local_versions(): get_lora_scanner=lora_factory, get_checkpoint_scanner=checkpoint_factory, get_embedding_scanner=embedding_factory, + get_downloaded_version_history_service=fake_download_history_service_factory, ), metadata_provider_factory=fake_metadata_provider_factory, ) @@ -811,10 +862,94 @@ async def test_check_model_exists_returns_local_versions(): assert payload["success"] is True assert payload["modelType"] == "lora" - assert payload["versions"] == versions + assert payload["versions"] == [ + {"versionId": 11, "name": "v1", "fileName": "model-one", "hasBeenDownloaded": True}, + {"versionId": 12, "name": "v2", "fileName": "model-two", "hasBeenDownloaded": True}, + ] assert lora_scanner.version_calls == [5] +@pytest.mark.asyncio +async def test_check_model_exists_returns_download_history_when_file_missing(): + history_service = FakeDownloadHistoryService({"checkpoint": {999}}) + + async def history_factory(): + return history_service + + handler = ModelLibraryHandler( + ServiceRegistryAdapter( + get_lora_scanner=fake_scanner_factory, + get_checkpoint_scanner=fake_scanner_factory, + get_embedding_scanner=fake_scanner_factory, + get_downloaded_version_history_service=history_factory, + ), + metadata_provider_factory=fake_metadata_provider_factory, + ) + + response = await handler.check_model_exists( + FakeRequest(query={"modelId": "5", "modelVersionId": "999"}) + ) + payload = json.loads(response.text) + + assert payload == { + "success": True, + "exists": False, + "modelType": "checkpoint", + "hasBeenDownloaded": True, + } + + +@pytest.mark.asyncio +async def test_model_version_download_status_endpoints(): + history_service = FakeDownloadHistoryService({"lora": {123}}) + + async def history_factory(): + return history_service + + handler = ModelLibraryHandler( + ServiceRegistryAdapter( + get_lora_scanner=fake_scanner_factory, + get_checkpoint_scanner=fake_scanner_factory, + get_embedding_scanner=fake_scanner_factory, + get_downloaded_version_history_service=history_factory, + ), + metadata_provider_factory=fake_metadata_provider_factory, + ) + + get_response = await handler.get_model_version_download_status( + FakeRequest(query={"modelType": "lora", "modelVersionId": "123"}) + ) + get_payload = json.loads(get_response.text) + assert get_payload == { + "success": True, + "modelType": "lora", + "modelVersionId": 123, + "hasBeenDownloaded": True, + } + + set_response = await handler.set_model_version_download_status( + FakeRequest( + json_data={ + "modelType": "checkpoint", + "modelVersionId": 456, + "modelId": 78, + "downloaded": True, + "filePath": "/tmp/model.safetensors", + } + ) + ) + set_payload = json.loads(set_response.text) + assert set_payload == { + "success": True, + "modelType": "checkpoint", + "modelVersionId": 456, + "hasBeenDownloaded": True, + } + assert history_service.marked_downloaded == [ + ("checkpoint", 456, 78, "manual", "/tmp/model.safetensors") + ] + + def test_create_handler_set_uses_provided_dependencies(): recorded_handlers: list[dict] = [] @@ -845,6 +980,7 @@ def test_create_handler_set_uses_provided_dependencies(): get_lora_scanner=fake_scanner_factory, get_checkpoint_scanner=fake_scanner_factory, get_embedding_scanner=fake_scanner_factory, + get_downloaded_version_history_service=fake_download_history_service_factory, ), metadata_provider_factory=fake_metadata_provider_factory, metadata_archive_manager_factory=fake_metadata_archive_manager_factory, diff --git a/tests/services/test_downloaded_version_history_service.py b/tests/services/test_downloaded_version_history_service.py new file mode 100644 index 00000000..3e58f73c --- /dev/null +++ b/tests/services/test_downloaded_version_history_service.py @@ -0,0 +1,70 @@ +from pathlib import Path + +import pytest + +from py.services.downloaded_version_history_service import ( + DownloadedVersionHistoryService, +) + + +class DummySettings: + def get_active_library_name(self) -> str: + return "alpha" + + +@pytest.mark.asyncio +async def test_download_history_roundtrip_and_manual_override(tmp_path: Path) -> None: + db_path = tmp_path / "download-history.sqlite" + service = DownloadedVersionHistoryService( + str(db_path), + settings_manager=DummySettings(), + ) + + await service.mark_downloaded( + "lora", + 101, + model_id=11, + source="scan", + file_path="/models/a.safetensors", + ) + assert await service.has_been_downloaded("lora", 101) is True + assert await service.get_downloaded_version_ids("lora", 11) == [101] + + await service.mark_not_downloaded("lora", 101) + assert await service.has_been_downloaded("lora", 101) is False + assert await service.get_downloaded_version_ids("lora", 11) == [] + + await service.mark_downloaded( + "lora", + 101, + model_id=11, + source="download", + file_path="/models/a.safetensors", + ) + assert await service.has_been_downloaded("lora", 101) is True + assert await service.get_downloaded_version_ids("lora", 11) == [101] + + +@pytest.mark.asyncio +async def test_download_history_bulk_lookup(tmp_path: Path) -> None: + db_path = tmp_path / "download-history.sqlite" + service = DownloadedVersionHistoryService( + str(db_path), + settings_manager=DummySettings(), + ) + + await service.mark_downloaded_bulk( + "checkpoint", + [ + {"model_id": 5, "version_id": 501, "file_path": "/m/one.safetensors"}, + {"model_id": 5, "version_id": 502, "file_path": "/m/two.safetensors"}, + {"model_id": 6, "version_id": 601, "file_path": "/m/three.safetensors"}, + ], + source="scan", + ) + + assert await service.get_downloaded_version_ids("checkpoint", 5) == [501, 502] + assert await service.get_downloaded_version_ids_bulk("checkpoint", [5, 6, 7]) == { + 5: {501, 502}, + 6: {601}, + }