feat(routes): extract orchestration use cases

This commit is contained in:
pixelpaws
2025-09-22 05:25:27 +08:00
parent 8cf99dd928
commit c063854b51
9 changed files with 609 additions and 112 deletions

View File

@@ -0,0 +1,191 @@
import asyncio
import logging
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
import pytest
from py_local.services.model_file_service import AutoOrganizeResult
from py_local.services.use_cases import (
AutoOrganizeInProgressError,
AutoOrganizeUseCase,
BulkMetadataRefreshUseCase,
DownloadModelEarlyAccessError,
DownloadModelUseCase,
DownloadModelValidationError,
)
from tests.conftest import MockModelService, MockScanner
class StubLockProvider:
def __init__(self) -> None:
self._lock = asyncio.Lock()
self.running = False
def is_auto_organize_running(self) -> bool:
return self.running
async def get_auto_organize_lock(self) -> asyncio.Lock:
return self._lock
class StubFileService:
def __init__(self) -> None:
self.calls: List[Dict[str, Any]] = []
async def auto_organize_models(
self,
*,
file_paths: Optional[List[str]] = None,
progress_callback=None,
) -> AutoOrganizeResult:
result = AutoOrganizeResult()
result.total = len(file_paths or [])
self.calls.append({"file_paths": file_paths, "progress_callback": progress_callback})
return result
class StubMetadataSync:
def __init__(self) -> None:
self.calls: List[Dict[str, Any]] = []
async def fetch_and_update_model(self, **kwargs: Any):
self.calls.append(kwargs)
model_data = kwargs["model_data"]
model_data["model_name"] = model_data.get("model_name", "model") + "-updated"
return True, None
@dataclass
class StubSettings:
enable_metadata_archive_db: bool = False
def get(self, key: str, default: Any = None) -> Any:
if key == "enable_metadata_archive_db":
return self.enable_metadata_archive_db
return default
class ProgressCollector:
def __init__(self) -> None:
self.events: List[Dict[str, Any]] = []
async def on_progress(self, payload: Dict[str, Any]) -> None:
self.events.append(payload)
class StubDownloadCoordinator:
def __init__(self, *, error: Optional[str] = None) -> None:
self.error = error
self.payloads: List[Dict[str, Any]] = []
async def schedule_download(self, payload: Dict[str, Any]) -> Dict[str, Any]:
self.payloads.append(payload)
if self.error == "validation":
raise ValueError("Missing required parameter: Please provide either 'model_id' or 'model_version_id'")
if self.error == "401":
raise RuntimeError("401 Unauthorized")
return {"success": True, "download_id": "abc123"}
async def test_auto_organize_use_case_executes_with_lock() -> None:
file_service = StubFileService()
lock_provider = StubLockProvider()
use_case = AutoOrganizeUseCase(file_service=file_service, lock_provider=lock_provider)
result = await use_case.execute(file_paths=["model1"], progress_callback=None)
assert isinstance(result, AutoOrganizeResult)
assert file_service.calls[0]["file_paths"] == ["model1"]
async def test_auto_organize_use_case_rejects_when_running() -> None:
file_service = StubFileService()
lock_provider = StubLockProvider()
lock_provider.running = True
use_case = AutoOrganizeUseCase(file_service=file_service, lock_provider=lock_provider)
with pytest.raises(AutoOrganizeInProgressError):
await use_case.execute(file_paths=None, progress_callback=None)
async def test_bulk_metadata_refresh_emits_progress_and_updates_cache() -> None:
scanner = MockScanner()
scanner._cache.raw_data = [
{
"file_path": "model1.safetensors",
"sha256": "hash",
"from_civitai": True,
"model_name": "Demo",
}
]
service = MockModelService(scanner)
metadata_sync = StubMetadataSync()
settings = StubSettings()
progress = ProgressCollector()
use_case = BulkMetadataRefreshUseCase(
service=service,
metadata_sync=metadata_sync,
settings_service=settings,
logger=logging.getLogger("test"),
)
result = await use_case.execute_with_error_handling(progress_callback=progress)
assert result["success"] is True
assert progress.events[0]["status"] == "started"
assert progress.events[-1]["status"] == "completed"
assert metadata_sync.calls
assert scanner._cache.resort_calls == 1
async def test_bulk_metadata_refresh_reports_errors() -> None:
class FailingScanner(MockScanner):
async def get_cached_data(self, force_refresh: bool = False):
raise RuntimeError("boom")
service = MockModelService(FailingScanner())
metadata_sync = StubMetadataSync()
settings = StubSettings()
progress = ProgressCollector()
use_case = BulkMetadataRefreshUseCase(
service=service,
metadata_sync=metadata_sync,
settings_service=settings,
logger=logging.getLogger("test"),
)
with pytest.raises(RuntimeError):
await use_case.execute_with_error_handling(progress_callback=progress)
assert progress.events
assert progress.events[-1]["status"] == "error"
assert progress.events[-1]["error"] == "boom"
async def test_download_model_use_case_raises_validation_error() -> None:
coordinator = StubDownloadCoordinator(error="validation")
use_case = DownloadModelUseCase(download_coordinator=coordinator)
with pytest.raises(DownloadModelValidationError):
await use_case.execute({})
async def test_download_model_use_case_raises_early_access() -> None:
coordinator = StubDownloadCoordinator(error="401")
use_case = DownloadModelUseCase(download_coordinator=coordinator)
with pytest.raises(DownloadModelEarlyAccessError):
await use_case.execute({"model_id": 1})
async def test_download_model_use_case_returns_result() -> None:
coordinator = StubDownloadCoordinator()
use_case = DownloadModelUseCase(download_coordinator=coordinator)
result = await use_case.execute({"model_id": 1})
assert result["success"] is True
assert result["download_id"] == "abc123"