mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
Merge pull request #453 from willmiao/codex/evaluate-need-for-further-refactoring
refactor: migrate model lifecycle handlers to dedicated service
This commit is contained in:
@@ -43,7 +43,7 @@ the set and the invariants that must hold after each handler returns.
|
|||||||
| --- | --- | --- | --- |
|
| --- | --- | --- | --- |
|
||||||
| `ModelPageView` | `/{prefix}` | `SettingsManager`, `server_i18n`, Jinja environment, `service.scanner` | Template is rendered with `is_initializing` flag when caches are cold; i18n filter is registered exactly once per environment instance. |
|
| `ModelPageView` | `/{prefix}` | `SettingsManager`, `server_i18n`, Jinja environment, `service.scanner` | Template is rendered with `is_initializing` flag when caches are cold; i18n filter is registered exactly once per environment instance. |
|
||||||
| `ModelListingHandler` | `/api/lm/{prefix}/list` | `service.get_paginated_data`, `service.format_response` | Listings respect pagination query parameters and cap `page_size` at 100; every item is formatted before response. |
|
| `ModelListingHandler` | `/api/lm/{prefix}/list` | `service.get_paginated_data`, `service.format_response` | Listings respect pagination query parameters and cap `page_size` at 100; every item is formatted before response. |
|
||||||
| `ModelManagementHandler` | Mutations (delete, exclude, metadata, preview, tags, rename, bulk delete, duplicate verification) | `ModelRouteUtils`, `MetadataSyncService`, `PreviewAssetService`, `TagUpdateService`, scanner cache/index | Cache state mirrors filesystem changes: deletes prune cache & hash index, preview replacements synchronise metadata and cache NSFW levels, metadata saves trigger cache resort when names change. |
|
| `ModelManagementHandler` | Mutations (delete, exclude, metadata, preview, tags, rename, bulk delete, duplicate verification) | `ModelLifecycleService`, `MetadataSyncService`, `PreviewAssetService`, `TagUpdateService`, scanner cache/index | Cache state mirrors filesystem changes: deletes prune cache & hash index, preview replacements synchronise metadata and cache NSFW levels, metadata saves trigger cache resort when names change. |
|
||||||
| `ModelQueryHandler` | Read-only queries (top tags, folders, duplicates, metadata, URLs) | Service query helpers & scanner cache | Outputs always wrapped in `{"success": True}` when no error; duplicate/filename grouping omits empty entries; invalid parameters (e.g. missing `model_root`) return HTTP 400. |
|
| `ModelQueryHandler` | Read-only queries (top tags, folders, duplicates, metadata, URLs) | Service query helpers & scanner cache | Outputs always wrapped in `{"success": True}` when no error; duplicate/filename grouping omits empty entries; invalid parameters (e.g. missing `model_root`) return HTTP 400. |
|
||||||
| `ModelDownloadHandler` | `/api/lm/download-model`, `/download-model-get`, `/download-progress/{id}`, `/cancel-download-get` | `DownloadModelUseCase`, `DownloadCoordinator`, `WebSocketManager` | Payload validation errors become HTTP 400 without mutating download progress cache; early-access failures surface as HTTP 401; successful downloads cache progress snapshots that back both WebSocket broadcasts and polling endpoints. |
|
| `ModelDownloadHandler` | `/api/lm/download-model`, `/download-model-get`, `/download-progress/{id}`, `/cancel-download-get` | `DownloadModelUseCase`, `DownloadCoordinator`, `WebSocketManager` | Payload validation errors become HTTP 400 without mutating download progress cache; early-access failures surface as HTTP 401; successful downloads cache progress snapshots that back both WebSocket broadcasts and polling endpoints. |
|
||||||
| `ModelCivitaiHandler` | CivitAI metadata routes | `MetadataSyncService`, metadata provider factory, `BulkMetadataRefreshUseCase` | `fetch_all_civitai` streams progress via `WebSocketBroadcastCallback`; version lookups validate model type before returning; local availability fields derive from hash lookups without mutating cache state. |
|
| `ModelCivitaiHandler` | CivitAI metadata routes | `MetadataSyncService`, metadata provider factory, `BulkMetadataRefreshUseCase` | `fetch_all_civitai` streams progress via `WebSocketBroadcastCallback`; version lookups validate model type before returning; local availability fields derive from hash lookups without mutating cache state. |
|
||||||
@@ -69,10 +69,10 @@ collaboration points:
|
|||||||
|
|
||||||
1. **Cache mutations** – Delete, exclude, rename, and bulk delete operations are
|
1. **Cache mutations** – Delete, exclude, rename, and bulk delete operations are
|
||||||
channelled through `ModelManagementHandler`. The handler delegates to
|
channelled through `ModelManagementHandler`. The handler delegates to
|
||||||
`ModelRouteUtils` or `MetadataSyncService`, and the scanner cache is mutated
|
`ModelLifecycleService` or `MetadataSyncService`, and the scanner cache is
|
||||||
in-place before the handler returns. The accompanying tests assert that
|
mutated in-place before the handler returns. The accompanying tests assert
|
||||||
`scanner._cache.raw_data` and `scanner._hash_index` stay in sync after each
|
that `scanner._cache.raw_data` and `scanner._hash_index` stay in sync after
|
||||||
mutation.
|
each mutation.
|
||||||
2. **Preview updates** – `PreviewAssetService.replace_preview` writes the new
|
2. **Preview updates** – `PreviewAssetService.replace_preview` writes the new
|
||||||
asset, `MetadataSyncService` persists the JSON metadata, and
|
asset, `MetadataSyncService` persists the JSON metadata, and
|
||||||
`scanner.update_preview_in_cache` mirrors the change. The handler returns
|
`scanner.update_preview_in_cache` mirrors the change. The handler returns
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ from ..services.downloader import get_downloader
|
|||||||
from ..services.metadata_service import get_default_metadata_provider, get_metadata_provider
|
from ..services.metadata_service import get_default_metadata_provider, get_metadata_provider
|
||||||
from ..services.metadata_sync_service import MetadataSyncService
|
from ..services.metadata_sync_service import MetadataSyncService
|
||||||
from ..services.model_file_service import ModelFileService, ModelMoveService
|
from ..services.model_file_service import ModelFileService, ModelMoveService
|
||||||
|
from ..services.model_lifecycle_service import ModelLifecycleService
|
||||||
from ..services.preview_asset_service import PreviewAssetService
|
from ..services.preview_asset_service import PreviewAssetService
|
||||||
from ..services.server_i18n import server_i18n as default_server_i18n
|
from ..services.server_i18n import server_i18n as default_server_i18n
|
||||||
from ..services.service_registry import ServiceRegistry
|
from ..services.service_registry import ServiceRegistry
|
||||||
@@ -75,6 +76,7 @@ class BaseModelRoutes(ABC):
|
|||||||
|
|
||||||
self.model_file_service: ModelFileService | None = None
|
self.model_file_service: ModelFileService | None = None
|
||||||
self.model_move_service: ModelMoveService | None = None
|
self.model_move_service: ModelMoveService | None = None
|
||||||
|
self.model_lifecycle_service: ModelLifecycleService | None = None
|
||||||
self.websocket_progress_callback = WebSocketProgressCallback()
|
self.websocket_progress_callback = WebSocketProgressCallback()
|
||||||
self.metadata_progress_callback = WebSocketBroadcastCallback()
|
self.metadata_progress_callback = WebSocketBroadcastCallback()
|
||||||
|
|
||||||
@@ -108,6 +110,12 @@ class BaseModelRoutes(ABC):
|
|||||||
self.model_type = service.model_type
|
self.model_type = service.model_type
|
||||||
self.model_file_service = ModelFileService(service.scanner, service.model_type)
|
self.model_file_service = ModelFileService(service.scanner, service.model_type)
|
||||||
self.model_move_service = ModelMoveService(service.scanner)
|
self.model_move_service = ModelMoveService(service.scanner)
|
||||||
|
self.model_lifecycle_service = ModelLifecycleService(
|
||||||
|
scanner=service.scanner,
|
||||||
|
metadata_manager=MetadataManager,
|
||||||
|
metadata_loader=self._metadata_sync_service.load_local_metadata,
|
||||||
|
recipe_scanner_factory=ServiceRegistry.get_recipe_scanner,
|
||||||
|
)
|
||||||
self._handler_set = None
|
self._handler_set = None
|
||||||
self._handler_mapping = None
|
self._handler_mapping = None
|
||||||
|
|
||||||
@@ -139,6 +147,7 @@ class BaseModelRoutes(ABC):
|
|||||||
metadata_sync=self._metadata_sync_service,
|
metadata_sync=self._metadata_sync_service,
|
||||||
preview_service=self._preview_service,
|
preview_service=self._preview_service,
|
||||||
tag_update_service=self._tag_update_service,
|
tag_update_service=self._tag_update_service,
|
||||||
|
lifecycle_service=self._ensure_lifecycle_service(),
|
||||||
)
|
)
|
||||||
query = ModelQueryHandler(service=service, logger=logger)
|
query = ModelQueryHandler(service=service, logger=logger)
|
||||||
download_use_case = DownloadModelUseCase(download_coordinator=self._download_coordinator)
|
download_use_case = DownloadModelUseCase(download_coordinator=self._download_coordinator)
|
||||||
@@ -248,6 +257,17 @@ class BaseModelRoutes(ABC):
|
|||||||
self.model_move_service = ModelMoveService(service.scanner)
|
self.model_move_service = ModelMoveService(service.scanner)
|
||||||
return self.model_move_service
|
return self.model_move_service
|
||||||
|
|
||||||
|
def _ensure_lifecycle_service(self) -> ModelLifecycleService:
|
||||||
|
if self.model_lifecycle_service is None:
|
||||||
|
service = self._ensure_service()
|
||||||
|
self.model_lifecycle_service = ModelLifecycleService(
|
||||||
|
scanner=service.scanner,
|
||||||
|
metadata_manager=MetadataManager,
|
||||||
|
metadata_loader=self._metadata_sync_service.load_local_metadata,
|
||||||
|
recipe_scanner_factory=ServiceRegistry.get_recipe_scanner,
|
||||||
|
)
|
||||||
|
return self.model_lifecycle_service
|
||||||
|
|
||||||
def _make_handler_proxy(self, name: str) -> Callable[[web.Request], web.StreamResponse]:
|
def _make_handler_proxy(self, name: str) -> Callable[[web.Request], web.StreamResponse]:
|
||||||
async def proxy(request: web.Request) -> web.StreamResponse:
|
async def proxy(request: web.Request) -> web.StreamResponse:
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -30,7 +30,6 @@ from ...services.use_cases import (
|
|||||||
from ...services.websocket_manager import WebSocketManager
|
from ...services.websocket_manager import WebSocketManager
|
||||||
from ...services.websocket_progress_callback import WebSocketProgressCallback
|
from ...services.websocket_progress_callback import WebSocketProgressCallback
|
||||||
from ...utils.file_utils import calculate_sha256
|
from ...utils.file_utils import calculate_sha256
|
||||||
from ...utils.routes_common import ModelRouteUtils
|
|
||||||
|
|
||||||
|
|
||||||
class ModelPageView:
|
class ModelPageView:
|
||||||
@@ -192,18 +191,44 @@ class ModelManagementHandler:
|
|||||||
metadata_sync: MetadataSyncService,
|
metadata_sync: MetadataSyncService,
|
||||||
preview_service: PreviewAssetService,
|
preview_service: PreviewAssetService,
|
||||||
tag_update_service: TagUpdateService,
|
tag_update_service: TagUpdateService,
|
||||||
|
lifecycle_service,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._service = service
|
self._service = service
|
||||||
self._logger = logger
|
self._logger = logger
|
||||||
self._metadata_sync = metadata_sync
|
self._metadata_sync = metadata_sync
|
||||||
self._preview_service = preview_service
|
self._preview_service = preview_service
|
||||||
self._tag_update_service = tag_update_service
|
self._tag_update_service = tag_update_service
|
||||||
|
self._lifecycle_service = lifecycle_service
|
||||||
|
|
||||||
async def delete_model(self, request: web.Request) -> web.Response:
|
async def delete_model(self, request: web.Request) -> web.Response:
|
||||||
return await ModelRouteUtils.handle_delete_model(request, self._service.scanner)
|
try:
|
||||||
|
data = await request.json()
|
||||||
|
file_path = data.get("file_path")
|
||||||
|
if not file_path:
|
||||||
|
return web.Response(text="Model path is required", status=400)
|
||||||
|
|
||||||
|
result = await self._lifecycle_service.delete_model(file_path)
|
||||||
|
return web.json_response(result)
|
||||||
|
except ValueError as exc:
|
||||||
|
return web.json_response({"success": False, "error": str(exc)}, status=400)
|
||||||
|
except Exception as exc:
|
||||||
|
self._logger.error("Error deleting model: %s", exc, exc_info=True)
|
||||||
|
return web.Response(text=str(exc), status=500)
|
||||||
|
|
||||||
async def exclude_model(self, request: web.Request) -> web.Response:
|
async def exclude_model(self, request: web.Request) -> web.Response:
|
||||||
return await ModelRouteUtils.handle_exclude_model(request, self._service.scanner)
|
try:
|
||||||
|
data = await request.json()
|
||||||
|
file_path = data.get("file_path")
|
||||||
|
if not file_path:
|
||||||
|
return web.Response(text="Model path is required", status=400)
|
||||||
|
|
||||||
|
result = await self._lifecycle_service.exclude_model(file_path)
|
||||||
|
return web.json_response(result)
|
||||||
|
except ValueError as exc:
|
||||||
|
return web.json_response({"success": False, "error": str(exc)}, status=400)
|
||||||
|
except Exception as exc:
|
||||||
|
self._logger.error("Error excluding model: %s", exc, exc_info=True)
|
||||||
|
return web.Response(text=str(exc), status=500)
|
||||||
|
|
||||||
async def fetch_civitai(self, request: web.Request) -> web.Response:
|
async def fetch_civitai(self, request: web.Request) -> web.Response:
|
||||||
try:
|
try:
|
||||||
@@ -375,10 +400,58 @@ class ModelManagementHandler:
|
|||||||
return web.Response(text=str(exc), status=500)
|
return web.Response(text=str(exc), status=500)
|
||||||
|
|
||||||
async def rename_model(self, request: web.Request) -> web.Response:
|
async def rename_model(self, request: web.Request) -> web.Response:
|
||||||
return await ModelRouteUtils.handle_rename_model(request, self._service.scanner)
|
try:
|
||||||
|
data = await request.json()
|
||||||
|
file_path = data.get("file_path")
|
||||||
|
new_file_name = data.get("new_file_name")
|
||||||
|
|
||||||
|
if not file_path or not new_file_name:
|
||||||
|
return web.json_response(
|
||||||
|
{
|
||||||
|
"success": False,
|
||||||
|
"error": "File path and new file name are required",
|
||||||
|
},
|
||||||
|
status=400,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await self._lifecycle_service.rename_model(
|
||||||
|
file_path=file_path, new_file_name=new_file_name
|
||||||
|
)
|
||||||
|
|
||||||
|
return web.json_response(
|
||||||
|
{
|
||||||
|
**result,
|
||||||
|
"new_preview_path": config.get_preview_static_url(
|
||||||
|
result.get("new_preview_path")
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except ValueError as exc:
|
||||||
|
return web.json_response({"success": False, "error": str(exc)}, status=400)
|
||||||
|
except Exception as exc:
|
||||||
|
self._logger.error("Error renaming model: %s", exc, exc_info=True)
|
||||||
|
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||||
|
|
||||||
async def bulk_delete_models(self, request: web.Request) -> web.Response:
|
async def bulk_delete_models(self, request: web.Request) -> web.Response:
|
||||||
return await ModelRouteUtils.handle_bulk_delete_models(request, self._service.scanner)
|
try:
|
||||||
|
data = await request.json()
|
||||||
|
file_paths = data.get("file_paths", [])
|
||||||
|
if not file_paths:
|
||||||
|
return web.json_response(
|
||||||
|
{
|
||||||
|
"success": False,
|
||||||
|
"error": "No file paths provided for deletion",
|
||||||
|
},
|
||||||
|
status=400,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await self._lifecycle_service.bulk_delete_models(file_paths)
|
||||||
|
return web.json_response(result)
|
||||||
|
except ValueError as exc:
|
||||||
|
return web.json_response({"success": False, "error": str(exc)}, status=400)
|
||||||
|
except Exception as exc:
|
||||||
|
self._logger.error("Error in bulk delete: %s", exc, exc_info=True)
|
||||||
|
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||||
|
|
||||||
async def verify_duplicates(self, request: web.Request) -> web.Response:
|
async def verify_duplicates(self, request: web.Request) -> web.Response:
|
||||||
try:
|
try:
|
||||||
|
|||||||
245
py/services/model_lifecycle_service.py
Normal file
245
py/services/model_lifecycle_service.py
Normal file
@@ -0,0 +1,245 @@
|
|||||||
|
"""Service routines for model lifecycle mutations."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import Awaitable, Callable, Dict, Iterable, List, Optional
|
||||||
|
|
||||||
|
from ..services.service_registry import ServiceRegistry
|
||||||
|
from ..utils.constants import PREVIEW_EXTENSIONS
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def delete_model_artifacts(target_dir: str, file_name: str) -> List[str]:
|
||||||
|
"""Delete the primary model artefacts within ``target_dir``."""
|
||||||
|
|
||||||
|
patterns = [
|
||||||
|
f"{file_name}.safetensors",
|
||||||
|
f"{file_name}.metadata.json",
|
||||||
|
]
|
||||||
|
for ext in PREVIEW_EXTENSIONS:
|
||||||
|
patterns.append(f"{file_name}{ext}")
|
||||||
|
|
||||||
|
deleted: List[str] = []
|
||||||
|
main_file = patterns[0]
|
||||||
|
main_path = os.path.join(target_dir, main_file).replace(os.sep, "/")
|
||||||
|
|
||||||
|
if os.path.exists(main_path):
|
||||||
|
os.remove(main_path)
|
||||||
|
deleted.append(main_path)
|
||||||
|
else:
|
||||||
|
logger.warning("Model file not found: %s", main_file)
|
||||||
|
|
||||||
|
for pattern in patterns[1:]:
|
||||||
|
path = os.path.join(target_dir, pattern)
|
||||||
|
if os.path.exists(path):
|
||||||
|
try:
|
||||||
|
os.remove(path)
|
||||||
|
deleted.append(pattern)
|
||||||
|
except Exception as exc: # pragma: no cover - defensive path
|
||||||
|
logger.warning("Failed to delete %s: %s", pattern, exc)
|
||||||
|
|
||||||
|
return deleted
|
||||||
|
|
||||||
|
|
||||||
|
class ModelLifecycleService:
|
||||||
|
"""Co-ordinate destructive and mutating model operations."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
scanner,
|
||||||
|
metadata_manager,
|
||||||
|
metadata_loader: Callable[[str], Awaitable[Dict[str, object]]],
|
||||||
|
recipe_scanner_factory: Callable[[], Awaitable] | None = None,
|
||||||
|
) -> None:
|
||||||
|
self._scanner = scanner
|
||||||
|
self._metadata_manager = metadata_manager
|
||||||
|
self._metadata_loader = metadata_loader
|
||||||
|
self._recipe_scanner_factory = (
|
||||||
|
recipe_scanner_factory or ServiceRegistry.get_recipe_scanner
|
||||||
|
)
|
||||||
|
|
||||||
|
async def delete_model(self, file_path: str) -> Dict[str, object]:
|
||||||
|
"""Delete a model file and associated artefacts."""
|
||||||
|
|
||||||
|
if not file_path:
|
||||||
|
raise ValueError("Model path is required")
|
||||||
|
|
||||||
|
target_dir = os.path.dirname(file_path)
|
||||||
|
file_name = os.path.splitext(os.path.basename(file_path))[0]
|
||||||
|
|
||||||
|
deleted_files = await delete_model_artifacts(target_dir, file_name)
|
||||||
|
|
||||||
|
cache = await self._scanner.get_cached_data()
|
||||||
|
cache.raw_data = [item for item in cache.raw_data if item["file_path"] != file_path]
|
||||||
|
await cache.resort()
|
||||||
|
|
||||||
|
if hasattr(self._scanner, "_hash_index") and self._scanner._hash_index:
|
||||||
|
self._scanner._hash_index.remove_by_path(file_path)
|
||||||
|
|
||||||
|
return {"success": True, "deleted_files": deleted_files}
|
||||||
|
|
||||||
|
async def exclude_model(self, file_path: str) -> Dict[str, object]:
|
||||||
|
"""Mark a model as excluded and prune cache references."""
|
||||||
|
|
||||||
|
if not file_path:
|
||||||
|
raise ValueError("Model path is required")
|
||||||
|
|
||||||
|
metadata_path = os.path.splitext(file_path)[0] + ".metadata.json"
|
||||||
|
metadata = await self._metadata_loader(metadata_path)
|
||||||
|
metadata["exclude"] = True
|
||||||
|
|
||||||
|
await self._metadata_manager.save_metadata(file_path, metadata)
|
||||||
|
|
||||||
|
cache = await self._scanner.get_cached_data()
|
||||||
|
model_to_remove = next(
|
||||||
|
(item for item in cache.raw_data if item["file_path"] == file_path),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
if model_to_remove:
|
||||||
|
for tag in model_to_remove.get("tags", []):
|
||||||
|
if tag in getattr(self._scanner, "_tags_count", {}):
|
||||||
|
self._scanner._tags_count[tag] = max(
|
||||||
|
0, self._scanner._tags_count[tag] - 1
|
||||||
|
)
|
||||||
|
if self._scanner._tags_count[tag] == 0:
|
||||||
|
del self._scanner._tags_count[tag]
|
||||||
|
|
||||||
|
if hasattr(self._scanner, "_hash_index") and self._scanner._hash_index:
|
||||||
|
self._scanner._hash_index.remove_by_path(file_path)
|
||||||
|
|
||||||
|
cache.raw_data = [
|
||||||
|
item for item in cache.raw_data if item["file_path"] != file_path
|
||||||
|
]
|
||||||
|
await cache.resort()
|
||||||
|
|
||||||
|
excluded = getattr(self._scanner, "_excluded_models", None)
|
||||||
|
if isinstance(excluded, list):
|
||||||
|
excluded.append(file_path)
|
||||||
|
|
||||||
|
message = f"Model {os.path.basename(file_path)} excluded"
|
||||||
|
return {"success": True, "message": message}
|
||||||
|
|
||||||
|
async def bulk_delete_models(self, file_paths: Iterable[str]) -> Dict[str, object]:
|
||||||
|
"""Delete a collection of models via the scanner bulk operation."""
|
||||||
|
|
||||||
|
file_paths = list(file_paths)
|
||||||
|
if not file_paths:
|
||||||
|
raise ValueError("No file paths provided for deletion")
|
||||||
|
|
||||||
|
return await self._scanner.bulk_delete_models(file_paths)
|
||||||
|
|
||||||
|
async def rename_model(
|
||||||
|
self, *, file_path: str, new_file_name: str
|
||||||
|
) -> Dict[str, object]:
|
||||||
|
"""Rename a model and its companion artefacts."""
|
||||||
|
|
||||||
|
if not file_path or not new_file_name:
|
||||||
|
raise ValueError("File path and new file name are required")
|
||||||
|
|
||||||
|
invalid_chars = {"/", "\\", ":", "*", "?", '"', "<", ">", "|"}
|
||||||
|
if any(char in new_file_name for char in invalid_chars):
|
||||||
|
raise ValueError("Invalid characters in file name")
|
||||||
|
|
||||||
|
target_dir = os.path.dirname(file_path)
|
||||||
|
old_file_name = os.path.splitext(os.path.basename(file_path))[0]
|
||||||
|
new_file_path = os.path.join(target_dir, f"{new_file_name}.safetensors").replace(
|
||||||
|
os.sep, "/"
|
||||||
|
)
|
||||||
|
|
||||||
|
if os.path.exists(new_file_path):
|
||||||
|
raise ValueError("A file with this name already exists")
|
||||||
|
|
||||||
|
patterns = [
|
||||||
|
f"{old_file_name}.safetensors",
|
||||||
|
f"{old_file_name}.metadata.json",
|
||||||
|
f"{old_file_name}.metadata.json.bak",
|
||||||
|
]
|
||||||
|
for ext in PREVIEW_EXTENSIONS:
|
||||||
|
patterns.append(f"{old_file_name}{ext}")
|
||||||
|
|
||||||
|
existing_files: List[tuple[str, str]] = []
|
||||||
|
for pattern in patterns:
|
||||||
|
path = os.path.join(target_dir, pattern)
|
||||||
|
if os.path.exists(path):
|
||||||
|
existing_files.append((path, pattern))
|
||||||
|
|
||||||
|
metadata_path = os.path.join(target_dir, f"{old_file_name}.metadata.json")
|
||||||
|
metadata: Optional[Dict[str, object]] = None
|
||||||
|
hash_value: Optional[str] = None
|
||||||
|
|
||||||
|
if os.path.exists(metadata_path):
|
||||||
|
metadata = await self._metadata_loader(metadata_path)
|
||||||
|
hash_value = metadata.get("sha256") if isinstance(metadata, dict) else None
|
||||||
|
|
||||||
|
renamed_files: List[str] = []
|
||||||
|
new_metadata_path: Optional[str] = None
|
||||||
|
new_preview: Optional[str] = None
|
||||||
|
|
||||||
|
for old_path, pattern in existing_files:
|
||||||
|
ext = self._get_multipart_ext(pattern)
|
||||||
|
new_path = os.path.join(target_dir, f"{new_file_name}{ext}").replace(
|
||||||
|
os.sep, "/"
|
||||||
|
)
|
||||||
|
os.rename(old_path, new_path)
|
||||||
|
renamed_files.append(new_path)
|
||||||
|
|
||||||
|
if ext == ".metadata.json":
|
||||||
|
new_metadata_path = new_path
|
||||||
|
|
||||||
|
if metadata and new_metadata_path:
|
||||||
|
metadata["file_name"] = new_file_name
|
||||||
|
metadata["file_path"] = new_file_path
|
||||||
|
|
||||||
|
if metadata.get("preview_url"):
|
||||||
|
old_preview = str(metadata["preview_url"])
|
||||||
|
ext = self._get_multipart_ext(old_preview)
|
||||||
|
new_preview = os.path.join(target_dir, f"{new_file_name}{ext}").replace(
|
||||||
|
os.sep, "/"
|
||||||
|
)
|
||||||
|
metadata["preview_url"] = new_preview
|
||||||
|
|
||||||
|
await self._metadata_manager.save_metadata(new_file_path, metadata)
|
||||||
|
|
||||||
|
if metadata:
|
||||||
|
await self._scanner.update_single_model_cache(
|
||||||
|
file_path, new_file_path, metadata
|
||||||
|
)
|
||||||
|
|
||||||
|
if hash_value and getattr(self._scanner, "model_type", "") == "lora":
|
||||||
|
recipe_scanner = await self._recipe_scanner_factory()
|
||||||
|
if recipe_scanner:
|
||||||
|
try:
|
||||||
|
await recipe_scanner.update_lora_filename_by_hash(
|
||||||
|
hash_value, new_file_name
|
||||||
|
)
|
||||||
|
except Exception as exc: # pragma: no cover - defensive logging
|
||||||
|
logger.error(
|
||||||
|
"Error updating recipe references for %s: %s",
|
||||||
|
file_path,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"new_file_path": new_file_path,
|
||||||
|
"new_preview_path": new_preview,
|
||||||
|
"renamed_files": renamed_files,
|
||||||
|
"reload_required": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_multipart_ext(filename: str) -> str:
|
||||||
|
"""Return the extension for files with compound suffixes."""
|
||||||
|
|
||||||
|
parts = filename.split(".")
|
||||||
|
if len(parts) == 3:
|
||||||
|
return "." + ".".join(parts[-2:])
|
||||||
|
if len(parts) >= 4:
|
||||||
|
return "." + ".".join(parts[-3:])
|
||||||
|
return os.path.splitext(filename)[1]
|
||||||
|
|
||||||
@@ -13,6 +13,7 @@ from ..utils.metadata_manager import MetadataManager
|
|||||||
from .model_cache import ModelCache
|
from .model_cache import ModelCache
|
||||||
from .model_hash_index import ModelHashIndex
|
from .model_hash_index import ModelHashIndex
|
||||||
from ..utils.constants import PREVIEW_EXTENSIONS
|
from ..utils.constants import PREVIEW_EXTENSIONS
|
||||||
|
from .model_lifecycle_service import delete_model_artifacts
|
||||||
from .service_registry import ServiceRegistry
|
from .service_registry import ServiceRegistry
|
||||||
from .websocket_manager import ws_manager
|
from .websocket_manager import ws_manager
|
||||||
|
|
||||||
@@ -1040,10 +1041,8 @@ class ModelScanner:
|
|||||||
target_dir = os.path.dirname(file_path)
|
target_dir = os.path.dirname(file_path)
|
||||||
file_name = os.path.splitext(os.path.basename(file_path))[0]
|
file_name = os.path.splitext(os.path.basename(file_path))[0]
|
||||||
|
|
||||||
# Delete all associated files for the model
|
deleted_files = await delete_model_artifacts(
|
||||||
from ..utils.routes_common import ModelRouteUtils
|
target_dir,
|
||||||
deleted_files = await ModelRouteUtils.delete_model_files(
|
|
||||||
target_dir,
|
|
||||||
file_name
|
file_name
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Dict, List, Callable, Awaitable
|
from typing import Dict, Callable, Awaitable
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
@@ -284,104 +284,6 @@ class ModelRouteUtils:
|
|||||||
]
|
]
|
||||||
return {k: data[k] for k in fields if k in data}
|
return {k: data[k] for k in fields if k in data}
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def delete_model_files(target_dir: str, file_name: str) -> List[str]:
|
|
||||||
"""Delete model and associated files
|
|
||||||
|
|
||||||
Args:
|
|
||||||
target_dir: Directory containing the model files
|
|
||||||
file_name: Base name of the model file without extension
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of deleted file paths
|
|
||||||
"""
|
|
||||||
patterns = [
|
|
||||||
f"{file_name}.safetensors", # Required
|
|
||||||
f"{file_name}.metadata.json",
|
|
||||||
]
|
|
||||||
|
|
||||||
# Add all preview file extensions
|
|
||||||
for ext in PREVIEW_EXTENSIONS:
|
|
||||||
patterns.append(f"{file_name}{ext}")
|
|
||||||
|
|
||||||
deleted = []
|
|
||||||
main_file = patterns[0]
|
|
||||||
main_path = os.path.join(target_dir, main_file).replace(os.sep, '/')
|
|
||||||
|
|
||||||
if os.path.exists(main_path):
|
|
||||||
# Delete file
|
|
||||||
os.remove(main_path)
|
|
||||||
deleted.append(main_path)
|
|
||||||
else:
|
|
||||||
logger.warning(f"Model file not found: {main_file}")
|
|
||||||
|
|
||||||
# Delete optional files
|
|
||||||
for pattern in patterns[1:]:
|
|
||||||
path = os.path.join(target_dir, pattern)
|
|
||||||
if os.path.exists(path):
|
|
||||||
try:
|
|
||||||
os.remove(path)
|
|
||||||
deleted.append(pattern)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to delete {pattern}: {e}")
|
|
||||||
|
|
||||||
return deleted
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_multipart_ext(filename):
|
|
||||||
"""Get extension that may have multiple parts like .metadata.json or .metadata.json.bak"""
|
|
||||||
parts = filename.split(".")
|
|
||||||
if len(parts) == 3: # If contains 2-part extension
|
|
||||||
return "." + ".".join(parts[-2:]) # Take the last two parts, like ".metadata.json"
|
|
||||||
elif len(parts) >= 4: # If contains 3-part or more extensions
|
|
||||||
return "." + ".".join(parts[-3:]) # Take the last three parts, like ".metadata.json.bak"
|
|
||||||
return os.path.splitext(filename)[1] # Otherwise take the regular extension, like ".safetensors"
|
|
||||||
|
|
||||||
# New common endpoint handlers
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def handle_delete_model(request: web.Request, scanner) -> web.Response:
|
|
||||||
"""Handle model deletion request
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request: The aiohttp request
|
|
||||||
scanner: The model scanner instance with cache management methods
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
web.Response: The HTTP response
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
data = await request.json()
|
|
||||||
file_path = data.get('file_path')
|
|
||||||
if not file_path:
|
|
||||||
return web.Response(text='Model path is required', status=400)
|
|
||||||
|
|
||||||
target_dir = os.path.dirname(file_path)
|
|
||||||
file_name = os.path.splitext(os.path.basename(file_path))[0]
|
|
||||||
|
|
||||||
deleted_files = await ModelRouteUtils.delete_model_files(
|
|
||||||
target_dir,
|
|
||||||
file_name
|
|
||||||
)
|
|
||||||
|
|
||||||
# Remove from cache
|
|
||||||
cache = await scanner.get_cached_data()
|
|
||||||
cache.raw_data = [item for item in cache.raw_data if item['file_path'] != file_path]
|
|
||||||
await cache.resort()
|
|
||||||
|
|
||||||
# Update hash index if available
|
|
||||||
if hasattr(scanner, '_hash_index') and scanner._hash_index:
|
|
||||||
scanner._hash_index.remove_by_path(file_path)
|
|
||||||
|
|
||||||
return web.json_response({
|
|
||||||
'success': True,
|
|
||||||
'deleted_files': deleted_files
|
|
||||||
})
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error deleting model: {e}", exc_info=True)
|
|
||||||
return web.Response(text=str(e), status=500)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def handle_fetch_civitai(request: web.Request, scanner) -> web.Response:
|
async def handle_fetch_civitai(request: web.Request, scanner) -> web.Response:
|
||||||
"""Handle CivitAI metadata fetch request
|
"""Handle CivitAI metadata fetch request
|
||||||
@@ -544,64 +446,6 @@ class ModelRouteUtils:
|
|||||||
logger.error(f"Error replacing preview: {e}", exc_info=True)
|
logger.error(f"Error replacing preview: {e}", exc_info=True)
|
||||||
return web.Response(text=str(e), status=500)
|
return web.Response(text=str(e), status=500)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def handle_exclude_model(request: web.Request, scanner) -> web.Response:
|
|
||||||
"""Handle model exclusion request
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request: The aiohttp request
|
|
||||||
scanner: The model scanner instance with cache management methods
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
web.Response: The HTTP response
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
data = await request.json()
|
|
||||||
file_path = data.get('file_path')
|
|
||||||
if not file_path:
|
|
||||||
return web.Response(text='Model path is required', status=400)
|
|
||||||
|
|
||||||
# Update metadata to mark as excluded
|
|
||||||
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
|
|
||||||
metadata = await ModelRouteUtils.load_local_metadata(metadata_path)
|
|
||||||
metadata['exclude'] = True
|
|
||||||
|
|
||||||
# Save updated metadata
|
|
||||||
await MetadataManager.save_metadata(file_path, metadata)
|
|
||||||
|
|
||||||
# Update cache
|
|
||||||
cache = await scanner.get_cached_data()
|
|
||||||
|
|
||||||
# Find and remove model from cache
|
|
||||||
model_to_remove = next((item for item in cache.raw_data if item['file_path'] == file_path), None)
|
|
||||||
if model_to_remove:
|
|
||||||
# Update tags count
|
|
||||||
for tag in model_to_remove.get('tags', []):
|
|
||||||
if tag in scanner._tags_count:
|
|
||||||
scanner._tags_count[tag] = max(0, scanner._tags_count[tag] - 1)
|
|
||||||
if scanner._tags_count[tag] == 0:
|
|
||||||
del scanner._tags_count[tag]
|
|
||||||
|
|
||||||
# Remove from hash index if available
|
|
||||||
if hasattr(scanner, '_hash_index') and scanner._hash_index:
|
|
||||||
scanner._hash_index.remove_by_path(file_path)
|
|
||||||
|
|
||||||
# Remove from cache data
|
|
||||||
cache.raw_data = [item for item in cache.raw_data if item['file_path'] != file_path]
|
|
||||||
await cache.resort()
|
|
||||||
|
|
||||||
# Add to excluded models list
|
|
||||||
scanner._excluded_models.append(file_path)
|
|
||||||
|
|
||||||
return web.json_response({
|
|
||||||
'success': True,
|
|
||||||
'message': f"Model {os.path.basename(file_path)} excluded"
|
|
||||||
})
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error excluding model: {e}", exc_info=True)
|
|
||||||
return web.Response(text=str(e), status=500)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def handle_download_model(request: web.Request) -> web.Response:
|
async def handle_download_model(request: web.Request) -> web.Response:
|
||||||
"""Handle model download request"""
|
"""Handle model download request"""
|
||||||
@@ -755,44 +599,6 @@ class ModelRouteUtils:
|
|||||||
'error': str(e)
|
'error': str(e)
|
||||||
}, status=500)
|
}, status=500)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def handle_bulk_delete_models(request: web.Request, scanner) -> web.Response:
|
|
||||||
"""Handle bulk deletion of models
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request: The aiohttp request
|
|
||||||
scanner: The model scanner instance with cache management methods
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
web.Response: The HTTP response
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
data = await request.json()
|
|
||||||
file_paths = data.get('file_paths', [])
|
|
||||||
|
|
||||||
if not file_paths:
|
|
||||||
return web.json_response({
|
|
||||||
'success': False,
|
|
||||||
'error': 'No file paths provided for deletion'
|
|
||||||
}, status=400)
|
|
||||||
|
|
||||||
# Use the scanner's bulk delete method to handle all cache and file operations
|
|
||||||
result = await scanner.bulk_delete_models(file_paths)
|
|
||||||
|
|
||||||
return web.json_response({
|
|
||||||
'success': result.get('success', False),
|
|
||||||
'total_deleted': result.get('total_deleted', 0),
|
|
||||||
'total_attempted': result.get('total_attempted', len(file_paths)),
|
|
||||||
'results': result.get('results', [])
|
|
||||||
})
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error in bulk delete: {e}", exc_info=True)
|
|
||||||
return web.json_response({
|
|
||||||
'success': False,
|
|
||||||
'error': str(e)
|
|
||||||
}, status=500)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def handle_relink_civitai(request: web.Request, scanner) -> web.Response:
|
async def handle_relink_civitai(request: web.Request, scanner) -> web.Response:
|
||||||
"""Handle CivitAI metadata re-linking request by model ID and/or version ID
|
"""Handle CivitAI metadata re-linking request by model ID and/or version ID
|
||||||
@@ -948,137 +754,6 @@ class ModelRouteUtils:
|
|||||||
'error': str(e)
|
'error': str(e)
|
||||||
}, status=500)
|
}, status=500)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def handle_rename_model(request: web.Request, scanner) -> web.Response:
|
|
||||||
"""Handle renaming a model file and its associated files
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request: The aiohttp request
|
|
||||||
scanner: The model scanner instance
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
web.Response: The HTTP response
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
data = await request.json()
|
|
||||||
file_path = data.get('file_path')
|
|
||||||
new_file_name = data.get('new_file_name')
|
|
||||||
|
|
||||||
if not file_path or not new_file_name:
|
|
||||||
return web.json_response({
|
|
||||||
'success': False,
|
|
||||||
'error': 'File path and new file name are required'
|
|
||||||
}, status=400)
|
|
||||||
|
|
||||||
# Validate the new file name (no path separators or invalid characters)
|
|
||||||
invalid_chars = ['/', '\\', ':', '*', '?', '"', '<', '>', '|']
|
|
||||||
if any(char in new_file_name for char in invalid_chars):
|
|
||||||
return web.json_response({
|
|
||||||
'success': False,
|
|
||||||
'error': 'Invalid characters in file name'
|
|
||||||
}, status=400)
|
|
||||||
|
|
||||||
# Get the directory and current file name
|
|
||||||
target_dir = os.path.dirname(file_path)
|
|
||||||
old_file_name = os.path.splitext(os.path.basename(file_path))[0]
|
|
||||||
|
|
||||||
# Check if the target file already exists
|
|
||||||
new_file_path = os.path.join(target_dir, f"{new_file_name}.safetensors").replace(os.sep, '/')
|
|
||||||
if os.path.exists(new_file_path):
|
|
||||||
return web.json_response({
|
|
||||||
'success': False,
|
|
||||||
'error': 'A file with this name already exists'
|
|
||||||
}, status=400)
|
|
||||||
|
|
||||||
# Define the patterns for associated files
|
|
||||||
patterns = [
|
|
||||||
f"{old_file_name}.safetensors", # Required
|
|
||||||
f"{old_file_name}.metadata.json",
|
|
||||||
f"{old_file_name}.metadata.json.bak",
|
|
||||||
]
|
|
||||||
|
|
||||||
# Add all preview file extensions
|
|
||||||
for ext in PREVIEW_EXTENSIONS:
|
|
||||||
patterns.append(f"{old_file_name}{ext}")
|
|
||||||
|
|
||||||
# Find all matching files
|
|
||||||
existing_files = []
|
|
||||||
for pattern in patterns:
|
|
||||||
path = os.path.join(target_dir, pattern)
|
|
||||||
if os.path.exists(path):
|
|
||||||
existing_files.append((path, pattern))
|
|
||||||
|
|
||||||
# Get the hash from the main file to update hash index
|
|
||||||
hash_value = None
|
|
||||||
metadata = None
|
|
||||||
metadata_path = os.path.join(target_dir, f"{old_file_name}.metadata.json")
|
|
||||||
|
|
||||||
if os.path.exists(metadata_path):
|
|
||||||
metadata = await ModelRouteUtils.load_local_metadata(metadata_path)
|
|
||||||
hash_value = metadata.get('sha256')
|
|
||||||
logger.info(f"hash_value: {hash_value}, metadata_path: {metadata_path}, metadata: {metadata}")
|
|
||||||
# Rename all files
|
|
||||||
renamed_files = []
|
|
||||||
new_metadata_path = None
|
|
||||||
new_preview = None
|
|
||||||
|
|
||||||
for old_path, pattern in existing_files:
|
|
||||||
# Get the file extension like .safetensors or .metadata.json
|
|
||||||
ext = ModelRouteUtils.get_multipart_ext(pattern)
|
|
||||||
|
|
||||||
# Create the new path
|
|
||||||
new_path = os.path.join(target_dir, f"{new_file_name}{ext}").replace(os.sep, '/')
|
|
||||||
|
|
||||||
# Rename the file
|
|
||||||
os.rename(old_path, new_path)
|
|
||||||
renamed_files.append(new_path)
|
|
||||||
|
|
||||||
# Keep track of metadata path for later update
|
|
||||||
if ext == '.metadata.json':
|
|
||||||
new_metadata_path = new_path
|
|
||||||
|
|
||||||
# Update the metadata file with new file name and paths
|
|
||||||
if new_metadata_path and metadata:
|
|
||||||
# Update file_name, file_path and preview_url in metadata
|
|
||||||
metadata['file_name'] = new_file_name
|
|
||||||
metadata['file_path'] = new_file_path
|
|
||||||
|
|
||||||
# Update preview_url if it exists
|
|
||||||
if 'preview_url' in metadata and metadata['preview_url']:
|
|
||||||
old_preview = metadata['preview_url']
|
|
||||||
ext = ModelRouteUtils.get_multipart_ext(old_preview)
|
|
||||||
new_preview = os.path.join(target_dir, f"{new_file_name}{ext}").replace(os.sep, '/')
|
|
||||||
metadata['preview_url'] = new_preview
|
|
||||||
|
|
||||||
# Save updated metadata
|
|
||||||
await MetadataManager.save_metadata(new_file_path, metadata)
|
|
||||||
|
|
||||||
# Update the scanner cache
|
|
||||||
if metadata:
|
|
||||||
await scanner.update_single_model_cache(file_path, new_file_path, metadata)
|
|
||||||
|
|
||||||
# Update recipe files and cache if hash is available and recipe_scanner exists
|
|
||||||
if hash_value and hasattr(scanner, 'update_lora_filename_by_hash'):
|
|
||||||
recipe_scanner = await ServiceRegistry.get_recipe_scanner()
|
|
||||||
if recipe_scanner:
|
|
||||||
recipes_updated, cache_updated = await recipe_scanner.update_lora_filename_by_hash(hash_value, new_file_name)
|
|
||||||
logger.info(f"Updated {recipes_updated} recipe files and {cache_updated} cache entries for renamed model")
|
|
||||||
|
|
||||||
return web.json_response({
|
|
||||||
'success': True,
|
|
||||||
'new_file_path': new_file_path,
|
|
||||||
'new_preview_path': config.get_preview_static_url(new_preview),
|
|
||||||
'renamed_files': renamed_files,
|
|
||||||
'reload_required': False
|
|
||||||
})
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error renaming model: {e}", exc_info=True)
|
|
||||||
return web.json_response({
|
|
||||||
'success': False,
|
|
||||||
'error': str(e)
|
|
||||||
}, status=500)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def handle_save_metadata(request: web.Request, scanner) -> web.Response:
|
async def handle_save_metadata(request: web.Request, scanner) -> web.Response:
|
||||||
"""Handle saving metadata updates
|
"""Handle saving metadata updates
|
||||||
|
|||||||
Reference in New Issue
Block a user