diff --git a/py/services/model_lifecycle_service.py b/py/services/model_lifecycle_service.py index dd67226c..88995b5f 100644 --- a/py/services/model_lifecycle_service.py +++ b/py/services/model_lifecycle_service.py @@ -15,18 +15,18 @@ if TYPE_CHECKING: from ..services.model_update_service import ModelUpdateService -async def delete_model_artifacts(target_dir: str, file_name: str) -> List[str]: +async def delete_model_artifacts( + target_dir: str, file_name: str, main_extension: str | None = None +) -> List[str]: """Delete the primary model artefacts within ``target_dir``.""" - patterns = [ - f"{file_name}.safetensors", - f"{file_name}.metadata.json", - ] + main_extension = ".safetensors" if main_extension is None else main_extension + main_file = f"{file_name}{main_extension}" if main_extension else file_name + patterns = [main_file, f"{file_name}.metadata.json"] for ext in PREVIEW_EXTENSIONS: patterns.append(f"{file_name}{ext}") deleted: List[str] = [] - main_file = patterns[0] main_path = os.path.join(target_dir, main_file).replace(os.sep, "/") if os.path.exists(main_path): @@ -94,8 +94,11 @@ class ModelLifecycleService: ) target_dir = os.path.dirname(file_path) - file_name = os.path.splitext(os.path.basename(file_path))[0] - deleted_files = await delete_model_artifacts(target_dir, file_name) + base_name = os.path.basename(file_path) + file_name, main_extension = os.path.splitext(base_name) + deleted_files = await delete_model_artifacts( + target_dir, file_name, main_extension=main_extension + ) if cache: cache.raw_data = [ diff --git a/py/services/model_scanner.py b/py/services/model_scanner.py index 8acaff4b..80c74e62 100644 --- a/py/services/model_scanner.py +++ b/py/services/model_scanner.py @@ -1444,11 +1444,13 @@ class ModelScanner: for file_path in file_paths: try: target_dir = os.path.dirname(file_path) - file_name = os.path.splitext(os.path.basename(file_path))[0] - + base_name = os.path.basename(file_path) + file_name, main_extension = os.path.splitext(base_name) + deleted_files = await delete_model_artifacts( target_dir, - file_name + file_name, + main_extension=main_extension, ) if deleted_files: diff --git a/tests/services/test_model_lifecycle_service.py b/tests/services/test_model_lifecycle_service.py index e29bfc0b..fa677e2b 100644 --- a/tests/services/test_model_lifecycle_service.py +++ b/tests/services/test_model_lifecycle_service.py @@ -182,3 +182,42 @@ async def test_delete_model_updates_update_service(tmp_path: Path): assert result["success"] is True assert not model_path.exists() assert update_service.calls == [("lora", 42, [1002])] + + +@pytest.mark.asyncio +async def test_delete_model_removes_gguf_file(tmp_path: Path): + model_path = tmp_path / "model.gguf" + model_path.write_bytes(b"content") + + metadata_path = tmp_path / "model.metadata.json" + metadata_path.write_text(json.dumps({})) + + preview_path = tmp_path / "model.preview.png" + preview_path.write_bytes(b"preview") + + raw_data = [ + { + "file_path": model_path.as_posix(), + "civitai": {"modelId": 1, "id": 10}, + } + ] + + scanner = VersionAwareScanner(raw_data) + metadata_manager = DummyMetadataManager({"civitai": {"modelId": 1, "id": 10}}) + + async def metadata_loader(path: str): + return {} + + service = ModelLifecycleService( + scanner=scanner, + metadata_manager=metadata_manager, + metadata_loader=metadata_loader, + ) + + result = await service.delete_model(model_path.as_posix()) + + assert result["success"] is True + assert not model_path.exists() + assert not metadata_path.exists() + assert not preview_path.exists() + assert any(item.endswith("model.gguf") for item in result["deleted_files"])