mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
feat(model-lifecycle): integrate model update service for deletion sync
This commit is contained in:
@@ -126,6 +126,7 @@ class BaseModelRoutes(ABC):
|
|||||||
metadata_manager=MetadataManager,
|
metadata_manager=MetadataManager,
|
||||||
metadata_loader=self._metadata_sync_service.load_local_metadata,
|
metadata_loader=self._metadata_sync_service.load_local_metadata,
|
||||||
recipe_scanner_factory=ServiceRegistry.get_recipe_scanner,
|
recipe_scanner_factory=ServiceRegistry.get_recipe_scanner,
|
||||||
|
update_service=self._model_update_service,
|
||||||
)
|
)
|
||||||
self._handler_set = None
|
self._handler_set = None
|
||||||
self._handler_mapping = None
|
self._handler_mapping = None
|
||||||
@@ -297,4 +298,3 @@ class BaseModelRoutes(ABC):
|
|||||||
if self._model_update_service is None:
|
if self._model_update_service is None:
|
||||||
raise RuntimeError("Model update service has not been attached")
|
raise RuntimeError("Model update service has not been attached")
|
||||||
return self._model_update_service
|
return self._model_update_service
|
||||||
|
|
||||||
|
|||||||
@@ -4,13 +4,16 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import Awaitable, Callable, Dict, Iterable, List, Optional
|
from typing import Any, Awaitable, Callable, Dict, Iterable, List, Mapping, Optional, TYPE_CHECKING
|
||||||
|
|
||||||
from ..services.service_registry import ServiceRegistry
|
from ..services.service_registry import ServiceRegistry
|
||||||
from ..utils.constants import PREVIEW_EXTENSIONS
|
from ..utils.constants import PREVIEW_EXTENSIONS
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..services.model_update_service import ModelUpdateService
|
||||||
|
|
||||||
|
|
||||||
async def delete_model_artifacts(target_dir: str, file_name: str) -> List[str]:
|
async def delete_model_artifacts(target_dir: str, file_name: str) -> List[str]:
|
||||||
"""Delete the primary model artefacts within ``target_dir``."""
|
"""Delete the primary model artefacts within ``target_dir``."""
|
||||||
@@ -54,6 +57,7 @@ class ModelLifecycleService:
|
|||||||
metadata_manager,
|
metadata_manager,
|
||||||
metadata_loader: Callable[[str], Awaitable[Dict[str, object]]],
|
metadata_loader: Callable[[str], Awaitable[Dict[str, object]]],
|
||||||
recipe_scanner_factory: Callable[[], Awaitable] | None = None,
|
recipe_scanner_factory: Callable[[], Awaitable] | None = None,
|
||||||
|
update_service: "ModelUpdateService" | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._scanner = scanner
|
self._scanner = scanner
|
||||||
self._metadata_manager = metadata_manager
|
self._metadata_manager = metadata_manager
|
||||||
@@ -61,6 +65,7 @@ class ModelLifecycleService:
|
|||||||
self._recipe_scanner_factory = (
|
self._recipe_scanner_factory = (
|
||||||
recipe_scanner_factory or ServiceRegistry.get_recipe_scanner
|
recipe_scanner_factory or ServiceRegistry.get_recipe_scanner
|
||||||
)
|
)
|
||||||
|
self._update_service = update_service
|
||||||
|
|
||||||
async def delete_model(self, file_path: str) -> Dict[str, object]:
|
async def delete_model(self, file_path: str) -> Dict[str, object]:
|
||||||
"""Delete a model file and associated artefacts."""
|
"""Delete a model file and associated artefacts."""
|
||||||
@@ -68,20 +73,100 @@ class ModelLifecycleService:
|
|||||||
if not file_path:
|
if not file_path:
|
||||||
raise ValueError("Model path is required")
|
raise ValueError("Model path is required")
|
||||||
|
|
||||||
|
cache = await self._scanner.get_cached_data()
|
||||||
|
|
||||||
|
cached_entry = None
|
||||||
|
if cache and hasattr(cache, "raw_data"):
|
||||||
|
cached_entry = next(
|
||||||
|
(item for item in cache.raw_data if item.get("file_path") == file_path),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
metadata_payload = {}
|
||||||
|
try:
|
||||||
|
metadata_payload = await self._metadata_manager.load_metadata_payload(file_path)
|
||||||
|
except Exception as exc: # pragma: no cover - defensive guard
|
||||||
|
logger.debug("Failed to load metadata payload for %s: %s", file_path, exc)
|
||||||
|
|
||||||
|
model_id = (
|
||||||
|
self._extract_model_id_from_payload(metadata_payload)
|
||||||
|
or self._extract_model_id_from_payload(cached_entry)
|
||||||
|
)
|
||||||
|
|
||||||
target_dir = os.path.dirname(file_path)
|
target_dir = os.path.dirname(file_path)
|
||||||
file_name = os.path.splitext(os.path.basename(file_path))[0]
|
file_name = os.path.splitext(os.path.basename(file_path))[0]
|
||||||
|
|
||||||
deleted_files = await delete_model_artifacts(target_dir, file_name)
|
deleted_files = await delete_model_artifacts(target_dir, file_name)
|
||||||
|
|
||||||
cache = await self._scanner.get_cached_data()
|
if cache:
|
||||||
cache.raw_data = [item for item in cache.raw_data if item["file_path"] != file_path]
|
cache.raw_data = [
|
||||||
await cache.resort()
|
item for item in cache.raw_data if item.get("file_path") != file_path
|
||||||
|
]
|
||||||
|
await cache.resort()
|
||||||
|
|
||||||
if hasattr(self._scanner, "_hash_index") and self._scanner._hash_index:
|
if hasattr(self._scanner, "_hash_index") and self._scanner._hash_index:
|
||||||
self._scanner._hash_index.remove_by_path(file_path)
|
self._scanner._hash_index.remove_by_path(file_path)
|
||||||
|
|
||||||
|
await self._sync_update_for_model(model_id)
|
||||||
return {"success": True, "deleted_files": deleted_files}
|
return {"success": True, "deleted_files": deleted_files}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _extract_model_id_from_payload(payload: Any) -> Optional[int]:
|
||||||
|
if not isinstance(payload, Mapping):
|
||||||
|
return None
|
||||||
|
civitai = payload.get("civitai")
|
||||||
|
if isinstance(civitai, Mapping):
|
||||||
|
candidate = civitai.get("modelId") or civitai.get("model_id")
|
||||||
|
if candidate is None:
|
||||||
|
model_section = civitai.get("model")
|
||||||
|
if isinstance(model_section, Mapping):
|
||||||
|
candidate = model_section.get("id")
|
||||||
|
normalized = ModelLifecycleService._coerce_int(candidate)
|
||||||
|
if normalized is not None:
|
||||||
|
return normalized
|
||||||
|
fallback = payload.get("model_id") or payload.get("civitai_model_id")
|
||||||
|
return ModelLifecycleService._coerce_int(fallback)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _coerce_int(value: Any) -> Optional[int]:
|
||||||
|
try:
|
||||||
|
return int(value)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _sync_update_for_model(self, model_id: Optional[int]) -> None:
|
||||||
|
if self._update_service is None or model_id is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
versions = await self._scanner.get_model_versions_by_id(model_id)
|
||||||
|
except Exception as exc: # pragma: no cover - defensive log
|
||||||
|
logger.debug(
|
||||||
|
"Failed to collect local versions for model %s: %s", model_id, exc
|
||||||
|
)
|
||||||
|
versions = []
|
||||||
|
|
||||||
|
version_ids = set()
|
||||||
|
for version in versions or []:
|
||||||
|
candidate = (
|
||||||
|
version.get("versionId")
|
||||||
|
or version.get("id")
|
||||||
|
or version.get("version_id")
|
||||||
|
)
|
||||||
|
normalized = ModelLifecycleService._coerce_int(candidate)
|
||||||
|
if normalized is not None:
|
||||||
|
version_ids.add(normalized)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self._update_service.update_in_library_versions(
|
||||||
|
self._scanner.model_type,
|
||||||
|
model_id,
|
||||||
|
sorted(version_ids),
|
||||||
|
)
|
||||||
|
except Exception as exc: # pragma: no cover - defensive log
|
||||||
|
logger.debug(
|
||||||
|
"Failed to sync update record for model %s: %s", model_id, exc
|
||||||
|
)
|
||||||
|
|
||||||
async def exclude_model(self, file_path: str) -> Dict[str, object]:
|
async def exclude_model(self, file_path: str) -> Dict[str, object]:
|
||||||
"""Mark a model as excluded and prune cache references."""
|
"""Mark a model as excluded and prune cache references."""
|
||||||
|
|
||||||
|
|||||||
@@ -7,6 +7,67 @@ from py.services.model_lifecycle_service import ModelLifecycleService
|
|||||||
from py.utils.metadata_manager import MetadataManager
|
from py.utils.metadata_manager import MetadataManager
|
||||||
|
|
||||||
|
|
||||||
|
class DummyCache:
|
||||||
|
def __init__(self, raw_data):
|
||||||
|
self.raw_data = raw_data
|
||||||
|
|
||||||
|
async def resort(self):
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
class DummyHashIndex:
|
||||||
|
def __init__(self):
|
||||||
|
self.removed = []
|
||||||
|
|
||||||
|
def remove_by_path(self, path, *args):
|
||||||
|
self.removed.append(path)
|
||||||
|
|
||||||
|
|
||||||
|
class VersionAwareScanner:
|
||||||
|
def __init__(self, raw_data, model_type="lora"):
|
||||||
|
self.model_type = model_type
|
||||||
|
self.cache = DummyCache(raw_data)
|
||||||
|
self._hash_index = DummyHashIndex()
|
||||||
|
|
||||||
|
async def get_cached_data(self):
|
||||||
|
return self.cache
|
||||||
|
|
||||||
|
async def get_model_versions_by_id(self, model_id):
|
||||||
|
collected = []
|
||||||
|
for item in self.cache.raw_data:
|
||||||
|
civitai = item.get("civitai")
|
||||||
|
if not isinstance(civitai, dict):
|
||||||
|
continue
|
||||||
|
candidate = civitai.get("modelId")
|
||||||
|
try:
|
||||||
|
normalized = int(candidate)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
continue
|
||||||
|
if normalized != model_id:
|
||||||
|
continue
|
||||||
|
version_id = civitai.get("id")
|
||||||
|
if version_id is None:
|
||||||
|
continue
|
||||||
|
collected.append({"versionId": version_id})
|
||||||
|
return collected
|
||||||
|
|
||||||
|
|
||||||
|
class DummyMetadataManager:
|
||||||
|
def __init__(self, payload):
|
||||||
|
self._payload = dict(payload)
|
||||||
|
|
||||||
|
async def load_metadata_payload(self, file_path: str):
|
||||||
|
return dict(self._payload)
|
||||||
|
|
||||||
|
|
||||||
|
class DummyUpdateService:
|
||||||
|
def __init__(self):
|
||||||
|
self.calls = []
|
||||||
|
|
||||||
|
async def update_in_library_versions(self, model_type, model_id, version_ids):
|
||||||
|
self.calls.append((model_type, model_id, version_ids))
|
||||||
|
|
||||||
|
|
||||||
class DummyScanner:
|
class DummyScanner:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.calls = []
|
self.calls = []
|
||||||
@@ -81,3 +142,43 @@ async def test_rename_model_preserves_compound_extensions(tmp_path: Path):
|
|||||||
assert old_call_path.endswith(f"{old_name}.safetensors")
|
assert old_call_path.endswith(f"{old_name}.safetensors")
|
||||||
assert new_call_path.endswith(f"{new_name}.safetensors")
|
assert new_call_path.endswith(f"{new_name}.safetensors")
|
||||||
assert payload["file_name"] == new_name
|
assert payload["file_name"] == new_name
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_model_updates_update_service(tmp_path: Path):
|
||||||
|
model_path = tmp_path / "sample.safetensors"
|
||||||
|
model_path.write_bytes(b"content")
|
||||||
|
|
||||||
|
other_path = tmp_path / "another.safetensors"
|
||||||
|
other_path.write_bytes(b"other")
|
||||||
|
|
||||||
|
raw_data = [
|
||||||
|
{
|
||||||
|
"file_path": model_path.as_posix(),
|
||||||
|
"civitai": {"modelId": 42, "id": 1001},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"file_path": other_path.as_posix(),
|
||||||
|
"civitai": {"modelId": 42, "id": 1002},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
scanner = VersionAwareScanner(raw_data)
|
||||||
|
metadata_manager = DummyMetadataManager({"civitai": {"modelId": 42, "id": 1001}})
|
||||||
|
|
||||||
|
async def metadata_loader(path: str):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
update_service = DummyUpdateService()
|
||||||
|
service = ModelLifecycleService(
|
||||||
|
scanner=scanner,
|
||||||
|
metadata_manager=metadata_manager,
|
||||||
|
metadata_loader=metadata_loader,
|
||||||
|
update_service=update_service,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await service.delete_model(model_path.as_posix())
|
||||||
|
|
||||||
|
assert result["success"] is True
|
||||||
|
assert not model_path.exists()
|
||||||
|
assert update_service.calls == [("lora", 42, [1002])]
|
||||||
|
|||||||
Reference in New Issue
Block a user