From 727d0ef04399dcf35fe04b3faef21880ac046e5f Mon Sep 17 00:00:00 2001 From: Will Miao Date: Fri, 3 Apr 2026 22:17:09 +0800 Subject: [PATCH] feat(misc): add model download status aggregation --- py/routes/handlers/misc_handlers.py | 45 +++++++++++++++++++++++++++-- py/routes/misc_route_registrar.py | 5 ++++ 2 files changed, 48 insertions(+), 2 deletions(-) diff --git a/py/routes/handlers/misc_handlers.py b/py/routes/handlers/misc_handlers.py index 554f5627..0dfaac19 100644 --- a/py/routes/handlers/misc_handlers.py +++ b/py/routes/handlers/misc_handlers.py @@ -896,18 +896,49 @@ class ModelLibraryHandler: model_type = None versions = [] + downloaded_version_ids = [] + history_service = await self._get_download_history_service() if lora_versions: model_type = "lora" versions = self._with_downloaded_flag(lora_versions) + downloaded_version_ids = await history_service.get_downloaded_version_ids( + model_type, + model_id, + ) elif checkpoint_versions: model_type = "checkpoint" versions = self._with_downloaded_flag(checkpoint_versions) + downloaded_version_ids = await history_service.get_downloaded_version_ids( + model_type, + model_id, + ) elif embedding_versions: model_type = "embedding" versions = self._with_downloaded_flag(embedding_versions) + downloaded_version_ids = await history_service.get_downloaded_version_ids( + model_type, + model_id, + ) + else: + for candidate_type in ("lora", "checkpoint", "embedding"): + candidate_downloaded_version_ids = ( + await history_service.get_downloaded_version_ids( + candidate_type, + model_id, + ) + ) + if candidate_downloaded_version_ids: + model_type = candidate_type + downloaded_version_ids = candidate_downloaded_version_ids + break return web.json_response( - {"success": True, "modelType": model_type, "versions": versions} + { + "success": True, + "modelType": model_type, + "versions": versions, + "downloadedVersionIds": downloaded_version_ids, + } ) except Exception as exc: # pragma: no cover - defensive logging logger.error("Failed to check model existence: %s", exc, exc_info=True) @@ -962,7 +993,10 @@ class ModelLibraryHandler: self, request: web.Request ) -> web.Response: try: - data = await request.json() + if request.method == "GET": + data = request.query + else: + data = await request.json() model_type, _ = await self._get_scanner_for_type(data.get("modelType")) if not model_type: return web.json_response( @@ -979,6 +1013,13 @@ class ModelLibraryHandler: ) downloaded = data.get("downloaded") + if isinstance(downloaded, str): + normalized_downloaded = downloaded.strip().lower() + if normalized_downloaded in {"true", "1"}: + downloaded = True + elif normalized_downloaded in {"false", "0"}: + downloaded = False + if not isinstance(downloaded, bool): return web.json_response( {"success": False, "error": "Parameter downloaded must be a boolean"}, diff --git a/py/routes/misc_route_registrar.py b/py/routes/misc_route_registrar.py index 9f7a35c9..9351d35a 100644 --- a/py/routes/misc_route_registrar.py +++ b/py/routes/misc_route_registrar.py @@ -47,6 +47,11 @@ MISC_ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = ( "/api/lm/model-version-download-status", "set_model_version_download_status", ), + RouteDefinition( + "GET", + "/api/lm/set-model-version-download-status", + "set_model_version_download_status", + ), RouteDefinition("GET", "/api/lm/civitai/user-models", "get_civitai_user_models"), RouteDefinition( "POST", "/api/lm/download-metadata-archive", "download_metadata_archive"