refactor(routes): introduce example images controller

This commit is contained in:
pixelpaws
2025-09-23 11:12:08 +08:00
parent 613cd81152
commit 85f79cd8d1
3 changed files with 371 additions and 162 deletions

View File

@@ -1,88 +1,69 @@
from __future__ import annotations
import logging
from typing import Callable
from typing import Callable, Mapping
from aiohttp import web
from .example_images_route_registrar import ExampleImagesRouteRegistrar
from .handlers.example_images_handlers import (
ExampleImagesDownloadHandler,
ExampleImagesFileHandler,
ExampleImagesHandlerSet,
ExampleImagesManagementHandler,
)
from ..utils.example_images_download_manager import DownloadManager
from ..utils.example_images_processor import ExampleImagesProcessor
from ..utils.example_images_file_manager import ExampleImagesFileManager
from ..services.websocket_manager import ws_manager
from ..utils.example_images_processor import ExampleImagesProcessor
logger = logging.getLogger(__name__)
class ExampleImagesRoutes:
"""Routes for example images related functionality"""
"""Route controller for example image endpoints."""
@staticmethod
def setup_routes(app: web.Application) -> None:
"""Register example images routes using the registrar."""
def __init__(
self,
*,
download_manager=DownloadManager,
processor=ExampleImagesProcessor,
file_manager=ExampleImagesFileManager,
) -> None:
self._download_manager = download_manager
self._processor = processor
self._file_manager = file_manager
self._handler_set: ExampleImagesHandlerSet | None = None
self._handler_mapping: Mapping[str, Callable[[web.Request], web.StreamResponse]] | None = None
@classmethod
def setup_routes(cls, app: web.Application) -> None:
"""Register routes on the given aiohttp application using default wiring."""
controller = cls()
controller.register(app)
def register(self, app: web.Application) -> None:
"""Bind the controller's handlers to the aiohttp router."""
registrar = ExampleImagesRouteRegistrar(app)
registrar.register_routes(ExampleImagesRoutes._route_mapping())
registrar.register_routes(self.to_route_mapping())
@staticmethod
def _route_mapping() -> dict[str, Callable[[web.Request], object]]:
return {
"download_example_images": ExampleImagesRoutes.download_example_images,
"import_example_images": ExampleImagesRoutes.import_example_images,
"get_example_images_status": ExampleImagesRoutes.get_example_images_status,
"pause_example_images": ExampleImagesRoutes.pause_example_images,
"resume_example_images": ExampleImagesRoutes.resume_example_images,
"open_example_images_folder": ExampleImagesRoutes.open_example_images_folder,
"get_example_image_files": ExampleImagesRoutes.get_example_image_files,
"has_example_images": ExampleImagesRoutes.has_example_images,
"delete_example_image": ExampleImagesRoutes.delete_example_image,
"force_download_example_images": ExampleImagesRoutes.force_download_example_images,
}
def to_route_mapping(self) -> Mapping[str, Callable[[web.Request], web.StreamResponse]]:
"""Return the registrar-compatible mapping of handler names to callables."""
@staticmethod
async def download_example_images(request):
"""Download example images for models from Civitai"""
return await DownloadManager.start_download(request)
if self._handler_mapping is None:
handler_set = self._build_handler_set()
self._handler_set = handler_set
self._handler_mapping = handler_set.to_route_mapping()
return self._handler_mapping
@staticmethod
async def get_example_images_status(request):
"""Get the current status of example images download"""
return await DownloadManager.get_status(request)
@staticmethod
async def pause_example_images(request):
"""Pause the example images download"""
return await DownloadManager.pause_download(request)
@staticmethod
async def resume_example_images(request):
"""Resume the example images download"""
return await DownloadManager.resume_download(request)
@staticmethod
async def open_example_images_folder(request):
"""Open the example images folder for a specific model"""
return await ExampleImagesFileManager.open_folder(request)
@staticmethod
async def get_example_image_files(request):
"""Get list of example image files for a specific model"""
return await ExampleImagesFileManager.get_files(request)
@staticmethod
async def import_example_images(request):
"""Import local example images for a model"""
return await ExampleImagesProcessor.import_images(request)
@staticmethod
async def has_example_images(request):
"""Check if example images folder exists and is not empty for a model"""
return await ExampleImagesFileManager.has_images(request)
@staticmethod
async def delete_example_image(request):
"""Delete a custom example image for a model"""
return await ExampleImagesProcessor.delete_custom_image(request)
@staticmethod
async def force_download_example_images(request):
"""Force download example images for specific models"""
return await DownloadManager.start_force_download(request)
def _build_handler_set(self) -> ExampleImagesHandlerSet:
logger.debug("Building ExampleImagesHandlerSet with %s, %s, %s", self._download_manager, self._processor, self._file_manager)
download_handler = ExampleImagesDownloadHandler(self._download_manager)
management_handler = ExampleImagesManagementHandler(self._processor)
file_handler = ExampleImagesFileHandler(self._file_manager)
return ExampleImagesHandlerSet(
download=download_handler,
management=management_handler,
files=file_handler,
)

View File

@@ -0,0 +1,83 @@
"""Handler set for example image routes."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Callable, Mapping
from aiohttp import web
class ExampleImagesDownloadHandler:
"""HTTP adapters for download-related example image endpoints."""
def __init__(self, download_manager) -> None:
self._download_manager = download_manager
async def download_example_images(self, request: web.Request) -> web.StreamResponse:
return await self._download_manager.start_download(request)
async def get_example_images_status(self, request: web.Request) -> web.StreamResponse:
return await self._download_manager.get_status(request)
async def pause_example_images(self, request: web.Request) -> web.StreamResponse:
return await self._download_manager.pause_download(request)
async def resume_example_images(self, request: web.Request) -> web.StreamResponse:
return await self._download_manager.resume_download(request)
async def force_download_example_images(self, request: web.Request) -> web.StreamResponse:
return await self._download_manager.start_force_download(request)
class ExampleImagesManagementHandler:
"""HTTP adapters for import/delete endpoints."""
def __init__(self, processor) -> None:
self._processor = processor
async def import_example_images(self, request: web.Request) -> web.StreamResponse:
return await self._processor.import_images(request)
async def delete_example_image(self, request: web.Request) -> web.StreamResponse:
return await self._processor.delete_custom_image(request)
class ExampleImagesFileHandler:
"""HTTP adapters for filesystem-centric endpoints."""
def __init__(self, file_manager) -> None:
self._file_manager = file_manager
async def open_example_images_folder(self, request: web.Request) -> web.StreamResponse:
return await self._file_manager.open_folder(request)
async def get_example_image_files(self, request: web.Request) -> web.StreamResponse:
return await self._file_manager.get_files(request)
async def has_example_images(self, request: web.Request) -> web.StreamResponse:
return await self._file_manager.has_images(request)
@dataclass(frozen=True)
class ExampleImagesHandlerSet:
"""Aggregate of handlers exposed to the registrar."""
download: ExampleImagesDownloadHandler
management: ExampleImagesManagementHandler
files: ExampleImagesFileHandler
def to_route_mapping(self) -> Mapping[str, Callable[[web.Request], web.StreamResponse]]:
"""Flatten handler methods into the registrar mapping."""
return {
"download_example_images": self.download.download_example_images,
"get_example_images_status": self.download.get_example_images_status,
"pause_example_images": self.download.pause_example_images,
"resume_example_images": self.download.resume_example_images,
"force_download_example_images": self.download.force_download_example_images,
"import_example_images": self.management.import_example_images,
"delete_example_image": self.management.delete_example_image,
"open_example_images_folder": self.files.open_example_images_folder,
"get_example_image_files": self.files.get_example_image_files,
"has_example_images": self.files.has_example_images,
}

View File

@@ -1,14 +1,21 @@
from __future__ import annotations
from contextlib import asynccontextmanager
from dataclasses import dataclass
from typing import Any, List, Set, Tuple
from typing import Any, List, Tuple
from aiohttp import web
from aiohttp.test_utils import TestClient, TestServer
import pytest
from py.routes import example_images_routes
from py.routes.example_images_routes import ExampleImagesRoutes
from py.routes.example_images_route_registrar import ROUTE_DEFINITIONS
from py.routes.example_images_routes import ExampleImagesRoutes
from py.routes.handlers.example_images_handlers import (
ExampleImagesDownloadHandler,
ExampleImagesFileHandler,
ExampleImagesHandlerSet,
ExampleImagesManagementHandler,
)
@dataclass
@@ -16,85 +23,88 @@ class ExampleImagesHarness:
"""Container exposing the aiohttp client and stubbed collaborators."""
client: TestClient
download_manager: Any
processor: Any
file_manager: Any
download_manager: "StubDownloadManager"
processor: "StubExampleImagesProcessor"
file_manager: "StubExampleImagesFileManager"
controller: ExampleImagesRoutes
class StubDownloadManager:
def __init__(self) -> None:
self.calls: List[Tuple[str, Any]] = []
async def start_download(self, request: web.Request) -> web.StreamResponse:
payload = await request.json()
self.calls.append(("start_download", payload))
return web.json_response({"operation": "start_download", "payload": payload})
async def get_status(self, request: web.Request) -> web.StreamResponse:
self.calls.append(("get_status", dict(request.query)))
return web.json_response({"operation": "get_status"})
async def pause_download(self, request: web.Request) -> web.StreamResponse:
self.calls.append(("pause_download", None))
return web.json_response({"operation": "pause_download"})
async def resume_download(self, request: web.Request) -> web.StreamResponse:
self.calls.append(("resume_download", None))
return web.json_response({"operation": "resume_download"})
async def start_force_download(self, request: web.Request) -> web.StreamResponse:
payload = await request.json()
self.calls.append(("start_force_download", payload))
return web.json_response({"operation": "start_force_download", "payload": payload})
class StubExampleImagesProcessor:
def __init__(self) -> None:
self.calls: List[Tuple[str, Any]] = []
async def import_images(self, request: web.Request) -> web.StreamResponse:
payload = await request.json()
self.calls.append(("import_images", payload))
return web.json_response({"operation": "import_images", "payload": payload})
async def delete_custom_image(self, request: web.Request) -> web.StreamResponse:
payload = await request.json()
self.calls.append(("delete_custom_image", payload))
return web.json_response({"operation": "delete_custom_image", "payload": payload})
class StubExampleImagesFileManager:
def __init__(self) -> None:
self.calls: List[Tuple[str, Any]] = []
async def open_folder(self, request: web.Request) -> web.StreamResponse:
payload = await request.json()
self.calls.append(("open_folder", payload))
return web.json_response({"operation": "open_folder", "payload": payload})
async def get_files(self, request: web.Request) -> web.StreamResponse:
self.calls.append(("get_files", dict(request.query)))
return web.json_response({"operation": "get_files", "query": dict(request.query)})
async def has_images(self, request: web.Request) -> web.StreamResponse:
self.calls.append(("has_images", dict(request.query)))
return web.json_response({"operation": "has_images", "query": dict(request.query)})
@asynccontextmanager
async def example_images_app(monkeypatch: pytest.MonkeyPatch) -> ExampleImagesHarness:
async def example_images_app() -> ExampleImagesHarness:
"""Yield an ExampleImagesRoutes app wired with stubbed collaborators."""
class StubDownloadManager:
calls: List[Tuple[str, Any]] = []
download_manager = StubDownloadManager()
processor = StubExampleImagesProcessor()
file_manager = StubExampleImagesFileManager()
@staticmethod
async def start_download(request):
payload = await request.json()
StubDownloadManager.calls.append(("start_download", payload))
return web.json_response({"operation": "start_download", "payload": payload})
@staticmethod
async def get_status(request):
StubDownloadManager.calls.append(("get_status", dict(request.query)))
return web.json_response({"operation": "get_status"})
@staticmethod
async def pause_download(request):
StubDownloadManager.calls.append(("pause_download", None))
return web.json_response({"operation": "pause_download"})
@staticmethod
async def resume_download(request):
StubDownloadManager.calls.append(("resume_download", None))
return web.json_response({"operation": "resume_download"})
@staticmethod
async def start_force_download(request):
payload = await request.json()
StubDownloadManager.calls.append(("start_force_download", payload))
return web.json_response({"operation": "start_force_download", "payload": payload})
class StubExampleImagesProcessor:
calls: List[Tuple[str, Any]] = []
@staticmethod
async def import_images(request):
payload = await request.json()
StubExampleImagesProcessor.calls.append(("import_images", payload))
return web.json_response({"operation": "import_images", "payload": payload})
@staticmethod
async def delete_custom_image(request):
payload = await request.json()
StubExampleImagesProcessor.calls.append(("delete_custom_image", payload))
return web.json_response({"operation": "delete_custom_image", "payload": payload})
class StubExampleImagesFileManager:
calls: List[Tuple[str, Any]] = []
@staticmethod
async def open_folder(request):
payload = await request.json()
StubExampleImagesFileManager.calls.append(("open_folder", payload))
return web.json_response({"operation": "open_folder", "payload": payload})
@staticmethod
async def get_files(request):
StubExampleImagesFileManager.calls.append(("get_files", dict(request.query)))
return web.json_response({"operation": "get_files", "query": dict(request.query)})
@staticmethod
async def has_images(request):
StubExampleImagesFileManager.calls.append(("has_images", dict(request.query)))
return web.json_response({"operation": "has_images", "query": dict(request.query)})
monkeypatch.setattr(example_images_routes, "DownloadManager", StubDownloadManager)
monkeypatch.setattr(example_images_routes, "ExampleImagesProcessor", StubExampleImagesProcessor)
monkeypatch.setattr(example_images_routes, "ExampleImagesFileManager", StubExampleImagesFileManager)
controller = ExampleImagesRoutes(
download_manager=download_manager,
processor=processor,
file_manager=file_manager,
)
app = web.Application()
ExampleImagesRoutes.setup_routes(app)
controller.register(app)
server = TestServer(app)
client = TestClient(server)
@@ -103,17 +113,18 @@ async def example_images_app(monkeypatch: pytest.MonkeyPatch) -> ExampleImagesHa
try:
yield ExampleImagesHarness(
client=client,
download_manager=StubDownloadManager,
processor=StubExampleImagesProcessor,
file_manager=StubExampleImagesFileManager,
download_manager=download_manager,
processor=processor,
file_manager=file_manager,
controller=controller,
)
finally:
await client.close()
async def test_setup_routes_registers_all_definitions(monkeypatch: pytest.MonkeyPatch):
async with example_images_app(monkeypatch) as harness:
registered: Set[tuple[str, str]] = {
async def test_setup_routes_registers_all_definitions():
async with example_images_app() as harness:
registered = {
(route.method, route.resource.canonical)
for route in harness.client.app.router.routes()
if route.resource.canonical
@@ -131,8 +142,8 @@ async def test_setup_routes_registers_all_definitions(monkeypatch: pytest.Monkey
("/api/lm/force-download-example-images", {"model_hashes": ["abc123"]}),
],
)
async def test_download_routes_delegate_to_manager(endpoint, payload, monkeypatch: pytest.MonkeyPatch):
async with example_images_app(monkeypatch) as harness:
async def test_download_routes_delegate_to_manager(endpoint, payload):
async with example_images_app() as harness:
response = await harness.client.post(endpoint, json=payload)
body = await response.json()
@@ -144,8 +155,8 @@ async def test_download_routes_delegate_to_manager(endpoint, payload, monkeypatc
assert expected_call in harness.download_manager.calls
async def test_status_route_returns_manager_payload(monkeypatch: pytest.MonkeyPatch):
async with example_images_app(monkeypatch) as harness:
async def test_status_route_returns_manager_payload():
async with example_images_app() as harness:
response = await harness.client.get(
"/api/lm/example-images-status", params={"detail": "true"}
)
@@ -156,8 +167,8 @@ async def test_status_route_returns_manager_payload(monkeypatch: pytest.MonkeyPa
assert harness.download_manager.calls == [("get_status", {"detail": "true"})]
async def test_pause_and_resume_routes_delegate(monkeypatch: pytest.MonkeyPatch):
async with example_images_app(monkeypatch) as harness:
async def test_pause_and_resume_routes_delegate():
async with example_images_app() as harness:
pause_response = await harness.client.post("/api/lm/pause-example-images")
resume_response = await harness.client.post("/api/lm/resume-example-images")
@@ -172,9 +183,9 @@ async def test_pause_and_resume_routes_delegate(monkeypatch: pytest.MonkeyPatch)
]
async def test_import_route_delegates_to_processor(monkeypatch: pytest.MonkeyPatch):
async def test_import_route_delegates_to_processor():
payload = {"model_hash": "abc123", "files": ["/path/image.png"]}
async with example_images_app(monkeypatch) as harness:
async with example_images_app() as harness:
response = await harness.client.post(
"/api/lm/import-example-images", json=payload
)
@@ -185,9 +196,9 @@ async def test_import_route_delegates_to_processor(monkeypatch: pytest.MonkeyPat
assert harness.processor.calls == [("import_images", payload)]
async def test_delete_route_delegates_to_processor(monkeypatch: pytest.MonkeyPatch):
async def test_delete_route_delegates_to_processor():
payload = {"model_hash": "abc123", "short_id": "xyz"}
async with example_images_app(monkeypatch) as harness:
async with example_images_app() as harness:
response = await harness.client.post(
"/api/lm/delete-example-image", json=payload
)
@@ -198,11 +209,11 @@ async def test_delete_route_delegates_to_processor(monkeypatch: pytest.MonkeyPat
assert harness.processor.calls == [("delete_custom_image", payload)]
async def test_file_routes_delegate_to_file_manager(monkeypatch: pytest.MonkeyPatch):
async def test_file_routes_delegate_to_file_manager():
open_payload = {"model_hash": "abc123"}
files_params = {"model_hash": "def456"}
async with example_images_app(monkeypatch) as harness:
async with example_images_app() as harness:
open_response = await harness.client.post(
"/api/lm/open-example-images-folder", json=open_payload
)
@@ -232,3 +243,137 @@ async def test_file_routes_delegate_to_file_manager(monkeypatch: pytest.MonkeyPa
("get_files", files_params),
("has_images", files_params),
]
@pytest.mark.asyncio
async def test_download_handler_methods_delegate() -> None:
class Recorder:
def __init__(self) -> None:
self.calls: List[Tuple[str, Any]] = []
async def start_download(self, request) -> str:
self.calls.append(("start_download", request))
return "download"
async def get_status(self, request) -> str:
self.calls.append(("get_status", request))
return "status"
async def pause_download(self, request) -> str:
self.calls.append(("pause_download", request))
return "pause"
async def resume_download(self, request) -> str:
self.calls.append(("resume_download", request))
return "resume"
async def start_force_download(self, request) -> str:
self.calls.append(("start_force_download", request))
return "force"
recorder = Recorder()
handler = ExampleImagesDownloadHandler(recorder)
request = object()
assert await handler.download_example_images(request) == "download"
assert await handler.get_example_images_status(request) == "status"
assert await handler.pause_example_images(request) == "pause"
assert await handler.resume_example_images(request) == "resume"
assert await handler.force_download_example_images(request) == "force"
expected = [
("start_download", request),
("get_status", request),
("pause_download", request),
("resume_download", request),
("start_force_download", request),
]
assert recorder.calls == expected
@pytest.mark.asyncio
async def test_management_handler_methods_delegate() -> None:
class Recorder:
def __init__(self) -> None:
self.calls: List[Tuple[str, Any]] = []
async def import_images(self, request) -> str:
self.calls.append(("import_images", request))
return "import"
async def delete_custom_image(self, request) -> str:
self.calls.append(("delete_custom_image", request))
return "delete"
recorder = Recorder()
handler = ExampleImagesManagementHandler(recorder)
request = object()
assert await handler.import_example_images(request) == "import"
assert await handler.delete_example_image(request) == "delete"
assert recorder.calls == [
("import_images", request),
("delete_custom_image", request),
]
@pytest.mark.asyncio
async def test_file_handler_methods_delegate() -> None:
class Recorder:
def __init__(self) -> None:
self.calls: List[Tuple[str, Any]] = []
async def open_folder(self, request) -> str:
self.calls.append(("open_folder", request))
return "open"
async def get_files(self, request) -> str:
self.calls.append(("get_files", request))
return "files"
async def has_images(self, request) -> str:
self.calls.append(("has_images", request))
return "has"
recorder = Recorder()
handler = ExampleImagesFileHandler(recorder)
request = object()
assert await handler.open_example_images_folder(request) == "open"
assert await handler.get_example_image_files(request) == "files"
assert await handler.has_example_images(request) == "has"
assert recorder.calls == [
("open_folder", request),
("get_files", request),
("has_images", request),
]
def test_handler_set_route_mapping_includes_all_handlers() -> None:
download = ExampleImagesDownloadHandler(object())
management = ExampleImagesManagementHandler(object())
files = ExampleImagesFileHandler(object())
handler_set = ExampleImagesHandlerSet(
download=download,
management=management,
files=files,
)
mapping = handler_set.to_route_mapping()
expected_keys = {
"download_example_images",
"get_example_images_status",
"pause_example_images",
"resume_example_images",
"force_download_example_images",
"import_example_images",
"delete_example_image",
"open_example_images_folder",
"get_example_image_files",
"has_example_images",
}
assert mapping.keys() == expected_keys
for key in expected_keys:
assert callable(mapping[key])