feat(routes): extract orchestration use cases

This commit is contained in:
pixelpaws
2025-09-22 05:25:27 +08:00
parent 8cf99dd928
commit c063854b51
9 changed files with 609 additions and 112 deletions

View File

@@ -0,0 +1,25 @@
"""Application-level orchestration services for model routes."""
from .auto_organize_use_case import (
AutoOrganizeInProgressError,
AutoOrganizeUseCase,
)
from .bulk_metadata_refresh_use_case import (
BulkMetadataRefreshUseCase,
MetadataRefreshProgressReporter,
)
from .download_model_use_case import (
DownloadModelEarlyAccessError,
DownloadModelUseCase,
DownloadModelValidationError,
)
__all__ = [
"AutoOrganizeInProgressError",
"AutoOrganizeUseCase",
"BulkMetadataRefreshUseCase",
"MetadataRefreshProgressReporter",
"DownloadModelEarlyAccessError",
"DownloadModelUseCase",
"DownloadModelValidationError",
]

View File

@@ -0,0 +1,56 @@
"""Auto-organize use case orchestrating concurrency and progress handling."""
from __future__ import annotations
import asyncio
from typing import Optional, Protocol, Sequence
from ..model_file_service import AutoOrganizeResult, ModelFileService, ProgressCallback
class AutoOrganizeLockProvider(Protocol):
"""Minimal protocol for objects exposing auto-organize locking primitives."""
def is_auto_organize_running(self) -> bool:
"""Return ``True`` when an auto-organize operation is in-flight."""
async def get_auto_organize_lock(self) -> asyncio.Lock:
"""Return the asyncio lock guarding auto-organize operations."""
class AutoOrganizeInProgressError(RuntimeError):
"""Raised when an auto-organize run is already active."""
class AutoOrganizeUseCase:
"""Coordinate auto-organize execution behind a shared lock."""
def __init__(
self,
*,
file_service: ModelFileService,
lock_provider: AutoOrganizeLockProvider,
) -> None:
self._file_service = file_service
self._lock_provider = lock_provider
async def execute(
self,
*,
file_paths: Optional[Sequence[str]] = None,
progress_callback: Optional[ProgressCallback] = None,
) -> AutoOrganizeResult:
"""Run the auto-organize routine guarded by a shared lock."""
if self._lock_provider.is_auto_organize_running():
raise AutoOrganizeInProgressError("Auto-organize is already running")
lock = await self._lock_provider.get_auto_organize_lock()
if lock.locked():
raise AutoOrganizeInProgressError("Auto-organize is already running")
async with lock:
return await self._file_service.auto_organize_models(
file_paths=list(file_paths) if file_paths is not None else None,
progress_callback=progress_callback,
)

View File

@@ -0,0 +1,122 @@
"""Use case encapsulating the bulk metadata refresh orchestration."""
from __future__ import annotations
import logging
from typing import Any, Dict, Optional, Protocol, Sequence
from ..metadata_sync_service import MetadataSyncService
class MetadataRefreshProgressReporter(Protocol):
"""Protocol for progress reporters used during metadata refresh."""
async def on_progress(self, payload: Dict[str, Any]) -> None:
"""Handle a metadata refresh progress update."""
class BulkMetadataRefreshUseCase:
"""Coordinate bulk metadata refreshes with progress emission."""
def __init__(
self,
*,
service,
metadata_sync: MetadataSyncService,
settings_service,
logger: Optional[logging.Logger] = None,
) -> None:
self._service = service
self._metadata_sync = metadata_sync
self._settings = settings_service
self._logger = logger or logging.getLogger(__name__)
async def execute(
self,
*,
progress_callback: Optional[MetadataRefreshProgressReporter] = None,
) -> Dict[str, Any]:
"""Refresh metadata for all qualifying models."""
cache = await self._service.scanner.get_cached_data()
total_models = len(cache.raw_data)
enable_metadata_archive_db = self._settings.get("enable_metadata_archive_db", False)
to_process: Sequence[Dict[str, Any]] = [
model
for model in cache.raw_data
if model.get("sha256")
and (not model.get("civitai") or not model["civitai"].get("id"))
and (
(enable_metadata_archive_db and not model.get("db_checked", False))
or (not enable_metadata_archive_db and model.get("from_civitai") is True)
)
]
total_to_process = len(to_process)
processed = 0
success = 0
needs_resort = False
async def emit(status: str, **extra: Any) -> None:
if progress_callback is None:
return
payload = {"status": status, "total": total_to_process, "processed": processed, "success": success}
payload.update(extra)
await progress_callback.on_progress(payload)
await emit("started")
for model in to_process:
try:
original_name = model.get("model_name")
result, _ = await self._metadata_sync.fetch_and_update_model(
sha256=model["sha256"],
file_path=model["file_path"],
model_data=model,
update_cache_func=self._service.scanner.update_single_model_cache,
)
if result:
success += 1
if original_name != model.get("model_name"):
needs_resort = True
processed += 1
await emit(
"processing",
processed=processed,
success=success,
current_name=model.get("model_name", "Unknown"),
)
except Exception as exc: # pragma: no cover - logging path
processed += 1
self._logger.error(
"Error fetching CivitAI data for %s: %s",
model.get("file_path"),
exc,
)
if needs_resort:
await cache.resort()
await emit("completed", processed=processed, success=success)
message = (
"Successfully updated "
f"{success} of {processed} processed {self._service.model_type}s (total: {total_models})"
)
return {"success": True, "message": message, "processed": processed, "updated": success, "total": total_models}
async def execute_with_error_handling(
self,
*,
progress_callback: Optional[MetadataRefreshProgressReporter] = None,
) -> Dict[str, Any]:
"""Wrapper providing progress notification on unexpected failures."""
try:
return await self.execute(progress_callback=progress_callback)
except Exception as exc:
if progress_callback is not None:
await progress_callback.on_progress({"status": "error", "error": str(exc)})
raise

View File

@@ -0,0 +1,37 @@
"""Use case for scheduling model downloads with consistent error handling."""
from __future__ import annotations
from typing import Any, Dict
from ..download_coordinator import DownloadCoordinator
class DownloadModelValidationError(ValueError):
"""Raised when incoming payload validation fails."""
class DownloadModelEarlyAccessError(RuntimeError):
"""Raised when the download is gated behind Civitai early access."""
class DownloadModelUseCase:
"""Coordinate download scheduling through the coordinator service."""
def __init__(self, *, download_coordinator: DownloadCoordinator) -> None:
self._download_coordinator = download_coordinator
async def execute(self, payload: Dict[str, Any]) -> Dict[str, Any]:
"""Schedule a download and normalize error conditions."""
try:
return await self._download_coordinator.schedule_download(payload)
except ValueError as exc:
raise DownloadModelValidationError(str(exc)) from exc
except Exception as exc: # pragma: no cover - defensive logging path
message = str(exc)
if "401" in message:
raise DownloadModelEarlyAccessError(
"Early Access Restriction: This model requires purchase. Please buy early access on Civitai.com."
) from exc
raise

View File

@@ -1,11 +1,29 @@
from typing import Dict, Any
"""Progress callback implementations backed by the shared WebSocket manager."""
from typing import Any, Dict, Protocol
from .model_file_service import ProgressCallback
from .websocket_manager import ws_manager
class WebSocketProgressCallback(ProgressCallback):
"""WebSocket implementation of progress callback"""
class ProgressReporter(Protocol):
"""Protocol representing an async progress callback."""
async def on_progress(self, progress_data: Dict[str, Any]) -> None:
"""Send progress data via WebSocket"""
await ws_manager.broadcast_auto_organize_progress(progress_data)
"""Handle a progress update payload."""
class WebSocketProgressCallback(ProgressCallback):
"""WebSocket implementation of progress callback."""
async def on_progress(self, progress_data: Dict[str, Any]) -> None:
"""Send progress data via WebSocket."""
await ws_manager.broadcast_auto_organize_progress(progress_data)
class WebSocketBroadcastCallback:
"""Generic WebSocket progress callback broadcasting to all clients."""
async def on_progress(self, progress_data: Dict[str, Any]) -> None:
"""Send the provided payload to all connected clients."""
await ws_manager.broadcast(progress_data)