mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-23 14:12:11 -03:00
feat(example-images): add use case orchestration
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
import random
|
||||
import string
|
||||
from aiohttp import web
|
||||
@@ -13,6 +12,14 @@ from ..utils.metadata_manager import MetadataManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ExampleImagesImportError(RuntimeError):
|
||||
"""Base error for example image import operations."""
|
||||
|
||||
|
||||
class ExampleImagesValidationError(ExampleImagesImportError):
|
||||
"""Raised when input validation fails."""
|
||||
|
||||
class ExampleImagesProcessor:
|
||||
"""Processes and manipulates example images"""
|
||||
|
||||
@@ -299,90 +306,29 @@ class ExampleImagesProcessor:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
async def import_images(request):
|
||||
"""
|
||||
Import local example images
|
||||
|
||||
Accepts:
|
||||
- multipart/form-data form with model_hash and files fields
|
||||
or
|
||||
- JSON request with model_hash and file_paths
|
||||
|
||||
Returns:
|
||||
- Success status and list of imported files
|
||||
"""
|
||||
async def import_images(model_hash: str, files_to_import: list[str]):
|
||||
"""Import local example images for a model."""
|
||||
|
||||
if not model_hash:
|
||||
raise ExampleImagesValidationError('Missing model_hash parameter')
|
||||
|
||||
if not files_to_import:
|
||||
raise ExampleImagesValidationError('No files provided to import')
|
||||
|
||||
try:
|
||||
model_hash = None
|
||||
files_to_import = []
|
||||
temp_files_to_cleanup = []
|
||||
|
||||
# Check if it's a multipart form-data request (direct file upload)
|
||||
if request.content_type and 'multipart/form-data' in request.content_type:
|
||||
reader = await request.multipart()
|
||||
|
||||
# First get model_hash
|
||||
field = await reader.next()
|
||||
if field.name == 'model_hash':
|
||||
model_hash = await field.text()
|
||||
|
||||
# Then process all files
|
||||
while True:
|
||||
field = await reader.next()
|
||||
if field is None:
|
||||
break
|
||||
|
||||
if field.name == 'files':
|
||||
# Create a temporary file with appropriate suffix for type detection
|
||||
file_name = field.filename
|
||||
file_ext = os.path.splitext(file_name)[1].lower()
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=file_ext, delete=False) as tmp_file:
|
||||
temp_path = tmp_file.name
|
||||
temp_files_to_cleanup.append(temp_path) # Track for cleanup
|
||||
|
||||
# Write chunks to the temporary file
|
||||
while True:
|
||||
chunk = await field.read_chunk()
|
||||
if not chunk:
|
||||
break
|
||||
tmp_file.write(chunk)
|
||||
|
||||
# Add to the list of files to process
|
||||
files_to_import.append(temp_path)
|
||||
else:
|
||||
# Parse JSON request (legacy method using file paths)
|
||||
data = await request.json()
|
||||
model_hash = data.get('model_hash')
|
||||
files_to_import = data.get('file_paths', [])
|
||||
|
||||
if not model_hash:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Missing model_hash parameter'
|
||||
}, status=400)
|
||||
|
||||
if not files_to_import:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'No files provided to import'
|
||||
}, status=400)
|
||||
|
||||
# Get example images path
|
||||
example_images_path = settings.get('example_images_path')
|
||||
if not example_images_path:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'No example images path configured'
|
||||
}, status=400)
|
||||
|
||||
raise ExampleImagesValidationError('No example images path configured')
|
||||
|
||||
# Find the model and get current metadata
|
||||
lora_scanner = await ServiceRegistry.get_lora_scanner()
|
||||
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
|
||||
|
||||
|
||||
model_data = None
|
||||
scanner = None
|
||||
|
||||
|
||||
# Check both scanners to find the model
|
||||
for scan_obj in [lora_scanner, checkpoint_scanner, embedding_scanner]:
|
||||
cache = await scan_obj.get_cached_data()
|
||||
@@ -393,21 +339,20 @@ class ExampleImagesProcessor:
|
||||
break
|
||||
if model_data:
|
||||
break
|
||||
|
||||
|
||||
if not model_data:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': f"Model with hash {model_hash} not found in cache"
|
||||
}, status=404)
|
||||
|
||||
raise ExampleImagesImportError(
|
||||
f"Model with hash {model_hash} not found in cache"
|
||||
)
|
||||
|
||||
# Create model folder
|
||||
model_folder = os.path.join(example_images_path, model_hash)
|
||||
os.makedirs(model_folder, exist_ok=True)
|
||||
|
||||
|
||||
imported_files = []
|
||||
errors = []
|
||||
newly_imported_paths = []
|
||||
|
||||
|
||||
# Process each file path
|
||||
for file_path in files_to_import:
|
||||
try:
|
||||
@@ -415,26 +360,26 @@ class ExampleImagesProcessor:
|
||||
if not os.path.isfile(file_path):
|
||||
errors.append(f"File not found: {file_path}")
|
||||
continue
|
||||
|
||||
|
||||
# Check if file type is supported
|
||||
file_ext = os.path.splitext(file_path)[1].lower()
|
||||
if not (file_ext in SUPPORTED_MEDIA_EXTENSIONS['images'] or
|
||||
if not (file_ext in SUPPORTED_MEDIA_EXTENSIONS['images'] or
|
||||
file_ext in SUPPORTED_MEDIA_EXTENSIONS['videos']):
|
||||
errors.append(f"Unsupported file type: {file_path}")
|
||||
continue
|
||||
|
||||
|
||||
# Generate new filename using short ID instead of UUID
|
||||
short_id = ExampleImagesProcessor.generate_short_id()
|
||||
new_filename = f"custom_{short_id}{file_ext}"
|
||||
|
||||
|
||||
dest_path = os.path.join(model_folder, new_filename)
|
||||
|
||||
|
||||
# Copy the file
|
||||
import shutil
|
||||
shutil.copy2(file_path, dest_path)
|
||||
# Store both the dest_path and the short_id
|
||||
newly_imported_paths.append((dest_path, short_id))
|
||||
|
||||
|
||||
# Add to imported files list
|
||||
imported_files.append({
|
||||
'name': new_filename,
|
||||
@@ -444,39 +389,31 @@ class ExampleImagesProcessor:
|
||||
})
|
||||
except Exception as e:
|
||||
errors.append(f"Error importing {file_path}: {str(e)}")
|
||||
|
||||
|
||||
# Update metadata with new example images
|
||||
regular_images, custom_images = await MetadataUpdater.update_metadata_after_import(
|
||||
model_hash,
|
||||
model_hash,
|
||||
model_data,
|
||||
scanner,
|
||||
newly_imported_paths
|
||||
)
|
||||
|
||||
return web.json_response({
|
||||
|
||||
return {
|
||||
'success': len(imported_files) > 0,
|
||||
'message': f'Successfully imported {len(imported_files)} files' +
|
||||
'message': f'Successfully imported {len(imported_files)} files' +
|
||||
(f' with {len(errors)} errors' if errors else ''),
|
||||
'files': imported_files,
|
||||
'errors': errors,
|
||||
'regular_images': regular_images,
|
||||
'custom_images': custom_images,
|
||||
"model_file_path": model_data.get('file_path', ''),
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
except ExampleImagesImportError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to import example images: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
finally:
|
||||
# Clean up temporary files
|
||||
for temp_file in temp_files_to_cleanup:
|
||||
try:
|
||||
os.remove(temp_file)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to remove temporary file {temp_file}: {e}")
|
||||
raise ExampleImagesImportError(str(e)) from e
|
||||
|
||||
@staticmethod
|
||||
async def delete_custom_image(request):
|
||||
|
||||
Reference in New Issue
Block a user