feat: add model metadata refresh functionality and enhance download progress tracking. https://github.com/willmiao/ComfyUI-Lora-Manager/issues/151

This commit is contained in:
Will Miao
2025-05-01 18:57:29 +08:00
parent 5cd5a82ddc
commit 9dbcc105e7

View File

@@ -10,6 +10,8 @@ from ..utils.usage_stats import UsageStats
from ..services.service_registry import ServiceRegistry from ..services.service_registry import ServiceRegistry
from ..utils.exif_utils import ExifUtils from ..utils.exif_utils import ExifUtils
from ..utils.constants import EXAMPLE_IMAGE_WIDTH, SUPPORTED_MEDIA_EXTENSIONS from ..utils.constants import EXAMPLE_IMAGE_WIDTH, SUPPORTED_MEDIA_EXTENSIONS
from ..services.civitai_client import CivitaiClient
from ..utils.routes_common import ModelRouteUtils
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -25,7 +27,8 @@ download_progress = {
'last_error': None, 'last_error': None,
'start_time': None, 'start_time': None,
'end_time': None, 'end_time': None,
'processed_models': set() # Track models that have been processed 'processed_models': set(), # Track models that have been processed
'refreshed_models': set() # Track models that had metadata refreshed
} }
class MiscRoutes: class MiscRoutes:
@@ -149,6 +152,7 @@ class MiscRoutes:
# Create a copy for JSON serialization # Create a copy for JSON serialization
response_progress = download_progress.copy() response_progress = download_progress.copy()
response_progress['processed_models'] = list(download_progress['processed_models']) response_progress['processed_models'] = list(download_progress['processed_models'])
response_progress['refreshed_models'] = list(download_progress['refreshed_models'])
return web.json_response({ return web.json_response({
'success': False, 'success': False,
@@ -211,6 +215,7 @@ class MiscRoutes:
# Create a copy for JSON serialization # Create a copy for JSON serialization
response_progress = download_progress.copy() response_progress = download_progress.copy()
response_progress['processed_models'] = list(download_progress['processed_models']) response_progress['processed_models'] = list(download_progress['processed_models'])
response_progress['refreshed_models'] = list(download_progress['refreshed_models'])
return web.json_response({ return web.json_response({
'success': True, 'success': True,
@@ -233,6 +238,7 @@ class MiscRoutes:
# Create a copy of the progress dict with the set converted to a list for JSON serialization # Create a copy of the progress dict with the set converted to a list for JSON serialization
response_progress = download_progress.copy() response_progress = download_progress.copy()
response_progress['processed_models'] = list(download_progress['processed_models']) response_progress['processed_models'] = list(download_progress['processed_models'])
response_progress['refreshed_models'] = list(download_progress['refreshed_models'])
return web.json_response({ return web.json_response({
'success': True, 'success': True,
@@ -282,6 +288,259 @@ class MiscRoutes:
'error': f"Download is in '{download_progress['status']}' state, cannot resume" 'error': f"Download is in '{download_progress['status']}' state, cannot resume"
}, status=400) }, status=400)
@staticmethod
async def _refresh_model_metadata(model_hash, model_name, scanner_type, scanner):
"""Refresh model metadata from CivitAI
Args:
model_hash: SHA256 hash of the model
model_name: Name of the model (for logging)
scanner_type: Type of scanner ('lora' or 'checkpoint')
scanner: Scanner instance for this model type
Returns:
bool: True if metadata was successfully refreshed, False otherwise
"""
global download_progress
try:
# Find the model in the scanner cache
cache = await scanner.get_cached_data()
model_data = None
for item in cache.raw_data:
if item.get('sha256') == model_hash:
model_data = item
break
if not model_data:
logger.warning(f"Model {model_name} with hash {model_hash} not found in cache")
return False
file_path = model_data.get('file_path')
if not file_path:
logger.warning(f"Model {model_name} has no file path")
return False
# Track that we're refreshing this model
download_progress['refreshed_models'].add(model_hash)
# Use ModelRouteUtils to refresh the metadata
async def update_cache_func(old_path, new_path, metadata):
return await scanner.update_single_model_cache(old_path, new_path, metadata)
success = await ModelRouteUtils.fetch_and_update_model(
model_hash,
file_path,
model_data,
update_cache_func
)
if success:
logger.info(f"Successfully refreshed metadata for {model_name}")
return True
else:
logger.warning(f"Failed to refresh metadata for {model_name}")
return False
except Exception as e:
error_msg = f"Error refreshing metadata for {model_name}: {str(e)}"
logger.error(error_msg, exc_info=True)
download_progress['errors'].append(error_msg)
download_progress['last_error'] = error_msg
return False
@staticmethod
async def _process_model_images(model_hash, model_name, model_images, model_dir, optimize, independent_session, delay):
"""Process and download images for a single model
Args:
model_hash: SHA256 hash of the model
model_name: Name of the model
model_images: List of image objects from CivitAI
model_dir: Directory to save images to
optimize: Whether to optimize images
independent_session: aiohttp session for downloads
delay: Delay between downloads
Returns:
bool: True if all images were processed successfully, False otherwise
"""
global download_progress
model_success = True
for i, image in enumerate(model_images, 1):
image_url = image.get('url')
if not image_url:
continue
# Get image filename from URL
image_filename = os.path.basename(image_url.split('?')[0])
image_ext = os.path.splitext(image_filename)[1].lower()
# Handle both images and videos
is_image = image_ext in SUPPORTED_MEDIA_EXTENSIONS['images']
is_video = image_ext in SUPPORTED_MEDIA_EXTENSIONS['videos']
if not (is_image or is_video):
logger.debug(f"Skipping unsupported file type: {image_filename}")
continue
save_filename = f"image_{i}{image_ext}"
# Check if already downloaded
save_path = os.path.join(model_dir, save_filename)
if os.path.exists(save_path):
logger.debug(f"File already exists: {save_path}")
continue
# Download the file
try:
logger.debug(f"Downloading {save_filename} for {model_name}")
# Direct download using the independent session
async with independent_session.get(image_url, timeout=60) as response:
if response.status == 200:
if is_image and optimize:
# For images, optimize if requested
image_data = await response.read()
optimized_data, ext = ExifUtils.optimize_image(
image_data,
target_width=EXAMPLE_IMAGE_WIDTH,
format='webp',
quality=85,
preserve_metadata=False
)
# Update save filename if format changed
if ext == '.webp':
save_filename = os.path.splitext(save_filename)[0] + '.webp'
save_path = os.path.join(model_dir, save_filename)
# Save the optimized image
with open(save_path, 'wb') as f:
f.write(optimized_data)
else:
# For videos or unoptimized images, save directly
with open(save_path, 'wb') as f:
async for chunk in response.content.iter_chunked(8192):
if chunk:
f.write(chunk)
elif response.status == 404:
error_msg = f"Failed to download file: {image_url}, status code: 404 - Model metadata might be stale"
logger.warning(error_msg)
download_progress['errors'].append(error_msg)
download_progress['last_error'] = error_msg
model_success = False # Mark model as failed due to 404
# Return early to trigger metadata refresh attempt
return False, True # (success, is_stale_metadata)
else:
error_msg = f"Failed to download file: {image_url}, status code: {response.status}"
logger.warning(error_msg)
download_progress['errors'].append(error_msg)
download_progress['last_error'] = error_msg
model_success = False # Mark model as failed
# Add a delay between downloads for remote files only
await asyncio.sleep(delay)
except Exception as e:
error_msg = f"Error downloading file {image_url}: {str(e)}"
logger.error(error_msg)
download_progress['errors'].append(error_msg)
download_progress['last_error'] = error_msg
model_success = False # Mark model as failed
return model_success, False # (success, is_stale_metadata)
@staticmethod
async def _process_local_example_images(model_file_path, model_file_name, model_name, model_dir, optimize):
"""Process local example images for a model
Args:
model_file_path: Path to the model file
model_file_name: Filename of the model
model_name: Name of the model
model_dir: Directory to save processed images to
optimize: Whether to optimize images
Returns:
bool: True if local images were processed successfully, False otherwise
"""
global download_progress
try:
model_dir_path = os.path.dirname(model_file_path)
local_images = []
# Look for files with pattern: filename.example.*.ext
if model_file_name:
example_prefix = f"{model_file_name}.example."
if os.path.exists(model_dir_path):
for file in os.listdir(model_dir_path):
file_lower = file.lower()
if file_lower.startswith(example_prefix.lower()):
file_ext = os.path.splitext(file_lower)[1]
is_supported = (file_ext in SUPPORTED_MEDIA_EXTENSIONS['images'] or
file_ext in SUPPORTED_MEDIA_EXTENSIONS['videos'])
if is_supported:
local_images.append(os.path.join(model_dir_path, file))
# Process local images if found
if local_images:
logger.info(f"Found {len(local_images)} local example images for {model_name}")
for i, local_image_path in enumerate(local_images, 1):
local_ext = os.path.splitext(local_image_path)[1].lower()
save_filename = f"image_{i}{local_ext}"
save_path = os.path.join(model_dir, save_filename)
# Skip if already exists in output directory
if os.path.exists(save_path):
logger.debug(f"File already exists in output: {save_path}")
continue
# Handle image processing based on file type and optimize setting
is_image = local_ext in SUPPORTED_MEDIA_EXTENSIONS['images']
if is_image and optimize:
# Optimize the image
with open(local_image_path, 'rb') as img_file:
image_data = img_file.read()
optimized_data, ext = ExifUtils.optimize_image(
image_data,
target_width=EXAMPLE_IMAGE_WIDTH,
format='webp',
quality=85,
preserve_metadata=False
)
# Update save filename if format changed
if ext == '.webp':
save_filename = os.path.splitext(save_filename)[0] + '.webp'
save_path = os.path.join(model_dir, save_filename)
# Save the optimized image
with open(save_path, 'wb') as f:
f.write(optimized_data)
else:
# For videos or unoptimized images, copy directly
with open(local_image_path, 'rb') as src_file:
with open(save_path, 'wb') as dst_file:
dst_file.write(src_file.read())
return True
return False
except Exception as e:
error_msg = f"Error processing local examples for {model_name}: {str(e)}"
logger.error(error_msg)
download_progress['errors'].append(error_msg)
download_progress['last_error'] = error_msg
return False
@staticmethod @staticmethod
async def _download_all_example_images(output_dir, optimize, model_types, delay): async def _download_all_example_images(output_dir, optimize, model_types, delay):
"""Download example images for all models """Download example images for all models
@@ -330,14 +589,14 @@ class MiscRoutes:
for model in cache.raw_data: for model in cache.raw_data:
# Only process models with images and a valid sha256 # Only process models with images and a valid sha256
if model.get('civitai') and model.get('civitai', {}).get('images') and model.get('sha256'): if model.get('civitai') and model.get('civitai', {}).get('images') and model.get('sha256'):
all_models.append((scanner_type, model)) all_models.append((scanner_type, model, scanner))
# Update total count # Update total count
download_progress['total'] = len(all_models) download_progress['total'] = len(all_models)
logger.info(f"Found {download_progress['total']} models with example images") logger.info(f"Found {download_progress['total']} models with example images")
# Process each model # Process each model
for scanner_type, model in all_models: for scanner_type, model, scanner in all_models:
# Check if download is paused # Check if download is paused
while download_progress['status'] == 'paused': while download_progress['status'] == 'paused':
await asyncio.sleep(1) await asyncio.sleep(1)
@@ -347,14 +606,13 @@ class MiscRoutes:
logger.info(f"Download stopped: {download_progress['status']}") logger.info(f"Download stopped: {download_progress['status']}")
break break
model_success = True # Track if all images for this model download successfully model_hash = model.get('sha256', '').lower()
model_name = model.get('model_name', 'Unknown')
model_file_path = model.get('file_path', '')
model_file_name = model.get('file_name', '')
try: try:
# Update current model info # Update current model info
model_hash = model.get('sha256', '').lower()
model_name = model.get('model_name', 'Unknown')
model_file_path = model.get('file_path', '')
model_file_name = model.get('file_name', '')
download_progress['current_model'] = f"{model_name} ({model_hash[:8]})" download_progress['current_model'] = f"{model_name} ({model_hash[:8]})"
# Skip if already processed # Skip if already processed
@@ -379,156 +637,69 @@ class MiscRoutes:
# First check if we have local example images for this model # First check if we have local example images for this model
local_images_processed = False local_images_processed = False
if model_file_path: if model_file_path:
try: local_images_processed = await MiscRoutes._process_local_example_images(
model_dir_path = os.path.dirname(model_file_path) model_file_path,
local_images = [] model_file_name,
model_name,
# Look for files with pattern: filename.example.*.ext model_dir,
if model_file_name: optimize
example_prefix = f"{model_file_name}.example." )
if os.path.exists(model_dir_path):
for file in os.listdir(model_dir_path):
file_lower = file.lower()
if file_lower.startswith(example_prefix.lower()):
file_ext = os.path.splitext(file_lower)[1]
is_supported = (file_ext in SUPPORTED_MEDIA_EXTENSIONS['images'] or
file_ext in SUPPORTED_MEDIA_EXTENSIONS['videos'])
if is_supported:
local_images.append(os.path.join(model_dir_path, file))
# Process local images if found
if local_images:
logger.info(f"Found {len(local_images)} local example images for {model_name}")
for i, local_image_path in enumerate(local_images, 1):
local_ext = os.path.splitext(local_image_path)[1].lower()
save_filename = f"image_{i}{local_ext}"
save_path = os.path.join(model_dir, save_filename)
# Skip if already exists in output directory
if os.path.exists(save_path):
logger.debug(f"File already exists in output: {save_path}")
continue
# Handle image processing based on file type and optimize setting
is_image = local_ext in SUPPORTED_MEDIA_EXTENSIONS['images']
if is_image and optimize:
# Optimize the image
with open(local_image_path, 'rb') as img_file:
image_data = img_file.read()
optimized_data, ext = ExifUtils.optimize_image(
image_data,
target_width=EXAMPLE_IMAGE_WIDTH,
format='webp',
quality=85,
preserve_metadata=False
)
# Update save filename if format changed
if ext == '.webp':
save_filename = os.path.splitext(save_filename)[0] + '.webp'
save_path = os.path.join(model_dir, save_filename)
# Save the optimized image
with open(save_path, 'wb') as f:
f.write(optimized_data)
else:
# For videos or unoptimized images, copy directly
with open(local_image_path, 'rb') as src_file:
with open(save_path, 'wb') as dst_file:
dst_file.write(src_file.read())
# Mark as successfully processed if all local images were processed
download_progress['processed_models'].add(model_hash)
local_images_processed = True
logger.info(f"Successfully processed local examples for {model_name}")
except Exception as e: if local_images_processed:
error_msg = f"Error processing local examples for {model_name}: {str(e)}" # Mark as successfully processed if all local images were processed
logger.error(error_msg) download_progress['processed_models'].add(model_hash)
download_progress['errors'].append(error_msg) logger.info(f"Successfully processed local examples for {model_name}")
download_progress['last_error'] = error_msg
# Continue to remote download if local processing fails
# If we didn't process local images, download from remote # If we didn't process local images, download from remote
if not local_images_processed: if not local_images_processed:
# Download example images # Try to download images
for i, image in enumerate(images, 1): model_success, is_stale_metadata = await MiscRoutes._process_model_images(
image_url = image.get('url') model_hash,
if not image_url: model_name,
continue images,
model_dir,
optimize,
independent_session,
delay
)
# If metadata is stale (404 error), try to refresh it and download again
if is_stale_metadata and model_hash not in download_progress['refreshed_models']:
logger.info(f"Metadata seems stale for {model_name}, attempting to refresh...")
# Get image filename from URL # Refresh metadata from CivitAI
image_filename = os.path.basename(image_url.split('?')[0]) refresh_success = await MiscRoutes._refresh_model_metadata(
image_ext = os.path.splitext(image_filename)[1].lower() model_hash,
model_name,
scanner_type,
scanner
)
# Handle both images and videos if refresh_success:
is_image = image_ext in SUPPORTED_MEDIA_EXTENSIONS['images'] # Get updated model data
is_video = image_ext in SUPPORTED_MEDIA_EXTENSIONS['videos'] updated_cache = await scanner.get_cached_data()
updated_model = None
if not (is_image or is_video):
logger.debug(f"Skipping unsupported file type: {image_filename}")
continue
save_filename = f"image_{i}{image_ext}"
# Check if already downloaded
save_path = os.path.join(model_dir, save_filename)
if os.path.exists(save_path):
logger.debug(f"File already exists: {save_path}")
continue
# Download the file
try:
logger.debug(f"Downloading {save_filename} for {model_name}")
# Direct download using the independent session for item in updated_cache.raw_data:
async with independent_session.get(image_url, timeout=60) as response: if item.get('sha256') == model_hash:
if response.status == 200: updated_model = item
if is_image and optimize: break
# For images, optimize if requested
image_data = await response.read()
optimized_data, ext = ExifUtils.optimize_image(
image_data,
target_width=EXAMPLE_IMAGE_WIDTH,
format='webp',
quality=85,
preserve_metadata=False
)
# Update save filename if format changed
if ext == '.webp':
save_filename = os.path.splitext(save_filename)[0] + '.webp'
save_path = os.path.join(model_dir, save_filename)
# Save the optimized image
with open(save_path, 'wb') as f:
f.write(optimized_data)
else:
# For videos or unoptimized images, save directly
with open(save_path, 'wb') as f:
async for chunk in response.content.iter_chunked(8192):
if chunk:
f.write(chunk)
else:
error_msg = f"Failed to download file: {image_url}, status code: {response.status}"
logger.warning(error_msg)
download_progress['errors'].append(error_msg)
download_progress['last_error'] = error_msg
model_success = False # Mark model as failed
# Add a delay between downloads for remote files only if updated_model and updated_model.get('civitai', {}).get('images'):
await asyncio.sleep(delay) # Try downloading with updated metadata
except Exception as e: logger.info(f"Retrying download with refreshed metadata for {model_name}")
error_msg = f"Error downloading file {image_url}: {str(e)}" updated_images = updated_model.get('civitai', {}).get('images', [])
logger.error(error_msg)
download_progress['errors'].append(error_msg) # Retry download with new images
download_progress['last_error'] = error_msg model_success, _ = await MiscRoutes._process_model_images(
model_success = False # Mark model as failed model_hash,
model_name,
updated_images,
model_dir,
optimize,
independent_session,
delay
)
# Only mark model as processed if all images downloaded successfully # Only mark model as processed if all images downloaded successfully
if model_success: if model_success:
@@ -542,6 +713,7 @@ class MiscRoutes:
with open(progress_file, 'w', encoding='utf-8') as f: with open(progress_file, 'w', encoding='utf-8') as f:
json.dump({ json.dump({
'processed_models': list(download_progress['processed_models']), 'processed_models': list(download_progress['processed_models']),
'refreshed_models': list(download_progress['refreshed_models']),
'completed': download_progress['completed'], 'completed': download_progress['completed'],
'total': download_progress['total'], 'total': download_progress['total'],
'last_update': time.time() 'last_update': time.time()
@@ -582,6 +754,7 @@ class MiscRoutes:
with open(progress_file, 'w', encoding='utf-8') as f: with open(progress_file, 'w', encoding='utf-8') as f:
json.dump({ json.dump({
'processed_models': list(download_progress['processed_models']), 'processed_models': list(download_progress['processed_models']),
'refreshed_models': list(download_progress['refreshed_models']),
'completed': download_progress['completed'], 'completed': download_progress['completed'],
'total': download_progress['total'], 'total': download_progress['total'],
'last_update': time.time(), 'last_update': time.time(),