Merge pull request #462 from willmiao/codex/wrap-long-running-flows-in-use-cases

feat(example-images): add use case orchestration
This commit is contained in:
pixelpaws
2025-09-23 11:55:06 +08:00
committed by GitHub
10 changed files with 582 additions and 262 deletions

View File

@@ -12,6 +12,10 @@ from .handlers.example_images_handlers import (
ExampleImagesHandlerSet, ExampleImagesHandlerSet,
ExampleImagesManagementHandler, ExampleImagesManagementHandler,
) )
from ..services.use_cases.example_images import (
DownloadExampleImagesUseCase,
ImportExampleImagesUseCase,
)
from ..utils.example_images_download_manager import DownloadManager from ..utils.example_images_download_manager import DownloadManager
from ..utils.example_images_file_manager import ExampleImagesFileManager from ..utils.example_images_file_manager import ExampleImagesFileManager
from ..utils.example_images_processor import ExampleImagesProcessor from ..utils.example_images_processor import ExampleImagesProcessor
@@ -59,8 +63,10 @@ class ExampleImagesRoutes:
def _build_handler_set(self) -> ExampleImagesHandlerSet: def _build_handler_set(self) -> ExampleImagesHandlerSet:
logger.debug("Building ExampleImagesHandlerSet with %s, %s, %s", self._download_manager, self._processor, self._file_manager) logger.debug("Building ExampleImagesHandlerSet with %s, %s, %s", self._download_manager, self._processor, self._file_manager)
download_handler = ExampleImagesDownloadHandler(self._download_manager) download_use_case = DownloadExampleImagesUseCase(download_manager=self._download_manager)
management_handler = ExampleImagesManagementHandler(self._processor) download_handler = ExampleImagesDownloadHandler(download_use_case, self._download_manager)
import_use_case = ImportExampleImagesUseCase(processor=self._processor)
management_handler = ExampleImagesManagementHandler(import_use_case, self._processor)
file_handler = ExampleImagesFileHandler(self._file_manager) file_handler = ExampleImagesFileHandler(self._file_manager)
return ExampleImagesHandlerSet( return ExampleImagesHandlerSet(
download=download_handler, download=download_handler,

View File

@@ -6,37 +6,101 @@ from typing import Callable, Mapping
from aiohttp import web from aiohttp import web
from ...services.use_cases.example_images import (
DownloadExampleImagesConfigurationError,
DownloadExampleImagesInProgressError,
DownloadExampleImagesUseCase,
ImportExampleImagesUseCase,
ImportExampleImagesValidationError,
)
from ...utils.example_images_download_manager import (
DownloadConfigurationError,
DownloadInProgressError,
DownloadNotRunningError,
ExampleImagesDownloadError,
)
from ...utils.example_images_processor import ExampleImagesImportError
class ExampleImagesDownloadHandler: class ExampleImagesDownloadHandler:
"""HTTP adapters for download-related example image endpoints.""" """HTTP adapters for download-related example image endpoints."""
def __init__(self, download_manager) -> None: def __init__(
self,
download_use_case: DownloadExampleImagesUseCase,
download_manager,
) -> None:
self._download_use_case = download_use_case
self._download_manager = download_manager self._download_manager = download_manager
async def download_example_images(self, request: web.Request) -> web.StreamResponse: async def download_example_images(self, request: web.Request) -> web.StreamResponse:
return await self._download_manager.start_download(request) try:
payload = await request.json()
result = await self._download_use_case.execute(payload)
return web.json_response(result)
except DownloadExampleImagesInProgressError as exc:
response = {
'success': False,
'error': str(exc),
'status': exc.progress,
}
return web.json_response(response, status=400)
except DownloadExampleImagesConfigurationError as exc:
return web.json_response({'success': False, 'error': str(exc)}, status=400)
except ExampleImagesDownloadError as exc:
return web.json_response({'success': False, 'error': str(exc)}, status=500)
async def get_example_images_status(self, request: web.Request) -> web.StreamResponse: async def get_example_images_status(self, request: web.Request) -> web.StreamResponse:
return await self._download_manager.get_status(request) result = await self._download_manager.get_status(request)
return web.json_response(result)
async def pause_example_images(self, request: web.Request) -> web.StreamResponse: async def pause_example_images(self, request: web.Request) -> web.StreamResponse:
return await self._download_manager.pause_download(request) try:
result = await self._download_manager.pause_download(request)
return web.json_response(result)
except DownloadNotRunningError as exc:
return web.json_response({'success': False, 'error': str(exc)}, status=400)
async def resume_example_images(self, request: web.Request) -> web.StreamResponse: async def resume_example_images(self, request: web.Request) -> web.StreamResponse:
return await self._download_manager.resume_download(request) try:
result = await self._download_manager.resume_download(request)
return web.json_response(result)
except DownloadNotRunningError as exc:
return web.json_response({'success': False, 'error': str(exc)}, status=400)
async def force_download_example_images(self, request: web.Request) -> web.StreamResponse: async def force_download_example_images(self, request: web.Request) -> web.StreamResponse:
return await self._download_manager.start_force_download(request) try:
payload = await request.json()
result = await self._download_manager.start_force_download(payload)
return web.json_response(result)
except DownloadInProgressError as exc:
response = {
'success': False,
'error': str(exc),
'status': exc.progress_snapshot,
}
return web.json_response(response, status=400)
except DownloadConfigurationError as exc:
return web.json_response({'success': False, 'error': str(exc)}, status=400)
except ExampleImagesDownloadError as exc:
return web.json_response({'success': False, 'error': str(exc)}, status=500)
class ExampleImagesManagementHandler: class ExampleImagesManagementHandler:
"""HTTP adapters for import/delete endpoints.""" """HTTP adapters for import/delete endpoints."""
def __init__(self, processor) -> None: def __init__(self, import_use_case: ImportExampleImagesUseCase, processor) -> None:
self._import_use_case = import_use_case
self._processor = processor self._processor = processor
async def import_example_images(self, request: web.Request) -> web.StreamResponse: async def import_example_images(self, request: web.Request) -> web.StreamResponse:
return await self._processor.import_images(request) try:
result = await self._import_use_case.execute(request)
return web.json_response(result)
except ImportExampleImagesValidationError as exc:
return web.json_response({'success': False, 'error': str(exc)}, status=400)
except ExampleImagesImportError as exc:
return web.json_response({'success': False, 'error': str(exc)}, status=500)
async def delete_example_image(self, request: web.Request) -> web.StreamResponse: async def delete_example_image(self, request: web.Request) -> web.StreamResponse:
return await self._processor.delete_custom_image(request) return await self._processor.delete_custom_image(request)

View File

@@ -13,6 +13,13 @@ from .download_model_use_case import (
DownloadModelUseCase, DownloadModelUseCase,
DownloadModelValidationError, DownloadModelValidationError,
) )
from .example_images import (
DownloadExampleImagesConfigurationError,
DownloadExampleImagesInProgressError,
DownloadExampleImagesUseCase,
ImportExampleImagesUseCase,
ImportExampleImagesValidationError,
)
__all__ = [ __all__ = [
"AutoOrganizeInProgressError", "AutoOrganizeInProgressError",
@@ -22,4 +29,9 @@ __all__ = [
"DownloadModelEarlyAccessError", "DownloadModelEarlyAccessError",
"DownloadModelUseCase", "DownloadModelUseCase",
"DownloadModelValidationError", "DownloadModelValidationError",
"DownloadExampleImagesConfigurationError",
"DownloadExampleImagesInProgressError",
"DownloadExampleImagesUseCase",
"ImportExampleImagesUseCase",
"ImportExampleImagesValidationError",
] ]

View File

@@ -0,0 +1,19 @@
"""Example image specific use case exports."""
from .download_example_images_use_case import (
DownloadExampleImagesUseCase,
DownloadExampleImagesInProgressError,
DownloadExampleImagesConfigurationError,
)
from .import_example_images_use_case import (
ImportExampleImagesUseCase,
ImportExampleImagesValidationError,
)
__all__ = [
"DownloadExampleImagesUseCase",
"DownloadExampleImagesInProgressError",
"DownloadExampleImagesConfigurationError",
"ImportExampleImagesUseCase",
"ImportExampleImagesValidationError",
]

View File

@@ -0,0 +1,42 @@
"""Use case coordinating example image downloads."""
from __future__ import annotations
from typing import Any, Dict
from ....utils.example_images_download_manager import (
DownloadConfigurationError,
DownloadInProgressError,
ExampleImagesDownloadError,
)
class DownloadExampleImagesInProgressError(RuntimeError):
"""Raised when a download is already running."""
def __init__(self, progress: Dict[str, Any]) -> None:
super().__init__("Download already in progress")
self.progress = progress
class DownloadExampleImagesConfigurationError(ValueError):
"""Raised when settings prevent downloads from starting."""
class DownloadExampleImagesUseCase:
"""Validate payloads and trigger the download manager."""
def __init__(self, *, download_manager) -> None:
self._download_manager = download_manager
async def execute(self, payload: Dict[str, Any]) -> Dict[str, Any]:
"""Start a download and translate manager errors."""
try:
return await self._download_manager.start_download(payload)
except DownloadInProgressError as exc:
raise DownloadExampleImagesInProgressError(exc.progress_snapshot) from exc
except DownloadConfigurationError as exc:
raise DownloadExampleImagesConfigurationError(str(exc)) from exc
except ExampleImagesDownloadError:
raise

View File

@@ -0,0 +1,86 @@
"""Use case for importing example images."""
from __future__ import annotations
import os
import tempfile
from contextlib import suppress
from typing import Any, Dict, List
from aiohttp import web
from ....utils.example_images_processor import (
ExampleImagesImportError,
ExampleImagesProcessor,
ExampleImagesValidationError,
)
class ImportExampleImagesValidationError(ValueError):
"""Raised when request validation fails."""
class ImportExampleImagesUseCase:
"""Parse upload payloads and delegate to the processor service."""
def __init__(self, *, processor: ExampleImagesProcessor) -> None:
self._processor = processor
async def execute(self, request: web.Request) -> Dict[str, Any]:
model_hash: str | None = None
files_to_import: List[str] = []
temp_files: List[str] = []
try:
if request.content_type and "multipart/form-data" in request.content_type:
reader = await request.multipart()
first_field = await reader.next()
if first_field and first_field.name == "model_hash":
model_hash = await first_field.text()
else:
# Support clients that send files first and hash later
if first_field is not None:
await self._collect_upload_file(first_field, files_to_import, temp_files)
async for field in reader:
if field.name == "model_hash" and not model_hash:
model_hash = await field.text()
elif field.name == "files":
await self._collect_upload_file(field, files_to_import, temp_files)
else:
data = await request.json()
model_hash = data.get("model_hash")
files_to_import = list(data.get("file_paths", []))
result = await self._processor.import_images(model_hash, files_to_import)
return result
except ExampleImagesValidationError as exc:
raise ImportExampleImagesValidationError(str(exc)) from exc
except ExampleImagesImportError:
raise
finally:
for path in temp_files:
with suppress(Exception):
os.remove(path)
async def _collect_upload_file(
self,
field: Any,
files_to_import: List[str],
temp_files: List[str],
) -> None:
"""Persist an uploaded file to disk and add it to the import list."""
filename = field.filename or "upload"
file_ext = os.path.splitext(filename)[1].lower()
with tempfile.NamedTemporaryFile(suffix=file_ext, delete=False) as tmp_file:
temp_files.append(tmp_file.name)
while True:
chunk = await field.read_chunk()
if not chunk:
break
tmp_file.write(chunk)
files_to_import.append(tmp_file.name)

View File

@@ -3,7 +3,6 @@ import os
import asyncio import asyncio
import json import json
import time import time
from aiohttp import web
from ..services.service_registry import ServiceRegistry from ..services.service_registry import ServiceRegistry
from ..utils.metadata_manager import MetadataManager from ..utils.metadata_manager import MetadataManager
from .example_images_processor import ExampleImagesProcessor from .example_images_processor import ExampleImagesProcessor
@@ -12,6 +11,30 @@ from ..services.websocket_manager import ws_manager # Add this import at the to
from ..services.downloader import get_downloader from ..services.downloader import get_downloader
from ..services.settings_manager import settings from ..services.settings_manager import settings
class ExampleImagesDownloadError(RuntimeError):
"""Base error for example image download operations."""
class DownloadInProgressError(ExampleImagesDownloadError):
"""Raised when a download is already running."""
def __init__(self, progress_snapshot: dict) -> None:
super().__init__("Download already in progress")
self.progress_snapshot = progress_snapshot
class DownloadNotRunningError(ExampleImagesDownloadError):
"""Raised when pause/resume is requested without an active download."""
def __init__(self, message: str = "No download in progress") -> None:
super().__init__(message)
class DownloadConfigurationError(ExampleImagesDownloadError):
"""Raised when configuration prevents starting a download."""
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Download status tracking # Download status tracking
@@ -31,11 +54,21 @@ download_progress = {
'failed_models': set() # Track models that failed to download after metadata refresh 'failed_models': set() # Track models that failed to download after metadata refresh
} }
def _serialize_progress() -> dict:
"""Return a JSON-serialisable snapshot of the current progress."""
snapshot = download_progress.copy()
snapshot['processed_models'] = list(download_progress['processed_models'])
snapshot['refreshed_models'] = list(download_progress['refreshed_models'])
snapshot['failed_models'] = list(download_progress['failed_models'])
return snapshot
class DownloadManager: class DownloadManager:
"""Manages downloading example images for models""" """Manages downloading example images for models"""
@staticmethod @staticmethod
async def start_download(request): async def start_download(options: dict):
""" """
Start downloading example images for models Start downloading example images for models
@@ -50,21 +83,10 @@ class DownloadManager:
global download_task, is_downloading, download_progress global download_task, is_downloading, download_progress
if is_downloading: if is_downloading:
# Create a copy for JSON serialization raise DownloadInProgressError(_serialize_progress())
response_progress = download_progress.copy()
response_progress['processed_models'] = list(download_progress['processed_models'])
response_progress['refreshed_models'] = list(download_progress['refreshed_models'])
response_progress['failed_models'] = list(download_progress['failed_models'])
return web.json_response({
'success': False,
'error': 'Download already in progress',
'status': response_progress
}, status=400)
try: try:
# Parse the request body data = options or {}
data = await request.json()
auto_mode = data.get('auto_mode', False) auto_mode = data.get('auto_mode', False)
optimize = data.get('optimize', True) optimize = data.get('optimize', True)
model_types = data.get('model_types', ['lora', 'checkpoint']) model_types = data.get('model_types', ['lora', 'checkpoint'])
@@ -78,15 +100,11 @@ class DownloadManager:
if auto_mode: if auto_mode:
# For auto mode, just log and return success to avoid showing error toasts # For auto mode, just log and return success to avoid showing error toasts
logger.debug(error_msg) logger.debug(error_msg)
return web.json_response({ return {
'success': True, 'success': True,
'message': 'Example images path not configured, skipping auto download' 'message': 'Example images path not configured, skipping auto download'
}) }
else: raise DownloadConfigurationError(error_msg)
return web.json_response({
'success': False,
'error': error_msg
}, status=400)
# Create the output directory # Create the output directory
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
@@ -129,24 +147,15 @@ class DownloadManager:
) )
) )
# Create a copy for JSON serialization return {
response_progress = download_progress.copy()
response_progress['processed_models'] = list(download_progress['processed_models'])
response_progress['refreshed_models'] = list(download_progress['refreshed_models'])
response_progress['failed_models'] = list(download_progress['failed_models'])
return web.json_response({
'success': True, 'success': True,
'message': 'Download started', 'message': 'Download started',
'status': response_progress 'status': _serialize_progress()
}) }
except Exception as e: except Exception as e:
logger.error(f"Failed to start example images download: {e}", exc_info=True) logger.error(f"Failed to start example images download: {e}", exc_info=True)
return web.json_response({ raise ExampleImagesDownloadError(str(e)) from e
'success': False,
'error': str(e)
}, status=500)
@staticmethod @staticmethod
async def get_status(request): async def get_status(request):
@@ -154,16 +163,13 @@ class DownloadManager:
global download_progress global download_progress
# Create a copy of the progress dict with the set converted to a list for JSON serialization # Create a copy of the progress dict with the set converted to a list for JSON serialization
response_progress = download_progress.copy() response_progress = _serialize_progress()
response_progress['processed_models'] = list(download_progress['processed_models'])
response_progress['refreshed_models'] = list(download_progress['refreshed_models'])
response_progress['failed_models'] = list(download_progress['failed_models'])
return web.json_response({ return {
'success': True, 'success': True,
'is_downloading': is_downloading, 'is_downloading': is_downloading,
'status': response_progress 'status': response_progress
}) }
@staticmethod @staticmethod
async def pause_download(request): async def pause_download(request):
@@ -171,17 +177,14 @@ class DownloadManager:
global download_progress global download_progress
if not is_downloading: if not is_downloading:
return web.json_response({ raise DownloadNotRunningError()
'success': False,
'error': 'No download in progress'
}, status=400)
download_progress['status'] = 'paused' download_progress['status'] = 'paused'
return web.json_response({ return {
'success': True, 'success': True,
'message': 'Download paused' 'message': 'Download paused'
}) }
@staticmethod @staticmethod
async def resume_download(request): async def resume_download(request):
@@ -189,23 +192,19 @@ class DownloadManager:
global download_progress global download_progress
if not is_downloading: if not is_downloading:
return web.json_response({ raise DownloadNotRunningError()
'success': False,
'error': 'No download in progress'
}, status=400)
if download_progress['status'] == 'paused': if download_progress['status'] == 'paused':
download_progress['status'] = 'running' download_progress['status'] = 'running'
return web.json_response({ return {
'success': True, 'success': True,
'message': 'Download resumed' 'message': 'Download resumed'
}) }
else:
return web.json_response({ raise DownloadNotRunningError(
'success': False, f"Download is in '{download_progress['status']}' state, cannot resume"
'error': f"Download is in '{download_progress['status']}' state, cannot resume" )
}, status=400)
@staticmethod @staticmethod
async def _download_all_example_images(output_dir, optimize, model_types, delay): async def _download_all_example_images(output_dir, optimize, model_types, delay):
@@ -432,7 +431,7 @@ class DownloadManager:
logger.error(f"Failed to save progress file: {e}") logger.error(f"Failed to save progress file: {e}")
@staticmethod @staticmethod
async def start_force_download(request): async def start_force_download(options: dict):
""" """
Force download example images for specific models Force download example images for specific models
@@ -447,33 +446,23 @@ class DownloadManager:
global download_task, is_downloading, download_progress global download_task, is_downloading, download_progress
if is_downloading: if is_downloading:
return web.json_response({ raise DownloadInProgressError(_serialize_progress())
'success': False,
'error': 'Download already in progress'
}, status=400)
try: try:
# Parse the request body data = options or {}
data = await request.json()
model_hashes = data.get('model_hashes', []) model_hashes = data.get('model_hashes', [])
optimize = data.get('optimize', True) optimize = data.get('optimize', True)
model_types = data.get('model_types', ['lora', 'checkpoint']) model_types = data.get('model_types', ['lora', 'checkpoint'])
delay = float(data.get('delay', 0.2)) # Default to 0.2 seconds delay = float(data.get('delay', 0.2)) # Default to 0.2 seconds
if not model_hashes: if not model_hashes:
return web.json_response({ raise DownloadConfigurationError('Missing model_hashes parameter')
'success': False,
'error': 'Missing model_hashes parameter'
}, status=400)
# Get output directory from settings # Get output directory from settings
output_dir = settings.get('example_images_path') output_dir = settings.get('example_images_path')
if not output_dir: if not output_dir:
return web.json_response({ raise DownloadConfigurationError('Example images path not configured in settings')
'success': False,
'error': 'Example images path not configured in settings'
}, status=400)
# Create the output directory # Create the output directory
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
@@ -506,20 +495,17 @@ class DownloadManager:
# Set download status to not downloading # Set download status to not downloading
is_downloading = False is_downloading = False
return web.json_response({ return {
'success': True, 'success': True,
'message': 'Force download completed', 'message': 'Force download completed',
'result': result 'result': result
}) }
except Exception as e: except Exception as e:
# Set download status to not downloading # Set download status to not downloading
is_downloading = False is_downloading = False
logger.error(f"Failed during forced example images download: {e}", exc_info=True) logger.error(f"Failed during forced example images download: {e}", exc_info=True)
return web.json_response({ raise ExampleImagesDownloadError(str(e)) from e
'success': False,
'error': str(e)
}, status=500)
@staticmethod @staticmethod
async def _download_specific_models_example_images_sync(model_hashes, output_dir, optimize, model_types, delay): async def _download_specific_models_example_images_sync(model_hashes, output_dir, optimize, model_types, delay):

View File

@@ -1,7 +1,6 @@
import logging import logging
import os import os
import re import re
import tempfile
import random import random
import string import string
from aiohttp import web from aiohttp import web
@@ -13,6 +12,14 @@ from ..utils.metadata_manager import MetadataManager
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ExampleImagesImportError(RuntimeError):
"""Base error for example image import operations."""
class ExampleImagesValidationError(ExampleImagesImportError):
"""Raised when input validation fails."""
class ExampleImagesProcessor: class ExampleImagesProcessor:
"""Processes and manipulates example images""" """Processes and manipulates example images"""
@@ -299,81 +306,20 @@ class ExampleImagesProcessor:
return False return False
@staticmethod @staticmethod
async def import_images(request): async def import_images(model_hash: str, files_to_import: list[str]):
""" """Import local example images for a model."""
Import local example images
Accepts:
- multipart/form-data form with model_hash and files fields
or
- JSON request with model_hash and file_paths
Returns:
- Success status and list of imported files
"""
try:
model_hash = None
files_to_import = []
temp_files_to_cleanup = []
# Check if it's a multipart form-data request (direct file upload)
if request.content_type and 'multipart/form-data' in request.content_type:
reader = await request.multipart()
# First get model_hash
field = await reader.next()
if field.name == 'model_hash':
model_hash = await field.text()
# Then process all files
while True:
field = await reader.next()
if field is None:
break
if field.name == 'files':
# Create a temporary file with appropriate suffix for type detection
file_name = field.filename
file_ext = os.path.splitext(file_name)[1].lower()
with tempfile.NamedTemporaryFile(suffix=file_ext, delete=False) as tmp_file:
temp_path = tmp_file.name
temp_files_to_cleanup.append(temp_path) # Track for cleanup
# Write chunks to the temporary file
while True:
chunk = await field.read_chunk()
if not chunk:
break
tmp_file.write(chunk)
# Add to the list of files to process
files_to_import.append(temp_path)
else:
# Parse JSON request (legacy method using file paths)
data = await request.json()
model_hash = data.get('model_hash')
files_to_import = data.get('file_paths', [])
if not model_hash: if not model_hash:
return web.json_response({ raise ExampleImagesValidationError('Missing model_hash parameter')
'success': False,
'error': 'Missing model_hash parameter'
}, status=400)
if not files_to_import: if not files_to_import:
return web.json_response({ raise ExampleImagesValidationError('No files provided to import')
'success': False,
'error': 'No files provided to import'
}, status=400)
try:
# Get example images path # Get example images path
example_images_path = settings.get('example_images_path') example_images_path = settings.get('example_images_path')
if not example_images_path: if not example_images_path:
return web.json_response({ raise ExampleImagesValidationError('No example images path configured')
'success': False,
'error': 'No example images path configured'
}, status=400)
# Find the model and get current metadata # Find the model and get current metadata
lora_scanner = await ServiceRegistry.get_lora_scanner() lora_scanner = await ServiceRegistry.get_lora_scanner()
@@ -395,10 +341,9 @@ class ExampleImagesProcessor:
break break
if not model_data: if not model_data:
return web.json_response({ raise ExampleImagesImportError(
'success': False, f"Model with hash {model_hash} not found in cache"
'error': f"Model with hash {model_hash} not found in cache" )
}, status=404)
# Create model folder # Create model folder
model_folder = os.path.join(example_images_path, model_hash) model_folder = os.path.join(example_images_path, model_hash)
@@ -453,7 +398,7 @@ class ExampleImagesProcessor:
newly_imported_paths newly_imported_paths
) )
return web.json_response({ return {
'success': len(imported_files) > 0, 'success': len(imported_files) > 0,
'message': f'Successfully imported {len(imported_files)} files' + 'message': f'Successfully imported {len(imported_files)} files' +
(f' with {len(errors)} errors' if errors else ''), (f' with {len(errors)} errors' if errors else ''),
@@ -462,21 +407,13 @@ class ExampleImagesProcessor:
'regular_images': regular_images, 'regular_images': regular_images,
'custom_images': custom_images, 'custom_images': custom_images,
"model_file_path": model_data.get('file_path', ''), "model_file_path": model_data.get('file_path', ''),
}) }
except ExampleImagesImportError:
raise
except Exception as e: except Exception as e:
logger.error(f"Failed to import example images: {e}", exc_info=True) logger.error(f"Failed to import example images: {e}", exc_info=True)
return web.json_response({ raise ExampleImagesImportError(str(e)) from e
'success': False,
'error': str(e)
}, status=500)
finally:
# Clean up temporary files
for temp_file in temp_files_to_cleanup:
try:
os.remove(temp_file)
except Exception as e:
logger.error(f"Failed to remove temporary file {temp_file}: {e}")
@staticmethod @staticmethod
async def delete_custom_image(request): async def delete_custom_image(request):

View File

@@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
import json
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, List, Tuple from typing import Any, List, Tuple
@@ -33,37 +34,35 @@ class StubDownloadManager:
def __init__(self) -> None: def __init__(self) -> None:
self.calls: List[Tuple[str, Any]] = [] self.calls: List[Tuple[str, Any]] = []
async def start_download(self, request: web.Request) -> web.StreamResponse: async def start_download(self, payload: Any) -> dict:
payload = await request.json()
self.calls.append(("start_download", payload)) self.calls.append(("start_download", payload))
return web.json_response({"operation": "start_download", "payload": payload}) return {"operation": "start_download", "payload": payload}
async def get_status(self, request: web.Request) -> web.StreamResponse: async def get_status(self, request: web.Request) -> dict:
self.calls.append(("get_status", dict(request.query))) self.calls.append(("get_status", dict(request.query)))
return web.json_response({"operation": "get_status"}) return {"operation": "get_status"}
async def pause_download(self, request: web.Request) -> web.StreamResponse: async def pause_download(self, request: web.Request) -> dict:
self.calls.append(("pause_download", None)) self.calls.append(("pause_download", None))
return web.json_response({"operation": "pause_download"}) return {"operation": "pause_download"}
async def resume_download(self, request: web.Request) -> web.StreamResponse: async def resume_download(self, request: web.Request) -> dict:
self.calls.append(("resume_download", None)) self.calls.append(("resume_download", None))
return web.json_response({"operation": "resume_download"}) return {"operation": "resume_download"}
async def start_force_download(self, request: web.Request) -> web.StreamResponse: async def start_force_download(self, payload: Any) -> dict:
payload = await request.json()
self.calls.append(("start_force_download", payload)) self.calls.append(("start_force_download", payload))
return web.json_response({"operation": "start_force_download", "payload": payload}) return {"operation": "start_force_download", "payload": payload}
class StubExampleImagesProcessor: class StubExampleImagesProcessor:
def __init__(self) -> None: def __init__(self) -> None:
self.calls: List[Tuple[str, Any]] = [] self.calls: List[Tuple[str, Any]] = []
async def import_images(self, request: web.Request) -> web.StreamResponse: async def import_images(self, model_hash: str, files: List[str]) -> dict:
payload = await request.json() payload = {"model_hash": model_hash, "file_paths": files}
self.calls.append(("import_images", payload)) self.calls.append(("import_images", payload))
return web.json_response({"operation": "import_images", "payload": payload}) return {"operation": "import_images", "payload": payload}
async def delete_custom_image(self, request: web.Request) -> web.StreamResponse: async def delete_custom_image(self, request: web.Request) -> web.StreamResponse:
payload = await request.json() payload = await request.json()
@@ -184,7 +183,7 @@ async def test_pause_and_resume_routes_delegate():
async def test_import_route_delegates_to_processor(): async def test_import_route_delegates_to_processor():
payload = {"model_hash": "abc123", "files": ["/path/image.png"]} payload = {"model_hash": "abc123", "file_paths": ["/path/image.png"]}
async with example_images_app() as harness: async with example_images_app() as harness:
response = await harness.client.post( response = await harness.client.post(
"/api/lm/import-example-images", json=payload "/api/lm/import-example-images", json=payload
@@ -193,7 +192,8 @@ async def test_import_route_delegates_to_processor():
assert response.status == 200 assert response.status == 200
assert body == {"operation": "import_images", "payload": payload} assert body == {"operation": "import_images", "payload": payload}
assert harness.processor.calls == [("import_images", payload)] expected_call = ("import_images", payload)
assert expected_call in harness.processor.calls
async def test_delete_route_delegates_to_processor(): async def test_delete_route_delegates_to_processor():
@@ -251,70 +251,91 @@ async def test_download_handler_methods_delegate() -> None:
def __init__(self) -> None: def __init__(self) -> None:
self.calls: List[Tuple[str, Any]] = [] self.calls: List[Tuple[str, Any]] = []
async def start_download(self, request) -> str: async def get_status(self, request) -> dict:
self.calls.append(("start_download", request))
return "download"
async def get_status(self, request) -> str:
self.calls.append(("get_status", request)) self.calls.append(("get_status", request))
return "status" return {"status": "ok"}
async def pause_download(self, request) -> str: async def pause_download(self, request) -> dict:
self.calls.append(("pause_download", request)) self.calls.append(("pause_download", request))
return "pause" return {"status": "paused"}
async def resume_download(self, request) -> str: async def resume_download(self, request) -> dict:
self.calls.append(("resume_download", request)) self.calls.append(("resume_download", request))
return "resume" return {"status": "running"}
async def start_force_download(self, request) -> str: async def start_force_download(self, payload) -> dict:
self.calls.append(("start_force_download", request)) self.calls.append(("start_force_download", payload))
return "force" return {"status": "force", "payload": payload}
class StubDownloadUseCase:
def __init__(self) -> None:
self.payloads: List[Any] = []
async def execute(self, payload: dict) -> dict:
self.payloads.append(payload)
return {"status": "started", "payload": payload}
class DummyRequest:
def __init__(self, payload: dict) -> None:
self._payload = payload
self.query = {}
async def json(self) -> dict:
return self._payload
recorder = Recorder() recorder = Recorder()
handler = ExampleImagesDownloadHandler(recorder) use_case = StubDownloadUseCase()
request = object() handler = ExampleImagesDownloadHandler(use_case, recorder)
request = DummyRequest({"foo": "bar"})
assert await handler.download_example_images(request) == "download" download_response = await handler.download_example_images(request)
assert await handler.get_example_images_status(request) == "status" assert json.loads(download_response.text) == {"status": "started", "payload": {"foo": "bar"}}
assert await handler.pause_example_images(request) == "pause" status_response = await handler.get_example_images_status(request)
assert await handler.resume_example_images(request) == "resume" assert json.loads(status_response.text) == {"status": "ok"}
assert await handler.force_download_example_images(request) == "force" pause_response = await handler.pause_example_images(request)
assert json.loads(pause_response.text) == {"status": "paused"}
resume_response = await handler.resume_example_images(request)
assert json.loads(resume_response.text) == {"status": "running"}
force_response = await handler.force_download_example_images(request)
assert json.loads(force_response.text) == {"status": "force", "payload": {"foo": "bar"}}
expected = [ assert use_case.payloads == [{"foo": "bar"}]
("start_download", request), assert recorder.calls == [
("get_status", request), ("get_status", request),
("pause_download", request), ("pause_download", request),
("resume_download", request), ("resume_download", request),
("start_force_download", request), ("start_force_download", {"foo": "bar"}),
] ]
assert recorder.calls == expected
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_management_handler_methods_delegate() -> None: async def test_management_handler_methods_delegate() -> None:
class StubImportUseCase:
def __init__(self) -> None:
self.requests: List[Any] = []
async def execute(self, request: Any) -> dict:
self.requests.append(request)
return {"status": "imported"}
class Recorder: class Recorder:
def __init__(self) -> None: def __init__(self) -> None:
self.calls: List[Tuple[str, Any]] = [] self.calls: List[Tuple[str, Any]] = []
async def import_images(self, request) -> str:
self.calls.append(("import_images", request))
return "import"
async def delete_custom_image(self, request) -> str: async def delete_custom_image(self, request) -> str:
self.calls.append(("delete_custom_image", request)) self.calls.append(("delete_custom_image", request))
return "delete" return "delete"
recorder = Recorder() recorder = Recorder()
handler = ExampleImagesManagementHandler(recorder) use_case = StubImportUseCase()
handler = ExampleImagesManagementHandler(use_case, recorder)
request = object() request = object()
assert await handler.import_example_images(request) == "import" import_response = await handler.import_example_images(request)
assert json.loads(import_response.text) == {"status": "imported"}
assert await handler.delete_example_image(request) == "delete" assert await handler.delete_example_image(request) == "delete"
assert recorder.calls == [ assert use_case.requests == [request]
("import_images", request), assert recorder.calls == [("delete_custom_image", request)]
("delete_custom_image", request),
]
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -350,8 +371,29 @@ async def test_file_handler_methods_delegate() -> None:
def test_handler_set_route_mapping_includes_all_handlers() -> None: def test_handler_set_route_mapping_includes_all_handlers() -> None:
download = ExampleImagesDownloadHandler(object()) class DummyUseCase:
management = ExampleImagesManagementHandler(object()) async def execute(self, payload):
return payload
class DummyManager:
async def get_status(self, request):
return {}
async def pause_download(self, request):
return {}
async def resume_download(self, request):
return {}
async def start_force_download(self, payload):
return payload
class DummyProcessor:
async def delete_custom_image(self, request):
return {}
download = ExampleImagesDownloadHandler(DummyUseCase(), DummyManager())
management = ExampleImagesManagementHandler(DummyUseCase(), DummyProcessor())
files = ExampleImagesFileHandler(object()) files = ExampleImagesFileHandler(object())
handler_set = ExampleImagesHandlerSet( handler_set = ExampleImagesHandlerSet(
download=download, download=download,

View File

@@ -10,9 +10,23 @@ from py_local.services.use_cases import (
AutoOrganizeInProgressError, AutoOrganizeInProgressError,
AutoOrganizeUseCase, AutoOrganizeUseCase,
BulkMetadataRefreshUseCase, BulkMetadataRefreshUseCase,
DownloadExampleImagesConfigurationError,
DownloadExampleImagesInProgressError,
DownloadExampleImagesUseCase,
DownloadModelEarlyAccessError, DownloadModelEarlyAccessError,
DownloadModelUseCase, DownloadModelUseCase,
DownloadModelValidationError, DownloadModelValidationError,
ImportExampleImagesUseCase,
ImportExampleImagesValidationError,
)
from py_local.utils.example_images_download_manager import (
DownloadConfigurationError,
DownloadInProgressError,
ExampleImagesDownloadError,
)
from py_local.utils.example_images_processor import (
ExampleImagesImportError,
ExampleImagesValidationError,
) )
from tests.conftest import MockModelService, MockScanner from tests.conftest import MockModelService, MockScanner
@@ -88,6 +102,38 @@ class StubDownloadCoordinator:
return {"success": True, "download_id": "abc123"} return {"success": True, "download_id": "abc123"}
class StubExampleImagesDownloadManager:
def __init__(self) -> None:
self.payloads: List[Dict[str, Any]] = []
self.error: Optional[str] = None
self.progress_snapshot = {"status": "running"}
async def start_download(self, payload: Dict[str, Any]) -> Dict[str, Any]:
self.payloads.append(payload)
if self.error == "in_progress":
raise DownloadInProgressError(self.progress_snapshot)
if self.error == "configuration":
raise DownloadConfigurationError("path missing")
if self.error == "generic":
raise ExampleImagesDownloadError("boom")
return {"success": True, "message": "ok"}
class StubExampleImagesProcessor:
def __init__(self) -> None:
self.calls: List[Dict[str, Any]] = []
self.error: Optional[str] = None
self.response: Dict[str, Any] = {"success": True}
async def import_images(self, model_hash: str, files: List[str]) -> Dict[str, Any]:
self.calls.append({"model_hash": model_hash, "files": files})
if self.error == "validation":
raise ExampleImagesValidationError("missing")
if self.error == "generic":
raise ExampleImagesImportError("boom")
return self.response
async def test_auto_organize_use_case_executes_with_lock() -> None: async def test_auto_organize_use_case_executes_with_lock() -> None:
file_service = StubFileService() file_service = StubFileService()
lock_provider = StubLockProvider() lock_provider = StubLockProvider()
@@ -189,3 +235,83 @@ async def test_download_model_use_case_returns_result() -> None:
assert result["success"] is True assert result["success"] is True
assert result["download_id"] == "abc123" assert result["download_id"] == "abc123"
async def test_download_example_images_use_case_triggers_manager() -> None:
manager = StubExampleImagesDownloadManager()
use_case = DownloadExampleImagesUseCase(download_manager=manager)
payload = {"optimize": True}
result = await use_case.execute(payload)
assert manager.payloads == [payload]
assert result == {"success": True, "message": "ok"}
async def test_download_example_images_use_case_maps_in_progress() -> None:
manager = StubExampleImagesDownloadManager()
manager.error = "in_progress"
use_case = DownloadExampleImagesUseCase(download_manager=manager)
with pytest.raises(DownloadExampleImagesInProgressError) as exc:
await use_case.execute({})
assert exc.value.progress == manager.progress_snapshot
async def test_download_example_images_use_case_maps_configuration() -> None:
manager = StubExampleImagesDownloadManager()
manager.error = "configuration"
use_case = DownloadExampleImagesUseCase(download_manager=manager)
with pytest.raises(DownloadExampleImagesConfigurationError):
await use_case.execute({})
async def test_download_example_images_use_case_propagates_generic_error() -> None:
manager = StubExampleImagesDownloadManager()
manager.error = "generic"
use_case = DownloadExampleImagesUseCase(download_manager=manager)
with pytest.raises(ExampleImagesDownloadError):
await use_case.execute({})
class DummyJsonRequest:
def __init__(self, payload: Dict[str, Any]) -> None:
self._payload = payload
self.content_type = "application/json"
async def json(self) -> Dict[str, Any]:
return self._payload
async def test_import_example_images_use_case_delegates() -> None:
processor = StubExampleImagesProcessor()
use_case = ImportExampleImagesUseCase(processor=processor)
request = DummyJsonRequest({"model_hash": "abc", "file_paths": ["/tmp/file"]})
result = await use_case.execute(request)
assert processor.calls == [{"model_hash": "abc", "files": ["/tmp/file"]}]
assert result == {"success": True}
async def test_import_example_images_use_case_maps_validation_error() -> None:
processor = StubExampleImagesProcessor()
processor.error = "validation"
use_case = ImportExampleImagesUseCase(processor=processor)
request = DummyJsonRequest({"model_hash": None, "file_paths": []})
with pytest.raises(ImportExampleImagesValidationError):
await use_case.execute(request)
async def test_import_example_images_use_case_propagates_generic_error() -> None:
processor = StubExampleImagesProcessor()
processor.error = "generic"
use_case = ImportExampleImagesUseCase(processor=processor)
request = DummyJsonRequest({"model_hash": "abc", "file_paths": ["/tmp/file"]})
with pytest.raises(ExampleImagesImportError):
await use_case.execute(request)