Files
ComfyUI-Lora-Manager/py/utils/example_images_metadata.py

563 lines
23 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import logging
import os
import re
from typing import TYPE_CHECKING, Any, Dict, Optional
from ..recipes.constants import GEN_PARAM_KEYS
from ..services.metadata_service import get_default_metadata_provider, get_metadata_provider
from ..services.metadata_sync_service import MetadataSyncService
from ..services.preview_asset_service import PreviewAssetService
from ..services.settings_manager import get_settings_manager
from ..services.downloader import get_downloader
from ..utils.constants import SUPPORTED_MEDIA_EXTENSIONS
from ..utils.exif_utils import ExifUtils
from ..utils.metadata_manager import MetadataManager
logger = logging.getLogger(__name__)
_preview_service = PreviewAssetService(
metadata_manager=MetadataManager,
downloader_factory=get_downloader,
exif_utils=ExifUtils,
)
_metadata_sync_service: MetadataSyncService | None = None
_metadata_sync_service_settings: Optional["SettingsManager"] = None
if TYPE_CHECKING: # pragma: no cover - import for type checkers only
from ..services.settings_manager import SettingsManager
def _build_metadata_sync_service(settings_manager: "SettingsManager") -> MetadataSyncService:
"""Construct a metadata sync service bound to the provided settings."""
return MetadataSyncService(
metadata_manager=MetadataManager,
preview_service=_preview_service,
settings=settings_manager,
default_metadata_provider_factory=get_default_metadata_provider,
metadata_provider_selector=get_metadata_provider,
)
def _get_metadata_sync_service() -> MetadataSyncService:
"""Return the shared metadata sync service, initialising it lazily."""
global _metadata_sync_service, _metadata_sync_service_settings
settings_manager = get_settings_manager()
if isinstance(_metadata_sync_service, MetadataSyncService):
if _metadata_sync_service_settings is not settings_manager:
_metadata_sync_service = _build_metadata_sync_service(settings_manager)
_metadata_sync_service_settings = settings_manager
elif _metadata_sync_service is None:
_metadata_sync_service = _build_metadata_sync_service(settings_manager)
_metadata_sync_service_settings = settings_manager
else:
# Tests may inject stand-ins that do not match the sync service type. Preserve
# those injections while still updating our cached settings reference so the
# next real service instantiation uses the current configuration.
_metadata_sync_service_settings = settings_manager
return _metadata_sync_service
class MetadataUpdater:
"""Handles updating model metadata related to example images"""
@staticmethod
async def refresh_model_metadata(model_hash, model_name, scanner_type, scanner, progress: dict | None = None):
"""Refresh model metadata from CivitAI
Args:
model_hash: SHA256 hash of the model
model_name: Model name (for logging)
scanner_type: Scanner type ('lora' or 'checkpoint')
scanner: Scanner instance for this model type
Returns:
bool: True if metadata was successfully refreshed, False otherwise
"""
try:
# Find the model in the scanner cache
cache = await scanner.get_cached_data()
model_data = None
for item in cache.raw_data:
if item.get('sha256') == model_hash:
model_data = item
break
if not model_data:
logger.warning(f"Model {model_name} with hash {model_hash} not found in cache")
return False
file_path = model_data.get('file_path')
if not file_path:
logger.warning(f"Model {model_name} has no file path")
return False
# Track that we're refreshing this model
if progress is not None:
progress['refreshed_models'].add(model_hash)
async def update_cache_func(old_path, new_path, metadata):
return await scanner.update_single_model_cache(old_path, new_path, metadata)
await MetadataManager.hydrate_model_data(model_data)
success, error = await _get_metadata_sync_service().fetch_and_update_model(
sha256=model_hash,
file_path=file_path,
model_data=model_data,
update_cache_func=update_cache_func,
)
if success:
logger.info(f"Successfully refreshed metadata for {model_name}")
return True
else:
logger.warning(f"Failed to refresh metadata for {model_name}, {error}")
return False
except Exception as e:
error_msg = f"Error refreshing metadata for {model_name}: {str(e)}"
logger.error(error_msg, exc_info=True)
if progress is not None:
progress['errors'].append(error_msg)
progress['last_error'] = error_msg
return False
@staticmethod
async def get_updated_model(model_hash, scanner):
"""Load the most recent metadata for a model identified by hash."""
cache = await scanner.get_cached_data()
target = None
for item in cache.raw_data:
if item.get('sha256') == model_hash:
target = item
break
if not target:
return None
file_path = target.get('file_path')
if not file_path:
return target
model_cls = getattr(scanner, 'model_class', None)
if model_cls is None:
metadata, should_skip = await MetadataManager.load_metadata(file_path)
else:
metadata, should_skip = await MetadataManager.load_metadata(file_path, model_cls)
if should_skip or metadata is None:
return target
rich_metadata = metadata.to_dict()
rich_metadata.setdefault('folder', target.get('folder', ''))
return rich_metadata
@staticmethod
async def update_metadata_from_local_examples(model_hash, model, scanner_type, scanner, model_dir):
"""Update model metadata with local example image information
Args:
model_hash: SHA256 hash of the model
model: Model data dictionary
scanner_type: Scanner type ('lora' or 'checkpoint')
scanner: Scanner instance for this model type
model_dir: Model images directory
Returns:
bool: True if metadata was successfully updated, False otherwise
"""
try:
# Collect local image paths
local_images_paths = []
if os.path.exists(model_dir):
for file in os.listdir(model_dir):
file_path = os.path.join(model_dir, file)
if os.path.isfile(file_path):
file_ext = os.path.splitext(file)[1].lower()
is_supported = (file_ext in SUPPORTED_MEDIA_EXTENSIONS['images'] or
file_ext in SUPPORTED_MEDIA_EXTENSIONS['videos'])
if is_supported:
local_images_paths.append(file_path)
await MetadataManager.hydrate_model_data(model)
civitai_data = model.setdefault('civitai', {})
# Check if metadata update is needed (no civitai field or empty images)
needs_update = not civitai_data or not civitai_data.get('images')
if needs_update and local_images_paths:
logger.debug(f"Found {len(local_images_paths)} local example images for {model.get('model_name')}, updating metadata")
# Create or get civitai field
# Create images array
images = []
# Generate metadata for each local image/video
for path in local_images_paths:
# Determine if video or image
file_ext = os.path.splitext(path)[1].lower()
is_video = file_ext in SUPPORTED_MEDIA_EXTENSIONS['videos']
# Create image metadata entry
image_entry = {
"url": "", # Empty URL as required
"nsfwLevel": 0,
"width": 720, # Default dimensions
"height": 1280,
"type": "video" if is_video else "image",
"meta": None,
"hasMeta": False,
"hasPositivePrompt": False
}
# If it's an image, try to get actual dimensions (optional enhancement)
try:
from PIL import Image
if not is_video and os.path.exists(path):
with Image.open(path) as img:
image_entry["width"], image_entry["height"] = img.size
except:
# If PIL fails or is unavailable, use default dimensions
pass
images.append(image_entry)
# Update the model's civitai.images field
civitai_data['images'] = images
# Save metadata to .metadata.json file
file_path = model.get('file_path')
try:
model_copy = model.copy()
model_copy.pop('folder', None)
await MetadataManager.save_metadata(file_path, model_copy)
logger.info(f"Saved metadata for {model.get('model_name')}")
except Exception as e:
logger.error(f"Failed to save metadata for {model.get('model_name')}: {str(e)}")
# Save updated metadata to scanner cache
success = await scanner.update_single_model_cache(file_path, file_path, model)
if success:
logger.info(f"Successfully updated metadata for {model.get('model_name')} with {len(images)} local examples")
return True
else:
logger.warning(f"Failed to update metadata for {model.get('model_name')}")
return False
except Exception as e:
logger.error(f"Error updating metadata from local examples: {str(e)}", exc_info=True)
return False
@staticmethod
async def update_metadata_after_import(model_hash, model_data, scanner, newly_imported_paths):
"""Update model metadata after importing example images
Args:
model_hash: SHA256 hash of the model
model_data: Model data dictionary
scanner: Scanner instance (lora or checkpoint)
newly_imported_paths: List of paths to newly imported files
Returns:
tuple: (regular_images, custom_images) - Both image arrays
"""
try:
await MetadataManager.hydrate_model_data(model_data)
civitai_data = model_data.get('civitai')
if not isinstance(civitai_data, dict):
civitai_data = {}
model_data['civitai'] = civitai_data
custom_images = civitai_data.get('customImages')
if not isinstance(custom_images, list):
custom_images = []
civitai_data['customImages'] = custom_images
# Add new image entry for each imported file
for path_tuple in newly_imported_paths:
path, short_id = path_tuple
# Determine if video or image
file_ext = os.path.splitext(path)[1].lower()
is_video = file_ext in SUPPORTED_MEDIA_EXTENSIONS['videos']
# Create image metadata entry
image_entry = {
"url": "", # Empty URL as requested
"id": short_id,
"nsfwLevel": 0,
"width": 720, # Default dimensions
"height": 1280,
"type": "video" if is_video else "image",
"meta": None,
"hasMeta": False,
"hasPositivePrompt": False
}
# Extract and parse metadata if this is an image
if not is_video:
try:
# Extract metadata from image
extracted_metadata = ExifUtils.extract_image_metadata(path)
if extracted_metadata:
# Parse the extracted metadata to get generation parameters
parsed_meta = MetadataUpdater._parse_image_metadata(extracted_metadata)
if parsed_meta:
image_entry["meta"] = parsed_meta
image_entry["hasMeta"] = True
image_entry["hasPositivePrompt"] = bool(parsed_meta.get("prompt", ""))
logger.debug(f"Extracted metadata from {os.path.basename(path)}")
except Exception as e:
logger.warning(f"Failed to extract metadata from {os.path.basename(path)}: {e}")
# If it's an image, try to get actual dimensions
try:
from PIL import Image
if not is_video and os.path.exists(path):
with Image.open(path) as img:
image_entry["width"], image_entry["height"] = img.size
except:
# If PIL fails or is unavailable, use default dimensions
pass
# Append to existing customImages array
custom_images.append(image_entry)
# Save metadata to .metadata.json file
file_path = model_data.get('file_path')
if file_path:
try:
model_copy = model_data.copy()
model_copy.pop('folder', None)
await MetadataManager.save_metadata(file_path, model_copy)
logger.info(f"Saved metadata for {model_data.get('model_name')}")
except Exception as e:
logger.error(f"Failed to save metadata: {str(e)}")
# Save updated metadata to scanner cache
if file_path:
await scanner.update_single_model_cache(file_path, file_path, model_data)
# Get regular images array (might be None)
regular_images = civitai_data.get('images', [])
# Return both image arrays
return regular_images, custom_images
except Exception as e:
logger.error(f"Failed to update metadata after import: {e}", exc_info=True)
return [], []
@staticmethod
def _parse_image_metadata(user_comment):
"""Parse metadata from image to extract generation parameters
Args:
user_comment: Metadata string extracted from image
Returns:
dict: Parsed metadata with generation parameters
"""
if not user_comment:
return None
try:
# Initialize metadata dictionary
metadata = {}
# Split on Negative prompt if it exists
if "Negative prompt:" in user_comment:
parts = user_comment.split('Negative prompt:', 1)
prompt = parts[0].strip()
negative_and_params = parts[1] if len(parts) > 1 else ""
else:
# No negative prompt section
param_start = re.search(r'Steps: \d+', user_comment)
if param_start:
prompt = user_comment[:param_start.start()].strip()
negative_and_params = user_comment[param_start.start():]
else:
prompt = user_comment.strip()
negative_and_params = ""
# Add prompt if it's in GEN_PARAM_KEYS
if 'prompt' in GEN_PARAM_KEYS:
metadata['prompt'] = prompt
# Extract negative prompt and parameters
if negative_and_params:
# If we split on "Negative prompt:", check for params section
if "Negative prompt:" in user_comment:
param_start = re.search(r'Steps: ', negative_and_params)
if param_start:
neg_prompt = negative_and_params[:param_start.start()].strip()
if 'negative_prompt' in GEN_PARAM_KEYS:
metadata['negative_prompt'] = neg_prompt
params_section = negative_and_params[param_start.start():]
else:
if 'negative_prompt' in GEN_PARAM_KEYS:
metadata['negative_prompt'] = negative_and_params.strip()
params_section = ""
else:
# No negative prompt, entire section is params
params_section = negative_and_params
# Extract generation parameters
if params_section:
# Extract basic parameters
param_pattern = r'([A-Za-z\s]+): ([^,]+)'
params = re.findall(param_pattern, params_section)
for key, value in params:
clean_key = key.strip().lower().replace(' ', '_')
# Skip if not in recognized gen param keys
if clean_key not in GEN_PARAM_KEYS:
continue
# Convert numeric values
if clean_key in ['steps', 'seed']:
try:
metadata[clean_key] = int(value.strip())
except ValueError:
metadata[clean_key] = value.strip()
elif clean_key in ['cfg_scale']:
try:
metadata[clean_key] = float(value.strip())
except ValueError:
metadata[clean_key] = value.strip()
else:
metadata[clean_key] = value.strip()
# Extract size if available and add if a recognized key
size_match = re.search(r'Size: (\d+)x(\d+)', params_section)
if size_match and 'size' in GEN_PARAM_KEYS:
width, height = size_match.groups()
metadata['size'] = f"{width}x{height}"
# Return metadata if we have any entries
return metadata if metadata else None
except Exception as e:
logger.error(f"Error parsing image metadata: {e}", exc_info=True)
return None
@staticmethod
async def prune_stale_example_images(metadata) -> bool:
"""Remove example-image metadata entries whose files no longer exist on disk.
Checks ``civitai.customImages`` (by ``id``) and ``civitai.images`` entries
that have an empty ``url`` (no remote fallback) against actual files in
the model's example-image folder. Stale entries are removed in-place so
the caller can persist the cleaned metadata afterwards.
Args:
metadata: A ``BaseModelMetadata`` instance (modified in place).
Returns:
True if at least one entry was removed.
"""
from ..utils.example_images_paths import get_model_folder
model_hash = getattr(metadata, "sha256", None)
if not model_hash:
return False
model_folder = get_model_folder(model_hash)
if not model_folder:
return False
civitai = getattr(metadata, "civitai", None)
if not isinstance(civitai, dict):
return False
has_changes = False
custom_images = civitai.get("customImages")
if isinstance(custom_images, list) and custom_images:
stale: list[int] = []
for idx, img in enumerate(custom_images):
img_id = img.get("id", "")
if not img_id:
continue
if not os.path.isdir(model_folder):
stale.append(idx)
else:
found = False
try:
prefix = f"custom_{img_id}"
for fname in os.listdir(model_folder):
if fname.startswith(prefix) and os.path.isfile(
os.path.join(model_folder, fname)
):
found = True
break
except OSError:
stale.append(idx)
continue
if not found:
stale.append(idx)
if stale:
for idx in reversed(stale):
custom_images.pop(idx)
has_changes = True
logger.info(
"Pruned %d stale custom image(s) for %s",
len(stale),
getattr(metadata, "model_name", model_hash),
)
images = civitai.get("images")
if isinstance(images, list) and images:
stale: list[int] = []
for idx, img in enumerate(images):
if img.get("url", ""):
# Has a remote fallback keep it even if the local copy
# is gone.
continue
if not os.path.isdir(model_folder):
stale.append(idx)
else:
found = False
try:
prefix = f"image_{idx}."
for fname in os.listdir(model_folder):
if fname.startswith(prefix):
found = True
break
except OSError:
stale.append(idx)
continue
if not found:
stale.append(idx)
if stale:
for idx in reversed(stale):
images.pop(idx)
has_changes = True
logger.info(
"Pruned %d stale image entry(ies) for %s",
len(stale),
getattr(metadata, "model_name", model_hash),
)
return has_changes