feat(download-history): track downloaded model versions

This commit is contained in:
Will Miao
2026-04-03 16:13:14 +08:00
parent 4f599aeced
commit 33a7f07558
9 changed files with 881 additions and 18 deletions

View File

@@ -751,6 +751,7 @@ class ServiceRegistryAdapter:
get_lora_scanner: Callable[[], Awaitable]
get_checkpoint_scanner: Callable[[], Awaitable]
get_embedding_scanner: Callable[[], Awaitable]
get_downloaded_version_history_service: Callable[[], Awaitable]
class ModelLibraryHandler:
@@ -764,6 +765,41 @@ class ModelLibraryHandler:
self._service_registry = service_registry
self._metadata_provider_factory = metadata_provider_factory
@staticmethod
def _normalize_model_type(model_type: str | None) -> str | None:
if not isinstance(model_type, str):
return None
normalized = model_type.strip().lower()
if normalized in {"lora", "locon", "dora"}:
return "lora"
if normalized == "checkpoint":
return "checkpoint"
if normalized in {"embedding", "textualinversion"}:
return "embedding"
return None
async def _get_scanner_for_type(self, model_type: str | None):
normalized_type = self._normalize_model_type(model_type)
if normalized_type == "lora":
return normalized_type, await self._service_registry.get_lora_scanner()
if normalized_type == "checkpoint":
return normalized_type, await self._service_registry.get_checkpoint_scanner()
if normalized_type == "embedding":
return normalized_type, await self._service_registry.get_embedding_scanner()
return None, None
async def _get_download_history_service(self):
return await self._service_registry.get_downloaded_version_history_service()
@staticmethod
def _with_downloaded_flag(versions: list[dict]) -> list[dict]:
enriched: list[dict] = []
for version in versions:
entry = dict(version)
entry.setdefault("hasBeenDownloaded", True)
enriched.append(entry)
return enriched
async def check_model_exists(self, request: web.Request) -> web.Response:
try:
model_id_str = request.query.get("modelId")
@@ -819,11 +855,30 @@ class ModelLibraryHandler:
exists = True
model_type = "embedding"
history_service = await self._get_download_history_service()
has_been_downloaded = False
history_type = model_type
if history_type:
has_been_downloaded = await history_service.has_been_downloaded(
history_type,
model_version_id,
)
else:
for candidate_type in ("lora", "checkpoint", "embedding"):
if await history_service.has_been_downloaded(
candidate_type,
model_version_id,
):
has_been_downloaded = True
history_type = candidate_type
break
return web.json_response(
{
"success": True,
"exists": exists,
"modelType": model_type if exists else None,
"modelType": model_type if exists else history_type,
"hasBeenDownloaded": has_been_downloaded,
}
)
@@ -843,13 +898,13 @@ class ModelLibraryHandler:
versions = []
if lora_versions:
model_type = "lora"
versions = lora_versions
versions = self._with_downloaded_flag(lora_versions)
elif checkpoint_versions:
model_type = "checkpoint"
versions = checkpoint_versions
versions = self._with_downloaded_flag(checkpoint_versions)
elif embedding_versions:
model_type = "embedding"
versions = embedding_versions
versions = self._with_downloaded_flag(embedding_versions)
return web.json_response(
{"success": True, "modelType": model_type, "versions": versions}
@@ -858,6 +913,108 @@ class ModelLibraryHandler:
logger.error("Failed to check model existence: %s", exc, exc_info=True)
return web.json_response({"success": False, "error": str(exc)}, status=500)
async def get_model_version_download_status(
self, request: web.Request
) -> web.Response:
try:
model_type, _ = await self._get_scanner_for_type(request.query.get("modelType"))
if not model_type:
return web.json_response(
{"success": False, "error": "Parameter modelType is required"},
status=400,
)
model_version_id_str = request.query.get("modelVersionId")
if not model_version_id_str:
return web.json_response(
{"success": False, "error": "Missing required parameter: modelVersionId"},
status=400,
)
try:
model_version_id = int(model_version_id_str)
except ValueError:
return web.json_response(
{"success": False, "error": "Parameter modelVersionId must be an integer"},
status=400,
)
history_service = await self._get_download_history_service()
return web.json_response(
{
"success": True,
"modelType": model_type,
"modelVersionId": model_version_id,
"hasBeenDownloaded": await history_service.has_been_downloaded(
model_type,
model_version_id,
),
}
)
except Exception as exc: # pragma: no cover - defensive logging
logger.error(
"Failed to get model version download status: %s",
exc,
exc_info=True,
)
return web.json_response({"success": False, "error": str(exc)}, status=500)
async def set_model_version_download_status(
self, request: web.Request
) -> web.Response:
try:
data = await request.json()
model_type, _ = await self._get_scanner_for_type(data.get("modelType"))
if not model_type:
return web.json_response(
{"success": False, "error": "Parameter modelType is required"},
status=400,
)
try:
model_version_id = int(data.get("modelVersionId"))
except (TypeError, ValueError):
return web.json_response(
{"success": False, "error": "Parameter modelVersionId must be an integer"},
status=400,
)
downloaded = data.get("downloaded")
if not isinstance(downloaded, bool):
return web.json_response(
{"success": False, "error": "Parameter downloaded must be a boolean"},
status=400,
)
history_service = await self._get_download_history_service()
if downloaded:
model_id = data.get("modelId")
file_path = data.get("filePath")
await history_service.mark_downloaded(
model_type,
model_version_id,
model_id=model_id,
source="manual",
file_path=file_path if isinstance(file_path, str) else None,
)
else:
await history_service.mark_not_downloaded(model_type, model_version_id)
return web.json_response(
{
"success": True,
"modelType": model_type,
"modelVersionId": model_version_id,
"hasBeenDownloaded": downloaded,
}
)
except Exception as exc: # pragma: no cover - defensive logging
logger.error(
"Failed to set model version download status: %s",
exc,
exc_info=True,
)
return web.json_response({"success": False, "error": str(exc)}, status=500)
async def get_model_versions_status(self, request: web.Request) -> web.Response:
try:
model_id_str = request.query.get("modelId")
@@ -896,18 +1053,8 @@ class ModelLibraryHandler:
model_name = response.get("name", "")
model_type = response.get("type", "").lower()
scanner = None
normalized_type = None
if model_type in {"lora", "locon", "dora"}:
scanner = await self._service_registry.get_lora_scanner()
normalized_type = "lora"
elif model_type == "checkpoint":
scanner = await self._service_registry.get_checkpoint_scanner()
normalized_type = "checkpoint"
elif model_type == "textualinversion":
scanner = await self._service_registry.get_embedding_scanner()
normalized_type = "embedding"
else:
normalized_type, scanner = await self._get_scanner_for_type(model_type)
if not normalized_type:
return web.json_response(
{
"success": False,
@@ -925,8 +1072,14 @@ class ModelLibraryHandler:
status=503,
)
history_service = await self._get_download_history_service()
local_versions = await scanner.get_model_versions_by_id(model_id)
local_version_ids = {version["versionId"] for version in local_versions}
downloaded_version_ids = await history_service.get_downloaded_version_ids(
normalized_type,
model_id,
)
downloaded_version_id_set = set(downloaded_version_ids)
enriched_versions = []
for version in versions:
@@ -939,6 +1092,7 @@ class ModelLibraryHandler:
if version.get("images")
else None,
"inLibrary": version_id in local_version_ids,
"hasBeenDownloaded": version_id in downloaded_version_id_set,
}
)
@@ -1007,6 +1161,33 @@ class ModelLibraryHandler:
}
versions: list[dict] = []
history_service = await self._get_download_history_service()
model_ids: list[int] = []
for model in models:
try:
model_ids.append(int(model.get("id")))
except (TypeError, ValueError):
continue
lora_downloaded = await history_service.get_downloaded_version_ids_bulk(
"lora",
model_ids,
)
checkpoint_downloaded = await history_service.get_downloaded_version_ids_bulk(
"checkpoint",
model_ids,
)
embedding_downloaded = await history_service.get_downloaded_version_ids_bulk(
"embedding",
model_ids,
)
downloaded_version_map: Dict[str, Dict[int, set[int]]] = {
"lora": lora_downloaded,
"locon": lora_downloaded,
"dora": lora_downloaded,
"checkpoint": checkpoint_downloaded,
"textualinversion": embedding_downloaded,
}
for model in models:
if not isinstance(model, dict):
continue
@@ -1061,6 +1242,8 @@ class ModelLibraryHandler:
in_library = await scanner.check_model_version_exists(
version_id_int
)
downloaded_versions = downloaded_version_map.get(model_type, {})
downloaded_version_ids = downloaded_versions.get(model_id_int, set())
versions.append(
{
@@ -1073,6 +1256,7 @@ class ModelLibraryHandler:
"baseModel": version.get("baseModel"),
"thumbnailUrl": thumbnail_url,
"inLibrary": in_library,
"hasBeenDownloaded": version_id_int in downloaded_version_ids,
}
)
@@ -1655,6 +1839,8 @@ class MiscHandlerSet:
"update_node_widget": self.node_registry.update_node_widget,
"get_registry": self.node_registry.get_registry,
"check_model_exists": self.model_library.check_model_exists,
"get_model_version_download_status": self.model_library.get_model_version_download_status,
"set_model_version_download_status": self.model_library.set_model_version_download_status,
"get_civitai_user_models": self.model_library.get_civitai_user_models,
"download_metadata_archive": self.metadata_archive.download_metadata_archive,
"remove_metadata_archive": self.metadata_archive.remove_metadata_archive,
@@ -1679,4 +1865,5 @@ def build_service_registry_adapter() -> ServiceRegistryAdapter:
get_lora_scanner=ServiceRegistry.get_lora_scanner,
get_checkpoint_scanner=ServiceRegistry.get_checkpoint_scanner,
get_embedding_scanner=ServiceRegistry.get_embedding_scanner,
get_downloaded_version_history_service=ServiceRegistry.get_downloaded_version_history_service,
)

View File

@@ -37,6 +37,16 @@ MISC_ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
RouteDefinition("POST", "/api/lm/update-node-widget", "update_node_widget"),
RouteDefinition("GET", "/api/lm/get-registry", "get_registry"),
RouteDefinition("GET", "/api/lm/check-model-exists", "check_model_exists"),
RouteDefinition(
"GET",
"/api/lm/model-version-download-status",
"get_model_version_download_status",
),
RouteDefinition(
"POST",
"/api/lm/model-version-download-status",
"set_model_version_download_status",
),
RouteDefinition("GET", "/api/lm/civitai/user-models", "get_civitai_user_models"),
RouteDefinition(
"POST", "/api/lm/download-metadata-archive", "download_metadata_archive"