Fix null-safety issues and apply code formatting

Bug fixes:
- Add null guards for base_models_roots/embeddings_roots in backup cleanup
- Fix null-safety initialization of extra_unet_roots

Formatting:
- Apply consistent code style across Python files
- Fix line wrapping, quote consistency, and trailing commas
- Add type ignore comments for dynamic/platform-specific code
This commit is contained in:
Will Miao
2026-02-28 21:38:41 +08:00
parent b005961ee5
commit c9e5ea42cb
9 changed files with 1097 additions and 760 deletions

View File

@@ -2,7 +2,7 @@ import os
import platform import platform
import threading import threading
from pathlib import Path from pathlib import Path
import folder_paths # type: ignore import folder_paths # type: ignore
from typing import Any, Dict, Iterable, List, Mapping, Optional, Set, Tuple from typing import Any, Dict, Iterable, List, Mapping, Optional, Set, Tuple
import logging import logging
import json import json
@@ -10,16 +10,23 @@ import urllib.parse
import time import time
from .utils.cache_paths import CacheType, get_cache_file_path, get_legacy_cache_paths from .utils.cache_paths import CacheType, get_cache_file_path, get_legacy_cache_paths
from .utils.settings_paths import ensure_settings_file, get_settings_dir, load_settings_template from .utils.settings_paths import (
ensure_settings_file,
get_settings_dir,
load_settings_template,
)
# Use an environment variable to control standalone mode # Use an environment variable to control standalone mode
standalone_mode = os.environ.get("LORA_MANAGER_STANDALONE", "0") == "1" or os.environ.get("HF_HUB_DISABLE_TELEMETRY", "0") == "0" standalone_mode = (
os.environ.get("LORA_MANAGER_STANDALONE", "0") == "1"
or os.environ.get("HF_HUB_DISABLE_TELEMETRY", "0") == "0"
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _normalize_folder_paths_for_comparison( def _normalize_folder_paths_for_comparison(
folder_paths: Mapping[str, Iterable[str]] folder_paths: Mapping[str, Iterable[str]],
) -> Dict[str, Set[str]]: ) -> Dict[str, Set[str]]:
"""Normalize folder paths for comparison across libraries.""" """Normalize folder paths for comparison across libraries."""
@@ -49,7 +56,7 @@ def _normalize_folder_paths_for_comparison(
def _normalize_library_folder_paths( def _normalize_library_folder_paths(
library_payload: Mapping[str, Any] library_payload: Mapping[str, Any],
) -> Dict[str, Set[str]]: ) -> Dict[str, Set[str]]:
"""Return normalized folder paths extracted from a library payload.""" """Return normalized folder paths extracted from a library payload."""
@@ -74,11 +81,17 @@ def _get_template_folder_paths() -> Dict[str, Set[str]]:
class Config: class Config:
"""Global configuration for LoRA Manager""" """Global configuration for LoRA Manager"""
def __init__(self): def __init__(self):
self.templates_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'templates') self.templates_path = os.path.join(
self.static_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'static') os.path.dirname(os.path.dirname(__file__)), "templates"
self.i18n_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'locales') )
self.static_path = os.path.join(
os.path.dirname(os.path.dirname(__file__)), "static"
)
self.i18n_path = os.path.join(
os.path.dirname(os.path.dirname(__file__)), "locales"
)
# Path mapping dictionary, target to link mapping # Path mapping dictionary, target to link mapping
self._path_mappings: Dict[str, str] = {} self._path_mappings: Dict[str, str] = {}
# Normalized preview root directories used to validate preview access # Normalized preview root directories used to validate preview access
@@ -98,7 +111,7 @@ class Config:
self.extra_embeddings_roots: List[str] = [] self.extra_embeddings_roots: List[str] = []
# Scan symbolic links during initialization # Scan symbolic links during initialization
self._initialize_symlink_mappings() self._initialize_symlink_mappings()
if not standalone_mode: if not standalone_mode:
# Save the paths to settings.json when running in ComfyUI mode # Save the paths to settings.json when running in ComfyUI mode
self.save_folder_paths_to_settings() self.save_folder_paths_to_settings()
@@ -152,17 +165,21 @@ class Config:
default_library = libraries.get("default", {}) default_library = libraries.get("default", {})
target_folder_paths = { target_folder_paths = {
'loras': list(self.loras_roots), "loras": list(self.loras_roots),
'checkpoints': list(self.checkpoints_roots or []), "checkpoints": list(self.checkpoints_roots or []),
'unet': list(self.unet_roots or []), "unet": list(self.unet_roots or []),
'embeddings': list(self.embeddings_roots or []), "embeddings": list(self.embeddings_roots or []),
} }
normalized_target_paths = _normalize_folder_paths_for_comparison(target_folder_paths) normalized_target_paths = _normalize_folder_paths_for_comparison(
target_folder_paths
)
normalized_default_paths: Optional[Dict[str, Set[str]]] = None normalized_default_paths: Optional[Dict[str, Set[str]]] = None
if isinstance(default_library, Mapping): if isinstance(default_library, Mapping):
normalized_default_paths = _normalize_library_folder_paths(default_library) normalized_default_paths = _normalize_library_folder_paths(
default_library
)
if ( if (
not comfy_library not comfy_library
@@ -185,13 +202,19 @@ class Config:
default_lora_root = self.loras_roots[0] default_lora_root = self.loras_roots[0]
default_checkpoint_root = comfy_library.get("default_checkpoint_root", "") default_checkpoint_root = comfy_library.get("default_checkpoint_root", "")
if (not default_checkpoint_root and self.checkpoints_roots and if (
len(self.checkpoints_roots) == 1): not default_checkpoint_root
and self.checkpoints_roots
and len(self.checkpoints_roots) == 1
):
default_checkpoint_root = self.checkpoints_roots[0] default_checkpoint_root = self.checkpoints_roots[0]
default_embedding_root = comfy_library.get("default_embedding_root", "") default_embedding_root = comfy_library.get("default_embedding_root", "")
if (not default_embedding_root and self.embeddings_roots and if (
len(self.embeddings_roots) == 1): not default_embedding_root
and self.embeddings_roots
and len(self.embeddings_roots) == 1
):
default_embedding_root = self.embeddings_roots[0] default_embedding_root = self.embeddings_roots[0]
metadata = dict(comfy_library.get("metadata", {})) metadata = dict(comfy_library.get("metadata", {}))
@@ -216,11 +239,12 @@ class Config:
try: try:
if os.path.islink(path): if os.path.islink(path):
return True return True
if platform.system() == 'Windows': if platform.system() == "Windows":
try: try:
import ctypes import ctypes
FILE_ATTRIBUTE_REPARSE_POINT = 0x400 FILE_ATTRIBUTE_REPARSE_POINT = 0x400
attrs = ctypes.windll.kernel32.GetFileAttributesW(str(path)) attrs = ctypes.windll.kernel32.GetFileAttributesW(str(path)) # type: ignore[attr-defined]
return attrs != -1 and (attrs & FILE_ATTRIBUTE_REPARSE_POINT) return attrs != -1 and (attrs & FILE_ATTRIBUTE_REPARSE_POINT)
except Exception as e: except Exception as e:
logger.error(f"Error checking Windows reparse point: {e}") logger.error(f"Error checking Windows reparse point: {e}")
@@ -233,18 +257,19 @@ class Config:
"""Check if a directory entry is a symlink, including Windows junctions.""" """Check if a directory entry is a symlink, including Windows junctions."""
if entry.is_symlink(): if entry.is_symlink():
return True return True
if platform.system() == 'Windows': if platform.system() == "Windows":
try: try:
import ctypes import ctypes
FILE_ATTRIBUTE_REPARSE_POINT = 0x400 FILE_ATTRIBUTE_REPARSE_POINT = 0x400
attrs = ctypes.windll.kernel32.GetFileAttributesW(entry.path) attrs = ctypes.windll.kernel32.GetFileAttributesW(entry.path) # type: ignore[attr-defined]
return attrs != -1 and (attrs & FILE_ATTRIBUTE_REPARSE_POINT) return attrs != -1 and (attrs & FILE_ATTRIBUTE_REPARSE_POINT)
except Exception: except Exception:
pass pass
return False return False
def _normalize_path(self, path: str) -> str: def _normalize_path(self, path: str) -> str:
return os.path.normpath(path).replace(os.sep, '/') return os.path.normpath(path).replace(os.sep, "/")
def _get_symlink_cache_path(self) -> Path: def _get_symlink_cache_path(self) -> Path:
canonical_path = get_cache_file_path(CacheType.SYMLINK, create_dir=True) canonical_path = get_cache_file_path(CacheType.SYMLINK, create_dir=True)
@@ -278,19 +303,18 @@ class Config:
if self._entry_is_symlink(entry): if self._entry_is_symlink(entry):
try: try:
target = os.path.realpath(entry.path) target = os.path.realpath(entry.path)
direct_symlinks.append([ direct_symlinks.append(
self._normalize_path(entry.path), [
self._normalize_path(target) self._normalize_path(entry.path),
]) self._normalize_path(target),
]
)
except OSError: except OSError:
pass pass
except (OSError, PermissionError): except (OSError, PermissionError):
pass pass
return { return {"roots": unique_roots, "direct_symlinks": sorted(direct_symlinks)}
"roots": unique_roots,
"direct_symlinks": sorted(direct_symlinks)
}
def _initialize_symlink_mappings(self) -> None: def _initialize_symlink_mappings(self) -> None:
start = time.perf_counter() start = time.perf_counter()
@@ -307,10 +331,14 @@ class Config:
cached_fingerprint = self._cached_fingerprint cached_fingerprint = self._cached_fingerprint
# Check 1: First-level symlinks unchanged (catches new symlinks at root) # Check 1: First-level symlinks unchanged (catches new symlinks at root)
fingerprint_valid = cached_fingerprint and current_fingerprint == cached_fingerprint fingerprint_valid = (
cached_fingerprint and current_fingerprint == cached_fingerprint
)
# Check 2: All cached mappings still valid (catches changes at any depth) # Check 2: All cached mappings still valid (catches changes at any depth)
mappings_valid = self._validate_cached_mappings() if fingerprint_valid else False mappings_valid = (
self._validate_cached_mappings() if fingerprint_valid else False
)
if fingerprint_valid and mappings_valid: if fingerprint_valid and mappings_valid:
return return
@@ -370,7 +398,9 @@ class Config:
for target, link in cached_mappings.items(): for target, link in cached_mappings.items():
if not isinstance(target, str) or not isinstance(link, str): if not isinstance(target, str) or not isinstance(link, str):
continue continue
normalized_mappings[self._normalize_path(target)] = self._normalize_path(link) normalized_mappings[self._normalize_path(target)] = self._normalize_path(
link
)
self._path_mappings = normalized_mappings self._path_mappings = normalized_mappings
@@ -391,7 +421,9 @@ class Config:
parent_dir = loaded_path.parent parent_dir = loaded_path.parent
if parent_dir.name == "cache" and not any(parent_dir.iterdir()): if parent_dir.name == "cache" and not any(parent_dir.iterdir()):
parent_dir.rmdir() parent_dir.rmdir()
logger.info("Removed empty legacy cache directory: %s", parent_dir) logger.info(
"Removed empty legacy cache directory: %s", parent_dir
)
except Exception: except Exception:
pass pass
@@ -402,7 +434,9 @@ class Config:
exc, exc,
) )
else: else:
logger.info("Symlink cache loaded with %d mappings", len(self._path_mappings)) logger.info(
"Symlink cache loaded with %d mappings", len(self._path_mappings)
)
return True return True
@@ -414,7 +448,7 @@ class Config:
""" """
for target, link in self._path_mappings.items(): for target, link in self._path_mappings.items():
# Convert normalized paths back to OS paths # Convert normalized paths back to OS paths
link_path = link.replace('/', os.sep) link_path = link.replace("/", os.sep)
# Check if symlink still exists # Check if symlink still exists
if not self._is_link(link_path): if not self._is_link(link_path):
@@ -427,7 +461,9 @@ class Config:
if actual_target != target: if actual_target != target:
logger.debug( logger.debug(
"Symlink target changed: %s -> %s (cached: %s)", "Symlink target changed: %s -> %s (cached: %s)",
link_path, actual_target, target link_path,
actual_target,
target,
) )
return False return False
except OSError: except OSError:
@@ -446,7 +482,11 @@ class Config:
try: try:
with cache_path.open("w", encoding="utf-8") as handle: with cache_path.open("w", encoding="utf-8") as handle:
json.dump(payload, handle, ensure_ascii=False, indent=2) json.dump(payload, handle, ensure_ascii=False, indent=2)
logger.debug("Symlink cache saved to %s with %d mappings", cache_path, len(self._path_mappings)) logger.debug(
"Symlink cache saved to %s with %d mappings",
cache_path,
len(self._path_mappings),
)
except Exception as exc: except Exception as exc:
logger.info("Failed to write symlink cache %s: %s", cache_path, exc) logger.info("Failed to write symlink cache %s: %s", cache_path, exc)
@@ -458,7 +498,7 @@ class Config:
at the root level only (not nested symlinks in subdirectories). at the root level only (not nested symlinks in subdirectories).
""" """
start = time.perf_counter() start = time.perf_counter()
# Reset mappings before rescanning to avoid stale entries # Reset mappings before rescanning to avoid stale entries
self._path_mappings.clear() self._path_mappings.clear()
self._seed_root_symlink_mappings() self._seed_root_symlink_mappings()
@@ -472,7 +512,7 @@ class Config:
def _scan_first_level_symlinks(self, root: str): def _scan_first_level_symlinks(self, root: str):
"""Scan only the first level of a directory for symlinks. """Scan only the first level of a directory for symlinks.
This avoids traversing the entire directory tree which can be extremely This avoids traversing the entire directory tree which can be extremely
slow for large model collections. Only symlinks directly under the root slow for large model collections. Only symlinks directly under the root
are detected. are detected.
@@ -494,13 +534,13 @@ class Config:
self.add_path_mapping(entry.path, target_path) self.add_path_mapping(entry.path, target_path)
except Exception as inner_exc: except Exception as inner_exc:
logger.debug( logger.debug(
"Error processing directory entry %s: %s", entry.path, inner_exc "Error processing directory entry %s: %s",
entry.path,
inner_exc,
) )
except Exception as e: except Exception as e:
logger.error(f"Error scanning links in {root}: {e}") logger.error(f"Error scanning links in {root}: {e}")
def add_path_mapping(self, link_path: str, target_path: str): def add_path_mapping(self, link_path: str, target_path: str):
"""Add a symbolic link path mapping """Add a symbolic link path mapping
target_path: actual target path target_path: actual target path
@@ -594,41 +634,46 @@ class Config:
preview_roots.update(self._expand_preview_root(target)) preview_roots.update(self._expand_preview_root(target))
preview_roots.update(self._expand_preview_root(link)) preview_roots.update(self._expand_preview_root(link))
self._preview_root_paths = {path for path in preview_roots if path.is_absolute()} self._preview_root_paths = {
path for path in preview_roots if path.is_absolute()
}
logger.debug( logger.debug(
"Preview roots rebuilt: %d paths from %d lora roots (%d extra), %d checkpoint roots (%d extra), %d embedding roots (%d extra), %d symlink mappings", "Preview roots rebuilt: %d paths from %d lora roots (%d extra), %d checkpoint roots (%d extra), %d embedding roots (%d extra), %d symlink mappings",
len(self._preview_root_paths), len(self._preview_root_paths),
len(self.loras_roots or []), len(self.extra_loras_roots or []), len(self.loras_roots or []),
len(self.base_models_roots or []), len(self.extra_checkpoints_roots or []), len(self.extra_loras_roots or []),
len(self.embeddings_roots or []), len(self.extra_embeddings_roots or []), len(self.base_models_roots or []),
len(self.extra_checkpoints_roots or []),
len(self.embeddings_roots or []),
len(self.extra_embeddings_roots or []),
len(self._path_mappings), len(self._path_mappings),
) )
def map_path_to_link(self, path: str) -> str: def map_path_to_link(self, path: str) -> str:
"""Map a target path back to its symbolic link path""" """Map a target path back to its symbolic link path"""
normalized_path = os.path.normpath(path).replace(os.sep, '/') normalized_path = os.path.normpath(path).replace(os.sep, "/")
# Check if the path is contained in any mapped target path # Check if the path is contained in any mapped target path
for target_path, link_path in self._path_mappings.items(): for target_path, link_path in self._path_mappings.items():
# Match whole path components to avoid prefix collisions (e.g., /a/b vs /a/bc) # Match whole path components to avoid prefix collisions (e.g., /a/b vs /a/bc)
if normalized_path == target_path: if normalized_path == target_path:
return link_path return link_path
if normalized_path.startswith(target_path + '/'): if normalized_path.startswith(target_path + "/"):
# If the path starts with the target path, replace with link path # If the path starts with the target path, replace with link path
mapped_path = normalized_path.replace(target_path, link_path, 1) mapped_path = normalized_path.replace(target_path, link_path, 1)
return mapped_path return mapped_path
return normalized_path return normalized_path
def map_link_to_path(self, link_path: str) -> str: def map_link_to_path(self, link_path: str) -> str:
"""Map a symbolic link path back to the actual path""" """Map a symbolic link path back to the actual path"""
normalized_link = os.path.normpath(link_path).replace(os.sep, '/') normalized_link = os.path.normpath(link_path).replace(os.sep, "/")
# Check if the path is contained in any mapped target path # Check if the path is contained in any mapped target path
for target_path, link_path_mapped in self._path_mappings.items(): for target_path, link_path_mapped in self._path_mappings.items():
# Match whole path components # Match whole path components
if normalized_link == link_path_mapped: if normalized_link == link_path_mapped:
return target_path return target_path
if normalized_link.startswith(link_path_mapped + '/'): if normalized_link.startswith(link_path_mapped + "/"):
# If the path starts with the link path, replace with actual path # If the path starts with the link path, replace with actual path
mapped_path = normalized_link.replace(link_path_mapped, target_path, 1) mapped_path = normalized_link.replace(link_path_mapped, target_path, 1)
return mapped_path return mapped_path
@@ -641,8 +686,8 @@ class Config:
continue continue
if not os.path.exists(path): if not os.path.exists(path):
continue continue
real_path = os.path.normpath(os.path.realpath(path)).replace(os.sep, '/') real_path = os.path.normpath(os.path.realpath(path)).replace(os.sep, "/")
normalized = os.path.normpath(path).replace(os.sep, '/') normalized = os.path.normpath(path).replace(os.sep, "/")
if real_path not in dedup: if real_path not in dedup:
dedup[real_path] = normalized dedup[real_path] = normalized
return dedup return dedup
@@ -652,7 +697,9 @@ class Config:
unique_paths = sorted(path_map.values(), key=lambda p: p.lower()) unique_paths = sorted(path_map.values(), key=lambda p: p.lower())
for original_path in unique_paths: for original_path in unique_paths:
real_path = os.path.normpath(os.path.realpath(original_path)).replace(os.sep, '/') real_path = os.path.normpath(os.path.realpath(original_path)).replace(
os.sep, "/"
)
if real_path != original_path: if real_path != original_path:
self.add_path_mapping(original_path, real_path) self.add_path_mapping(original_path, real_path)
@@ -674,7 +721,7 @@ class Config:
"Please fix your ComfyUI path configuration to separate these folders. " "Please fix your ComfyUI path configuration to separate these folders. "
"Falling back to 'checkpoints' for backward compatibility. " "Falling back to 'checkpoints' for backward compatibility. "
"Overlapping real paths: %s", "Overlapping real paths: %s",
[checkpoint_map.get(rp, rp) for rp in overlapping_real_paths] [checkpoint_map.get(rp, rp) for rp in overlapping_real_paths],
) )
# Remove overlapping paths from unet_map to prioritize checkpoints # Remove overlapping paths from unet_map to prioritize checkpoints
for rp in overlapping_real_paths: for rp in overlapping_real_paths:
@@ -694,7 +741,9 @@ class Config:
self.unet_roots = [p for p in unique_paths if p in unet_values] self.unet_roots = [p for p in unique_paths if p in unet_values]
for original_path in unique_paths: for original_path in unique_paths:
real_path = os.path.normpath(os.path.realpath(original_path)).replace(os.sep, '/') real_path = os.path.normpath(os.path.realpath(original_path)).replace(
os.sep, "/"
)
if real_path != original_path: if real_path != original_path:
self.add_path_mapping(original_path, real_path) self.add_path_mapping(original_path, real_path)
@@ -705,7 +754,9 @@ class Config:
unique_paths = sorted(path_map.values(), key=lambda p: p.lower()) unique_paths = sorted(path_map.values(), key=lambda p: p.lower())
for original_path in unique_paths: for original_path in unique_paths:
real_path = os.path.normpath(os.path.realpath(original_path)).replace(os.sep, '/') real_path = os.path.normpath(os.path.realpath(original_path)).replace(
os.sep, "/"
)
if real_path != original_path: if real_path != original_path:
self.add_path_mapping(original_path, real_path) self.add_path_mapping(original_path, real_path)
@@ -719,42 +770,66 @@ class Config:
self._path_mappings.clear() self._path_mappings.clear()
self._preview_root_paths = set() self._preview_root_paths = set()
lora_paths = folder_paths.get('loras', []) or [] lora_paths = folder_paths.get("loras", []) or []
checkpoint_paths = folder_paths.get('checkpoints', []) or [] checkpoint_paths = folder_paths.get("checkpoints", []) or []
unet_paths = folder_paths.get('unet', []) or [] unet_paths = folder_paths.get("unet", []) or []
embedding_paths = folder_paths.get('embeddings', []) or [] embedding_paths = folder_paths.get("embeddings", []) or []
self.loras_roots = self._prepare_lora_paths(lora_paths) self.loras_roots = self._prepare_lora_paths(lora_paths)
self.base_models_roots = self._prepare_checkpoint_paths(checkpoint_paths, unet_paths) self.base_models_roots = self._prepare_checkpoint_paths(
checkpoint_paths, unet_paths
)
self.embeddings_roots = self._prepare_embedding_paths(embedding_paths) self.embeddings_roots = self._prepare_embedding_paths(embedding_paths)
# Process extra paths (only for LoRA Manager, not shared with ComfyUI) # Process extra paths (only for LoRA Manager, not shared with ComfyUI)
extra_paths = extra_folder_paths or {} extra_paths = extra_folder_paths or {}
extra_lora_paths = extra_paths.get('loras', []) or [] extra_lora_paths = extra_paths.get("loras", []) or []
extra_checkpoint_paths = extra_paths.get('checkpoints', []) or [] extra_checkpoint_paths = extra_paths.get("checkpoints", []) or []
extra_unet_paths = extra_paths.get('unet', []) or [] extra_unet_paths = extra_paths.get("unet", []) or []
extra_embedding_paths = extra_paths.get('embeddings', []) or [] extra_embedding_paths = extra_paths.get("embeddings", []) or []
self.extra_loras_roots = self._prepare_lora_paths(extra_lora_paths) self.extra_loras_roots = self._prepare_lora_paths(extra_lora_paths)
# Save main paths before processing extra paths ( _prepare_checkpoint_paths overwrites them) # Save main paths before processing extra paths ( _prepare_checkpoint_paths overwrites them)
saved_checkpoints_roots = self.checkpoints_roots saved_checkpoints_roots = self.checkpoints_roots
saved_unet_roots = self.unet_roots saved_unet_roots = self.unet_roots
self.extra_checkpoints_roots = self._prepare_checkpoint_paths(extra_checkpoint_paths, extra_unet_paths) self.extra_checkpoints_roots = self._prepare_checkpoint_paths(
self.extra_unet_roots = self.unet_roots # unet_roots was set by _prepare_checkpoint_paths extra_checkpoint_paths, extra_unet_paths
)
self.extra_unet_roots = (
self.unet_roots if self.unet_roots is not None else []
) # unet_roots was set by _prepare_checkpoint_paths
# Restore main paths # Restore main paths
self.checkpoints_roots = saved_checkpoints_roots self.checkpoints_roots = saved_checkpoints_roots
self.unet_roots = saved_unet_roots self.unet_roots = saved_unet_roots
self.extra_embeddings_roots = self._prepare_embedding_paths(extra_embedding_paths) self.extra_embeddings_roots = self._prepare_embedding_paths(
extra_embedding_paths
)
# Log extra folder paths # Log extra folder paths
if self.extra_loras_roots: if self.extra_loras_roots:
logger.info("Found extra LoRA roots:" + "\n - " + "\n - ".join(self.extra_loras_roots)) logger.info(
"Found extra LoRA roots:"
+ "\n - "
+ "\n - ".join(self.extra_loras_roots)
)
if self.extra_checkpoints_roots: if self.extra_checkpoints_roots:
logger.info("Found extra checkpoint roots:" + "\n - " + "\n - ".join(self.extra_checkpoints_roots)) logger.info(
"Found extra checkpoint roots:"
+ "\n - "
+ "\n - ".join(self.extra_checkpoints_roots)
)
if self.extra_unet_roots: if self.extra_unet_roots:
logger.info("Found extra diffusion model roots:" + "\n - " + "\n - ".join(self.extra_unet_roots)) logger.info(
"Found extra diffusion model roots:"
+ "\n - "
+ "\n - ".join(self.extra_unet_roots)
)
if self.extra_embeddings_roots: if self.extra_embeddings_roots:
logger.info("Found extra embedding roots:" + "\n - " + "\n - ".join(self.extra_embeddings_roots)) logger.info(
"Found extra embedding roots:"
+ "\n - "
+ "\n - ".join(self.extra_embeddings_roots)
)
self._initialize_symlink_mappings() self._initialize_symlink_mappings()
@@ -763,7 +838,10 @@ class Config:
try: try:
raw_paths = folder_paths.get_folder_paths("loras") raw_paths = folder_paths.get_folder_paths("loras")
unique_paths = self._prepare_lora_paths(raw_paths) unique_paths = self._prepare_lora_paths(raw_paths)
logger.info("Found LoRA roots:" + ("\n - " + "\n - ".join(unique_paths) if unique_paths else "[]")) logger.info(
"Found LoRA roots:"
+ ("\n - " + "\n - ".join(unique_paths) if unique_paths else "[]")
)
if not unique_paths: if not unique_paths:
logger.warning("No valid loras folders found in ComfyUI configuration") logger.warning("No valid loras folders found in ComfyUI configuration")
@@ -779,12 +857,19 @@ class Config:
try: try:
raw_checkpoint_paths = folder_paths.get_folder_paths("checkpoints") raw_checkpoint_paths = folder_paths.get_folder_paths("checkpoints")
raw_unet_paths = folder_paths.get_folder_paths("unet") raw_unet_paths = folder_paths.get_folder_paths("unet")
unique_paths = self._prepare_checkpoint_paths(raw_checkpoint_paths, raw_unet_paths) unique_paths = self._prepare_checkpoint_paths(
raw_checkpoint_paths, raw_unet_paths
)
logger.info("Found checkpoint roots:" + ("\n - " + "\n - ".join(unique_paths) if unique_paths else "[]")) logger.info(
"Found checkpoint roots:"
+ ("\n - " + "\n - ".join(unique_paths) if unique_paths else "[]")
)
if not unique_paths: if not unique_paths:
logger.warning("No valid checkpoint folders found in ComfyUI configuration") logger.warning(
"No valid checkpoint folders found in ComfyUI configuration"
)
return [] return []
return unique_paths return unique_paths
@@ -797,10 +882,15 @@ class Config:
try: try:
raw_paths = folder_paths.get_folder_paths("embeddings") raw_paths = folder_paths.get_folder_paths("embeddings")
unique_paths = self._prepare_embedding_paths(raw_paths) unique_paths = self._prepare_embedding_paths(raw_paths)
logger.info("Found embedding roots:" + ("\n - " + "\n - ".join(unique_paths) if unique_paths else "[]")) logger.info(
"Found embedding roots:"
+ ("\n - " + "\n - ".join(unique_paths) if unique_paths else "[]")
)
if not unique_paths: if not unique_paths:
logger.warning("No valid embeddings folders found in ComfyUI configuration") logger.warning(
"No valid embeddings folders found in ComfyUI configuration"
)
return [] return []
return unique_paths return unique_paths
@@ -812,13 +902,13 @@ class Config:
if not preview_path: if not preview_path:
return "" return ""
normalized = os.path.normpath(preview_path).replace(os.sep, '/') normalized = os.path.normpath(preview_path).replace(os.sep, "/")
encoded_path = urllib.parse.quote(normalized, safe='') encoded_path = urllib.parse.quote(normalized, safe="")
return f'/api/lm/previews?path={encoded_path}' return f"/api/lm/previews?path={encoded_path}"
def is_preview_path_allowed(self, preview_path: str) -> bool: def is_preview_path_allowed(self, preview_path: str) -> bool:
"""Return ``True`` if ``preview_path`` is within an allowed directory. """Return ``True`` if ``preview_path`` is within an allowed directory.
If the path is initially rejected, attempts to discover deep symlinks If the path is initially rejected, attempts to discover deep symlinks
that were not scanned during initialization. If a symlink is found, that were not scanned during initialization. If a symlink is found,
updates the in-memory path mappings and retries the check. updates the in-memory path mappings and retries the check.
@@ -889,14 +979,18 @@ class Config:
normalized_link = self._normalize_path(str(current)) normalized_link = self._normalize_path(str(current))
self._path_mappings[normalized_target] = normalized_link self._path_mappings[normalized_target] = normalized_link
self._preview_root_paths.update(self._expand_preview_root(normalized_target)) self._preview_root_paths.update(
self._preview_root_paths.update(self._expand_preview_root(normalized_link)) self._expand_preview_root(normalized_target)
)
self._preview_root_paths.update(
self._expand_preview_root(normalized_link)
)
logger.debug( logger.debug(
"Discovered deep symlink: %s -> %s (preview path: %s)", "Discovered deep symlink: %s -> %s (preview path: %s)",
normalized_link, normalized_link,
normalized_target, normalized_target,
preview_path preview_path,
) )
return True return True
@@ -914,8 +1008,16 @@ class Config:
def apply_library_settings(self, library_config: Mapping[str, object]) -> None: def apply_library_settings(self, library_config: Mapping[str, object]) -> None:
"""Update runtime paths to match the provided library configuration.""" """Update runtime paths to match the provided library configuration."""
folder_paths = library_config.get('folder_paths') if isinstance(library_config, Mapping) else {} folder_paths = (
extra_folder_paths = library_config.get('extra_folder_paths') if isinstance(library_config, Mapping) else None library_config.get("folder_paths")
if isinstance(library_config, Mapping)
else {}
)
extra_folder_paths = (
library_config.get("extra_folder_paths")
if isinstance(library_config, Mapping)
else None
)
if not isinstance(folder_paths, Mapping): if not isinstance(folder_paths, Mapping):
folder_paths = {} folder_paths = {}
if not isinstance(extra_folder_paths, Mapping): if not isinstance(extra_folder_paths, Mapping):
@@ -925,9 +1027,12 @@ class Config:
logger.info( logger.info(
"Applied library settings with %d lora roots (%d extra), %d checkpoint roots (%d extra), and %d embedding roots (%d extra)", "Applied library settings with %d lora roots (%d extra), %d checkpoint roots (%d extra), and %d embedding roots (%d extra)",
len(self.loras_roots or []), len(self.extra_loras_roots or []), len(self.loras_roots or []),
len(self.base_models_roots or []), len(self.extra_checkpoints_roots or []), len(self.extra_loras_roots or []),
len(self.embeddings_roots or []), len(self.extra_embeddings_roots or []), len(self.base_models_roots or []),
len(self.extra_checkpoints_roots or []),
len(self.embeddings_roots or []),
len(self.extra_embeddings_roots or []),
) )
def get_library_registry_snapshot(self) -> Dict[str, object]: def get_library_registry_snapshot(self) -> Dict[str, object]:
@@ -947,5 +1052,6 @@ class Config:
logger.debug("Failed to collect library registry snapshot: %s", exc) logger.debug("Failed to collect library registry snapshot: %s", exc)
return {"active_library": "", "libraries": {}} return {"active_library": "", "libraries": {}}
# Global config instance # Global config instance
config = Config() config = Config()

View File

@@ -5,16 +5,22 @@ import logging
from .utils.logging_config import setup_logging from .utils.logging_config import setup_logging
# Check if we're in standalone mode # Check if we're in standalone mode
standalone_mode = os.environ.get("LORA_MANAGER_STANDALONE", "0") == "1" or os.environ.get("HF_HUB_DISABLE_TELEMETRY", "0") == "0" standalone_mode = (
os.environ.get("LORA_MANAGER_STANDALONE", "0") == "1"
or os.environ.get("HF_HUB_DISABLE_TELEMETRY", "0") == "0"
)
# Only setup logging prefix if not in standalone mode # Only setup logging prefix if not in standalone mode
if not standalone_mode: if not standalone_mode:
setup_logging() setup_logging()
from server import PromptServer # type: ignore from server import PromptServer # type: ignore
from .config import config from .config import config
from .services.model_service_factory import ModelServiceFactory, register_default_model_types from .services.model_service_factory import (
ModelServiceFactory,
register_default_model_types,
)
from .routes.recipe_routes import RecipeRoutes from .routes.recipe_routes import RecipeRoutes
from .routes.stats_routes import StatsRoutes from .routes.stats_routes import StatsRoutes
from .routes.update_routes import UpdateRoutes from .routes.update_routes import UpdateRoutes
@@ -61,9 +67,10 @@ class _SettingsProxy:
settings = _SettingsProxy() settings = _SettingsProxy()
class LoraManager: class LoraManager:
"""Main entry point for LoRA Manager plugin""" """Main entry point for LoRA Manager plugin"""
@classmethod @classmethod
def add_routes(cls): def add_routes(cls):
"""Initialize and register all routes using the new refactored architecture""" """Initialize and register all routes using the new refactored architecture"""
@@ -76,7 +83,8 @@ class LoraManager:
( (
idx idx
for idx, middleware in enumerate(app.middlewares) for idx, middleware in enumerate(app.middlewares)
if getattr(middleware, "__name__", "") == "block_external_middleware" if getattr(middleware, "__name__", "")
== "block_external_middleware"
), ),
None, None,
) )
@@ -84,7 +92,9 @@ class LoraManager:
if block_middleware_index is None: if block_middleware_index is None:
app.middlewares.append(relax_csp_for_remote_media) app.middlewares.append(relax_csp_for_remote_media)
else: else:
app.middlewares.insert(block_middleware_index, relax_csp_for_remote_media) app.middlewares.insert(
block_middleware_index, relax_csp_for_remote_media
)
# Increase allowed header sizes so browsers with large localhost cookie # Increase allowed header sizes so browsers with large localhost cookie
# jars (multiple UIs on 127.0.0.1) don't trip aiohttp's 8KB default # jars (multiple UIs on 127.0.0.1) don't trip aiohttp's 8KB default
@@ -105,7 +115,7 @@ class LoraManager:
app._handler_args = updated_handler_args app._handler_args = updated_handler_args
# Configure aiohttp access logger to be less verbose # Configure aiohttp access logger to be less verbose
logging.getLogger('aiohttp.access').setLevel(logging.WARNING) logging.getLogger("aiohttp.access").setLevel(logging.WARNING)
# Add specific suppression for connection reset errors # Add specific suppression for connection reset errors
class ConnectionResetFilter(logging.Filter): class ConnectionResetFilter(logging.Filter):
@@ -124,46 +134,52 @@ class LoraManager:
asyncio_logger.addFilter(ConnectionResetFilter()) asyncio_logger.addFilter(ConnectionResetFilter())
# Add static route for example images if the path exists in settings # Add static route for example images if the path exists in settings
example_images_path = settings.get('example_images_path') example_images_path = settings.get("example_images_path")
logger.info(f"Example images path: {example_images_path}") logger.info(f"Example images path: {example_images_path}")
if example_images_path and os.path.exists(example_images_path): if example_images_path and os.path.exists(example_images_path):
app.router.add_static('/example_images_static', example_images_path) app.router.add_static("/example_images_static", example_images_path)
logger.info(f"Added static route for example images: /example_images_static -> {example_images_path}") logger.info(
f"Added static route for example images: /example_images_static -> {example_images_path}"
)
# Add static route for locales JSON files # Add static route for locales JSON files
if os.path.exists(config.i18n_path): if os.path.exists(config.i18n_path):
app.router.add_static('/locales', config.i18n_path) app.router.add_static("/locales", config.i18n_path)
logger.info(f"Added static route for locales: /locales -> {config.i18n_path}") logger.info(
f"Added static route for locales: /locales -> {config.i18n_path}"
)
# Add static route for plugin assets # Add static route for plugin assets
app.router.add_static('/loras_static', config.static_path) app.router.add_static("/loras_static", config.static_path)
# Register default model types with the factory # Register default model types with the factory
register_default_model_types() register_default_model_types()
# Setup all model routes using the factory # Setup all model routes using the factory
ModelServiceFactory.setup_all_routes(app) ModelServiceFactory.setup_all_routes(app)
# Setup non-model-specific routes # Setup non-model-specific routes
stats_routes = StatsRoutes() stats_routes = StatsRoutes()
stats_routes.setup_routes(app) stats_routes.setup_routes(app)
RecipeRoutes.setup_routes(app) RecipeRoutes.setup_routes(app)
UpdateRoutes.setup_routes(app) UpdateRoutes.setup_routes(app)
MiscRoutes.setup_routes(app) MiscRoutes.setup_routes(app)
ExampleImagesRoutes.setup_routes(app, ws_manager=ws_manager) ExampleImagesRoutes.setup_routes(app, ws_manager=ws_manager)
PreviewRoutes.setup_routes(app) PreviewRoutes.setup_routes(app)
# Setup WebSocket routes that are shared across all model types # Setup WebSocket routes that are shared across all model types
app.router.add_get('/ws/fetch-progress', ws_manager.handle_connection) app.router.add_get("/ws/fetch-progress", ws_manager.handle_connection)
app.router.add_get('/ws/download-progress', ws_manager.handle_download_connection) app.router.add_get(
app.router.add_get('/ws/init-progress', ws_manager.handle_init_connection) "/ws/download-progress", ws_manager.handle_download_connection
)
# Schedule service initialization app.router.add_get("/ws/init-progress", ws_manager.handle_init_connection)
# Schedule service initialization
app.on_startup.append(lambda app: cls._initialize_services()) app.on_startup.append(lambda app: cls._initialize_services())
# Add cleanup # Add cleanup
app.on_shutdown.append(cls._cleanup) app.on_shutdown.append(cls._cleanup)
@classmethod @classmethod
async def _initialize_services(cls): async def _initialize_services(cls):
"""Initialize all services using the ServiceRegistry""" """Initialize all services using the ServiceRegistry"""
@@ -197,7 +213,9 @@ class LoraManager:
extra_paths.get("embeddings", []), extra_paths.get("embeddings", []),
) )
except Exception as exc: except Exception as exc:
logger.warning("Failed to apply library settings during initialization: %s", exc) logger.warning(
"Failed to apply library settings during initialization: %s", exc
)
# Initialize CivitaiClient first to ensure it's ready for other services # Initialize CivitaiClient first to ensure it's ready for other services
await ServiceRegistry.get_civitai_client() await ServiceRegistry.get_civitai_client()
@@ -206,163 +224,200 @@ class LoraManager:
await ServiceRegistry.get_download_manager() await ServiceRegistry.get_download_manager()
from .services.metadata_service import initialize_metadata_providers from .services.metadata_service import initialize_metadata_providers
await initialize_metadata_providers() await initialize_metadata_providers()
# Initialize WebSocket manager # Initialize WebSocket manager
await ServiceRegistry.get_websocket_manager() await ServiceRegistry.get_websocket_manager()
# Initialize scanners in background # Initialize scanners in background
lora_scanner = await ServiceRegistry.get_lora_scanner() lora_scanner = await ServiceRegistry.get_lora_scanner()
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner() checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
embedding_scanner = await ServiceRegistry.get_embedding_scanner() embedding_scanner = await ServiceRegistry.get_embedding_scanner()
# Initialize recipe scanner if needed # Initialize recipe scanner if needed
recipe_scanner = await ServiceRegistry.get_recipe_scanner() recipe_scanner = await ServiceRegistry.get_recipe_scanner()
# Create low-priority initialization tasks # Create low-priority initialization tasks
init_tasks = [ init_tasks = [
asyncio.create_task(lora_scanner.initialize_in_background(), name='lora_cache_init'), asyncio.create_task(
asyncio.create_task(checkpoint_scanner.initialize_in_background(), name='checkpoint_cache_init'), lora_scanner.initialize_in_background(), name="lora_cache_init"
asyncio.create_task(embedding_scanner.initialize_in_background(), name='embedding_cache_init'), ),
asyncio.create_task(recipe_scanner.initialize_in_background(), name='recipe_cache_init') asyncio.create_task(
checkpoint_scanner.initialize_in_background(),
name="checkpoint_cache_init",
),
asyncio.create_task(
embedding_scanner.initialize_in_background(),
name="embedding_cache_init",
),
asyncio.create_task(
recipe_scanner.initialize_in_background(), name="recipe_cache_init"
),
] ]
await ExampleImagesMigration.check_and_run_migrations() await ExampleImagesMigration.check_and_run_migrations()
# Schedule post-initialization tasks to run after scanners complete # Schedule post-initialization tasks to run after scanners complete
asyncio.create_task( asyncio.create_task(
cls._run_post_initialization_tasks(init_tasks), cls._run_post_initialization_tasks(init_tasks), name="post_init_tasks"
name='post_init_tasks'
) )
logger.debug("LoRA Manager: All services initialized and background tasks scheduled") logger.debug(
"LoRA Manager: All services initialized and background tasks scheduled"
)
except Exception as e: except Exception as e:
logger.error(f"LoRA Manager: Error initializing services: {e}", exc_info=True) logger.error(
f"LoRA Manager: Error initializing services: {e}", exc_info=True
)
@classmethod @classmethod
async def _run_post_initialization_tasks(cls, init_tasks): async def _run_post_initialization_tasks(cls, init_tasks):
"""Run post-initialization tasks after all scanners complete""" """Run post-initialization tasks after all scanners complete"""
try: try:
logger.debug("LoRA Manager: Waiting for scanner initialization to complete...") logger.debug(
"LoRA Manager: Waiting for scanner initialization to complete..."
)
# Wait for all scanner initialization tasks to complete # Wait for all scanner initialization tasks to complete
await asyncio.gather(*init_tasks, return_exceptions=True) await asyncio.gather(*init_tasks, return_exceptions=True)
logger.debug("LoRA Manager: Scanner initialization completed, starting post-initialization tasks...") logger.debug(
"LoRA Manager: Scanner initialization completed, starting post-initialization tasks..."
)
# Run post-initialization tasks # Run post-initialization tasks
post_tasks = [ post_tasks = [
asyncio.create_task(cls._cleanup_backup_files(), name='cleanup_bak_files'), asyncio.create_task(
cls._cleanup_backup_files(), name="cleanup_bak_files"
),
# Add more post-initialization tasks here as needed # Add more post-initialization tasks here as needed
# asyncio.create_task(cls._another_post_task(), name='another_task'), # asyncio.create_task(cls._another_post_task(), name='another_task'),
] ]
# Run all post-initialization tasks # Run all post-initialization tasks
results = await asyncio.gather(*post_tasks, return_exceptions=True) results = await asyncio.gather(*post_tasks, return_exceptions=True)
# Log results # Log results
for i, result in enumerate(results): for i, result in enumerate(results):
task_name = post_tasks[i].get_name() task_name = post_tasks[i].get_name()
if isinstance(result, Exception): if isinstance(result, Exception):
logger.error(f"Post-initialization task '{task_name}' failed: {result}") logger.error(
f"Post-initialization task '{task_name}' failed: {result}"
)
else: else:
logger.debug(f"Post-initialization task '{task_name}' completed successfully") logger.debug(
f"Post-initialization task '{task_name}' completed successfully"
)
logger.debug("LoRA Manager: All post-initialization tasks completed") logger.debug("LoRA Manager: All post-initialization tasks completed")
except Exception as e: except Exception as e:
logger.error(f"LoRA Manager: Error in post-initialization tasks: {e}", exc_info=True) logger.error(
f"LoRA Manager: Error in post-initialization tasks: {e}", exc_info=True
)
@classmethod @classmethod
async def _cleanup_backup_files(cls): async def _cleanup_backup_files(cls):
"""Clean up .bak files in all model roots""" """Clean up .bak files in all model roots"""
try: try:
logger.debug("Starting cleanup of .bak files in model directories...") logger.debug("Starting cleanup of .bak files in model directories...")
# Collect all model roots # Collect all model roots
all_roots = set() all_roots = set()
all_roots.update(config.loras_roots) all_roots.update(config.loras_roots)
all_roots.update(config.base_models_roots) all_roots.update(config.base_models_roots or [])
all_roots.update(config.embeddings_roots) all_roots.update(config.embeddings_roots or [])
total_deleted = 0 total_deleted = 0
total_size_freed = 0 total_size_freed = 0
for root_path in all_roots: for root_path in all_roots:
if not os.path.exists(root_path): if not os.path.exists(root_path):
continue continue
try: try:
deleted_count, size_freed = await cls._cleanup_backup_files_in_directory(root_path) (
deleted_count,
size_freed,
) = await cls._cleanup_backup_files_in_directory(root_path)
total_deleted += deleted_count total_deleted += deleted_count
total_size_freed += size_freed total_size_freed += size_freed
if deleted_count > 0: if deleted_count > 0:
logger.debug(f"Cleaned up {deleted_count} .bak files in {root_path} (freed {size_freed / (1024*1024):.2f} MB)") logger.debug(
f"Cleaned up {deleted_count} .bak files in {root_path} (freed {size_freed / (1024 * 1024):.2f} MB)"
)
except Exception as e: except Exception as e:
logger.error(f"Error cleaning up .bak files in {root_path}: {e}") logger.error(f"Error cleaning up .bak files in {root_path}: {e}")
# Yield control periodically # Yield control periodically
await asyncio.sleep(0.01) await asyncio.sleep(0.01)
if total_deleted > 0: if total_deleted > 0:
logger.debug(f"Backup cleanup completed: removed {total_deleted} .bak files, freed {total_size_freed / (1024*1024):.2f} MB total") logger.debug(
f"Backup cleanup completed: removed {total_deleted} .bak files, freed {total_size_freed / (1024 * 1024):.2f} MB total"
)
else: else:
logger.debug("Backup cleanup completed: no .bak files found") logger.debug("Backup cleanup completed: no .bak files found")
except Exception as e: except Exception as e:
logger.error(f"Error during backup file cleanup: {e}", exc_info=True) logger.error(f"Error during backup file cleanup: {e}", exc_info=True)
@classmethod @classmethod
async def _cleanup_backup_files_in_directory(cls, directory_path: str): async def _cleanup_backup_files_in_directory(cls, directory_path: str):
"""Clean up .bak files in a specific directory recursively """Clean up .bak files in a specific directory recursively
Args: Args:
directory_path: Path to the directory to clean directory_path: Path to the directory to clean
Returns: Returns:
Tuple[int, int]: (number of files deleted, total size freed in bytes) Tuple[int, int]: (number of files deleted, total size freed in bytes)
""" """
deleted_count = 0 deleted_count = 0
size_freed = 0 size_freed = 0
visited_paths = set() visited_paths = set()
def cleanup_recursive(path): def cleanup_recursive(path):
nonlocal deleted_count, size_freed nonlocal deleted_count, size_freed
try: try:
real_path = os.path.realpath(path) real_path = os.path.realpath(path)
if real_path in visited_paths: if real_path in visited_paths:
return return
visited_paths.add(real_path) visited_paths.add(real_path)
with os.scandir(path) as it: with os.scandir(path) as it:
for entry in it: for entry in it:
try: try:
if entry.is_file(follow_symlinks=True) and entry.name.endswith('.bak'): if entry.is_file(
follow_symlinks=True
) and entry.name.endswith(".bak"):
file_size = entry.stat().st_size file_size = entry.stat().st_size
os.remove(entry.path) os.remove(entry.path)
deleted_count += 1 deleted_count += 1
size_freed += file_size size_freed += file_size
logger.debug(f"Deleted .bak file: {entry.path}") logger.debug(f"Deleted .bak file: {entry.path}")
elif entry.is_dir(follow_symlinks=True): elif entry.is_dir(follow_symlinks=True):
cleanup_recursive(entry.path) cleanup_recursive(entry.path)
except Exception as e: except Exception as e:
logger.warning(f"Could not delete .bak file {entry.path}: {e}") logger.warning(
f"Could not delete .bak file {entry.path}: {e}"
)
except Exception as e: except Exception as e:
logger.error(f"Error scanning directory {path} for .bak files: {e}") logger.error(f"Error scanning directory {path} for .bak files: {e}")
# Run the recursive cleanup in a thread pool to avoid blocking # Run the recursive cleanup in a thread pool to avoid blocking
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
await loop.run_in_executor(None, cleanup_recursive, directory_path) await loop.run_in_executor(None, cleanup_recursive, directory_path)
return deleted_count, size_freed return deleted_count, size_freed
@classmethod @classmethod
async def _cleanup_example_images_folders(cls): async def _cleanup_example_images_folders(cls):
"""Invoke the example images cleanup service for manual execution.""" """Invoke the example images cleanup service for manual execution."""
@@ -370,21 +425,21 @@ class LoraManager:
service = ExampleImagesCleanupService() service = ExampleImagesCleanupService()
result = await service.cleanup_example_image_folders() result = await service.cleanup_example_image_folders()
if result.get('success'): if result.get("success"):
logger.debug( logger.debug(
"Manual example images cleanup completed: moved=%s", "Manual example images cleanup completed: moved=%s",
result.get('moved_total'), result.get("moved_total"),
) )
elif result.get('partial_success'): elif result.get("partial_success"):
logger.warning( logger.warning(
"Manual example images cleanup partially succeeded: moved=%s failures=%s", "Manual example images cleanup partially succeeded: moved=%s failures=%s",
result.get('moved_total'), result.get("moved_total"),
result.get('move_failures'), result.get("move_failures"),
) )
else: else:
logger.debug( logger.debug(
"Manual example images cleanup skipped or failed: %s", "Manual example images cleanup skipped or failed: %s",
result.get('error', 'no changes'), result.get("error", "no changes"),
) )
return result return result
@@ -392,9 +447,9 @@ class LoraManager:
except Exception as e: # pragma: no cover - defensive guard except Exception as e: # pragma: no cover - defensive guard
logger.error(f"Error during example images cleanup: {e}", exc_info=True) logger.error(f"Error during example images cleanup: {e}", exc_info=True)
return { return {
'success': False, "success": False,
'error': str(e), "error": str(e),
'error_code': 'unexpected_error', "error_code": "unexpected_error",
} }
@classmethod @classmethod
@@ -402,6 +457,6 @@ class LoraManager:
"""Cleanup resources using ServiceRegistry""" """Cleanup resources using ServiceRegistry"""
try: try:
logger.info("LoRA Manager: Cleaning up services") logger.info("LoRA Manager: Cleaning up services")
except Exception as e: except Exception as e:
logger.error(f"Error during cleanup: {e}", exc_info=True) logger.error(f"Error during cleanup: {e}", exc_info=True)

View File

@@ -4,7 +4,10 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Check if running in standalone mode # Check if running in standalone mode
standalone_mode = os.environ.get("LORA_MANAGER_STANDALONE", "0") == "1" or os.environ.get("HF_HUB_DISABLE_TELEMETRY", "0") == "0" standalone_mode = (
os.environ.get("LORA_MANAGER_STANDALONE", "0") == "1"
or os.environ.get("HF_HUB_DISABLE_TELEMETRY", "0") == "0"
)
if not standalone_mode: if not standalone_mode:
from .metadata_hook import MetadataHook from .metadata_hook import MetadataHook
@@ -13,13 +16,13 @@ if not standalone_mode:
def init(): def init():
# Install hooks to collect metadata during execution # Install hooks to collect metadata during execution
MetadataHook.install() MetadataHook.install()
# Initialize registry # Initialize registry
registry = MetadataRegistry() registry = MetadataRegistry()
logger.info("ComfyUI Metadata Collector initialized") logger.info("ComfyUI Metadata Collector initialized")
def get_metadata(prompt_id=None): def get_metadata(prompt_id=None): # type: ignore[no-redef]
"""Helper function to get metadata from the registry""" """Helper function to get metadata from the registry"""
registry = MetadataRegistry() registry = MetadataRegistry()
return registry.get_metadata(prompt_id) return registry.get_metadata(prompt_id)
@@ -27,7 +30,7 @@ else:
# Standalone mode - provide dummy implementations # Standalone mode - provide dummy implementations
def init(): def init():
logger.info("ComfyUI Metadata Collector disabled in standalone mode") logger.info("ComfyUI Metadata Collector disabled in standalone mode")
def get_metadata(prompt_id=None): def get_metadata(prompt_id=None): # type: ignore[no-redef]
"""Dummy implementation for standalone mode""" """Dummy implementation for standalone mode"""
return {} return {}

View File

@@ -1,50 +1,54 @@
import time import time
from nodes import NODE_CLASS_MAPPINGS from nodes import NODE_CLASS_MAPPINGS # type: ignore
from .node_extractors import NODE_EXTRACTORS, GenericNodeExtractor from .node_extractors import NODE_EXTRACTORS, GenericNodeExtractor
from .constants import METADATA_CATEGORIES, IMAGES from .constants import METADATA_CATEGORIES, IMAGES
class MetadataRegistry: class MetadataRegistry:
"""A singleton registry to store and retrieve workflow metadata""" """A singleton registry to store and retrieve workflow metadata"""
_instance = None _instance = None
def __new__(cls): def __new__(cls):
if cls._instance is None: if cls._instance is None:
cls._instance = super().__new__(cls) cls._instance = super().__new__(cls)
cls._instance._reset() cls._instance._reset()
return cls._instance return cls._instance
def _reset(self): def _reset(self):
self.current_prompt_id = None self.current_prompt_id = None
self.current_prompt = None self.current_prompt = None
self.metadata = {} self.metadata = {}
self.prompt_metadata = {} self.prompt_metadata = {}
self.executed_nodes = set() self.executed_nodes = set()
# Node-level cache for metadata # Node-level cache for metadata
self.node_cache = {} self.node_cache = {}
# Limit the number of stored prompts # Limit the number of stored prompts
self.max_prompt_history = 3 self.max_prompt_history = 3
# Categories we want to track and retrieve from cache # Categories we want to track and retrieve from cache
self.metadata_categories = METADATA_CATEGORIES self.metadata_categories = METADATA_CATEGORIES
def _clean_old_prompts(self): def _clean_old_prompts(self):
"""Clean up old prompt metadata, keeping only recent ones""" """Clean up old prompt metadata, keeping only recent ones"""
if len(self.prompt_metadata) <= self.max_prompt_history: if len(self.prompt_metadata) <= self.max_prompt_history:
return return
# Sort all prompt_ids by timestamp # Sort all prompt_ids by timestamp
sorted_prompts = sorted( sorted_prompts = sorted(
self.prompt_metadata.keys(), self.prompt_metadata.keys(),
key=lambda pid: self.prompt_metadata[pid].get("timestamp", 0) key=lambda pid: self.prompt_metadata[pid].get("timestamp", 0),
) )
# Remove oldest records # Remove oldest records
prompts_to_remove = sorted_prompts[:len(sorted_prompts) - self.max_prompt_history] prompts_to_remove = sorted_prompts[
: len(sorted_prompts) - self.max_prompt_history
]
for pid in prompts_to_remove: for pid in prompts_to_remove:
del self.prompt_metadata[pid] del self.prompt_metadata[pid]
def start_collection(self, prompt_id): def start_collection(self, prompt_id):
"""Begin metadata collection for a new prompt""" """Begin metadata collection for a new prompt"""
self.current_prompt_id = prompt_id self.current_prompt_id = prompt_id
@@ -53,90 +57,96 @@ class MetadataRegistry:
category: {} for category in METADATA_CATEGORIES category: {} for category in METADATA_CATEGORIES
} }
# Add additional metadata fields # Add additional metadata fields
self.prompt_metadata[prompt_id].update({ self.prompt_metadata[prompt_id].update(
"execution_order": [], {
"current_prompt": None, # Will store the prompt object "execution_order": [],
"timestamp": time.time() "current_prompt": None, # Will store the prompt object
}) "timestamp": time.time(),
}
)
# Clean up old prompt data # Clean up old prompt data
self._clean_old_prompts() self._clean_old_prompts()
def set_current_prompt(self, prompt): def set_current_prompt(self, prompt):
"""Set the current prompt object reference""" """Set the current prompt object reference"""
self.current_prompt = prompt self.current_prompt = prompt
if self.current_prompt_id and self.current_prompt_id in self.prompt_metadata: if self.current_prompt_id and self.current_prompt_id in self.prompt_metadata:
# Store the prompt in the metadata for later relationship tracing # Store the prompt in the metadata for later relationship tracing
self.prompt_metadata[self.current_prompt_id]["current_prompt"] = prompt self.prompt_metadata[self.current_prompt_id]["current_prompt"] = prompt
def get_metadata(self, prompt_id=None): def get_metadata(self, prompt_id=None):
"""Get collected metadata for a prompt""" """Get collected metadata for a prompt"""
key = prompt_id if prompt_id is not None else self.current_prompt_id key = prompt_id if prompt_id is not None else self.current_prompt_id
if key not in self.prompt_metadata: if key not in self.prompt_metadata:
return {} return {}
metadata = self.prompt_metadata[key] metadata = self.prompt_metadata[key]
# If we have a current prompt object, check for non-executed nodes # If we have a current prompt object, check for non-executed nodes
prompt_obj = metadata.get("current_prompt") prompt_obj = metadata.get("current_prompt")
if prompt_obj and hasattr(prompt_obj, "original_prompt"): if prompt_obj and hasattr(prompt_obj, "original_prompt"):
original_prompt = prompt_obj.original_prompt original_prompt = prompt_obj.original_prompt
# Fill in missing metadata from cache for nodes that weren't executed # Fill in missing metadata from cache for nodes that weren't executed
self._fill_missing_metadata(key, original_prompt) self._fill_missing_metadata(key, original_prompt)
return self.prompt_metadata.get(key, {}) return self.prompt_metadata.get(key, {})
def _fill_missing_metadata(self, prompt_id, original_prompt): def _fill_missing_metadata(self, prompt_id, original_prompt):
"""Fill missing metadata from cache for non-executed nodes""" """Fill missing metadata from cache for non-executed nodes"""
if not original_prompt: if not original_prompt:
return return
executed_nodes = self.executed_nodes executed_nodes = self.executed_nodes
metadata = self.prompt_metadata[prompt_id] metadata = self.prompt_metadata[prompt_id]
# Iterate through nodes in the original prompt # Iterate through nodes in the original prompt
for node_id, node_data in original_prompt.items(): for node_id, node_data in original_prompt.items():
# Skip if already executed in this run # Skip if already executed in this run
if node_id in executed_nodes: if node_id in executed_nodes:
continue continue
# Get the node type from the prompt (this is the key in NODE_CLASS_MAPPINGS) # Get the node type from the prompt (this is the key in NODE_CLASS_MAPPINGS)
prompt_class_type = node_data.get("class_type") prompt_class_type = node_data.get("class_type")
if not prompt_class_type: if not prompt_class_type:
continue continue
# Convert to actual class name (which is what we use in our cache) # Convert to actual class name (which is what we use in our cache)
class_type = prompt_class_type class_type = prompt_class_type
if prompt_class_type in NODE_CLASS_MAPPINGS: if prompt_class_type in NODE_CLASS_MAPPINGS:
class_obj = NODE_CLASS_MAPPINGS[prompt_class_type] class_obj = NODE_CLASS_MAPPINGS[prompt_class_type]
class_type = class_obj.__name__ class_type = class_obj.__name__
# Create cache key using the actual class name # Create cache key using the actual class name
cache_key = f"{node_id}:{class_type}" cache_key = f"{node_id}:{class_type}"
# Check if this node type is relevant for metadata collection # Check if this node type is relevant for metadata collection
if class_type in NODE_EXTRACTORS: if class_type in NODE_EXTRACTORS:
# Check if we have cached metadata for this node # Check if we have cached metadata for this node
if cache_key in self.node_cache: if cache_key in self.node_cache:
cached_data = self.node_cache[cache_key] cached_data = self.node_cache[cache_key]
# Apply cached metadata to the current metadata # Apply cached metadata to the current metadata
for category in self.metadata_categories: for category in self.metadata_categories:
if category in cached_data and node_id in cached_data[category]: if category in cached_data and node_id in cached_data[category]:
if node_id not in metadata[category]: if node_id not in metadata[category]:
metadata[category][node_id] = cached_data[category][node_id] metadata[category][node_id] = cached_data[category][
node_id
]
def record_node_execution(self, node_id, class_type, inputs, outputs): def record_node_execution(self, node_id, class_type, inputs, outputs):
"""Record information about a node's execution""" """Record information about a node's execution"""
if not self.current_prompt_id: if not self.current_prompt_id:
return return
# Add to execution order and mark as executed # Add to execution order and mark as executed
if node_id not in self.executed_nodes: if node_id not in self.executed_nodes:
self.executed_nodes.add(node_id) self.executed_nodes.add(node_id)
self.prompt_metadata[self.current_prompt_id]["execution_order"].append(node_id) self.prompt_metadata[self.current_prompt_id]["execution_order"].append(
node_id
)
# Process inputs to simplify working with them # Process inputs to simplify working with them
processed_inputs = {} processed_inputs = {}
for input_name, input_values in inputs.items(): for input_name, input_values in inputs.items():
@@ -145,63 +155,61 @@ class MetadataRegistry:
processed_inputs[input_name] = input_values[0] processed_inputs[input_name] = input_values[0]
else: else:
processed_inputs[input_name] = input_values processed_inputs[input_name] = input_values
# Extract node-specific metadata # Extract node-specific metadata
extractor = NODE_EXTRACTORS.get(class_type, GenericNodeExtractor) extractor = NODE_EXTRACTORS.get(class_type, GenericNodeExtractor)
extractor.extract( extractor.extract(
node_id, node_id,
processed_inputs, processed_inputs,
outputs, outputs,
self.prompt_metadata[self.current_prompt_id] self.prompt_metadata[self.current_prompt_id],
) )
# Cache this node's metadata # Cache this node's metadata
self._cache_node_metadata(node_id, class_type) self._cache_node_metadata(node_id, class_type)
def update_node_execution(self, node_id, class_type, outputs): def update_node_execution(self, node_id, class_type, outputs):
"""Update node metadata with output information""" """Update node metadata with output information"""
if not self.current_prompt_id: if not self.current_prompt_id:
return return
# Process outputs to make them more usable # Process outputs to make them more usable
processed_outputs = outputs processed_outputs = outputs
# Use the same extractor to update with outputs # Use the same extractor to update with outputs
extractor = NODE_EXTRACTORS.get(class_type, GenericNodeExtractor) extractor = NODE_EXTRACTORS.get(class_type, GenericNodeExtractor)
if hasattr(extractor, 'update'): if hasattr(extractor, "update"):
extractor.update( extractor.update(
node_id, node_id, processed_outputs, self.prompt_metadata[self.current_prompt_id]
processed_outputs,
self.prompt_metadata[self.current_prompt_id]
) )
# Update the cached metadata for this node # Update the cached metadata for this node
self._cache_node_metadata(node_id, class_type) self._cache_node_metadata(node_id, class_type)
def _cache_node_metadata(self, node_id, class_type): def _cache_node_metadata(self, node_id, class_type):
"""Cache the metadata for a specific node""" """Cache the metadata for a specific node"""
if not self.current_prompt_id or not node_id or not class_type: if not self.current_prompt_id or not node_id or not class_type:
return return
# Create a cache key combining node_id and class_type # Create a cache key combining node_id and class_type
cache_key = f"{node_id}:{class_type}" cache_key = f"{node_id}:{class_type}"
# Create a shallow copy of the node's metadata # Create a shallow copy of the node's metadata
node_metadata = {} node_metadata = {}
current_metadata = self.prompt_metadata[self.current_prompt_id] current_metadata = self.prompt_metadata[self.current_prompt_id]
for category in self.metadata_categories: for category in self.metadata_categories:
if category in current_metadata and node_id in current_metadata[category]: if category in current_metadata and node_id in current_metadata[category]:
if category not in node_metadata: if category not in node_metadata:
node_metadata[category] = {} node_metadata[category] = {}
node_metadata[category][node_id] = current_metadata[category][node_id] node_metadata[category][node_id] = current_metadata[category][node_id]
# Save new metadata or clear stale cache entries when metadata is empty # Save new metadata or clear stale cache entries when metadata is empty
if any(node_metadata.values()): if any(node_metadata.values()):
self.node_cache[cache_key] = node_metadata self.node_cache[cache_key] = node_metadata
else: else:
self.node_cache.pop(cache_key, None) self.node_cache.pop(cache_key, None)
def clear_unused_cache(self): def clear_unused_cache(self):
"""Clean up node_cache entries that are no longer in use""" """Clean up node_cache entries that are no longer in use"""
# Collect all node_ids currently in prompt_metadata # Collect all node_ids currently in prompt_metadata
@@ -210,18 +218,18 @@ class MetadataRegistry:
for category in self.metadata_categories: for category in self.metadata_categories:
if category in prompt_data: if category in prompt_data:
active_node_ids.update(prompt_data[category].keys()) active_node_ids.update(prompt_data[category].keys())
# Find cache keys that are no longer needed # Find cache keys that are no longer needed
keys_to_remove = [] keys_to_remove = []
for cache_key in self.node_cache: for cache_key in self.node_cache:
node_id = cache_key.split(':')[0] node_id = cache_key.split(":")[0]
if node_id not in active_node_ids: if node_id not in active_node_ids:
keys_to_remove.append(cache_key) keys_to_remove.append(cache_key)
# Remove cache entries that are no longer needed # Remove cache entries that are no longer needed
for key in keys_to_remove: for key in keys_to_remove:
del self.node_cache[key] del self.node_cache[key]
def clear_metadata(self, prompt_id=None): def clear_metadata(self, prompt_id=None):
"""Clear metadata for a specific prompt or reset all data""" """Clear metadata for a specific prompt or reset all data"""
if prompt_id is not None: if prompt_id is not None:
@@ -232,25 +240,25 @@ class MetadataRegistry:
else: else:
# Reset all data # Reset all data
self._reset() self._reset()
def get_first_decoded_image(self, prompt_id=None): def get_first_decoded_image(self, prompt_id=None):
"""Get the first decoded image result""" """Get the first decoded image result"""
key = prompt_id if prompt_id is not None else self.current_prompt_id key = prompt_id if prompt_id is not None else self.current_prompt_id
if key not in self.prompt_metadata: if key not in self.prompt_metadata:
return None return None
metadata = self.prompt_metadata[key] metadata = self.prompt_metadata[key]
if IMAGES in metadata and "first_decode" in metadata[IMAGES]: if IMAGES in metadata and "first_decode" in metadata[IMAGES]:
image_data = metadata[IMAGES]["first_decode"]["image"] image_data = metadata[IMAGES]["first_decode"]["image"]
# If it's an image batch or tuple, handle various formats # If it's an image batch or tuple, handle various formats
if isinstance(image_data, (list, tuple)) and len(image_data) > 0: if isinstance(image_data, (list, tuple)) and len(image_data) > 0:
# Return first element of list/tuple # Return first element of list/tuple
return image_data[0] return image_data[0]
# If it's a tensor, return as is for processing in the route handler # If it's a tensor, return as is for processing in the route handler
return image_data return image_data
# If no image is found in the current metadata, try to find it in the cache # If no image is found in the current metadata, try to find it in the cache
# This handles the case where VAEDecode was cached by ComfyUI and not executed # This handles the case where VAEDecode was cached by ComfyUI and not executed
prompt_obj = metadata.get("current_prompt") prompt_obj = metadata.get("current_prompt")
@@ -270,8 +278,11 @@ class MetadataRegistry:
if IMAGES in cached_data and node_id in cached_data[IMAGES]: if IMAGES in cached_data and node_id in cached_data[IMAGES]:
image_data = cached_data[IMAGES][node_id]["image"] image_data = cached_data[IMAGES][node_id]["image"]
# Handle different image formats # Handle different image formats
if isinstance(image_data, (list, tuple)) and len(image_data) > 0: if (
isinstance(image_data, (list, tuple))
and len(image_data) > 0
):
return image_data[0] return image_data[0]
return image_data return image_data
return None return None

View File

@@ -1,8 +1,9 @@
import json import json
import os import os
import re import re
from typing import Any, Dict, Optional
import numpy as np import numpy as np
import folder_paths # type: ignore import folder_paths # type: ignore
from ..services.service_registry import ServiceRegistry from ..services.service_registry import ServiceRegistry
from ..metadata_collector.metadata_processor import MetadataProcessor from ..metadata_collector.metadata_processor import MetadataProcessor
from ..metadata_collector import get_metadata from ..metadata_collector import get_metadata
@@ -12,6 +13,7 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class SaveImageLM: class SaveImageLM:
NAME = "Save Image (LoraManager)" NAME = "Save Image (LoraManager)"
CATEGORY = "Lora Manager/utils" CATEGORY = "Lora Manager/utils"
@@ -23,42 +25,60 @@ class SaveImageLM:
self.prefix_append = "" self.prefix_append = ""
self.compress_level = 4 self.compress_level = 4
self.counter = 0 self.counter = 0
# Add pattern format regex for filename substitution # Add pattern format regex for filename substitution
pattern_format = re.compile(r"(%[^%]+%)") pattern_format = re.compile(r"(%[^%]+%)")
@classmethod @classmethod
def INPUT_TYPES(cls): def INPUT_TYPES(cls):
return { return {
"required": { "required": {
"images": ("IMAGE",), "images": ("IMAGE",),
"filename_prefix": ("STRING", { "filename_prefix": (
"default": "ComfyUI", "STRING",
"tooltip": "Base filename for saved images. Supports format patterns like %seed%, %width%, %height%, %model%, etc." {
}), "default": "ComfyUI",
"file_format": (["png", "jpeg", "webp"], { "tooltip": "Base filename for saved images. Supports format patterns like %seed%, %width%, %height%, %model%, etc.",
"tooltip": "Image format to save as. PNG preserves quality, JPEG is smaller, WebP balances size and quality." },
}), ),
"file_format": (
["png", "jpeg", "webp"],
{
"tooltip": "Image format to save as. PNG preserves quality, JPEG is smaller, WebP balances size and quality."
},
),
}, },
"optional": { "optional": {
"lossless_webp": ("BOOLEAN", { "lossless_webp": (
"default": False, "BOOLEAN",
"tooltip": "When enabled, saves WebP images with lossless compression. Results in larger files but no quality loss." {
}), "default": False,
"quality": ("INT", { "tooltip": "When enabled, saves WebP images with lossless compression. Results in larger files but no quality loss.",
"default": 100, },
"min": 1, ),
"max": 100, "quality": (
"tooltip": "Compression quality for JPEG and lossy WebP formats (1-100). Higher values mean better quality but larger files." "INT",
}), {
"embed_workflow": ("BOOLEAN", { "default": 100,
"default": False, "min": 1,
"tooltip": "Embeds the complete workflow data into the image metadata. Only works with PNG and WebP formats." "max": 100,
}), "tooltip": "Compression quality for JPEG and lossy WebP formats (1-100). Higher values mean better quality but larger files.",
"add_counter_to_filename": ("BOOLEAN", { },
"default": True, ),
"tooltip": "Adds an incremental counter to filenames to prevent overwriting previous images." "embed_workflow": (
}), "BOOLEAN",
{
"default": False,
"tooltip": "Embeds the complete workflow data into the image metadata. Only works with PNG and WebP formats.",
},
),
"add_counter_to_filename": (
"BOOLEAN",
{
"default": True,
"tooltip": "Adds an incremental counter to filenames to prevent overwriting previous images.",
},
),
}, },
"hidden": { "hidden": {
"id": "UNIQUE_ID", "id": "UNIQUE_ID",
@@ -75,57 +95,59 @@ class SaveImageLM:
def get_lora_hash(self, lora_name): def get_lora_hash(self, lora_name):
"""Get the lora hash from cache""" """Get the lora hash from cache"""
scanner = ServiceRegistry.get_service_sync("lora_scanner") scanner = ServiceRegistry.get_service_sync("lora_scanner")
# Use the new direct filename lookup method # Use the new direct filename lookup method
hash_value = scanner.get_hash_by_filename(lora_name) if scanner is not None:
if hash_value: hash_value = scanner.get_hash_by_filename(lora_name)
return hash_value if hash_value:
return hash_value
return None return None
def get_checkpoint_hash(self, checkpoint_path): def get_checkpoint_hash(self, checkpoint_path):
"""Get the checkpoint hash from cache""" """Get the checkpoint hash from cache"""
scanner = ServiceRegistry.get_service_sync("checkpoint_scanner") scanner = ServiceRegistry.get_service_sync("checkpoint_scanner")
if not checkpoint_path: if not checkpoint_path:
return None return None
# Extract basename without extension # Extract basename without extension
checkpoint_name = os.path.basename(checkpoint_path) checkpoint_name = os.path.basename(checkpoint_path)
checkpoint_name = os.path.splitext(checkpoint_name)[0] checkpoint_name = os.path.splitext(checkpoint_name)[0]
# Try direct filename lookup first # Try direct filename lookup first
hash_value = scanner.get_hash_by_filename(checkpoint_name) if scanner is not None:
if hash_value: hash_value = scanner.get_hash_by_filename(checkpoint_name)
return hash_value if hash_value:
return hash_value
return None return None
def format_metadata(self, metadata_dict): def format_metadata(self, metadata_dict):
"""Format metadata in the requested format similar to userComment example""" """Format metadata in the requested format similar to userComment example"""
if not metadata_dict: if not metadata_dict:
return "" return ""
# Helper function to only add parameter if value is not None # Helper function to only add parameter if value is not None
def add_param_if_not_none(param_list, label, value): def add_param_if_not_none(param_list, label, value):
if value is not None: if value is not None:
param_list.append(f"{label}: {value}") param_list.append(f"{label}: {value}")
# Extract the prompt and negative prompt # Extract the prompt and negative prompt
prompt = metadata_dict.get('prompt', '') prompt = metadata_dict.get("prompt", "")
negative_prompt = metadata_dict.get('negative_prompt', '') negative_prompt = metadata_dict.get("negative_prompt", "")
# Extract loras from the prompt if present # Extract loras from the prompt if present
loras_text = metadata_dict.get('loras', '') loras_text = metadata_dict.get("loras", "")
lora_hashes = {} lora_hashes = {}
# If loras are found, add them on a new line after the prompt # If loras are found, add them on a new line after the prompt
if loras_text: if loras_text:
prompt_with_loras = f"{prompt}\n{loras_text}" prompt_with_loras = f"{prompt}\n{loras_text}"
# Extract lora names from the format <lora:name:strength> # Extract lora names from the format <lora:name:strength>
lora_matches = re.findall(r'<lora:([^:]+):([^>]+)>', loras_text) lora_matches = re.findall(r"<lora:([^:]+):([^>]+)>", loras_text)
# Get hash for each lora # Get hash for each lora
for lora_name, strength in lora_matches: for lora_name, strength in lora_matches:
hash_value = self.get_lora_hash(lora_name) hash_value = self.get_lora_hash(lora_name)
@@ -133,112 +155,114 @@ class SaveImageLM:
lora_hashes[lora_name] = hash_value lora_hashes[lora_name] = hash_value
else: else:
prompt_with_loras = prompt prompt_with_loras = prompt
# Format the first part (prompt and loras) # Format the first part (prompt and loras)
metadata_parts = [prompt_with_loras] metadata_parts = [prompt_with_loras]
# Add negative prompt # Add negative prompt
if negative_prompt: if negative_prompt:
metadata_parts.append(f"Negative prompt: {negative_prompt}") metadata_parts.append(f"Negative prompt: {negative_prompt}")
# Format the second part (generation parameters) # Format the second part (generation parameters)
params = [] params = []
# Add standard parameters in the correct order # Add standard parameters in the correct order
if 'steps' in metadata_dict: if "steps" in metadata_dict:
add_param_if_not_none(params, "Steps", metadata_dict.get('steps')) add_param_if_not_none(params, "Steps", metadata_dict.get("steps"))
# Combine sampler and scheduler information # Combine sampler and scheduler information
sampler_name = None sampler_name = None
scheduler_name = None scheduler_name = None
if 'sampler' in metadata_dict: if "sampler" in metadata_dict:
sampler = metadata_dict.get('sampler') sampler = metadata_dict.get("sampler")
# Convert ComfyUI sampler names to user-friendly names # Convert ComfyUI sampler names to user-friendly names
sampler_mapping = { sampler_mapping = {
'euler': 'Euler', "euler": "Euler",
'euler_ancestral': 'Euler a', "euler_ancestral": "Euler a",
'dpm_2': 'DPM2', "dpm_2": "DPM2",
'dpm_2_ancestral': 'DPM2 a', "dpm_2_ancestral": "DPM2 a",
'heun': 'Heun', "heun": "Heun",
'dpm_fast': 'DPM fast', "dpm_fast": "DPM fast",
'dpm_adaptive': 'DPM adaptive', "dpm_adaptive": "DPM adaptive",
'lms': 'LMS', "lms": "LMS",
'dpmpp_2s_ancestral': 'DPM++ 2S a', "dpmpp_2s_ancestral": "DPM++ 2S a",
'dpmpp_sde': 'DPM++ SDE', "dpmpp_sde": "DPM++ SDE",
'dpmpp_sde_gpu': 'DPM++ SDE', "dpmpp_sde_gpu": "DPM++ SDE",
'dpmpp_2m': 'DPM++ 2M', "dpmpp_2m": "DPM++ 2M",
'dpmpp_2m_sde': 'DPM++ 2M SDE', "dpmpp_2m_sde": "DPM++ 2M SDE",
'dpmpp_2m_sde_gpu': 'DPM++ 2M SDE', "dpmpp_2m_sde_gpu": "DPM++ 2M SDE",
'ddim': 'DDIM' "ddim": "DDIM",
} }
sampler_name = sampler_mapping.get(sampler, sampler) sampler_name = sampler_mapping.get(sampler, sampler)
if 'scheduler' in metadata_dict: if "scheduler" in metadata_dict:
scheduler = metadata_dict.get('scheduler') scheduler = metadata_dict.get("scheduler")
scheduler_mapping = { scheduler_mapping = {
'normal': 'Simple', "normal": "Simple",
'karras': 'Karras', "karras": "Karras",
'exponential': 'Exponential', "exponential": "Exponential",
'sgm_uniform': 'SGM Uniform', "sgm_uniform": "SGM Uniform",
'sgm_quadratic': 'SGM Quadratic' "sgm_quadratic": "SGM Quadratic",
} }
scheduler_name = scheduler_mapping.get(scheduler, scheduler) scheduler_name = scheduler_mapping.get(scheduler, scheduler)
# Add combined sampler and scheduler information # Add combined sampler and scheduler information
if sampler_name: if sampler_name:
if scheduler_name: if scheduler_name:
params.append(f"Sampler: {sampler_name} {scheduler_name}") params.append(f"Sampler: {sampler_name} {scheduler_name}")
else: else:
params.append(f"Sampler: {sampler_name}") params.append(f"Sampler: {sampler_name}")
# CFG scale (Use guidance if available, otherwise fall back to cfg_scale or cfg) # CFG scale (Use guidance if available, otherwise fall back to cfg_scale or cfg)
if 'guidance' in metadata_dict: if "guidance" in metadata_dict:
add_param_if_not_none(params, "CFG scale", metadata_dict.get('guidance')) add_param_if_not_none(params, "CFG scale", metadata_dict.get("guidance"))
elif 'cfg_scale' in metadata_dict: elif "cfg_scale" in metadata_dict:
add_param_if_not_none(params, "CFG scale", metadata_dict.get('cfg_scale')) add_param_if_not_none(params, "CFG scale", metadata_dict.get("cfg_scale"))
elif 'cfg' in metadata_dict: elif "cfg" in metadata_dict:
add_param_if_not_none(params, "CFG scale", metadata_dict.get('cfg')) add_param_if_not_none(params, "CFG scale", metadata_dict.get("cfg"))
# Seed # Seed
if 'seed' in metadata_dict: if "seed" in metadata_dict:
add_param_if_not_none(params, "Seed", metadata_dict.get('seed')) add_param_if_not_none(params, "Seed", metadata_dict.get("seed"))
# Size # Size
if 'size' in metadata_dict: if "size" in metadata_dict:
add_param_if_not_none(params, "Size", metadata_dict.get('size')) add_param_if_not_none(params, "Size", metadata_dict.get("size"))
# Model info # Model info
if 'checkpoint' in metadata_dict: if "checkpoint" in metadata_dict:
# Ensure checkpoint is a string before processing # Ensure checkpoint is a string before processing
checkpoint = metadata_dict.get('checkpoint') checkpoint = metadata_dict.get("checkpoint")
if checkpoint is not None: if checkpoint is not None:
# Get model hash # Get model hash
model_hash = self.get_checkpoint_hash(checkpoint) model_hash = self.get_checkpoint_hash(checkpoint)
# Extract basename without path # Extract basename without path
checkpoint_name = os.path.basename(checkpoint) checkpoint_name = os.path.basename(checkpoint)
# Remove extension if present # Remove extension if present
checkpoint_name = os.path.splitext(checkpoint_name)[0] checkpoint_name = os.path.splitext(checkpoint_name)[0]
# Add model hash if available # Add model hash if available
if model_hash: if model_hash:
params.append(f"Model hash: {model_hash[:10]}, Model: {checkpoint_name}") params.append(
f"Model hash: {model_hash[:10]}, Model: {checkpoint_name}"
)
else: else:
params.append(f"Model: {checkpoint_name}") params.append(f"Model: {checkpoint_name}")
# Add LoRA hashes if available # Add LoRA hashes if available
if lora_hashes: if lora_hashes:
lora_hash_parts = [] lora_hash_parts = []
for lora_name, hash_value in lora_hashes.items(): for lora_name, hash_value in lora_hashes.items():
lora_hash_parts.append(f"{lora_name}: {hash_value[:10]}") lora_hash_parts.append(f"{lora_name}: {hash_value[:10]}")
if lora_hash_parts: if lora_hash_parts:
params.append(f"Lora hashes: \"{', '.join(lora_hash_parts)}\"") params.append(f'Lora hashes: "{", ".join(lora_hash_parts)}"')
# Combine all parameters with commas # Combine all parameters with commas
metadata_parts.append(", ".join(params)) metadata_parts.append(", ".join(params))
# Join all parts with a new line # Join all parts with a new line
return "\n".join(metadata_parts) return "\n".join(metadata_parts)
@@ -248,36 +272,36 @@ class SaveImageLM:
"""Format filename with metadata values""" """Format filename with metadata values"""
if not metadata_dict: if not metadata_dict:
return filename return filename
result = re.findall(self.pattern_format, filename) result = re.findall(self.pattern_format, filename)
for segment in result: for segment in result:
parts = segment.replace("%", "").split(":") parts = segment.replace("%", "").split(":")
key = parts[0] key = parts[0]
if key == "seed" and 'seed' in metadata_dict: if key == "seed" and "seed" in metadata_dict:
filename = filename.replace(segment, str(metadata_dict.get('seed', ''))) filename = filename.replace(segment, str(metadata_dict.get("seed", "")))
elif key == "width" and 'size' in metadata_dict: elif key == "width" and "size" in metadata_dict:
size = metadata_dict.get('size', 'x') size = metadata_dict.get("size", "x")
w = size.split('x')[0] if isinstance(size, str) else size[0] w = size.split("x")[0] if isinstance(size, str) else size[0]
filename = filename.replace(segment, str(w)) filename = filename.replace(segment, str(w))
elif key == "height" and 'size' in metadata_dict: elif key == "height" and "size" in metadata_dict:
size = metadata_dict.get('size', 'x') size = metadata_dict.get("size", "x")
h = size.split('x')[1] if isinstance(size, str) else size[1] h = size.split("x")[1] if isinstance(size, str) else size[1]
filename = filename.replace(segment, str(h)) filename = filename.replace(segment, str(h))
elif key == "pprompt" and 'prompt' in metadata_dict: elif key == "pprompt" and "prompt" in metadata_dict:
prompt = metadata_dict.get('prompt', '').replace("\n", " ") prompt = metadata_dict.get("prompt", "").replace("\n", " ")
if len(parts) >= 2: if len(parts) >= 2:
length = int(parts[1]) length = int(parts[1])
prompt = prompt[:length] prompt = prompt[:length]
filename = filename.replace(segment, prompt.strip()) filename = filename.replace(segment, prompt.strip())
elif key == "nprompt" and 'negative_prompt' in metadata_dict: elif key == "nprompt" and "negative_prompt" in metadata_dict:
prompt = metadata_dict.get('negative_prompt', '').replace("\n", " ") prompt = metadata_dict.get("negative_prompt", "").replace("\n", " ")
if len(parts) >= 2: if len(parts) >= 2:
length = int(parts[1]) length = int(parts[1])
prompt = prompt[:length] prompt = prompt[:length]
filename = filename.replace(segment, prompt.strip()) filename = filename.replace(segment, prompt.strip())
elif key == "model": elif key == "model":
model_value = metadata_dict.get('checkpoint') model_value = metadata_dict.get("checkpoint")
if isinstance(model_value, (bytes, os.PathLike)): if isinstance(model_value, (bytes, os.PathLike)):
model_value = str(model_value) model_value = str(model_value)
@@ -291,6 +315,7 @@ class SaveImageLM:
filename = filename.replace(segment, model) filename = filename.replace(segment, model)
elif key == "date": elif key == "date":
from datetime import datetime from datetime import datetime
now = datetime.now() now = datetime.now()
date_table = { date_table = {
"yyyy": f"{now.year:04d}", "yyyy": f"{now.year:04d}",
@@ -311,46 +336,62 @@ class SaveImageLM:
for k, v in date_table.items(): for k, v in date_table.items():
date_format = date_format.replace(k, v) date_format = date_format.replace(k, v)
filename = filename.replace(segment, date_format) filename = filename.replace(segment, date_format)
return filename return filename
def save_images(self, images, filename_prefix, file_format, id, prompt=None, extra_pnginfo=None, def save_images(
lossless_webp=True, quality=100, embed_workflow=False, add_counter_to_filename=True): self,
images,
filename_prefix,
file_format,
id,
prompt=None,
extra_pnginfo=None,
lossless_webp=True,
quality=100,
embed_workflow=False,
add_counter_to_filename=True,
):
"""Save images with metadata""" """Save images with metadata"""
results = [] results = []
# Get metadata using the metadata collector # Get metadata using the metadata collector
raw_metadata = get_metadata() raw_metadata = get_metadata()
metadata_dict = MetadataProcessor.to_dict(raw_metadata, id) metadata_dict = MetadataProcessor.to_dict(raw_metadata, id)
metadata = self.format_metadata(metadata_dict) metadata = self.format_metadata(metadata_dict)
# Process filename_prefix with pattern substitution # Process filename_prefix with pattern substitution
filename_prefix = self.format_filename(filename_prefix, metadata_dict) filename_prefix = self.format_filename(filename_prefix, metadata_dict)
# Get initial save path info once for the batch # Get initial save path info once for the batch
full_output_folder, filename, counter, subfolder, processed_prefix = folder_paths.get_save_image_path( full_output_folder, filename, counter, subfolder, processed_prefix = (
filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0] folder_paths.get_save_image_path(
filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0]
)
) )
# Create directory if it doesn't exist # Create directory if it doesn't exist
if not os.path.exists(full_output_folder): if not os.path.exists(full_output_folder):
os.makedirs(full_output_folder, exist_ok=True) os.makedirs(full_output_folder, exist_ok=True)
# Process each image with incrementing counter # Process each image with incrementing counter
for i, image in enumerate(images): for i, image in enumerate(images):
# Convert the tensor image to numpy array # Convert the tensor image to numpy array
img = 255. * image.cpu().numpy() img = 255.0 * image.cpu().numpy()
img = Image.fromarray(np.clip(img, 0, 255).astype(np.uint8)) img = Image.fromarray(np.clip(img, 0, 255).astype(np.uint8))
# Generate filename with counter if needed # Generate filename with counter if needed
base_filename = filename base_filename = filename
if add_counter_to_filename: if add_counter_to_filename:
# Use counter + i to ensure unique filenames for all images in batch # Use counter + i to ensure unique filenames for all images in batch
current_counter = counter + i current_counter = counter + i
base_filename += f"_{current_counter:05}_" base_filename += f"_{current_counter:05}_"
# Set file extension and prepare saving parameters # Set file extension and prepare saving parameters
file: str
save_kwargs: Dict[str, Any]
pnginfo: Optional[PngImagePlugin.PngInfo] = None
if file_format == "png": if file_format == "png":
file = base_filename + ".png" file = base_filename + ".png"
file_extension = ".png" file_extension = ".png"
@@ -362,17 +403,24 @@ class SaveImageLM:
file_extension = ".jpg" file_extension = ".jpg"
save_kwargs = {"quality": quality, "optimize": True} save_kwargs = {"quality": quality, "optimize": True}
elif file_format == "webp": elif file_format == "webp":
file = base_filename + ".webp" file = base_filename + ".webp"
file_extension = ".webp" file_extension = ".webp"
# Add optimization param to control performance # Add optimization param to control performance
save_kwargs = {"quality": quality, "lossless": lossless_webp, "method": 0} save_kwargs = {
"quality": quality,
"lossless": lossless_webp,
"method": 0,
}
else:
raise ValueError(f"Unsupported file format: {file_format}")
# Full save path # Full save path
file_path = os.path.join(full_output_folder, file) file_path = os.path.join(full_output_folder, file)
# Save the image with metadata # Save the image with metadata
try: try:
if file_format == "png": if file_format == "png":
assert pnginfo is not None
if metadata: if metadata:
pnginfo.add_text("parameters", metadata) pnginfo.add_text("parameters", metadata)
if embed_workflow and extra_pnginfo is not None: if embed_workflow and extra_pnginfo is not None:
@@ -384,7 +432,12 @@ class SaveImageLM:
# For JPEG, use piexif # For JPEG, use piexif
if metadata: if metadata:
try: try:
exif_dict = {'Exif': {piexif.ExifIFD.UserComment: b'UNICODE\0' + metadata.encode('utf-16be')}} exif_dict = {
"Exif": {
piexif.ExifIFD.UserComment: b"UNICODE\0"
+ metadata.encode("utf-16be")
}
}
exif_bytes = piexif.dump(exif_dict) exif_bytes = piexif.dump(exif_dict)
save_kwargs["exif"] = exif_bytes save_kwargs["exif"] = exif_bytes
except Exception as e: except Exception as e:
@@ -396,37 +449,52 @@ class SaveImageLM:
exif_dict = {} exif_dict = {}
if metadata: if metadata:
exif_dict['Exif'] = {piexif.ExifIFD.UserComment: b'UNICODE\0' + metadata.encode('utf-16be')} exif_dict["Exif"] = {
piexif.ExifIFD.UserComment: b"UNICODE\0"
+ metadata.encode("utf-16be")
}
# Add workflow if needed # Add workflow if needed
if embed_workflow and extra_pnginfo is not None: if embed_workflow and extra_pnginfo is not None:
workflow_json = json.dumps(extra_pnginfo["workflow"]) workflow_json = json.dumps(extra_pnginfo["workflow"])
exif_dict['0th'] = {piexif.ImageIFD.ImageDescription: "Workflow:" + workflow_json} exif_dict["0th"] = {
piexif.ImageIFD.ImageDescription: "Workflow:"
+ workflow_json
}
exif_bytes = piexif.dump(exif_dict) exif_bytes = piexif.dump(exif_dict)
save_kwargs["exif"] = exif_bytes save_kwargs["exif"] = exif_bytes
except Exception as e: except Exception as e:
logger.error(f"Error adding EXIF data: {e}") logger.error(f"Error adding EXIF data: {e}")
img.save(file_path, format="WEBP", **save_kwargs) img.save(file_path, format="WEBP", **save_kwargs)
results.append({ results.append(
"filename": file, {"filename": file, "subfolder": subfolder, "type": self.type}
"subfolder": subfolder, )
"type": self.type
})
except Exception as e: except Exception as e:
logger.error(f"Error saving image: {e}") logger.error(f"Error saving image: {e}")
return results return results
def process_image(self, images, id, filename_prefix="ComfyUI", file_format="png", prompt=None, extra_pnginfo=None, def process_image(
lossless_webp=True, quality=100, embed_workflow=False, add_counter_to_filename=True): self,
images,
id,
filename_prefix="ComfyUI",
file_format="png",
prompt=None,
extra_pnginfo=None,
lossless_webp=True,
quality=100,
embed_workflow=False,
add_counter_to_filename=True,
):
"""Process and save image with metadata""" """Process and save image with metadata"""
# Make sure the output directory exists # Make sure the output directory exists
os.makedirs(self.output_dir, exist_ok=True) os.makedirs(self.output_dir, exist_ok=True)
# If images is already a list or array of images, do nothing; otherwise, convert to list # If images is already a list or array of images, do nothing; otherwise, convert to list
if isinstance(images, (list, np.ndarray)): if isinstance(images, (list, np.ndarray)):
pass pass
@@ -436,19 +504,19 @@ class SaveImageLM:
images = [images] images = [images]
else: # Multiple images (batch, height, width, channels) else: # Multiple images (batch, height, width, channels)
images = [img for img in images] images = [img for img in images]
# Save all images # Save all images
results = self.save_images( results = self.save_images(
images, images,
filename_prefix, filename_prefix,
file_format, file_format,
id, id,
prompt, prompt,
extra_pnginfo, extra_pnginfo,
lossless_webp, lossless_webp,
quality, quality,
embed_workflow, embed_workflow,
add_counter_to_filename add_counter_to_filename,
) )
return (images,) return (images,)

View File

@@ -1,33 +1,35 @@
class AnyType(str): class AnyType(str):
"""A special class that is always equal in not equal comparisons. Credit to pythongosssss""" """A special class that is always equal in not equal comparisons. Credit to pythongosssss"""
def __ne__(self, __value: object) -> bool:
return False
def __ne__(self, __value: object) -> bool:
return False
# Credit to Regis Gaughan, III (rgthree) # Credit to Regis Gaughan, III (rgthree)
class FlexibleOptionalInputType(dict): class FlexibleOptionalInputType(dict):
"""A special class to make flexible nodes that pass data to our python handlers. """A special class to make flexible nodes that pass data to our python handlers.
Enables both flexible/dynamic input types (like for Any Switch) or a dynamic number of inputs Enables both flexible/dynamic input types (like for Any Switch) or a dynamic number of inputs
(like for Any Switch, Context Switch, Context Merge, Power Lora Loader, etc). (like for Any Switch, Context Switch, Context Merge, Power Lora Loader, etc).
Note, for ComfyUI, all that's needed is the `__contains__` override below, which tells ComfyUI Note, for ComfyUI, all that's needed is the `__contains__` override below, which tells ComfyUI
that our node will handle the input, regardless of what it is. that our node will handle the input, regardless of what it is.
However, with https://github.com/comfyanonymous/ComfyUI/pull/2666 a large change would occur However, with https://github.com/comfyanonymous/ComfyUI/pull/2666 a large change would occur
requiring more details on the input itself. There, we need to return a list/tuple where the first requiring more details on the input itself. There, we need to return a list/tuple where the first
item is the type. This can be a real type, or use the AnyType for additional flexibility. item is the type. This can be a real type, or use the AnyType for additional flexibility.
This should be forwards compatible unless more changes occur in the PR. This should be forwards compatible unless more changes occur in the PR.
""" """
def __init__(self, type):
self.type = type
def __getitem__(self, key): def __init__(self, type):
return (self.type, ) self.type = type
def __contains__(self, key): def __getitem__(self, key):
return True return (self.type,)
def __contains__(self, key):
return True
any_type = AnyType("*") any_type = AnyType("*")
@@ -37,25 +39,27 @@ import os
import logging import logging
import copy import copy
import sys import sys
import folder_paths import folder_paths # type: ignore
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def extract_lora_name(lora_path): def extract_lora_name(lora_path):
"""Extract the lora name from a lora path (e.g., 'IL\\aorunIllstrious.safetensors' -> 'aorunIllstrious')""" """Extract the lora name from a lora path (e.g., 'IL\\aorunIllstrious.safetensors' -> 'aorunIllstrious')"""
# Get the basename without extension # Get the basename without extension
basename = os.path.basename(lora_path) basename = os.path.basename(lora_path)
return os.path.splitext(basename)[0] return os.path.splitext(basename)[0]
def get_loras_list(kwargs): def get_loras_list(kwargs):
"""Helper to extract loras list from either old or new kwargs format""" """Helper to extract loras list from either old or new kwargs format"""
if 'loras' not in kwargs: if "loras" not in kwargs:
return [] return []
loras_data = kwargs['loras'] loras_data = kwargs["loras"]
# Handle new format: {'loras': {'__value__': [...]}} # Handle new format: {'loras': {'__value__': [...]}}
if isinstance(loras_data, dict) and '__value__' in loras_data: if isinstance(loras_data, dict) and "__value__" in loras_data:
return loras_data['__value__'] return loras_data["__value__"]
# Handle old format: {'loras': [...]} # Handle old format: {'loras': [...]}
elif isinstance(loras_data, list): elif isinstance(loras_data, list):
return loras_data return loras_data
@@ -64,24 +68,26 @@ def get_loras_list(kwargs):
logger.warning(f"Unexpected loras format: {type(loras_data)}") logger.warning(f"Unexpected loras format: {type(loras_data)}")
return [] return []
def load_state_dict_in_safetensors(path, device="cpu", filter_prefix=""): def load_state_dict_in_safetensors(path, device="cpu", filter_prefix=""):
"""Simplified version of load_state_dict_in_safetensors that just loads from a local path""" """Simplified version of load_state_dict_in_safetensors that just loads from a local path"""
import safetensors.torch import safetensors.torch
state_dict = {} state_dict = {}
with safetensors.torch.safe_open(path, framework="pt", device=device) as f: with safetensors.torch.safe_open(path, framework="pt", device=device) as f: # type: ignore[attr-defined]
for k in f.keys(): for k in f.keys():
if filter_prefix and not k.startswith(filter_prefix): if filter_prefix and not k.startswith(filter_prefix):
continue continue
state_dict[k.removeprefix(filter_prefix)] = f.get_tensor(k) state_dict[k.removeprefix(filter_prefix)] = f.get_tensor(k)
return state_dict return state_dict
def to_diffusers(input_lora): def to_diffusers(input_lora):
"""Simplified version of to_diffusers for Flux LoRA conversion""" """Simplified version of to_diffusers for Flux LoRA conversion"""
import torch import torch
from diffusers.utils.state_dict_utils import convert_unet_state_dict_to_peft from diffusers.utils.state_dict_utils import convert_unet_state_dict_to_peft
from diffusers.loaders import FluxLoraLoaderMixin from diffusers.loaders import FluxLoraLoaderMixin # type: ignore[attr-defined]
if isinstance(input_lora, str): if isinstance(input_lora, str):
tensors = load_state_dict_in_safetensors(input_lora, device="cpu") tensors = load_state_dict_in_safetensors(input_lora, device="cpu")
else: else:
@@ -91,22 +97,27 @@ def to_diffusers(input_lora):
for k, v in tensors.items(): for k, v in tensors.items():
if v.dtype not in [torch.float64, torch.float32, torch.bfloat16, torch.float16]: if v.dtype not in [torch.float64, torch.float32, torch.bfloat16, torch.float16]:
tensors[k] = v.to(torch.bfloat16) tensors[k] = v.to(torch.bfloat16)
new_tensors = FluxLoraLoaderMixin.lora_state_dict(tensors) new_tensors = FluxLoraLoaderMixin.lora_state_dict(tensors)
new_tensors = convert_unet_state_dict_to_peft(new_tensors) new_tensors = convert_unet_state_dict_to_peft(new_tensors)
return new_tensors return new_tensors
def nunchaku_load_lora(model, lora_name, lora_strength): def nunchaku_load_lora(model, lora_name, lora_strength):
"""Load a Flux LoRA for Nunchaku model""" """Load a Flux LoRA for Nunchaku model"""
# Get full path to the LoRA file. Allow both direct paths and registered LoRA names. # Get full path to the LoRA file. Allow both direct paths and registered LoRA names.
lora_path = lora_name if os.path.isfile(lora_name) else folder_paths.get_full_path("loras", lora_name) lora_path = (
lora_name
if os.path.isfile(lora_name)
else folder_paths.get_full_path("loras", lora_name)
)
if not lora_path or not os.path.isfile(lora_path): if not lora_path or not os.path.isfile(lora_path):
logger.warning("Skipping LoRA '%s' because it could not be found", lora_name) logger.warning("Skipping LoRA '%s' because it could not be found", lora_name)
return model return model
model_wrapper = model.model.diffusion_model model_wrapper = model.model.diffusion_model
# Try to find copy_with_ctx in the same module as ComfyFluxWrapper # Try to find copy_with_ctx in the same module as ComfyFluxWrapper
module_name = model_wrapper.__class__.__module__ module_name = model_wrapper.__class__.__module__
module = sys.modules.get(module_name) module = sys.modules.get(module_name)
@@ -118,14 +129,16 @@ def nunchaku_load_lora(model, lora_name, lora_strength):
ret_model_wrapper.loras = [*model_wrapper.loras, (lora_path, lora_strength)] ret_model_wrapper.loras = [*model_wrapper.loras, (lora_path, lora_strength)]
else: else:
# Fallback to legacy logic # Fallback to legacy logic
logger.warning("Please upgrade ComfyUI-nunchaku to 1.1.0 or above for better LoRA support. Falling back to legacy loading logic.") logger.warning(
"Please upgrade ComfyUI-nunchaku to 1.1.0 or above for better LoRA support. Falling back to legacy loading logic."
)
transformer = model_wrapper.model transformer = model_wrapper.model
# Save the transformer temporarily # Save the transformer temporarily
model_wrapper.model = None model_wrapper.model = None
ret_model = copy.deepcopy(model) # copy everything except the model ret_model = copy.deepcopy(model) # copy everything except the model
ret_model_wrapper = ret_model.model.diffusion_model ret_model_wrapper = ret_model.model.diffusion_model
# Restore the model and set it for the copy # Restore the model and set it for the copy
model_wrapper.model = transformer model_wrapper.model = transformer
ret_model_wrapper.model = transformer ret_model_wrapper.model = transformer
@@ -133,15 +146,15 @@ def nunchaku_load_lora(model, lora_name, lora_strength):
# Convert the LoRA to diffusers format # Convert the LoRA to diffusers format
sd = to_diffusers(lora_path) sd = to_diffusers(lora_path)
# Handle embedding adjustment if needed # Handle embedding adjustment if needed
if "transformer.x_embedder.lora_A.weight" in sd: if "transformer.x_embedder.lora_A.weight" in sd:
new_in_channels = sd["transformer.x_embedder.lora_A.weight"].shape[1] new_in_channels = sd["transformer.x_embedder.lora_A.weight"].shape[1]
assert new_in_channels % 4 == 0 assert new_in_channels % 4 == 0
new_in_channels = new_in_channels // 4 new_in_channels = new_in_channels // 4
old_in_channels = ret_model.model.model_config.unet_config["in_channels"] old_in_channels = ret_model.model.model_config.unet_config["in_channels"]
if old_in_channels < new_in_channels: if old_in_channels < new_in_channels:
ret_model.model.model_config.unet_config["in_channels"] = new_in_channels ret_model.model.model_config.unet_config["in_channels"] = new_in_channels
return ret_model return ret_model

View File

@@ -6,23 +6,24 @@ from .parsers import (
ComfyMetadataParser, ComfyMetadataParser,
MetaFormatParser, MetaFormatParser,
AutomaticMetadataParser, AutomaticMetadataParser,
CivitaiApiMetadataParser CivitaiApiMetadataParser,
) )
from .base import RecipeMetadataParser from .base import RecipeMetadataParser
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class RecipeParserFactory: class RecipeParserFactory:
"""Factory for creating recipe metadata parsers""" """Factory for creating recipe metadata parsers"""
@staticmethod @staticmethod
def create_parser(metadata) -> RecipeMetadataParser: def create_parser(metadata) -> RecipeMetadataParser | None:
""" """
Create appropriate parser based on the metadata content Create appropriate parser based on the metadata content
Args: Args:
metadata: The metadata from the image (dict or str) metadata: The metadata from the image (dict or str)
Returns: Returns:
Appropriate RecipeMetadataParser implementation Appropriate RecipeMetadataParser implementation
""" """
@@ -34,17 +35,18 @@ class RecipeParserFactory:
except Exception as e: except Exception as e:
logger.debug(f"CivitaiApiMetadataParser check failed: {e}") logger.debug(f"CivitaiApiMetadataParser check failed: {e}")
pass pass
# Convert dict to string for other parsers that expect string input # Convert dict to string for other parsers that expect string input
try: try:
import json import json
metadata_str = json.dumps(metadata) metadata_str = json.dumps(metadata)
except Exception as e: except Exception as e:
logger.debug(f"Failed to convert dict to JSON string: {e}") logger.debug(f"Failed to convert dict to JSON string: {e}")
return None return None
else: else:
metadata_str = metadata metadata_str = metadata
# Try ComfyMetadataParser which requires valid JSON # Try ComfyMetadataParser which requires valid JSON
try: try:
if ComfyMetadataParser().is_metadata_matching(metadata_str): if ComfyMetadataParser().is_metadata_matching(metadata_str):
@@ -52,7 +54,7 @@ class RecipeParserFactory:
except Exception: except Exception:
# If JSON parsing fails, move on to other parsers # If JSON parsing fails, move on to other parsers
pass pass
# Check other parsers that expect string input # Check other parsers that expect string input
if RecipeFormatParser().is_metadata_matching(metadata_str): if RecipeFormatParser().is_metadata_matching(metadata_str):
return RecipeFormatParser() return RecipeFormatParser()

View File

@@ -9,15 +9,16 @@ from ...services.metadata_service import get_default_metadata_provider
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class CivitaiApiMetadataParser(RecipeMetadataParser): class CivitaiApiMetadataParser(RecipeMetadataParser):
"""Parser for Civitai image metadata format""" """Parser for Civitai image metadata format"""
def is_metadata_matching(self, metadata) -> bool: def is_metadata_matching(self, metadata) -> bool:
"""Check if the metadata matches the Civitai image metadata format """Check if the metadata matches the Civitai image metadata format
Args: Args:
metadata: The metadata from the image (dict) metadata: The metadata from the image (dict)
Returns: Returns:
bool: True if this parser can handle the metadata bool: True if this parser can handle the metadata
""" """
@@ -28,7 +29,7 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
# Check for common CivitAI image metadata fields # Check for common CivitAI image metadata fields
civitai_image_fields = ( civitai_image_fields = (
"resources", "resources",
"civitaiResources", "civitaiResources",
"additionalResources", "additionalResources",
"hashes", "hashes",
"prompt", "prompt",
@@ -40,7 +41,7 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
"width", "width",
"height", "height",
"Model", "Model",
"Model hash" "Model hash",
) )
return any(key in payload for key in civitai_image_fields) return any(key in payload for key in civitai_image_fields)
@@ -50,7 +51,9 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
# Check for LoRA hash patterns # Check for LoRA hash patterns
hashes = metadata.get("hashes") hashes = metadata.get("hashes")
if isinstance(hashes, dict) and any(str(key).lower().startswith("lora:") for key in hashes): if isinstance(hashes, dict) and any(
str(key).lower().startswith("lora:") for key in hashes
):
return True return True
# Check nested meta object (common in CivitAI image responses) # Check nested meta object (common in CivitAI image responses)
@@ -61,22 +64,28 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
# Also check for LoRA hash patterns in nested meta # Also check for LoRA hash patterns in nested meta
hashes = nested_meta.get("hashes") hashes = nested_meta.get("hashes")
if isinstance(hashes, dict) and any(str(key).lower().startswith("lora:") for key in hashes): if isinstance(hashes, dict) and any(
str(key).lower().startswith("lora:") for key in hashes
):
return True return True
return False return False
async def parse_metadata(self, metadata, recipe_scanner=None, civitai_client=None) -> Dict[str, Any]: async def parse_metadata( # type: ignore[override]
self, user_comment, recipe_scanner=None, civitai_client=None
) -> Dict[str, Any]:
"""Parse metadata from Civitai image format """Parse metadata from Civitai image format
Args: Args:
metadata: The metadata from the image (dict) user_comment: The metadata from the image (dict)
recipe_scanner: Optional recipe scanner service recipe_scanner: Optional recipe scanner service
civitai_client: Optional Civitai API client (deprecated, use metadata_provider instead) civitai_client: Optional Civitai API client (deprecated, use metadata_provider instead)
Returns: Returns:
Dict containing parsed recipe data Dict containing parsed recipe data
""" """
metadata: Dict[str, Any] = user_comment # type: ignore[assignment]
metadata = user_comment
try: try:
# Get metadata provider instead of using civitai_client directly # Get metadata provider instead of using civitai_client directly
metadata_provider = await get_default_metadata_provider() metadata_provider = await get_default_metadata_provider()
@@ -100,19 +109,19 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
) )
): ):
metadata = inner_meta metadata = inner_meta
# Initialize result structure # Initialize result structure
result = { result = {
'base_model': None, "base_model": None,
'loras': [], "loras": [],
'model': None, "model": None,
'gen_params': {}, "gen_params": {},
'from_civitai_image': True "from_civitai_image": True,
} }
# Track already added LoRAs to prevent duplicates # Track already added LoRAs to prevent duplicates
added_loras = {} # key: model_version_id or hash, value: index in result["loras"] added_loras = {} # key: model_version_id or hash, value: index in result["loras"]
# Extract hash information from hashes field for LoRA matching # Extract hash information from hashes field for LoRA matching
lora_hashes = {} lora_hashes = {}
if "hashes" in metadata and isinstance(metadata["hashes"], dict): if "hashes" in metadata and isinstance(metadata["hashes"], dict):
@@ -121,14 +130,14 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
if key_str.lower().startswith("lora:"): if key_str.lower().startswith("lora:"):
lora_name = key_str.split(":", 1)[1] lora_name = key_str.split(":", 1)[1]
lora_hashes[lora_name] = hash_value lora_hashes[lora_name] = hash_value
# Extract prompt and negative prompt # Extract prompt and negative prompt
if "prompt" in metadata: if "prompt" in metadata:
result["gen_params"]["prompt"] = metadata["prompt"] result["gen_params"]["prompt"] = metadata["prompt"]
if "negativePrompt" in metadata: if "negativePrompt" in metadata:
result["gen_params"]["negative_prompt"] = metadata["negativePrompt"] result["gen_params"]["negative_prompt"] = metadata["negativePrompt"]
# Extract other generation parameters # Extract other generation parameters
param_mapping = { param_mapping = {
"steps": "steps", "steps": "steps",
@@ -138,98 +147,117 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
"Size": "size", "Size": "size",
"clipSkip": "clip_skip", "clipSkip": "clip_skip",
} }
for civitai_key, our_key in param_mapping.items(): for civitai_key, our_key in param_mapping.items():
if civitai_key in metadata and our_key in GEN_PARAM_KEYS: if civitai_key in metadata and our_key in GEN_PARAM_KEYS:
result["gen_params"][our_key] = metadata[civitai_key] result["gen_params"][our_key] = metadata[civitai_key]
# Extract base model information - directly if available # Extract base model information - directly if available
if "baseModel" in metadata: if "baseModel" in metadata:
result["base_model"] = metadata["baseModel"] result["base_model"] = metadata["baseModel"]
elif "Model hash" in metadata and metadata_provider: elif "Model hash" in metadata and metadata_provider:
model_hash = metadata["Model hash"] model_hash = metadata["Model hash"]
model_info, error = await metadata_provider.get_model_by_hash(model_hash) model_info, error = await metadata_provider.get_model_by_hash(
model_hash
)
if model_info: if model_info:
result["base_model"] = model_info.get("baseModel", "") result["base_model"] = model_info.get("baseModel", "")
elif "Model" in metadata and isinstance(metadata.get("resources"), list): elif "Model" in metadata and isinstance(metadata.get("resources"), list):
# Try to find base model in resources # Try to find base model in resources
for resource in metadata.get("resources", []): for resource in metadata.get("resources", []):
if resource.get("type") == "model" and resource.get("name") == metadata.get("Model"): if resource.get("type") == "model" and resource.get(
"name"
) == metadata.get("Model"):
# This is likely the checkpoint model # This is likely the checkpoint model
if metadata_provider and resource.get("hash"): if metadata_provider and resource.get("hash"):
model_info, error = await metadata_provider.get_model_by_hash(resource.get("hash")) (
model_info,
error,
) = await metadata_provider.get_model_by_hash(
resource.get("hash")
)
if model_info: if model_info:
result["base_model"] = model_info.get("baseModel", "") result["base_model"] = model_info.get("baseModel", "")
base_model_counts = {} base_model_counts = {}
# Process standard resources array # Process standard resources array
if "resources" in metadata and isinstance(metadata["resources"], list): if "resources" in metadata and isinstance(metadata["resources"], list):
for resource in metadata["resources"]: for resource in metadata["resources"]:
# Modified to process resources without a type field as potential LoRAs # Modified to process resources without a type field as potential LoRAs
if resource.get("type", "lora") == "lora": if resource.get("type", "lora") == "lora":
lora_hash = resource.get("hash", "") lora_hash = resource.get("hash", "")
# Try to get hash from the hashes field if not present in resource # Try to get hash from the hashes field if not present in resource
if not lora_hash and resource.get("name"): if not lora_hash and resource.get("name"):
lora_hash = lora_hashes.get(resource["name"], "") lora_hash = lora_hashes.get(resource["name"], "")
# Skip LoRAs without proper identification (hash or modelVersionId) # Skip LoRAs without proper identification (hash or modelVersionId)
if not lora_hash and not resource.get("modelVersionId"): if not lora_hash and not resource.get("modelVersionId"):
logger.debug(f"Skipping LoRA resource '{resource.get('name', 'Unknown')}' - no hash or modelVersionId") logger.debug(
f"Skipping LoRA resource '{resource.get('name', 'Unknown')}' - no hash or modelVersionId"
)
continue continue
# Skip if we've already added this LoRA by hash # Skip if we've already added this LoRA by hash
if lora_hash and lora_hash in added_loras: if lora_hash and lora_hash in added_loras:
continue continue
lora_entry = { lora_entry = {
'name': resource.get("name", "Unknown LoRA"), "name": resource.get("name", "Unknown LoRA"),
'type': "lora", "type": "lora",
'weight': float(resource.get("weight", 1.0)), "weight": float(resource.get("weight", 1.0)),
'hash': lora_hash, "hash": lora_hash,
'existsLocally': False, "existsLocally": False,
'localPath': None, "localPath": None,
'file_name': resource.get("name", "Unknown"), "file_name": resource.get("name", "Unknown"),
'thumbnailUrl': '/loras_static/images/no-preview.png', "thumbnailUrl": "/loras_static/images/no-preview.png",
'baseModel': '', "baseModel": "",
'size': 0, "size": 0,
'downloadUrl': '', "downloadUrl": "",
'isDeleted': False "isDeleted": False,
} }
# Try to get info from Civitai if hash is available # Try to get info from Civitai if hash is available
if lora_entry['hash'] and metadata_provider: if lora_entry["hash"] and metadata_provider:
try: try:
civitai_info = await metadata_provider.get_model_by_hash(lora_hash) civitai_info = (
await metadata_provider.get_model_by_hash(lora_hash)
)
populated_entry = await self.populate_lora_from_civitai( populated_entry = await self.populate_lora_from_civitai(
lora_entry, lora_entry,
civitai_info, civitai_info,
recipe_scanner, recipe_scanner,
base_model_counts, base_model_counts,
lora_hash lora_hash,
) )
if populated_entry is None: if populated_entry is None:
continue # Skip invalid LoRA types continue # Skip invalid LoRA types
lora_entry = populated_entry lora_entry = populated_entry
# If we have a version ID from Civitai, track it for deduplication # If we have a version ID from Civitai, track it for deduplication
if 'id' in lora_entry and lora_entry['id']: if "id" in lora_entry and lora_entry["id"]:
added_loras[str(lora_entry['id'])] = len(result["loras"]) added_loras[str(lora_entry["id"])] = len(
result["loras"]
)
except Exception as e: except Exception as e:
logger.error(f"Error fetching Civitai info for LoRA hash {lora_entry['hash']}: {e}") logger.error(
f"Error fetching Civitai info for LoRA hash {lora_entry['hash']}: {e}"
)
# Track by hash if we have it # Track by hash if we have it
if lora_hash: if lora_hash:
added_loras[lora_hash] = len(result["loras"]) added_loras[lora_hash] = len(result["loras"])
result["loras"].append(lora_entry) result["loras"].append(lora_entry)
# Process civitaiResources array # Process civitaiResources array
if "civitaiResources" in metadata and isinstance(metadata["civitaiResources"], list): if "civitaiResources" in metadata and isinstance(
metadata["civitaiResources"], list
):
for resource in metadata["civitaiResources"]: for resource in metadata["civitaiResources"]:
# Get resource type and identifier # Get resource type and identifier
resource_type = str(resource.get("type") or "").lower() resource_type = str(resource.get("type") or "").lower()
@@ -237,32 +265,39 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
if resource_type == "checkpoint": if resource_type == "checkpoint":
checkpoint_entry = { checkpoint_entry = {
'id': resource.get("modelVersionId", 0), "id": resource.get("modelVersionId", 0),
'modelId': resource.get("modelId", 0), "modelId": resource.get("modelId", 0),
'name': resource.get("modelName", "Unknown Checkpoint"), "name": resource.get("modelName", "Unknown Checkpoint"),
'version': resource.get("modelVersionName", ""), "version": resource.get("modelVersionName", ""),
'type': resource.get("type", "checkpoint"), "type": resource.get("type", "checkpoint"),
'existsLocally': False, "existsLocally": False,
'localPath': None, "localPath": None,
'file_name': resource.get("modelName", ""), "file_name": resource.get("modelName", ""),
'hash': resource.get("hash", "") or "", "hash": resource.get("hash", "") or "",
'thumbnailUrl': '/loras_static/images/no-preview.png', "thumbnailUrl": "/loras_static/images/no-preview.png",
'baseModel': '', "baseModel": "",
'size': 0, "size": 0,
'downloadUrl': '', "downloadUrl": "",
'isDeleted': False "isDeleted": False,
} }
if version_id and metadata_provider: if version_id and metadata_provider:
try: try:
civitai_info = await metadata_provider.get_model_version_info(version_id) civitai_info = (
await metadata_provider.get_model_version_info(
version_id
)
)
checkpoint_entry = await self.populate_checkpoint_from_civitai( checkpoint_entry = (
checkpoint_entry, await self.populate_checkpoint_from_civitai(
civitai_info checkpoint_entry, civitai_info
)
) )
except Exception as e: except Exception as e:
logger.error(f"Error fetching Civitai info for checkpoint version {version_id}: {e}") logger.error(
f"Error fetching Civitai info for checkpoint version {version_id}: {e}"
)
if result["model"] is None: if result["model"] is None:
result["model"] = checkpoint_entry result["model"] = checkpoint_entry
@@ -275,31 +310,35 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
# Initialize lora entry # Initialize lora entry
lora_entry = { lora_entry = {
'id': resource.get("modelVersionId", 0), "id": resource.get("modelVersionId", 0),
'modelId': resource.get("modelId", 0), "modelId": resource.get("modelId", 0),
'name': resource.get("modelName", "Unknown LoRA"), "name": resource.get("modelName", "Unknown LoRA"),
'version': resource.get("modelVersionName", ""), "version": resource.get("modelVersionName", ""),
'type': resource.get("type", "lora"), "type": resource.get("type", "lora"),
'weight': round(float(resource.get("weight", 1.0)), 2), "weight": round(float(resource.get("weight", 1.0)), 2),
'existsLocally': False, "existsLocally": False,
'thumbnailUrl': '/loras_static/images/no-preview.png', "thumbnailUrl": "/loras_static/images/no-preview.png",
'baseModel': '', "baseModel": "",
'size': 0, "size": 0,
'downloadUrl': '', "downloadUrl": "",
'isDeleted': False "isDeleted": False,
} }
# Try to get info from Civitai if modelVersionId is available # Try to get info from Civitai if modelVersionId is available
if version_id and metadata_provider: if version_id and metadata_provider:
try: try:
# Use get_model_version_info instead of get_model_version # Use get_model_version_info instead of get_model_version
civitai_info = await metadata_provider.get_model_version_info(version_id) civitai_info = (
await metadata_provider.get_model_version_info(
version_id
)
)
populated_entry = await self.populate_lora_from_civitai( populated_entry = await self.populate_lora_from_civitai(
lora_entry, lora_entry,
civitai_info, civitai_info,
recipe_scanner, recipe_scanner,
base_model_counts base_model_counts,
) )
if populated_entry is None: if populated_entry is None:
@@ -307,74 +346,87 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
lora_entry = populated_entry lora_entry = populated_entry
except Exception as e: except Exception as e:
logger.error(f"Error fetching Civitai info for model version {version_id}: {e}") logger.error(
f"Error fetching Civitai info for model version {version_id}: {e}"
)
# Track this LoRA in our deduplication dict # Track this LoRA in our deduplication dict
if version_id: if version_id:
added_loras[version_id] = len(result["loras"]) added_loras[version_id] = len(result["loras"])
result["loras"].append(lora_entry) result["loras"].append(lora_entry)
# Process additionalResources array # Process additionalResources array
if "additionalResources" in metadata and isinstance(metadata["additionalResources"], list): if "additionalResources" in metadata and isinstance(
metadata["additionalResources"], list
):
for resource in metadata["additionalResources"]: for resource in metadata["additionalResources"]:
# Skip resources that aren't LoRAs or LyCORIS # Skip resources that aren't LoRAs or LyCORIS
if resource.get("type") not in ["lora", "lycoris"] and "type" not in resource: if (
resource.get("type") not in ["lora", "lycoris"]
and "type" not in resource
):
continue continue
lora_type = resource.get("type", "lora") lora_type = resource.get("type", "lora")
name = resource.get("name", "") name = resource.get("name", "")
# Extract ID from URN format if available # Extract ID from URN format if available
version_id = None version_id = None
if name and "civitai:" in name: if name and "civitai:" in name:
parts = name.split("@") parts = name.split("@")
if len(parts) > 1: if len(parts) > 1:
version_id = parts[1] version_id = parts[1]
# Skip if we've already added this LoRA # Skip if we've already added this LoRA
if version_id in added_loras: if version_id in added_loras:
continue continue
lora_entry = { lora_entry = {
'name': name, "name": name,
'type': lora_type, "type": lora_type,
'weight': float(resource.get("strength", 1.0)), "weight": float(resource.get("strength", 1.0)),
'hash': "", "hash": "",
'existsLocally': False, "existsLocally": False,
'localPath': None, "localPath": None,
'file_name': name, "file_name": name,
'thumbnailUrl': '/loras_static/images/no-preview.png', "thumbnailUrl": "/loras_static/images/no-preview.png",
'baseModel': '', "baseModel": "",
'size': 0, "size": 0,
'downloadUrl': '', "downloadUrl": "",
'isDeleted': False "isDeleted": False,
} }
# If we have a version ID and metadata provider, try to get more info # If we have a version ID and metadata provider, try to get more info
if version_id and metadata_provider: if version_id and metadata_provider:
try: try:
# Use get_model_version_info with the version ID # Use get_model_version_info with the version ID
civitai_info = await metadata_provider.get_model_version_info(version_id) civitai_info = (
await metadata_provider.get_model_version_info(
version_id
)
)
populated_entry = await self.populate_lora_from_civitai( populated_entry = await self.populate_lora_from_civitai(
lora_entry, lora_entry,
civitai_info, civitai_info,
recipe_scanner, recipe_scanner,
base_model_counts base_model_counts,
) )
if populated_entry is None: if populated_entry is None:
continue # Skip invalid LoRA types continue # Skip invalid LoRA types
lora_entry = populated_entry lora_entry = populated_entry
# Track this LoRA for deduplication # Track this LoRA for deduplication
if version_id: if version_id:
added_loras[version_id] = len(result["loras"]) added_loras[version_id] = len(result["loras"])
except Exception as e: except Exception as e:
logger.error(f"Error fetching Civitai info for model ID {version_id}: {e}") logger.error(
f"Error fetching Civitai info for model ID {version_id}: {e}"
)
result["loras"].append(lora_entry) result["loras"].append(lora_entry)
# If we found LoRA hashes in the metadata but haven't already # If we found LoRA hashes in the metadata but haven't already
@@ -390,30 +442,32 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
continue continue
lora_entry = { lora_entry = {
'name': lora_name, "name": lora_name,
'type': "lora", "type": "lora",
'weight': 1.0, "weight": 1.0,
'hash': lora_hash, "hash": lora_hash,
'existsLocally': False, "existsLocally": False,
'localPath': None, "localPath": None,
'file_name': lora_name, "file_name": lora_name,
'thumbnailUrl': '/loras_static/images/no-preview.png', "thumbnailUrl": "/loras_static/images/no-preview.png",
'baseModel': '', "baseModel": "",
'size': 0, "size": 0,
'downloadUrl': '', "downloadUrl": "",
'isDeleted': False "isDeleted": False,
} }
if metadata_provider: if metadata_provider:
try: try:
civitai_info = await metadata_provider.get_model_by_hash(lora_hash) civitai_info = await metadata_provider.get_model_by_hash(
lora_hash
)
populated_entry = await self.populate_lora_from_civitai( populated_entry = await self.populate_lora_from_civitai(
lora_entry, lora_entry,
civitai_info, civitai_info,
recipe_scanner, recipe_scanner,
base_model_counts, base_model_counts,
lora_hash lora_hash,
) )
if populated_entry is None: if populated_entry is None:
@@ -421,80 +475,93 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
lora_entry = populated_entry lora_entry = populated_entry
if 'id' in lora_entry and lora_entry['id']: if "id" in lora_entry and lora_entry["id"]:
added_loras[str(lora_entry['id'])] = len(result["loras"]) added_loras[str(lora_entry["id"])] = len(result["loras"])
except Exception as e: except Exception as e:
logger.error(f"Error fetching Civitai info for LoRA hash {lora_hash}: {e}") logger.error(
f"Error fetching Civitai info for LoRA hash {lora_hash}: {e}"
)
added_loras[lora_hash] = len(result["loras"]) added_loras[lora_hash] = len(result["loras"])
result["loras"].append(lora_entry) result["loras"].append(lora_entry)
# Check for LoRA info in the format "Lora_0 Model hash", "Lora_0 Model name", etc. # Check for LoRA info in the format "Lora_0 Model hash", "Lora_0 Model name", etc.
lora_index = 0 lora_index = 0
while f"Lora_{lora_index} Model hash" in metadata and f"Lora_{lora_index} Model name" in metadata: while (
f"Lora_{lora_index} Model hash" in metadata
and f"Lora_{lora_index} Model name" in metadata
):
lora_hash = metadata[f"Lora_{lora_index} Model hash"] lora_hash = metadata[f"Lora_{lora_index} Model hash"]
lora_name = metadata[f"Lora_{lora_index} Model name"] lora_name = metadata[f"Lora_{lora_index} Model name"]
lora_strength_model = float(metadata.get(f"Lora_{lora_index} Strength model", 1.0)) lora_strength_model = float(
metadata.get(f"Lora_{lora_index} Strength model", 1.0)
)
# Skip if we've already added this LoRA by hash # Skip if we've already added this LoRA by hash
if lora_hash and lora_hash in added_loras: if lora_hash and lora_hash in added_loras:
lora_index += 1 lora_index += 1
continue continue
lora_entry = { lora_entry = {
'name': lora_name, "name": lora_name,
'type': "lora", "type": "lora",
'weight': lora_strength_model, "weight": lora_strength_model,
'hash': lora_hash, "hash": lora_hash,
'existsLocally': False, "existsLocally": False,
'localPath': None, "localPath": None,
'file_name': lora_name, "file_name": lora_name,
'thumbnailUrl': '/loras_static/images/no-preview.png', "thumbnailUrl": "/loras_static/images/no-preview.png",
'baseModel': '', "baseModel": "",
'size': 0, "size": 0,
'downloadUrl': '', "downloadUrl": "",
'isDeleted': False "isDeleted": False,
} }
# Try to get info from Civitai if hash is available # Try to get info from Civitai if hash is available
if lora_entry['hash'] and metadata_provider: if lora_entry["hash"] and metadata_provider:
try: try:
civitai_info = await metadata_provider.get_model_by_hash(lora_hash) civitai_info = await metadata_provider.get_model_by_hash(
lora_hash
)
populated_entry = await self.populate_lora_from_civitai( populated_entry = await self.populate_lora_from_civitai(
lora_entry, lora_entry,
civitai_info, civitai_info,
recipe_scanner, recipe_scanner,
base_model_counts, base_model_counts,
lora_hash lora_hash,
) )
if populated_entry is None: if populated_entry is None:
lora_index += 1 lora_index += 1
continue # Skip invalid LoRA types continue # Skip invalid LoRA types
lora_entry = populated_entry lora_entry = populated_entry
# If we have a version ID from Civitai, track it for deduplication # If we have a version ID from Civitai, track it for deduplication
if 'id' in lora_entry and lora_entry['id']: if "id" in lora_entry and lora_entry["id"]:
added_loras[str(lora_entry['id'])] = len(result["loras"]) added_loras[str(lora_entry["id"])] = len(result["loras"])
except Exception as e: except Exception as e:
logger.error(f"Error fetching Civitai info for LoRA hash {lora_entry['hash']}: {e}") logger.error(
f"Error fetching Civitai info for LoRA hash {lora_entry['hash']}: {e}"
)
# Track by hash if we have it # Track by hash if we have it
if lora_hash: if lora_hash:
added_loras[lora_hash] = len(result["loras"]) added_loras[lora_hash] = len(result["loras"])
result["loras"].append(lora_entry) result["loras"].append(lora_entry)
lora_index += 1 lora_index += 1
# If base model wasn't found earlier, use the most common one from LoRAs # If base model wasn't found earlier, use the most common one from LoRAs
if not result["base_model"] and base_model_counts: if not result["base_model"] and base_model_counts:
result["base_model"] = max(base_model_counts.items(), key=lambda x: x[1])[0] result["base_model"] = max(
base_model_counts.items(), key=lambda x: x[1]
)[0]
return result return result
except Exception as e: except Exception as e:
logger.error(f"Error parsing Civitai image metadata: {e}", exc_info=True) logger.error(f"Error parsing Civitai image metadata: {e}", exc_info=True)
return {"error": str(e), "loras": []} return {"error": str(e), "loras": []}

View File

@@ -3,36 +3,42 @@ import copy
import logging import logging
import os import os
from typing import Any, Optional, Dict, Tuple, List, Sequence from typing import Any, Optional, Dict, Tuple, List, Sequence
from .model_metadata_provider import CivitaiModelMetadataProvider, ModelMetadataProviderManager from .model_metadata_provider import (
CivitaiModelMetadataProvider,
ModelMetadataProviderManager,
)
from .downloader import get_downloader from .downloader import get_downloader
from .errors import RateLimitError, ResourceNotFoundError from .errors import RateLimitError, ResourceNotFoundError
from ..utils.civitai_utils import resolve_license_payload from ..utils.civitai_utils import resolve_license_payload
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class CivitaiClient: class CivitaiClient:
_instance = None _instance = None
_lock = asyncio.Lock() _lock = asyncio.Lock()
@classmethod @classmethod
async def get_instance(cls): async def get_instance(cls):
"""Get singleton instance of CivitaiClient""" """Get singleton instance of CivitaiClient"""
async with cls._lock: async with cls._lock:
if cls._instance is None: if cls._instance is None:
cls._instance = cls() cls._instance = cls()
# Register this client as a metadata provider # Register this client as a metadata provider
provider_manager = await ModelMetadataProviderManager.get_instance() provider_manager = await ModelMetadataProviderManager.get_instance()
provider_manager.register_provider('civitai', CivitaiModelMetadataProvider(cls._instance), True) provider_manager.register_provider(
"civitai", CivitaiModelMetadataProvider(cls._instance), True
)
return cls._instance return cls._instance
def __init__(self): def __init__(self):
# Check if already initialized for singleton pattern # Check if already initialized for singleton pattern
if hasattr(self, '_initialized'): if hasattr(self, "_initialized"):
return return
self._initialized = True self._initialized = True
self.base_url = "https://civitai.com/api/v1" self.base_url = "https://civitai.com/api/v1"
async def _make_request( async def _make_request(
@@ -75,8 +81,10 @@ class CivitaiClient:
meta = image.get("meta") meta = image.get("meta")
if isinstance(meta, dict) and "comfy" in meta: if isinstance(meta, dict) and "comfy" in meta:
meta.pop("comfy", None) meta.pop("comfy", None)
async def download_file(self, url: str, save_dir: str, default_filename: str, progress_callback=None) -> Tuple[bool, str]: async def download_file(
self, url: str, save_dir: str, default_filename: str, progress_callback=None
) -> Tuple[bool, str]:
"""Download file with resumable downloads and retry mechanism """Download file with resumable downloads and retry mechanism
Args: Args:
@@ -90,41 +98,48 @@ class CivitaiClient:
""" """
downloader = await get_downloader() downloader = await get_downloader()
save_path = os.path.join(save_dir, default_filename) save_path = os.path.join(save_dir, default_filename)
# Use unified downloader with CivitAI authentication # Use unified downloader with CivitAI authentication
success, result = await downloader.download_file( success, result = await downloader.download_file(
url=url, url=url,
save_path=save_path, save_path=save_path,
progress_callback=progress_callback, progress_callback=progress_callback,
use_auth=True, # Enable CivitAI authentication use_auth=True, # Enable CivitAI authentication
allow_resume=True allow_resume=True,
) )
return success, result return success, result
async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]: async def get_model_by_hash(
self, model_hash: str
) -> Tuple[Optional[Dict], Optional[str]]:
try: try:
success, version = await self._make_request( success, version = await self._make_request(
'GET', "GET",
f"{self.base_url}/model-versions/by-hash/{model_hash}", f"{self.base_url}/model-versions/by-hash/{model_hash}",
use_auth=True use_auth=True,
) )
if not success: if not success:
message = str(version) message = str(version)
if "not found" in message.lower(): if "not found" in message.lower():
return None, "Model not found" return None, "Model not found"
logger.error("Failed to fetch model info for %s: %s", model_hash[:10], message) logger.error(
"Failed to fetch model info for %s: %s", model_hash[:10], message
)
return None, message return None, message
model_id = version.get('modelId') if isinstance(version, dict):
if model_id: model_id = version.get("modelId")
model_data = await self._fetch_model_data(model_id) if model_id:
if model_data: model_data = await self._fetch_model_data(model_id)
self._enrich_version_with_model_data(version, model_data) if model_data:
self._enrich_version_with_model_data(version, model_data)
self._remove_comfy_metadata(version) self._remove_comfy_metadata(version)
return version, None return version, None
else:
return None, "Invalid response format"
except RateLimitError: except RateLimitError:
raise raise
except Exception as exc: except Exception as exc:
@@ -136,19 +151,19 @@ class CivitaiClient:
downloader = await get_downloader() downloader = await get_downloader()
success, content, headers = await downloader.download_to_memory( success, content, headers = await downloader.download_to_memory(
image_url, image_url,
use_auth=False # Preview images don't need auth use_auth=False, # Preview images don't need auth
) )
if success: if success:
# Ensure directory exists # Ensure directory exists
os.makedirs(os.path.dirname(save_path), exist_ok=True) os.makedirs(os.path.dirname(save_path), exist_ok=True)
with open(save_path, 'wb') as f: with open(save_path, "wb") as f:
f.write(content) f.write(content)
return True return True
return False return False
except Exception as e: except Exception as e:
logger.error(f"Download Error: {str(e)}") logger.error(f"Download Error: {str(e)}")
return False return False
@staticmethod @staticmethod
def _extract_error_message(payload: Any) -> str: def _extract_error_message(payload: Any) -> str:
"""Return a human-readable error message from an API payload.""" """Return a human-readable error message from an API payload."""
@@ -175,19 +190,17 @@ class CivitaiClient:
"""Get all versions of a model with local availability info""" """Get all versions of a model with local availability info"""
try: try:
success, result = await self._make_request( success, result = await self._make_request(
'GET', "GET", f"{self.base_url}/models/{model_id}", use_auth=True
f"{self.base_url}/models/{model_id}",
use_auth=True
) )
if success: if success:
# Also return model type along with versions # Also return model type along with versions
return { return {
'modelVersions': result.get('modelVersions', []), "modelVersions": result.get("modelVersions", []),
'type': result.get('type', ''), "type": result.get("type", ""),
'name': result.get('name', '') "name": result.get("name", ""),
} }
message = self._extract_error_message(result) message = self._extract_error_message(result)
if message and 'not found' in message.lower(): if message and "not found" in message.lower():
raise ResourceNotFoundError(f"Resource not found for model {model_id}") raise ResourceNotFoundError(f"Resource not found for model {model_id}")
if message: if message:
raise RuntimeError(message) raise RuntimeError(message)
@@ -221,15 +234,15 @@ class CivitaiClient:
try: try:
query = ",".join(normalized_ids) query = ",".join(normalized_ids)
success, result = await self._make_request( success, result = await self._make_request(
'GET', "GET",
f"{self.base_url}/models", f"{self.base_url}/models",
use_auth=True, use_auth=True,
params={'ids': query}, params={"ids": query},
) )
if not success: if not success:
return None return None
items = result.get('items') if isinstance(result, dict) else None items = result.get("items") if isinstance(result, dict) else None
if not isinstance(items, list): if not isinstance(items, list):
return {} return {}
@@ -237,19 +250,19 @@ class CivitaiClient:
for item in items: for item in items:
if not isinstance(item, dict): if not isinstance(item, dict):
continue continue
model_id = item.get('id') model_id = item.get("id")
try: try:
normalized_id = int(model_id) normalized_id = int(model_id)
except (TypeError, ValueError): except (TypeError, ValueError):
continue continue
payload[normalized_id] = { payload[normalized_id] = {
'modelVersions': item.get('modelVersions', []), "modelVersions": item.get("modelVersions", []),
'type': item.get('type', ''), "type": item.get("type", ""),
'name': item.get('name', ''), "name": item.get("name", ""),
'allowNoCredit': item.get('allowNoCredit'), "allowNoCredit": item.get("allowNoCredit"),
'allowCommercialUse': item.get('allowCommercialUse'), "allowCommercialUse": item.get("allowCommercialUse"),
'allowDerivatives': item.get('allowDerivatives'), "allowDerivatives": item.get("allowDerivatives"),
'allowDifferentLicense': item.get('allowDifferentLicense'), "allowDifferentLicense": item.get("allowDifferentLicense"),
} }
return payload return payload
except RateLimitError: except RateLimitError:
@@ -257,8 +270,10 @@ class CivitaiClient:
except Exception as exc: except Exception as exc:
logger.error(f"Error fetching model versions in bulk: {exc}") logger.error(f"Error fetching model versions in bulk: {exc}")
return None return None
async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]: async def get_model_version(
self, model_id: int = None, version_id: int = None
) -> Optional[Dict]:
"""Get specific model version with additional metadata.""" """Get specific model version with additional metadata."""
try: try:
if model_id is None and version_id is not None: if model_id is None and version_id is not None:
@@ -281,7 +296,7 @@ class CivitaiClient:
if version is None: if version is None:
return None return None
model_id = version.get('modelId') model_id = version.get("modelId")
if not model_id: if not model_id:
logger.error(f"No modelId found in version {version_id}") logger.error(f"No modelId found in version {version_id}")
return None return None
@@ -293,7 +308,9 @@ class CivitaiClient:
self._remove_comfy_metadata(version) self._remove_comfy_metadata(version)
return version return version
async def _get_version_with_model_id(self, model_id: int, version_id: Optional[int]) -> Optional[Dict]: async def _get_version_with_model_id(
self, model_id: int, version_id: Optional[int]
) -> Optional[Dict]:
model_data = await self._fetch_model_data(model_id) model_data = await self._fetch_model_data(model_id)
if not model_data: if not model_data:
return None return None
@@ -302,8 +319,12 @@ class CivitaiClient:
if target_version is None: if target_version is None:
return None return None
target_version_id = target_version.get('id') target_version_id = target_version.get("id")
version = await self._fetch_version_by_id(target_version_id) if target_version_id else None version = (
await self._fetch_version_by_id(target_version_id)
if target_version_id
else None
)
if version is None: if version is None:
model_hash = self._extract_primary_model_hash(target_version) model_hash = self._extract_primary_model_hash(target_version)
@@ -315,7 +336,9 @@ class CivitaiClient:
) )
if version is None: if version is None:
version = self._build_version_from_model_data(target_version, model_id, model_data) version = self._build_version_from_model_data(
target_version, model_id, model_data
)
self._enrich_version_with_model_data(version, model_data) self._enrich_version_with_model_data(version, model_data)
self._remove_comfy_metadata(version) self._remove_comfy_metadata(version)
@@ -323,9 +346,7 @@ class CivitaiClient:
async def _fetch_model_data(self, model_id: int) -> Optional[Dict]: async def _fetch_model_data(self, model_id: int) -> Optional[Dict]:
success, data = await self._make_request( success, data = await self._make_request(
'GET', "GET", f"{self.base_url}/models/{model_id}", use_auth=True
f"{self.base_url}/models/{model_id}",
use_auth=True
) )
if success: if success:
return data return data
@@ -337,9 +358,7 @@ class CivitaiClient:
return None return None
success, version = await self._make_request( success, version = await self._make_request(
'GET', "GET", f"{self.base_url}/model-versions/{version_id}", use_auth=True
f"{self.base_url}/model-versions/{version_id}",
use_auth=True
) )
if success: if success:
return version return version
@@ -352,9 +371,7 @@ class CivitaiClient:
return None return None
success, version = await self._make_request( success, version = await self._make_request(
'GET', "GET", f"{self.base_url}/model-versions/by-hash/{model_hash}", use_auth=True
f"{self.base_url}/model-versions/by-hash/{model_hash}",
use_auth=True
) )
if success: if success:
return version return version
@@ -362,16 +379,17 @@ class CivitaiClient:
logger.warning(f"Failed to fetch version by hash {model_hash}") logger.warning(f"Failed to fetch version by hash {model_hash}")
return None return None
def _select_target_version(self, model_data: Dict, model_id: int, version_id: Optional[int]) -> Optional[Dict]: def _select_target_version(
model_versions = model_data.get('modelVersions', []) self, model_data: Dict, model_id: int, version_id: Optional[int]
) -> Optional[Dict]:
model_versions = model_data.get("modelVersions", [])
if not model_versions: if not model_versions:
logger.warning(f"No model versions found for model {model_id}") logger.warning(f"No model versions found for model {model_id}")
return None return None
if version_id is not None: if version_id is not None:
target_version = next( target_version = next(
(item for item in model_versions if item.get('id') == version_id), (item for item in model_versions if item.get("id") == version_id), None
None
) )
if target_version is None: if target_version is None:
logger.warning( logger.warning(
@@ -383,46 +401,50 @@ class CivitaiClient:
return model_versions[0] return model_versions[0]
def _extract_primary_model_hash(self, version_entry: Dict) -> Optional[str]: def _extract_primary_model_hash(self, version_entry: Dict) -> Optional[str]:
for file_info in version_entry.get('files', []): for file_info in version_entry.get("files", []):
if file_info.get('type') == 'Model' and file_info.get('primary'): if file_info.get("type") == "Model" and file_info.get("primary"):
hashes = file_info.get('hashes', {}) hashes = file_info.get("hashes", {})
model_hash = hashes.get('SHA256') model_hash = hashes.get("SHA256")
if model_hash: if model_hash:
return model_hash return model_hash
return None return None
def _build_version_from_model_data(self, version_entry: Dict, model_id: int, model_data: Dict) -> Dict: def _build_version_from_model_data(
self, version_entry: Dict, model_id: int, model_data: Dict
) -> Dict:
version = copy.deepcopy(version_entry) version = copy.deepcopy(version_entry)
version.pop('index', None) version.pop("index", None)
version['modelId'] = model_id version["modelId"] = model_id
version['model'] = { version["model"] = {
'name': model_data.get('name'), "name": model_data.get("name"),
'type': model_data.get('type'), "type": model_data.get("type"),
'nsfw': model_data.get('nsfw'), "nsfw": model_data.get("nsfw"),
'poi': model_data.get('poi') "poi": model_data.get("poi"),
} }
return version return version
def _enrich_version_with_model_data(self, version: Dict, model_data: Dict) -> None: def _enrich_version_with_model_data(self, version: Dict, model_data: Dict) -> None:
model_info = version.get('model') model_info = version.get("model")
if not isinstance(model_info, dict): if not isinstance(model_info, dict):
model_info = {} model_info = {}
version['model'] = model_info version["model"] = model_info
model_info['description'] = model_data.get("description") model_info["description"] = model_data.get("description")
model_info['tags'] = model_data.get("tags", []) model_info["tags"] = model_data.get("tags", [])
version['creator'] = model_data.get("creator") version["creator"] = model_data.get("creator")
license_payload = resolve_license_payload(model_data) license_payload = resolve_license_payload(model_data)
for field, value in license_payload.items(): for field, value in license_payload.items():
model_info[field] = value model_info[field] = value
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]: async def get_model_version_info(
self, version_id: str
) -> Tuple[Optional[Dict], Optional[str]]:
"""Fetch model version metadata from Civitai """Fetch model version metadata from Civitai
Args: Args:
version_id: The Civitai model version ID version_id: The Civitai model version ID
Returns: Returns:
Tuple[Optional[Dict], Optional[str]]: A tuple containing: Tuple[Optional[Dict], Optional[str]]: A tuple containing:
- The model version data or None if not found - The model version data or None if not found
@@ -430,25 +452,23 @@ class CivitaiClient:
""" """
try: try:
url = f"{self.base_url}/model-versions/{version_id}" url = f"{self.base_url}/model-versions/{version_id}"
logger.debug(f"Resolving DNS for model version info: {url}") logger.debug(f"Resolving DNS for model version info: {url}")
success, result = await self._make_request( success, result = await self._make_request("GET", url, use_auth=True)
'GET',
url,
use_auth=True
)
if success: if success:
logger.debug(f"Successfully fetched model version info for: {version_id}") logger.debug(
f"Successfully fetched model version info for: {version_id}"
)
self._remove_comfy_metadata(result) self._remove_comfy_metadata(result)
return result, None return result, None
# Handle specific error cases # Handle specific error cases
if "not found" in str(result): if "not found" in str(result):
error_msg = f"Model not found" error_msg = f"Model not found"
logger.warning(f"Model version not found: {version_id} - {error_msg}") logger.warning(f"Model version not found: {version_id} - {error_msg}")
return None, error_msg return None, error_msg
# Other error cases # Other error cases
logger.error(f"Failed to fetch model info for {version_id}: {result}") logger.error(f"Failed to fetch model info for {version_id}: {result}")
return None, str(result) return None, str(result)
@@ -464,27 +484,23 @@ class CivitaiClient:
Args: Args:
image_id: The Civitai image ID image_id: The Civitai image ID
Returns: Returns:
Optional[Dict]: The image data or None if not found Optional[Dict]: The image data or None if not found
""" """
try: try:
url = f"{self.base_url}/images?imageId={image_id}&nsfw=X" url = f"{self.base_url}/images?imageId={image_id}&nsfw=X"
logger.debug(f"Fetching image info for ID: {image_id}") logger.debug(f"Fetching image info for ID: {image_id}")
success, result = await self._make_request( success, result = await self._make_request("GET", url, use_auth=True)
'GET',
url,
use_auth=True
)
if success: if success:
if result and "items" in result and len(result["items"]) > 0: if result and "items" in result and len(result["items"]) > 0:
logger.debug(f"Successfully fetched image info for ID: {image_id}") logger.debug(f"Successfully fetched image info for ID: {image_id}")
return result["items"][0] return result["items"][0]
logger.warning(f"No image found with ID: {image_id}") logger.warning(f"No image found with ID: {image_id}")
return None return None
logger.error(f"Failed to fetch image info for ID: {image_id}: {result}") logger.error(f"Failed to fetch image info for ID: {image_id}: {result}")
return None return None
except RateLimitError: except RateLimitError:
@@ -501,11 +517,7 @@ class CivitaiClient:
try: try:
url = f"{self.base_url}/models?username={username}" url = f"{self.base_url}/models?username={username}"
success, result = await self._make_request( success, result = await self._make_request("GET", url, use_auth=True)
'GET',
url,
use_auth=True
)
if not success: if not success:
logger.error("Failed to fetch models for %s: %s", username, result) logger.error("Failed to fetch models for %s: %s", username, result)