mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-04-10 04:42:14 -03:00
feat(download-history): track downloaded model versions
This commit is contained in:
@@ -751,6 +751,7 @@ class ServiceRegistryAdapter:
|
|||||||
get_lora_scanner: Callable[[], Awaitable]
|
get_lora_scanner: Callable[[], Awaitable]
|
||||||
get_checkpoint_scanner: Callable[[], Awaitable]
|
get_checkpoint_scanner: Callable[[], Awaitable]
|
||||||
get_embedding_scanner: Callable[[], Awaitable]
|
get_embedding_scanner: Callable[[], Awaitable]
|
||||||
|
get_downloaded_version_history_service: Callable[[], Awaitable]
|
||||||
|
|
||||||
|
|
||||||
class ModelLibraryHandler:
|
class ModelLibraryHandler:
|
||||||
@@ -764,6 +765,41 @@ class ModelLibraryHandler:
|
|||||||
self._service_registry = service_registry
|
self._service_registry = service_registry
|
||||||
self._metadata_provider_factory = metadata_provider_factory
|
self._metadata_provider_factory = metadata_provider_factory
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _normalize_model_type(model_type: str | None) -> str | None:
|
||||||
|
if not isinstance(model_type, str):
|
||||||
|
return None
|
||||||
|
normalized = model_type.strip().lower()
|
||||||
|
if normalized in {"lora", "locon", "dora"}:
|
||||||
|
return "lora"
|
||||||
|
if normalized == "checkpoint":
|
||||||
|
return "checkpoint"
|
||||||
|
if normalized in {"embedding", "textualinversion"}:
|
||||||
|
return "embedding"
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _get_scanner_for_type(self, model_type: str | None):
|
||||||
|
normalized_type = self._normalize_model_type(model_type)
|
||||||
|
if normalized_type == "lora":
|
||||||
|
return normalized_type, await self._service_registry.get_lora_scanner()
|
||||||
|
if normalized_type == "checkpoint":
|
||||||
|
return normalized_type, await self._service_registry.get_checkpoint_scanner()
|
||||||
|
if normalized_type == "embedding":
|
||||||
|
return normalized_type, await self._service_registry.get_embedding_scanner()
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
async def _get_download_history_service(self):
|
||||||
|
return await self._service_registry.get_downloaded_version_history_service()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _with_downloaded_flag(versions: list[dict]) -> list[dict]:
|
||||||
|
enriched: list[dict] = []
|
||||||
|
for version in versions:
|
||||||
|
entry = dict(version)
|
||||||
|
entry.setdefault("hasBeenDownloaded", True)
|
||||||
|
enriched.append(entry)
|
||||||
|
return enriched
|
||||||
|
|
||||||
async def check_model_exists(self, request: web.Request) -> web.Response:
|
async def check_model_exists(self, request: web.Request) -> web.Response:
|
||||||
try:
|
try:
|
||||||
model_id_str = request.query.get("modelId")
|
model_id_str = request.query.get("modelId")
|
||||||
@@ -819,11 +855,30 @@ class ModelLibraryHandler:
|
|||||||
exists = True
|
exists = True
|
||||||
model_type = "embedding"
|
model_type = "embedding"
|
||||||
|
|
||||||
|
history_service = await self._get_download_history_service()
|
||||||
|
has_been_downloaded = False
|
||||||
|
history_type = model_type
|
||||||
|
if history_type:
|
||||||
|
has_been_downloaded = await history_service.has_been_downloaded(
|
||||||
|
history_type,
|
||||||
|
model_version_id,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
for candidate_type in ("lora", "checkpoint", "embedding"):
|
||||||
|
if await history_service.has_been_downloaded(
|
||||||
|
candidate_type,
|
||||||
|
model_version_id,
|
||||||
|
):
|
||||||
|
has_been_downloaded = True
|
||||||
|
history_type = candidate_type
|
||||||
|
break
|
||||||
|
|
||||||
return web.json_response(
|
return web.json_response(
|
||||||
{
|
{
|
||||||
"success": True,
|
"success": True,
|
||||||
"exists": exists,
|
"exists": exists,
|
||||||
"modelType": model_type if exists else None,
|
"modelType": model_type if exists else history_type,
|
||||||
|
"hasBeenDownloaded": has_been_downloaded,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -843,13 +898,13 @@ class ModelLibraryHandler:
|
|||||||
versions = []
|
versions = []
|
||||||
if lora_versions:
|
if lora_versions:
|
||||||
model_type = "lora"
|
model_type = "lora"
|
||||||
versions = lora_versions
|
versions = self._with_downloaded_flag(lora_versions)
|
||||||
elif checkpoint_versions:
|
elif checkpoint_versions:
|
||||||
model_type = "checkpoint"
|
model_type = "checkpoint"
|
||||||
versions = checkpoint_versions
|
versions = self._with_downloaded_flag(checkpoint_versions)
|
||||||
elif embedding_versions:
|
elif embedding_versions:
|
||||||
model_type = "embedding"
|
model_type = "embedding"
|
||||||
versions = embedding_versions
|
versions = self._with_downloaded_flag(embedding_versions)
|
||||||
|
|
||||||
return web.json_response(
|
return web.json_response(
|
||||||
{"success": True, "modelType": model_type, "versions": versions}
|
{"success": True, "modelType": model_type, "versions": versions}
|
||||||
@@ -858,6 +913,108 @@ class ModelLibraryHandler:
|
|||||||
logger.error("Failed to check model existence: %s", exc, exc_info=True)
|
logger.error("Failed to check model existence: %s", exc, exc_info=True)
|
||||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||||
|
|
||||||
|
async def get_model_version_download_status(
|
||||||
|
self, request: web.Request
|
||||||
|
) -> web.Response:
|
||||||
|
try:
|
||||||
|
model_type, _ = await self._get_scanner_for_type(request.query.get("modelType"))
|
||||||
|
if not model_type:
|
||||||
|
return web.json_response(
|
||||||
|
{"success": False, "error": "Parameter modelType is required"},
|
||||||
|
status=400,
|
||||||
|
)
|
||||||
|
|
||||||
|
model_version_id_str = request.query.get("modelVersionId")
|
||||||
|
if not model_version_id_str:
|
||||||
|
return web.json_response(
|
||||||
|
{"success": False, "error": "Missing required parameter: modelVersionId"},
|
||||||
|
status=400,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
model_version_id = int(model_version_id_str)
|
||||||
|
except ValueError:
|
||||||
|
return web.json_response(
|
||||||
|
{"success": False, "error": "Parameter modelVersionId must be an integer"},
|
||||||
|
status=400,
|
||||||
|
)
|
||||||
|
|
||||||
|
history_service = await self._get_download_history_service()
|
||||||
|
return web.json_response(
|
||||||
|
{
|
||||||
|
"success": True,
|
||||||
|
"modelType": model_type,
|
||||||
|
"modelVersionId": model_version_id,
|
||||||
|
"hasBeenDownloaded": await history_service.has_been_downloaded(
|
||||||
|
model_type,
|
||||||
|
model_version_id,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except Exception as exc: # pragma: no cover - defensive logging
|
||||||
|
logger.error(
|
||||||
|
"Failed to get model version download status: %s",
|
||||||
|
exc,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||||
|
|
||||||
|
async def set_model_version_download_status(
|
||||||
|
self, request: web.Request
|
||||||
|
) -> web.Response:
|
||||||
|
try:
|
||||||
|
data = await request.json()
|
||||||
|
model_type, _ = await self._get_scanner_for_type(data.get("modelType"))
|
||||||
|
if not model_type:
|
||||||
|
return web.json_response(
|
||||||
|
{"success": False, "error": "Parameter modelType is required"},
|
||||||
|
status=400,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
model_version_id = int(data.get("modelVersionId"))
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return web.json_response(
|
||||||
|
{"success": False, "error": "Parameter modelVersionId must be an integer"},
|
||||||
|
status=400,
|
||||||
|
)
|
||||||
|
|
||||||
|
downloaded = data.get("downloaded")
|
||||||
|
if not isinstance(downloaded, bool):
|
||||||
|
return web.json_response(
|
||||||
|
{"success": False, "error": "Parameter downloaded must be a boolean"},
|
||||||
|
status=400,
|
||||||
|
)
|
||||||
|
|
||||||
|
history_service = await self._get_download_history_service()
|
||||||
|
if downloaded:
|
||||||
|
model_id = data.get("modelId")
|
||||||
|
file_path = data.get("filePath")
|
||||||
|
await history_service.mark_downloaded(
|
||||||
|
model_type,
|
||||||
|
model_version_id,
|
||||||
|
model_id=model_id,
|
||||||
|
source="manual",
|
||||||
|
file_path=file_path if isinstance(file_path, str) else None,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await history_service.mark_not_downloaded(model_type, model_version_id)
|
||||||
|
|
||||||
|
return web.json_response(
|
||||||
|
{
|
||||||
|
"success": True,
|
||||||
|
"modelType": model_type,
|
||||||
|
"modelVersionId": model_version_id,
|
||||||
|
"hasBeenDownloaded": downloaded,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except Exception as exc: # pragma: no cover - defensive logging
|
||||||
|
logger.error(
|
||||||
|
"Failed to set model version download status: %s",
|
||||||
|
exc,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||||
|
|
||||||
async def get_model_versions_status(self, request: web.Request) -> web.Response:
|
async def get_model_versions_status(self, request: web.Request) -> web.Response:
|
||||||
try:
|
try:
|
||||||
model_id_str = request.query.get("modelId")
|
model_id_str = request.query.get("modelId")
|
||||||
@@ -896,18 +1053,8 @@ class ModelLibraryHandler:
|
|||||||
model_name = response.get("name", "")
|
model_name = response.get("name", "")
|
||||||
model_type = response.get("type", "").lower()
|
model_type = response.get("type", "").lower()
|
||||||
|
|
||||||
scanner = None
|
normalized_type, scanner = await self._get_scanner_for_type(model_type)
|
||||||
normalized_type = None
|
if not normalized_type:
|
||||||
if model_type in {"lora", "locon", "dora"}:
|
|
||||||
scanner = await self._service_registry.get_lora_scanner()
|
|
||||||
normalized_type = "lora"
|
|
||||||
elif model_type == "checkpoint":
|
|
||||||
scanner = await self._service_registry.get_checkpoint_scanner()
|
|
||||||
normalized_type = "checkpoint"
|
|
||||||
elif model_type == "textualinversion":
|
|
||||||
scanner = await self._service_registry.get_embedding_scanner()
|
|
||||||
normalized_type = "embedding"
|
|
||||||
else:
|
|
||||||
return web.json_response(
|
return web.json_response(
|
||||||
{
|
{
|
||||||
"success": False,
|
"success": False,
|
||||||
@@ -925,8 +1072,14 @@ class ModelLibraryHandler:
|
|||||||
status=503,
|
status=503,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
history_service = await self._get_download_history_service()
|
||||||
local_versions = await scanner.get_model_versions_by_id(model_id)
|
local_versions = await scanner.get_model_versions_by_id(model_id)
|
||||||
local_version_ids = {version["versionId"] for version in local_versions}
|
local_version_ids = {version["versionId"] for version in local_versions}
|
||||||
|
downloaded_version_ids = await history_service.get_downloaded_version_ids(
|
||||||
|
normalized_type,
|
||||||
|
model_id,
|
||||||
|
)
|
||||||
|
downloaded_version_id_set = set(downloaded_version_ids)
|
||||||
|
|
||||||
enriched_versions = []
|
enriched_versions = []
|
||||||
for version in versions:
|
for version in versions:
|
||||||
@@ -939,6 +1092,7 @@ class ModelLibraryHandler:
|
|||||||
if version.get("images")
|
if version.get("images")
|
||||||
else None,
|
else None,
|
||||||
"inLibrary": version_id in local_version_ids,
|
"inLibrary": version_id in local_version_ids,
|
||||||
|
"hasBeenDownloaded": version_id in downloaded_version_id_set,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1007,6 +1161,33 @@ class ModelLibraryHandler:
|
|||||||
}
|
}
|
||||||
|
|
||||||
versions: list[dict] = []
|
versions: list[dict] = []
|
||||||
|
history_service = await self._get_download_history_service()
|
||||||
|
model_ids: list[int] = []
|
||||||
|
for model in models:
|
||||||
|
try:
|
||||||
|
model_ids.append(int(model.get("id")))
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
continue
|
||||||
|
|
||||||
|
lora_downloaded = await history_service.get_downloaded_version_ids_bulk(
|
||||||
|
"lora",
|
||||||
|
model_ids,
|
||||||
|
)
|
||||||
|
checkpoint_downloaded = await history_service.get_downloaded_version_ids_bulk(
|
||||||
|
"checkpoint",
|
||||||
|
model_ids,
|
||||||
|
)
|
||||||
|
embedding_downloaded = await history_service.get_downloaded_version_ids_bulk(
|
||||||
|
"embedding",
|
||||||
|
model_ids,
|
||||||
|
)
|
||||||
|
downloaded_version_map: Dict[str, Dict[int, set[int]]] = {
|
||||||
|
"lora": lora_downloaded,
|
||||||
|
"locon": lora_downloaded,
|
||||||
|
"dora": lora_downloaded,
|
||||||
|
"checkpoint": checkpoint_downloaded,
|
||||||
|
"textualinversion": embedding_downloaded,
|
||||||
|
}
|
||||||
for model in models:
|
for model in models:
|
||||||
if not isinstance(model, dict):
|
if not isinstance(model, dict):
|
||||||
continue
|
continue
|
||||||
@@ -1061,6 +1242,8 @@ class ModelLibraryHandler:
|
|||||||
in_library = await scanner.check_model_version_exists(
|
in_library = await scanner.check_model_version_exists(
|
||||||
version_id_int
|
version_id_int
|
||||||
)
|
)
|
||||||
|
downloaded_versions = downloaded_version_map.get(model_type, {})
|
||||||
|
downloaded_version_ids = downloaded_versions.get(model_id_int, set())
|
||||||
|
|
||||||
versions.append(
|
versions.append(
|
||||||
{
|
{
|
||||||
@@ -1073,6 +1256,7 @@ class ModelLibraryHandler:
|
|||||||
"baseModel": version.get("baseModel"),
|
"baseModel": version.get("baseModel"),
|
||||||
"thumbnailUrl": thumbnail_url,
|
"thumbnailUrl": thumbnail_url,
|
||||||
"inLibrary": in_library,
|
"inLibrary": in_library,
|
||||||
|
"hasBeenDownloaded": version_id_int in downloaded_version_ids,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1655,6 +1839,8 @@ class MiscHandlerSet:
|
|||||||
"update_node_widget": self.node_registry.update_node_widget,
|
"update_node_widget": self.node_registry.update_node_widget,
|
||||||
"get_registry": self.node_registry.get_registry,
|
"get_registry": self.node_registry.get_registry,
|
||||||
"check_model_exists": self.model_library.check_model_exists,
|
"check_model_exists": self.model_library.check_model_exists,
|
||||||
|
"get_model_version_download_status": self.model_library.get_model_version_download_status,
|
||||||
|
"set_model_version_download_status": self.model_library.set_model_version_download_status,
|
||||||
"get_civitai_user_models": self.model_library.get_civitai_user_models,
|
"get_civitai_user_models": self.model_library.get_civitai_user_models,
|
||||||
"download_metadata_archive": self.metadata_archive.download_metadata_archive,
|
"download_metadata_archive": self.metadata_archive.download_metadata_archive,
|
||||||
"remove_metadata_archive": self.metadata_archive.remove_metadata_archive,
|
"remove_metadata_archive": self.metadata_archive.remove_metadata_archive,
|
||||||
@@ -1679,4 +1865,5 @@ def build_service_registry_adapter() -> ServiceRegistryAdapter:
|
|||||||
get_lora_scanner=ServiceRegistry.get_lora_scanner,
|
get_lora_scanner=ServiceRegistry.get_lora_scanner,
|
||||||
get_checkpoint_scanner=ServiceRegistry.get_checkpoint_scanner,
|
get_checkpoint_scanner=ServiceRegistry.get_checkpoint_scanner,
|
||||||
get_embedding_scanner=ServiceRegistry.get_embedding_scanner,
|
get_embedding_scanner=ServiceRegistry.get_embedding_scanner,
|
||||||
|
get_downloaded_version_history_service=ServiceRegistry.get_downloaded_version_history_service,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -37,6 +37,16 @@ MISC_ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
|
|||||||
RouteDefinition("POST", "/api/lm/update-node-widget", "update_node_widget"),
|
RouteDefinition("POST", "/api/lm/update-node-widget", "update_node_widget"),
|
||||||
RouteDefinition("GET", "/api/lm/get-registry", "get_registry"),
|
RouteDefinition("GET", "/api/lm/get-registry", "get_registry"),
|
||||||
RouteDefinition("GET", "/api/lm/check-model-exists", "check_model_exists"),
|
RouteDefinition("GET", "/api/lm/check-model-exists", "check_model_exists"),
|
||||||
|
RouteDefinition(
|
||||||
|
"GET",
|
||||||
|
"/api/lm/model-version-download-status",
|
||||||
|
"get_model_version_download_status",
|
||||||
|
),
|
||||||
|
RouteDefinition(
|
||||||
|
"POST",
|
||||||
|
"/api/lm/model-version-download-status",
|
||||||
|
"set_model_version_download_status",
|
||||||
|
),
|
||||||
RouteDefinition("GET", "/api/lm/civitai/user-models", "get_civitai_user_models"),
|
RouteDefinition("GET", "/api/lm/civitai/user-models", "get_civitai_user_models"),
|
||||||
RouteDefinition(
|
RouteDefinition(
|
||||||
"POST", "/api/lm/download-metadata-archive", "download_metadata_archive"
|
"POST", "/api/lm/download-metadata-archive", "download_metadata_archive"
|
||||||
|
|||||||
@@ -640,6 +640,13 @@ class DownloadManager:
|
|||||||
or version_info.get("modelId")
|
or version_info.get("modelId")
|
||||||
or (version_info.get("model") or {}).get("id")
|
or (version_info.get("model") or {}).get("id")
|
||||||
)
|
)
|
||||||
|
await self._record_downloaded_version_history(
|
||||||
|
model_type,
|
||||||
|
resolved_model_id,
|
||||||
|
version_info,
|
||||||
|
model_version_id,
|
||||||
|
save_path,
|
||||||
|
)
|
||||||
await self._sync_downloaded_version(
|
await self._sync_downloaded_version(
|
||||||
model_type,
|
model_type,
|
||||||
resolved_model_id,
|
resolved_model_id,
|
||||||
@@ -669,6 +676,55 @@ class DownloadManager:
|
|||||||
}
|
}
|
||||||
return {"success": False, "error": str(e)}
|
return {"success": False, "error": str(e)}
|
||||||
|
|
||||||
|
async def _record_downloaded_version_history(
|
||||||
|
self,
|
||||||
|
model_type: str,
|
||||||
|
model_id_value,
|
||||||
|
version_info: Dict,
|
||||||
|
fallback_version_id=None,
|
||||||
|
file_path: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
try:
|
||||||
|
history_service = await ServiceRegistry.get_downloaded_version_history_service()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.debug(
|
||||||
|
"Skipping download history sync; failed to acquire history service: %s",
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
if history_service is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
resolved_model_id = model_id_value
|
||||||
|
if resolved_model_id is None:
|
||||||
|
resolved_model_id = version_info.get("modelId")
|
||||||
|
if resolved_model_id is None:
|
||||||
|
model_info = version_info.get("model")
|
||||||
|
if isinstance(model_info, dict):
|
||||||
|
resolved_model_id = model_info.get("id")
|
||||||
|
|
||||||
|
version_id = version_info.get("id")
|
||||||
|
if version_id is None:
|
||||||
|
version_id = fallback_version_id
|
||||||
|
|
||||||
|
try:
|
||||||
|
await history_service.mark_downloaded(
|
||||||
|
model_type,
|
||||||
|
int(version_id),
|
||||||
|
model_id=int(resolved_model_id) if resolved_model_id is not None else None,
|
||||||
|
source="download",
|
||||||
|
file_path=file_path,
|
||||||
|
)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
logger.debug(
|
||||||
|
"Skipping download history sync; invalid identifiers model=%s version=%s",
|
||||||
|
resolved_model_id,
|
||||||
|
version_id,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.debug("Failed to sync download history for %s: %s", model_type, exc)
|
||||||
|
|
||||||
async def _sync_downloaded_version(
|
async def _sync_downloaded_version(
|
||||||
self,
|
self,
|
||||||
model_type: str,
|
model_type: str,
|
||||||
|
|||||||
313
py/services/downloaded_version_history_service.py
Normal file
313
py/services/downloaded_version_history_service.py
Normal file
@@ -0,0 +1,313 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sqlite3
|
||||||
|
import time
|
||||||
|
from typing import Iterable, Mapping, Optional, Sequence
|
||||||
|
|
||||||
|
from ..utils.cache_paths import get_cache_base_dir
|
||||||
|
from .settings_manager import get_settings_manager
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_model_type(model_type: str | None) -> Optional[str]:
|
||||||
|
if not isinstance(model_type, str):
|
||||||
|
return None
|
||||||
|
normalized = model_type.strip().lower()
|
||||||
|
if normalized in {"lora", "locon", "dora"}:
|
||||||
|
return "lora"
|
||||||
|
if normalized == "checkpoint":
|
||||||
|
return "checkpoint"
|
||||||
|
if normalized in {"embedding", "textualinversion"}:
|
||||||
|
return "embedding"
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_int(value) -> Optional[int]:
|
||||||
|
try:
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
return int(value)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_database_path() -> str:
|
||||||
|
base_dir = get_cache_base_dir(create=True)
|
||||||
|
history_dir = os.path.join(base_dir, "download_history")
|
||||||
|
os.makedirs(history_dir, exist_ok=True)
|
||||||
|
return os.path.join(history_dir, "downloaded_versions.sqlite")
|
||||||
|
|
||||||
|
|
||||||
|
class DownloadedVersionHistoryService:
|
||||||
|
_SCHEMA = """
|
||||||
|
CREATE TABLE IF NOT EXISTS downloaded_model_versions (
|
||||||
|
model_type TEXT NOT NULL,
|
||||||
|
version_id INTEGER NOT NULL,
|
||||||
|
model_id INTEGER,
|
||||||
|
first_seen_at REAL NOT NULL,
|
||||||
|
last_seen_at REAL NOT NULL,
|
||||||
|
source TEXT NOT NULL,
|
||||||
|
last_file_path TEXT,
|
||||||
|
last_library_name TEXT,
|
||||||
|
is_deleted_override INTEGER NOT NULL DEFAULT 0,
|
||||||
|
PRIMARY KEY (model_type, version_id)
|
||||||
|
);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_downloaded_model_versions_model
|
||||||
|
ON downloaded_model_versions(model_type, model_id);
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, db_path: str | None = None, *, settings_manager=None) -> None:
|
||||||
|
self._db_path = db_path or _resolve_database_path()
|
||||||
|
self._settings = settings_manager or get_settings_manager()
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
|
self._schema_initialized = False
|
||||||
|
self._ensure_directory()
|
||||||
|
self._initialize_schema()
|
||||||
|
|
||||||
|
def _ensure_directory(self) -> None:
|
||||||
|
directory = os.path.dirname(self._db_path)
|
||||||
|
if directory:
|
||||||
|
os.makedirs(directory, exist_ok=True)
|
||||||
|
|
||||||
|
def _connect(self) -> sqlite3.Connection:
|
||||||
|
conn = sqlite3.connect(self._db_path, check_same_thread=False)
|
||||||
|
conn.row_factory = sqlite3.Row
|
||||||
|
return conn
|
||||||
|
|
||||||
|
def _initialize_schema(self) -> None:
|
||||||
|
if self._schema_initialized:
|
||||||
|
return
|
||||||
|
with self._connect() as conn:
|
||||||
|
conn.executescript(self._SCHEMA)
|
||||||
|
conn.commit()
|
||||||
|
self._schema_initialized = True
|
||||||
|
|
||||||
|
def get_database_path(self) -> str:
|
||||||
|
return self._db_path
|
||||||
|
|
||||||
|
def _get_active_library_name(self) -> str | None:
|
||||||
|
try:
|
||||||
|
value = self._settings.get_active_library_name()
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
return value or None
|
||||||
|
|
||||||
|
async def mark_downloaded(
|
||||||
|
self,
|
||||||
|
model_type: str,
|
||||||
|
version_id: int,
|
||||||
|
*,
|
||||||
|
model_id: int | None = None,
|
||||||
|
source: str = "manual",
|
||||||
|
file_path: str | None = None,
|
||||||
|
library_name: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
normalized_type = _normalize_model_type(model_type)
|
||||||
|
normalized_version_id = _normalize_int(version_id)
|
||||||
|
normalized_model_id = _normalize_int(model_id)
|
||||||
|
if normalized_type is None or normalized_version_id is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
active_library_name = library_name or self._get_active_library_name()
|
||||||
|
timestamp = time.time()
|
||||||
|
|
||||||
|
async with self._lock:
|
||||||
|
with self._connect() as conn:
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO downloaded_model_versions (
|
||||||
|
model_type, version_id, model_id, first_seen_at, last_seen_at,
|
||||||
|
source, last_file_path, last_library_name, is_deleted_override
|
||||||
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, 0)
|
||||||
|
ON CONFLICT(model_type, version_id) DO UPDATE SET
|
||||||
|
model_id = COALESCE(excluded.model_id, downloaded_model_versions.model_id),
|
||||||
|
last_seen_at = excluded.last_seen_at,
|
||||||
|
source = excluded.source,
|
||||||
|
last_file_path = COALESCE(excluded.last_file_path, downloaded_model_versions.last_file_path),
|
||||||
|
last_library_name = COALESCE(excluded.last_library_name, downloaded_model_versions.last_library_name),
|
||||||
|
is_deleted_override = 0
|
||||||
|
""",
|
||||||
|
(
|
||||||
|
normalized_type,
|
||||||
|
normalized_version_id,
|
||||||
|
normalized_model_id,
|
||||||
|
timestamp,
|
||||||
|
timestamp,
|
||||||
|
source,
|
||||||
|
file_path,
|
||||||
|
active_library_name,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
async def mark_downloaded_bulk(
|
||||||
|
self,
|
||||||
|
model_type: str,
|
||||||
|
records: Sequence[Mapping[str, object]],
|
||||||
|
*,
|
||||||
|
source: str = "scan",
|
||||||
|
library_name: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
normalized_type = _normalize_model_type(model_type)
|
||||||
|
if normalized_type is None or not records:
|
||||||
|
return
|
||||||
|
|
||||||
|
timestamp = time.time()
|
||||||
|
active_library_name = library_name or self._get_active_library_name()
|
||||||
|
payload: list[tuple[object, ...]] = []
|
||||||
|
for record in records:
|
||||||
|
version_id = _normalize_int(record.get("version_id"))
|
||||||
|
if version_id is None:
|
||||||
|
continue
|
||||||
|
payload.append(
|
||||||
|
(
|
||||||
|
normalized_type,
|
||||||
|
version_id,
|
||||||
|
_normalize_int(record.get("model_id")),
|
||||||
|
timestamp,
|
||||||
|
timestamp,
|
||||||
|
source,
|
||||||
|
record.get("file_path"),
|
||||||
|
active_library_name,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if not payload:
|
||||||
|
return
|
||||||
|
|
||||||
|
async with self._lock:
|
||||||
|
with self._connect() as conn:
|
||||||
|
conn.executemany(
|
||||||
|
"""
|
||||||
|
INSERT INTO downloaded_model_versions (
|
||||||
|
model_type, version_id, model_id, first_seen_at, last_seen_at,
|
||||||
|
source, last_file_path, last_library_name, is_deleted_override
|
||||||
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, 0)
|
||||||
|
ON CONFLICT(model_type, version_id) DO UPDATE SET
|
||||||
|
model_id = COALESCE(excluded.model_id, downloaded_model_versions.model_id),
|
||||||
|
last_seen_at = excluded.last_seen_at,
|
||||||
|
source = excluded.source,
|
||||||
|
last_file_path = COALESCE(excluded.last_file_path, downloaded_model_versions.last_file_path),
|
||||||
|
last_library_name = COALESCE(excluded.last_library_name, downloaded_model_versions.last_library_name),
|
||||||
|
is_deleted_override = 0
|
||||||
|
""",
|
||||||
|
payload,
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
async def mark_not_downloaded(self, model_type: str, version_id: int) -> None:
|
||||||
|
normalized_type = _normalize_model_type(model_type)
|
||||||
|
normalized_version_id = _normalize_int(version_id)
|
||||||
|
if normalized_type is None or normalized_version_id is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
timestamp = time.time()
|
||||||
|
|
||||||
|
async with self._lock:
|
||||||
|
with self._connect() as conn:
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO downloaded_model_versions (
|
||||||
|
model_type, version_id, model_id, first_seen_at, last_seen_at,
|
||||||
|
source, last_file_path, last_library_name, is_deleted_override
|
||||||
|
) VALUES (?, ?, NULL, ?, ?, 'manual', NULL, ?, 1)
|
||||||
|
ON CONFLICT(model_type, version_id) DO UPDATE SET
|
||||||
|
last_seen_at = excluded.last_seen_at,
|
||||||
|
source = excluded.source,
|
||||||
|
last_library_name = COALESCE(excluded.last_library_name, downloaded_model_versions.last_library_name),
|
||||||
|
is_deleted_override = 1
|
||||||
|
""",
|
||||||
|
(
|
||||||
|
normalized_type,
|
||||||
|
normalized_version_id,
|
||||||
|
timestamp,
|
||||||
|
timestamp,
|
||||||
|
self._get_active_library_name(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
async def has_been_downloaded(self, model_type: str, version_id: int) -> bool:
|
||||||
|
normalized_type = _normalize_model_type(model_type)
|
||||||
|
normalized_version_id = _normalize_int(version_id)
|
||||||
|
if normalized_type is None or normalized_version_id is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
async with self._lock:
|
||||||
|
with self._connect() as conn:
|
||||||
|
row = conn.execute(
|
||||||
|
"""
|
||||||
|
SELECT is_deleted_override
|
||||||
|
FROM downloaded_model_versions
|
||||||
|
WHERE model_type = ? AND version_id = ?
|
||||||
|
""",
|
||||||
|
(normalized_type, normalized_version_id),
|
||||||
|
).fetchone()
|
||||||
|
return bool(row) and not bool(row["is_deleted_override"])
|
||||||
|
|
||||||
|
async def get_downloaded_version_ids(
|
||||||
|
self, model_type: str, model_id: int
|
||||||
|
) -> list[int]:
|
||||||
|
normalized_type = _normalize_model_type(model_type)
|
||||||
|
normalized_model_id = _normalize_int(model_id)
|
||||||
|
if normalized_type is None or normalized_model_id is None:
|
||||||
|
return []
|
||||||
|
|
||||||
|
async with self._lock:
|
||||||
|
with self._connect() as conn:
|
||||||
|
rows = conn.execute(
|
||||||
|
"""
|
||||||
|
SELECT version_id
|
||||||
|
FROM downloaded_model_versions
|
||||||
|
WHERE model_type = ? AND model_id = ? AND is_deleted_override = 0
|
||||||
|
ORDER BY version_id ASC
|
||||||
|
""",
|
||||||
|
(normalized_type, normalized_model_id),
|
||||||
|
).fetchall()
|
||||||
|
return [int(row["version_id"]) for row in rows]
|
||||||
|
|
||||||
|
async def get_downloaded_version_ids_bulk(
|
||||||
|
self, model_type: str, model_ids: Iterable[int]
|
||||||
|
) -> dict[int, set[int]]:
|
||||||
|
normalized_type = _normalize_model_type(model_type)
|
||||||
|
if normalized_type is None:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
normalized_model_ids = sorted(
|
||||||
|
{
|
||||||
|
value
|
||||||
|
for value in (_normalize_int(model_id) for model_id in model_ids)
|
||||||
|
if value is not None
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if not normalized_model_ids:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
placeholders = ", ".join(["?"] * len(normalized_model_ids))
|
||||||
|
params: list[object] = [normalized_type, *normalized_model_ids]
|
||||||
|
|
||||||
|
async with self._lock:
|
||||||
|
with self._connect() as conn:
|
||||||
|
rows = conn.execute(
|
||||||
|
f"""
|
||||||
|
SELECT model_id, version_id
|
||||||
|
FROM downloaded_model_versions
|
||||||
|
WHERE model_type = ?
|
||||||
|
AND model_id IN ({placeholders})
|
||||||
|
AND is_deleted_override = 0
|
||||||
|
""",
|
||||||
|
params,
|
||||||
|
).fetchall()
|
||||||
|
|
||||||
|
result: dict[int, set[int]] = {}
|
||||||
|
for row in rows:
|
||||||
|
model_id = _normalize_int(row["model_id"])
|
||||||
|
version_id = _normalize_int(row["version_id"])
|
||||||
|
if model_id is None or version_id is None:
|
||||||
|
continue
|
||||||
|
result.setdefault(model_id, set()).add(version_id)
|
||||||
|
return result
|
||||||
@@ -411,6 +411,7 @@ class ModelScanner:
|
|||||||
if scan_result:
|
if scan_result:
|
||||||
await self._apply_scan_result(scan_result)
|
await self._apply_scan_result(scan_result)
|
||||||
await self._save_persistent_cache(scan_result)
|
await self._save_persistent_cache(scan_result)
|
||||||
|
await self._sync_download_history(scan_result.raw_data, source='scan')
|
||||||
|
|
||||||
# Send final progress update
|
# Send final progress update
|
||||||
await ws_manager.broadcast_init_progress({
|
await ws_manager.broadcast_init_progress({
|
||||||
@@ -516,6 +517,7 @@ class ModelScanner:
|
|||||||
)
|
)
|
||||||
|
|
||||||
await self._apply_scan_result(scan_result)
|
await self._apply_scan_result(scan_result)
|
||||||
|
await self._sync_download_history(adjusted_raw_data, source='scan')
|
||||||
|
|
||||||
await ws_manager.broadcast_init_progress({
|
await ws_manager.broadcast_init_progress({
|
||||||
'stage': 'loading_cache',
|
'stage': 'loading_cache',
|
||||||
@@ -576,6 +578,7 @@ class ModelScanner:
|
|||||||
excluded_models=list(self._excluded_models)
|
excluded_models=list(self._excluded_models)
|
||||||
)
|
)
|
||||||
await self._save_persistent_cache(snapshot)
|
await self._save_persistent_cache(snapshot)
|
||||||
|
await self._sync_download_history(snapshot.raw_data, source='scan')
|
||||||
def _count_model_files(self) -> int:
|
def _count_model_files(self) -> int:
|
||||||
"""Count all model files with supported extensions in all roots
|
"""Count all model files with supported extensions in all roots
|
||||||
|
|
||||||
@@ -704,6 +707,7 @@ class ModelScanner:
|
|||||||
scan_result = await self._gather_model_data()
|
scan_result = await self._gather_model_data()
|
||||||
await self._apply_scan_result(scan_result)
|
await self._apply_scan_result(scan_result)
|
||||||
await self._save_persistent_cache(scan_result)
|
await self._save_persistent_cache(scan_result)
|
||||||
|
await self._sync_download_history(scan_result.raw_data, source='scan')
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"{self.model_type.capitalize()} Scanner: Cache initialization completed in {time.time() - start_time:.2f} seconds, "
|
f"{self.model_type.capitalize()} Scanner: Cache initialization completed in {time.time() - start_time:.2f} seconds, "
|
||||||
@@ -1101,6 +1105,49 @@ class ModelScanner:
|
|||||||
|
|
||||||
await self._cache.resort()
|
await self._cache.resort()
|
||||||
|
|
||||||
|
async def _sync_download_history(
|
||||||
|
self,
|
||||||
|
raw_data: List[Mapping[str, Any]],
|
||||||
|
*,
|
||||||
|
source: str,
|
||||||
|
) -> None:
|
||||||
|
records: List[Dict[str, Any]] = []
|
||||||
|
for item in raw_data or []:
|
||||||
|
if not isinstance(item, Mapping):
|
||||||
|
continue
|
||||||
|
civitai = item.get('civitai')
|
||||||
|
if not isinstance(civitai, Mapping):
|
||||||
|
continue
|
||||||
|
|
||||||
|
version_id = civitai.get('id')
|
||||||
|
if version_id in (None, ''):
|
||||||
|
continue
|
||||||
|
|
||||||
|
records.append(
|
||||||
|
{
|
||||||
|
'version_id': version_id,
|
||||||
|
'model_id': civitai.get('modelId'),
|
||||||
|
'file_path': item.get('file_path'),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if not records:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
history_service = await ServiceRegistry.get_downloaded_version_history_service()
|
||||||
|
await history_service.mark_downloaded_bulk(
|
||||||
|
self.model_type,
|
||||||
|
records,
|
||||||
|
source=source,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.debug(
|
||||||
|
"%s Scanner: Failed to sync download history: %s",
|
||||||
|
self.model_type.capitalize(),
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
|
||||||
async def _gather_model_data(
|
async def _gather_model_data(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
|
|||||||
@@ -167,6 +167,28 @@ class ServiceRegistry:
|
|||||||
logger.debug(f"Created and registered {service_name}")
|
logger.debug(f"Created and registered {service_name}")
|
||||||
return service
|
return service
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def get_downloaded_version_history_service(cls):
|
||||||
|
"""Get or create the downloaded-version history service."""
|
||||||
|
|
||||||
|
service_name = "downloaded_version_history_service"
|
||||||
|
|
||||||
|
if service_name in cls._services:
|
||||||
|
return cls._services[service_name]
|
||||||
|
|
||||||
|
async with cls._get_lock(service_name):
|
||||||
|
if service_name in cls._services:
|
||||||
|
return cls._services[service_name]
|
||||||
|
|
||||||
|
from .downloaded_version_history_service import (
|
||||||
|
DownloadedVersionHistoryService,
|
||||||
|
)
|
||||||
|
|
||||||
|
service = DownloadedVersionHistoryService()
|
||||||
|
cls._services[service_name] = service
|
||||||
|
logger.debug(f"Created and registered {service_name}")
|
||||||
|
return service
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_civarchive_client(cls):
|
async def get_civarchive_client(cls):
|
||||||
"""Get or create CivArchive client instance"""
|
"""Get or create CivArchive client instance"""
|
||||||
|
|||||||
@@ -66,6 +66,27 @@ class FakePromptServer:
|
|||||||
instance = Instance()
|
instance = Instance()
|
||||||
|
|
||||||
|
|
||||||
|
class FakeDownloadHistoryService:
|
||||||
|
async def has_been_downloaded(self, _model_type, _version_id):
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def get_downloaded_version_ids(self, _model_type, _model_id):
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def get_downloaded_version_ids_bulk(self, _model_type, _model_ids):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
async def mark_downloaded(self, *_args, **_kwargs):
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def mark_not_downloaded(self, *_args, **_kwargs):
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def fake_download_history_service_factory():
|
||||||
|
return FakeDownloadHistoryService()
|
||||||
|
|
||||||
|
|
||||||
class TestSettingsHandlerSnapshots:
|
class TestSettingsHandlerSnapshots:
|
||||||
"""Snapshot tests for SettingsHandler responses."""
|
"""Snapshot tests for SettingsHandler responses."""
|
||||||
|
|
||||||
@@ -223,6 +244,7 @@ class TestModelLibraryHandlerSnapshots:
|
|||||||
get_lora_scanner=scanner_factory,
|
get_lora_scanner=scanner_factory,
|
||||||
get_checkpoint_scanner=scanner_factory,
|
get_checkpoint_scanner=scanner_factory,
|
||||||
get_embedding_scanner=scanner_factory,
|
get_embedding_scanner=scanner_factory,
|
||||||
|
get_downloaded_version_history_service=fake_download_history_service_factory,
|
||||||
),
|
),
|
||||||
metadata_provider_factory=lambda: None,
|
metadata_provider_factory=lambda: None,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -438,6 +438,46 @@ async def fake_metadata_archive_manager_factory():
|
|||||||
return FakeMetadataArchiveManager()
|
return FakeMetadataArchiveManager()
|
||||||
|
|
||||||
|
|
||||||
|
class FakeDownloadHistoryService:
|
||||||
|
def __init__(self, downloaded_by_type=None):
|
||||||
|
self.downloaded_by_type = downloaded_by_type or {}
|
||||||
|
self.marked_downloaded: list[tuple] = []
|
||||||
|
self.marked_not_downloaded: list[tuple] = []
|
||||||
|
|
||||||
|
async def has_been_downloaded(self, model_type, version_id):
|
||||||
|
return version_id in self.downloaded_by_type.get(model_type, set())
|
||||||
|
|
||||||
|
async def get_downloaded_version_ids(self, model_type, model_id):
|
||||||
|
entries = self.downloaded_by_type.get(model_type, {})
|
||||||
|
if isinstance(entries, dict):
|
||||||
|
return sorted(entries.get(model_id, set()))
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def get_downloaded_version_ids_bulk(self, model_type, model_ids):
|
||||||
|
entries = self.downloaded_by_type.get(model_type, {})
|
||||||
|
if not isinstance(entries, dict):
|
||||||
|
return {}
|
||||||
|
return {
|
||||||
|
model_id: set(entries.get(model_id, set()))
|
||||||
|
for model_id in model_ids
|
||||||
|
if model_id in entries
|
||||||
|
}
|
||||||
|
|
||||||
|
async def mark_downloaded(
|
||||||
|
self, model_type, version_id, *, model_id=None, source="manual", file_path=None
|
||||||
|
):
|
||||||
|
self.marked_downloaded.append(
|
||||||
|
(model_type, version_id, model_id, source, file_path)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def mark_not_downloaded(self, model_type, version_id):
|
||||||
|
self.marked_not_downloaded.append((model_type, version_id))
|
||||||
|
|
||||||
|
|
||||||
|
async def fake_download_history_service_factory():
|
||||||
|
return FakeDownloadHistoryService()
|
||||||
|
|
||||||
|
|
||||||
class RecordingRegistrar:
|
class RecordingRegistrar:
|
||||||
def __init__(self, _app):
|
def __init__(self, _app):
|
||||||
self.registered_mapping = None
|
self.registered_mapping = None
|
||||||
@@ -452,6 +492,7 @@ async def test_misc_routes_bind_produces_expected_handlers():
|
|||||||
get_lora_scanner=fake_scanner_factory,
|
get_lora_scanner=fake_scanner_factory,
|
||||||
get_checkpoint_scanner=fake_scanner_factory,
|
get_checkpoint_scanner=fake_scanner_factory,
|
||||||
get_embedding_scanner=fake_scanner_factory,
|
get_embedding_scanner=fake_scanner_factory,
|
||||||
|
get_downloaded_version_history_service=fake_download_history_service_factory,
|
||||||
)
|
)
|
||||||
|
|
||||||
recorded_registrars = []
|
recorded_registrars = []
|
||||||
@@ -578,6 +619,7 @@ async def test_get_civitai_user_models_marks_library_versions():
|
|||||||
get_lora_scanner=lora_factory,
|
get_lora_scanner=lora_factory,
|
||||||
get_checkpoint_scanner=checkpoint_factory,
|
get_checkpoint_scanner=checkpoint_factory,
|
||||||
get_embedding_scanner=embedding_factory,
|
get_embedding_scanner=embedding_factory,
|
||||||
|
get_downloaded_version_history_service=lambda: fake_download_history_service_factory(),
|
||||||
),
|
),
|
||||||
metadata_provider_factory=provider_factory,
|
metadata_provider_factory=provider_factory,
|
||||||
)
|
)
|
||||||
@@ -600,6 +642,7 @@ async def test_get_civitai_user_models_marks_library_versions():
|
|||||||
"baseModel": "Flux.1",
|
"baseModel": "Flux.1",
|
||||||
"thumbnailUrl": "http://example.com/a1.jpg",
|
"thumbnailUrl": "http://example.com/a1.jpg",
|
||||||
"inLibrary": False,
|
"inLibrary": False,
|
||||||
|
"hasBeenDownloaded": False,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"modelId": 1,
|
"modelId": 1,
|
||||||
@@ -611,6 +654,7 @@ async def test_get_civitai_user_models_marks_library_versions():
|
|||||||
"baseModel": "Flux.1",
|
"baseModel": "Flux.1",
|
||||||
"thumbnailUrl": "http://example.com/a2.jpg",
|
"thumbnailUrl": "http://example.com/a2.jpg",
|
||||||
"inLibrary": True,
|
"inLibrary": True,
|
||||||
|
"hasBeenDownloaded": False,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"modelId": 2,
|
"modelId": 2,
|
||||||
@@ -622,6 +666,7 @@ async def test_get_civitai_user_models_marks_library_versions():
|
|||||||
"baseModel": None,
|
"baseModel": None,
|
||||||
"thumbnailUrl": "http://example.com/e1.jpg",
|
"thumbnailUrl": "http://example.com/e1.jpg",
|
||||||
"inLibrary": False,
|
"inLibrary": False,
|
||||||
|
"hasBeenDownloaded": False,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"modelId": 2,
|
"modelId": 2,
|
||||||
@@ -633,6 +678,7 @@ async def test_get_civitai_user_models_marks_library_versions():
|
|||||||
"baseModel": None,
|
"baseModel": None,
|
||||||
"thumbnailUrl": None,
|
"thumbnailUrl": None,
|
||||||
"inLibrary": True,
|
"inLibrary": True,
|
||||||
|
"hasBeenDownloaded": False,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"modelId": 3,
|
"modelId": 3,
|
||||||
@@ -644,6 +690,7 @@ async def test_get_civitai_user_models_marks_library_versions():
|
|||||||
"baseModel": "SDXL",
|
"baseModel": "SDXL",
|
||||||
"thumbnailUrl": None,
|
"thumbnailUrl": None,
|
||||||
"inLibrary": False,
|
"inLibrary": False,
|
||||||
|
"hasBeenDownloaded": False,
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -692,6 +739,7 @@ async def test_get_civitai_user_models_rewrites_civitai_previews():
|
|||||||
get_lora_scanner=fake_scanner_factory,
|
get_lora_scanner=fake_scanner_factory,
|
||||||
get_checkpoint_scanner=fake_scanner_factory,
|
get_checkpoint_scanner=fake_scanner_factory,
|
||||||
get_embedding_scanner=fake_scanner_factory,
|
get_embedding_scanner=fake_scanner_factory,
|
||||||
|
get_downloaded_version_history_service=fake_download_history_service_factory,
|
||||||
),
|
),
|
||||||
metadata_provider_factory=provider_factory,
|
metadata_provider_factory=provider_factory,
|
||||||
)
|
)
|
||||||
@@ -727,6 +775,7 @@ async def test_get_civitai_user_models_requires_username():
|
|||||||
get_lora_scanner=fake_scanner_factory,
|
get_lora_scanner=fake_scanner_factory,
|
||||||
get_checkpoint_scanner=fake_scanner_factory,
|
get_checkpoint_scanner=fake_scanner_factory,
|
||||||
get_embedding_scanner=fake_scanner_factory,
|
get_embedding_scanner=fake_scanner_factory,
|
||||||
|
get_downloaded_version_history_service=fake_download_history_service_factory,
|
||||||
),
|
),
|
||||||
metadata_provider_factory=provider_factory,
|
metadata_provider_factory=provider_factory,
|
||||||
)
|
)
|
||||||
@@ -760,6 +809,7 @@ def test_ensure_handler_mapping_caches_result():
|
|||||||
get_lora_scanner=fake_scanner_factory,
|
get_lora_scanner=fake_scanner_factory,
|
||||||
get_checkpoint_scanner=fake_scanner_factory,
|
get_checkpoint_scanner=fake_scanner_factory,
|
||||||
get_embedding_scanner=fake_scanner_factory,
|
get_embedding_scanner=fake_scanner_factory,
|
||||||
|
get_downloaded_version_history_service=fake_download_history_service_factory,
|
||||||
),
|
),
|
||||||
metadata_provider_factory=fake_metadata_provider_factory,
|
metadata_provider_factory=fake_metadata_provider_factory,
|
||||||
metadata_archive_manager_factory=fake_metadata_archive_manager_factory,
|
metadata_archive_manager_factory=fake_metadata_archive_manager_factory,
|
||||||
@@ -802,6 +852,7 @@ async def test_check_model_exists_returns_local_versions():
|
|||||||
get_lora_scanner=lora_factory,
|
get_lora_scanner=lora_factory,
|
||||||
get_checkpoint_scanner=checkpoint_factory,
|
get_checkpoint_scanner=checkpoint_factory,
|
||||||
get_embedding_scanner=embedding_factory,
|
get_embedding_scanner=embedding_factory,
|
||||||
|
get_downloaded_version_history_service=fake_download_history_service_factory,
|
||||||
),
|
),
|
||||||
metadata_provider_factory=fake_metadata_provider_factory,
|
metadata_provider_factory=fake_metadata_provider_factory,
|
||||||
)
|
)
|
||||||
@@ -811,10 +862,94 @@ async def test_check_model_exists_returns_local_versions():
|
|||||||
|
|
||||||
assert payload["success"] is True
|
assert payload["success"] is True
|
||||||
assert payload["modelType"] == "lora"
|
assert payload["modelType"] == "lora"
|
||||||
assert payload["versions"] == versions
|
assert payload["versions"] == [
|
||||||
|
{"versionId": 11, "name": "v1", "fileName": "model-one", "hasBeenDownloaded": True},
|
||||||
|
{"versionId": 12, "name": "v2", "fileName": "model-two", "hasBeenDownloaded": True},
|
||||||
|
]
|
||||||
assert lora_scanner.version_calls == [5]
|
assert lora_scanner.version_calls == [5]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_check_model_exists_returns_download_history_when_file_missing():
|
||||||
|
history_service = FakeDownloadHistoryService({"checkpoint": {999}})
|
||||||
|
|
||||||
|
async def history_factory():
|
||||||
|
return history_service
|
||||||
|
|
||||||
|
handler = ModelLibraryHandler(
|
||||||
|
ServiceRegistryAdapter(
|
||||||
|
get_lora_scanner=fake_scanner_factory,
|
||||||
|
get_checkpoint_scanner=fake_scanner_factory,
|
||||||
|
get_embedding_scanner=fake_scanner_factory,
|
||||||
|
get_downloaded_version_history_service=history_factory,
|
||||||
|
),
|
||||||
|
metadata_provider_factory=fake_metadata_provider_factory,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await handler.check_model_exists(
|
||||||
|
FakeRequest(query={"modelId": "5", "modelVersionId": "999"})
|
||||||
|
)
|
||||||
|
payload = json.loads(response.text)
|
||||||
|
|
||||||
|
assert payload == {
|
||||||
|
"success": True,
|
||||||
|
"exists": False,
|
||||||
|
"modelType": "checkpoint",
|
||||||
|
"hasBeenDownloaded": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_model_version_download_status_endpoints():
|
||||||
|
history_service = FakeDownloadHistoryService({"lora": {123}})
|
||||||
|
|
||||||
|
async def history_factory():
|
||||||
|
return history_service
|
||||||
|
|
||||||
|
handler = ModelLibraryHandler(
|
||||||
|
ServiceRegistryAdapter(
|
||||||
|
get_lora_scanner=fake_scanner_factory,
|
||||||
|
get_checkpoint_scanner=fake_scanner_factory,
|
||||||
|
get_embedding_scanner=fake_scanner_factory,
|
||||||
|
get_downloaded_version_history_service=history_factory,
|
||||||
|
),
|
||||||
|
metadata_provider_factory=fake_metadata_provider_factory,
|
||||||
|
)
|
||||||
|
|
||||||
|
get_response = await handler.get_model_version_download_status(
|
||||||
|
FakeRequest(query={"modelType": "lora", "modelVersionId": "123"})
|
||||||
|
)
|
||||||
|
get_payload = json.loads(get_response.text)
|
||||||
|
assert get_payload == {
|
||||||
|
"success": True,
|
||||||
|
"modelType": "lora",
|
||||||
|
"modelVersionId": 123,
|
||||||
|
"hasBeenDownloaded": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
set_response = await handler.set_model_version_download_status(
|
||||||
|
FakeRequest(
|
||||||
|
json_data={
|
||||||
|
"modelType": "checkpoint",
|
||||||
|
"modelVersionId": 456,
|
||||||
|
"modelId": 78,
|
||||||
|
"downloaded": True,
|
||||||
|
"filePath": "/tmp/model.safetensors",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
set_payload = json.loads(set_response.text)
|
||||||
|
assert set_payload == {
|
||||||
|
"success": True,
|
||||||
|
"modelType": "checkpoint",
|
||||||
|
"modelVersionId": 456,
|
||||||
|
"hasBeenDownloaded": True,
|
||||||
|
}
|
||||||
|
assert history_service.marked_downloaded == [
|
||||||
|
("checkpoint", 456, 78, "manual", "/tmp/model.safetensors")
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def test_create_handler_set_uses_provided_dependencies():
|
def test_create_handler_set_uses_provided_dependencies():
|
||||||
recorded_handlers: list[dict] = []
|
recorded_handlers: list[dict] = []
|
||||||
|
|
||||||
@@ -845,6 +980,7 @@ def test_create_handler_set_uses_provided_dependencies():
|
|||||||
get_lora_scanner=fake_scanner_factory,
|
get_lora_scanner=fake_scanner_factory,
|
||||||
get_checkpoint_scanner=fake_scanner_factory,
|
get_checkpoint_scanner=fake_scanner_factory,
|
||||||
get_embedding_scanner=fake_scanner_factory,
|
get_embedding_scanner=fake_scanner_factory,
|
||||||
|
get_downloaded_version_history_service=fake_download_history_service_factory,
|
||||||
),
|
),
|
||||||
metadata_provider_factory=fake_metadata_provider_factory,
|
metadata_provider_factory=fake_metadata_provider_factory,
|
||||||
metadata_archive_manager_factory=fake_metadata_archive_manager_factory,
|
metadata_archive_manager_factory=fake_metadata_archive_manager_factory,
|
||||||
|
|||||||
70
tests/services/test_downloaded_version_history_service.py
Normal file
70
tests/services/test_downloaded_version_history_service.py
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from py.services.downloaded_version_history_service import (
|
||||||
|
DownloadedVersionHistoryService,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DummySettings:
|
||||||
|
def get_active_library_name(self) -> str:
|
||||||
|
return "alpha"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_download_history_roundtrip_and_manual_override(tmp_path: Path) -> None:
|
||||||
|
db_path = tmp_path / "download-history.sqlite"
|
||||||
|
service = DownloadedVersionHistoryService(
|
||||||
|
str(db_path),
|
||||||
|
settings_manager=DummySettings(),
|
||||||
|
)
|
||||||
|
|
||||||
|
await service.mark_downloaded(
|
||||||
|
"lora",
|
||||||
|
101,
|
||||||
|
model_id=11,
|
||||||
|
source="scan",
|
||||||
|
file_path="/models/a.safetensors",
|
||||||
|
)
|
||||||
|
assert await service.has_been_downloaded("lora", 101) is True
|
||||||
|
assert await service.get_downloaded_version_ids("lora", 11) == [101]
|
||||||
|
|
||||||
|
await service.mark_not_downloaded("lora", 101)
|
||||||
|
assert await service.has_been_downloaded("lora", 101) is False
|
||||||
|
assert await service.get_downloaded_version_ids("lora", 11) == []
|
||||||
|
|
||||||
|
await service.mark_downloaded(
|
||||||
|
"lora",
|
||||||
|
101,
|
||||||
|
model_id=11,
|
||||||
|
source="download",
|
||||||
|
file_path="/models/a.safetensors",
|
||||||
|
)
|
||||||
|
assert await service.has_been_downloaded("lora", 101) is True
|
||||||
|
assert await service.get_downloaded_version_ids("lora", 11) == [101]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_download_history_bulk_lookup(tmp_path: Path) -> None:
|
||||||
|
db_path = tmp_path / "download-history.sqlite"
|
||||||
|
service = DownloadedVersionHistoryService(
|
||||||
|
str(db_path),
|
||||||
|
settings_manager=DummySettings(),
|
||||||
|
)
|
||||||
|
|
||||||
|
await service.mark_downloaded_bulk(
|
||||||
|
"checkpoint",
|
||||||
|
[
|
||||||
|
{"model_id": 5, "version_id": 501, "file_path": "/m/one.safetensors"},
|
||||||
|
{"model_id": 5, "version_id": 502, "file_path": "/m/two.safetensors"},
|
||||||
|
{"model_id": 6, "version_id": 601, "file_path": "/m/three.safetensors"},
|
||||||
|
],
|
||||||
|
source="scan",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert await service.get_downloaded_version_ids("checkpoint", 5) == [501, 502]
|
||||||
|
assert await service.get_downloaded_version_ids_bulk("checkpoint", [5, 6, 7]) == {
|
||||||
|
5: {501, 502},
|
||||||
|
6: {601},
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user