diff --git a/py/utils/example_images_download_manager.py b/py/utils/example_images_download_manager.py index 6101ffec..f5019847 100644 --- a/py/utils/example_images_download_manager.py +++ b/py/utils/example_images_download_manager.py @@ -62,8 +62,8 @@ class _DownloadProgress(dict): self.update( total=0, completed=0, - current_model="", - status="idle", + current_model='', + status='idle', errors=[], last_error=None, start_time=None, @@ -71,7 +71,6 @@ class _DownloadProgress(dict): processed_models=set(), refreshed_models=set(), failed_models=set(), - failed_model_timestamps={}, reprocessed_models=set(), ) @@ -79,11 +78,10 @@ class _DownloadProgress(dict): """Return a JSON-serialisable snapshot of the current progress.""" snapshot = dict(self) - snapshot["processed_models"] = list(self["processed_models"]) - snapshot["refreshed_models"] = list(self["refreshed_models"]) - snapshot["failed_models"] = list(self["failed_models"]) - snapshot["reprocessed_models"] = list(self.get("reprocessed_models", set())) - snapshot.pop("failed_model_timestamps", None) + snapshot['processed_models'] = list(self['processed_models']) + snapshot['refreshed_models'] = list(self['refreshed_models']) + snapshot['failed_models'] = list(self['failed_models']) + snapshot['reprocessed_models'] = list(self.get('reprocessed_models', set())) return snapshot @@ -102,7 +100,6 @@ def _model_directory_has_files(path: str) -> bool: return False - class DownloadManager: """Manages downloading example images for models.""" @@ -115,68 +112,55 @@ class DownloadManager: self._stop_requested = False def _resolve_output_dir(self, library_name: str | None = None) -> str: - base_path = get_settings_manager().get("example_images_path") + base_path = get_settings_manager().get('example_images_path') if not base_path: - return "" + return '' return ensure_library_root_exists(library_name) async def start_download(self, options: dict): """Start downloading example images for models.""" async with self._state_lock: - snapshot = self._progress.snapshot() if self._is_downloading: - raise DownloadInProgressError(snapshot) + raise DownloadInProgressError(self._progress.snapshot()) try: data = options or {} - auto_mode = data.get("auto_mode", False) - optimize = data.get("optimize", True) - model_types = data.get("model_types", ["lora", "checkpoint"]) - delay = float(data.get("delay", 0.2)) + auto_mode = data.get('auto_mode', False) + optimize = data.get('optimize', True) + model_types = data.get('model_types', ['lora', 'checkpoint']) + delay = float(data.get('delay', 0.2)) settings_manager = get_settings_manager() - base_path = settings_manager.get("example_images_path") + base_path = settings_manager.get('example_images_path') if not base_path: - error_msg = "Example images path not configured in settings" + error_msg = 'Example images path not configured in settings' if auto_mode: logger.debug(error_msg) return { - "success": True, - "message": "Example images path not configured, skipping auto download", + 'success': True, + 'message': 'Example images path not configured, skipping auto download' } raise DownloadConfigurationError(error_msg) active_library = get_settings_manager().get_active_library_name() output_dir = self._resolve_output_dir(active_library) if not output_dir: - raise DownloadConfigurationError( - "Example images path not configured in settings" - ) + raise DownloadConfigurationError('Example images path not configured in settings') self._progress.reset() self._stop_requested = False - self._progress["status"] = "running" - self._progress["start_time"] = time.time() - self._progress["end_time"] = None + self._progress['status'] = 'running' + self._progress['start_time'] = time.time() + self._progress['end_time'] = None - progress_file = os.path.join(output_dir, ".download_progress.json") + progress_file = os.path.join(output_dir, '.download_progress.json') progress_source = progress_file if uses_library_scoped_folders(): - legacy_root = ( - get_settings_manager().get("example_images_path") or "" - ) - legacy_progress = ( - os.path.join(legacy_root, ".download_progress.json") - if legacy_root - else "" - ) - if ( - legacy_progress - and os.path.exists(legacy_progress) - and not os.path.exists(progress_file) - ): + legacy_root = get_settings_manager().get('example_images_path') or '' + legacy_progress = os.path.join(legacy_root, '.download_progress.json') if legacy_root else '' + if legacy_progress and os.path.exists(legacy_progress) and not os.path.exists(progress_file): try: os.makedirs(output_dir, exist_ok=True) shutil.move(legacy_progress, progress_file) @@ -196,31 +180,22 @@ class DownloadManager: if os.path.exists(progress_source): try: - with open(progress_source, "r", encoding="utf-8") as f: + with open(progress_source, 'r', encoding='utf-8') as f: saved_progress = json.load(f) - self._progress["processed_models"] = set( - saved_progress.get("processed_models", []) - ) - self._progress["failed_models"] = set( - saved_progress.get("failed_models", []) - ) - self._progress["failed_model_timestamps"] = ( - saved_progress.get("failed_model_timestamps", {}) - ) + self._progress['processed_models'] = set(saved_progress.get('processed_models', [])) + self._progress['failed_models'] = set(saved_progress.get('failed_models', [])) logger.debug( "Loaded previous progress, %s models already processed, %s models marked as failed", - len(self._progress["processed_models"]), - len(self._progress["failed_models"]), + len(self._progress['processed_models']), + len(self._progress['failed_models']), ) except Exception as e: logger.error(f"Failed to load progress file: {e}") - self._progress["processed_models"] = set() - self._progress["failed_models"] = set() - self._progress["failed_model_timestamps"] = {} + self._progress['processed_models'] = set() + self._progress['failed_models'] = set() else: - self._progress["processed_models"] = set() - self._progress["failed_models"] = set() - self._progress["failed_model_timestamps"] = {} + self._progress['processed_models'] = set() + self._progress['failed_models'] = set() self._is_downloading = True self._download_task = asyncio.create_task( @@ -233,6 +208,7 @@ class DownloadManager: ) ) + snapshot = self._progress.snapshot() except ExampleImagesDownloadError: # Re-raise our own exception types without wrapping self._is_downloading = False @@ -241,22 +217,24 @@ class DownloadManager: except Exception as e: self._is_downloading = False self._download_task = None - logger.error( - f"Failed to start example images download: {e}", exc_info=True - ) + logger.error(f"Failed to start example images download: {e}", exc_info=True) raise ExampleImagesDownloadError(str(e)) from e - await self._broadcast_progress(status="running") - - return {"success": True, "message": "Download started", "status": snapshot} + await self._broadcast_progress(status='running') + return { + 'success': True, + 'message': 'Download started', + 'status': snapshot + } + async def get_status(self, request): """Get the current status of example images download.""" return { - "success": True, - "is_downloading": self._is_downloading, - "status": self._progress.snapshot(), + 'success': True, + 'is_downloading': self._is_downloading, + 'status': self._progress.snapshot(), } async def pause_download(self, request): @@ -266,11 +244,14 @@ class DownloadManager: if not self._is_downloading: raise DownloadNotRunningError() - self._progress["status"] = "paused" + self._progress['status'] = 'paused' - await self._broadcast_progress(status="paused") + await self._broadcast_progress(status='paused') - return {"success": True, "message": "Download paused"} + return { + 'success': True, + 'message': 'Download paused' + } async def resume_download(self, request): """Resume the example images download.""" @@ -279,16 +260,19 @@ class DownloadManager: if not self._is_downloading: raise DownloadNotRunningError() - if self._progress["status"] == "paused": - self._progress["status"] = "running" + if self._progress['status'] == 'paused': + self._progress['status'] = 'running' else: raise DownloadNotRunningError( f"Download is in '{self._progress['status']}' state, cannot resume" ) - await self._broadcast_progress(status="running") + await self._broadcast_progress(status='running') - return {"success": True, "message": "Download resumed"} + return { + 'success': True, + 'message': 'Download resumed' + } async def stop_download(self, request): """Stop the example images download after the current model completes.""" @@ -297,17 +281,20 @@ class DownloadManager: if not self._is_downloading: raise DownloadNotRunningError() - if self._progress["status"] in {"completed", "error", "stopped"}: + if self._progress['status'] in {'completed', 'error', 'stopped'}: raise DownloadNotRunningError() - if self._progress["status"] != "stopping": + if self._progress['status'] != 'stopping': self._stop_requested = True - self._progress["status"] = "stopping" + self._progress['status'] = 'stopping' - await self._broadcast_progress(status="stopping") - - return {"success": True, "message": "Download stopping"} + await self._broadcast_progress(status='stopping') + return { + 'success': True, + 'message': 'Download stopping' + } + async def _download_all_example_images( self, output_dir, @@ -319,42 +306,42 @@ class DownloadManager: """Download example images for all models.""" downloader = await get_downloader() - + try: # Get scanners scanners = [] - if "lora" in model_types: + if 'lora' in model_types: lora_scanner = await ServiceRegistry.get_lora_scanner() - scanners.append(("lora", lora_scanner)) - - if "checkpoint" in model_types: + scanners.append(('lora', lora_scanner)) + + if 'checkpoint' in model_types: checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner() - scanners.append(("checkpoint", checkpoint_scanner)) + scanners.append(('checkpoint', checkpoint_scanner)) - if "embedding" in model_types: + if 'embedding' in model_types: embedding_scanner = await ServiceRegistry.get_embedding_scanner() - scanners.append(("embedding", embedding_scanner)) - + scanners.append(('embedding', embedding_scanner)) + # Get all models all_models = [] for scanner_type, scanner in scanners: cache = await scanner.get_cached_data() if cache and cache.raw_data: for model in cache.raw_data: - if model.get("sha256"): + if model.get('sha256'): all_models.append((scanner_type, model, scanner)) - + # Update total count - self._progress["total"] = len(all_models) + self._progress['total'] = len(all_models) logger.debug(f"Found {self._progress['total']} models to process") - await self._broadcast_progress(status="running") - + await self._broadcast_progress(status='running') + # Process each model for i, (scanner_type, model, scanner) in enumerate(all_models): async with self._state_lock: - current_status = self._progress["status"] + current_status = self._progress['status'] - if current_status not in {"running", "paused", "stopping"}: + if current_status not in {'running', 'paused', 'stopping'}: break # Main logic for processing model is here, but actual operations are delegated to other classes @@ -369,15 +356,13 @@ class DownloadManager: ) # Update progress - self._progress["completed"] += 1 + self._progress['completed'] += 1 async with self._state_lock: - current_status = self._progress["status"] - should_stop = self._stop_requested and current_status == "stopping" + current_status = self._progress['status'] + should_stop = self._stop_requested and current_status == 'stopping' - broadcast_status = ( - "running" if current_status == "running" else current_status - ) + broadcast_status = 'running' if current_status == 'running' else current_status await self._broadcast_progress(status=broadcast_status) if should_stop: @@ -387,64 +372,57 @@ class DownloadManager: if ( was_remote_download and i < len(all_models) - 1 - and current_status == "running" + and current_status == 'running' ): await asyncio.sleep(delay) async with self._state_lock: - if self._stop_requested and self._progress["status"] == "stopping": - self._progress["status"] = "stopped" - self._progress["end_time"] = time.time() + if self._stop_requested and self._progress['status'] == 'stopping': + self._progress['status'] = 'stopped' + self._progress['end_time'] = time.time() self._stop_requested = False - final_status = "stopped" - elif self._progress["status"] not in {"error", "stopped"}: - self._progress["status"] = "completed" - self._progress["end_time"] = time.time() + final_status = 'stopped' + elif self._progress['status'] not in {'error', 'stopped'}: + self._progress['status'] = 'completed' + self._progress['end_time'] = time.time() self._stop_requested = False - final_status = "completed" + final_status = 'completed' else: - final_status = self._progress["status"] + final_status = self._progress['status'] self._stop_requested = False - if self._progress["end_time"] is None: - self._progress["end_time"] = time.time() + if self._progress['end_time'] is None: + self._progress['end_time'] = time.time() - if final_status == "completed": + if final_status == 'completed': logger.debug( "Example images download completed: %s/%s models processed", - self._progress["completed"], - self._progress["total"], + self._progress['completed'], + self._progress['total'], ) - elif final_status == "stopped": + elif final_status == 'stopped': logger.debug( "Example images download stopped: %s/%s models processed", - self._progress["completed"], - self._progress["total"], + self._progress['completed'], + self._progress['total'], ) - reprocessed = self._progress.get("reprocessed_models", set()) + reprocessed = self._progress.get('reprocessed_models', set()) if reprocessed: logger.info( "Detected %s models with missing or empty example image folders; reprocessing triggered for those models", len(reprocessed), ) - failed_count = len(self._progress["failed_models"]) - if failed_count > 0: - logger.warning( - "%s models failed to download example images. These will be retried in 24 hours.", - failed_count, - ) - await self._broadcast_progress(status=final_status) except Exception as e: error_msg = f"Error during example images download: {str(e)}" logger.error(error_msg, exc_info=True) - self._progress["errors"].append(error_msg) - self._progress["last_error"] = error_msg - self._progress["status"] = "error" - self._progress["end_time"] = time.time() - await self._broadcast_progress(status="error", extra={"error": error_msg}) + self._progress['errors'].append(error_msg) + self._progress['last_error'] = error_msg + self._progress['status'] = 'error' + self._progress['end_time'] = time.time() + await self._broadcast_progress(status='error', extra={'error': error_msg}) finally: # Save final progress to file @@ -458,7 +436,7 @@ class DownloadManager: self._is_downloading = False self._download_task = None self._stop_requested = False - + async def _process_model( self, scanner_type, @@ -472,80 +450,54 @@ class DownloadManager: """Process a single model download.""" # Check if download is paused - while self._progress["status"] == "paused": + while self._progress['status'] == 'paused': await asyncio.sleep(1) # Check if download should continue - if self._progress["status"] not in {"running", "stopping"}: + if self._progress['status'] not in {'running', 'stopping'}: logger.info(f"Download stopped: {self._progress['status']}") return False # Return False to indicate no remote download happened - - model_hash = model.get("sha256", "").lower() - model_name = model.get("model_name", "Unknown") - model_file_path = model.get("file_path", "") - model_file_name = model.get("file_name", "") - + + model_hash = model.get('sha256', '').lower() + model_name = model.get('model_name', 'Unknown') + model_file_path = model.get('file_path', '') + model_file_name = model.get('file_name', '') + try: # Update current model info - self._progress["current_model"] = f"{model_name} ({model_hash[:8]})" - await self._broadcast_progress(status="running") - - model_dir = ExampleImagePathResolver.get_model_folder( - model_hash, library_name - ) + self._progress['current_model'] = f"{model_name} ({model_hash[:8]})" + await self._broadcast_progress(status='running') + + # Skip if already in failed models + if model_hash in self._progress['failed_models']: + logger.debug(f"Skipping known failed model: {model_name}") + return False + + model_dir = ExampleImagePathResolver.get_model_folder(model_hash, library_name) existing_files = _model_directory_has_files(model_dir) - # Skip if already in failed models - if model_hash in self._progress["failed_models"]: - if existing_files: - logger.debug( - f"Skipping failed model with existing files: {model_name}" - ) - return False - - failure_time = self._progress["failed_model_timestamps"].get( - model_hash, 0 - ) - retry_interval = 24 * 60 * 60 - - if time.time() - failure_time < retry_interval: - logger.debug( - f"Skipping recently failed model: {model_name} (retry in {int((retry_interval - (time.time() - failure_time)) / 60)} minutes)" - ) - return False - - logger.info( - "Retrying previously failed model %s (%s) - %.1f hours since last failure", - model_name, - model_hash[:8], - (time.time() - failure_time) / 3600, - ) - self._progress["failed_models"].discard(model_hash) - self._progress["failed_model_timestamps"].pop(model_hash, None) - # FALL THROUGH to normal processing with variables already set - # Skip if already processed AND directory exists with files - if model_hash in self._progress["processed_models"]: + if model_hash in self._progress['processed_models']: if existing_files: logger.debug(f"Skipping already processed model: {model_name}") return False - + logger.debug( "Model %s (%s) marked as processed but folder empty or missing, reprocessing triggered", model_name, model_hash, ) # Track that we are reprocessing this model for summary logging - self._progress["reprocessed_models"].add(model_hash) + self._progress['reprocessed_models'].add(model_hash) # Remove from processed models since we need to reprocess - self._progress["processed_models"].discard(model_hash) + self._progress['processed_models'].discard(model_hash) - if existing_files and model_hash not in self._progress["processed_models"]: + if existing_files and model_hash not in self._progress['processed_models']: logger.debug( "Model folder already populated for %s, marking as processed without download", model_name, ) - self._progress["processed_models"].add(model_hash) + self._progress['processed_models'].add(model_hash) return False if not model_dir: @@ -558,42 +510,38 @@ class DownloadManager: # Create model directory os.makedirs(model_dir, exist_ok=True) - + # First check for local example images - local processing doesn't need delay - local_images_processed = ( - await ExampleImagesProcessor.process_local_examples( - model_file_path, model_file_name, model_name, model_dir, optimize - ) + local_images_processed = await ExampleImagesProcessor.process_local_examples( + model_file_path, model_file_name, model_name, model_dir, optimize ) - + # If we processed local images, update metadata if local_images_processed: await MetadataUpdater.update_metadata_from_local_examples( model_hash, model, scanner_type, scanner, model_dir ) - self._progress["processed_models"].add(model_hash) + self._progress['processed_models'].add(model_hash) return False # Return False to indicate no remote download happened - - full_model = await MetadataUpdater.get_updated_model(model_hash, scanner) - civitai_payload = (full_model or {}).get("civitai") if full_model else None + + full_model = await MetadataUpdater.get_updated_model( + model_hash, scanner + ) + civitai_payload = (full_model or {}).get('civitai') if full_model else None civitai_payload = civitai_payload or {} # If no local images, try to download from remote - if civitai_payload.get("images"): - images = civitai_payload.get("images", []) + if civitai_payload.get('images'): + images = civitai_payload.get('images', []) - ( - success, - is_stale, - failed_images, - ) = await ExampleImagesProcessor.download_model_images_with_tracking( + success, is_stale, failed_images = await ExampleImagesProcessor.download_model_images_with_tracking( model_hash, model_name, images, model_dir, optimize, downloader ) failed_urls: Set[str] = set(failed_images) # If metadata is stale, try to refresh it - if is_stale and model_hash not in self._progress["refreshed_models"]: + if is_stale and model_hash not in self._progress['refreshed_models']: await MetadataUpdater.refresh_model_metadata( model_hash, model_name, scanner_type, scanner, self._progress ) @@ -602,30 +550,19 @@ class DownloadManager: updated_model = await MetadataUpdater.get_updated_model( model_hash, scanner ) - updated_civitai = ( - (updated_model or {}).get("civitai") if updated_model else None - ) + updated_civitai = (updated_model or {}).get('civitai') if updated_model else None updated_civitai = updated_civitai or {} - if updated_civitai.get("images"): + if updated_civitai.get('images'): # Retry download with updated metadata - updated_images = updated_civitai.get("images", []) - ( - success, - _, - additional_failed, - ) = await ExampleImagesProcessor.download_model_images_with_tracking( - model_hash, - model_name, - updated_images, - model_dir, - optimize, - downloader, + updated_images = updated_civitai.get('images', []) + success, _, additional_failed = await ExampleImagesProcessor.download_model_images_with_tracking( + model_hash, model_name, updated_images, model_dir, optimize, downloader ) failed_urls.update(additional_failed) - self._progress["refreshed_models"].add(model_hash) + self._progress['refreshed_models'].add(model_hash) if failed_urls: await self._remove_failed_images_from_metadata( @@ -637,90 +574,75 @@ class DownloadManager: ) if failed_urls: - self._progress["failed_models"].add(model_hash) - self._progress["failed_model_timestamps"][model_hash] = time.time() - self._progress["processed_models"].add(model_hash) + self._progress['failed_models'].add(model_hash) + self._progress['processed_models'].add(model_hash) logger.info( - "Removed %s failed example images for %s", - len(failed_urls), - model_name, + "Removed %s failed example images for %s", len(failed_urls), model_name ) elif success: - self._progress["processed_models"].add(model_hash) + self._progress['processed_models'].add(model_hash) else: - self._progress["failed_models"].add(model_hash) - self._progress["failed_model_timestamps"][model_hash] = time.time() + self._progress['failed_models'].add(model_hash) logger.info( - "Example images download failed for %s despite metadata refresh", - model_name, + "Example images download failed for %s despite metadata refresh", model_name ) return True # Return True to indicate a remote download happened else: # No civitai data or images available, mark as failed to avoid future attempts - self._progress["failed_models"].add(model_hash) - self._progress["failed_model_timestamps"][model_hash] = time.time() - logger.debug( - f"No civitai images available for model {model_name}, marking as failed" - ) + self._progress['failed_models'].add(model_hash) + logger.debug(f"No civitai images available for model {model_name}, marking as failed") # Save progress periodically - if ( - self._progress["completed"] % 10 == 0 - or self._progress["completed"] == self._progress["total"] - 1 - ): + if self._progress['completed'] % 10 == 0 or self._progress['completed'] == self._progress['total'] - 1: self._save_progress(output_dir) - + return False # Default return if no conditions met - + except Exception as e: error_msg = f"Error processing model {model.get('model_name')} ({model_hash}): {str(e)}" logger.error(error_msg, exc_info=True) - self._progress["errors"].append(error_msg) - self._progress["last_error"] = error_msg + self._progress['errors'].append(error_msg) + self._progress['last_error'] = error_msg # Ensure model is marked as failed so we don't try again in this run - self._progress["failed_models"].add(model_hash) - self._progress["failed_model_timestamps"][model_hash] = time.time() + self._progress['failed_models'].add(model_hash) return False - + def _save_progress(self, output_dir): """Save download progress to file.""" try: - progress_file = os.path.join(output_dir, ".download_progress.json") - + progress_file = os.path.join(output_dir, '.download_progress.json') + # Read existing progress file if it exists existing_data = {} if os.path.exists(progress_file): try: - with open(progress_file, "r", encoding="utf-8") as f: + with open(progress_file, 'r', encoding='utf-8') as f: existing_data = json.load(f) except Exception as e: logger.warning(f"Failed to read existing progress file: {e}") - + # Create new progress data progress_data = { - "processed_models": list(self._progress["processed_models"]), - "refreshed_models": list(self._progress["refreshed_models"]), - "failed_models": list(self._progress["failed_models"]), - "failed_model_timestamps": self._progress.get( - "failed_model_timestamps", {} - ), - "completed": self._progress["completed"], - "total": self._progress["total"], - "last_update": time.time(), + 'processed_models': list(self._progress['processed_models']), + 'refreshed_models': list(self._progress['refreshed_models']), + 'failed_models': list(self._progress['failed_models']), + 'completed': self._progress['completed'], + 'total': self._progress['total'], + 'last_update': time.time() } - + # Preserve existing fields (especially naming_version) for key, value in existing_data.items(): if key not in progress_data: progress_data[key] = value - + # Write updated progress data - with open(progress_file, "w", encoding="utf-8") as f: + with open(progress_file, 'w', encoding='utf-8') as f: json.dump(progress_data, f, indent=2) except Exception as e: logger.error(f"Failed to save progress file: {e}") - + async def start_force_download(self, options: dict): """Force download example images for specific models.""" @@ -729,38 +651,34 @@ class DownloadManager: raise DownloadInProgressError(self._progress.snapshot()) data = options or {} - model_hashes = data.get("model_hashes", []) - optimize = data.get("optimize", True) - model_types = data.get("model_types", ["lora", "checkpoint"]) - delay = float(data.get("delay", 0.2)) + model_hashes = data.get('model_hashes', []) + optimize = data.get('optimize', True) + model_types = data.get('model_types', ['lora', 'checkpoint']) + delay = float(data.get('delay', 0.2)) if not model_hashes: - raise DownloadConfigurationError("Missing model_hashes parameter") + raise DownloadConfigurationError('Missing model_hashes parameter') settings_manager = get_settings_manager() - base_path = settings_manager.get("example_images_path") + base_path = settings_manager.get('example_images_path') if not base_path: - raise DownloadConfigurationError( - "Example images path not configured in settings" - ) + raise DownloadConfigurationError('Example images path not configured in settings') active_library = settings_manager.get_active_library_name() output_dir = self._resolve_output_dir(active_library) if not output_dir: - raise DownloadConfigurationError( - "Example images path not configured in settings" - ) + raise DownloadConfigurationError('Example images path not configured in settings') self._progress.reset() self._stop_requested = False - self._progress["total"] = len(model_hashes) - self._progress["status"] = "running" - self._progress["start_time"] = time.time() - self._progress["end_time"] = None + self._progress['total'] = len(model_hashes) + self._progress['status'] = 'running' + self._progress['start_time'] = time.time() + self._progress['end_time'] = None self._is_downloading = True - await self._broadcast_progress(status="running") + await self._broadcast_progress(status='running') try: result = await self._download_specific_models_example_images_sync( @@ -774,23 +692,25 @@ class DownloadManager: async with self._state_lock: self._is_downloading = False - final_status = self._progress["status"] + final_status = self._progress['status'] - message = "Force download completed" - if final_status == "stopped": - message = "Force download stopped" + message = 'Force download completed' + if final_status == 'stopped': + message = 'Force download stopped' - return {"success": True, "message": message, "result": result} + return { + 'success': True, + 'message': message, + 'result': result + } except Exception as e: async with self._state_lock: self._is_downloading = False - logger.error( - f"Failed during forced example images download: {e}", exc_info=True - ) - await self._broadcast_progress(status="error", extra={"error": str(e)}) + logger.error(f"Failed during forced example images download: {e}", exc_info=True) + await self._broadcast_progress(status='error', extra={'error': str(e)}) raise ExampleImagesDownloadError(str(e)) from e - + async def _download_specific_models_example_images_sync( self, model_hashes, @@ -803,45 +723,45 @@ class DownloadManager: """Download example images for specific models only - synchronous version.""" downloader = await get_downloader() - + try: # Get scanners scanners = [] - if "lora" in model_types: + if 'lora' in model_types: lora_scanner = await ServiceRegistry.get_lora_scanner() - scanners.append(("lora", lora_scanner)) - - if "checkpoint" in model_types: + scanners.append(('lora', lora_scanner)) + + if 'checkpoint' in model_types: checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner() - scanners.append(("checkpoint", checkpoint_scanner)) + scanners.append(('checkpoint', checkpoint_scanner)) - if "embedding" in model_types: + if 'embedding' in model_types: embedding_scanner = await ServiceRegistry.get_embedding_scanner() - scanners.append(("embedding", embedding_scanner)) - + scanners.append(('embedding', embedding_scanner)) + # Find the specified models models_to_process = [] for scanner_type, scanner in scanners: cache = await scanner.get_cached_data() if cache and cache.raw_data: for model in cache.raw_data: - if model.get("sha256") in model_hashes: + if model.get('sha256') in model_hashes: models_to_process.append((scanner_type, model, scanner)) - + # Update total count based on found models - self._progress["total"] = len(models_to_process) + self._progress['total'] = len(models_to_process) logger.debug(f"Found {self._progress['total']} models to process") # Send initial progress via WebSocket - await self._broadcast_progress(status="running") - + await self._broadcast_progress(status='running') + # Process each model success_count = 0 for i, (scanner_type, model, scanner) in enumerate(models_to_process): async with self._state_lock: - current_status = self._progress["status"] + current_status = self._progress['status'] - if current_status not in {"running", "paused", "stopping"}: + if current_status not in {'running', 'paused', 'stopping'}: break # Force process this model regardless of previous status @@ -859,15 +779,13 @@ class DownloadManager: success_count += 1 # Update progress - self._progress["completed"] += 1 + self._progress['completed'] += 1 async with self._state_lock: - current_status = self._progress["status"] - should_stop = self._stop_requested and current_status == "stopping" + current_status = self._progress['status'] + should_stop = self._stop_requested and current_status == 'stopping' - broadcast_status = ( - "running" if current_status == "running" else current_status - ) + broadcast_status = 'running' if current_status == 'running' else current_status # Send progress update via WebSocket await self._broadcast_progress(status=broadcast_status) @@ -878,67 +796,67 @@ class DownloadManager: if ( was_successful and i < len(models_to_process) - 1 - and current_status == "running" + and current_status == 'running' ): await asyncio.sleep(delay) async with self._state_lock: - if self._stop_requested and self._progress["status"] == "stopping": - self._progress["status"] = "stopped" - self._progress["end_time"] = time.time() + if self._stop_requested and self._progress['status'] == 'stopping': + self._progress['status'] = 'stopped' + self._progress['end_time'] = time.time() self._stop_requested = False - final_status = "stopped" - elif self._progress["status"] not in {"error", "stopped"}: - self._progress["status"] = "completed" - self._progress["end_time"] = time.time() + final_status = 'stopped' + elif self._progress['status'] not in {'error', 'stopped'}: + self._progress['status'] = 'completed' + self._progress['end_time'] = time.time() self._stop_requested = False - final_status = "completed" + final_status = 'completed' else: - final_status = self._progress["status"] + final_status = self._progress['status'] self._stop_requested = False - if self._progress["end_time"] is None: - self._progress["end_time"] = time.time() + if self._progress['end_time'] is None: + self._progress['end_time'] = time.time() - if final_status == "completed": + if final_status == 'completed': logger.debug( "Forced example images download completed: %s/%s models processed", - self._progress["completed"], - self._progress["total"], + self._progress['completed'], + self._progress['total'], ) - elif final_status == "stopped": + elif final_status == 'stopped': logger.debug( "Forced example images download stopped: %s/%s models processed", - self._progress["completed"], - self._progress["total"], + self._progress['completed'], + self._progress['total'], ) # Send final progress via WebSocket await self._broadcast_progress(status=final_status) return { - "total": self._progress["total"], - "processed": self._progress["completed"], - "successful": success_count, - "errors": self._progress["errors"], + 'total': self._progress['total'], + 'processed': self._progress['completed'], + 'successful': success_count, + 'errors': self._progress['errors'] } - + except Exception as e: error_msg = f"Error during forced example images download: {str(e)}" logger.error(error_msg, exc_info=True) - self._progress["errors"].append(error_msg) - self._progress["last_error"] = error_msg - self._progress["status"] = "error" - self._progress["end_time"] = time.time() + self._progress['errors'].append(error_msg) + self._progress['last_error'] = error_msg + self._progress['status'] = 'error' + self._progress['end_time'] = time.time() # Send error status via WebSocket - await self._broadcast_progress(status="error", extra={"error": error_msg}) - + await self._broadcast_progress(status='error', extra={'error': error_msg}) + raise - + finally: # No need to close any sessions since we use the global downloader pass - + async def _process_specific_model( self, scanner_type, @@ -952,27 +870,25 @@ class DownloadManager: """Process a specific model for forced download, ignoring previous download status.""" # Check if download is paused - while self._progress["status"] == "paused": + while self._progress['status'] == 'paused': await asyncio.sleep(1) - + # Check if download should continue - if self._progress["status"] not in {"running", "stopping"}: + if self._progress['status'] not in {'running', 'stopping'}: logger.info(f"Download stopped: {self._progress['status']}") return False - - model_hash = model.get("sha256", "").lower() - model_name = model.get("model_name", "Unknown") - model_file_path = model.get("file_path", "") - model_file_name = model.get("file_name", "") - + + model_hash = model.get('sha256', '').lower() + model_name = model.get('model_name', 'Unknown') + model_file_path = model.get('file_path', '') + model_file_name = model.get('file_name', '') + try: # Update current model info - self._progress["current_model"] = f"{model_name} ({model_hash[:8]})" - await self._broadcast_progress(status="running") - - model_dir = ExampleImagePathResolver.get_model_folder( - model_hash, library_name - ) + self._progress['current_model'] = f"{model_name} ({model_hash[:8]})" + await self._broadcast_progress(status='running') + + model_dir = ExampleImagePathResolver.get_model_folder(model_hash, library_name) if not model_dir: logger.warning( "Unable to resolve example images folder for model %s (%s)", @@ -982,42 +898,38 @@ class DownloadManager: return False os.makedirs(model_dir, exist_ok=True) - + # First check for local example images - local processing doesn't need delay - local_images_processed = ( - await ExampleImagesProcessor.process_local_examples( - model_file_path, model_file_name, model_name, model_dir, optimize - ) + local_images_processed = await ExampleImagesProcessor.process_local_examples( + model_file_path, model_file_name, model_name, model_dir, optimize ) - + # If we processed local images, update metadata if local_images_processed: await MetadataUpdater.update_metadata_from_local_examples( model_hash, model, scanner_type, scanner, model_dir ) - self._progress["processed_models"].add(model_hash) + self._progress['processed_models'].add(model_hash) return False # Return False to indicate no remote download happened - - full_model = await MetadataUpdater.get_updated_model(model_hash, scanner) - civitai_payload = (full_model or {}).get("civitai") if full_model else None + + full_model = await MetadataUpdater.get_updated_model( + model_hash, scanner + ) + civitai_payload = (full_model or {}).get('civitai') if full_model else None civitai_payload = civitai_payload or {} # If no local images, try to download from remote - if civitai_payload.get("images"): - images = civitai_payload.get("images", []) + if civitai_payload.get('images'): + images = civitai_payload.get('images', []) - ( - success, - is_stale, - failed_images, - ) = await ExampleImagesProcessor.download_model_images_with_tracking( + success, is_stale, failed_images = await ExampleImagesProcessor.download_model_images_with_tracking( model_hash, model_name, images, model_dir, optimize, downloader ) failed_urls: Set[str] = set(failed_images) # If metadata is stale, try to refresh it - if is_stale and model_hash not in self._progress["refreshed_models"]: + if is_stale and model_hash not in self._progress['refreshed_models']: await MetadataUpdater.refresh_model_metadata( model_hash, model_name, scanner_type, scanner, self._progress ) @@ -1026,31 +938,20 @@ class DownloadManager: updated_model = await MetadataUpdater.get_updated_model( model_hash, scanner ) - updated_civitai = ( - (updated_model or {}).get("civitai") if updated_model else None - ) + updated_civitai = (updated_model or {}).get('civitai') if updated_model else None updated_civitai = updated_civitai or {} - if updated_civitai.get("images"): + if updated_civitai.get('images'): # Retry download with updated metadata - updated_images = updated_civitai.get("images", []) - ( - success, - _, - additional_failed_images, - ) = await ExampleImagesProcessor.download_model_images_with_tracking( - model_hash, - model_name, - updated_images, - model_dir, - optimize, - downloader, + updated_images = updated_civitai.get('images', []) + success, _, additional_failed_images = await ExampleImagesProcessor.download_model_images_with_tracking( + model_hash, model_name, updated_images, model_dir, optimize, downloader ) # Combine failed images from both attempts failed_urls.update(additional_failed_images) - self._progress["refreshed_models"].add(model_hash) + self._progress['refreshed_models'].add(model_hash) # For forced downloads, remove failed images from metadata if failed_urls: @@ -1059,22 +960,21 @@ class DownloadManager: ) # Mark as processed - if ( - success or failed_urls - ): # Mark as processed if we successfully downloaded some images or removed failed ones - self._progress["processed_models"].add(model_hash) + if success or failed_urls: # Mark as processed if we successfully downloaded some images or removed failed ones + self._progress['processed_models'].add(model_hash) return True # Return True to indicate a remote download happened else: logger.debug(f"No civitai images available for model {model_name}") - return False + return False + except Exception as e: error_msg = f"Error processing model {model.get('model_name')}: {str(e)}" logger.error(error_msg, exc_info=True) - self._progress["errors"].append(error_msg) - self._progress["last_error"] = error_msg + self._progress['errors'].append(error_msg) + self._progress['last_error'] = error_msg return False # Return False on exception async def _remove_failed_images_from_metadata( @@ -1095,13 +995,11 @@ class DownloadManager: # Get current model data model_data = await MetadataUpdater.get_updated_model(model_hash, scanner) if not model_data: - logger.warning( - f"Could not find model data for {model_name} to remove failed images" - ) + logger.warning(f"Could not find model data for {model_name} to remove failed images") return - civitai_payload = model_data.get("civitai") or {} - current_images = civitai_payload.get("images") or [] + civitai_payload = model_data.get('civitai') or {} + current_images = civitai_payload.get('images') or [] if not current_images: logger.warning(f"No images in metadata for {model_name}") return @@ -1109,21 +1007,21 @@ class DownloadManager: updated = False for image in current_images: - image_url = image.get("url") + image_url = image.get('url') optimized_url = ( ExampleImagesProcessor.get_civitai_optimized_url(image_url) - if image_url and "civitai.com" in image_url + if image_url and 'civitai.com' in image_url else None ) if image_url not in failed_set and optimized_url not in failed_set: continue - if image.get("downloadFailed"): + if image.get('downloadFailed'): continue - image["downloadFailed"] = True - image.setdefault("downloadError", "not_found") + image['downloadFailed'] = True + image.setdefault('downloadError', 'not_found') logger.debug( "Marked example image %s for %s as failed due to missing remote asset", image_url, @@ -1134,34 +1032,27 @@ class DownloadManager: if not updated: return - file_path = model_data.get("file_path") + file_path = model_data.get('file_path') if file_path: model_copy = model_data.copy() - model_copy.pop("folder", None) + model_copy.pop('folder', None) await MetadataManager.save_metadata(file_path, model_copy) try: - await scanner.update_single_model_cache( - file_path, file_path, model_data - ) + await scanner.update_single_model_cache(file_path, file_path, model_data) except AttributeError: - logger.debug( - "Scanner does not expose cache update for %s", model_name - ) + logger.debug("Scanner does not expose cache update for %s", model_name) except Exception as exc: # pragma: no cover - defensive logging logger.error( - "Error removing failed images from metadata for %s: %s", - model_name, - exc, - exc_info=True, + "Error removing failed images from metadata for %s: %s", model_name, exc, exc_info=True ) def _renumber_example_image_files(self, model_dir: str) -> None: if not model_dir or not os.path.isdir(model_dir): return - pattern = re.compile(r"^image_(\d+)(\.[^.]+)$", re.IGNORECASE) + pattern = re.compile(r'^image_(\d+)(\.[^.]+)$', re.IGNORECASE) matches: List[Tuple[int, str, str]] = [] for entry in os.listdir(model_dir): @@ -1212,17 +1103,17 @@ class DownloadManager: extra: Dict[str, Any] | None = None, ) -> Dict[str, Any]: payload: Dict[str, Any] = { - "type": "example_images_progress", - "processed": self._progress["completed"], - "total": self._progress["total"], - "status": status or self._progress["status"], - "current_model": self._progress["current_model"], + 'type': 'example_images_progress', + 'processed': self._progress['completed'], + 'total': self._progress['total'], + 'status': status or self._progress['status'], + 'current_model': self._progress['current_model'], } - if self._progress["errors"]: - payload["errors"] = list(self._progress["errors"]) - if self._progress["last_error"]: - payload["last_error"] = self._progress["last_error"] + if self._progress['errors']: + payload['errors'] = list(self._progress['errors']) + if self._progress['last_error']: + payload['last_error'] = self._progress['last_error'] if extra: payload.update(extra) diff --git a/tests/utils/test_failed_model_retry.py b/tests/utils/test_failed_model_retry.py deleted file mode 100644 index ede2ce76..00000000 --- a/tests/utils/test_failed_model_retry.py +++ /dev/null @@ -1,287 +0,0 @@ -from __future__ import annotations - -import time -from typing import Any, Dict -from unittest.mock import AsyncMock, MagicMock, patch - -import pytest - -from py.utils import example_images_download_manager as download_module - - -@pytest.fixture(autouse=True) -def restore_settings(): - from py.services.settings_manager import get_settings_manager - - manager = get_settings_manager() - original = manager.settings.copy() - try: - yield - finally: - manager.settings.clear() - manager.settings.update(original) - - -class RecordingWebSocketManager: - def __init__(self) -> None: - self.payloads: list[Dict[str, Any]] = [] - - async def broadcast(self, payload: Dict[str, Any]) -> None: - self.payloads.append(payload) - - -@pytest.mark.asyncio -async def test_process_model_with_old_failure_and_empty_folder_retries(): - """Test that models with old failures and empty folders are removed from failed list for retry.""" - from py.services.settings_manager import get_settings_manager - from py.services.downloader import get_downloader - - settings_manager = get_settings_manager() - settings_manager.settings["libraries"] = {"default": {}} - settings_manager.settings["active_library"] = "default" - - ws_manager = RecordingWebSocketManager() - manager = download_module.DownloadManager(ws_manager=ws_manager) - - # Create a model hash - test_hash = "test_hash_12345678" - test_name = "Test Model" - - # Mark as failed with timestamp 25 hours ago - old_timestamp = time.time() - (25 * 60 * 60) - manager._progress["failed_models"].add(test_hash) - manager._progress["failed_model_timestamps"][test_hash] = old_timestamp - - # Verify initial state - model is in failed list - assert test_hash in manager._progress["failed_models"] - assert test_hash in manager._progress["failed_model_timestamps"] - initial_timestamp = manager._progress["failed_model_timestamps"][test_hash] - - # Mock dependencies to make model successfully download - mock_scanner = MagicMock() - mock_model = { - "sha256": test_hash, - "model_name": test_name, - "file_path": "/fake/path/model.safetensors", - "file_name": "model.safetensors", - "civitai": {"images": [{"url": "http://example.com/image.jpg"}]}, - } - - mock_downloader = await get_downloader() - - # Mock path resolver to return directory with existing files - # This will make the code skip retry but verify the logic works - with ( - patch.object( - download_module.ExampleImagePathResolver, - "get_model_folder", - return_value="/fake/dir/with/files", - ), - patch.object( - download_module, - "_model_directory_has_files", - return_value=True, # Files exist - ), - ): - result = await manager._process_model( - "lora", - mock_model, - mock_scanner, - "/fake/output", - False, - mock_downloader, - "default", - ) - - # When files exist, model should remain in failed list (not retried) - assert test_hash in manager._progress["failed_models"] - assert test_hash in manager._progress["failed_model_timestamps"] - # Result should be False because no remote download happened (skipped due to existing files) - assert result is False - - # Now test the actual retry path by mocking empty directory - manager._progress["processed_models"].clear() - manager._progress["failed_models"].clear() - manager._progress["failed_model_timestamps"].clear() - - # Re-add to failed with old timestamp - manager._progress["failed_models"].add(test_hash) - manager._progress["failed_model_timestamps"][test_hash] = old_timestamp - - with ( - patch.object( - download_module.ExampleImagePathResolver, - "get_model_folder", - return_value="/fake/empty/dir", - ), - patch.object( - download_module, - "_model_directory_has_files", - return_value=False, # No files - ), - patch.object( - download_module.ExampleImagesProcessor, - "process_local_examples", - new_callable=AsyncMock, - return_value=False, - ), - # Note: We don't mock download_model_images_with_tracking here - # because it's complex. The key thing is that the model is - # removed from failed list so it can be retried. - ): - # Just check that the model is removed from failed list before processing - # This proves the retry logic is triggered - result = await manager._process_model( - "lora", - mock_model, - mock_scanner, - "/fake/output", - False, - mock_downloader, - "default", - ) - - # The model should have been removed from failed_models for retry - # (even if it gets re-added later due to download failure) - assert ( - result is False - or test_hash not in manager._progress["failed_models"] - or test_hash in manager._progress["processed_models"] - ) - - -@pytest.mark.asyncio -async def test_process_model_with_old_failure_and_existing_files_skips(): - """Test that models with old failures but existing files are not retried.""" - from py.services.settings_manager import get_settings_manager - from py.services.downloader import get_downloader - - settings_manager = get_settings_manager() - settings_manager.settings["libraries"] = {"default": {}} - settings_manager.settings["active_library"] = "default" - - ws_manager = RecordingWebSocketManager() - manager = download_module.DownloadManager(ws_manager=ws_manager) - - # Create a model hash - test_hash = "test_hash_12345678" - test_name = "Test Model" - - # Mark as failed with timestamp 25 hours ago - old_timestamp = time.time() - (25 * 60 * 60) - manager._progress["failed_models"].add(test_hash) - manager._progress["failed_model_timestamps"][test_hash] = old_timestamp - - mock_scanner = MagicMock() - mock_model = { - "sha256": test_hash, - "model_name": test_name, - "file_path": "/fake/path/model.safetensors", - "file_name": "model.safetensors", - } - - # Mock path resolver to return directory with files - with ( - patch.object( - download_module.ExampleImagePathResolver, - "get_model_folder", - return_value="/fake/dir/with/files", - ), - patch.object( - download_module, - "_model_directory_has_files", - return_value=True, - ), - ): - result = await manager._process_model( - "lora", - mock_model, - mock_scanner, - "/fake/output", - False, - await get_downloader(), - "default", - ) - - # Verify model is still in failed list (not retried because files exist) - assert test_hash in manager._progress["failed_models"] - assert test_hash in manager._progress["failed_model_timestamps"] - assert result is False # No remote download happened - - -@pytest.mark.asyncio -async def test_process_model_with_recent_failure_skips(): - """Test that models with recent failures are not retried.""" - from py.services.settings_manager import get_settings_manager - from py.services.downloader import get_downloader - - settings_manager = get_settings_manager() - settings_manager.settings["libraries"] = {"default": {}} - settings_manager.settings["active_library"] = "default" - - ws_manager = RecordingWebSocketManager() - manager = download_module.DownloadManager(ws_manager=ws_manager) - - # Create a model hash - test_hash = "test_hash_12345678" - test_name = "Test Model" - - # Mark as failed with timestamp 2 hours ago (recent) - recent_timestamp = time.time() - (2 * 60 * 60) - manager._progress["failed_models"].add(test_hash) - manager._progress["failed_model_timestamps"][test_hash] = recent_timestamp - - mock_scanner = MagicMock() - mock_model = { - "sha256": test_hash, - "model_name": test_name, - "file_path": "/fake/path/model.safetensors", - "file_name": "model.safetensors", - } - - # Mock path resolver to return empty directory - with ( - patch.object( - download_module.ExampleImagePathResolver, - "get_model_folder", - return_value="/fake/empty/dir", - ), - patch.object( - download_module, - "_model_directory_has_files", - return_value=False, - ), - ): - result = await manager._process_model( - "lora", - mock_model, - mock_scanner, - "/fake/output", - False, - await get_downloader(), - "default", - ) - - # Verify model is still in failed list (not retried because too recent) - assert test_hash in manager._progress["failed_models"] - assert test_hash in manager._progress["failed_model_timestamps"] - assert result is False # No remote download happened - - -def test_progress_includes_failed_timestamps(): - """Test that _DownloadProgress includes failed_model_timestamps.""" - progress = download_module._DownloadProgress() - progress.reset() - - assert "failed_model_timestamps" in progress - assert isinstance(progress["failed_model_timestamps"], dict) - - -def test_progress_snapshot_excludes_failed_timestamps(): - """Test that snapshot() excludes failed_model_timestamps.""" - progress = download_module._DownloadProgress() - progress.reset() - - snapshot = progress.snapshot() - - assert "failed_model_timestamps" not in snapshot