From 3cd57a582c4cd228af939d8404e24779b88a40da Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Fri, 15 Aug 2025 15:16:12 +0800 Subject: [PATCH] feat: add force download functionality for example images with progress tracking --- py/routes/example_images_routes.py | 9 +- py/utils/example_images_download_manager.py | 364 +++++++++++++++++++- py/utils/example_images_processor.py | 72 ++++ static/js/api/apiConfig.js | 3 +- static/js/api/baseModelApi.js | 100 +++++- static/js/components/shared/ModelCard.js | 39 ++- 6 files changed, 572 insertions(+), 15 deletions(-) diff --git a/py/routes/example_images_routes.py b/py/routes/example_images_routes.py index c26e2db3..9f20b470 100644 --- a/py/routes/example_images_routes.py +++ b/py/routes/example_images_routes.py @@ -2,6 +2,7 @@ import logging from ..utils.example_images_download_manager import DownloadManager from ..utils.example_images_processor import ExampleImagesProcessor from ..utils.example_images_file_manager import ExampleImagesFileManager +from ..services.websocket_manager import ws_manager logger = logging.getLogger(__name__) @@ -20,6 +21,7 @@ class ExampleImagesRoutes: app.router.add_get('/api/example-image-files', ExampleImagesRoutes.get_example_image_files) app.router.add_get('/api/has-example-images', ExampleImagesRoutes.has_example_images) app.router.add_post('/api/delete-example-image', ExampleImagesRoutes.delete_example_image) + app.router.add_post('/api/force-download-example-images', ExampleImagesRoutes.force_download_example_images) @staticmethod async def download_example_images(request): @@ -64,4 +66,9 @@ class ExampleImagesRoutes: @staticmethod async def delete_example_image(request): """Delete a custom example image for a model""" - return await ExampleImagesProcessor.delete_custom_image(request) \ No newline at end of file + return await ExampleImagesProcessor.delete_custom_image(request) + + @staticmethod + async def force_download_example_images(request): + """Force download example images for specific models""" + return await DownloadManager.start_force_download(request) \ No newline at end of file diff --git a/py/utils/example_images_download_manager.py b/py/utils/example_images_download_manager.py index 545d4662..db1b93f0 100644 --- a/py/utils/example_images_download_manager.py +++ b/py/utils/example_images_download_manager.py @@ -6,8 +6,10 @@ import time import aiohttp from aiohttp import web from ..services.service_registry import ServiceRegistry +from ..utils.metadata_manager import MetadataManager from .example_images_processor import ExampleImagesProcessor from .example_images_metadata import MetadataUpdater +from ..services.websocket_manager import ws_manager # Add this import at the top logger = logging.getLogger(__name__) @@ -431,4 +433,364 @@ class DownloadManager: with open(progress_file, 'w', encoding='utf-8') as f: json.dump(progress_data, f, indent=2) except Exception as e: - logger.error(f"Failed to save progress file: {e}") \ No newline at end of file + logger.error(f"Failed to save progress file: {e}") + + @staticmethod + async def start_force_download(request): + """ + Force download example images for specific models + + Expects a JSON body with: + { + "model_hashes": ["hash1", "hash2", ...], # List of model hashes to download + "output_dir": "path/to/output", # Base directory to save example images + "optimize": true, # Whether to optimize images (default: true) + "model_types": ["lora", "checkpoint"], # Model types to process (default: both) + "delay": 1.0 # Delay between downloads (default: 1.0) + } + """ + global download_task, is_downloading, download_progress + + if is_downloading: + return web.json_response({ + 'success': False, + 'error': 'Download already in progress' + }, status=400) + + try: + # Parse the request body + data = await request.json() + model_hashes = data.get('model_hashes', []) + output_dir = data.get('output_dir') + optimize = data.get('optimize', True) + model_types = data.get('model_types', ['lora', 'checkpoint']) + delay = float(data.get('delay', 0.2)) # Default to 0.2 seconds + + if not model_hashes: + return web.json_response({ + 'success': False, + 'error': 'Missing model_hashes parameter' + }, status=400) + + if not output_dir: + return web.json_response({ + 'success': False, + 'error': 'Missing output_dir parameter' + }, status=400) + + # Create the output directory + os.makedirs(output_dir, exist_ok=True) + + # Initialize progress tracking + download_progress['total'] = len(model_hashes) + download_progress['completed'] = 0 + download_progress['current_model'] = '' + download_progress['status'] = 'running' + download_progress['errors'] = [] + download_progress['last_error'] = None + download_progress['start_time'] = time.time() + download_progress['end_time'] = None + download_progress['processed_models'] = set() + download_progress['refreshed_models'] = set() + download_progress['failed_models'] = set() + + # Set download status to downloading + is_downloading = True + + # Execute the download function directly instead of creating a background task + result = await DownloadManager._download_specific_models_example_images_sync( + model_hashes, + output_dir, + optimize, + model_types, + delay + ) + + # Set download status to not downloading + is_downloading = False + + return web.json_response({ + 'success': True, + 'message': 'Force download completed', + 'result': result + }) + + except Exception as e: + # Set download status to not downloading + is_downloading = False + logger.error(f"Failed during forced example images download: {e}", exc_info=True) + return web.json_response({ + 'success': False, + 'error': str(e) + }, status=500) + + @staticmethod + async def _download_specific_models_example_images_sync(model_hashes, output_dir, optimize, model_types, delay): + """Download example images for specific models only - synchronous version""" + global download_progress + + # Create independent download session + connector = aiohttp.TCPConnector( + ssl=True, + limit=3, + force_close=False, + enable_cleanup_closed=True + ) + timeout = aiohttp.ClientTimeout(total=None, connect=60, sock_read=60) + independent_session = aiohttp.ClientSession( + connector=connector, + trust_env=True, + timeout=timeout + ) + + try: + # Get scanners + scanners = [] + if 'lora' in model_types: + lora_scanner = await ServiceRegistry.get_lora_scanner() + scanners.append(('lora', lora_scanner)) + + if 'checkpoint' in model_types: + checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner() + scanners.append(('checkpoint', checkpoint_scanner)) + + if 'embedding' in model_types: + embedding_scanner = await ServiceRegistry.get_embedding_scanner() + scanners.append(('embedding', embedding_scanner)) + + # Find the specified models + models_to_process = [] + for scanner_type, scanner in scanners: + cache = await scanner.get_cached_data() + if cache and cache.raw_data: + for model in cache.raw_data: + if model.get('sha256') in model_hashes: + models_to_process.append((scanner_type, model, scanner)) + + # Update total count based on found models + download_progress['total'] = len(models_to_process) + logger.debug(f"Found {download_progress['total']} models to process") + + # Send initial progress via WebSocket + await ws_manager.broadcast({ + 'type': 'example_images_progress', + 'processed': 0, + 'total': download_progress['total'], + 'status': 'running', + 'current_model': '' + }) + + # Process each model + success_count = 0 + for i, (scanner_type, model, scanner) in enumerate(models_to_process): + # Force process this model regardless of previous status + was_successful = await DownloadManager._process_specific_model( + scanner_type, model, scanner, + output_dir, optimize, independent_session + ) + + if was_successful: + success_count += 1 + + # Update progress + download_progress['completed'] += 1 + + # Send progress update via WebSocket + await ws_manager.broadcast({ + 'type': 'example_images_progress', + 'processed': download_progress['completed'], + 'total': download_progress['total'], + 'status': 'running', + 'current_model': download_progress['current_model'] + }) + + # Only add delay after remote download, and not after processing the last model + if was_successful and i < len(models_to_process) - 1 and download_progress['status'] == 'running': + await asyncio.sleep(delay) + + # Mark as completed + download_progress['status'] = 'completed' + download_progress['end_time'] = time.time() + logger.debug(f"Forced example images download completed: {download_progress['completed']}/{download_progress['total']} models processed") + + # Send final progress via WebSocket + await ws_manager.broadcast({ + 'type': 'example_images_progress', + 'processed': download_progress['completed'], + 'total': download_progress['total'], + 'status': 'completed', + 'current_model': '' + }) + + return { + 'total': download_progress['total'], + 'processed': download_progress['completed'], + 'successful': success_count, + 'errors': download_progress['errors'] + } + + except Exception as e: + error_msg = f"Error during forced example images download: {str(e)}" + logger.error(error_msg, exc_info=True) + download_progress['errors'].append(error_msg) + download_progress['last_error'] = error_msg + download_progress['status'] = 'error' + download_progress['end_time'] = time.time() + + # Send error status via WebSocket + await ws_manager.broadcast({ + 'type': 'example_images_progress', + 'processed': download_progress['completed'], + 'total': download_progress['total'], + 'status': 'error', + 'error': error_msg, + 'current_model': '' + }) + + raise + + finally: + # Close the independent session + try: + await independent_session.close() + except Exception as e: + logger.error(f"Error closing download session: {e}") + + @staticmethod + async def _process_specific_model(scanner_type, model, scanner, output_dir, optimize, independent_session): + """Process a specific model for forced download, ignoring previous download status""" + global download_progress + + # Check if download is paused + while download_progress['status'] == 'paused': + await asyncio.sleep(1) + + # Check if download should continue + if download_progress['status'] != 'running': + logger.info(f"Download stopped: {download_progress['status']}") + return False + + model_hash = model.get('sha256', '').lower() + model_name = model.get('model_name', 'Unknown') + model_file_path = model.get('file_path', '') + model_file_name = model.get('file_name', '') + + try: + # Update current model info + download_progress['current_model'] = f"{model_name} ({model_hash[:8]})" + + # Create model directory + model_dir = os.path.join(output_dir, model_hash) + os.makedirs(model_dir, exist_ok=True) + + # First check for local example images - local processing doesn't need delay + local_images_processed = await ExampleImagesProcessor.process_local_examples( + model_file_path, model_file_name, model_name, model_dir, optimize + ) + + # If we processed local images, update metadata + if local_images_processed: + await MetadataUpdater.update_metadata_from_local_examples( + model_hash, model, scanner_type, scanner, model_dir + ) + download_progress['processed_models'].add(model_hash) + return False # Return False to indicate no remote download happened + + # If no local images, try to download from remote + elif model.get('civitai') and model.get('civitai', {}).get('images'): + images = model.get('civitai', {}).get('images', []) + + success, is_stale, failed_images = await ExampleImagesProcessor.download_model_images_with_tracking( + model_hash, model_name, images, model_dir, optimize, independent_session + ) + + # If metadata is stale, try to refresh it + if is_stale and model_hash not in download_progress['refreshed_models']: + await MetadataUpdater.refresh_model_metadata( + model_hash, model_name, scanner_type, scanner + ) + + # Get the updated model data + updated_model = await MetadataUpdater.get_updated_model( + model_hash, scanner + ) + + if updated_model and updated_model.get('civitai', {}).get('images'): + # Retry download with updated metadata + updated_images = updated_model.get('civitai', {}).get('images', []) + success, _, additional_failed_images = await ExampleImagesProcessor.download_model_images_with_tracking( + model_hash, model_name, updated_images, model_dir, optimize, independent_session + ) + + # Combine failed images from both attempts + failed_images.extend(additional_failed_images) + + download_progress['refreshed_models'].add(model_hash) + + # For forced downloads, remove failed images from metadata + if failed_images: + # Create a copy of images excluding failed ones + await DownloadManager._remove_failed_images_from_metadata( + model_hash, model_name, failed_images, scanner + ) + + # Mark as processed + if success or failed_images: # Mark as processed if we successfully downloaded some images or removed failed ones + download_progress['processed_models'].add(model_hash) + + return True # Return True to indicate a remote download happened + else: + logger.debug(f"No civitai images available for model {model_name}") + return False + + except Exception as e: + error_msg = f"Error processing model {model.get('model_name')}: {str(e)}" + logger.error(error_msg, exc_info=True) + download_progress['errors'].append(error_msg) + download_progress['last_error'] = error_msg + return False # Return False on exception + + @staticmethod + async def _remove_failed_images_from_metadata(model_hash, model_name, failed_images, scanner): + """Remove failed images from model metadata""" + try: + # Get current model data + model_data = await MetadataUpdater.get_updated_model(model_hash, scanner) + if not model_data: + logger.warning(f"Could not find model data for {model_name} to remove failed images") + return + + if not model_data.get('civitai', {}).get('images'): + logger.warning(f"No images in metadata for {model_name}") + return + + # Get current images + current_images = model_data['civitai']['images'] + + # Filter out failed images + updated_images = [img for img in current_images if img.get('url') not in failed_images] + + # If images were removed, update metadata + if len(updated_images) < len(current_images): + removed_count = len(current_images) - len(updated_images) + logger.info(f"Removing {removed_count} failed images from metadata for {model_name}") + + # Update the images list + model_data['civitai']['images'] = updated_images + + # Save metadata to file + file_path = model_data.get('file_path') + if file_path: + # Create a copy of model data without 'folder' field + model_copy = model_data.copy() + model_copy.pop('folder', None) + + # Write metadata to file + await MetadataManager.save_metadata(file_path, model_copy) + logger.info(f"Saved updated metadata for {model_name} after removing failed images") + + # Update the scanner cache + await scanner.update_single_model_cache(file_path, file_path, model_data) + + except Exception as e: + logger.error(f"Error removing failed images from metadata for {model_name}: {e}", exc_info=True) \ No newline at end of file diff --git a/py/utils/example_images_processor.py b/py/utils/example_images_processor.py index cf4852ab..6d14e621 100644 --- a/py/utils/example_images_processor.py +++ b/py/utils/example_images_processor.py @@ -102,6 +102,78 @@ class ExampleImagesProcessor: return model_success, False # (success, is_metadata_stale) + @staticmethod + async def download_model_images_with_tracking(model_hash, model_name, model_images, model_dir, optimize, independent_session): + """Download images for a single model with tracking of failed image URLs + + Returns: + tuple: (success, is_stale_metadata, failed_images) - whether download was successful, whether metadata is stale, list of failed image URLs + """ + model_success = True + failed_images = [] + + for i, image in enumerate(model_images): + image_url = image.get('url') + if not image_url: + continue + + # Get image filename from URL + image_filename = os.path.basename(image_url.split('?')[0]) + image_ext = os.path.splitext(image_filename)[1].lower() + + # Handle images and videos + is_image = image_ext in SUPPORTED_MEDIA_EXTENSIONS['images'] + is_video = image_ext in SUPPORTED_MEDIA_EXTENSIONS['videos'] + + if not (is_image or is_video): + logger.debug(f"Skipping unsupported file type: {image_filename}") + continue + + # Use 0-based indexing instead of 1-based indexing + save_filename = f"image_{i}{image_ext}" + + # If optimizing images and this is a Civitai image, use their pre-optimized WebP version + if is_image and optimize and 'civitai.com' in image_url: + image_url = ExampleImagesProcessor.get_civitai_optimized_url(image_url) + save_filename = f"image_{i}.webp" + + # Check if already downloaded + save_path = os.path.join(model_dir, save_filename) + if os.path.exists(save_path): + logger.debug(f"File already exists: {save_path}") + continue + + # Download the file + try: + logger.debug(f"Downloading {save_filename} for {model_name}") + + # Download directly using the independent session + async with independent_session.get(image_url, timeout=60) as response: + if response.status == 200: + with open(save_path, 'wb') as f: + async for chunk in response.content.iter_chunked(8192): + if chunk: + f.write(chunk) + elif response.status == 404: + error_msg = f"Failed to download file: {image_url}, status code: 404 - Model metadata might be stale" + logger.warning(error_msg) + model_success = False # Mark the model as failed due to 404 error + failed_images.append(image_url) # Track failed URL + # Return early to trigger metadata refresh attempt + return False, True, failed_images # (success, is_metadata_stale, failed_images) + else: + error_msg = f"Failed to download file: {image_url}, status code: {response.status}" + logger.warning(error_msg) + model_success = False # Mark the model as failed + failed_images.append(image_url) # Track failed URL + except Exception as e: + error_msg = f"Error downloading file {image_url}: {str(e)}" + logger.error(error_msg) + model_success = False # Mark the model as failed + failed_images.append(image_url) # Track failed URL + + return model_success, False, failed_images # (success, is_metadata_stale, failed_images) + @staticmethod async def process_local_examples(model_file_path, model_file_name, model_name, model_dir, optimize): """Process local example images diff --git a/static/js/api/apiConfig.js b/static/js/api/apiConfig.js index bd8c9104..b0fcc30a 100644 --- a/static/js/api/apiConfig.js +++ b/static/js/api/apiConfig.js @@ -165,7 +165,8 @@ export const DOWNLOAD_ENDPOINTS = { download: '/api/download-model', downloadGet: '/api/download-model-get', cancelGet: '/api/cancel-download-get', - progress: '/api/download-progress' + progress: '/api/download-progress', + exampleImages: '/api/force-download-example-images' // New endpoint for downloading example images }; // WebSocket endpoints diff --git a/static/js/api/baseModelApi.js b/static/js/api/baseModelApi.js index 465975de..03475ce8 100644 --- a/static/js/api/baseModelApi.js +++ b/static/js/api/baseModelApi.js @@ -1,6 +1,6 @@ import { state, getCurrentPageState } from '../state/index.js'; import { showToast, updateFolderTags } from '../utils/uiHelpers.js'; -import { getSessionItem, saveMapToStorage } from '../utils/storageHelpers.js'; +import { getStorageItem, getSessionItem, saveMapToStorage } from '../utils/storageHelpers.js'; import { getCompleteApiConfig, getCurrentModelType, @@ -855,4 +855,102 @@ export class BaseModelApiClient { state.loadingManager.hide(); } } + + async downloadExampleImages(modelHashes, modelTypes) { + let ws = null; + + await state.loadingManager.showWithProgress(async (loading) => { + try { + // Connect to WebSocket for progress updates + const wsProtocol = window.location.protocol === 'https:' ? 'wss://' : 'ws://'; + ws = new WebSocket(`${wsProtocol}${window.location.host}${WS_ENDPOINTS.fetchProgress}`); + + const operationComplete = new Promise((resolve, reject) => { + ws.onmessage = (event) => { + const data = JSON.parse(event.data); + + if (data.type !== 'example_images_progress') return; + + switch(data.status) { + case 'running': + const percent = ((data.processed / data.total) * 100).toFixed(1); + loading.setProgress(percent); + loading.setStatus( + `Processing (${data.processed}/${data.total}) ${data.current_model || ''}` + ); + break; + + case 'completed': + loading.setProgress(100); + loading.setStatus( + `Completed: Downloaded example images for ${data.processed} models` + ); + resolve(); + break; + + case 'error': + reject(new Error(data.error)); + break; + } + }; + + ws.onerror = (error) => { + reject(new Error('WebSocket error: ' + error.message)); + }; + }); + + // Wait for WebSocket connection to establish + await new Promise((resolve, reject) => { + ws.onopen = resolve; + ws.onerror = reject; + }); + + // Get the output directory from storage + const outputDir = getStorageItem('example_images_path', ''); + if (!outputDir) { + throw new Error('Please set the example images path in the settings first.'); + } + + // Determine optimize setting + const optimize = state.global?.settings?.optimizeExampleImages ?? true; + + // Make the API request to start the download process + const response = await fetch(DOWNLOAD_ENDPOINTS.exampleImages, { + method: 'POST', + headers: { + 'Content-Type': 'application/json' + }, + body: JSON.stringify({ + model_hashes: modelHashes, + output_dir: outputDir, + optimize: optimize, + model_types: modelTypes || [this.apiConfig.config.singularName] + }) + }); + + if (!response.ok) { + const errorData = await response.json().catch(() => ({})); + throw new Error(errorData.error || 'Failed to download example images'); + } + + // Wait for the operation to complete via WebSocket + await operationComplete; + + showToast('Successfully downloaded example images!', 'success'); + return true; + + } catch (error) { + console.error('Error downloading example images:', error); + showToast(`Failed to download example images: ${error.message}`, 'error'); + throw error; + } finally { + if (ws) { + ws.close(); + } + } + }, { + initialMessage: 'Starting example images download...', + completionMessage: 'Example images download complete' + }); + } } \ No newline at end of file diff --git a/static/js/components/shared/ModelCard.js b/static/js/components/shared/ModelCard.js index b0865830..49ae8952 100644 --- a/static/js/components/shared/ModelCard.js +++ b/static/js/components/shared/ModelCard.js @@ -273,18 +273,35 @@ function showExampleAccessModal(card, modelType) { if (hasRemoteExamples) { downloadBtn.classList.remove('disabled'); downloadBtn.removeAttribute('title'); - downloadBtn.onclick = () => { + downloadBtn.onclick = async () => { + // Get the model hash + const modelHash = card.dataset.sha256; + if (!modelHash) { + showToast('Missing model hash information.', 'error'); + return; + } + + // Determine model type (singular form) + let modelTypeSingular = 'lora'; + if (modelType === MODEL_TYPES.CHECKPOINT) { + modelTypeSingular = 'checkpoint'; + } else if (modelType === MODEL_TYPES.EMBEDDING) { + modelTypeSingular = 'embedding'; + } + + // Close the modal modalManager.closeModal('exampleAccessModal'); - // Open settings modal and scroll to example images section - const settingsModal = document.getElementById('settingsModal'); - if (settingsModal) { - modalManager.showModal('settingsModal'); - setTimeout(() => { - const exampleSection = settingsModal.querySelector('.settings-section:nth-child(7)'); - if (exampleSection) { - exampleSection.scrollIntoView({ behavior: 'smooth' }); - } - }, 300); + + try { + // Use the appropriate model API client to download examples + const apiClient = getModelApiClient(modelType); + await apiClient.downloadExampleImages([modelHash], [modelTypeSingular]); + + // Open the example images folder if successful + openExampleImagesFolder(modelHash); + } catch (error) { + console.error('Error downloading example images:', error); + // Error already shown by the API client } }; } else {