mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-05-14 17:17:37 -03:00
563 lines
23 KiB
Python
563 lines
23 KiB
Python
import logging
|
||
import os
|
||
import re
|
||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||
|
||
from ..recipes.constants import GEN_PARAM_KEYS
|
||
from ..services.metadata_service import get_default_metadata_provider, get_metadata_provider
|
||
from ..services.metadata_sync_service import MetadataSyncService
|
||
from ..services.preview_asset_service import PreviewAssetService
|
||
from ..services.settings_manager import get_settings_manager
|
||
from ..services.downloader import get_downloader
|
||
from ..utils.constants import SUPPORTED_MEDIA_EXTENSIONS
|
||
from ..utils.exif_utils import ExifUtils
|
||
from ..utils.metadata_manager import MetadataManager
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
_preview_service = PreviewAssetService(
|
||
metadata_manager=MetadataManager,
|
||
downloader_factory=get_downloader,
|
||
exif_utils=ExifUtils,
|
||
)
|
||
|
||
_metadata_sync_service: MetadataSyncService | None = None
|
||
_metadata_sync_service_settings: Optional["SettingsManager"] = None
|
||
|
||
if TYPE_CHECKING: # pragma: no cover - import for type checkers only
|
||
from ..services.settings_manager import SettingsManager
|
||
|
||
|
||
def _build_metadata_sync_service(settings_manager: "SettingsManager") -> MetadataSyncService:
|
||
"""Construct a metadata sync service bound to the provided settings."""
|
||
|
||
return MetadataSyncService(
|
||
metadata_manager=MetadataManager,
|
||
preview_service=_preview_service,
|
||
settings=settings_manager,
|
||
default_metadata_provider_factory=get_default_metadata_provider,
|
||
metadata_provider_selector=get_metadata_provider,
|
||
)
|
||
|
||
|
||
def _get_metadata_sync_service() -> MetadataSyncService:
|
||
"""Return the shared metadata sync service, initialising it lazily."""
|
||
|
||
global _metadata_sync_service, _metadata_sync_service_settings
|
||
|
||
settings_manager = get_settings_manager()
|
||
|
||
if isinstance(_metadata_sync_service, MetadataSyncService):
|
||
if _metadata_sync_service_settings is not settings_manager:
|
||
_metadata_sync_service = _build_metadata_sync_service(settings_manager)
|
||
_metadata_sync_service_settings = settings_manager
|
||
elif _metadata_sync_service is None:
|
||
_metadata_sync_service = _build_metadata_sync_service(settings_manager)
|
||
_metadata_sync_service_settings = settings_manager
|
||
else:
|
||
# Tests may inject stand-ins that do not match the sync service type. Preserve
|
||
# those injections while still updating our cached settings reference so the
|
||
# next real service instantiation uses the current configuration.
|
||
_metadata_sync_service_settings = settings_manager
|
||
|
||
return _metadata_sync_service
|
||
|
||
|
||
class MetadataUpdater:
|
||
"""Handles updating model metadata related to example images"""
|
||
|
||
@staticmethod
|
||
async def refresh_model_metadata(model_hash, model_name, scanner_type, scanner, progress: dict | None = None):
|
||
"""Refresh model metadata from CivitAI
|
||
|
||
Args:
|
||
model_hash: SHA256 hash of the model
|
||
model_name: Model name (for logging)
|
||
scanner_type: Scanner type ('lora' or 'checkpoint')
|
||
scanner: Scanner instance for this model type
|
||
|
||
Returns:
|
||
bool: True if metadata was successfully refreshed, False otherwise
|
||
"""
|
||
try:
|
||
# Find the model in the scanner cache
|
||
cache = await scanner.get_cached_data()
|
||
model_data = None
|
||
|
||
for item in cache.raw_data:
|
||
if item.get('sha256') == model_hash:
|
||
model_data = item
|
||
break
|
||
|
||
if not model_data:
|
||
logger.warning(f"Model {model_name} with hash {model_hash} not found in cache")
|
||
return False
|
||
|
||
file_path = model_data.get('file_path')
|
||
if not file_path:
|
||
logger.warning(f"Model {model_name} has no file path")
|
||
return False
|
||
|
||
# Track that we're refreshing this model
|
||
if progress is not None:
|
||
progress['refreshed_models'].add(model_hash)
|
||
|
||
async def update_cache_func(old_path, new_path, metadata):
|
||
return await scanner.update_single_model_cache(old_path, new_path, metadata)
|
||
|
||
await MetadataManager.hydrate_model_data(model_data)
|
||
success, error = await _get_metadata_sync_service().fetch_and_update_model(
|
||
sha256=model_hash,
|
||
file_path=file_path,
|
||
model_data=model_data,
|
||
update_cache_func=update_cache_func,
|
||
)
|
||
|
||
if success:
|
||
logger.info(f"Successfully refreshed metadata for {model_name}")
|
||
return True
|
||
else:
|
||
logger.warning(f"Failed to refresh metadata for {model_name}, {error}")
|
||
return False
|
||
|
||
except Exception as e:
|
||
error_msg = f"Error refreshing metadata for {model_name}: {str(e)}"
|
||
logger.error(error_msg, exc_info=True)
|
||
if progress is not None:
|
||
progress['errors'].append(error_msg)
|
||
progress['last_error'] = error_msg
|
||
return False
|
||
|
||
@staticmethod
|
||
async def get_updated_model(model_hash, scanner):
|
||
"""Load the most recent metadata for a model identified by hash."""
|
||
cache = await scanner.get_cached_data()
|
||
target = None
|
||
for item in cache.raw_data:
|
||
if item.get('sha256') == model_hash:
|
||
target = item
|
||
break
|
||
|
||
if not target:
|
||
return None
|
||
|
||
file_path = target.get('file_path')
|
||
if not file_path:
|
||
return target
|
||
|
||
model_cls = getattr(scanner, 'model_class', None)
|
||
if model_cls is None:
|
||
metadata, should_skip = await MetadataManager.load_metadata(file_path)
|
||
else:
|
||
metadata, should_skip = await MetadataManager.load_metadata(file_path, model_cls)
|
||
|
||
if should_skip or metadata is None:
|
||
return target
|
||
|
||
rich_metadata = metadata.to_dict()
|
||
rich_metadata.setdefault('folder', target.get('folder', ''))
|
||
return rich_metadata
|
||
|
||
|
||
@staticmethod
|
||
async def update_metadata_from_local_examples(model_hash, model, scanner_type, scanner, model_dir):
|
||
"""Update model metadata with local example image information
|
||
|
||
Args:
|
||
model_hash: SHA256 hash of the model
|
||
model: Model data dictionary
|
||
scanner_type: Scanner type ('lora' or 'checkpoint')
|
||
scanner: Scanner instance for this model type
|
||
model_dir: Model images directory
|
||
|
||
Returns:
|
||
bool: True if metadata was successfully updated, False otherwise
|
||
"""
|
||
try:
|
||
# Collect local image paths
|
||
local_images_paths = []
|
||
if os.path.exists(model_dir):
|
||
for file in os.listdir(model_dir):
|
||
file_path = os.path.join(model_dir, file)
|
||
if os.path.isfile(file_path):
|
||
file_ext = os.path.splitext(file)[1].lower()
|
||
is_supported = (file_ext in SUPPORTED_MEDIA_EXTENSIONS['images'] or
|
||
file_ext in SUPPORTED_MEDIA_EXTENSIONS['videos'])
|
||
if is_supported:
|
||
local_images_paths.append(file_path)
|
||
|
||
await MetadataManager.hydrate_model_data(model)
|
||
civitai_data = model.setdefault('civitai', {})
|
||
|
||
# Check if metadata update is needed (no civitai field or empty images)
|
||
needs_update = not civitai_data or not civitai_data.get('images')
|
||
|
||
if needs_update and local_images_paths:
|
||
logger.debug(f"Found {len(local_images_paths)} local example images for {model.get('model_name')}, updating metadata")
|
||
|
||
# Create or get civitai field
|
||
# Create images array
|
||
images = []
|
||
|
||
# Generate metadata for each local image/video
|
||
for path in local_images_paths:
|
||
# Determine if video or image
|
||
file_ext = os.path.splitext(path)[1].lower()
|
||
is_video = file_ext in SUPPORTED_MEDIA_EXTENSIONS['videos']
|
||
|
||
# Create image metadata entry
|
||
image_entry = {
|
||
"url": "", # Empty URL as required
|
||
"nsfwLevel": 0,
|
||
"width": 720, # Default dimensions
|
||
"height": 1280,
|
||
"type": "video" if is_video else "image",
|
||
"meta": None,
|
||
"hasMeta": False,
|
||
"hasPositivePrompt": False
|
||
}
|
||
|
||
# If it's an image, try to get actual dimensions (optional enhancement)
|
||
try:
|
||
from PIL import Image
|
||
if not is_video and os.path.exists(path):
|
||
with Image.open(path) as img:
|
||
image_entry["width"], image_entry["height"] = img.size
|
||
except:
|
||
# If PIL fails or is unavailable, use default dimensions
|
||
pass
|
||
|
||
images.append(image_entry)
|
||
|
||
# Update the model's civitai.images field
|
||
civitai_data['images'] = images
|
||
|
||
# Save metadata to .metadata.json file
|
||
file_path = model.get('file_path')
|
||
try:
|
||
model_copy = model.copy()
|
||
model_copy.pop('folder', None)
|
||
await MetadataManager.save_metadata(file_path, model_copy)
|
||
logger.info(f"Saved metadata for {model.get('model_name')}")
|
||
except Exception as e:
|
||
logger.error(f"Failed to save metadata for {model.get('model_name')}: {str(e)}")
|
||
|
||
# Save updated metadata to scanner cache
|
||
success = await scanner.update_single_model_cache(file_path, file_path, model)
|
||
if success:
|
||
logger.info(f"Successfully updated metadata for {model.get('model_name')} with {len(images)} local examples")
|
||
return True
|
||
else:
|
||
logger.warning(f"Failed to update metadata for {model.get('model_name')}")
|
||
|
||
return False
|
||
except Exception as e:
|
||
logger.error(f"Error updating metadata from local examples: {str(e)}", exc_info=True)
|
||
return False
|
||
|
||
@staticmethod
|
||
async def update_metadata_after_import(model_hash, model_data, scanner, newly_imported_paths):
|
||
"""Update model metadata after importing example images
|
||
|
||
Args:
|
||
model_hash: SHA256 hash of the model
|
||
model_data: Model data dictionary
|
||
scanner: Scanner instance (lora or checkpoint)
|
||
newly_imported_paths: List of paths to newly imported files
|
||
|
||
Returns:
|
||
tuple: (regular_images, custom_images) - Both image arrays
|
||
"""
|
||
try:
|
||
await MetadataManager.hydrate_model_data(model_data)
|
||
civitai_data = model_data.get('civitai')
|
||
|
||
if not isinstance(civitai_data, dict):
|
||
civitai_data = {}
|
||
model_data['civitai'] = civitai_data
|
||
|
||
custom_images = civitai_data.get('customImages')
|
||
|
||
if not isinstance(custom_images, list):
|
||
custom_images = []
|
||
civitai_data['customImages'] = custom_images
|
||
|
||
# Add new image entry for each imported file
|
||
for path_tuple in newly_imported_paths:
|
||
path, short_id = path_tuple
|
||
|
||
# Determine if video or image
|
||
file_ext = os.path.splitext(path)[1].lower()
|
||
is_video = file_ext in SUPPORTED_MEDIA_EXTENSIONS['videos']
|
||
|
||
# Create image metadata entry
|
||
image_entry = {
|
||
"url": "", # Empty URL as requested
|
||
"id": short_id,
|
||
"nsfwLevel": 0,
|
||
"width": 720, # Default dimensions
|
||
"height": 1280,
|
||
"type": "video" if is_video else "image",
|
||
"meta": None,
|
||
"hasMeta": False,
|
||
"hasPositivePrompt": False
|
||
}
|
||
|
||
# Extract and parse metadata if this is an image
|
||
if not is_video:
|
||
try:
|
||
# Extract metadata from image
|
||
extracted_metadata = ExifUtils.extract_image_metadata(path)
|
||
|
||
if extracted_metadata:
|
||
# Parse the extracted metadata to get generation parameters
|
||
parsed_meta = MetadataUpdater._parse_image_metadata(extracted_metadata)
|
||
|
||
if parsed_meta:
|
||
image_entry["meta"] = parsed_meta
|
||
image_entry["hasMeta"] = True
|
||
image_entry["hasPositivePrompt"] = bool(parsed_meta.get("prompt", ""))
|
||
logger.debug(f"Extracted metadata from {os.path.basename(path)}")
|
||
except Exception as e:
|
||
logger.warning(f"Failed to extract metadata from {os.path.basename(path)}: {e}")
|
||
|
||
# If it's an image, try to get actual dimensions
|
||
try:
|
||
from PIL import Image
|
||
if not is_video and os.path.exists(path):
|
||
with Image.open(path) as img:
|
||
image_entry["width"], image_entry["height"] = img.size
|
||
except:
|
||
# If PIL fails or is unavailable, use default dimensions
|
||
pass
|
||
|
||
# Append to existing customImages array
|
||
custom_images.append(image_entry)
|
||
|
||
# Save metadata to .metadata.json file
|
||
file_path = model_data.get('file_path')
|
||
if file_path:
|
||
try:
|
||
model_copy = model_data.copy()
|
||
model_copy.pop('folder', None)
|
||
await MetadataManager.save_metadata(file_path, model_copy)
|
||
logger.info(f"Saved metadata for {model_data.get('model_name')}")
|
||
except Exception as e:
|
||
logger.error(f"Failed to save metadata: {str(e)}")
|
||
|
||
# Save updated metadata to scanner cache
|
||
if file_path:
|
||
await scanner.update_single_model_cache(file_path, file_path, model_data)
|
||
|
||
# Get regular images array (might be None)
|
||
regular_images = civitai_data.get('images', [])
|
||
|
||
# Return both image arrays
|
||
return regular_images, custom_images
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to update metadata after import: {e}", exc_info=True)
|
||
return [], []
|
||
|
||
@staticmethod
|
||
def _parse_image_metadata(user_comment):
|
||
"""Parse metadata from image to extract generation parameters
|
||
|
||
Args:
|
||
user_comment: Metadata string extracted from image
|
||
|
||
Returns:
|
||
dict: Parsed metadata with generation parameters
|
||
"""
|
||
if not user_comment:
|
||
return None
|
||
|
||
try:
|
||
# Initialize metadata dictionary
|
||
metadata = {}
|
||
|
||
# Split on Negative prompt if it exists
|
||
if "Negative prompt:" in user_comment:
|
||
parts = user_comment.split('Negative prompt:', 1)
|
||
prompt = parts[0].strip()
|
||
negative_and_params = parts[1] if len(parts) > 1 else ""
|
||
else:
|
||
# No negative prompt section
|
||
param_start = re.search(r'Steps: \d+', user_comment)
|
||
if param_start:
|
||
prompt = user_comment[:param_start.start()].strip()
|
||
negative_and_params = user_comment[param_start.start():]
|
||
else:
|
||
prompt = user_comment.strip()
|
||
negative_and_params = ""
|
||
|
||
# Add prompt if it's in GEN_PARAM_KEYS
|
||
if 'prompt' in GEN_PARAM_KEYS:
|
||
metadata['prompt'] = prompt
|
||
|
||
# Extract negative prompt and parameters
|
||
if negative_and_params:
|
||
# If we split on "Negative prompt:", check for params section
|
||
if "Negative prompt:" in user_comment:
|
||
param_start = re.search(r'Steps: ', negative_and_params)
|
||
if param_start:
|
||
neg_prompt = negative_and_params[:param_start.start()].strip()
|
||
if 'negative_prompt' in GEN_PARAM_KEYS:
|
||
metadata['negative_prompt'] = neg_prompt
|
||
params_section = negative_and_params[param_start.start():]
|
||
else:
|
||
if 'negative_prompt' in GEN_PARAM_KEYS:
|
||
metadata['negative_prompt'] = negative_and_params.strip()
|
||
params_section = ""
|
||
else:
|
||
# No negative prompt, entire section is params
|
||
params_section = negative_and_params
|
||
|
||
# Extract generation parameters
|
||
if params_section:
|
||
# Extract basic parameters
|
||
param_pattern = r'([A-Za-z\s]+): ([^,]+)'
|
||
params = re.findall(param_pattern, params_section)
|
||
|
||
for key, value in params:
|
||
clean_key = key.strip().lower().replace(' ', '_')
|
||
|
||
# Skip if not in recognized gen param keys
|
||
if clean_key not in GEN_PARAM_KEYS:
|
||
continue
|
||
|
||
# Convert numeric values
|
||
if clean_key in ['steps', 'seed']:
|
||
try:
|
||
metadata[clean_key] = int(value.strip())
|
||
except ValueError:
|
||
metadata[clean_key] = value.strip()
|
||
elif clean_key in ['cfg_scale']:
|
||
try:
|
||
metadata[clean_key] = float(value.strip())
|
||
except ValueError:
|
||
metadata[clean_key] = value.strip()
|
||
else:
|
||
metadata[clean_key] = value.strip()
|
||
|
||
# Extract size if available and add if a recognized key
|
||
size_match = re.search(r'Size: (\d+)x(\d+)', params_section)
|
||
if size_match and 'size' in GEN_PARAM_KEYS:
|
||
width, height = size_match.groups()
|
||
metadata['size'] = f"{width}x{height}"
|
||
|
||
# Return metadata if we have any entries
|
||
return metadata if metadata else None
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error parsing image metadata: {e}", exc_info=True)
|
||
return None
|
||
|
||
@staticmethod
|
||
async def prune_stale_example_images(metadata) -> bool:
|
||
"""Remove example-image metadata entries whose files no longer exist on disk.
|
||
|
||
Checks ``civitai.customImages`` (by ``id``) and ``civitai.images`` entries
|
||
that have an empty ``url`` (no remote fallback) against actual files in
|
||
the model's example-image folder. Stale entries are removed in-place so
|
||
the caller can persist the cleaned metadata afterwards.
|
||
|
||
Args:
|
||
metadata: A ``BaseModelMetadata`` instance (modified in place).
|
||
|
||
Returns:
|
||
True if at least one entry was removed.
|
||
"""
|
||
from ..utils.example_images_paths import get_model_folder
|
||
|
||
model_hash = getattr(metadata, "sha256", None)
|
||
if not model_hash:
|
||
return False
|
||
|
||
model_folder = get_model_folder(model_hash)
|
||
if not model_folder:
|
||
return False
|
||
|
||
civitai = getattr(metadata, "civitai", None)
|
||
if not isinstance(civitai, dict):
|
||
return False
|
||
|
||
has_changes = False
|
||
|
||
custom_images = civitai.get("customImages")
|
||
if isinstance(custom_images, list) and custom_images:
|
||
stale: list[int] = []
|
||
|
||
for idx, img in enumerate(custom_images):
|
||
img_id = img.get("id", "")
|
||
if not img_id:
|
||
continue
|
||
|
||
if not os.path.isdir(model_folder):
|
||
stale.append(idx)
|
||
else:
|
||
found = False
|
||
try:
|
||
prefix = f"custom_{img_id}"
|
||
for fname in os.listdir(model_folder):
|
||
if fname.startswith(prefix) and os.path.isfile(
|
||
os.path.join(model_folder, fname)
|
||
):
|
||
found = True
|
||
break
|
||
except OSError:
|
||
stale.append(idx)
|
||
continue
|
||
|
||
if not found:
|
||
stale.append(idx)
|
||
|
||
if stale:
|
||
for idx in reversed(stale):
|
||
custom_images.pop(idx)
|
||
has_changes = True
|
||
logger.info(
|
||
"Pruned %d stale custom image(s) for %s",
|
||
len(stale),
|
||
getattr(metadata, "model_name", model_hash),
|
||
)
|
||
|
||
images = civitai.get("images")
|
||
if isinstance(images, list) and images:
|
||
stale: list[int] = []
|
||
|
||
for idx, img in enumerate(images):
|
||
if img.get("url", ""):
|
||
# Has a remote fallback – keep it even if the local copy
|
||
# is gone.
|
||
continue
|
||
|
||
if not os.path.isdir(model_folder):
|
||
stale.append(idx)
|
||
else:
|
||
found = False
|
||
try:
|
||
prefix = f"image_{idx}."
|
||
for fname in os.listdir(model_folder):
|
||
if fname.startswith(prefix):
|
||
found = True
|
||
break
|
||
except OSError:
|
||
stale.append(idx)
|
||
continue
|
||
|
||
if not found:
|
||
stale.append(idx)
|
||
|
||
if stale:
|
||
for idx in reversed(stale):
|
||
images.pop(idx)
|
||
has_changes = True
|
||
logger.info(
|
||
"Pruned %d stale image entry(ies) for %s",
|
||
len(stale),
|
||
getattr(metadata, "model_name", model_hash),
|
||
)
|
||
|
||
return has_changes
|