diff --git a/py/routes/example_images_routes.py b/py/routes/example_images_routes.py index 44effa3b..d5d34218 100644 --- a/py/routes/example_images_routes.py +++ b/py/routes/example_images_routes.py @@ -16,7 +16,10 @@ from ..services.use_cases.example_images import ( DownloadExampleImagesUseCase, ImportExampleImagesUseCase, ) -from ..utils.example_images_download_manager import DownloadManager +from ..utils.example_images_download_manager import ( + DownloadManager, + get_default_download_manager, +) from ..utils.example_images_file_manager import ExampleImagesFileManager from ..utils.example_images_processor import ExampleImagesProcessor @@ -29,11 +32,11 @@ class ExampleImagesRoutes: def __init__( self, *, - download_manager=DownloadManager, + download_manager: DownloadManager | None = None, processor=ExampleImagesProcessor, file_manager=ExampleImagesFileManager, ) -> None: - self._download_manager = download_manager + self._download_manager = download_manager or get_default_download_manager() self._processor = processor self._file_manager = file_manager self._handler_set: ExampleImagesHandlerSet | None = None diff --git a/py/utils/example_images_download_manager.py b/py/utils/example_images_download_manager.py index 7df0c6fb..e538f50a 100644 --- a/py/utils/example_images_download_manager.py +++ b/py/utils/example_images_download_manager.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import logging import os import asyncio @@ -37,165 +39,150 @@ class DownloadConfigurationError(ExampleImagesDownloadError): logger = logging.getLogger(__name__) -# Download status tracking -download_task = None -is_downloading = False -download_progress = { - 'total': 0, - 'completed': 0, - 'current_model': '', - 'status': 'idle', # idle, running, paused, completed, error - 'errors': [], - 'last_error': None, - 'start_time': None, - 'end_time': None, - 'processed_models': set(), # Track models that have been processed - 'refreshed_models': set(), # Track models that had metadata refreshed - 'failed_models': set() # Track models that failed to download after metadata refresh -} +class _DownloadProgress(dict): + """Mutable mapping maintaining download progress with set-aware serialisation.""" -def _serialize_progress() -> dict: - """Return a JSON-serialisable snapshot of the current progress.""" + def __init__(self) -> None: + super().__init__() + self.reset() - snapshot = download_progress.copy() - snapshot['processed_models'] = list(download_progress['processed_models']) - snapshot['refreshed_models'] = list(download_progress['refreshed_models']) - snapshot['failed_models'] = list(download_progress['failed_models']) - return snapshot + def reset(self) -> None: + """Reset the progress dictionary to its initial state.""" + + self.update( + total=0, + completed=0, + current_model='', + status='idle', + errors=[], + last_error=None, + start_time=None, + end_time=None, + processed_models=set(), + refreshed_models=set(), + failed_models=set(), + ) + + def snapshot(self) -> dict: + """Return a JSON-serialisable snapshot of the current progress.""" + + snapshot = dict(self) + snapshot['processed_models'] = list(self['processed_models']) + snapshot['refreshed_models'] = list(self['refreshed_models']) + snapshot['failed_models'] = list(self['failed_models']) + return snapshot class DownloadManager: - """Manages downloading example images for models""" - - @staticmethod - async def start_download(options: dict): - """ - Start downloading example images for models - - Expects a JSON body with: - { - "optimize": true, # Whether to optimize images (default: true) - "model_types": ["lora", "checkpoint"], # Model types to process (default: both) - "delay": 1.0, # Delay between downloads to avoid rate limiting (default: 1.0) - "auto_mode": false # Flag to indicate automatic download (default: false) - } - """ - global download_task, is_downloading, download_progress - - if is_downloading: - raise DownloadInProgressError(_serialize_progress()) + """Manages downloading example images for models.""" + + def __init__(self) -> None: + self._download_task: asyncio.Task | None = None + self._is_downloading = False + self._progress = _DownloadProgress() + + async def start_download(self, options: dict): + """Start downloading example images for models.""" + + if self._is_downloading: + raise DownloadInProgressError(self._progress.snapshot()) try: data = options or {} auto_mode = data.get('auto_mode', False) optimize = data.get('optimize', True) model_types = data.get('model_types', ['lora', 'checkpoint']) - delay = float(data.get('delay', 0.2)) # Default to 0.2 seconds - - # Get output directory from settings + delay = float(data.get('delay', 0.2)) + output_dir = settings.get('example_images_path') if not output_dir: error_msg = 'Example images path not configured in settings' if auto_mode: - # For auto mode, just log and return success to avoid showing error toasts logger.debug(error_msg) return { 'success': True, 'message': 'Example images path not configured, skipping auto download' } raise DownloadConfigurationError(error_msg) - - # Create the output directory + os.makedirs(output_dir, exist_ok=True) - - # Initialize progress tracking - download_progress['total'] = 0 - download_progress['completed'] = 0 - download_progress['current_model'] = '' - download_progress['status'] = 'running' - download_progress['errors'] = [] - download_progress['last_error'] = None - download_progress['start_time'] = time.time() - download_progress['end_time'] = None - - # Get the processed models list from a file if it exists + + self._progress.reset() + self._progress['status'] = 'running' + self._progress['start_time'] = time.time() + self._progress['end_time'] = None + progress_file = os.path.join(output_dir, '.download_progress.json') if os.path.exists(progress_file): try: with open(progress_file, 'r', encoding='utf-8') as f: saved_progress = json.load(f) - download_progress['processed_models'] = set(saved_progress.get('processed_models', [])) - download_progress['failed_models'] = set(saved_progress.get('failed_models', [])) - logger.debug(f"Loaded previous progress, {len(download_progress['processed_models'])} models already processed, {len(download_progress['failed_models'])} models marked as failed") + self._progress['processed_models'] = set(saved_progress.get('processed_models', [])) + self._progress['failed_models'] = set(saved_progress.get('failed_models', [])) + logger.debug( + "Loaded previous progress, %s models already processed, %s models marked as failed", + len(self._progress['processed_models']), + len(self._progress['failed_models']), + ) except Exception as e: logger.error(f"Failed to load progress file: {e}") - download_progress['processed_models'] = set() - download_progress['failed_models'] = set() + self._progress['processed_models'] = set() + self._progress['failed_models'] = set() else: - download_progress['processed_models'] = set() - download_progress['failed_models'] = set() - - # Start the download task - is_downloading = True - download_task = asyncio.create_task( - DownloadManager._download_all_example_images( - output_dir, - optimize, + self._progress['processed_models'] = set() + self._progress['failed_models'] = set() + + self._is_downloading = True + self._download_task = asyncio.create_task( + self._download_all_example_images( + output_dir, + optimize, model_types, delay ) ) - + return { 'success': True, 'message': 'Download started', - 'status': _serialize_progress() + 'status': self._progress.snapshot() } except Exception as e: logger.error(f"Failed to start example images download: {e}", exc_info=True) raise ExampleImagesDownloadError(str(e)) from e - @staticmethod - async def get_status(request): - """Get the current status of example images download""" - global download_progress - - # Create a copy of the progress dict with the set converted to a list for JSON serialization - response_progress = _serialize_progress() + async def get_status(self, request): + """Get the current status of example images download.""" return { 'success': True, - 'is_downloading': is_downloading, - 'status': response_progress + 'is_downloading': self._is_downloading, + 'status': self._progress.snapshot(), } - @staticmethod - async def pause_download(request): - """Pause the example images download""" - global download_progress - - if not is_downloading: + async def pause_download(self, request): + """Pause the example images download.""" + + if not self._is_downloading: raise DownloadNotRunningError() - download_progress['status'] = 'paused' + self._progress['status'] = 'paused' return { 'success': True, 'message': 'Download paused' } - @staticmethod - async def resume_download(request): - """Resume the example images download""" - global download_progress - - if not is_downloading: + async def resume_download(self, request): + """Resume the example images download.""" + + if not self._is_downloading: raise DownloadNotRunningError() - if download_progress['status'] == 'paused': - download_progress['status'] = 'running' + if self._progress['status'] == 'paused': + self._progress['status'] = 'running' return { 'success': True, @@ -203,15 +190,12 @@ class DownloadManager: } raise DownloadNotRunningError( - f"Download is in '{download_progress['status']}' state, cannot resume" + f"Download is in '{self._progress['status']}' state, cannot resume" ) - @staticmethod - async def _download_all_example_images(output_dir, optimize, model_types, delay): - """Download example images for all models""" - global is_downloading, download_progress - - # Get unified downloader + async def _download_all_example_images(self, output_dir, optimize, model_types, delay): + """Download example images for all models.""" + downloader = await get_downloader() try: @@ -239,59 +223,58 @@ class DownloadManager: all_models.append((scanner_type, model, scanner)) # Update total count - download_progress['total'] = len(all_models) - logger.debug(f"Found {download_progress['total']} models to process") + self._progress['total'] = len(all_models) + logger.debug(f"Found {self._progress['total']} models to process") # Process each model for i, (scanner_type, model, scanner) in enumerate(all_models): # Main logic for processing model is here, but actual operations are delegated to other classes - was_remote_download = await DownloadManager._process_model( - scanner_type, model, scanner, + was_remote_download = await self._process_model( + scanner_type, model, scanner, output_dir, optimize, downloader ) # Update progress - download_progress['completed'] += 1 + self._progress['completed'] += 1 # Only add delay after remote download of models, and not after processing the last model - if was_remote_download and i < len(all_models) - 1 and download_progress['status'] == 'running': + if was_remote_download and i < len(all_models) - 1 and self._progress['status'] == 'running': await asyncio.sleep(delay) # Mark as completed - download_progress['status'] = 'completed' - download_progress['end_time'] = time.time() - logger.debug(f"Example images download completed: {download_progress['completed']}/{download_progress['total']} models processed") + self._progress['status'] = 'completed' + self._progress['end_time'] = time.time() + logger.debug(f"Example images download completed: {self._progress['completed']}/{self._progress['total']} models processed") except Exception as e: error_msg = f"Error during example images download: {str(e)}" logger.error(error_msg, exc_info=True) - download_progress['errors'].append(error_msg) - download_progress['last_error'] = error_msg - download_progress['status'] = 'error' - download_progress['end_time'] = time.time() + self._progress['errors'].append(error_msg) + self._progress['last_error'] = error_msg + self._progress['status'] = 'error' + self._progress['end_time'] = time.time() finally: # Save final progress to file try: - DownloadManager._save_progress(output_dir) + self._save_progress(output_dir) except Exception as e: logger.error(f"Failed to save progress file: {e}") - + # Set download status to not downloading - is_downloading = False + self._is_downloading = False + self._download_task = None - @staticmethod - async def _process_model(scanner_type, model, scanner, output_dir, optimize, downloader): - """Process a single model download""" - global download_progress - + async def _process_model(self, scanner_type, model, scanner, output_dir, optimize, downloader): + """Process a single model download.""" + # Check if download is paused - while download_progress['status'] == 'paused': + while self._progress['status'] == 'paused': await asyncio.sleep(1) - + # Check if download should continue - if download_progress['status'] != 'running': - logger.info(f"Download stopped: {download_progress['status']}") + if self._progress['status'] != 'running': + logger.info(f"Download stopped: {self._progress['status']}") return False # Return False to indicate no remote download happened model_hash = model.get('sha256', '').lower() @@ -301,15 +284,15 @@ class DownloadManager: try: # Update current model info - download_progress['current_model'] = f"{model_name} ({model_hash[:8]})" + self._progress['current_model'] = f"{model_name} ({model_hash[:8]})" # Skip if already in failed models - if model_hash in download_progress['failed_models']: + if model_hash in self._progress['failed_models']: logger.debug(f"Skipping known failed model: {model_name}") return False # Skip if already processed AND directory exists with files - if model_hash in download_progress['processed_models']: + if model_hash in self._progress['processed_models']: model_dir = os.path.join(output_dir, model_hash) has_files = os.path.exists(model_dir) and any(os.listdir(model_dir)) if has_files: @@ -318,7 +301,7 @@ class DownloadManager: else: logger.info(f"Model {model_name} marked as processed but folder empty or missing, reprocessing") # Remove from processed models since we need to reprocess - download_progress['processed_models'].discard(model_hash) + self._progress['processed_models'].discard(model_hash) # Create model directory model_dir = os.path.join(output_dir, model_hash) @@ -334,7 +317,7 @@ class DownloadManager: await MetadataUpdater.update_metadata_from_local_examples( model_hash, model, scanner_type, scanner, model_dir ) - download_progress['processed_models'].add(model_hash) + self._progress['processed_models'].add(model_hash) return False # Return False to indicate no remote download happened # If no local images, try to download from remote @@ -346,57 +329,55 @@ class DownloadManager: ) # If metadata is stale, try to refresh it - if is_stale and model_hash not in download_progress['refreshed_models']: + if is_stale and model_hash not in self._progress['refreshed_models']: await MetadataUpdater.refresh_model_metadata( - model_hash, model_name, scanner_type, scanner + model_hash, model_name, scanner_type, scanner, self._progress ) - + # Get the updated model data updated_model = await MetadataUpdater.get_updated_model( model_hash, scanner ) - + if updated_model and updated_model.get('civitai', {}).get('images'): # Retry download with updated metadata updated_images = updated_model.get('civitai', {}).get('images', []) success, _ = await ExampleImagesProcessor.download_model_images( model_hash, model_name, updated_images, model_dir, optimize, downloader ) - - download_progress['refreshed_models'].add(model_hash) + + self._progress['refreshed_models'].add(model_hash) # Mark as processed if successful, or as failed if unsuccessful after refresh if success: - download_progress['processed_models'].add(model_hash) + self._progress['processed_models'].add(model_hash) else: # If we refreshed metadata and still failed, mark as permanently failed - if model_hash in download_progress['refreshed_models']: - download_progress['failed_models'].add(model_hash) + if model_hash in self._progress['refreshed_models']: + self._progress['failed_models'].add(model_hash) logger.info(f"Marking model {model_name} as failed after metadata refresh") return True # Return True to indicate a remote download happened else: # No civitai data or images available, mark as failed to avoid future attempts - download_progress['failed_models'].add(model_hash) + self._progress['failed_models'].add(model_hash) logger.debug(f"No civitai images available for model {model_name}, marking as failed") # Save progress periodically - if download_progress['completed'] % 10 == 0 or download_progress['completed'] == download_progress['total'] - 1: - DownloadManager._save_progress(output_dir) + if self._progress['completed'] % 10 == 0 or self._progress['completed'] == self._progress['total'] - 1: + self._save_progress(output_dir) return False # Default return if no conditions met except Exception as e: error_msg = f"Error processing model {model.get('model_name')}: {str(e)}" logger.error(error_msg, exc_info=True) - download_progress['errors'].append(error_msg) - download_progress['last_error'] = error_msg + self._progress['errors'].append(error_msg) + self._progress['last_error'] = error_msg return False # Return False on exception - @staticmethod - def _save_progress(output_dir): - """Save download progress to file""" - global download_progress + def _save_progress(self, output_dir): + """Save download progress to file.""" try: progress_file = os.path.join(output_dir, '.download_progress.json') @@ -411,11 +392,11 @@ class DownloadManager: # Create new progress data progress_data = { - 'processed_models': list(download_progress['processed_models']), - 'refreshed_models': list(download_progress['refreshed_models']), - 'failed_models': list(download_progress['failed_models']), - 'completed': download_progress['completed'], - 'total': download_progress['total'], + 'processed_models': list(self._progress['processed_models']), + 'refreshed_models': list(self._progress['refreshed_models']), + 'failed_models': list(self._progress['failed_models']), + 'completed': self._progress['completed'], + 'total': self._progress['total'], 'last_update': time.time() } @@ -430,70 +411,46 @@ class DownloadManager: except Exception as e: logger.error(f"Failed to save progress file: {e}") - @staticmethod - async def start_force_download(options: dict): - """ - Force download example images for specific models - - Expects a JSON body with: - { - "model_hashes": ["hash1", "hash2", ...], # List of model hashes to download - "optimize": true, # Whether to optimize images (default: true) - "model_types": ["lora", "checkpoint"], # Model types to process (default: both) - "delay": 1.0 # Delay between downloads (default: 1.0) - } - """ - global download_task, is_downloading, download_progress + async def start_force_download(self, options: dict): + """Force download example images for specific models.""" - if is_downloading: - raise DownloadInProgressError(_serialize_progress()) + if self._is_downloading: + raise DownloadInProgressError(self._progress.snapshot()) try: data = options or {} model_hashes = data.get('model_hashes', []) optimize = data.get('optimize', True) model_types = data.get('model_types', ['lora', 'checkpoint']) - delay = float(data.get('delay', 0.2)) # Default to 0.2 seconds + delay = float(data.get('delay', 0.2)) if not model_hashes: raise DownloadConfigurationError('Missing model_hashes parameter') - # Get output directory from settings output_dir = settings.get('example_images_path') if not output_dir: raise DownloadConfigurationError('Example images path not configured in settings') - - # Create the output directory + os.makedirs(output_dir, exist_ok=True) - - # Initialize progress tracking - download_progress['total'] = len(model_hashes) - download_progress['completed'] = 0 - download_progress['current_model'] = '' - download_progress['status'] = 'running' - download_progress['errors'] = [] - download_progress['last_error'] = None - download_progress['start_time'] = time.time() - download_progress['end_time'] = None - download_progress['processed_models'] = set() - download_progress['refreshed_models'] = set() - download_progress['failed_models'] = set() - # Set download status to downloading - is_downloading = True + self._progress.reset() + self._progress['total'] = len(model_hashes) + self._progress['status'] = 'running' + self._progress['start_time'] = time.time() + self._progress['end_time'] = None - # Execute the download function directly instead of creating a background task - result = await DownloadManager._download_specific_models_example_images_sync( + self._is_downloading = True + + result = await self._download_specific_models_example_images_sync( model_hashes, - output_dir, - optimize, + output_dir, + optimize, model_types, delay ) - # Set download status to not downloading - is_downloading = False + self._is_downloading = False return { 'success': True, @@ -502,17 +459,13 @@ class DownloadManager: } except Exception as e: - # Set download status to not downloading - is_downloading = False + self._is_downloading = False logger.error(f"Failed during forced example images download: {e}", exc_info=True) raise ExampleImagesDownloadError(str(e)) from e - @staticmethod - async def _download_specific_models_example_images_sync(model_hashes, output_dir, optimize, model_types, delay): - """Download example images for specific models only - synchronous version""" - global download_progress - - # Get unified downloader + async def _download_specific_models_example_images_sync(self, model_hashes, output_dir, optimize, model_types, delay): + """Download example images for specific models only - synchronous version.""" + downloader = await get_downloader() try: @@ -540,14 +493,14 @@ class DownloadManager: models_to_process.append((scanner_type, model, scanner)) # Update total count based on found models - download_progress['total'] = len(models_to_process) - logger.debug(f"Found {download_progress['total']} models to process") + self._progress['total'] = len(models_to_process) + logger.debug(f"Found {self._progress['total']} models to process") # Send initial progress via WebSocket await ws_manager.broadcast({ 'type': 'example_images_progress', 'processed': 0, - 'total': download_progress['total'], + 'total': self._progress['total'], 'status': 'running', 'current_model': '' }) @@ -556,8 +509,8 @@ class DownloadManager: success_count = 0 for i, (scanner_type, model, scanner) in enumerate(models_to_process): # Force process this model regardless of previous status - was_successful = await DownloadManager._process_specific_model( - scanner_type, model, scanner, + was_successful = await self._process_specific_model( + scanner_type, model, scanner, output_dir, optimize, downloader ) @@ -565,55 +518,55 @@ class DownloadManager: success_count += 1 # Update progress - download_progress['completed'] += 1 + self._progress['completed'] += 1 # Send progress update via WebSocket await ws_manager.broadcast({ 'type': 'example_images_progress', - 'processed': download_progress['completed'], - 'total': download_progress['total'], + 'processed': self._progress['completed'], + 'total': self._progress['total'], 'status': 'running', - 'current_model': download_progress['current_model'] + 'current_model': self._progress['current_model'] }) # Only add delay after remote download, and not after processing the last model - if was_successful and i < len(models_to_process) - 1 and download_progress['status'] == 'running': + if was_successful and i < len(models_to_process) - 1 and self._progress['status'] == 'running': await asyncio.sleep(delay) # Mark as completed - download_progress['status'] = 'completed' - download_progress['end_time'] = time.time() - logger.debug(f"Forced example images download completed: {download_progress['completed']}/{download_progress['total']} models processed") + self._progress['status'] = 'completed' + self._progress['end_time'] = time.time() + logger.debug(f"Forced example images download completed: {self._progress['completed']}/{self._progress['total']} models processed") # Send final progress via WebSocket await ws_manager.broadcast({ 'type': 'example_images_progress', - 'processed': download_progress['completed'], - 'total': download_progress['total'], + 'processed': self._progress['completed'], + 'total': self._progress['total'], 'status': 'completed', 'current_model': '' }) return { - 'total': download_progress['total'], - 'processed': download_progress['completed'], + 'total': self._progress['total'], + 'processed': self._progress['completed'], 'successful': success_count, - 'errors': download_progress['errors'] + 'errors': self._progress['errors'] } except Exception as e: error_msg = f"Error during forced example images download: {str(e)}" logger.error(error_msg, exc_info=True) - download_progress['errors'].append(error_msg) - download_progress['last_error'] = error_msg - download_progress['status'] = 'error' - download_progress['end_time'] = time.time() + self._progress['errors'].append(error_msg) + self._progress['last_error'] = error_msg + self._progress['status'] = 'error' + self._progress['end_time'] = time.time() # Send error status via WebSocket await ws_manager.broadcast({ 'type': 'example_images_progress', - 'processed': download_progress['completed'], - 'total': download_progress['total'], + 'processed': self._progress['completed'], + 'total': self._progress['total'], 'status': 'error', 'error': error_msg, 'current_model': '' @@ -625,18 +578,16 @@ class DownloadManager: # No need to close any sessions since we use the global downloader pass - @staticmethod - async def _process_specific_model(scanner_type, model, scanner, output_dir, optimize, downloader): - """Process a specific model for forced download, ignoring previous download status""" - global download_progress - + async def _process_specific_model(self, scanner_type, model, scanner, output_dir, optimize, downloader): + """Process a specific model for forced download, ignoring previous download status.""" + # Check if download is paused - while download_progress['status'] == 'paused': + while self._progress['status'] == 'paused': await asyncio.sleep(1) # Check if download should continue - if download_progress['status'] != 'running': - logger.info(f"Download stopped: {download_progress['status']}") + if self._progress['status'] != 'running': + logger.info(f"Download stopped: {self._progress['status']}") return False model_hash = model.get('sha256', '').lower() @@ -646,7 +597,7 @@ class DownloadManager: try: # Update current model info - download_progress['current_model'] = f"{model_name} ({model_hash[:8]})" + self._progress['current_model'] = f"{model_name} ({model_hash[:8]})" # Create model directory model_dir = os.path.join(output_dir, model_hash) @@ -662,7 +613,7 @@ class DownloadManager: await MetadataUpdater.update_metadata_from_local_examples( model_hash, model, scanner_type, scanner, model_dir ) - download_progress['processed_models'].add(model_hash) + self._progress['processed_models'].add(model_hash) return False # Return False to indicate no remote download happened # If no local images, try to download from remote @@ -674,9 +625,9 @@ class DownloadManager: ) # If metadata is stale, try to refresh it - if is_stale and model_hash not in download_progress['refreshed_models']: + if is_stale and model_hash not in self._progress['refreshed_models']: await MetadataUpdater.refresh_model_metadata( - model_hash, model_name, scanner_type, scanner + model_hash, model_name, scanner_type, scanner, self._progress ) # Get the updated model data @@ -694,18 +645,18 @@ class DownloadManager: # Combine failed images from both attempts failed_images.extend(additional_failed_images) - download_progress['refreshed_models'].add(model_hash) + self._progress['refreshed_models'].add(model_hash) # For forced downloads, remove failed images from metadata if failed_images: # Create a copy of images excluding failed ones - await DownloadManager._remove_failed_images_from_metadata( + await self._remove_failed_images_from_metadata( model_hash, model_name, failed_images, scanner ) # Mark as processed if success or failed_images: # Mark as processed if we successfully downloaded some images or removed failed ones - download_progress['processed_models'].add(model_hash) + self._progress['processed_models'].add(model_hash) return True # Return True to indicate a remote download happened else: @@ -715,12 +666,11 @@ class DownloadManager: except Exception as e: error_msg = f"Error processing model {model.get('model_name')}: {str(e)}" logger.error(error_msg, exc_info=True) - download_progress['errors'].append(error_msg) - download_progress['last_error'] = error_msg + self._progress['errors'].append(error_msg) + self._progress['last_error'] = error_msg return False # Return False on exception - @staticmethod - async def _remove_failed_images_from_metadata(model_hash, model_name, failed_images, scanner): + async def _remove_failed_images_from_metadata(self, model_hash, model_name, failed_images, scanner): """Remove failed images from model metadata""" try: # Get current model data @@ -762,4 +712,13 @@ class DownloadManager: await scanner.update_single_model_cache(file_path, file_path, model_data) except Exception as e: - logger.error(f"Error removing failed images from metadata for {model_name}: {e}", exc_info=True) \ No newline at end of file + logger.error(f"Error removing failed images from metadata for {model_name}: {e}", exc_info=True) + + +default_download_manager = DownloadManager() + + +def get_default_download_manager() -> DownloadManager: + """Return the singleton download manager used by default routes.""" + + return default_download_manager diff --git a/py/utils/example_images_metadata.py b/py/utils/example_images_metadata.py index 8820b49b..780eb43b 100644 --- a/py/utils/example_images_metadata.py +++ b/py/utils/example_images_metadata.py @@ -33,7 +33,7 @@ class MetadataUpdater: """Handles updating model metadata related to example images""" @staticmethod - async def refresh_model_metadata(model_hash, model_name, scanner_type, scanner): + async def refresh_model_metadata(model_hash, model_name, scanner_type, scanner, progress: dict | None = None): """Refresh model metadata from CivitAI Args: @@ -45,8 +45,6 @@ class MetadataUpdater: Returns: bool: True if metadata was successfully refreshed, False otherwise """ - from ..utils.example_images_download_manager import download_progress - try: # Find the model in the scanner cache cache = await scanner.get_cached_data() @@ -67,7 +65,8 @@ class MetadataUpdater: return False # Track that we're refreshing this model - download_progress['refreshed_models'].add(model_hash) + if progress is not None: + progress['refreshed_models'].add(model_hash) async def update_cache_func(old_path, new_path, metadata): return await scanner.update_single_model_cache(old_path, new_path, metadata) @@ -85,12 +84,13 @@ class MetadataUpdater: else: logger.warning(f"Failed to refresh metadata for {model_name}, {error}") return False - + except Exception as e: error_msg = f"Error refreshing metadata for {model_name}: {str(e)}" logger.error(error_msg, exc_info=True) - download_progress['errors'].append(error_msg) - download_progress['last_error'] = error_msg + if progress is not None: + progress['errors'].append(error_msg) + progress['last_error'] = error_msg return False @staticmethod