diff --git a/py/utils/example_images_download_manager.py b/py/utils/example_images_download_manager.py index f5019847..56022eba 100644 --- a/py/utils/example_images_download_manager.py +++ b/py/utils/example_images_download_manager.py @@ -62,8 +62,8 @@ class _DownloadProgress(dict): self.update( total=0, completed=0, - current_model='', - status='idle', + current_model="", + status="idle", errors=[], last_error=None, start_time=None, @@ -78,10 +78,10 @@ class _DownloadProgress(dict): """Return a JSON-serialisable snapshot of the current progress.""" snapshot = dict(self) - snapshot['processed_models'] = list(self['processed_models']) - snapshot['refreshed_models'] = list(self['refreshed_models']) - snapshot['failed_models'] = list(self['failed_models']) - snapshot['reprocessed_models'] = list(self.get('reprocessed_models', set())) + snapshot["processed_models"] = list(self["processed_models"]) + snapshot["refreshed_models"] = list(self["refreshed_models"]) + snapshot["failed_models"] = list(self["failed_models"]) + snapshot["reprocessed_models"] = list(self.get("reprocessed_models", set())) return snapshot @@ -100,6 +100,7 @@ def _model_directory_has_files(path: str) -> bool: return False + class DownloadManager: """Manages downloading example images for models.""" @@ -112,9 +113,9 @@ class DownloadManager: self._stop_requested = False def _resolve_output_dir(self, library_name: str | None = None) -> str: - base_path = get_settings_manager().get('example_images_path') + base_path = get_settings_manager().get("example_images_path") if not base_path: - return '' + return "" return ensure_library_root_exists(library_name) async def start_download(self, options: dict): @@ -126,41 +127,54 @@ class DownloadManager: try: data = options or {} - auto_mode = data.get('auto_mode', False) - optimize = data.get('optimize', True) - model_types = data.get('model_types', ['lora', 'checkpoint']) - delay = float(data.get('delay', 0.2)) + auto_mode = data.get("auto_mode", False) + optimize = data.get("optimize", True) + model_types = data.get("model_types", ["lora", "checkpoint"]) + delay = float(data.get("delay", 0.2)) + force = data.get("force", False) settings_manager = get_settings_manager() - base_path = settings_manager.get('example_images_path') + base_path = settings_manager.get("example_images_path") if not base_path: - error_msg = 'Example images path not configured in settings' + error_msg = "Example images path not configured in settings" if auto_mode: logger.debug(error_msg) return { - 'success': True, - 'message': 'Example images path not configured, skipping auto download' + "success": True, + "message": "Example images path not configured, skipping auto download", } raise DownloadConfigurationError(error_msg) active_library = get_settings_manager().get_active_library_name() output_dir = self._resolve_output_dir(active_library) if not output_dir: - raise DownloadConfigurationError('Example images path not configured in settings') + raise DownloadConfigurationError( + "Example images path not configured in settings" + ) self._progress.reset() self._stop_requested = False - self._progress['status'] = 'running' - self._progress['start_time'] = time.time() - self._progress['end_time'] = None + self._progress["status"] = "running" + self._progress["start_time"] = time.time() + self._progress["end_time"] = None - progress_file = os.path.join(output_dir, '.download_progress.json') + progress_file = os.path.join(output_dir, ".download_progress.json") progress_source = progress_file if uses_library_scoped_folders(): - legacy_root = get_settings_manager().get('example_images_path') or '' - legacy_progress = os.path.join(legacy_root, '.download_progress.json') if legacy_root else '' - if legacy_progress and os.path.exists(legacy_progress) and not os.path.exists(progress_file): + legacy_root = ( + get_settings_manager().get("example_images_path") or "" + ) + legacy_progress = ( + os.path.join(legacy_root, ".download_progress.json") + if legacy_root + else "" + ) + if ( + legacy_progress + and os.path.exists(legacy_progress) + and not os.path.exists(progress_file) + ): try: os.makedirs(output_dir, exist_ok=True) shutil.move(legacy_progress, progress_file) @@ -180,22 +194,26 @@ class DownloadManager: if os.path.exists(progress_source): try: - with open(progress_source, 'r', encoding='utf-8') as f: + with open(progress_source, "r", encoding="utf-8") as f: saved_progress = json.load(f) - self._progress['processed_models'] = set(saved_progress.get('processed_models', [])) - self._progress['failed_models'] = set(saved_progress.get('failed_models', [])) + self._progress["processed_models"] = set( + saved_progress.get("processed_models", []) + ) + self._progress["failed_models"] = set( + saved_progress.get("failed_models", []) + ) logger.debug( "Loaded previous progress, %s models already processed, %s models marked as failed", - len(self._progress['processed_models']), - len(self._progress['failed_models']), + len(self._progress["processed_models"]), + len(self._progress["failed_models"]), ) except Exception as e: logger.error(f"Failed to load progress file: {e}") - self._progress['processed_models'] = set() - self._progress['failed_models'] = set() + self._progress["processed_models"] = set() + self._progress["failed_models"] = set() else: - self._progress['processed_models'] = set() - self._progress['failed_models'] = set() + self._progress["processed_models"] = set() + self._progress["failed_models"] = set() self._is_downloading = True self._download_task = asyncio.create_task( @@ -205,6 +223,7 @@ class DownloadManager: model_types, delay, active_library, + force, ) ) @@ -217,24 +236,22 @@ class DownloadManager: except Exception as e: self._is_downloading = False self._download_task = None - logger.error(f"Failed to start example images download: {e}", exc_info=True) + logger.error( + f"Failed to start example images download: {e}", exc_info=True + ) raise ExampleImagesDownloadError(str(e)) from e - await self._broadcast_progress(status='running') + await self._broadcast_progress(status="running") + + return {"success": True, "message": "Download started", "status": snapshot} - return { - 'success': True, - 'message': 'Download started', - 'status': snapshot - } - async def get_status(self, request): """Get the current status of example images download.""" return { - 'success': True, - 'is_downloading': self._is_downloading, - 'status': self._progress.snapshot(), + "success": True, + "is_downloading": self._is_downloading, + "status": self._progress.snapshot(), } async def pause_download(self, request): @@ -244,14 +261,11 @@ class DownloadManager: if not self._is_downloading: raise DownloadNotRunningError() - self._progress['status'] = 'paused' + self._progress["status"] = "paused" - await self._broadcast_progress(status='paused') + await self._broadcast_progress(status="paused") - return { - 'success': True, - 'message': 'Download paused' - } + return {"success": True, "message": "Download paused"} async def resume_download(self, request): """Resume the example images download.""" @@ -260,19 +274,16 @@ class DownloadManager: if not self._is_downloading: raise DownloadNotRunningError() - if self._progress['status'] == 'paused': - self._progress['status'] = 'running' + if self._progress["status"] == "paused": + self._progress["status"] = "running" else: raise DownloadNotRunningError( f"Download is in '{self._progress['status']}' state, cannot resume" ) - await self._broadcast_progress(status='running') + await self._broadcast_progress(status="running") - return { - 'success': True, - 'message': 'Download resumed' - } + return {"success": True, "message": "Download resumed"} async def stop_download(self, request): """Stop the example images download after the current model completes.""" @@ -281,20 +292,17 @@ class DownloadManager: if not self._is_downloading: raise DownloadNotRunningError() - if self._progress['status'] in {'completed', 'error', 'stopped'}: + if self._progress["status"] in {"completed", "error", "stopped"}: raise DownloadNotRunningError() - if self._progress['status'] != 'stopping': + if self._progress["status"] != "stopping": self._stop_requested = True - self._progress['status'] = 'stopping' + self._progress["status"] = "stopping" - await self._broadcast_progress(status='stopping') + await self._broadcast_progress(status="stopping") + + return {"success": True, "message": "Download stopping"} - return { - 'success': True, - 'message': 'Download stopping' - } - async def _download_all_example_images( self, output_dir, @@ -302,46 +310,47 @@ class DownloadManager: model_types, delay, library_name, + force: bool = False, ): """Download example images for all models.""" downloader = await get_downloader() - + try: # Get scanners scanners = [] - if 'lora' in model_types: + if "lora" in model_types: lora_scanner = await ServiceRegistry.get_lora_scanner() - scanners.append(('lora', lora_scanner)) - - if 'checkpoint' in model_types: - checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner() - scanners.append(('checkpoint', checkpoint_scanner)) + scanners.append(("lora", lora_scanner)) - if 'embedding' in model_types: + if "checkpoint" in model_types: + checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner() + scanners.append(("checkpoint", checkpoint_scanner)) + + if "embedding" in model_types: embedding_scanner = await ServiceRegistry.get_embedding_scanner() - scanners.append(('embedding', embedding_scanner)) - + scanners.append(("embedding", embedding_scanner)) + # Get all models all_models = [] for scanner_type, scanner in scanners: cache = await scanner.get_cached_data() if cache and cache.raw_data: for model in cache.raw_data: - if model.get('sha256'): + if model.get("sha256"): all_models.append((scanner_type, model, scanner)) - + # Update total count - self._progress['total'] = len(all_models) + self._progress["total"] = len(all_models) logger.debug(f"Found {self._progress['total']} models to process") - await self._broadcast_progress(status='running') - + await self._broadcast_progress(status="running") + # Process each model for i, (scanner_type, model, scanner) in enumerate(all_models): async with self._state_lock: - current_status = self._progress['status'] + current_status = self._progress["status"] - if current_status not in {'running', 'paused', 'stopping'}: + if current_status not in {"running", "paused", "stopping"}: break # Main logic for processing model is here, but actual operations are delegated to other classes @@ -353,16 +362,19 @@ class DownloadManager: optimize, downloader, library_name, + force, ) # Update progress - self._progress['completed'] += 1 + self._progress["completed"] += 1 async with self._state_lock: - current_status = self._progress['status'] - should_stop = self._stop_requested and current_status == 'stopping' + current_status = self._progress["status"] + should_stop = self._stop_requested and current_status == "stopping" - broadcast_status = 'running' if current_status == 'running' else current_status + broadcast_status = ( + "running" if current_status == "running" else current_status + ) await self._broadcast_progress(status=broadcast_status) if should_stop: @@ -372,41 +384,41 @@ class DownloadManager: if ( was_remote_download and i < len(all_models) - 1 - and current_status == 'running' + and current_status == "running" ): await asyncio.sleep(delay) async with self._state_lock: - if self._stop_requested and self._progress['status'] == 'stopping': - self._progress['status'] = 'stopped' - self._progress['end_time'] = time.time() + if self._stop_requested and self._progress["status"] == "stopping": + self._progress["status"] = "stopped" + self._progress["end_time"] = time.time() self._stop_requested = False - final_status = 'stopped' - elif self._progress['status'] not in {'error', 'stopped'}: - self._progress['status'] = 'completed' - self._progress['end_time'] = time.time() + final_status = "stopped" + elif self._progress["status"] not in {"error", "stopped"}: + self._progress["status"] = "completed" + self._progress["end_time"] = time.time() self._stop_requested = False - final_status = 'completed' + final_status = "completed" else: - final_status = self._progress['status'] + final_status = self._progress["status"] self._stop_requested = False - if self._progress['end_time'] is None: - self._progress['end_time'] = time.time() + if self._progress["end_time"] is None: + self._progress["end_time"] = time.time() - if final_status == 'completed': + if final_status == "completed": logger.debug( "Example images download completed: %s/%s models processed", - self._progress['completed'], - self._progress['total'], + self._progress["completed"], + self._progress["total"], ) - elif final_status == 'stopped': + elif final_status == "stopped": logger.debug( "Example images download stopped: %s/%s models processed", - self._progress['completed'], - self._progress['total'], + self._progress["completed"], + self._progress["total"], ) - reprocessed = self._progress.get('reprocessed_models', set()) + reprocessed = self._progress.get("reprocessed_models", set()) if reprocessed: logger.info( "Detected %s models with missing or empty example image folders; reprocessing triggered for those models", @@ -418,11 +430,11 @@ class DownloadManager: except Exception as e: error_msg = f"Error during example images download: {str(e)}" logger.error(error_msg, exc_info=True) - self._progress['errors'].append(error_msg) - self._progress['last_error'] = error_msg - self._progress['status'] = 'error' - self._progress['end_time'] = time.time() - await self._broadcast_progress(status='error', extra={'error': error_msg}) + self._progress["errors"].append(error_msg) + self._progress["last_error"] = error_msg + self._progress["status"] = "error" + self._progress["end_time"] = time.time() + await self._broadcast_progress(status="error", extra={"error": error_msg}) finally: # Save final progress to file @@ -436,7 +448,7 @@ class DownloadManager: self._is_downloading = False self._download_task = None self._stop_requested = False - + async def _process_model( self, scanner_type, @@ -446,58 +458,61 @@ class DownloadManager: optimize, downloader, library_name, + force: bool = False, ): """Process a single model download.""" # Check if download is paused - while self._progress['status'] == 'paused': + while self._progress["status"] == "paused": await asyncio.sleep(1) # Check if download should continue - if self._progress['status'] not in {'running', 'stopping'}: + if self._progress["status"] not in {"running", "stopping"}: logger.info(f"Download stopped: {self._progress['status']}") return False # Return False to indicate no remote download happened - - model_hash = model.get('sha256', '').lower() - model_name = model.get('model_name', 'Unknown') - model_file_path = model.get('file_path', '') - model_file_name = model.get('file_name', '') - + + model_hash = model.get("sha256", "").lower() + model_name = model.get("model_name", "Unknown") + model_file_path = model.get("file_path", "") + model_file_name = model.get("file_name", "") + try: # Update current model info - self._progress['current_model'] = f"{model_name} ({model_hash[:8]})" - await self._broadcast_progress(status='running') - - # Skip if already in failed models - if model_hash in self._progress['failed_models']: + self._progress["current_model"] = f"{model_name} ({model_hash[:8]})" + await self._broadcast_progress(status="running") + + # Skip if already in failed models (unless force mode is enabled) + if not force and model_hash in self._progress["failed_models"]: logger.debug(f"Skipping known failed model: {model_name}") return False - - model_dir = ExampleImagePathResolver.get_model_folder(model_hash, library_name) + + model_dir = ExampleImagePathResolver.get_model_folder( + model_hash, library_name + ) existing_files = _model_directory_has_files(model_dir) # Skip if already processed AND directory exists with files - if model_hash in self._progress['processed_models']: + if model_hash in self._progress["processed_models"]: if existing_files: logger.debug(f"Skipping already processed model: {model_name}") return False - + logger.debug( "Model %s (%s) marked as processed but folder empty or missing, reprocessing triggered", model_name, model_hash, ) # Track that we are reprocessing this model for summary logging - self._progress['reprocessed_models'].add(model_hash) + self._progress["reprocessed_models"].add(model_hash) # Remove from processed models since we need to reprocess - self._progress['processed_models'].discard(model_hash) + self._progress["processed_models"].discard(model_hash) - if existing_files and model_hash not in self._progress['processed_models']: + if existing_files and model_hash not in self._progress["processed_models"]: logger.debug( "Model folder already populated for %s, marking as processed without download", model_name, ) - self._progress['processed_models'].add(model_hash) + self._progress["processed_models"].add(model_hash) return False if not model_dir: @@ -510,38 +525,42 @@ class DownloadManager: # Create model directory os.makedirs(model_dir, exist_ok=True) - + # First check for local example images - local processing doesn't need delay - local_images_processed = await ExampleImagesProcessor.process_local_examples( - model_file_path, model_file_name, model_name, model_dir, optimize + local_images_processed = ( + await ExampleImagesProcessor.process_local_examples( + model_file_path, model_file_name, model_name, model_dir, optimize + ) ) - + # If we processed local images, update metadata if local_images_processed: await MetadataUpdater.update_metadata_from_local_examples( model_hash, model, scanner_type, scanner, model_dir ) - self._progress['processed_models'].add(model_hash) + self._progress["processed_models"].add(model_hash) return False # Return False to indicate no remote download happened - - full_model = await MetadataUpdater.get_updated_model( - model_hash, scanner - ) - civitai_payload = (full_model or {}).get('civitai') if full_model else None + + full_model = await MetadataUpdater.get_updated_model(model_hash, scanner) + civitai_payload = (full_model or {}).get("civitai") if full_model else None civitai_payload = civitai_payload or {} # If no local images, try to download from remote - if civitai_payload.get('images'): - images = civitai_payload.get('images', []) + if civitai_payload.get("images"): + images = civitai_payload.get("images", []) - success, is_stale, failed_images = await ExampleImagesProcessor.download_model_images_with_tracking( + ( + success, + is_stale, + failed_images, + ) = await ExampleImagesProcessor.download_model_images_with_tracking( model_hash, model_name, images, model_dir, optimize, downloader ) failed_urls: Set[str] = set(failed_images) # If metadata is stale, try to refresh it - if is_stale and model_hash not in self._progress['refreshed_models']: + if is_stale and model_hash not in self._progress["refreshed_models"]: await MetadataUpdater.refresh_model_metadata( model_hash, model_name, scanner_type, scanner, self._progress ) @@ -550,19 +569,30 @@ class DownloadManager: updated_model = await MetadataUpdater.get_updated_model( model_hash, scanner ) - updated_civitai = (updated_model or {}).get('civitai') if updated_model else None + updated_civitai = ( + (updated_model or {}).get("civitai") if updated_model else None + ) updated_civitai = updated_civitai or {} - if updated_civitai.get('images'): + if updated_civitai.get("images"): # Retry download with updated metadata - updated_images = updated_civitai.get('images', []) - success, _, additional_failed = await ExampleImagesProcessor.download_model_images_with_tracking( - model_hash, model_name, updated_images, model_dir, optimize, downloader + updated_images = updated_civitai.get("images", []) + ( + success, + _, + additional_failed, + ) = await ExampleImagesProcessor.download_model_images_with_tracking( + model_hash, + model_name, + updated_images, + model_dir, + optimize, + downloader, ) failed_urls.update(additional_failed) - self._progress['refreshed_models'].add(model_hash) + self._progress["refreshed_models"].add(model_hash) if failed_urls: await self._remove_failed_images_from_metadata( @@ -574,75 +604,89 @@ class DownloadManager: ) if failed_urls: - self._progress['failed_models'].add(model_hash) - self._progress['processed_models'].add(model_hash) + self._progress["failed_models"].add(model_hash) + self._progress["processed_models"].add(model_hash) logger.info( - "Removed %s failed example images for %s", len(failed_urls), model_name + "Removed %s failed example images for %s", + len(failed_urls), + model_name, ) elif success: - self._progress['processed_models'].add(model_hash) + self._progress["processed_models"].add(model_hash) + # Remove from failed_models if force mode enabled and model was previously failed + if force and model_hash in self._progress["failed_models"]: + self._progress["failed_models"].discard(model_hash) + logger.info( + f"Removed {model_name} from failed_models after successful force retry" + ) else: - self._progress['failed_models'].add(model_hash) + self._progress["failed_models"].add(model_hash) logger.info( - "Example images download failed for %s despite metadata refresh", model_name + "Example images download failed for %s despite metadata refresh", + model_name, ) return True # Return True to indicate a remote download happened else: # No civitai data or images available, mark as failed to avoid future attempts - self._progress['failed_models'].add(model_hash) - logger.debug(f"No civitai images available for model {model_name}, marking as failed") + self._progress["failed_models"].add(model_hash) + logger.debug( + f"No civitai images available for model {model_name}, marking as failed" + ) # Save progress periodically - if self._progress['completed'] % 10 == 0 or self._progress['completed'] == self._progress['total'] - 1: + if ( + self._progress["completed"] % 10 == 0 + or self._progress["completed"] == self._progress["total"] - 1 + ): self._save_progress(output_dir) - + return False # Default return if no conditions met - + except Exception as e: error_msg = f"Error processing model {model.get('model_name')} ({model_hash}): {str(e)}" logger.error(error_msg, exc_info=True) - self._progress['errors'].append(error_msg) - self._progress['last_error'] = error_msg + self._progress["errors"].append(error_msg) + self._progress["last_error"] = error_msg # Ensure model is marked as failed so we don't try again in this run - self._progress['failed_models'].add(model_hash) + self._progress["failed_models"].add(model_hash) return False - + def _save_progress(self, output_dir): """Save download progress to file.""" try: - progress_file = os.path.join(output_dir, '.download_progress.json') - + progress_file = os.path.join(output_dir, ".download_progress.json") + # Read existing progress file if it exists existing_data = {} if os.path.exists(progress_file): try: - with open(progress_file, 'r', encoding='utf-8') as f: + with open(progress_file, "r", encoding="utf-8") as f: existing_data = json.load(f) except Exception as e: logger.warning(f"Failed to read existing progress file: {e}") - + # Create new progress data progress_data = { - 'processed_models': list(self._progress['processed_models']), - 'refreshed_models': list(self._progress['refreshed_models']), - 'failed_models': list(self._progress['failed_models']), - 'completed': self._progress['completed'], - 'total': self._progress['total'], - 'last_update': time.time() + "processed_models": list(self._progress["processed_models"]), + "refreshed_models": list(self._progress["refreshed_models"]), + "failed_models": list(self._progress["failed_models"]), + "completed": self._progress["completed"], + "total": self._progress["total"], + "last_update": time.time(), } - + # Preserve existing fields (especially naming_version) for key, value in existing_data.items(): if key not in progress_data: progress_data[key] = value - + # Write updated progress data - with open(progress_file, 'w', encoding='utf-8') as f: + with open(progress_file, "w", encoding="utf-8") as f: json.dump(progress_data, f, indent=2) except Exception as e: logger.error(f"Failed to save progress file: {e}") - + async def start_force_download(self, options: dict): """Force download example images for specific models.""" @@ -651,34 +695,38 @@ class DownloadManager: raise DownloadInProgressError(self._progress.snapshot()) data = options or {} - model_hashes = data.get('model_hashes', []) - optimize = data.get('optimize', True) - model_types = data.get('model_types', ['lora', 'checkpoint']) - delay = float(data.get('delay', 0.2)) + model_hashes = data.get("model_hashes", []) + optimize = data.get("optimize", True) + model_types = data.get("model_types", ["lora", "checkpoint"]) + delay = float(data.get("delay", 0.2)) if not model_hashes: - raise DownloadConfigurationError('Missing model_hashes parameter') + raise DownloadConfigurationError("Missing model_hashes parameter") settings_manager = get_settings_manager() - base_path = settings_manager.get('example_images_path') + base_path = settings_manager.get("example_images_path") if not base_path: - raise DownloadConfigurationError('Example images path not configured in settings') + raise DownloadConfigurationError( + "Example images path not configured in settings" + ) active_library = settings_manager.get_active_library_name() output_dir = self._resolve_output_dir(active_library) if not output_dir: - raise DownloadConfigurationError('Example images path not configured in settings') + raise DownloadConfigurationError( + "Example images path not configured in settings" + ) self._progress.reset() self._stop_requested = False - self._progress['total'] = len(model_hashes) - self._progress['status'] = 'running' - self._progress['start_time'] = time.time() - self._progress['end_time'] = None + self._progress["total"] = len(model_hashes) + self._progress["status"] = "running" + self._progress["start_time"] = time.time() + self._progress["end_time"] = None self._is_downloading = True - await self._broadcast_progress(status='running') + await self._broadcast_progress(status="running") try: result = await self._download_specific_models_example_images_sync( @@ -692,25 +740,23 @@ class DownloadManager: async with self._state_lock: self._is_downloading = False - final_status = self._progress['status'] + final_status = self._progress["status"] - message = 'Force download completed' - if final_status == 'stopped': - message = 'Force download stopped' + message = "Force download completed" + if final_status == "stopped": + message = "Force download stopped" - return { - 'success': True, - 'message': message, - 'result': result - } + return {"success": True, "message": message, "result": result} except Exception as e: async with self._state_lock: self._is_downloading = False - logger.error(f"Failed during forced example images download: {e}", exc_info=True) - await self._broadcast_progress(status='error', extra={'error': str(e)}) + logger.error( + f"Failed during forced example images download: {e}", exc_info=True + ) + await self._broadcast_progress(status="error", extra={"error": str(e)}) raise ExampleImagesDownloadError(str(e)) from e - + async def _download_specific_models_example_images_sync( self, model_hashes, @@ -723,45 +769,45 @@ class DownloadManager: """Download example images for specific models only - synchronous version.""" downloader = await get_downloader() - + try: # Get scanners scanners = [] - if 'lora' in model_types: + if "lora" in model_types: lora_scanner = await ServiceRegistry.get_lora_scanner() - scanners.append(('lora', lora_scanner)) - - if 'checkpoint' in model_types: - checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner() - scanners.append(('checkpoint', checkpoint_scanner)) + scanners.append(("lora", lora_scanner)) - if 'embedding' in model_types: + if "checkpoint" in model_types: + checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner() + scanners.append(("checkpoint", checkpoint_scanner)) + + if "embedding" in model_types: embedding_scanner = await ServiceRegistry.get_embedding_scanner() - scanners.append(('embedding', embedding_scanner)) - + scanners.append(("embedding", embedding_scanner)) + # Find the specified models models_to_process = [] for scanner_type, scanner in scanners: cache = await scanner.get_cached_data() if cache and cache.raw_data: for model in cache.raw_data: - if model.get('sha256') in model_hashes: + if model.get("sha256") in model_hashes: models_to_process.append((scanner_type, model, scanner)) - + # Update total count based on found models - self._progress['total'] = len(models_to_process) + self._progress["total"] = len(models_to_process) logger.debug(f"Found {self._progress['total']} models to process") # Send initial progress via WebSocket - await self._broadcast_progress(status='running') - + await self._broadcast_progress(status="running") + # Process each model success_count = 0 for i, (scanner_type, model, scanner) in enumerate(models_to_process): async with self._state_lock: - current_status = self._progress['status'] + current_status = self._progress["status"] - if current_status not in {'running', 'paused', 'stopping'}: + if current_status not in {"running", "paused", "stopping"}: break # Force process this model regardless of previous status @@ -779,13 +825,15 @@ class DownloadManager: success_count += 1 # Update progress - self._progress['completed'] += 1 + self._progress["completed"] += 1 async with self._state_lock: - current_status = self._progress['status'] - should_stop = self._stop_requested and current_status == 'stopping' + current_status = self._progress["status"] + should_stop = self._stop_requested and current_status == "stopping" - broadcast_status = 'running' if current_status == 'running' else current_status + broadcast_status = ( + "running" if current_status == "running" else current_status + ) # Send progress update via WebSocket await self._broadcast_progress(status=broadcast_status) @@ -796,67 +844,67 @@ class DownloadManager: if ( was_successful and i < len(models_to_process) - 1 - and current_status == 'running' + and current_status == "running" ): await asyncio.sleep(delay) async with self._state_lock: - if self._stop_requested and self._progress['status'] == 'stopping': - self._progress['status'] = 'stopped' - self._progress['end_time'] = time.time() + if self._stop_requested and self._progress["status"] == "stopping": + self._progress["status"] = "stopped" + self._progress["end_time"] = time.time() self._stop_requested = False - final_status = 'stopped' - elif self._progress['status'] not in {'error', 'stopped'}: - self._progress['status'] = 'completed' - self._progress['end_time'] = time.time() + final_status = "stopped" + elif self._progress["status"] not in {"error", "stopped"}: + self._progress["status"] = "completed" + self._progress["end_time"] = time.time() self._stop_requested = False - final_status = 'completed' + final_status = "completed" else: - final_status = self._progress['status'] + final_status = self._progress["status"] self._stop_requested = False - if self._progress['end_time'] is None: - self._progress['end_time'] = time.time() + if self._progress["end_time"] is None: + self._progress["end_time"] = time.time() - if final_status == 'completed': + if final_status == "completed": logger.debug( "Forced example images download completed: %s/%s models processed", - self._progress['completed'], - self._progress['total'], + self._progress["completed"], + self._progress["total"], ) - elif final_status == 'stopped': + elif final_status == "stopped": logger.debug( "Forced example images download stopped: %s/%s models processed", - self._progress['completed'], - self._progress['total'], + self._progress["completed"], + self._progress["total"], ) # Send final progress via WebSocket await self._broadcast_progress(status=final_status) return { - 'total': self._progress['total'], - 'processed': self._progress['completed'], - 'successful': success_count, - 'errors': self._progress['errors'] + "total": self._progress["total"], + "processed": self._progress["completed"], + "successful": success_count, + "errors": self._progress["errors"], } - + except Exception as e: error_msg = f"Error during forced example images download: {str(e)}" logger.error(error_msg, exc_info=True) - self._progress['errors'].append(error_msg) - self._progress['last_error'] = error_msg - self._progress['status'] = 'error' - self._progress['end_time'] = time.time() + self._progress["errors"].append(error_msg) + self._progress["last_error"] = error_msg + self._progress["status"] = "error" + self._progress["end_time"] = time.time() # Send error status via WebSocket - await self._broadcast_progress(status='error', extra={'error': error_msg}) - + await self._broadcast_progress(status="error", extra={"error": error_msg}) + raise - + finally: # No need to close any sessions since we use the global downloader pass - + async def _process_specific_model( self, scanner_type, @@ -870,25 +918,27 @@ class DownloadManager: """Process a specific model for forced download, ignoring previous download status.""" # Check if download is paused - while self._progress['status'] == 'paused': + while self._progress["status"] == "paused": await asyncio.sleep(1) - + # Check if download should continue - if self._progress['status'] not in {'running', 'stopping'}: + if self._progress["status"] not in {"running", "stopping"}: logger.info(f"Download stopped: {self._progress['status']}") return False - - model_hash = model.get('sha256', '').lower() - model_name = model.get('model_name', 'Unknown') - model_file_path = model.get('file_path', '') - model_file_name = model.get('file_name', '') - + + model_hash = model.get("sha256", "").lower() + model_name = model.get("model_name", "Unknown") + model_file_path = model.get("file_path", "") + model_file_name = model.get("file_name", "") + try: # Update current model info - self._progress['current_model'] = f"{model_name} ({model_hash[:8]})" - await self._broadcast_progress(status='running') - - model_dir = ExampleImagePathResolver.get_model_folder(model_hash, library_name) + self._progress["current_model"] = f"{model_name} ({model_hash[:8]})" + await self._broadcast_progress(status="running") + + model_dir = ExampleImagePathResolver.get_model_folder( + model_hash, library_name + ) if not model_dir: logger.warning( "Unable to resolve example images folder for model %s (%s)", @@ -898,38 +948,42 @@ class DownloadManager: return False os.makedirs(model_dir, exist_ok=True) - + # First check for local example images - local processing doesn't need delay - local_images_processed = await ExampleImagesProcessor.process_local_examples( - model_file_path, model_file_name, model_name, model_dir, optimize + local_images_processed = ( + await ExampleImagesProcessor.process_local_examples( + model_file_path, model_file_name, model_name, model_dir, optimize + ) ) - + # If we processed local images, update metadata if local_images_processed: await MetadataUpdater.update_metadata_from_local_examples( model_hash, model, scanner_type, scanner, model_dir ) - self._progress['processed_models'].add(model_hash) + self._progress["processed_models"].add(model_hash) return False # Return False to indicate no remote download happened - - full_model = await MetadataUpdater.get_updated_model( - model_hash, scanner - ) - civitai_payload = (full_model or {}).get('civitai') if full_model else None + + full_model = await MetadataUpdater.get_updated_model(model_hash, scanner) + civitai_payload = (full_model or {}).get("civitai") if full_model else None civitai_payload = civitai_payload or {} # If no local images, try to download from remote - if civitai_payload.get('images'): - images = civitai_payload.get('images', []) + if civitai_payload.get("images"): + images = civitai_payload.get("images", []) - success, is_stale, failed_images = await ExampleImagesProcessor.download_model_images_with_tracking( + ( + success, + is_stale, + failed_images, + ) = await ExampleImagesProcessor.download_model_images_with_tracking( model_hash, model_name, images, model_dir, optimize, downloader ) failed_urls: Set[str] = set(failed_images) # If metadata is stale, try to refresh it - if is_stale and model_hash not in self._progress['refreshed_models']: + if is_stale and model_hash not in self._progress["refreshed_models"]: await MetadataUpdater.refresh_model_metadata( model_hash, model_name, scanner_type, scanner, self._progress ) @@ -938,20 +992,31 @@ class DownloadManager: updated_model = await MetadataUpdater.get_updated_model( model_hash, scanner ) - updated_civitai = (updated_model or {}).get('civitai') if updated_model else None + updated_civitai = ( + (updated_model or {}).get("civitai") if updated_model else None + ) updated_civitai = updated_civitai or {} - if updated_civitai.get('images'): + if updated_civitai.get("images"): # Retry download with updated metadata - updated_images = updated_civitai.get('images', []) - success, _, additional_failed_images = await ExampleImagesProcessor.download_model_images_with_tracking( - model_hash, model_name, updated_images, model_dir, optimize, downloader + updated_images = updated_civitai.get("images", []) + ( + success, + _, + additional_failed_images, + ) = await ExampleImagesProcessor.download_model_images_with_tracking( + model_hash, + model_name, + updated_images, + model_dir, + optimize, + downloader, ) # Combine failed images from both attempts failed_urls.update(additional_failed_images) - self._progress['refreshed_models'].add(model_hash) + self._progress["refreshed_models"].add(model_hash) # For forced downloads, remove failed images from metadata if failed_urls: @@ -960,21 +1025,22 @@ class DownloadManager: ) # Mark as processed - if success or failed_urls: # Mark as processed if we successfully downloaded some images or removed failed ones - self._progress['processed_models'].add(model_hash) + if ( + success or failed_urls + ): # Mark as processed if we successfully downloaded some images or removed failed ones + self._progress["processed_models"].add(model_hash) return True # Return True to indicate a remote download happened else: logger.debug(f"No civitai images available for model {model_name}") - return False - + except Exception as e: error_msg = f"Error processing model {model.get('model_name')}: {str(e)}" logger.error(error_msg, exc_info=True) - self._progress['errors'].append(error_msg) - self._progress['last_error'] = error_msg + self._progress["errors"].append(error_msg) + self._progress["last_error"] = error_msg return False # Return False on exception async def _remove_failed_images_from_metadata( @@ -995,11 +1061,13 @@ class DownloadManager: # Get current model data model_data = await MetadataUpdater.get_updated_model(model_hash, scanner) if not model_data: - logger.warning(f"Could not find model data for {model_name} to remove failed images") + logger.warning( + f"Could not find model data for {model_name} to remove failed images" + ) return - civitai_payload = model_data.get('civitai') or {} - current_images = civitai_payload.get('images') or [] + civitai_payload = model_data.get("civitai") or {} + current_images = civitai_payload.get("images") or [] if not current_images: logger.warning(f"No images in metadata for {model_name}") return @@ -1007,21 +1075,21 @@ class DownloadManager: updated = False for image in current_images: - image_url = image.get('url') + image_url = image.get("url") optimized_url = ( ExampleImagesProcessor.get_civitai_optimized_url(image_url) - if image_url and 'civitai.com' in image_url + if image_url and "civitai.com" in image_url else None ) if image_url not in failed_set and optimized_url not in failed_set: continue - if image.get('downloadFailed'): + if image.get("downloadFailed"): continue - image['downloadFailed'] = True - image.setdefault('downloadError', 'not_found') + image["downloadFailed"] = True + image.setdefault("downloadError", "not_found") logger.debug( "Marked example image %s for %s as failed due to missing remote asset", image_url, @@ -1032,27 +1100,34 @@ class DownloadManager: if not updated: return - file_path = model_data.get('file_path') + file_path = model_data.get("file_path") if file_path: model_copy = model_data.copy() - model_copy.pop('folder', None) + model_copy.pop("folder", None) await MetadataManager.save_metadata(file_path, model_copy) try: - await scanner.update_single_model_cache(file_path, file_path, model_data) + await scanner.update_single_model_cache( + file_path, file_path, model_data + ) except AttributeError: - logger.debug("Scanner does not expose cache update for %s", model_name) + logger.debug( + "Scanner does not expose cache update for %s", model_name + ) except Exception as exc: # pragma: no cover - defensive logging logger.error( - "Error removing failed images from metadata for %s: %s", model_name, exc, exc_info=True + "Error removing failed images from metadata for %s: %s", + model_name, + exc, + exc_info=True, ) def _renumber_example_image_files(self, model_dir: str) -> None: if not model_dir or not os.path.isdir(model_dir): return - pattern = re.compile(r'^image_(\d+)(\.[^.]+)$', re.IGNORECASE) + pattern = re.compile(r"^image_(\d+)(\.[^.]+)$", re.IGNORECASE) matches: List[Tuple[int, str, str]] = [] for entry in os.listdir(model_dir): @@ -1103,17 +1178,17 @@ class DownloadManager: extra: Dict[str, Any] | None = None, ) -> Dict[str, Any]: payload: Dict[str, Any] = { - 'type': 'example_images_progress', - 'processed': self._progress['completed'], - 'total': self._progress['total'], - 'status': status or self._progress['status'], - 'current_model': self._progress['current_model'], + "type": "example_images_progress", + "processed": self._progress["completed"], + "total": self._progress["total"], + "status": status or self._progress["status"], + "current_model": self._progress["current_model"], } - if self._progress['errors']: - payload['errors'] = list(self._progress['errors']) - if self._progress['last_error']: - payload['last_error'] = self._progress['last_error'] + if self._progress["errors"]: + payload["errors"] = list(self._progress["errors"]) + if self._progress["last_error"]: + payload["last_error"] = self._progress["last_error"] if extra: payload.update(extra) diff --git a/static/js/components/ContextMenu/GlobalContextMenu.js b/static/js/components/ContextMenu/GlobalContextMenu.js index 3cc7699a..7bc61be5 100644 --- a/static/js/components/ContextMenu/GlobalContextMenu.js +++ b/static/js/components/ContextMenu/GlobalContextMenu.js @@ -75,13 +75,6 @@ export class GlobalContextMenu extends BaseContextMenu { } async downloadExampleImages(menuItem) { - const exampleImagesManager = window.exampleImagesManager; - - if (!exampleImagesManager) { - showToast('globalContextMenu.downloadExampleImages.unavailable', {}, 'error'); - return; - } - const downloadPath = state?.global?.settings?.example_images_path; if (!downloadPath) { showToast('globalContextMenu.downloadExampleImages.missingPath', {}, 'warning'); @@ -91,7 +84,48 @@ export class GlobalContextMenu extends BaseContextMenu { menuItem?.classList.add('disabled'); try { - await exampleImagesManager.handleDownloadButton(); + const optimize = state.global.settings.optimize_example_images; + + const response = await fetch('/api/lm/download-example-images', { + method: 'POST', + headers: { + 'Content-Type': 'application/json' + }, + body: JSON.stringify({ + force: true, + optimize, + model_types: ['lora', 'checkpoint', 'embedding'] + }) + }); + + const data = await response.json(); + + if (data.success) { + showToast('toast.exampleImages.downloadStarted', {}, 'success'); + + const exampleImagesManager = window.exampleImagesManager; + if (exampleImagesManager) { + exampleImagesManager.isDownloading = true; + exampleImagesManager.isPaused = false; + exampleImagesManager.isStopping = false; + exampleImagesManager.hasShownCompletionToast = false; + exampleImagesManager.startTime = new Date(); + exampleImagesManager.updateUI(data.status); + exampleImagesManager.showProgressPanel(); + exampleImagesManager.startProgressUpdates(); + exampleImagesManager.updateDownloadButtonText(); + + const stopButton = document.getElementById('stopExampleDownloadBtn'); + if (stopButton) { + stopButton.disabled = false; + } + } + } else { + showToast('toast.exampleImages.downloadStartFailed', { error: data.error }, 'error'); + } + } catch (error) { + console.error('Failed to trigger example images download:', error); + showToast('toast.exampleImages.downloadStartFailed', {}, 'error'); } finally { menuItem?.classList.remove('disabled'); } diff --git a/tests/utils/test_example_images_download_manager_unit.py b/tests/utils/test_example_images_download_manager_unit.py index f6617ba6..51ed6c9f 100644 --- a/tests/utils/test_example_images_download_manager_unit.py +++ b/tests/utils/test_example_images_download_manager_unit.py @@ -29,12 +29,14 @@ def restore_settings() -> None: manager.settings.update(original) -async def test_start_download_requires_configured_path(monkeypatch: pytest.MonkeyPatch) -> None: +async def test_start_download_requires_configured_path( + monkeypatch: pytest.MonkeyPatch, +) -> None: manager = download_module.DownloadManager(ws_manager=RecordingWebSocketManager()) # Ensure example_images_path is not configured settings_manager = get_settings_manager() - settings_manager.settings.pop('example_images_path', None) + settings_manager.settings.pop("example_images_path", None) with pytest.raises(download_module.DownloadConfigurationError) as exc_info: await manager.start_download({}) @@ -46,7 +48,9 @@ async def test_start_download_requires_configured_path(monkeypatch: pytest.Monke assert "skipping auto download" in result["message"] -async def test_start_download_bootstraps_progress_and_task(monkeypatch: pytest.MonkeyPatch, tmp_path) -> None: +async def test_start_download_bootstraps_progress_and_task( + monkeypatch: pytest.MonkeyPatch, tmp_path +) -> None: settings_manager = get_settings_manager() settings_manager.settings["example_images_path"] = str(tmp_path) settings_manager.settings["libraries"] = {"default": {}} @@ -58,7 +62,9 @@ async def test_start_download_bootstraps_progress_and_task(monkeypatch: pytest.M started = asyncio.Event() release = asyncio.Event() - async def fake_download(self, output_dir, optimize, model_types, delay, library_name): + async def fake_download( + self, output_dir, optimize, model_types, delay, library_name, force=False + ): started.set() await release.wait() async with self._state_lock: @@ -129,7 +135,9 @@ async def test_pause_and_resume_flow(monkeypatch: pytest.MonkeyPatch, tmp_path) await asyncio.wait_for(task, timeout=1) -async def test_stop_download_transitions_to_stopped(monkeypatch: pytest.MonkeyPatch, tmp_path) -> None: +async def test_stop_download_transitions_to_stopped( + monkeypatch: pytest.MonkeyPatch, tmp_path +) -> None: settings_manager = get_settings_manager() settings_manager.settings["example_images_path"] = str(tmp_path) settings_manager.settings["libraries"] = {"default": {}} @@ -145,13 +153,13 @@ async def test_stop_download_transitions_to_stopped(monkeypatch: pytest.MonkeyPa started.set() await release.wait() async with self._state_lock: - if self._stop_requested and self._progress['status'] == 'stopping': - self._progress['status'] = 'stopped' + if self._stop_requested and self._progress["status"] == "stopping": + self._progress["status"] = "stopped" else: - self._progress['status'] = 'completed' - self._progress['end_time'] = time.time() + self._progress["status"] = "completed" + self._progress["end_time"] = time.time() self._stop_requested = False - await self._broadcast_progress(status=self._progress['status']) + await self._broadcast_progress(status=self._progress["status"]) async with self._state_lock: self._is_downloading = False self._download_task = None @@ -182,7 +190,9 @@ async def test_stop_download_transitions_to_stopped(monkeypatch: pytest.MonkeyPa assert "stopped" in statuses -async def test_pause_or_resume_without_running_download(monkeypatch: pytest.MonkeyPatch) -> None: +async def test_pause_or_resume_without_running_download( + monkeypatch: pytest.MonkeyPatch, +) -> None: manager = download_module.DownloadManager(ws_manager=RecordingWebSocketManager()) with pytest.raises(download_module.DownloadNotRunningError):