feat(routes): extract orchestration use cases

This commit is contained in:
pixelpaws
2025-09-22 05:25:27 +08:00
parent 8cf99dd928
commit c063854b51
9 changed files with 609 additions and 112 deletions

View File

@@ -19,7 +19,15 @@ from ..services.service_registry import ServiceRegistry
from ..services.settings_manager import settings as default_settings
from ..services.tag_update_service import TagUpdateService
from ..services.websocket_manager import ws_manager as default_ws_manager
from ..services.websocket_progress_callback import WebSocketProgressCallback
from ..services.use_cases import (
AutoOrganizeUseCase,
BulkMetadataRefreshUseCase,
DownloadModelUseCase,
)
from ..services.websocket_progress_callback import (
WebSocketBroadcastCallback,
WebSocketProgressCallback,
)
from ..utils.exif_utils import ExifUtils
from ..utils.metadata_manager import MetadataManager
from ..utils.routes_common import ModelRouteUtils
@@ -68,6 +76,7 @@ class BaseModelRoutes(ABC):
self.model_file_service: ModelFileService | None = None
self.model_move_service: ModelMoveService | None = None
self.websocket_progress_callback = WebSocketProgressCallback()
self.metadata_progress_callback = WebSocketBroadcastCallback()
self._handler_set: ModelHandlerSet | None = None
self._handler_mapping: Dict[str, Callable[[web.Request], web.StreamResponse]] | None = None
@@ -132,11 +141,19 @@ class BaseModelRoutes(ABC):
tag_update_service=self._tag_update_service,
)
query = ModelQueryHandler(service=service, logger=logger)
download_use_case = DownloadModelUseCase(download_coordinator=self._download_coordinator)
download = ModelDownloadHandler(
ws_manager=self._ws_manager,
logger=logger,
download_use_case=download_use_case,
download_coordinator=self._download_coordinator,
)
metadata_refresh_use_case = BulkMetadataRefreshUseCase(
service=service,
metadata_sync=self._metadata_sync_service,
settings_service=self._settings,
logger=logger,
)
civitai = ModelCivitaiHandler(
service=service,
settings_service=self._settings,
@@ -147,10 +164,16 @@ class BaseModelRoutes(ABC):
expected_model_types=self._get_expected_model_types,
find_model_file=self._find_model_file,
metadata_sync=self._metadata_sync_service,
metadata_refresh_use_case=metadata_refresh_use_case,
metadata_progress_callback=self.metadata_progress_callback,
)
move = ModelMoveHandler(move_service=self._ensure_move_service(), logger=logger)
auto_organize = ModelAutoOrganizeHandler(
auto_organize_use_case = AutoOrganizeUseCase(
file_service=self._ensure_file_service(),
lock_provider=self._ws_manager,
)
auto_organize = ModelAutoOrganizeHandler(
use_case=auto_organize_use_case,
progress_callback=self.websocket_progress_callback,
ws_manager=self._ws_manager,
logger=logger,

View File

@@ -14,10 +14,19 @@ import jinja2
from ...config import config
from ...services.download_coordinator import DownloadCoordinator
from ...services.metadata_sync_service import MetadataSyncService
from ...services.model_file_service import ModelFileService, ModelMoveService
from ...services.model_file_service import ModelMoveService
from ...services.preview_asset_service import PreviewAssetService
from ...services.settings_manager import SettingsManager
from ...services.tag_update_service import TagUpdateService
from ...services.use_cases import (
AutoOrganizeInProgressError,
AutoOrganizeUseCase,
BulkMetadataRefreshUseCase,
DownloadModelEarlyAccessError,
DownloadModelUseCase,
DownloadModelValidationError,
MetadataRefreshProgressReporter,
)
from ...services.websocket_manager import WebSocketManager
from ...services.websocket_progress_callback import WebSocketProgressCallback
from ...utils.file_utils import calculate_sha256
@@ -600,33 +609,29 @@ class ModelDownloadHandler:
*,
ws_manager: WebSocketManager,
logger: logging.Logger,
download_use_case: DownloadModelUseCase,
download_coordinator: DownloadCoordinator,
) -> None:
self._ws_manager = ws_manager
self._logger = logger
self._download_use_case = download_use_case
self._download_coordinator = download_coordinator
async def download_model(self, request: web.Request) -> web.Response:
try:
payload = await request.json()
result = await self._download_coordinator.schedule_download(payload)
result = await self._download_use_case.execute(payload)
if not result.get("success", False):
return web.json_response(result, status=500)
return web.json_response(result)
except ValueError as exc:
except DownloadModelValidationError as exc:
return web.json_response({"success": False, "error": str(exc)}, status=400)
except DownloadModelEarlyAccessError as exc:
self._logger.warning("Early access error: %s", exc)
return web.json_response({"success": False, "error": str(exc)}, status=401)
except Exception as exc:
error_message = str(exc)
if "401" in error_message:
self._logger.warning("Early access error (401): %s", error_message)
return web.json_response(
{
"success": False,
"error": "Early Access Restriction: This model requires purchase. Please buy early access on Civitai.com.",
},
status=401,
)
self._logger.error("Error downloading model: %s", error_message)
self._logger.error("Error downloading model: %s", error_message, exc_info=True)
return web.json_response({"success": False, "error": error_message}, status=500)
async def download_model_get(self, request: web.Request) -> web.Response:
@@ -653,12 +658,15 @@ class ModelDownloadHandler:
future.set_result(data)
mock_request = type("MockRequest", (), {"json": lambda self=None: future})()
result = await self._download_coordinator.schedule_download(data)
result = await self._download_use_case.execute(data)
if not result.get("success", False):
return web.json_response(result, status=500)
return web.json_response(result)
except ValueError as exc:
except DownloadModelValidationError as exc:
return web.json_response({"success": False, "error": str(exc)}, status=400)
except DownloadModelEarlyAccessError as exc:
self._logger.warning("Early access error: %s", exc)
return web.json_response({"success": False, "error": str(exc)}, status=401)
except Exception as exc:
self._logger.error("Error downloading model via GET: %s", exc, exc_info=True)
return web.Response(status=500, text=str(exc))
@@ -703,6 +711,8 @@ class ModelCivitaiHandler:
expected_model_types: Callable[[], str],
find_model_file: Callable[[Iterable[Mapping[str, object]]], Optional[Mapping[str, object]]],
metadata_sync: MetadataSyncService,
metadata_refresh_use_case: BulkMetadataRefreshUseCase,
metadata_progress_callback: MetadataRefreshProgressReporter,
) -> None:
self._service = service
self._settings = settings_service
@@ -713,75 +723,16 @@ class ModelCivitaiHandler:
self._expected_model_types = expected_model_types
self._find_model_file = find_model_file
self._metadata_sync = metadata_sync
self._metadata_refresh_use_case = metadata_refresh_use_case
self._metadata_progress_callback = metadata_progress_callback
async def fetch_all_civitai(self, request: web.Request) -> web.Response:
try:
cache = await self._service.scanner.get_cached_data()
total = len(cache.raw_data)
processed = 0
success = 0
needs_resort = False
enable_metadata_archive_db = self._settings.get("enable_metadata_archive_db", False)
to_process = [
model
for model in cache.raw_data
if model.get("sha256")
and (not model.get("civitai") or not model["civitai"].get("id"))
and (
(enable_metadata_archive_db and not model.get("db_checked", False))
or (not enable_metadata_archive_db and model.get("from_civitai") is True)
)
]
total_to_process = len(to_process)
await self._ws_manager.broadcast({
"status": "started",
"total": total_to_process,
"processed": 0,
"success": 0,
})
for model in to_process:
try:
original_name = model.get("model_name")
result, error = await self._metadata_sync.fetch_and_update_model(
sha256=model["sha256"],
file_path=model["file_path"],
model_data=model,
update_cache_func=self._service.scanner.update_single_model_cache,
)
if result:
success += 1
if original_name != model.get("model_name"):
needs_resort = True
processed += 1
await self._ws_manager.broadcast({
"status": "processing",
"total": total_to_process,
"processed": processed,
"success": success,
"current_name": model.get("model_name", "Unknown"),
})
except Exception as exc: # pragma: no cover - logging path
self._logger.error("Error fetching CivitAI data for %s: %s", model["file_path"], exc)
if needs_resort:
await cache.resort()
await self._ws_manager.broadcast({
"status": "completed",
"total": total_to_process,
"processed": processed,
"success": success,
})
return web.json_response({
"success": True,
"message": f"Successfully updated {success} of {processed} processed {self._service.model_type}s (total: {total})",
})
result = await self._metadata_refresh_use_case.execute_with_error_handling(
progress_callback=self._metadata_progress_callback
)
return web.json_response(result)
except Exception as exc:
await self._ws_manager.broadcast({"status": "error", "error": str(exc)})
self._logger.error("Error in fetch_all_civitai for %ss: %s", self._service.model_type, exc)
return web.Response(text=str(exc), status=500)
@@ -887,31 +838,18 @@ class ModelAutoOrganizeHandler:
def __init__(
self,
*,
file_service: ModelFileService,
use_case: AutoOrganizeUseCase,
progress_callback: WebSocketProgressCallback,
ws_manager: WebSocketManager,
logger: logging.Logger,
) -> None:
self._file_service = file_service
self._use_case = use_case
self._progress_callback = progress_callback
self._ws_manager = ws_manager
self._logger = logger
async def auto_organize_models(self, request: web.Request) -> web.Response:
try:
if self._ws_manager.is_auto_organize_running():
return web.json_response(
{"success": False, "error": "Auto-organize is already running. Please wait for it to complete."},
status=409,
)
auto_organize_lock = await self._ws_manager.get_auto_organize_lock()
if auto_organize_lock.locked():
return web.json_response(
{"success": False, "error": "Auto-organize is already running. Please wait for it to complete."},
status=409,
)
file_paths = None
if request.method == "POST":
try:
@@ -920,17 +858,24 @@ class ModelAutoOrganizeHandler:
except Exception: # pragma: no cover - permissive path
pass
async with auto_organize_lock:
result = await self._file_service.auto_organize_models(
file_paths=file_paths,
progress_callback=self._progress_callback,
)
return web.json_response(result.to_dict())
result = await self._use_case.execute(
file_paths=file_paths,
progress_callback=self._progress_callback,
)
return web.json_response(result.to_dict())
except AutoOrganizeInProgressError:
return web.json_response(
{"success": False, "error": "Auto-organize is already running. Please wait for it to complete."},
status=409,
)
except Exception as exc:
self._logger.error("Error in auto_organize_models: %s", exc, exc_info=True)
await self._ws_manager.broadcast_auto_organize_progress(
{"type": "auto_organize_progress", "status": "error", "error": str(exc)}
)
try:
await self._progress_callback.on_progress(
{"type": "auto_organize_progress", "status": "error", "error": str(exc)}
)
except Exception: # pragma: no cover - defensive reporting
pass
return web.json_response({"success": False, "error": str(exc)}, status=500)
async def get_auto_organize_progress(self, request: web.Request) -> web.Response: