mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
Implement example image import functionality with UI and backend integration
This commit is contained in:
@@ -2,6 +2,7 @@ import logging
|
||||
import os
|
||||
import asyncio
|
||||
import json
|
||||
import tempfile
|
||||
import time
|
||||
import aiohttp
|
||||
import re
|
||||
@@ -38,7 +39,7 @@ class ExampleImagesRoutes:
|
||||
def setup_routes(app):
|
||||
"""Register example images routes"""
|
||||
app.router.add_post('/api/download-example-images', ExampleImagesRoutes.download_example_images)
|
||||
app.router.add_post('/api/migrate-example-images', ExampleImagesRoutes.migrate_example_images)
|
||||
app.router.add_post('/api/import-example-images', ExampleImagesRoutes.import_example_images)
|
||||
app.router.add_get('/api/example-images-status', ExampleImagesRoutes.get_example_images_status)
|
||||
app.router.add_post('/api/pause-example-images', ExampleImagesRoutes.pause_example_images)
|
||||
app.router.add_post('/api/resume-example-images', ExampleImagesRoutes.resume_example_images)
|
||||
@@ -199,7 +200,7 @@ class ExampleImagesRoutes:
|
||||
'success': False,
|
||||
'error': f"Download is in '{download_progress['status']}' state, cannot resume"
|
||||
}, status=400)
|
||||
|
||||
|
||||
@staticmethod
|
||||
async def _refresh_model_metadata(model_hash, model_name, scanner_type, scanner):
|
||||
"""Refresh model metadata from CivitAI
|
||||
@@ -781,454 +782,6 @@ class ExampleImagesRoutes:
|
||||
# Set download status to not downloading
|
||||
is_downloading = False
|
||||
|
||||
@staticmethod
|
||||
async def migrate_example_images(request):
|
||||
"""
|
||||
Migrate existing example images to central storage location
|
||||
|
||||
Expects a JSON body with:
|
||||
{
|
||||
"output_dir": "path/to/output", # Base directory to save example images
|
||||
"pattern": "{model}.example.{index}.{ext}", # Pattern to match example images
|
||||
"optimize": true, # Whether to optimize images (default: true)
|
||||
"model_types": ["lora", "checkpoint"], # Model types to process (default: both)
|
||||
}
|
||||
"""
|
||||
global download_task, is_downloading, download_progress
|
||||
|
||||
if is_downloading:
|
||||
# Create a copy for JSON serialization
|
||||
response_progress = download_progress.copy()
|
||||
response_progress['processed_models'] = list(download_progress['processed_models'])
|
||||
response_progress['refreshed_models'] = list(download_progress['refreshed_models'])
|
||||
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Download or migration already in progress',
|
||||
'status': response_progress
|
||||
}, status=400)
|
||||
|
||||
try:
|
||||
# Parse the request body
|
||||
data = await request.json()
|
||||
output_dir = data.get('output_dir')
|
||||
pattern = data.get('pattern', '{model}.example.{index}.{ext}')
|
||||
optimize = data.get('optimize', True)
|
||||
model_types = data.get('model_types', ['lora', 'checkpoint'])
|
||||
|
||||
if not output_dir:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Missing output_dir parameter'
|
||||
}, status=400)
|
||||
|
||||
# Create the output directory
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# Initialize progress tracking
|
||||
download_progress['total'] = 0
|
||||
download_progress['completed'] = 0
|
||||
download_progress['current_model'] = ''
|
||||
download_progress['status'] = 'running'
|
||||
download_progress['errors'] = []
|
||||
download_progress['last_error'] = None
|
||||
download_progress['start_time'] = time.time()
|
||||
download_progress['end_time'] = None
|
||||
download_progress['is_migrating'] = True # Mark this as a migration task
|
||||
|
||||
# Get the processed models list from a file if it exists
|
||||
progress_file = os.path.join(output_dir, '.download_progress.json')
|
||||
if os.path.exists(progress_file):
|
||||
try:
|
||||
with open(progress_file, 'r', encoding='utf-8') as f:
|
||||
saved_progress = json.load(f)
|
||||
download_progress['processed_models'] = set(saved_progress.get('processed_models', []))
|
||||
logger.info(f"Loaded previous progress, {len(download_progress['processed_models'])} models already processed")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load progress file: {e}")
|
||||
download_progress['processed_models'] = set()
|
||||
else:
|
||||
download_progress['processed_models'] = set()
|
||||
|
||||
# Start the migration task
|
||||
is_downloading = True
|
||||
download_task = asyncio.create_task(
|
||||
ExampleImagesRoutes._migrate_all_example_images(
|
||||
output_dir,
|
||||
pattern,
|
||||
optimize,
|
||||
model_types
|
||||
)
|
||||
)
|
||||
|
||||
# Create a copy for JSON serialization
|
||||
response_progress = download_progress.copy()
|
||||
response_progress['processed_models'] = list(download_progress['processed_models'])
|
||||
response_progress['refreshed_models'] = list(download_progress['refreshed_models'])
|
||||
response_progress['is_migrating'] = True
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'message': 'Migration started',
|
||||
'status': response_progress
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start example images migration: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
@staticmethod
|
||||
async def _migrate_all_example_images(output_dir, pattern, optimize, model_types):
|
||||
"""Migrate example images for all models based on pattern
|
||||
|
||||
Args:
|
||||
output_dir: Base directory to save example images
|
||||
pattern: Pattern to match example images
|
||||
optimize: Whether to optimize images
|
||||
model_types: List of model types to process
|
||||
"""
|
||||
global is_downloading, download_progress
|
||||
|
||||
try:
|
||||
# Get the scanners
|
||||
scanners = []
|
||||
if 'lora' in model_types:
|
||||
lora_scanner = await ServiceRegistry.get_lora_scanner()
|
||||
scanners.append(('lora', lora_scanner))
|
||||
|
||||
if 'checkpoint' in model_types:
|
||||
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
scanners.append(('checkpoint', checkpoint_scanner))
|
||||
|
||||
# Convert user pattern to regex
|
||||
regex_pattern = ExampleImagesRoutes._convert_pattern_to_regex(pattern)
|
||||
logger.info(f"Using pattern regex: {regex_pattern.pattern}")
|
||||
|
||||
# Get all models from all scanners
|
||||
all_models = []
|
||||
for scanner_type, scanner in scanners:
|
||||
cache = await scanner.get_cached_data()
|
||||
if cache and cache.raw_data:
|
||||
for model in cache.raw_data:
|
||||
# Only process models with a valid file path and sha256
|
||||
if model.get('file_path') and model.get('sha256'):
|
||||
all_models.append((scanner_type, model, scanner))
|
||||
|
||||
# Update total count
|
||||
download_progress['total'] = len(all_models)
|
||||
logger.info(f"Found {download_progress['total']} models to check for example images")
|
||||
|
||||
# Process each model
|
||||
for scanner_type, model, scanner in all_models:
|
||||
# Check if download is paused
|
||||
while download_progress['status'] == 'paused':
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Check if download should continue
|
||||
if download_progress['status'] != 'running':
|
||||
logger.info(f"Migration stopped: {download_progress['status']}")
|
||||
break
|
||||
|
||||
model_hash = model.get('sha256', '').lower()
|
||||
model_name = model.get('model_name', 'Unknown')
|
||||
model_file_path = model.get('file_path', '')
|
||||
model_file_name = os.path.basename(model_file_path) if model_file_path else ''
|
||||
model_dir_path = os.path.dirname(model_file_path) if model_file_path else ''
|
||||
|
||||
try:
|
||||
# Update current model info
|
||||
download_progress['current_model'] = f"{model_name} ({model_hash[:8]})"
|
||||
|
||||
# Skip if already processed
|
||||
if model_hash in download_progress['processed_models']:
|
||||
logger.debug(f"Skipping already processed model: {model_name}")
|
||||
download_progress['completed'] += 1
|
||||
continue
|
||||
|
||||
# Find matching example files based on pattern
|
||||
if model_file_name and os.path.exists(model_dir_path):
|
||||
example_files = ExampleImagesRoutes._find_matching_example_files(
|
||||
model_dir_path,
|
||||
model_file_name,
|
||||
regex_pattern
|
||||
)
|
||||
|
||||
# Process found files
|
||||
if example_files:
|
||||
logger.info(f"Found {len(example_files)} example images for {model_name}")
|
||||
|
||||
# Create model directory in output location
|
||||
model_dir = os.path.join(output_dir, model_hash)
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
|
||||
# Track local image paths for metadata update
|
||||
local_image_paths = []
|
||||
|
||||
# Migrate each example file
|
||||
for local_image_path, index in example_files:
|
||||
# Get file extension
|
||||
local_ext = os.path.splitext(local_image_path)[1].lower()
|
||||
save_filename = f"image_{index}{local_ext}"
|
||||
save_path = os.path.join(model_dir, save_filename)
|
||||
|
||||
# Track all local image paths for potential metadata update
|
||||
local_image_paths.append(local_image_path)
|
||||
|
||||
# Skip if already exists in output directory
|
||||
if os.path.exists(save_path):
|
||||
logger.debug(f"File already exists in output: {save_path}")
|
||||
continue
|
||||
|
||||
try:
|
||||
# Copy the file
|
||||
with open(local_image_path, 'rb') as src_file:
|
||||
with open(save_path, 'wb') as dst_file:
|
||||
dst_file.write(src_file.read())
|
||||
logger.debug(f"Migrated {os.path.basename(local_image_path)} to {save_path}")
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to copy file {os.path.basename(local_image_path)}: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
download_progress['errors'].append(error_msg)
|
||||
download_progress['last_error'] = error_msg
|
||||
|
||||
# Update model metadata if local images were found
|
||||
if local_image_paths:
|
||||
await ExampleImagesRoutes._update_model_metadata_from_local_examples(
|
||||
model,
|
||||
local_image_paths,
|
||||
scanner_type,
|
||||
scanner
|
||||
)
|
||||
|
||||
# Mark this model as processed
|
||||
download_progress['processed_models'].add(model_hash)
|
||||
|
||||
# Save progress to file periodically
|
||||
if download_progress['completed'] % 10 == 0 or download_progress['completed'] == download_progress['total'] - 1:
|
||||
progress_file = os.path.join(output_dir, '.download_progress.json')
|
||||
with open(progress_file, 'w', encoding='utf-8') as f:
|
||||
json.dump({
|
||||
'processed_models': list(download_progress['processed_models']),
|
||||
'refreshed_models': list(download_progress['refreshed_models']),
|
||||
'completed': download_progress['completed'],
|
||||
'total': download_progress['total'],
|
||||
'last_update': time.time()
|
||||
}, f, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error processing model {model.get('model_name')}: {str(e)}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
download_progress['errors'].append(error_msg)
|
||||
download_progress['last_error'] = error_msg
|
||||
|
||||
# Update progress
|
||||
download_progress['completed'] += 1
|
||||
|
||||
# Mark as completed
|
||||
download_progress['status'] = 'completed'
|
||||
download_progress['end_time'] = time.time()
|
||||
download_progress['is_migrating'] = False
|
||||
logger.info(f"Example images migration completed: {download_progress['completed']}/{download_progress['total']} models processed")
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error during example images migration: {str(e)}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
download_progress['errors'].append(error_msg)
|
||||
download_progress['last_error'] = error_msg
|
||||
download_progress['status'] = 'error'
|
||||
download_progress['end_time'] = time.time()
|
||||
download_progress['is_migrating'] = False
|
||||
|
||||
finally:
|
||||
# Save final progress to file
|
||||
try:
|
||||
progress_file = os.path.join(output_dir, '.download_progress.json')
|
||||
with open(progress_file, 'w', encoding='utf-8') as f:
|
||||
json.dump({
|
||||
'processed_models': list(download_progress['processed_models']),
|
||||
'refreshed_models': list(download_progress['refreshed_models']),
|
||||
'completed': download_progress['completed'],
|
||||
'total': download_progress['total'],
|
||||
'last_update': time.time(),
|
||||
'status': download_progress['status'],
|
||||
'is_migrating': False
|
||||
}, f, indent=2)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save progress file: {e}")
|
||||
|
||||
# Set download status to not downloading
|
||||
is_downloading = False
|
||||
|
||||
@staticmethod
|
||||
def _convert_pattern_to_regex(pattern):
|
||||
"""Convert a user-friendly template pattern to a regex pattern
|
||||
|
||||
Args:
|
||||
pattern: Template pattern string
|
||||
|
||||
Returns:
|
||||
re.Pattern: Compiled regex pattern object
|
||||
"""
|
||||
# Normalize path separators to forward slashes for consistent matching
|
||||
pattern = pattern.replace('\\', '/')
|
||||
|
||||
# Escape special regex characters
|
||||
regex_safe = re.escape(pattern)
|
||||
|
||||
# Handle multiple occurrences of {model}
|
||||
model_count = pattern.count('{model}')
|
||||
if (model_count > 1):
|
||||
# Replace the first occurrence with a named capture group
|
||||
regex_safe = regex_safe.replace(r'\{model\}', r'(?P<model>.*?)', 1)
|
||||
|
||||
# Replace subsequent occurrences with a back-reference
|
||||
# Using (?P=model) for Python's regex named backreference syntax
|
||||
for _ in range(model_count - 1):
|
||||
regex_safe = regex_safe.replace(r'\{model\}', r'(?P=model)', 1)
|
||||
else:
|
||||
# Just one occurrence, handle normally
|
||||
regex_safe = regex_safe.replace(r'\{model\}', r'(?P<model>.*?)')
|
||||
|
||||
# {index} becomes a capture group for digits
|
||||
regex_safe = regex_safe.replace(r'\{index\}', r'(?P<index>\d+)')
|
||||
|
||||
# {ext} becomes a capture group for file extension WITHOUT including the dot
|
||||
regex_safe = regex_safe.replace(r'\{ext\}', r'(?P<ext>\w+)')
|
||||
|
||||
# Handle wildcard * character (which was escaped earlier)
|
||||
regex_safe = regex_safe.replace(r'\*', r'.*?')
|
||||
|
||||
logger.info(f"Converted pattern '{pattern}' to regex: '{regex_safe}'")
|
||||
|
||||
# Compile the regex pattern
|
||||
return re.compile(regex_safe)
|
||||
|
||||
@staticmethod
|
||||
def _find_matching_example_files(dir_path, model_filename, regex_pattern):
|
||||
"""Find example files matching the pattern in the given directory
|
||||
|
||||
Args:
|
||||
dir_path: Directory to search in
|
||||
model_filename: Model filename (without extension)
|
||||
regex_pattern: Compiled regex pattern to match against
|
||||
|
||||
Returns:
|
||||
list: List of tuples (file_path, index) of matching files
|
||||
"""
|
||||
matching_files = []
|
||||
model_name = os.path.splitext(model_filename)[0]
|
||||
|
||||
# Check if pattern contains a directory separator
|
||||
has_subdirs = '/' in regex_pattern.pattern or '\\\\' in regex_pattern.pattern
|
||||
|
||||
# Determine search paths (keep existing logic for subdirectories)
|
||||
if has_subdirs:
|
||||
# Handle patterns with subdirectories
|
||||
subdir_match = re.match(r'.*(?P<model>.*?)(/|\\\\).*', regex_pattern.pattern)
|
||||
if subdir_match:
|
||||
potential_subdir = os.path.join(dir_path, model_name)
|
||||
if os.path.exists(potential_subdir) and os.path.isdir(potential_subdir):
|
||||
search_paths = [potential_subdir]
|
||||
else:
|
||||
search_paths = [dir_path]
|
||||
else:
|
||||
search_paths = [dir_path]
|
||||
else:
|
||||
search_paths = [dir_path]
|
||||
|
||||
for search_path in search_paths:
|
||||
if not os.path.exists(search_path):
|
||||
continue
|
||||
|
||||
# For optimized performance: create a model name prefix check
|
||||
# This works for any pattern where the model name appears at the start
|
||||
if not has_subdirs:
|
||||
# Get list of all files first
|
||||
all_files = os.listdir(search_path)
|
||||
|
||||
# First pass: filter files that start with model name (case insensitive)
|
||||
# This is much faster than regex for initial filtering
|
||||
potential_matches = []
|
||||
lower_model_name = model_name.lower()
|
||||
|
||||
for file in all_files:
|
||||
# Quick check if file starts with model name
|
||||
if file.lower().startswith(lower_model_name):
|
||||
file_path = os.path.join(search_path, file)
|
||||
if os.path.isfile(file_path):
|
||||
potential_matches.append((file, file_path))
|
||||
|
||||
# Second pass: apply full regex only to potential matches
|
||||
for file, file_path in potential_matches:
|
||||
match = regex_pattern.match(file)
|
||||
if match:
|
||||
# Verify model name matches exactly what we're looking for
|
||||
if match.group('model') != model_name:
|
||||
logger.debug(f"File {file} matched pattern but model name {match.group('model')} doesn't match {model_name}")
|
||||
continue
|
||||
|
||||
# Check if file extension is supported
|
||||
file_ext = os.path.splitext(file)[1].lower()
|
||||
is_supported = (file_ext in SUPPORTED_MEDIA_EXTENSIONS['images'] or
|
||||
file_ext in SUPPORTED_MEDIA_EXTENSIONS['videos'])
|
||||
|
||||
if is_supported:
|
||||
# Extract index from match
|
||||
try:
|
||||
index = int(match.group('index'))
|
||||
except (IndexError, ValueError):
|
||||
index = len(matching_files) + 1
|
||||
|
||||
matching_files.append((file_path, index))
|
||||
else:
|
||||
# Original scanning logic for patterns with subdirectories
|
||||
for file in os.listdir(search_path):
|
||||
file_path = os.path.join(search_path, file)
|
||||
if os.path.isfile(file_path):
|
||||
# Try to match the filename directly first
|
||||
match = regex_pattern.match(file)
|
||||
|
||||
# If no match and subdirs are expected, try the relative path
|
||||
if not match and has_subdirs:
|
||||
# Get relative path and normalize slashes for consistent matching
|
||||
rel_path = os.path.relpath(file_path, dir_path)
|
||||
# Replace Windows backslashes with forward slashes for consistent regex matching
|
||||
rel_path = rel_path.replace('\\', '/')
|
||||
match = regex_pattern.match(rel_path)
|
||||
|
||||
if match:
|
||||
# For subdirectory patterns, model name in the match might refer to the dir name only
|
||||
# so we need a different checking logic
|
||||
matched_model = match.group('model')
|
||||
if has_subdirs and '/' in rel_path:
|
||||
# For subdirectory patterns, it's okay if just the folder name matches
|
||||
folder_name = rel_path.split('/')[0]
|
||||
if matched_model != model_name and matched_model != folder_name:
|
||||
logger.debug(f"File {file} matched pattern but model name {matched_model} doesn't match {model_name}")
|
||||
continue
|
||||
elif matched_model != model_name:
|
||||
logger.debug(f"File {file} matched pattern but model name {matched_model} doesn't match {model_name}")
|
||||
continue
|
||||
|
||||
file_ext = os.path.splitext(file)[1].lower()
|
||||
is_supported = (file_ext in SUPPORTED_MEDIA_EXTENSIONS['images'] or
|
||||
file_ext in SUPPORTED_MEDIA_EXTENSIONS['videos'])
|
||||
|
||||
if is_supported:
|
||||
try:
|
||||
index = int(match.group('index'))
|
||||
except (IndexError, ValueError):
|
||||
index = len(matching_files) + 1
|
||||
|
||||
matching_files.append((file_path, index))
|
||||
|
||||
# Sort files by their index
|
||||
matching_files.sort(key=lambda x: x[1])
|
||||
return matching_files
|
||||
|
||||
@staticmethod
|
||||
async def open_example_images_folder(request):
|
||||
"""
|
||||
@@ -1426,3 +979,269 @@ class ExampleImagesRoutes:
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
@staticmethod
|
||||
async def import_example_images(request):
|
||||
"""
|
||||
Import local example images for a model
|
||||
|
||||
Expects:
|
||||
- multipart/form-data with model_hash and files fields
|
||||
OR
|
||||
- JSON request with model_hash and file_paths
|
||||
|
||||
Returns:
|
||||
- Success status and list of imported files
|
||||
"""
|
||||
try:
|
||||
model_hash = None
|
||||
files_to_import = []
|
||||
temp_files_to_cleanup = []
|
||||
|
||||
# Check if this is a multipart form data request (direct file upload)
|
||||
if request.content_type and 'multipart/form-data' in request.content_type:
|
||||
reader = await request.multipart()
|
||||
|
||||
# First, get the model_hash
|
||||
field = await reader.next()
|
||||
if field.name == 'model_hash':
|
||||
model_hash = await field.text()
|
||||
|
||||
# Then process all files
|
||||
while True:
|
||||
field = await reader.next()
|
||||
if field is None:
|
||||
break
|
||||
|
||||
if field.name == 'files':
|
||||
# Create a temporary file with a proper suffix for type detection
|
||||
file_name = field.filename
|
||||
file_ext = os.path.splitext(file_name)[1].lower()
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=file_ext, delete=False) as tmp_file:
|
||||
temp_path = tmp_file.name
|
||||
temp_files_to_cleanup.append(temp_path) # Track for cleanup
|
||||
|
||||
# Write chunks to the temp file
|
||||
while True:
|
||||
chunk = await field.read_chunk()
|
||||
if not chunk:
|
||||
break
|
||||
tmp_file.write(chunk)
|
||||
|
||||
# Add to our list of files to process
|
||||
files_to_import.append(temp_path)
|
||||
else:
|
||||
# Parse JSON request (legacy method with file paths)
|
||||
data = await request.json()
|
||||
model_hash = data.get('model_hash')
|
||||
files_to_import = data.get('file_paths', [])
|
||||
|
||||
if not model_hash:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Missing model_hash parameter'
|
||||
}, status=400)
|
||||
|
||||
if not files_to_import:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'No files provided to import'
|
||||
}, status=400)
|
||||
|
||||
# Get example images path
|
||||
example_images_path = settings.get('example_images_path')
|
||||
if not example_images_path:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'No example images path configured'
|
||||
}, status=400)
|
||||
|
||||
# Find the model and get current metadata
|
||||
lora_scanner = await ServiceRegistry.get_lora_scanner()
|
||||
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
|
||||
model_data = None
|
||||
scanner = None
|
||||
|
||||
# Check both scanners to find the model
|
||||
for scan_obj in [lora_scanner, checkpoint_scanner]:
|
||||
cache = await scan_obj.get_cached_data()
|
||||
for item in cache.raw_data:
|
||||
if item.get('sha256') == model_hash:
|
||||
model_data = item
|
||||
scanner = scan_obj
|
||||
break
|
||||
if model_data:
|
||||
break
|
||||
|
||||
if not model_data:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': f"Model with hash {model_hash} not found in cache"
|
||||
}, status=404)
|
||||
|
||||
# Get current number of images in civitai.images array
|
||||
civitai_data = model_data.get('civitai')
|
||||
current_images = civitai_data.get('images', []) if civitai_data is not None else []
|
||||
next_index = len(current_images)
|
||||
|
||||
# Create model folder
|
||||
model_folder = os.path.join(example_images_path, model_hash)
|
||||
os.makedirs(model_folder, exist_ok=True)
|
||||
|
||||
imported_files = []
|
||||
errors = []
|
||||
newly_imported_paths = []
|
||||
|
||||
# Process each file path
|
||||
for file_path in files_to_import:
|
||||
try:
|
||||
# Ensure file exists
|
||||
if not os.path.isfile(file_path):
|
||||
errors.append(f"File not found: {file_path}")
|
||||
continue
|
||||
|
||||
# Check if file type is supported
|
||||
file_ext = os.path.splitext(file_path)[1].lower()
|
||||
if not (file_ext in SUPPORTED_MEDIA_EXTENSIONS['images'] or
|
||||
file_ext in SUPPORTED_MEDIA_EXTENSIONS['videos']):
|
||||
errors.append(f"Unsupported file type: {file_path}")
|
||||
continue
|
||||
|
||||
# Generate new filename with sequential index starting from current images length
|
||||
new_filename = f"image_{next_index}{file_ext}"
|
||||
next_index += 1
|
||||
|
||||
dest_path = os.path.join(model_folder, new_filename)
|
||||
|
||||
# Copy the file
|
||||
import shutil
|
||||
shutil.copy2(file_path, dest_path)
|
||||
newly_imported_paths.append(dest_path)
|
||||
|
||||
# Add to imported files list
|
||||
imported_files.append({
|
||||
'name': new_filename,
|
||||
'path': f'/example_images_static/{model_hash}/{new_filename}',
|
||||
'extension': file_ext,
|
||||
'is_video': file_ext in SUPPORTED_MEDIA_EXTENSIONS['videos']
|
||||
})
|
||||
except Exception as e:
|
||||
errors.append(f"Error importing {file_path}: {str(e)}")
|
||||
|
||||
# Update metadata with new example images
|
||||
updated_images = await ExampleImagesRoutes._update_metadata_after_import(
|
||||
model_hash,
|
||||
model_data,
|
||||
scanner,
|
||||
newly_imported_paths
|
||||
)
|
||||
|
||||
return web.json_response({
|
||||
'success': len(imported_files) > 0,
|
||||
'message': f'Successfully imported {len(imported_files)} files' +
|
||||
(f' with {len(errors)} errors' if errors else ''),
|
||||
'files': imported_files,
|
||||
'errors': errors,
|
||||
'updated_images': updated_images,
|
||||
"model_file_path": model_data.get('file_path', ''),
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to import example images: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
finally:
|
||||
# Clean up temporary files if any
|
||||
for temp_file in temp_files_to_cleanup:
|
||||
try:
|
||||
os.remove(temp_file)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to remove temporary file {temp_file}: {e}")
|
||||
|
||||
@staticmethod
|
||||
async def _update_metadata_after_import(model_hash, model_data, scanner, newly_imported_paths):
|
||||
"""
|
||||
Update model metadata after importing example images by appending new images to the existing array
|
||||
|
||||
Args:
|
||||
model_hash: SHA256 hash of the model
|
||||
model_data: Model data dictionary
|
||||
scanner: Scanner instance (lora or checkpoint)
|
||||
newly_imported_paths: List of paths to newly imported files
|
||||
|
||||
Returns:
|
||||
list: Updated images array
|
||||
"""
|
||||
try:
|
||||
# Ensure civitai field exists in model data
|
||||
if not model_data.get('civitai'):
|
||||
model_data['civitai'] = {}
|
||||
|
||||
# Ensure images array exists
|
||||
if not model_data['civitai'].get('images'):
|
||||
model_data['civitai']['images'] = []
|
||||
|
||||
# Get current images array
|
||||
images = model_data['civitai']['images']
|
||||
|
||||
# Add new image entries for each imported file
|
||||
for path in newly_imported_paths:
|
||||
# Determine if it's a video or image
|
||||
file_ext = os.path.splitext(path)[1].lower()
|
||||
is_video = file_ext in SUPPORTED_MEDIA_EXTENSIONS['videos']
|
||||
|
||||
# Create image metadata entry
|
||||
image_entry = {
|
||||
"url": "", # Empty URL as requested
|
||||
"nsfwLevel": 0,
|
||||
"width": 720, # Default dimensions
|
||||
"height": 1280,
|
||||
"type": "video" if is_video else "image",
|
||||
"meta": None,
|
||||
"hasMeta": False,
|
||||
"hasPositivePrompt": False
|
||||
}
|
||||
|
||||
# Try to get actual dimensions if it's an image
|
||||
try:
|
||||
from PIL import Image
|
||||
if not is_video and os.path.exists(path):
|
||||
with Image.open(path) as img:
|
||||
image_entry["width"], image_entry["height"] = img.size
|
||||
except:
|
||||
# If PIL fails or isn't available, use default dimensions
|
||||
pass
|
||||
|
||||
# Append to the existing images array
|
||||
images.append(image_entry)
|
||||
|
||||
# Save metadata to the .metadata.json file
|
||||
file_path = model_data.get('file_path')
|
||||
if file_path:
|
||||
base_path = os.path.splitext(file_path)[0]
|
||||
metadata_path = f"{base_path}.metadata.json"
|
||||
try:
|
||||
# Create a copy of the model data without the 'folder' field
|
||||
model_copy = model_data.copy()
|
||||
model_copy.pop('folder', None)
|
||||
|
||||
# Write the metadata to file
|
||||
with open(metadata_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(model_copy, f, indent=2, ensure_ascii=False)
|
||||
logger.info(f"Saved metadata to {metadata_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save metadata to {metadata_path}: {str(e)}")
|
||||
|
||||
# Save updated metadata to scanner cache
|
||||
if file_path:
|
||||
await scanner.update_single_model_cache(file_path, file_path, model_data)
|
||||
|
||||
return images
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update metadata after import: {e}", exc_info=True)
|
||||
return []
|
||||
@@ -289,4 +289,95 @@
|
||||
|
||||
.lazy[src] {
|
||||
opacity: 1;
|
||||
}
|
||||
|
||||
/* Example Import Area */
|
||||
.example-import-area {
|
||||
margin-top: var(--space-4);
|
||||
padding: var(--space-2);
|
||||
}
|
||||
|
||||
.example-import-area.empty {
|
||||
margin-top: var(--space-2);
|
||||
padding: var(--space-4) var(--space-2);
|
||||
}
|
||||
|
||||
.import-container {
|
||||
border: 2px dashed var(--border-color);
|
||||
border-radius: var(--border-radius-sm);
|
||||
padding: var(--space-4);
|
||||
text-align: center;
|
||||
transition: all 0.3s ease;
|
||||
background: var(--lora-surface);
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
.import-container.highlight {
|
||||
border-color: var(--lora-accent);
|
||||
background: oklch(var(--lora-accent-l) var(--lora-accent-c) var(--lora-accent-h) / 0.1);
|
||||
transform: scale(1.01);
|
||||
}
|
||||
|
||||
.import-placeholder {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
gap: var(--space-1);
|
||||
padding-top: var(--space-1);
|
||||
}
|
||||
|
||||
.import-placeholder i {
|
||||
font-size: 2.5rem;
|
||||
/* color: var(--lora-accent); */
|
||||
opacity: 0.8;
|
||||
margin-bottom: var(--space-1);
|
||||
}
|
||||
|
||||
.import-placeholder h3 {
|
||||
margin: 0 0 var(--space-1);
|
||||
font-size: 1.2rem;
|
||||
font-weight: 500;
|
||||
color: var(--text-color);
|
||||
}
|
||||
|
||||
.import-placeholder p {
|
||||
margin: var(--space-1) 0;
|
||||
color: var(--text-color);
|
||||
opacity: 0.8;
|
||||
}
|
||||
|
||||
.import-placeholder .sub-text {
|
||||
font-size: 0.9em;
|
||||
opacity: 0.6;
|
||||
margin: var(--space-1) 0;
|
||||
}
|
||||
|
||||
.import-formats {
|
||||
font-size: 0.8em !important;
|
||||
opacity: 0.6 !important;
|
||||
margin-top: var(--space-2) !important;
|
||||
}
|
||||
|
||||
.select-files-btn {
|
||||
background: var(--lora-accent);
|
||||
color: var(--lora-text);
|
||||
border: none;
|
||||
border-radius: var(--border-radius-xs);
|
||||
padding: var(--space-2) var(--space-3);
|
||||
cursor: pointer;
|
||||
font-size: 0.9em;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
transition: all 0.2s;
|
||||
}
|
||||
|
||||
.select-files-btn:hover {
|
||||
opacity: 0.9;
|
||||
transform: translateY(-1px);
|
||||
}
|
||||
|
||||
/* For dark theme */
|
||||
[data-theme="dark"] .import-container {
|
||||
background: rgba(255, 255, 255, 0.03);
|
||||
}
|
||||
@@ -17,7 +17,10 @@ import { NSFW_LEVELS } from '../../utils/constants.js';
|
||||
* @returns {Promise<string>} HTML内容
|
||||
*/
|
||||
export function renderShowcaseContent(images, exampleFiles = []) {
|
||||
if (!images?.length) return '<div class="no-examples">No example images available</div>';
|
||||
if (!images?.length) {
|
||||
// Replace empty message with import interface
|
||||
return renderImportInterface(true);
|
||||
}
|
||||
|
||||
// Filter images based on SFW setting
|
||||
const showOnlySFW = state.settings.show_only_sfw;
|
||||
@@ -136,10 +139,202 @@ export function renderShowcaseContent(images, exampleFiles = []) {
|
||||
);
|
||||
}).join('')}
|
||||
</div>
|
||||
|
||||
<!-- Add import interface at the bottom of existing examples -->
|
||||
${renderImportInterface(false)}
|
||||
</div>
|
||||
`;
|
||||
}
|
||||
|
||||
/**
|
||||
* Render the import interface for example images
|
||||
* @param {boolean} isEmpty - Whether there are no existing examples
|
||||
* @returns {string} HTML content for import interface
|
||||
*/
|
||||
function renderImportInterface(isEmpty) {
|
||||
return `
|
||||
<div class="example-import-area ${isEmpty ? 'empty' : ''}">
|
||||
<div class="import-container" id="exampleImportContainer">
|
||||
<div class="import-placeholder">
|
||||
<i class="fas fa-cloud-upload-alt"></i>
|
||||
<h3>${isEmpty ? 'No example images available' : 'Add more examples'}</h3>
|
||||
<p>Drag & drop images or videos here</p>
|
||||
<p class="sub-text">or</p>
|
||||
<button class="select-files-btn" id="selectExampleFilesBtn">
|
||||
<i class="fas fa-folder-open"></i> Select Files
|
||||
</button>
|
||||
<p class="import-formats">Supported formats: jpg, png, gif, webp, mp4, webm</p>
|
||||
</div>
|
||||
<input type="file" id="exampleFilesInput" multiple accept="image/*,video/mp4,video/webm" style="display: none;">
|
||||
<div class="import-progress-container" style="display: none;">
|
||||
<div class="import-progress">
|
||||
<div class="progress-bar"></div>
|
||||
</div>
|
||||
<span class="progress-text">Importing files...</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
`;
|
||||
}
|
||||
|
||||
/**
|
||||
* Initialize the import functionality for example images
|
||||
* @param {string} modelHash - The SHA256 hash of the model
|
||||
* @param {Element} container - The container element for the import area
|
||||
*/
|
||||
export function initExampleImport(modelHash, container) {
|
||||
if (!container) return;
|
||||
|
||||
const importContainer = container.querySelector('#exampleImportContainer');
|
||||
const fileInput = container.querySelector('#exampleFilesInput');
|
||||
const selectFilesBtn = container.querySelector('#selectExampleFilesBtn');
|
||||
|
||||
// Set up file selection button
|
||||
if (selectFilesBtn) {
|
||||
selectFilesBtn.addEventListener('click', () => {
|
||||
fileInput.click();
|
||||
});
|
||||
}
|
||||
|
||||
// Handle file selection
|
||||
if (fileInput) {
|
||||
fileInput.addEventListener('change', (e) => {
|
||||
if (e.target.files.length > 0) {
|
||||
handleImportFiles(Array.from(e.target.files), modelHash, importContainer);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Set up drag and drop
|
||||
if (importContainer) {
|
||||
['dragenter', 'dragover', 'dragleave', 'drop'].forEach(eventName => {
|
||||
importContainer.addEventListener(eventName, preventDefaults, false);
|
||||
});
|
||||
|
||||
function preventDefaults(e) {
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
}
|
||||
|
||||
// Highlight drop area on drag over
|
||||
['dragenter', 'dragover'].forEach(eventName => {
|
||||
importContainer.addEventListener(eventName, () => {
|
||||
importContainer.classList.add('highlight');
|
||||
}, false);
|
||||
});
|
||||
|
||||
// Remove highlight on drag leave
|
||||
['dragleave', 'drop'].forEach(eventName => {
|
||||
importContainer.addEventListener(eventName, () => {
|
||||
importContainer.classList.remove('highlight');
|
||||
}, false);
|
||||
});
|
||||
|
||||
// Handle dropped files
|
||||
importContainer.addEventListener('drop', (e) => {
|
||||
const files = Array.from(e.dataTransfer.files);
|
||||
handleImportFiles(files, modelHash, importContainer);
|
||||
}, false);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle the file import process
|
||||
* @param {File[]} files - Array of files to import
|
||||
* @param {string} modelHash - The SHA256 hash of the model
|
||||
* @param {Element} importContainer - The container element for import UI
|
||||
*/
|
||||
async function handleImportFiles(files, modelHash, importContainer) {
|
||||
// Filter for supported file types
|
||||
const supportedImages = ['.jpg', '.jpeg', '.png', '.gif', '.webp'];
|
||||
const supportedVideos = ['.mp4', '.webm'];
|
||||
const supportedExtensions = [...supportedImages, ...supportedVideos];
|
||||
|
||||
const validFiles = files.filter(file => {
|
||||
const ext = '.' + file.name.split('.').pop().toLowerCase();
|
||||
return supportedExtensions.includes(ext);
|
||||
});
|
||||
|
||||
if (validFiles.length === 0) {
|
||||
alert('No supported files selected. Please select image or video files.');
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
// Get file paths to send to backend
|
||||
const filePaths = validFiles.map(file => {
|
||||
// We need the full path, but we only have the filename
|
||||
// For security reasons, browsers don't provide full paths
|
||||
// This will only work if the backend can handle just filenames
|
||||
return URL.createObjectURL(file);
|
||||
});
|
||||
|
||||
// Use FileReader to get the file data for direct upload
|
||||
const formData = new FormData();
|
||||
formData.append('model_hash', modelHash);
|
||||
|
||||
validFiles.forEach(file => {
|
||||
formData.append('files', file);
|
||||
});
|
||||
|
||||
// Call API to import files
|
||||
const response = await fetch('/api/import-example-images', {
|
||||
method: 'POST',
|
||||
body: formData
|
||||
});
|
||||
|
||||
const result = await response.json();
|
||||
|
||||
if (!result.success) {
|
||||
throw new Error(result.error || 'Failed to import example files');
|
||||
}
|
||||
|
||||
// Get updated local files
|
||||
const updatedFilesResponse = await fetch(`/api/example-image-files?model_hash=${modelHash}`);
|
||||
const updatedFilesResult = await updatedFilesResponse.json();
|
||||
|
||||
if (!updatedFilesResult.success) {
|
||||
throw new Error(updatedFilesResult.error || 'Failed to get updated file list');
|
||||
}
|
||||
|
||||
// Re-render the showcase content
|
||||
const showcaseTab = document.getElementById('showcase-tab');
|
||||
if (showcaseTab) {
|
||||
// Get the updated images from the result
|
||||
const updatedImages = result.updated_images || [];
|
||||
showcaseTab.innerHTML = renderShowcaseContent(updatedImages, updatedFilesResult.files);
|
||||
|
||||
// Re-initialize showcase functionality
|
||||
const carousel = showcaseTab.querySelector('.carousel');
|
||||
if (carousel) {
|
||||
if (!carousel.classList.contains('collapsed')) {
|
||||
initLazyLoading(carousel);
|
||||
initNsfwBlurHandlers(carousel);
|
||||
initMetadataPanelHandlers(carousel);
|
||||
}
|
||||
// Initialize the import UI for the new content
|
||||
initExampleImport(modelHash, showcaseTab);
|
||||
}
|
||||
|
||||
// Update VirtualScroller if available
|
||||
if (state.virtualScroller && result.model_file_path) {
|
||||
// Create an update object with only the necessary properties
|
||||
const updateData = {
|
||||
civitai: {
|
||||
images: updatedImages
|
||||
}
|
||||
};
|
||||
|
||||
// Update the item in the virtual scroller
|
||||
state.virtualScroller.updateSingleItem(result.model_file_path, updateData);
|
||||
console.log('Updated VirtualScroller item with new example images');
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error importing examples:', error);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate metadata panel HTML
|
||||
*/
|
||||
|
||||
@@ -5,7 +5,13 @@
|
||||
*/
|
||||
import { showToast, copyToClipboard, getExampleImageFiles } from '../../utils/uiHelpers.js';
|
||||
import { modalManager } from '../../managers/ModalManager.js';
|
||||
import { renderShowcaseContent, toggleShowcase, setupShowcaseScroll, scrollToTop } from './ShowcaseView.js';
|
||||
import {
|
||||
renderShowcaseContent,
|
||||
toggleShowcase,
|
||||
setupShowcaseScroll,
|
||||
scrollToTop,
|
||||
initExampleImport
|
||||
} from './ShowcaseView.js';
|
||||
import { setupTabSwitching, loadModelDescription } from './ModelDescription.js';
|
||||
import { renderTriggerWords, setupTriggerWordsEditMode } from './TriggerWords.js';
|
||||
import { parsePresets, renderPresetTags } from './PresetTags.js';
|
||||
@@ -207,14 +213,8 @@ async function loadExampleImages(images, modelHash, filePath) {
|
||||
let localFiles = [];
|
||||
|
||||
try {
|
||||
// Choose endpoint based on centralized examples setting
|
||||
const useCentralized = state.global.settings.useCentralizedExamples !== false;
|
||||
const endpoint = useCentralized ? '/api/example-image-files' : '/api/model-example-files';
|
||||
|
||||
// Use different params based on endpoint
|
||||
const params = useCentralized ?
|
||||
`model_hash=${modelHash}` :
|
||||
`file_path=${encodeURIComponent(filePath)}`;
|
||||
const endpoint = '/api/example-image-files';
|
||||
const params = `model_hash=${modelHash}`;
|
||||
|
||||
const response = await fetch(`${endpoint}?${params}`);
|
||||
const result = await response.json();
|
||||
@@ -239,6 +239,9 @@ async function loadExampleImages(images, modelHash, filePath) {
|
||||
initMetadataPanelHandlers(carousel);
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize the example import functionality
|
||||
initExampleImport(modelHash, showcaseTab);
|
||||
} catch (error) {
|
||||
console.error('Error loading example images:', error);
|
||||
const showcaseTab = document.getElementById('showcase-tab');
|
||||
|
||||
Reference in New Issue
Block a user