diff --git a/py/routes/example_images_routes.py b/py/routes/example_images_routes.py index 585144f5..bb8c4cba 100644 --- a/py/routes/example_images_routes.py +++ b/py/routes/example_images_routes.py @@ -2,6 +2,7 @@ import logging import os import asyncio import json +import tempfile import time import aiohttp import re @@ -38,7 +39,7 @@ class ExampleImagesRoutes: def setup_routes(app): """Register example images routes""" app.router.add_post('/api/download-example-images', ExampleImagesRoutes.download_example_images) - app.router.add_post('/api/migrate-example-images', ExampleImagesRoutes.migrate_example_images) + app.router.add_post('/api/import-example-images', ExampleImagesRoutes.import_example_images) app.router.add_get('/api/example-images-status', ExampleImagesRoutes.get_example_images_status) app.router.add_post('/api/pause-example-images', ExampleImagesRoutes.pause_example_images) app.router.add_post('/api/resume-example-images', ExampleImagesRoutes.resume_example_images) @@ -199,7 +200,7 @@ class ExampleImagesRoutes: 'success': False, 'error': f"Download is in '{download_progress['status']}' state, cannot resume" }, status=400) - + @staticmethod async def _refresh_model_metadata(model_hash, model_name, scanner_type, scanner): """Refresh model metadata from CivitAI @@ -781,454 +782,6 @@ class ExampleImagesRoutes: # Set download status to not downloading is_downloading = False - @staticmethod - async def migrate_example_images(request): - """ - Migrate existing example images to central storage location - - Expects a JSON body with: - { - "output_dir": "path/to/output", # Base directory to save example images - "pattern": "{model}.example.{index}.{ext}", # Pattern to match example images - "optimize": true, # Whether to optimize images (default: true) - "model_types": ["lora", "checkpoint"], # Model types to process (default: both) - } - """ - global download_task, is_downloading, download_progress - - if is_downloading: - # Create a copy for JSON serialization - response_progress = download_progress.copy() - response_progress['processed_models'] = list(download_progress['processed_models']) - response_progress['refreshed_models'] = list(download_progress['refreshed_models']) - - return web.json_response({ - 'success': False, - 'error': 'Download or migration already in progress', - 'status': response_progress - }, status=400) - - try: - # Parse the request body - data = await request.json() - output_dir = data.get('output_dir') - pattern = data.get('pattern', '{model}.example.{index}.{ext}') - optimize = data.get('optimize', True) - model_types = data.get('model_types', ['lora', 'checkpoint']) - - if not output_dir: - return web.json_response({ - 'success': False, - 'error': 'Missing output_dir parameter' - }, status=400) - - # Create the output directory - os.makedirs(output_dir, exist_ok=True) - - # Initialize progress tracking - download_progress['total'] = 0 - download_progress['completed'] = 0 - download_progress['current_model'] = '' - download_progress['status'] = 'running' - download_progress['errors'] = [] - download_progress['last_error'] = None - download_progress['start_time'] = time.time() - download_progress['end_time'] = None - download_progress['is_migrating'] = True # Mark this as a migration task - - # Get the processed models list from a file if it exists - progress_file = os.path.join(output_dir, '.download_progress.json') - if os.path.exists(progress_file): - try: - with open(progress_file, 'r', encoding='utf-8') as f: - saved_progress = json.load(f) - download_progress['processed_models'] = set(saved_progress.get('processed_models', [])) - logger.info(f"Loaded previous progress, {len(download_progress['processed_models'])} models already processed") - except Exception as e: - logger.error(f"Failed to load progress file: {e}") - download_progress['processed_models'] = set() - else: - download_progress['processed_models'] = set() - - # Start the migration task - is_downloading = True - download_task = asyncio.create_task( - ExampleImagesRoutes._migrate_all_example_images( - output_dir, - pattern, - optimize, - model_types - ) - ) - - # Create a copy for JSON serialization - response_progress = download_progress.copy() - response_progress['processed_models'] = list(download_progress['processed_models']) - response_progress['refreshed_models'] = list(download_progress['refreshed_models']) - response_progress['is_migrating'] = True - - return web.json_response({ - 'success': True, - 'message': 'Migration started', - 'status': response_progress - }) - - except Exception as e: - logger.error(f"Failed to start example images migration: {e}", exc_info=True) - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) - - @staticmethod - async def _migrate_all_example_images(output_dir, pattern, optimize, model_types): - """Migrate example images for all models based on pattern - - Args: - output_dir: Base directory to save example images - pattern: Pattern to match example images - optimize: Whether to optimize images - model_types: List of model types to process - """ - global is_downloading, download_progress - - try: - # Get the scanners - scanners = [] - if 'lora' in model_types: - lora_scanner = await ServiceRegistry.get_lora_scanner() - scanners.append(('lora', lora_scanner)) - - if 'checkpoint' in model_types: - checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner() - scanners.append(('checkpoint', checkpoint_scanner)) - - # Convert user pattern to regex - regex_pattern = ExampleImagesRoutes._convert_pattern_to_regex(pattern) - logger.info(f"Using pattern regex: {regex_pattern.pattern}") - - # Get all models from all scanners - all_models = [] - for scanner_type, scanner in scanners: - cache = await scanner.get_cached_data() - if cache and cache.raw_data: - for model in cache.raw_data: - # Only process models with a valid file path and sha256 - if model.get('file_path') and model.get('sha256'): - all_models.append((scanner_type, model, scanner)) - - # Update total count - download_progress['total'] = len(all_models) - logger.info(f"Found {download_progress['total']} models to check for example images") - - # Process each model - for scanner_type, model, scanner in all_models: - # Check if download is paused - while download_progress['status'] == 'paused': - await asyncio.sleep(1) - - # Check if download should continue - if download_progress['status'] != 'running': - logger.info(f"Migration stopped: {download_progress['status']}") - break - - model_hash = model.get('sha256', '').lower() - model_name = model.get('model_name', 'Unknown') - model_file_path = model.get('file_path', '') - model_file_name = os.path.basename(model_file_path) if model_file_path else '' - model_dir_path = os.path.dirname(model_file_path) if model_file_path else '' - - try: - # Update current model info - download_progress['current_model'] = f"{model_name} ({model_hash[:8]})" - - # Skip if already processed - if model_hash in download_progress['processed_models']: - logger.debug(f"Skipping already processed model: {model_name}") - download_progress['completed'] += 1 - continue - - # Find matching example files based on pattern - if model_file_name and os.path.exists(model_dir_path): - example_files = ExampleImagesRoutes._find_matching_example_files( - model_dir_path, - model_file_name, - regex_pattern - ) - - # Process found files - if example_files: - logger.info(f"Found {len(example_files)} example images for {model_name}") - - # Create model directory in output location - model_dir = os.path.join(output_dir, model_hash) - os.makedirs(model_dir, exist_ok=True) - - # Track local image paths for metadata update - local_image_paths = [] - - # Migrate each example file - for local_image_path, index in example_files: - # Get file extension - local_ext = os.path.splitext(local_image_path)[1].lower() - save_filename = f"image_{index}{local_ext}" - save_path = os.path.join(model_dir, save_filename) - - # Track all local image paths for potential metadata update - local_image_paths.append(local_image_path) - - # Skip if already exists in output directory - if os.path.exists(save_path): - logger.debug(f"File already exists in output: {save_path}") - continue - - try: - # Copy the file - with open(local_image_path, 'rb') as src_file: - with open(save_path, 'wb') as dst_file: - dst_file.write(src_file.read()) - logger.debug(f"Migrated {os.path.basename(local_image_path)} to {save_path}") - except Exception as e: - error_msg = f"Failed to copy file {os.path.basename(local_image_path)}: {str(e)}" - logger.error(error_msg) - download_progress['errors'].append(error_msg) - download_progress['last_error'] = error_msg - - # Update model metadata if local images were found - if local_image_paths: - await ExampleImagesRoutes._update_model_metadata_from_local_examples( - model, - local_image_paths, - scanner_type, - scanner - ) - - # Mark this model as processed - download_progress['processed_models'].add(model_hash) - - # Save progress to file periodically - if download_progress['completed'] % 10 == 0 or download_progress['completed'] == download_progress['total'] - 1: - progress_file = os.path.join(output_dir, '.download_progress.json') - with open(progress_file, 'w', encoding='utf-8') as f: - json.dump({ - 'processed_models': list(download_progress['processed_models']), - 'refreshed_models': list(download_progress['refreshed_models']), - 'completed': download_progress['completed'], - 'total': download_progress['total'], - 'last_update': time.time() - }, f, indent=2) - - except Exception as e: - error_msg = f"Error processing model {model.get('model_name')}: {str(e)}" - logger.error(error_msg, exc_info=True) - download_progress['errors'].append(error_msg) - download_progress['last_error'] = error_msg - - # Update progress - download_progress['completed'] += 1 - - # Mark as completed - download_progress['status'] = 'completed' - download_progress['end_time'] = time.time() - download_progress['is_migrating'] = False - logger.info(f"Example images migration completed: {download_progress['completed']}/{download_progress['total']} models processed") - - except Exception as e: - error_msg = f"Error during example images migration: {str(e)}" - logger.error(error_msg, exc_info=True) - download_progress['errors'].append(error_msg) - download_progress['last_error'] = error_msg - download_progress['status'] = 'error' - download_progress['end_time'] = time.time() - download_progress['is_migrating'] = False - - finally: - # Save final progress to file - try: - progress_file = os.path.join(output_dir, '.download_progress.json') - with open(progress_file, 'w', encoding='utf-8') as f: - json.dump({ - 'processed_models': list(download_progress['processed_models']), - 'refreshed_models': list(download_progress['refreshed_models']), - 'completed': download_progress['completed'], - 'total': download_progress['total'], - 'last_update': time.time(), - 'status': download_progress['status'], - 'is_migrating': False - }, f, indent=2) - except Exception as e: - logger.error(f"Failed to save progress file: {e}") - - # Set download status to not downloading - is_downloading = False - - @staticmethod - def _convert_pattern_to_regex(pattern): - """Convert a user-friendly template pattern to a regex pattern - - Args: - pattern: Template pattern string - - Returns: - re.Pattern: Compiled regex pattern object - """ - # Normalize path separators to forward slashes for consistent matching - pattern = pattern.replace('\\', '/') - - # Escape special regex characters - regex_safe = re.escape(pattern) - - # Handle multiple occurrences of {model} - model_count = pattern.count('{model}') - if (model_count > 1): - # Replace the first occurrence with a named capture group - regex_safe = regex_safe.replace(r'\{model\}', r'(?P.*?)', 1) - - # Replace subsequent occurrences with a back-reference - # Using (?P=model) for Python's regex named backreference syntax - for _ in range(model_count - 1): - regex_safe = regex_safe.replace(r'\{model\}', r'(?P=model)', 1) - else: - # Just one occurrence, handle normally - regex_safe = regex_safe.replace(r'\{model\}', r'(?P.*?)') - - # {index} becomes a capture group for digits - regex_safe = regex_safe.replace(r'\{index\}', r'(?P\d+)') - - # {ext} becomes a capture group for file extension WITHOUT including the dot - regex_safe = regex_safe.replace(r'\{ext\}', r'(?P\w+)') - - # Handle wildcard * character (which was escaped earlier) - regex_safe = regex_safe.replace(r'\*', r'.*?') - - logger.info(f"Converted pattern '{pattern}' to regex: '{regex_safe}'") - - # Compile the regex pattern - return re.compile(regex_safe) - - @staticmethod - def _find_matching_example_files(dir_path, model_filename, regex_pattern): - """Find example files matching the pattern in the given directory - - Args: - dir_path: Directory to search in - model_filename: Model filename (without extension) - regex_pattern: Compiled regex pattern to match against - - Returns: - list: List of tuples (file_path, index) of matching files - """ - matching_files = [] - model_name = os.path.splitext(model_filename)[0] - - # Check if pattern contains a directory separator - has_subdirs = '/' in regex_pattern.pattern or '\\\\' in regex_pattern.pattern - - # Determine search paths (keep existing logic for subdirectories) - if has_subdirs: - # Handle patterns with subdirectories - subdir_match = re.match(r'.*(?P.*?)(/|\\\\).*', regex_pattern.pattern) - if subdir_match: - potential_subdir = os.path.join(dir_path, model_name) - if os.path.exists(potential_subdir) and os.path.isdir(potential_subdir): - search_paths = [potential_subdir] - else: - search_paths = [dir_path] - else: - search_paths = [dir_path] - else: - search_paths = [dir_path] - - for search_path in search_paths: - if not os.path.exists(search_path): - continue - - # For optimized performance: create a model name prefix check - # This works for any pattern where the model name appears at the start - if not has_subdirs: - # Get list of all files first - all_files = os.listdir(search_path) - - # First pass: filter files that start with model name (case insensitive) - # This is much faster than regex for initial filtering - potential_matches = [] - lower_model_name = model_name.lower() - - for file in all_files: - # Quick check if file starts with model name - if file.lower().startswith(lower_model_name): - file_path = os.path.join(search_path, file) - if os.path.isfile(file_path): - potential_matches.append((file, file_path)) - - # Second pass: apply full regex only to potential matches - for file, file_path in potential_matches: - match = regex_pattern.match(file) - if match: - # Verify model name matches exactly what we're looking for - if match.group('model') != model_name: - logger.debug(f"File {file} matched pattern but model name {match.group('model')} doesn't match {model_name}") - continue - - # Check if file extension is supported - file_ext = os.path.splitext(file)[1].lower() - is_supported = (file_ext in SUPPORTED_MEDIA_EXTENSIONS['images'] or - file_ext in SUPPORTED_MEDIA_EXTENSIONS['videos']) - - if is_supported: - # Extract index from match - try: - index = int(match.group('index')) - except (IndexError, ValueError): - index = len(matching_files) + 1 - - matching_files.append((file_path, index)) - else: - # Original scanning logic for patterns with subdirectories - for file in os.listdir(search_path): - file_path = os.path.join(search_path, file) - if os.path.isfile(file_path): - # Try to match the filename directly first - match = regex_pattern.match(file) - - # If no match and subdirs are expected, try the relative path - if not match and has_subdirs: - # Get relative path and normalize slashes for consistent matching - rel_path = os.path.relpath(file_path, dir_path) - # Replace Windows backslashes with forward slashes for consistent regex matching - rel_path = rel_path.replace('\\', '/') - match = regex_pattern.match(rel_path) - - if match: - # For subdirectory patterns, model name in the match might refer to the dir name only - # so we need a different checking logic - matched_model = match.group('model') - if has_subdirs and '/' in rel_path: - # For subdirectory patterns, it's okay if just the folder name matches - folder_name = rel_path.split('/')[0] - if matched_model != model_name and matched_model != folder_name: - logger.debug(f"File {file} matched pattern but model name {matched_model} doesn't match {model_name}") - continue - elif matched_model != model_name: - logger.debug(f"File {file} matched pattern but model name {matched_model} doesn't match {model_name}") - continue - - file_ext = os.path.splitext(file)[1].lower() - is_supported = (file_ext in SUPPORTED_MEDIA_EXTENSIONS['images'] or - file_ext in SUPPORTED_MEDIA_EXTENSIONS['videos']) - - if is_supported: - try: - index = int(match.group('index')) - except (IndexError, ValueError): - index = len(matching_files) + 1 - - matching_files.append((file_path, index)) - - # Sort files by their index - matching_files.sort(key=lambda x: x[1]) - return matching_files - @staticmethod async def open_example_images_folder(request): """ @@ -1426,3 +979,269 @@ class ExampleImagesRoutes: 'success': False, 'error': str(e) }, status=500) + + @staticmethod + async def import_example_images(request): + """ + Import local example images for a model + + Expects: + - multipart/form-data with model_hash and files fields + OR + - JSON request with model_hash and file_paths + + Returns: + - Success status and list of imported files + """ + try: + model_hash = None + files_to_import = [] + temp_files_to_cleanup = [] + + # Check if this is a multipart form data request (direct file upload) + if request.content_type and 'multipart/form-data' in request.content_type: + reader = await request.multipart() + + # First, get the model_hash + field = await reader.next() + if field.name == 'model_hash': + model_hash = await field.text() + + # Then process all files + while True: + field = await reader.next() + if field is None: + break + + if field.name == 'files': + # Create a temporary file with a proper suffix for type detection + file_name = field.filename + file_ext = os.path.splitext(file_name)[1].lower() + + with tempfile.NamedTemporaryFile(suffix=file_ext, delete=False) as tmp_file: + temp_path = tmp_file.name + temp_files_to_cleanup.append(temp_path) # Track for cleanup + + # Write chunks to the temp file + while True: + chunk = await field.read_chunk() + if not chunk: + break + tmp_file.write(chunk) + + # Add to our list of files to process + files_to_import.append(temp_path) + else: + # Parse JSON request (legacy method with file paths) + data = await request.json() + model_hash = data.get('model_hash') + files_to_import = data.get('file_paths', []) + + if not model_hash: + return web.json_response({ + 'success': False, + 'error': 'Missing model_hash parameter' + }, status=400) + + if not files_to_import: + return web.json_response({ + 'success': False, + 'error': 'No files provided to import' + }, status=400) + + # Get example images path + example_images_path = settings.get('example_images_path') + if not example_images_path: + return web.json_response({ + 'success': False, + 'error': 'No example images path configured' + }, status=400) + + # Find the model and get current metadata + lora_scanner = await ServiceRegistry.get_lora_scanner() + checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner() + + model_data = None + scanner = None + + # Check both scanners to find the model + for scan_obj in [lora_scanner, checkpoint_scanner]: + cache = await scan_obj.get_cached_data() + for item in cache.raw_data: + if item.get('sha256') == model_hash: + model_data = item + scanner = scan_obj + break + if model_data: + break + + if not model_data: + return web.json_response({ + 'success': False, + 'error': f"Model with hash {model_hash} not found in cache" + }, status=404) + + # Get current number of images in civitai.images array + civitai_data = model_data.get('civitai') + current_images = civitai_data.get('images', []) if civitai_data is not None else [] + next_index = len(current_images) + + # Create model folder + model_folder = os.path.join(example_images_path, model_hash) + os.makedirs(model_folder, exist_ok=True) + + imported_files = [] + errors = [] + newly_imported_paths = [] + + # Process each file path + for file_path in files_to_import: + try: + # Ensure file exists + if not os.path.isfile(file_path): + errors.append(f"File not found: {file_path}") + continue + + # Check if file type is supported + file_ext = os.path.splitext(file_path)[1].lower() + if not (file_ext in SUPPORTED_MEDIA_EXTENSIONS['images'] or + file_ext in SUPPORTED_MEDIA_EXTENSIONS['videos']): + errors.append(f"Unsupported file type: {file_path}") + continue + + # Generate new filename with sequential index starting from current images length + new_filename = f"image_{next_index}{file_ext}" + next_index += 1 + + dest_path = os.path.join(model_folder, new_filename) + + # Copy the file + import shutil + shutil.copy2(file_path, dest_path) + newly_imported_paths.append(dest_path) + + # Add to imported files list + imported_files.append({ + 'name': new_filename, + 'path': f'/example_images_static/{model_hash}/{new_filename}', + 'extension': file_ext, + 'is_video': file_ext in SUPPORTED_MEDIA_EXTENSIONS['videos'] + }) + except Exception as e: + errors.append(f"Error importing {file_path}: {str(e)}") + + # Update metadata with new example images + updated_images = await ExampleImagesRoutes._update_metadata_after_import( + model_hash, + model_data, + scanner, + newly_imported_paths + ) + + return web.json_response({ + 'success': len(imported_files) > 0, + 'message': f'Successfully imported {len(imported_files)} files' + + (f' with {len(errors)} errors' if errors else ''), + 'files': imported_files, + 'errors': errors, + 'updated_images': updated_images, + "model_file_path": model_data.get('file_path', ''), + }) + + except Exception as e: + logger.error(f"Failed to import example images: {e}", exc_info=True) + return web.json_response({ + 'success': False, + 'error': str(e) + }, status=500) + finally: + # Clean up temporary files if any + for temp_file in temp_files_to_cleanup: + try: + os.remove(temp_file) + except Exception as e: + logger.error(f"Failed to remove temporary file {temp_file}: {e}") + + @staticmethod + async def _update_metadata_after_import(model_hash, model_data, scanner, newly_imported_paths): + """ + Update model metadata after importing example images by appending new images to the existing array + + Args: + model_hash: SHA256 hash of the model + model_data: Model data dictionary + scanner: Scanner instance (lora or checkpoint) + newly_imported_paths: List of paths to newly imported files + + Returns: + list: Updated images array + """ + try: + # Ensure civitai field exists in model data + if not model_data.get('civitai'): + model_data['civitai'] = {} + + # Ensure images array exists + if not model_data['civitai'].get('images'): + model_data['civitai']['images'] = [] + + # Get current images array + images = model_data['civitai']['images'] + + # Add new image entries for each imported file + for path in newly_imported_paths: + # Determine if it's a video or image + file_ext = os.path.splitext(path)[1].lower() + is_video = file_ext in SUPPORTED_MEDIA_EXTENSIONS['videos'] + + # Create image metadata entry + image_entry = { + "url": "", # Empty URL as requested + "nsfwLevel": 0, + "width": 720, # Default dimensions + "height": 1280, + "type": "video" if is_video else "image", + "meta": None, + "hasMeta": False, + "hasPositivePrompt": False + } + + # Try to get actual dimensions if it's an image + try: + from PIL import Image + if not is_video and os.path.exists(path): + with Image.open(path) as img: + image_entry["width"], image_entry["height"] = img.size + except: + # If PIL fails or isn't available, use default dimensions + pass + + # Append to the existing images array + images.append(image_entry) + + # Save metadata to the .metadata.json file + file_path = model_data.get('file_path') + if file_path: + base_path = os.path.splitext(file_path)[0] + metadata_path = f"{base_path}.metadata.json" + try: + # Create a copy of the model data without the 'folder' field + model_copy = model_data.copy() + model_copy.pop('folder', None) + + # Write the metadata to file + with open(metadata_path, 'w', encoding='utf-8') as f: + json.dump(model_copy, f, indent=2, ensure_ascii=False) + logger.info(f"Saved metadata to {metadata_path}") + except Exception as e: + logger.error(f"Failed to save metadata to {metadata_path}: {str(e)}") + + # Save updated metadata to scanner cache + if file_path: + await scanner.update_single_model_cache(file_path, file_path, model_data) + + return images + + except Exception as e: + logger.error(f"Failed to update metadata after import: {e}", exc_info=True) + return [] \ No newline at end of file diff --git a/static/css/components/lora-modal/showcase.css b/static/css/components/lora-modal/showcase.css index d1d684e8..843da923 100644 --- a/static/css/components/lora-modal/showcase.css +++ b/static/css/components/lora-modal/showcase.css @@ -289,4 +289,95 @@ .lazy[src] { opacity: 1; +} + +/* Example Import Area */ +.example-import-area { + margin-top: var(--space-4); + padding: var(--space-2); +} + +.example-import-area.empty { + margin-top: var(--space-2); + padding: var(--space-4) var(--space-2); +} + +.import-container { + border: 2px dashed var(--border-color); + border-radius: var(--border-radius-sm); + padding: var(--space-4); + text-align: center; + transition: all 0.3s ease; + background: var(--lora-surface); + cursor: pointer; +} + +.import-container.highlight { + border-color: var(--lora-accent); + background: oklch(var(--lora-accent-l) var(--lora-accent-c) var(--lora-accent-h) / 0.1); + transform: scale(1.01); +} + +.import-placeholder { + display: flex; + flex-direction: column; + align-items: center; + gap: var(--space-1); + padding-top: var(--space-1); +} + +.import-placeholder i { + font-size: 2.5rem; + /* color: var(--lora-accent); */ + opacity: 0.8; + margin-bottom: var(--space-1); +} + +.import-placeholder h3 { + margin: 0 0 var(--space-1); + font-size: 1.2rem; + font-weight: 500; + color: var(--text-color); +} + +.import-placeholder p { + margin: var(--space-1) 0; + color: var(--text-color); + opacity: 0.8; +} + +.import-placeholder .sub-text { + font-size: 0.9em; + opacity: 0.6; + margin: var(--space-1) 0; +} + +.import-formats { + font-size: 0.8em !important; + opacity: 0.6 !important; + margin-top: var(--space-2) !important; +} + +.select-files-btn { + background: var(--lora-accent); + color: var(--lora-text); + border: none; + border-radius: var(--border-radius-xs); + padding: var(--space-2) var(--space-3); + cursor: pointer; + font-size: 0.9em; + display: flex; + align-items: center; + gap: 8px; + transition: all 0.2s; +} + +.select-files-btn:hover { + opacity: 0.9; + transform: translateY(-1px); +} + +/* For dark theme */ +[data-theme="dark"] .import-container { + background: rgba(255, 255, 255, 0.03); } \ No newline at end of file diff --git a/static/js/components/loraModal/ShowcaseView.js b/static/js/components/loraModal/ShowcaseView.js index 5e2340b1..b7571aa5 100644 --- a/static/js/components/loraModal/ShowcaseView.js +++ b/static/js/components/loraModal/ShowcaseView.js @@ -17,7 +17,10 @@ import { NSFW_LEVELS } from '../../utils/constants.js'; * @returns {Promise} HTML内容 */ export function renderShowcaseContent(images, exampleFiles = []) { - if (!images?.length) return '
No example images available
'; + if (!images?.length) { + // Replace empty message with import interface + return renderImportInterface(true); + } // Filter images based on SFW setting const showOnlySFW = state.settings.show_only_sfw; @@ -136,10 +139,202 @@ export function renderShowcaseContent(images, exampleFiles = []) { ); }).join('')} + + + ${renderImportInterface(false)} `; } +/** + * Render the import interface for example images + * @param {boolean} isEmpty - Whether there are no existing examples + * @returns {string} HTML content for import interface + */ +function renderImportInterface(isEmpty) { + return ` +
+
+
+ +

${isEmpty ? 'No example images available' : 'Add more examples'}

+

Drag & drop images or videos here

+

or

+ +

Supported formats: jpg, png, gif, webp, mp4, webm

+
+ + +
+
+ `; +} + +/** + * Initialize the import functionality for example images + * @param {string} modelHash - The SHA256 hash of the model + * @param {Element} container - The container element for the import area + */ +export function initExampleImport(modelHash, container) { + if (!container) return; + + const importContainer = container.querySelector('#exampleImportContainer'); + const fileInput = container.querySelector('#exampleFilesInput'); + const selectFilesBtn = container.querySelector('#selectExampleFilesBtn'); + + // Set up file selection button + if (selectFilesBtn) { + selectFilesBtn.addEventListener('click', () => { + fileInput.click(); + }); + } + + // Handle file selection + if (fileInput) { + fileInput.addEventListener('change', (e) => { + if (e.target.files.length > 0) { + handleImportFiles(Array.from(e.target.files), modelHash, importContainer); + } + }); + } + + // Set up drag and drop + if (importContainer) { + ['dragenter', 'dragover', 'dragleave', 'drop'].forEach(eventName => { + importContainer.addEventListener(eventName, preventDefaults, false); + }); + + function preventDefaults(e) { + e.preventDefault(); + e.stopPropagation(); + } + + // Highlight drop area on drag over + ['dragenter', 'dragover'].forEach(eventName => { + importContainer.addEventListener(eventName, () => { + importContainer.classList.add('highlight'); + }, false); + }); + + // Remove highlight on drag leave + ['dragleave', 'drop'].forEach(eventName => { + importContainer.addEventListener(eventName, () => { + importContainer.classList.remove('highlight'); + }, false); + }); + + // Handle dropped files + importContainer.addEventListener('drop', (e) => { + const files = Array.from(e.dataTransfer.files); + handleImportFiles(files, modelHash, importContainer); + }, false); + } +} + +/** + * Handle the file import process + * @param {File[]} files - Array of files to import + * @param {string} modelHash - The SHA256 hash of the model + * @param {Element} importContainer - The container element for import UI + */ +async function handleImportFiles(files, modelHash, importContainer) { + // Filter for supported file types + const supportedImages = ['.jpg', '.jpeg', '.png', '.gif', '.webp']; + const supportedVideos = ['.mp4', '.webm']; + const supportedExtensions = [...supportedImages, ...supportedVideos]; + + const validFiles = files.filter(file => { + const ext = '.' + file.name.split('.').pop().toLowerCase(); + return supportedExtensions.includes(ext); + }); + + if (validFiles.length === 0) { + alert('No supported files selected. Please select image or video files.'); + return; + } + + try { + // Get file paths to send to backend + const filePaths = validFiles.map(file => { + // We need the full path, but we only have the filename + // For security reasons, browsers don't provide full paths + // This will only work if the backend can handle just filenames + return URL.createObjectURL(file); + }); + + // Use FileReader to get the file data for direct upload + const formData = new FormData(); + formData.append('model_hash', modelHash); + + validFiles.forEach(file => { + formData.append('files', file); + }); + + // Call API to import files + const response = await fetch('/api/import-example-images', { + method: 'POST', + body: formData + }); + + const result = await response.json(); + + if (!result.success) { + throw new Error(result.error || 'Failed to import example files'); + } + + // Get updated local files + const updatedFilesResponse = await fetch(`/api/example-image-files?model_hash=${modelHash}`); + const updatedFilesResult = await updatedFilesResponse.json(); + + if (!updatedFilesResult.success) { + throw new Error(updatedFilesResult.error || 'Failed to get updated file list'); + } + + // Re-render the showcase content + const showcaseTab = document.getElementById('showcase-tab'); + if (showcaseTab) { + // Get the updated images from the result + const updatedImages = result.updated_images || []; + showcaseTab.innerHTML = renderShowcaseContent(updatedImages, updatedFilesResult.files); + + // Re-initialize showcase functionality + const carousel = showcaseTab.querySelector('.carousel'); + if (carousel) { + if (!carousel.classList.contains('collapsed')) { + initLazyLoading(carousel); + initNsfwBlurHandlers(carousel); + initMetadataPanelHandlers(carousel); + } + // Initialize the import UI for the new content + initExampleImport(modelHash, showcaseTab); + } + + // Update VirtualScroller if available + if (state.virtualScroller && result.model_file_path) { + // Create an update object with only the necessary properties + const updateData = { + civitai: { + images: updatedImages + } + }; + + // Update the item in the virtual scroller + state.virtualScroller.updateSingleItem(result.model_file_path, updateData); + console.log('Updated VirtualScroller item with new example images'); + } + } + } catch (error) { + console.error('Error importing examples:', error); + } +} + /** * Generate metadata panel HTML */ diff --git a/static/js/components/loraModal/index.js b/static/js/components/loraModal/index.js index 49b498d0..13fa3406 100644 --- a/static/js/components/loraModal/index.js +++ b/static/js/components/loraModal/index.js @@ -5,7 +5,13 @@ */ import { showToast, copyToClipboard, getExampleImageFiles } from '../../utils/uiHelpers.js'; import { modalManager } from '../../managers/ModalManager.js'; -import { renderShowcaseContent, toggleShowcase, setupShowcaseScroll, scrollToTop } from './ShowcaseView.js'; +import { + renderShowcaseContent, + toggleShowcase, + setupShowcaseScroll, + scrollToTop, + initExampleImport +} from './ShowcaseView.js'; import { setupTabSwitching, loadModelDescription } from './ModelDescription.js'; import { renderTriggerWords, setupTriggerWordsEditMode } from './TriggerWords.js'; import { parsePresets, renderPresetTags } from './PresetTags.js'; @@ -207,14 +213,8 @@ async function loadExampleImages(images, modelHash, filePath) { let localFiles = []; try { - // Choose endpoint based on centralized examples setting - const useCentralized = state.global.settings.useCentralizedExamples !== false; - const endpoint = useCentralized ? '/api/example-image-files' : '/api/model-example-files'; - - // Use different params based on endpoint - const params = useCentralized ? - `model_hash=${modelHash}` : - `file_path=${encodeURIComponent(filePath)}`; + const endpoint = '/api/example-image-files'; + const params = `model_hash=${modelHash}`; const response = await fetch(`${endpoint}?${params}`); const result = await response.json(); @@ -239,6 +239,9 @@ async function loadExampleImages(images, modelHash, filePath) { initMetadataPanelHandlers(carousel); } } + + // Initialize the example import functionality + initExampleImport(modelHash, showcaseTab); } catch (error) { console.error('Error loading example images:', error); const showcaseTab = document.getElementById('showcase-tab');