mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-24 22:52:12 -03:00
feat(example-images): add use case orchestration
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, List, Tuple
|
||||
@@ -33,37 +34,35 @@ class StubDownloadManager:
|
||||
def __init__(self) -> None:
|
||||
self.calls: List[Tuple[str, Any]] = []
|
||||
|
||||
async def start_download(self, request: web.Request) -> web.StreamResponse:
|
||||
payload = await request.json()
|
||||
async def start_download(self, payload: Any) -> dict:
|
||||
self.calls.append(("start_download", payload))
|
||||
return web.json_response({"operation": "start_download", "payload": payload})
|
||||
return {"operation": "start_download", "payload": payload}
|
||||
|
||||
async def get_status(self, request: web.Request) -> web.StreamResponse:
|
||||
async def get_status(self, request: web.Request) -> dict:
|
||||
self.calls.append(("get_status", dict(request.query)))
|
||||
return web.json_response({"operation": "get_status"})
|
||||
return {"operation": "get_status"}
|
||||
|
||||
async def pause_download(self, request: web.Request) -> web.StreamResponse:
|
||||
async def pause_download(self, request: web.Request) -> dict:
|
||||
self.calls.append(("pause_download", None))
|
||||
return web.json_response({"operation": "pause_download"})
|
||||
return {"operation": "pause_download"}
|
||||
|
||||
async def resume_download(self, request: web.Request) -> web.StreamResponse:
|
||||
async def resume_download(self, request: web.Request) -> dict:
|
||||
self.calls.append(("resume_download", None))
|
||||
return web.json_response({"operation": "resume_download"})
|
||||
return {"operation": "resume_download"}
|
||||
|
||||
async def start_force_download(self, request: web.Request) -> web.StreamResponse:
|
||||
payload = await request.json()
|
||||
async def start_force_download(self, payload: Any) -> dict:
|
||||
self.calls.append(("start_force_download", payload))
|
||||
return web.json_response({"operation": "start_force_download", "payload": payload})
|
||||
return {"operation": "start_force_download", "payload": payload}
|
||||
|
||||
|
||||
class StubExampleImagesProcessor:
|
||||
def __init__(self) -> None:
|
||||
self.calls: List[Tuple[str, Any]] = []
|
||||
|
||||
async def import_images(self, request: web.Request) -> web.StreamResponse:
|
||||
payload = await request.json()
|
||||
async def import_images(self, model_hash: str, files: List[str]) -> dict:
|
||||
payload = {"model_hash": model_hash, "file_paths": files}
|
||||
self.calls.append(("import_images", payload))
|
||||
return web.json_response({"operation": "import_images", "payload": payload})
|
||||
return {"operation": "import_images", "payload": payload}
|
||||
|
||||
async def delete_custom_image(self, request: web.Request) -> web.StreamResponse:
|
||||
payload = await request.json()
|
||||
@@ -184,7 +183,7 @@ async def test_pause_and_resume_routes_delegate():
|
||||
|
||||
|
||||
async def test_import_route_delegates_to_processor():
|
||||
payload = {"model_hash": "abc123", "files": ["/path/image.png"]}
|
||||
payload = {"model_hash": "abc123", "file_paths": ["/path/image.png"]}
|
||||
async with example_images_app() as harness:
|
||||
response = await harness.client.post(
|
||||
"/api/lm/import-example-images", json=payload
|
||||
@@ -193,7 +192,8 @@ async def test_import_route_delegates_to_processor():
|
||||
|
||||
assert response.status == 200
|
||||
assert body == {"operation": "import_images", "payload": payload}
|
||||
assert harness.processor.calls == [("import_images", payload)]
|
||||
expected_call = ("import_images", payload)
|
||||
assert expected_call in harness.processor.calls
|
||||
|
||||
|
||||
async def test_delete_route_delegates_to_processor():
|
||||
@@ -251,70 +251,91 @@ async def test_download_handler_methods_delegate() -> None:
|
||||
def __init__(self) -> None:
|
||||
self.calls: List[Tuple[str, Any]] = []
|
||||
|
||||
async def start_download(self, request) -> str:
|
||||
self.calls.append(("start_download", request))
|
||||
return "download"
|
||||
|
||||
async def get_status(self, request) -> str:
|
||||
async def get_status(self, request) -> dict:
|
||||
self.calls.append(("get_status", request))
|
||||
return "status"
|
||||
return {"status": "ok"}
|
||||
|
||||
async def pause_download(self, request) -> str:
|
||||
async def pause_download(self, request) -> dict:
|
||||
self.calls.append(("pause_download", request))
|
||||
return "pause"
|
||||
return {"status": "paused"}
|
||||
|
||||
async def resume_download(self, request) -> str:
|
||||
async def resume_download(self, request) -> dict:
|
||||
self.calls.append(("resume_download", request))
|
||||
return "resume"
|
||||
return {"status": "running"}
|
||||
|
||||
async def start_force_download(self, request) -> str:
|
||||
self.calls.append(("start_force_download", request))
|
||||
return "force"
|
||||
async def start_force_download(self, payload) -> dict:
|
||||
self.calls.append(("start_force_download", payload))
|
||||
return {"status": "force", "payload": payload}
|
||||
|
||||
class StubDownloadUseCase:
|
||||
def __init__(self) -> None:
|
||||
self.payloads: List[Any] = []
|
||||
|
||||
async def execute(self, payload: dict) -> dict:
|
||||
self.payloads.append(payload)
|
||||
return {"status": "started", "payload": payload}
|
||||
|
||||
class DummyRequest:
|
||||
def __init__(self, payload: dict) -> None:
|
||||
self._payload = payload
|
||||
self.query = {}
|
||||
|
||||
async def json(self) -> dict:
|
||||
return self._payload
|
||||
|
||||
recorder = Recorder()
|
||||
handler = ExampleImagesDownloadHandler(recorder)
|
||||
request = object()
|
||||
use_case = StubDownloadUseCase()
|
||||
handler = ExampleImagesDownloadHandler(use_case, recorder)
|
||||
request = DummyRequest({"foo": "bar"})
|
||||
|
||||
assert await handler.download_example_images(request) == "download"
|
||||
assert await handler.get_example_images_status(request) == "status"
|
||||
assert await handler.pause_example_images(request) == "pause"
|
||||
assert await handler.resume_example_images(request) == "resume"
|
||||
assert await handler.force_download_example_images(request) == "force"
|
||||
download_response = await handler.download_example_images(request)
|
||||
assert json.loads(download_response.text) == {"status": "started", "payload": {"foo": "bar"}}
|
||||
status_response = await handler.get_example_images_status(request)
|
||||
assert json.loads(status_response.text) == {"status": "ok"}
|
||||
pause_response = await handler.pause_example_images(request)
|
||||
assert json.loads(pause_response.text) == {"status": "paused"}
|
||||
resume_response = await handler.resume_example_images(request)
|
||||
assert json.loads(resume_response.text) == {"status": "running"}
|
||||
force_response = await handler.force_download_example_images(request)
|
||||
assert json.loads(force_response.text) == {"status": "force", "payload": {"foo": "bar"}}
|
||||
|
||||
expected = [
|
||||
("start_download", request),
|
||||
assert use_case.payloads == [{"foo": "bar"}]
|
||||
assert recorder.calls == [
|
||||
("get_status", request),
|
||||
("pause_download", request),
|
||||
("resume_download", request),
|
||||
("start_force_download", request),
|
||||
("start_force_download", {"foo": "bar"}),
|
||||
]
|
||||
assert recorder.calls == expected
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_management_handler_methods_delegate() -> None:
|
||||
class StubImportUseCase:
|
||||
def __init__(self) -> None:
|
||||
self.requests: List[Any] = []
|
||||
|
||||
async def execute(self, request: Any) -> dict:
|
||||
self.requests.append(request)
|
||||
return {"status": "imported"}
|
||||
|
||||
class Recorder:
|
||||
def __init__(self) -> None:
|
||||
self.calls: List[Tuple[str, Any]] = []
|
||||
|
||||
async def import_images(self, request) -> str:
|
||||
self.calls.append(("import_images", request))
|
||||
return "import"
|
||||
|
||||
async def delete_custom_image(self, request) -> str:
|
||||
self.calls.append(("delete_custom_image", request))
|
||||
return "delete"
|
||||
|
||||
recorder = Recorder()
|
||||
handler = ExampleImagesManagementHandler(recorder)
|
||||
use_case = StubImportUseCase()
|
||||
handler = ExampleImagesManagementHandler(use_case, recorder)
|
||||
request = object()
|
||||
|
||||
assert await handler.import_example_images(request) == "import"
|
||||
import_response = await handler.import_example_images(request)
|
||||
assert json.loads(import_response.text) == {"status": "imported"}
|
||||
assert await handler.delete_example_image(request) == "delete"
|
||||
assert recorder.calls == [
|
||||
("import_images", request),
|
||||
("delete_custom_image", request),
|
||||
]
|
||||
assert use_case.requests == [request]
|
||||
assert recorder.calls == [("delete_custom_image", request)]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -350,8 +371,29 @@ async def test_file_handler_methods_delegate() -> None:
|
||||
|
||||
|
||||
def test_handler_set_route_mapping_includes_all_handlers() -> None:
|
||||
download = ExampleImagesDownloadHandler(object())
|
||||
management = ExampleImagesManagementHandler(object())
|
||||
class DummyUseCase:
|
||||
async def execute(self, payload):
|
||||
return payload
|
||||
|
||||
class DummyManager:
|
||||
async def get_status(self, request):
|
||||
return {}
|
||||
|
||||
async def pause_download(self, request):
|
||||
return {}
|
||||
|
||||
async def resume_download(self, request):
|
||||
return {}
|
||||
|
||||
async def start_force_download(self, payload):
|
||||
return payload
|
||||
|
||||
class DummyProcessor:
|
||||
async def delete_custom_image(self, request):
|
||||
return {}
|
||||
|
||||
download = ExampleImagesDownloadHandler(DummyUseCase(), DummyManager())
|
||||
management = ExampleImagesManagementHandler(DummyUseCase(), DummyProcessor())
|
||||
files = ExampleImagesFileHandler(object())
|
||||
handler_set = ExampleImagesHandlerSet(
|
||||
download=download,
|
||||
|
||||
@@ -10,9 +10,23 @@ from py_local.services.use_cases import (
|
||||
AutoOrganizeInProgressError,
|
||||
AutoOrganizeUseCase,
|
||||
BulkMetadataRefreshUseCase,
|
||||
DownloadExampleImagesConfigurationError,
|
||||
DownloadExampleImagesInProgressError,
|
||||
DownloadExampleImagesUseCase,
|
||||
DownloadModelEarlyAccessError,
|
||||
DownloadModelUseCase,
|
||||
DownloadModelValidationError,
|
||||
ImportExampleImagesUseCase,
|
||||
ImportExampleImagesValidationError,
|
||||
)
|
||||
from py_local.utils.example_images_download_manager import (
|
||||
DownloadConfigurationError,
|
||||
DownloadInProgressError,
|
||||
ExampleImagesDownloadError,
|
||||
)
|
||||
from py_local.utils.example_images_processor import (
|
||||
ExampleImagesImportError,
|
||||
ExampleImagesValidationError,
|
||||
)
|
||||
from tests.conftest import MockModelService, MockScanner
|
||||
|
||||
@@ -88,6 +102,38 @@ class StubDownloadCoordinator:
|
||||
return {"success": True, "download_id": "abc123"}
|
||||
|
||||
|
||||
class StubExampleImagesDownloadManager:
|
||||
def __init__(self) -> None:
|
||||
self.payloads: List[Dict[str, Any]] = []
|
||||
self.error: Optional[str] = None
|
||||
self.progress_snapshot = {"status": "running"}
|
||||
|
||||
async def start_download(self, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
self.payloads.append(payload)
|
||||
if self.error == "in_progress":
|
||||
raise DownloadInProgressError(self.progress_snapshot)
|
||||
if self.error == "configuration":
|
||||
raise DownloadConfigurationError("path missing")
|
||||
if self.error == "generic":
|
||||
raise ExampleImagesDownloadError("boom")
|
||||
return {"success": True, "message": "ok"}
|
||||
|
||||
|
||||
class StubExampleImagesProcessor:
|
||||
def __init__(self) -> None:
|
||||
self.calls: List[Dict[str, Any]] = []
|
||||
self.error: Optional[str] = None
|
||||
self.response: Dict[str, Any] = {"success": True}
|
||||
|
||||
async def import_images(self, model_hash: str, files: List[str]) -> Dict[str, Any]:
|
||||
self.calls.append({"model_hash": model_hash, "files": files})
|
||||
if self.error == "validation":
|
||||
raise ExampleImagesValidationError("missing")
|
||||
if self.error == "generic":
|
||||
raise ExampleImagesImportError("boom")
|
||||
return self.response
|
||||
|
||||
|
||||
async def test_auto_organize_use_case_executes_with_lock() -> None:
|
||||
file_service = StubFileService()
|
||||
lock_provider = StubLockProvider()
|
||||
@@ -189,3 +235,83 @@ async def test_download_model_use_case_returns_result() -> None:
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["download_id"] == "abc123"
|
||||
|
||||
|
||||
async def test_download_example_images_use_case_triggers_manager() -> None:
|
||||
manager = StubExampleImagesDownloadManager()
|
||||
use_case = DownloadExampleImagesUseCase(download_manager=manager)
|
||||
|
||||
payload = {"optimize": True}
|
||||
result = await use_case.execute(payload)
|
||||
|
||||
assert manager.payloads == [payload]
|
||||
assert result == {"success": True, "message": "ok"}
|
||||
|
||||
|
||||
async def test_download_example_images_use_case_maps_in_progress() -> None:
|
||||
manager = StubExampleImagesDownloadManager()
|
||||
manager.error = "in_progress"
|
||||
use_case = DownloadExampleImagesUseCase(download_manager=manager)
|
||||
|
||||
with pytest.raises(DownloadExampleImagesInProgressError) as exc:
|
||||
await use_case.execute({})
|
||||
|
||||
assert exc.value.progress == manager.progress_snapshot
|
||||
|
||||
|
||||
async def test_download_example_images_use_case_maps_configuration() -> None:
|
||||
manager = StubExampleImagesDownloadManager()
|
||||
manager.error = "configuration"
|
||||
use_case = DownloadExampleImagesUseCase(download_manager=manager)
|
||||
|
||||
with pytest.raises(DownloadExampleImagesConfigurationError):
|
||||
await use_case.execute({})
|
||||
|
||||
|
||||
async def test_download_example_images_use_case_propagates_generic_error() -> None:
|
||||
manager = StubExampleImagesDownloadManager()
|
||||
manager.error = "generic"
|
||||
use_case = DownloadExampleImagesUseCase(download_manager=manager)
|
||||
|
||||
with pytest.raises(ExampleImagesDownloadError):
|
||||
await use_case.execute({})
|
||||
|
||||
|
||||
class DummyJsonRequest:
|
||||
def __init__(self, payload: Dict[str, Any]) -> None:
|
||||
self._payload = payload
|
||||
self.content_type = "application/json"
|
||||
|
||||
async def json(self) -> Dict[str, Any]:
|
||||
return self._payload
|
||||
|
||||
|
||||
async def test_import_example_images_use_case_delegates() -> None:
|
||||
processor = StubExampleImagesProcessor()
|
||||
use_case = ImportExampleImagesUseCase(processor=processor)
|
||||
|
||||
request = DummyJsonRequest({"model_hash": "abc", "file_paths": ["/tmp/file"]})
|
||||
result = await use_case.execute(request)
|
||||
|
||||
assert processor.calls == [{"model_hash": "abc", "files": ["/tmp/file"]}]
|
||||
assert result == {"success": True}
|
||||
|
||||
|
||||
async def test_import_example_images_use_case_maps_validation_error() -> None:
|
||||
processor = StubExampleImagesProcessor()
|
||||
processor.error = "validation"
|
||||
use_case = ImportExampleImagesUseCase(processor=processor)
|
||||
request = DummyJsonRequest({"model_hash": None, "file_paths": []})
|
||||
|
||||
with pytest.raises(ImportExampleImagesValidationError):
|
||||
await use_case.execute(request)
|
||||
|
||||
|
||||
async def test_import_example_images_use_case_propagates_generic_error() -> None:
|
||||
processor = StubExampleImagesProcessor()
|
||||
processor.error = "generic"
|
||||
use_case = ImportExampleImagesUseCase(processor=processor)
|
||||
request = DummyJsonRequest({"model_hash": "abc", "file_paths": ["/tmp/file"]})
|
||||
|
||||
with pytest.raises(ExampleImagesImportError):
|
||||
await use_case.execute(request)
|
||||
|
||||
Reference in New Issue
Block a user