mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-24 22:52:12 -03:00
- Add optional main_extension parameter to delete_model_artifacts function - Extract file extension from model filename to handle different file types - Update model scanner to pass file extension when deleting models - Add test case for GGUF file deletion to ensure proper cleanup - Maintain backward compatibility with existing safetensors models This change allows the model lifecycle service to properly delete GGUF model files along with their associated metadata and preview files, expanding support beyond just safetensors format.
344 lines
12 KiB
Python
344 lines
12 KiB
Python
"""Service routines for model lifecycle mutations."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import os
|
|
from typing import Any, Awaitable, Callable, Dict, Iterable, List, Mapping, Optional, TYPE_CHECKING
|
|
|
|
from ..services.service_registry import ServiceRegistry
|
|
from ..utils.constants import PREVIEW_EXTENSIONS
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
if TYPE_CHECKING:
|
|
from ..services.model_update_service import ModelUpdateService
|
|
|
|
|
|
async def delete_model_artifacts(
|
|
target_dir: str, file_name: str, main_extension: str | None = None
|
|
) -> List[str]:
|
|
"""Delete the primary model artefacts within ``target_dir``."""
|
|
|
|
main_extension = ".safetensors" if main_extension is None else main_extension
|
|
main_file = f"{file_name}{main_extension}" if main_extension else file_name
|
|
patterns = [main_file, f"{file_name}.metadata.json"]
|
|
for ext in PREVIEW_EXTENSIONS:
|
|
patterns.append(f"{file_name}{ext}")
|
|
|
|
deleted: List[str] = []
|
|
main_path = os.path.join(target_dir, main_file).replace(os.sep, "/")
|
|
|
|
if os.path.exists(main_path):
|
|
os.remove(main_path)
|
|
deleted.append(main_path)
|
|
else:
|
|
logger.warning("Model file not found: %s", main_file)
|
|
|
|
for pattern in patterns[1:]:
|
|
path = os.path.join(target_dir, pattern)
|
|
if os.path.exists(path):
|
|
try:
|
|
os.remove(path)
|
|
deleted.append(pattern)
|
|
except Exception as exc: # pragma: no cover - defensive path
|
|
logger.warning("Failed to delete %s: %s", pattern, exc)
|
|
|
|
return deleted
|
|
|
|
|
|
class ModelLifecycleService:
|
|
"""Co-ordinate destructive and mutating model operations."""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
scanner,
|
|
metadata_manager,
|
|
metadata_loader: Callable[[str], Awaitable[Dict[str, object]]],
|
|
recipe_scanner_factory: Callable[[], Awaitable] | None = None,
|
|
update_service: "ModelUpdateService" | None = None,
|
|
) -> None:
|
|
self._scanner = scanner
|
|
self._metadata_manager = metadata_manager
|
|
self._metadata_loader = metadata_loader
|
|
self._recipe_scanner_factory = (
|
|
recipe_scanner_factory or ServiceRegistry.get_recipe_scanner
|
|
)
|
|
self._update_service = update_service
|
|
|
|
async def delete_model(self, file_path: str) -> Dict[str, object]:
|
|
"""Delete a model file and associated artefacts."""
|
|
|
|
if not file_path:
|
|
raise ValueError("Model path is required")
|
|
|
|
cache = await self._scanner.get_cached_data()
|
|
|
|
cached_entry = None
|
|
if cache and hasattr(cache, "raw_data"):
|
|
cached_entry = next(
|
|
(item for item in cache.raw_data if item.get("file_path") == file_path),
|
|
None,
|
|
)
|
|
|
|
metadata_payload = {}
|
|
try:
|
|
metadata_payload = await self._metadata_manager.load_metadata_payload(file_path)
|
|
except Exception as exc: # pragma: no cover - defensive guard
|
|
logger.debug("Failed to load metadata payload for %s: %s", file_path, exc)
|
|
|
|
model_id = (
|
|
self._extract_model_id_from_payload(metadata_payload)
|
|
or self._extract_model_id_from_payload(cached_entry)
|
|
)
|
|
|
|
target_dir = os.path.dirname(file_path)
|
|
base_name = os.path.basename(file_path)
|
|
file_name, main_extension = os.path.splitext(base_name)
|
|
deleted_files = await delete_model_artifacts(
|
|
target_dir, file_name, main_extension=main_extension
|
|
)
|
|
|
|
if cache:
|
|
cache.raw_data = [
|
|
item for item in cache.raw_data if item.get("file_path") != file_path
|
|
]
|
|
await cache.resort()
|
|
|
|
if hasattr(self._scanner, "_hash_index") and self._scanner._hash_index:
|
|
self._scanner._hash_index.remove_by_path(file_path)
|
|
|
|
await self._sync_update_for_model(model_id)
|
|
return {"success": True, "deleted_files": deleted_files}
|
|
|
|
@staticmethod
|
|
def _extract_model_id_from_payload(payload: Any) -> Optional[int]:
|
|
if not isinstance(payload, Mapping):
|
|
return None
|
|
civitai = payload.get("civitai")
|
|
if isinstance(civitai, Mapping):
|
|
candidate = civitai.get("modelId") or civitai.get("model_id")
|
|
if candidate is None:
|
|
model_section = civitai.get("model")
|
|
if isinstance(model_section, Mapping):
|
|
candidate = model_section.get("id")
|
|
normalized = ModelLifecycleService._coerce_int(candidate)
|
|
if normalized is not None:
|
|
return normalized
|
|
fallback = payload.get("model_id") or payload.get("civitai_model_id")
|
|
return ModelLifecycleService._coerce_int(fallback)
|
|
|
|
@staticmethod
|
|
def _coerce_int(value: Any) -> Optional[int]:
|
|
try:
|
|
return int(value)
|
|
except (TypeError, ValueError):
|
|
return None
|
|
|
|
async def _sync_update_for_model(self, model_id: Optional[int]) -> None:
|
|
if self._update_service is None or model_id is None:
|
|
return
|
|
|
|
try:
|
|
versions = await self._scanner.get_model_versions_by_id(model_id)
|
|
except Exception as exc: # pragma: no cover - defensive log
|
|
logger.debug(
|
|
"Failed to collect local versions for model %s: %s", model_id, exc
|
|
)
|
|
versions = []
|
|
|
|
version_ids = set()
|
|
for version in versions or []:
|
|
candidate = (
|
|
version.get("versionId")
|
|
or version.get("id")
|
|
or version.get("version_id")
|
|
)
|
|
normalized = ModelLifecycleService._coerce_int(candidate)
|
|
if normalized is not None:
|
|
version_ids.add(normalized)
|
|
|
|
try:
|
|
await self._update_service.update_in_library_versions(
|
|
self._scanner.model_type,
|
|
model_id,
|
|
sorted(version_ids),
|
|
)
|
|
except Exception as exc: # pragma: no cover - defensive log
|
|
logger.debug(
|
|
"Failed to sync update record for model %s: %s", model_id, exc
|
|
)
|
|
|
|
async def exclude_model(self, file_path: str) -> Dict[str, object]:
|
|
"""Mark a model as excluded and prune cache references."""
|
|
|
|
if not file_path:
|
|
raise ValueError("Model path is required")
|
|
|
|
metadata_path = os.path.splitext(file_path)[0] + ".metadata.json"
|
|
metadata = await self._metadata_loader(metadata_path)
|
|
metadata["exclude"] = True
|
|
|
|
await self._metadata_manager.save_metadata(file_path, metadata)
|
|
|
|
cache = await self._scanner.get_cached_data()
|
|
model_to_remove = next(
|
|
(item for item in cache.raw_data if item["file_path"] == file_path),
|
|
None,
|
|
)
|
|
|
|
if model_to_remove:
|
|
for tag in model_to_remove.get("tags", []):
|
|
if tag in getattr(self._scanner, "_tags_count", {}):
|
|
self._scanner._tags_count[tag] = max(
|
|
0, self._scanner._tags_count[tag] - 1
|
|
)
|
|
if self._scanner._tags_count[tag] == 0:
|
|
del self._scanner._tags_count[tag]
|
|
|
|
if hasattr(self._scanner, "_hash_index") and self._scanner._hash_index:
|
|
self._scanner._hash_index.remove_by_path(file_path)
|
|
|
|
cache.raw_data = [
|
|
item for item in cache.raw_data if item["file_path"] != file_path
|
|
]
|
|
await cache.resort()
|
|
|
|
excluded = getattr(self._scanner, "_excluded_models", None)
|
|
if isinstance(excluded, list):
|
|
excluded.append(file_path)
|
|
|
|
message = f"Model {os.path.basename(file_path)} excluded"
|
|
return {"success": True, "message": message}
|
|
|
|
async def bulk_delete_models(self, file_paths: Iterable[str]) -> Dict[str, object]:
|
|
"""Delete a collection of models via the scanner bulk operation."""
|
|
|
|
file_paths = list(file_paths)
|
|
if not file_paths:
|
|
raise ValueError("No file paths provided for deletion")
|
|
|
|
return await self._scanner.bulk_delete_models(file_paths)
|
|
|
|
async def rename_model(
|
|
self, *, file_path: str, new_file_name: str
|
|
) -> Dict[str, object]:
|
|
"""Rename a model and its companion artefacts."""
|
|
|
|
if not file_path or not new_file_name:
|
|
raise ValueError("File path and new file name are required")
|
|
|
|
invalid_chars = {"/", "\\", ":", "*", "?", '"', "<", ">", "|"}
|
|
if any(char in new_file_name for char in invalid_chars):
|
|
raise ValueError("Invalid characters in file name")
|
|
|
|
target_dir = os.path.dirname(file_path)
|
|
old_file_name = os.path.splitext(os.path.basename(file_path))[0]
|
|
new_file_path = os.path.join(target_dir, f"{new_file_name}.safetensors").replace(
|
|
os.sep, "/"
|
|
)
|
|
|
|
if os.path.exists(new_file_path):
|
|
raise ValueError("A file with this name already exists")
|
|
|
|
patterns = [
|
|
f"{old_file_name}.safetensors",
|
|
f"{old_file_name}.metadata.json",
|
|
f"{old_file_name}.metadata.json.bak",
|
|
]
|
|
for ext in PREVIEW_EXTENSIONS:
|
|
patterns.append(f"{old_file_name}{ext}")
|
|
|
|
existing_files: List[tuple[str, str]] = []
|
|
for pattern in patterns:
|
|
path = os.path.join(target_dir, pattern)
|
|
if os.path.exists(path):
|
|
existing_files.append((path, pattern))
|
|
|
|
metadata_path = os.path.join(target_dir, f"{old_file_name}.metadata.json")
|
|
metadata: Optional[Dict[str, object]] = None
|
|
hash_value: Optional[str] = None
|
|
|
|
if os.path.exists(metadata_path):
|
|
metadata = await self._metadata_loader(metadata_path)
|
|
hash_value = metadata.get("sha256") if isinstance(metadata, dict) else None
|
|
|
|
renamed_files: List[str] = []
|
|
new_metadata_path: Optional[str] = None
|
|
new_preview: Optional[str] = None
|
|
|
|
for old_path, pattern in existing_files:
|
|
ext = self._get_multipart_ext(pattern)
|
|
new_path = os.path.join(target_dir, f"{new_file_name}{ext}").replace(
|
|
os.sep, "/"
|
|
)
|
|
os.rename(old_path, new_path)
|
|
renamed_files.append(new_path)
|
|
|
|
if ext == ".metadata.json":
|
|
new_metadata_path = new_path
|
|
|
|
if metadata and new_metadata_path:
|
|
metadata["file_name"] = new_file_name
|
|
metadata["file_path"] = new_file_path
|
|
|
|
if metadata.get("preview_url"):
|
|
old_preview = str(metadata["preview_url"])
|
|
ext = self._get_multipart_ext(old_preview)
|
|
new_preview = os.path.join(target_dir, f"{new_file_name}{ext}").replace(
|
|
os.sep, "/"
|
|
)
|
|
metadata["preview_url"] = new_preview
|
|
|
|
await self._metadata_manager.save_metadata(new_file_path, metadata)
|
|
|
|
if metadata:
|
|
await self._scanner.update_single_model_cache(
|
|
file_path, new_file_path, metadata
|
|
)
|
|
|
|
if hash_value and getattr(self._scanner, "model_type", "") == "lora":
|
|
recipe_scanner = await self._recipe_scanner_factory()
|
|
if recipe_scanner:
|
|
try:
|
|
await recipe_scanner.update_lora_filename_by_hash(
|
|
hash_value, new_file_name
|
|
)
|
|
except Exception as exc: # pragma: no cover - defensive logging
|
|
logger.error(
|
|
"Error updating recipe references for %s: %s",
|
|
file_path,
|
|
exc,
|
|
)
|
|
|
|
return {
|
|
"success": True,
|
|
"new_file_path": new_file_path,
|
|
"new_preview_path": new_preview,
|
|
"renamed_files": renamed_files,
|
|
"reload_required": False,
|
|
}
|
|
|
|
@staticmethod
|
|
def _get_multipart_ext(filename: str) -> str:
|
|
"""Return the extension for files with compound suffixes."""
|
|
|
|
known_suffixes = [
|
|
".metadata.json.bak",
|
|
".metadata.json",
|
|
".safetensors",
|
|
*PREVIEW_EXTENSIONS,
|
|
]
|
|
|
|
for suffix in sorted(known_suffixes, key=len, reverse=True):
|
|
if filename.endswith(suffix):
|
|
return suffix
|
|
|
|
basename = os.path.basename(filename)
|
|
dot_index = basename.find(".")
|
|
if dot_index != -1:
|
|
return basename[dot_index:]
|
|
|
|
return os.path.splitext(basename)[1]
|