diff --git a/py/routes/base_model_routes.py b/py/routes/base_model_routes.py index 5eb6d1cc..b6aba1e4 100644 --- a/py/routes/base_model_routes.py +++ b/py/routes/base_model_routes.py @@ -126,6 +126,7 @@ class BaseModelRoutes(ABC): metadata_manager=MetadataManager, metadata_loader=self._metadata_sync_service.load_local_metadata, recipe_scanner_factory=ServiceRegistry.get_recipe_scanner, + update_service=self._model_update_service, ) self._handler_set = None self._handler_mapping = None @@ -297,4 +298,3 @@ class BaseModelRoutes(ABC): if self._model_update_service is None: raise RuntimeError("Model update service has not been attached") return self._model_update_service - diff --git a/py/services/model_lifecycle_service.py b/py/services/model_lifecycle_service.py index 1416768d..dd67226c 100644 --- a/py/services/model_lifecycle_service.py +++ b/py/services/model_lifecycle_service.py @@ -4,13 +4,16 @@ from __future__ import annotations import logging import os -from typing import Awaitable, Callable, Dict, Iterable, List, Optional +from typing import Any, Awaitable, Callable, Dict, Iterable, List, Mapping, Optional, TYPE_CHECKING from ..services.service_registry import ServiceRegistry from ..utils.constants import PREVIEW_EXTENSIONS logger = logging.getLogger(__name__) +if TYPE_CHECKING: + from ..services.model_update_service import ModelUpdateService + async def delete_model_artifacts(target_dir: str, file_name: str) -> List[str]: """Delete the primary model artefacts within ``target_dir``.""" @@ -54,6 +57,7 @@ class ModelLifecycleService: metadata_manager, metadata_loader: Callable[[str], Awaitable[Dict[str, object]]], recipe_scanner_factory: Callable[[], Awaitable] | None = None, + update_service: "ModelUpdateService" | None = None, ) -> None: self._scanner = scanner self._metadata_manager = metadata_manager @@ -61,6 +65,7 @@ class ModelLifecycleService: self._recipe_scanner_factory = ( recipe_scanner_factory or ServiceRegistry.get_recipe_scanner ) + self._update_service = update_service async def delete_model(self, file_path: str) -> Dict[str, object]: """Delete a model file and associated artefacts.""" @@ -68,20 +73,100 @@ class ModelLifecycleService: if not file_path: raise ValueError("Model path is required") + cache = await self._scanner.get_cached_data() + + cached_entry = None + if cache and hasattr(cache, "raw_data"): + cached_entry = next( + (item for item in cache.raw_data if item.get("file_path") == file_path), + None, + ) + + metadata_payload = {} + try: + metadata_payload = await self._metadata_manager.load_metadata_payload(file_path) + except Exception as exc: # pragma: no cover - defensive guard + logger.debug("Failed to load metadata payload for %s: %s", file_path, exc) + + model_id = ( + self._extract_model_id_from_payload(metadata_payload) + or self._extract_model_id_from_payload(cached_entry) + ) + target_dir = os.path.dirname(file_path) file_name = os.path.splitext(os.path.basename(file_path))[0] - deleted_files = await delete_model_artifacts(target_dir, file_name) - cache = await self._scanner.get_cached_data() - cache.raw_data = [item for item in cache.raw_data if item["file_path"] != file_path] - await cache.resort() + if cache: + cache.raw_data = [ + item for item in cache.raw_data if item.get("file_path") != file_path + ] + await cache.resort() if hasattr(self._scanner, "_hash_index") and self._scanner._hash_index: self._scanner._hash_index.remove_by_path(file_path) + await self._sync_update_for_model(model_id) return {"success": True, "deleted_files": deleted_files} + @staticmethod + def _extract_model_id_from_payload(payload: Any) -> Optional[int]: + if not isinstance(payload, Mapping): + return None + civitai = payload.get("civitai") + if isinstance(civitai, Mapping): + candidate = civitai.get("modelId") or civitai.get("model_id") + if candidate is None: + model_section = civitai.get("model") + if isinstance(model_section, Mapping): + candidate = model_section.get("id") + normalized = ModelLifecycleService._coerce_int(candidate) + if normalized is not None: + return normalized + fallback = payload.get("model_id") or payload.get("civitai_model_id") + return ModelLifecycleService._coerce_int(fallback) + + @staticmethod + def _coerce_int(value: Any) -> Optional[int]: + try: + return int(value) + except (TypeError, ValueError): + return None + + async def _sync_update_for_model(self, model_id: Optional[int]) -> None: + if self._update_service is None or model_id is None: + return + + try: + versions = await self._scanner.get_model_versions_by_id(model_id) + except Exception as exc: # pragma: no cover - defensive log + logger.debug( + "Failed to collect local versions for model %s: %s", model_id, exc + ) + versions = [] + + version_ids = set() + for version in versions or []: + candidate = ( + version.get("versionId") + or version.get("id") + or version.get("version_id") + ) + normalized = ModelLifecycleService._coerce_int(candidate) + if normalized is not None: + version_ids.add(normalized) + + try: + await self._update_service.update_in_library_versions( + self._scanner.model_type, + model_id, + sorted(version_ids), + ) + except Exception as exc: # pragma: no cover - defensive log + logger.debug( + "Failed to sync update record for model %s: %s", model_id, exc + ) + async def exclude_model(self, file_path: str) -> Dict[str, object]: """Mark a model as excluded and prune cache references.""" diff --git a/tests/services/test_model_lifecycle_service.py b/tests/services/test_model_lifecycle_service.py index e96f8236..e29bfc0b 100644 --- a/tests/services/test_model_lifecycle_service.py +++ b/tests/services/test_model_lifecycle_service.py @@ -7,6 +7,67 @@ from py.services.model_lifecycle_service import ModelLifecycleService from py.utils.metadata_manager import MetadataManager +class DummyCache: + def __init__(self, raw_data): + self.raw_data = raw_data + + async def resort(self): + return + + +class DummyHashIndex: + def __init__(self): + self.removed = [] + + def remove_by_path(self, path, *args): + self.removed.append(path) + + +class VersionAwareScanner: + def __init__(self, raw_data, model_type="lora"): + self.model_type = model_type + self.cache = DummyCache(raw_data) + self._hash_index = DummyHashIndex() + + async def get_cached_data(self): + return self.cache + + async def get_model_versions_by_id(self, model_id): + collected = [] + for item in self.cache.raw_data: + civitai = item.get("civitai") + if not isinstance(civitai, dict): + continue + candidate = civitai.get("modelId") + try: + normalized = int(candidate) + except (TypeError, ValueError): + continue + if normalized != model_id: + continue + version_id = civitai.get("id") + if version_id is None: + continue + collected.append({"versionId": version_id}) + return collected + + +class DummyMetadataManager: + def __init__(self, payload): + self._payload = dict(payload) + + async def load_metadata_payload(self, file_path: str): + return dict(self._payload) + + +class DummyUpdateService: + def __init__(self): + self.calls = [] + + async def update_in_library_versions(self, model_type, model_id, version_ids): + self.calls.append((model_type, model_id, version_ids)) + + class DummyScanner: def __init__(self): self.calls = [] @@ -81,3 +142,43 @@ async def test_rename_model_preserves_compound_extensions(tmp_path: Path): assert old_call_path.endswith(f"{old_name}.safetensors") assert new_call_path.endswith(f"{new_name}.safetensors") assert payload["file_name"] == new_name + + +@pytest.mark.asyncio +async def test_delete_model_updates_update_service(tmp_path: Path): + model_path = tmp_path / "sample.safetensors" + model_path.write_bytes(b"content") + + other_path = tmp_path / "another.safetensors" + other_path.write_bytes(b"other") + + raw_data = [ + { + "file_path": model_path.as_posix(), + "civitai": {"modelId": 42, "id": 1001}, + }, + { + "file_path": other_path.as_posix(), + "civitai": {"modelId": 42, "id": 1002}, + }, + ] + + scanner = VersionAwareScanner(raw_data) + metadata_manager = DummyMetadataManager({"civitai": {"modelId": 42, "id": 1001}}) + + async def metadata_loader(path: str): + return {} + + update_service = DummyUpdateService() + service = ModelLifecycleService( + scanner=scanner, + metadata_manager=metadata_manager, + metadata_loader=metadata_loader, + update_service=update_service, + ) + + result = await service.delete_model(model_path.as_posix()) + + assert result["success"] is True + assert not model_path.exists() + assert update_service.calls == [("lora", 42, [1002])]