Merge pull request #683 from willmiao/deletion-sync, see #673

feat(model-lifecycle): integrate model update service for deletion sync
This commit is contained in:
pixelpaws
2025-11-18 23:28:17 +08:00
committed by GitHub
3 changed files with 192 additions and 6 deletions

View File

@@ -126,6 +126,7 @@ class BaseModelRoutes(ABC):
metadata_manager=MetadataManager,
metadata_loader=self._metadata_sync_service.load_local_metadata,
recipe_scanner_factory=ServiceRegistry.get_recipe_scanner,
update_service=self._model_update_service,
)
self._handler_set = None
self._handler_mapping = None
@@ -297,4 +298,3 @@ class BaseModelRoutes(ABC):
if self._model_update_service is None:
raise RuntimeError("Model update service has not been attached")
return self._model_update_service

View File

@@ -4,13 +4,16 @@ from __future__ import annotations
import logging
import os
from typing import Awaitable, Callable, Dict, Iterable, List, Optional
from typing import Any, Awaitable, Callable, Dict, Iterable, List, Mapping, Optional, TYPE_CHECKING
from ..services.service_registry import ServiceRegistry
from ..utils.constants import PREVIEW_EXTENSIONS
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from ..services.model_update_service import ModelUpdateService
async def delete_model_artifacts(target_dir: str, file_name: str) -> List[str]:
"""Delete the primary model artefacts within ``target_dir``."""
@@ -54,6 +57,7 @@ class ModelLifecycleService:
metadata_manager,
metadata_loader: Callable[[str], Awaitable[Dict[str, object]]],
recipe_scanner_factory: Callable[[], Awaitable] | None = None,
update_service: "ModelUpdateService" | None = None,
) -> None:
self._scanner = scanner
self._metadata_manager = metadata_manager
@@ -61,6 +65,7 @@ class ModelLifecycleService:
self._recipe_scanner_factory = (
recipe_scanner_factory or ServiceRegistry.get_recipe_scanner
)
self._update_service = update_service
async def delete_model(self, file_path: str) -> Dict[str, object]:
"""Delete a model file and associated artefacts."""
@@ -68,20 +73,100 @@ class ModelLifecycleService:
if not file_path:
raise ValueError("Model path is required")
cache = await self._scanner.get_cached_data()
cached_entry = None
if cache and hasattr(cache, "raw_data"):
cached_entry = next(
(item for item in cache.raw_data if item.get("file_path") == file_path),
None,
)
metadata_payload = {}
try:
metadata_payload = await self._metadata_manager.load_metadata_payload(file_path)
except Exception as exc: # pragma: no cover - defensive guard
logger.debug("Failed to load metadata payload for %s: %s", file_path, exc)
model_id = (
self._extract_model_id_from_payload(metadata_payload)
or self._extract_model_id_from_payload(cached_entry)
)
target_dir = os.path.dirname(file_path)
file_name = os.path.splitext(os.path.basename(file_path))[0]
deleted_files = await delete_model_artifacts(target_dir, file_name)
cache = await self._scanner.get_cached_data()
cache.raw_data = [item for item in cache.raw_data if item["file_path"] != file_path]
await cache.resort()
if cache:
cache.raw_data = [
item for item in cache.raw_data if item.get("file_path") != file_path
]
await cache.resort()
if hasattr(self._scanner, "_hash_index") and self._scanner._hash_index:
self._scanner._hash_index.remove_by_path(file_path)
await self._sync_update_for_model(model_id)
return {"success": True, "deleted_files": deleted_files}
@staticmethod
def _extract_model_id_from_payload(payload: Any) -> Optional[int]:
if not isinstance(payload, Mapping):
return None
civitai = payload.get("civitai")
if isinstance(civitai, Mapping):
candidate = civitai.get("modelId") or civitai.get("model_id")
if candidate is None:
model_section = civitai.get("model")
if isinstance(model_section, Mapping):
candidate = model_section.get("id")
normalized = ModelLifecycleService._coerce_int(candidate)
if normalized is not None:
return normalized
fallback = payload.get("model_id") or payload.get("civitai_model_id")
return ModelLifecycleService._coerce_int(fallback)
@staticmethod
def _coerce_int(value: Any) -> Optional[int]:
try:
return int(value)
except (TypeError, ValueError):
return None
async def _sync_update_for_model(self, model_id: Optional[int]) -> None:
if self._update_service is None or model_id is None:
return
try:
versions = await self._scanner.get_model_versions_by_id(model_id)
except Exception as exc: # pragma: no cover - defensive log
logger.debug(
"Failed to collect local versions for model %s: %s", model_id, exc
)
versions = []
version_ids = set()
for version in versions or []:
candidate = (
version.get("versionId")
or version.get("id")
or version.get("version_id")
)
normalized = ModelLifecycleService._coerce_int(candidate)
if normalized is not None:
version_ids.add(normalized)
try:
await self._update_service.update_in_library_versions(
self._scanner.model_type,
model_id,
sorted(version_ids),
)
except Exception as exc: # pragma: no cover - defensive log
logger.debug(
"Failed to sync update record for model %s: %s", model_id, exc
)
async def exclude_model(self, file_path: str) -> Dict[str, object]:
"""Mark a model as excluded and prune cache references."""

View File

@@ -7,6 +7,67 @@ from py.services.model_lifecycle_service import ModelLifecycleService
from py.utils.metadata_manager import MetadataManager
class DummyCache:
def __init__(self, raw_data):
self.raw_data = raw_data
async def resort(self):
return
class DummyHashIndex:
def __init__(self):
self.removed = []
def remove_by_path(self, path, *args):
self.removed.append(path)
class VersionAwareScanner:
def __init__(self, raw_data, model_type="lora"):
self.model_type = model_type
self.cache = DummyCache(raw_data)
self._hash_index = DummyHashIndex()
async def get_cached_data(self):
return self.cache
async def get_model_versions_by_id(self, model_id):
collected = []
for item in self.cache.raw_data:
civitai = item.get("civitai")
if not isinstance(civitai, dict):
continue
candidate = civitai.get("modelId")
try:
normalized = int(candidate)
except (TypeError, ValueError):
continue
if normalized != model_id:
continue
version_id = civitai.get("id")
if version_id is None:
continue
collected.append({"versionId": version_id})
return collected
class DummyMetadataManager:
def __init__(self, payload):
self._payload = dict(payload)
async def load_metadata_payload(self, file_path: str):
return dict(self._payload)
class DummyUpdateService:
def __init__(self):
self.calls = []
async def update_in_library_versions(self, model_type, model_id, version_ids):
self.calls.append((model_type, model_id, version_ids))
class DummyScanner:
def __init__(self):
self.calls = []
@@ -81,3 +142,43 @@ async def test_rename_model_preserves_compound_extensions(tmp_path: Path):
assert old_call_path.endswith(f"{old_name}.safetensors")
assert new_call_path.endswith(f"{new_name}.safetensors")
assert payload["file_name"] == new_name
@pytest.mark.asyncio
async def test_delete_model_updates_update_service(tmp_path: Path):
model_path = tmp_path / "sample.safetensors"
model_path.write_bytes(b"content")
other_path = tmp_path / "another.safetensors"
other_path.write_bytes(b"other")
raw_data = [
{
"file_path": model_path.as_posix(),
"civitai": {"modelId": 42, "id": 1001},
},
{
"file_path": other_path.as_posix(),
"civitai": {"modelId": 42, "id": 1002},
},
]
scanner = VersionAwareScanner(raw_data)
metadata_manager = DummyMetadataManager({"civitai": {"modelId": 42, "id": 1001}})
async def metadata_loader(path: str):
return {}
update_service = DummyUpdateService()
service = ModelLifecycleService(
scanner=scanner,
metadata_manager=metadata_manager,
metadata_loader=metadata_loader,
update_service=update_service,
)
result = await service.delete_model(model_path.as_posix())
assert result["success"] is True
assert not model_path.exists()
assert update_service.calls == [("lora", 42, [1002])]