Merge pull request #452 from willmiao/codex/create-application-level-use-case-services

feat(routes): extract orchestration use cases
This commit is contained in:
pixelpaws
2025-09-22 05:27:19 +08:00
committed by GitHub
9 changed files with 609 additions and 112 deletions

View File

@@ -19,7 +19,15 @@ from ..services.service_registry import ServiceRegistry
from ..services.settings_manager import settings as default_settings
from ..services.tag_update_service import TagUpdateService
from ..services.websocket_manager import ws_manager as default_ws_manager
from ..services.websocket_progress_callback import WebSocketProgressCallback
from ..services.use_cases import (
AutoOrganizeUseCase,
BulkMetadataRefreshUseCase,
DownloadModelUseCase,
)
from ..services.websocket_progress_callback import (
WebSocketBroadcastCallback,
WebSocketProgressCallback,
)
from ..utils.exif_utils import ExifUtils
from ..utils.metadata_manager import MetadataManager
from ..utils.routes_common import ModelRouteUtils
@@ -68,6 +76,7 @@ class BaseModelRoutes(ABC):
self.model_file_service: ModelFileService | None = None
self.model_move_service: ModelMoveService | None = None
self.websocket_progress_callback = WebSocketProgressCallback()
self.metadata_progress_callback = WebSocketBroadcastCallback()
self._handler_set: ModelHandlerSet | None = None
self._handler_mapping: Dict[str, Callable[[web.Request], web.StreamResponse]] | None = None
@@ -132,11 +141,19 @@ class BaseModelRoutes(ABC):
tag_update_service=self._tag_update_service,
)
query = ModelQueryHandler(service=service, logger=logger)
download_use_case = DownloadModelUseCase(download_coordinator=self._download_coordinator)
download = ModelDownloadHandler(
ws_manager=self._ws_manager,
logger=logger,
download_use_case=download_use_case,
download_coordinator=self._download_coordinator,
)
metadata_refresh_use_case = BulkMetadataRefreshUseCase(
service=service,
metadata_sync=self._metadata_sync_service,
settings_service=self._settings,
logger=logger,
)
civitai = ModelCivitaiHandler(
service=service,
settings_service=self._settings,
@@ -147,10 +164,16 @@ class BaseModelRoutes(ABC):
expected_model_types=self._get_expected_model_types,
find_model_file=self._find_model_file,
metadata_sync=self._metadata_sync_service,
metadata_refresh_use_case=metadata_refresh_use_case,
metadata_progress_callback=self.metadata_progress_callback,
)
move = ModelMoveHandler(move_service=self._ensure_move_service(), logger=logger)
auto_organize = ModelAutoOrganizeHandler(
auto_organize_use_case = AutoOrganizeUseCase(
file_service=self._ensure_file_service(),
lock_provider=self._ws_manager,
)
auto_organize = ModelAutoOrganizeHandler(
use_case=auto_organize_use_case,
progress_callback=self.websocket_progress_callback,
ws_manager=self._ws_manager,
logger=logger,

View File

@@ -14,10 +14,19 @@ import jinja2
from ...config import config
from ...services.download_coordinator import DownloadCoordinator
from ...services.metadata_sync_service import MetadataSyncService
from ...services.model_file_service import ModelFileService, ModelMoveService
from ...services.model_file_service import ModelMoveService
from ...services.preview_asset_service import PreviewAssetService
from ...services.settings_manager import SettingsManager
from ...services.tag_update_service import TagUpdateService
from ...services.use_cases import (
AutoOrganizeInProgressError,
AutoOrganizeUseCase,
BulkMetadataRefreshUseCase,
DownloadModelEarlyAccessError,
DownloadModelUseCase,
DownloadModelValidationError,
MetadataRefreshProgressReporter,
)
from ...services.websocket_manager import WebSocketManager
from ...services.websocket_progress_callback import WebSocketProgressCallback
from ...utils.file_utils import calculate_sha256
@@ -600,33 +609,29 @@ class ModelDownloadHandler:
*,
ws_manager: WebSocketManager,
logger: logging.Logger,
download_use_case: DownloadModelUseCase,
download_coordinator: DownloadCoordinator,
) -> None:
self._ws_manager = ws_manager
self._logger = logger
self._download_use_case = download_use_case
self._download_coordinator = download_coordinator
async def download_model(self, request: web.Request) -> web.Response:
try:
payload = await request.json()
result = await self._download_coordinator.schedule_download(payload)
result = await self._download_use_case.execute(payload)
if not result.get("success", False):
return web.json_response(result, status=500)
return web.json_response(result)
except ValueError as exc:
except DownloadModelValidationError as exc:
return web.json_response({"success": False, "error": str(exc)}, status=400)
except DownloadModelEarlyAccessError as exc:
self._logger.warning("Early access error: %s", exc)
return web.json_response({"success": False, "error": str(exc)}, status=401)
except Exception as exc:
error_message = str(exc)
if "401" in error_message:
self._logger.warning("Early access error (401): %s", error_message)
return web.json_response(
{
"success": False,
"error": "Early Access Restriction: This model requires purchase. Please buy early access on Civitai.com.",
},
status=401,
)
self._logger.error("Error downloading model: %s", error_message)
self._logger.error("Error downloading model: %s", error_message, exc_info=True)
return web.json_response({"success": False, "error": error_message}, status=500)
async def download_model_get(self, request: web.Request) -> web.Response:
@@ -653,12 +658,15 @@ class ModelDownloadHandler:
future.set_result(data)
mock_request = type("MockRequest", (), {"json": lambda self=None: future})()
result = await self._download_coordinator.schedule_download(data)
result = await self._download_use_case.execute(data)
if not result.get("success", False):
return web.json_response(result, status=500)
return web.json_response(result)
except ValueError as exc:
except DownloadModelValidationError as exc:
return web.json_response({"success": False, "error": str(exc)}, status=400)
except DownloadModelEarlyAccessError as exc:
self._logger.warning("Early access error: %s", exc)
return web.json_response({"success": False, "error": str(exc)}, status=401)
except Exception as exc:
self._logger.error("Error downloading model via GET: %s", exc, exc_info=True)
return web.Response(status=500, text=str(exc))
@@ -703,6 +711,8 @@ class ModelCivitaiHandler:
expected_model_types: Callable[[], str],
find_model_file: Callable[[Iterable[Mapping[str, object]]], Optional[Mapping[str, object]]],
metadata_sync: MetadataSyncService,
metadata_refresh_use_case: BulkMetadataRefreshUseCase,
metadata_progress_callback: MetadataRefreshProgressReporter,
) -> None:
self._service = service
self._settings = settings_service
@@ -713,75 +723,16 @@ class ModelCivitaiHandler:
self._expected_model_types = expected_model_types
self._find_model_file = find_model_file
self._metadata_sync = metadata_sync
self._metadata_refresh_use_case = metadata_refresh_use_case
self._metadata_progress_callback = metadata_progress_callback
async def fetch_all_civitai(self, request: web.Request) -> web.Response:
try:
cache = await self._service.scanner.get_cached_data()
total = len(cache.raw_data)
processed = 0
success = 0
needs_resort = False
enable_metadata_archive_db = self._settings.get("enable_metadata_archive_db", False)
to_process = [
model
for model in cache.raw_data
if model.get("sha256")
and (not model.get("civitai") or not model["civitai"].get("id"))
and (
(enable_metadata_archive_db and not model.get("db_checked", False))
or (not enable_metadata_archive_db and model.get("from_civitai") is True)
)
]
total_to_process = len(to_process)
await self._ws_manager.broadcast({
"status": "started",
"total": total_to_process,
"processed": 0,
"success": 0,
})
for model in to_process:
try:
original_name = model.get("model_name")
result, error = await self._metadata_sync.fetch_and_update_model(
sha256=model["sha256"],
file_path=model["file_path"],
model_data=model,
update_cache_func=self._service.scanner.update_single_model_cache,
)
if result:
success += 1
if original_name != model.get("model_name"):
needs_resort = True
processed += 1
await self._ws_manager.broadcast({
"status": "processing",
"total": total_to_process,
"processed": processed,
"success": success,
"current_name": model.get("model_name", "Unknown"),
})
except Exception as exc: # pragma: no cover - logging path
self._logger.error("Error fetching CivitAI data for %s: %s", model["file_path"], exc)
if needs_resort:
await cache.resort()
await self._ws_manager.broadcast({
"status": "completed",
"total": total_to_process,
"processed": processed,
"success": success,
})
return web.json_response({
"success": True,
"message": f"Successfully updated {success} of {processed} processed {self._service.model_type}s (total: {total})",
})
result = await self._metadata_refresh_use_case.execute_with_error_handling(
progress_callback=self._metadata_progress_callback
)
return web.json_response(result)
except Exception as exc:
await self._ws_manager.broadcast({"status": "error", "error": str(exc)})
self._logger.error("Error in fetch_all_civitai for %ss: %s", self._service.model_type, exc)
return web.Response(text=str(exc), status=500)
@@ -887,31 +838,18 @@ class ModelAutoOrganizeHandler:
def __init__(
self,
*,
file_service: ModelFileService,
use_case: AutoOrganizeUseCase,
progress_callback: WebSocketProgressCallback,
ws_manager: WebSocketManager,
logger: logging.Logger,
) -> None:
self._file_service = file_service
self._use_case = use_case
self._progress_callback = progress_callback
self._ws_manager = ws_manager
self._logger = logger
async def auto_organize_models(self, request: web.Request) -> web.Response:
try:
if self._ws_manager.is_auto_organize_running():
return web.json_response(
{"success": False, "error": "Auto-organize is already running. Please wait for it to complete."},
status=409,
)
auto_organize_lock = await self._ws_manager.get_auto_organize_lock()
if auto_organize_lock.locked():
return web.json_response(
{"success": False, "error": "Auto-organize is already running. Please wait for it to complete."},
status=409,
)
file_paths = None
if request.method == "POST":
try:
@@ -920,17 +858,24 @@ class ModelAutoOrganizeHandler:
except Exception: # pragma: no cover - permissive path
pass
async with auto_organize_lock:
result = await self._file_service.auto_organize_models(
file_paths=file_paths,
progress_callback=self._progress_callback,
)
return web.json_response(result.to_dict())
result = await self._use_case.execute(
file_paths=file_paths,
progress_callback=self._progress_callback,
)
return web.json_response(result.to_dict())
except AutoOrganizeInProgressError:
return web.json_response(
{"success": False, "error": "Auto-organize is already running. Please wait for it to complete."},
status=409,
)
except Exception as exc:
self._logger.error("Error in auto_organize_models: %s", exc, exc_info=True)
await self._ws_manager.broadcast_auto_organize_progress(
{"type": "auto_organize_progress", "status": "error", "error": str(exc)}
)
try:
await self._progress_callback.on_progress(
{"type": "auto_organize_progress", "status": "error", "error": str(exc)}
)
except Exception: # pragma: no cover - defensive reporting
pass
return web.json_response({"success": False, "error": str(exc)}, status=500)
async def get_auto_organize_progress(self, request: web.Request) -> web.Response:

View File

@@ -0,0 +1,25 @@
"""Application-level orchestration services for model routes."""
from .auto_organize_use_case import (
AutoOrganizeInProgressError,
AutoOrganizeUseCase,
)
from .bulk_metadata_refresh_use_case import (
BulkMetadataRefreshUseCase,
MetadataRefreshProgressReporter,
)
from .download_model_use_case import (
DownloadModelEarlyAccessError,
DownloadModelUseCase,
DownloadModelValidationError,
)
__all__ = [
"AutoOrganizeInProgressError",
"AutoOrganizeUseCase",
"BulkMetadataRefreshUseCase",
"MetadataRefreshProgressReporter",
"DownloadModelEarlyAccessError",
"DownloadModelUseCase",
"DownloadModelValidationError",
]

View File

@@ -0,0 +1,56 @@
"""Auto-organize use case orchestrating concurrency and progress handling."""
from __future__ import annotations
import asyncio
from typing import Optional, Protocol, Sequence
from ..model_file_service import AutoOrganizeResult, ModelFileService, ProgressCallback
class AutoOrganizeLockProvider(Protocol):
"""Minimal protocol for objects exposing auto-organize locking primitives."""
def is_auto_organize_running(self) -> bool:
"""Return ``True`` when an auto-organize operation is in-flight."""
async def get_auto_organize_lock(self) -> asyncio.Lock:
"""Return the asyncio lock guarding auto-organize operations."""
class AutoOrganizeInProgressError(RuntimeError):
"""Raised when an auto-organize run is already active."""
class AutoOrganizeUseCase:
"""Coordinate auto-organize execution behind a shared lock."""
def __init__(
self,
*,
file_service: ModelFileService,
lock_provider: AutoOrganizeLockProvider,
) -> None:
self._file_service = file_service
self._lock_provider = lock_provider
async def execute(
self,
*,
file_paths: Optional[Sequence[str]] = None,
progress_callback: Optional[ProgressCallback] = None,
) -> AutoOrganizeResult:
"""Run the auto-organize routine guarded by a shared lock."""
if self._lock_provider.is_auto_organize_running():
raise AutoOrganizeInProgressError("Auto-organize is already running")
lock = await self._lock_provider.get_auto_organize_lock()
if lock.locked():
raise AutoOrganizeInProgressError("Auto-organize is already running")
async with lock:
return await self._file_service.auto_organize_models(
file_paths=list(file_paths) if file_paths is not None else None,
progress_callback=progress_callback,
)

View File

@@ -0,0 +1,122 @@
"""Use case encapsulating the bulk metadata refresh orchestration."""
from __future__ import annotations
import logging
from typing import Any, Dict, Optional, Protocol, Sequence
from ..metadata_sync_service import MetadataSyncService
class MetadataRefreshProgressReporter(Protocol):
"""Protocol for progress reporters used during metadata refresh."""
async def on_progress(self, payload: Dict[str, Any]) -> None:
"""Handle a metadata refresh progress update."""
class BulkMetadataRefreshUseCase:
"""Coordinate bulk metadata refreshes with progress emission."""
def __init__(
self,
*,
service,
metadata_sync: MetadataSyncService,
settings_service,
logger: Optional[logging.Logger] = None,
) -> None:
self._service = service
self._metadata_sync = metadata_sync
self._settings = settings_service
self._logger = logger or logging.getLogger(__name__)
async def execute(
self,
*,
progress_callback: Optional[MetadataRefreshProgressReporter] = None,
) -> Dict[str, Any]:
"""Refresh metadata for all qualifying models."""
cache = await self._service.scanner.get_cached_data()
total_models = len(cache.raw_data)
enable_metadata_archive_db = self._settings.get("enable_metadata_archive_db", False)
to_process: Sequence[Dict[str, Any]] = [
model
for model in cache.raw_data
if model.get("sha256")
and (not model.get("civitai") or not model["civitai"].get("id"))
and (
(enable_metadata_archive_db and not model.get("db_checked", False))
or (not enable_metadata_archive_db and model.get("from_civitai") is True)
)
]
total_to_process = len(to_process)
processed = 0
success = 0
needs_resort = False
async def emit(status: str, **extra: Any) -> None:
if progress_callback is None:
return
payload = {"status": status, "total": total_to_process, "processed": processed, "success": success}
payload.update(extra)
await progress_callback.on_progress(payload)
await emit("started")
for model in to_process:
try:
original_name = model.get("model_name")
result, _ = await self._metadata_sync.fetch_and_update_model(
sha256=model["sha256"],
file_path=model["file_path"],
model_data=model,
update_cache_func=self._service.scanner.update_single_model_cache,
)
if result:
success += 1
if original_name != model.get("model_name"):
needs_resort = True
processed += 1
await emit(
"processing",
processed=processed,
success=success,
current_name=model.get("model_name", "Unknown"),
)
except Exception as exc: # pragma: no cover - logging path
processed += 1
self._logger.error(
"Error fetching CivitAI data for %s: %s",
model.get("file_path"),
exc,
)
if needs_resort:
await cache.resort()
await emit("completed", processed=processed, success=success)
message = (
"Successfully updated "
f"{success} of {processed} processed {self._service.model_type}s (total: {total_models})"
)
return {"success": True, "message": message, "processed": processed, "updated": success, "total": total_models}
async def execute_with_error_handling(
self,
*,
progress_callback: Optional[MetadataRefreshProgressReporter] = None,
) -> Dict[str, Any]:
"""Wrapper providing progress notification on unexpected failures."""
try:
return await self.execute(progress_callback=progress_callback)
except Exception as exc:
if progress_callback is not None:
await progress_callback.on_progress({"status": "error", "error": str(exc)})
raise

View File

@@ -0,0 +1,37 @@
"""Use case for scheduling model downloads with consistent error handling."""
from __future__ import annotations
from typing import Any, Dict
from ..download_coordinator import DownloadCoordinator
class DownloadModelValidationError(ValueError):
"""Raised when incoming payload validation fails."""
class DownloadModelEarlyAccessError(RuntimeError):
"""Raised when the download is gated behind Civitai early access."""
class DownloadModelUseCase:
"""Coordinate download scheduling through the coordinator service."""
def __init__(self, *, download_coordinator: DownloadCoordinator) -> None:
self._download_coordinator = download_coordinator
async def execute(self, payload: Dict[str, Any]) -> Dict[str, Any]:
"""Schedule a download and normalize error conditions."""
try:
return await self._download_coordinator.schedule_download(payload)
except ValueError as exc:
raise DownloadModelValidationError(str(exc)) from exc
except Exception as exc: # pragma: no cover - defensive logging path
message = str(exc)
if "401" in message:
raise DownloadModelEarlyAccessError(
"Early Access Restriction: This model requires purchase. Please buy early access on Civitai.com."
) from exc
raise

View File

@@ -1,11 +1,29 @@
from typing import Dict, Any
"""Progress callback implementations backed by the shared WebSocket manager."""
from typing import Any, Dict, Protocol
from .model_file_service import ProgressCallback
from .websocket_manager import ws_manager
class WebSocketProgressCallback(ProgressCallback):
"""WebSocket implementation of progress callback"""
class ProgressReporter(Protocol):
"""Protocol representing an async progress callback."""
async def on_progress(self, progress_data: Dict[str, Any]) -> None:
"""Send progress data via WebSocket"""
await ws_manager.broadcast_auto_organize_progress(progress_data)
"""Handle a progress update payload."""
class WebSocketProgressCallback(ProgressCallback):
"""WebSocket implementation of progress callback."""
async def on_progress(self, progress_data: Dict[str, Any]) -> None:
"""Send progress data via WebSocket."""
await ws_manager.broadcast_auto_organize_progress(progress_data)
class WebSocketBroadcastCallback:
"""Generic WebSocket progress callback broadcasting to all clients."""
async def on_progress(self, progress_data: Dict[str, Any]) -> None:
"""Send the provided payload to all connected clients."""
await ws_manager.broadcast(progress_data)

View File

@@ -28,6 +28,7 @@ spec.loader.exec_module(py_local)
sys.modules.setdefault("py_local", py_local)
from py_local.routes.base_model_routes import BaseModelRoutes
from py_local.services.model_file_service import AutoOrganizeResult
from py_local.services.service_registry import ServiceRegistry
from py_local.services.websocket_manager import ws_manager
from py_local.utils.routes_common import ExifUtils
@@ -222,6 +223,25 @@ def test_download_model_invokes_download_manager(
asyncio.run(scenario())
def test_download_model_requires_identifier(mock_service, download_manager_stub):
async def scenario():
client = await create_test_client(mock_service)
try:
response = await client.post(
"/api/lm/download-model",
json={"model_root": "/tmp"},
)
payload = await response.json()
assert response.status == 400
assert payload["success"] is False
assert "Missing required" in payload["error"]
finally:
await client.close()
asyncio.run(scenario())
def test_auto_organize_progress_returns_latest_snapshot(mock_service):
async def scenario():
client = await create_test_client(mock_service)
@@ -235,5 +255,65 @@ def test_auto_organize_progress_returns_latest_snapshot(mock_service):
assert payload == {"success": True, "progress": {"status": "processing", "percent": 50}}
finally:
await client.close()
asyncio.run(scenario())
def test_auto_organize_route_emits_progress(mock_service, monkeypatch: pytest.MonkeyPatch):
async def fake_auto_organize(self, file_paths=None, progress_callback=None):
result = AutoOrganizeResult()
result.total = 1
result.processed = 1
result.success_count = 1
result.skipped_count = 0
result.failure_count = 0
result.operation_type = "bulk"
if progress_callback is not None:
await progress_callback.on_progress({"type": "auto_organize_progress", "status": "started"})
await progress_callback.on_progress({"type": "auto_organize_progress", "status": "completed"})
return result
monkeypatch.setattr(
py_local.services.model_file_service.ModelFileService,
"auto_organize_models",
fake_auto_organize,
)
async def scenario():
client = await create_test_client(mock_service)
try:
response = await client.post("/api/lm/test-models/auto-organize", json={"file_paths": []})
payload = await response.json()
assert response.status == 200
assert payload["success"] is True
progress = ws_manager.get_auto_organize_progress()
assert progress is not None
assert progress["status"] == "completed"
finally:
await client.close()
asyncio.run(scenario())
def test_auto_organize_conflict_when_running(mock_service):
async def scenario():
client = await create_test_client(mock_service)
try:
await ws_manager.broadcast_auto_organize_progress(
{"type": "auto_organize_progress", "status": "started"}
)
response = await client.post("/api/lm/test-models/auto-organize")
payload = await response.json()
assert response.status == 409
assert payload == {
"success": False,
"error": "Auto-organize is already running. Please wait for it to complete.",
}
finally:
await client.close()
asyncio.run(scenario())

View File

@@ -0,0 +1,191 @@
import asyncio
import logging
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
import pytest
from py_local.services.model_file_service import AutoOrganizeResult
from py_local.services.use_cases import (
AutoOrganizeInProgressError,
AutoOrganizeUseCase,
BulkMetadataRefreshUseCase,
DownloadModelEarlyAccessError,
DownloadModelUseCase,
DownloadModelValidationError,
)
from tests.conftest import MockModelService, MockScanner
class StubLockProvider:
def __init__(self) -> None:
self._lock = asyncio.Lock()
self.running = False
def is_auto_organize_running(self) -> bool:
return self.running
async def get_auto_organize_lock(self) -> asyncio.Lock:
return self._lock
class StubFileService:
def __init__(self) -> None:
self.calls: List[Dict[str, Any]] = []
async def auto_organize_models(
self,
*,
file_paths: Optional[List[str]] = None,
progress_callback=None,
) -> AutoOrganizeResult:
result = AutoOrganizeResult()
result.total = len(file_paths or [])
self.calls.append({"file_paths": file_paths, "progress_callback": progress_callback})
return result
class StubMetadataSync:
def __init__(self) -> None:
self.calls: List[Dict[str, Any]] = []
async def fetch_and_update_model(self, **kwargs: Any):
self.calls.append(kwargs)
model_data = kwargs["model_data"]
model_data["model_name"] = model_data.get("model_name", "model") + "-updated"
return True, None
@dataclass
class StubSettings:
enable_metadata_archive_db: bool = False
def get(self, key: str, default: Any = None) -> Any:
if key == "enable_metadata_archive_db":
return self.enable_metadata_archive_db
return default
class ProgressCollector:
def __init__(self) -> None:
self.events: List[Dict[str, Any]] = []
async def on_progress(self, payload: Dict[str, Any]) -> None:
self.events.append(payload)
class StubDownloadCoordinator:
def __init__(self, *, error: Optional[str] = None) -> None:
self.error = error
self.payloads: List[Dict[str, Any]] = []
async def schedule_download(self, payload: Dict[str, Any]) -> Dict[str, Any]:
self.payloads.append(payload)
if self.error == "validation":
raise ValueError("Missing required parameter: Please provide either 'model_id' or 'model_version_id'")
if self.error == "401":
raise RuntimeError("401 Unauthorized")
return {"success": True, "download_id": "abc123"}
async def test_auto_organize_use_case_executes_with_lock() -> None:
file_service = StubFileService()
lock_provider = StubLockProvider()
use_case = AutoOrganizeUseCase(file_service=file_service, lock_provider=lock_provider)
result = await use_case.execute(file_paths=["model1"], progress_callback=None)
assert isinstance(result, AutoOrganizeResult)
assert file_service.calls[0]["file_paths"] == ["model1"]
async def test_auto_organize_use_case_rejects_when_running() -> None:
file_service = StubFileService()
lock_provider = StubLockProvider()
lock_provider.running = True
use_case = AutoOrganizeUseCase(file_service=file_service, lock_provider=lock_provider)
with pytest.raises(AutoOrganizeInProgressError):
await use_case.execute(file_paths=None, progress_callback=None)
async def test_bulk_metadata_refresh_emits_progress_and_updates_cache() -> None:
scanner = MockScanner()
scanner._cache.raw_data = [
{
"file_path": "model1.safetensors",
"sha256": "hash",
"from_civitai": True,
"model_name": "Demo",
}
]
service = MockModelService(scanner)
metadata_sync = StubMetadataSync()
settings = StubSettings()
progress = ProgressCollector()
use_case = BulkMetadataRefreshUseCase(
service=service,
metadata_sync=metadata_sync,
settings_service=settings,
logger=logging.getLogger("test"),
)
result = await use_case.execute_with_error_handling(progress_callback=progress)
assert result["success"] is True
assert progress.events[0]["status"] == "started"
assert progress.events[-1]["status"] == "completed"
assert metadata_sync.calls
assert scanner._cache.resort_calls == 1
async def test_bulk_metadata_refresh_reports_errors() -> None:
class FailingScanner(MockScanner):
async def get_cached_data(self, force_refresh: bool = False):
raise RuntimeError("boom")
service = MockModelService(FailingScanner())
metadata_sync = StubMetadataSync()
settings = StubSettings()
progress = ProgressCollector()
use_case = BulkMetadataRefreshUseCase(
service=service,
metadata_sync=metadata_sync,
settings_service=settings,
logger=logging.getLogger("test"),
)
with pytest.raises(RuntimeError):
await use_case.execute_with_error_handling(progress_callback=progress)
assert progress.events
assert progress.events[-1]["status"] == "error"
assert progress.events[-1]["error"] == "boom"
async def test_download_model_use_case_raises_validation_error() -> None:
coordinator = StubDownloadCoordinator(error="validation")
use_case = DownloadModelUseCase(download_coordinator=coordinator)
with pytest.raises(DownloadModelValidationError):
await use_case.execute({})
async def test_download_model_use_case_raises_early_access() -> None:
coordinator = StubDownloadCoordinator(error="401")
use_case = DownloadModelUseCase(download_coordinator=coordinator)
with pytest.raises(DownloadModelEarlyAccessError):
await use_case.execute({"model_id": 1})
async def test_download_model_use_case_returns_result() -> None:
coordinator = StubDownloadCoordinator()
use_case = DownloadModelUseCase(download_coordinator=coordinator)
result = await use_case.execute({"model_id": 1})
assert result["success"] is True
assert result["download_id"] == "abc123"