refactor(routes): extract route utilities into services

This commit is contained in:
pixelpaws
2025-09-21 23:34:46 +08:00
parent 2d00cfdd31
commit 21772feadd
11 changed files with 1269 additions and 30 deletions

View File

@@ -8,12 +8,20 @@ import jinja2
from aiohttp import web
from ..config import config
from ..services.metadata_service import get_default_metadata_provider
from ..services.download_coordinator import DownloadCoordinator
from ..services.downloader import get_downloader
from ..services.metadata_service import get_default_metadata_provider, get_metadata_provider
from ..services.metadata_sync_service import MetadataSyncService
from ..services.model_file_service import ModelFileService, ModelMoveService
from ..services.preview_asset_service import PreviewAssetService
from ..services.server_i18n import server_i18n as default_server_i18n
from ..services.service_registry import ServiceRegistry
from ..services.settings_manager import settings as default_settings
from ..services.tag_update_service import TagUpdateService
from ..services.websocket_manager import ws_manager as default_ws_manager
from ..services.websocket_progress_callback import WebSocketProgressCallback
from ..services.server_i18n import server_i18n as default_server_i18n
from ..utils.exif_utils import ExifUtils
from ..utils.metadata_manager import MetadataManager
from ..utils.routes_common import ModelRouteUtils
from .model_route_registrar import COMMON_ROUTE_DEFINITIONS, ModelRouteRegistrar
from .handlers.model_handlers import (
@@ -64,6 +72,24 @@ class BaseModelRoutes(ABC):
self._handler_set: ModelHandlerSet | None = None
self._handler_mapping: Dict[str, Callable[[web.Request], web.StreamResponse]] | None = None
self._preview_service = PreviewAssetService(
metadata_manager=MetadataManager,
downloader_factory=get_downloader,
exif_utils=ExifUtils,
)
self._metadata_sync_service = MetadataSyncService(
metadata_manager=MetadataManager,
preview_service=self._preview_service,
settings=settings_service,
default_metadata_provider_factory=metadata_provider_factory,
metadata_provider_selector=get_metadata_provider,
)
self._tag_update_service = TagUpdateService(metadata_manager=MetadataManager)
self._download_coordinator = DownloadCoordinator(
ws_manager=self._ws_manager,
download_manager_factory=ServiceRegistry.get_download_manager,
)
if service is not None:
self.attach_service(service)
@@ -98,9 +124,19 @@ class BaseModelRoutes(ABC):
parse_specific_params=self._parse_specific_params,
logger=logger,
)
management = ModelManagementHandler(service=service, logger=logger)
management = ModelManagementHandler(
service=service,
logger=logger,
metadata_sync=self._metadata_sync_service,
preview_service=self._preview_service,
tag_update_service=self._tag_update_service,
)
query = ModelQueryHandler(service=service, logger=logger)
download = ModelDownloadHandler(ws_manager=self._ws_manager, logger=logger)
download = ModelDownloadHandler(
ws_manager=self._ws_manager,
logger=logger,
download_coordinator=self._download_coordinator,
)
civitai = ModelCivitaiHandler(
service=service,
settings_service=self._settings,
@@ -110,6 +146,7 @@ class BaseModelRoutes(ABC):
validate_model_type=self._validate_civitai_model_type,
expected_model_types=self._get_expected_model_types,
find_model_file=self._find_model_file,
metadata_sync=self._metadata_sync_service,
)
move = ModelMoveHandler(move_service=self._ensure_move_service(), logger=logger)
auto_organize = ModelAutoOrganizeHandler(

View File

@@ -4,16 +4,23 @@ from __future__ import annotations
import asyncio
import json
import logging
import os
from dataclasses import dataclass
from typing import Awaitable, Callable, Dict, Iterable, Mapping, Optional
from aiohttp import web
import jinja2
from ...config import config
from ...services.download_coordinator import DownloadCoordinator
from ...services.metadata_sync_service import MetadataSyncService
from ...services.model_file_service import ModelFileService, ModelMoveService
from ...services.websocket_progress_callback import WebSocketProgressCallback
from ...services.websocket_manager import WebSocketManager
from ...services.preview_asset_service import PreviewAssetService
from ...services.settings_manager import SettingsManager
from ...services.tag_update_service import TagUpdateService
from ...services.websocket_manager import WebSocketManager
from ...services.websocket_progress_callback import WebSocketProgressCallback
from ...utils.file_utils import calculate_sha256
from ...utils.routes_common import ModelRouteUtils
@@ -168,9 +175,20 @@ class ModelListingHandler:
class ModelManagementHandler:
"""Handle mutation operations on models."""
def __init__(self, *, service, logger: logging.Logger) -> None:
def __init__(
self,
*,
service,
logger: logging.Logger,
metadata_sync: MetadataSyncService,
preview_service: PreviewAssetService,
tag_update_service: TagUpdateService,
) -> None:
self._service = service
self._logger = logger
self._metadata_sync = metadata_sync
self._preview_service = preview_service
self._tag_update_service = tag_update_service
async def delete_model(self, request: web.Request) -> web.Response:
return await ModelRouteUtils.handle_delete_model(request, self._service.scanner)
@@ -192,7 +210,7 @@ class ModelManagementHandler:
if not model_data.get("sha256"):
return web.json_response({"success": False, "error": "No SHA256 hash found"}, status=400)
success, error = await ModelRouteUtils.fetch_and_update_model(
success, error = await self._metadata_sync.fetch_and_update_model(
sha256=model_data["sha256"],
file_path=file_path,
model_data=model_data,
@@ -208,16 +226,144 @@ class ModelManagementHandler:
return web.json_response({"success": False, "error": str(exc)}, status=500)
async def relink_civitai(self, request: web.Request) -> web.Response:
return await ModelRouteUtils.handle_relink_civitai(request, self._service.scanner)
try:
data = await request.json()
file_path = data.get("file_path")
model_id = data.get("model_id")
model_version_id = data.get("model_version_id")
if not file_path or model_id is None:
return web.json_response(
{"success": False, "error": "Both file_path and model_id are required"},
status=400,
)
metadata_path = os.path.splitext(file_path)[0] + ".metadata.json"
local_metadata = await self._metadata_sync.load_local_metadata(metadata_path)
updated_metadata = await self._metadata_sync.relink_metadata(
file_path=file_path,
metadata=local_metadata,
model_id=int(model_id),
model_version_id=int(model_version_id) if model_version_id else None,
)
await self._service.scanner.update_single_model_cache(
file_path, file_path, updated_metadata
)
message = (
f"Model successfully re-linked to Civitai model {model_id}"
+ (f" version {model_version_id}" if model_version_id else "")
)
return web.json_response(
{"success": True, "message": message, "hash": updated_metadata.get("sha256", "")}
)
except Exception as exc:
self._logger.error("Error re-linking to CivitAI: %s", exc, exc_info=True)
return web.json_response({"success": False, "error": str(exc)}, status=500)
async def replace_preview(self, request: web.Request) -> web.Response:
return await ModelRouteUtils.handle_replace_preview(request, self._service.scanner)
try:
reader = await request.multipart()
field = await reader.next()
if field is None or field.name != "preview_file":
raise ValueError("Expected 'preview_file' field")
content_type = field.headers.get("Content-Type", "image/png")
content_disposition = field.headers.get("Content-Disposition", "")
original_filename = None
import re
match = re.search(r'filename="(.*?)"', content_disposition)
if match:
original_filename = match.group(1)
preview_data = await field.read()
field = await reader.next()
if field is None or field.name != "model_path":
raise ValueError("Expected 'model_path' field")
model_path = (await field.read()).decode()
nsfw_level = 0
field = await reader.next()
if field and field.name == "nsfw_level":
try:
nsfw_level = int((await field.read()).decode())
except (ValueError, TypeError):
self._logger.warning("Invalid NSFW level format, using default 0")
result = await self._preview_service.replace_preview(
model_path=model_path,
preview_data=preview_data,
content_type=content_type,
original_filename=original_filename,
nsfw_level=nsfw_level,
update_preview_in_cache=self._service.scanner.update_preview_in_cache,
metadata_loader=self._metadata_sync.load_local_metadata,
)
return web.json_response(
{
"success": True,
"preview_url": config.get_preview_static_url(result["preview_path"]),
"preview_nsfw_level": result["preview_nsfw_level"],
}
)
except Exception as exc:
self._logger.error("Error replacing preview: %s", exc, exc_info=True)
return web.Response(text=str(exc), status=500)
async def save_metadata(self, request: web.Request) -> web.Response:
return await ModelRouteUtils.handle_save_metadata(request, self._service.scanner)
try:
data = await request.json()
file_path = data.get("file_path")
if not file_path:
return web.Response(text="File path is required", status=400)
metadata_updates = {k: v for k, v in data.items() if k != "file_path"}
await self._metadata_sync.save_metadata_updates(
file_path=file_path,
updates=metadata_updates,
metadata_loader=self._metadata_sync.load_local_metadata,
update_cache=self._service.scanner.update_single_model_cache,
)
if "model_name" in metadata_updates:
cache = await self._service.scanner.get_cached_data()
await cache.resort()
return web.json_response({"success": True})
except Exception as exc:
self._logger.error("Error saving metadata: %s", exc, exc_info=True)
return web.Response(text=str(exc), status=500)
async def add_tags(self, request: web.Request) -> web.Response:
return await ModelRouteUtils.handle_add_tags(request, self._service.scanner)
try:
data = await request.json()
file_path = data.get("file_path")
new_tags = data.get("tags", [])
if not file_path:
return web.Response(text="File path is required", status=400)
if not isinstance(new_tags, list):
return web.Response(text="Tags must be a list", status=400)
tags = await self._tag_update_service.add_tags(
file_path=file_path,
new_tags=new_tags,
metadata_loader=self._metadata_sync.load_local_metadata,
update_cache=self._service.scanner.update_single_model_cache,
)
return web.json_response({"success": True, "tags": tags})
except Exception as exc:
self._logger.error("Error adding tags: %s", exc, exc_info=True)
return web.Response(text=str(exc), status=500)
async def rename_model(self, request: web.Request) -> web.Response:
return await ModelRouteUtils.handle_rename_model(request, self._service.scanner)
@@ -226,7 +372,27 @@ class ModelManagementHandler:
return await ModelRouteUtils.handle_bulk_delete_models(request, self._service.scanner)
async def verify_duplicates(self, request: web.Request) -> web.Response:
return await ModelRouteUtils.handle_verify_duplicates(request, self._service.scanner)
try:
data = await request.json()
file_paths = data.get("file_paths", [])
if not file_paths:
return web.json_response(
{"success": False, "error": "No file paths provided for verification"},
status=400,
)
results = await self._metadata_sync.verify_duplicate_hashes(
file_paths=file_paths,
metadata_loader=self._metadata_sync.load_local_metadata,
hash_calculator=calculate_sha256,
update_cache=self._service.scanner.update_single_model_cache,
)
return web.json_response({"success": True, **results})
except Exception as exc:
self._logger.error("Error verifying duplicate models: %s", exc, exc_info=True)
return web.json_response({"success": False, "error": str(exc)}, status=500)
class ModelQueryHandler:
@@ -429,12 +595,39 @@ class ModelQueryHandler:
class ModelDownloadHandler:
"""Coordinate downloads and progress reporting."""
def __init__(self, *, ws_manager: WebSocketManager, logger: logging.Logger) -> None:
def __init__(
self,
*,
ws_manager: WebSocketManager,
logger: logging.Logger,
download_coordinator: DownloadCoordinator,
) -> None:
self._ws_manager = ws_manager
self._logger = logger
self._download_coordinator = download_coordinator
async def download_model(self, request: web.Request) -> web.Response:
return await ModelRouteUtils.handle_download_model(request)
try:
payload = await request.json()
result = await self._download_coordinator.schedule_download(payload)
if not result.get("success", False):
return web.json_response(result, status=500)
return web.json_response(result)
except ValueError as exc:
return web.json_response({"success": False, "error": str(exc)}, status=400)
except Exception as exc:
error_message = str(exc)
if "401" in error_message:
self._logger.warning("Early access error (401): %s", error_message)
return web.json_response(
{
"success": False,
"error": "Early Access Restriction: This model requires purchase. Please buy early access on Civitai.com.",
},
status=401,
)
self._logger.error("Error downloading model: %s", error_message)
return web.json_response({"success": False, "error": error_message}, status=500)
async def download_model_get(self, request: web.Request) -> web.Response:
try:
@@ -460,7 +653,12 @@ class ModelDownloadHandler:
future.set_result(data)
mock_request = type("MockRequest", (), {"json": lambda self=None: future})()
return await ModelRouteUtils.handle_download_model(mock_request)
result = await self._download_coordinator.schedule_download(data)
if not result.get("success", False):
return web.json_response(result, status=500)
return web.json_response(result)
except ValueError as exc:
return web.json_response({"success": False, "error": str(exc)}, status=400)
except Exception as exc:
self._logger.error("Error downloading model via GET: %s", exc, exc_info=True)
return web.Response(status=500, text=str(exc))
@@ -470,8 +668,8 @@ class ModelDownloadHandler:
download_id = request.query.get("download_id")
if not download_id:
return web.json_response({"success": False, "error": "Download ID is required"}, status=400)
mock_request = type("MockRequest", (), {"match_info": {"download_id": download_id}})()
return await ModelRouteUtils.handle_cancel_download(mock_request)
result = await self._download_coordinator.cancel_download(download_id)
return web.json_response(result)
except Exception as exc:
self._logger.error("Error cancelling download via GET: %s", exc, exc_info=True)
return web.json_response({"success": False, "error": str(exc)}, status=500)
@@ -504,6 +702,7 @@ class ModelCivitaiHandler:
validate_model_type: Callable[[str], bool],
expected_model_types: Callable[[], str],
find_model_file: Callable[[Iterable[Mapping[str, object]]], Optional[Mapping[str, object]]],
metadata_sync: MetadataSyncService,
) -> None:
self._service = service
self._settings = settings_service
@@ -513,6 +712,7 @@ class ModelCivitaiHandler:
self._validate_model_type = validate_model_type
self._expected_model_types = expected_model_types
self._find_model_file = find_model_file
self._metadata_sync = metadata_sync
async def fetch_all_civitai(self, request: web.Request) -> web.Response:
try:
@@ -545,7 +745,7 @@ class ModelCivitaiHandler:
for model in to_process:
try:
original_name = model.get("model_name")
result, error = await ModelRouteUtils.fetch_and_update_model(
result, error = await self._metadata_sync.fetch_and_update_model(
sha256=model["sha256"],
file_path=model["file_path"],
model_data=model,