mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-24 22:52:12 -03:00
feat(example-images): add use case orchestration
This commit is contained in:
@@ -13,6 +13,13 @@ from .download_model_use_case import (
|
||||
DownloadModelUseCase,
|
||||
DownloadModelValidationError,
|
||||
)
|
||||
from .example_images import (
|
||||
DownloadExampleImagesConfigurationError,
|
||||
DownloadExampleImagesInProgressError,
|
||||
DownloadExampleImagesUseCase,
|
||||
ImportExampleImagesUseCase,
|
||||
ImportExampleImagesValidationError,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AutoOrganizeInProgressError",
|
||||
@@ -22,4 +29,9 @@ __all__ = [
|
||||
"DownloadModelEarlyAccessError",
|
||||
"DownloadModelUseCase",
|
||||
"DownloadModelValidationError",
|
||||
"DownloadExampleImagesConfigurationError",
|
||||
"DownloadExampleImagesInProgressError",
|
||||
"DownloadExampleImagesUseCase",
|
||||
"ImportExampleImagesUseCase",
|
||||
"ImportExampleImagesValidationError",
|
||||
]
|
||||
|
||||
19
py/services/use_cases/example_images/__init__.py
Normal file
19
py/services/use_cases/example_images/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""Example image specific use case exports."""
|
||||
|
||||
from .download_example_images_use_case import (
|
||||
DownloadExampleImagesUseCase,
|
||||
DownloadExampleImagesInProgressError,
|
||||
DownloadExampleImagesConfigurationError,
|
||||
)
|
||||
from .import_example_images_use_case import (
|
||||
ImportExampleImagesUseCase,
|
||||
ImportExampleImagesValidationError,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"DownloadExampleImagesUseCase",
|
||||
"DownloadExampleImagesInProgressError",
|
||||
"DownloadExampleImagesConfigurationError",
|
||||
"ImportExampleImagesUseCase",
|
||||
"ImportExampleImagesValidationError",
|
||||
]
|
||||
@@ -0,0 +1,42 @@
|
||||
"""Use case coordinating example image downloads."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from ....utils.example_images_download_manager import (
|
||||
DownloadConfigurationError,
|
||||
DownloadInProgressError,
|
||||
ExampleImagesDownloadError,
|
||||
)
|
||||
|
||||
|
||||
class DownloadExampleImagesInProgressError(RuntimeError):
|
||||
"""Raised when a download is already running."""
|
||||
|
||||
def __init__(self, progress: Dict[str, Any]) -> None:
|
||||
super().__init__("Download already in progress")
|
||||
self.progress = progress
|
||||
|
||||
|
||||
class DownloadExampleImagesConfigurationError(ValueError):
|
||||
"""Raised when settings prevent downloads from starting."""
|
||||
|
||||
|
||||
class DownloadExampleImagesUseCase:
|
||||
"""Validate payloads and trigger the download manager."""
|
||||
|
||||
def __init__(self, *, download_manager) -> None:
|
||||
self._download_manager = download_manager
|
||||
|
||||
async def execute(self, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Start a download and translate manager errors."""
|
||||
|
||||
try:
|
||||
return await self._download_manager.start_download(payload)
|
||||
except DownloadInProgressError as exc:
|
||||
raise DownloadExampleImagesInProgressError(exc.progress_snapshot) from exc
|
||||
except DownloadConfigurationError as exc:
|
||||
raise DownloadExampleImagesConfigurationError(str(exc)) from exc
|
||||
except ExampleImagesDownloadError:
|
||||
raise
|
||||
@@ -0,0 +1,86 @@
|
||||
"""Use case for importing example images."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from contextlib import suppress
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
from ....utils.example_images_processor import (
|
||||
ExampleImagesImportError,
|
||||
ExampleImagesProcessor,
|
||||
ExampleImagesValidationError,
|
||||
)
|
||||
|
||||
|
||||
class ImportExampleImagesValidationError(ValueError):
|
||||
"""Raised when request validation fails."""
|
||||
|
||||
|
||||
class ImportExampleImagesUseCase:
|
||||
"""Parse upload payloads and delegate to the processor service."""
|
||||
|
||||
def __init__(self, *, processor: ExampleImagesProcessor) -> None:
|
||||
self._processor = processor
|
||||
|
||||
async def execute(self, request: web.Request) -> Dict[str, Any]:
|
||||
model_hash: str | None = None
|
||||
files_to_import: List[str] = []
|
||||
temp_files: List[str] = []
|
||||
|
||||
try:
|
||||
if request.content_type and "multipart/form-data" in request.content_type:
|
||||
reader = await request.multipart()
|
||||
|
||||
first_field = await reader.next()
|
||||
if first_field and first_field.name == "model_hash":
|
||||
model_hash = await first_field.text()
|
||||
else:
|
||||
# Support clients that send files first and hash later
|
||||
if first_field is not None:
|
||||
await self._collect_upload_file(first_field, files_to_import, temp_files)
|
||||
|
||||
async for field in reader:
|
||||
if field.name == "model_hash" and not model_hash:
|
||||
model_hash = await field.text()
|
||||
elif field.name == "files":
|
||||
await self._collect_upload_file(field, files_to_import, temp_files)
|
||||
else:
|
||||
data = await request.json()
|
||||
model_hash = data.get("model_hash")
|
||||
files_to_import = list(data.get("file_paths", []))
|
||||
|
||||
result = await self._processor.import_images(model_hash, files_to_import)
|
||||
return result
|
||||
except ExampleImagesValidationError as exc:
|
||||
raise ImportExampleImagesValidationError(str(exc)) from exc
|
||||
except ExampleImagesImportError:
|
||||
raise
|
||||
finally:
|
||||
for path in temp_files:
|
||||
with suppress(Exception):
|
||||
os.remove(path)
|
||||
|
||||
async def _collect_upload_file(
|
||||
self,
|
||||
field: Any,
|
||||
files_to_import: List[str],
|
||||
temp_files: List[str],
|
||||
) -> None:
|
||||
"""Persist an uploaded file to disk and add it to the import list."""
|
||||
|
||||
filename = field.filename or "upload"
|
||||
file_ext = os.path.splitext(filename)[1].lower()
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=file_ext, delete=False) as tmp_file:
|
||||
temp_files.append(tmp_file.name)
|
||||
while True:
|
||||
chunk = await field.read_chunk()
|
||||
if not chunk:
|
||||
break
|
||||
tmp_file.write(chunk)
|
||||
|
||||
files_to_import.append(tmp_file.name)
|
||||
Reference in New Issue
Block a user