mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
Merge pull request #462 from willmiao/codex/wrap-long-running-flows-in-use-cases
feat(example-images): add use case orchestration
This commit is contained in:
@@ -12,6 +12,10 @@ from .handlers.example_images_handlers import (
|
||||
ExampleImagesHandlerSet,
|
||||
ExampleImagesManagementHandler,
|
||||
)
|
||||
from ..services.use_cases.example_images import (
|
||||
DownloadExampleImagesUseCase,
|
||||
ImportExampleImagesUseCase,
|
||||
)
|
||||
from ..utils.example_images_download_manager import DownloadManager
|
||||
from ..utils.example_images_file_manager import ExampleImagesFileManager
|
||||
from ..utils.example_images_processor import ExampleImagesProcessor
|
||||
@@ -59,8 +63,10 @@ class ExampleImagesRoutes:
|
||||
|
||||
def _build_handler_set(self) -> ExampleImagesHandlerSet:
|
||||
logger.debug("Building ExampleImagesHandlerSet with %s, %s, %s", self._download_manager, self._processor, self._file_manager)
|
||||
download_handler = ExampleImagesDownloadHandler(self._download_manager)
|
||||
management_handler = ExampleImagesManagementHandler(self._processor)
|
||||
download_use_case = DownloadExampleImagesUseCase(download_manager=self._download_manager)
|
||||
download_handler = ExampleImagesDownloadHandler(download_use_case, self._download_manager)
|
||||
import_use_case = ImportExampleImagesUseCase(processor=self._processor)
|
||||
management_handler = ExampleImagesManagementHandler(import_use_case, self._processor)
|
||||
file_handler = ExampleImagesFileHandler(self._file_manager)
|
||||
return ExampleImagesHandlerSet(
|
||||
download=download_handler,
|
||||
|
||||
@@ -6,37 +6,101 @@ from typing import Callable, Mapping
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
from ...services.use_cases.example_images import (
|
||||
DownloadExampleImagesConfigurationError,
|
||||
DownloadExampleImagesInProgressError,
|
||||
DownloadExampleImagesUseCase,
|
||||
ImportExampleImagesUseCase,
|
||||
ImportExampleImagesValidationError,
|
||||
)
|
||||
from ...utils.example_images_download_manager import (
|
||||
DownloadConfigurationError,
|
||||
DownloadInProgressError,
|
||||
DownloadNotRunningError,
|
||||
ExampleImagesDownloadError,
|
||||
)
|
||||
from ...utils.example_images_processor import ExampleImagesImportError
|
||||
|
||||
|
||||
class ExampleImagesDownloadHandler:
|
||||
"""HTTP adapters for download-related example image endpoints."""
|
||||
|
||||
def __init__(self, download_manager) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
download_use_case: DownloadExampleImagesUseCase,
|
||||
download_manager,
|
||||
) -> None:
|
||||
self._download_use_case = download_use_case
|
||||
self._download_manager = download_manager
|
||||
|
||||
async def download_example_images(self, request: web.Request) -> web.StreamResponse:
|
||||
return await self._download_manager.start_download(request)
|
||||
try:
|
||||
payload = await request.json()
|
||||
result = await self._download_use_case.execute(payload)
|
||||
return web.json_response(result)
|
||||
except DownloadExampleImagesInProgressError as exc:
|
||||
response = {
|
||||
'success': False,
|
||||
'error': str(exc),
|
||||
'status': exc.progress,
|
||||
}
|
||||
return web.json_response(response, status=400)
|
||||
except DownloadExampleImagesConfigurationError as exc:
|
||||
return web.json_response({'success': False, 'error': str(exc)}, status=400)
|
||||
except ExampleImagesDownloadError as exc:
|
||||
return web.json_response({'success': False, 'error': str(exc)}, status=500)
|
||||
|
||||
async def get_example_images_status(self, request: web.Request) -> web.StreamResponse:
|
||||
return await self._download_manager.get_status(request)
|
||||
result = await self._download_manager.get_status(request)
|
||||
return web.json_response(result)
|
||||
|
||||
async def pause_example_images(self, request: web.Request) -> web.StreamResponse:
|
||||
return await self._download_manager.pause_download(request)
|
||||
try:
|
||||
result = await self._download_manager.pause_download(request)
|
||||
return web.json_response(result)
|
||||
except DownloadNotRunningError as exc:
|
||||
return web.json_response({'success': False, 'error': str(exc)}, status=400)
|
||||
|
||||
async def resume_example_images(self, request: web.Request) -> web.StreamResponse:
|
||||
return await self._download_manager.resume_download(request)
|
||||
try:
|
||||
result = await self._download_manager.resume_download(request)
|
||||
return web.json_response(result)
|
||||
except DownloadNotRunningError as exc:
|
||||
return web.json_response({'success': False, 'error': str(exc)}, status=400)
|
||||
|
||||
async def force_download_example_images(self, request: web.Request) -> web.StreamResponse:
|
||||
return await self._download_manager.start_force_download(request)
|
||||
try:
|
||||
payload = await request.json()
|
||||
result = await self._download_manager.start_force_download(payload)
|
||||
return web.json_response(result)
|
||||
except DownloadInProgressError as exc:
|
||||
response = {
|
||||
'success': False,
|
||||
'error': str(exc),
|
||||
'status': exc.progress_snapshot,
|
||||
}
|
||||
return web.json_response(response, status=400)
|
||||
except DownloadConfigurationError as exc:
|
||||
return web.json_response({'success': False, 'error': str(exc)}, status=400)
|
||||
except ExampleImagesDownloadError as exc:
|
||||
return web.json_response({'success': False, 'error': str(exc)}, status=500)
|
||||
|
||||
|
||||
class ExampleImagesManagementHandler:
|
||||
"""HTTP adapters for import/delete endpoints."""
|
||||
|
||||
def __init__(self, processor) -> None:
|
||||
def __init__(self, import_use_case: ImportExampleImagesUseCase, processor) -> None:
|
||||
self._import_use_case = import_use_case
|
||||
self._processor = processor
|
||||
|
||||
async def import_example_images(self, request: web.Request) -> web.StreamResponse:
|
||||
return await self._processor.import_images(request)
|
||||
try:
|
||||
result = await self._import_use_case.execute(request)
|
||||
return web.json_response(result)
|
||||
except ImportExampleImagesValidationError as exc:
|
||||
return web.json_response({'success': False, 'error': str(exc)}, status=400)
|
||||
except ExampleImagesImportError as exc:
|
||||
return web.json_response({'success': False, 'error': str(exc)}, status=500)
|
||||
|
||||
async def delete_example_image(self, request: web.Request) -> web.StreamResponse:
|
||||
return await self._processor.delete_custom_image(request)
|
||||
|
||||
@@ -13,6 +13,13 @@ from .download_model_use_case import (
|
||||
DownloadModelUseCase,
|
||||
DownloadModelValidationError,
|
||||
)
|
||||
from .example_images import (
|
||||
DownloadExampleImagesConfigurationError,
|
||||
DownloadExampleImagesInProgressError,
|
||||
DownloadExampleImagesUseCase,
|
||||
ImportExampleImagesUseCase,
|
||||
ImportExampleImagesValidationError,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AutoOrganizeInProgressError",
|
||||
@@ -22,4 +29,9 @@ __all__ = [
|
||||
"DownloadModelEarlyAccessError",
|
||||
"DownloadModelUseCase",
|
||||
"DownloadModelValidationError",
|
||||
"DownloadExampleImagesConfigurationError",
|
||||
"DownloadExampleImagesInProgressError",
|
||||
"DownloadExampleImagesUseCase",
|
||||
"ImportExampleImagesUseCase",
|
||||
"ImportExampleImagesValidationError",
|
||||
]
|
||||
|
||||
19
py/services/use_cases/example_images/__init__.py
Normal file
19
py/services/use_cases/example_images/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""Example image specific use case exports."""
|
||||
|
||||
from .download_example_images_use_case import (
|
||||
DownloadExampleImagesUseCase,
|
||||
DownloadExampleImagesInProgressError,
|
||||
DownloadExampleImagesConfigurationError,
|
||||
)
|
||||
from .import_example_images_use_case import (
|
||||
ImportExampleImagesUseCase,
|
||||
ImportExampleImagesValidationError,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"DownloadExampleImagesUseCase",
|
||||
"DownloadExampleImagesInProgressError",
|
||||
"DownloadExampleImagesConfigurationError",
|
||||
"ImportExampleImagesUseCase",
|
||||
"ImportExampleImagesValidationError",
|
||||
]
|
||||
@@ -0,0 +1,42 @@
|
||||
"""Use case coordinating example image downloads."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from ....utils.example_images_download_manager import (
|
||||
DownloadConfigurationError,
|
||||
DownloadInProgressError,
|
||||
ExampleImagesDownloadError,
|
||||
)
|
||||
|
||||
|
||||
class DownloadExampleImagesInProgressError(RuntimeError):
|
||||
"""Raised when a download is already running."""
|
||||
|
||||
def __init__(self, progress: Dict[str, Any]) -> None:
|
||||
super().__init__("Download already in progress")
|
||||
self.progress = progress
|
||||
|
||||
|
||||
class DownloadExampleImagesConfigurationError(ValueError):
|
||||
"""Raised when settings prevent downloads from starting."""
|
||||
|
||||
|
||||
class DownloadExampleImagesUseCase:
|
||||
"""Validate payloads and trigger the download manager."""
|
||||
|
||||
def __init__(self, *, download_manager) -> None:
|
||||
self._download_manager = download_manager
|
||||
|
||||
async def execute(self, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Start a download and translate manager errors."""
|
||||
|
||||
try:
|
||||
return await self._download_manager.start_download(payload)
|
||||
except DownloadInProgressError as exc:
|
||||
raise DownloadExampleImagesInProgressError(exc.progress_snapshot) from exc
|
||||
except DownloadConfigurationError as exc:
|
||||
raise DownloadExampleImagesConfigurationError(str(exc)) from exc
|
||||
except ExampleImagesDownloadError:
|
||||
raise
|
||||
@@ -0,0 +1,86 @@
|
||||
"""Use case for importing example images."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from contextlib import suppress
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
from ....utils.example_images_processor import (
|
||||
ExampleImagesImportError,
|
||||
ExampleImagesProcessor,
|
||||
ExampleImagesValidationError,
|
||||
)
|
||||
|
||||
|
||||
class ImportExampleImagesValidationError(ValueError):
|
||||
"""Raised when request validation fails."""
|
||||
|
||||
|
||||
class ImportExampleImagesUseCase:
|
||||
"""Parse upload payloads and delegate to the processor service."""
|
||||
|
||||
def __init__(self, *, processor: ExampleImagesProcessor) -> None:
|
||||
self._processor = processor
|
||||
|
||||
async def execute(self, request: web.Request) -> Dict[str, Any]:
|
||||
model_hash: str | None = None
|
||||
files_to_import: List[str] = []
|
||||
temp_files: List[str] = []
|
||||
|
||||
try:
|
||||
if request.content_type and "multipart/form-data" in request.content_type:
|
||||
reader = await request.multipart()
|
||||
|
||||
first_field = await reader.next()
|
||||
if first_field and first_field.name == "model_hash":
|
||||
model_hash = await first_field.text()
|
||||
else:
|
||||
# Support clients that send files first and hash later
|
||||
if first_field is not None:
|
||||
await self._collect_upload_file(first_field, files_to_import, temp_files)
|
||||
|
||||
async for field in reader:
|
||||
if field.name == "model_hash" and not model_hash:
|
||||
model_hash = await field.text()
|
||||
elif field.name == "files":
|
||||
await self._collect_upload_file(field, files_to_import, temp_files)
|
||||
else:
|
||||
data = await request.json()
|
||||
model_hash = data.get("model_hash")
|
||||
files_to_import = list(data.get("file_paths", []))
|
||||
|
||||
result = await self._processor.import_images(model_hash, files_to_import)
|
||||
return result
|
||||
except ExampleImagesValidationError as exc:
|
||||
raise ImportExampleImagesValidationError(str(exc)) from exc
|
||||
except ExampleImagesImportError:
|
||||
raise
|
||||
finally:
|
||||
for path in temp_files:
|
||||
with suppress(Exception):
|
||||
os.remove(path)
|
||||
|
||||
async def _collect_upload_file(
|
||||
self,
|
||||
field: Any,
|
||||
files_to_import: List[str],
|
||||
temp_files: List[str],
|
||||
) -> None:
|
||||
"""Persist an uploaded file to disk and add it to the import list."""
|
||||
|
||||
filename = field.filename or "upload"
|
||||
file_ext = os.path.splitext(filename)[1].lower()
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=file_ext, delete=False) as tmp_file:
|
||||
temp_files.append(tmp_file.name)
|
||||
while True:
|
||||
chunk = await field.read_chunk()
|
||||
if not chunk:
|
||||
break
|
||||
tmp_file.write(chunk)
|
||||
|
||||
files_to_import.append(tmp_file.name)
|
||||
@@ -3,7 +3,6 @@ import os
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from aiohttp import web
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
from ..utils.metadata_manager import MetadataManager
|
||||
from .example_images_processor import ExampleImagesProcessor
|
||||
@@ -12,6 +11,30 @@ from ..services.websocket_manager import ws_manager # Add this import at the to
|
||||
from ..services.downloader import get_downloader
|
||||
from ..services.settings_manager import settings
|
||||
|
||||
|
||||
class ExampleImagesDownloadError(RuntimeError):
|
||||
"""Base error for example image download operations."""
|
||||
|
||||
|
||||
class DownloadInProgressError(ExampleImagesDownloadError):
|
||||
"""Raised when a download is already running."""
|
||||
|
||||
def __init__(self, progress_snapshot: dict) -> None:
|
||||
super().__init__("Download already in progress")
|
||||
self.progress_snapshot = progress_snapshot
|
||||
|
||||
|
||||
class DownloadNotRunningError(ExampleImagesDownloadError):
|
||||
"""Raised when pause/resume is requested without an active download."""
|
||||
|
||||
def __init__(self, message: str = "No download in progress") -> None:
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class DownloadConfigurationError(ExampleImagesDownloadError):
|
||||
"""Raised when configuration prevents starting a download."""
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Download status tracking
|
||||
@@ -31,11 +54,21 @@ download_progress = {
|
||||
'failed_models': set() # Track models that failed to download after metadata refresh
|
||||
}
|
||||
|
||||
|
||||
def _serialize_progress() -> dict:
|
||||
"""Return a JSON-serialisable snapshot of the current progress."""
|
||||
|
||||
snapshot = download_progress.copy()
|
||||
snapshot['processed_models'] = list(download_progress['processed_models'])
|
||||
snapshot['refreshed_models'] = list(download_progress['refreshed_models'])
|
||||
snapshot['failed_models'] = list(download_progress['failed_models'])
|
||||
return snapshot
|
||||
|
||||
class DownloadManager:
|
||||
"""Manages downloading example images for models"""
|
||||
|
||||
@staticmethod
|
||||
async def start_download(request):
|
||||
async def start_download(options: dict):
|
||||
"""
|
||||
Start downloading example images for models
|
||||
|
||||
@@ -50,25 +83,14 @@ class DownloadManager:
|
||||
global download_task, is_downloading, download_progress
|
||||
|
||||
if is_downloading:
|
||||
# Create a copy for JSON serialization
|
||||
response_progress = download_progress.copy()
|
||||
response_progress['processed_models'] = list(download_progress['processed_models'])
|
||||
response_progress['refreshed_models'] = list(download_progress['refreshed_models'])
|
||||
response_progress['failed_models'] = list(download_progress['failed_models'])
|
||||
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Download already in progress',
|
||||
'status': response_progress
|
||||
}, status=400)
|
||||
|
||||
raise DownloadInProgressError(_serialize_progress())
|
||||
|
||||
try:
|
||||
# Parse the request body
|
||||
data = await request.json()
|
||||
data = options or {}
|
||||
auto_mode = data.get('auto_mode', False)
|
||||
optimize = data.get('optimize', True)
|
||||
model_types = data.get('model_types', ['lora', 'checkpoint'])
|
||||
delay = float(data.get('delay', 0.2)) # Default to 0.2 seconds
|
||||
delay = float(data.get('delay', 0.2)) # Default to 0.2 seconds
|
||||
|
||||
# Get output directory from settings
|
||||
output_dir = settings.get('example_images_path')
|
||||
@@ -78,15 +100,11 @@ class DownloadManager:
|
||||
if auto_mode:
|
||||
# For auto mode, just log and return success to avoid showing error toasts
|
||||
logger.debug(error_msg)
|
||||
return web.json_response({
|
||||
return {
|
||||
'success': True,
|
||||
'message': 'Example images path not configured, skipping auto download'
|
||||
})
|
||||
else:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': error_msg
|
||||
}, status=400)
|
||||
}
|
||||
raise DownloadConfigurationError(error_msg)
|
||||
|
||||
# Create the output directory
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
@@ -129,41 +147,29 @@ class DownloadManager:
|
||||
)
|
||||
)
|
||||
|
||||
# Create a copy for JSON serialization
|
||||
response_progress = download_progress.copy()
|
||||
response_progress['processed_models'] = list(download_progress['processed_models'])
|
||||
response_progress['refreshed_models'] = list(download_progress['refreshed_models'])
|
||||
response_progress['failed_models'] = list(download_progress['failed_models'])
|
||||
|
||||
return web.json_response({
|
||||
return {
|
||||
'success': True,
|
||||
'message': 'Download started',
|
||||
'status': response_progress
|
||||
})
|
||||
|
||||
'status': _serialize_progress()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start example images download: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
raise ExampleImagesDownloadError(str(e)) from e
|
||||
|
||||
@staticmethod
|
||||
async def get_status(request):
|
||||
"""Get the current status of example images download"""
|
||||
global download_progress
|
||||
|
||||
|
||||
# Create a copy of the progress dict with the set converted to a list for JSON serialization
|
||||
response_progress = download_progress.copy()
|
||||
response_progress['processed_models'] = list(download_progress['processed_models'])
|
||||
response_progress['refreshed_models'] = list(download_progress['refreshed_models'])
|
||||
response_progress['failed_models'] = list(download_progress['failed_models'])
|
||||
|
||||
return web.json_response({
|
||||
response_progress = _serialize_progress()
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'is_downloading': is_downloading,
|
||||
'status': response_progress
|
||||
})
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
async def pause_download(request):
|
||||
@@ -171,17 +177,14 @@ class DownloadManager:
|
||||
global download_progress
|
||||
|
||||
if not is_downloading:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'No download in progress'
|
||||
}, status=400)
|
||||
|
||||
raise DownloadNotRunningError()
|
||||
|
||||
download_progress['status'] = 'paused'
|
||||
|
||||
return web.json_response({
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'message': 'Download paused'
|
||||
})
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
async def resume_download(request):
|
||||
@@ -189,23 +192,19 @@ class DownloadManager:
|
||||
global download_progress
|
||||
|
||||
if not is_downloading:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'No download in progress'
|
||||
}, status=400)
|
||||
|
||||
raise DownloadNotRunningError()
|
||||
|
||||
if download_progress['status'] == 'paused':
|
||||
download_progress['status'] = 'running'
|
||||
|
||||
return web.json_response({
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'message': 'Download resumed'
|
||||
})
|
||||
else:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': f"Download is in '{download_progress['status']}' state, cannot resume"
|
||||
}, status=400)
|
||||
}
|
||||
|
||||
raise DownloadNotRunningError(
|
||||
f"Download is in '{download_progress['status']}' state, cannot resume"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _download_all_example_images(output_dir, optimize, model_types, delay):
|
||||
@@ -432,7 +431,7 @@ class DownloadManager:
|
||||
logger.error(f"Failed to save progress file: {e}")
|
||||
|
||||
@staticmethod
|
||||
async def start_force_download(request):
|
||||
async def start_force_download(options: dict):
|
||||
"""
|
||||
Force download example images for specific models
|
||||
|
||||
@@ -447,33 +446,23 @@ class DownloadManager:
|
||||
global download_task, is_downloading, download_progress
|
||||
|
||||
if is_downloading:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Download already in progress'
|
||||
}, status=400)
|
||||
raise DownloadInProgressError(_serialize_progress())
|
||||
|
||||
try:
|
||||
# Parse the request body
|
||||
data = await request.json()
|
||||
data = options or {}
|
||||
model_hashes = data.get('model_hashes', [])
|
||||
optimize = data.get('optimize', True)
|
||||
model_types = data.get('model_types', ['lora', 'checkpoint'])
|
||||
delay = float(data.get('delay', 0.2)) # Default to 0.2 seconds
|
||||
|
||||
delay = float(data.get('delay', 0.2)) # Default to 0.2 seconds
|
||||
|
||||
if not model_hashes:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Missing model_hashes parameter'
|
||||
}, status=400)
|
||||
|
||||
raise DownloadConfigurationError('Missing model_hashes parameter')
|
||||
|
||||
# Get output directory from settings
|
||||
output_dir = settings.get('example_images_path')
|
||||
|
||||
|
||||
if not output_dir:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Example images path not configured in settings'
|
||||
}, status=400)
|
||||
raise DownloadConfigurationError('Example images path not configured in settings')
|
||||
|
||||
# Create the output directory
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
@@ -506,20 +495,17 @@ class DownloadManager:
|
||||
# Set download status to not downloading
|
||||
is_downloading = False
|
||||
|
||||
return web.json_response({
|
||||
return {
|
||||
'success': True,
|
||||
'message': 'Force download completed',
|
||||
'result': result
|
||||
})
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
# Set download status to not downloading
|
||||
is_downloading = False
|
||||
logger.error(f"Failed during forced example images download: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
raise ExampleImagesDownloadError(str(e)) from e
|
||||
|
||||
@staticmethod
|
||||
async def _download_specific_models_example_images_sync(model_hashes, output_dir, optimize, model_types, delay):
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
import random
|
||||
import string
|
||||
from aiohttp import web
|
||||
@@ -13,6 +12,14 @@ from ..utils.metadata_manager import MetadataManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ExampleImagesImportError(RuntimeError):
|
||||
"""Base error for example image import operations."""
|
||||
|
||||
|
||||
class ExampleImagesValidationError(ExampleImagesImportError):
|
||||
"""Raised when input validation fails."""
|
||||
|
||||
class ExampleImagesProcessor:
|
||||
"""Processes and manipulates example images"""
|
||||
|
||||
@@ -299,90 +306,29 @@ class ExampleImagesProcessor:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
async def import_images(request):
|
||||
"""
|
||||
Import local example images
|
||||
|
||||
Accepts:
|
||||
- multipart/form-data form with model_hash and files fields
|
||||
or
|
||||
- JSON request with model_hash and file_paths
|
||||
|
||||
Returns:
|
||||
- Success status and list of imported files
|
||||
"""
|
||||
async def import_images(model_hash: str, files_to_import: list[str]):
|
||||
"""Import local example images for a model."""
|
||||
|
||||
if not model_hash:
|
||||
raise ExampleImagesValidationError('Missing model_hash parameter')
|
||||
|
||||
if not files_to_import:
|
||||
raise ExampleImagesValidationError('No files provided to import')
|
||||
|
||||
try:
|
||||
model_hash = None
|
||||
files_to_import = []
|
||||
temp_files_to_cleanup = []
|
||||
|
||||
# Check if it's a multipart form-data request (direct file upload)
|
||||
if request.content_type and 'multipart/form-data' in request.content_type:
|
||||
reader = await request.multipart()
|
||||
|
||||
# First get model_hash
|
||||
field = await reader.next()
|
||||
if field.name == 'model_hash':
|
||||
model_hash = await field.text()
|
||||
|
||||
# Then process all files
|
||||
while True:
|
||||
field = await reader.next()
|
||||
if field is None:
|
||||
break
|
||||
|
||||
if field.name == 'files':
|
||||
# Create a temporary file with appropriate suffix for type detection
|
||||
file_name = field.filename
|
||||
file_ext = os.path.splitext(file_name)[1].lower()
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=file_ext, delete=False) as tmp_file:
|
||||
temp_path = tmp_file.name
|
||||
temp_files_to_cleanup.append(temp_path) # Track for cleanup
|
||||
|
||||
# Write chunks to the temporary file
|
||||
while True:
|
||||
chunk = await field.read_chunk()
|
||||
if not chunk:
|
||||
break
|
||||
tmp_file.write(chunk)
|
||||
|
||||
# Add to the list of files to process
|
||||
files_to_import.append(temp_path)
|
||||
else:
|
||||
# Parse JSON request (legacy method using file paths)
|
||||
data = await request.json()
|
||||
model_hash = data.get('model_hash')
|
||||
files_to_import = data.get('file_paths', [])
|
||||
|
||||
if not model_hash:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Missing model_hash parameter'
|
||||
}, status=400)
|
||||
|
||||
if not files_to_import:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'No files provided to import'
|
||||
}, status=400)
|
||||
|
||||
# Get example images path
|
||||
example_images_path = settings.get('example_images_path')
|
||||
if not example_images_path:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'No example images path configured'
|
||||
}, status=400)
|
||||
|
||||
raise ExampleImagesValidationError('No example images path configured')
|
||||
|
||||
# Find the model and get current metadata
|
||||
lora_scanner = await ServiceRegistry.get_lora_scanner()
|
||||
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
|
||||
|
||||
|
||||
model_data = None
|
||||
scanner = None
|
||||
|
||||
|
||||
# Check both scanners to find the model
|
||||
for scan_obj in [lora_scanner, checkpoint_scanner, embedding_scanner]:
|
||||
cache = await scan_obj.get_cached_data()
|
||||
@@ -393,21 +339,20 @@ class ExampleImagesProcessor:
|
||||
break
|
||||
if model_data:
|
||||
break
|
||||
|
||||
|
||||
if not model_data:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': f"Model with hash {model_hash} not found in cache"
|
||||
}, status=404)
|
||||
|
||||
raise ExampleImagesImportError(
|
||||
f"Model with hash {model_hash} not found in cache"
|
||||
)
|
||||
|
||||
# Create model folder
|
||||
model_folder = os.path.join(example_images_path, model_hash)
|
||||
os.makedirs(model_folder, exist_ok=True)
|
||||
|
||||
|
||||
imported_files = []
|
||||
errors = []
|
||||
newly_imported_paths = []
|
||||
|
||||
|
||||
# Process each file path
|
||||
for file_path in files_to_import:
|
||||
try:
|
||||
@@ -415,26 +360,26 @@ class ExampleImagesProcessor:
|
||||
if not os.path.isfile(file_path):
|
||||
errors.append(f"File not found: {file_path}")
|
||||
continue
|
||||
|
||||
|
||||
# Check if file type is supported
|
||||
file_ext = os.path.splitext(file_path)[1].lower()
|
||||
if not (file_ext in SUPPORTED_MEDIA_EXTENSIONS['images'] or
|
||||
if not (file_ext in SUPPORTED_MEDIA_EXTENSIONS['images'] or
|
||||
file_ext in SUPPORTED_MEDIA_EXTENSIONS['videos']):
|
||||
errors.append(f"Unsupported file type: {file_path}")
|
||||
continue
|
||||
|
||||
|
||||
# Generate new filename using short ID instead of UUID
|
||||
short_id = ExampleImagesProcessor.generate_short_id()
|
||||
new_filename = f"custom_{short_id}{file_ext}"
|
||||
|
||||
|
||||
dest_path = os.path.join(model_folder, new_filename)
|
||||
|
||||
|
||||
# Copy the file
|
||||
import shutil
|
||||
shutil.copy2(file_path, dest_path)
|
||||
# Store both the dest_path and the short_id
|
||||
newly_imported_paths.append((dest_path, short_id))
|
||||
|
||||
|
||||
# Add to imported files list
|
||||
imported_files.append({
|
||||
'name': new_filename,
|
||||
@@ -444,39 +389,31 @@ class ExampleImagesProcessor:
|
||||
})
|
||||
except Exception as e:
|
||||
errors.append(f"Error importing {file_path}: {str(e)}")
|
||||
|
||||
|
||||
# Update metadata with new example images
|
||||
regular_images, custom_images = await MetadataUpdater.update_metadata_after_import(
|
||||
model_hash,
|
||||
model_hash,
|
||||
model_data,
|
||||
scanner,
|
||||
newly_imported_paths
|
||||
)
|
||||
|
||||
return web.json_response({
|
||||
|
||||
return {
|
||||
'success': len(imported_files) > 0,
|
||||
'message': f'Successfully imported {len(imported_files)} files' +
|
||||
'message': f'Successfully imported {len(imported_files)} files' +
|
||||
(f' with {len(errors)} errors' if errors else ''),
|
||||
'files': imported_files,
|
||||
'errors': errors,
|
||||
'regular_images': regular_images,
|
||||
'custom_images': custom_images,
|
||||
"model_file_path": model_data.get('file_path', ''),
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
except ExampleImagesImportError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to import example images: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
finally:
|
||||
# Clean up temporary files
|
||||
for temp_file in temp_files_to_cleanup:
|
||||
try:
|
||||
os.remove(temp_file)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to remove temporary file {temp_file}: {e}")
|
||||
raise ExampleImagesImportError(str(e)) from e
|
||||
|
||||
@staticmethod
|
||||
async def delete_custom_image(request):
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, List, Tuple
|
||||
@@ -33,37 +34,35 @@ class StubDownloadManager:
|
||||
def __init__(self) -> None:
|
||||
self.calls: List[Tuple[str, Any]] = []
|
||||
|
||||
async def start_download(self, request: web.Request) -> web.StreamResponse:
|
||||
payload = await request.json()
|
||||
async def start_download(self, payload: Any) -> dict:
|
||||
self.calls.append(("start_download", payload))
|
||||
return web.json_response({"operation": "start_download", "payload": payload})
|
||||
return {"operation": "start_download", "payload": payload}
|
||||
|
||||
async def get_status(self, request: web.Request) -> web.StreamResponse:
|
||||
async def get_status(self, request: web.Request) -> dict:
|
||||
self.calls.append(("get_status", dict(request.query)))
|
||||
return web.json_response({"operation": "get_status"})
|
||||
return {"operation": "get_status"}
|
||||
|
||||
async def pause_download(self, request: web.Request) -> web.StreamResponse:
|
||||
async def pause_download(self, request: web.Request) -> dict:
|
||||
self.calls.append(("pause_download", None))
|
||||
return web.json_response({"operation": "pause_download"})
|
||||
return {"operation": "pause_download"}
|
||||
|
||||
async def resume_download(self, request: web.Request) -> web.StreamResponse:
|
||||
async def resume_download(self, request: web.Request) -> dict:
|
||||
self.calls.append(("resume_download", None))
|
||||
return web.json_response({"operation": "resume_download"})
|
||||
return {"operation": "resume_download"}
|
||||
|
||||
async def start_force_download(self, request: web.Request) -> web.StreamResponse:
|
||||
payload = await request.json()
|
||||
async def start_force_download(self, payload: Any) -> dict:
|
||||
self.calls.append(("start_force_download", payload))
|
||||
return web.json_response({"operation": "start_force_download", "payload": payload})
|
||||
return {"operation": "start_force_download", "payload": payload}
|
||||
|
||||
|
||||
class StubExampleImagesProcessor:
|
||||
def __init__(self) -> None:
|
||||
self.calls: List[Tuple[str, Any]] = []
|
||||
|
||||
async def import_images(self, request: web.Request) -> web.StreamResponse:
|
||||
payload = await request.json()
|
||||
async def import_images(self, model_hash: str, files: List[str]) -> dict:
|
||||
payload = {"model_hash": model_hash, "file_paths": files}
|
||||
self.calls.append(("import_images", payload))
|
||||
return web.json_response({"operation": "import_images", "payload": payload})
|
||||
return {"operation": "import_images", "payload": payload}
|
||||
|
||||
async def delete_custom_image(self, request: web.Request) -> web.StreamResponse:
|
||||
payload = await request.json()
|
||||
@@ -184,7 +183,7 @@ async def test_pause_and_resume_routes_delegate():
|
||||
|
||||
|
||||
async def test_import_route_delegates_to_processor():
|
||||
payload = {"model_hash": "abc123", "files": ["/path/image.png"]}
|
||||
payload = {"model_hash": "abc123", "file_paths": ["/path/image.png"]}
|
||||
async with example_images_app() as harness:
|
||||
response = await harness.client.post(
|
||||
"/api/lm/import-example-images", json=payload
|
||||
@@ -193,7 +192,8 @@ async def test_import_route_delegates_to_processor():
|
||||
|
||||
assert response.status == 200
|
||||
assert body == {"operation": "import_images", "payload": payload}
|
||||
assert harness.processor.calls == [("import_images", payload)]
|
||||
expected_call = ("import_images", payload)
|
||||
assert expected_call in harness.processor.calls
|
||||
|
||||
|
||||
async def test_delete_route_delegates_to_processor():
|
||||
@@ -251,70 +251,91 @@ async def test_download_handler_methods_delegate() -> None:
|
||||
def __init__(self) -> None:
|
||||
self.calls: List[Tuple[str, Any]] = []
|
||||
|
||||
async def start_download(self, request) -> str:
|
||||
self.calls.append(("start_download", request))
|
||||
return "download"
|
||||
|
||||
async def get_status(self, request) -> str:
|
||||
async def get_status(self, request) -> dict:
|
||||
self.calls.append(("get_status", request))
|
||||
return "status"
|
||||
return {"status": "ok"}
|
||||
|
||||
async def pause_download(self, request) -> str:
|
||||
async def pause_download(self, request) -> dict:
|
||||
self.calls.append(("pause_download", request))
|
||||
return "pause"
|
||||
return {"status": "paused"}
|
||||
|
||||
async def resume_download(self, request) -> str:
|
||||
async def resume_download(self, request) -> dict:
|
||||
self.calls.append(("resume_download", request))
|
||||
return "resume"
|
||||
return {"status": "running"}
|
||||
|
||||
async def start_force_download(self, request) -> str:
|
||||
self.calls.append(("start_force_download", request))
|
||||
return "force"
|
||||
async def start_force_download(self, payload) -> dict:
|
||||
self.calls.append(("start_force_download", payload))
|
||||
return {"status": "force", "payload": payload}
|
||||
|
||||
class StubDownloadUseCase:
|
||||
def __init__(self) -> None:
|
||||
self.payloads: List[Any] = []
|
||||
|
||||
async def execute(self, payload: dict) -> dict:
|
||||
self.payloads.append(payload)
|
||||
return {"status": "started", "payload": payload}
|
||||
|
||||
class DummyRequest:
|
||||
def __init__(self, payload: dict) -> None:
|
||||
self._payload = payload
|
||||
self.query = {}
|
||||
|
||||
async def json(self) -> dict:
|
||||
return self._payload
|
||||
|
||||
recorder = Recorder()
|
||||
handler = ExampleImagesDownloadHandler(recorder)
|
||||
request = object()
|
||||
use_case = StubDownloadUseCase()
|
||||
handler = ExampleImagesDownloadHandler(use_case, recorder)
|
||||
request = DummyRequest({"foo": "bar"})
|
||||
|
||||
assert await handler.download_example_images(request) == "download"
|
||||
assert await handler.get_example_images_status(request) == "status"
|
||||
assert await handler.pause_example_images(request) == "pause"
|
||||
assert await handler.resume_example_images(request) == "resume"
|
||||
assert await handler.force_download_example_images(request) == "force"
|
||||
download_response = await handler.download_example_images(request)
|
||||
assert json.loads(download_response.text) == {"status": "started", "payload": {"foo": "bar"}}
|
||||
status_response = await handler.get_example_images_status(request)
|
||||
assert json.loads(status_response.text) == {"status": "ok"}
|
||||
pause_response = await handler.pause_example_images(request)
|
||||
assert json.loads(pause_response.text) == {"status": "paused"}
|
||||
resume_response = await handler.resume_example_images(request)
|
||||
assert json.loads(resume_response.text) == {"status": "running"}
|
||||
force_response = await handler.force_download_example_images(request)
|
||||
assert json.loads(force_response.text) == {"status": "force", "payload": {"foo": "bar"}}
|
||||
|
||||
expected = [
|
||||
("start_download", request),
|
||||
assert use_case.payloads == [{"foo": "bar"}]
|
||||
assert recorder.calls == [
|
||||
("get_status", request),
|
||||
("pause_download", request),
|
||||
("resume_download", request),
|
||||
("start_force_download", request),
|
||||
("start_force_download", {"foo": "bar"}),
|
||||
]
|
||||
assert recorder.calls == expected
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_management_handler_methods_delegate() -> None:
|
||||
class StubImportUseCase:
|
||||
def __init__(self) -> None:
|
||||
self.requests: List[Any] = []
|
||||
|
||||
async def execute(self, request: Any) -> dict:
|
||||
self.requests.append(request)
|
||||
return {"status": "imported"}
|
||||
|
||||
class Recorder:
|
||||
def __init__(self) -> None:
|
||||
self.calls: List[Tuple[str, Any]] = []
|
||||
|
||||
async def import_images(self, request) -> str:
|
||||
self.calls.append(("import_images", request))
|
||||
return "import"
|
||||
|
||||
async def delete_custom_image(self, request) -> str:
|
||||
self.calls.append(("delete_custom_image", request))
|
||||
return "delete"
|
||||
|
||||
recorder = Recorder()
|
||||
handler = ExampleImagesManagementHandler(recorder)
|
||||
use_case = StubImportUseCase()
|
||||
handler = ExampleImagesManagementHandler(use_case, recorder)
|
||||
request = object()
|
||||
|
||||
assert await handler.import_example_images(request) == "import"
|
||||
import_response = await handler.import_example_images(request)
|
||||
assert json.loads(import_response.text) == {"status": "imported"}
|
||||
assert await handler.delete_example_image(request) == "delete"
|
||||
assert recorder.calls == [
|
||||
("import_images", request),
|
||||
("delete_custom_image", request),
|
||||
]
|
||||
assert use_case.requests == [request]
|
||||
assert recorder.calls == [("delete_custom_image", request)]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -350,8 +371,29 @@ async def test_file_handler_methods_delegate() -> None:
|
||||
|
||||
|
||||
def test_handler_set_route_mapping_includes_all_handlers() -> None:
|
||||
download = ExampleImagesDownloadHandler(object())
|
||||
management = ExampleImagesManagementHandler(object())
|
||||
class DummyUseCase:
|
||||
async def execute(self, payload):
|
||||
return payload
|
||||
|
||||
class DummyManager:
|
||||
async def get_status(self, request):
|
||||
return {}
|
||||
|
||||
async def pause_download(self, request):
|
||||
return {}
|
||||
|
||||
async def resume_download(self, request):
|
||||
return {}
|
||||
|
||||
async def start_force_download(self, payload):
|
||||
return payload
|
||||
|
||||
class DummyProcessor:
|
||||
async def delete_custom_image(self, request):
|
||||
return {}
|
||||
|
||||
download = ExampleImagesDownloadHandler(DummyUseCase(), DummyManager())
|
||||
management = ExampleImagesManagementHandler(DummyUseCase(), DummyProcessor())
|
||||
files = ExampleImagesFileHandler(object())
|
||||
handler_set = ExampleImagesHandlerSet(
|
||||
download=download,
|
||||
|
||||
@@ -10,9 +10,23 @@ from py_local.services.use_cases import (
|
||||
AutoOrganizeInProgressError,
|
||||
AutoOrganizeUseCase,
|
||||
BulkMetadataRefreshUseCase,
|
||||
DownloadExampleImagesConfigurationError,
|
||||
DownloadExampleImagesInProgressError,
|
||||
DownloadExampleImagesUseCase,
|
||||
DownloadModelEarlyAccessError,
|
||||
DownloadModelUseCase,
|
||||
DownloadModelValidationError,
|
||||
ImportExampleImagesUseCase,
|
||||
ImportExampleImagesValidationError,
|
||||
)
|
||||
from py_local.utils.example_images_download_manager import (
|
||||
DownloadConfigurationError,
|
||||
DownloadInProgressError,
|
||||
ExampleImagesDownloadError,
|
||||
)
|
||||
from py_local.utils.example_images_processor import (
|
||||
ExampleImagesImportError,
|
||||
ExampleImagesValidationError,
|
||||
)
|
||||
from tests.conftest import MockModelService, MockScanner
|
||||
|
||||
@@ -88,6 +102,38 @@ class StubDownloadCoordinator:
|
||||
return {"success": True, "download_id": "abc123"}
|
||||
|
||||
|
||||
class StubExampleImagesDownloadManager:
|
||||
def __init__(self) -> None:
|
||||
self.payloads: List[Dict[str, Any]] = []
|
||||
self.error: Optional[str] = None
|
||||
self.progress_snapshot = {"status": "running"}
|
||||
|
||||
async def start_download(self, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
self.payloads.append(payload)
|
||||
if self.error == "in_progress":
|
||||
raise DownloadInProgressError(self.progress_snapshot)
|
||||
if self.error == "configuration":
|
||||
raise DownloadConfigurationError("path missing")
|
||||
if self.error == "generic":
|
||||
raise ExampleImagesDownloadError("boom")
|
||||
return {"success": True, "message": "ok"}
|
||||
|
||||
|
||||
class StubExampleImagesProcessor:
|
||||
def __init__(self) -> None:
|
||||
self.calls: List[Dict[str, Any]] = []
|
||||
self.error: Optional[str] = None
|
||||
self.response: Dict[str, Any] = {"success": True}
|
||||
|
||||
async def import_images(self, model_hash: str, files: List[str]) -> Dict[str, Any]:
|
||||
self.calls.append({"model_hash": model_hash, "files": files})
|
||||
if self.error == "validation":
|
||||
raise ExampleImagesValidationError("missing")
|
||||
if self.error == "generic":
|
||||
raise ExampleImagesImportError("boom")
|
||||
return self.response
|
||||
|
||||
|
||||
async def test_auto_organize_use_case_executes_with_lock() -> None:
|
||||
file_service = StubFileService()
|
||||
lock_provider = StubLockProvider()
|
||||
@@ -189,3 +235,83 @@ async def test_download_model_use_case_returns_result() -> None:
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["download_id"] == "abc123"
|
||||
|
||||
|
||||
async def test_download_example_images_use_case_triggers_manager() -> None:
|
||||
manager = StubExampleImagesDownloadManager()
|
||||
use_case = DownloadExampleImagesUseCase(download_manager=manager)
|
||||
|
||||
payload = {"optimize": True}
|
||||
result = await use_case.execute(payload)
|
||||
|
||||
assert manager.payloads == [payload]
|
||||
assert result == {"success": True, "message": "ok"}
|
||||
|
||||
|
||||
async def test_download_example_images_use_case_maps_in_progress() -> None:
|
||||
manager = StubExampleImagesDownloadManager()
|
||||
manager.error = "in_progress"
|
||||
use_case = DownloadExampleImagesUseCase(download_manager=manager)
|
||||
|
||||
with pytest.raises(DownloadExampleImagesInProgressError) as exc:
|
||||
await use_case.execute({})
|
||||
|
||||
assert exc.value.progress == manager.progress_snapshot
|
||||
|
||||
|
||||
async def test_download_example_images_use_case_maps_configuration() -> None:
|
||||
manager = StubExampleImagesDownloadManager()
|
||||
manager.error = "configuration"
|
||||
use_case = DownloadExampleImagesUseCase(download_manager=manager)
|
||||
|
||||
with pytest.raises(DownloadExampleImagesConfigurationError):
|
||||
await use_case.execute({})
|
||||
|
||||
|
||||
async def test_download_example_images_use_case_propagates_generic_error() -> None:
|
||||
manager = StubExampleImagesDownloadManager()
|
||||
manager.error = "generic"
|
||||
use_case = DownloadExampleImagesUseCase(download_manager=manager)
|
||||
|
||||
with pytest.raises(ExampleImagesDownloadError):
|
||||
await use_case.execute({})
|
||||
|
||||
|
||||
class DummyJsonRequest:
|
||||
def __init__(self, payload: Dict[str, Any]) -> None:
|
||||
self._payload = payload
|
||||
self.content_type = "application/json"
|
||||
|
||||
async def json(self) -> Dict[str, Any]:
|
||||
return self._payload
|
||||
|
||||
|
||||
async def test_import_example_images_use_case_delegates() -> None:
|
||||
processor = StubExampleImagesProcessor()
|
||||
use_case = ImportExampleImagesUseCase(processor=processor)
|
||||
|
||||
request = DummyJsonRequest({"model_hash": "abc", "file_paths": ["/tmp/file"]})
|
||||
result = await use_case.execute(request)
|
||||
|
||||
assert processor.calls == [{"model_hash": "abc", "files": ["/tmp/file"]}]
|
||||
assert result == {"success": True}
|
||||
|
||||
|
||||
async def test_import_example_images_use_case_maps_validation_error() -> None:
|
||||
processor = StubExampleImagesProcessor()
|
||||
processor.error = "validation"
|
||||
use_case = ImportExampleImagesUseCase(processor=processor)
|
||||
request = DummyJsonRequest({"model_hash": None, "file_paths": []})
|
||||
|
||||
with pytest.raises(ImportExampleImagesValidationError):
|
||||
await use_case.execute(request)
|
||||
|
||||
|
||||
async def test_import_example_images_use_case_propagates_generic_error() -> None:
|
||||
processor = StubExampleImagesProcessor()
|
||||
processor.error = "generic"
|
||||
use_case = ImportExampleImagesUseCase(processor=processor)
|
||||
request = DummyJsonRequest({"model_hash": "abc", "file_paths": ["/tmp/file"]})
|
||||
|
||||
with pytest.raises(ExampleImagesImportError):
|
||||
await use_case.execute(request)
|
||||
|
||||
Reference in New Issue
Block a user