feat(example-images): add use case orchestration

This commit is contained in:
pixelpaws
2025-09-23 11:47:12 +08:00
parent bd10280736
commit aaad270822
10 changed files with 582 additions and 262 deletions

View File

@@ -0,0 +1,19 @@
"""Example image specific use case exports."""
from .download_example_images_use_case import (
DownloadExampleImagesUseCase,
DownloadExampleImagesInProgressError,
DownloadExampleImagesConfigurationError,
)
from .import_example_images_use_case import (
ImportExampleImagesUseCase,
ImportExampleImagesValidationError,
)
__all__ = [
"DownloadExampleImagesUseCase",
"DownloadExampleImagesInProgressError",
"DownloadExampleImagesConfigurationError",
"ImportExampleImagesUseCase",
"ImportExampleImagesValidationError",
]

View File

@@ -0,0 +1,42 @@
"""Use case coordinating example image downloads."""
from __future__ import annotations
from typing import Any, Dict
from ....utils.example_images_download_manager import (
DownloadConfigurationError,
DownloadInProgressError,
ExampleImagesDownloadError,
)
class DownloadExampleImagesInProgressError(RuntimeError):
"""Raised when a download is already running."""
def __init__(self, progress: Dict[str, Any]) -> None:
super().__init__("Download already in progress")
self.progress = progress
class DownloadExampleImagesConfigurationError(ValueError):
"""Raised when settings prevent downloads from starting."""
class DownloadExampleImagesUseCase:
"""Validate payloads and trigger the download manager."""
def __init__(self, *, download_manager) -> None:
self._download_manager = download_manager
async def execute(self, payload: Dict[str, Any]) -> Dict[str, Any]:
"""Start a download and translate manager errors."""
try:
return await self._download_manager.start_download(payload)
except DownloadInProgressError as exc:
raise DownloadExampleImagesInProgressError(exc.progress_snapshot) from exc
except DownloadConfigurationError as exc:
raise DownloadExampleImagesConfigurationError(str(exc)) from exc
except ExampleImagesDownloadError:
raise

View File

@@ -0,0 +1,86 @@
"""Use case for importing example images."""
from __future__ import annotations
import os
import tempfile
from contextlib import suppress
from typing import Any, Dict, List
from aiohttp import web
from ....utils.example_images_processor import (
ExampleImagesImportError,
ExampleImagesProcessor,
ExampleImagesValidationError,
)
class ImportExampleImagesValidationError(ValueError):
"""Raised when request validation fails."""
class ImportExampleImagesUseCase:
"""Parse upload payloads and delegate to the processor service."""
def __init__(self, *, processor: ExampleImagesProcessor) -> None:
self._processor = processor
async def execute(self, request: web.Request) -> Dict[str, Any]:
model_hash: str | None = None
files_to_import: List[str] = []
temp_files: List[str] = []
try:
if request.content_type and "multipart/form-data" in request.content_type:
reader = await request.multipart()
first_field = await reader.next()
if first_field and first_field.name == "model_hash":
model_hash = await first_field.text()
else:
# Support clients that send files first and hash later
if first_field is not None:
await self._collect_upload_file(first_field, files_to_import, temp_files)
async for field in reader:
if field.name == "model_hash" and not model_hash:
model_hash = await field.text()
elif field.name == "files":
await self._collect_upload_file(field, files_to_import, temp_files)
else:
data = await request.json()
model_hash = data.get("model_hash")
files_to_import = list(data.get("file_paths", []))
result = await self._processor.import_images(model_hash, files_to_import)
return result
except ExampleImagesValidationError as exc:
raise ImportExampleImagesValidationError(str(exc)) from exc
except ExampleImagesImportError:
raise
finally:
for path in temp_files:
with suppress(Exception):
os.remove(path)
async def _collect_upload_file(
self,
field: Any,
files_to_import: List[str],
temp_files: List[str],
) -> None:
"""Persist an uploaded file to disk and add it to the import list."""
filename = field.filename or "upload"
file_ext = os.path.splitext(filename)[1].lower()
with tempfile.NamedTemporaryFile(suffix=file_ext, delete=False) as tmp_file:
temp_files.append(tmp_file.name)
while True:
chunk = await field.read_chunk()
if not chunk:
break
tmp_file.write(chunk)
files_to_import.append(tmp_file.name)