diff --git a/py/routes/handlers/misc_handlers.py b/py/routes/handlers/misc_handlers.py index dbc35924..e0ffc1f5 100644 --- a/py/routes/handlers/misc_handlers.py +++ b/py/routes/handlers/misc_handlers.py @@ -1890,6 +1890,86 @@ class ModelLibraryHandler: logger.error("Failed to check model existence: %s", exc, exc_info=True) return web.json_response({"success": False, "error": str(exc)}, status=500) + async def check_models_exist(self, request: web.Request) -> web.Response: + try: + model_ids_raw = request.query.get("modelIds", "") + if not model_ids_raw: + return web.json_response( + {"success": True, "results": []} + ) + + raw_ids = model_ids_raw.split(",") + seen: set[int] = set() + model_ids: list[int] = [] + for raw in raw_ids: + stripped = raw.strip() + if not stripped: + continue + try: + mid = int(stripped) + except ValueError: + continue + if mid not in seen: + seen.add(mid) + model_ids.append(mid) + + if not model_ids: + return web.json_response( + {"success": True, "results": []} + ) + + lora_scanner = await self._service_registry.get_lora_scanner() + checkpoint_scanner = await self._service_registry.get_checkpoint_scanner() + embedding_scanner = await self._service_registry.get_embedding_scanner() + + results: list[dict] = [] + for model_id in model_ids: + lora_versions = await lora_scanner.get_model_versions_by_id(model_id) + if lora_versions: + results.append({ + "modelId": model_id, + "modelType": "lora", + "versions": self._with_downloaded_flag(lora_versions), + "downloadedVersionIds": [], + }) + continue + + if checkpoint_scanner: + checkpoint_versions = await checkpoint_scanner.get_model_versions_by_id(model_id) + if checkpoint_versions: + results.append({ + "modelId": model_id, + "modelType": "checkpoint", + "versions": self._with_downloaded_flag(checkpoint_versions), + "downloadedVersionIds": [], + }) + continue + + if embedding_scanner: + embedding_versions = await embedding_scanner.get_model_versions_by_id(model_id) + if embedding_versions: + results.append({ + "modelId": model_id, + "modelType": "embedding", + "versions": self._with_downloaded_flag(embedding_versions), + "downloadedVersionIds": [], + }) + continue + + results.append({ + "modelId": model_id, + "modelType": None, + "versions": [], + "downloadedVersionIds": [], + }) + + return web.json_response( + {"success": True, "results": results} + ) + except Exception as exc: + logger.error("Failed to check models existence: %s", exc, exc_info=True) + return web.json_response({"success": False, "error": str(exc)}, status=500) + async def get_model_version_download_status( self, request: web.Request ) -> web.Response: @@ -3035,6 +3115,7 @@ class MiscHandlerSet: "update_node_widget": self.node_registry.update_node_widget, "get_registry": self.node_registry.get_registry, "check_model_exists": self.model_library.check_model_exists, + "check_models_exist": self.model_library.check_models_exist, "get_model_version_download_status": self.model_library.get_model_version_download_status, "set_model_version_download_status": self.model_library.set_model_version_download_status, "get_civitai_user_models": self.model_library.get_civitai_user_models, diff --git a/py/routes/misc_route_registrar.py b/py/routes/misc_route_registrar.py index 447a36a9..4c98de4d 100644 --- a/py/routes/misc_route_registrar.py +++ b/py/routes/misc_route_registrar.py @@ -43,6 +43,7 @@ MISC_ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = ( RouteDefinition("POST", "/api/lm/update-node-widget", "update_node_widget"), RouteDefinition("GET", "/api/lm/get-registry", "get_registry"), RouteDefinition("GET", "/api/lm/check-model-exists", "check_model_exists"), + RouteDefinition("GET", "/api/lm/check-models-exist", "check_models_exist"), RouteDefinition( "GET", "/api/lm/model-version-download-status",