from __future__ import annotations import asyncio import json import time import logging import os import re import shutil import uuid from typing import Any, Dict, Iterable, List, Set, Tuple from ..services.service_registry import ServiceRegistry from ..utils.example_images_paths import ( ExampleImagePathResolver, ensure_library_root_exists, uses_library_scoped_folders, ) from ..utils.metadata_manager import MetadataManager from .example_images_processor import ExampleImagesProcessor from .example_images_metadata import MetadataUpdater from ..services.downloader import get_downloader from ..services.settings_manager import get_settings_manager class ExampleImagesDownloadError(RuntimeError): """Base error for example image download operations.""" class DownloadInProgressError(ExampleImagesDownloadError): """Raised when a download is already running.""" def __init__(self, progress_snapshot: dict) -> None: super().__init__("Download already in progress") self.progress_snapshot = progress_snapshot class DownloadNotRunningError(ExampleImagesDownloadError): """Raised when pause/resume is requested without an active download.""" def __init__(self, message: str = "No download in progress") -> None: super().__init__(message) class DownloadConfigurationError(ExampleImagesDownloadError): """Raised when configuration prevents starting a download.""" logger = logging.getLogger(__name__) class _DownloadProgress(dict): """Mutable mapping maintaining download progress with set-aware serialisation.""" def __init__(self) -> None: super().__init__() self.reset() def reset(self) -> None: """Reset the progress dictionary to its initial state.""" self.update( total=0, completed=0, current_model="", status="idle", errors=[], last_error=None, start_time=None, end_time=None, processed_models=set(), refreshed_models=set(), failed_models=set(), reprocessed_models=set(), ) def snapshot(self) -> dict: """Return a JSON-serialisable snapshot of the current progress.""" snapshot = dict(self) snapshot["processed_models"] = list(self["processed_models"]) snapshot["refreshed_models"] = list(self["refreshed_models"]) snapshot["failed_models"] = list(self["failed_models"]) snapshot["reprocessed_models"] = list(self.get("reprocessed_models", set())) return snapshot def _model_directory_has_files(path: str) -> bool: """Return True when the provided directory exists and contains entries.""" if not path or not os.path.isdir(path): return False try: with os.scandir(path) as entries: for _ in entries: return True except OSError: return False return False class DownloadManager: """Manages downloading example images for models.""" def __init__(self, *, ws_manager, state_lock: asyncio.Lock | None = None) -> None: self._download_task: asyncio.Task | None = None self._is_downloading = False self._progress = _DownloadProgress() self._ws_manager = ws_manager self._state_lock = state_lock or asyncio.Lock() self._stop_requested = False def _resolve_output_dir(self, library_name: str | None = None) -> str: base_path = get_settings_manager().get("example_images_path") if not base_path: return "" return ensure_library_root_exists(library_name) async def start_download(self, options: dict): """Start downloading example images for models.""" # Step 1: Parse options (fast, non-blocking) data = options or {} auto_mode = data.get("auto_mode", False) optimize = data.get("optimize", True) model_types = data.get("model_types", ["lora", "checkpoint"]) delay = float(data.get("delay", 0.2)) force = data.get("force", False) # Step 2: Validate configuration (fast lookup) settings_manager = get_settings_manager() base_path = settings_manager.get("example_images_path") if not base_path: error_msg = "Example images path not configured in settings" if auto_mode: logger.debug(error_msg) return { "success": True, "message": "Example images path not configured, skipping auto download", } raise DownloadConfigurationError(error_msg) active_library = settings_manager.get_active_library_name() output_dir = self._resolve_output_dir(active_library) if not output_dir: raise DownloadConfigurationError( "Example images path not configured in settings" ) # Step 3: Load progress file (I/O operation, done outside lock) processed_models = set() failed_models = set() try: progress_file, processed_models, failed_models = await self._load_progress_file(output_dir) logger.debug( "Loaded previous progress, %s models already processed, %s models marked as failed", len(processed_models), len(failed_models), ) except Exception as e: logger.error(f"Failed to load progress file: {e}") # Continue with empty sets # Step 4: Quick state check and update (minimal lock time) async with self._state_lock: if self._is_downloading: raise DownloadInProgressError(self._progress.snapshot()) try: # Reset progress with loaded data self._progress.reset() self._progress["processed_models"] = processed_models self._progress["failed_models"] = failed_models self._stop_requested = False self._progress["status"] = "running" self._progress["start_time"] = time.time() self._progress["end_time"] = None self._is_downloading = True snapshot = self._progress.snapshot() # Create the download task without awaiting it # This ensures the HTTP response is returned immediately # while the actual processing happens in the background self._download_task = asyncio.create_task( self._download_all_example_images( output_dir, optimize, model_types, delay, active_library, force, ) ) # Add a callback to handle task completion/errors self._download_task.add_done_callback( lambda t: self._handle_download_task_done(t, output_dir) ) except ExampleImagesDownloadError: # Re-raise our own exception types without wrapping self._is_downloading = False self._download_task = None raise except Exception as e: self._is_downloading = False self._download_task = None logger.error( f"Failed to start example images download: {e}", exc_info=True ) raise ExampleImagesDownloadError(str(e)) from e # Broadcast progress in the background without blocking the response # This ensures the HTTP response is returned immediately asyncio.create_task(self._broadcast_progress(status="running")) return {"success": True, "message": "Download started", "status": snapshot} def _handle_download_task_done(self, task: asyncio.Task, output_dir: str) -> None: """Handle download task completion, including saving progress on error.""" try: # This will re-raise any exception from the task task.result() except Exception as e: logger.error(f"Download task failed with error: {e}", exc_info=True) # Ensure progress is saved even on failure try: self._save_progress(output_dir) except Exception as save_error: logger.error(f"Failed to save progress after task failure: {save_error}") async def get_status(self, request) -> dict: """Get the current status of example images download.""" return { "success": True, "is_downloading": self._is_downloading, "status": self._progress.snapshot(), } async def _load_progress_file(self, output_dir: str) -> tuple[str, set, set]: """Load progress file from disk. Returns (progress_file_path, processed_models, failed_models). This is a separate async method to allow running in executor to avoid blocking event loop. """ loop = asyncio.get_event_loop() return await loop.run_in_executor( None, self._load_progress_file_sync, output_dir ) def _load_progress_file_sync(self, output_dir: str) -> tuple[str, set, set]: """Synchronous implementation of progress file loading.""" progress_file = os.path.join(output_dir, ".download_progress.json") progress_source = progress_file # Handle legacy migration if needed if uses_library_scoped_folders(): legacy_root = get_settings_manager().get("example_images_path") or "" legacy_progress = ( os.path.join(legacy_root, ".download_progress.json") if legacy_root else "" ) if ( legacy_progress and os.path.exists(legacy_progress) and not os.path.exists(progress_file) ): try: os.makedirs(output_dir, exist_ok=True) shutil.move(legacy_progress, progress_file) logger.info( "Migrated legacy download progress file '%s' to '%s'", legacy_progress, progress_file, ) except OSError as exc: logger.warning( "Failed to migrate download progress file from '%s' to '%s': %s", legacy_progress, progress_file, exc, ) progress_source = legacy_progress processed_models = set() failed_models = set() if os.path.exists(progress_source): try: with open(progress_source, "r", encoding="utf-8") as f: saved_progress = json.load(f) processed_models = set(saved_progress.get("processed_models", [])) failed_models = set(saved_progress.get("failed_models", [])) except Exception: # Return empty sets on error pass return progress_file, processed_models, failed_models def _load_progress_sets_sync(self, progress_file: str) -> tuple[set, set]: """Load only the processed and failed model sets from progress file. This is a lighter version for quick checks without legacy migration. Returns (processed_models, failed_models). """ processed_models = set() failed_models = set() if os.path.exists(progress_file): try: with open(progress_file, "r", encoding="utf-8") as f: saved_progress = json.load(f) processed_models = set(saved_progress.get("processed_models", [])) failed_models = set(saved_progress.get("failed_models", [])) except Exception: # Return empty sets on error pass return processed_models, failed_models async def check_pending_models(self, model_types: list[str]) -> dict: """Quickly check how many models need example images downloaded. This is a lightweight check that avoids the overhead of starting a full download task when no work is needed. Returns: dict with keys: - total_models: Total number of models across specified types - pending_count: Number of models needing example images - processed_count: Number of already processed models - failed_count: Number of models marked as failed - needs_download: True if there are pending models to process """ from ..services.service_registry import ServiceRegistry if self._is_downloading: return { "success": True, "is_downloading": True, "total_models": 0, "pending_count": 0, "processed_count": 0, "failed_count": 0, "needs_download": False, "message": "Download already in progress", } try: # Get scanners scanners = [] if "lora" in model_types: lora_scanner = await ServiceRegistry.get_lora_scanner() scanners.append(("lora", lora_scanner)) if "checkpoint" in model_types: checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner() scanners.append(("checkpoint", checkpoint_scanner)) if "embedding" in model_types: embedding_scanner = await ServiceRegistry.get_embedding_scanner() scanners.append(("embedding", embedding_scanner)) # Load progress file to check processed models (async to avoid blocking) settings_manager = get_settings_manager() active_library = settings_manager.get_active_library_name() output_dir = self._resolve_output_dir(active_library) processed_models: set[str] = set() failed_models: set[str] = set() if output_dir: progress_file = os.path.join(output_dir, ".download_progress.json") loop = asyncio.get_event_loop() processed_models, failed_models = await loop.run_in_executor( None, self._load_progress_sets_sync, progress_file ) # Collect all models and count in a single pass per scanner total_models = 0 all_models_with_hash: list[tuple[str, str]] = [] # (hash, name) pairs for scanner_type, scanner in scanners: cache = await scanner.get_cached_data() if cache and cache.raw_data: for model in cache.raw_data: total_models += 1 raw_hash = model.get("sha256") if raw_hash: model_hash = raw_hash.lower() all_models_with_hash.append((model_hash, model.get("model_name", "Unknown"))) models_with_hash = len(all_models_with_hash) # Calculate pending count: check which models actually need processing # A model is pending if it has a hash, is not in processed_models, # and its folder doesn't exist or is empty pending_hashes = set() for model_hash, model_name in all_models_with_hash: if model_hash not in processed_models: # Check if model folder exists with files model_dir = ExampleImagePathResolver.get_model_folder( model_hash, active_library ) if not _model_directory_has_files(model_dir): pending_hashes.add(model_hash) pending_count = len(pending_hashes) return { "success": True, "is_downloading": False, "total_models": total_models, "pending_count": pending_count, "processed_count": len(processed_models), "failed_count": len(failed_models), "needs_download": pending_count > 0, } except Exception as e: logger.error(f"Error checking pending models: {e}", exc_info=True) return { "success": False, "error": str(e), "total_models": 0, "pending_count": 0, "processed_count": 0, "failed_count": 0, "needs_download": False, } async def pause_download(self, request): """Pause the example images download.""" async with self._state_lock: if not self._is_downloading: raise DownloadNotRunningError() self._progress["status"] = "paused" await self._broadcast_progress(status="paused") return {"success": True, "message": "Download paused"} async def resume_download(self, request): """Resume the example images download.""" async with self._state_lock: if not self._is_downloading: raise DownloadNotRunningError() if self._progress["status"] == "paused": self._progress["status"] = "running" else: raise DownloadNotRunningError( f"Download is in '{self._progress['status']}' state, cannot resume" ) await self._broadcast_progress(status="running") return {"success": True, "message": "Download resumed"} async def stop_download(self, request): """Stop the example images download after the current model completes.""" async with self._state_lock: if not self._is_downloading: raise DownloadNotRunningError() if self._progress["status"] in {"completed", "error", "stopped"}: raise DownloadNotRunningError() if self._progress["status"] != "stopping": self._stop_requested = True self._progress["status"] = "stopping" await self._broadcast_progress(status="stopping") return {"success": True, "message": "Download stopping"} async def _download_all_example_images( self, output_dir, optimize, model_types, delay, library_name, force: bool = False, ): """Download example images for all models.""" downloader = await get_downloader() try: # Get scanners scanners = [] if "lora" in model_types: lora_scanner = await ServiceRegistry.get_lora_scanner() scanners.append(("lora", lora_scanner)) if "checkpoint" in model_types: checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner() scanners.append(("checkpoint", checkpoint_scanner)) if "embedding" in model_types: embedding_scanner = await ServiceRegistry.get_embedding_scanner() scanners.append(("embedding", embedding_scanner)) # Get all models all_models = [] for scanner_type, scanner in scanners: cache = await scanner.get_cached_data() if cache and cache.raw_data: for model in cache.raw_data: if model.get("sha256"): all_models.append((scanner_type, model, scanner)) # Update total count self._progress["total"] = len(all_models) logger.debug(f"Found {self._progress['total']} models to process") await self._broadcast_progress(status="running") # Process each model for i, (scanner_type, model, scanner) in enumerate(all_models): async with self._state_lock: current_status = self._progress["status"] if current_status not in {"running", "paused", "stopping"}: break # Main logic for processing model is here, but actual operations are delegated to other classes was_remote_download = await self._process_model( scanner_type, model, scanner, output_dir, optimize, downloader, library_name, force, ) # Update progress self._progress["completed"] += 1 async with self._state_lock: current_status = self._progress["status"] should_stop = self._stop_requested and current_status == "stopping" broadcast_status = ( "running" if current_status == "running" else current_status ) await self._broadcast_progress(status=broadcast_status) if should_stop: break # Only add delay after remote download of models, and not after processing the last model if ( was_remote_download and i < len(all_models) - 1 and current_status == "running" ): await asyncio.sleep(delay) async with self._state_lock: if self._stop_requested and self._progress["status"] == "stopping": self._progress["status"] = "stopped" self._progress["end_time"] = time.time() self._stop_requested = False final_status = "stopped" elif self._progress["status"] not in {"error", "stopped"}: self._progress["status"] = "completed" self._progress["end_time"] = time.time() self._stop_requested = False final_status = "completed" else: final_status = self._progress["status"] self._stop_requested = False if self._progress["end_time"] is None: self._progress["end_time"] = time.time() if final_status == "completed": logger.debug( "Example images download completed: %s/%s models processed", self._progress["completed"], self._progress["total"], ) elif final_status == "stopped": logger.debug( "Example images download stopped: %s/%s models processed", self._progress["completed"], self._progress["total"], ) reprocessed = self._progress.get("reprocessed_models", set()) if reprocessed: logger.info( "Detected %s models with missing or empty example image folders; reprocessing triggered for those models", len(reprocessed), ) await self._broadcast_progress(status=final_status) except Exception as e: error_msg = f"Error during example images download: {str(e)}" logger.error(error_msg, exc_info=True) self._progress["errors"].append(error_msg) self._progress["last_error"] = error_msg self._progress["status"] = "error" self._progress["end_time"] = time.time() await self._broadcast_progress(status="error", extra={"error": error_msg}) finally: # Save final progress to file try: self._save_progress(output_dir) except Exception as e: logger.error(f"Failed to save progress file: {e}") # Set download status to not downloading async with self._state_lock: self._is_downloading = False self._download_task = None self._stop_requested = False async def _process_model( self, scanner_type, model, scanner, output_dir, optimize, downloader, library_name, force: bool = False, ): """Process a single model download.""" # Check if download is paused while self._progress["status"] == "paused": await asyncio.sleep(1) # Check if download should continue if self._progress["status"] not in {"running", "stopping"}: logger.info(f"Download stopped: {self._progress['status']}") return False # Return False to indicate no remote download happened model_hash = model.get("sha256", "").lower() model_name = model.get("model_name", "Unknown") model_file_path = model.get("file_path", "") model_file_name = model.get("file_name", "") try: # Update current model info self._progress["current_model"] = f"{model_name} ({model_hash[:8]})" await self._broadcast_progress(status="running") # Skip if already in failed models (unless force mode is enabled) if not force and model_hash in self._progress["failed_models"]: logger.debug(f"Skipping known failed model: {model_name}") return False model_dir = ExampleImagePathResolver.get_model_folder( model_hash, library_name ) existing_files = _model_directory_has_files(model_dir) # Skip if already processed AND directory exists with files if model_hash in self._progress["processed_models"]: if existing_files: logger.debug(f"Skipping already processed model: {model_name}") return False logger.debug( "Model %s (%s) marked as processed but folder empty or missing, reprocessing triggered", model_name, model_hash, ) # Track that we are reprocessing this model for summary logging self._progress["reprocessed_models"].add(model_hash) # Remove from processed models since we need to reprocess self._progress["processed_models"].discard(model_hash) if existing_files and model_hash not in self._progress["processed_models"]: logger.debug( "Model folder already populated for %s, marking as processed without download", model_name, ) self._progress["processed_models"].add(model_hash) return False if not model_dir: logger.warning( "Unable to resolve example images folder for model %s (%s)", model_name, model_hash, ) return False # Create model directory os.makedirs(model_dir, exist_ok=True) # First check for local example images - local processing doesn't need delay local_images_processed = ( await ExampleImagesProcessor.process_local_examples( model_file_path, model_file_name, model_name, model_dir, optimize ) ) # If we processed local images, update metadata if local_images_processed: await MetadataUpdater.update_metadata_from_local_examples( model_hash, model, scanner_type, scanner, model_dir ) self._progress["processed_models"].add(model_hash) return False # Return False to indicate no remote download happened full_model = await MetadataUpdater.get_updated_model(model_hash, scanner) civitai_payload = (full_model or {}).get("civitai") if full_model else None civitai_payload = civitai_payload or {} # If no local images, try to download from remote if civitai_payload.get("images"): images = civitai_payload.get("images", []) ( success, is_stale, failed_images, ) = await ExampleImagesProcessor.download_model_images_with_tracking( model_hash, model_name, images, model_dir, optimize, downloader ) failed_urls: Set[str] = set(failed_images) # If metadata is stale, try to refresh it if is_stale and model_hash not in self._progress["refreshed_models"]: await MetadataUpdater.refresh_model_metadata( model_hash, model_name, scanner_type, scanner, self._progress ) # Get the updated model data updated_model = await MetadataUpdater.get_updated_model( model_hash, scanner ) updated_civitai = ( (updated_model or {}).get("civitai") if updated_model else None ) updated_civitai = updated_civitai or {} if updated_civitai.get("images"): # Retry download with updated metadata updated_images = updated_civitai.get("images", []) ( success, _, additional_failed, ) = await ExampleImagesProcessor.download_model_images_with_tracking( model_hash, model_name, updated_images, model_dir, optimize, downloader, ) failed_urls.update(additional_failed) self._progress["refreshed_models"].add(model_hash) if failed_urls: await self._remove_failed_images_from_metadata( model_hash, model_name, model_dir, failed_urls, scanner, ) if failed_urls: self._progress["failed_models"].add(model_hash) self._progress["processed_models"].add(model_hash) logger.info( "Removed %s failed example images for %s", len(failed_urls), model_name, ) elif success: self._progress["processed_models"].add(model_hash) # Remove from failed_models if force mode enabled and model was previously failed if force and model_hash in self._progress["failed_models"]: self._progress["failed_models"].discard(model_hash) logger.info( f"Removed {model_name} from failed_models after successful force retry" ) else: self._progress["failed_models"].add(model_hash) logger.info( "Example images download failed for %s despite metadata refresh", model_name, ) return True # Return True to indicate a remote download happened else: # No civitai data or images available, mark as failed to avoid future attempts self._progress["failed_models"].add(model_hash) logger.debug( f"No civitai images available for model {model_name}, marking as failed" ) # Save progress periodically if ( self._progress["completed"] % 10 == 0 or self._progress["completed"] == self._progress["total"] - 1 ): self._save_progress(output_dir) return False # Default return if no conditions met except Exception as e: error_msg = f"Error processing model {model.get('model_name')} ({model_hash}): {str(e)}" logger.error(error_msg, exc_info=True) self._progress["errors"].append(error_msg) self._progress["last_error"] = error_msg # Ensure model is marked as failed so we don't try again in this run self._progress["failed_models"].add(model_hash) return False def _save_progress(self, output_dir): """Save download progress to file.""" try: progress_file = os.path.join(output_dir, ".download_progress.json") # Read existing progress file if it exists existing_data = {} if os.path.exists(progress_file): try: with open(progress_file, "r", encoding="utf-8") as f: existing_data = json.load(f) except Exception as e: logger.warning(f"Failed to read existing progress file: {e}") # Create new progress data progress_data = { "processed_models": list(self._progress["processed_models"]), "refreshed_models": list(self._progress["refreshed_models"]), "failed_models": list(self._progress["failed_models"]), "completed": self._progress["completed"], "total": self._progress["total"], "last_update": time.time(), } # Preserve existing fields (especially naming_version) for key, value in existing_data.items(): if key not in progress_data: progress_data[key] = value # Write updated progress data with open(progress_file, "w", encoding="utf-8") as f: json.dump(progress_data, f, indent=2) except Exception as e: logger.error(f"Failed to save progress file: {e}") async def start_force_download(self, options: dict): """Force download example images for specific models.""" async with self._state_lock: if self._is_downloading: raise DownloadInProgressError(self._progress.snapshot()) data = options or {} model_hashes = data.get("model_hashes", []) optimize = data.get("optimize", True) model_types = data.get("model_types", ["lora", "checkpoint"]) delay = float(data.get("delay", 0.2)) if not model_hashes: raise DownloadConfigurationError("Missing model_hashes parameter") settings_manager = get_settings_manager() base_path = settings_manager.get("example_images_path") if not base_path: raise DownloadConfigurationError( "Example images path not configured in settings" ) active_library = settings_manager.get_active_library_name() output_dir = self._resolve_output_dir(active_library) if not output_dir: raise DownloadConfigurationError( "Example images path not configured in settings" ) self._progress.reset() self._stop_requested = False self._progress["total"] = len(model_hashes) self._progress["status"] = "running" self._progress["start_time"] = time.time() self._progress["end_time"] = None self._is_downloading = True await self._broadcast_progress(status="running") try: result = await self._download_specific_models_example_images_sync( model_hashes, output_dir, optimize, model_types, delay, active_library, ) async with self._state_lock: self._is_downloading = False final_status = self._progress["status"] message = "Force download completed" if final_status == "stopped": message = "Force download stopped" return {"success": True, "message": message, "result": result} except Exception as e: async with self._state_lock: self._is_downloading = False logger.error( f"Failed during forced example images download: {e}", exc_info=True ) await self._broadcast_progress(status="error", extra={"error": str(e)}) raise ExampleImagesDownloadError(str(e)) from e async def _download_specific_models_example_images_sync( self, model_hashes, output_dir, optimize, model_types, delay, library_name, ): """Download example images for specific models only - synchronous version.""" downloader = await get_downloader() try: # Get scanners scanners = [] if "lora" in model_types: lora_scanner = await ServiceRegistry.get_lora_scanner() scanners.append(("lora", lora_scanner)) if "checkpoint" in model_types: checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner() scanners.append(("checkpoint", checkpoint_scanner)) if "embedding" in model_types: embedding_scanner = await ServiceRegistry.get_embedding_scanner() scanners.append(("embedding", embedding_scanner)) # Find the specified models models_to_process = [] for scanner_type, scanner in scanners: cache = await scanner.get_cached_data() if cache and cache.raw_data: for model in cache.raw_data: if model.get("sha256") in model_hashes: models_to_process.append((scanner_type, model, scanner)) # Update total count based on found models self._progress["total"] = len(models_to_process) logger.debug(f"Found {self._progress['total']} models to process") # Send initial progress via WebSocket await self._broadcast_progress(status="running") # Process each model success_count = 0 for i, (scanner_type, model, scanner) in enumerate(models_to_process): async with self._state_lock: current_status = self._progress["status"] if current_status not in {"running", "paused", "stopping"}: break # Force process this model regardless of previous status was_successful = await self._process_specific_model( scanner_type, model, scanner, output_dir, optimize, downloader, library_name, ) if was_successful: success_count += 1 # Update progress self._progress["completed"] += 1 async with self._state_lock: current_status = self._progress["status"] should_stop = self._stop_requested and current_status == "stopping" broadcast_status = ( "running" if current_status == "running" else current_status ) # Send progress update via WebSocket await self._broadcast_progress(status=broadcast_status) if should_stop: break # Only add delay after remote download, and not after processing the last model if ( was_successful and i < len(models_to_process) - 1 and current_status == "running" ): await asyncio.sleep(delay) async with self._state_lock: if self._stop_requested and self._progress["status"] == "stopping": self._progress["status"] = "stopped" self._progress["end_time"] = time.time() self._stop_requested = False final_status = "stopped" elif self._progress["status"] not in {"error", "stopped"}: self._progress["status"] = "completed" self._progress["end_time"] = time.time() self._stop_requested = False final_status = "completed" else: final_status = self._progress["status"] self._stop_requested = False if self._progress["end_time"] is None: self._progress["end_time"] = time.time() if final_status == "completed": logger.debug( "Forced example images download completed: %s/%s models processed", self._progress["completed"], self._progress["total"], ) elif final_status == "stopped": logger.debug( "Forced example images download stopped: %s/%s models processed", self._progress["completed"], self._progress["total"], ) # Send final progress via WebSocket await self._broadcast_progress(status=final_status) return { "total": self._progress["total"], "processed": self._progress["completed"], "successful": success_count, "errors": self._progress["errors"], } except Exception as e: error_msg = f"Error during forced example images download: {str(e)}" logger.error(error_msg, exc_info=True) self._progress["errors"].append(error_msg) self._progress["last_error"] = error_msg self._progress["status"] = "error" self._progress["end_time"] = time.time() # Send error status via WebSocket await self._broadcast_progress(status="error", extra={"error": error_msg}) raise finally: # No need to close any sessions since we use the global downloader pass async def _process_specific_model( self, scanner_type, model, scanner, output_dir, optimize, downloader, library_name, ): """Process a specific model for forced download, ignoring previous download status.""" # Check if download is paused while self._progress["status"] == "paused": await asyncio.sleep(1) # Check if download should continue if self._progress["status"] not in {"running", "stopping"}: logger.info(f"Download stopped: {self._progress['status']}") return False model_hash = model.get("sha256", "").lower() model_name = model.get("model_name", "Unknown") model_file_path = model.get("file_path", "") model_file_name = model.get("file_name", "") try: # Update current model info self._progress["current_model"] = f"{model_name} ({model_hash[:8]})" await self._broadcast_progress(status="running") model_dir = ExampleImagePathResolver.get_model_folder( model_hash, library_name ) if not model_dir: logger.warning( "Unable to resolve example images folder for model %s (%s)", model_name, model_hash, ) return False os.makedirs(model_dir, exist_ok=True) # First check for local example images - local processing doesn't need delay local_images_processed = ( await ExampleImagesProcessor.process_local_examples( model_file_path, model_file_name, model_name, model_dir, optimize ) ) # If we processed local images, update metadata if local_images_processed: await MetadataUpdater.update_metadata_from_local_examples( model_hash, model, scanner_type, scanner, model_dir ) self._progress["processed_models"].add(model_hash) return False # Return False to indicate no remote download happened full_model = await MetadataUpdater.get_updated_model(model_hash, scanner) civitai_payload = (full_model or {}).get("civitai") if full_model else None civitai_payload = civitai_payload or {} # If no local images, try to download from remote if civitai_payload.get("images"): images = civitai_payload.get("images", []) ( success, is_stale, failed_images, ) = await ExampleImagesProcessor.download_model_images_with_tracking( model_hash, model_name, images, model_dir, optimize, downloader ) failed_urls: Set[str] = set(failed_images) # If metadata is stale, try to refresh it if is_stale and model_hash not in self._progress["refreshed_models"]: await MetadataUpdater.refresh_model_metadata( model_hash, model_name, scanner_type, scanner, self._progress ) # Get the updated model data updated_model = await MetadataUpdater.get_updated_model( model_hash, scanner ) updated_civitai = ( (updated_model or {}).get("civitai") if updated_model else None ) updated_civitai = updated_civitai or {} if updated_civitai.get("images"): # Retry download with updated metadata updated_images = updated_civitai.get("images", []) ( success, _, additional_failed_images, ) = await ExampleImagesProcessor.download_model_images_with_tracking( model_hash, model_name, updated_images, model_dir, optimize, downloader, ) # Combine failed images from both attempts failed_urls.update(additional_failed_images) self._progress["refreshed_models"].add(model_hash) # For forced downloads, remove failed images from metadata if failed_urls: await self._remove_failed_images_from_metadata( model_hash, model_name, model_dir, failed_urls, scanner ) # Mark as processed if ( success or failed_urls ): # Mark as processed if we successfully downloaded some images or removed failed ones self._progress["processed_models"].add(model_hash) return True # Return True to indicate a remote download happened else: logger.debug(f"No civitai images available for model {model_name}") return False except Exception as e: error_msg = f"Error processing model {model.get('model_name')}: {str(e)}" logger.error(error_msg, exc_info=True) self._progress["errors"].append(error_msg) self._progress["last_error"] = error_msg return False # Return False on exception async def _remove_failed_images_from_metadata( self, model_hash: str, model_name: str, model_dir: str, failed_images: Iterable[str], scanner, ) -> None: """Mark failed images in model metadata so they won't be retried.""" failed_set: Set[str] = {url for url in failed_images if url} if not failed_set: return try: # Get current model data model_data = await MetadataUpdater.get_updated_model(model_hash, scanner) if not model_data: logger.warning( f"Could not find model data for {model_name} to remove failed images" ) return civitai_payload = model_data.get("civitai") or {} current_images = civitai_payload.get("images") or [] if not current_images: logger.warning(f"No images in metadata for {model_name}") return updated = False for image in current_images: image_url = image.get("url") optimized_url = ( ExampleImagesProcessor.get_civitai_optimized_url(image_url) if image_url and "civitai.com" in image_url else None ) if image_url not in failed_set and optimized_url not in failed_set: continue if image.get("downloadFailed"): continue image["downloadFailed"] = True image.setdefault("downloadError", "not_found") logger.debug( "Marked example image %s for %s as failed due to missing remote asset", image_url, model_name, ) updated = True if not updated: return file_path = model_data.get("file_path") if file_path: model_copy = model_data.copy() model_copy.pop("folder", None) await MetadataManager.save_metadata(file_path, model_copy) try: await scanner.update_single_model_cache( file_path, file_path, model_data ) except AttributeError: logger.debug( "Scanner does not expose cache update for %s", model_name ) except Exception as exc: # pragma: no cover - defensive logging logger.error( "Error removing failed images from metadata for %s: %s", model_name, exc, exc_info=True, ) def _renumber_example_image_files(self, model_dir: str) -> None: if not model_dir or not os.path.isdir(model_dir): return pattern = re.compile(r"^image_(\d+)(\.[^.]+)$", re.IGNORECASE) matches: List[Tuple[int, str, str]] = [] for entry in os.listdir(model_dir): match = pattern.match(entry) if match: matches.append((int(match.group(1)), entry, match.group(2))) if not matches: return matches.sort(key=lambda item: item[0]) staged_paths: List[Tuple[str, str]] = [] for _, original_name, extension in matches: source_path = os.path.join(model_dir, original_name) temp_name = f"tmp_{uuid.uuid4().hex}_{original_name}" temp_path = os.path.join(model_dir, temp_name) try: os.rename(source_path, temp_path) staged_paths.append((temp_path, extension)) except OSError as exc: logger.warning("Failed to stage rename for %s: %s", source_path, exc) for new_index, (temp_path, extension) in enumerate(staged_paths): final_name = f"image_{new_index}{extension}" final_path = os.path.join(model_dir, final_name) try: os.rename(temp_path, final_path) except OSError as exc: logger.warning("Failed to finalise rename for %s: %s", final_path, exc) async def _broadcast_progress( self, *, status: str | None = None, extra: Dict[str, Any] | None = None, ) -> None: payload = self._build_progress_payload(status=status, extra=extra) try: await self._ws_manager.broadcast(payload) except Exception as exc: # pragma: no cover - defensive logging logger.warning("Failed to broadcast example image progress: %s", exc) def _build_progress_payload( self, *, status: str | None = None, extra: Dict[str, Any] | None = None, ) -> Dict[str, Any]: payload: Dict[str, Any] = { "type": "example_images_progress", "processed": self._progress["completed"], "total": self._progress["total"], "status": status or self._progress["status"], "current_model": self._progress["current_model"], } if self._progress["errors"]: payload["errors"] = list(self._progress["errors"]) if self._progress["last_error"]: payload["last_error"] = self._progress["last_error"] if extra: payload.update(extra) return payload _default_download_manager: DownloadManager | None = None def get_default_download_manager(ws_manager) -> DownloadManager: """Return the singleton download manager used by default routes.""" global _default_download_manager if ( _default_download_manager is None or getattr(_default_download_manager, "_ws_manager", None) is not ws_manager ): _default_download_manager = DownloadManager(ws_manager=ws_manager) return _default_download_manager