From aaad270822c16b9acac418d5f0161b15d49a7d29 Mon Sep 17 00:00:00 2001 From: pixelpaws Date: Tue, 23 Sep 2025 11:47:12 +0800 Subject: [PATCH] feat(example-images): add use case orchestration --- py/routes/example_images_routes.py | 10 +- py/routes/handlers/example_images_handlers.py | 80 ++++++++- py/services/use_cases/__init__.py | 12 ++ .../use_cases/example_images/__init__.py | 19 ++ .../download_example_images_use_case.py | 42 +++++ .../import_example_images_use_case.py | 86 +++++++++ py/utils/example_images_download_manager.py | 170 ++++++++---------- py/utils/example_images_processor.py | 151 +++++----------- tests/routes/test_example_images_routes.py | 148 +++++++++------ tests/services/test_use_cases.py | 126 +++++++++++++ 10 files changed, 582 insertions(+), 262 deletions(-) create mode 100644 py/services/use_cases/example_images/__init__.py create mode 100644 py/services/use_cases/example_images/download_example_images_use_case.py create mode 100644 py/services/use_cases/example_images/import_example_images_use_case.py diff --git a/py/routes/example_images_routes.py b/py/routes/example_images_routes.py index 829760c2..44effa3b 100644 --- a/py/routes/example_images_routes.py +++ b/py/routes/example_images_routes.py @@ -12,6 +12,10 @@ from .handlers.example_images_handlers import ( ExampleImagesHandlerSet, ExampleImagesManagementHandler, ) +from ..services.use_cases.example_images import ( + DownloadExampleImagesUseCase, + ImportExampleImagesUseCase, +) from ..utils.example_images_download_manager import DownloadManager from ..utils.example_images_file_manager import ExampleImagesFileManager from ..utils.example_images_processor import ExampleImagesProcessor @@ -59,8 +63,10 @@ class ExampleImagesRoutes: def _build_handler_set(self) -> ExampleImagesHandlerSet: logger.debug("Building ExampleImagesHandlerSet with %s, %s, %s", self._download_manager, self._processor, self._file_manager) - download_handler = ExampleImagesDownloadHandler(self._download_manager) - management_handler = ExampleImagesManagementHandler(self._processor) + download_use_case = DownloadExampleImagesUseCase(download_manager=self._download_manager) + download_handler = ExampleImagesDownloadHandler(download_use_case, self._download_manager) + import_use_case = ImportExampleImagesUseCase(processor=self._processor) + management_handler = ExampleImagesManagementHandler(import_use_case, self._processor) file_handler = ExampleImagesFileHandler(self._file_manager) return ExampleImagesHandlerSet( download=download_handler, diff --git a/py/routes/handlers/example_images_handlers.py b/py/routes/handlers/example_images_handlers.py index 3d960338..fd39de04 100644 --- a/py/routes/handlers/example_images_handlers.py +++ b/py/routes/handlers/example_images_handlers.py @@ -6,37 +6,101 @@ from typing import Callable, Mapping from aiohttp import web +from ...services.use_cases.example_images import ( + DownloadExampleImagesConfigurationError, + DownloadExampleImagesInProgressError, + DownloadExampleImagesUseCase, + ImportExampleImagesUseCase, + ImportExampleImagesValidationError, +) +from ...utils.example_images_download_manager import ( + DownloadConfigurationError, + DownloadInProgressError, + DownloadNotRunningError, + ExampleImagesDownloadError, +) +from ...utils.example_images_processor import ExampleImagesImportError + class ExampleImagesDownloadHandler: """HTTP adapters for download-related example image endpoints.""" - def __init__(self, download_manager) -> None: + def __init__( + self, + download_use_case: DownloadExampleImagesUseCase, + download_manager, + ) -> None: + self._download_use_case = download_use_case self._download_manager = download_manager async def download_example_images(self, request: web.Request) -> web.StreamResponse: - return await self._download_manager.start_download(request) + try: + payload = await request.json() + result = await self._download_use_case.execute(payload) + return web.json_response(result) + except DownloadExampleImagesInProgressError as exc: + response = { + 'success': False, + 'error': str(exc), + 'status': exc.progress, + } + return web.json_response(response, status=400) + except DownloadExampleImagesConfigurationError as exc: + return web.json_response({'success': False, 'error': str(exc)}, status=400) + except ExampleImagesDownloadError as exc: + return web.json_response({'success': False, 'error': str(exc)}, status=500) async def get_example_images_status(self, request: web.Request) -> web.StreamResponse: - return await self._download_manager.get_status(request) + result = await self._download_manager.get_status(request) + return web.json_response(result) async def pause_example_images(self, request: web.Request) -> web.StreamResponse: - return await self._download_manager.pause_download(request) + try: + result = await self._download_manager.pause_download(request) + return web.json_response(result) + except DownloadNotRunningError as exc: + return web.json_response({'success': False, 'error': str(exc)}, status=400) async def resume_example_images(self, request: web.Request) -> web.StreamResponse: - return await self._download_manager.resume_download(request) + try: + result = await self._download_manager.resume_download(request) + return web.json_response(result) + except DownloadNotRunningError as exc: + return web.json_response({'success': False, 'error': str(exc)}, status=400) async def force_download_example_images(self, request: web.Request) -> web.StreamResponse: - return await self._download_manager.start_force_download(request) + try: + payload = await request.json() + result = await self._download_manager.start_force_download(payload) + return web.json_response(result) + except DownloadInProgressError as exc: + response = { + 'success': False, + 'error': str(exc), + 'status': exc.progress_snapshot, + } + return web.json_response(response, status=400) + except DownloadConfigurationError as exc: + return web.json_response({'success': False, 'error': str(exc)}, status=400) + except ExampleImagesDownloadError as exc: + return web.json_response({'success': False, 'error': str(exc)}, status=500) class ExampleImagesManagementHandler: """HTTP adapters for import/delete endpoints.""" - def __init__(self, processor) -> None: + def __init__(self, import_use_case: ImportExampleImagesUseCase, processor) -> None: + self._import_use_case = import_use_case self._processor = processor async def import_example_images(self, request: web.Request) -> web.StreamResponse: - return await self._processor.import_images(request) + try: + result = await self._import_use_case.execute(request) + return web.json_response(result) + except ImportExampleImagesValidationError as exc: + return web.json_response({'success': False, 'error': str(exc)}, status=400) + except ExampleImagesImportError as exc: + return web.json_response({'success': False, 'error': str(exc)}, status=500) async def delete_example_image(self, request: web.Request) -> web.StreamResponse: return await self._processor.delete_custom_image(request) diff --git a/py/services/use_cases/__init__.py b/py/services/use_cases/__init__.py index 986f0f57..8a43318c 100644 --- a/py/services/use_cases/__init__.py +++ b/py/services/use_cases/__init__.py @@ -13,6 +13,13 @@ from .download_model_use_case import ( DownloadModelUseCase, DownloadModelValidationError, ) +from .example_images import ( + DownloadExampleImagesConfigurationError, + DownloadExampleImagesInProgressError, + DownloadExampleImagesUseCase, + ImportExampleImagesUseCase, + ImportExampleImagesValidationError, +) __all__ = [ "AutoOrganizeInProgressError", @@ -22,4 +29,9 @@ __all__ = [ "DownloadModelEarlyAccessError", "DownloadModelUseCase", "DownloadModelValidationError", + "DownloadExampleImagesConfigurationError", + "DownloadExampleImagesInProgressError", + "DownloadExampleImagesUseCase", + "ImportExampleImagesUseCase", + "ImportExampleImagesValidationError", ] diff --git a/py/services/use_cases/example_images/__init__.py b/py/services/use_cases/example_images/__init__.py new file mode 100644 index 00000000..820de618 --- /dev/null +++ b/py/services/use_cases/example_images/__init__.py @@ -0,0 +1,19 @@ +"""Example image specific use case exports.""" + +from .download_example_images_use_case import ( + DownloadExampleImagesUseCase, + DownloadExampleImagesInProgressError, + DownloadExampleImagesConfigurationError, +) +from .import_example_images_use_case import ( + ImportExampleImagesUseCase, + ImportExampleImagesValidationError, +) + +__all__ = [ + "DownloadExampleImagesUseCase", + "DownloadExampleImagesInProgressError", + "DownloadExampleImagesConfigurationError", + "ImportExampleImagesUseCase", + "ImportExampleImagesValidationError", +] diff --git a/py/services/use_cases/example_images/download_example_images_use_case.py b/py/services/use_cases/example_images/download_example_images_use_case.py new file mode 100644 index 00000000..e9a51e13 --- /dev/null +++ b/py/services/use_cases/example_images/download_example_images_use_case.py @@ -0,0 +1,42 @@ +"""Use case coordinating example image downloads.""" + +from __future__ import annotations + +from typing import Any, Dict + +from ....utils.example_images_download_manager import ( + DownloadConfigurationError, + DownloadInProgressError, + ExampleImagesDownloadError, +) + + +class DownloadExampleImagesInProgressError(RuntimeError): + """Raised when a download is already running.""" + + def __init__(self, progress: Dict[str, Any]) -> None: + super().__init__("Download already in progress") + self.progress = progress + + +class DownloadExampleImagesConfigurationError(ValueError): + """Raised when settings prevent downloads from starting.""" + + +class DownloadExampleImagesUseCase: + """Validate payloads and trigger the download manager.""" + + def __init__(self, *, download_manager) -> None: + self._download_manager = download_manager + + async def execute(self, payload: Dict[str, Any]) -> Dict[str, Any]: + """Start a download and translate manager errors.""" + + try: + return await self._download_manager.start_download(payload) + except DownloadInProgressError as exc: + raise DownloadExampleImagesInProgressError(exc.progress_snapshot) from exc + except DownloadConfigurationError as exc: + raise DownloadExampleImagesConfigurationError(str(exc)) from exc + except ExampleImagesDownloadError: + raise diff --git a/py/services/use_cases/example_images/import_example_images_use_case.py b/py/services/use_cases/example_images/import_example_images_use_case.py new file mode 100644 index 00000000..547b2f4e --- /dev/null +++ b/py/services/use_cases/example_images/import_example_images_use_case.py @@ -0,0 +1,86 @@ +"""Use case for importing example images.""" + +from __future__ import annotations + +import os +import tempfile +from contextlib import suppress +from typing import Any, Dict, List + +from aiohttp import web + +from ....utils.example_images_processor import ( + ExampleImagesImportError, + ExampleImagesProcessor, + ExampleImagesValidationError, +) + + +class ImportExampleImagesValidationError(ValueError): + """Raised when request validation fails.""" + + +class ImportExampleImagesUseCase: + """Parse upload payloads and delegate to the processor service.""" + + def __init__(self, *, processor: ExampleImagesProcessor) -> None: + self._processor = processor + + async def execute(self, request: web.Request) -> Dict[str, Any]: + model_hash: str | None = None + files_to_import: List[str] = [] + temp_files: List[str] = [] + + try: + if request.content_type and "multipart/form-data" in request.content_type: + reader = await request.multipart() + + first_field = await reader.next() + if first_field and first_field.name == "model_hash": + model_hash = await first_field.text() + else: + # Support clients that send files first and hash later + if first_field is not None: + await self._collect_upload_file(first_field, files_to_import, temp_files) + + async for field in reader: + if field.name == "model_hash" and not model_hash: + model_hash = await field.text() + elif field.name == "files": + await self._collect_upload_file(field, files_to_import, temp_files) + else: + data = await request.json() + model_hash = data.get("model_hash") + files_to_import = list(data.get("file_paths", [])) + + result = await self._processor.import_images(model_hash, files_to_import) + return result + except ExampleImagesValidationError as exc: + raise ImportExampleImagesValidationError(str(exc)) from exc + except ExampleImagesImportError: + raise + finally: + for path in temp_files: + with suppress(Exception): + os.remove(path) + + async def _collect_upload_file( + self, + field: Any, + files_to_import: List[str], + temp_files: List[str], + ) -> None: + """Persist an uploaded file to disk and add it to the import list.""" + + filename = field.filename or "upload" + file_ext = os.path.splitext(filename)[1].lower() + + with tempfile.NamedTemporaryFile(suffix=file_ext, delete=False) as tmp_file: + temp_files.append(tmp_file.name) + while True: + chunk = await field.read_chunk() + if not chunk: + break + tmp_file.write(chunk) + + files_to_import.append(tmp_file.name) diff --git a/py/utils/example_images_download_manager.py b/py/utils/example_images_download_manager.py index 842192f2..7df0c6fb 100644 --- a/py/utils/example_images_download_manager.py +++ b/py/utils/example_images_download_manager.py @@ -3,7 +3,6 @@ import os import asyncio import json import time -from aiohttp import web from ..services.service_registry import ServiceRegistry from ..utils.metadata_manager import MetadataManager from .example_images_processor import ExampleImagesProcessor @@ -12,6 +11,30 @@ from ..services.websocket_manager import ws_manager # Add this import at the to from ..services.downloader import get_downloader from ..services.settings_manager import settings + +class ExampleImagesDownloadError(RuntimeError): + """Base error for example image download operations.""" + + +class DownloadInProgressError(ExampleImagesDownloadError): + """Raised when a download is already running.""" + + def __init__(self, progress_snapshot: dict) -> None: + super().__init__("Download already in progress") + self.progress_snapshot = progress_snapshot + + +class DownloadNotRunningError(ExampleImagesDownloadError): + """Raised when pause/resume is requested without an active download.""" + + def __init__(self, message: str = "No download in progress") -> None: + super().__init__(message) + + +class DownloadConfigurationError(ExampleImagesDownloadError): + """Raised when configuration prevents starting a download.""" + + logger = logging.getLogger(__name__) # Download status tracking @@ -31,11 +54,21 @@ download_progress = { 'failed_models': set() # Track models that failed to download after metadata refresh } + +def _serialize_progress() -> dict: + """Return a JSON-serialisable snapshot of the current progress.""" + + snapshot = download_progress.copy() + snapshot['processed_models'] = list(download_progress['processed_models']) + snapshot['refreshed_models'] = list(download_progress['refreshed_models']) + snapshot['failed_models'] = list(download_progress['failed_models']) + return snapshot + class DownloadManager: """Manages downloading example images for models""" @staticmethod - async def start_download(request): + async def start_download(options: dict): """ Start downloading example images for models @@ -50,25 +83,14 @@ class DownloadManager: global download_task, is_downloading, download_progress if is_downloading: - # Create a copy for JSON serialization - response_progress = download_progress.copy() - response_progress['processed_models'] = list(download_progress['processed_models']) - response_progress['refreshed_models'] = list(download_progress['refreshed_models']) - response_progress['failed_models'] = list(download_progress['failed_models']) - - return web.json_response({ - 'success': False, - 'error': 'Download already in progress', - 'status': response_progress - }, status=400) - + raise DownloadInProgressError(_serialize_progress()) + try: - # Parse the request body - data = await request.json() + data = options or {} auto_mode = data.get('auto_mode', False) optimize = data.get('optimize', True) model_types = data.get('model_types', ['lora', 'checkpoint']) - delay = float(data.get('delay', 0.2)) # Default to 0.2 seconds + delay = float(data.get('delay', 0.2)) # Default to 0.2 seconds # Get output directory from settings output_dir = settings.get('example_images_path') @@ -78,15 +100,11 @@ class DownloadManager: if auto_mode: # For auto mode, just log and return success to avoid showing error toasts logger.debug(error_msg) - return web.json_response({ + return { 'success': True, 'message': 'Example images path not configured, skipping auto download' - }) - else: - return web.json_response({ - 'success': False, - 'error': error_msg - }, status=400) + } + raise DownloadConfigurationError(error_msg) # Create the output directory os.makedirs(output_dir, exist_ok=True) @@ -129,41 +147,29 @@ class DownloadManager: ) ) - # Create a copy for JSON serialization - response_progress = download_progress.copy() - response_progress['processed_models'] = list(download_progress['processed_models']) - response_progress['refreshed_models'] = list(download_progress['refreshed_models']) - response_progress['failed_models'] = list(download_progress['failed_models']) - - return web.json_response({ + return { 'success': True, 'message': 'Download started', - 'status': response_progress - }) - + 'status': _serialize_progress() + } + except Exception as e: logger.error(f"Failed to start example images download: {e}", exc_info=True) - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) + raise ExampleImagesDownloadError(str(e)) from e @staticmethod async def get_status(request): """Get the current status of example images download""" global download_progress - + # Create a copy of the progress dict with the set converted to a list for JSON serialization - response_progress = download_progress.copy() - response_progress['processed_models'] = list(download_progress['processed_models']) - response_progress['refreshed_models'] = list(download_progress['refreshed_models']) - response_progress['failed_models'] = list(download_progress['failed_models']) - - return web.json_response({ + response_progress = _serialize_progress() + + return { 'success': True, 'is_downloading': is_downloading, 'status': response_progress - }) + } @staticmethod async def pause_download(request): @@ -171,17 +177,14 @@ class DownloadManager: global download_progress if not is_downloading: - return web.json_response({ - 'success': False, - 'error': 'No download in progress' - }, status=400) - + raise DownloadNotRunningError() + download_progress['status'] = 'paused' - - return web.json_response({ + + return { 'success': True, 'message': 'Download paused' - }) + } @staticmethod async def resume_download(request): @@ -189,23 +192,19 @@ class DownloadManager: global download_progress if not is_downloading: - return web.json_response({ - 'success': False, - 'error': 'No download in progress' - }, status=400) - + raise DownloadNotRunningError() + if download_progress['status'] == 'paused': download_progress['status'] = 'running' - - return web.json_response({ + + return { 'success': True, 'message': 'Download resumed' - }) - else: - return web.json_response({ - 'success': False, - 'error': f"Download is in '{download_progress['status']}' state, cannot resume" - }, status=400) + } + + raise DownloadNotRunningError( + f"Download is in '{download_progress['status']}' state, cannot resume" + ) @staticmethod async def _download_all_example_images(output_dir, optimize, model_types, delay): @@ -432,7 +431,7 @@ class DownloadManager: logger.error(f"Failed to save progress file: {e}") @staticmethod - async def start_force_download(request): + async def start_force_download(options: dict): """ Force download example images for specific models @@ -447,33 +446,23 @@ class DownloadManager: global download_task, is_downloading, download_progress if is_downloading: - return web.json_response({ - 'success': False, - 'error': 'Download already in progress' - }, status=400) + raise DownloadInProgressError(_serialize_progress()) try: - # Parse the request body - data = await request.json() + data = options or {} model_hashes = data.get('model_hashes', []) optimize = data.get('optimize', True) model_types = data.get('model_types', ['lora', 'checkpoint']) - delay = float(data.get('delay', 0.2)) # Default to 0.2 seconds - + delay = float(data.get('delay', 0.2)) # Default to 0.2 seconds + if not model_hashes: - return web.json_response({ - 'success': False, - 'error': 'Missing model_hashes parameter' - }, status=400) - + raise DownloadConfigurationError('Missing model_hashes parameter') + # Get output directory from settings output_dir = settings.get('example_images_path') - + if not output_dir: - return web.json_response({ - 'success': False, - 'error': 'Example images path not configured in settings' - }, status=400) + raise DownloadConfigurationError('Example images path not configured in settings') # Create the output directory os.makedirs(output_dir, exist_ok=True) @@ -506,20 +495,17 @@ class DownloadManager: # Set download status to not downloading is_downloading = False - return web.json_response({ + return { 'success': True, 'message': 'Force download completed', 'result': result - }) + } except Exception as e: # Set download status to not downloading is_downloading = False logger.error(f"Failed during forced example images download: {e}", exc_info=True) - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) + raise ExampleImagesDownloadError(str(e)) from e @staticmethod async def _download_specific_models_example_images_sync(model_hashes, output_dir, optimize, model_types, delay): diff --git a/py/utils/example_images_processor.py b/py/utils/example_images_processor.py index f1cfd2bf..7f108ef9 100644 --- a/py/utils/example_images_processor.py +++ b/py/utils/example_images_processor.py @@ -1,7 +1,6 @@ import logging import os import re -import tempfile import random import string from aiohttp import web @@ -13,6 +12,14 @@ from ..utils.metadata_manager import MetadataManager logger = logging.getLogger(__name__) + +class ExampleImagesImportError(RuntimeError): + """Base error for example image import operations.""" + + +class ExampleImagesValidationError(ExampleImagesImportError): + """Raised when input validation fails.""" + class ExampleImagesProcessor: """Processes and manipulates example images""" @@ -299,90 +306,29 @@ class ExampleImagesProcessor: return False @staticmethod - async def import_images(request): - """ - Import local example images - - Accepts: - - multipart/form-data form with model_hash and files fields - or - - JSON request with model_hash and file_paths - - Returns: - - Success status and list of imported files - """ + async def import_images(model_hash: str, files_to_import: list[str]): + """Import local example images for a model.""" + + if not model_hash: + raise ExampleImagesValidationError('Missing model_hash parameter') + + if not files_to_import: + raise ExampleImagesValidationError('No files provided to import') + try: - model_hash = None - files_to_import = [] - temp_files_to_cleanup = [] - - # Check if it's a multipart form-data request (direct file upload) - if request.content_type and 'multipart/form-data' in request.content_type: - reader = await request.multipart() - - # First get model_hash - field = await reader.next() - if field.name == 'model_hash': - model_hash = await field.text() - - # Then process all files - while True: - field = await reader.next() - if field is None: - break - - if field.name == 'files': - # Create a temporary file with appropriate suffix for type detection - file_name = field.filename - file_ext = os.path.splitext(file_name)[1].lower() - - with tempfile.NamedTemporaryFile(suffix=file_ext, delete=False) as tmp_file: - temp_path = tmp_file.name - temp_files_to_cleanup.append(temp_path) # Track for cleanup - - # Write chunks to the temporary file - while True: - chunk = await field.read_chunk() - if not chunk: - break - tmp_file.write(chunk) - - # Add to the list of files to process - files_to_import.append(temp_path) - else: - # Parse JSON request (legacy method using file paths) - data = await request.json() - model_hash = data.get('model_hash') - files_to_import = data.get('file_paths', []) - - if not model_hash: - return web.json_response({ - 'success': False, - 'error': 'Missing model_hash parameter' - }, status=400) - - if not files_to_import: - return web.json_response({ - 'success': False, - 'error': 'No files provided to import' - }, status=400) - # Get example images path example_images_path = settings.get('example_images_path') if not example_images_path: - return web.json_response({ - 'success': False, - 'error': 'No example images path configured' - }, status=400) - + raise ExampleImagesValidationError('No example images path configured') + # Find the model and get current metadata lora_scanner = await ServiceRegistry.get_lora_scanner() checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner() embedding_scanner = await ServiceRegistry.get_embedding_scanner() - + model_data = None scanner = None - + # Check both scanners to find the model for scan_obj in [lora_scanner, checkpoint_scanner, embedding_scanner]: cache = await scan_obj.get_cached_data() @@ -393,21 +339,20 @@ class ExampleImagesProcessor: break if model_data: break - + if not model_data: - return web.json_response({ - 'success': False, - 'error': f"Model with hash {model_hash} not found in cache" - }, status=404) - + raise ExampleImagesImportError( + f"Model with hash {model_hash} not found in cache" + ) + # Create model folder model_folder = os.path.join(example_images_path, model_hash) os.makedirs(model_folder, exist_ok=True) - + imported_files = [] errors = [] newly_imported_paths = [] - + # Process each file path for file_path in files_to_import: try: @@ -415,26 +360,26 @@ class ExampleImagesProcessor: if not os.path.isfile(file_path): errors.append(f"File not found: {file_path}") continue - + # Check if file type is supported file_ext = os.path.splitext(file_path)[1].lower() - if not (file_ext in SUPPORTED_MEDIA_EXTENSIONS['images'] or + if not (file_ext in SUPPORTED_MEDIA_EXTENSIONS['images'] or file_ext in SUPPORTED_MEDIA_EXTENSIONS['videos']): errors.append(f"Unsupported file type: {file_path}") continue - + # Generate new filename using short ID instead of UUID short_id = ExampleImagesProcessor.generate_short_id() new_filename = f"custom_{short_id}{file_ext}" - + dest_path = os.path.join(model_folder, new_filename) - + # Copy the file import shutil shutil.copy2(file_path, dest_path) # Store both the dest_path and the short_id newly_imported_paths.append((dest_path, short_id)) - + # Add to imported files list imported_files.append({ 'name': new_filename, @@ -444,39 +389,31 @@ class ExampleImagesProcessor: }) except Exception as e: errors.append(f"Error importing {file_path}: {str(e)}") - + # Update metadata with new example images regular_images, custom_images = await MetadataUpdater.update_metadata_after_import( - model_hash, + model_hash, model_data, scanner, newly_imported_paths ) - - return web.json_response({ + + return { 'success': len(imported_files) > 0, - 'message': f'Successfully imported {len(imported_files)} files' + + 'message': f'Successfully imported {len(imported_files)} files' + (f' with {len(errors)} errors' if errors else ''), 'files': imported_files, 'errors': errors, 'regular_images': regular_images, 'custom_images': custom_images, "model_file_path": model_data.get('file_path', ''), - }) - + } + + except ExampleImagesImportError: + raise except Exception as e: logger.error(f"Failed to import example images: {e}", exc_info=True) - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) - finally: - # Clean up temporary files - for temp_file in temp_files_to_cleanup: - try: - os.remove(temp_file) - except Exception as e: - logger.error(f"Failed to remove temporary file {temp_file}: {e}") + raise ExampleImagesImportError(str(e)) from e @staticmethod async def delete_custom_image(request): diff --git a/tests/routes/test_example_images_routes.py b/tests/routes/test_example_images_routes.py index b9806dae..e921e744 100644 --- a/tests/routes/test_example_images_routes.py +++ b/tests/routes/test_example_images_routes.py @@ -1,5 +1,6 @@ from __future__ import annotations +import json from contextlib import asynccontextmanager from dataclasses import dataclass from typing import Any, List, Tuple @@ -33,37 +34,35 @@ class StubDownloadManager: def __init__(self) -> None: self.calls: List[Tuple[str, Any]] = [] - async def start_download(self, request: web.Request) -> web.StreamResponse: - payload = await request.json() + async def start_download(self, payload: Any) -> dict: self.calls.append(("start_download", payload)) - return web.json_response({"operation": "start_download", "payload": payload}) + return {"operation": "start_download", "payload": payload} - async def get_status(self, request: web.Request) -> web.StreamResponse: + async def get_status(self, request: web.Request) -> dict: self.calls.append(("get_status", dict(request.query))) - return web.json_response({"operation": "get_status"}) + return {"operation": "get_status"} - async def pause_download(self, request: web.Request) -> web.StreamResponse: + async def pause_download(self, request: web.Request) -> dict: self.calls.append(("pause_download", None)) - return web.json_response({"operation": "pause_download"}) + return {"operation": "pause_download"} - async def resume_download(self, request: web.Request) -> web.StreamResponse: + async def resume_download(self, request: web.Request) -> dict: self.calls.append(("resume_download", None)) - return web.json_response({"operation": "resume_download"}) + return {"operation": "resume_download"} - async def start_force_download(self, request: web.Request) -> web.StreamResponse: - payload = await request.json() + async def start_force_download(self, payload: Any) -> dict: self.calls.append(("start_force_download", payload)) - return web.json_response({"operation": "start_force_download", "payload": payload}) + return {"operation": "start_force_download", "payload": payload} class StubExampleImagesProcessor: def __init__(self) -> None: self.calls: List[Tuple[str, Any]] = [] - async def import_images(self, request: web.Request) -> web.StreamResponse: - payload = await request.json() + async def import_images(self, model_hash: str, files: List[str]) -> dict: + payload = {"model_hash": model_hash, "file_paths": files} self.calls.append(("import_images", payload)) - return web.json_response({"operation": "import_images", "payload": payload}) + return {"operation": "import_images", "payload": payload} async def delete_custom_image(self, request: web.Request) -> web.StreamResponse: payload = await request.json() @@ -184,7 +183,7 @@ async def test_pause_and_resume_routes_delegate(): async def test_import_route_delegates_to_processor(): - payload = {"model_hash": "abc123", "files": ["/path/image.png"]} + payload = {"model_hash": "abc123", "file_paths": ["/path/image.png"]} async with example_images_app() as harness: response = await harness.client.post( "/api/lm/import-example-images", json=payload @@ -193,7 +192,8 @@ async def test_import_route_delegates_to_processor(): assert response.status == 200 assert body == {"operation": "import_images", "payload": payload} - assert harness.processor.calls == [("import_images", payload)] + expected_call = ("import_images", payload) + assert expected_call in harness.processor.calls async def test_delete_route_delegates_to_processor(): @@ -251,70 +251,91 @@ async def test_download_handler_methods_delegate() -> None: def __init__(self) -> None: self.calls: List[Tuple[str, Any]] = [] - async def start_download(self, request) -> str: - self.calls.append(("start_download", request)) - return "download" - - async def get_status(self, request) -> str: + async def get_status(self, request) -> dict: self.calls.append(("get_status", request)) - return "status" + return {"status": "ok"} - async def pause_download(self, request) -> str: + async def pause_download(self, request) -> dict: self.calls.append(("pause_download", request)) - return "pause" + return {"status": "paused"} - async def resume_download(self, request) -> str: + async def resume_download(self, request) -> dict: self.calls.append(("resume_download", request)) - return "resume" + return {"status": "running"} - async def start_force_download(self, request) -> str: - self.calls.append(("start_force_download", request)) - return "force" + async def start_force_download(self, payload) -> dict: + self.calls.append(("start_force_download", payload)) + return {"status": "force", "payload": payload} + + class StubDownloadUseCase: + def __init__(self) -> None: + self.payloads: List[Any] = [] + + async def execute(self, payload: dict) -> dict: + self.payloads.append(payload) + return {"status": "started", "payload": payload} + + class DummyRequest: + def __init__(self, payload: dict) -> None: + self._payload = payload + self.query = {} + + async def json(self) -> dict: + return self._payload recorder = Recorder() - handler = ExampleImagesDownloadHandler(recorder) - request = object() + use_case = StubDownloadUseCase() + handler = ExampleImagesDownloadHandler(use_case, recorder) + request = DummyRequest({"foo": "bar"}) - assert await handler.download_example_images(request) == "download" - assert await handler.get_example_images_status(request) == "status" - assert await handler.pause_example_images(request) == "pause" - assert await handler.resume_example_images(request) == "resume" - assert await handler.force_download_example_images(request) == "force" + download_response = await handler.download_example_images(request) + assert json.loads(download_response.text) == {"status": "started", "payload": {"foo": "bar"}} + status_response = await handler.get_example_images_status(request) + assert json.loads(status_response.text) == {"status": "ok"} + pause_response = await handler.pause_example_images(request) + assert json.loads(pause_response.text) == {"status": "paused"} + resume_response = await handler.resume_example_images(request) + assert json.loads(resume_response.text) == {"status": "running"} + force_response = await handler.force_download_example_images(request) + assert json.loads(force_response.text) == {"status": "force", "payload": {"foo": "bar"}} - expected = [ - ("start_download", request), + assert use_case.payloads == [{"foo": "bar"}] + assert recorder.calls == [ ("get_status", request), ("pause_download", request), ("resume_download", request), - ("start_force_download", request), + ("start_force_download", {"foo": "bar"}), ] - assert recorder.calls == expected @pytest.mark.asyncio async def test_management_handler_methods_delegate() -> None: + class StubImportUseCase: + def __init__(self) -> None: + self.requests: List[Any] = [] + + async def execute(self, request: Any) -> dict: + self.requests.append(request) + return {"status": "imported"} + class Recorder: def __init__(self) -> None: self.calls: List[Tuple[str, Any]] = [] - async def import_images(self, request) -> str: - self.calls.append(("import_images", request)) - return "import" - async def delete_custom_image(self, request) -> str: self.calls.append(("delete_custom_image", request)) return "delete" recorder = Recorder() - handler = ExampleImagesManagementHandler(recorder) + use_case = StubImportUseCase() + handler = ExampleImagesManagementHandler(use_case, recorder) request = object() - assert await handler.import_example_images(request) == "import" + import_response = await handler.import_example_images(request) + assert json.loads(import_response.text) == {"status": "imported"} assert await handler.delete_example_image(request) == "delete" - assert recorder.calls == [ - ("import_images", request), - ("delete_custom_image", request), - ] + assert use_case.requests == [request] + assert recorder.calls == [("delete_custom_image", request)] @pytest.mark.asyncio @@ -350,8 +371,29 @@ async def test_file_handler_methods_delegate() -> None: def test_handler_set_route_mapping_includes_all_handlers() -> None: - download = ExampleImagesDownloadHandler(object()) - management = ExampleImagesManagementHandler(object()) + class DummyUseCase: + async def execute(self, payload): + return payload + + class DummyManager: + async def get_status(self, request): + return {} + + async def pause_download(self, request): + return {} + + async def resume_download(self, request): + return {} + + async def start_force_download(self, payload): + return payload + + class DummyProcessor: + async def delete_custom_image(self, request): + return {} + + download = ExampleImagesDownloadHandler(DummyUseCase(), DummyManager()) + management = ExampleImagesManagementHandler(DummyUseCase(), DummyProcessor()) files = ExampleImagesFileHandler(object()) handler_set = ExampleImagesHandlerSet( download=download, diff --git a/tests/services/test_use_cases.py b/tests/services/test_use_cases.py index 64057fc6..cfd0f10c 100644 --- a/tests/services/test_use_cases.py +++ b/tests/services/test_use_cases.py @@ -10,9 +10,23 @@ from py_local.services.use_cases import ( AutoOrganizeInProgressError, AutoOrganizeUseCase, BulkMetadataRefreshUseCase, + DownloadExampleImagesConfigurationError, + DownloadExampleImagesInProgressError, + DownloadExampleImagesUseCase, DownloadModelEarlyAccessError, DownloadModelUseCase, DownloadModelValidationError, + ImportExampleImagesUseCase, + ImportExampleImagesValidationError, +) +from py_local.utils.example_images_download_manager import ( + DownloadConfigurationError, + DownloadInProgressError, + ExampleImagesDownloadError, +) +from py_local.utils.example_images_processor import ( + ExampleImagesImportError, + ExampleImagesValidationError, ) from tests.conftest import MockModelService, MockScanner @@ -88,6 +102,38 @@ class StubDownloadCoordinator: return {"success": True, "download_id": "abc123"} +class StubExampleImagesDownloadManager: + def __init__(self) -> None: + self.payloads: List[Dict[str, Any]] = [] + self.error: Optional[str] = None + self.progress_snapshot = {"status": "running"} + + async def start_download(self, payload: Dict[str, Any]) -> Dict[str, Any]: + self.payloads.append(payload) + if self.error == "in_progress": + raise DownloadInProgressError(self.progress_snapshot) + if self.error == "configuration": + raise DownloadConfigurationError("path missing") + if self.error == "generic": + raise ExampleImagesDownloadError("boom") + return {"success": True, "message": "ok"} + + +class StubExampleImagesProcessor: + def __init__(self) -> None: + self.calls: List[Dict[str, Any]] = [] + self.error: Optional[str] = None + self.response: Dict[str, Any] = {"success": True} + + async def import_images(self, model_hash: str, files: List[str]) -> Dict[str, Any]: + self.calls.append({"model_hash": model_hash, "files": files}) + if self.error == "validation": + raise ExampleImagesValidationError("missing") + if self.error == "generic": + raise ExampleImagesImportError("boom") + return self.response + + async def test_auto_organize_use_case_executes_with_lock() -> None: file_service = StubFileService() lock_provider = StubLockProvider() @@ -189,3 +235,83 @@ async def test_download_model_use_case_returns_result() -> None: assert result["success"] is True assert result["download_id"] == "abc123" + + +async def test_download_example_images_use_case_triggers_manager() -> None: + manager = StubExampleImagesDownloadManager() + use_case = DownloadExampleImagesUseCase(download_manager=manager) + + payload = {"optimize": True} + result = await use_case.execute(payload) + + assert manager.payloads == [payload] + assert result == {"success": True, "message": "ok"} + + +async def test_download_example_images_use_case_maps_in_progress() -> None: + manager = StubExampleImagesDownloadManager() + manager.error = "in_progress" + use_case = DownloadExampleImagesUseCase(download_manager=manager) + + with pytest.raises(DownloadExampleImagesInProgressError) as exc: + await use_case.execute({}) + + assert exc.value.progress == manager.progress_snapshot + + +async def test_download_example_images_use_case_maps_configuration() -> None: + manager = StubExampleImagesDownloadManager() + manager.error = "configuration" + use_case = DownloadExampleImagesUseCase(download_manager=manager) + + with pytest.raises(DownloadExampleImagesConfigurationError): + await use_case.execute({}) + + +async def test_download_example_images_use_case_propagates_generic_error() -> None: + manager = StubExampleImagesDownloadManager() + manager.error = "generic" + use_case = DownloadExampleImagesUseCase(download_manager=manager) + + with pytest.raises(ExampleImagesDownloadError): + await use_case.execute({}) + + +class DummyJsonRequest: + def __init__(self, payload: Dict[str, Any]) -> None: + self._payload = payload + self.content_type = "application/json" + + async def json(self) -> Dict[str, Any]: + return self._payload + + +async def test_import_example_images_use_case_delegates() -> None: + processor = StubExampleImagesProcessor() + use_case = ImportExampleImagesUseCase(processor=processor) + + request = DummyJsonRequest({"model_hash": "abc", "file_paths": ["/tmp/file"]}) + result = await use_case.execute(request) + + assert processor.calls == [{"model_hash": "abc", "files": ["/tmp/file"]}] + assert result == {"success": True} + + +async def test_import_example_images_use_case_maps_validation_error() -> None: + processor = StubExampleImagesProcessor() + processor.error = "validation" + use_case = ImportExampleImagesUseCase(processor=processor) + request = DummyJsonRequest({"model_hash": None, "file_paths": []}) + + with pytest.raises(ImportExampleImagesValidationError): + await use_case.execute(request) + + +async def test_import_example_images_use_case_propagates_generic_error() -> None: + processor = StubExampleImagesProcessor() + processor.error = "generic" + use_case = ImportExampleImagesUseCase(processor=processor) + request = DummyJsonRequest({"model_hash": "abc", "file_paths": ["/tmp/file"]}) + + with pytest.raises(ExampleImagesImportError): + await use_case.execute(request)