feat(example-images): add use case orchestration

This commit is contained in:
pixelpaws
2025-09-23 11:47:12 +08:00
parent bd10280736
commit aaad270822
10 changed files with 582 additions and 262 deletions

View File

@@ -12,6 +12,10 @@ from .handlers.example_images_handlers import (
ExampleImagesHandlerSet,
ExampleImagesManagementHandler,
)
from ..services.use_cases.example_images import (
DownloadExampleImagesUseCase,
ImportExampleImagesUseCase,
)
from ..utils.example_images_download_manager import DownloadManager
from ..utils.example_images_file_manager import ExampleImagesFileManager
from ..utils.example_images_processor import ExampleImagesProcessor
@@ -59,8 +63,10 @@ class ExampleImagesRoutes:
def _build_handler_set(self) -> ExampleImagesHandlerSet:
logger.debug("Building ExampleImagesHandlerSet with %s, %s, %s", self._download_manager, self._processor, self._file_manager)
download_handler = ExampleImagesDownloadHandler(self._download_manager)
management_handler = ExampleImagesManagementHandler(self._processor)
download_use_case = DownloadExampleImagesUseCase(download_manager=self._download_manager)
download_handler = ExampleImagesDownloadHandler(download_use_case, self._download_manager)
import_use_case = ImportExampleImagesUseCase(processor=self._processor)
management_handler = ExampleImagesManagementHandler(import_use_case, self._processor)
file_handler = ExampleImagesFileHandler(self._file_manager)
return ExampleImagesHandlerSet(
download=download_handler,

View File

@@ -6,37 +6,101 @@ from typing import Callable, Mapping
from aiohttp import web
from ...services.use_cases.example_images import (
DownloadExampleImagesConfigurationError,
DownloadExampleImagesInProgressError,
DownloadExampleImagesUseCase,
ImportExampleImagesUseCase,
ImportExampleImagesValidationError,
)
from ...utils.example_images_download_manager import (
DownloadConfigurationError,
DownloadInProgressError,
DownloadNotRunningError,
ExampleImagesDownloadError,
)
from ...utils.example_images_processor import ExampleImagesImportError
class ExampleImagesDownloadHandler:
"""HTTP adapters for download-related example image endpoints."""
def __init__(self, download_manager) -> None:
def __init__(
self,
download_use_case: DownloadExampleImagesUseCase,
download_manager,
) -> None:
self._download_use_case = download_use_case
self._download_manager = download_manager
async def download_example_images(self, request: web.Request) -> web.StreamResponse:
return await self._download_manager.start_download(request)
try:
payload = await request.json()
result = await self._download_use_case.execute(payload)
return web.json_response(result)
except DownloadExampleImagesInProgressError as exc:
response = {
'success': False,
'error': str(exc),
'status': exc.progress,
}
return web.json_response(response, status=400)
except DownloadExampleImagesConfigurationError as exc:
return web.json_response({'success': False, 'error': str(exc)}, status=400)
except ExampleImagesDownloadError as exc:
return web.json_response({'success': False, 'error': str(exc)}, status=500)
async def get_example_images_status(self, request: web.Request) -> web.StreamResponse:
return await self._download_manager.get_status(request)
result = await self._download_manager.get_status(request)
return web.json_response(result)
async def pause_example_images(self, request: web.Request) -> web.StreamResponse:
return await self._download_manager.pause_download(request)
try:
result = await self._download_manager.pause_download(request)
return web.json_response(result)
except DownloadNotRunningError as exc:
return web.json_response({'success': False, 'error': str(exc)}, status=400)
async def resume_example_images(self, request: web.Request) -> web.StreamResponse:
return await self._download_manager.resume_download(request)
try:
result = await self._download_manager.resume_download(request)
return web.json_response(result)
except DownloadNotRunningError as exc:
return web.json_response({'success': False, 'error': str(exc)}, status=400)
async def force_download_example_images(self, request: web.Request) -> web.StreamResponse:
return await self._download_manager.start_force_download(request)
try:
payload = await request.json()
result = await self._download_manager.start_force_download(payload)
return web.json_response(result)
except DownloadInProgressError as exc:
response = {
'success': False,
'error': str(exc),
'status': exc.progress_snapshot,
}
return web.json_response(response, status=400)
except DownloadConfigurationError as exc:
return web.json_response({'success': False, 'error': str(exc)}, status=400)
except ExampleImagesDownloadError as exc:
return web.json_response({'success': False, 'error': str(exc)}, status=500)
class ExampleImagesManagementHandler:
"""HTTP adapters for import/delete endpoints."""
def __init__(self, processor) -> None:
def __init__(self, import_use_case: ImportExampleImagesUseCase, processor) -> None:
self._import_use_case = import_use_case
self._processor = processor
async def import_example_images(self, request: web.Request) -> web.StreamResponse:
return await self._processor.import_images(request)
try:
result = await self._import_use_case.execute(request)
return web.json_response(result)
except ImportExampleImagesValidationError as exc:
return web.json_response({'success': False, 'error': str(exc)}, status=400)
except ExampleImagesImportError as exc:
return web.json_response({'success': False, 'error': str(exc)}, status=500)
async def delete_example_image(self, request: web.Request) -> web.StreamResponse:
return await self._processor.delete_custom_image(request)

View File

@@ -13,6 +13,13 @@ from .download_model_use_case import (
DownloadModelUseCase,
DownloadModelValidationError,
)
from .example_images import (
DownloadExampleImagesConfigurationError,
DownloadExampleImagesInProgressError,
DownloadExampleImagesUseCase,
ImportExampleImagesUseCase,
ImportExampleImagesValidationError,
)
__all__ = [
"AutoOrganizeInProgressError",
@@ -22,4 +29,9 @@ __all__ = [
"DownloadModelEarlyAccessError",
"DownloadModelUseCase",
"DownloadModelValidationError",
"DownloadExampleImagesConfigurationError",
"DownloadExampleImagesInProgressError",
"DownloadExampleImagesUseCase",
"ImportExampleImagesUseCase",
"ImportExampleImagesValidationError",
]

View File

@@ -0,0 +1,19 @@
"""Example image specific use case exports."""
from .download_example_images_use_case import (
DownloadExampleImagesUseCase,
DownloadExampleImagesInProgressError,
DownloadExampleImagesConfigurationError,
)
from .import_example_images_use_case import (
ImportExampleImagesUseCase,
ImportExampleImagesValidationError,
)
__all__ = [
"DownloadExampleImagesUseCase",
"DownloadExampleImagesInProgressError",
"DownloadExampleImagesConfigurationError",
"ImportExampleImagesUseCase",
"ImportExampleImagesValidationError",
]

View File

@@ -0,0 +1,42 @@
"""Use case coordinating example image downloads."""
from __future__ import annotations
from typing import Any, Dict
from ....utils.example_images_download_manager import (
DownloadConfigurationError,
DownloadInProgressError,
ExampleImagesDownloadError,
)
class DownloadExampleImagesInProgressError(RuntimeError):
"""Raised when a download is already running."""
def __init__(self, progress: Dict[str, Any]) -> None:
super().__init__("Download already in progress")
self.progress = progress
class DownloadExampleImagesConfigurationError(ValueError):
"""Raised when settings prevent downloads from starting."""
class DownloadExampleImagesUseCase:
"""Validate payloads and trigger the download manager."""
def __init__(self, *, download_manager) -> None:
self._download_manager = download_manager
async def execute(self, payload: Dict[str, Any]) -> Dict[str, Any]:
"""Start a download and translate manager errors."""
try:
return await self._download_manager.start_download(payload)
except DownloadInProgressError as exc:
raise DownloadExampleImagesInProgressError(exc.progress_snapshot) from exc
except DownloadConfigurationError as exc:
raise DownloadExampleImagesConfigurationError(str(exc)) from exc
except ExampleImagesDownloadError:
raise

View File

@@ -0,0 +1,86 @@
"""Use case for importing example images."""
from __future__ import annotations
import os
import tempfile
from contextlib import suppress
from typing import Any, Dict, List
from aiohttp import web
from ....utils.example_images_processor import (
ExampleImagesImportError,
ExampleImagesProcessor,
ExampleImagesValidationError,
)
class ImportExampleImagesValidationError(ValueError):
"""Raised when request validation fails."""
class ImportExampleImagesUseCase:
"""Parse upload payloads and delegate to the processor service."""
def __init__(self, *, processor: ExampleImagesProcessor) -> None:
self._processor = processor
async def execute(self, request: web.Request) -> Dict[str, Any]:
model_hash: str | None = None
files_to_import: List[str] = []
temp_files: List[str] = []
try:
if request.content_type and "multipart/form-data" in request.content_type:
reader = await request.multipart()
first_field = await reader.next()
if first_field and first_field.name == "model_hash":
model_hash = await first_field.text()
else:
# Support clients that send files first and hash later
if first_field is not None:
await self._collect_upload_file(first_field, files_to_import, temp_files)
async for field in reader:
if field.name == "model_hash" and not model_hash:
model_hash = await field.text()
elif field.name == "files":
await self._collect_upload_file(field, files_to_import, temp_files)
else:
data = await request.json()
model_hash = data.get("model_hash")
files_to_import = list(data.get("file_paths", []))
result = await self._processor.import_images(model_hash, files_to_import)
return result
except ExampleImagesValidationError as exc:
raise ImportExampleImagesValidationError(str(exc)) from exc
except ExampleImagesImportError:
raise
finally:
for path in temp_files:
with suppress(Exception):
os.remove(path)
async def _collect_upload_file(
self,
field: Any,
files_to_import: List[str],
temp_files: List[str],
) -> None:
"""Persist an uploaded file to disk and add it to the import list."""
filename = field.filename or "upload"
file_ext = os.path.splitext(filename)[1].lower()
with tempfile.NamedTemporaryFile(suffix=file_ext, delete=False) as tmp_file:
temp_files.append(tmp_file.name)
while True:
chunk = await field.read_chunk()
if not chunk:
break
tmp_file.write(chunk)
files_to_import.append(tmp_file.name)

View File

@@ -3,7 +3,6 @@ import os
import asyncio
import json
import time
from aiohttp import web
from ..services.service_registry import ServiceRegistry
from ..utils.metadata_manager import MetadataManager
from .example_images_processor import ExampleImagesProcessor
@@ -12,6 +11,30 @@ from ..services.websocket_manager import ws_manager # Add this import at the to
from ..services.downloader import get_downloader
from ..services.settings_manager import settings
class ExampleImagesDownloadError(RuntimeError):
"""Base error for example image download operations."""
class DownloadInProgressError(ExampleImagesDownloadError):
"""Raised when a download is already running."""
def __init__(self, progress_snapshot: dict) -> None:
super().__init__("Download already in progress")
self.progress_snapshot = progress_snapshot
class DownloadNotRunningError(ExampleImagesDownloadError):
"""Raised when pause/resume is requested without an active download."""
def __init__(self, message: str = "No download in progress") -> None:
super().__init__(message)
class DownloadConfigurationError(ExampleImagesDownloadError):
"""Raised when configuration prevents starting a download."""
logger = logging.getLogger(__name__)
# Download status tracking
@@ -31,11 +54,21 @@ download_progress = {
'failed_models': set() # Track models that failed to download after metadata refresh
}
def _serialize_progress() -> dict:
"""Return a JSON-serialisable snapshot of the current progress."""
snapshot = download_progress.copy()
snapshot['processed_models'] = list(download_progress['processed_models'])
snapshot['refreshed_models'] = list(download_progress['refreshed_models'])
snapshot['failed_models'] = list(download_progress['failed_models'])
return snapshot
class DownloadManager:
"""Manages downloading example images for models"""
@staticmethod
async def start_download(request):
async def start_download(options: dict):
"""
Start downloading example images for models
@@ -50,25 +83,14 @@ class DownloadManager:
global download_task, is_downloading, download_progress
if is_downloading:
# Create a copy for JSON serialization
response_progress = download_progress.copy()
response_progress['processed_models'] = list(download_progress['processed_models'])
response_progress['refreshed_models'] = list(download_progress['refreshed_models'])
response_progress['failed_models'] = list(download_progress['failed_models'])
return web.json_response({
'success': False,
'error': 'Download already in progress',
'status': response_progress
}, status=400)
raise DownloadInProgressError(_serialize_progress())
try:
# Parse the request body
data = await request.json()
data = options or {}
auto_mode = data.get('auto_mode', False)
optimize = data.get('optimize', True)
model_types = data.get('model_types', ['lora', 'checkpoint'])
delay = float(data.get('delay', 0.2)) # Default to 0.2 seconds
delay = float(data.get('delay', 0.2)) # Default to 0.2 seconds
# Get output directory from settings
output_dir = settings.get('example_images_path')
@@ -78,15 +100,11 @@ class DownloadManager:
if auto_mode:
# For auto mode, just log and return success to avoid showing error toasts
logger.debug(error_msg)
return web.json_response({
return {
'success': True,
'message': 'Example images path not configured, skipping auto download'
})
else:
return web.json_response({
'success': False,
'error': error_msg
}, status=400)
}
raise DownloadConfigurationError(error_msg)
# Create the output directory
os.makedirs(output_dir, exist_ok=True)
@@ -129,41 +147,29 @@ class DownloadManager:
)
)
# Create a copy for JSON serialization
response_progress = download_progress.copy()
response_progress['processed_models'] = list(download_progress['processed_models'])
response_progress['refreshed_models'] = list(download_progress['refreshed_models'])
response_progress['failed_models'] = list(download_progress['failed_models'])
return web.json_response({
return {
'success': True,
'message': 'Download started',
'status': response_progress
})
'status': _serialize_progress()
}
except Exception as e:
logger.error(f"Failed to start example images download: {e}", exc_info=True)
return web.json_response({
'success': False,
'error': str(e)
}, status=500)
raise ExampleImagesDownloadError(str(e)) from e
@staticmethod
async def get_status(request):
"""Get the current status of example images download"""
global download_progress
# Create a copy of the progress dict with the set converted to a list for JSON serialization
response_progress = download_progress.copy()
response_progress['processed_models'] = list(download_progress['processed_models'])
response_progress['refreshed_models'] = list(download_progress['refreshed_models'])
response_progress['failed_models'] = list(download_progress['failed_models'])
return web.json_response({
response_progress = _serialize_progress()
return {
'success': True,
'is_downloading': is_downloading,
'status': response_progress
})
}
@staticmethod
async def pause_download(request):
@@ -171,17 +177,14 @@ class DownloadManager:
global download_progress
if not is_downloading:
return web.json_response({
'success': False,
'error': 'No download in progress'
}, status=400)
raise DownloadNotRunningError()
download_progress['status'] = 'paused'
return web.json_response({
return {
'success': True,
'message': 'Download paused'
})
}
@staticmethod
async def resume_download(request):
@@ -189,23 +192,19 @@ class DownloadManager:
global download_progress
if not is_downloading:
return web.json_response({
'success': False,
'error': 'No download in progress'
}, status=400)
raise DownloadNotRunningError()
if download_progress['status'] == 'paused':
download_progress['status'] = 'running'
return web.json_response({
return {
'success': True,
'message': 'Download resumed'
})
else:
return web.json_response({
'success': False,
'error': f"Download is in '{download_progress['status']}' state, cannot resume"
}, status=400)
}
raise DownloadNotRunningError(
f"Download is in '{download_progress['status']}' state, cannot resume"
)
@staticmethod
async def _download_all_example_images(output_dir, optimize, model_types, delay):
@@ -432,7 +431,7 @@ class DownloadManager:
logger.error(f"Failed to save progress file: {e}")
@staticmethod
async def start_force_download(request):
async def start_force_download(options: dict):
"""
Force download example images for specific models
@@ -447,33 +446,23 @@ class DownloadManager:
global download_task, is_downloading, download_progress
if is_downloading:
return web.json_response({
'success': False,
'error': 'Download already in progress'
}, status=400)
raise DownloadInProgressError(_serialize_progress())
try:
# Parse the request body
data = await request.json()
data = options or {}
model_hashes = data.get('model_hashes', [])
optimize = data.get('optimize', True)
model_types = data.get('model_types', ['lora', 'checkpoint'])
delay = float(data.get('delay', 0.2)) # Default to 0.2 seconds
delay = float(data.get('delay', 0.2)) # Default to 0.2 seconds
if not model_hashes:
return web.json_response({
'success': False,
'error': 'Missing model_hashes parameter'
}, status=400)
raise DownloadConfigurationError('Missing model_hashes parameter')
# Get output directory from settings
output_dir = settings.get('example_images_path')
if not output_dir:
return web.json_response({
'success': False,
'error': 'Example images path not configured in settings'
}, status=400)
raise DownloadConfigurationError('Example images path not configured in settings')
# Create the output directory
os.makedirs(output_dir, exist_ok=True)
@@ -506,20 +495,17 @@ class DownloadManager:
# Set download status to not downloading
is_downloading = False
return web.json_response({
return {
'success': True,
'message': 'Force download completed',
'result': result
})
}
except Exception as e:
# Set download status to not downloading
is_downloading = False
logger.error(f"Failed during forced example images download: {e}", exc_info=True)
return web.json_response({
'success': False,
'error': str(e)
}, status=500)
raise ExampleImagesDownloadError(str(e)) from e
@staticmethod
async def _download_specific_models_example_images_sync(model_hashes, output_dir, optimize, model_types, delay):

View File

@@ -1,7 +1,6 @@
import logging
import os
import re
import tempfile
import random
import string
from aiohttp import web
@@ -13,6 +12,14 @@ from ..utils.metadata_manager import MetadataManager
logger = logging.getLogger(__name__)
class ExampleImagesImportError(RuntimeError):
"""Base error for example image import operations."""
class ExampleImagesValidationError(ExampleImagesImportError):
"""Raised when input validation fails."""
class ExampleImagesProcessor:
"""Processes and manipulates example images"""
@@ -299,90 +306,29 @@ class ExampleImagesProcessor:
return False
@staticmethod
async def import_images(request):
"""
Import local example images
Accepts:
- multipart/form-data form with model_hash and files fields
or
- JSON request with model_hash and file_paths
Returns:
- Success status and list of imported files
"""
async def import_images(model_hash: str, files_to_import: list[str]):
"""Import local example images for a model."""
if not model_hash:
raise ExampleImagesValidationError('Missing model_hash parameter')
if not files_to_import:
raise ExampleImagesValidationError('No files provided to import')
try:
model_hash = None
files_to_import = []
temp_files_to_cleanup = []
# Check if it's a multipart form-data request (direct file upload)
if request.content_type and 'multipart/form-data' in request.content_type:
reader = await request.multipart()
# First get model_hash
field = await reader.next()
if field.name == 'model_hash':
model_hash = await field.text()
# Then process all files
while True:
field = await reader.next()
if field is None:
break
if field.name == 'files':
# Create a temporary file with appropriate suffix for type detection
file_name = field.filename
file_ext = os.path.splitext(file_name)[1].lower()
with tempfile.NamedTemporaryFile(suffix=file_ext, delete=False) as tmp_file:
temp_path = tmp_file.name
temp_files_to_cleanup.append(temp_path) # Track for cleanup
# Write chunks to the temporary file
while True:
chunk = await field.read_chunk()
if not chunk:
break
tmp_file.write(chunk)
# Add to the list of files to process
files_to_import.append(temp_path)
else:
# Parse JSON request (legacy method using file paths)
data = await request.json()
model_hash = data.get('model_hash')
files_to_import = data.get('file_paths', [])
if not model_hash:
return web.json_response({
'success': False,
'error': 'Missing model_hash parameter'
}, status=400)
if not files_to_import:
return web.json_response({
'success': False,
'error': 'No files provided to import'
}, status=400)
# Get example images path
example_images_path = settings.get('example_images_path')
if not example_images_path:
return web.json_response({
'success': False,
'error': 'No example images path configured'
}, status=400)
raise ExampleImagesValidationError('No example images path configured')
# Find the model and get current metadata
lora_scanner = await ServiceRegistry.get_lora_scanner()
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
model_data = None
scanner = None
# Check both scanners to find the model
for scan_obj in [lora_scanner, checkpoint_scanner, embedding_scanner]:
cache = await scan_obj.get_cached_data()
@@ -393,21 +339,20 @@ class ExampleImagesProcessor:
break
if model_data:
break
if not model_data:
return web.json_response({
'success': False,
'error': f"Model with hash {model_hash} not found in cache"
}, status=404)
raise ExampleImagesImportError(
f"Model with hash {model_hash} not found in cache"
)
# Create model folder
model_folder = os.path.join(example_images_path, model_hash)
os.makedirs(model_folder, exist_ok=True)
imported_files = []
errors = []
newly_imported_paths = []
# Process each file path
for file_path in files_to_import:
try:
@@ -415,26 +360,26 @@ class ExampleImagesProcessor:
if not os.path.isfile(file_path):
errors.append(f"File not found: {file_path}")
continue
# Check if file type is supported
file_ext = os.path.splitext(file_path)[1].lower()
if not (file_ext in SUPPORTED_MEDIA_EXTENSIONS['images'] or
if not (file_ext in SUPPORTED_MEDIA_EXTENSIONS['images'] or
file_ext in SUPPORTED_MEDIA_EXTENSIONS['videos']):
errors.append(f"Unsupported file type: {file_path}")
continue
# Generate new filename using short ID instead of UUID
short_id = ExampleImagesProcessor.generate_short_id()
new_filename = f"custom_{short_id}{file_ext}"
dest_path = os.path.join(model_folder, new_filename)
# Copy the file
import shutil
shutil.copy2(file_path, dest_path)
# Store both the dest_path and the short_id
newly_imported_paths.append((dest_path, short_id))
# Add to imported files list
imported_files.append({
'name': new_filename,
@@ -444,39 +389,31 @@ class ExampleImagesProcessor:
})
except Exception as e:
errors.append(f"Error importing {file_path}: {str(e)}")
# Update metadata with new example images
regular_images, custom_images = await MetadataUpdater.update_metadata_after_import(
model_hash,
model_hash,
model_data,
scanner,
newly_imported_paths
)
return web.json_response({
return {
'success': len(imported_files) > 0,
'message': f'Successfully imported {len(imported_files)} files' +
'message': f'Successfully imported {len(imported_files)} files' +
(f' with {len(errors)} errors' if errors else ''),
'files': imported_files,
'errors': errors,
'regular_images': regular_images,
'custom_images': custom_images,
"model_file_path": model_data.get('file_path', ''),
})
}
except ExampleImagesImportError:
raise
except Exception as e:
logger.error(f"Failed to import example images: {e}", exc_info=True)
return web.json_response({
'success': False,
'error': str(e)
}, status=500)
finally:
# Clean up temporary files
for temp_file in temp_files_to_cleanup:
try:
os.remove(temp_file)
except Exception as e:
logger.error(f"Failed to remove temporary file {temp_file}: {e}")
raise ExampleImagesImportError(str(e)) from e
@staticmethod
async def delete_custom_image(request):