diff --git a/py/lora_manager.py b/py/lora_manager.py index 1a99d508..ed37f27d 100644 --- a/py/lora_manager.py +++ b/py/lora_manager.py @@ -166,7 +166,7 @@ class LoraManager: RecipeRoutes.setup_routes(app) UpdateRoutes.setup_routes(app) MiscRoutes.setup_routes(app) - ExampleImagesRoutes.setup_routes(app) + ExampleImagesRoutes.setup_routes(app, ws_manager=ws_manager) # Setup WebSocket routes that are shared across all model types app.router.add_get('/ws/fetch-progress', ws_manager.handle_connection) diff --git a/py/routes/example_images_routes.py b/py/routes/example_images_routes.py index d5d34218..5073410d 100644 --- a/py/routes/example_images_routes.py +++ b/py/routes/example_images_routes.py @@ -32,21 +32,24 @@ class ExampleImagesRoutes: def __init__( self, *, + ws_manager, download_manager: DownloadManager | None = None, processor=ExampleImagesProcessor, file_manager=ExampleImagesFileManager, ) -> None: - self._download_manager = download_manager or get_default_download_manager() + if ws_manager is None: + raise ValueError("ws_manager is required") + self._download_manager = download_manager or get_default_download_manager(ws_manager) self._processor = processor self._file_manager = file_manager self._handler_set: ExampleImagesHandlerSet | None = None self._handler_mapping: Mapping[str, Callable[[web.Request], web.StreamResponse]] | None = None @classmethod - def setup_routes(cls, app: web.Application) -> None: + def setup_routes(cls, app: web.Application, *, ws_manager) -> None: """Register routes on the given aiohttp application using default wiring.""" - controller = cls() + controller = cls(ws_manager=ws_manager) controller.register(app) def register(self, app: web.Application) -> None: diff --git a/py/utils/example_images_download_manager.py b/py/utils/example_images_download_manager.py index e538f50a..9ddf03a4 100644 --- a/py/utils/example_images_download_manager.py +++ b/py/utils/example_images_download_manager.py @@ -5,11 +5,12 @@ import os import asyncio import json import time +from typing import Any, Dict + from ..services.service_registry import ServiceRegistry from ..utils.metadata_manager import MetadataManager from .example_images_processor import ExampleImagesProcessor from .example_images_metadata import MetadataUpdater -from ..services.websocket_manager import ws_manager # Add this import at the top from ..services.downloader import get_downloader from ..services.settings_manager import settings @@ -76,82 +77,90 @@ class _DownloadProgress(dict): class DownloadManager: """Manages downloading example images for models.""" - def __init__(self) -> None: + def __init__(self, *, ws_manager, state_lock: asyncio.Lock | None = None) -> None: self._download_task: asyncio.Task | None = None self._is_downloading = False self._progress = _DownloadProgress() + self._ws_manager = ws_manager + self._state_lock = state_lock or asyncio.Lock() async def start_download(self, options: dict): """Start downloading example images for models.""" - if self._is_downloading: - raise DownloadInProgressError(self._progress.snapshot()) + async with self._state_lock: + if self._is_downloading: + raise DownloadInProgressError(self._progress.snapshot()) - try: - data = options or {} - auto_mode = data.get('auto_mode', False) - optimize = data.get('optimize', True) - model_types = data.get('model_types', ['lora', 'checkpoint']) - delay = float(data.get('delay', 0.2)) + try: + data = options or {} + auto_mode = data.get('auto_mode', False) + optimize = data.get('optimize', True) + model_types = data.get('model_types', ['lora', 'checkpoint']) + delay = float(data.get('delay', 0.2)) - output_dir = settings.get('example_images_path') + output_dir = settings.get('example_images_path') - if not output_dir: - error_msg = 'Example images path not configured in settings' - if auto_mode: - logger.debug(error_msg) - return { - 'success': True, - 'message': 'Example images path not configured, skipping auto download' - } - raise DownloadConfigurationError(error_msg) + if not output_dir: + error_msg = 'Example images path not configured in settings' + if auto_mode: + logger.debug(error_msg) + return { + 'success': True, + 'message': 'Example images path not configured, skipping auto download' + } + raise DownloadConfigurationError(error_msg) - os.makedirs(output_dir, exist_ok=True) + os.makedirs(output_dir, exist_ok=True) - self._progress.reset() - self._progress['status'] = 'running' - self._progress['start_time'] = time.time() - self._progress['end_time'] = None + self._progress.reset() + self._progress['status'] = 'running' + self._progress['start_time'] = time.time() + self._progress['end_time'] = None - progress_file = os.path.join(output_dir, '.download_progress.json') - if os.path.exists(progress_file): - try: - with open(progress_file, 'r', encoding='utf-8') as f: - saved_progress = json.load(f) - self._progress['processed_models'] = set(saved_progress.get('processed_models', [])) - self._progress['failed_models'] = set(saved_progress.get('failed_models', [])) - logger.debug( - "Loaded previous progress, %s models already processed, %s models marked as failed", - len(self._progress['processed_models']), - len(self._progress['failed_models']), - ) - except Exception as e: - logger.error(f"Failed to load progress file: {e}") + progress_file = os.path.join(output_dir, '.download_progress.json') + if os.path.exists(progress_file): + try: + with open(progress_file, 'r', encoding='utf-8') as f: + saved_progress = json.load(f) + self._progress['processed_models'] = set(saved_progress.get('processed_models', [])) + self._progress['failed_models'] = set(saved_progress.get('failed_models', [])) + logger.debug( + "Loaded previous progress, %s models already processed, %s models marked as failed", + len(self._progress['processed_models']), + len(self._progress['failed_models']), + ) + except Exception as e: + logger.error(f"Failed to load progress file: {e}") + self._progress['processed_models'] = set() + self._progress['failed_models'] = set() + else: self._progress['processed_models'] = set() self._progress['failed_models'] = set() - else: - self._progress['processed_models'] = set() - self._progress['failed_models'] = set() - self._is_downloading = True - self._download_task = asyncio.create_task( - self._download_all_example_images( - output_dir, - optimize, - model_types, - delay + self._is_downloading = True + self._download_task = asyncio.create_task( + self._download_all_example_images( + output_dir, + optimize, + model_types, + delay + ) ) - ) - return { - 'success': True, - 'message': 'Download started', - 'status': self._progress.snapshot() - } + snapshot = self._progress.snapshot() + except Exception as e: + self._is_downloading = False + self._download_task = None + logger.error(f"Failed to start example images download: {e}", exc_info=True) + raise ExampleImagesDownloadError(str(e)) from e - except Exception as e: - logger.error(f"Failed to start example images download: {e}", exc_info=True) - raise ExampleImagesDownloadError(str(e)) from e + await self._broadcast_progress(status='running') + + return { + 'success': True, + 'message': 'Download started', + 'status': snapshot + } async def get_status(self, request): """Get the current status of example images download.""" @@ -165,10 +174,13 @@ class DownloadManager: async def pause_download(self, request): """Pause the example images download.""" - if not self._is_downloading: - raise DownloadNotRunningError() + async with self._state_lock: + if not self._is_downloading: + raise DownloadNotRunningError() - self._progress['status'] = 'paused' + self._progress['status'] = 'paused' + + await self._broadcast_progress(status='paused') return { 'success': True, @@ -178,20 +190,23 @@ class DownloadManager: async def resume_download(self, request): """Resume the example images download.""" - if not self._is_downloading: - raise DownloadNotRunningError() + async with self._state_lock: + if not self._is_downloading: + raise DownloadNotRunningError() - if self._progress['status'] == 'paused': - self._progress['status'] = 'running' + if self._progress['status'] == 'paused': + self._progress['status'] = 'running' + else: + raise DownloadNotRunningError( + f"Download is in '{self._progress['status']}' state, cannot resume" + ) - return { - 'success': True, - 'message': 'Download resumed' - } + await self._broadcast_progress(status='running') - raise DownloadNotRunningError( - f"Download is in '{self._progress['status']}' state, cannot resume" - ) + return { + 'success': True, + 'message': 'Download resumed' + } async def _download_all_example_images(self, output_dir, optimize, model_types, delay): """Download example images for all models.""" @@ -225,6 +240,7 @@ class DownloadManager: # Update total count self._progress['total'] = len(all_models) logger.debug(f"Found {self._progress['total']} models to process") + await self._broadcast_progress(status='running') # Process each model for i, (scanner_type, model, scanner) in enumerate(all_models): @@ -236,6 +252,7 @@ class DownloadManager: # Update progress self._progress['completed'] += 1 + await self._broadcast_progress(status='running') # Only add delay after remote download of models, and not after processing the last model if was_remote_download and i < len(all_models) - 1 and self._progress['status'] == 'running': @@ -244,8 +261,13 @@ class DownloadManager: # Mark as completed self._progress['status'] = 'completed' self._progress['end_time'] = time.time() - logger.debug(f"Example images download completed: {self._progress['completed']}/{self._progress['total']} models processed") - + logger.debug( + "Example images download completed: %s/%s models processed", + self._progress['completed'], + self._progress['total'], + ) + await self._broadcast_progress(status='completed') + except Exception as e: error_msg = f"Error during example images download: {str(e)}" logger.error(error_msg, exc_info=True) @@ -253,7 +275,8 @@ class DownloadManager: self._progress['last_error'] = error_msg self._progress['status'] = 'error' self._progress['end_time'] = time.time() - + await self._broadcast_progress(status='error', extra={'error': error_msg}) + finally: # Save final progress to file try: @@ -262,8 +285,9 @@ class DownloadManager: logger.error(f"Failed to save progress file: {e}") # Set download status to not downloading - self._is_downloading = False - self._download_task = None + async with self._state_lock: + self._is_downloading = False + self._download_task = None async def _process_model(self, scanner_type, model, scanner, output_dir, optimize, downloader): """Process a single model download.""" @@ -285,6 +309,7 @@ class DownloadManager: try: # Update current model info self._progress['current_model'] = f"{model_name} ({model_hash[:8]})" + await self._broadcast_progress(status='running') # Skip if already in failed models if model_hash in self._progress['failed_models']: @@ -414,10 +439,10 @@ class DownloadManager: async def start_force_download(self, options: dict): """Force download example images for specific models.""" - if self._is_downloading: - raise DownloadInProgressError(self._progress.snapshot()) + async with self._state_lock: + if self._is_downloading: + raise DownloadInProgressError(self._progress.snapshot()) - try: data = options or {} model_hashes = data.get('model_hashes', []) optimize = data.get('optimize', True) @@ -442,6 +467,9 @@ class DownloadManager: self._is_downloading = True + await self._broadcast_progress(status='running') + + try: result = await self._download_specific_models_example_images_sync( model_hashes, output_dir, @@ -450,7 +478,8 @@ class DownloadManager: delay ) - self._is_downloading = False + async with self._state_lock: + self._is_downloading = False return { 'success': True, @@ -459,8 +488,10 @@ class DownloadManager: } except Exception as e: - self._is_downloading = False + async with self._state_lock: + self._is_downloading = False logger.error(f"Failed during forced example images download: {e}", exc_info=True) + await self._broadcast_progress(status='error', extra={'error': str(e)}) raise ExampleImagesDownloadError(str(e)) from e async def _download_specific_models_example_images_sync(self, model_hashes, output_dir, optimize, model_types, delay): @@ -495,15 +526,9 @@ class DownloadManager: # Update total count based on found models self._progress['total'] = len(models_to_process) logger.debug(f"Found {self._progress['total']} models to process") - + # Send initial progress via WebSocket - await ws_manager.broadcast({ - 'type': 'example_images_progress', - 'processed': 0, - 'total': self._progress['total'], - 'status': 'running', - 'current_model': '' - }) + await self._broadcast_progress(status='running') # Process each model success_count = 0 @@ -519,15 +544,9 @@ class DownloadManager: # Update progress self._progress['completed'] += 1 - + # Send progress update via WebSocket - await ws_manager.broadcast({ - 'type': 'example_images_progress', - 'processed': self._progress['completed'], - 'total': self._progress['total'], - 'status': 'running', - 'current_model': self._progress['current_model'] - }) + await self._broadcast_progress(status='running') # Only add delay after remote download, and not after processing the last model if was_successful and i < len(models_to_process) - 1 and self._progress['status'] == 'running': @@ -536,16 +555,14 @@ class DownloadManager: # Mark as completed self._progress['status'] = 'completed' self._progress['end_time'] = time.time() - logger.debug(f"Forced example images download completed: {self._progress['completed']}/{self._progress['total']} models processed") - + logger.debug( + "Forced example images download completed: %s/%s models processed", + self._progress['completed'], + self._progress['total'], + ) + # Send final progress via WebSocket - await ws_manager.broadcast({ - 'type': 'example_images_progress', - 'processed': self._progress['completed'], - 'total': self._progress['total'], - 'status': 'completed', - 'current_model': '' - }) + await self._broadcast_progress(status='completed') return { 'total': self._progress['total'], @@ -561,16 +578,9 @@ class DownloadManager: self._progress['last_error'] = error_msg self._progress['status'] = 'error' self._progress['end_time'] = time.time() - + # Send error status via WebSocket - await ws_manager.broadcast({ - 'type': 'example_images_progress', - 'processed': self._progress['completed'], - 'total': self._progress['total'], - 'status': 'error', - 'error': error_msg, - 'current_model': '' - }) + await self._broadcast_progress(status='error', extra={'error': error_msg}) raise @@ -598,6 +608,7 @@ class DownloadManager: try: # Update current model info self._progress['current_model'] = f"{model_name} ({model_hash[:8]})" + await self._broadcast_progress(status='running') # Create model directory model_dir = os.path.join(output_dir, model_hash) @@ -714,11 +725,53 @@ class DownloadManager: except Exception as e: logger.error(f"Error removing failed images from metadata for {model_name}: {e}", exc_info=True) + async def _broadcast_progress( + self, + *, + status: str | None = None, + extra: Dict[str, Any] | None = None, + ) -> None: + payload = self._build_progress_payload(status=status, extra=extra) + try: + await self._ws_manager.broadcast(payload) + except Exception as exc: # pragma: no cover - defensive logging + logger.warning("Failed to broadcast example image progress: %s", exc) -default_download_manager = DownloadManager() + def _build_progress_payload( + self, + *, + status: str | None = None, + extra: Dict[str, Any] | None = None, + ) -> Dict[str, Any]: + payload: Dict[str, Any] = { + 'type': 'example_images_progress', + 'processed': self._progress['completed'], + 'total': self._progress['total'], + 'status': status or self._progress['status'], + 'current_model': self._progress['current_model'], + } + + if self._progress['errors']: + payload['errors'] = list(self._progress['errors']) + if self._progress['last_error']: + payload['last_error'] = self._progress['last_error'] + + if extra: + payload.update(extra) + + return payload -def get_default_download_manager() -> DownloadManager: +_default_download_manager: DownloadManager | None = None + + +def get_default_download_manager(ws_manager) -> DownloadManager: """Return the singleton download manager used by default routes.""" - return default_download_manager + global _default_download_manager + if ( + _default_download_manager is None + or getattr(_default_download_manager, "_ws_manager", None) is not ws_manager + ): + _default_download_manager = DownloadManager(ws_manager=ws_manager) + return _default_download_manager diff --git a/standalone.py b/standalone.py index a6259851..95c45ca7 100644 --- a/standalone.py +++ b/standalone.py @@ -421,7 +421,7 @@ class StandaloneLoraManager(LoraManager): RecipeRoutes.setup_routes(app) UpdateRoutes.setup_routes(app) MiscRoutes.setup_routes(app) - ExampleImagesRoutes.setup_routes(app) + ExampleImagesRoutes.setup_routes(app, ws_manager=ws_manager) # Setup WebSocket routes that are shared across all model types app.router.add_get('/ws/fetch-progress', ws_manager.handle_connection) diff --git a/tests/routes/test_example_images_routes.py b/tests/routes/test_example_images_routes.py index e921e744..9a316499 100644 --- a/tests/routes/test_example_images_routes.py +++ b/tests/routes/test_example_images_routes.py @@ -3,7 +3,7 @@ from __future__ import annotations import json from contextlib import asynccontextmanager from dataclasses import dataclass -from typing import Any, List, Tuple +from typing import Any, Dict, List, Tuple from aiohttp import web from aiohttp.test_utils import TestClient, TestServer @@ -88,6 +88,14 @@ class StubExampleImagesFileManager: return web.json_response({"operation": "has_images", "query": dict(request.query)}) +class StubWebSocketManager: + def __init__(self) -> None: + self.broadcast_calls: List[Dict[str, Any]] = [] + + async def broadcast(self, payload: Dict[str, Any]) -> None: + self.broadcast_calls.append(payload) + + @asynccontextmanager async def example_images_app() -> ExampleImagesHarness: """Yield an ExampleImagesRoutes app wired with stubbed collaborators.""" @@ -95,8 +103,10 @@ async def example_images_app() -> ExampleImagesHarness: download_manager = StubDownloadManager() processor = StubExampleImagesProcessor() file_manager = StubExampleImagesFileManager() + ws_manager = StubWebSocketManager() controller = ExampleImagesRoutes( + ws_manager=ws_manager, download_manager=download_manager, processor=processor, file_manager=file_manager,