diff --git a/py/config.py b/py/config.py index eb5eaad9..bd2d3453 100644 --- a/py/config.py +++ b/py/config.py @@ -2,7 +2,7 @@ import os import platform import threading from pathlib import Path -import folder_paths # type: ignore +import folder_paths # type: ignore from typing import Any, Dict, Iterable, List, Mapping, Optional, Set, Tuple import logging import json @@ -10,16 +10,23 @@ import urllib.parse import time from .utils.cache_paths import CacheType, get_cache_file_path, get_legacy_cache_paths -from .utils.settings_paths import ensure_settings_file, get_settings_dir, load_settings_template +from .utils.settings_paths import ( + ensure_settings_file, + get_settings_dir, + load_settings_template, +) # Use an environment variable to control standalone mode -standalone_mode = os.environ.get("LORA_MANAGER_STANDALONE", "0") == "1" or os.environ.get("HF_HUB_DISABLE_TELEMETRY", "0") == "0" +standalone_mode = ( + os.environ.get("LORA_MANAGER_STANDALONE", "0") == "1" + or os.environ.get("HF_HUB_DISABLE_TELEMETRY", "0") == "0" +) logger = logging.getLogger(__name__) def _normalize_folder_paths_for_comparison( - folder_paths: Mapping[str, Iterable[str]] + folder_paths: Mapping[str, Iterable[str]], ) -> Dict[str, Set[str]]: """Normalize folder paths for comparison across libraries.""" @@ -49,7 +56,7 @@ def _normalize_folder_paths_for_comparison( def _normalize_library_folder_paths( - library_payload: Mapping[str, Any] + library_payload: Mapping[str, Any], ) -> Dict[str, Set[str]]: """Return normalized folder paths extracted from a library payload.""" @@ -74,11 +81,17 @@ def _get_template_folder_paths() -> Dict[str, Set[str]]: class Config: """Global configuration for LoRA Manager""" - + def __init__(self): - self.templates_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'templates') - self.static_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'static') - self.i18n_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'locales') + self.templates_path = os.path.join( + os.path.dirname(os.path.dirname(__file__)), "templates" + ) + self.static_path = os.path.join( + os.path.dirname(os.path.dirname(__file__)), "static" + ) + self.i18n_path = os.path.join( + os.path.dirname(os.path.dirname(__file__)), "locales" + ) # Path mapping dictionary, target to link mapping self._path_mappings: Dict[str, str] = {} # Normalized preview root directories used to validate preview access @@ -98,7 +111,7 @@ class Config: self.extra_embeddings_roots: List[str] = [] # Scan symbolic links during initialization self._initialize_symlink_mappings() - + if not standalone_mode: # Save the paths to settings.json when running in ComfyUI mode self.save_folder_paths_to_settings() @@ -152,17 +165,21 @@ class Config: default_library = libraries.get("default", {}) target_folder_paths = { - 'loras': list(self.loras_roots), - 'checkpoints': list(self.checkpoints_roots or []), - 'unet': list(self.unet_roots or []), - 'embeddings': list(self.embeddings_roots or []), + "loras": list(self.loras_roots), + "checkpoints": list(self.checkpoints_roots or []), + "unet": list(self.unet_roots or []), + "embeddings": list(self.embeddings_roots or []), } - normalized_target_paths = _normalize_folder_paths_for_comparison(target_folder_paths) + normalized_target_paths = _normalize_folder_paths_for_comparison( + target_folder_paths + ) normalized_default_paths: Optional[Dict[str, Set[str]]] = None if isinstance(default_library, Mapping): - normalized_default_paths = _normalize_library_folder_paths(default_library) + normalized_default_paths = _normalize_library_folder_paths( + default_library + ) if ( not comfy_library @@ -185,13 +202,19 @@ class Config: default_lora_root = self.loras_roots[0] default_checkpoint_root = comfy_library.get("default_checkpoint_root", "") - if (not default_checkpoint_root and self.checkpoints_roots and - len(self.checkpoints_roots) == 1): + if ( + not default_checkpoint_root + and self.checkpoints_roots + and len(self.checkpoints_roots) == 1 + ): default_checkpoint_root = self.checkpoints_roots[0] default_embedding_root = comfy_library.get("default_embedding_root", "") - if (not default_embedding_root and self.embeddings_roots and - len(self.embeddings_roots) == 1): + if ( + not default_embedding_root + and self.embeddings_roots + and len(self.embeddings_roots) == 1 + ): default_embedding_root = self.embeddings_roots[0] metadata = dict(comfy_library.get("metadata", {})) @@ -216,11 +239,12 @@ class Config: try: if os.path.islink(path): return True - if platform.system() == 'Windows': + if platform.system() == "Windows": try: import ctypes + FILE_ATTRIBUTE_REPARSE_POINT = 0x400 - attrs = ctypes.windll.kernel32.GetFileAttributesW(str(path)) + attrs = ctypes.windll.kernel32.GetFileAttributesW(str(path)) # type: ignore[attr-defined] return attrs != -1 and (attrs & FILE_ATTRIBUTE_REPARSE_POINT) except Exception as e: logger.error(f"Error checking Windows reparse point: {e}") @@ -233,18 +257,19 @@ class Config: """Check if a directory entry is a symlink, including Windows junctions.""" if entry.is_symlink(): return True - if platform.system() == 'Windows': + if platform.system() == "Windows": try: import ctypes + FILE_ATTRIBUTE_REPARSE_POINT = 0x400 - attrs = ctypes.windll.kernel32.GetFileAttributesW(entry.path) + attrs = ctypes.windll.kernel32.GetFileAttributesW(entry.path) # type: ignore[attr-defined] return attrs != -1 and (attrs & FILE_ATTRIBUTE_REPARSE_POINT) except Exception: pass return False def _normalize_path(self, path: str) -> str: - return os.path.normpath(path).replace(os.sep, '/') + return os.path.normpath(path).replace(os.sep, "/") def _get_symlink_cache_path(self) -> Path: canonical_path = get_cache_file_path(CacheType.SYMLINK, create_dir=True) @@ -278,19 +303,18 @@ class Config: if self._entry_is_symlink(entry): try: target = os.path.realpath(entry.path) - direct_symlinks.append([ - self._normalize_path(entry.path), - self._normalize_path(target) - ]) + direct_symlinks.append( + [ + self._normalize_path(entry.path), + self._normalize_path(target), + ] + ) except OSError: pass except (OSError, PermissionError): pass - return { - "roots": unique_roots, - "direct_symlinks": sorted(direct_symlinks) - } + return {"roots": unique_roots, "direct_symlinks": sorted(direct_symlinks)} def _initialize_symlink_mappings(self) -> None: start = time.perf_counter() @@ -307,10 +331,14 @@ class Config: cached_fingerprint = self._cached_fingerprint # Check 1: First-level symlinks unchanged (catches new symlinks at root) - fingerprint_valid = cached_fingerprint and current_fingerprint == cached_fingerprint + fingerprint_valid = ( + cached_fingerprint and current_fingerprint == cached_fingerprint + ) # Check 2: All cached mappings still valid (catches changes at any depth) - mappings_valid = self._validate_cached_mappings() if fingerprint_valid else False + mappings_valid = ( + self._validate_cached_mappings() if fingerprint_valid else False + ) if fingerprint_valid and mappings_valid: return @@ -370,7 +398,9 @@ class Config: for target, link in cached_mappings.items(): if not isinstance(target, str) or not isinstance(link, str): continue - normalized_mappings[self._normalize_path(target)] = self._normalize_path(link) + normalized_mappings[self._normalize_path(target)] = self._normalize_path( + link + ) self._path_mappings = normalized_mappings @@ -391,7 +421,9 @@ class Config: parent_dir = loaded_path.parent if parent_dir.name == "cache" and not any(parent_dir.iterdir()): parent_dir.rmdir() - logger.info("Removed empty legacy cache directory: %s", parent_dir) + logger.info( + "Removed empty legacy cache directory: %s", parent_dir + ) except Exception: pass @@ -402,7 +434,9 @@ class Config: exc, ) else: - logger.info("Symlink cache loaded with %d mappings", len(self._path_mappings)) + logger.info( + "Symlink cache loaded with %d mappings", len(self._path_mappings) + ) return True @@ -414,7 +448,7 @@ class Config: """ for target, link in self._path_mappings.items(): # Convert normalized paths back to OS paths - link_path = link.replace('/', os.sep) + link_path = link.replace("/", os.sep) # Check if symlink still exists if not self._is_link(link_path): @@ -427,7 +461,9 @@ class Config: if actual_target != target: logger.debug( "Symlink target changed: %s -> %s (cached: %s)", - link_path, actual_target, target + link_path, + actual_target, + target, ) return False except OSError: @@ -446,7 +482,11 @@ class Config: try: with cache_path.open("w", encoding="utf-8") as handle: json.dump(payload, handle, ensure_ascii=False, indent=2) - logger.debug("Symlink cache saved to %s with %d mappings", cache_path, len(self._path_mappings)) + logger.debug( + "Symlink cache saved to %s with %d mappings", + cache_path, + len(self._path_mappings), + ) except Exception as exc: logger.info("Failed to write symlink cache %s: %s", cache_path, exc) @@ -458,7 +498,7 @@ class Config: at the root level only (not nested symlinks in subdirectories). """ start = time.perf_counter() - + # Reset mappings before rescanning to avoid stale entries self._path_mappings.clear() self._seed_root_symlink_mappings() @@ -472,7 +512,7 @@ class Config: def _scan_first_level_symlinks(self, root: str): """Scan only the first level of a directory for symlinks. - + This avoids traversing the entire directory tree which can be extremely slow for large model collections. Only symlinks directly under the root are detected. @@ -494,13 +534,13 @@ class Config: self.add_path_mapping(entry.path, target_path) except Exception as inner_exc: logger.debug( - "Error processing directory entry %s: %s", entry.path, inner_exc + "Error processing directory entry %s: %s", + entry.path, + inner_exc, ) except Exception as e: logger.error(f"Error scanning links in {root}: {e}") - - def add_path_mapping(self, link_path: str, target_path: str): """Add a symbolic link path mapping target_path: actual target path @@ -594,41 +634,46 @@ class Config: preview_roots.update(self._expand_preview_root(target)) preview_roots.update(self._expand_preview_root(link)) - self._preview_root_paths = {path for path in preview_roots if path.is_absolute()} + self._preview_root_paths = { + path for path in preview_roots if path.is_absolute() + } logger.debug( "Preview roots rebuilt: %d paths from %d lora roots (%d extra), %d checkpoint roots (%d extra), %d embedding roots (%d extra), %d symlink mappings", len(self._preview_root_paths), - len(self.loras_roots or []), len(self.extra_loras_roots or []), - len(self.base_models_roots or []), len(self.extra_checkpoints_roots or []), - len(self.embeddings_roots or []), len(self.extra_embeddings_roots or []), + len(self.loras_roots or []), + len(self.extra_loras_roots or []), + len(self.base_models_roots or []), + len(self.extra_checkpoints_roots or []), + len(self.embeddings_roots or []), + len(self.extra_embeddings_roots or []), len(self._path_mappings), ) def map_path_to_link(self, path: str) -> str: """Map a target path back to its symbolic link path""" - normalized_path = os.path.normpath(path).replace(os.sep, '/') + normalized_path = os.path.normpath(path).replace(os.sep, "/") # Check if the path is contained in any mapped target path for target_path, link_path in self._path_mappings.items(): # Match whole path components to avoid prefix collisions (e.g., /a/b vs /a/bc) if normalized_path == target_path: return link_path - - if normalized_path.startswith(target_path + '/'): + + if normalized_path.startswith(target_path + "/"): # If the path starts with the target path, replace with link path mapped_path = normalized_path.replace(target_path, link_path, 1) return mapped_path return normalized_path - + def map_link_to_path(self, link_path: str) -> str: """Map a symbolic link path back to the actual path""" - normalized_link = os.path.normpath(link_path).replace(os.sep, '/') + normalized_link = os.path.normpath(link_path).replace(os.sep, "/") # Check if the path is contained in any mapped target path for target_path, link_path_mapped in self._path_mappings.items(): # Match whole path components if normalized_link == link_path_mapped: return target_path - if normalized_link.startswith(link_path_mapped + '/'): + if normalized_link.startswith(link_path_mapped + "/"): # If the path starts with the link path, replace with actual path mapped_path = normalized_link.replace(link_path_mapped, target_path, 1) return mapped_path @@ -641,8 +686,8 @@ class Config: continue if not os.path.exists(path): continue - real_path = os.path.normpath(os.path.realpath(path)).replace(os.sep, '/') - normalized = os.path.normpath(path).replace(os.sep, '/') + real_path = os.path.normpath(os.path.realpath(path)).replace(os.sep, "/") + normalized = os.path.normpath(path).replace(os.sep, "/") if real_path not in dedup: dedup[real_path] = normalized return dedup @@ -652,7 +697,9 @@ class Config: unique_paths = sorted(path_map.values(), key=lambda p: p.lower()) for original_path in unique_paths: - real_path = os.path.normpath(os.path.realpath(original_path)).replace(os.sep, '/') + real_path = os.path.normpath(os.path.realpath(original_path)).replace( + os.sep, "/" + ) if real_path != original_path: self.add_path_mapping(original_path, real_path) @@ -674,7 +721,7 @@ class Config: "Please fix your ComfyUI path configuration to separate these folders. " "Falling back to 'checkpoints' for backward compatibility. " "Overlapping real paths: %s", - [checkpoint_map.get(rp, rp) for rp in overlapping_real_paths] + [checkpoint_map.get(rp, rp) for rp in overlapping_real_paths], ) # Remove overlapping paths from unet_map to prioritize checkpoints for rp in overlapping_real_paths: @@ -694,7 +741,9 @@ class Config: self.unet_roots = [p for p in unique_paths if p in unet_values] for original_path in unique_paths: - real_path = os.path.normpath(os.path.realpath(original_path)).replace(os.sep, '/') + real_path = os.path.normpath(os.path.realpath(original_path)).replace( + os.sep, "/" + ) if real_path != original_path: self.add_path_mapping(original_path, real_path) @@ -705,7 +754,9 @@ class Config: unique_paths = sorted(path_map.values(), key=lambda p: p.lower()) for original_path in unique_paths: - real_path = os.path.normpath(os.path.realpath(original_path)).replace(os.sep, '/') + real_path = os.path.normpath(os.path.realpath(original_path)).replace( + os.sep, "/" + ) if real_path != original_path: self.add_path_mapping(original_path, real_path) @@ -719,42 +770,66 @@ class Config: self._path_mappings.clear() self._preview_root_paths = set() - lora_paths = folder_paths.get('loras', []) or [] - checkpoint_paths = folder_paths.get('checkpoints', []) or [] - unet_paths = folder_paths.get('unet', []) or [] - embedding_paths = folder_paths.get('embeddings', []) or [] + lora_paths = folder_paths.get("loras", []) or [] + checkpoint_paths = folder_paths.get("checkpoints", []) or [] + unet_paths = folder_paths.get("unet", []) or [] + embedding_paths = folder_paths.get("embeddings", []) or [] self.loras_roots = self._prepare_lora_paths(lora_paths) - self.base_models_roots = self._prepare_checkpoint_paths(checkpoint_paths, unet_paths) + self.base_models_roots = self._prepare_checkpoint_paths( + checkpoint_paths, unet_paths + ) self.embeddings_roots = self._prepare_embedding_paths(embedding_paths) # Process extra paths (only for LoRA Manager, not shared with ComfyUI) extra_paths = extra_folder_paths or {} - extra_lora_paths = extra_paths.get('loras', []) or [] - extra_checkpoint_paths = extra_paths.get('checkpoints', []) or [] - extra_unet_paths = extra_paths.get('unet', []) or [] - extra_embedding_paths = extra_paths.get('embeddings', []) or [] + extra_lora_paths = extra_paths.get("loras", []) or [] + extra_checkpoint_paths = extra_paths.get("checkpoints", []) or [] + extra_unet_paths = extra_paths.get("unet", []) or [] + extra_embedding_paths = extra_paths.get("embeddings", []) or [] self.extra_loras_roots = self._prepare_lora_paths(extra_lora_paths) # Save main paths before processing extra paths ( _prepare_checkpoint_paths overwrites them) saved_checkpoints_roots = self.checkpoints_roots saved_unet_roots = self.unet_roots - self.extra_checkpoints_roots = self._prepare_checkpoint_paths(extra_checkpoint_paths, extra_unet_paths) - self.extra_unet_roots = self.unet_roots # unet_roots was set by _prepare_checkpoint_paths + self.extra_checkpoints_roots = self._prepare_checkpoint_paths( + extra_checkpoint_paths, extra_unet_paths + ) + self.extra_unet_roots = ( + self.unet_roots if self.unet_roots is not None else [] + ) # unet_roots was set by _prepare_checkpoint_paths # Restore main paths self.checkpoints_roots = saved_checkpoints_roots self.unet_roots = saved_unet_roots - self.extra_embeddings_roots = self._prepare_embedding_paths(extra_embedding_paths) + self.extra_embeddings_roots = self._prepare_embedding_paths( + extra_embedding_paths + ) # Log extra folder paths if self.extra_loras_roots: - logger.info("Found extra LoRA roots:" + "\n - " + "\n - ".join(self.extra_loras_roots)) + logger.info( + "Found extra LoRA roots:" + + "\n - " + + "\n - ".join(self.extra_loras_roots) + ) if self.extra_checkpoints_roots: - logger.info("Found extra checkpoint roots:" + "\n - " + "\n - ".join(self.extra_checkpoints_roots)) + logger.info( + "Found extra checkpoint roots:" + + "\n - " + + "\n - ".join(self.extra_checkpoints_roots) + ) if self.extra_unet_roots: - logger.info("Found extra diffusion model roots:" + "\n - " + "\n - ".join(self.extra_unet_roots)) + logger.info( + "Found extra diffusion model roots:" + + "\n - " + + "\n - ".join(self.extra_unet_roots) + ) if self.extra_embeddings_roots: - logger.info("Found extra embedding roots:" + "\n - " + "\n - ".join(self.extra_embeddings_roots)) + logger.info( + "Found extra embedding roots:" + + "\n - " + + "\n - ".join(self.extra_embeddings_roots) + ) self._initialize_symlink_mappings() @@ -763,7 +838,10 @@ class Config: try: raw_paths = folder_paths.get_folder_paths("loras") unique_paths = self._prepare_lora_paths(raw_paths) - logger.info("Found LoRA roots:" + ("\n - " + "\n - ".join(unique_paths) if unique_paths else "[]")) + logger.info( + "Found LoRA roots:" + + ("\n - " + "\n - ".join(unique_paths) if unique_paths else "[]") + ) if not unique_paths: logger.warning("No valid loras folders found in ComfyUI configuration") @@ -779,12 +857,19 @@ class Config: try: raw_checkpoint_paths = folder_paths.get_folder_paths("checkpoints") raw_unet_paths = folder_paths.get_folder_paths("unet") - unique_paths = self._prepare_checkpoint_paths(raw_checkpoint_paths, raw_unet_paths) + unique_paths = self._prepare_checkpoint_paths( + raw_checkpoint_paths, raw_unet_paths + ) - logger.info("Found checkpoint roots:" + ("\n - " + "\n - ".join(unique_paths) if unique_paths else "[]")) + logger.info( + "Found checkpoint roots:" + + ("\n - " + "\n - ".join(unique_paths) if unique_paths else "[]") + ) if not unique_paths: - logger.warning("No valid checkpoint folders found in ComfyUI configuration") + logger.warning( + "No valid checkpoint folders found in ComfyUI configuration" + ) return [] return unique_paths @@ -797,10 +882,15 @@ class Config: try: raw_paths = folder_paths.get_folder_paths("embeddings") unique_paths = self._prepare_embedding_paths(raw_paths) - logger.info("Found embedding roots:" + ("\n - " + "\n - ".join(unique_paths) if unique_paths else "[]")) + logger.info( + "Found embedding roots:" + + ("\n - " + "\n - ".join(unique_paths) if unique_paths else "[]") + ) if not unique_paths: - logger.warning("No valid embeddings folders found in ComfyUI configuration") + logger.warning( + "No valid embeddings folders found in ComfyUI configuration" + ) return [] return unique_paths @@ -812,13 +902,13 @@ class Config: if not preview_path: return "" - normalized = os.path.normpath(preview_path).replace(os.sep, '/') - encoded_path = urllib.parse.quote(normalized, safe='') - return f'/api/lm/previews?path={encoded_path}' + normalized = os.path.normpath(preview_path).replace(os.sep, "/") + encoded_path = urllib.parse.quote(normalized, safe="") + return f"/api/lm/previews?path={encoded_path}" def is_preview_path_allowed(self, preview_path: str) -> bool: """Return ``True`` if ``preview_path`` is within an allowed directory. - + If the path is initially rejected, attempts to discover deep symlinks that were not scanned during initialization. If a symlink is found, updates the in-memory path mappings and retries the check. @@ -889,14 +979,18 @@ class Config: normalized_link = self._normalize_path(str(current)) self._path_mappings[normalized_target] = normalized_link - self._preview_root_paths.update(self._expand_preview_root(normalized_target)) - self._preview_root_paths.update(self._expand_preview_root(normalized_link)) + self._preview_root_paths.update( + self._expand_preview_root(normalized_target) + ) + self._preview_root_paths.update( + self._expand_preview_root(normalized_link) + ) logger.debug( "Discovered deep symlink: %s -> %s (preview path: %s)", normalized_link, normalized_target, - preview_path + preview_path, ) return True @@ -914,8 +1008,16 @@ class Config: def apply_library_settings(self, library_config: Mapping[str, object]) -> None: """Update runtime paths to match the provided library configuration.""" - folder_paths = library_config.get('folder_paths') if isinstance(library_config, Mapping) else {} - extra_folder_paths = library_config.get('extra_folder_paths') if isinstance(library_config, Mapping) else None + folder_paths = ( + library_config.get("folder_paths") + if isinstance(library_config, Mapping) + else {} + ) + extra_folder_paths = ( + library_config.get("extra_folder_paths") + if isinstance(library_config, Mapping) + else None + ) if not isinstance(folder_paths, Mapping): folder_paths = {} if not isinstance(extra_folder_paths, Mapping): @@ -925,9 +1027,12 @@ class Config: logger.info( "Applied library settings with %d lora roots (%d extra), %d checkpoint roots (%d extra), and %d embedding roots (%d extra)", - len(self.loras_roots or []), len(self.extra_loras_roots or []), - len(self.base_models_roots or []), len(self.extra_checkpoints_roots or []), - len(self.embeddings_roots or []), len(self.extra_embeddings_roots or []), + len(self.loras_roots or []), + len(self.extra_loras_roots or []), + len(self.base_models_roots or []), + len(self.extra_checkpoints_roots or []), + len(self.embeddings_roots or []), + len(self.extra_embeddings_roots or []), ) def get_library_registry_snapshot(self) -> Dict[str, object]: @@ -947,5 +1052,6 @@ class Config: logger.debug("Failed to collect library registry snapshot: %s", exc) return {"active_library": "", "libraries": {}} + # Global config instance config = Config() diff --git a/py/lora_manager.py b/py/lora_manager.py index 0b5ed614..f2038040 100644 --- a/py/lora_manager.py +++ b/py/lora_manager.py @@ -5,16 +5,22 @@ import logging from .utils.logging_config import setup_logging # Check if we're in standalone mode -standalone_mode = os.environ.get("LORA_MANAGER_STANDALONE", "0") == "1" or os.environ.get("HF_HUB_DISABLE_TELEMETRY", "0") == "0" +standalone_mode = ( + os.environ.get("LORA_MANAGER_STANDALONE", "0") == "1" + or os.environ.get("HF_HUB_DISABLE_TELEMETRY", "0") == "0" +) # Only setup logging prefix if not in standalone mode if not standalone_mode: setup_logging() -from server import PromptServer # type: ignore +from server import PromptServer # type: ignore from .config import config -from .services.model_service_factory import ModelServiceFactory, register_default_model_types +from .services.model_service_factory import ( + ModelServiceFactory, + register_default_model_types, +) from .routes.recipe_routes import RecipeRoutes from .routes.stats_routes import StatsRoutes from .routes.update_routes import UpdateRoutes @@ -61,9 +67,10 @@ class _SettingsProxy: settings = _SettingsProxy() + class LoraManager: """Main entry point for LoRA Manager plugin""" - + @classmethod def add_routes(cls): """Initialize and register all routes using the new refactored architecture""" @@ -76,7 +83,8 @@ class LoraManager: ( idx for idx, middleware in enumerate(app.middlewares) - if getattr(middleware, "__name__", "") == "block_external_middleware" + if getattr(middleware, "__name__", "") + == "block_external_middleware" ), None, ) @@ -84,7 +92,9 @@ class LoraManager: if block_middleware_index is None: app.middlewares.append(relax_csp_for_remote_media) else: - app.middlewares.insert(block_middleware_index, relax_csp_for_remote_media) + app.middlewares.insert( + block_middleware_index, relax_csp_for_remote_media + ) # Increase allowed header sizes so browsers with large localhost cookie # jars (multiple UIs on 127.0.0.1) don't trip aiohttp's 8KB default @@ -105,7 +115,7 @@ class LoraManager: app._handler_args = updated_handler_args # Configure aiohttp access logger to be less verbose - logging.getLogger('aiohttp.access').setLevel(logging.WARNING) + logging.getLogger("aiohttp.access").setLevel(logging.WARNING) # Add specific suppression for connection reset errors class ConnectionResetFilter(logging.Filter): @@ -124,46 +134,52 @@ class LoraManager: asyncio_logger.addFilter(ConnectionResetFilter()) # Add static route for example images if the path exists in settings - example_images_path = settings.get('example_images_path') + example_images_path = settings.get("example_images_path") logger.info(f"Example images path: {example_images_path}") if example_images_path and os.path.exists(example_images_path): - app.router.add_static('/example_images_static', example_images_path) - logger.info(f"Added static route for example images: /example_images_static -> {example_images_path}") + app.router.add_static("/example_images_static", example_images_path) + logger.info( + f"Added static route for example images: /example_images_static -> {example_images_path}" + ) # Add static route for locales JSON files if os.path.exists(config.i18n_path): - app.router.add_static('/locales', config.i18n_path) - logger.info(f"Added static route for locales: /locales -> {config.i18n_path}") + app.router.add_static("/locales", config.i18n_path) + logger.info( + f"Added static route for locales: /locales -> {config.i18n_path}" + ) # Add static route for plugin assets - app.router.add_static('/loras_static', config.static_path) - + app.router.add_static("/loras_static", config.static_path) + # Register default model types with the factory register_default_model_types() - + # Setup all model routes using the factory ModelServiceFactory.setup_all_routes(app) - + # Setup non-model-specific routes stats_routes = StatsRoutes() stats_routes.setup_routes(app) RecipeRoutes.setup_routes(app) - UpdateRoutes.setup_routes(app) + UpdateRoutes.setup_routes(app) MiscRoutes.setup_routes(app) ExampleImagesRoutes.setup_routes(app, ws_manager=ws_manager) PreviewRoutes.setup_routes(app) - + # Setup WebSocket routes that are shared across all model types - app.router.add_get('/ws/fetch-progress', ws_manager.handle_connection) - app.router.add_get('/ws/download-progress', ws_manager.handle_download_connection) - app.router.add_get('/ws/init-progress', ws_manager.handle_init_connection) - - # Schedule service initialization + app.router.add_get("/ws/fetch-progress", ws_manager.handle_connection) + app.router.add_get( + "/ws/download-progress", ws_manager.handle_download_connection + ) + app.router.add_get("/ws/init-progress", ws_manager.handle_init_connection) + + # Schedule service initialization app.on_startup.append(lambda app: cls._initialize_services()) - + # Add cleanup app.on_shutdown.append(cls._cleanup) - + @classmethod async def _initialize_services(cls): """Initialize all services using the ServiceRegistry""" @@ -197,7 +213,9 @@ class LoraManager: extra_paths.get("embeddings", []), ) except Exception as exc: - logger.warning("Failed to apply library settings during initialization: %s", exc) + logger.warning( + "Failed to apply library settings during initialization: %s", exc + ) # Initialize CivitaiClient first to ensure it's ready for other services await ServiceRegistry.get_civitai_client() @@ -206,163 +224,200 @@ class LoraManager: await ServiceRegistry.get_download_manager() from .services.metadata_service import initialize_metadata_providers + await initialize_metadata_providers() - + # Initialize WebSocket manager await ServiceRegistry.get_websocket_manager() - + # Initialize scanners in background lora_scanner = await ServiceRegistry.get_lora_scanner() checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner() embedding_scanner = await ServiceRegistry.get_embedding_scanner() - + # Initialize recipe scanner if needed recipe_scanner = await ServiceRegistry.get_recipe_scanner() - + # Create low-priority initialization tasks init_tasks = [ - asyncio.create_task(lora_scanner.initialize_in_background(), name='lora_cache_init'), - asyncio.create_task(checkpoint_scanner.initialize_in_background(), name='checkpoint_cache_init'), - asyncio.create_task(embedding_scanner.initialize_in_background(), name='embedding_cache_init'), - asyncio.create_task(recipe_scanner.initialize_in_background(), name='recipe_cache_init') + asyncio.create_task( + lora_scanner.initialize_in_background(), name="lora_cache_init" + ), + asyncio.create_task( + checkpoint_scanner.initialize_in_background(), + name="checkpoint_cache_init", + ), + asyncio.create_task( + embedding_scanner.initialize_in_background(), + name="embedding_cache_init", + ), + asyncio.create_task( + recipe_scanner.initialize_in_background(), name="recipe_cache_init" + ), ] await ExampleImagesMigration.check_and_run_migrations() - + # Schedule post-initialization tasks to run after scanners complete asyncio.create_task( - cls._run_post_initialization_tasks(init_tasks), - name='post_init_tasks' + cls._run_post_initialization_tasks(init_tasks), name="post_init_tasks" ) - - logger.debug("LoRA Manager: All services initialized and background tasks scheduled") - + + logger.debug( + "LoRA Manager: All services initialized and background tasks scheduled" + ) + except Exception as e: - logger.error(f"LoRA Manager: Error initializing services: {e}", exc_info=True) - + logger.error( + f"LoRA Manager: Error initializing services: {e}", exc_info=True + ) + @classmethod async def _run_post_initialization_tasks(cls, init_tasks): """Run post-initialization tasks after all scanners complete""" try: - logger.debug("LoRA Manager: Waiting for scanner initialization to complete...") - + logger.debug( + "LoRA Manager: Waiting for scanner initialization to complete..." + ) + # Wait for all scanner initialization tasks to complete await asyncio.gather(*init_tasks, return_exceptions=True) - logger.debug("LoRA Manager: Scanner initialization completed, starting post-initialization tasks...") - + logger.debug( + "LoRA Manager: Scanner initialization completed, starting post-initialization tasks..." + ) + # Run post-initialization tasks post_tasks = [ - asyncio.create_task(cls._cleanup_backup_files(), name='cleanup_bak_files'), + asyncio.create_task( + cls._cleanup_backup_files(), name="cleanup_bak_files" + ), # Add more post-initialization tasks here as needed # asyncio.create_task(cls._another_post_task(), name='another_task'), ] - + # Run all post-initialization tasks results = await asyncio.gather(*post_tasks, return_exceptions=True) - + # Log results for i, result in enumerate(results): task_name = post_tasks[i].get_name() if isinstance(result, Exception): - logger.error(f"Post-initialization task '{task_name}' failed: {result}") + logger.error( + f"Post-initialization task '{task_name}' failed: {result}" + ) else: - logger.debug(f"Post-initialization task '{task_name}' completed successfully") - + logger.debug( + f"Post-initialization task '{task_name}' completed successfully" + ) + logger.debug("LoRA Manager: All post-initialization tasks completed") - + except Exception as e: - logger.error(f"LoRA Manager: Error in post-initialization tasks: {e}", exc_info=True) - + logger.error( + f"LoRA Manager: Error in post-initialization tasks: {e}", exc_info=True + ) + @classmethod async def _cleanup_backup_files(cls): """Clean up .bak files in all model roots""" try: logger.debug("Starting cleanup of .bak files in model directories...") - + # Collect all model roots all_roots = set() all_roots.update(config.loras_roots) - all_roots.update(config.base_models_roots) - all_roots.update(config.embeddings_roots) - + all_roots.update(config.base_models_roots or []) + all_roots.update(config.embeddings_roots or []) + total_deleted = 0 total_size_freed = 0 - + for root_path in all_roots: if not os.path.exists(root_path): continue - + try: - deleted_count, size_freed = await cls._cleanup_backup_files_in_directory(root_path) + ( + deleted_count, + size_freed, + ) = await cls._cleanup_backup_files_in_directory(root_path) total_deleted += deleted_count total_size_freed += size_freed - + if deleted_count > 0: - logger.debug(f"Cleaned up {deleted_count} .bak files in {root_path} (freed {size_freed / (1024*1024):.2f} MB)") - + logger.debug( + f"Cleaned up {deleted_count} .bak files in {root_path} (freed {size_freed / (1024 * 1024):.2f} MB)" + ) + except Exception as e: logger.error(f"Error cleaning up .bak files in {root_path}: {e}") - + # Yield control periodically await asyncio.sleep(0.01) - + if total_deleted > 0: - logger.debug(f"Backup cleanup completed: removed {total_deleted} .bak files, freed {total_size_freed / (1024*1024):.2f} MB total") + logger.debug( + f"Backup cleanup completed: removed {total_deleted} .bak files, freed {total_size_freed / (1024 * 1024):.2f} MB total" + ) else: logger.debug("Backup cleanup completed: no .bak files found") except Exception as e: logger.error(f"Error during backup file cleanup: {e}", exc_info=True) - + @classmethod async def _cleanup_backup_files_in_directory(cls, directory_path: str): """Clean up .bak files in a specific directory recursively - + Args: directory_path: Path to the directory to clean - + Returns: Tuple[int, int]: (number of files deleted, total size freed in bytes) """ deleted_count = 0 size_freed = 0 visited_paths = set() - + def cleanup_recursive(path): nonlocal deleted_count, size_freed - + try: real_path = os.path.realpath(path) if real_path in visited_paths: return visited_paths.add(real_path) - + with os.scandir(path) as it: for entry in it: try: - if entry.is_file(follow_symlinks=True) and entry.name.endswith('.bak'): + if entry.is_file( + follow_symlinks=True + ) and entry.name.endswith(".bak"): file_size = entry.stat().st_size os.remove(entry.path) deleted_count += 1 size_freed += file_size logger.debug(f"Deleted .bak file: {entry.path}") - + elif entry.is_dir(follow_symlinks=True): cleanup_recursive(entry.path) - + except Exception as e: - logger.warning(f"Could not delete .bak file {entry.path}: {e}") - + logger.warning( + f"Could not delete .bak file {entry.path}: {e}" + ) + except Exception as e: logger.error(f"Error scanning directory {path} for .bak files: {e}") - + # Run the recursive cleanup in a thread pool to avoid blocking loop = asyncio.get_event_loop() await loop.run_in_executor(None, cleanup_recursive, directory_path) - + return deleted_count, size_freed - + @classmethod async def _cleanup_example_images_folders(cls): """Invoke the example images cleanup service for manual execution.""" @@ -370,21 +425,21 @@ class LoraManager: service = ExampleImagesCleanupService() result = await service.cleanup_example_image_folders() - if result.get('success'): + if result.get("success"): logger.debug( "Manual example images cleanup completed: moved=%s", - result.get('moved_total'), + result.get("moved_total"), ) - elif result.get('partial_success'): + elif result.get("partial_success"): logger.warning( "Manual example images cleanup partially succeeded: moved=%s failures=%s", - result.get('moved_total'), - result.get('move_failures'), + result.get("moved_total"), + result.get("move_failures"), ) else: logger.debug( "Manual example images cleanup skipped or failed: %s", - result.get('error', 'no changes'), + result.get("error", "no changes"), ) return result @@ -392,9 +447,9 @@ class LoraManager: except Exception as e: # pragma: no cover - defensive guard logger.error(f"Error during example images cleanup: {e}", exc_info=True) return { - 'success': False, - 'error': str(e), - 'error_code': 'unexpected_error', + "success": False, + "error": str(e), + "error_code": "unexpected_error", } @classmethod @@ -402,6 +457,6 @@ class LoraManager: """Cleanup resources using ServiceRegistry""" try: logger.info("LoRA Manager: Cleaning up services") - + except Exception as e: logger.error(f"Error during cleanup: {e}", exc_info=True) diff --git a/py/metadata_collector/__init__.py b/py/metadata_collector/__init__.py index 2f9cc5b1..f3dde058 100644 --- a/py/metadata_collector/__init__.py +++ b/py/metadata_collector/__init__.py @@ -4,7 +4,10 @@ import logging logger = logging.getLogger(__name__) # Check if running in standalone mode -standalone_mode = os.environ.get("LORA_MANAGER_STANDALONE", "0") == "1" or os.environ.get("HF_HUB_DISABLE_TELEMETRY", "0") == "0" +standalone_mode = ( + os.environ.get("LORA_MANAGER_STANDALONE", "0") == "1" + or os.environ.get("HF_HUB_DISABLE_TELEMETRY", "0") == "0" +) if not standalone_mode: from .metadata_hook import MetadataHook @@ -13,13 +16,13 @@ if not standalone_mode: def init(): # Install hooks to collect metadata during execution MetadataHook.install() - + # Initialize registry registry = MetadataRegistry() - + logger.info("ComfyUI Metadata Collector initialized") - - def get_metadata(prompt_id=None): + + def get_metadata(prompt_id=None): # type: ignore[no-redef] """Helper function to get metadata from the registry""" registry = MetadataRegistry() return registry.get_metadata(prompt_id) @@ -27,7 +30,7 @@ else: # Standalone mode - provide dummy implementations def init(): logger.info("ComfyUI Metadata Collector disabled in standalone mode") - - def get_metadata(prompt_id=None): + + def get_metadata(prompt_id=None): # type: ignore[no-redef] """Dummy implementation for standalone mode""" return {} diff --git a/py/metadata_collector/metadata_registry.py b/py/metadata_collector/metadata_registry.py index 528f996a..2a084325 100644 --- a/py/metadata_collector/metadata_registry.py +++ b/py/metadata_collector/metadata_registry.py @@ -1,50 +1,54 @@ import time -from nodes import NODE_CLASS_MAPPINGS +from nodes import NODE_CLASS_MAPPINGS # type: ignore from .node_extractors import NODE_EXTRACTORS, GenericNodeExtractor from .constants import METADATA_CATEGORIES, IMAGES + class MetadataRegistry: """A singleton registry to store and retrieve workflow metadata""" + _instance = None - + def __new__(cls): if cls._instance is None: cls._instance = super().__new__(cls) cls._instance._reset() return cls._instance - + def _reset(self): self.current_prompt_id = None self.current_prompt = None self.metadata = {} self.prompt_metadata = {} self.executed_nodes = set() - + # Node-level cache for metadata self.node_cache = {} - + # Limit the number of stored prompts self.max_prompt_history = 3 - + # Categories we want to track and retrieve from cache self.metadata_categories = METADATA_CATEGORIES - + def _clean_old_prompts(self): """Clean up old prompt metadata, keeping only recent ones""" if len(self.prompt_metadata) <= self.max_prompt_history: return - + # Sort all prompt_ids by timestamp sorted_prompts = sorted( self.prompt_metadata.keys(), - key=lambda pid: self.prompt_metadata[pid].get("timestamp", 0) + key=lambda pid: self.prompt_metadata[pid].get("timestamp", 0), ) - + # Remove oldest records - prompts_to_remove = sorted_prompts[:len(sorted_prompts) - self.max_prompt_history] + prompts_to_remove = sorted_prompts[ + : len(sorted_prompts) - self.max_prompt_history + ] for pid in prompts_to_remove: del self.prompt_metadata[pid] - + def start_collection(self, prompt_id): """Begin metadata collection for a new prompt""" self.current_prompt_id = prompt_id @@ -53,90 +57,96 @@ class MetadataRegistry: category: {} for category in METADATA_CATEGORIES } # Add additional metadata fields - self.prompt_metadata[prompt_id].update({ - "execution_order": [], - "current_prompt": None, # Will store the prompt object - "timestamp": time.time() - }) - + self.prompt_metadata[prompt_id].update( + { + "execution_order": [], + "current_prompt": None, # Will store the prompt object + "timestamp": time.time(), + } + ) + # Clean up old prompt data self._clean_old_prompts() - + def set_current_prompt(self, prompt): """Set the current prompt object reference""" self.current_prompt = prompt if self.current_prompt_id and self.current_prompt_id in self.prompt_metadata: # Store the prompt in the metadata for later relationship tracing self.prompt_metadata[self.current_prompt_id]["current_prompt"] = prompt - + def get_metadata(self, prompt_id=None): """Get collected metadata for a prompt""" key = prompt_id if prompt_id is not None else self.current_prompt_id if key not in self.prompt_metadata: return {} - + metadata = self.prompt_metadata[key] - + # If we have a current prompt object, check for non-executed nodes prompt_obj = metadata.get("current_prompt") if prompt_obj and hasattr(prompt_obj, "original_prompt"): original_prompt = prompt_obj.original_prompt - + # Fill in missing metadata from cache for nodes that weren't executed self._fill_missing_metadata(key, original_prompt) - + return self.prompt_metadata.get(key, {}) - + def _fill_missing_metadata(self, prompt_id, original_prompt): """Fill missing metadata from cache for non-executed nodes""" if not original_prompt: return - + executed_nodes = self.executed_nodes metadata = self.prompt_metadata[prompt_id] - + # Iterate through nodes in the original prompt for node_id, node_data in original_prompt.items(): # Skip if already executed in this run if node_id in executed_nodes: continue - + # Get the node type from the prompt (this is the key in NODE_CLASS_MAPPINGS) prompt_class_type = node_data.get("class_type") if not prompt_class_type: continue - + # Convert to actual class name (which is what we use in our cache) class_type = prompt_class_type if prompt_class_type in NODE_CLASS_MAPPINGS: class_obj = NODE_CLASS_MAPPINGS[prompt_class_type] class_type = class_obj.__name__ - + # Create cache key using the actual class name cache_key = f"{node_id}:{class_type}" - + # Check if this node type is relevant for metadata collection if class_type in NODE_EXTRACTORS: # Check if we have cached metadata for this node if cache_key in self.node_cache: cached_data = self.node_cache[cache_key] - + # Apply cached metadata to the current metadata for category in self.metadata_categories: if category in cached_data and node_id in cached_data[category]: if node_id not in metadata[category]: - metadata[category][node_id] = cached_data[category][node_id] - + metadata[category][node_id] = cached_data[category][ + node_id + ] + def record_node_execution(self, node_id, class_type, inputs, outputs): """Record information about a node's execution""" if not self.current_prompt_id: return - + # Add to execution order and mark as executed if node_id not in self.executed_nodes: self.executed_nodes.add(node_id) - self.prompt_metadata[self.current_prompt_id]["execution_order"].append(node_id) - + self.prompt_metadata[self.current_prompt_id]["execution_order"].append( + node_id + ) + # Process inputs to simplify working with them processed_inputs = {} for input_name, input_values in inputs.items(): @@ -145,63 +155,61 @@ class MetadataRegistry: processed_inputs[input_name] = input_values[0] else: processed_inputs[input_name] = input_values - + # Extract node-specific metadata extractor = NODE_EXTRACTORS.get(class_type, GenericNodeExtractor) extractor.extract( - node_id, - processed_inputs, - outputs, - self.prompt_metadata[self.current_prompt_id] + node_id, + processed_inputs, + outputs, + self.prompt_metadata[self.current_prompt_id], ) - + # Cache this node's metadata self._cache_node_metadata(node_id, class_type) - + def update_node_execution(self, node_id, class_type, outputs): """Update node metadata with output information""" if not self.current_prompt_id: return - + # Process outputs to make them more usable processed_outputs = outputs - + # Use the same extractor to update with outputs extractor = NODE_EXTRACTORS.get(class_type, GenericNodeExtractor) - if hasattr(extractor, 'update'): + if hasattr(extractor, "update"): extractor.update( - node_id, - processed_outputs, - self.prompt_metadata[self.current_prompt_id] + node_id, processed_outputs, self.prompt_metadata[self.current_prompt_id] ) - + # Update the cached metadata for this node self._cache_node_metadata(node_id, class_type) - + def _cache_node_metadata(self, node_id, class_type): """Cache the metadata for a specific node""" if not self.current_prompt_id or not node_id or not class_type: return - + # Create a cache key combining node_id and class_type cache_key = f"{node_id}:{class_type}" - + # Create a shallow copy of the node's metadata node_metadata = {} current_metadata = self.prompt_metadata[self.current_prompt_id] - + for category in self.metadata_categories: if category in current_metadata and node_id in current_metadata[category]: if category not in node_metadata: node_metadata[category] = {} node_metadata[category][node_id] = current_metadata[category][node_id] - + # Save new metadata or clear stale cache entries when metadata is empty if any(node_metadata.values()): self.node_cache[cache_key] = node_metadata else: self.node_cache.pop(cache_key, None) - + def clear_unused_cache(self): """Clean up node_cache entries that are no longer in use""" # Collect all node_ids currently in prompt_metadata @@ -210,18 +218,18 @@ class MetadataRegistry: for category in self.metadata_categories: if category in prompt_data: active_node_ids.update(prompt_data[category].keys()) - + # Find cache keys that are no longer needed keys_to_remove = [] for cache_key in self.node_cache: - node_id = cache_key.split(':')[0] + node_id = cache_key.split(":")[0] if node_id not in active_node_ids: keys_to_remove.append(cache_key) - + # Remove cache entries that are no longer needed for key in keys_to_remove: del self.node_cache[key] - + def clear_metadata(self, prompt_id=None): """Clear metadata for a specific prompt or reset all data""" if prompt_id is not None: @@ -232,25 +240,25 @@ class MetadataRegistry: else: # Reset all data self._reset() - + def get_first_decoded_image(self, prompt_id=None): """Get the first decoded image result""" key = prompt_id if prompt_id is not None else self.current_prompt_id if key not in self.prompt_metadata: return None - + metadata = self.prompt_metadata[key] if IMAGES in metadata and "first_decode" in metadata[IMAGES]: image_data = metadata[IMAGES]["first_decode"]["image"] - + # If it's an image batch or tuple, handle various formats if isinstance(image_data, (list, tuple)) and len(image_data) > 0: # Return first element of list/tuple return image_data[0] - + # If it's a tensor, return as is for processing in the route handler return image_data - + # If no image is found in the current metadata, try to find it in the cache # This handles the case where VAEDecode was cached by ComfyUI and not executed prompt_obj = metadata.get("current_prompt") @@ -270,8 +278,11 @@ class MetadataRegistry: if IMAGES in cached_data and node_id in cached_data[IMAGES]: image_data = cached_data[IMAGES][node_id]["image"] # Handle different image formats - if isinstance(image_data, (list, tuple)) and len(image_data) > 0: + if ( + isinstance(image_data, (list, tuple)) + and len(image_data) > 0 + ): return image_data[0] return image_data - + return None diff --git a/py/nodes/save_image.py b/py/nodes/save_image.py index de9d5649..78b70394 100644 --- a/py/nodes/save_image.py +++ b/py/nodes/save_image.py @@ -1,8 +1,9 @@ import json import os import re +from typing import Any, Dict, Optional import numpy as np -import folder_paths # type: ignore +import folder_paths # type: ignore from ..services.service_registry import ServiceRegistry from ..metadata_collector.metadata_processor import MetadataProcessor from ..metadata_collector import get_metadata @@ -12,6 +13,7 @@ import logging logger = logging.getLogger(__name__) + class SaveImageLM: NAME = "Save Image (LoraManager)" CATEGORY = "Lora Manager/utils" @@ -23,42 +25,60 @@ class SaveImageLM: self.prefix_append = "" self.compress_level = 4 self.counter = 0 - + # Add pattern format regex for filename substitution pattern_format = re.compile(r"(%[^%]+%)") - + @classmethod def INPUT_TYPES(cls): return { "required": { "images": ("IMAGE",), - "filename_prefix": ("STRING", { - "default": "ComfyUI", - "tooltip": "Base filename for saved images. Supports format patterns like %seed%, %width%, %height%, %model%, etc." - }), - "file_format": (["png", "jpeg", "webp"], { - "tooltip": "Image format to save as. PNG preserves quality, JPEG is smaller, WebP balances size and quality." - }), + "filename_prefix": ( + "STRING", + { + "default": "ComfyUI", + "tooltip": "Base filename for saved images. Supports format patterns like %seed%, %width%, %height%, %model%, etc.", + }, + ), + "file_format": ( + ["png", "jpeg", "webp"], + { + "tooltip": "Image format to save as. PNG preserves quality, JPEG is smaller, WebP balances size and quality." + }, + ), }, "optional": { - "lossless_webp": ("BOOLEAN", { - "default": False, - "tooltip": "When enabled, saves WebP images with lossless compression. Results in larger files but no quality loss." - }), - "quality": ("INT", { - "default": 100, - "min": 1, - "max": 100, - "tooltip": "Compression quality for JPEG and lossy WebP formats (1-100). Higher values mean better quality but larger files." - }), - "embed_workflow": ("BOOLEAN", { - "default": False, - "tooltip": "Embeds the complete workflow data into the image metadata. Only works with PNG and WebP formats." - }), - "add_counter_to_filename": ("BOOLEAN", { - "default": True, - "tooltip": "Adds an incremental counter to filenames to prevent overwriting previous images." - }), + "lossless_webp": ( + "BOOLEAN", + { + "default": False, + "tooltip": "When enabled, saves WebP images with lossless compression. Results in larger files but no quality loss.", + }, + ), + "quality": ( + "INT", + { + "default": 100, + "min": 1, + "max": 100, + "tooltip": "Compression quality for JPEG and lossy WebP formats (1-100). Higher values mean better quality but larger files.", + }, + ), + "embed_workflow": ( + "BOOLEAN", + { + "default": False, + "tooltip": "Embeds the complete workflow data into the image metadata. Only works with PNG and WebP formats.", + }, + ), + "add_counter_to_filename": ( + "BOOLEAN", + { + "default": True, + "tooltip": "Adds an incremental counter to filenames to prevent overwriting previous images.", + }, + ), }, "hidden": { "id": "UNIQUE_ID", @@ -75,57 +95,59 @@ class SaveImageLM: def get_lora_hash(self, lora_name): """Get the lora hash from cache""" scanner = ServiceRegistry.get_service_sync("lora_scanner") - + # Use the new direct filename lookup method - hash_value = scanner.get_hash_by_filename(lora_name) - if hash_value: - return hash_value - + if scanner is not None: + hash_value = scanner.get_hash_by_filename(lora_name) + if hash_value: + return hash_value + return None def get_checkpoint_hash(self, checkpoint_path): """Get the checkpoint hash from cache""" scanner = ServiceRegistry.get_service_sync("checkpoint_scanner") - + if not checkpoint_path: return None - + # Extract basename without extension checkpoint_name = os.path.basename(checkpoint_path) checkpoint_name = os.path.splitext(checkpoint_name)[0] - + # Try direct filename lookup first - hash_value = scanner.get_hash_by_filename(checkpoint_name) - if hash_value: - return hash_value - + if scanner is not None: + hash_value = scanner.get_hash_by_filename(checkpoint_name) + if hash_value: + return hash_value + return None def format_metadata(self, metadata_dict): """Format metadata in the requested format similar to userComment example""" if not metadata_dict: return "" - + # Helper function to only add parameter if value is not None def add_param_if_not_none(param_list, label, value): if value is not None: param_list.append(f"{label}: {value}") - + # Extract the prompt and negative prompt - prompt = metadata_dict.get('prompt', '') - negative_prompt = metadata_dict.get('negative_prompt', '') - + prompt = metadata_dict.get("prompt", "") + negative_prompt = metadata_dict.get("negative_prompt", "") + # Extract loras from the prompt if present - loras_text = metadata_dict.get('loras', '') + loras_text = metadata_dict.get("loras", "") lora_hashes = {} - + # If loras are found, add them on a new line after the prompt if loras_text: prompt_with_loras = f"{prompt}\n{loras_text}" - + # Extract lora names from the format - lora_matches = re.findall(r']+)>', loras_text) - + lora_matches = re.findall(r"]+)>", loras_text) + # Get hash for each lora for lora_name, strength in lora_matches: hash_value = self.get_lora_hash(lora_name) @@ -133,112 +155,114 @@ class SaveImageLM: lora_hashes[lora_name] = hash_value else: prompt_with_loras = prompt - + # Format the first part (prompt and loras) metadata_parts = [prompt_with_loras] - + # Add negative prompt if negative_prompt: metadata_parts.append(f"Negative prompt: {negative_prompt}") - + # Format the second part (generation parameters) params = [] - + # Add standard parameters in the correct order - if 'steps' in metadata_dict: - add_param_if_not_none(params, "Steps", metadata_dict.get('steps')) - + if "steps" in metadata_dict: + add_param_if_not_none(params, "Steps", metadata_dict.get("steps")) + # Combine sampler and scheduler information sampler_name = None scheduler_name = None - - if 'sampler' in metadata_dict: - sampler = metadata_dict.get('sampler') + + if "sampler" in metadata_dict: + sampler = metadata_dict.get("sampler") # Convert ComfyUI sampler names to user-friendly names sampler_mapping = { - 'euler': 'Euler', - 'euler_ancestral': 'Euler a', - 'dpm_2': 'DPM2', - 'dpm_2_ancestral': 'DPM2 a', - 'heun': 'Heun', - 'dpm_fast': 'DPM fast', - 'dpm_adaptive': 'DPM adaptive', - 'lms': 'LMS', - 'dpmpp_2s_ancestral': 'DPM++ 2S a', - 'dpmpp_sde': 'DPM++ SDE', - 'dpmpp_sde_gpu': 'DPM++ SDE', - 'dpmpp_2m': 'DPM++ 2M', - 'dpmpp_2m_sde': 'DPM++ 2M SDE', - 'dpmpp_2m_sde_gpu': 'DPM++ 2M SDE', - 'ddim': 'DDIM' + "euler": "Euler", + "euler_ancestral": "Euler a", + "dpm_2": "DPM2", + "dpm_2_ancestral": "DPM2 a", + "heun": "Heun", + "dpm_fast": "DPM fast", + "dpm_adaptive": "DPM adaptive", + "lms": "LMS", + "dpmpp_2s_ancestral": "DPM++ 2S a", + "dpmpp_sde": "DPM++ SDE", + "dpmpp_sde_gpu": "DPM++ SDE", + "dpmpp_2m": "DPM++ 2M", + "dpmpp_2m_sde": "DPM++ 2M SDE", + "dpmpp_2m_sde_gpu": "DPM++ 2M SDE", + "ddim": "DDIM", } sampler_name = sampler_mapping.get(sampler, sampler) - - if 'scheduler' in metadata_dict: - scheduler = metadata_dict.get('scheduler') + + if "scheduler" in metadata_dict: + scheduler = metadata_dict.get("scheduler") scheduler_mapping = { - 'normal': 'Simple', - 'karras': 'Karras', - 'exponential': 'Exponential', - 'sgm_uniform': 'SGM Uniform', - 'sgm_quadratic': 'SGM Quadratic' + "normal": "Simple", + "karras": "Karras", + "exponential": "Exponential", + "sgm_uniform": "SGM Uniform", + "sgm_quadratic": "SGM Quadratic", } scheduler_name = scheduler_mapping.get(scheduler, scheduler) - + # Add combined sampler and scheduler information if sampler_name: if scheduler_name: params.append(f"Sampler: {sampler_name} {scheduler_name}") else: params.append(f"Sampler: {sampler_name}") - + # CFG scale (Use guidance if available, otherwise fall back to cfg_scale or cfg) - if 'guidance' in metadata_dict: - add_param_if_not_none(params, "CFG scale", metadata_dict.get('guidance')) - elif 'cfg_scale' in metadata_dict: - add_param_if_not_none(params, "CFG scale", metadata_dict.get('cfg_scale')) - elif 'cfg' in metadata_dict: - add_param_if_not_none(params, "CFG scale", metadata_dict.get('cfg')) - + if "guidance" in metadata_dict: + add_param_if_not_none(params, "CFG scale", metadata_dict.get("guidance")) + elif "cfg_scale" in metadata_dict: + add_param_if_not_none(params, "CFG scale", metadata_dict.get("cfg_scale")) + elif "cfg" in metadata_dict: + add_param_if_not_none(params, "CFG scale", metadata_dict.get("cfg")) + # Seed - if 'seed' in metadata_dict: - add_param_if_not_none(params, "Seed", metadata_dict.get('seed')) - + if "seed" in metadata_dict: + add_param_if_not_none(params, "Seed", metadata_dict.get("seed")) + # Size - if 'size' in metadata_dict: - add_param_if_not_none(params, "Size", metadata_dict.get('size')) - + if "size" in metadata_dict: + add_param_if_not_none(params, "Size", metadata_dict.get("size")) + # Model info - if 'checkpoint' in metadata_dict: + if "checkpoint" in metadata_dict: # Ensure checkpoint is a string before processing - checkpoint = metadata_dict.get('checkpoint') + checkpoint = metadata_dict.get("checkpoint") if checkpoint is not None: # Get model hash model_hash = self.get_checkpoint_hash(checkpoint) - + # Extract basename without path checkpoint_name = os.path.basename(checkpoint) # Remove extension if present checkpoint_name = os.path.splitext(checkpoint_name)[0] - + # Add model hash if available if model_hash: - params.append(f"Model hash: {model_hash[:10]}, Model: {checkpoint_name}") + params.append( + f"Model hash: {model_hash[:10]}, Model: {checkpoint_name}" + ) else: params.append(f"Model: {checkpoint_name}") - + # Add LoRA hashes if available if lora_hashes: lora_hash_parts = [] for lora_name, hash_value in lora_hashes.items(): lora_hash_parts.append(f"{lora_name}: {hash_value[:10]}") - + if lora_hash_parts: - params.append(f"Lora hashes: \"{', '.join(lora_hash_parts)}\"") - + params.append(f'Lora hashes: "{", ".join(lora_hash_parts)}"') + # Combine all parameters with commas metadata_parts.append(", ".join(params)) - + # Join all parts with a new line return "\n".join(metadata_parts) @@ -248,36 +272,36 @@ class SaveImageLM: """Format filename with metadata values""" if not metadata_dict: return filename - + result = re.findall(self.pattern_format, filename) for segment in result: parts = segment.replace("%", "").split(":") key = parts[0] - - if key == "seed" and 'seed' in metadata_dict: - filename = filename.replace(segment, str(metadata_dict.get('seed', ''))) - elif key == "width" and 'size' in metadata_dict: - size = metadata_dict.get('size', 'x') - w = size.split('x')[0] if isinstance(size, str) else size[0] + + if key == "seed" and "seed" in metadata_dict: + filename = filename.replace(segment, str(metadata_dict.get("seed", ""))) + elif key == "width" and "size" in metadata_dict: + size = metadata_dict.get("size", "x") + w = size.split("x")[0] if isinstance(size, str) else size[0] filename = filename.replace(segment, str(w)) - elif key == "height" and 'size' in metadata_dict: - size = metadata_dict.get('size', 'x') - h = size.split('x')[1] if isinstance(size, str) else size[1] + elif key == "height" and "size" in metadata_dict: + size = metadata_dict.get("size", "x") + h = size.split("x")[1] if isinstance(size, str) else size[1] filename = filename.replace(segment, str(h)) - elif key == "pprompt" and 'prompt' in metadata_dict: - prompt = metadata_dict.get('prompt', '').replace("\n", " ") + elif key == "pprompt" and "prompt" in metadata_dict: + prompt = metadata_dict.get("prompt", "").replace("\n", " ") if len(parts) >= 2: length = int(parts[1]) prompt = prompt[:length] filename = filename.replace(segment, prompt.strip()) - elif key == "nprompt" and 'negative_prompt' in metadata_dict: - prompt = metadata_dict.get('negative_prompt', '').replace("\n", " ") + elif key == "nprompt" and "negative_prompt" in metadata_dict: + prompt = metadata_dict.get("negative_prompt", "").replace("\n", " ") if len(parts) >= 2: length = int(parts[1]) prompt = prompt[:length] filename = filename.replace(segment, prompt.strip()) elif key == "model": - model_value = metadata_dict.get('checkpoint') + model_value = metadata_dict.get("checkpoint") if isinstance(model_value, (bytes, os.PathLike)): model_value = str(model_value) @@ -291,6 +315,7 @@ class SaveImageLM: filename = filename.replace(segment, model) elif key == "date": from datetime import datetime + now = datetime.now() date_table = { "yyyy": f"{now.year:04d}", @@ -311,46 +336,62 @@ class SaveImageLM: for k, v in date_table.items(): date_format = date_format.replace(k, v) filename = filename.replace(segment, date_format) - + return filename - def save_images(self, images, filename_prefix, file_format, id, prompt=None, extra_pnginfo=None, - lossless_webp=True, quality=100, embed_workflow=False, add_counter_to_filename=True): + def save_images( + self, + images, + filename_prefix, + file_format, + id, + prompt=None, + extra_pnginfo=None, + lossless_webp=True, + quality=100, + embed_workflow=False, + add_counter_to_filename=True, + ): """Save images with metadata""" results = [] # Get metadata using the metadata collector raw_metadata = get_metadata() metadata_dict = MetadataProcessor.to_dict(raw_metadata, id) - + metadata = self.format_metadata(metadata_dict) - + # Process filename_prefix with pattern substitution filename_prefix = self.format_filename(filename_prefix, metadata_dict) - + # Get initial save path info once for the batch - full_output_folder, filename, counter, subfolder, processed_prefix = folder_paths.get_save_image_path( - filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0] + full_output_folder, filename, counter, subfolder, processed_prefix = ( + folder_paths.get_save_image_path( + filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0] + ) ) - + # Create directory if it doesn't exist if not os.path.exists(full_output_folder): os.makedirs(full_output_folder, exist_ok=True) - + # Process each image with incrementing counter for i, image in enumerate(images): # Convert the tensor image to numpy array - img = 255. * image.cpu().numpy() + img = 255.0 * image.cpu().numpy() img = Image.fromarray(np.clip(img, 0, 255).astype(np.uint8)) - + # Generate filename with counter if needed base_filename = filename if add_counter_to_filename: # Use counter + i to ensure unique filenames for all images in batch current_counter = counter + i base_filename += f"_{current_counter:05}_" - + # Set file extension and prepare saving parameters + file: str + save_kwargs: Dict[str, Any] + pnginfo: Optional[PngImagePlugin.PngInfo] = None if file_format == "png": file = base_filename + ".png" file_extension = ".png" @@ -362,17 +403,24 @@ class SaveImageLM: file_extension = ".jpg" save_kwargs = {"quality": quality, "optimize": True} elif file_format == "webp": - file = base_filename + ".webp" + file = base_filename + ".webp" file_extension = ".webp" # Add optimization param to control performance - save_kwargs = {"quality": quality, "lossless": lossless_webp, "method": 0} - + save_kwargs = { + "quality": quality, + "lossless": lossless_webp, + "method": 0, + } + else: + raise ValueError(f"Unsupported file format: {file_format}") + # Full save path file_path = os.path.join(full_output_folder, file) - + # Save the image with metadata try: if file_format == "png": + assert pnginfo is not None if metadata: pnginfo.add_text("parameters", metadata) if embed_workflow and extra_pnginfo is not None: @@ -384,7 +432,12 @@ class SaveImageLM: # For JPEG, use piexif if metadata: try: - exif_dict = {'Exif': {piexif.ExifIFD.UserComment: b'UNICODE\0' + metadata.encode('utf-16be')}} + exif_dict = { + "Exif": { + piexif.ExifIFD.UserComment: b"UNICODE\0" + + metadata.encode("utf-16be") + } + } exif_bytes = piexif.dump(exif_dict) save_kwargs["exif"] = exif_bytes except Exception as e: @@ -396,37 +449,52 @@ class SaveImageLM: exif_dict = {} if metadata: - exif_dict['Exif'] = {piexif.ExifIFD.UserComment: b'UNICODE\0' + metadata.encode('utf-16be')} - + exif_dict["Exif"] = { + piexif.ExifIFD.UserComment: b"UNICODE\0" + + metadata.encode("utf-16be") + } + # Add workflow if needed if embed_workflow and extra_pnginfo is not None: - workflow_json = json.dumps(extra_pnginfo["workflow"]) - exif_dict['0th'] = {piexif.ImageIFD.ImageDescription: "Workflow:" + workflow_json} - + workflow_json = json.dumps(extra_pnginfo["workflow"]) + exif_dict["0th"] = { + piexif.ImageIFD.ImageDescription: "Workflow:" + + workflow_json + } + exif_bytes = piexif.dump(exif_dict) save_kwargs["exif"] = exif_bytes except Exception as e: logger.error(f"Error adding EXIF data: {e}") - + img.save(file_path, format="WEBP", **save_kwargs) - - results.append({ - "filename": file, - "subfolder": subfolder, - "type": self.type - }) - + + results.append( + {"filename": file, "subfolder": subfolder, "type": self.type} + ) + except Exception as e: logger.error(f"Error saving image: {e}") - + return results - def process_image(self, images, id, filename_prefix="ComfyUI", file_format="png", prompt=None, extra_pnginfo=None, - lossless_webp=True, quality=100, embed_workflow=False, add_counter_to_filename=True): + def process_image( + self, + images, + id, + filename_prefix="ComfyUI", + file_format="png", + prompt=None, + extra_pnginfo=None, + lossless_webp=True, + quality=100, + embed_workflow=False, + add_counter_to_filename=True, + ): """Process and save image with metadata""" # Make sure the output directory exists os.makedirs(self.output_dir, exist_ok=True) - + # If images is already a list or array of images, do nothing; otherwise, convert to list if isinstance(images, (list, np.ndarray)): pass @@ -436,19 +504,19 @@ class SaveImageLM: images = [images] else: # Multiple images (batch, height, width, channels) images = [img for img in images] - + # Save all images results = self.save_images( - images, - filename_prefix, - file_format, + images, + filename_prefix, + file_format, id, - prompt, + prompt, extra_pnginfo, lossless_webp, quality, embed_workflow, - add_counter_to_filename + add_counter_to_filename, ) - + return (images,) diff --git a/py/nodes/utils.py b/py/nodes/utils.py index 41127e4f..89183508 100644 --- a/py/nodes/utils.py +++ b/py/nodes/utils.py @@ -1,33 +1,35 @@ class AnyType(str): - """A special class that is always equal in not equal comparisons. Credit to pythongosssss""" + """A special class that is always equal in not equal comparisons. Credit to pythongosssss""" + + def __ne__(self, __value: object) -> bool: + return False - def __ne__(self, __value: object) -> bool: - return False # Credit to Regis Gaughan, III (rgthree) class FlexibleOptionalInputType(dict): - """A special class to make flexible nodes that pass data to our python handlers. + """A special class to make flexible nodes that pass data to our python handlers. - Enables both flexible/dynamic input types (like for Any Switch) or a dynamic number of inputs - (like for Any Switch, Context Switch, Context Merge, Power Lora Loader, etc). + Enables both flexible/dynamic input types (like for Any Switch) or a dynamic number of inputs + (like for Any Switch, Context Switch, Context Merge, Power Lora Loader, etc). - Note, for ComfyUI, all that's needed is the `__contains__` override below, which tells ComfyUI - that our node will handle the input, regardless of what it is. + Note, for ComfyUI, all that's needed is the `__contains__` override below, which tells ComfyUI + that our node will handle the input, regardless of what it is. - However, with https://github.com/comfyanonymous/ComfyUI/pull/2666 a large change would occur - requiring more details on the input itself. There, we need to return a list/tuple where the first - item is the type. This can be a real type, or use the AnyType for additional flexibility. + However, with https://github.com/comfyanonymous/ComfyUI/pull/2666 a large change would occur + requiring more details on the input itself. There, we need to return a list/tuple where the first + item is the type. This can be a real type, or use the AnyType for additional flexibility. - This should be forwards compatible unless more changes occur in the PR. - """ - def __init__(self, type): - self.type = type + This should be forwards compatible unless more changes occur in the PR. + """ - def __getitem__(self, key): - return (self.type, ) + def __init__(self, type): + self.type = type - def __contains__(self, key): - return True + def __getitem__(self, key): + return (self.type,) + + def __contains__(self, key): + return True any_type = AnyType("*") @@ -37,25 +39,27 @@ import os import logging import copy import sys -import folder_paths +import folder_paths # type: ignore logger = logging.getLogger(__name__) + def extract_lora_name(lora_path): """Extract the lora name from a lora path (e.g., 'IL\\aorunIllstrious.safetensors' -> 'aorunIllstrious')""" # Get the basename without extension basename = os.path.basename(lora_path) return os.path.splitext(basename)[0] + def get_loras_list(kwargs): """Helper to extract loras list from either old or new kwargs format""" - if 'loras' not in kwargs: + if "loras" not in kwargs: return [] - - loras_data = kwargs['loras'] + + loras_data = kwargs["loras"] # Handle new format: {'loras': {'__value__': [...]}} - if isinstance(loras_data, dict) and '__value__' in loras_data: - return loras_data['__value__'] + if isinstance(loras_data, dict) and "__value__" in loras_data: + return loras_data["__value__"] # Handle old format: {'loras': [...]} elif isinstance(loras_data, list): return loras_data @@ -64,24 +68,26 @@ def get_loras_list(kwargs): logger.warning(f"Unexpected loras format: {type(loras_data)}") return [] + def load_state_dict_in_safetensors(path, device="cpu", filter_prefix=""): - """Simplified version of load_state_dict_in_safetensors that just loads from a local path""" + """Simplified version of load_state_dict_in_safetensors that just loads from a local path""" import safetensors.torch - + state_dict = {} - with safetensors.torch.safe_open(path, framework="pt", device=device) as f: + with safetensors.torch.safe_open(path, framework="pt", device=device) as f: # type: ignore[attr-defined] for k in f.keys(): if filter_prefix and not k.startswith(filter_prefix): continue state_dict[k.removeprefix(filter_prefix)] = f.get_tensor(k) return state_dict + def to_diffusers(input_lora): """Simplified version of to_diffusers for Flux LoRA conversion""" import torch from diffusers.utils.state_dict_utils import convert_unet_state_dict_to_peft - from diffusers.loaders import FluxLoraLoaderMixin - + from diffusers.loaders import FluxLoraLoaderMixin # type: ignore[attr-defined] + if isinstance(input_lora, str): tensors = load_state_dict_in_safetensors(input_lora, device="cpu") else: @@ -91,22 +97,27 @@ def to_diffusers(input_lora): for k, v in tensors.items(): if v.dtype not in [torch.float64, torch.float32, torch.bfloat16, torch.float16]: tensors[k] = v.to(torch.bfloat16) - + new_tensors = FluxLoraLoaderMixin.lora_state_dict(tensors) new_tensors = convert_unet_state_dict_to_peft(new_tensors) return new_tensors + def nunchaku_load_lora(model, lora_name, lora_strength): - """Load a Flux LoRA for Nunchaku model""" + """Load a Flux LoRA for Nunchaku model""" # Get full path to the LoRA file. Allow both direct paths and registered LoRA names. - lora_path = lora_name if os.path.isfile(lora_name) else folder_paths.get_full_path("loras", lora_name) + lora_path = ( + lora_name + if os.path.isfile(lora_name) + else folder_paths.get_full_path("loras", lora_name) + ) if not lora_path or not os.path.isfile(lora_path): logger.warning("Skipping LoRA '%s' because it could not be found", lora_name) return model model_wrapper = model.model.diffusion_model - + # Try to find copy_with_ctx in the same module as ComfyFluxWrapper module_name = model_wrapper.__class__.__module__ module = sys.modules.get(module_name) @@ -118,14 +129,16 @@ def nunchaku_load_lora(model, lora_name, lora_strength): ret_model_wrapper.loras = [*model_wrapper.loras, (lora_path, lora_strength)] else: # Fallback to legacy logic - logger.warning("Please upgrade ComfyUI-nunchaku to 1.1.0 or above for better LoRA support. Falling back to legacy loading logic.") + logger.warning( + "Please upgrade ComfyUI-nunchaku to 1.1.0 or above for better LoRA support. Falling back to legacy loading logic." + ) transformer = model_wrapper.model - + # Save the transformer temporarily model_wrapper.model = None ret_model = copy.deepcopy(model) # copy everything except the model ret_model_wrapper = ret_model.model.diffusion_model - + # Restore the model and set it for the copy model_wrapper.model = transformer ret_model_wrapper.model = transformer @@ -133,15 +146,15 @@ def nunchaku_load_lora(model, lora_name, lora_strength): # Convert the LoRA to diffusers format sd = to_diffusers(lora_path) - + # Handle embedding adjustment if needed if "transformer.x_embedder.lora_A.weight" in sd: new_in_channels = sd["transformer.x_embedder.lora_A.weight"].shape[1] assert new_in_channels % 4 == 0 new_in_channels = new_in_channels // 4 - + old_in_channels = ret_model.model.model_config.unet_config["in_channels"] if old_in_channels < new_in_channels: ret_model.model.model_config.unet_config["in_channels"] = new_in_channels - - return ret_model \ No newline at end of file + + return ret_model diff --git a/py/recipes/factory.py b/py/recipes/factory.py index cab1a6be..6dbcee2b 100644 --- a/py/recipes/factory.py +++ b/py/recipes/factory.py @@ -6,23 +6,24 @@ from .parsers import ( ComfyMetadataParser, MetaFormatParser, AutomaticMetadataParser, - CivitaiApiMetadataParser + CivitaiApiMetadataParser, ) from .base import RecipeMetadataParser logger = logging.getLogger(__name__) + class RecipeParserFactory: """Factory for creating recipe metadata parsers""" - + @staticmethod - def create_parser(metadata) -> RecipeMetadataParser: + def create_parser(metadata) -> RecipeMetadataParser | None: """ Create appropriate parser based on the metadata content - + Args: metadata: The metadata from the image (dict or str) - + Returns: Appropriate RecipeMetadataParser implementation """ @@ -34,17 +35,18 @@ class RecipeParserFactory: except Exception as e: logger.debug(f"CivitaiApiMetadataParser check failed: {e}") pass - + # Convert dict to string for other parsers that expect string input try: import json + metadata_str = json.dumps(metadata) except Exception as e: logger.debug(f"Failed to convert dict to JSON string: {e}") return None else: metadata_str = metadata - + # Try ComfyMetadataParser which requires valid JSON try: if ComfyMetadataParser().is_metadata_matching(metadata_str): @@ -52,7 +54,7 @@ class RecipeParserFactory: except Exception: # If JSON parsing fails, move on to other parsers pass - + # Check other parsers that expect string input if RecipeFormatParser().is_metadata_matching(metadata_str): return RecipeFormatParser() diff --git a/py/recipes/parsers/civitai_image.py b/py/recipes/parsers/civitai_image.py index f3b4f8ba..87910c01 100644 --- a/py/recipes/parsers/civitai_image.py +++ b/py/recipes/parsers/civitai_image.py @@ -9,15 +9,16 @@ from ...services.metadata_service import get_default_metadata_provider logger = logging.getLogger(__name__) + class CivitaiApiMetadataParser(RecipeMetadataParser): """Parser for Civitai image metadata format""" - + def is_metadata_matching(self, metadata) -> bool: """Check if the metadata matches the Civitai image metadata format - + Args: metadata: The metadata from the image (dict) - + Returns: bool: True if this parser can handle the metadata """ @@ -28,7 +29,7 @@ class CivitaiApiMetadataParser(RecipeMetadataParser): # Check for common CivitAI image metadata fields civitai_image_fields = ( "resources", - "civitaiResources", + "civitaiResources", "additionalResources", "hashes", "prompt", @@ -40,7 +41,7 @@ class CivitaiApiMetadataParser(RecipeMetadataParser): "width", "height", "Model", - "Model hash" + "Model hash", ) return any(key in payload for key in civitai_image_fields) @@ -50,7 +51,9 @@ class CivitaiApiMetadataParser(RecipeMetadataParser): # Check for LoRA hash patterns hashes = metadata.get("hashes") - if isinstance(hashes, dict) and any(str(key).lower().startswith("lora:") for key in hashes): + if isinstance(hashes, dict) and any( + str(key).lower().startswith("lora:") for key in hashes + ): return True # Check nested meta object (common in CivitAI image responses) @@ -61,22 +64,28 @@ class CivitaiApiMetadataParser(RecipeMetadataParser): # Also check for LoRA hash patterns in nested meta hashes = nested_meta.get("hashes") - if isinstance(hashes, dict) and any(str(key).lower().startswith("lora:") for key in hashes): + if isinstance(hashes, dict) and any( + str(key).lower().startswith("lora:") for key in hashes + ): return True return False - - async def parse_metadata(self, metadata, recipe_scanner=None, civitai_client=None) -> Dict[str, Any]: + + async def parse_metadata( # type: ignore[override] + self, user_comment, recipe_scanner=None, civitai_client=None + ) -> Dict[str, Any]: """Parse metadata from Civitai image format - + Args: - metadata: The metadata from the image (dict) + user_comment: The metadata from the image (dict) recipe_scanner: Optional recipe scanner service civitai_client: Optional Civitai API client (deprecated, use metadata_provider instead) - + Returns: Dict containing parsed recipe data """ + metadata: Dict[str, Any] = user_comment # type: ignore[assignment] + metadata = user_comment try: # Get metadata provider instead of using civitai_client directly metadata_provider = await get_default_metadata_provider() @@ -100,19 +109,19 @@ class CivitaiApiMetadataParser(RecipeMetadataParser): ) ): metadata = inner_meta - + # Initialize result structure result = { - 'base_model': None, - 'loras': [], - 'model': None, - 'gen_params': {}, - 'from_civitai_image': True + "base_model": None, + "loras": [], + "model": None, + "gen_params": {}, + "from_civitai_image": True, } - + # Track already added LoRAs to prevent duplicates added_loras = {} # key: model_version_id or hash, value: index in result["loras"] - + # Extract hash information from hashes field for LoRA matching lora_hashes = {} if "hashes" in metadata and isinstance(metadata["hashes"], dict): @@ -121,14 +130,14 @@ class CivitaiApiMetadataParser(RecipeMetadataParser): if key_str.lower().startswith("lora:"): lora_name = key_str.split(":", 1)[1] lora_hashes[lora_name] = hash_value - + # Extract prompt and negative prompt if "prompt" in metadata: result["gen_params"]["prompt"] = metadata["prompt"] - + if "negativePrompt" in metadata: result["gen_params"]["negative_prompt"] = metadata["negativePrompt"] - + # Extract other generation parameters param_mapping = { "steps": "steps", @@ -138,98 +147,117 @@ class CivitaiApiMetadataParser(RecipeMetadataParser): "Size": "size", "clipSkip": "clip_skip", } - + for civitai_key, our_key in param_mapping.items(): if civitai_key in metadata and our_key in GEN_PARAM_KEYS: result["gen_params"][our_key] = metadata[civitai_key] - + # Extract base model information - directly if available if "baseModel" in metadata: result["base_model"] = metadata["baseModel"] elif "Model hash" in metadata and metadata_provider: model_hash = metadata["Model hash"] - model_info, error = await metadata_provider.get_model_by_hash(model_hash) + model_info, error = await metadata_provider.get_model_by_hash( + model_hash + ) if model_info: result["base_model"] = model_info.get("baseModel", "") elif "Model" in metadata and isinstance(metadata.get("resources"), list): # Try to find base model in resources for resource in metadata.get("resources", []): - if resource.get("type") == "model" and resource.get("name") == metadata.get("Model"): + if resource.get("type") == "model" and resource.get( + "name" + ) == metadata.get("Model"): # This is likely the checkpoint model if metadata_provider and resource.get("hash"): - model_info, error = await metadata_provider.get_model_by_hash(resource.get("hash")) + ( + model_info, + error, + ) = await metadata_provider.get_model_by_hash( + resource.get("hash") + ) if model_info: result["base_model"] = model_info.get("baseModel", "") - + base_model_counts = {} - + # Process standard resources array if "resources" in metadata and isinstance(metadata["resources"], list): for resource in metadata["resources"]: # Modified to process resources without a type field as potential LoRAs if resource.get("type", "lora") == "lora": lora_hash = resource.get("hash", "") - + # Try to get hash from the hashes field if not present in resource if not lora_hash and resource.get("name"): lora_hash = lora_hashes.get(resource["name"], "") - + # Skip LoRAs without proper identification (hash or modelVersionId) if not lora_hash and not resource.get("modelVersionId"): - logger.debug(f"Skipping LoRA resource '{resource.get('name', 'Unknown')}' - no hash or modelVersionId") + logger.debug( + f"Skipping LoRA resource '{resource.get('name', 'Unknown')}' - no hash or modelVersionId" + ) continue - + # Skip if we've already added this LoRA by hash if lora_hash and lora_hash in added_loras: continue - + lora_entry = { - 'name': resource.get("name", "Unknown LoRA"), - 'type': "lora", - 'weight': float(resource.get("weight", 1.0)), - 'hash': lora_hash, - 'existsLocally': False, - 'localPath': None, - 'file_name': resource.get("name", "Unknown"), - 'thumbnailUrl': '/loras_static/images/no-preview.png', - 'baseModel': '', - 'size': 0, - 'downloadUrl': '', - 'isDeleted': False + "name": resource.get("name", "Unknown LoRA"), + "type": "lora", + "weight": float(resource.get("weight", 1.0)), + "hash": lora_hash, + "existsLocally": False, + "localPath": None, + "file_name": resource.get("name", "Unknown"), + "thumbnailUrl": "/loras_static/images/no-preview.png", + "baseModel": "", + "size": 0, + "downloadUrl": "", + "isDeleted": False, } - + # Try to get info from Civitai if hash is available - if lora_entry['hash'] and metadata_provider: + if lora_entry["hash"] and metadata_provider: try: - civitai_info = await metadata_provider.get_model_by_hash(lora_hash) - + civitai_info = ( + await metadata_provider.get_model_by_hash(lora_hash) + ) + populated_entry = await self.populate_lora_from_civitai( lora_entry, civitai_info, recipe_scanner, base_model_counts, - lora_hash + lora_hash, ) - + if populated_entry is None: continue # Skip invalid LoRA types - + lora_entry = populated_entry - + # If we have a version ID from Civitai, track it for deduplication - if 'id' in lora_entry and lora_entry['id']: - added_loras[str(lora_entry['id'])] = len(result["loras"]) + if "id" in lora_entry and lora_entry["id"]: + added_loras[str(lora_entry["id"])] = len( + result["loras"] + ) except Exception as e: - logger.error(f"Error fetching Civitai info for LoRA hash {lora_entry['hash']}: {e}") - + logger.error( + f"Error fetching Civitai info for LoRA hash {lora_entry['hash']}: {e}" + ) + # Track by hash if we have it if lora_hash: added_loras[lora_hash] = len(result["loras"]) - + result["loras"].append(lora_entry) - + # Process civitaiResources array - if "civitaiResources" in metadata and isinstance(metadata["civitaiResources"], list): + if "civitaiResources" in metadata and isinstance( + metadata["civitaiResources"], list + ): for resource in metadata["civitaiResources"]: # Get resource type and identifier resource_type = str(resource.get("type") or "").lower() @@ -237,32 +265,39 @@ class CivitaiApiMetadataParser(RecipeMetadataParser): if resource_type == "checkpoint": checkpoint_entry = { - 'id': resource.get("modelVersionId", 0), - 'modelId': resource.get("modelId", 0), - 'name': resource.get("modelName", "Unknown Checkpoint"), - 'version': resource.get("modelVersionName", ""), - 'type': resource.get("type", "checkpoint"), - 'existsLocally': False, - 'localPath': None, - 'file_name': resource.get("modelName", ""), - 'hash': resource.get("hash", "") or "", - 'thumbnailUrl': '/loras_static/images/no-preview.png', - 'baseModel': '', - 'size': 0, - 'downloadUrl': '', - 'isDeleted': False + "id": resource.get("modelVersionId", 0), + "modelId": resource.get("modelId", 0), + "name": resource.get("modelName", "Unknown Checkpoint"), + "version": resource.get("modelVersionName", ""), + "type": resource.get("type", "checkpoint"), + "existsLocally": False, + "localPath": None, + "file_name": resource.get("modelName", ""), + "hash": resource.get("hash", "") or "", + "thumbnailUrl": "/loras_static/images/no-preview.png", + "baseModel": "", + "size": 0, + "downloadUrl": "", + "isDeleted": False, } if version_id and metadata_provider: try: - civitai_info = await metadata_provider.get_model_version_info(version_id) + civitai_info = ( + await metadata_provider.get_model_version_info( + version_id + ) + ) - checkpoint_entry = await self.populate_checkpoint_from_civitai( - checkpoint_entry, - civitai_info + checkpoint_entry = ( + await self.populate_checkpoint_from_civitai( + checkpoint_entry, civitai_info + ) ) except Exception as e: - logger.error(f"Error fetching Civitai info for checkpoint version {version_id}: {e}") + logger.error( + f"Error fetching Civitai info for checkpoint version {version_id}: {e}" + ) if result["model"] is None: result["model"] = checkpoint_entry @@ -275,31 +310,35 @@ class CivitaiApiMetadataParser(RecipeMetadataParser): # Initialize lora entry lora_entry = { - 'id': resource.get("modelVersionId", 0), - 'modelId': resource.get("modelId", 0), - 'name': resource.get("modelName", "Unknown LoRA"), - 'version': resource.get("modelVersionName", ""), - 'type': resource.get("type", "lora"), - 'weight': round(float(resource.get("weight", 1.0)), 2), - 'existsLocally': False, - 'thumbnailUrl': '/loras_static/images/no-preview.png', - 'baseModel': '', - 'size': 0, - 'downloadUrl': '', - 'isDeleted': False + "id": resource.get("modelVersionId", 0), + "modelId": resource.get("modelId", 0), + "name": resource.get("modelName", "Unknown LoRA"), + "version": resource.get("modelVersionName", ""), + "type": resource.get("type", "lora"), + "weight": round(float(resource.get("weight", 1.0)), 2), + "existsLocally": False, + "thumbnailUrl": "/loras_static/images/no-preview.png", + "baseModel": "", + "size": 0, + "downloadUrl": "", + "isDeleted": False, } # Try to get info from Civitai if modelVersionId is available if version_id and metadata_provider: try: # Use get_model_version_info instead of get_model_version - civitai_info = await metadata_provider.get_model_version_info(version_id) + civitai_info = ( + await metadata_provider.get_model_version_info( + version_id + ) + ) populated_entry = await self.populate_lora_from_civitai( lora_entry, civitai_info, recipe_scanner, - base_model_counts + base_model_counts, ) if populated_entry is None: @@ -307,74 +346,87 @@ class CivitaiApiMetadataParser(RecipeMetadataParser): lora_entry = populated_entry except Exception as e: - logger.error(f"Error fetching Civitai info for model version {version_id}: {e}") + logger.error( + f"Error fetching Civitai info for model version {version_id}: {e}" + ) # Track this LoRA in our deduplication dict if version_id: added_loras[version_id] = len(result["loras"]) result["loras"].append(lora_entry) - + # Process additionalResources array - if "additionalResources" in metadata and isinstance(metadata["additionalResources"], list): + if "additionalResources" in metadata and isinstance( + metadata["additionalResources"], list + ): for resource in metadata["additionalResources"]: # Skip resources that aren't LoRAs or LyCORIS - if resource.get("type") not in ["lora", "lycoris"] and "type" not in resource: + if ( + resource.get("type") not in ["lora", "lycoris"] + and "type" not in resource + ): continue - + lora_type = resource.get("type", "lora") name = resource.get("name", "") - + # Extract ID from URN format if available version_id = None if name and "civitai:" in name: parts = name.split("@") if len(parts) > 1: version_id = parts[1] - + # Skip if we've already added this LoRA if version_id in added_loras: continue - + lora_entry = { - 'name': name, - 'type': lora_type, - 'weight': float(resource.get("strength", 1.0)), - 'hash': "", - 'existsLocally': False, - 'localPath': None, - 'file_name': name, - 'thumbnailUrl': '/loras_static/images/no-preview.png', - 'baseModel': '', - 'size': 0, - 'downloadUrl': '', - 'isDeleted': False + "name": name, + "type": lora_type, + "weight": float(resource.get("strength", 1.0)), + "hash": "", + "existsLocally": False, + "localPath": None, + "file_name": name, + "thumbnailUrl": "/loras_static/images/no-preview.png", + "baseModel": "", + "size": 0, + "downloadUrl": "", + "isDeleted": False, } - + # If we have a version ID and metadata provider, try to get more info if version_id and metadata_provider: try: # Use get_model_version_info with the version ID - civitai_info = await metadata_provider.get_model_version_info(version_id) - + civitai_info = ( + await metadata_provider.get_model_version_info( + version_id + ) + ) + populated_entry = await self.populate_lora_from_civitai( lora_entry, civitai_info, recipe_scanner, - base_model_counts + base_model_counts, ) - + if populated_entry is None: continue # Skip invalid LoRA types - + lora_entry = populated_entry - + # Track this LoRA for deduplication if version_id: added_loras[version_id] = len(result["loras"]) except Exception as e: - logger.error(f"Error fetching Civitai info for model ID {version_id}: {e}") - + logger.error( + f"Error fetching Civitai info for model ID {version_id}: {e}" + ) + result["loras"].append(lora_entry) # If we found LoRA hashes in the metadata but haven't already @@ -390,30 +442,32 @@ class CivitaiApiMetadataParser(RecipeMetadataParser): continue lora_entry = { - 'name': lora_name, - 'type': "lora", - 'weight': 1.0, - 'hash': lora_hash, - 'existsLocally': False, - 'localPath': None, - 'file_name': lora_name, - 'thumbnailUrl': '/loras_static/images/no-preview.png', - 'baseModel': '', - 'size': 0, - 'downloadUrl': '', - 'isDeleted': False + "name": lora_name, + "type": "lora", + "weight": 1.0, + "hash": lora_hash, + "existsLocally": False, + "localPath": None, + "file_name": lora_name, + "thumbnailUrl": "/loras_static/images/no-preview.png", + "baseModel": "", + "size": 0, + "downloadUrl": "", + "isDeleted": False, } if metadata_provider: try: - civitai_info = await metadata_provider.get_model_by_hash(lora_hash) + civitai_info = await metadata_provider.get_model_by_hash( + lora_hash + ) populated_entry = await self.populate_lora_from_civitai( lora_entry, civitai_info, recipe_scanner, base_model_counts, - lora_hash + lora_hash, ) if populated_entry is None: @@ -421,80 +475,93 @@ class CivitaiApiMetadataParser(RecipeMetadataParser): lora_entry = populated_entry - if 'id' in lora_entry and lora_entry['id']: - added_loras[str(lora_entry['id'])] = len(result["loras"]) + if "id" in lora_entry and lora_entry["id"]: + added_loras[str(lora_entry["id"])] = len(result["loras"]) except Exception as e: - logger.error(f"Error fetching Civitai info for LoRA hash {lora_hash}: {e}") + logger.error( + f"Error fetching Civitai info for LoRA hash {lora_hash}: {e}" + ) added_loras[lora_hash] = len(result["loras"]) result["loras"].append(lora_entry) # Check for LoRA info in the format "Lora_0 Model hash", "Lora_0 Model name", etc. lora_index = 0 - while f"Lora_{lora_index} Model hash" in metadata and f"Lora_{lora_index} Model name" in metadata: + while ( + f"Lora_{lora_index} Model hash" in metadata + and f"Lora_{lora_index} Model name" in metadata + ): lora_hash = metadata[f"Lora_{lora_index} Model hash"] lora_name = metadata[f"Lora_{lora_index} Model name"] - lora_strength_model = float(metadata.get(f"Lora_{lora_index} Strength model", 1.0)) - + lora_strength_model = float( + metadata.get(f"Lora_{lora_index} Strength model", 1.0) + ) + # Skip if we've already added this LoRA by hash if lora_hash and lora_hash in added_loras: lora_index += 1 continue - + lora_entry = { - 'name': lora_name, - 'type': "lora", - 'weight': lora_strength_model, - 'hash': lora_hash, - 'existsLocally': False, - 'localPath': None, - 'file_name': lora_name, - 'thumbnailUrl': '/loras_static/images/no-preview.png', - 'baseModel': '', - 'size': 0, - 'downloadUrl': '', - 'isDeleted': False + "name": lora_name, + "type": "lora", + "weight": lora_strength_model, + "hash": lora_hash, + "existsLocally": False, + "localPath": None, + "file_name": lora_name, + "thumbnailUrl": "/loras_static/images/no-preview.png", + "baseModel": "", + "size": 0, + "downloadUrl": "", + "isDeleted": False, } - + # Try to get info from Civitai if hash is available - if lora_entry['hash'] and metadata_provider: + if lora_entry["hash"] and metadata_provider: try: - civitai_info = await metadata_provider.get_model_by_hash(lora_hash) - + civitai_info = await metadata_provider.get_model_by_hash( + lora_hash + ) + populated_entry = await self.populate_lora_from_civitai( lora_entry, civitai_info, recipe_scanner, base_model_counts, - lora_hash + lora_hash, ) - + if populated_entry is None: lora_index += 1 continue # Skip invalid LoRA types - + lora_entry = populated_entry - + # If we have a version ID from Civitai, track it for deduplication - if 'id' in lora_entry and lora_entry['id']: - added_loras[str(lora_entry['id'])] = len(result["loras"]) + if "id" in lora_entry and lora_entry["id"]: + added_loras[str(lora_entry["id"])] = len(result["loras"]) except Exception as e: - logger.error(f"Error fetching Civitai info for LoRA hash {lora_entry['hash']}: {e}") - + logger.error( + f"Error fetching Civitai info for LoRA hash {lora_entry['hash']}: {e}" + ) + # Track by hash if we have it if lora_hash: added_loras[lora_hash] = len(result["loras"]) - + result["loras"].append(lora_entry) - + lora_index += 1 - + # If base model wasn't found earlier, use the most common one from LoRAs if not result["base_model"] and base_model_counts: - result["base_model"] = max(base_model_counts.items(), key=lambda x: x[1])[0] - + result["base_model"] = max( + base_model_counts.items(), key=lambda x: x[1] + )[0] + return result - + except Exception as e: logger.error(f"Error parsing Civitai image metadata: {e}", exc_info=True) return {"error": str(e), "loras": []} diff --git a/py/services/civitai_client.py b/py/services/civitai_client.py index 503988b7..31a51aed 100644 --- a/py/services/civitai_client.py +++ b/py/services/civitai_client.py @@ -3,36 +3,42 @@ import copy import logging import os from typing import Any, Optional, Dict, Tuple, List, Sequence -from .model_metadata_provider import CivitaiModelMetadataProvider, ModelMetadataProviderManager +from .model_metadata_provider import ( + CivitaiModelMetadataProvider, + ModelMetadataProviderManager, +) from .downloader import get_downloader from .errors import RateLimitError, ResourceNotFoundError from ..utils.civitai_utils import resolve_license_payload logger = logging.getLogger(__name__) + class CivitaiClient: _instance = None _lock = asyncio.Lock() - + @classmethod async def get_instance(cls): """Get singleton instance of CivitaiClient""" async with cls._lock: if cls._instance is None: cls._instance = cls() - + # Register this client as a metadata provider provider_manager = await ModelMetadataProviderManager.get_instance() - provider_manager.register_provider('civitai', CivitaiModelMetadataProvider(cls._instance), True) - + provider_manager.register_provider( + "civitai", CivitaiModelMetadataProvider(cls._instance), True + ) + return cls._instance def __init__(self): # Check if already initialized for singleton pattern - if hasattr(self, '_initialized'): + if hasattr(self, "_initialized"): return self._initialized = True - + self.base_url = "https://civitai.com/api/v1" async def _make_request( @@ -75,8 +81,10 @@ class CivitaiClient: meta = image.get("meta") if isinstance(meta, dict) and "comfy" in meta: meta.pop("comfy", None) - - async def download_file(self, url: str, save_dir: str, default_filename: str, progress_callback=None) -> Tuple[bool, str]: + + async def download_file( + self, url: str, save_dir: str, default_filename: str, progress_callback=None + ) -> Tuple[bool, str]: """Download file with resumable downloads and retry mechanism Args: @@ -90,41 +98,48 @@ class CivitaiClient: """ downloader = await get_downloader() save_path = os.path.join(save_dir, default_filename) - + # Use unified downloader with CivitAI authentication success, result = await downloader.download_file( url=url, save_path=save_path, progress_callback=progress_callback, use_auth=True, # Enable CivitAI authentication - allow_resume=True + allow_resume=True, ) - + return success, result - async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]: + async def get_model_by_hash( + self, model_hash: str + ) -> Tuple[Optional[Dict], Optional[str]]: try: success, version = await self._make_request( - 'GET', + "GET", f"{self.base_url}/model-versions/by-hash/{model_hash}", - use_auth=True + use_auth=True, ) if not success: message = str(version) if "not found" in message.lower(): return None, "Model not found" - logger.error("Failed to fetch model info for %s: %s", model_hash[:10], message) + logger.error( + "Failed to fetch model info for %s: %s", model_hash[:10], message + ) return None, message - model_id = version.get('modelId') - if model_id: - model_data = await self._fetch_model_data(model_id) - if model_data: - self._enrich_version_with_model_data(version, model_data) + if isinstance(version, dict): + model_id = version.get("modelId") + if model_id: + model_data = await self._fetch_model_data(model_id) + if model_data: + self._enrich_version_with_model_data(version, model_data) - self._remove_comfy_metadata(version) - return version, None + self._remove_comfy_metadata(version) + return version, None + else: + return None, "Invalid response format" except RateLimitError: raise except Exception as exc: @@ -136,19 +151,19 @@ class CivitaiClient: downloader = await get_downloader() success, content, headers = await downloader.download_to_memory( image_url, - use_auth=False # Preview images don't need auth + use_auth=False, # Preview images don't need auth ) if success: # Ensure directory exists os.makedirs(os.path.dirname(save_path), exist_ok=True) - with open(save_path, 'wb') as f: + with open(save_path, "wb") as f: f.write(content) return True return False except Exception as e: logger.error(f"Download Error: {str(e)}") return False - + @staticmethod def _extract_error_message(payload: Any) -> str: """Return a human-readable error message from an API payload.""" @@ -175,19 +190,17 @@ class CivitaiClient: """Get all versions of a model with local availability info""" try: success, result = await self._make_request( - 'GET', - f"{self.base_url}/models/{model_id}", - use_auth=True + "GET", f"{self.base_url}/models/{model_id}", use_auth=True ) if success: # Also return model type along with versions return { - 'modelVersions': result.get('modelVersions', []), - 'type': result.get('type', ''), - 'name': result.get('name', '') + "modelVersions": result.get("modelVersions", []), + "type": result.get("type", ""), + "name": result.get("name", ""), } message = self._extract_error_message(result) - if message and 'not found' in message.lower(): + if message and "not found" in message.lower(): raise ResourceNotFoundError(f"Resource not found for model {model_id}") if message: raise RuntimeError(message) @@ -221,15 +234,15 @@ class CivitaiClient: try: query = ",".join(normalized_ids) success, result = await self._make_request( - 'GET', + "GET", f"{self.base_url}/models", use_auth=True, - params={'ids': query}, + params={"ids": query}, ) if not success: return None - items = result.get('items') if isinstance(result, dict) else None + items = result.get("items") if isinstance(result, dict) else None if not isinstance(items, list): return {} @@ -237,19 +250,19 @@ class CivitaiClient: for item in items: if not isinstance(item, dict): continue - model_id = item.get('id') + model_id = item.get("id") try: normalized_id = int(model_id) except (TypeError, ValueError): continue payload[normalized_id] = { - 'modelVersions': item.get('modelVersions', []), - 'type': item.get('type', ''), - 'name': item.get('name', ''), - 'allowNoCredit': item.get('allowNoCredit'), - 'allowCommercialUse': item.get('allowCommercialUse'), - 'allowDerivatives': item.get('allowDerivatives'), - 'allowDifferentLicense': item.get('allowDifferentLicense'), + "modelVersions": item.get("modelVersions", []), + "type": item.get("type", ""), + "name": item.get("name", ""), + "allowNoCredit": item.get("allowNoCredit"), + "allowCommercialUse": item.get("allowCommercialUse"), + "allowDerivatives": item.get("allowDerivatives"), + "allowDifferentLicense": item.get("allowDifferentLicense"), } return payload except RateLimitError: @@ -257,8 +270,10 @@ class CivitaiClient: except Exception as exc: logger.error(f"Error fetching model versions in bulk: {exc}") return None - - async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]: + + async def get_model_version( + self, model_id: int = None, version_id: int = None + ) -> Optional[Dict]: """Get specific model version with additional metadata.""" try: if model_id is None and version_id is not None: @@ -281,7 +296,7 @@ class CivitaiClient: if version is None: return None - model_id = version.get('modelId') + model_id = version.get("modelId") if not model_id: logger.error(f"No modelId found in version {version_id}") return None @@ -293,7 +308,9 @@ class CivitaiClient: self._remove_comfy_metadata(version) return version - async def _get_version_with_model_id(self, model_id: int, version_id: Optional[int]) -> Optional[Dict]: + async def _get_version_with_model_id( + self, model_id: int, version_id: Optional[int] + ) -> Optional[Dict]: model_data = await self._fetch_model_data(model_id) if not model_data: return None @@ -302,8 +319,12 @@ class CivitaiClient: if target_version is None: return None - target_version_id = target_version.get('id') - version = await self._fetch_version_by_id(target_version_id) if target_version_id else None + target_version_id = target_version.get("id") + version = ( + await self._fetch_version_by_id(target_version_id) + if target_version_id + else None + ) if version is None: model_hash = self._extract_primary_model_hash(target_version) @@ -315,7 +336,9 @@ class CivitaiClient: ) if version is None: - version = self._build_version_from_model_data(target_version, model_id, model_data) + version = self._build_version_from_model_data( + target_version, model_id, model_data + ) self._enrich_version_with_model_data(version, model_data) self._remove_comfy_metadata(version) @@ -323,9 +346,7 @@ class CivitaiClient: async def _fetch_model_data(self, model_id: int) -> Optional[Dict]: success, data = await self._make_request( - 'GET', - f"{self.base_url}/models/{model_id}", - use_auth=True + "GET", f"{self.base_url}/models/{model_id}", use_auth=True ) if success: return data @@ -337,9 +358,7 @@ class CivitaiClient: return None success, version = await self._make_request( - 'GET', - f"{self.base_url}/model-versions/{version_id}", - use_auth=True + "GET", f"{self.base_url}/model-versions/{version_id}", use_auth=True ) if success: return version @@ -352,9 +371,7 @@ class CivitaiClient: return None success, version = await self._make_request( - 'GET', - f"{self.base_url}/model-versions/by-hash/{model_hash}", - use_auth=True + "GET", f"{self.base_url}/model-versions/by-hash/{model_hash}", use_auth=True ) if success: return version @@ -362,16 +379,17 @@ class CivitaiClient: logger.warning(f"Failed to fetch version by hash {model_hash}") return None - def _select_target_version(self, model_data: Dict, model_id: int, version_id: Optional[int]) -> Optional[Dict]: - model_versions = model_data.get('modelVersions', []) + def _select_target_version( + self, model_data: Dict, model_id: int, version_id: Optional[int] + ) -> Optional[Dict]: + model_versions = model_data.get("modelVersions", []) if not model_versions: logger.warning(f"No model versions found for model {model_id}") return None if version_id is not None: target_version = next( - (item for item in model_versions if item.get('id') == version_id), - None + (item for item in model_versions if item.get("id") == version_id), None ) if target_version is None: logger.warning( @@ -383,46 +401,50 @@ class CivitaiClient: return model_versions[0] def _extract_primary_model_hash(self, version_entry: Dict) -> Optional[str]: - for file_info in version_entry.get('files', []): - if file_info.get('type') == 'Model' and file_info.get('primary'): - hashes = file_info.get('hashes', {}) - model_hash = hashes.get('SHA256') + for file_info in version_entry.get("files", []): + if file_info.get("type") == "Model" and file_info.get("primary"): + hashes = file_info.get("hashes", {}) + model_hash = hashes.get("SHA256") if model_hash: return model_hash return None - def _build_version_from_model_data(self, version_entry: Dict, model_id: int, model_data: Dict) -> Dict: + def _build_version_from_model_data( + self, version_entry: Dict, model_id: int, model_data: Dict + ) -> Dict: version = copy.deepcopy(version_entry) - version.pop('index', None) - version['modelId'] = model_id - version['model'] = { - 'name': model_data.get('name'), - 'type': model_data.get('type'), - 'nsfw': model_data.get('nsfw'), - 'poi': model_data.get('poi') + version.pop("index", None) + version["modelId"] = model_id + version["model"] = { + "name": model_data.get("name"), + "type": model_data.get("type"), + "nsfw": model_data.get("nsfw"), + "poi": model_data.get("poi"), } return version def _enrich_version_with_model_data(self, version: Dict, model_data: Dict) -> None: - model_info = version.get('model') + model_info = version.get("model") if not isinstance(model_info, dict): model_info = {} - version['model'] = model_info + version["model"] = model_info - model_info['description'] = model_data.get("description") - model_info['tags'] = model_data.get("tags", []) - version['creator'] = model_data.get("creator") + model_info["description"] = model_data.get("description") + model_info["tags"] = model_data.get("tags", []) + version["creator"] = model_data.get("creator") license_payload = resolve_license_payload(model_data) for field, value in license_payload.items(): model_info[field] = value - async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]: + async def get_model_version_info( + self, version_id: str + ) -> Tuple[Optional[Dict], Optional[str]]: """Fetch model version metadata from Civitai - + Args: version_id: The Civitai model version ID - + Returns: Tuple[Optional[Dict], Optional[str]]: A tuple containing: - The model version data or None if not found @@ -430,25 +452,23 @@ class CivitaiClient: """ try: url = f"{self.base_url}/model-versions/{version_id}" - + logger.debug(f"Resolving DNS for model version info: {url}") - success, result = await self._make_request( - 'GET', - url, - use_auth=True - ) - + success, result = await self._make_request("GET", url, use_auth=True) + if success: - logger.debug(f"Successfully fetched model version info for: {version_id}") + logger.debug( + f"Successfully fetched model version info for: {version_id}" + ) self._remove_comfy_metadata(result) return result, None - + # Handle specific error cases if "not found" in str(result): error_msg = f"Model not found" logger.warning(f"Model version not found: {version_id} - {error_msg}") return None, error_msg - + # Other error cases logger.error(f"Failed to fetch model info for {version_id}: {result}") return None, str(result) @@ -464,27 +484,23 @@ class CivitaiClient: Args: image_id: The Civitai image ID - + Returns: Optional[Dict]: The image data or None if not found """ try: url = f"{self.base_url}/images?imageId={image_id}&nsfw=X" - + logger.debug(f"Fetching image info for ID: {image_id}") - success, result = await self._make_request( - 'GET', - url, - use_auth=True - ) - + success, result = await self._make_request("GET", url, use_auth=True) + if success: if result and "items" in result and len(result["items"]) > 0: logger.debug(f"Successfully fetched image info for ID: {image_id}") return result["items"][0] logger.warning(f"No image found with ID: {image_id}") return None - + logger.error(f"Failed to fetch image info for ID: {image_id}: {result}") return None except RateLimitError: @@ -501,11 +517,7 @@ class CivitaiClient: try: url = f"{self.base_url}/models?username={username}" - success, result = await self._make_request( - 'GET', - url, - use_auth=True - ) + success, result = await self._make_request("GET", url, use_auth=True) if not success: logger.error("Failed to fetch models for %s: %s", username, result)