mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-22 05:32:12 -03:00
246 lines
8.8 KiB
Python
246 lines
8.8 KiB
Python
"""Service routines for model lifecycle mutations."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import os
|
|
from typing import Awaitable, Callable, Dict, Iterable, List, Optional
|
|
|
|
from ..services.service_registry import ServiceRegistry
|
|
from ..utils.constants import PREVIEW_EXTENSIONS
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
async def delete_model_artifacts(target_dir: str, file_name: str) -> List[str]:
|
|
"""Delete the primary model artefacts within ``target_dir``."""
|
|
|
|
patterns = [
|
|
f"{file_name}.safetensors",
|
|
f"{file_name}.metadata.json",
|
|
]
|
|
for ext in PREVIEW_EXTENSIONS:
|
|
patterns.append(f"{file_name}{ext}")
|
|
|
|
deleted: List[str] = []
|
|
main_file = patterns[0]
|
|
main_path = os.path.join(target_dir, main_file).replace(os.sep, "/")
|
|
|
|
if os.path.exists(main_path):
|
|
os.remove(main_path)
|
|
deleted.append(main_path)
|
|
else:
|
|
logger.warning("Model file not found: %s", main_file)
|
|
|
|
for pattern in patterns[1:]:
|
|
path = os.path.join(target_dir, pattern)
|
|
if os.path.exists(path):
|
|
try:
|
|
os.remove(path)
|
|
deleted.append(pattern)
|
|
except Exception as exc: # pragma: no cover - defensive path
|
|
logger.warning("Failed to delete %s: %s", pattern, exc)
|
|
|
|
return deleted
|
|
|
|
|
|
class ModelLifecycleService:
|
|
"""Co-ordinate destructive and mutating model operations."""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
scanner,
|
|
metadata_manager,
|
|
metadata_loader: Callable[[str], Awaitable[Dict[str, object]]],
|
|
recipe_scanner_factory: Callable[[], Awaitable] | None = None,
|
|
) -> None:
|
|
self._scanner = scanner
|
|
self._metadata_manager = metadata_manager
|
|
self._metadata_loader = metadata_loader
|
|
self._recipe_scanner_factory = (
|
|
recipe_scanner_factory or ServiceRegistry.get_recipe_scanner
|
|
)
|
|
|
|
async def delete_model(self, file_path: str) -> Dict[str, object]:
|
|
"""Delete a model file and associated artefacts."""
|
|
|
|
if not file_path:
|
|
raise ValueError("Model path is required")
|
|
|
|
target_dir = os.path.dirname(file_path)
|
|
file_name = os.path.splitext(os.path.basename(file_path))[0]
|
|
|
|
deleted_files = await delete_model_artifacts(target_dir, file_name)
|
|
|
|
cache = await self._scanner.get_cached_data()
|
|
cache.raw_data = [item for item in cache.raw_data if item["file_path"] != file_path]
|
|
await cache.resort()
|
|
|
|
if hasattr(self._scanner, "_hash_index") and self._scanner._hash_index:
|
|
self._scanner._hash_index.remove_by_path(file_path)
|
|
|
|
return {"success": True, "deleted_files": deleted_files}
|
|
|
|
async def exclude_model(self, file_path: str) -> Dict[str, object]:
|
|
"""Mark a model as excluded and prune cache references."""
|
|
|
|
if not file_path:
|
|
raise ValueError("Model path is required")
|
|
|
|
metadata_path = os.path.splitext(file_path)[0] + ".metadata.json"
|
|
metadata = await self._metadata_loader(metadata_path)
|
|
metadata["exclude"] = True
|
|
|
|
await self._metadata_manager.save_metadata(file_path, metadata)
|
|
|
|
cache = await self._scanner.get_cached_data()
|
|
model_to_remove = next(
|
|
(item for item in cache.raw_data if item["file_path"] == file_path),
|
|
None,
|
|
)
|
|
|
|
if model_to_remove:
|
|
for tag in model_to_remove.get("tags", []):
|
|
if tag in getattr(self._scanner, "_tags_count", {}):
|
|
self._scanner._tags_count[tag] = max(
|
|
0, self._scanner._tags_count[tag] - 1
|
|
)
|
|
if self._scanner._tags_count[tag] == 0:
|
|
del self._scanner._tags_count[tag]
|
|
|
|
if hasattr(self._scanner, "_hash_index") and self._scanner._hash_index:
|
|
self._scanner._hash_index.remove_by_path(file_path)
|
|
|
|
cache.raw_data = [
|
|
item for item in cache.raw_data if item["file_path"] != file_path
|
|
]
|
|
await cache.resort()
|
|
|
|
excluded = getattr(self._scanner, "_excluded_models", None)
|
|
if isinstance(excluded, list):
|
|
excluded.append(file_path)
|
|
|
|
message = f"Model {os.path.basename(file_path)} excluded"
|
|
return {"success": True, "message": message}
|
|
|
|
async def bulk_delete_models(self, file_paths: Iterable[str]) -> Dict[str, object]:
|
|
"""Delete a collection of models via the scanner bulk operation."""
|
|
|
|
file_paths = list(file_paths)
|
|
if not file_paths:
|
|
raise ValueError("No file paths provided for deletion")
|
|
|
|
return await self._scanner.bulk_delete_models(file_paths)
|
|
|
|
async def rename_model(
|
|
self, *, file_path: str, new_file_name: str
|
|
) -> Dict[str, object]:
|
|
"""Rename a model and its companion artefacts."""
|
|
|
|
if not file_path or not new_file_name:
|
|
raise ValueError("File path and new file name are required")
|
|
|
|
invalid_chars = {"/", "\\", ":", "*", "?", '"', "<", ">", "|"}
|
|
if any(char in new_file_name for char in invalid_chars):
|
|
raise ValueError("Invalid characters in file name")
|
|
|
|
target_dir = os.path.dirname(file_path)
|
|
old_file_name = os.path.splitext(os.path.basename(file_path))[0]
|
|
new_file_path = os.path.join(target_dir, f"{new_file_name}.safetensors").replace(
|
|
os.sep, "/"
|
|
)
|
|
|
|
if os.path.exists(new_file_path):
|
|
raise ValueError("A file with this name already exists")
|
|
|
|
patterns = [
|
|
f"{old_file_name}.safetensors",
|
|
f"{old_file_name}.metadata.json",
|
|
f"{old_file_name}.metadata.json.bak",
|
|
]
|
|
for ext in PREVIEW_EXTENSIONS:
|
|
patterns.append(f"{old_file_name}{ext}")
|
|
|
|
existing_files: List[tuple[str, str]] = []
|
|
for pattern in patterns:
|
|
path = os.path.join(target_dir, pattern)
|
|
if os.path.exists(path):
|
|
existing_files.append((path, pattern))
|
|
|
|
metadata_path = os.path.join(target_dir, f"{old_file_name}.metadata.json")
|
|
metadata: Optional[Dict[str, object]] = None
|
|
hash_value: Optional[str] = None
|
|
|
|
if os.path.exists(metadata_path):
|
|
metadata = await self._metadata_loader(metadata_path)
|
|
hash_value = metadata.get("sha256") if isinstance(metadata, dict) else None
|
|
|
|
renamed_files: List[str] = []
|
|
new_metadata_path: Optional[str] = None
|
|
new_preview: Optional[str] = None
|
|
|
|
for old_path, pattern in existing_files:
|
|
ext = self._get_multipart_ext(pattern)
|
|
new_path = os.path.join(target_dir, f"{new_file_name}{ext}").replace(
|
|
os.sep, "/"
|
|
)
|
|
os.rename(old_path, new_path)
|
|
renamed_files.append(new_path)
|
|
|
|
if ext == ".metadata.json":
|
|
new_metadata_path = new_path
|
|
|
|
if metadata and new_metadata_path:
|
|
metadata["file_name"] = new_file_name
|
|
metadata["file_path"] = new_file_path
|
|
|
|
if metadata.get("preview_url"):
|
|
old_preview = str(metadata["preview_url"])
|
|
ext = self._get_multipart_ext(old_preview)
|
|
new_preview = os.path.join(target_dir, f"{new_file_name}{ext}").replace(
|
|
os.sep, "/"
|
|
)
|
|
metadata["preview_url"] = new_preview
|
|
|
|
await self._metadata_manager.save_metadata(new_file_path, metadata)
|
|
|
|
if metadata:
|
|
await self._scanner.update_single_model_cache(
|
|
file_path, new_file_path, metadata
|
|
)
|
|
|
|
if hash_value and getattr(self._scanner, "model_type", "") == "lora":
|
|
recipe_scanner = await self._recipe_scanner_factory()
|
|
if recipe_scanner:
|
|
try:
|
|
await recipe_scanner.update_lora_filename_by_hash(
|
|
hash_value, new_file_name
|
|
)
|
|
except Exception as exc: # pragma: no cover - defensive logging
|
|
logger.error(
|
|
"Error updating recipe references for %s: %s",
|
|
file_path,
|
|
exc,
|
|
)
|
|
|
|
return {
|
|
"success": True,
|
|
"new_file_path": new_file_path,
|
|
"new_preview_path": new_preview,
|
|
"renamed_files": renamed_files,
|
|
"reload_required": False,
|
|
}
|
|
|
|
@staticmethod
|
|
def _get_multipart_ext(filename: str) -> str:
|
|
"""Return the extension for files with compound suffixes."""
|
|
|
|
parts = filename.split(".")
|
|
if len(parts) == 3:
|
|
return "." + ".".join(parts[-2:])
|
|
if len(parts) >= 4:
|
|
return "." + ".".join(parts[-3:])
|
|
return os.path.splitext(filename)[1]
|
|
|