Merge pull request #683 from willmiao/deletion-sync, see #673

feat(model-lifecycle): integrate model update service for deletion sync
This commit is contained in:
pixelpaws
2025-11-18 23:28:17 +08:00
committed by GitHub
3 changed files with 192 additions and 6 deletions

View File

@@ -126,6 +126,7 @@ class BaseModelRoutes(ABC):
metadata_manager=MetadataManager, metadata_manager=MetadataManager,
metadata_loader=self._metadata_sync_service.load_local_metadata, metadata_loader=self._metadata_sync_service.load_local_metadata,
recipe_scanner_factory=ServiceRegistry.get_recipe_scanner, recipe_scanner_factory=ServiceRegistry.get_recipe_scanner,
update_service=self._model_update_service,
) )
self._handler_set = None self._handler_set = None
self._handler_mapping = None self._handler_mapping = None
@@ -297,4 +298,3 @@ class BaseModelRoutes(ABC):
if self._model_update_service is None: if self._model_update_service is None:
raise RuntimeError("Model update service has not been attached") raise RuntimeError("Model update service has not been attached")
return self._model_update_service return self._model_update_service

View File

@@ -4,13 +4,16 @@ from __future__ import annotations
import logging import logging
import os import os
from typing import Awaitable, Callable, Dict, Iterable, List, Optional from typing import Any, Awaitable, Callable, Dict, Iterable, List, Mapping, Optional, TYPE_CHECKING
from ..services.service_registry import ServiceRegistry from ..services.service_registry import ServiceRegistry
from ..utils.constants import PREVIEW_EXTENSIONS from ..utils.constants import PREVIEW_EXTENSIONS
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from ..services.model_update_service import ModelUpdateService
async def delete_model_artifacts(target_dir: str, file_name: str) -> List[str]: async def delete_model_artifacts(target_dir: str, file_name: str) -> List[str]:
"""Delete the primary model artefacts within ``target_dir``.""" """Delete the primary model artefacts within ``target_dir``."""
@@ -54,6 +57,7 @@ class ModelLifecycleService:
metadata_manager, metadata_manager,
metadata_loader: Callable[[str], Awaitable[Dict[str, object]]], metadata_loader: Callable[[str], Awaitable[Dict[str, object]]],
recipe_scanner_factory: Callable[[], Awaitable] | None = None, recipe_scanner_factory: Callable[[], Awaitable] | None = None,
update_service: "ModelUpdateService" | None = None,
) -> None: ) -> None:
self._scanner = scanner self._scanner = scanner
self._metadata_manager = metadata_manager self._metadata_manager = metadata_manager
@@ -61,6 +65,7 @@ class ModelLifecycleService:
self._recipe_scanner_factory = ( self._recipe_scanner_factory = (
recipe_scanner_factory or ServiceRegistry.get_recipe_scanner recipe_scanner_factory or ServiceRegistry.get_recipe_scanner
) )
self._update_service = update_service
async def delete_model(self, file_path: str) -> Dict[str, object]: async def delete_model(self, file_path: str) -> Dict[str, object]:
"""Delete a model file and associated artefacts.""" """Delete a model file and associated artefacts."""
@@ -68,20 +73,100 @@ class ModelLifecycleService:
if not file_path: if not file_path:
raise ValueError("Model path is required") raise ValueError("Model path is required")
cache = await self._scanner.get_cached_data()
cached_entry = None
if cache and hasattr(cache, "raw_data"):
cached_entry = next(
(item for item in cache.raw_data if item.get("file_path") == file_path),
None,
)
metadata_payload = {}
try:
metadata_payload = await self._metadata_manager.load_metadata_payload(file_path)
except Exception as exc: # pragma: no cover - defensive guard
logger.debug("Failed to load metadata payload for %s: %s", file_path, exc)
model_id = (
self._extract_model_id_from_payload(metadata_payload)
or self._extract_model_id_from_payload(cached_entry)
)
target_dir = os.path.dirname(file_path) target_dir = os.path.dirname(file_path)
file_name = os.path.splitext(os.path.basename(file_path))[0] file_name = os.path.splitext(os.path.basename(file_path))[0]
deleted_files = await delete_model_artifacts(target_dir, file_name) deleted_files = await delete_model_artifacts(target_dir, file_name)
cache = await self._scanner.get_cached_data() if cache:
cache.raw_data = [item for item in cache.raw_data if item["file_path"] != file_path] cache.raw_data = [
await cache.resort() item for item in cache.raw_data if item.get("file_path") != file_path
]
await cache.resort()
if hasattr(self._scanner, "_hash_index") and self._scanner._hash_index: if hasattr(self._scanner, "_hash_index") and self._scanner._hash_index:
self._scanner._hash_index.remove_by_path(file_path) self._scanner._hash_index.remove_by_path(file_path)
await self._sync_update_for_model(model_id)
return {"success": True, "deleted_files": deleted_files} return {"success": True, "deleted_files": deleted_files}
@staticmethod
def _extract_model_id_from_payload(payload: Any) -> Optional[int]:
if not isinstance(payload, Mapping):
return None
civitai = payload.get("civitai")
if isinstance(civitai, Mapping):
candidate = civitai.get("modelId") or civitai.get("model_id")
if candidate is None:
model_section = civitai.get("model")
if isinstance(model_section, Mapping):
candidate = model_section.get("id")
normalized = ModelLifecycleService._coerce_int(candidate)
if normalized is not None:
return normalized
fallback = payload.get("model_id") or payload.get("civitai_model_id")
return ModelLifecycleService._coerce_int(fallback)
@staticmethod
def _coerce_int(value: Any) -> Optional[int]:
try:
return int(value)
except (TypeError, ValueError):
return None
async def _sync_update_for_model(self, model_id: Optional[int]) -> None:
if self._update_service is None or model_id is None:
return
try:
versions = await self._scanner.get_model_versions_by_id(model_id)
except Exception as exc: # pragma: no cover - defensive log
logger.debug(
"Failed to collect local versions for model %s: %s", model_id, exc
)
versions = []
version_ids = set()
for version in versions or []:
candidate = (
version.get("versionId")
or version.get("id")
or version.get("version_id")
)
normalized = ModelLifecycleService._coerce_int(candidate)
if normalized is not None:
version_ids.add(normalized)
try:
await self._update_service.update_in_library_versions(
self._scanner.model_type,
model_id,
sorted(version_ids),
)
except Exception as exc: # pragma: no cover - defensive log
logger.debug(
"Failed to sync update record for model %s: %s", model_id, exc
)
async def exclude_model(self, file_path: str) -> Dict[str, object]: async def exclude_model(self, file_path: str) -> Dict[str, object]:
"""Mark a model as excluded and prune cache references.""" """Mark a model as excluded and prune cache references."""

View File

@@ -7,6 +7,67 @@ from py.services.model_lifecycle_service import ModelLifecycleService
from py.utils.metadata_manager import MetadataManager from py.utils.metadata_manager import MetadataManager
class DummyCache:
def __init__(self, raw_data):
self.raw_data = raw_data
async def resort(self):
return
class DummyHashIndex:
def __init__(self):
self.removed = []
def remove_by_path(self, path, *args):
self.removed.append(path)
class VersionAwareScanner:
def __init__(self, raw_data, model_type="lora"):
self.model_type = model_type
self.cache = DummyCache(raw_data)
self._hash_index = DummyHashIndex()
async def get_cached_data(self):
return self.cache
async def get_model_versions_by_id(self, model_id):
collected = []
for item in self.cache.raw_data:
civitai = item.get("civitai")
if not isinstance(civitai, dict):
continue
candidate = civitai.get("modelId")
try:
normalized = int(candidate)
except (TypeError, ValueError):
continue
if normalized != model_id:
continue
version_id = civitai.get("id")
if version_id is None:
continue
collected.append({"versionId": version_id})
return collected
class DummyMetadataManager:
def __init__(self, payload):
self._payload = dict(payload)
async def load_metadata_payload(self, file_path: str):
return dict(self._payload)
class DummyUpdateService:
def __init__(self):
self.calls = []
async def update_in_library_versions(self, model_type, model_id, version_ids):
self.calls.append((model_type, model_id, version_ids))
class DummyScanner: class DummyScanner:
def __init__(self): def __init__(self):
self.calls = [] self.calls = []
@@ -81,3 +142,43 @@ async def test_rename_model_preserves_compound_extensions(tmp_path: Path):
assert old_call_path.endswith(f"{old_name}.safetensors") assert old_call_path.endswith(f"{old_name}.safetensors")
assert new_call_path.endswith(f"{new_name}.safetensors") assert new_call_path.endswith(f"{new_name}.safetensors")
assert payload["file_name"] == new_name assert payload["file_name"] == new_name
@pytest.mark.asyncio
async def test_delete_model_updates_update_service(tmp_path: Path):
model_path = tmp_path / "sample.safetensors"
model_path.write_bytes(b"content")
other_path = tmp_path / "another.safetensors"
other_path.write_bytes(b"other")
raw_data = [
{
"file_path": model_path.as_posix(),
"civitai": {"modelId": 42, "id": 1001},
},
{
"file_path": other_path.as_posix(),
"civitai": {"modelId": 42, "id": 1002},
},
]
scanner = VersionAwareScanner(raw_data)
metadata_manager = DummyMetadataManager({"civitai": {"modelId": 42, "id": 1001}})
async def metadata_loader(path: str):
return {}
update_service = DummyUpdateService()
service = ModelLifecycleService(
scanner=scanner,
metadata_manager=metadata_manager,
metadata_loader=metadata_loader,
update_service=update_service,
)
result = await service.delete_model(model_path.as_posix())
assert result["success"] is True
assert not model_path.exists()
assert update_service.calls == [("lora", 42, [1002])]