mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
feat(model-lifecycle): integrate model update service for deletion sync
This commit is contained in:
@@ -126,6 +126,7 @@ class BaseModelRoutes(ABC):
|
||||
metadata_manager=MetadataManager,
|
||||
metadata_loader=self._metadata_sync_service.load_local_metadata,
|
||||
recipe_scanner_factory=ServiceRegistry.get_recipe_scanner,
|
||||
update_service=self._model_update_service,
|
||||
)
|
||||
self._handler_set = None
|
||||
self._handler_mapping = None
|
||||
@@ -297,4 +298,3 @@ class BaseModelRoutes(ABC):
|
||||
if self._model_update_service is None:
|
||||
raise RuntimeError("Model update service has not been attached")
|
||||
return self._model_update_service
|
||||
|
||||
|
||||
@@ -4,13 +4,16 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Awaitable, Callable, Dict, Iterable, List, Optional
|
||||
from typing import Any, Awaitable, Callable, Dict, Iterable, List, Mapping, Optional, TYPE_CHECKING
|
||||
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
from ..utils.constants import PREVIEW_EXTENSIONS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..services.model_update_service import ModelUpdateService
|
||||
|
||||
|
||||
async def delete_model_artifacts(target_dir: str, file_name: str) -> List[str]:
|
||||
"""Delete the primary model artefacts within ``target_dir``."""
|
||||
@@ -54,6 +57,7 @@ class ModelLifecycleService:
|
||||
metadata_manager,
|
||||
metadata_loader: Callable[[str], Awaitable[Dict[str, object]]],
|
||||
recipe_scanner_factory: Callable[[], Awaitable] | None = None,
|
||||
update_service: "ModelUpdateService" | None = None,
|
||||
) -> None:
|
||||
self._scanner = scanner
|
||||
self._metadata_manager = metadata_manager
|
||||
@@ -61,6 +65,7 @@ class ModelLifecycleService:
|
||||
self._recipe_scanner_factory = (
|
||||
recipe_scanner_factory or ServiceRegistry.get_recipe_scanner
|
||||
)
|
||||
self._update_service = update_service
|
||||
|
||||
async def delete_model(self, file_path: str) -> Dict[str, object]:
|
||||
"""Delete a model file and associated artefacts."""
|
||||
@@ -68,20 +73,100 @@ class ModelLifecycleService:
|
||||
if not file_path:
|
||||
raise ValueError("Model path is required")
|
||||
|
||||
cache = await self._scanner.get_cached_data()
|
||||
|
||||
cached_entry = None
|
||||
if cache and hasattr(cache, "raw_data"):
|
||||
cached_entry = next(
|
||||
(item for item in cache.raw_data if item.get("file_path") == file_path),
|
||||
None,
|
||||
)
|
||||
|
||||
metadata_payload = {}
|
||||
try:
|
||||
metadata_payload = await self._metadata_manager.load_metadata_payload(file_path)
|
||||
except Exception as exc: # pragma: no cover - defensive guard
|
||||
logger.debug("Failed to load metadata payload for %s: %s", file_path, exc)
|
||||
|
||||
model_id = (
|
||||
self._extract_model_id_from_payload(metadata_payload)
|
||||
or self._extract_model_id_from_payload(cached_entry)
|
||||
)
|
||||
|
||||
target_dir = os.path.dirname(file_path)
|
||||
file_name = os.path.splitext(os.path.basename(file_path))[0]
|
||||
|
||||
deleted_files = await delete_model_artifacts(target_dir, file_name)
|
||||
|
||||
cache = await self._scanner.get_cached_data()
|
||||
cache.raw_data = [item for item in cache.raw_data if item["file_path"] != file_path]
|
||||
await cache.resort()
|
||||
if cache:
|
||||
cache.raw_data = [
|
||||
item for item in cache.raw_data if item.get("file_path") != file_path
|
||||
]
|
||||
await cache.resort()
|
||||
|
||||
if hasattr(self._scanner, "_hash_index") and self._scanner._hash_index:
|
||||
self._scanner._hash_index.remove_by_path(file_path)
|
||||
|
||||
await self._sync_update_for_model(model_id)
|
||||
return {"success": True, "deleted_files": deleted_files}
|
||||
|
||||
@staticmethod
|
||||
def _extract_model_id_from_payload(payload: Any) -> Optional[int]:
|
||||
if not isinstance(payload, Mapping):
|
||||
return None
|
||||
civitai = payload.get("civitai")
|
||||
if isinstance(civitai, Mapping):
|
||||
candidate = civitai.get("modelId") or civitai.get("model_id")
|
||||
if candidate is None:
|
||||
model_section = civitai.get("model")
|
||||
if isinstance(model_section, Mapping):
|
||||
candidate = model_section.get("id")
|
||||
normalized = ModelLifecycleService._coerce_int(candidate)
|
||||
if normalized is not None:
|
||||
return normalized
|
||||
fallback = payload.get("model_id") or payload.get("civitai_model_id")
|
||||
return ModelLifecycleService._coerce_int(fallback)
|
||||
|
||||
@staticmethod
|
||||
def _coerce_int(value: Any) -> Optional[int]:
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
async def _sync_update_for_model(self, model_id: Optional[int]) -> None:
|
||||
if self._update_service is None or model_id is None:
|
||||
return
|
||||
|
||||
try:
|
||||
versions = await self._scanner.get_model_versions_by_id(model_id)
|
||||
except Exception as exc: # pragma: no cover - defensive log
|
||||
logger.debug(
|
||||
"Failed to collect local versions for model %s: %s", model_id, exc
|
||||
)
|
||||
versions = []
|
||||
|
||||
version_ids = set()
|
||||
for version in versions or []:
|
||||
candidate = (
|
||||
version.get("versionId")
|
||||
or version.get("id")
|
||||
or version.get("version_id")
|
||||
)
|
||||
normalized = ModelLifecycleService._coerce_int(candidate)
|
||||
if normalized is not None:
|
||||
version_ids.add(normalized)
|
||||
|
||||
try:
|
||||
await self._update_service.update_in_library_versions(
|
||||
self._scanner.model_type,
|
||||
model_id,
|
||||
sorted(version_ids),
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - defensive log
|
||||
logger.debug(
|
||||
"Failed to sync update record for model %s: %s", model_id, exc
|
||||
)
|
||||
|
||||
async def exclude_model(self, file_path: str) -> Dict[str, object]:
|
||||
"""Mark a model as excluded and prune cache references."""
|
||||
|
||||
|
||||
@@ -7,6 +7,67 @@ from py.services.model_lifecycle_service import ModelLifecycleService
|
||||
from py.utils.metadata_manager import MetadataManager
|
||||
|
||||
|
||||
class DummyCache:
|
||||
def __init__(self, raw_data):
|
||||
self.raw_data = raw_data
|
||||
|
||||
async def resort(self):
|
||||
return
|
||||
|
||||
|
||||
class DummyHashIndex:
|
||||
def __init__(self):
|
||||
self.removed = []
|
||||
|
||||
def remove_by_path(self, path, *args):
|
||||
self.removed.append(path)
|
||||
|
||||
|
||||
class VersionAwareScanner:
|
||||
def __init__(self, raw_data, model_type="lora"):
|
||||
self.model_type = model_type
|
||||
self.cache = DummyCache(raw_data)
|
||||
self._hash_index = DummyHashIndex()
|
||||
|
||||
async def get_cached_data(self):
|
||||
return self.cache
|
||||
|
||||
async def get_model_versions_by_id(self, model_id):
|
||||
collected = []
|
||||
for item in self.cache.raw_data:
|
||||
civitai = item.get("civitai")
|
||||
if not isinstance(civitai, dict):
|
||||
continue
|
||||
candidate = civitai.get("modelId")
|
||||
try:
|
||||
normalized = int(candidate)
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
if normalized != model_id:
|
||||
continue
|
||||
version_id = civitai.get("id")
|
||||
if version_id is None:
|
||||
continue
|
||||
collected.append({"versionId": version_id})
|
||||
return collected
|
||||
|
||||
|
||||
class DummyMetadataManager:
|
||||
def __init__(self, payload):
|
||||
self._payload = dict(payload)
|
||||
|
||||
async def load_metadata_payload(self, file_path: str):
|
||||
return dict(self._payload)
|
||||
|
||||
|
||||
class DummyUpdateService:
|
||||
def __init__(self):
|
||||
self.calls = []
|
||||
|
||||
async def update_in_library_versions(self, model_type, model_id, version_ids):
|
||||
self.calls.append((model_type, model_id, version_ids))
|
||||
|
||||
|
||||
class DummyScanner:
|
||||
def __init__(self):
|
||||
self.calls = []
|
||||
@@ -81,3 +142,43 @@ async def test_rename_model_preserves_compound_extensions(tmp_path: Path):
|
||||
assert old_call_path.endswith(f"{old_name}.safetensors")
|
||||
assert new_call_path.endswith(f"{new_name}.safetensors")
|
||||
assert payload["file_name"] == new_name
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_model_updates_update_service(tmp_path: Path):
|
||||
model_path = tmp_path / "sample.safetensors"
|
||||
model_path.write_bytes(b"content")
|
||||
|
||||
other_path = tmp_path / "another.safetensors"
|
||||
other_path.write_bytes(b"other")
|
||||
|
||||
raw_data = [
|
||||
{
|
||||
"file_path": model_path.as_posix(),
|
||||
"civitai": {"modelId": 42, "id": 1001},
|
||||
},
|
||||
{
|
||||
"file_path": other_path.as_posix(),
|
||||
"civitai": {"modelId": 42, "id": 1002},
|
||||
},
|
||||
]
|
||||
|
||||
scanner = VersionAwareScanner(raw_data)
|
||||
metadata_manager = DummyMetadataManager({"civitai": {"modelId": 42, "id": 1001}})
|
||||
|
||||
async def metadata_loader(path: str):
|
||||
return {}
|
||||
|
||||
update_service = DummyUpdateService()
|
||||
service = ModelLifecycleService(
|
||||
scanner=scanner,
|
||||
metadata_manager=metadata_manager,
|
||||
metadata_loader=metadata_loader,
|
||||
update_service=update_service,
|
||||
)
|
||||
|
||||
result = await service.delete_model(model_path.as_posix())
|
||||
|
||||
assert result["success"] is True
|
||||
assert not model_path.exists()
|
||||
assert update_service.calls == [("lora", 42, [1002])]
|
||||
|
||||
Reference in New Issue
Block a user