From c063854b511be529e8a4ef8b70c28b2d557d7db3 Mon Sep 17 00:00:00 2001 From: pixelpaws Date: Mon, 22 Sep 2025 05:25:27 +0800 Subject: [PATCH] feat(routes): extract orchestration use cases --- py/routes/base_model_routes.py | 27 ++- py/routes/handlers/model_handlers.py | 153 +++++--------- py/services/use_cases/__init__.py | 25 +++ .../use_cases/auto_organize_use_case.py | 56 +++++ .../bulk_metadata_refresh_use_case.py | 122 +++++++++++ .../use_cases/download_model_use_case.py | 37 ++++ py/services/websocket_progress_callback.py | 30 ++- tests/routes/test_base_model_routes_smoke.py | 80 ++++++++ tests/services/test_use_cases.py | 191 ++++++++++++++++++ 9 files changed, 609 insertions(+), 112 deletions(-) create mode 100644 py/services/use_cases/__init__.py create mode 100644 py/services/use_cases/auto_organize_use_case.py create mode 100644 py/services/use_cases/bulk_metadata_refresh_use_case.py create mode 100644 py/services/use_cases/download_model_use_case.py create mode 100644 tests/services/test_use_cases.py diff --git a/py/routes/base_model_routes.py b/py/routes/base_model_routes.py index 65103ece..872dca8b 100644 --- a/py/routes/base_model_routes.py +++ b/py/routes/base_model_routes.py @@ -19,7 +19,15 @@ from ..services.service_registry import ServiceRegistry from ..services.settings_manager import settings as default_settings from ..services.tag_update_service import TagUpdateService from ..services.websocket_manager import ws_manager as default_ws_manager -from ..services.websocket_progress_callback import WebSocketProgressCallback +from ..services.use_cases import ( + AutoOrganizeUseCase, + BulkMetadataRefreshUseCase, + DownloadModelUseCase, +) +from ..services.websocket_progress_callback import ( + WebSocketBroadcastCallback, + WebSocketProgressCallback, +) from ..utils.exif_utils import ExifUtils from ..utils.metadata_manager import MetadataManager from ..utils.routes_common import ModelRouteUtils @@ -68,6 +76,7 @@ class BaseModelRoutes(ABC): self.model_file_service: ModelFileService | None = None self.model_move_service: ModelMoveService | None = None self.websocket_progress_callback = WebSocketProgressCallback() + self.metadata_progress_callback = WebSocketBroadcastCallback() self._handler_set: ModelHandlerSet | None = None self._handler_mapping: Dict[str, Callable[[web.Request], web.StreamResponse]] | None = None @@ -132,11 +141,19 @@ class BaseModelRoutes(ABC): tag_update_service=self._tag_update_service, ) query = ModelQueryHandler(service=service, logger=logger) + download_use_case = DownloadModelUseCase(download_coordinator=self._download_coordinator) download = ModelDownloadHandler( ws_manager=self._ws_manager, logger=logger, + download_use_case=download_use_case, download_coordinator=self._download_coordinator, ) + metadata_refresh_use_case = BulkMetadataRefreshUseCase( + service=service, + metadata_sync=self._metadata_sync_service, + settings_service=self._settings, + logger=logger, + ) civitai = ModelCivitaiHandler( service=service, settings_service=self._settings, @@ -147,10 +164,16 @@ class BaseModelRoutes(ABC): expected_model_types=self._get_expected_model_types, find_model_file=self._find_model_file, metadata_sync=self._metadata_sync_service, + metadata_refresh_use_case=metadata_refresh_use_case, + metadata_progress_callback=self.metadata_progress_callback, ) move = ModelMoveHandler(move_service=self._ensure_move_service(), logger=logger) - auto_organize = ModelAutoOrganizeHandler( + auto_organize_use_case = AutoOrganizeUseCase( file_service=self._ensure_file_service(), + lock_provider=self._ws_manager, + ) + auto_organize = ModelAutoOrganizeHandler( + use_case=auto_organize_use_case, progress_callback=self.websocket_progress_callback, ws_manager=self._ws_manager, logger=logger, diff --git a/py/routes/handlers/model_handlers.py b/py/routes/handlers/model_handlers.py index 5f9eaf3b..ba0628c6 100644 --- a/py/routes/handlers/model_handlers.py +++ b/py/routes/handlers/model_handlers.py @@ -14,10 +14,19 @@ import jinja2 from ...config import config from ...services.download_coordinator import DownloadCoordinator from ...services.metadata_sync_service import MetadataSyncService -from ...services.model_file_service import ModelFileService, ModelMoveService +from ...services.model_file_service import ModelMoveService from ...services.preview_asset_service import PreviewAssetService from ...services.settings_manager import SettingsManager from ...services.tag_update_service import TagUpdateService +from ...services.use_cases import ( + AutoOrganizeInProgressError, + AutoOrganizeUseCase, + BulkMetadataRefreshUseCase, + DownloadModelEarlyAccessError, + DownloadModelUseCase, + DownloadModelValidationError, + MetadataRefreshProgressReporter, +) from ...services.websocket_manager import WebSocketManager from ...services.websocket_progress_callback import WebSocketProgressCallback from ...utils.file_utils import calculate_sha256 @@ -600,33 +609,29 @@ class ModelDownloadHandler: *, ws_manager: WebSocketManager, logger: logging.Logger, + download_use_case: DownloadModelUseCase, download_coordinator: DownloadCoordinator, ) -> None: self._ws_manager = ws_manager self._logger = logger + self._download_use_case = download_use_case self._download_coordinator = download_coordinator async def download_model(self, request: web.Request) -> web.Response: try: payload = await request.json() - result = await self._download_coordinator.schedule_download(payload) + result = await self._download_use_case.execute(payload) if not result.get("success", False): return web.json_response(result, status=500) return web.json_response(result) - except ValueError as exc: + except DownloadModelValidationError as exc: return web.json_response({"success": False, "error": str(exc)}, status=400) + except DownloadModelEarlyAccessError as exc: + self._logger.warning("Early access error: %s", exc) + return web.json_response({"success": False, "error": str(exc)}, status=401) except Exception as exc: error_message = str(exc) - if "401" in error_message: - self._logger.warning("Early access error (401): %s", error_message) - return web.json_response( - { - "success": False, - "error": "Early Access Restriction: This model requires purchase. Please buy early access on Civitai.com.", - }, - status=401, - ) - self._logger.error("Error downloading model: %s", error_message) + self._logger.error("Error downloading model: %s", error_message, exc_info=True) return web.json_response({"success": False, "error": error_message}, status=500) async def download_model_get(self, request: web.Request) -> web.Response: @@ -653,12 +658,15 @@ class ModelDownloadHandler: future.set_result(data) mock_request = type("MockRequest", (), {"json": lambda self=None: future})() - result = await self._download_coordinator.schedule_download(data) + result = await self._download_use_case.execute(data) if not result.get("success", False): return web.json_response(result, status=500) return web.json_response(result) - except ValueError as exc: + except DownloadModelValidationError as exc: return web.json_response({"success": False, "error": str(exc)}, status=400) + except DownloadModelEarlyAccessError as exc: + self._logger.warning("Early access error: %s", exc) + return web.json_response({"success": False, "error": str(exc)}, status=401) except Exception as exc: self._logger.error("Error downloading model via GET: %s", exc, exc_info=True) return web.Response(status=500, text=str(exc)) @@ -703,6 +711,8 @@ class ModelCivitaiHandler: expected_model_types: Callable[[], str], find_model_file: Callable[[Iterable[Mapping[str, object]]], Optional[Mapping[str, object]]], metadata_sync: MetadataSyncService, + metadata_refresh_use_case: BulkMetadataRefreshUseCase, + metadata_progress_callback: MetadataRefreshProgressReporter, ) -> None: self._service = service self._settings = settings_service @@ -713,75 +723,16 @@ class ModelCivitaiHandler: self._expected_model_types = expected_model_types self._find_model_file = find_model_file self._metadata_sync = metadata_sync + self._metadata_refresh_use_case = metadata_refresh_use_case + self._metadata_progress_callback = metadata_progress_callback async def fetch_all_civitai(self, request: web.Request) -> web.Response: try: - cache = await self._service.scanner.get_cached_data() - total = len(cache.raw_data) - processed = 0 - success = 0 - needs_resort = False - - enable_metadata_archive_db = self._settings.get("enable_metadata_archive_db", False) - to_process = [ - model - for model in cache.raw_data - if model.get("sha256") - and (not model.get("civitai") or not model["civitai"].get("id")) - and ( - (enable_metadata_archive_db and not model.get("db_checked", False)) - or (not enable_metadata_archive_db and model.get("from_civitai") is True) - ) - ] - total_to_process = len(to_process) - - await self._ws_manager.broadcast({ - "status": "started", - "total": total_to_process, - "processed": 0, - "success": 0, - }) - - for model in to_process: - try: - original_name = model.get("model_name") - result, error = await self._metadata_sync.fetch_and_update_model( - sha256=model["sha256"], - file_path=model["file_path"], - model_data=model, - update_cache_func=self._service.scanner.update_single_model_cache, - ) - if result: - success += 1 - if original_name != model.get("model_name"): - needs_resort = True - processed += 1 - await self._ws_manager.broadcast({ - "status": "processing", - "total": total_to_process, - "processed": processed, - "success": success, - "current_name": model.get("model_name", "Unknown"), - }) - except Exception as exc: # pragma: no cover - logging path - self._logger.error("Error fetching CivitAI data for %s: %s", model["file_path"], exc) - - if needs_resort: - await cache.resort() - - await self._ws_manager.broadcast({ - "status": "completed", - "total": total_to_process, - "processed": processed, - "success": success, - }) - - return web.json_response({ - "success": True, - "message": f"Successfully updated {success} of {processed} processed {self._service.model_type}s (total: {total})", - }) + result = await self._metadata_refresh_use_case.execute_with_error_handling( + progress_callback=self._metadata_progress_callback + ) + return web.json_response(result) except Exception as exc: - await self._ws_manager.broadcast({"status": "error", "error": str(exc)}) self._logger.error("Error in fetch_all_civitai for %ss: %s", self._service.model_type, exc) return web.Response(text=str(exc), status=500) @@ -887,31 +838,18 @@ class ModelAutoOrganizeHandler: def __init__( self, *, - file_service: ModelFileService, + use_case: AutoOrganizeUseCase, progress_callback: WebSocketProgressCallback, ws_manager: WebSocketManager, logger: logging.Logger, ) -> None: - self._file_service = file_service + self._use_case = use_case self._progress_callback = progress_callback self._ws_manager = ws_manager self._logger = logger async def auto_organize_models(self, request: web.Request) -> web.Response: try: - if self._ws_manager.is_auto_organize_running(): - return web.json_response( - {"success": False, "error": "Auto-organize is already running. Please wait for it to complete."}, - status=409, - ) - - auto_organize_lock = await self._ws_manager.get_auto_organize_lock() - if auto_organize_lock.locked(): - return web.json_response( - {"success": False, "error": "Auto-organize is already running. Please wait for it to complete."}, - status=409, - ) - file_paths = None if request.method == "POST": try: @@ -920,17 +858,24 @@ class ModelAutoOrganizeHandler: except Exception: # pragma: no cover - permissive path pass - async with auto_organize_lock: - result = await self._file_service.auto_organize_models( - file_paths=file_paths, - progress_callback=self._progress_callback, - ) - return web.json_response(result.to_dict()) + result = await self._use_case.execute( + file_paths=file_paths, + progress_callback=self._progress_callback, + ) + return web.json_response(result.to_dict()) + except AutoOrganizeInProgressError: + return web.json_response( + {"success": False, "error": "Auto-organize is already running. Please wait for it to complete."}, + status=409, + ) except Exception as exc: self._logger.error("Error in auto_organize_models: %s", exc, exc_info=True) - await self._ws_manager.broadcast_auto_organize_progress( - {"type": "auto_organize_progress", "status": "error", "error": str(exc)} - ) + try: + await self._progress_callback.on_progress( + {"type": "auto_organize_progress", "status": "error", "error": str(exc)} + ) + except Exception: # pragma: no cover - defensive reporting + pass return web.json_response({"success": False, "error": str(exc)}, status=500) async def get_auto_organize_progress(self, request: web.Request) -> web.Response: diff --git a/py/services/use_cases/__init__.py b/py/services/use_cases/__init__.py new file mode 100644 index 00000000..986f0f57 --- /dev/null +++ b/py/services/use_cases/__init__.py @@ -0,0 +1,25 @@ +"""Application-level orchestration services for model routes.""" + +from .auto_organize_use_case import ( + AutoOrganizeInProgressError, + AutoOrganizeUseCase, +) +from .bulk_metadata_refresh_use_case import ( + BulkMetadataRefreshUseCase, + MetadataRefreshProgressReporter, +) +from .download_model_use_case import ( + DownloadModelEarlyAccessError, + DownloadModelUseCase, + DownloadModelValidationError, +) + +__all__ = [ + "AutoOrganizeInProgressError", + "AutoOrganizeUseCase", + "BulkMetadataRefreshUseCase", + "MetadataRefreshProgressReporter", + "DownloadModelEarlyAccessError", + "DownloadModelUseCase", + "DownloadModelValidationError", +] diff --git a/py/services/use_cases/auto_organize_use_case.py b/py/services/use_cases/auto_organize_use_case.py new file mode 100644 index 00000000..0914739f --- /dev/null +++ b/py/services/use_cases/auto_organize_use_case.py @@ -0,0 +1,56 @@ +"""Auto-organize use case orchestrating concurrency and progress handling.""" + +from __future__ import annotations + +import asyncio +from typing import Optional, Protocol, Sequence + +from ..model_file_service import AutoOrganizeResult, ModelFileService, ProgressCallback + + +class AutoOrganizeLockProvider(Protocol): + """Minimal protocol for objects exposing auto-organize locking primitives.""" + + def is_auto_organize_running(self) -> bool: + """Return ``True`` when an auto-organize operation is in-flight.""" + + async def get_auto_organize_lock(self) -> asyncio.Lock: + """Return the asyncio lock guarding auto-organize operations.""" + + +class AutoOrganizeInProgressError(RuntimeError): + """Raised when an auto-organize run is already active.""" + + +class AutoOrganizeUseCase: + """Coordinate auto-organize execution behind a shared lock.""" + + def __init__( + self, + *, + file_service: ModelFileService, + lock_provider: AutoOrganizeLockProvider, + ) -> None: + self._file_service = file_service + self._lock_provider = lock_provider + + async def execute( + self, + *, + file_paths: Optional[Sequence[str]] = None, + progress_callback: Optional[ProgressCallback] = None, + ) -> AutoOrganizeResult: + """Run the auto-organize routine guarded by a shared lock.""" + + if self._lock_provider.is_auto_organize_running(): + raise AutoOrganizeInProgressError("Auto-organize is already running") + + lock = await self._lock_provider.get_auto_organize_lock() + if lock.locked(): + raise AutoOrganizeInProgressError("Auto-organize is already running") + + async with lock: + return await self._file_service.auto_organize_models( + file_paths=list(file_paths) if file_paths is not None else None, + progress_callback=progress_callback, + ) diff --git a/py/services/use_cases/bulk_metadata_refresh_use_case.py b/py/services/use_cases/bulk_metadata_refresh_use_case.py new file mode 100644 index 00000000..6a809955 --- /dev/null +++ b/py/services/use_cases/bulk_metadata_refresh_use_case.py @@ -0,0 +1,122 @@ +"""Use case encapsulating the bulk metadata refresh orchestration.""" + +from __future__ import annotations + +import logging +from typing import Any, Dict, Optional, Protocol, Sequence + +from ..metadata_sync_service import MetadataSyncService + + +class MetadataRefreshProgressReporter(Protocol): + """Protocol for progress reporters used during metadata refresh.""" + + async def on_progress(self, payload: Dict[str, Any]) -> None: + """Handle a metadata refresh progress update.""" + + +class BulkMetadataRefreshUseCase: + """Coordinate bulk metadata refreshes with progress emission.""" + + def __init__( + self, + *, + service, + metadata_sync: MetadataSyncService, + settings_service, + logger: Optional[logging.Logger] = None, + ) -> None: + self._service = service + self._metadata_sync = metadata_sync + self._settings = settings_service + self._logger = logger or logging.getLogger(__name__) + + async def execute( + self, + *, + progress_callback: Optional[MetadataRefreshProgressReporter] = None, + ) -> Dict[str, Any]: + """Refresh metadata for all qualifying models.""" + + cache = await self._service.scanner.get_cached_data() + total_models = len(cache.raw_data) + + enable_metadata_archive_db = self._settings.get("enable_metadata_archive_db", False) + to_process: Sequence[Dict[str, Any]] = [ + model + for model in cache.raw_data + if model.get("sha256") + and (not model.get("civitai") or not model["civitai"].get("id")) + and ( + (enable_metadata_archive_db and not model.get("db_checked", False)) + or (not enable_metadata_archive_db and model.get("from_civitai") is True) + ) + ] + + total_to_process = len(to_process) + processed = 0 + success = 0 + needs_resort = False + + async def emit(status: str, **extra: Any) -> None: + if progress_callback is None: + return + payload = {"status": status, "total": total_to_process, "processed": processed, "success": success} + payload.update(extra) + await progress_callback.on_progress(payload) + + await emit("started") + + for model in to_process: + try: + original_name = model.get("model_name") + result, _ = await self._metadata_sync.fetch_and_update_model( + sha256=model["sha256"], + file_path=model["file_path"], + model_data=model, + update_cache_func=self._service.scanner.update_single_model_cache, + ) + if result: + success += 1 + if original_name != model.get("model_name"): + needs_resort = True + processed += 1 + await emit( + "processing", + processed=processed, + success=success, + current_name=model.get("model_name", "Unknown"), + ) + except Exception as exc: # pragma: no cover - logging path + processed += 1 + self._logger.error( + "Error fetching CivitAI data for %s: %s", + model.get("file_path"), + exc, + ) + + if needs_resort: + await cache.resort() + + await emit("completed", processed=processed, success=success) + + message = ( + "Successfully updated " + f"{success} of {processed} processed {self._service.model_type}s (total: {total_models})" + ) + + return {"success": True, "message": message, "processed": processed, "updated": success, "total": total_models} + + async def execute_with_error_handling( + self, + *, + progress_callback: Optional[MetadataRefreshProgressReporter] = None, + ) -> Dict[str, Any]: + """Wrapper providing progress notification on unexpected failures.""" + + try: + return await self.execute(progress_callback=progress_callback) + except Exception as exc: + if progress_callback is not None: + await progress_callback.on_progress({"status": "error", "error": str(exc)}) + raise diff --git a/py/services/use_cases/download_model_use_case.py b/py/services/use_cases/download_model_use_case.py new file mode 100644 index 00000000..5aa25bda --- /dev/null +++ b/py/services/use_cases/download_model_use_case.py @@ -0,0 +1,37 @@ +"""Use case for scheduling model downloads with consistent error handling.""" + +from __future__ import annotations + +from typing import Any, Dict + +from ..download_coordinator import DownloadCoordinator + + +class DownloadModelValidationError(ValueError): + """Raised when incoming payload validation fails.""" + + +class DownloadModelEarlyAccessError(RuntimeError): + """Raised when the download is gated behind Civitai early access.""" + + +class DownloadModelUseCase: + """Coordinate download scheduling through the coordinator service.""" + + def __init__(self, *, download_coordinator: DownloadCoordinator) -> None: + self._download_coordinator = download_coordinator + + async def execute(self, payload: Dict[str, Any]) -> Dict[str, Any]: + """Schedule a download and normalize error conditions.""" + + try: + return await self._download_coordinator.schedule_download(payload) + except ValueError as exc: + raise DownloadModelValidationError(str(exc)) from exc + except Exception as exc: # pragma: no cover - defensive logging path + message = str(exc) + if "401" in message: + raise DownloadModelEarlyAccessError( + "Early Access Restriction: This model requires purchase. Please buy early access on Civitai.com." + ) from exc + raise diff --git a/py/services/websocket_progress_callback.py b/py/services/websocket_progress_callback.py index 1a390f30..21423044 100644 --- a/py/services/websocket_progress_callback.py +++ b/py/services/websocket_progress_callback.py @@ -1,11 +1,29 @@ -from typing import Dict, Any +"""Progress callback implementations backed by the shared WebSocket manager.""" + +from typing import Any, Dict, Protocol + from .model_file_service import ProgressCallback from .websocket_manager import ws_manager -class WebSocketProgressCallback(ProgressCallback): - """WebSocket implementation of progress callback""" - +class ProgressReporter(Protocol): + """Protocol representing an async progress callback.""" + async def on_progress(self, progress_data: Dict[str, Any]) -> None: - """Send progress data via WebSocket""" - await ws_manager.broadcast_auto_organize_progress(progress_data) \ No newline at end of file + """Handle a progress update payload.""" + + +class WebSocketProgressCallback(ProgressCallback): + """WebSocket implementation of progress callback.""" + + async def on_progress(self, progress_data: Dict[str, Any]) -> None: + """Send progress data via WebSocket.""" + await ws_manager.broadcast_auto_organize_progress(progress_data) + + +class WebSocketBroadcastCallback: + """Generic WebSocket progress callback broadcasting to all clients.""" + + async def on_progress(self, progress_data: Dict[str, Any]) -> None: + """Send the provided payload to all connected clients.""" + await ws_manager.broadcast(progress_data) diff --git a/tests/routes/test_base_model_routes_smoke.py b/tests/routes/test_base_model_routes_smoke.py index 2b9ed805..25ebaabc 100644 --- a/tests/routes/test_base_model_routes_smoke.py +++ b/tests/routes/test_base_model_routes_smoke.py @@ -28,6 +28,7 @@ spec.loader.exec_module(py_local) sys.modules.setdefault("py_local", py_local) from py_local.routes.base_model_routes import BaseModelRoutes +from py_local.services.model_file_service import AutoOrganizeResult from py_local.services.service_registry import ServiceRegistry from py_local.services.websocket_manager import ws_manager from py_local.utils.routes_common import ExifUtils @@ -222,6 +223,25 @@ def test_download_model_invokes_download_manager( asyncio.run(scenario()) +def test_download_model_requires_identifier(mock_service, download_manager_stub): + async def scenario(): + client = await create_test_client(mock_service) + try: + response = await client.post( + "/api/lm/download-model", + json={"model_root": "/tmp"}, + ) + payload = await response.json() + + assert response.status == 400 + assert payload["success"] is False + assert "Missing required" in payload["error"] + finally: + await client.close() + + asyncio.run(scenario()) + + def test_auto_organize_progress_returns_latest_snapshot(mock_service): async def scenario(): client = await create_test_client(mock_service) @@ -235,5 +255,65 @@ def test_auto_organize_progress_returns_latest_snapshot(mock_service): assert payload == {"success": True, "progress": {"status": "processing", "percent": 50}} finally: await client.close() + + asyncio.run(scenario()) + + +def test_auto_organize_route_emits_progress(mock_service, monkeypatch: pytest.MonkeyPatch): + async def fake_auto_organize(self, file_paths=None, progress_callback=None): + result = AutoOrganizeResult() + result.total = 1 + result.processed = 1 + result.success_count = 1 + result.skipped_count = 0 + result.failure_count = 0 + result.operation_type = "bulk" + if progress_callback is not None: + await progress_callback.on_progress({"type": "auto_organize_progress", "status": "started"}) + await progress_callback.on_progress({"type": "auto_organize_progress", "status": "completed"}) + return result + + monkeypatch.setattr( + py_local.services.model_file_service.ModelFileService, + "auto_organize_models", + fake_auto_organize, + ) + + async def scenario(): + client = await create_test_client(mock_service) + try: + response = await client.post("/api/lm/test-models/auto-organize", json={"file_paths": []}) + payload = await response.json() + + assert response.status == 200 + assert payload["success"] is True + + progress = ws_manager.get_auto_organize_progress() + assert progress is not None + assert progress["status"] == "completed" + finally: + await client.close() + + asyncio.run(scenario()) + + +def test_auto_organize_conflict_when_running(mock_service): + async def scenario(): + client = await create_test_client(mock_service) + try: + await ws_manager.broadcast_auto_organize_progress( + {"type": "auto_organize_progress", "status": "started"} + ) + + response = await client.post("/api/lm/test-models/auto-organize") + payload = await response.json() + + assert response.status == 409 + assert payload == { + "success": False, + "error": "Auto-organize is already running. Please wait for it to complete.", + } + finally: + await client.close() asyncio.run(scenario()) diff --git a/tests/services/test_use_cases.py b/tests/services/test_use_cases.py new file mode 100644 index 00000000..64057fc6 --- /dev/null +++ b/tests/services/test_use_cases.py @@ -0,0 +1,191 @@ +import asyncio +import logging +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +import pytest + +from py_local.services.model_file_service import AutoOrganizeResult +from py_local.services.use_cases import ( + AutoOrganizeInProgressError, + AutoOrganizeUseCase, + BulkMetadataRefreshUseCase, + DownloadModelEarlyAccessError, + DownloadModelUseCase, + DownloadModelValidationError, +) +from tests.conftest import MockModelService, MockScanner + + +class StubLockProvider: + def __init__(self) -> None: + self._lock = asyncio.Lock() + self.running = False + + def is_auto_organize_running(self) -> bool: + return self.running + + async def get_auto_organize_lock(self) -> asyncio.Lock: + return self._lock + + +class StubFileService: + def __init__(self) -> None: + self.calls: List[Dict[str, Any]] = [] + + async def auto_organize_models( + self, + *, + file_paths: Optional[List[str]] = None, + progress_callback=None, + ) -> AutoOrganizeResult: + result = AutoOrganizeResult() + result.total = len(file_paths or []) + self.calls.append({"file_paths": file_paths, "progress_callback": progress_callback}) + return result + + +class StubMetadataSync: + def __init__(self) -> None: + self.calls: List[Dict[str, Any]] = [] + + async def fetch_and_update_model(self, **kwargs: Any): + self.calls.append(kwargs) + model_data = kwargs["model_data"] + model_data["model_name"] = model_data.get("model_name", "model") + "-updated" + return True, None + + +@dataclass +class StubSettings: + enable_metadata_archive_db: bool = False + + def get(self, key: str, default: Any = None) -> Any: + if key == "enable_metadata_archive_db": + return self.enable_metadata_archive_db + return default + + +class ProgressCollector: + def __init__(self) -> None: + self.events: List[Dict[str, Any]] = [] + + async def on_progress(self, payload: Dict[str, Any]) -> None: + self.events.append(payload) + + +class StubDownloadCoordinator: + def __init__(self, *, error: Optional[str] = None) -> None: + self.error = error + self.payloads: List[Dict[str, Any]] = [] + + async def schedule_download(self, payload: Dict[str, Any]) -> Dict[str, Any]: + self.payloads.append(payload) + if self.error == "validation": + raise ValueError("Missing required parameter: Please provide either 'model_id' or 'model_version_id'") + if self.error == "401": + raise RuntimeError("401 Unauthorized") + return {"success": True, "download_id": "abc123"} + + +async def test_auto_organize_use_case_executes_with_lock() -> None: + file_service = StubFileService() + lock_provider = StubLockProvider() + use_case = AutoOrganizeUseCase(file_service=file_service, lock_provider=lock_provider) + + result = await use_case.execute(file_paths=["model1"], progress_callback=None) + + assert isinstance(result, AutoOrganizeResult) + assert file_service.calls[0]["file_paths"] == ["model1"] + + +async def test_auto_organize_use_case_rejects_when_running() -> None: + file_service = StubFileService() + lock_provider = StubLockProvider() + lock_provider.running = True + use_case = AutoOrganizeUseCase(file_service=file_service, lock_provider=lock_provider) + + with pytest.raises(AutoOrganizeInProgressError): + await use_case.execute(file_paths=None, progress_callback=None) + + +async def test_bulk_metadata_refresh_emits_progress_and_updates_cache() -> None: + scanner = MockScanner() + scanner._cache.raw_data = [ + { + "file_path": "model1.safetensors", + "sha256": "hash", + "from_civitai": True, + "model_name": "Demo", + } + ] + service = MockModelService(scanner) + metadata_sync = StubMetadataSync() + settings = StubSettings() + progress = ProgressCollector() + + use_case = BulkMetadataRefreshUseCase( + service=service, + metadata_sync=metadata_sync, + settings_service=settings, + logger=logging.getLogger("test"), + ) + + result = await use_case.execute_with_error_handling(progress_callback=progress) + + assert result["success"] is True + assert progress.events[0]["status"] == "started" + assert progress.events[-1]["status"] == "completed" + assert metadata_sync.calls + assert scanner._cache.resort_calls == 1 + + +async def test_bulk_metadata_refresh_reports_errors() -> None: + class FailingScanner(MockScanner): + async def get_cached_data(self, force_refresh: bool = False): + raise RuntimeError("boom") + + service = MockModelService(FailingScanner()) + metadata_sync = StubMetadataSync() + settings = StubSettings() + progress = ProgressCollector() + + use_case = BulkMetadataRefreshUseCase( + service=service, + metadata_sync=metadata_sync, + settings_service=settings, + logger=logging.getLogger("test"), + ) + + with pytest.raises(RuntimeError): + await use_case.execute_with_error_handling(progress_callback=progress) + + assert progress.events + assert progress.events[-1]["status"] == "error" + assert progress.events[-1]["error"] == "boom" + + +async def test_download_model_use_case_raises_validation_error() -> None: + coordinator = StubDownloadCoordinator(error="validation") + use_case = DownloadModelUseCase(download_coordinator=coordinator) + + with pytest.raises(DownloadModelValidationError): + await use_case.execute({}) + + +async def test_download_model_use_case_raises_early_access() -> None: + coordinator = StubDownloadCoordinator(error="401") + use_case = DownloadModelUseCase(download_coordinator=coordinator) + + with pytest.raises(DownloadModelEarlyAccessError): + await use_case.execute({"model_id": 1}) + + +async def test_download_model_use_case_returns_result() -> None: + coordinator = StubDownloadCoordinator() + use_case = DownloadModelUseCase(download_coordinator=coordinator) + + result = await use_case.execute({"model_id": 1}) + + assert result["success"] is True + assert result["download_id"] == "abc123"