from __future__ import annotations import logging import os import asyncio import json import time from typing import Any, Dict from ..services.service_registry import ServiceRegistry from ..utils.example_images_paths import ensure_library_root_exists from ..utils.metadata_manager import MetadataManager from .example_images_processor import ExampleImagesProcessor from .example_images_metadata import MetadataUpdater from ..services.downloader import get_downloader from ..services.settings_manager import settings class ExampleImagesDownloadError(RuntimeError): """Base error for example image download operations.""" class DownloadInProgressError(ExampleImagesDownloadError): """Raised when a download is already running.""" def __init__(self, progress_snapshot: dict) -> None: super().__init__("Download already in progress") self.progress_snapshot = progress_snapshot class DownloadNotRunningError(ExampleImagesDownloadError): """Raised when pause/resume is requested without an active download.""" def __init__(self, message: str = "No download in progress") -> None: super().__init__(message) class DownloadConfigurationError(ExampleImagesDownloadError): """Raised when configuration prevents starting a download.""" logger = logging.getLogger(__name__) class _DownloadProgress(dict): """Mutable mapping maintaining download progress with set-aware serialisation.""" def __init__(self) -> None: super().__init__() self.reset() def reset(self) -> None: """Reset the progress dictionary to its initial state.""" self.update( total=0, completed=0, current_model='', status='idle', errors=[], last_error=None, start_time=None, end_time=None, processed_models=set(), refreshed_models=set(), failed_models=set(), ) def snapshot(self) -> dict: """Return a JSON-serialisable snapshot of the current progress.""" snapshot = dict(self) snapshot['processed_models'] = list(self['processed_models']) snapshot['refreshed_models'] = list(self['refreshed_models']) snapshot['failed_models'] = list(self['failed_models']) return snapshot class DownloadManager: """Manages downloading example images for models.""" def __init__(self, *, ws_manager, state_lock: asyncio.Lock | None = None) -> None: self._download_task: asyncio.Task | None = None self._is_downloading = False self._progress = _DownloadProgress() self._ws_manager = ws_manager self._state_lock = state_lock or asyncio.Lock() def _resolve_output_dir(self) -> str: base_path = settings.get('example_images_path') if not base_path: return '' library_name = settings.get_active_library_name() return ensure_library_root_exists(library_name) async def start_download(self, options: dict): """Start downloading example images for models.""" async with self._state_lock: if self._is_downloading: raise DownloadInProgressError(self._progress.snapshot()) try: data = options or {} auto_mode = data.get('auto_mode', False) optimize = data.get('optimize', True) model_types = data.get('model_types', ['lora', 'checkpoint']) delay = float(data.get('delay', 0.2)) base_path = settings.get('example_images_path') if not base_path: error_msg = 'Example images path not configured in settings' if auto_mode: logger.debug(error_msg) return { 'success': True, 'message': 'Example images path not configured, skipping auto download' } raise DownloadConfigurationError(error_msg) output_dir = self._resolve_output_dir() if not output_dir: raise DownloadConfigurationError('Example images path not configured in settings') self._progress.reset() self._progress['status'] = 'running' self._progress['start_time'] = time.time() self._progress['end_time'] = None progress_file = os.path.join(output_dir, '.download_progress.json') if os.path.exists(progress_file): try: with open(progress_file, 'r', encoding='utf-8') as f: saved_progress = json.load(f) self._progress['processed_models'] = set(saved_progress.get('processed_models', [])) self._progress['failed_models'] = set(saved_progress.get('failed_models', [])) logger.debug( "Loaded previous progress, %s models already processed, %s models marked as failed", len(self._progress['processed_models']), len(self._progress['failed_models']), ) except Exception as e: logger.error(f"Failed to load progress file: {e}") self._progress['processed_models'] = set() self._progress['failed_models'] = set() else: self._progress['processed_models'] = set() self._progress['failed_models'] = set() self._is_downloading = True self._download_task = asyncio.create_task( self._download_all_example_images( output_dir, optimize, model_types, delay ) ) snapshot = self._progress.snapshot() except Exception as e: self._is_downloading = False self._download_task = None logger.error(f"Failed to start example images download: {e}", exc_info=True) raise ExampleImagesDownloadError(str(e)) from e await self._broadcast_progress(status='running') return { 'success': True, 'message': 'Download started', 'status': snapshot } async def get_status(self, request): """Get the current status of example images download.""" return { 'success': True, 'is_downloading': self._is_downloading, 'status': self._progress.snapshot(), } async def pause_download(self, request): """Pause the example images download.""" async with self._state_lock: if not self._is_downloading: raise DownloadNotRunningError() self._progress['status'] = 'paused' await self._broadcast_progress(status='paused') return { 'success': True, 'message': 'Download paused' } async def resume_download(self, request): """Resume the example images download.""" async with self._state_lock: if not self._is_downloading: raise DownloadNotRunningError() if self._progress['status'] == 'paused': self._progress['status'] = 'running' else: raise DownloadNotRunningError( f"Download is in '{self._progress['status']}' state, cannot resume" ) await self._broadcast_progress(status='running') return { 'success': True, 'message': 'Download resumed' } async def _download_all_example_images(self, output_dir, optimize, model_types, delay): """Download example images for all models.""" downloader = await get_downloader() try: # Get scanners scanners = [] if 'lora' in model_types: lora_scanner = await ServiceRegistry.get_lora_scanner() scanners.append(('lora', lora_scanner)) if 'checkpoint' in model_types: checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner() scanners.append(('checkpoint', checkpoint_scanner)) if 'embedding' in model_types: embedding_scanner = await ServiceRegistry.get_embedding_scanner() scanners.append(('embedding', embedding_scanner)) # Get all models all_models = [] for scanner_type, scanner in scanners: cache = await scanner.get_cached_data() if cache and cache.raw_data: for model in cache.raw_data: if model.get('sha256'): all_models.append((scanner_type, model, scanner)) # Update total count self._progress['total'] = len(all_models) logger.debug(f"Found {self._progress['total']} models to process") await self._broadcast_progress(status='running') # Process each model for i, (scanner_type, model, scanner) in enumerate(all_models): # Main logic for processing model is here, but actual operations are delegated to other classes was_remote_download = await self._process_model( scanner_type, model, scanner, output_dir, optimize, downloader ) # Update progress self._progress['completed'] += 1 await self._broadcast_progress(status='running') # Only add delay after remote download of models, and not after processing the last model if was_remote_download and i < len(all_models) - 1 and self._progress['status'] == 'running': await asyncio.sleep(delay) # Mark as completed self._progress['status'] = 'completed' self._progress['end_time'] = time.time() logger.debug( "Example images download completed: %s/%s models processed", self._progress['completed'], self._progress['total'], ) await self._broadcast_progress(status='completed') except Exception as e: error_msg = f"Error during example images download: {str(e)}" logger.error(error_msg, exc_info=True) self._progress['errors'].append(error_msg) self._progress['last_error'] = error_msg self._progress['status'] = 'error' self._progress['end_time'] = time.time() await self._broadcast_progress(status='error', extra={'error': error_msg}) finally: # Save final progress to file try: self._save_progress(output_dir) except Exception as e: logger.error(f"Failed to save progress file: {e}") # Set download status to not downloading async with self._state_lock: self._is_downloading = False self._download_task = None async def _process_model(self, scanner_type, model, scanner, output_dir, optimize, downloader): """Process a single model download.""" # Check if download is paused while self._progress['status'] == 'paused': await asyncio.sleep(1) # Check if download should continue if self._progress['status'] != 'running': logger.info(f"Download stopped: {self._progress['status']}") return False # Return False to indicate no remote download happened model_hash = model.get('sha256', '').lower() model_name = model.get('model_name', 'Unknown') model_file_path = model.get('file_path', '') model_file_name = model.get('file_name', '') try: # Update current model info self._progress['current_model'] = f"{model_name} ({model_hash[:8]})" await self._broadcast_progress(status='running') # Skip if already in failed models if model_hash in self._progress['failed_models']: logger.debug(f"Skipping known failed model: {model_name}") return False # Skip if already processed AND directory exists with files if model_hash in self._progress['processed_models']: model_dir = os.path.join(output_dir, model_hash) has_files = os.path.exists(model_dir) and any(os.listdir(model_dir)) if has_files: logger.debug(f"Skipping already processed model: {model_name}") return False else: logger.info(f"Model {model_name} marked as processed but folder empty or missing, reprocessing") # Remove from processed models since we need to reprocess self._progress['processed_models'].discard(model_hash) # Create model directory model_dir = os.path.join(output_dir, model_hash) os.makedirs(model_dir, exist_ok=True) # First check for local example images - local processing doesn't need delay local_images_processed = await ExampleImagesProcessor.process_local_examples( model_file_path, model_file_name, model_name, model_dir, optimize ) # If we processed local images, update metadata if local_images_processed: await MetadataUpdater.update_metadata_from_local_examples( model_hash, model, scanner_type, scanner, model_dir ) self._progress['processed_models'].add(model_hash) return False # Return False to indicate no remote download happened full_model = await MetadataUpdater.get_updated_model( model_hash, scanner ) civitai_payload = (full_model or {}).get('civitai', {}) if full_model else {} # If no local images, try to download from remote if civitai_payload.get('images'): images = civitai_payload.get('images', []) success, is_stale = await ExampleImagesProcessor.download_model_images( model_hash, model_name, images, model_dir, optimize, downloader ) # If metadata is stale, try to refresh it if is_stale and model_hash not in self._progress['refreshed_models']: await MetadataUpdater.refresh_model_metadata( model_hash, model_name, scanner_type, scanner, self._progress ) # Get the updated model data updated_model = await MetadataUpdater.get_updated_model( model_hash, scanner ) updated_civitai = (updated_model or {}).get('civitai', {}) if updated_model else {} if updated_civitai.get('images'): # Retry download with updated metadata updated_images = updated_civitai.get('images', []) success, _ = await ExampleImagesProcessor.download_model_images( model_hash, model_name, updated_images, model_dir, optimize, downloader ) self._progress['refreshed_models'].add(model_hash) # Mark as processed if successful, or as failed if unsuccessful after refresh if success: self._progress['processed_models'].add(model_hash) else: # If we refreshed metadata and still failed, mark as permanently failed if model_hash in self._progress['refreshed_models']: self._progress['failed_models'].add(model_hash) logger.info(f"Marking model {model_name} as failed after metadata refresh") return True # Return True to indicate a remote download happened else: # No civitai data or images available, mark as failed to avoid future attempts self._progress['failed_models'].add(model_hash) logger.debug(f"No civitai images available for model {model_name}, marking as failed") # Save progress periodically if self._progress['completed'] % 10 == 0 or self._progress['completed'] == self._progress['total'] - 1: self._save_progress(output_dir) return False # Default return if no conditions met except Exception as e: error_msg = f"Error processing model {model.get('model_name')}: {str(e)}" logger.error(error_msg, exc_info=True) self._progress['errors'].append(error_msg) self._progress['last_error'] = error_msg return False # Return False on exception def _save_progress(self, output_dir): """Save download progress to file.""" try: progress_file = os.path.join(output_dir, '.download_progress.json') # Read existing progress file if it exists existing_data = {} if os.path.exists(progress_file): try: with open(progress_file, 'r', encoding='utf-8') as f: existing_data = json.load(f) except Exception as e: logger.warning(f"Failed to read existing progress file: {e}") # Create new progress data progress_data = { 'processed_models': list(self._progress['processed_models']), 'refreshed_models': list(self._progress['refreshed_models']), 'failed_models': list(self._progress['failed_models']), 'completed': self._progress['completed'], 'total': self._progress['total'], 'last_update': time.time() } # Preserve existing fields (especially naming_version) for key, value in existing_data.items(): if key not in progress_data: progress_data[key] = value # Write updated progress data with open(progress_file, 'w', encoding='utf-8') as f: json.dump(progress_data, f, indent=2) except Exception as e: logger.error(f"Failed to save progress file: {e}") async def start_force_download(self, options: dict): """Force download example images for specific models.""" async with self._state_lock: if self._is_downloading: raise DownloadInProgressError(self._progress.snapshot()) data = options or {} model_hashes = data.get('model_hashes', []) optimize = data.get('optimize', True) model_types = data.get('model_types', ['lora', 'checkpoint']) delay = float(data.get('delay', 0.2)) if not model_hashes: raise DownloadConfigurationError('Missing model_hashes parameter') base_path = settings.get('example_images_path') if not base_path: raise DownloadConfigurationError('Example images path not configured in settings') output_dir = self._resolve_output_dir() if not output_dir: raise DownloadConfigurationError('Example images path not configured in settings') self._progress.reset() self._progress['total'] = len(model_hashes) self._progress['status'] = 'running' self._progress['start_time'] = time.time() self._progress['end_time'] = None self._is_downloading = True await self._broadcast_progress(status='running') try: result = await self._download_specific_models_example_images_sync( model_hashes, output_dir, optimize, model_types, delay ) async with self._state_lock: self._is_downloading = False return { 'success': True, 'message': 'Force download completed', 'result': result } except Exception as e: async with self._state_lock: self._is_downloading = False logger.error(f"Failed during forced example images download: {e}", exc_info=True) await self._broadcast_progress(status='error', extra={'error': str(e)}) raise ExampleImagesDownloadError(str(e)) from e async def _download_specific_models_example_images_sync(self, model_hashes, output_dir, optimize, model_types, delay): """Download example images for specific models only - synchronous version.""" downloader = await get_downloader() try: # Get scanners scanners = [] if 'lora' in model_types: lora_scanner = await ServiceRegistry.get_lora_scanner() scanners.append(('lora', lora_scanner)) if 'checkpoint' in model_types: checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner() scanners.append(('checkpoint', checkpoint_scanner)) if 'embedding' in model_types: embedding_scanner = await ServiceRegistry.get_embedding_scanner() scanners.append(('embedding', embedding_scanner)) # Find the specified models models_to_process = [] for scanner_type, scanner in scanners: cache = await scanner.get_cached_data() if cache and cache.raw_data: for model in cache.raw_data: if model.get('sha256') in model_hashes: models_to_process.append((scanner_type, model, scanner)) # Update total count based on found models self._progress['total'] = len(models_to_process) logger.debug(f"Found {self._progress['total']} models to process") # Send initial progress via WebSocket await self._broadcast_progress(status='running') # Process each model success_count = 0 for i, (scanner_type, model, scanner) in enumerate(models_to_process): # Force process this model regardless of previous status was_successful = await self._process_specific_model( scanner_type, model, scanner, output_dir, optimize, downloader ) if was_successful: success_count += 1 # Update progress self._progress['completed'] += 1 # Send progress update via WebSocket await self._broadcast_progress(status='running') # Only add delay after remote download, and not after processing the last model if was_successful and i < len(models_to_process) - 1 and self._progress['status'] == 'running': await asyncio.sleep(delay) # Mark as completed self._progress['status'] = 'completed' self._progress['end_time'] = time.time() logger.debug( "Forced example images download completed: %s/%s models processed", self._progress['completed'], self._progress['total'], ) # Send final progress via WebSocket await self._broadcast_progress(status='completed') return { 'total': self._progress['total'], 'processed': self._progress['completed'], 'successful': success_count, 'errors': self._progress['errors'] } except Exception as e: error_msg = f"Error during forced example images download: {str(e)}" logger.error(error_msg, exc_info=True) self._progress['errors'].append(error_msg) self._progress['last_error'] = error_msg self._progress['status'] = 'error' self._progress['end_time'] = time.time() # Send error status via WebSocket await self._broadcast_progress(status='error', extra={'error': error_msg}) raise finally: # No need to close any sessions since we use the global downloader pass async def _process_specific_model(self, scanner_type, model, scanner, output_dir, optimize, downloader): """Process a specific model for forced download, ignoring previous download status.""" # Check if download is paused while self._progress['status'] == 'paused': await asyncio.sleep(1) # Check if download should continue if self._progress['status'] != 'running': logger.info(f"Download stopped: {self._progress['status']}") return False model_hash = model.get('sha256', '').lower() model_name = model.get('model_name', 'Unknown') model_file_path = model.get('file_path', '') model_file_name = model.get('file_name', '') try: # Update current model info self._progress['current_model'] = f"{model_name} ({model_hash[:8]})" await self._broadcast_progress(status='running') # Create model directory model_dir = os.path.join(output_dir, model_hash) os.makedirs(model_dir, exist_ok=True) # First check for local example images - local processing doesn't need delay local_images_processed = await ExampleImagesProcessor.process_local_examples( model_file_path, model_file_name, model_name, model_dir, optimize ) # If we processed local images, update metadata if local_images_processed: await MetadataUpdater.update_metadata_from_local_examples( model_hash, model, scanner_type, scanner, model_dir ) self._progress['processed_models'].add(model_hash) return False # Return False to indicate no remote download happened full_model = await MetadataUpdater.get_updated_model( model_hash, scanner ) civitai_payload = (full_model or {}).get('civitai', {}) if full_model else {} # If no local images, try to download from remote if civitai_payload.get('images'): images = civitai_payload.get('images', []) success, is_stale, failed_images = await ExampleImagesProcessor.download_model_images_with_tracking( model_hash, model_name, images, model_dir, optimize, downloader ) # If metadata is stale, try to refresh it if is_stale and model_hash not in self._progress['refreshed_models']: await MetadataUpdater.refresh_model_metadata( model_hash, model_name, scanner_type, scanner, self._progress ) # Get the updated model data updated_model = await MetadataUpdater.get_updated_model( model_hash, scanner ) updated_civitai = (updated_model or {}).get('civitai', {}) if updated_model else {} if updated_civitai.get('images'): # Retry download with updated metadata updated_images = updated_civitai.get('images', []) success, _, additional_failed_images = await ExampleImagesProcessor.download_model_images_with_tracking( model_hash, model_name, updated_images, model_dir, optimize, downloader ) # Combine failed images from both attempts failed_images.extend(additional_failed_images) self._progress['refreshed_models'].add(model_hash) # For forced downloads, remove failed images from metadata if failed_images: # Create a copy of images excluding failed ones await self._remove_failed_images_from_metadata( model_hash, model_name, failed_images, scanner ) # Mark as processed if success or failed_images: # Mark as processed if we successfully downloaded some images or removed failed ones self._progress['processed_models'].add(model_hash) return True # Return True to indicate a remote download happened else: logger.debug(f"No civitai images available for model {model_name}") return False except Exception as e: error_msg = f"Error processing model {model.get('model_name')}: {str(e)}" logger.error(error_msg, exc_info=True) self._progress['errors'].append(error_msg) self._progress['last_error'] = error_msg return False # Return False on exception async def _remove_failed_images_from_metadata(self, model_hash, model_name, failed_images, scanner): """Remove failed images from model metadata""" try: # Get current model data model_data = await MetadataUpdater.get_updated_model(model_hash, scanner) if not model_data: logger.warning(f"Could not find model data for {model_name} to remove failed images") return if not model_data.get('civitai', {}).get('images'): logger.warning(f"No images in metadata for {model_name}") return # Get current images current_images = model_data['civitai']['images'] # Filter out failed images updated_images = [img for img in current_images if img.get('url') not in failed_images] # If images were removed, update metadata if len(updated_images) < len(current_images): removed_count = len(current_images) - len(updated_images) logger.info(f"Removing {removed_count} failed images from metadata for {model_name}") # Update the images list model_data['civitai']['images'] = updated_images # Save metadata to file file_path = model_data.get('file_path') if file_path: # Create a copy of model data without 'folder' field model_copy = model_data.copy() model_copy.pop('folder', None) # Write metadata to file await MetadataManager.save_metadata(file_path, model_copy) logger.info(f"Saved updated metadata for {model_name} after removing failed images") # Update the scanner cache await scanner.update_single_model_cache(file_path, file_path, model_data) except Exception as e: logger.error(f"Error removing failed images from metadata for {model_name}: {e}", exc_info=True) async def _broadcast_progress( self, *, status: str | None = None, extra: Dict[str, Any] | None = None, ) -> None: payload = self._build_progress_payload(status=status, extra=extra) try: await self._ws_manager.broadcast(payload) except Exception as exc: # pragma: no cover - defensive logging logger.warning("Failed to broadcast example image progress: %s", exc) def _build_progress_payload( self, *, status: str | None = None, extra: Dict[str, Any] | None = None, ) -> Dict[str, Any]: payload: Dict[str, Any] = { 'type': 'example_images_progress', 'processed': self._progress['completed'], 'total': self._progress['total'], 'status': status or self._progress['status'], 'current_model': self._progress['current_model'], } if self._progress['errors']: payload['errors'] = list(self._progress['errors']) if self._progress['last_error']: payload['last_error'] = self._progress['last_error'] if extra: payload.update(extra) return payload _default_download_manager: DownloadManager | None = None def get_default_download_manager(ws_manager) -> DownloadManager: """Return the singleton download manager used by default routes.""" global _default_download_manager if ( _default_download_manager is None or getattr(_default_download_manager, "_ws_manager", None) is not ws_manager ): _default_download_manager = DownloadManager(ws_manager=ws_manager) return _default_download_manager