feat(example-images): add use case orchestration

This commit is contained in:
pixelpaws
2025-09-23 11:47:12 +08:00
parent bd10280736
commit aaad270822
10 changed files with 582 additions and 262 deletions

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
import json
from contextlib import asynccontextmanager
from dataclasses import dataclass
from typing import Any, List, Tuple
@@ -33,37 +34,35 @@ class StubDownloadManager:
def __init__(self) -> None:
self.calls: List[Tuple[str, Any]] = []
async def start_download(self, request: web.Request) -> web.StreamResponse:
payload = await request.json()
async def start_download(self, payload: Any) -> dict:
self.calls.append(("start_download", payload))
return web.json_response({"operation": "start_download", "payload": payload})
return {"operation": "start_download", "payload": payload}
async def get_status(self, request: web.Request) -> web.StreamResponse:
async def get_status(self, request: web.Request) -> dict:
self.calls.append(("get_status", dict(request.query)))
return web.json_response({"operation": "get_status"})
return {"operation": "get_status"}
async def pause_download(self, request: web.Request) -> web.StreamResponse:
async def pause_download(self, request: web.Request) -> dict:
self.calls.append(("pause_download", None))
return web.json_response({"operation": "pause_download"})
return {"operation": "pause_download"}
async def resume_download(self, request: web.Request) -> web.StreamResponse:
async def resume_download(self, request: web.Request) -> dict:
self.calls.append(("resume_download", None))
return web.json_response({"operation": "resume_download"})
return {"operation": "resume_download"}
async def start_force_download(self, request: web.Request) -> web.StreamResponse:
payload = await request.json()
async def start_force_download(self, payload: Any) -> dict:
self.calls.append(("start_force_download", payload))
return web.json_response({"operation": "start_force_download", "payload": payload})
return {"operation": "start_force_download", "payload": payload}
class StubExampleImagesProcessor:
def __init__(self) -> None:
self.calls: List[Tuple[str, Any]] = []
async def import_images(self, request: web.Request) -> web.StreamResponse:
payload = await request.json()
async def import_images(self, model_hash: str, files: List[str]) -> dict:
payload = {"model_hash": model_hash, "file_paths": files}
self.calls.append(("import_images", payload))
return web.json_response({"operation": "import_images", "payload": payload})
return {"operation": "import_images", "payload": payload}
async def delete_custom_image(self, request: web.Request) -> web.StreamResponse:
payload = await request.json()
@@ -184,7 +183,7 @@ async def test_pause_and_resume_routes_delegate():
async def test_import_route_delegates_to_processor():
payload = {"model_hash": "abc123", "files": ["/path/image.png"]}
payload = {"model_hash": "abc123", "file_paths": ["/path/image.png"]}
async with example_images_app() as harness:
response = await harness.client.post(
"/api/lm/import-example-images", json=payload
@@ -193,7 +192,8 @@ async def test_import_route_delegates_to_processor():
assert response.status == 200
assert body == {"operation": "import_images", "payload": payload}
assert harness.processor.calls == [("import_images", payload)]
expected_call = ("import_images", payload)
assert expected_call in harness.processor.calls
async def test_delete_route_delegates_to_processor():
@@ -251,70 +251,91 @@ async def test_download_handler_methods_delegate() -> None:
def __init__(self) -> None:
self.calls: List[Tuple[str, Any]] = []
async def start_download(self, request) -> str:
self.calls.append(("start_download", request))
return "download"
async def get_status(self, request) -> str:
async def get_status(self, request) -> dict:
self.calls.append(("get_status", request))
return "status"
return {"status": "ok"}
async def pause_download(self, request) -> str:
async def pause_download(self, request) -> dict:
self.calls.append(("pause_download", request))
return "pause"
return {"status": "paused"}
async def resume_download(self, request) -> str:
async def resume_download(self, request) -> dict:
self.calls.append(("resume_download", request))
return "resume"
return {"status": "running"}
async def start_force_download(self, request) -> str:
self.calls.append(("start_force_download", request))
return "force"
async def start_force_download(self, payload) -> dict:
self.calls.append(("start_force_download", payload))
return {"status": "force", "payload": payload}
class StubDownloadUseCase:
def __init__(self) -> None:
self.payloads: List[Any] = []
async def execute(self, payload: dict) -> dict:
self.payloads.append(payload)
return {"status": "started", "payload": payload}
class DummyRequest:
def __init__(self, payload: dict) -> None:
self._payload = payload
self.query = {}
async def json(self) -> dict:
return self._payload
recorder = Recorder()
handler = ExampleImagesDownloadHandler(recorder)
request = object()
use_case = StubDownloadUseCase()
handler = ExampleImagesDownloadHandler(use_case, recorder)
request = DummyRequest({"foo": "bar"})
assert await handler.download_example_images(request) == "download"
assert await handler.get_example_images_status(request) == "status"
assert await handler.pause_example_images(request) == "pause"
assert await handler.resume_example_images(request) == "resume"
assert await handler.force_download_example_images(request) == "force"
download_response = await handler.download_example_images(request)
assert json.loads(download_response.text) == {"status": "started", "payload": {"foo": "bar"}}
status_response = await handler.get_example_images_status(request)
assert json.loads(status_response.text) == {"status": "ok"}
pause_response = await handler.pause_example_images(request)
assert json.loads(pause_response.text) == {"status": "paused"}
resume_response = await handler.resume_example_images(request)
assert json.loads(resume_response.text) == {"status": "running"}
force_response = await handler.force_download_example_images(request)
assert json.loads(force_response.text) == {"status": "force", "payload": {"foo": "bar"}}
expected = [
("start_download", request),
assert use_case.payloads == [{"foo": "bar"}]
assert recorder.calls == [
("get_status", request),
("pause_download", request),
("resume_download", request),
("start_force_download", request),
("start_force_download", {"foo": "bar"}),
]
assert recorder.calls == expected
@pytest.mark.asyncio
async def test_management_handler_methods_delegate() -> None:
class StubImportUseCase:
def __init__(self) -> None:
self.requests: List[Any] = []
async def execute(self, request: Any) -> dict:
self.requests.append(request)
return {"status": "imported"}
class Recorder:
def __init__(self) -> None:
self.calls: List[Tuple[str, Any]] = []
async def import_images(self, request) -> str:
self.calls.append(("import_images", request))
return "import"
async def delete_custom_image(self, request) -> str:
self.calls.append(("delete_custom_image", request))
return "delete"
recorder = Recorder()
handler = ExampleImagesManagementHandler(recorder)
use_case = StubImportUseCase()
handler = ExampleImagesManagementHandler(use_case, recorder)
request = object()
assert await handler.import_example_images(request) == "import"
import_response = await handler.import_example_images(request)
assert json.loads(import_response.text) == {"status": "imported"}
assert await handler.delete_example_image(request) == "delete"
assert recorder.calls == [
("import_images", request),
("delete_custom_image", request),
]
assert use_case.requests == [request]
assert recorder.calls == [("delete_custom_image", request)]
@pytest.mark.asyncio
@@ -350,8 +371,29 @@ async def test_file_handler_methods_delegate() -> None:
def test_handler_set_route_mapping_includes_all_handlers() -> None:
download = ExampleImagesDownloadHandler(object())
management = ExampleImagesManagementHandler(object())
class DummyUseCase:
async def execute(self, payload):
return payload
class DummyManager:
async def get_status(self, request):
return {}
async def pause_download(self, request):
return {}
async def resume_download(self, request):
return {}
async def start_force_download(self, payload):
return payload
class DummyProcessor:
async def delete_custom_image(self, request):
return {}
download = ExampleImagesDownloadHandler(DummyUseCase(), DummyManager())
management = ExampleImagesManagementHandler(DummyUseCase(), DummyProcessor())
files = ExampleImagesFileHandler(object())
handler_set = ExampleImagesHandlerSet(
download=download,

View File

@@ -10,9 +10,23 @@ from py_local.services.use_cases import (
AutoOrganizeInProgressError,
AutoOrganizeUseCase,
BulkMetadataRefreshUseCase,
DownloadExampleImagesConfigurationError,
DownloadExampleImagesInProgressError,
DownloadExampleImagesUseCase,
DownloadModelEarlyAccessError,
DownloadModelUseCase,
DownloadModelValidationError,
ImportExampleImagesUseCase,
ImportExampleImagesValidationError,
)
from py_local.utils.example_images_download_manager import (
DownloadConfigurationError,
DownloadInProgressError,
ExampleImagesDownloadError,
)
from py_local.utils.example_images_processor import (
ExampleImagesImportError,
ExampleImagesValidationError,
)
from tests.conftest import MockModelService, MockScanner
@@ -88,6 +102,38 @@ class StubDownloadCoordinator:
return {"success": True, "download_id": "abc123"}
class StubExampleImagesDownloadManager:
def __init__(self) -> None:
self.payloads: List[Dict[str, Any]] = []
self.error: Optional[str] = None
self.progress_snapshot = {"status": "running"}
async def start_download(self, payload: Dict[str, Any]) -> Dict[str, Any]:
self.payloads.append(payload)
if self.error == "in_progress":
raise DownloadInProgressError(self.progress_snapshot)
if self.error == "configuration":
raise DownloadConfigurationError("path missing")
if self.error == "generic":
raise ExampleImagesDownloadError("boom")
return {"success": True, "message": "ok"}
class StubExampleImagesProcessor:
def __init__(self) -> None:
self.calls: List[Dict[str, Any]] = []
self.error: Optional[str] = None
self.response: Dict[str, Any] = {"success": True}
async def import_images(self, model_hash: str, files: List[str]) -> Dict[str, Any]:
self.calls.append({"model_hash": model_hash, "files": files})
if self.error == "validation":
raise ExampleImagesValidationError("missing")
if self.error == "generic":
raise ExampleImagesImportError("boom")
return self.response
async def test_auto_organize_use_case_executes_with_lock() -> None:
file_service = StubFileService()
lock_provider = StubLockProvider()
@@ -189,3 +235,83 @@ async def test_download_model_use_case_returns_result() -> None:
assert result["success"] is True
assert result["download_id"] == "abc123"
async def test_download_example_images_use_case_triggers_manager() -> None:
manager = StubExampleImagesDownloadManager()
use_case = DownloadExampleImagesUseCase(download_manager=manager)
payload = {"optimize": True}
result = await use_case.execute(payload)
assert manager.payloads == [payload]
assert result == {"success": True, "message": "ok"}
async def test_download_example_images_use_case_maps_in_progress() -> None:
manager = StubExampleImagesDownloadManager()
manager.error = "in_progress"
use_case = DownloadExampleImagesUseCase(download_manager=manager)
with pytest.raises(DownloadExampleImagesInProgressError) as exc:
await use_case.execute({})
assert exc.value.progress == manager.progress_snapshot
async def test_download_example_images_use_case_maps_configuration() -> None:
manager = StubExampleImagesDownloadManager()
manager.error = "configuration"
use_case = DownloadExampleImagesUseCase(download_manager=manager)
with pytest.raises(DownloadExampleImagesConfigurationError):
await use_case.execute({})
async def test_download_example_images_use_case_propagates_generic_error() -> None:
manager = StubExampleImagesDownloadManager()
manager.error = "generic"
use_case = DownloadExampleImagesUseCase(download_manager=manager)
with pytest.raises(ExampleImagesDownloadError):
await use_case.execute({})
class DummyJsonRequest:
def __init__(self, payload: Dict[str, Any]) -> None:
self._payload = payload
self.content_type = "application/json"
async def json(self) -> Dict[str, Any]:
return self._payload
async def test_import_example_images_use_case_delegates() -> None:
processor = StubExampleImagesProcessor()
use_case = ImportExampleImagesUseCase(processor=processor)
request = DummyJsonRequest({"model_hash": "abc", "file_paths": ["/tmp/file"]})
result = await use_case.execute(request)
assert processor.calls == [{"model_hash": "abc", "files": ["/tmp/file"]}]
assert result == {"success": True}
async def test_import_example_images_use_case_maps_validation_error() -> None:
processor = StubExampleImagesProcessor()
processor.error = "validation"
use_case = ImportExampleImagesUseCase(processor=processor)
request = DummyJsonRequest({"model_hash": None, "file_paths": []})
with pytest.raises(ImportExampleImagesValidationError):
await use_case.execute(request)
async def test_import_example_images_use_case_propagates_generic_error() -> None:
processor = StubExampleImagesProcessor()
processor.error = "generic"
use_case = ImportExampleImagesUseCase(processor=processor)
request = DummyJsonRequest({"model_hash": "abc", "file_paths": ["/tmp/file"]})
with pytest.raises(ExampleImagesImportError):
await use_case.execute(request)