From f1b928a037cc45721b7d50377a5b45b0c1c350c4 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Sat, 7 Jun 2025 09:34:07 +0800 Subject: [PATCH] Add migration functionality for example images: implement API endpoint and UI controls --- py/routes/misc_routes.py | 436 +++++++++++++++++++++ static/css/components/modal.css | 12 + static/js/managers/ExampleImagesManager.js | 103 ++++- templates/components/modals.html | 21 + 4 files changed, 569 insertions(+), 3 deletions(-) diff --git a/py/routes/misc_routes.py b/py/routes/misc_routes.py index 76f3a44d..43b2ddfa 100644 --- a/py/routes/misc_routes.py +++ b/py/routes/misc_routes.py @@ -52,6 +52,7 @@ class MiscRoutes: # Example images download routes app.router.add_post('/api/download-example-images', MiscRoutes.download_example_images) + app.router.add_post('/api/migrate-example-images', MiscRoutes.migrate_example_images) app.router.add_get('/api/example-images-status', MiscRoutes.get_example_images_status) app.router.add_post('/api/pause-example-images', MiscRoutes.pause_example_images) app.router.add_post('/api/resume-example-images', MiscRoutes.resume_example_images) @@ -832,6 +833,441 @@ class MiscRoutes: # Set download status to not downloading is_downloading = False + + @staticmethod + async def migrate_example_images(request): + """ + Migrate existing example images to central storage location + + Expects a JSON body with: + { + "output_dir": "path/to/output", # Base directory to save example images + "pattern": "{model}.example.{index}.{ext}", # Pattern to match example images + "optimize": true, # Whether to optimize images (default: true) + "model_types": ["lora", "checkpoint"], # Model types to process (default: both) + } + """ + global download_task, is_downloading, download_progress + + if is_downloading: + # Create a copy for JSON serialization + response_progress = download_progress.copy() + response_progress['processed_models'] = list(download_progress['processed_models']) + response_progress['refreshed_models'] = list(download_progress['refreshed_models']) + + return web.json_response({ + 'success': False, + 'error': 'Download or migration already in progress', + 'status': response_progress + }, status=400) + + try: + # Parse the request body + data = await request.json() + output_dir = data.get('output_dir') + pattern = data.get('pattern', '{model}.example.{index}.{ext}') + optimize = data.get('optimize', True) + model_types = data.get('model_types', ['lora', 'checkpoint']) + + if not output_dir: + return web.json_response({ + 'success': False, + 'error': 'Missing output_dir parameter' + }, status=400) + + # Create the output directory + os.makedirs(output_dir, exist_ok=True) + + # Initialize progress tracking + download_progress['total'] = 0 + download_progress['completed'] = 0 + download_progress['current_model'] = '' + download_progress['status'] = 'running' + download_progress['errors'] = [] + download_progress['last_error'] = None + download_progress['start_time'] = time.time() + download_progress['end_time'] = None + download_progress['is_migrating'] = True # Mark this as a migration task + + # Get the processed models list from a file if it exists + progress_file = os.path.join(output_dir, '.download_progress.json') + if os.path.exists(progress_file): + try: + with open(progress_file, 'r', encoding='utf-8') as f: + saved_progress = json.load(f) + download_progress['processed_models'] = set(saved_progress.get('processed_models', [])) + logger.info(f"Loaded previous progress, {len(download_progress['processed_models'])} models already processed") + except Exception as e: + logger.error(f"Failed to load progress file: {e}") + download_progress['processed_models'] = set() + else: + download_progress['processed_models'] = set() + + # Start the migration task + is_downloading = True + download_task = asyncio.create_task( + MiscRoutes._migrate_all_example_images( + output_dir, + pattern, + optimize, + model_types + ) + ) + + # Create a copy for JSON serialization + response_progress = download_progress.copy() + response_progress['processed_models'] = list(download_progress['processed_models']) + response_progress['refreshed_models'] = list(download_progress['refreshed_models']) + response_progress['is_migrating'] = True + + return web.json_response({ + 'success': True, + 'message': 'Migration started', + 'status': response_progress + }) + + except Exception as e: + logger.error(f"Failed to start example images migration: {e}", exc_info=True) + return web.json_response({ + 'success': False, + 'error': str(e) + }, status=500) + + @staticmethod + async def _migrate_all_example_images(output_dir, pattern, optimize, model_types): + """Migrate example images for all models based on pattern + + Args: + output_dir: Base directory to save example images + pattern: Pattern to match example images + optimize: Whether to optimize images + model_types: List of model types to process + """ + global is_downloading, download_progress + + try: + # Get the scanners + scanners = [] + if 'lora' in model_types: + lora_scanner = await ServiceRegistry.get_lora_scanner() + scanners.append(('lora', lora_scanner)) + + if 'checkpoint' in model_types: + checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner() + scanners.append(('checkpoint', checkpoint_scanner)) + + # Convert user pattern to regex + regex_pattern = MiscRoutes._convert_pattern_to_regex(pattern) + logger.info(f"Using pattern regex: {regex_pattern.pattern}") + + # Get all models from all scanners + all_models = [] + for scanner_type, scanner in scanners: + cache = await scanner.get_cached_data() + if cache and cache.raw_data: + for model in cache.raw_data: + # Only process models with a valid file path and sha256 + if model.get('file_path') and model.get('sha256'): + all_models.append((scanner_type, model, scanner)) + + # Update total count + download_progress['total'] = len(all_models) + logger.info(f"Found {download_progress['total']} models to check for example images") + + # Process each model + for scanner_type, model, scanner in all_models: + # Check if download is paused + while download_progress['status'] == 'paused': + await asyncio.sleep(1) + + # Check if download should continue + if download_progress['status'] != 'running': + logger.info(f"Migration stopped: {download_progress['status']}") + break + + model_hash = model.get('sha256', '').lower() + model_name = model.get('model_name', 'Unknown') + model_file_path = model.get('file_path', '') + model_file_name = os.path.basename(model_file_path) if model_file_path else '' + model_dir_path = os.path.dirname(model_file_path) if model_file_path else '' + + try: + # Update current model info + download_progress['current_model'] = f"{model_name} ({model_hash[:8]})" + + # Skip if already processed + if model_hash in download_progress['processed_models']: + logger.debug(f"Skipping already processed model: {model_name}") + download_progress['completed'] += 1 + continue + + # Find matching example files based on pattern + if model_file_name and os.path.exists(model_dir_path): + example_files = MiscRoutes._find_matching_example_files( + model_dir_path, + model_file_name, + regex_pattern + ) + + # Process found files + if example_files: + logger.info(f"Found {len(example_files)} example images for {model_name}") + + # Create model directory in output location + model_dir = os.path.join(output_dir, model_hash) + os.makedirs(model_dir, exist_ok=True) + + # Migrate each example file + for local_image_path, index in example_files: + # Get file extension + local_ext = os.path.splitext(local_image_path)[1].lower() + save_filename = f"image_{index}{local_ext}" + save_path = os.path.join(model_dir, save_filename) + + # Skip if already exists in output directory + if os.path.exists(save_path): + logger.debug(f"File already exists in output: {save_path}") + continue + + try: + # Copy the file + with open(local_image_path, 'rb') as src_file: + with open(save_path, 'wb') as dst_file: + dst_file.write(src_file.read()) + logger.debug(f"Migrated {os.path.basename(local_image_path)} to {save_path}") + except Exception as e: + error_msg = f"Failed to copy file {os.path.basename(local_image_path)}: {str(e)}" + logger.error(error_msg) + download_progress['errors'].append(error_msg) + download_progress['last_error'] = error_msg + + # Mark this model as processed + download_progress['processed_models'].add(model_hash) + + # Save progress to file periodically + if download_progress['completed'] % 10 == 0 or download_progress['completed'] == download_progress['total'] - 1: + progress_file = os.path.join(output_dir, '.download_progress.json') + with open(progress_file, 'w', encoding='utf-8') as f: + json.dump({ + 'processed_models': list(download_progress['processed_models']), + 'refreshed_models': list(download_progress['refreshed_models']), + 'completed': download_progress['completed'], + 'total': download_progress['total'], + 'last_update': time.time() + }, f, indent=2) + + except Exception as e: + error_msg = f"Error processing model {model.get('model_name')}: {str(e)}" + logger.error(error_msg, exc_info=True) + download_progress['errors'].append(error_msg) + download_progress['last_error'] = error_msg + + # Update progress + download_progress['completed'] += 1 + + # Mark as completed + download_progress['status'] = 'completed' + download_progress['end_time'] = time.time() + download_progress['is_migrating'] = False + logger.info(f"Example images migration completed: {download_progress['completed']}/{download_progress['total']} models processed") + + except Exception as e: + error_msg = f"Error during example images migration: {str(e)}" + logger.error(error_msg, exc_info=True) + download_progress['errors'].append(error_msg) + download_progress['last_error'] = error_msg + download_progress['status'] = 'error' + download_progress['end_time'] = time.time() + download_progress['is_migrating'] = False + + finally: + # Save final progress to file + try: + progress_file = os.path.join(output_dir, '.download_progress.json') + with open(progress_file, 'w', encoding='utf-8') as f: + json.dump({ + 'processed_models': list(download_progress['processed_models']), + 'refreshed_models': list(download_progress['refreshed_models']), + 'completed': download_progress['completed'], + 'total': download_progress['total'], + 'last_update': time.time(), + 'status': download_progress['status'], + 'is_migrating': False + }, f, indent=2) + except Exception as e: + logger.error(f"Failed to save progress file: {e}") + + # Set download status to not downloading + is_downloading = False + + @staticmethod + def _convert_pattern_to_regex(pattern): + """Convert a user-friendly template pattern to a regex pattern + + Args: + pattern: Template pattern string + + Returns: + re.Pattern: Compiled regex pattern object + """ + # Normalize path separators to forward slashes for consistent matching + pattern = pattern.replace('\\', '/') + + # Escape special regex characters + regex_safe = re.escape(pattern) + + # Handle multiple occurrences of {model} + model_count = pattern.count('{model}') + if model_count > 1: + # Replace the first occurrence with a named capture group + regex_safe = regex_safe.replace(r'\{model\}', r'(?P.*?)', 1) + + # Replace subsequent occurrences with a back-reference + # Using (?P=model) for Python's regex named backreference syntax + for _ in range(model_count - 1): + regex_safe = regex_safe.replace(r'\{model\}', r'(?P=model)', 1) + else: + # Just one occurrence, handle normally + regex_safe = regex_safe.replace(r'\{model\}', r'(?P.*?)') + + # {index} becomes a capture group for digits + regex_safe = regex_safe.replace(r'\{index\}', r'(?P\d+)') + + # {ext} becomes a capture group for file extension WITHOUT including the dot + regex_safe = regex_safe.replace(r'\{ext\}', r'(?P\w+)') + + # Handle wildcard * character (which was escaped earlier) + regex_safe = regex_safe.replace(r'\*', r'.*?') + + logger.info(f"Converted pattern '{pattern}' to regex: '{regex_safe}'") + + # Compile the regex pattern + return re.compile(regex_safe) + + @staticmethod + def _find_matching_example_files(dir_path, model_filename, regex_pattern): + """Find example files matching the pattern in the given directory + + Args: + dir_path: Directory to search in + model_filename: Model filename (without extension) + regex_pattern: Compiled regex pattern to match against + + Returns: + list: List of tuples (file_path, index) of matching files + """ + matching_files = [] + model_name = os.path.splitext(model_filename)[0] + if model_name == "FluxMechaKnights": + logger.info(f"Processing model: {model_name}") + + # Check if pattern contains a directory separator + has_subdirs = '/' in regex_pattern.pattern or '\\\\' in regex_pattern.pattern + + # Determine search paths (keep existing logic for subdirectories) + if has_subdirs: + # Handle patterns with subdirectories + subdir_match = re.match(r'.*(?P.*?)(/|\\\\).*', regex_pattern.pattern) + if subdir_match: + potential_subdir = os.path.join(dir_path, model_name) + if os.path.exists(potential_subdir) and os.path.isdir(potential_subdir): + search_paths = [potential_subdir] + else: + search_paths = [dir_path] + else: + search_paths = [dir_path] + else: + search_paths = [dir_path] + + for search_path in search_paths: + if not os.path.exists(search_path): + continue + + # For optimized performance: create a model name prefix check + # This works for any pattern where the model name appears at the start + if not has_subdirs: + # Get list of all files first + all_files = os.listdir(search_path) + + # First pass: filter files that start with model name (case insensitive) + # This is much faster than regex for initial filtering + potential_matches = [] + lower_model_name = model_name.lower() + + for file in all_files: + # Quick check if file starts with model name + if file.lower().startswith(lower_model_name): + file_path = os.path.join(search_path, file) + if os.path.isfile(file_path): + potential_matches.append((file, file_path)) + + # Second pass: apply full regex only to potential matches + for file, file_path in potential_matches: + match = regex_pattern.match(file) + if match: + # Verify model name matches exactly what we're looking for + if match.group('model') != model_name: + logger.debug(f"File {file} matched pattern but model name {match.group('model')} doesn't match {model_name}") + continue + + # Check if file extension is supported + file_ext = os.path.splitext(file)[1].lower() + is_supported = (file_ext in SUPPORTED_MEDIA_EXTENSIONS['images'] or + file_ext in SUPPORTED_MEDIA_EXTENSIONS['videos']) + + if is_supported: + # Extract index from match + try: + index = int(match.group('index')) + except (IndexError, ValueError): + index = len(matching_files) + 1 + + matching_files.append((file_path, index)) + else: + # Original scanning logic for patterns with subdirectories + for file in os.listdir(search_path): + file_path = os.path.join(search_path, file) + if os.path.isfile(file_path): + # Try to match the filename directly first + match = regex_pattern.match(file) + + # If no match and subdirs are expected, try the relative path + if not match and has_subdirs: + # Get relative path and normalize slashes for consistent matching + rel_path = os.path.relpath(file_path, dir_path) + # Replace Windows backslashes with forward slashes for consistent regex matching + rel_path = rel_path.replace('\\', '/') + match = regex_pattern.match(rel_path) + + if match: + # For subdirectory patterns, model name in the match might refer to the dir name only + # so we need a different checking logic + matched_model = match.group('model') + if has_subdirs and '/' in rel_path: + # For subdirectory patterns, it's okay if just the folder name matches + folder_name = rel_path.split('/')[0] + if matched_model != model_name and matched_model != folder_name: + logger.debug(f"File {file} matched pattern but model name {matched_model} doesn't match {model_name}") + continue + elif matched_model != model_name: + logger.debug(f"File {file} matched pattern but model name {matched_model} doesn't match {model_name}") + continue + + file_ext = os.path.splitext(file)[1].lower() + is_supported = (file_ext in SUPPORTED_MEDIA_EXTENSIONS['images'] or + file_ext in SUPPORTED_MEDIA_EXTENSIONS['videos']) + + if is_supported: + try: + index = int(match.group('index')) + except (IndexError, ValueError): + index = len(matching_files) + 1 + + matching_files.append((file_path, index)) + + # Sort files by their index + matching_files.sort(key=lambda x: x[1]) + return matching_files @staticmethod async def update_lora_code(request): diff --git a/static/css/components/modal.css b/static/css/components/modal.css index 21e7914e..ad725409 100644 --- a/static/css/components/modal.css +++ b/static/css/components/modal.css @@ -306,6 +306,18 @@ body.modal-open { width: 100%; /* Full width */ } +/* Migrate control styling */ +.migrate-control { + display: flex; + align-items: center; + gap: 8px; +} + +.migrate-control input { + flex: 1; + min-width: 0; +} + /* 统一各个 section 的样式 */ .support-section, .changelog-section, diff --git a/static/js/managers/ExampleImagesManager.js b/static/js/managers/ExampleImagesManager.js index 3f5ed909..dca79ea9 100644 --- a/static/js/managers/ExampleImagesManager.js +++ b/static/js/managers/ExampleImagesManager.js @@ -11,6 +11,7 @@ class ExampleImagesManager { this.progressPanel = null; this.isProgressPanelCollapsed = false; this.pauseButton = null; // Store reference to the pause button + this.isMigrating = false; // Track migration state separately from downloading // Initialize download path field and check download status this.initializePathOptions(); @@ -46,6 +47,12 @@ class ExampleImagesManager { if (collapseBtn) { collapseBtn.onclick = () => this.toggleProgressPanel(); } + + // Initialize migration button handler + const migrateBtn = document.getElementById('exampleImagesMigrateBtn'); + if (migrateBtn) { + migrateBtn.onclick = () => this.handleMigrateButton(); + } } // Initialize event listeners for buttons @@ -141,6 +148,75 @@ class ExampleImagesManager { } } + // Method to handle migrate button click + async handleMigrateButton() { + if (this.isDownloading || this.isMigrating) { + if (this.isPaused) { + // If paused, resume + this.resumeDownload(); + } else { + showToast('Migration or download already in progress', 'info'); + } + return; + } + + // Start migration + this.startMigrate(); + } + + async startMigrate() { + try { + const outputDir = document.getElementById('exampleImagesPath').value || ''; + + if (!outputDir) { + showToast('Please enter a download location first', 'warning'); + return; + } + + const pattern = document.getElementById('exampleImagesMigratePattern').value || '{model}.example.{index}.{ext}'; + const optimize = document.getElementById('optimizeExampleImages').checked; + + const response = await fetch('/api/migrate-example-images', { + method: 'POST', + headers: { + 'Content-Type': 'application/json' + }, + body: JSON.stringify({ + output_dir: outputDir, + pattern: pattern, + optimize: optimize, + model_types: ['lora', 'checkpoint'] + }) + }); + + const data = await response.json(); + + if (data.success) { + this.isDownloading = true; + this.isMigrating = true; + this.isPaused = false; + this.startTime = new Date(); + this.updateUI(data.status); + this.showProgressPanel(); + this.startProgressUpdates(); + // Update button text + const btnTextElement = document.getElementById('exampleDownloadBtnText'); + if (btnTextElement) { + btnTextElement.textContent = "Resume"; + } + showToast('Example images migration started', 'success'); + + // Close settings modal + modalManager.closeModal('settingsModal'); + } else { + showToast(data.error || 'Failed to start migration', 'error'); + } + } catch (error) { + console.error('Failed to start migration:', error); + showToast('Failed to start migration', 'error'); + } + } + async checkDownloadStatus() { try { const response = await fetch('/api/example-images-status'); @@ -334,6 +410,7 @@ class ExampleImagesManager { if (data.success) { this.isDownloading = data.is_downloading; this.isPaused = data.status.status === 'paused'; + this.isMigrating = data.is_migrating || false; // Update download button text this.updateDownloadButtonText(); @@ -346,11 +423,16 @@ class ExampleImagesManager { this.progressUpdateInterval = null; if (data.status.status === 'completed') { - showToast('Example images download completed', 'success'); + const actionType = this.isMigrating ? 'migration' : 'download'; + showToast(`Example images ${actionType} completed`, 'success'); + // Reset migration flag + this.isMigrating = false; // Hide the panel after a delay setTimeout(() => this.hideProgressPanel(), 5000); } else if (data.status.status === 'error') { - showToast('Example images download failed', 'error'); + const actionType = this.isMigrating ? 'migration' : 'download'; + showToast(`Example images ${actionType} failed`, 'error'); + this.isMigrating = false; } } } @@ -441,6 +523,19 @@ class ExampleImagesManager { this.updateMiniProgress(progressPercent); } } + + // Update title text + const titleElement = document.querySelector('.progress-panel-title'); + if (titleElement) { + const titleIcon = titleElement.querySelector('i'); + if (titleIcon) { + titleIcon.className = this.isMigrating ? 'fas fa-file-import' : 'fas fa-images'; + } + + titleElement.innerHTML = + ` ` + + `${this.isMigrating ? 'Example Images Migration' : 'Example Images Download'}`; + } } // Update the mini progress circle in the pause button @@ -536,8 +631,10 @@ class ExampleImagesManager { } getStatusText(status) { + const prefix = this.isMigrating ? 'Migrating' : 'Downloading'; + switch (status) { - case 'running': return 'Downloading'; + case 'running': return this.isMigrating ? 'Migrating' : 'Downloading'; case 'paused': return 'Paused'; case 'completed': return 'Completed'; case 'error': return 'Error'; diff --git a/templates/components/modals.html b/templates/components/modals.html index 5807f132..fd22cce0 100644 --- a/templates/components/modals.html +++ b/templates/components/modals.html @@ -271,6 +271,27 @@ Enter the folder path where example images from Civitai will be saved + + +
+
+
+ +
+
+ + +
+
+
+ Pattern to find existing example images. Use {model} for model filename, {index} for numbering, and {ext} for file extension.
+ Example patterns: "{model}.example.{index}.{ext}", "{model}_{index}.{ext}", "{model}/{model}.example.{index}.{ext}" +
+