mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-24 14:42:11 -03:00
refactor(routes): extract route utilities into services
This commit is contained in:
@@ -8,12 +8,20 @@ import jinja2
|
||||
from aiohttp import web
|
||||
|
||||
from ..config import config
|
||||
from ..services.metadata_service import get_default_metadata_provider
|
||||
from ..services.download_coordinator import DownloadCoordinator
|
||||
from ..services.downloader import get_downloader
|
||||
from ..services.metadata_service import get_default_metadata_provider, get_metadata_provider
|
||||
from ..services.metadata_sync_service import MetadataSyncService
|
||||
from ..services.model_file_service import ModelFileService, ModelMoveService
|
||||
from ..services.preview_asset_service import PreviewAssetService
|
||||
from ..services.server_i18n import server_i18n as default_server_i18n
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
from ..services.settings_manager import settings as default_settings
|
||||
from ..services.tag_update_service import TagUpdateService
|
||||
from ..services.websocket_manager import ws_manager as default_ws_manager
|
||||
from ..services.websocket_progress_callback import WebSocketProgressCallback
|
||||
from ..services.server_i18n import server_i18n as default_server_i18n
|
||||
from ..utils.exif_utils import ExifUtils
|
||||
from ..utils.metadata_manager import MetadataManager
|
||||
from ..utils.routes_common import ModelRouteUtils
|
||||
from .model_route_registrar import COMMON_ROUTE_DEFINITIONS, ModelRouteRegistrar
|
||||
from .handlers.model_handlers import (
|
||||
@@ -64,6 +72,24 @@ class BaseModelRoutes(ABC):
|
||||
self._handler_set: ModelHandlerSet | None = None
|
||||
self._handler_mapping: Dict[str, Callable[[web.Request], web.StreamResponse]] | None = None
|
||||
|
||||
self._preview_service = PreviewAssetService(
|
||||
metadata_manager=MetadataManager,
|
||||
downloader_factory=get_downloader,
|
||||
exif_utils=ExifUtils,
|
||||
)
|
||||
self._metadata_sync_service = MetadataSyncService(
|
||||
metadata_manager=MetadataManager,
|
||||
preview_service=self._preview_service,
|
||||
settings=settings_service,
|
||||
default_metadata_provider_factory=metadata_provider_factory,
|
||||
metadata_provider_selector=get_metadata_provider,
|
||||
)
|
||||
self._tag_update_service = TagUpdateService(metadata_manager=MetadataManager)
|
||||
self._download_coordinator = DownloadCoordinator(
|
||||
ws_manager=self._ws_manager,
|
||||
download_manager_factory=ServiceRegistry.get_download_manager,
|
||||
)
|
||||
|
||||
if service is not None:
|
||||
self.attach_service(service)
|
||||
|
||||
@@ -98,9 +124,19 @@ class BaseModelRoutes(ABC):
|
||||
parse_specific_params=self._parse_specific_params,
|
||||
logger=logger,
|
||||
)
|
||||
management = ModelManagementHandler(service=service, logger=logger)
|
||||
management = ModelManagementHandler(
|
||||
service=service,
|
||||
logger=logger,
|
||||
metadata_sync=self._metadata_sync_service,
|
||||
preview_service=self._preview_service,
|
||||
tag_update_service=self._tag_update_service,
|
||||
)
|
||||
query = ModelQueryHandler(service=service, logger=logger)
|
||||
download = ModelDownloadHandler(ws_manager=self._ws_manager, logger=logger)
|
||||
download = ModelDownloadHandler(
|
||||
ws_manager=self._ws_manager,
|
||||
logger=logger,
|
||||
download_coordinator=self._download_coordinator,
|
||||
)
|
||||
civitai = ModelCivitaiHandler(
|
||||
service=service,
|
||||
settings_service=self._settings,
|
||||
@@ -110,6 +146,7 @@ class BaseModelRoutes(ABC):
|
||||
validate_model_type=self._validate_civitai_model_type,
|
||||
expected_model_types=self._get_expected_model_types,
|
||||
find_model_file=self._find_model_file,
|
||||
metadata_sync=self._metadata_sync_service,
|
||||
)
|
||||
move = ModelMoveHandler(move_service=self._ensure_move_service(), logger=logger)
|
||||
auto_organize = ModelAutoOrganizeHandler(
|
||||
|
||||
@@ -4,16 +4,23 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Awaitable, Callable, Dict, Iterable, Mapping, Optional
|
||||
|
||||
from aiohttp import web
|
||||
import jinja2
|
||||
|
||||
from ...config import config
|
||||
from ...services.download_coordinator import DownloadCoordinator
|
||||
from ...services.metadata_sync_service import MetadataSyncService
|
||||
from ...services.model_file_service import ModelFileService, ModelMoveService
|
||||
from ...services.websocket_progress_callback import WebSocketProgressCallback
|
||||
from ...services.websocket_manager import WebSocketManager
|
||||
from ...services.preview_asset_service import PreviewAssetService
|
||||
from ...services.settings_manager import SettingsManager
|
||||
from ...services.tag_update_service import TagUpdateService
|
||||
from ...services.websocket_manager import WebSocketManager
|
||||
from ...services.websocket_progress_callback import WebSocketProgressCallback
|
||||
from ...utils.file_utils import calculate_sha256
|
||||
from ...utils.routes_common import ModelRouteUtils
|
||||
|
||||
|
||||
@@ -168,9 +175,20 @@ class ModelListingHandler:
|
||||
class ModelManagementHandler:
|
||||
"""Handle mutation operations on models."""
|
||||
|
||||
def __init__(self, *, service, logger: logging.Logger) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
service,
|
||||
logger: logging.Logger,
|
||||
metadata_sync: MetadataSyncService,
|
||||
preview_service: PreviewAssetService,
|
||||
tag_update_service: TagUpdateService,
|
||||
) -> None:
|
||||
self._service = service
|
||||
self._logger = logger
|
||||
self._metadata_sync = metadata_sync
|
||||
self._preview_service = preview_service
|
||||
self._tag_update_service = tag_update_service
|
||||
|
||||
async def delete_model(self, request: web.Request) -> web.Response:
|
||||
return await ModelRouteUtils.handle_delete_model(request, self._service.scanner)
|
||||
@@ -192,7 +210,7 @@ class ModelManagementHandler:
|
||||
if not model_data.get("sha256"):
|
||||
return web.json_response({"success": False, "error": "No SHA256 hash found"}, status=400)
|
||||
|
||||
success, error = await ModelRouteUtils.fetch_and_update_model(
|
||||
success, error = await self._metadata_sync.fetch_and_update_model(
|
||||
sha256=model_data["sha256"],
|
||||
file_path=file_path,
|
||||
model_data=model_data,
|
||||
@@ -208,16 +226,144 @@ class ModelManagementHandler:
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
async def relink_civitai(self, request: web.Request) -> web.Response:
|
||||
return await ModelRouteUtils.handle_relink_civitai(request, self._service.scanner)
|
||||
try:
|
||||
data = await request.json()
|
||||
file_path = data.get("file_path")
|
||||
model_id = data.get("model_id")
|
||||
model_version_id = data.get("model_version_id")
|
||||
|
||||
if not file_path or model_id is None:
|
||||
return web.json_response(
|
||||
{"success": False, "error": "Both file_path and model_id are required"},
|
||||
status=400,
|
||||
)
|
||||
|
||||
metadata_path = os.path.splitext(file_path)[0] + ".metadata.json"
|
||||
local_metadata = await self._metadata_sync.load_local_metadata(metadata_path)
|
||||
|
||||
updated_metadata = await self._metadata_sync.relink_metadata(
|
||||
file_path=file_path,
|
||||
metadata=local_metadata,
|
||||
model_id=int(model_id),
|
||||
model_version_id=int(model_version_id) if model_version_id else None,
|
||||
)
|
||||
|
||||
await self._service.scanner.update_single_model_cache(
|
||||
file_path, file_path, updated_metadata
|
||||
)
|
||||
|
||||
message = (
|
||||
f"Model successfully re-linked to Civitai model {model_id}"
|
||||
+ (f" version {model_version_id}" if model_version_id else "")
|
||||
)
|
||||
return web.json_response(
|
||||
{"success": True, "message": message, "hash": updated_metadata.get("sha256", "")}
|
||||
)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error re-linking to CivitAI: %s", exc, exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
async def replace_preview(self, request: web.Request) -> web.Response:
|
||||
return await ModelRouteUtils.handle_replace_preview(request, self._service.scanner)
|
||||
try:
|
||||
reader = await request.multipart()
|
||||
|
||||
field = await reader.next()
|
||||
if field is None or field.name != "preview_file":
|
||||
raise ValueError("Expected 'preview_file' field")
|
||||
content_type = field.headers.get("Content-Type", "image/png")
|
||||
content_disposition = field.headers.get("Content-Disposition", "")
|
||||
|
||||
original_filename = None
|
||||
import re
|
||||
|
||||
match = re.search(r'filename="(.*?)"', content_disposition)
|
||||
if match:
|
||||
original_filename = match.group(1)
|
||||
|
||||
preview_data = await field.read()
|
||||
|
||||
field = await reader.next()
|
||||
if field is None or field.name != "model_path":
|
||||
raise ValueError("Expected 'model_path' field")
|
||||
model_path = (await field.read()).decode()
|
||||
|
||||
nsfw_level = 0
|
||||
field = await reader.next()
|
||||
if field and field.name == "nsfw_level":
|
||||
try:
|
||||
nsfw_level = int((await field.read()).decode())
|
||||
except (ValueError, TypeError):
|
||||
self._logger.warning("Invalid NSFW level format, using default 0")
|
||||
|
||||
result = await self._preview_service.replace_preview(
|
||||
model_path=model_path,
|
||||
preview_data=preview_data,
|
||||
content_type=content_type,
|
||||
original_filename=original_filename,
|
||||
nsfw_level=nsfw_level,
|
||||
update_preview_in_cache=self._service.scanner.update_preview_in_cache,
|
||||
metadata_loader=self._metadata_sync.load_local_metadata,
|
||||
)
|
||||
|
||||
return web.json_response(
|
||||
{
|
||||
"success": True,
|
||||
"preview_url": config.get_preview_static_url(result["preview_path"]),
|
||||
"preview_nsfw_level": result["preview_nsfw_level"],
|
||||
}
|
||||
)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error replacing preview: %s", exc, exc_info=True)
|
||||
return web.Response(text=str(exc), status=500)
|
||||
|
||||
async def save_metadata(self, request: web.Request) -> web.Response:
|
||||
return await ModelRouteUtils.handle_save_metadata(request, self._service.scanner)
|
||||
try:
|
||||
data = await request.json()
|
||||
file_path = data.get("file_path")
|
||||
if not file_path:
|
||||
return web.Response(text="File path is required", status=400)
|
||||
|
||||
metadata_updates = {k: v for k, v in data.items() if k != "file_path"}
|
||||
|
||||
await self._metadata_sync.save_metadata_updates(
|
||||
file_path=file_path,
|
||||
updates=metadata_updates,
|
||||
metadata_loader=self._metadata_sync.load_local_metadata,
|
||||
update_cache=self._service.scanner.update_single_model_cache,
|
||||
)
|
||||
|
||||
if "model_name" in metadata_updates:
|
||||
cache = await self._service.scanner.get_cached_data()
|
||||
await cache.resort()
|
||||
|
||||
return web.json_response({"success": True})
|
||||
except Exception as exc:
|
||||
self._logger.error("Error saving metadata: %s", exc, exc_info=True)
|
||||
return web.Response(text=str(exc), status=500)
|
||||
|
||||
async def add_tags(self, request: web.Request) -> web.Response:
|
||||
return await ModelRouteUtils.handle_add_tags(request, self._service.scanner)
|
||||
try:
|
||||
data = await request.json()
|
||||
file_path = data.get("file_path")
|
||||
new_tags = data.get("tags", [])
|
||||
|
||||
if not file_path:
|
||||
return web.Response(text="File path is required", status=400)
|
||||
|
||||
if not isinstance(new_tags, list):
|
||||
return web.Response(text="Tags must be a list", status=400)
|
||||
|
||||
tags = await self._tag_update_service.add_tags(
|
||||
file_path=file_path,
|
||||
new_tags=new_tags,
|
||||
metadata_loader=self._metadata_sync.load_local_metadata,
|
||||
update_cache=self._service.scanner.update_single_model_cache,
|
||||
)
|
||||
|
||||
return web.json_response({"success": True, "tags": tags})
|
||||
except Exception as exc:
|
||||
self._logger.error("Error adding tags: %s", exc, exc_info=True)
|
||||
return web.Response(text=str(exc), status=500)
|
||||
|
||||
async def rename_model(self, request: web.Request) -> web.Response:
|
||||
return await ModelRouteUtils.handle_rename_model(request, self._service.scanner)
|
||||
@@ -226,7 +372,27 @@ class ModelManagementHandler:
|
||||
return await ModelRouteUtils.handle_bulk_delete_models(request, self._service.scanner)
|
||||
|
||||
async def verify_duplicates(self, request: web.Request) -> web.Response:
|
||||
return await ModelRouteUtils.handle_verify_duplicates(request, self._service.scanner)
|
||||
try:
|
||||
data = await request.json()
|
||||
file_paths = data.get("file_paths", [])
|
||||
|
||||
if not file_paths:
|
||||
return web.json_response(
|
||||
{"success": False, "error": "No file paths provided for verification"},
|
||||
status=400,
|
||||
)
|
||||
|
||||
results = await self._metadata_sync.verify_duplicate_hashes(
|
||||
file_paths=file_paths,
|
||||
metadata_loader=self._metadata_sync.load_local_metadata,
|
||||
hash_calculator=calculate_sha256,
|
||||
update_cache=self._service.scanner.update_single_model_cache,
|
||||
)
|
||||
|
||||
return web.json_response({"success": True, **results})
|
||||
except Exception as exc:
|
||||
self._logger.error("Error verifying duplicate models: %s", exc, exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
|
||||
class ModelQueryHandler:
|
||||
@@ -429,12 +595,39 @@ class ModelQueryHandler:
|
||||
class ModelDownloadHandler:
|
||||
"""Coordinate downloads and progress reporting."""
|
||||
|
||||
def __init__(self, *, ws_manager: WebSocketManager, logger: logging.Logger) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ws_manager: WebSocketManager,
|
||||
logger: logging.Logger,
|
||||
download_coordinator: DownloadCoordinator,
|
||||
) -> None:
|
||||
self._ws_manager = ws_manager
|
||||
self._logger = logger
|
||||
self._download_coordinator = download_coordinator
|
||||
|
||||
async def download_model(self, request: web.Request) -> web.Response:
|
||||
return await ModelRouteUtils.handle_download_model(request)
|
||||
try:
|
||||
payload = await request.json()
|
||||
result = await self._download_coordinator.schedule_download(payload)
|
||||
if not result.get("success", False):
|
||||
return web.json_response(result, status=500)
|
||||
return web.json_response(result)
|
||||
except ValueError as exc:
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=400)
|
||||
except Exception as exc:
|
||||
error_message = str(exc)
|
||||
if "401" in error_message:
|
||||
self._logger.warning("Early access error (401): %s", error_message)
|
||||
return web.json_response(
|
||||
{
|
||||
"success": False,
|
||||
"error": "Early Access Restriction: This model requires purchase. Please buy early access on Civitai.com.",
|
||||
},
|
||||
status=401,
|
||||
)
|
||||
self._logger.error("Error downloading model: %s", error_message)
|
||||
return web.json_response({"success": False, "error": error_message}, status=500)
|
||||
|
||||
async def download_model_get(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
@@ -460,7 +653,12 @@ class ModelDownloadHandler:
|
||||
future.set_result(data)
|
||||
|
||||
mock_request = type("MockRequest", (), {"json": lambda self=None: future})()
|
||||
return await ModelRouteUtils.handle_download_model(mock_request)
|
||||
result = await self._download_coordinator.schedule_download(data)
|
||||
if not result.get("success", False):
|
||||
return web.json_response(result, status=500)
|
||||
return web.json_response(result)
|
||||
except ValueError as exc:
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=400)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error downloading model via GET: %s", exc, exc_info=True)
|
||||
return web.Response(status=500, text=str(exc))
|
||||
@@ -470,8 +668,8 @@ class ModelDownloadHandler:
|
||||
download_id = request.query.get("download_id")
|
||||
if not download_id:
|
||||
return web.json_response({"success": False, "error": "Download ID is required"}, status=400)
|
||||
mock_request = type("MockRequest", (), {"match_info": {"download_id": download_id}})()
|
||||
return await ModelRouteUtils.handle_cancel_download(mock_request)
|
||||
result = await self._download_coordinator.cancel_download(download_id)
|
||||
return web.json_response(result)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error cancelling download via GET: %s", exc, exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
@@ -504,6 +702,7 @@ class ModelCivitaiHandler:
|
||||
validate_model_type: Callable[[str], bool],
|
||||
expected_model_types: Callable[[], str],
|
||||
find_model_file: Callable[[Iterable[Mapping[str, object]]], Optional[Mapping[str, object]]],
|
||||
metadata_sync: MetadataSyncService,
|
||||
) -> None:
|
||||
self._service = service
|
||||
self._settings = settings_service
|
||||
@@ -513,6 +712,7 @@ class ModelCivitaiHandler:
|
||||
self._validate_model_type = validate_model_type
|
||||
self._expected_model_types = expected_model_types
|
||||
self._find_model_file = find_model_file
|
||||
self._metadata_sync = metadata_sync
|
||||
|
||||
async def fetch_all_civitai(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
@@ -545,7 +745,7 @@ class ModelCivitaiHandler:
|
||||
for model in to_process:
|
||||
try:
|
||||
original_name = model.get("model_name")
|
||||
result, error = await ModelRouteUtils.fetch_and_update_model(
|
||||
result, error = await self._metadata_sync.fetch_and_update_model(
|
||||
sha256=model["sha256"],
|
||||
file_path=model["file_path"],
|
||||
model_data=model,
|
||||
|
||||
Reference in New Issue
Block a user