refactor(routes): extract route utilities into services

This commit is contained in:
pixelpaws
2025-09-21 23:34:46 +08:00
parent 2d00cfdd31
commit 21772feadd
11 changed files with 1269 additions and 30 deletions

View File

@@ -4,16 +4,23 @@ from __future__ import annotations
import asyncio
import json
import logging
import os
from dataclasses import dataclass
from typing import Awaitable, Callable, Dict, Iterable, Mapping, Optional
from aiohttp import web
import jinja2
from ...config import config
from ...services.download_coordinator import DownloadCoordinator
from ...services.metadata_sync_service import MetadataSyncService
from ...services.model_file_service import ModelFileService, ModelMoveService
from ...services.websocket_progress_callback import WebSocketProgressCallback
from ...services.websocket_manager import WebSocketManager
from ...services.preview_asset_service import PreviewAssetService
from ...services.settings_manager import SettingsManager
from ...services.tag_update_service import TagUpdateService
from ...services.websocket_manager import WebSocketManager
from ...services.websocket_progress_callback import WebSocketProgressCallback
from ...utils.file_utils import calculate_sha256
from ...utils.routes_common import ModelRouteUtils
@@ -168,9 +175,20 @@ class ModelListingHandler:
class ModelManagementHandler:
"""Handle mutation operations on models."""
def __init__(self, *, service, logger: logging.Logger) -> None:
def __init__(
self,
*,
service,
logger: logging.Logger,
metadata_sync: MetadataSyncService,
preview_service: PreviewAssetService,
tag_update_service: TagUpdateService,
) -> None:
self._service = service
self._logger = logger
self._metadata_sync = metadata_sync
self._preview_service = preview_service
self._tag_update_service = tag_update_service
async def delete_model(self, request: web.Request) -> web.Response:
return await ModelRouteUtils.handle_delete_model(request, self._service.scanner)
@@ -192,7 +210,7 @@ class ModelManagementHandler:
if not model_data.get("sha256"):
return web.json_response({"success": False, "error": "No SHA256 hash found"}, status=400)
success, error = await ModelRouteUtils.fetch_and_update_model(
success, error = await self._metadata_sync.fetch_and_update_model(
sha256=model_data["sha256"],
file_path=file_path,
model_data=model_data,
@@ -208,16 +226,144 @@ class ModelManagementHandler:
return web.json_response({"success": False, "error": str(exc)}, status=500)
async def relink_civitai(self, request: web.Request) -> web.Response:
return await ModelRouteUtils.handle_relink_civitai(request, self._service.scanner)
try:
data = await request.json()
file_path = data.get("file_path")
model_id = data.get("model_id")
model_version_id = data.get("model_version_id")
if not file_path or model_id is None:
return web.json_response(
{"success": False, "error": "Both file_path and model_id are required"},
status=400,
)
metadata_path = os.path.splitext(file_path)[0] + ".metadata.json"
local_metadata = await self._metadata_sync.load_local_metadata(metadata_path)
updated_metadata = await self._metadata_sync.relink_metadata(
file_path=file_path,
metadata=local_metadata,
model_id=int(model_id),
model_version_id=int(model_version_id) if model_version_id else None,
)
await self._service.scanner.update_single_model_cache(
file_path, file_path, updated_metadata
)
message = (
f"Model successfully re-linked to Civitai model {model_id}"
+ (f" version {model_version_id}" if model_version_id else "")
)
return web.json_response(
{"success": True, "message": message, "hash": updated_metadata.get("sha256", "")}
)
except Exception as exc:
self._logger.error("Error re-linking to CivitAI: %s", exc, exc_info=True)
return web.json_response({"success": False, "error": str(exc)}, status=500)
async def replace_preview(self, request: web.Request) -> web.Response:
return await ModelRouteUtils.handle_replace_preview(request, self._service.scanner)
try:
reader = await request.multipart()
field = await reader.next()
if field is None or field.name != "preview_file":
raise ValueError("Expected 'preview_file' field")
content_type = field.headers.get("Content-Type", "image/png")
content_disposition = field.headers.get("Content-Disposition", "")
original_filename = None
import re
match = re.search(r'filename="(.*?)"', content_disposition)
if match:
original_filename = match.group(1)
preview_data = await field.read()
field = await reader.next()
if field is None or field.name != "model_path":
raise ValueError("Expected 'model_path' field")
model_path = (await field.read()).decode()
nsfw_level = 0
field = await reader.next()
if field and field.name == "nsfw_level":
try:
nsfw_level = int((await field.read()).decode())
except (ValueError, TypeError):
self._logger.warning("Invalid NSFW level format, using default 0")
result = await self._preview_service.replace_preview(
model_path=model_path,
preview_data=preview_data,
content_type=content_type,
original_filename=original_filename,
nsfw_level=nsfw_level,
update_preview_in_cache=self._service.scanner.update_preview_in_cache,
metadata_loader=self._metadata_sync.load_local_metadata,
)
return web.json_response(
{
"success": True,
"preview_url": config.get_preview_static_url(result["preview_path"]),
"preview_nsfw_level": result["preview_nsfw_level"],
}
)
except Exception as exc:
self._logger.error("Error replacing preview: %s", exc, exc_info=True)
return web.Response(text=str(exc), status=500)
async def save_metadata(self, request: web.Request) -> web.Response:
return await ModelRouteUtils.handle_save_metadata(request, self._service.scanner)
try:
data = await request.json()
file_path = data.get("file_path")
if not file_path:
return web.Response(text="File path is required", status=400)
metadata_updates = {k: v for k, v in data.items() if k != "file_path"}
await self._metadata_sync.save_metadata_updates(
file_path=file_path,
updates=metadata_updates,
metadata_loader=self._metadata_sync.load_local_metadata,
update_cache=self._service.scanner.update_single_model_cache,
)
if "model_name" in metadata_updates:
cache = await self._service.scanner.get_cached_data()
await cache.resort()
return web.json_response({"success": True})
except Exception as exc:
self._logger.error("Error saving metadata: %s", exc, exc_info=True)
return web.Response(text=str(exc), status=500)
async def add_tags(self, request: web.Request) -> web.Response:
return await ModelRouteUtils.handle_add_tags(request, self._service.scanner)
try:
data = await request.json()
file_path = data.get("file_path")
new_tags = data.get("tags", [])
if not file_path:
return web.Response(text="File path is required", status=400)
if not isinstance(new_tags, list):
return web.Response(text="Tags must be a list", status=400)
tags = await self._tag_update_service.add_tags(
file_path=file_path,
new_tags=new_tags,
metadata_loader=self._metadata_sync.load_local_metadata,
update_cache=self._service.scanner.update_single_model_cache,
)
return web.json_response({"success": True, "tags": tags})
except Exception as exc:
self._logger.error("Error adding tags: %s", exc, exc_info=True)
return web.Response(text=str(exc), status=500)
async def rename_model(self, request: web.Request) -> web.Response:
return await ModelRouteUtils.handle_rename_model(request, self._service.scanner)
@@ -226,7 +372,27 @@ class ModelManagementHandler:
return await ModelRouteUtils.handle_bulk_delete_models(request, self._service.scanner)
async def verify_duplicates(self, request: web.Request) -> web.Response:
return await ModelRouteUtils.handle_verify_duplicates(request, self._service.scanner)
try:
data = await request.json()
file_paths = data.get("file_paths", [])
if not file_paths:
return web.json_response(
{"success": False, "error": "No file paths provided for verification"},
status=400,
)
results = await self._metadata_sync.verify_duplicate_hashes(
file_paths=file_paths,
metadata_loader=self._metadata_sync.load_local_metadata,
hash_calculator=calculate_sha256,
update_cache=self._service.scanner.update_single_model_cache,
)
return web.json_response({"success": True, **results})
except Exception as exc:
self._logger.error("Error verifying duplicate models: %s", exc, exc_info=True)
return web.json_response({"success": False, "error": str(exc)}, status=500)
class ModelQueryHandler:
@@ -429,12 +595,39 @@ class ModelQueryHandler:
class ModelDownloadHandler:
"""Coordinate downloads and progress reporting."""
def __init__(self, *, ws_manager: WebSocketManager, logger: logging.Logger) -> None:
def __init__(
self,
*,
ws_manager: WebSocketManager,
logger: logging.Logger,
download_coordinator: DownloadCoordinator,
) -> None:
self._ws_manager = ws_manager
self._logger = logger
self._download_coordinator = download_coordinator
async def download_model(self, request: web.Request) -> web.Response:
return await ModelRouteUtils.handle_download_model(request)
try:
payload = await request.json()
result = await self._download_coordinator.schedule_download(payload)
if not result.get("success", False):
return web.json_response(result, status=500)
return web.json_response(result)
except ValueError as exc:
return web.json_response({"success": False, "error": str(exc)}, status=400)
except Exception as exc:
error_message = str(exc)
if "401" in error_message:
self._logger.warning("Early access error (401): %s", error_message)
return web.json_response(
{
"success": False,
"error": "Early Access Restriction: This model requires purchase. Please buy early access on Civitai.com.",
},
status=401,
)
self._logger.error("Error downloading model: %s", error_message)
return web.json_response({"success": False, "error": error_message}, status=500)
async def download_model_get(self, request: web.Request) -> web.Response:
try:
@@ -460,7 +653,12 @@ class ModelDownloadHandler:
future.set_result(data)
mock_request = type("MockRequest", (), {"json": lambda self=None: future})()
return await ModelRouteUtils.handle_download_model(mock_request)
result = await self._download_coordinator.schedule_download(data)
if not result.get("success", False):
return web.json_response(result, status=500)
return web.json_response(result)
except ValueError as exc:
return web.json_response({"success": False, "error": str(exc)}, status=400)
except Exception as exc:
self._logger.error("Error downloading model via GET: %s", exc, exc_info=True)
return web.Response(status=500, text=str(exc))
@@ -470,8 +668,8 @@ class ModelDownloadHandler:
download_id = request.query.get("download_id")
if not download_id:
return web.json_response({"success": False, "error": "Download ID is required"}, status=400)
mock_request = type("MockRequest", (), {"match_info": {"download_id": download_id}})()
return await ModelRouteUtils.handle_cancel_download(mock_request)
result = await self._download_coordinator.cancel_download(download_id)
return web.json_response(result)
except Exception as exc:
self._logger.error("Error cancelling download via GET: %s", exc, exc_info=True)
return web.json_response({"success": False, "error": str(exc)}, status=500)
@@ -504,6 +702,7 @@ class ModelCivitaiHandler:
validate_model_type: Callable[[str], bool],
expected_model_types: Callable[[], str],
find_model_file: Callable[[Iterable[Mapping[str, object]]], Optional[Mapping[str, object]]],
metadata_sync: MetadataSyncService,
) -> None:
self._service = service
self._settings = settings_service
@@ -513,6 +712,7 @@ class ModelCivitaiHandler:
self._validate_model_type = validate_model_type
self._expected_model_types = expected_model_types
self._find_model_file = find_model_file
self._metadata_sync = metadata_sync
async def fetch_all_civitai(self, request: web.Request) -> web.Response:
try:
@@ -545,7 +745,7 @@ class ModelCivitaiHandler:
for model in to_process:
try:
original_name = model.get("model_name")
result, error = await ModelRouteUtils.fetch_and_update_model(
result, error = await self._metadata_sync.fetch_and_update_model(
sha256=model["sha256"],
file_path=model["file_path"],
model_data=model,