mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-04-10 04:42:14 -03:00
feat(download-history): track downloaded model versions
This commit is contained in:
@@ -751,6 +751,7 @@ class ServiceRegistryAdapter:
|
||||
get_lora_scanner: Callable[[], Awaitable]
|
||||
get_checkpoint_scanner: Callable[[], Awaitable]
|
||||
get_embedding_scanner: Callable[[], Awaitable]
|
||||
get_downloaded_version_history_service: Callable[[], Awaitable]
|
||||
|
||||
|
||||
class ModelLibraryHandler:
|
||||
@@ -764,6 +765,41 @@ class ModelLibraryHandler:
|
||||
self._service_registry = service_registry
|
||||
self._metadata_provider_factory = metadata_provider_factory
|
||||
|
||||
@staticmethod
|
||||
def _normalize_model_type(model_type: str | None) -> str | None:
|
||||
if not isinstance(model_type, str):
|
||||
return None
|
||||
normalized = model_type.strip().lower()
|
||||
if normalized in {"lora", "locon", "dora"}:
|
||||
return "lora"
|
||||
if normalized == "checkpoint":
|
||||
return "checkpoint"
|
||||
if normalized in {"embedding", "textualinversion"}:
|
||||
return "embedding"
|
||||
return None
|
||||
|
||||
async def _get_scanner_for_type(self, model_type: str | None):
|
||||
normalized_type = self._normalize_model_type(model_type)
|
||||
if normalized_type == "lora":
|
||||
return normalized_type, await self._service_registry.get_lora_scanner()
|
||||
if normalized_type == "checkpoint":
|
||||
return normalized_type, await self._service_registry.get_checkpoint_scanner()
|
||||
if normalized_type == "embedding":
|
||||
return normalized_type, await self._service_registry.get_embedding_scanner()
|
||||
return None, None
|
||||
|
||||
async def _get_download_history_service(self):
|
||||
return await self._service_registry.get_downloaded_version_history_service()
|
||||
|
||||
@staticmethod
|
||||
def _with_downloaded_flag(versions: list[dict]) -> list[dict]:
|
||||
enriched: list[dict] = []
|
||||
for version in versions:
|
||||
entry = dict(version)
|
||||
entry.setdefault("hasBeenDownloaded", True)
|
||||
enriched.append(entry)
|
||||
return enriched
|
||||
|
||||
async def check_model_exists(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
model_id_str = request.query.get("modelId")
|
||||
@@ -819,11 +855,30 @@ class ModelLibraryHandler:
|
||||
exists = True
|
||||
model_type = "embedding"
|
||||
|
||||
history_service = await self._get_download_history_service()
|
||||
has_been_downloaded = False
|
||||
history_type = model_type
|
||||
if history_type:
|
||||
has_been_downloaded = await history_service.has_been_downloaded(
|
||||
history_type,
|
||||
model_version_id,
|
||||
)
|
||||
else:
|
||||
for candidate_type in ("lora", "checkpoint", "embedding"):
|
||||
if await history_service.has_been_downloaded(
|
||||
candidate_type,
|
||||
model_version_id,
|
||||
):
|
||||
has_been_downloaded = True
|
||||
history_type = candidate_type
|
||||
break
|
||||
|
||||
return web.json_response(
|
||||
{
|
||||
"success": True,
|
||||
"exists": exists,
|
||||
"modelType": model_type if exists else None,
|
||||
"modelType": model_type if exists else history_type,
|
||||
"hasBeenDownloaded": has_been_downloaded,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -843,13 +898,13 @@ class ModelLibraryHandler:
|
||||
versions = []
|
||||
if lora_versions:
|
||||
model_type = "lora"
|
||||
versions = lora_versions
|
||||
versions = self._with_downloaded_flag(lora_versions)
|
||||
elif checkpoint_versions:
|
||||
model_type = "checkpoint"
|
||||
versions = checkpoint_versions
|
||||
versions = self._with_downloaded_flag(checkpoint_versions)
|
||||
elif embedding_versions:
|
||||
model_type = "embedding"
|
||||
versions = embedding_versions
|
||||
versions = self._with_downloaded_flag(embedding_versions)
|
||||
|
||||
return web.json_response(
|
||||
{"success": True, "modelType": model_type, "versions": versions}
|
||||
@@ -858,6 +913,108 @@ class ModelLibraryHandler:
|
||||
logger.error("Failed to check model existence: %s", exc, exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
async def get_model_version_download_status(
|
||||
self, request: web.Request
|
||||
) -> web.Response:
|
||||
try:
|
||||
model_type, _ = await self._get_scanner_for_type(request.query.get("modelType"))
|
||||
if not model_type:
|
||||
return web.json_response(
|
||||
{"success": False, "error": "Parameter modelType is required"},
|
||||
status=400,
|
||||
)
|
||||
|
||||
model_version_id_str = request.query.get("modelVersionId")
|
||||
if not model_version_id_str:
|
||||
return web.json_response(
|
||||
{"success": False, "error": "Missing required parameter: modelVersionId"},
|
||||
status=400,
|
||||
)
|
||||
try:
|
||||
model_version_id = int(model_version_id_str)
|
||||
except ValueError:
|
||||
return web.json_response(
|
||||
{"success": False, "error": "Parameter modelVersionId must be an integer"},
|
||||
status=400,
|
||||
)
|
||||
|
||||
history_service = await self._get_download_history_service()
|
||||
return web.json_response(
|
||||
{
|
||||
"success": True,
|
||||
"modelType": model_type,
|
||||
"modelVersionId": model_version_id,
|
||||
"hasBeenDownloaded": await history_service.has_been_downloaded(
|
||||
model_type,
|
||||
model_version_id,
|
||||
),
|
||||
}
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
logger.error(
|
||||
"Failed to get model version download status: %s",
|
||||
exc,
|
||||
exc_info=True,
|
||||
)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
async def set_model_version_download_status(
|
||||
self, request: web.Request
|
||||
) -> web.Response:
|
||||
try:
|
||||
data = await request.json()
|
||||
model_type, _ = await self._get_scanner_for_type(data.get("modelType"))
|
||||
if not model_type:
|
||||
return web.json_response(
|
||||
{"success": False, "error": "Parameter modelType is required"},
|
||||
status=400,
|
||||
)
|
||||
|
||||
try:
|
||||
model_version_id = int(data.get("modelVersionId"))
|
||||
except (TypeError, ValueError):
|
||||
return web.json_response(
|
||||
{"success": False, "error": "Parameter modelVersionId must be an integer"},
|
||||
status=400,
|
||||
)
|
||||
|
||||
downloaded = data.get("downloaded")
|
||||
if not isinstance(downloaded, bool):
|
||||
return web.json_response(
|
||||
{"success": False, "error": "Parameter downloaded must be a boolean"},
|
||||
status=400,
|
||||
)
|
||||
|
||||
history_service = await self._get_download_history_service()
|
||||
if downloaded:
|
||||
model_id = data.get("modelId")
|
||||
file_path = data.get("filePath")
|
||||
await history_service.mark_downloaded(
|
||||
model_type,
|
||||
model_version_id,
|
||||
model_id=model_id,
|
||||
source="manual",
|
||||
file_path=file_path if isinstance(file_path, str) else None,
|
||||
)
|
||||
else:
|
||||
await history_service.mark_not_downloaded(model_type, model_version_id)
|
||||
|
||||
return web.json_response(
|
||||
{
|
||||
"success": True,
|
||||
"modelType": model_type,
|
||||
"modelVersionId": model_version_id,
|
||||
"hasBeenDownloaded": downloaded,
|
||||
}
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
logger.error(
|
||||
"Failed to set model version download status: %s",
|
||||
exc,
|
||||
exc_info=True,
|
||||
)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
async def get_model_versions_status(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
model_id_str = request.query.get("modelId")
|
||||
@@ -896,18 +1053,8 @@ class ModelLibraryHandler:
|
||||
model_name = response.get("name", "")
|
||||
model_type = response.get("type", "").lower()
|
||||
|
||||
scanner = None
|
||||
normalized_type = None
|
||||
if model_type in {"lora", "locon", "dora"}:
|
||||
scanner = await self._service_registry.get_lora_scanner()
|
||||
normalized_type = "lora"
|
||||
elif model_type == "checkpoint":
|
||||
scanner = await self._service_registry.get_checkpoint_scanner()
|
||||
normalized_type = "checkpoint"
|
||||
elif model_type == "textualinversion":
|
||||
scanner = await self._service_registry.get_embedding_scanner()
|
||||
normalized_type = "embedding"
|
||||
else:
|
||||
normalized_type, scanner = await self._get_scanner_for_type(model_type)
|
||||
if not normalized_type:
|
||||
return web.json_response(
|
||||
{
|
||||
"success": False,
|
||||
@@ -925,8 +1072,14 @@ class ModelLibraryHandler:
|
||||
status=503,
|
||||
)
|
||||
|
||||
history_service = await self._get_download_history_service()
|
||||
local_versions = await scanner.get_model_versions_by_id(model_id)
|
||||
local_version_ids = {version["versionId"] for version in local_versions}
|
||||
downloaded_version_ids = await history_service.get_downloaded_version_ids(
|
||||
normalized_type,
|
||||
model_id,
|
||||
)
|
||||
downloaded_version_id_set = set(downloaded_version_ids)
|
||||
|
||||
enriched_versions = []
|
||||
for version in versions:
|
||||
@@ -939,6 +1092,7 @@ class ModelLibraryHandler:
|
||||
if version.get("images")
|
||||
else None,
|
||||
"inLibrary": version_id in local_version_ids,
|
||||
"hasBeenDownloaded": version_id in downloaded_version_id_set,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -1007,6 +1161,33 @@ class ModelLibraryHandler:
|
||||
}
|
||||
|
||||
versions: list[dict] = []
|
||||
history_service = await self._get_download_history_service()
|
||||
model_ids: list[int] = []
|
||||
for model in models:
|
||||
try:
|
||||
model_ids.append(int(model.get("id")))
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
|
||||
lora_downloaded = await history_service.get_downloaded_version_ids_bulk(
|
||||
"lora",
|
||||
model_ids,
|
||||
)
|
||||
checkpoint_downloaded = await history_service.get_downloaded_version_ids_bulk(
|
||||
"checkpoint",
|
||||
model_ids,
|
||||
)
|
||||
embedding_downloaded = await history_service.get_downloaded_version_ids_bulk(
|
||||
"embedding",
|
||||
model_ids,
|
||||
)
|
||||
downloaded_version_map: Dict[str, Dict[int, set[int]]] = {
|
||||
"lora": lora_downloaded,
|
||||
"locon": lora_downloaded,
|
||||
"dora": lora_downloaded,
|
||||
"checkpoint": checkpoint_downloaded,
|
||||
"textualinversion": embedding_downloaded,
|
||||
}
|
||||
for model in models:
|
||||
if not isinstance(model, dict):
|
||||
continue
|
||||
@@ -1061,6 +1242,8 @@ class ModelLibraryHandler:
|
||||
in_library = await scanner.check_model_version_exists(
|
||||
version_id_int
|
||||
)
|
||||
downloaded_versions = downloaded_version_map.get(model_type, {})
|
||||
downloaded_version_ids = downloaded_versions.get(model_id_int, set())
|
||||
|
||||
versions.append(
|
||||
{
|
||||
@@ -1073,6 +1256,7 @@ class ModelLibraryHandler:
|
||||
"baseModel": version.get("baseModel"),
|
||||
"thumbnailUrl": thumbnail_url,
|
||||
"inLibrary": in_library,
|
||||
"hasBeenDownloaded": version_id_int in downloaded_version_ids,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -1655,6 +1839,8 @@ class MiscHandlerSet:
|
||||
"update_node_widget": self.node_registry.update_node_widget,
|
||||
"get_registry": self.node_registry.get_registry,
|
||||
"check_model_exists": self.model_library.check_model_exists,
|
||||
"get_model_version_download_status": self.model_library.get_model_version_download_status,
|
||||
"set_model_version_download_status": self.model_library.set_model_version_download_status,
|
||||
"get_civitai_user_models": self.model_library.get_civitai_user_models,
|
||||
"download_metadata_archive": self.metadata_archive.download_metadata_archive,
|
||||
"remove_metadata_archive": self.metadata_archive.remove_metadata_archive,
|
||||
@@ -1679,4 +1865,5 @@ def build_service_registry_adapter() -> ServiceRegistryAdapter:
|
||||
get_lora_scanner=ServiceRegistry.get_lora_scanner,
|
||||
get_checkpoint_scanner=ServiceRegistry.get_checkpoint_scanner,
|
||||
get_embedding_scanner=ServiceRegistry.get_embedding_scanner,
|
||||
get_downloaded_version_history_service=ServiceRegistry.get_downloaded_version_history_service,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user