feat(download-history): track downloaded model versions

This commit is contained in:
Will Miao
2026-04-03 16:13:14 +08:00
parent 4f599aeced
commit 33a7f07558
9 changed files with 881 additions and 18 deletions

View File

@@ -751,6 +751,7 @@ class ServiceRegistryAdapter:
get_lora_scanner: Callable[[], Awaitable] get_lora_scanner: Callable[[], Awaitable]
get_checkpoint_scanner: Callable[[], Awaitable] get_checkpoint_scanner: Callable[[], Awaitable]
get_embedding_scanner: Callable[[], Awaitable] get_embedding_scanner: Callable[[], Awaitable]
get_downloaded_version_history_service: Callable[[], Awaitable]
class ModelLibraryHandler: class ModelLibraryHandler:
@@ -764,6 +765,41 @@ class ModelLibraryHandler:
self._service_registry = service_registry self._service_registry = service_registry
self._metadata_provider_factory = metadata_provider_factory self._metadata_provider_factory = metadata_provider_factory
@staticmethod
def _normalize_model_type(model_type: str | None) -> str | None:
if not isinstance(model_type, str):
return None
normalized = model_type.strip().lower()
if normalized in {"lora", "locon", "dora"}:
return "lora"
if normalized == "checkpoint":
return "checkpoint"
if normalized in {"embedding", "textualinversion"}:
return "embedding"
return None
async def _get_scanner_for_type(self, model_type: str | None):
normalized_type = self._normalize_model_type(model_type)
if normalized_type == "lora":
return normalized_type, await self._service_registry.get_lora_scanner()
if normalized_type == "checkpoint":
return normalized_type, await self._service_registry.get_checkpoint_scanner()
if normalized_type == "embedding":
return normalized_type, await self._service_registry.get_embedding_scanner()
return None, None
async def _get_download_history_service(self):
return await self._service_registry.get_downloaded_version_history_service()
@staticmethod
def _with_downloaded_flag(versions: list[dict]) -> list[dict]:
enriched: list[dict] = []
for version in versions:
entry = dict(version)
entry.setdefault("hasBeenDownloaded", True)
enriched.append(entry)
return enriched
async def check_model_exists(self, request: web.Request) -> web.Response: async def check_model_exists(self, request: web.Request) -> web.Response:
try: try:
model_id_str = request.query.get("modelId") model_id_str = request.query.get("modelId")
@@ -819,11 +855,30 @@ class ModelLibraryHandler:
exists = True exists = True
model_type = "embedding" model_type = "embedding"
history_service = await self._get_download_history_service()
has_been_downloaded = False
history_type = model_type
if history_type:
has_been_downloaded = await history_service.has_been_downloaded(
history_type,
model_version_id,
)
else:
for candidate_type in ("lora", "checkpoint", "embedding"):
if await history_service.has_been_downloaded(
candidate_type,
model_version_id,
):
has_been_downloaded = True
history_type = candidate_type
break
return web.json_response( return web.json_response(
{ {
"success": True, "success": True,
"exists": exists, "exists": exists,
"modelType": model_type if exists else None, "modelType": model_type if exists else history_type,
"hasBeenDownloaded": has_been_downloaded,
} }
) )
@@ -843,13 +898,13 @@ class ModelLibraryHandler:
versions = [] versions = []
if lora_versions: if lora_versions:
model_type = "lora" model_type = "lora"
versions = lora_versions versions = self._with_downloaded_flag(lora_versions)
elif checkpoint_versions: elif checkpoint_versions:
model_type = "checkpoint" model_type = "checkpoint"
versions = checkpoint_versions versions = self._with_downloaded_flag(checkpoint_versions)
elif embedding_versions: elif embedding_versions:
model_type = "embedding" model_type = "embedding"
versions = embedding_versions versions = self._with_downloaded_flag(embedding_versions)
return web.json_response( return web.json_response(
{"success": True, "modelType": model_type, "versions": versions} {"success": True, "modelType": model_type, "versions": versions}
@@ -858,6 +913,108 @@ class ModelLibraryHandler:
logger.error("Failed to check model existence: %s", exc, exc_info=True) logger.error("Failed to check model existence: %s", exc, exc_info=True)
return web.json_response({"success": False, "error": str(exc)}, status=500) return web.json_response({"success": False, "error": str(exc)}, status=500)
async def get_model_version_download_status(
self, request: web.Request
) -> web.Response:
try:
model_type, _ = await self._get_scanner_for_type(request.query.get("modelType"))
if not model_type:
return web.json_response(
{"success": False, "error": "Parameter modelType is required"},
status=400,
)
model_version_id_str = request.query.get("modelVersionId")
if not model_version_id_str:
return web.json_response(
{"success": False, "error": "Missing required parameter: modelVersionId"},
status=400,
)
try:
model_version_id = int(model_version_id_str)
except ValueError:
return web.json_response(
{"success": False, "error": "Parameter modelVersionId must be an integer"},
status=400,
)
history_service = await self._get_download_history_service()
return web.json_response(
{
"success": True,
"modelType": model_type,
"modelVersionId": model_version_id,
"hasBeenDownloaded": await history_service.has_been_downloaded(
model_type,
model_version_id,
),
}
)
except Exception as exc: # pragma: no cover - defensive logging
logger.error(
"Failed to get model version download status: %s",
exc,
exc_info=True,
)
return web.json_response({"success": False, "error": str(exc)}, status=500)
async def set_model_version_download_status(
self, request: web.Request
) -> web.Response:
try:
data = await request.json()
model_type, _ = await self._get_scanner_for_type(data.get("modelType"))
if not model_type:
return web.json_response(
{"success": False, "error": "Parameter modelType is required"},
status=400,
)
try:
model_version_id = int(data.get("modelVersionId"))
except (TypeError, ValueError):
return web.json_response(
{"success": False, "error": "Parameter modelVersionId must be an integer"},
status=400,
)
downloaded = data.get("downloaded")
if not isinstance(downloaded, bool):
return web.json_response(
{"success": False, "error": "Parameter downloaded must be a boolean"},
status=400,
)
history_service = await self._get_download_history_service()
if downloaded:
model_id = data.get("modelId")
file_path = data.get("filePath")
await history_service.mark_downloaded(
model_type,
model_version_id,
model_id=model_id,
source="manual",
file_path=file_path if isinstance(file_path, str) else None,
)
else:
await history_service.mark_not_downloaded(model_type, model_version_id)
return web.json_response(
{
"success": True,
"modelType": model_type,
"modelVersionId": model_version_id,
"hasBeenDownloaded": downloaded,
}
)
except Exception as exc: # pragma: no cover - defensive logging
logger.error(
"Failed to set model version download status: %s",
exc,
exc_info=True,
)
return web.json_response({"success": False, "error": str(exc)}, status=500)
async def get_model_versions_status(self, request: web.Request) -> web.Response: async def get_model_versions_status(self, request: web.Request) -> web.Response:
try: try:
model_id_str = request.query.get("modelId") model_id_str = request.query.get("modelId")
@@ -896,18 +1053,8 @@ class ModelLibraryHandler:
model_name = response.get("name", "") model_name = response.get("name", "")
model_type = response.get("type", "").lower() model_type = response.get("type", "").lower()
scanner = None normalized_type, scanner = await self._get_scanner_for_type(model_type)
normalized_type = None if not normalized_type:
if model_type in {"lora", "locon", "dora"}:
scanner = await self._service_registry.get_lora_scanner()
normalized_type = "lora"
elif model_type == "checkpoint":
scanner = await self._service_registry.get_checkpoint_scanner()
normalized_type = "checkpoint"
elif model_type == "textualinversion":
scanner = await self._service_registry.get_embedding_scanner()
normalized_type = "embedding"
else:
return web.json_response( return web.json_response(
{ {
"success": False, "success": False,
@@ -925,8 +1072,14 @@ class ModelLibraryHandler:
status=503, status=503,
) )
history_service = await self._get_download_history_service()
local_versions = await scanner.get_model_versions_by_id(model_id) local_versions = await scanner.get_model_versions_by_id(model_id)
local_version_ids = {version["versionId"] for version in local_versions} local_version_ids = {version["versionId"] for version in local_versions}
downloaded_version_ids = await history_service.get_downloaded_version_ids(
normalized_type,
model_id,
)
downloaded_version_id_set = set(downloaded_version_ids)
enriched_versions = [] enriched_versions = []
for version in versions: for version in versions:
@@ -939,6 +1092,7 @@ class ModelLibraryHandler:
if version.get("images") if version.get("images")
else None, else None,
"inLibrary": version_id in local_version_ids, "inLibrary": version_id in local_version_ids,
"hasBeenDownloaded": version_id in downloaded_version_id_set,
} }
) )
@@ -1007,6 +1161,33 @@ class ModelLibraryHandler:
} }
versions: list[dict] = [] versions: list[dict] = []
history_service = await self._get_download_history_service()
model_ids: list[int] = []
for model in models:
try:
model_ids.append(int(model.get("id")))
except (TypeError, ValueError):
continue
lora_downloaded = await history_service.get_downloaded_version_ids_bulk(
"lora",
model_ids,
)
checkpoint_downloaded = await history_service.get_downloaded_version_ids_bulk(
"checkpoint",
model_ids,
)
embedding_downloaded = await history_service.get_downloaded_version_ids_bulk(
"embedding",
model_ids,
)
downloaded_version_map: Dict[str, Dict[int, set[int]]] = {
"lora": lora_downloaded,
"locon": lora_downloaded,
"dora": lora_downloaded,
"checkpoint": checkpoint_downloaded,
"textualinversion": embedding_downloaded,
}
for model in models: for model in models:
if not isinstance(model, dict): if not isinstance(model, dict):
continue continue
@@ -1061,6 +1242,8 @@ class ModelLibraryHandler:
in_library = await scanner.check_model_version_exists( in_library = await scanner.check_model_version_exists(
version_id_int version_id_int
) )
downloaded_versions = downloaded_version_map.get(model_type, {})
downloaded_version_ids = downloaded_versions.get(model_id_int, set())
versions.append( versions.append(
{ {
@@ -1073,6 +1256,7 @@ class ModelLibraryHandler:
"baseModel": version.get("baseModel"), "baseModel": version.get("baseModel"),
"thumbnailUrl": thumbnail_url, "thumbnailUrl": thumbnail_url,
"inLibrary": in_library, "inLibrary": in_library,
"hasBeenDownloaded": version_id_int in downloaded_version_ids,
} }
) )
@@ -1655,6 +1839,8 @@ class MiscHandlerSet:
"update_node_widget": self.node_registry.update_node_widget, "update_node_widget": self.node_registry.update_node_widget,
"get_registry": self.node_registry.get_registry, "get_registry": self.node_registry.get_registry,
"check_model_exists": self.model_library.check_model_exists, "check_model_exists": self.model_library.check_model_exists,
"get_model_version_download_status": self.model_library.get_model_version_download_status,
"set_model_version_download_status": self.model_library.set_model_version_download_status,
"get_civitai_user_models": self.model_library.get_civitai_user_models, "get_civitai_user_models": self.model_library.get_civitai_user_models,
"download_metadata_archive": self.metadata_archive.download_metadata_archive, "download_metadata_archive": self.metadata_archive.download_metadata_archive,
"remove_metadata_archive": self.metadata_archive.remove_metadata_archive, "remove_metadata_archive": self.metadata_archive.remove_metadata_archive,
@@ -1679,4 +1865,5 @@ def build_service_registry_adapter() -> ServiceRegistryAdapter:
get_lora_scanner=ServiceRegistry.get_lora_scanner, get_lora_scanner=ServiceRegistry.get_lora_scanner,
get_checkpoint_scanner=ServiceRegistry.get_checkpoint_scanner, get_checkpoint_scanner=ServiceRegistry.get_checkpoint_scanner,
get_embedding_scanner=ServiceRegistry.get_embedding_scanner, get_embedding_scanner=ServiceRegistry.get_embedding_scanner,
get_downloaded_version_history_service=ServiceRegistry.get_downloaded_version_history_service,
) )

View File

@@ -37,6 +37,16 @@ MISC_ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
RouteDefinition("POST", "/api/lm/update-node-widget", "update_node_widget"), RouteDefinition("POST", "/api/lm/update-node-widget", "update_node_widget"),
RouteDefinition("GET", "/api/lm/get-registry", "get_registry"), RouteDefinition("GET", "/api/lm/get-registry", "get_registry"),
RouteDefinition("GET", "/api/lm/check-model-exists", "check_model_exists"), RouteDefinition("GET", "/api/lm/check-model-exists", "check_model_exists"),
RouteDefinition(
"GET",
"/api/lm/model-version-download-status",
"get_model_version_download_status",
),
RouteDefinition(
"POST",
"/api/lm/model-version-download-status",
"set_model_version_download_status",
),
RouteDefinition("GET", "/api/lm/civitai/user-models", "get_civitai_user_models"), RouteDefinition("GET", "/api/lm/civitai/user-models", "get_civitai_user_models"),
RouteDefinition( RouteDefinition(
"POST", "/api/lm/download-metadata-archive", "download_metadata_archive" "POST", "/api/lm/download-metadata-archive", "download_metadata_archive"

View File

@@ -640,6 +640,13 @@ class DownloadManager:
or version_info.get("modelId") or version_info.get("modelId")
or (version_info.get("model") or {}).get("id") or (version_info.get("model") or {}).get("id")
) )
await self._record_downloaded_version_history(
model_type,
resolved_model_id,
version_info,
model_version_id,
save_path,
)
await self._sync_downloaded_version( await self._sync_downloaded_version(
model_type, model_type,
resolved_model_id, resolved_model_id,
@@ -669,6 +676,55 @@ class DownloadManager:
} }
return {"success": False, "error": str(e)} return {"success": False, "error": str(e)}
async def _record_downloaded_version_history(
self,
model_type: str,
model_id_value,
version_info: Dict,
fallback_version_id=None,
file_path: str | None = None,
) -> None:
try:
history_service = await ServiceRegistry.get_downloaded_version_history_service()
except Exception as exc:
logger.debug(
"Skipping download history sync; failed to acquire history service: %s",
exc,
)
return
if history_service is None:
return
resolved_model_id = model_id_value
if resolved_model_id is None:
resolved_model_id = version_info.get("modelId")
if resolved_model_id is None:
model_info = version_info.get("model")
if isinstance(model_info, dict):
resolved_model_id = model_info.get("id")
version_id = version_info.get("id")
if version_id is None:
version_id = fallback_version_id
try:
await history_service.mark_downloaded(
model_type,
int(version_id),
model_id=int(resolved_model_id) if resolved_model_id is not None else None,
source="download",
file_path=file_path,
)
except (TypeError, ValueError):
logger.debug(
"Skipping download history sync; invalid identifiers model=%s version=%s",
resolved_model_id,
version_id,
)
except Exception as exc:
logger.debug("Failed to sync download history for %s: %s", model_type, exc)
async def _sync_downloaded_version( async def _sync_downloaded_version(
self, self,
model_type: str, model_type: str,

View File

@@ -0,0 +1,313 @@
from __future__ import annotations
import asyncio
import logging
import os
import sqlite3
import time
from typing import Iterable, Mapping, Optional, Sequence
from ..utils.cache_paths import get_cache_base_dir
from .settings_manager import get_settings_manager
logger = logging.getLogger(__name__)
def _normalize_model_type(model_type: str | None) -> Optional[str]:
if not isinstance(model_type, str):
return None
normalized = model_type.strip().lower()
if normalized in {"lora", "locon", "dora"}:
return "lora"
if normalized == "checkpoint":
return "checkpoint"
if normalized in {"embedding", "textualinversion"}:
return "embedding"
return None
def _normalize_int(value) -> Optional[int]:
try:
if value is None:
return None
return int(value)
except (TypeError, ValueError):
return None
def _resolve_database_path() -> str:
base_dir = get_cache_base_dir(create=True)
history_dir = os.path.join(base_dir, "download_history")
os.makedirs(history_dir, exist_ok=True)
return os.path.join(history_dir, "downloaded_versions.sqlite")
class DownloadedVersionHistoryService:
_SCHEMA = """
CREATE TABLE IF NOT EXISTS downloaded_model_versions (
model_type TEXT NOT NULL,
version_id INTEGER NOT NULL,
model_id INTEGER,
first_seen_at REAL NOT NULL,
last_seen_at REAL NOT NULL,
source TEXT NOT NULL,
last_file_path TEXT,
last_library_name TEXT,
is_deleted_override INTEGER NOT NULL DEFAULT 0,
PRIMARY KEY (model_type, version_id)
);
CREATE INDEX IF NOT EXISTS idx_downloaded_model_versions_model
ON downloaded_model_versions(model_type, model_id);
"""
def __init__(self, db_path: str | None = None, *, settings_manager=None) -> None:
self._db_path = db_path or _resolve_database_path()
self._settings = settings_manager or get_settings_manager()
self._lock = asyncio.Lock()
self._schema_initialized = False
self._ensure_directory()
self._initialize_schema()
def _ensure_directory(self) -> None:
directory = os.path.dirname(self._db_path)
if directory:
os.makedirs(directory, exist_ok=True)
def _connect(self) -> sqlite3.Connection:
conn = sqlite3.connect(self._db_path, check_same_thread=False)
conn.row_factory = sqlite3.Row
return conn
def _initialize_schema(self) -> None:
if self._schema_initialized:
return
with self._connect() as conn:
conn.executescript(self._SCHEMA)
conn.commit()
self._schema_initialized = True
def get_database_path(self) -> str:
return self._db_path
def _get_active_library_name(self) -> str | None:
try:
value = self._settings.get_active_library_name()
except Exception:
return None
return value or None
async def mark_downloaded(
self,
model_type: str,
version_id: int,
*,
model_id: int | None = None,
source: str = "manual",
file_path: str | None = None,
library_name: str | None = None,
) -> None:
normalized_type = _normalize_model_type(model_type)
normalized_version_id = _normalize_int(version_id)
normalized_model_id = _normalize_int(model_id)
if normalized_type is None or normalized_version_id is None:
return
active_library_name = library_name or self._get_active_library_name()
timestamp = time.time()
async with self._lock:
with self._connect() as conn:
conn.execute(
"""
INSERT INTO downloaded_model_versions (
model_type, version_id, model_id, first_seen_at, last_seen_at,
source, last_file_path, last_library_name, is_deleted_override
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, 0)
ON CONFLICT(model_type, version_id) DO UPDATE SET
model_id = COALESCE(excluded.model_id, downloaded_model_versions.model_id),
last_seen_at = excluded.last_seen_at,
source = excluded.source,
last_file_path = COALESCE(excluded.last_file_path, downloaded_model_versions.last_file_path),
last_library_name = COALESCE(excluded.last_library_name, downloaded_model_versions.last_library_name),
is_deleted_override = 0
""",
(
normalized_type,
normalized_version_id,
normalized_model_id,
timestamp,
timestamp,
source,
file_path,
active_library_name,
),
)
conn.commit()
async def mark_downloaded_bulk(
self,
model_type: str,
records: Sequence[Mapping[str, object]],
*,
source: str = "scan",
library_name: str | None = None,
) -> None:
normalized_type = _normalize_model_type(model_type)
if normalized_type is None or not records:
return
timestamp = time.time()
active_library_name = library_name or self._get_active_library_name()
payload: list[tuple[object, ...]] = []
for record in records:
version_id = _normalize_int(record.get("version_id"))
if version_id is None:
continue
payload.append(
(
normalized_type,
version_id,
_normalize_int(record.get("model_id")),
timestamp,
timestamp,
source,
record.get("file_path"),
active_library_name,
)
)
if not payload:
return
async with self._lock:
with self._connect() as conn:
conn.executemany(
"""
INSERT INTO downloaded_model_versions (
model_type, version_id, model_id, first_seen_at, last_seen_at,
source, last_file_path, last_library_name, is_deleted_override
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, 0)
ON CONFLICT(model_type, version_id) DO UPDATE SET
model_id = COALESCE(excluded.model_id, downloaded_model_versions.model_id),
last_seen_at = excluded.last_seen_at,
source = excluded.source,
last_file_path = COALESCE(excluded.last_file_path, downloaded_model_versions.last_file_path),
last_library_name = COALESCE(excluded.last_library_name, downloaded_model_versions.last_library_name),
is_deleted_override = 0
""",
payload,
)
conn.commit()
async def mark_not_downloaded(self, model_type: str, version_id: int) -> None:
normalized_type = _normalize_model_type(model_type)
normalized_version_id = _normalize_int(version_id)
if normalized_type is None or normalized_version_id is None:
return
timestamp = time.time()
async with self._lock:
with self._connect() as conn:
conn.execute(
"""
INSERT INTO downloaded_model_versions (
model_type, version_id, model_id, first_seen_at, last_seen_at,
source, last_file_path, last_library_name, is_deleted_override
) VALUES (?, ?, NULL, ?, ?, 'manual', NULL, ?, 1)
ON CONFLICT(model_type, version_id) DO UPDATE SET
last_seen_at = excluded.last_seen_at,
source = excluded.source,
last_library_name = COALESCE(excluded.last_library_name, downloaded_model_versions.last_library_name),
is_deleted_override = 1
""",
(
normalized_type,
normalized_version_id,
timestamp,
timestamp,
self._get_active_library_name(),
),
)
conn.commit()
async def has_been_downloaded(self, model_type: str, version_id: int) -> bool:
normalized_type = _normalize_model_type(model_type)
normalized_version_id = _normalize_int(version_id)
if normalized_type is None or normalized_version_id is None:
return False
async with self._lock:
with self._connect() as conn:
row = conn.execute(
"""
SELECT is_deleted_override
FROM downloaded_model_versions
WHERE model_type = ? AND version_id = ?
""",
(normalized_type, normalized_version_id),
).fetchone()
return bool(row) and not bool(row["is_deleted_override"])
async def get_downloaded_version_ids(
self, model_type: str, model_id: int
) -> list[int]:
normalized_type = _normalize_model_type(model_type)
normalized_model_id = _normalize_int(model_id)
if normalized_type is None or normalized_model_id is None:
return []
async with self._lock:
with self._connect() as conn:
rows = conn.execute(
"""
SELECT version_id
FROM downloaded_model_versions
WHERE model_type = ? AND model_id = ? AND is_deleted_override = 0
ORDER BY version_id ASC
""",
(normalized_type, normalized_model_id),
).fetchall()
return [int(row["version_id"]) for row in rows]
async def get_downloaded_version_ids_bulk(
self, model_type: str, model_ids: Iterable[int]
) -> dict[int, set[int]]:
normalized_type = _normalize_model_type(model_type)
if normalized_type is None:
return {}
normalized_model_ids = sorted(
{
value
for value in (_normalize_int(model_id) for model_id in model_ids)
if value is not None
}
)
if not normalized_model_ids:
return {}
placeholders = ", ".join(["?"] * len(normalized_model_ids))
params: list[object] = [normalized_type, *normalized_model_ids]
async with self._lock:
with self._connect() as conn:
rows = conn.execute(
f"""
SELECT model_id, version_id
FROM downloaded_model_versions
WHERE model_type = ?
AND model_id IN ({placeholders})
AND is_deleted_override = 0
""",
params,
).fetchall()
result: dict[int, set[int]] = {}
for row in rows:
model_id = _normalize_int(row["model_id"])
version_id = _normalize_int(row["version_id"])
if model_id is None or version_id is None:
continue
result.setdefault(model_id, set()).add(version_id)
return result

View File

@@ -411,6 +411,7 @@ class ModelScanner:
if scan_result: if scan_result:
await self._apply_scan_result(scan_result) await self._apply_scan_result(scan_result)
await self._save_persistent_cache(scan_result) await self._save_persistent_cache(scan_result)
await self._sync_download_history(scan_result.raw_data, source='scan')
# Send final progress update # Send final progress update
await ws_manager.broadcast_init_progress({ await ws_manager.broadcast_init_progress({
@@ -516,6 +517,7 @@ class ModelScanner:
) )
await self._apply_scan_result(scan_result) await self._apply_scan_result(scan_result)
await self._sync_download_history(adjusted_raw_data, source='scan')
await ws_manager.broadcast_init_progress({ await ws_manager.broadcast_init_progress({
'stage': 'loading_cache', 'stage': 'loading_cache',
@@ -576,6 +578,7 @@ class ModelScanner:
excluded_models=list(self._excluded_models) excluded_models=list(self._excluded_models)
) )
await self._save_persistent_cache(snapshot) await self._save_persistent_cache(snapshot)
await self._sync_download_history(snapshot.raw_data, source='scan')
def _count_model_files(self) -> int: def _count_model_files(self) -> int:
"""Count all model files with supported extensions in all roots """Count all model files with supported extensions in all roots
@@ -704,6 +707,7 @@ class ModelScanner:
scan_result = await self._gather_model_data() scan_result = await self._gather_model_data()
await self._apply_scan_result(scan_result) await self._apply_scan_result(scan_result)
await self._save_persistent_cache(scan_result) await self._save_persistent_cache(scan_result)
await self._sync_download_history(scan_result.raw_data, source='scan')
logger.info( logger.info(
f"{self.model_type.capitalize()} Scanner: Cache initialization completed in {time.time() - start_time:.2f} seconds, " f"{self.model_type.capitalize()} Scanner: Cache initialization completed in {time.time() - start_time:.2f} seconds, "
@@ -1101,6 +1105,49 @@ class ModelScanner:
await self._cache.resort() await self._cache.resort()
async def _sync_download_history(
self,
raw_data: List[Mapping[str, Any]],
*,
source: str,
) -> None:
records: List[Dict[str, Any]] = []
for item in raw_data or []:
if not isinstance(item, Mapping):
continue
civitai = item.get('civitai')
if not isinstance(civitai, Mapping):
continue
version_id = civitai.get('id')
if version_id in (None, ''):
continue
records.append(
{
'version_id': version_id,
'model_id': civitai.get('modelId'),
'file_path': item.get('file_path'),
}
)
if not records:
return
try:
history_service = await ServiceRegistry.get_downloaded_version_history_service()
await history_service.mark_downloaded_bulk(
self.model_type,
records,
source=source,
)
except Exception as exc:
logger.debug(
"%s Scanner: Failed to sync download history: %s",
self.model_type.capitalize(),
exc,
)
async def _gather_model_data( async def _gather_model_data(
self, self,
*, *,

View File

@@ -167,6 +167,28 @@ class ServiceRegistry:
logger.debug(f"Created and registered {service_name}") logger.debug(f"Created and registered {service_name}")
return service return service
@classmethod
async def get_downloaded_version_history_service(cls):
"""Get or create the downloaded-version history service."""
service_name = "downloaded_version_history_service"
if service_name in cls._services:
return cls._services[service_name]
async with cls._get_lock(service_name):
if service_name in cls._services:
return cls._services[service_name]
from .downloaded_version_history_service import (
DownloadedVersionHistoryService,
)
service = DownloadedVersionHistoryService()
cls._services[service_name] = service
logger.debug(f"Created and registered {service_name}")
return service
@classmethod @classmethod
async def get_civarchive_client(cls): async def get_civarchive_client(cls):
"""Get or create CivArchive client instance""" """Get or create CivArchive client instance"""
@@ -255,4 +277,4 @@ class ServiceRegistry:
"""Clear all registered services - mainly for testing""" """Clear all registered services - mainly for testing"""
cls._services.clear() cls._services.clear()
cls._locks.clear() cls._locks.clear()
logger.info("Cleared all registered services") logger.info("Cleared all registered services")

View File

@@ -66,6 +66,27 @@ class FakePromptServer:
instance = Instance() instance = Instance()
class FakeDownloadHistoryService:
async def has_been_downloaded(self, _model_type, _version_id):
return False
async def get_downloaded_version_ids(self, _model_type, _model_id):
return []
async def get_downloaded_version_ids_bulk(self, _model_type, _model_ids):
return {}
async def mark_downloaded(self, *_args, **_kwargs):
return None
async def mark_not_downloaded(self, *_args, **_kwargs):
return None
async def fake_download_history_service_factory():
return FakeDownloadHistoryService()
class TestSettingsHandlerSnapshots: class TestSettingsHandlerSnapshots:
"""Snapshot tests for SettingsHandler responses.""" """Snapshot tests for SettingsHandler responses."""
@@ -223,6 +244,7 @@ class TestModelLibraryHandlerSnapshots:
get_lora_scanner=scanner_factory, get_lora_scanner=scanner_factory,
get_checkpoint_scanner=scanner_factory, get_checkpoint_scanner=scanner_factory,
get_embedding_scanner=scanner_factory, get_embedding_scanner=scanner_factory,
get_downloaded_version_history_service=fake_download_history_service_factory,
), ),
metadata_provider_factory=lambda: None, metadata_provider_factory=lambda: None,
) )

View File

@@ -438,6 +438,46 @@ async def fake_metadata_archive_manager_factory():
return FakeMetadataArchiveManager() return FakeMetadataArchiveManager()
class FakeDownloadHistoryService:
def __init__(self, downloaded_by_type=None):
self.downloaded_by_type = downloaded_by_type or {}
self.marked_downloaded: list[tuple] = []
self.marked_not_downloaded: list[tuple] = []
async def has_been_downloaded(self, model_type, version_id):
return version_id in self.downloaded_by_type.get(model_type, set())
async def get_downloaded_version_ids(self, model_type, model_id):
entries = self.downloaded_by_type.get(model_type, {})
if isinstance(entries, dict):
return sorted(entries.get(model_id, set()))
return []
async def get_downloaded_version_ids_bulk(self, model_type, model_ids):
entries = self.downloaded_by_type.get(model_type, {})
if not isinstance(entries, dict):
return {}
return {
model_id: set(entries.get(model_id, set()))
for model_id in model_ids
if model_id in entries
}
async def mark_downloaded(
self, model_type, version_id, *, model_id=None, source="manual", file_path=None
):
self.marked_downloaded.append(
(model_type, version_id, model_id, source, file_path)
)
async def mark_not_downloaded(self, model_type, version_id):
self.marked_not_downloaded.append((model_type, version_id))
async def fake_download_history_service_factory():
return FakeDownloadHistoryService()
class RecordingRegistrar: class RecordingRegistrar:
def __init__(self, _app): def __init__(self, _app):
self.registered_mapping = None self.registered_mapping = None
@@ -452,6 +492,7 @@ async def test_misc_routes_bind_produces_expected_handlers():
get_lora_scanner=fake_scanner_factory, get_lora_scanner=fake_scanner_factory,
get_checkpoint_scanner=fake_scanner_factory, get_checkpoint_scanner=fake_scanner_factory,
get_embedding_scanner=fake_scanner_factory, get_embedding_scanner=fake_scanner_factory,
get_downloaded_version_history_service=fake_download_history_service_factory,
) )
recorded_registrars = [] recorded_registrars = []
@@ -578,6 +619,7 @@ async def test_get_civitai_user_models_marks_library_versions():
get_lora_scanner=lora_factory, get_lora_scanner=lora_factory,
get_checkpoint_scanner=checkpoint_factory, get_checkpoint_scanner=checkpoint_factory,
get_embedding_scanner=embedding_factory, get_embedding_scanner=embedding_factory,
get_downloaded_version_history_service=lambda: fake_download_history_service_factory(),
), ),
metadata_provider_factory=provider_factory, metadata_provider_factory=provider_factory,
) )
@@ -600,6 +642,7 @@ async def test_get_civitai_user_models_marks_library_versions():
"baseModel": "Flux.1", "baseModel": "Flux.1",
"thumbnailUrl": "http://example.com/a1.jpg", "thumbnailUrl": "http://example.com/a1.jpg",
"inLibrary": False, "inLibrary": False,
"hasBeenDownloaded": False,
}, },
{ {
"modelId": 1, "modelId": 1,
@@ -611,6 +654,7 @@ async def test_get_civitai_user_models_marks_library_versions():
"baseModel": "Flux.1", "baseModel": "Flux.1",
"thumbnailUrl": "http://example.com/a2.jpg", "thumbnailUrl": "http://example.com/a2.jpg",
"inLibrary": True, "inLibrary": True,
"hasBeenDownloaded": False,
}, },
{ {
"modelId": 2, "modelId": 2,
@@ -622,6 +666,7 @@ async def test_get_civitai_user_models_marks_library_versions():
"baseModel": None, "baseModel": None,
"thumbnailUrl": "http://example.com/e1.jpg", "thumbnailUrl": "http://example.com/e1.jpg",
"inLibrary": False, "inLibrary": False,
"hasBeenDownloaded": False,
}, },
{ {
"modelId": 2, "modelId": 2,
@@ -633,6 +678,7 @@ async def test_get_civitai_user_models_marks_library_versions():
"baseModel": None, "baseModel": None,
"thumbnailUrl": None, "thumbnailUrl": None,
"inLibrary": True, "inLibrary": True,
"hasBeenDownloaded": False,
}, },
{ {
"modelId": 3, "modelId": 3,
@@ -644,6 +690,7 @@ async def test_get_civitai_user_models_marks_library_versions():
"baseModel": "SDXL", "baseModel": "SDXL",
"thumbnailUrl": None, "thumbnailUrl": None,
"inLibrary": False, "inLibrary": False,
"hasBeenDownloaded": False,
}, },
] ]
@@ -692,6 +739,7 @@ async def test_get_civitai_user_models_rewrites_civitai_previews():
get_lora_scanner=fake_scanner_factory, get_lora_scanner=fake_scanner_factory,
get_checkpoint_scanner=fake_scanner_factory, get_checkpoint_scanner=fake_scanner_factory,
get_embedding_scanner=fake_scanner_factory, get_embedding_scanner=fake_scanner_factory,
get_downloaded_version_history_service=fake_download_history_service_factory,
), ),
metadata_provider_factory=provider_factory, metadata_provider_factory=provider_factory,
) )
@@ -727,6 +775,7 @@ async def test_get_civitai_user_models_requires_username():
get_lora_scanner=fake_scanner_factory, get_lora_scanner=fake_scanner_factory,
get_checkpoint_scanner=fake_scanner_factory, get_checkpoint_scanner=fake_scanner_factory,
get_embedding_scanner=fake_scanner_factory, get_embedding_scanner=fake_scanner_factory,
get_downloaded_version_history_service=fake_download_history_service_factory,
), ),
metadata_provider_factory=provider_factory, metadata_provider_factory=provider_factory,
) )
@@ -760,6 +809,7 @@ def test_ensure_handler_mapping_caches_result():
get_lora_scanner=fake_scanner_factory, get_lora_scanner=fake_scanner_factory,
get_checkpoint_scanner=fake_scanner_factory, get_checkpoint_scanner=fake_scanner_factory,
get_embedding_scanner=fake_scanner_factory, get_embedding_scanner=fake_scanner_factory,
get_downloaded_version_history_service=fake_download_history_service_factory,
), ),
metadata_provider_factory=fake_metadata_provider_factory, metadata_provider_factory=fake_metadata_provider_factory,
metadata_archive_manager_factory=fake_metadata_archive_manager_factory, metadata_archive_manager_factory=fake_metadata_archive_manager_factory,
@@ -802,6 +852,7 @@ async def test_check_model_exists_returns_local_versions():
get_lora_scanner=lora_factory, get_lora_scanner=lora_factory,
get_checkpoint_scanner=checkpoint_factory, get_checkpoint_scanner=checkpoint_factory,
get_embedding_scanner=embedding_factory, get_embedding_scanner=embedding_factory,
get_downloaded_version_history_service=fake_download_history_service_factory,
), ),
metadata_provider_factory=fake_metadata_provider_factory, metadata_provider_factory=fake_metadata_provider_factory,
) )
@@ -811,10 +862,94 @@ async def test_check_model_exists_returns_local_versions():
assert payload["success"] is True assert payload["success"] is True
assert payload["modelType"] == "lora" assert payload["modelType"] == "lora"
assert payload["versions"] == versions assert payload["versions"] == [
{"versionId": 11, "name": "v1", "fileName": "model-one", "hasBeenDownloaded": True},
{"versionId": 12, "name": "v2", "fileName": "model-two", "hasBeenDownloaded": True},
]
assert lora_scanner.version_calls == [5] assert lora_scanner.version_calls == [5]
@pytest.mark.asyncio
async def test_check_model_exists_returns_download_history_when_file_missing():
history_service = FakeDownloadHistoryService({"checkpoint": {999}})
async def history_factory():
return history_service
handler = ModelLibraryHandler(
ServiceRegistryAdapter(
get_lora_scanner=fake_scanner_factory,
get_checkpoint_scanner=fake_scanner_factory,
get_embedding_scanner=fake_scanner_factory,
get_downloaded_version_history_service=history_factory,
),
metadata_provider_factory=fake_metadata_provider_factory,
)
response = await handler.check_model_exists(
FakeRequest(query={"modelId": "5", "modelVersionId": "999"})
)
payload = json.loads(response.text)
assert payload == {
"success": True,
"exists": False,
"modelType": "checkpoint",
"hasBeenDownloaded": True,
}
@pytest.mark.asyncio
async def test_model_version_download_status_endpoints():
history_service = FakeDownloadHistoryService({"lora": {123}})
async def history_factory():
return history_service
handler = ModelLibraryHandler(
ServiceRegistryAdapter(
get_lora_scanner=fake_scanner_factory,
get_checkpoint_scanner=fake_scanner_factory,
get_embedding_scanner=fake_scanner_factory,
get_downloaded_version_history_service=history_factory,
),
metadata_provider_factory=fake_metadata_provider_factory,
)
get_response = await handler.get_model_version_download_status(
FakeRequest(query={"modelType": "lora", "modelVersionId": "123"})
)
get_payload = json.loads(get_response.text)
assert get_payload == {
"success": True,
"modelType": "lora",
"modelVersionId": 123,
"hasBeenDownloaded": True,
}
set_response = await handler.set_model_version_download_status(
FakeRequest(
json_data={
"modelType": "checkpoint",
"modelVersionId": 456,
"modelId": 78,
"downloaded": True,
"filePath": "/tmp/model.safetensors",
}
)
)
set_payload = json.loads(set_response.text)
assert set_payload == {
"success": True,
"modelType": "checkpoint",
"modelVersionId": 456,
"hasBeenDownloaded": True,
}
assert history_service.marked_downloaded == [
("checkpoint", 456, 78, "manual", "/tmp/model.safetensors")
]
def test_create_handler_set_uses_provided_dependencies(): def test_create_handler_set_uses_provided_dependencies():
recorded_handlers: list[dict] = [] recorded_handlers: list[dict] = []
@@ -845,6 +980,7 @@ def test_create_handler_set_uses_provided_dependencies():
get_lora_scanner=fake_scanner_factory, get_lora_scanner=fake_scanner_factory,
get_checkpoint_scanner=fake_scanner_factory, get_checkpoint_scanner=fake_scanner_factory,
get_embedding_scanner=fake_scanner_factory, get_embedding_scanner=fake_scanner_factory,
get_downloaded_version_history_service=fake_download_history_service_factory,
), ),
metadata_provider_factory=fake_metadata_provider_factory, metadata_provider_factory=fake_metadata_provider_factory,
metadata_archive_manager_factory=fake_metadata_archive_manager_factory, metadata_archive_manager_factory=fake_metadata_archive_manager_factory,

View File

@@ -0,0 +1,70 @@
from pathlib import Path
import pytest
from py.services.downloaded_version_history_service import (
DownloadedVersionHistoryService,
)
class DummySettings:
def get_active_library_name(self) -> str:
return "alpha"
@pytest.mark.asyncio
async def test_download_history_roundtrip_and_manual_override(tmp_path: Path) -> None:
db_path = tmp_path / "download-history.sqlite"
service = DownloadedVersionHistoryService(
str(db_path),
settings_manager=DummySettings(),
)
await service.mark_downloaded(
"lora",
101,
model_id=11,
source="scan",
file_path="/models/a.safetensors",
)
assert await service.has_been_downloaded("lora", 101) is True
assert await service.get_downloaded_version_ids("lora", 11) == [101]
await service.mark_not_downloaded("lora", 101)
assert await service.has_been_downloaded("lora", 101) is False
assert await service.get_downloaded_version_ids("lora", 11) == []
await service.mark_downloaded(
"lora",
101,
model_id=11,
source="download",
file_path="/models/a.safetensors",
)
assert await service.has_been_downloaded("lora", 101) is True
assert await service.get_downloaded_version_ids("lora", 11) == [101]
@pytest.mark.asyncio
async def test_download_history_bulk_lookup(tmp_path: Path) -> None:
db_path = tmp_path / "download-history.sqlite"
service = DownloadedVersionHistoryService(
str(db_path),
settings_manager=DummySettings(),
)
await service.mark_downloaded_bulk(
"checkpoint",
[
{"model_id": 5, "version_id": 501, "file_path": "/m/one.safetensors"},
{"model_id": 5, "version_id": 502, "file_path": "/m/two.safetensors"},
{"model_id": 6, "version_id": 601, "file_path": "/m/three.safetensors"},
],
source="scan",
)
assert await service.get_downloaded_version_ids("checkpoint", 5) == [501, 502]
assert await service.get_downloaded_version_ids_bulk("checkpoint", [5, 6, 7]) == {
5: {501, 502},
6: {601},
}