Fix null-safety issues and apply code formatting

Bug fixes:
- Add null guards for base_models_roots/embeddings_roots in backup cleanup
- Fix null-safety initialization of extra_unet_roots

Formatting:
- Apply consistent code style across Python files
- Fix line wrapping, quote consistency, and trailing commas
- Add type ignore comments for dynamic/platform-specific code
This commit is contained in:
Will Miao
2026-02-28 21:38:41 +08:00
parent b005961ee5
commit c9e5ea42cb
9 changed files with 1097 additions and 760 deletions

View File

@@ -2,7 +2,7 @@ import os
import platform import platform
import threading import threading
from pathlib import Path from pathlib import Path
import folder_paths # type: ignore import folder_paths # type: ignore
from typing import Any, Dict, Iterable, List, Mapping, Optional, Set, Tuple from typing import Any, Dict, Iterable, List, Mapping, Optional, Set, Tuple
import logging import logging
import json import json
@@ -10,16 +10,23 @@ import urllib.parse
import time import time
from .utils.cache_paths import CacheType, get_cache_file_path, get_legacy_cache_paths from .utils.cache_paths import CacheType, get_cache_file_path, get_legacy_cache_paths
from .utils.settings_paths import ensure_settings_file, get_settings_dir, load_settings_template from .utils.settings_paths import (
ensure_settings_file,
get_settings_dir,
load_settings_template,
)
# Use an environment variable to control standalone mode # Use an environment variable to control standalone mode
standalone_mode = os.environ.get("LORA_MANAGER_STANDALONE", "0") == "1" or os.environ.get("HF_HUB_DISABLE_TELEMETRY", "0") == "0" standalone_mode = (
os.environ.get("LORA_MANAGER_STANDALONE", "0") == "1"
or os.environ.get("HF_HUB_DISABLE_TELEMETRY", "0") == "0"
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _normalize_folder_paths_for_comparison( def _normalize_folder_paths_for_comparison(
folder_paths: Mapping[str, Iterable[str]] folder_paths: Mapping[str, Iterable[str]],
) -> Dict[str, Set[str]]: ) -> Dict[str, Set[str]]:
"""Normalize folder paths for comparison across libraries.""" """Normalize folder paths for comparison across libraries."""
@@ -49,7 +56,7 @@ def _normalize_folder_paths_for_comparison(
def _normalize_library_folder_paths( def _normalize_library_folder_paths(
library_payload: Mapping[str, Any] library_payload: Mapping[str, Any],
) -> Dict[str, Set[str]]: ) -> Dict[str, Set[str]]:
"""Return normalized folder paths extracted from a library payload.""" """Return normalized folder paths extracted from a library payload."""
@@ -76,9 +83,15 @@ class Config:
"""Global configuration for LoRA Manager""" """Global configuration for LoRA Manager"""
def __init__(self): def __init__(self):
self.templates_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'templates') self.templates_path = os.path.join(
self.static_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'static') os.path.dirname(os.path.dirname(__file__)), "templates"
self.i18n_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'locales') )
self.static_path = os.path.join(
os.path.dirname(os.path.dirname(__file__)), "static"
)
self.i18n_path = os.path.join(
os.path.dirname(os.path.dirname(__file__)), "locales"
)
# Path mapping dictionary, target to link mapping # Path mapping dictionary, target to link mapping
self._path_mappings: Dict[str, str] = {} self._path_mappings: Dict[str, str] = {}
# Normalized preview root directories used to validate preview access # Normalized preview root directories used to validate preview access
@@ -152,17 +165,21 @@ class Config:
default_library = libraries.get("default", {}) default_library = libraries.get("default", {})
target_folder_paths = { target_folder_paths = {
'loras': list(self.loras_roots), "loras": list(self.loras_roots),
'checkpoints': list(self.checkpoints_roots or []), "checkpoints": list(self.checkpoints_roots or []),
'unet': list(self.unet_roots or []), "unet": list(self.unet_roots or []),
'embeddings': list(self.embeddings_roots or []), "embeddings": list(self.embeddings_roots or []),
} }
normalized_target_paths = _normalize_folder_paths_for_comparison(target_folder_paths) normalized_target_paths = _normalize_folder_paths_for_comparison(
target_folder_paths
)
normalized_default_paths: Optional[Dict[str, Set[str]]] = None normalized_default_paths: Optional[Dict[str, Set[str]]] = None
if isinstance(default_library, Mapping): if isinstance(default_library, Mapping):
normalized_default_paths = _normalize_library_folder_paths(default_library) normalized_default_paths = _normalize_library_folder_paths(
default_library
)
if ( if (
not comfy_library not comfy_library
@@ -185,13 +202,19 @@ class Config:
default_lora_root = self.loras_roots[0] default_lora_root = self.loras_roots[0]
default_checkpoint_root = comfy_library.get("default_checkpoint_root", "") default_checkpoint_root = comfy_library.get("default_checkpoint_root", "")
if (not default_checkpoint_root and self.checkpoints_roots and if (
len(self.checkpoints_roots) == 1): not default_checkpoint_root
and self.checkpoints_roots
and len(self.checkpoints_roots) == 1
):
default_checkpoint_root = self.checkpoints_roots[0] default_checkpoint_root = self.checkpoints_roots[0]
default_embedding_root = comfy_library.get("default_embedding_root", "") default_embedding_root = comfy_library.get("default_embedding_root", "")
if (not default_embedding_root and self.embeddings_roots and if (
len(self.embeddings_roots) == 1): not default_embedding_root
and self.embeddings_roots
and len(self.embeddings_roots) == 1
):
default_embedding_root = self.embeddings_roots[0] default_embedding_root = self.embeddings_roots[0]
metadata = dict(comfy_library.get("metadata", {})) metadata = dict(comfy_library.get("metadata", {}))
@@ -216,11 +239,12 @@ class Config:
try: try:
if os.path.islink(path): if os.path.islink(path):
return True return True
if platform.system() == 'Windows': if platform.system() == "Windows":
try: try:
import ctypes import ctypes
FILE_ATTRIBUTE_REPARSE_POINT = 0x400 FILE_ATTRIBUTE_REPARSE_POINT = 0x400
attrs = ctypes.windll.kernel32.GetFileAttributesW(str(path)) attrs = ctypes.windll.kernel32.GetFileAttributesW(str(path)) # type: ignore[attr-defined]
return attrs != -1 and (attrs & FILE_ATTRIBUTE_REPARSE_POINT) return attrs != -1 and (attrs & FILE_ATTRIBUTE_REPARSE_POINT)
except Exception as e: except Exception as e:
logger.error(f"Error checking Windows reparse point: {e}") logger.error(f"Error checking Windows reparse point: {e}")
@@ -233,18 +257,19 @@ class Config:
"""Check if a directory entry is a symlink, including Windows junctions.""" """Check if a directory entry is a symlink, including Windows junctions."""
if entry.is_symlink(): if entry.is_symlink():
return True return True
if platform.system() == 'Windows': if platform.system() == "Windows":
try: try:
import ctypes import ctypes
FILE_ATTRIBUTE_REPARSE_POINT = 0x400 FILE_ATTRIBUTE_REPARSE_POINT = 0x400
attrs = ctypes.windll.kernel32.GetFileAttributesW(entry.path) attrs = ctypes.windll.kernel32.GetFileAttributesW(entry.path) # type: ignore[attr-defined]
return attrs != -1 and (attrs & FILE_ATTRIBUTE_REPARSE_POINT) return attrs != -1 and (attrs & FILE_ATTRIBUTE_REPARSE_POINT)
except Exception: except Exception:
pass pass
return False return False
def _normalize_path(self, path: str) -> str: def _normalize_path(self, path: str) -> str:
return os.path.normpath(path).replace(os.sep, '/') return os.path.normpath(path).replace(os.sep, "/")
def _get_symlink_cache_path(self) -> Path: def _get_symlink_cache_path(self) -> Path:
canonical_path = get_cache_file_path(CacheType.SYMLINK, create_dir=True) canonical_path = get_cache_file_path(CacheType.SYMLINK, create_dir=True)
@@ -278,19 +303,18 @@ class Config:
if self._entry_is_symlink(entry): if self._entry_is_symlink(entry):
try: try:
target = os.path.realpath(entry.path) target = os.path.realpath(entry.path)
direct_symlinks.append([ direct_symlinks.append(
self._normalize_path(entry.path), [
self._normalize_path(target) self._normalize_path(entry.path),
]) self._normalize_path(target),
]
)
except OSError: except OSError:
pass pass
except (OSError, PermissionError): except (OSError, PermissionError):
pass pass
return { return {"roots": unique_roots, "direct_symlinks": sorted(direct_symlinks)}
"roots": unique_roots,
"direct_symlinks": sorted(direct_symlinks)
}
def _initialize_symlink_mappings(self) -> None: def _initialize_symlink_mappings(self) -> None:
start = time.perf_counter() start = time.perf_counter()
@@ -307,10 +331,14 @@ class Config:
cached_fingerprint = self._cached_fingerprint cached_fingerprint = self._cached_fingerprint
# Check 1: First-level symlinks unchanged (catches new symlinks at root) # Check 1: First-level symlinks unchanged (catches new symlinks at root)
fingerprint_valid = cached_fingerprint and current_fingerprint == cached_fingerprint fingerprint_valid = (
cached_fingerprint and current_fingerprint == cached_fingerprint
)
# Check 2: All cached mappings still valid (catches changes at any depth) # Check 2: All cached mappings still valid (catches changes at any depth)
mappings_valid = self._validate_cached_mappings() if fingerprint_valid else False mappings_valid = (
self._validate_cached_mappings() if fingerprint_valid else False
)
if fingerprint_valid and mappings_valid: if fingerprint_valid and mappings_valid:
return return
@@ -370,7 +398,9 @@ class Config:
for target, link in cached_mappings.items(): for target, link in cached_mappings.items():
if not isinstance(target, str) or not isinstance(link, str): if not isinstance(target, str) or not isinstance(link, str):
continue continue
normalized_mappings[self._normalize_path(target)] = self._normalize_path(link) normalized_mappings[self._normalize_path(target)] = self._normalize_path(
link
)
self._path_mappings = normalized_mappings self._path_mappings = normalized_mappings
@@ -391,7 +421,9 @@ class Config:
parent_dir = loaded_path.parent parent_dir = loaded_path.parent
if parent_dir.name == "cache" and not any(parent_dir.iterdir()): if parent_dir.name == "cache" and not any(parent_dir.iterdir()):
parent_dir.rmdir() parent_dir.rmdir()
logger.info("Removed empty legacy cache directory: %s", parent_dir) logger.info(
"Removed empty legacy cache directory: %s", parent_dir
)
except Exception: except Exception:
pass pass
@@ -402,7 +434,9 @@ class Config:
exc, exc,
) )
else: else:
logger.info("Symlink cache loaded with %d mappings", len(self._path_mappings)) logger.info(
"Symlink cache loaded with %d mappings", len(self._path_mappings)
)
return True return True
@@ -414,7 +448,7 @@ class Config:
""" """
for target, link in self._path_mappings.items(): for target, link in self._path_mappings.items():
# Convert normalized paths back to OS paths # Convert normalized paths back to OS paths
link_path = link.replace('/', os.sep) link_path = link.replace("/", os.sep)
# Check if symlink still exists # Check if symlink still exists
if not self._is_link(link_path): if not self._is_link(link_path):
@@ -427,7 +461,9 @@ class Config:
if actual_target != target: if actual_target != target:
logger.debug( logger.debug(
"Symlink target changed: %s -> %s (cached: %s)", "Symlink target changed: %s -> %s (cached: %s)",
link_path, actual_target, target link_path,
actual_target,
target,
) )
return False return False
except OSError: except OSError:
@@ -446,7 +482,11 @@ class Config:
try: try:
with cache_path.open("w", encoding="utf-8") as handle: with cache_path.open("w", encoding="utf-8") as handle:
json.dump(payload, handle, ensure_ascii=False, indent=2) json.dump(payload, handle, ensure_ascii=False, indent=2)
logger.debug("Symlink cache saved to %s with %d mappings", cache_path, len(self._path_mappings)) logger.debug(
"Symlink cache saved to %s with %d mappings",
cache_path,
len(self._path_mappings),
)
except Exception as exc: except Exception as exc:
logger.info("Failed to write symlink cache %s: %s", cache_path, exc) logger.info("Failed to write symlink cache %s: %s", cache_path, exc)
@@ -494,13 +534,13 @@ class Config:
self.add_path_mapping(entry.path, target_path) self.add_path_mapping(entry.path, target_path)
except Exception as inner_exc: except Exception as inner_exc:
logger.debug( logger.debug(
"Error processing directory entry %s: %s", entry.path, inner_exc "Error processing directory entry %s: %s",
entry.path,
inner_exc,
) )
except Exception as e: except Exception as e:
logger.error(f"Error scanning links in {root}: {e}") logger.error(f"Error scanning links in {root}: {e}")
def add_path_mapping(self, link_path: str, target_path: str): def add_path_mapping(self, link_path: str, target_path: str):
"""Add a symbolic link path mapping """Add a symbolic link path mapping
target_path: actual target path target_path: actual target path
@@ -594,26 +634,31 @@ class Config:
preview_roots.update(self._expand_preview_root(target)) preview_roots.update(self._expand_preview_root(target))
preview_roots.update(self._expand_preview_root(link)) preview_roots.update(self._expand_preview_root(link))
self._preview_root_paths = {path for path in preview_roots if path.is_absolute()} self._preview_root_paths = {
path for path in preview_roots if path.is_absolute()
}
logger.debug( logger.debug(
"Preview roots rebuilt: %d paths from %d lora roots (%d extra), %d checkpoint roots (%d extra), %d embedding roots (%d extra), %d symlink mappings", "Preview roots rebuilt: %d paths from %d lora roots (%d extra), %d checkpoint roots (%d extra), %d embedding roots (%d extra), %d symlink mappings",
len(self._preview_root_paths), len(self._preview_root_paths),
len(self.loras_roots or []), len(self.extra_loras_roots or []), len(self.loras_roots or []),
len(self.base_models_roots or []), len(self.extra_checkpoints_roots or []), len(self.extra_loras_roots or []),
len(self.embeddings_roots or []), len(self.extra_embeddings_roots or []), len(self.base_models_roots or []),
len(self.extra_checkpoints_roots or []),
len(self.embeddings_roots or []),
len(self.extra_embeddings_roots or []),
len(self._path_mappings), len(self._path_mappings),
) )
def map_path_to_link(self, path: str) -> str: def map_path_to_link(self, path: str) -> str:
"""Map a target path back to its symbolic link path""" """Map a target path back to its symbolic link path"""
normalized_path = os.path.normpath(path).replace(os.sep, '/') normalized_path = os.path.normpath(path).replace(os.sep, "/")
# Check if the path is contained in any mapped target path # Check if the path is contained in any mapped target path
for target_path, link_path in self._path_mappings.items(): for target_path, link_path in self._path_mappings.items():
# Match whole path components to avoid prefix collisions (e.g., /a/b vs /a/bc) # Match whole path components to avoid prefix collisions (e.g., /a/b vs /a/bc)
if normalized_path == target_path: if normalized_path == target_path:
return link_path return link_path
if normalized_path.startswith(target_path + '/'): if normalized_path.startswith(target_path + "/"):
# If the path starts with the target path, replace with link path # If the path starts with the target path, replace with link path
mapped_path = normalized_path.replace(target_path, link_path, 1) mapped_path = normalized_path.replace(target_path, link_path, 1)
return mapped_path return mapped_path
@@ -621,14 +666,14 @@ class Config:
def map_link_to_path(self, link_path: str) -> str: def map_link_to_path(self, link_path: str) -> str:
"""Map a symbolic link path back to the actual path""" """Map a symbolic link path back to the actual path"""
normalized_link = os.path.normpath(link_path).replace(os.sep, '/') normalized_link = os.path.normpath(link_path).replace(os.sep, "/")
# Check if the path is contained in any mapped target path # Check if the path is contained in any mapped target path
for target_path, link_path_mapped in self._path_mappings.items(): for target_path, link_path_mapped in self._path_mappings.items():
# Match whole path components # Match whole path components
if normalized_link == link_path_mapped: if normalized_link == link_path_mapped:
return target_path return target_path
if normalized_link.startswith(link_path_mapped + '/'): if normalized_link.startswith(link_path_mapped + "/"):
# If the path starts with the link path, replace with actual path # If the path starts with the link path, replace with actual path
mapped_path = normalized_link.replace(link_path_mapped, target_path, 1) mapped_path = normalized_link.replace(link_path_mapped, target_path, 1)
return mapped_path return mapped_path
@@ -641,8 +686,8 @@ class Config:
continue continue
if not os.path.exists(path): if not os.path.exists(path):
continue continue
real_path = os.path.normpath(os.path.realpath(path)).replace(os.sep, '/') real_path = os.path.normpath(os.path.realpath(path)).replace(os.sep, "/")
normalized = os.path.normpath(path).replace(os.sep, '/') normalized = os.path.normpath(path).replace(os.sep, "/")
if real_path not in dedup: if real_path not in dedup:
dedup[real_path] = normalized dedup[real_path] = normalized
return dedup return dedup
@@ -652,7 +697,9 @@ class Config:
unique_paths = sorted(path_map.values(), key=lambda p: p.lower()) unique_paths = sorted(path_map.values(), key=lambda p: p.lower())
for original_path in unique_paths: for original_path in unique_paths:
real_path = os.path.normpath(os.path.realpath(original_path)).replace(os.sep, '/') real_path = os.path.normpath(os.path.realpath(original_path)).replace(
os.sep, "/"
)
if real_path != original_path: if real_path != original_path:
self.add_path_mapping(original_path, real_path) self.add_path_mapping(original_path, real_path)
@@ -674,7 +721,7 @@ class Config:
"Please fix your ComfyUI path configuration to separate these folders. " "Please fix your ComfyUI path configuration to separate these folders. "
"Falling back to 'checkpoints' for backward compatibility. " "Falling back to 'checkpoints' for backward compatibility. "
"Overlapping real paths: %s", "Overlapping real paths: %s",
[checkpoint_map.get(rp, rp) for rp in overlapping_real_paths] [checkpoint_map.get(rp, rp) for rp in overlapping_real_paths],
) )
# Remove overlapping paths from unet_map to prioritize checkpoints # Remove overlapping paths from unet_map to prioritize checkpoints
for rp in overlapping_real_paths: for rp in overlapping_real_paths:
@@ -694,7 +741,9 @@ class Config:
self.unet_roots = [p for p in unique_paths if p in unet_values] self.unet_roots = [p for p in unique_paths if p in unet_values]
for original_path in unique_paths: for original_path in unique_paths:
real_path = os.path.normpath(os.path.realpath(original_path)).replace(os.sep, '/') real_path = os.path.normpath(os.path.realpath(original_path)).replace(
os.sep, "/"
)
if real_path != original_path: if real_path != original_path:
self.add_path_mapping(original_path, real_path) self.add_path_mapping(original_path, real_path)
@@ -705,7 +754,9 @@ class Config:
unique_paths = sorted(path_map.values(), key=lambda p: p.lower()) unique_paths = sorted(path_map.values(), key=lambda p: p.lower())
for original_path in unique_paths: for original_path in unique_paths:
real_path = os.path.normpath(os.path.realpath(original_path)).replace(os.sep, '/') real_path = os.path.normpath(os.path.realpath(original_path)).replace(
os.sep, "/"
)
if real_path != original_path: if real_path != original_path:
self.add_path_mapping(original_path, real_path) self.add_path_mapping(original_path, real_path)
@@ -719,42 +770,66 @@ class Config:
self._path_mappings.clear() self._path_mappings.clear()
self._preview_root_paths = set() self._preview_root_paths = set()
lora_paths = folder_paths.get('loras', []) or [] lora_paths = folder_paths.get("loras", []) or []
checkpoint_paths = folder_paths.get('checkpoints', []) or [] checkpoint_paths = folder_paths.get("checkpoints", []) or []
unet_paths = folder_paths.get('unet', []) or [] unet_paths = folder_paths.get("unet", []) or []
embedding_paths = folder_paths.get('embeddings', []) or [] embedding_paths = folder_paths.get("embeddings", []) or []
self.loras_roots = self._prepare_lora_paths(lora_paths) self.loras_roots = self._prepare_lora_paths(lora_paths)
self.base_models_roots = self._prepare_checkpoint_paths(checkpoint_paths, unet_paths) self.base_models_roots = self._prepare_checkpoint_paths(
checkpoint_paths, unet_paths
)
self.embeddings_roots = self._prepare_embedding_paths(embedding_paths) self.embeddings_roots = self._prepare_embedding_paths(embedding_paths)
# Process extra paths (only for LoRA Manager, not shared with ComfyUI) # Process extra paths (only for LoRA Manager, not shared with ComfyUI)
extra_paths = extra_folder_paths or {} extra_paths = extra_folder_paths or {}
extra_lora_paths = extra_paths.get('loras', []) or [] extra_lora_paths = extra_paths.get("loras", []) or []
extra_checkpoint_paths = extra_paths.get('checkpoints', []) or [] extra_checkpoint_paths = extra_paths.get("checkpoints", []) or []
extra_unet_paths = extra_paths.get('unet', []) or [] extra_unet_paths = extra_paths.get("unet", []) or []
extra_embedding_paths = extra_paths.get('embeddings', []) or [] extra_embedding_paths = extra_paths.get("embeddings", []) or []
self.extra_loras_roots = self._prepare_lora_paths(extra_lora_paths) self.extra_loras_roots = self._prepare_lora_paths(extra_lora_paths)
# Save main paths before processing extra paths ( _prepare_checkpoint_paths overwrites them) # Save main paths before processing extra paths ( _prepare_checkpoint_paths overwrites them)
saved_checkpoints_roots = self.checkpoints_roots saved_checkpoints_roots = self.checkpoints_roots
saved_unet_roots = self.unet_roots saved_unet_roots = self.unet_roots
self.extra_checkpoints_roots = self._prepare_checkpoint_paths(extra_checkpoint_paths, extra_unet_paths) self.extra_checkpoints_roots = self._prepare_checkpoint_paths(
self.extra_unet_roots = self.unet_roots # unet_roots was set by _prepare_checkpoint_paths extra_checkpoint_paths, extra_unet_paths
)
self.extra_unet_roots = (
self.unet_roots if self.unet_roots is not None else []
) # unet_roots was set by _prepare_checkpoint_paths
# Restore main paths # Restore main paths
self.checkpoints_roots = saved_checkpoints_roots self.checkpoints_roots = saved_checkpoints_roots
self.unet_roots = saved_unet_roots self.unet_roots = saved_unet_roots
self.extra_embeddings_roots = self._prepare_embedding_paths(extra_embedding_paths) self.extra_embeddings_roots = self._prepare_embedding_paths(
extra_embedding_paths
)
# Log extra folder paths # Log extra folder paths
if self.extra_loras_roots: if self.extra_loras_roots:
logger.info("Found extra LoRA roots:" + "\n - " + "\n - ".join(self.extra_loras_roots)) logger.info(
"Found extra LoRA roots:"
+ "\n - "
+ "\n - ".join(self.extra_loras_roots)
)
if self.extra_checkpoints_roots: if self.extra_checkpoints_roots:
logger.info("Found extra checkpoint roots:" + "\n - " + "\n - ".join(self.extra_checkpoints_roots)) logger.info(
"Found extra checkpoint roots:"
+ "\n - "
+ "\n - ".join(self.extra_checkpoints_roots)
)
if self.extra_unet_roots: if self.extra_unet_roots:
logger.info("Found extra diffusion model roots:" + "\n - " + "\n - ".join(self.extra_unet_roots)) logger.info(
"Found extra diffusion model roots:"
+ "\n - "
+ "\n - ".join(self.extra_unet_roots)
)
if self.extra_embeddings_roots: if self.extra_embeddings_roots:
logger.info("Found extra embedding roots:" + "\n - " + "\n - ".join(self.extra_embeddings_roots)) logger.info(
"Found extra embedding roots:"
+ "\n - "
+ "\n - ".join(self.extra_embeddings_roots)
)
self._initialize_symlink_mappings() self._initialize_symlink_mappings()
@@ -763,7 +838,10 @@ class Config:
try: try:
raw_paths = folder_paths.get_folder_paths("loras") raw_paths = folder_paths.get_folder_paths("loras")
unique_paths = self._prepare_lora_paths(raw_paths) unique_paths = self._prepare_lora_paths(raw_paths)
logger.info("Found LoRA roots:" + ("\n - " + "\n - ".join(unique_paths) if unique_paths else "[]")) logger.info(
"Found LoRA roots:"
+ ("\n - " + "\n - ".join(unique_paths) if unique_paths else "[]")
)
if not unique_paths: if not unique_paths:
logger.warning("No valid loras folders found in ComfyUI configuration") logger.warning("No valid loras folders found in ComfyUI configuration")
@@ -779,12 +857,19 @@ class Config:
try: try:
raw_checkpoint_paths = folder_paths.get_folder_paths("checkpoints") raw_checkpoint_paths = folder_paths.get_folder_paths("checkpoints")
raw_unet_paths = folder_paths.get_folder_paths("unet") raw_unet_paths = folder_paths.get_folder_paths("unet")
unique_paths = self._prepare_checkpoint_paths(raw_checkpoint_paths, raw_unet_paths) unique_paths = self._prepare_checkpoint_paths(
raw_checkpoint_paths, raw_unet_paths
)
logger.info("Found checkpoint roots:" + ("\n - " + "\n - ".join(unique_paths) if unique_paths else "[]")) logger.info(
"Found checkpoint roots:"
+ ("\n - " + "\n - ".join(unique_paths) if unique_paths else "[]")
)
if not unique_paths: if not unique_paths:
logger.warning("No valid checkpoint folders found in ComfyUI configuration") logger.warning(
"No valid checkpoint folders found in ComfyUI configuration"
)
return [] return []
return unique_paths return unique_paths
@@ -797,10 +882,15 @@ class Config:
try: try:
raw_paths = folder_paths.get_folder_paths("embeddings") raw_paths = folder_paths.get_folder_paths("embeddings")
unique_paths = self._prepare_embedding_paths(raw_paths) unique_paths = self._prepare_embedding_paths(raw_paths)
logger.info("Found embedding roots:" + ("\n - " + "\n - ".join(unique_paths) if unique_paths else "[]")) logger.info(
"Found embedding roots:"
+ ("\n - " + "\n - ".join(unique_paths) if unique_paths else "[]")
)
if not unique_paths: if not unique_paths:
logger.warning("No valid embeddings folders found in ComfyUI configuration") logger.warning(
"No valid embeddings folders found in ComfyUI configuration"
)
return [] return []
return unique_paths return unique_paths
@@ -812,9 +902,9 @@ class Config:
if not preview_path: if not preview_path:
return "" return ""
normalized = os.path.normpath(preview_path).replace(os.sep, '/') normalized = os.path.normpath(preview_path).replace(os.sep, "/")
encoded_path = urllib.parse.quote(normalized, safe='') encoded_path = urllib.parse.quote(normalized, safe="")
return f'/api/lm/previews?path={encoded_path}' return f"/api/lm/previews?path={encoded_path}"
def is_preview_path_allowed(self, preview_path: str) -> bool: def is_preview_path_allowed(self, preview_path: str) -> bool:
"""Return ``True`` if ``preview_path`` is within an allowed directory. """Return ``True`` if ``preview_path`` is within an allowed directory.
@@ -889,14 +979,18 @@ class Config:
normalized_link = self._normalize_path(str(current)) normalized_link = self._normalize_path(str(current))
self._path_mappings[normalized_target] = normalized_link self._path_mappings[normalized_target] = normalized_link
self._preview_root_paths.update(self._expand_preview_root(normalized_target)) self._preview_root_paths.update(
self._preview_root_paths.update(self._expand_preview_root(normalized_link)) self._expand_preview_root(normalized_target)
)
self._preview_root_paths.update(
self._expand_preview_root(normalized_link)
)
logger.debug( logger.debug(
"Discovered deep symlink: %s -> %s (preview path: %s)", "Discovered deep symlink: %s -> %s (preview path: %s)",
normalized_link, normalized_link,
normalized_target, normalized_target,
preview_path preview_path,
) )
return True return True
@@ -914,8 +1008,16 @@ class Config:
def apply_library_settings(self, library_config: Mapping[str, object]) -> None: def apply_library_settings(self, library_config: Mapping[str, object]) -> None:
"""Update runtime paths to match the provided library configuration.""" """Update runtime paths to match the provided library configuration."""
folder_paths = library_config.get('folder_paths') if isinstance(library_config, Mapping) else {} folder_paths = (
extra_folder_paths = library_config.get('extra_folder_paths') if isinstance(library_config, Mapping) else None library_config.get("folder_paths")
if isinstance(library_config, Mapping)
else {}
)
extra_folder_paths = (
library_config.get("extra_folder_paths")
if isinstance(library_config, Mapping)
else None
)
if not isinstance(folder_paths, Mapping): if not isinstance(folder_paths, Mapping):
folder_paths = {} folder_paths = {}
if not isinstance(extra_folder_paths, Mapping): if not isinstance(extra_folder_paths, Mapping):
@@ -925,9 +1027,12 @@ class Config:
logger.info( logger.info(
"Applied library settings with %d lora roots (%d extra), %d checkpoint roots (%d extra), and %d embedding roots (%d extra)", "Applied library settings with %d lora roots (%d extra), %d checkpoint roots (%d extra), and %d embedding roots (%d extra)",
len(self.loras_roots or []), len(self.extra_loras_roots or []), len(self.loras_roots or []),
len(self.base_models_roots or []), len(self.extra_checkpoints_roots or []), len(self.extra_loras_roots or []),
len(self.embeddings_roots or []), len(self.extra_embeddings_roots or []), len(self.base_models_roots or []),
len(self.extra_checkpoints_roots or []),
len(self.embeddings_roots or []),
len(self.extra_embeddings_roots or []),
) )
def get_library_registry_snapshot(self) -> Dict[str, object]: def get_library_registry_snapshot(self) -> Dict[str, object]:
@@ -947,5 +1052,6 @@ class Config:
logger.debug("Failed to collect library registry snapshot: %s", exc) logger.debug("Failed to collect library registry snapshot: %s", exc)
return {"active_library": "", "libraries": {}} return {"active_library": "", "libraries": {}}
# Global config instance # Global config instance
config = Config() config = Config()

View File

@@ -5,16 +5,22 @@ import logging
from .utils.logging_config import setup_logging from .utils.logging_config import setup_logging
# Check if we're in standalone mode # Check if we're in standalone mode
standalone_mode = os.environ.get("LORA_MANAGER_STANDALONE", "0") == "1" or os.environ.get("HF_HUB_DISABLE_TELEMETRY", "0") == "0" standalone_mode = (
os.environ.get("LORA_MANAGER_STANDALONE", "0") == "1"
or os.environ.get("HF_HUB_DISABLE_TELEMETRY", "0") == "0"
)
# Only setup logging prefix if not in standalone mode # Only setup logging prefix if not in standalone mode
if not standalone_mode: if not standalone_mode:
setup_logging() setup_logging()
from server import PromptServer # type: ignore from server import PromptServer # type: ignore
from .config import config from .config import config
from .services.model_service_factory import ModelServiceFactory, register_default_model_types from .services.model_service_factory import (
ModelServiceFactory,
register_default_model_types,
)
from .routes.recipe_routes import RecipeRoutes from .routes.recipe_routes import RecipeRoutes
from .routes.stats_routes import StatsRoutes from .routes.stats_routes import StatsRoutes
from .routes.update_routes import UpdateRoutes from .routes.update_routes import UpdateRoutes
@@ -61,6 +67,7 @@ class _SettingsProxy:
settings = _SettingsProxy() settings = _SettingsProxy()
class LoraManager: class LoraManager:
"""Main entry point for LoRA Manager plugin""" """Main entry point for LoRA Manager plugin"""
@@ -76,7 +83,8 @@ class LoraManager:
( (
idx idx
for idx, middleware in enumerate(app.middlewares) for idx, middleware in enumerate(app.middlewares)
if getattr(middleware, "__name__", "") == "block_external_middleware" if getattr(middleware, "__name__", "")
== "block_external_middleware"
), ),
None, None,
) )
@@ -84,7 +92,9 @@ class LoraManager:
if block_middleware_index is None: if block_middleware_index is None:
app.middlewares.append(relax_csp_for_remote_media) app.middlewares.append(relax_csp_for_remote_media)
else: else:
app.middlewares.insert(block_middleware_index, relax_csp_for_remote_media) app.middlewares.insert(
block_middleware_index, relax_csp_for_remote_media
)
# Increase allowed header sizes so browsers with large localhost cookie # Increase allowed header sizes so browsers with large localhost cookie
# jars (multiple UIs on 127.0.0.1) don't trip aiohttp's 8KB default # jars (multiple UIs on 127.0.0.1) don't trip aiohttp's 8KB default
@@ -105,7 +115,7 @@ class LoraManager:
app._handler_args = updated_handler_args app._handler_args = updated_handler_args
# Configure aiohttp access logger to be less verbose # Configure aiohttp access logger to be less verbose
logging.getLogger('aiohttp.access').setLevel(logging.WARNING) logging.getLogger("aiohttp.access").setLevel(logging.WARNING)
# Add specific suppression for connection reset errors # Add specific suppression for connection reset errors
class ConnectionResetFilter(logging.Filter): class ConnectionResetFilter(logging.Filter):
@@ -124,19 +134,23 @@ class LoraManager:
asyncio_logger.addFilter(ConnectionResetFilter()) asyncio_logger.addFilter(ConnectionResetFilter())
# Add static route for example images if the path exists in settings # Add static route for example images if the path exists in settings
example_images_path = settings.get('example_images_path') example_images_path = settings.get("example_images_path")
logger.info(f"Example images path: {example_images_path}") logger.info(f"Example images path: {example_images_path}")
if example_images_path and os.path.exists(example_images_path): if example_images_path and os.path.exists(example_images_path):
app.router.add_static('/example_images_static', example_images_path) app.router.add_static("/example_images_static", example_images_path)
logger.info(f"Added static route for example images: /example_images_static -> {example_images_path}") logger.info(
f"Added static route for example images: /example_images_static -> {example_images_path}"
)
# Add static route for locales JSON files # Add static route for locales JSON files
if os.path.exists(config.i18n_path): if os.path.exists(config.i18n_path):
app.router.add_static('/locales', config.i18n_path) app.router.add_static("/locales", config.i18n_path)
logger.info(f"Added static route for locales: /locales -> {config.i18n_path}") logger.info(
f"Added static route for locales: /locales -> {config.i18n_path}"
)
# Add static route for plugin assets # Add static route for plugin assets
app.router.add_static('/loras_static', config.static_path) app.router.add_static("/loras_static", config.static_path)
# Register default model types with the factory # Register default model types with the factory
register_default_model_types() register_default_model_types()
@@ -154,9 +168,11 @@ class LoraManager:
PreviewRoutes.setup_routes(app) PreviewRoutes.setup_routes(app)
# Setup WebSocket routes that are shared across all model types # Setup WebSocket routes that are shared across all model types
app.router.add_get('/ws/fetch-progress', ws_manager.handle_connection) app.router.add_get("/ws/fetch-progress", ws_manager.handle_connection)
app.router.add_get('/ws/download-progress', ws_manager.handle_download_connection) app.router.add_get(
app.router.add_get('/ws/init-progress', ws_manager.handle_init_connection) "/ws/download-progress", ws_manager.handle_download_connection
)
app.router.add_get("/ws/init-progress", ws_manager.handle_init_connection)
# Schedule service initialization # Schedule service initialization
app.on_startup.append(lambda app: cls._initialize_services()) app.on_startup.append(lambda app: cls._initialize_services())
@@ -197,7 +213,9 @@ class LoraManager:
extra_paths.get("embeddings", []), extra_paths.get("embeddings", []),
) )
except Exception as exc: except Exception as exc:
logger.warning("Failed to apply library settings during initialization: %s", exc) logger.warning(
"Failed to apply library settings during initialization: %s", exc
)
# Initialize CivitaiClient first to ensure it's ready for other services # Initialize CivitaiClient first to ensure it's ready for other services
await ServiceRegistry.get_civitai_client() await ServiceRegistry.get_civitai_client()
@@ -206,6 +224,7 @@ class LoraManager:
await ServiceRegistry.get_download_manager() await ServiceRegistry.get_download_manager()
from .services.metadata_service import initialize_metadata_providers from .services.metadata_service import initialize_metadata_providers
await initialize_metadata_providers() await initialize_metadata_providers()
# Initialize WebSocket manager # Initialize WebSocket manager
@@ -221,39 +240,58 @@ class LoraManager:
# Create low-priority initialization tasks # Create low-priority initialization tasks
init_tasks = [ init_tasks = [
asyncio.create_task(lora_scanner.initialize_in_background(), name='lora_cache_init'), asyncio.create_task(
asyncio.create_task(checkpoint_scanner.initialize_in_background(), name='checkpoint_cache_init'), lora_scanner.initialize_in_background(), name="lora_cache_init"
asyncio.create_task(embedding_scanner.initialize_in_background(), name='embedding_cache_init'), ),
asyncio.create_task(recipe_scanner.initialize_in_background(), name='recipe_cache_init') asyncio.create_task(
checkpoint_scanner.initialize_in_background(),
name="checkpoint_cache_init",
),
asyncio.create_task(
embedding_scanner.initialize_in_background(),
name="embedding_cache_init",
),
asyncio.create_task(
recipe_scanner.initialize_in_background(), name="recipe_cache_init"
),
] ]
await ExampleImagesMigration.check_and_run_migrations() await ExampleImagesMigration.check_and_run_migrations()
# Schedule post-initialization tasks to run after scanners complete # Schedule post-initialization tasks to run after scanners complete
asyncio.create_task( asyncio.create_task(
cls._run_post_initialization_tasks(init_tasks), cls._run_post_initialization_tasks(init_tasks), name="post_init_tasks"
name='post_init_tasks'
) )
logger.debug("LoRA Manager: All services initialized and background tasks scheduled") logger.debug(
"LoRA Manager: All services initialized and background tasks scheduled"
)
except Exception as e: except Exception as e:
logger.error(f"LoRA Manager: Error initializing services: {e}", exc_info=True) logger.error(
f"LoRA Manager: Error initializing services: {e}", exc_info=True
)
@classmethod @classmethod
async def _run_post_initialization_tasks(cls, init_tasks): async def _run_post_initialization_tasks(cls, init_tasks):
"""Run post-initialization tasks after all scanners complete""" """Run post-initialization tasks after all scanners complete"""
try: try:
logger.debug("LoRA Manager: Waiting for scanner initialization to complete...") logger.debug(
"LoRA Manager: Waiting for scanner initialization to complete..."
)
# Wait for all scanner initialization tasks to complete # Wait for all scanner initialization tasks to complete
await asyncio.gather(*init_tasks, return_exceptions=True) await asyncio.gather(*init_tasks, return_exceptions=True)
logger.debug("LoRA Manager: Scanner initialization completed, starting post-initialization tasks...") logger.debug(
"LoRA Manager: Scanner initialization completed, starting post-initialization tasks..."
)
# Run post-initialization tasks # Run post-initialization tasks
post_tasks = [ post_tasks = [
asyncio.create_task(cls._cleanup_backup_files(), name='cleanup_bak_files'), asyncio.create_task(
cls._cleanup_backup_files(), name="cleanup_bak_files"
),
# Add more post-initialization tasks here as needed # Add more post-initialization tasks here as needed
# asyncio.create_task(cls._another_post_task(), name='another_task'), # asyncio.create_task(cls._another_post_task(), name='another_task'),
] ]
@@ -265,14 +303,20 @@ class LoraManager:
for i, result in enumerate(results): for i, result in enumerate(results):
task_name = post_tasks[i].get_name() task_name = post_tasks[i].get_name()
if isinstance(result, Exception): if isinstance(result, Exception):
logger.error(f"Post-initialization task '{task_name}' failed: {result}") logger.error(
f"Post-initialization task '{task_name}' failed: {result}"
)
else: else:
logger.debug(f"Post-initialization task '{task_name}' completed successfully") logger.debug(
f"Post-initialization task '{task_name}' completed successfully"
)
logger.debug("LoRA Manager: All post-initialization tasks completed") logger.debug("LoRA Manager: All post-initialization tasks completed")
except Exception as e: except Exception as e:
logger.error(f"LoRA Manager: Error in post-initialization tasks: {e}", exc_info=True) logger.error(
f"LoRA Manager: Error in post-initialization tasks: {e}", exc_info=True
)
@classmethod @classmethod
async def _cleanup_backup_files(cls): async def _cleanup_backup_files(cls):
@@ -283,8 +327,8 @@ class LoraManager:
# Collect all model roots # Collect all model roots
all_roots = set() all_roots = set()
all_roots.update(config.loras_roots) all_roots.update(config.loras_roots)
all_roots.update(config.base_models_roots) all_roots.update(config.base_models_roots or [])
all_roots.update(config.embeddings_roots) all_roots.update(config.embeddings_roots or [])
total_deleted = 0 total_deleted = 0
total_size_freed = 0 total_size_freed = 0
@@ -294,12 +338,17 @@ class LoraManager:
continue continue
try: try:
deleted_count, size_freed = await cls._cleanup_backup_files_in_directory(root_path) (
deleted_count,
size_freed,
) = await cls._cleanup_backup_files_in_directory(root_path)
total_deleted += deleted_count total_deleted += deleted_count
total_size_freed += size_freed total_size_freed += size_freed
if deleted_count > 0: if deleted_count > 0:
logger.debug(f"Cleaned up {deleted_count} .bak files in {root_path} (freed {size_freed / (1024*1024):.2f} MB)") logger.debug(
f"Cleaned up {deleted_count} .bak files in {root_path} (freed {size_freed / (1024 * 1024):.2f} MB)"
)
except Exception as e: except Exception as e:
logger.error(f"Error cleaning up .bak files in {root_path}: {e}") logger.error(f"Error cleaning up .bak files in {root_path}: {e}")
@@ -308,7 +357,9 @@ class LoraManager:
await asyncio.sleep(0.01) await asyncio.sleep(0.01)
if total_deleted > 0: if total_deleted > 0:
logger.debug(f"Backup cleanup completed: removed {total_deleted} .bak files, freed {total_size_freed / (1024*1024):.2f} MB total") logger.debug(
f"Backup cleanup completed: removed {total_deleted} .bak files, freed {total_size_freed / (1024 * 1024):.2f} MB total"
)
else: else:
logger.debug("Backup cleanup completed: no .bak files found") logger.debug("Backup cleanup completed: no .bak files found")
@@ -341,7 +392,9 @@ class LoraManager:
with os.scandir(path) as it: with os.scandir(path) as it:
for entry in it: for entry in it:
try: try:
if entry.is_file(follow_symlinks=True) and entry.name.endswith('.bak'): if entry.is_file(
follow_symlinks=True
) and entry.name.endswith(".bak"):
file_size = entry.stat().st_size file_size = entry.stat().st_size
os.remove(entry.path) os.remove(entry.path)
deleted_count += 1 deleted_count += 1
@@ -352,7 +405,9 @@ class LoraManager:
cleanup_recursive(entry.path) cleanup_recursive(entry.path)
except Exception as e: except Exception as e:
logger.warning(f"Could not delete .bak file {entry.path}: {e}") logger.warning(
f"Could not delete .bak file {entry.path}: {e}"
)
except Exception as e: except Exception as e:
logger.error(f"Error scanning directory {path} for .bak files: {e}") logger.error(f"Error scanning directory {path} for .bak files: {e}")
@@ -370,21 +425,21 @@ class LoraManager:
service = ExampleImagesCleanupService() service = ExampleImagesCleanupService()
result = await service.cleanup_example_image_folders() result = await service.cleanup_example_image_folders()
if result.get('success'): if result.get("success"):
logger.debug( logger.debug(
"Manual example images cleanup completed: moved=%s", "Manual example images cleanup completed: moved=%s",
result.get('moved_total'), result.get("moved_total"),
) )
elif result.get('partial_success'): elif result.get("partial_success"):
logger.warning( logger.warning(
"Manual example images cleanup partially succeeded: moved=%s failures=%s", "Manual example images cleanup partially succeeded: moved=%s failures=%s",
result.get('moved_total'), result.get("moved_total"),
result.get('move_failures'), result.get("move_failures"),
) )
else: else:
logger.debug( logger.debug(
"Manual example images cleanup skipped or failed: %s", "Manual example images cleanup skipped or failed: %s",
result.get('error', 'no changes'), result.get("error", "no changes"),
) )
return result return result
@@ -392,9 +447,9 @@ class LoraManager:
except Exception as e: # pragma: no cover - defensive guard except Exception as e: # pragma: no cover - defensive guard
logger.error(f"Error during example images cleanup: {e}", exc_info=True) logger.error(f"Error during example images cleanup: {e}", exc_info=True)
return { return {
'success': False, "success": False,
'error': str(e), "error": str(e),
'error_code': 'unexpected_error', "error_code": "unexpected_error",
} }
@classmethod @classmethod

View File

@@ -4,7 +4,10 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Check if running in standalone mode # Check if running in standalone mode
standalone_mode = os.environ.get("LORA_MANAGER_STANDALONE", "0") == "1" or os.environ.get("HF_HUB_DISABLE_TELEMETRY", "0") == "0" standalone_mode = (
os.environ.get("LORA_MANAGER_STANDALONE", "0") == "1"
or os.environ.get("HF_HUB_DISABLE_TELEMETRY", "0") == "0"
)
if not standalone_mode: if not standalone_mode:
from .metadata_hook import MetadataHook from .metadata_hook import MetadataHook
@@ -19,7 +22,7 @@ if not standalone_mode:
logger.info("ComfyUI Metadata Collector initialized") logger.info("ComfyUI Metadata Collector initialized")
def get_metadata(prompt_id=None): def get_metadata(prompt_id=None): # type: ignore[no-redef]
"""Helper function to get metadata from the registry""" """Helper function to get metadata from the registry"""
registry = MetadataRegistry() registry = MetadataRegistry()
return registry.get_metadata(prompt_id) return registry.get_metadata(prompt_id)
@@ -28,6 +31,6 @@ else:
def init(): def init():
logger.info("ComfyUI Metadata Collector disabled in standalone mode") logger.info("ComfyUI Metadata Collector disabled in standalone mode")
def get_metadata(prompt_id=None): def get_metadata(prompt_id=None): # type: ignore[no-redef]
"""Dummy implementation for standalone mode""" """Dummy implementation for standalone mode"""
return {} return {}

View File

@@ -1,10 +1,12 @@
import time import time
from nodes import NODE_CLASS_MAPPINGS from nodes import NODE_CLASS_MAPPINGS # type: ignore
from .node_extractors import NODE_EXTRACTORS, GenericNodeExtractor from .node_extractors import NODE_EXTRACTORS, GenericNodeExtractor
from .constants import METADATA_CATEGORIES, IMAGES from .constants import METADATA_CATEGORIES, IMAGES
class MetadataRegistry: class MetadataRegistry:
"""A singleton registry to store and retrieve workflow metadata""" """A singleton registry to store and retrieve workflow metadata"""
_instance = None _instance = None
def __new__(cls): def __new__(cls):
@@ -37,11 +39,13 @@ class MetadataRegistry:
# Sort all prompt_ids by timestamp # Sort all prompt_ids by timestamp
sorted_prompts = sorted( sorted_prompts = sorted(
self.prompt_metadata.keys(), self.prompt_metadata.keys(),
key=lambda pid: self.prompt_metadata[pid].get("timestamp", 0) key=lambda pid: self.prompt_metadata[pid].get("timestamp", 0),
) )
# Remove oldest records # Remove oldest records
prompts_to_remove = sorted_prompts[:len(sorted_prompts) - self.max_prompt_history] prompts_to_remove = sorted_prompts[
: len(sorted_prompts) - self.max_prompt_history
]
for pid in prompts_to_remove: for pid in prompts_to_remove:
del self.prompt_metadata[pid] del self.prompt_metadata[pid]
@@ -53,11 +57,13 @@ class MetadataRegistry:
category: {} for category in METADATA_CATEGORIES category: {} for category in METADATA_CATEGORIES
} }
# Add additional metadata fields # Add additional metadata fields
self.prompt_metadata[prompt_id].update({ self.prompt_metadata[prompt_id].update(
"execution_order": [], {
"current_prompt": None, # Will store the prompt object "execution_order": [],
"timestamp": time.time() "current_prompt": None, # Will store the prompt object
}) "timestamp": time.time(),
}
)
# Clean up old prompt data # Clean up old prompt data
self._clean_old_prompts() self._clean_old_prompts()
@@ -125,7 +131,9 @@ class MetadataRegistry:
for category in self.metadata_categories: for category in self.metadata_categories:
if category in cached_data and node_id in cached_data[category]: if category in cached_data and node_id in cached_data[category]:
if node_id not in metadata[category]: if node_id not in metadata[category]:
metadata[category][node_id] = cached_data[category][node_id] metadata[category][node_id] = cached_data[category][
node_id
]
def record_node_execution(self, node_id, class_type, inputs, outputs): def record_node_execution(self, node_id, class_type, inputs, outputs):
"""Record information about a node's execution""" """Record information about a node's execution"""
@@ -135,7 +143,9 @@ class MetadataRegistry:
# Add to execution order and mark as executed # Add to execution order and mark as executed
if node_id not in self.executed_nodes: if node_id not in self.executed_nodes:
self.executed_nodes.add(node_id) self.executed_nodes.add(node_id)
self.prompt_metadata[self.current_prompt_id]["execution_order"].append(node_id) self.prompt_metadata[self.current_prompt_id]["execution_order"].append(
node_id
)
# Process inputs to simplify working with them # Process inputs to simplify working with them
processed_inputs = {} processed_inputs = {}
@@ -152,7 +162,7 @@ class MetadataRegistry:
node_id, node_id,
processed_inputs, processed_inputs,
outputs, outputs,
self.prompt_metadata[self.current_prompt_id] self.prompt_metadata[self.current_prompt_id],
) )
# Cache this node's metadata # Cache this node's metadata
@@ -168,11 +178,9 @@ class MetadataRegistry:
# Use the same extractor to update with outputs # Use the same extractor to update with outputs
extractor = NODE_EXTRACTORS.get(class_type, GenericNodeExtractor) extractor = NODE_EXTRACTORS.get(class_type, GenericNodeExtractor)
if hasattr(extractor, 'update'): if hasattr(extractor, "update"):
extractor.update( extractor.update(
node_id, node_id, processed_outputs, self.prompt_metadata[self.current_prompt_id]
processed_outputs,
self.prompt_metadata[self.current_prompt_id]
) )
# Update the cached metadata for this node # Update the cached metadata for this node
@@ -214,7 +222,7 @@ class MetadataRegistry:
# Find cache keys that are no longer needed # Find cache keys that are no longer needed
keys_to_remove = [] keys_to_remove = []
for cache_key in self.node_cache: for cache_key in self.node_cache:
node_id = cache_key.split(':')[0] node_id = cache_key.split(":")[0]
if node_id not in active_node_ids: if node_id not in active_node_ids:
keys_to_remove.append(cache_key) keys_to_remove.append(cache_key)
@@ -270,7 +278,10 @@ class MetadataRegistry:
if IMAGES in cached_data and node_id in cached_data[IMAGES]: if IMAGES in cached_data and node_id in cached_data[IMAGES]:
image_data = cached_data[IMAGES][node_id]["image"] image_data = cached_data[IMAGES][node_id]["image"]
# Handle different image formats # Handle different image formats
if isinstance(image_data, (list, tuple)) and len(image_data) > 0: if (
isinstance(image_data, (list, tuple))
and len(image_data) > 0
):
return image_data[0] return image_data[0]
return image_data return image_data

View File

@@ -1,8 +1,9 @@
import json import json
import os import os
import re import re
from typing import Any, Dict, Optional
import numpy as np import numpy as np
import folder_paths # type: ignore import folder_paths # type: ignore
from ..services.service_registry import ServiceRegistry from ..services.service_registry import ServiceRegistry
from ..metadata_collector.metadata_processor import MetadataProcessor from ..metadata_collector.metadata_processor import MetadataProcessor
from ..metadata_collector import get_metadata from ..metadata_collector import get_metadata
@@ -12,6 +13,7 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class SaveImageLM: class SaveImageLM:
NAME = "Save Image (LoraManager)" NAME = "Save Image (LoraManager)"
CATEGORY = "Lora Manager/utils" CATEGORY = "Lora Manager/utils"
@@ -32,33 +34,51 @@ class SaveImageLM:
return { return {
"required": { "required": {
"images": ("IMAGE",), "images": ("IMAGE",),
"filename_prefix": ("STRING", { "filename_prefix": (
"default": "ComfyUI", "STRING",
"tooltip": "Base filename for saved images. Supports format patterns like %seed%, %width%, %height%, %model%, etc." {
}), "default": "ComfyUI",
"file_format": (["png", "jpeg", "webp"], { "tooltip": "Base filename for saved images. Supports format patterns like %seed%, %width%, %height%, %model%, etc.",
"tooltip": "Image format to save as. PNG preserves quality, JPEG is smaller, WebP balances size and quality." },
}), ),
"file_format": (
["png", "jpeg", "webp"],
{
"tooltip": "Image format to save as. PNG preserves quality, JPEG is smaller, WebP balances size and quality."
},
),
}, },
"optional": { "optional": {
"lossless_webp": ("BOOLEAN", { "lossless_webp": (
"default": False, "BOOLEAN",
"tooltip": "When enabled, saves WebP images with lossless compression. Results in larger files but no quality loss." {
}), "default": False,
"quality": ("INT", { "tooltip": "When enabled, saves WebP images with lossless compression. Results in larger files but no quality loss.",
"default": 100, },
"min": 1, ),
"max": 100, "quality": (
"tooltip": "Compression quality for JPEG and lossy WebP formats (1-100). Higher values mean better quality but larger files." "INT",
}), {
"embed_workflow": ("BOOLEAN", { "default": 100,
"default": False, "min": 1,
"tooltip": "Embeds the complete workflow data into the image metadata. Only works with PNG and WebP formats." "max": 100,
}), "tooltip": "Compression quality for JPEG and lossy WebP formats (1-100). Higher values mean better quality but larger files.",
"add_counter_to_filename": ("BOOLEAN", { },
"default": True, ),
"tooltip": "Adds an incremental counter to filenames to prevent overwriting previous images." "embed_workflow": (
}), "BOOLEAN",
{
"default": False,
"tooltip": "Embeds the complete workflow data into the image metadata. Only works with PNG and WebP formats.",
},
),
"add_counter_to_filename": (
"BOOLEAN",
{
"default": True,
"tooltip": "Adds an incremental counter to filenames to prevent overwriting previous images.",
},
),
}, },
"hidden": { "hidden": {
"id": "UNIQUE_ID", "id": "UNIQUE_ID",
@@ -77,9 +97,10 @@ class SaveImageLM:
scanner = ServiceRegistry.get_service_sync("lora_scanner") scanner = ServiceRegistry.get_service_sync("lora_scanner")
# Use the new direct filename lookup method # Use the new direct filename lookup method
hash_value = scanner.get_hash_by_filename(lora_name) if scanner is not None:
if hash_value: hash_value = scanner.get_hash_by_filename(lora_name)
return hash_value if hash_value:
return hash_value
return None return None
@@ -95,9 +116,10 @@ class SaveImageLM:
checkpoint_name = os.path.splitext(checkpoint_name)[0] checkpoint_name = os.path.splitext(checkpoint_name)[0]
# Try direct filename lookup first # Try direct filename lookup first
hash_value = scanner.get_hash_by_filename(checkpoint_name) if scanner is not None:
if hash_value: hash_value = scanner.get_hash_by_filename(checkpoint_name)
return hash_value if hash_value:
return hash_value
return None return None
@@ -112,11 +134,11 @@ class SaveImageLM:
param_list.append(f"{label}: {value}") param_list.append(f"{label}: {value}")
# Extract the prompt and negative prompt # Extract the prompt and negative prompt
prompt = metadata_dict.get('prompt', '') prompt = metadata_dict.get("prompt", "")
negative_prompt = metadata_dict.get('negative_prompt', '') negative_prompt = metadata_dict.get("negative_prompt", "")
# Extract loras from the prompt if present # Extract loras from the prompt if present
loras_text = metadata_dict.get('loras', '') loras_text = metadata_dict.get("loras", "")
lora_hashes = {} lora_hashes = {}
# If loras are found, add them on a new line after the prompt # If loras are found, add them on a new line after the prompt
@@ -124,7 +146,7 @@ class SaveImageLM:
prompt_with_loras = f"{prompt}\n{loras_text}" prompt_with_loras = f"{prompt}\n{loras_text}"
# Extract lora names from the format <lora:name:strength> # Extract lora names from the format <lora:name:strength>
lora_matches = re.findall(r'<lora:([^:]+):([^>]+)>', loras_text) lora_matches = re.findall(r"<lora:([^:]+):([^>]+)>", loras_text)
# Get hash for each lora # Get hash for each lora
for lora_name, strength in lora_matches: for lora_name, strength in lora_matches:
@@ -145,43 +167,43 @@ class SaveImageLM:
params = [] params = []
# Add standard parameters in the correct order # Add standard parameters in the correct order
if 'steps' in metadata_dict: if "steps" in metadata_dict:
add_param_if_not_none(params, "Steps", metadata_dict.get('steps')) add_param_if_not_none(params, "Steps", metadata_dict.get("steps"))
# Combine sampler and scheduler information # Combine sampler and scheduler information
sampler_name = None sampler_name = None
scheduler_name = None scheduler_name = None
if 'sampler' in metadata_dict: if "sampler" in metadata_dict:
sampler = metadata_dict.get('sampler') sampler = metadata_dict.get("sampler")
# Convert ComfyUI sampler names to user-friendly names # Convert ComfyUI sampler names to user-friendly names
sampler_mapping = { sampler_mapping = {
'euler': 'Euler', "euler": "Euler",
'euler_ancestral': 'Euler a', "euler_ancestral": "Euler a",
'dpm_2': 'DPM2', "dpm_2": "DPM2",
'dpm_2_ancestral': 'DPM2 a', "dpm_2_ancestral": "DPM2 a",
'heun': 'Heun', "heun": "Heun",
'dpm_fast': 'DPM fast', "dpm_fast": "DPM fast",
'dpm_adaptive': 'DPM adaptive', "dpm_adaptive": "DPM adaptive",
'lms': 'LMS', "lms": "LMS",
'dpmpp_2s_ancestral': 'DPM++ 2S a', "dpmpp_2s_ancestral": "DPM++ 2S a",
'dpmpp_sde': 'DPM++ SDE', "dpmpp_sde": "DPM++ SDE",
'dpmpp_sde_gpu': 'DPM++ SDE', "dpmpp_sde_gpu": "DPM++ SDE",
'dpmpp_2m': 'DPM++ 2M', "dpmpp_2m": "DPM++ 2M",
'dpmpp_2m_sde': 'DPM++ 2M SDE', "dpmpp_2m_sde": "DPM++ 2M SDE",
'dpmpp_2m_sde_gpu': 'DPM++ 2M SDE', "dpmpp_2m_sde_gpu": "DPM++ 2M SDE",
'ddim': 'DDIM' "ddim": "DDIM",
} }
sampler_name = sampler_mapping.get(sampler, sampler) sampler_name = sampler_mapping.get(sampler, sampler)
if 'scheduler' in metadata_dict: if "scheduler" in metadata_dict:
scheduler = metadata_dict.get('scheduler') scheduler = metadata_dict.get("scheduler")
scheduler_mapping = { scheduler_mapping = {
'normal': 'Simple', "normal": "Simple",
'karras': 'Karras', "karras": "Karras",
'exponential': 'Exponential', "exponential": "Exponential",
'sgm_uniform': 'SGM Uniform', "sgm_uniform": "SGM Uniform",
'sgm_quadratic': 'SGM Quadratic' "sgm_quadratic": "SGM Quadratic",
} }
scheduler_name = scheduler_mapping.get(scheduler, scheduler) scheduler_name = scheduler_mapping.get(scheduler, scheduler)
@@ -193,25 +215,25 @@ class SaveImageLM:
params.append(f"Sampler: {sampler_name}") params.append(f"Sampler: {sampler_name}")
# CFG scale (Use guidance if available, otherwise fall back to cfg_scale or cfg) # CFG scale (Use guidance if available, otherwise fall back to cfg_scale or cfg)
if 'guidance' in metadata_dict: if "guidance" in metadata_dict:
add_param_if_not_none(params, "CFG scale", metadata_dict.get('guidance')) add_param_if_not_none(params, "CFG scale", metadata_dict.get("guidance"))
elif 'cfg_scale' in metadata_dict: elif "cfg_scale" in metadata_dict:
add_param_if_not_none(params, "CFG scale", metadata_dict.get('cfg_scale')) add_param_if_not_none(params, "CFG scale", metadata_dict.get("cfg_scale"))
elif 'cfg' in metadata_dict: elif "cfg" in metadata_dict:
add_param_if_not_none(params, "CFG scale", metadata_dict.get('cfg')) add_param_if_not_none(params, "CFG scale", metadata_dict.get("cfg"))
# Seed # Seed
if 'seed' in metadata_dict: if "seed" in metadata_dict:
add_param_if_not_none(params, "Seed", metadata_dict.get('seed')) add_param_if_not_none(params, "Seed", metadata_dict.get("seed"))
# Size # Size
if 'size' in metadata_dict: if "size" in metadata_dict:
add_param_if_not_none(params, "Size", metadata_dict.get('size')) add_param_if_not_none(params, "Size", metadata_dict.get("size"))
# Model info # Model info
if 'checkpoint' in metadata_dict: if "checkpoint" in metadata_dict:
# Ensure checkpoint is a string before processing # Ensure checkpoint is a string before processing
checkpoint = metadata_dict.get('checkpoint') checkpoint = metadata_dict.get("checkpoint")
if checkpoint is not None: if checkpoint is not None:
# Get model hash # Get model hash
model_hash = self.get_checkpoint_hash(checkpoint) model_hash = self.get_checkpoint_hash(checkpoint)
@@ -223,7 +245,9 @@ class SaveImageLM:
# Add model hash if available # Add model hash if available
if model_hash: if model_hash:
params.append(f"Model hash: {model_hash[:10]}, Model: {checkpoint_name}") params.append(
f"Model hash: {model_hash[:10]}, Model: {checkpoint_name}"
)
else: else:
params.append(f"Model: {checkpoint_name}") params.append(f"Model: {checkpoint_name}")
@@ -234,7 +258,7 @@ class SaveImageLM:
lora_hash_parts.append(f"{lora_name}: {hash_value[:10]}") lora_hash_parts.append(f"{lora_name}: {hash_value[:10]}")
if lora_hash_parts: if lora_hash_parts:
params.append(f"Lora hashes: \"{', '.join(lora_hash_parts)}\"") params.append(f'Lora hashes: "{", ".join(lora_hash_parts)}"')
# Combine all parameters with commas # Combine all parameters with commas
metadata_parts.append(", ".join(params)) metadata_parts.append(", ".join(params))
@@ -254,30 +278,30 @@ class SaveImageLM:
parts = segment.replace("%", "").split(":") parts = segment.replace("%", "").split(":")
key = parts[0] key = parts[0]
if key == "seed" and 'seed' in metadata_dict: if key == "seed" and "seed" in metadata_dict:
filename = filename.replace(segment, str(metadata_dict.get('seed', ''))) filename = filename.replace(segment, str(metadata_dict.get("seed", "")))
elif key == "width" and 'size' in metadata_dict: elif key == "width" and "size" in metadata_dict:
size = metadata_dict.get('size', 'x') size = metadata_dict.get("size", "x")
w = size.split('x')[0] if isinstance(size, str) else size[0] w = size.split("x")[0] if isinstance(size, str) else size[0]
filename = filename.replace(segment, str(w)) filename = filename.replace(segment, str(w))
elif key == "height" and 'size' in metadata_dict: elif key == "height" and "size" in metadata_dict:
size = metadata_dict.get('size', 'x') size = metadata_dict.get("size", "x")
h = size.split('x')[1] if isinstance(size, str) else size[1] h = size.split("x")[1] if isinstance(size, str) else size[1]
filename = filename.replace(segment, str(h)) filename = filename.replace(segment, str(h))
elif key == "pprompt" and 'prompt' in metadata_dict: elif key == "pprompt" and "prompt" in metadata_dict:
prompt = metadata_dict.get('prompt', '').replace("\n", " ") prompt = metadata_dict.get("prompt", "").replace("\n", " ")
if len(parts) >= 2: if len(parts) >= 2:
length = int(parts[1]) length = int(parts[1])
prompt = prompt[:length] prompt = prompt[:length]
filename = filename.replace(segment, prompt.strip()) filename = filename.replace(segment, prompt.strip())
elif key == "nprompt" and 'negative_prompt' in metadata_dict: elif key == "nprompt" and "negative_prompt" in metadata_dict:
prompt = metadata_dict.get('negative_prompt', '').replace("\n", " ") prompt = metadata_dict.get("negative_prompt", "").replace("\n", " ")
if len(parts) >= 2: if len(parts) >= 2:
length = int(parts[1]) length = int(parts[1])
prompt = prompt[:length] prompt = prompt[:length]
filename = filename.replace(segment, prompt.strip()) filename = filename.replace(segment, prompt.strip())
elif key == "model": elif key == "model":
model_value = metadata_dict.get('checkpoint') model_value = metadata_dict.get("checkpoint")
if isinstance(model_value, (bytes, os.PathLike)): if isinstance(model_value, (bytes, os.PathLike)):
model_value = str(model_value) model_value = str(model_value)
@@ -291,6 +315,7 @@ class SaveImageLM:
filename = filename.replace(segment, model) filename = filename.replace(segment, model)
elif key == "date": elif key == "date":
from datetime import datetime from datetime import datetime
now = datetime.now() now = datetime.now()
date_table = { date_table = {
"yyyy": f"{now.year:04d}", "yyyy": f"{now.year:04d}",
@@ -314,8 +339,19 @@ class SaveImageLM:
return filename return filename
def save_images(self, images, filename_prefix, file_format, id, prompt=None, extra_pnginfo=None, def save_images(
lossless_webp=True, quality=100, embed_workflow=False, add_counter_to_filename=True): self,
images,
filename_prefix,
file_format,
id,
prompt=None,
extra_pnginfo=None,
lossless_webp=True,
quality=100,
embed_workflow=False,
add_counter_to_filename=True,
):
"""Save images with metadata""" """Save images with metadata"""
results = [] results = []
@@ -329,8 +365,10 @@ class SaveImageLM:
filename_prefix = self.format_filename(filename_prefix, metadata_dict) filename_prefix = self.format_filename(filename_prefix, metadata_dict)
# Get initial save path info once for the batch # Get initial save path info once for the batch
full_output_folder, filename, counter, subfolder, processed_prefix = folder_paths.get_save_image_path( full_output_folder, filename, counter, subfolder, processed_prefix = (
filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0] folder_paths.get_save_image_path(
filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0]
)
) )
# Create directory if it doesn't exist # Create directory if it doesn't exist
@@ -340,7 +378,7 @@ class SaveImageLM:
# Process each image with incrementing counter # Process each image with incrementing counter
for i, image in enumerate(images): for i, image in enumerate(images):
# Convert the tensor image to numpy array # Convert the tensor image to numpy array
img = 255. * image.cpu().numpy() img = 255.0 * image.cpu().numpy()
img = Image.fromarray(np.clip(img, 0, 255).astype(np.uint8)) img = Image.fromarray(np.clip(img, 0, 255).astype(np.uint8))
# Generate filename with counter if needed # Generate filename with counter if needed
@@ -351,6 +389,9 @@ class SaveImageLM:
base_filename += f"_{current_counter:05}_" base_filename += f"_{current_counter:05}_"
# Set file extension and prepare saving parameters # Set file extension and prepare saving parameters
file: str
save_kwargs: Dict[str, Any]
pnginfo: Optional[PngImagePlugin.PngInfo] = None
if file_format == "png": if file_format == "png":
file = base_filename + ".png" file = base_filename + ".png"
file_extension = ".png" file_extension = ".png"
@@ -365,7 +406,13 @@ class SaveImageLM:
file = base_filename + ".webp" file = base_filename + ".webp"
file_extension = ".webp" file_extension = ".webp"
# Add optimization param to control performance # Add optimization param to control performance
save_kwargs = {"quality": quality, "lossless": lossless_webp, "method": 0} save_kwargs = {
"quality": quality,
"lossless": lossless_webp,
"method": 0,
}
else:
raise ValueError(f"Unsupported file format: {file_format}")
# Full save path # Full save path
file_path = os.path.join(full_output_folder, file) file_path = os.path.join(full_output_folder, file)
@@ -373,6 +420,7 @@ class SaveImageLM:
# Save the image with metadata # Save the image with metadata
try: try:
if file_format == "png": if file_format == "png":
assert pnginfo is not None
if metadata: if metadata:
pnginfo.add_text("parameters", metadata) pnginfo.add_text("parameters", metadata)
if embed_workflow and extra_pnginfo is not None: if embed_workflow and extra_pnginfo is not None:
@@ -384,7 +432,12 @@ class SaveImageLM:
# For JPEG, use piexif # For JPEG, use piexif
if metadata: if metadata:
try: try:
exif_dict = {'Exif': {piexif.ExifIFD.UserComment: b'UNICODE\0' + metadata.encode('utf-16be')}} exif_dict = {
"Exif": {
piexif.ExifIFD.UserComment: b"UNICODE\0"
+ metadata.encode("utf-16be")
}
}
exif_bytes = piexif.dump(exif_dict) exif_bytes = piexif.dump(exif_dict)
save_kwargs["exif"] = exif_bytes save_kwargs["exif"] = exif_bytes
except Exception as e: except Exception as e:
@@ -396,12 +449,18 @@ class SaveImageLM:
exif_dict = {} exif_dict = {}
if metadata: if metadata:
exif_dict['Exif'] = {piexif.ExifIFD.UserComment: b'UNICODE\0' + metadata.encode('utf-16be')} exif_dict["Exif"] = {
piexif.ExifIFD.UserComment: b"UNICODE\0"
+ metadata.encode("utf-16be")
}
# Add workflow if needed # Add workflow if needed
if embed_workflow and extra_pnginfo is not None: if embed_workflow and extra_pnginfo is not None:
workflow_json = json.dumps(extra_pnginfo["workflow"]) workflow_json = json.dumps(extra_pnginfo["workflow"])
exif_dict['0th'] = {piexif.ImageIFD.ImageDescription: "Workflow:" + workflow_json} exif_dict["0th"] = {
piexif.ImageIFD.ImageDescription: "Workflow:"
+ workflow_json
}
exif_bytes = piexif.dump(exif_dict) exif_bytes = piexif.dump(exif_dict)
save_kwargs["exif"] = exif_bytes save_kwargs["exif"] = exif_bytes
@@ -410,19 +469,28 @@ class SaveImageLM:
img.save(file_path, format="WEBP", **save_kwargs) img.save(file_path, format="WEBP", **save_kwargs)
results.append({ results.append(
"filename": file, {"filename": file, "subfolder": subfolder, "type": self.type}
"subfolder": subfolder, )
"type": self.type
})
except Exception as e: except Exception as e:
logger.error(f"Error saving image: {e}") logger.error(f"Error saving image: {e}")
return results return results
def process_image(self, images, id, filename_prefix="ComfyUI", file_format="png", prompt=None, extra_pnginfo=None, def process_image(
lossless_webp=True, quality=100, embed_workflow=False, add_counter_to_filename=True): self,
images,
id,
filename_prefix="ComfyUI",
file_format="png",
prompt=None,
extra_pnginfo=None,
lossless_webp=True,
quality=100,
embed_workflow=False,
add_counter_to_filename=True,
):
"""Process and save image with metadata""" """Process and save image with metadata"""
# Make sure the output directory exists # Make sure the output directory exists
os.makedirs(self.output_dir, exist_ok=True) os.makedirs(self.output_dir, exist_ok=True)
@@ -448,7 +516,7 @@ class SaveImageLM:
lossless_webp, lossless_webp,
quality, quality,
embed_workflow, embed_workflow,
add_counter_to_filename add_counter_to_filename,
) )
return (images,) return (images,)

View File

@@ -1,33 +1,35 @@
class AnyType(str): class AnyType(str):
"""A special class that is always equal in not equal comparisons. Credit to pythongosssss""" """A special class that is always equal in not equal comparisons. Credit to pythongosssss"""
def __ne__(self, __value: object) -> bool:
return False
def __ne__(self, __value: object) -> bool:
return False
# Credit to Regis Gaughan, III (rgthree) # Credit to Regis Gaughan, III (rgthree)
class FlexibleOptionalInputType(dict): class FlexibleOptionalInputType(dict):
"""A special class to make flexible nodes that pass data to our python handlers. """A special class to make flexible nodes that pass data to our python handlers.
Enables both flexible/dynamic input types (like for Any Switch) or a dynamic number of inputs Enables both flexible/dynamic input types (like for Any Switch) or a dynamic number of inputs
(like for Any Switch, Context Switch, Context Merge, Power Lora Loader, etc). (like for Any Switch, Context Switch, Context Merge, Power Lora Loader, etc).
Note, for ComfyUI, all that's needed is the `__contains__` override below, which tells ComfyUI Note, for ComfyUI, all that's needed is the `__contains__` override below, which tells ComfyUI
that our node will handle the input, regardless of what it is. that our node will handle the input, regardless of what it is.
However, with https://github.com/comfyanonymous/ComfyUI/pull/2666 a large change would occur However, with https://github.com/comfyanonymous/ComfyUI/pull/2666 a large change would occur
requiring more details on the input itself. There, we need to return a list/tuple where the first requiring more details on the input itself. There, we need to return a list/tuple where the first
item is the type. This can be a real type, or use the AnyType for additional flexibility. item is the type. This can be a real type, or use the AnyType for additional flexibility.
This should be forwards compatible unless more changes occur in the PR. This should be forwards compatible unless more changes occur in the PR.
""" """
def __init__(self, type):
self.type = type
def __getitem__(self, key): def __init__(self, type):
return (self.type, ) self.type = type
def __contains__(self, key): def __getitem__(self, key):
return True return (self.type,)
def __contains__(self, key):
return True
any_type = AnyType("*") any_type = AnyType("*")
@@ -37,25 +39,27 @@ import os
import logging import logging
import copy import copy
import sys import sys
import folder_paths import folder_paths # type: ignore
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def extract_lora_name(lora_path): def extract_lora_name(lora_path):
"""Extract the lora name from a lora path (e.g., 'IL\\aorunIllstrious.safetensors' -> 'aorunIllstrious')""" """Extract the lora name from a lora path (e.g., 'IL\\aorunIllstrious.safetensors' -> 'aorunIllstrious')"""
# Get the basename without extension # Get the basename without extension
basename = os.path.basename(lora_path) basename = os.path.basename(lora_path)
return os.path.splitext(basename)[0] return os.path.splitext(basename)[0]
def get_loras_list(kwargs): def get_loras_list(kwargs):
"""Helper to extract loras list from either old or new kwargs format""" """Helper to extract loras list from either old or new kwargs format"""
if 'loras' not in kwargs: if "loras" not in kwargs:
return [] return []
loras_data = kwargs['loras'] loras_data = kwargs["loras"]
# Handle new format: {'loras': {'__value__': [...]}} # Handle new format: {'loras': {'__value__': [...]}}
if isinstance(loras_data, dict) and '__value__' in loras_data: if isinstance(loras_data, dict) and "__value__" in loras_data:
return loras_data['__value__'] return loras_data["__value__"]
# Handle old format: {'loras': [...]} # Handle old format: {'loras': [...]}
elif isinstance(loras_data, list): elif isinstance(loras_data, list):
return loras_data return loras_data
@@ -64,23 +68,25 @@ def get_loras_list(kwargs):
logger.warning(f"Unexpected loras format: {type(loras_data)}") logger.warning(f"Unexpected loras format: {type(loras_data)}")
return [] return []
def load_state_dict_in_safetensors(path, device="cpu", filter_prefix=""): def load_state_dict_in_safetensors(path, device="cpu", filter_prefix=""):
"""Simplified version of load_state_dict_in_safetensors that just loads from a local path""" """Simplified version of load_state_dict_in_safetensors that just loads from a local path"""
import safetensors.torch import safetensors.torch
state_dict = {} state_dict = {}
with safetensors.torch.safe_open(path, framework="pt", device=device) as f: with safetensors.torch.safe_open(path, framework="pt", device=device) as f: # type: ignore[attr-defined]
for k in f.keys(): for k in f.keys():
if filter_prefix and not k.startswith(filter_prefix): if filter_prefix and not k.startswith(filter_prefix):
continue continue
state_dict[k.removeprefix(filter_prefix)] = f.get_tensor(k) state_dict[k.removeprefix(filter_prefix)] = f.get_tensor(k)
return state_dict return state_dict
def to_diffusers(input_lora): def to_diffusers(input_lora):
"""Simplified version of to_diffusers for Flux LoRA conversion""" """Simplified version of to_diffusers for Flux LoRA conversion"""
import torch import torch
from diffusers.utils.state_dict_utils import convert_unet_state_dict_to_peft from diffusers.utils.state_dict_utils import convert_unet_state_dict_to_peft
from diffusers.loaders import FluxLoraLoaderMixin from diffusers.loaders import FluxLoraLoaderMixin # type: ignore[attr-defined]
if isinstance(input_lora, str): if isinstance(input_lora, str):
tensors = load_state_dict_in_safetensors(input_lora, device="cpu") tensors = load_state_dict_in_safetensors(input_lora, device="cpu")
@@ -97,10 +103,15 @@ def to_diffusers(input_lora):
return new_tensors return new_tensors
def nunchaku_load_lora(model, lora_name, lora_strength): def nunchaku_load_lora(model, lora_name, lora_strength):
"""Load a Flux LoRA for Nunchaku model""" """Load a Flux LoRA for Nunchaku model"""
# Get full path to the LoRA file. Allow both direct paths and registered LoRA names. # Get full path to the LoRA file. Allow both direct paths and registered LoRA names.
lora_path = lora_name if os.path.isfile(lora_name) else folder_paths.get_full_path("loras", lora_name) lora_path = (
lora_name
if os.path.isfile(lora_name)
else folder_paths.get_full_path("loras", lora_name)
)
if not lora_path or not os.path.isfile(lora_path): if not lora_path or not os.path.isfile(lora_path):
logger.warning("Skipping LoRA '%s' because it could not be found", lora_name) logger.warning("Skipping LoRA '%s' because it could not be found", lora_name)
return model return model
@@ -118,7 +129,9 @@ def nunchaku_load_lora(model, lora_name, lora_strength):
ret_model_wrapper.loras = [*model_wrapper.loras, (lora_path, lora_strength)] ret_model_wrapper.loras = [*model_wrapper.loras, (lora_path, lora_strength)]
else: else:
# Fallback to legacy logic # Fallback to legacy logic
logger.warning("Please upgrade ComfyUI-nunchaku to 1.1.0 or above for better LoRA support. Falling back to legacy loading logic.") logger.warning(
"Please upgrade ComfyUI-nunchaku to 1.1.0 or above for better LoRA support. Falling back to legacy loading logic."
)
transformer = model_wrapper.model transformer = model_wrapper.model
# Save the transformer temporarily # Save the transformer temporarily

View File

@@ -6,17 +6,18 @@ from .parsers import (
ComfyMetadataParser, ComfyMetadataParser,
MetaFormatParser, MetaFormatParser,
AutomaticMetadataParser, AutomaticMetadataParser,
CivitaiApiMetadataParser CivitaiApiMetadataParser,
) )
from .base import RecipeMetadataParser from .base import RecipeMetadataParser
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class RecipeParserFactory: class RecipeParserFactory:
"""Factory for creating recipe metadata parsers""" """Factory for creating recipe metadata parsers"""
@staticmethod @staticmethod
def create_parser(metadata) -> RecipeMetadataParser: def create_parser(metadata) -> RecipeMetadataParser | None:
""" """
Create appropriate parser based on the metadata content Create appropriate parser based on the metadata content
@@ -38,6 +39,7 @@ class RecipeParserFactory:
# Convert dict to string for other parsers that expect string input # Convert dict to string for other parsers that expect string input
try: try:
import json import json
metadata_str = json.dumps(metadata) metadata_str = json.dumps(metadata)
except Exception as e: except Exception as e:
logger.debug(f"Failed to convert dict to JSON string: {e}") logger.debug(f"Failed to convert dict to JSON string: {e}")

View File

@@ -9,6 +9,7 @@ from ...services.metadata_service import get_default_metadata_provider
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class CivitaiApiMetadataParser(RecipeMetadataParser): class CivitaiApiMetadataParser(RecipeMetadataParser):
"""Parser for Civitai image metadata format""" """Parser for Civitai image metadata format"""
@@ -40,7 +41,7 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
"width", "width",
"height", "height",
"Model", "Model",
"Model hash" "Model hash",
) )
return any(key in payload for key in civitai_image_fields) return any(key in payload for key in civitai_image_fields)
@@ -50,7 +51,9 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
# Check for LoRA hash patterns # Check for LoRA hash patterns
hashes = metadata.get("hashes") hashes = metadata.get("hashes")
if isinstance(hashes, dict) and any(str(key).lower().startswith("lora:") for key in hashes): if isinstance(hashes, dict) and any(
str(key).lower().startswith("lora:") for key in hashes
):
return True return True
# Check nested meta object (common in CivitAI image responses) # Check nested meta object (common in CivitAI image responses)
@@ -61,22 +64,28 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
# Also check for LoRA hash patterns in nested meta # Also check for LoRA hash patterns in nested meta
hashes = nested_meta.get("hashes") hashes = nested_meta.get("hashes")
if isinstance(hashes, dict) and any(str(key).lower().startswith("lora:") for key in hashes): if isinstance(hashes, dict) and any(
str(key).lower().startswith("lora:") for key in hashes
):
return True return True
return False return False
async def parse_metadata(self, metadata, recipe_scanner=None, civitai_client=None) -> Dict[str, Any]: async def parse_metadata( # type: ignore[override]
self, user_comment, recipe_scanner=None, civitai_client=None
) -> Dict[str, Any]:
"""Parse metadata from Civitai image format """Parse metadata from Civitai image format
Args: Args:
metadata: The metadata from the image (dict) user_comment: The metadata from the image (dict)
recipe_scanner: Optional recipe scanner service recipe_scanner: Optional recipe scanner service
civitai_client: Optional Civitai API client (deprecated, use metadata_provider instead) civitai_client: Optional Civitai API client (deprecated, use metadata_provider instead)
Returns: Returns:
Dict containing parsed recipe data Dict containing parsed recipe data
""" """
metadata: Dict[str, Any] = user_comment # type: ignore[assignment]
metadata = user_comment
try: try:
# Get metadata provider instead of using civitai_client directly # Get metadata provider instead of using civitai_client directly
metadata_provider = await get_default_metadata_provider() metadata_provider = await get_default_metadata_provider()
@@ -103,11 +112,11 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
# Initialize result structure # Initialize result structure
result = { result = {
'base_model': None, "base_model": None,
'loras': [], "loras": [],
'model': None, "model": None,
'gen_params': {}, "gen_params": {},
'from_civitai_image': True "from_civitai_image": True,
} }
# Track already added LoRAs to prevent duplicates # Track already added LoRAs to prevent duplicates
@@ -148,16 +157,25 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
result["base_model"] = metadata["baseModel"] result["base_model"] = metadata["baseModel"]
elif "Model hash" in metadata and metadata_provider: elif "Model hash" in metadata and metadata_provider:
model_hash = metadata["Model hash"] model_hash = metadata["Model hash"]
model_info, error = await metadata_provider.get_model_by_hash(model_hash) model_info, error = await metadata_provider.get_model_by_hash(
model_hash
)
if model_info: if model_info:
result["base_model"] = model_info.get("baseModel", "") result["base_model"] = model_info.get("baseModel", "")
elif "Model" in metadata and isinstance(metadata.get("resources"), list): elif "Model" in metadata and isinstance(metadata.get("resources"), list):
# Try to find base model in resources # Try to find base model in resources
for resource in metadata.get("resources", []): for resource in metadata.get("resources", []):
if resource.get("type") == "model" and resource.get("name") == metadata.get("Model"): if resource.get("type") == "model" and resource.get(
"name"
) == metadata.get("Model"):
# This is likely the checkpoint model # This is likely the checkpoint model
if metadata_provider and resource.get("hash"): if metadata_provider and resource.get("hash"):
model_info, error = await metadata_provider.get_model_by_hash(resource.get("hash")) (
model_info,
error,
) = await metadata_provider.get_model_by_hash(
resource.get("hash")
)
if model_info: if model_info:
result["base_model"] = model_info.get("baseModel", "") result["base_model"] = model_info.get("baseModel", "")
@@ -176,7 +194,9 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
# Skip LoRAs without proper identification (hash or modelVersionId) # Skip LoRAs without proper identification (hash or modelVersionId)
if not lora_hash and not resource.get("modelVersionId"): if not lora_hash and not resource.get("modelVersionId"):
logger.debug(f"Skipping LoRA resource '{resource.get('name', 'Unknown')}' - no hash or modelVersionId") logger.debug(
f"Skipping LoRA resource '{resource.get('name', 'Unknown')}' - no hash or modelVersionId"
)
continue continue
# Skip if we've already added this LoRA by hash # Skip if we've already added this LoRA by hash
@@ -184,31 +204,33 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
continue continue
lora_entry = { lora_entry = {
'name': resource.get("name", "Unknown LoRA"), "name": resource.get("name", "Unknown LoRA"),
'type': "lora", "type": "lora",
'weight': float(resource.get("weight", 1.0)), "weight": float(resource.get("weight", 1.0)),
'hash': lora_hash, "hash": lora_hash,
'existsLocally': False, "existsLocally": False,
'localPath': None, "localPath": None,
'file_name': resource.get("name", "Unknown"), "file_name": resource.get("name", "Unknown"),
'thumbnailUrl': '/loras_static/images/no-preview.png', "thumbnailUrl": "/loras_static/images/no-preview.png",
'baseModel': '', "baseModel": "",
'size': 0, "size": 0,
'downloadUrl': '', "downloadUrl": "",
'isDeleted': False "isDeleted": False,
} }
# Try to get info from Civitai if hash is available # Try to get info from Civitai if hash is available
if lora_entry['hash'] and metadata_provider: if lora_entry["hash"] and metadata_provider:
try: try:
civitai_info = await metadata_provider.get_model_by_hash(lora_hash) civitai_info = (
await metadata_provider.get_model_by_hash(lora_hash)
)
populated_entry = await self.populate_lora_from_civitai( populated_entry = await self.populate_lora_from_civitai(
lora_entry, lora_entry,
civitai_info, civitai_info,
recipe_scanner, recipe_scanner,
base_model_counts, base_model_counts,
lora_hash lora_hash,
) )
if populated_entry is None: if populated_entry is None:
@@ -217,10 +239,14 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
lora_entry = populated_entry lora_entry = populated_entry
# If we have a version ID from Civitai, track it for deduplication # If we have a version ID from Civitai, track it for deduplication
if 'id' in lora_entry and lora_entry['id']: if "id" in lora_entry and lora_entry["id"]:
added_loras[str(lora_entry['id'])] = len(result["loras"]) added_loras[str(lora_entry["id"])] = len(
result["loras"]
)
except Exception as e: except Exception as e:
logger.error(f"Error fetching Civitai info for LoRA hash {lora_entry['hash']}: {e}") logger.error(
f"Error fetching Civitai info for LoRA hash {lora_entry['hash']}: {e}"
)
# Track by hash if we have it # Track by hash if we have it
if lora_hash: if lora_hash:
@@ -229,7 +255,9 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
result["loras"].append(lora_entry) result["loras"].append(lora_entry)
# Process civitaiResources array # Process civitaiResources array
if "civitaiResources" in metadata and isinstance(metadata["civitaiResources"], list): if "civitaiResources" in metadata and isinstance(
metadata["civitaiResources"], list
):
for resource in metadata["civitaiResources"]: for resource in metadata["civitaiResources"]:
# Get resource type and identifier # Get resource type and identifier
resource_type = str(resource.get("type") or "").lower() resource_type = str(resource.get("type") or "").lower()
@@ -237,32 +265,39 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
if resource_type == "checkpoint": if resource_type == "checkpoint":
checkpoint_entry = { checkpoint_entry = {
'id': resource.get("modelVersionId", 0), "id": resource.get("modelVersionId", 0),
'modelId': resource.get("modelId", 0), "modelId": resource.get("modelId", 0),
'name': resource.get("modelName", "Unknown Checkpoint"), "name": resource.get("modelName", "Unknown Checkpoint"),
'version': resource.get("modelVersionName", ""), "version": resource.get("modelVersionName", ""),
'type': resource.get("type", "checkpoint"), "type": resource.get("type", "checkpoint"),
'existsLocally': False, "existsLocally": False,
'localPath': None, "localPath": None,
'file_name': resource.get("modelName", ""), "file_name": resource.get("modelName", ""),
'hash': resource.get("hash", "") or "", "hash": resource.get("hash", "") or "",
'thumbnailUrl': '/loras_static/images/no-preview.png', "thumbnailUrl": "/loras_static/images/no-preview.png",
'baseModel': '', "baseModel": "",
'size': 0, "size": 0,
'downloadUrl': '', "downloadUrl": "",
'isDeleted': False "isDeleted": False,
} }
if version_id and metadata_provider: if version_id and metadata_provider:
try: try:
civitai_info = await metadata_provider.get_model_version_info(version_id) civitai_info = (
await metadata_provider.get_model_version_info(
version_id
)
)
checkpoint_entry = await self.populate_checkpoint_from_civitai( checkpoint_entry = (
checkpoint_entry, await self.populate_checkpoint_from_civitai(
civitai_info checkpoint_entry, civitai_info
)
) )
except Exception as e: except Exception as e:
logger.error(f"Error fetching Civitai info for checkpoint version {version_id}: {e}") logger.error(
f"Error fetching Civitai info for checkpoint version {version_id}: {e}"
)
if result["model"] is None: if result["model"] is None:
result["model"] = checkpoint_entry result["model"] = checkpoint_entry
@@ -275,31 +310,35 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
# Initialize lora entry # Initialize lora entry
lora_entry = { lora_entry = {
'id': resource.get("modelVersionId", 0), "id": resource.get("modelVersionId", 0),
'modelId': resource.get("modelId", 0), "modelId": resource.get("modelId", 0),
'name': resource.get("modelName", "Unknown LoRA"), "name": resource.get("modelName", "Unknown LoRA"),
'version': resource.get("modelVersionName", ""), "version": resource.get("modelVersionName", ""),
'type': resource.get("type", "lora"), "type": resource.get("type", "lora"),
'weight': round(float(resource.get("weight", 1.0)), 2), "weight": round(float(resource.get("weight", 1.0)), 2),
'existsLocally': False, "existsLocally": False,
'thumbnailUrl': '/loras_static/images/no-preview.png', "thumbnailUrl": "/loras_static/images/no-preview.png",
'baseModel': '', "baseModel": "",
'size': 0, "size": 0,
'downloadUrl': '', "downloadUrl": "",
'isDeleted': False "isDeleted": False,
} }
# Try to get info from Civitai if modelVersionId is available # Try to get info from Civitai if modelVersionId is available
if version_id and metadata_provider: if version_id and metadata_provider:
try: try:
# Use get_model_version_info instead of get_model_version # Use get_model_version_info instead of get_model_version
civitai_info = await metadata_provider.get_model_version_info(version_id) civitai_info = (
await metadata_provider.get_model_version_info(
version_id
)
)
populated_entry = await self.populate_lora_from_civitai( populated_entry = await self.populate_lora_from_civitai(
lora_entry, lora_entry,
civitai_info, civitai_info,
recipe_scanner, recipe_scanner,
base_model_counts base_model_counts,
) )
if populated_entry is None: if populated_entry is None:
@@ -307,7 +346,9 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
lora_entry = populated_entry lora_entry = populated_entry
except Exception as e: except Exception as e:
logger.error(f"Error fetching Civitai info for model version {version_id}: {e}") logger.error(
f"Error fetching Civitai info for model version {version_id}: {e}"
)
# Track this LoRA in our deduplication dict # Track this LoRA in our deduplication dict
if version_id: if version_id:
@@ -316,10 +357,15 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
result["loras"].append(lora_entry) result["loras"].append(lora_entry)
# Process additionalResources array # Process additionalResources array
if "additionalResources" in metadata and isinstance(metadata["additionalResources"], list): if "additionalResources" in metadata and isinstance(
metadata["additionalResources"], list
):
for resource in metadata["additionalResources"]: for resource in metadata["additionalResources"]:
# Skip resources that aren't LoRAs or LyCORIS # Skip resources that aren't LoRAs or LyCORIS
if resource.get("type") not in ["lora", "lycoris"] and "type" not in resource: if (
resource.get("type") not in ["lora", "lycoris"]
and "type" not in resource
):
continue continue
lora_type = resource.get("type", "lora") lora_type = resource.get("type", "lora")
@@ -337,31 +383,35 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
continue continue
lora_entry = { lora_entry = {
'name': name, "name": name,
'type': lora_type, "type": lora_type,
'weight': float(resource.get("strength", 1.0)), "weight": float(resource.get("strength", 1.0)),
'hash': "", "hash": "",
'existsLocally': False, "existsLocally": False,
'localPath': None, "localPath": None,
'file_name': name, "file_name": name,
'thumbnailUrl': '/loras_static/images/no-preview.png', "thumbnailUrl": "/loras_static/images/no-preview.png",
'baseModel': '', "baseModel": "",
'size': 0, "size": 0,
'downloadUrl': '', "downloadUrl": "",
'isDeleted': False "isDeleted": False,
} }
# If we have a version ID and metadata provider, try to get more info # If we have a version ID and metadata provider, try to get more info
if version_id and metadata_provider: if version_id and metadata_provider:
try: try:
# Use get_model_version_info with the version ID # Use get_model_version_info with the version ID
civitai_info = await metadata_provider.get_model_version_info(version_id) civitai_info = (
await metadata_provider.get_model_version_info(
version_id
)
)
populated_entry = await self.populate_lora_from_civitai( populated_entry = await self.populate_lora_from_civitai(
lora_entry, lora_entry,
civitai_info, civitai_info,
recipe_scanner, recipe_scanner,
base_model_counts base_model_counts,
) )
if populated_entry is None: if populated_entry is None:
@@ -373,7 +423,9 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
if version_id: if version_id:
added_loras[version_id] = len(result["loras"]) added_loras[version_id] = len(result["loras"])
except Exception as e: except Exception as e:
logger.error(f"Error fetching Civitai info for model ID {version_id}: {e}") logger.error(
f"Error fetching Civitai info for model ID {version_id}: {e}"
)
result["loras"].append(lora_entry) result["loras"].append(lora_entry)
@@ -390,30 +442,32 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
continue continue
lora_entry = { lora_entry = {
'name': lora_name, "name": lora_name,
'type': "lora", "type": "lora",
'weight': 1.0, "weight": 1.0,
'hash': lora_hash, "hash": lora_hash,
'existsLocally': False, "existsLocally": False,
'localPath': None, "localPath": None,
'file_name': lora_name, "file_name": lora_name,
'thumbnailUrl': '/loras_static/images/no-preview.png', "thumbnailUrl": "/loras_static/images/no-preview.png",
'baseModel': '', "baseModel": "",
'size': 0, "size": 0,
'downloadUrl': '', "downloadUrl": "",
'isDeleted': False "isDeleted": False,
} }
if metadata_provider: if metadata_provider:
try: try:
civitai_info = await metadata_provider.get_model_by_hash(lora_hash) civitai_info = await metadata_provider.get_model_by_hash(
lora_hash
)
populated_entry = await self.populate_lora_from_civitai( populated_entry = await self.populate_lora_from_civitai(
lora_entry, lora_entry,
civitai_info, civitai_info,
recipe_scanner, recipe_scanner,
base_model_counts, base_model_counts,
lora_hash lora_hash,
) )
if populated_entry is None: if populated_entry is None:
@@ -421,20 +475,27 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
lora_entry = populated_entry lora_entry = populated_entry
if 'id' in lora_entry and lora_entry['id']: if "id" in lora_entry and lora_entry["id"]:
added_loras[str(lora_entry['id'])] = len(result["loras"]) added_loras[str(lora_entry["id"])] = len(result["loras"])
except Exception as e: except Exception as e:
logger.error(f"Error fetching Civitai info for LoRA hash {lora_hash}: {e}") logger.error(
f"Error fetching Civitai info for LoRA hash {lora_hash}: {e}"
)
added_loras[lora_hash] = len(result["loras"]) added_loras[lora_hash] = len(result["loras"])
result["loras"].append(lora_entry) result["loras"].append(lora_entry)
# Check for LoRA info in the format "Lora_0 Model hash", "Lora_0 Model name", etc. # Check for LoRA info in the format "Lora_0 Model hash", "Lora_0 Model name", etc.
lora_index = 0 lora_index = 0
while f"Lora_{lora_index} Model hash" in metadata and f"Lora_{lora_index} Model name" in metadata: while (
f"Lora_{lora_index} Model hash" in metadata
and f"Lora_{lora_index} Model name" in metadata
):
lora_hash = metadata[f"Lora_{lora_index} Model hash"] lora_hash = metadata[f"Lora_{lora_index} Model hash"]
lora_name = metadata[f"Lora_{lora_index} Model name"] lora_name = metadata[f"Lora_{lora_index} Model name"]
lora_strength_model = float(metadata.get(f"Lora_{lora_index} Strength model", 1.0)) lora_strength_model = float(
metadata.get(f"Lora_{lora_index} Strength model", 1.0)
)
# Skip if we've already added this LoRA by hash # Skip if we've already added this LoRA by hash
if lora_hash and lora_hash in added_loras: if lora_hash and lora_hash in added_loras:
@@ -442,31 +503,33 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
continue continue
lora_entry = { lora_entry = {
'name': lora_name, "name": lora_name,
'type': "lora", "type": "lora",
'weight': lora_strength_model, "weight": lora_strength_model,
'hash': lora_hash, "hash": lora_hash,
'existsLocally': False, "existsLocally": False,
'localPath': None, "localPath": None,
'file_name': lora_name, "file_name": lora_name,
'thumbnailUrl': '/loras_static/images/no-preview.png', "thumbnailUrl": "/loras_static/images/no-preview.png",
'baseModel': '', "baseModel": "",
'size': 0, "size": 0,
'downloadUrl': '', "downloadUrl": "",
'isDeleted': False "isDeleted": False,
} }
# Try to get info from Civitai if hash is available # Try to get info from Civitai if hash is available
if lora_entry['hash'] and metadata_provider: if lora_entry["hash"] and metadata_provider:
try: try:
civitai_info = await metadata_provider.get_model_by_hash(lora_hash) civitai_info = await metadata_provider.get_model_by_hash(
lora_hash
)
populated_entry = await self.populate_lora_from_civitai( populated_entry = await self.populate_lora_from_civitai(
lora_entry, lora_entry,
civitai_info, civitai_info,
recipe_scanner, recipe_scanner,
base_model_counts, base_model_counts,
lora_hash lora_hash,
) )
if populated_entry is None: if populated_entry is None:
@@ -476,10 +539,12 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
lora_entry = populated_entry lora_entry = populated_entry
# If we have a version ID from Civitai, track it for deduplication # If we have a version ID from Civitai, track it for deduplication
if 'id' in lora_entry and lora_entry['id']: if "id" in lora_entry and lora_entry["id"]:
added_loras[str(lora_entry['id'])] = len(result["loras"]) added_loras[str(lora_entry["id"])] = len(result["loras"])
except Exception as e: except Exception as e:
logger.error(f"Error fetching Civitai info for LoRA hash {lora_entry['hash']}: {e}") logger.error(
f"Error fetching Civitai info for LoRA hash {lora_entry['hash']}: {e}"
)
# Track by hash if we have it # Track by hash if we have it
if lora_hash: if lora_hash:
@@ -491,7 +556,9 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
# If base model wasn't found earlier, use the most common one from LoRAs # If base model wasn't found earlier, use the most common one from LoRAs
if not result["base_model"] and base_model_counts: if not result["base_model"] and base_model_counts:
result["base_model"] = max(base_model_counts.items(), key=lambda x: x[1])[0] result["base_model"] = max(
base_model_counts.items(), key=lambda x: x[1]
)[0]
return result return result

View File

@@ -3,13 +3,17 @@ import copy
import logging import logging
import os import os
from typing import Any, Optional, Dict, Tuple, List, Sequence from typing import Any, Optional, Dict, Tuple, List, Sequence
from .model_metadata_provider import CivitaiModelMetadataProvider, ModelMetadataProviderManager from .model_metadata_provider import (
CivitaiModelMetadataProvider,
ModelMetadataProviderManager,
)
from .downloader import get_downloader from .downloader import get_downloader
from .errors import RateLimitError, ResourceNotFoundError from .errors import RateLimitError, ResourceNotFoundError
from ..utils.civitai_utils import resolve_license_payload from ..utils.civitai_utils import resolve_license_payload
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class CivitaiClient: class CivitaiClient:
_instance = None _instance = None
_lock = asyncio.Lock() _lock = asyncio.Lock()
@@ -23,13 +27,15 @@ class CivitaiClient:
# Register this client as a metadata provider # Register this client as a metadata provider
provider_manager = await ModelMetadataProviderManager.get_instance() provider_manager = await ModelMetadataProviderManager.get_instance()
provider_manager.register_provider('civitai', CivitaiModelMetadataProvider(cls._instance), True) provider_manager.register_provider(
"civitai", CivitaiModelMetadataProvider(cls._instance), True
)
return cls._instance return cls._instance
def __init__(self): def __init__(self):
# Check if already initialized for singleton pattern # Check if already initialized for singleton pattern
if hasattr(self, '_initialized'): if hasattr(self, "_initialized"):
return return
self._initialized = True self._initialized = True
@@ -76,7 +82,9 @@ class CivitaiClient:
if isinstance(meta, dict) and "comfy" in meta: if isinstance(meta, dict) and "comfy" in meta:
meta.pop("comfy", None) meta.pop("comfy", None)
async def download_file(self, url: str, save_dir: str, default_filename: str, progress_callback=None) -> Tuple[bool, str]: async def download_file(
self, url: str, save_dir: str, default_filename: str, progress_callback=None
) -> Tuple[bool, str]:
"""Download file with resumable downloads and retry mechanism """Download file with resumable downloads and retry mechanism
Args: Args:
@@ -97,34 +105,41 @@ class CivitaiClient:
save_path=save_path, save_path=save_path,
progress_callback=progress_callback, progress_callback=progress_callback,
use_auth=True, # Enable CivitAI authentication use_auth=True, # Enable CivitAI authentication
allow_resume=True allow_resume=True,
) )
return success, result return success, result
async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]: async def get_model_by_hash(
self, model_hash: str
) -> Tuple[Optional[Dict], Optional[str]]:
try: try:
success, version = await self._make_request( success, version = await self._make_request(
'GET', "GET",
f"{self.base_url}/model-versions/by-hash/{model_hash}", f"{self.base_url}/model-versions/by-hash/{model_hash}",
use_auth=True use_auth=True,
) )
if not success: if not success:
message = str(version) message = str(version)
if "not found" in message.lower(): if "not found" in message.lower():
return None, "Model not found" return None, "Model not found"
logger.error("Failed to fetch model info for %s: %s", model_hash[:10], message) logger.error(
"Failed to fetch model info for %s: %s", model_hash[:10], message
)
return None, message return None, message
model_id = version.get('modelId') if isinstance(version, dict):
if model_id: model_id = version.get("modelId")
model_data = await self._fetch_model_data(model_id) if model_id:
if model_data: model_data = await self._fetch_model_data(model_id)
self._enrich_version_with_model_data(version, model_data) if model_data:
self._enrich_version_with_model_data(version, model_data)
self._remove_comfy_metadata(version) self._remove_comfy_metadata(version)
return version, None return version, None
else:
return None, "Invalid response format"
except RateLimitError: except RateLimitError:
raise raise
except Exception as exc: except Exception as exc:
@@ -136,12 +151,12 @@ class CivitaiClient:
downloader = await get_downloader() downloader = await get_downloader()
success, content, headers = await downloader.download_to_memory( success, content, headers = await downloader.download_to_memory(
image_url, image_url,
use_auth=False # Preview images don't need auth use_auth=False, # Preview images don't need auth
) )
if success: if success:
# Ensure directory exists # Ensure directory exists
os.makedirs(os.path.dirname(save_path), exist_ok=True) os.makedirs(os.path.dirname(save_path), exist_ok=True)
with open(save_path, 'wb') as f: with open(save_path, "wb") as f:
f.write(content) f.write(content)
return True return True
return False return False
@@ -175,19 +190,17 @@ class CivitaiClient:
"""Get all versions of a model with local availability info""" """Get all versions of a model with local availability info"""
try: try:
success, result = await self._make_request( success, result = await self._make_request(
'GET', "GET", f"{self.base_url}/models/{model_id}", use_auth=True
f"{self.base_url}/models/{model_id}",
use_auth=True
) )
if success: if success:
# Also return model type along with versions # Also return model type along with versions
return { return {
'modelVersions': result.get('modelVersions', []), "modelVersions": result.get("modelVersions", []),
'type': result.get('type', ''), "type": result.get("type", ""),
'name': result.get('name', '') "name": result.get("name", ""),
} }
message = self._extract_error_message(result) message = self._extract_error_message(result)
if message and 'not found' in message.lower(): if message and "not found" in message.lower():
raise ResourceNotFoundError(f"Resource not found for model {model_id}") raise ResourceNotFoundError(f"Resource not found for model {model_id}")
if message: if message:
raise RuntimeError(message) raise RuntimeError(message)
@@ -221,15 +234,15 @@ class CivitaiClient:
try: try:
query = ",".join(normalized_ids) query = ",".join(normalized_ids)
success, result = await self._make_request( success, result = await self._make_request(
'GET', "GET",
f"{self.base_url}/models", f"{self.base_url}/models",
use_auth=True, use_auth=True,
params={'ids': query}, params={"ids": query},
) )
if not success: if not success:
return None return None
items = result.get('items') if isinstance(result, dict) else None items = result.get("items") if isinstance(result, dict) else None
if not isinstance(items, list): if not isinstance(items, list):
return {} return {}
@@ -237,19 +250,19 @@ class CivitaiClient:
for item in items: for item in items:
if not isinstance(item, dict): if not isinstance(item, dict):
continue continue
model_id = item.get('id') model_id = item.get("id")
try: try:
normalized_id = int(model_id) normalized_id = int(model_id)
except (TypeError, ValueError): except (TypeError, ValueError):
continue continue
payload[normalized_id] = { payload[normalized_id] = {
'modelVersions': item.get('modelVersions', []), "modelVersions": item.get("modelVersions", []),
'type': item.get('type', ''), "type": item.get("type", ""),
'name': item.get('name', ''), "name": item.get("name", ""),
'allowNoCredit': item.get('allowNoCredit'), "allowNoCredit": item.get("allowNoCredit"),
'allowCommercialUse': item.get('allowCommercialUse'), "allowCommercialUse": item.get("allowCommercialUse"),
'allowDerivatives': item.get('allowDerivatives'), "allowDerivatives": item.get("allowDerivatives"),
'allowDifferentLicense': item.get('allowDifferentLicense'), "allowDifferentLicense": item.get("allowDifferentLicense"),
} }
return payload return payload
except RateLimitError: except RateLimitError:
@@ -258,7 +271,9 @@ class CivitaiClient:
logger.error(f"Error fetching model versions in bulk: {exc}") logger.error(f"Error fetching model versions in bulk: {exc}")
return None return None
async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]: async def get_model_version(
self, model_id: int = None, version_id: int = None
) -> Optional[Dict]:
"""Get specific model version with additional metadata.""" """Get specific model version with additional metadata."""
try: try:
if model_id is None and version_id is not None: if model_id is None and version_id is not None:
@@ -281,7 +296,7 @@ class CivitaiClient:
if version is None: if version is None:
return None return None
model_id = version.get('modelId') model_id = version.get("modelId")
if not model_id: if not model_id:
logger.error(f"No modelId found in version {version_id}") logger.error(f"No modelId found in version {version_id}")
return None return None
@@ -293,7 +308,9 @@ class CivitaiClient:
self._remove_comfy_metadata(version) self._remove_comfy_metadata(version)
return version return version
async def _get_version_with_model_id(self, model_id: int, version_id: Optional[int]) -> Optional[Dict]: async def _get_version_with_model_id(
self, model_id: int, version_id: Optional[int]
) -> Optional[Dict]:
model_data = await self._fetch_model_data(model_id) model_data = await self._fetch_model_data(model_id)
if not model_data: if not model_data:
return None return None
@@ -302,8 +319,12 @@ class CivitaiClient:
if target_version is None: if target_version is None:
return None return None
target_version_id = target_version.get('id') target_version_id = target_version.get("id")
version = await self._fetch_version_by_id(target_version_id) if target_version_id else None version = (
await self._fetch_version_by_id(target_version_id)
if target_version_id
else None
)
if version is None: if version is None:
model_hash = self._extract_primary_model_hash(target_version) model_hash = self._extract_primary_model_hash(target_version)
@@ -315,7 +336,9 @@ class CivitaiClient:
) )
if version is None: if version is None:
version = self._build_version_from_model_data(target_version, model_id, model_data) version = self._build_version_from_model_data(
target_version, model_id, model_data
)
self._enrich_version_with_model_data(version, model_data) self._enrich_version_with_model_data(version, model_data)
self._remove_comfy_metadata(version) self._remove_comfy_metadata(version)
@@ -323,9 +346,7 @@ class CivitaiClient:
async def _fetch_model_data(self, model_id: int) -> Optional[Dict]: async def _fetch_model_data(self, model_id: int) -> Optional[Dict]:
success, data = await self._make_request( success, data = await self._make_request(
'GET', "GET", f"{self.base_url}/models/{model_id}", use_auth=True
f"{self.base_url}/models/{model_id}",
use_auth=True
) )
if success: if success:
return data return data
@@ -337,9 +358,7 @@ class CivitaiClient:
return None return None
success, version = await self._make_request( success, version = await self._make_request(
'GET', "GET", f"{self.base_url}/model-versions/{version_id}", use_auth=True
f"{self.base_url}/model-versions/{version_id}",
use_auth=True
) )
if success: if success:
return version return version
@@ -352,9 +371,7 @@ class CivitaiClient:
return None return None
success, version = await self._make_request( success, version = await self._make_request(
'GET', "GET", f"{self.base_url}/model-versions/by-hash/{model_hash}", use_auth=True
f"{self.base_url}/model-versions/by-hash/{model_hash}",
use_auth=True
) )
if success: if success:
return version return version
@@ -362,16 +379,17 @@ class CivitaiClient:
logger.warning(f"Failed to fetch version by hash {model_hash}") logger.warning(f"Failed to fetch version by hash {model_hash}")
return None return None
def _select_target_version(self, model_data: Dict, model_id: int, version_id: Optional[int]) -> Optional[Dict]: def _select_target_version(
model_versions = model_data.get('modelVersions', []) self, model_data: Dict, model_id: int, version_id: Optional[int]
) -> Optional[Dict]:
model_versions = model_data.get("modelVersions", [])
if not model_versions: if not model_versions:
logger.warning(f"No model versions found for model {model_id}") logger.warning(f"No model versions found for model {model_id}")
return None return None
if version_id is not None: if version_id is not None:
target_version = next( target_version = next(
(item for item in model_versions if item.get('id') == version_id), (item for item in model_versions if item.get("id") == version_id), None
None
) )
if target_version is None: if target_version is None:
logger.warning( logger.warning(
@@ -383,41 +401,45 @@ class CivitaiClient:
return model_versions[0] return model_versions[0]
def _extract_primary_model_hash(self, version_entry: Dict) -> Optional[str]: def _extract_primary_model_hash(self, version_entry: Dict) -> Optional[str]:
for file_info in version_entry.get('files', []): for file_info in version_entry.get("files", []):
if file_info.get('type') == 'Model' and file_info.get('primary'): if file_info.get("type") == "Model" and file_info.get("primary"):
hashes = file_info.get('hashes', {}) hashes = file_info.get("hashes", {})
model_hash = hashes.get('SHA256') model_hash = hashes.get("SHA256")
if model_hash: if model_hash:
return model_hash return model_hash
return None return None
def _build_version_from_model_data(self, version_entry: Dict, model_id: int, model_data: Dict) -> Dict: def _build_version_from_model_data(
self, version_entry: Dict, model_id: int, model_data: Dict
) -> Dict:
version = copy.deepcopy(version_entry) version = copy.deepcopy(version_entry)
version.pop('index', None) version.pop("index", None)
version['modelId'] = model_id version["modelId"] = model_id
version['model'] = { version["model"] = {
'name': model_data.get('name'), "name": model_data.get("name"),
'type': model_data.get('type'), "type": model_data.get("type"),
'nsfw': model_data.get('nsfw'), "nsfw": model_data.get("nsfw"),
'poi': model_data.get('poi') "poi": model_data.get("poi"),
} }
return version return version
def _enrich_version_with_model_data(self, version: Dict, model_data: Dict) -> None: def _enrich_version_with_model_data(self, version: Dict, model_data: Dict) -> None:
model_info = version.get('model') model_info = version.get("model")
if not isinstance(model_info, dict): if not isinstance(model_info, dict):
model_info = {} model_info = {}
version['model'] = model_info version["model"] = model_info
model_info['description'] = model_data.get("description") model_info["description"] = model_data.get("description")
model_info['tags'] = model_data.get("tags", []) model_info["tags"] = model_data.get("tags", [])
version['creator'] = model_data.get("creator") version["creator"] = model_data.get("creator")
license_payload = resolve_license_payload(model_data) license_payload = resolve_license_payload(model_data)
for field, value in license_payload.items(): for field, value in license_payload.items():
model_info[field] = value model_info[field] = value
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]: async def get_model_version_info(
self, version_id: str
) -> Tuple[Optional[Dict], Optional[str]]:
"""Fetch model version metadata from Civitai """Fetch model version metadata from Civitai
Args: Args:
@@ -432,14 +454,12 @@ class CivitaiClient:
url = f"{self.base_url}/model-versions/{version_id}" url = f"{self.base_url}/model-versions/{version_id}"
logger.debug(f"Resolving DNS for model version info: {url}") logger.debug(f"Resolving DNS for model version info: {url}")
success, result = await self._make_request( success, result = await self._make_request("GET", url, use_auth=True)
'GET',
url,
use_auth=True
)
if success: if success:
logger.debug(f"Successfully fetched model version info for: {version_id}") logger.debug(
f"Successfully fetched model version info for: {version_id}"
)
self._remove_comfy_metadata(result) self._remove_comfy_metadata(result)
return result, None return result, None
@@ -472,11 +492,7 @@ class CivitaiClient:
url = f"{self.base_url}/images?imageId={image_id}&nsfw=X" url = f"{self.base_url}/images?imageId={image_id}&nsfw=X"
logger.debug(f"Fetching image info for ID: {image_id}") logger.debug(f"Fetching image info for ID: {image_id}")
success, result = await self._make_request( success, result = await self._make_request("GET", url, use_auth=True)
'GET',
url,
use_auth=True
)
if success: if success:
if result and "items" in result and len(result["items"]) > 0: if result and "items" in result and len(result["items"]) > 0:
@@ -501,11 +517,7 @@ class CivitaiClient:
try: try:
url = f"{self.base_url}/models?username={username}" url = f"{self.base_url}/models?username={username}"
success, result = await self._make_request( success, result = await self._make_request("GET", url, use_auth=True)
'GET',
url,
use_auth=True
)
if not success: if not success:
logger.error("Failed to fetch models for %s: %s", username, result) logger.error("Failed to fetch models for %s: %s", username, result)