mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
feat(model-lifecycle): integrate model update service for deletion sync
Add ModelUpdateService dependency to ModelLifecycleService to enable synchronization during model deletion. The service is now passed through BaseModelRoutes initialization and used in delete_model to trigger updates when a model is removed. This ensures external systems stay in sync with local model state changes. Key changes: - Inject update_service into ModelLifecycleService constructor - Extract model ID from metadata during deletion - Call update service sync method after successful deletion - Add proper type hints and TYPE_CHECKING imports
This commit is contained in:
@@ -126,6 +126,7 @@ class BaseModelRoutes(ABC):
|
||||
metadata_manager=MetadataManager,
|
||||
metadata_loader=self._metadata_sync_service.load_local_metadata,
|
||||
recipe_scanner_factory=ServiceRegistry.get_recipe_scanner,
|
||||
update_service=self._model_update_service,
|
||||
)
|
||||
self._handler_set = None
|
||||
self._handler_mapping = None
|
||||
@@ -297,4 +298,3 @@ class BaseModelRoutes(ABC):
|
||||
if self._model_update_service is None:
|
||||
raise RuntimeError("Model update service has not been attached")
|
||||
return self._model_update_service
|
||||
|
||||
|
||||
@@ -4,13 +4,16 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Awaitable, Callable, Dict, Iterable, List, Optional
|
||||
from typing import Any, Awaitable, Callable, Dict, Iterable, List, Mapping, Optional, TYPE_CHECKING
|
||||
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
from ..utils.constants import PREVIEW_EXTENSIONS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..services.model_update_service import ModelUpdateService
|
||||
|
||||
|
||||
async def delete_model_artifacts(target_dir: str, file_name: str) -> List[str]:
|
||||
"""Delete the primary model artefacts within ``target_dir``."""
|
||||
@@ -54,6 +57,7 @@ class ModelLifecycleService:
|
||||
metadata_manager,
|
||||
metadata_loader: Callable[[str], Awaitable[Dict[str, object]]],
|
||||
recipe_scanner_factory: Callable[[], Awaitable] | None = None,
|
||||
update_service: "ModelUpdateService" | None = None,
|
||||
) -> None:
|
||||
self._scanner = scanner
|
||||
self._metadata_manager = metadata_manager
|
||||
@@ -61,6 +65,7 @@ class ModelLifecycleService:
|
||||
self._recipe_scanner_factory = (
|
||||
recipe_scanner_factory or ServiceRegistry.get_recipe_scanner
|
||||
)
|
||||
self._update_service = update_service
|
||||
|
||||
async def delete_model(self, file_path: str) -> Dict[str, object]:
|
||||
"""Delete a model file and associated artefacts."""
|
||||
@@ -68,20 +73,100 @@ class ModelLifecycleService:
|
||||
if not file_path:
|
||||
raise ValueError("Model path is required")
|
||||
|
||||
cache = await self._scanner.get_cached_data()
|
||||
|
||||
cached_entry = None
|
||||
if cache and hasattr(cache, "raw_data"):
|
||||
cached_entry = next(
|
||||
(item for item in cache.raw_data if item.get("file_path") == file_path),
|
||||
None,
|
||||
)
|
||||
|
||||
metadata_payload = {}
|
||||
try:
|
||||
metadata_payload = await self._metadata_manager.load_metadata_payload(file_path)
|
||||
except Exception as exc: # pragma: no cover - defensive guard
|
||||
logger.debug("Failed to load metadata payload for %s: %s", file_path, exc)
|
||||
|
||||
model_id = (
|
||||
self._extract_model_id_from_payload(metadata_payload)
|
||||
or self._extract_model_id_from_payload(cached_entry)
|
||||
)
|
||||
|
||||
target_dir = os.path.dirname(file_path)
|
||||
file_name = os.path.splitext(os.path.basename(file_path))[0]
|
||||
|
||||
deleted_files = await delete_model_artifacts(target_dir, file_name)
|
||||
|
||||
cache = await self._scanner.get_cached_data()
|
||||
cache.raw_data = [item for item in cache.raw_data if item["file_path"] != file_path]
|
||||
await cache.resort()
|
||||
if cache:
|
||||
cache.raw_data = [
|
||||
item for item in cache.raw_data if item.get("file_path") != file_path
|
||||
]
|
||||
await cache.resort()
|
||||
|
||||
if hasattr(self._scanner, "_hash_index") and self._scanner._hash_index:
|
||||
self._scanner._hash_index.remove_by_path(file_path)
|
||||
|
||||
await self._sync_update_for_model(model_id)
|
||||
return {"success": True, "deleted_files": deleted_files}
|
||||
|
||||
@staticmethod
|
||||
def _extract_model_id_from_payload(payload: Any) -> Optional[int]:
|
||||
if not isinstance(payload, Mapping):
|
||||
return None
|
||||
civitai = payload.get("civitai")
|
||||
if isinstance(civitai, Mapping):
|
||||
candidate = civitai.get("modelId") or civitai.get("model_id")
|
||||
if candidate is None:
|
||||
model_section = civitai.get("model")
|
||||
if isinstance(model_section, Mapping):
|
||||
candidate = model_section.get("id")
|
||||
normalized = ModelLifecycleService._coerce_int(candidate)
|
||||
if normalized is not None:
|
||||
return normalized
|
||||
fallback = payload.get("model_id") or payload.get("civitai_model_id")
|
||||
return ModelLifecycleService._coerce_int(fallback)
|
||||
|
||||
@staticmethod
|
||||
def _coerce_int(value: Any) -> Optional[int]:
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
async def _sync_update_for_model(self, model_id: Optional[int]) -> None:
|
||||
if self._update_service is None or model_id is None:
|
||||
return
|
||||
|
||||
try:
|
||||
versions = await self._scanner.get_model_versions_by_id(model_id)
|
||||
except Exception as exc: # pragma: no cover - defensive log
|
||||
logger.debug(
|
||||
"Failed to collect local versions for model %s: %s", model_id, exc
|
||||
)
|
||||
versions = []
|
||||
|
||||
version_ids = set()
|
||||
for version in versions or []:
|
||||
candidate = (
|
||||
version.get("versionId")
|
||||
or version.get("id")
|
||||
or version.get("version_id")
|
||||
)
|
||||
normalized = ModelLifecycleService._coerce_int(candidate)
|
||||
if normalized is not None:
|
||||
version_ids.add(normalized)
|
||||
|
||||
try:
|
||||
await self._update_service.update_in_library_versions(
|
||||
self._scanner.model_type,
|
||||
model_id,
|
||||
sorted(version_ids),
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - defensive log
|
||||
logger.debug(
|
||||
"Failed to sync update record for model %s: %s", model_id, exc
|
||||
)
|
||||
|
||||
async def exclude_model(self, file_path: str) -> Dict[str, object]:
|
||||
"""Mark a model as excluded and prune cache references."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user