mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
feat(routes): extract orchestration use cases
This commit is contained in:
@@ -28,6 +28,7 @@ spec.loader.exec_module(py_local)
|
||||
sys.modules.setdefault("py_local", py_local)
|
||||
|
||||
from py_local.routes.base_model_routes import BaseModelRoutes
|
||||
from py_local.services.model_file_service import AutoOrganizeResult
|
||||
from py_local.services.service_registry import ServiceRegistry
|
||||
from py_local.services.websocket_manager import ws_manager
|
||||
from py_local.utils.routes_common import ExifUtils
|
||||
@@ -222,6 +223,25 @@ def test_download_model_invokes_download_manager(
|
||||
asyncio.run(scenario())
|
||||
|
||||
|
||||
def test_download_model_requires_identifier(mock_service, download_manager_stub):
|
||||
async def scenario():
|
||||
client = await create_test_client(mock_service)
|
||||
try:
|
||||
response = await client.post(
|
||||
"/api/lm/download-model",
|
||||
json={"model_root": "/tmp"},
|
||||
)
|
||||
payload = await response.json()
|
||||
|
||||
assert response.status == 400
|
||||
assert payload["success"] is False
|
||||
assert "Missing required" in payload["error"]
|
||||
finally:
|
||||
await client.close()
|
||||
|
||||
asyncio.run(scenario())
|
||||
|
||||
|
||||
def test_auto_organize_progress_returns_latest_snapshot(mock_service):
|
||||
async def scenario():
|
||||
client = await create_test_client(mock_service)
|
||||
@@ -235,5 +255,65 @@ def test_auto_organize_progress_returns_latest_snapshot(mock_service):
|
||||
assert payload == {"success": True, "progress": {"status": "processing", "percent": 50}}
|
||||
finally:
|
||||
await client.close()
|
||||
|
||||
asyncio.run(scenario())
|
||||
|
||||
|
||||
def test_auto_organize_route_emits_progress(mock_service, monkeypatch: pytest.MonkeyPatch):
|
||||
async def fake_auto_organize(self, file_paths=None, progress_callback=None):
|
||||
result = AutoOrganizeResult()
|
||||
result.total = 1
|
||||
result.processed = 1
|
||||
result.success_count = 1
|
||||
result.skipped_count = 0
|
||||
result.failure_count = 0
|
||||
result.operation_type = "bulk"
|
||||
if progress_callback is not None:
|
||||
await progress_callback.on_progress({"type": "auto_organize_progress", "status": "started"})
|
||||
await progress_callback.on_progress({"type": "auto_organize_progress", "status": "completed"})
|
||||
return result
|
||||
|
||||
monkeypatch.setattr(
|
||||
py_local.services.model_file_service.ModelFileService,
|
||||
"auto_organize_models",
|
||||
fake_auto_organize,
|
||||
)
|
||||
|
||||
async def scenario():
|
||||
client = await create_test_client(mock_service)
|
||||
try:
|
||||
response = await client.post("/api/lm/test-models/auto-organize", json={"file_paths": []})
|
||||
payload = await response.json()
|
||||
|
||||
assert response.status == 200
|
||||
assert payload["success"] is True
|
||||
|
||||
progress = ws_manager.get_auto_organize_progress()
|
||||
assert progress is not None
|
||||
assert progress["status"] == "completed"
|
||||
finally:
|
||||
await client.close()
|
||||
|
||||
asyncio.run(scenario())
|
||||
|
||||
|
||||
def test_auto_organize_conflict_when_running(mock_service):
|
||||
async def scenario():
|
||||
client = await create_test_client(mock_service)
|
||||
try:
|
||||
await ws_manager.broadcast_auto_organize_progress(
|
||||
{"type": "auto_organize_progress", "status": "started"}
|
||||
)
|
||||
|
||||
response = await client.post("/api/lm/test-models/auto-organize")
|
||||
payload = await response.json()
|
||||
|
||||
assert response.status == 409
|
||||
assert payload == {
|
||||
"success": False,
|
||||
"error": "Auto-organize is already running. Please wait for it to complete.",
|
||||
}
|
||||
finally:
|
||||
await client.close()
|
||||
|
||||
asyncio.run(scenario())
|
||||
|
||||
191
tests/services/test_use_cases.py
Normal file
191
tests/services/test_use_cases.py
Normal file
@@ -0,0 +1,191 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from py_local.services.model_file_service import AutoOrganizeResult
|
||||
from py_local.services.use_cases import (
|
||||
AutoOrganizeInProgressError,
|
||||
AutoOrganizeUseCase,
|
||||
BulkMetadataRefreshUseCase,
|
||||
DownloadModelEarlyAccessError,
|
||||
DownloadModelUseCase,
|
||||
DownloadModelValidationError,
|
||||
)
|
||||
from tests.conftest import MockModelService, MockScanner
|
||||
|
||||
|
||||
class StubLockProvider:
|
||||
def __init__(self) -> None:
|
||||
self._lock = asyncio.Lock()
|
||||
self.running = False
|
||||
|
||||
def is_auto_organize_running(self) -> bool:
|
||||
return self.running
|
||||
|
||||
async def get_auto_organize_lock(self) -> asyncio.Lock:
|
||||
return self._lock
|
||||
|
||||
|
||||
class StubFileService:
|
||||
def __init__(self) -> None:
|
||||
self.calls: List[Dict[str, Any]] = []
|
||||
|
||||
async def auto_organize_models(
|
||||
self,
|
||||
*,
|
||||
file_paths: Optional[List[str]] = None,
|
||||
progress_callback=None,
|
||||
) -> AutoOrganizeResult:
|
||||
result = AutoOrganizeResult()
|
||||
result.total = len(file_paths or [])
|
||||
self.calls.append({"file_paths": file_paths, "progress_callback": progress_callback})
|
||||
return result
|
||||
|
||||
|
||||
class StubMetadataSync:
|
||||
def __init__(self) -> None:
|
||||
self.calls: List[Dict[str, Any]] = []
|
||||
|
||||
async def fetch_and_update_model(self, **kwargs: Any):
|
||||
self.calls.append(kwargs)
|
||||
model_data = kwargs["model_data"]
|
||||
model_data["model_name"] = model_data.get("model_name", "model") + "-updated"
|
||||
return True, None
|
||||
|
||||
|
||||
@dataclass
|
||||
class StubSettings:
|
||||
enable_metadata_archive_db: bool = False
|
||||
|
||||
def get(self, key: str, default: Any = None) -> Any:
|
||||
if key == "enable_metadata_archive_db":
|
||||
return self.enable_metadata_archive_db
|
||||
return default
|
||||
|
||||
|
||||
class ProgressCollector:
|
||||
def __init__(self) -> None:
|
||||
self.events: List[Dict[str, Any]] = []
|
||||
|
||||
async def on_progress(self, payload: Dict[str, Any]) -> None:
|
||||
self.events.append(payload)
|
||||
|
||||
|
||||
class StubDownloadCoordinator:
|
||||
def __init__(self, *, error: Optional[str] = None) -> None:
|
||||
self.error = error
|
||||
self.payloads: List[Dict[str, Any]] = []
|
||||
|
||||
async def schedule_download(self, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
self.payloads.append(payload)
|
||||
if self.error == "validation":
|
||||
raise ValueError("Missing required parameter: Please provide either 'model_id' or 'model_version_id'")
|
||||
if self.error == "401":
|
||||
raise RuntimeError("401 Unauthorized")
|
||||
return {"success": True, "download_id": "abc123"}
|
||||
|
||||
|
||||
async def test_auto_organize_use_case_executes_with_lock() -> None:
|
||||
file_service = StubFileService()
|
||||
lock_provider = StubLockProvider()
|
||||
use_case = AutoOrganizeUseCase(file_service=file_service, lock_provider=lock_provider)
|
||||
|
||||
result = await use_case.execute(file_paths=["model1"], progress_callback=None)
|
||||
|
||||
assert isinstance(result, AutoOrganizeResult)
|
||||
assert file_service.calls[0]["file_paths"] == ["model1"]
|
||||
|
||||
|
||||
async def test_auto_organize_use_case_rejects_when_running() -> None:
|
||||
file_service = StubFileService()
|
||||
lock_provider = StubLockProvider()
|
||||
lock_provider.running = True
|
||||
use_case = AutoOrganizeUseCase(file_service=file_service, lock_provider=lock_provider)
|
||||
|
||||
with pytest.raises(AutoOrganizeInProgressError):
|
||||
await use_case.execute(file_paths=None, progress_callback=None)
|
||||
|
||||
|
||||
async def test_bulk_metadata_refresh_emits_progress_and_updates_cache() -> None:
|
||||
scanner = MockScanner()
|
||||
scanner._cache.raw_data = [
|
||||
{
|
||||
"file_path": "model1.safetensors",
|
||||
"sha256": "hash",
|
||||
"from_civitai": True,
|
||||
"model_name": "Demo",
|
||||
}
|
||||
]
|
||||
service = MockModelService(scanner)
|
||||
metadata_sync = StubMetadataSync()
|
||||
settings = StubSettings()
|
||||
progress = ProgressCollector()
|
||||
|
||||
use_case = BulkMetadataRefreshUseCase(
|
||||
service=service,
|
||||
metadata_sync=metadata_sync,
|
||||
settings_service=settings,
|
||||
logger=logging.getLogger("test"),
|
||||
)
|
||||
|
||||
result = await use_case.execute_with_error_handling(progress_callback=progress)
|
||||
|
||||
assert result["success"] is True
|
||||
assert progress.events[0]["status"] == "started"
|
||||
assert progress.events[-1]["status"] == "completed"
|
||||
assert metadata_sync.calls
|
||||
assert scanner._cache.resort_calls == 1
|
||||
|
||||
|
||||
async def test_bulk_metadata_refresh_reports_errors() -> None:
|
||||
class FailingScanner(MockScanner):
|
||||
async def get_cached_data(self, force_refresh: bool = False):
|
||||
raise RuntimeError("boom")
|
||||
|
||||
service = MockModelService(FailingScanner())
|
||||
metadata_sync = StubMetadataSync()
|
||||
settings = StubSettings()
|
||||
progress = ProgressCollector()
|
||||
|
||||
use_case = BulkMetadataRefreshUseCase(
|
||||
service=service,
|
||||
metadata_sync=metadata_sync,
|
||||
settings_service=settings,
|
||||
logger=logging.getLogger("test"),
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
await use_case.execute_with_error_handling(progress_callback=progress)
|
||||
|
||||
assert progress.events
|
||||
assert progress.events[-1]["status"] == "error"
|
||||
assert progress.events[-1]["error"] == "boom"
|
||||
|
||||
|
||||
async def test_download_model_use_case_raises_validation_error() -> None:
|
||||
coordinator = StubDownloadCoordinator(error="validation")
|
||||
use_case = DownloadModelUseCase(download_coordinator=coordinator)
|
||||
|
||||
with pytest.raises(DownloadModelValidationError):
|
||||
await use_case.execute({})
|
||||
|
||||
|
||||
async def test_download_model_use_case_raises_early_access() -> None:
|
||||
coordinator = StubDownloadCoordinator(error="401")
|
||||
use_case = DownloadModelUseCase(download_coordinator=coordinator)
|
||||
|
||||
with pytest.raises(DownloadModelEarlyAccessError):
|
||||
await use_case.execute({"model_id": 1})
|
||||
|
||||
|
||||
async def test_download_model_use_case_returns_result() -> None:
|
||||
coordinator = StubDownloadCoordinator()
|
||||
use_case = DownloadModelUseCase(download_coordinator=coordinator)
|
||||
|
||||
result = await use_case.execute({"model_id": 1})
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["download_id"] == "abc123"
|
||||
Reference in New Issue
Block a user