Fix null-safety issues and apply code formatting

Bug fixes:
- Add null guards for base_models_roots/embeddings_roots in backup cleanup
- Fix null-safety initialization of extra_unet_roots

Formatting:
- Apply consistent code style across Python files
- Fix line wrapping, quote consistency, and trailing commas
- Add type ignore comments for dynamic/platform-specific code
This commit is contained in:
Will Miao
2026-02-28 21:38:41 +08:00
parent b005961ee5
commit c9e5ea42cb
9 changed files with 1097 additions and 760 deletions

View File

@@ -2,7 +2,7 @@ import os
import platform
import threading
from pathlib import Path
import folder_paths # type: ignore
import folder_paths # type: ignore
from typing import Any, Dict, Iterable, List, Mapping, Optional, Set, Tuple
import logging
import json
@@ -10,16 +10,23 @@ import urllib.parse
import time
from .utils.cache_paths import CacheType, get_cache_file_path, get_legacy_cache_paths
from .utils.settings_paths import ensure_settings_file, get_settings_dir, load_settings_template
from .utils.settings_paths import (
ensure_settings_file,
get_settings_dir,
load_settings_template,
)
# Use an environment variable to control standalone mode
standalone_mode = os.environ.get("LORA_MANAGER_STANDALONE", "0") == "1" or os.environ.get("HF_HUB_DISABLE_TELEMETRY", "0") == "0"
standalone_mode = (
os.environ.get("LORA_MANAGER_STANDALONE", "0") == "1"
or os.environ.get("HF_HUB_DISABLE_TELEMETRY", "0") == "0"
)
logger = logging.getLogger(__name__)
def _normalize_folder_paths_for_comparison(
folder_paths: Mapping[str, Iterable[str]]
folder_paths: Mapping[str, Iterable[str]],
) -> Dict[str, Set[str]]:
"""Normalize folder paths for comparison across libraries."""
@@ -49,7 +56,7 @@ def _normalize_folder_paths_for_comparison(
def _normalize_library_folder_paths(
library_payload: Mapping[str, Any]
library_payload: Mapping[str, Any],
) -> Dict[str, Set[str]]:
"""Return normalized folder paths extracted from a library payload."""
@@ -74,11 +81,17 @@ def _get_template_folder_paths() -> Dict[str, Set[str]]:
class Config:
"""Global configuration for LoRA Manager"""
def __init__(self):
self.templates_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'templates')
self.static_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'static')
self.i18n_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'locales')
self.templates_path = os.path.join(
os.path.dirname(os.path.dirname(__file__)), "templates"
)
self.static_path = os.path.join(
os.path.dirname(os.path.dirname(__file__)), "static"
)
self.i18n_path = os.path.join(
os.path.dirname(os.path.dirname(__file__)), "locales"
)
# Path mapping dictionary, target to link mapping
self._path_mappings: Dict[str, str] = {}
# Normalized preview root directories used to validate preview access
@@ -98,7 +111,7 @@ class Config:
self.extra_embeddings_roots: List[str] = []
# Scan symbolic links during initialization
self._initialize_symlink_mappings()
if not standalone_mode:
# Save the paths to settings.json when running in ComfyUI mode
self.save_folder_paths_to_settings()
@@ -152,17 +165,21 @@ class Config:
default_library = libraries.get("default", {})
target_folder_paths = {
'loras': list(self.loras_roots),
'checkpoints': list(self.checkpoints_roots or []),
'unet': list(self.unet_roots or []),
'embeddings': list(self.embeddings_roots or []),
"loras": list(self.loras_roots),
"checkpoints": list(self.checkpoints_roots or []),
"unet": list(self.unet_roots or []),
"embeddings": list(self.embeddings_roots or []),
}
normalized_target_paths = _normalize_folder_paths_for_comparison(target_folder_paths)
normalized_target_paths = _normalize_folder_paths_for_comparison(
target_folder_paths
)
normalized_default_paths: Optional[Dict[str, Set[str]]] = None
if isinstance(default_library, Mapping):
normalized_default_paths = _normalize_library_folder_paths(default_library)
normalized_default_paths = _normalize_library_folder_paths(
default_library
)
if (
not comfy_library
@@ -185,13 +202,19 @@ class Config:
default_lora_root = self.loras_roots[0]
default_checkpoint_root = comfy_library.get("default_checkpoint_root", "")
if (not default_checkpoint_root and self.checkpoints_roots and
len(self.checkpoints_roots) == 1):
if (
not default_checkpoint_root
and self.checkpoints_roots
and len(self.checkpoints_roots) == 1
):
default_checkpoint_root = self.checkpoints_roots[0]
default_embedding_root = comfy_library.get("default_embedding_root", "")
if (not default_embedding_root and self.embeddings_roots and
len(self.embeddings_roots) == 1):
if (
not default_embedding_root
and self.embeddings_roots
and len(self.embeddings_roots) == 1
):
default_embedding_root = self.embeddings_roots[0]
metadata = dict(comfy_library.get("metadata", {}))
@@ -216,11 +239,12 @@ class Config:
try:
if os.path.islink(path):
return True
if platform.system() == 'Windows':
if platform.system() == "Windows":
try:
import ctypes
FILE_ATTRIBUTE_REPARSE_POINT = 0x400
attrs = ctypes.windll.kernel32.GetFileAttributesW(str(path))
attrs = ctypes.windll.kernel32.GetFileAttributesW(str(path)) # type: ignore[attr-defined]
return attrs != -1 and (attrs & FILE_ATTRIBUTE_REPARSE_POINT)
except Exception as e:
logger.error(f"Error checking Windows reparse point: {e}")
@@ -233,18 +257,19 @@ class Config:
"""Check if a directory entry is a symlink, including Windows junctions."""
if entry.is_symlink():
return True
if platform.system() == 'Windows':
if platform.system() == "Windows":
try:
import ctypes
FILE_ATTRIBUTE_REPARSE_POINT = 0x400
attrs = ctypes.windll.kernel32.GetFileAttributesW(entry.path)
attrs = ctypes.windll.kernel32.GetFileAttributesW(entry.path) # type: ignore[attr-defined]
return attrs != -1 and (attrs & FILE_ATTRIBUTE_REPARSE_POINT)
except Exception:
pass
return False
def _normalize_path(self, path: str) -> str:
return os.path.normpath(path).replace(os.sep, '/')
return os.path.normpath(path).replace(os.sep, "/")
def _get_symlink_cache_path(self) -> Path:
canonical_path = get_cache_file_path(CacheType.SYMLINK, create_dir=True)
@@ -278,19 +303,18 @@ class Config:
if self._entry_is_symlink(entry):
try:
target = os.path.realpath(entry.path)
direct_symlinks.append([
self._normalize_path(entry.path),
self._normalize_path(target)
])
direct_symlinks.append(
[
self._normalize_path(entry.path),
self._normalize_path(target),
]
)
except OSError:
pass
except (OSError, PermissionError):
pass
return {
"roots": unique_roots,
"direct_symlinks": sorted(direct_symlinks)
}
return {"roots": unique_roots, "direct_symlinks": sorted(direct_symlinks)}
def _initialize_symlink_mappings(self) -> None:
start = time.perf_counter()
@@ -307,10 +331,14 @@ class Config:
cached_fingerprint = self._cached_fingerprint
# Check 1: First-level symlinks unchanged (catches new symlinks at root)
fingerprint_valid = cached_fingerprint and current_fingerprint == cached_fingerprint
fingerprint_valid = (
cached_fingerprint and current_fingerprint == cached_fingerprint
)
# Check 2: All cached mappings still valid (catches changes at any depth)
mappings_valid = self._validate_cached_mappings() if fingerprint_valid else False
mappings_valid = (
self._validate_cached_mappings() if fingerprint_valid else False
)
if fingerprint_valid and mappings_valid:
return
@@ -370,7 +398,9 @@ class Config:
for target, link in cached_mappings.items():
if not isinstance(target, str) or not isinstance(link, str):
continue
normalized_mappings[self._normalize_path(target)] = self._normalize_path(link)
normalized_mappings[self._normalize_path(target)] = self._normalize_path(
link
)
self._path_mappings = normalized_mappings
@@ -391,7 +421,9 @@ class Config:
parent_dir = loaded_path.parent
if parent_dir.name == "cache" and not any(parent_dir.iterdir()):
parent_dir.rmdir()
logger.info("Removed empty legacy cache directory: %s", parent_dir)
logger.info(
"Removed empty legacy cache directory: %s", parent_dir
)
except Exception:
pass
@@ -402,7 +434,9 @@ class Config:
exc,
)
else:
logger.info("Symlink cache loaded with %d mappings", len(self._path_mappings))
logger.info(
"Symlink cache loaded with %d mappings", len(self._path_mappings)
)
return True
@@ -414,7 +448,7 @@ class Config:
"""
for target, link in self._path_mappings.items():
# Convert normalized paths back to OS paths
link_path = link.replace('/', os.sep)
link_path = link.replace("/", os.sep)
# Check if symlink still exists
if not self._is_link(link_path):
@@ -427,7 +461,9 @@ class Config:
if actual_target != target:
logger.debug(
"Symlink target changed: %s -> %s (cached: %s)",
link_path, actual_target, target
link_path,
actual_target,
target,
)
return False
except OSError:
@@ -446,7 +482,11 @@ class Config:
try:
with cache_path.open("w", encoding="utf-8") as handle:
json.dump(payload, handle, ensure_ascii=False, indent=2)
logger.debug("Symlink cache saved to %s with %d mappings", cache_path, len(self._path_mappings))
logger.debug(
"Symlink cache saved to %s with %d mappings",
cache_path,
len(self._path_mappings),
)
except Exception as exc:
logger.info("Failed to write symlink cache %s: %s", cache_path, exc)
@@ -458,7 +498,7 @@ class Config:
at the root level only (not nested symlinks in subdirectories).
"""
start = time.perf_counter()
# Reset mappings before rescanning to avoid stale entries
self._path_mappings.clear()
self._seed_root_symlink_mappings()
@@ -472,7 +512,7 @@ class Config:
def _scan_first_level_symlinks(self, root: str):
"""Scan only the first level of a directory for symlinks.
This avoids traversing the entire directory tree which can be extremely
slow for large model collections. Only symlinks directly under the root
are detected.
@@ -494,13 +534,13 @@ class Config:
self.add_path_mapping(entry.path, target_path)
except Exception as inner_exc:
logger.debug(
"Error processing directory entry %s: %s", entry.path, inner_exc
"Error processing directory entry %s: %s",
entry.path,
inner_exc,
)
except Exception as e:
logger.error(f"Error scanning links in {root}: {e}")
def add_path_mapping(self, link_path: str, target_path: str):
"""Add a symbolic link path mapping
target_path: actual target path
@@ -594,41 +634,46 @@ class Config:
preview_roots.update(self._expand_preview_root(target))
preview_roots.update(self._expand_preview_root(link))
self._preview_root_paths = {path for path in preview_roots if path.is_absolute()}
self._preview_root_paths = {
path for path in preview_roots if path.is_absolute()
}
logger.debug(
"Preview roots rebuilt: %d paths from %d lora roots (%d extra), %d checkpoint roots (%d extra), %d embedding roots (%d extra), %d symlink mappings",
len(self._preview_root_paths),
len(self.loras_roots or []), len(self.extra_loras_roots or []),
len(self.base_models_roots or []), len(self.extra_checkpoints_roots or []),
len(self.embeddings_roots or []), len(self.extra_embeddings_roots or []),
len(self.loras_roots or []),
len(self.extra_loras_roots or []),
len(self.base_models_roots or []),
len(self.extra_checkpoints_roots or []),
len(self.embeddings_roots or []),
len(self.extra_embeddings_roots or []),
len(self._path_mappings),
)
def map_path_to_link(self, path: str) -> str:
"""Map a target path back to its symbolic link path"""
normalized_path = os.path.normpath(path).replace(os.sep, '/')
normalized_path = os.path.normpath(path).replace(os.sep, "/")
# Check if the path is contained in any mapped target path
for target_path, link_path in self._path_mappings.items():
# Match whole path components to avoid prefix collisions (e.g., /a/b vs /a/bc)
if normalized_path == target_path:
return link_path
if normalized_path.startswith(target_path + '/'):
if normalized_path.startswith(target_path + "/"):
# If the path starts with the target path, replace with link path
mapped_path = normalized_path.replace(target_path, link_path, 1)
return mapped_path
return normalized_path
def map_link_to_path(self, link_path: str) -> str:
"""Map a symbolic link path back to the actual path"""
normalized_link = os.path.normpath(link_path).replace(os.sep, '/')
normalized_link = os.path.normpath(link_path).replace(os.sep, "/")
# Check if the path is contained in any mapped target path
for target_path, link_path_mapped in self._path_mappings.items():
# Match whole path components
if normalized_link == link_path_mapped:
return target_path
if normalized_link.startswith(link_path_mapped + '/'):
if normalized_link.startswith(link_path_mapped + "/"):
# If the path starts with the link path, replace with actual path
mapped_path = normalized_link.replace(link_path_mapped, target_path, 1)
return mapped_path
@@ -641,8 +686,8 @@ class Config:
continue
if not os.path.exists(path):
continue
real_path = os.path.normpath(os.path.realpath(path)).replace(os.sep, '/')
normalized = os.path.normpath(path).replace(os.sep, '/')
real_path = os.path.normpath(os.path.realpath(path)).replace(os.sep, "/")
normalized = os.path.normpath(path).replace(os.sep, "/")
if real_path not in dedup:
dedup[real_path] = normalized
return dedup
@@ -652,7 +697,9 @@ class Config:
unique_paths = sorted(path_map.values(), key=lambda p: p.lower())
for original_path in unique_paths:
real_path = os.path.normpath(os.path.realpath(original_path)).replace(os.sep, '/')
real_path = os.path.normpath(os.path.realpath(original_path)).replace(
os.sep, "/"
)
if real_path != original_path:
self.add_path_mapping(original_path, real_path)
@@ -674,7 +721,7 @@ class Config:
"Please fix your ComfyUI path configuration to separate these folders. "
"Falling back to 'checkpoints' for backward compatibility. "
"Overlapping real paths: %s",
[checkpoint_map.get(rp, rp) for rp in overlapping_real_paths]
[checkpoint_map.get(rp, rp) for rp in overlapping_real_paths],
)
# Remove overlapping paths from unet_map to prioritize checkpoints
for rp in overlapping_real_paths:
@@ -694,7 +741,9 @@ class Config:
self.unet_roots = [p for p in unique_paths if p in unet_values]
for original_path in unique_paths:
real_path = os.path.normpath(os.path.realpath(original_path)).replace(os.sep, '/')
real_path = os.path.normpath(os.path.realpath(original_path)).replace(
os.sep, "/"
)
if real_path != original_path:
self.add_path_mapping(original_path, real_path)
@@ -705,7 +754,9 @@ class Config:
unique_paths = sorted(path_map.values(), key=lambda p: p.lower())
for original_path in unique_paths:
real_path = os.path.normpath(os.path.realpath(original_path)).replace(os.sep, '/')
real_path = os.path.normpath(os.path.realpath(original_path)).replace(
os.sep, "/"
)
if real_path != original_path:
self.add_path_mapping(original_path, real_path)
@@ -719,42 +770,66 @@ class Config:
self._path_mappings.clear()
self._preview_root_paths = set()
lora_paths = folder_paths.get('loras', []) or []
checkpoint_paths = folder_paths.get('checkpoints', []) or []
unet_paths = folder_paths.get('unet', []) or []
embedding_paths = folder_paths.get('embeddings', []) or []
lora_paths = folder_paths.get("loras", []) or []
checkpoint_paths = folder_paths.get("checkpoints", []) or []
unet_paths = folder_paths.get("unet", []) or []
embedding_paths = folder_paths.get("embeddings", []) or []
self.loras_roots = self._prepare_lora_paths(lora_paths)
self.base_models_roots = self._prepare_checkpoint_paths(checkpoint_paths, unet_paths)
self.base_models_roots = self._prepare_checkpoint_paths(
checkpoint_paths, unet_paths
)
self.embeddings_roots = self._prepare_embedding_paths(embedding_paths)
# Process extra paths (only for LoRA Manager, not shared with ComfyUI)
extra_paths = extra_folder_paths or {}
extra_lora_paths = extra_paths.get('loras', []) or []
extra_checkpoint_paths = extra_paths.get('checkpoints', []) or []
extra_unet_paths = extra_paths.get('unet', []) or []
extra_embedding_paths = extra_paths.get('embeddings', []) or []
extra_lora_paths = extra_paths.get("loras", []) or []
extra_checkpoint_paths = extra_paths.get("checkpoints", []) or []
extra_unet_paths = extra_paths.get("unet", []) or []
extra_embedding_paths = extra_paths.get("embeddings", []) or []
self.extra_loras_roots = self._prepare_lora_paths(extra_lora_paths)
# Save main paths before processing extra paths ( _prepare_checkpoint_paths overwrites them)
saved_checkpoints_roots = self.checkpoints_roots
saved_unet_roots = self.unet_roots
self.extra_checkpoints_roots = self._prepare_checkpoint_paths(extra_checkpoint_paths, extra_unet_paths)
self.extra_unet_roots = self.unet_roots # unet_roots was set by _prepare_checkpoint_paths
self.extra_checkpoints_roots = self._prepare_checkpoint_paths(
extra_checkpoint_paths, extra_unet_paths
)
self.extra_unet_roots = (
self.unet_roots if self.unet_roots is not None else []
) # unet_roots was set by _prepare_checkpoint_paths
# Restore main paths
self.checkpoints_roots = saved_checkpoints_roots
self.unet_roots = saved_unet_roots
self.extra_embeddings_roots = self._prepare_embedding_paths(extra_embedding_paths)
self.extra_embeddings_roots = self._prepare_embedding_paths(
extra_embedding_paths
)
# Log extra folder paths
if self.extra_loras_roots:
logger.info("Found extra LoRA roots:" + "\n - " + "\n - ".join(self.extra_loras_roots))
logger.info(
"Found extra LoRA roots:"
+ "\n - "
+ "\n - ".join(self.extra_loras_roots)
)
if self.extra_checkpoints_roots:
logger.info("Found extra checkpoint roots:" + "\n - " + "\n - ".join(self.extra_checkpoints_roots))
logger.info(
"Found extra checkpoint roots:"
+ "\n - "
+ "\n - ".join(self.extra_checkpoints_roots)
)
if self.extra_unet_roots:
logger.info("Found extra diffusion model roots:" + "\n - " + "\n - ".join(self.extra_unet_roots))
logger.info(
"Found extra diffusion model roots:"
+ "\n - "
+ "\n - ".join(self.extra_unet_roots)
)
if self.extra_embeddings_roots:
logger.info("Found extra embedding roots:" + "\n - " + "\n - ".join(self.extra_embeddings_roots))
logger.info(
"Found extra embedding roots:"
+ "\n - "
+ "\n - ".join(self.extra_embeddings_roots)
)
self._initialize_symlink_mappings()
@@ -763,7 +838,10 @@ class Config:
try:
raw_paths = folder_paths.get_folder_paths("loras")
unique_paths = self._prepare_lora_paths(raw_paths)
logger.info("Found LoRA roots:" + ("\n - " + "\n - ".join(unique_paths) if unique_paths else "[]"))
logger.info(
"Found LoRA roots:"
+ ("\n - " + "\n - ".join(unique_paths) if unique_paths else "[]")
)
if not unique_paths:
logger.warning("No valid loras folders found in ComfyUI configuration")
@@ -779,12 +857,19 @@ class Config:
try:
raw_checkpoint_paths = folder_paths.get_folder_paths("checkpoints")
raw_unet_paths = folder_paths.get_folder_paths("unet")
unique_paths = self._prepare_checkpoint_paths(raw_checkpoint_paths, raw_unet_paths)
unique_paths = self._prepare_checkpoint_paths(
raw_checkpoint_paths, raw_unet_paths
)
logger.info("Found checkpoint roots:" + ("\n - " + "\n - ".join(unique_paths) if unique_paths else "[]"))
logger.info(
"Found checkpoint roots:"
+ ("\n - " + "\n - ".join(unique_paths) if unique_paths else "[]")
)
if not unique_paths:
logger.warning("No valid checkpoint folders found in ComfyUI configuration")
logger.warning(
"No valid checkpoint folders found in ComfyUI configuration"
)
return []
return unique_paths
@@ -797,10 +882,15 @@ class Config:
try:
raw_paths = folder_paths.get_folder_paths("embeddings")
unique_paths = self._prepare_embedding_paths(raw_paths)
logger.info("Found embedding roots:" + ("\n - " + "\n - ".join(unique_paths) if unique_paths else "[]"))
logger.info(
"Found embedding roots:"
+ ("\n - " + "\n - ".join(unique_paths) if unique_paths else "[]")
)
if not unique_paths:
logger.warning("No valid embeddings folders found in ComfyUI configuration")
logger.warning(
"No valid embeddings folders found in ComfyUI configuration"
)
return []
return unique_paths
@@ -812,13 +902,13 @@ class Config:
if not preview_path:
return ""
normalized = os.path.normpath(preview_path).replace(os.sep, '/')
encoded_path = urllib.parse.quote(normalized, safe='')
return f'/api/lm/previews?path={encoded_path}'
normalized = os.path.normpath(preview_path).replace(os.sep, "/")
encoded_path = urllib.parse.quote(normalized, safe="")
return f"/api/lm/previews?path={encoded_path}"
def is_preview_path_allowed(self, preview_path: str) -> bool:
"""Return ``True`` if ``preview_path`` is within an allowed directory.
If the path is initially rejected, attempts to discover deep symlinks
that were not scanned during initialization. If a symlink is found,
updates the in-memory path mappings and retries the check.
@@ -889,14 +979,18 @@ class Config:
normalized_link = self._normalize_path(str(current))
self._path_mappings[normalized_target] = normalized_link
self._preview_root_paths.update(self._expand_preview_root(normalized_target))
self._preview_root_paths.update(self._expand_preview_root(normalized_link))
self._preview_root_paths.update(
self._expand_preview_root(normalized_target)
)
self._preview_root_paths.update(
self._expand_preview_root(normalized_link)
)
logger.debug(
"Discovered deep symlink: %s -> %s (preview path: %s)",
normalized_link,
normalized_target,
preview_path
preview_path,
)
return True
@@ -914,8 +1008,16 @@ class Config:
def apply_library_settings(self, library_config: Mapping[str, object]) -> None:
"""Update runtime paths to match the provided library configuration."""
folder_paths = library_config.get('folder_paths') if isinstance(library_config, Mapping) else {}
extra_folder_paths = library_config.get('extra_folder_paths') if isinstance(library_config, Mapping) else None
folder_paths = (
library_config.get("folder_paths")
if isinstance(library_config, Mapping)
else {}
)
extra_folder_paths = (
library_config.get("extra_folder_paths")
if isinstance(library_config, Mapping)
else None
)
if not isinstance(folder_paths, Mapping):
folder_paths = {}
if not isinstance(extra_folder_paths, Mapping):
@@ -925,9 +1027,12 @@ class Config:
logger.info(
"Applied library settings with %d lora roots (%d extra), %d checkpoint roots (%d extra), and %d embedding roots (%d extra)",
len(self.loras_roots or []), len(self.extra_loras_roots or []),
len(self.base_models_roots or []), len(self.extra_checkpoints_roots or []),
len(self.embeddings_roots or []), len(self.extra_embeddings_roots or []),
len(self.loras_roots or []),
len(self.extra_loras_roots or []),
len(self.base_models_roots or []),
len(self.extra_checkpoints_roots or []),
len(self.embeddings_roots or []),
len(self.extra_embeddings_roots or []),
)
def get_library_registry_snapshot(self) -> Dict[str, object]:
@@ -947,5 +1052,6 @@ class Config:
logger.debug("Failed to collect library registry snapshot: %s", exc)
return {"active_library": "", "libraries": {}}
# Global config instance
config = Config()

View File

@@ -5,16 +5,22 @@ import logging
from .utils.logging_config import setup_logging
# Check if we're in standalone mode
standalone_mode = os.environ.get("LORA_MANAGER_STANDALONE", "0") == "1" or os.environ.get("HF_HUB_DISABLE_TELEMETRY", "0") == "0"
standalone_mode = (
os.environ.get("LORA_MANAGER_STANDALONE", "0") == "1"
or os.environ.get("HF_HUB_DISABLE_TELEMETRY", "0") == "0"
)
# Only setup logging prefix if not in standalone mode
if not standalone_mode:
setup_logging()
from server import PromptServer # type: ignore
from server import PromptServer # type: ignore
from .config import config
from .services.model_service_factory import ModelServiceFactory, register_default_model_types
from .services.model_service_factory import (
ModelServiceFactory,
register_default_model_types,
)
from .routes.recipe_routes import RecipeRoutes
from .routes.stats_routes import StatsRoutes
from .routes.update_routes import UpdateRoutes
@@ -61,9 +67,10 @@ class _SettingsProxy:
settings = _SettingsProxy()
class LoraManager:
"""Main entry point for LoRA Manager plugin"""
@classmethod
def add_routes(cls):
"""Initialize and register all routes using the new refactored architecture"""
@@ -76,7 +83,8 @@ class LoraManager:
(
idx
for idx, middleware in enumerate(app.middlewares)
if getattr(middleware, "__name__", "") == "block_external_middleware"
if getattr(middleware, "__name__", "")
== "block_external_middleware"
),
None,
)
@@ -84,7 +92,9 @@ class LoraManager:
if block_middleware_index is None:
app.middlewares.append(relax_csp_for_remote_media)
else:
app.middlewares.insert(block_middleware_index, relax_csp_for_remote_media)
app.middlewares.insert(
block_middleware_index, relax_csp_for_remote_media
)
# Increase allowed header sizes so browsers with large localhost cookie
# jars (multiple UIs on 127.0.0.1) don't trip aiohttp's 8KB default
@@ -105,7 +115,7 @@ class LoraManager:
app._handler_args = updated_handler_args
# Configure aiohttp access logger to be less verbose
logging.getLogger('aiohttp.access').setLevel(logging.WARNING)
logging.getLogger("aiohttp.access").setLevel(logging.WARNING)
# Add specific suppression for connection reset errors
class ConnectionResetFilter(logging.Filter):
@@ -124,46 +134,52 @@ class LoraManager:
asyncio_logger.addFilter(ConnectionResetFilter())
# Add static route for example images if the path exists in settings
example_images_path = settings.get('example_images_path')
example_images_path = settings.get("example_images_path")
logger.info(f"Example images path: {example_images_path}")
if example_images_path and os.path.exists(example_images_path):
app.router.add_static('/example_images_static', example_images_path)
logger.info(f"Added static route for example images: /example_images_static -> {example_images_path}")
app.router.add_static("/example_images_static", example_images_path)
logger.info(
f"Added static route for example images: /example_images_static -> {example_images_path}"
)
# Add static route for locales JSON files
if os.path.exists(config.i18n_path):
app.router.add_static('/locales', config.i18n_path)
logger.info(f"Added static route for locales: /locales -> {config.i18n_path}")
app.router.add_static("/locales", config.i18n_path)
logger.info(
f"Added static route for locales: /locales -> {config.i18n_path}"
)
# Add static route for plugin assets
app.router.add_static('/loras_static', config.static_path)
app.router.add_static("/loras_static", config.static_path)
# Register default model types with the factory
register_default_model_types()
# Setup all model routes using the factory
ModelServiceFactory.setup_all_routes(app)
# Setup non-model-specific routes
stats_routes = StatsRoutes()
stats_routes.setup_routes(app)
RecipeRoutes.setup_routes(app)
UpdateRoutes.setup_routes(app)
UpdateRoutes.setup_routes(app)
MiscRoutes.setup_routes(app)
ExampleImagesRoutes.setup_routes(app, ws_manager=ws_manager)
PreviewRoutes.setup_routes(app)
# Setup WebSocket routes that are shared across all model types
app.router.add_get('/ws/fetch-progress', ws_manager.handle_connection)
app.router.add_get('/ws/download-progress', ws_manager.handle_download_connection)
app.router.add_get('/ws/init-progress', ws_manager.handle_init_connection)
# Schedule service initialization
app.router.add_get("/ws/fetch-progress", ws_manager.handle_connection)
app.router.add_get(
"/ws/download-progress", ws_manager.handle_download_connection
)
app.router.add_get("/ws/init-progress", ws_manager.handle_init_connection)
# Schedule service initialization
app.on_startup.append(lambda app: cls._initialize_services())
# Add cleanup
app.on_shutdown.append(cls._cleanup)
@classmethod
async def _initialize_services(cls):
"""Initialize all services using the ServiceRegistry"""
@@ -197,7 +213,9 @@ class LoraManager:
extra_paths.get("embeddings", []),
)
except Exception as exc:
logger.warning("Failed to apply library settings during initialization: %s", exc)
logger.warning(
"Failed to apply library settings during initialization: %s", exc
)
# Initialize CivitaiClient first to ensure it's ready for other services
await ServiceRegistry.get_civitai_client()
@@ -206,163 +224,200 @@ class LoraManager:
await ServiceRegistry.get_download_manager()
from .services.metadata_service import initialize_metadata_providers
await initialize_metadata_providers()
# Initialize WebSocket manager
await ServiceRegistry.get_websocket_manager()
# Initialize scanners in background
lora_scanner = await ServiceRegistry.get_lora_scanner()
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
# Initialize recipe scanner if needed
recipe_scanner = await ServiceRegistry.get_recipe_scanner()
# Create low-priority initialization tasks
init_tasks = [
asyncio.create_task(lora_scanner.initialize_in_background(), name='lora_cache_init'),
asyncio.create_task(checkpoint_scanner.initialize_in_background(), name='checkpoint_cache_init'),
asyncio.create_task(embedding_scanner.initialize_in_background(), name='embedding_cache_init'),
asyncio.create_task(recipe_scanner.initialize_in_background(), name='recipe_cache_init')
asyncio.create_task(
lora_scanner.initialize_in_background(), name="lora_cache_init"
),
asyncio.create_task(
checkpoint_scanner.initialize_in_background(),
name="checkpoint_cache_init",
),
asyncio.create_task(
embedding_scanner.initialize_in_background(),
name="embedding_cache_init",
),
asyncio.create_task(
recipe_scanner.initialize_in_background(), name="recipe_cache_init"
),
]
await ExampleImagesMigration.check_and_run_migrations()
# Schedule post-initialization tasks to run after scanners complete
asyncio.create_task(
cls._run_post_initialization_tasks(init_tasks),
name='post_init_tasks'
cls._run_post_initialization_tasks(init_tasks), name="post_init_tasks"
)
logger.debug("LoRA Manager: All services initialized and background tasks scheduled")
logger.debug(
"LoRA Manager: All services initialized and background tasks scheduled"
)
except Exception as e:
logger.error(f"LoRA Manager: Error initializing services: {e}", exc_info=True)
logger.error(
f"LoRA Manager: Error initializing services: {e}", exc_info=True
)
@classmethod
async def _run_post_initialization_tasks(cls, init_tasks):
"""Run post-initialization tasks after all scanners complete"""
try:
logger.debug("LoRA Manager: Waiting for scanner initialization to complete...")
logger.debug(
"LoRA Manager: Waiting for scanner initialization to complete..."
)
# Wait for all scanner initialization tasks to complete
await asyncio.gather(*init_tasks, return_exceptions=True)
logger.debug("LoRA Manager: Scanner initialization completed, starting post-initialization tasks...")
logger.debug(
"LoRA Manager: Scanner initialization completed, starting post-initialization tasks..."
)
# Run post-initialization tasks
post_tasks = [
asyncio.create_task(cls._cleanup_backup_files(), name='cleanup_bak_files'),
asyncio.create_task(
cls._cleanup_backup_files(), name="cleanup_bak_files"
),
# Add more post-initialization tasks here as needed
# asyncio.create_task(cls._another_post_task(), name='another_task'),
]
# Run all post-initialization tasks
results = await asyncio.gather(*post_tasks, return_exceptions=True)
# Log results
for i, result in enumerate(results):
task_name = post_tasks[i].get_name()
if isinstance(result, Exception):
logger.error(f"Post-initialization task '{task_name}' failed: {result}")
logger.error(
f"Post-initialization task '{task_name}' failed: {result}"
)
else:
logger.debug(f"Post-initialization task '{task_name}' completed successfully")
logger.debug(
f"Post-initialization task '{task_name}' completed successfully"
)
logger.debug("LoRA Manager: All post-initialization tasks completed")
except Exception as e:
logger.error(f"LoRA Manager: Error in post-initialization tasks: {e}", exc_info=True)
logger.error(
f"LoRA Manager: Error in post-initialization tasks: {e}", exc_info=True
)
@classmethod
async def _cleanup_backup_files(cls):
"""Clean up .bak files in all model roots"""
try:
logger.debug("Starting cleanup of .bak files in model directories...")
# Collect all model roots
all_roots = set()
all_roots.update(config.loras_roots)
all_roots.update(config.base_models_roots)
all_roots.update(config.embeddings_roots)
all_roots.update(config.base_models_roots or [])
all_roots.update(config.embeddings_roots or [])
total_deleted = 0
total_size_freed = 0
for root_path in all_roots:
if not os.path.exists(root_path):
continue
try:
deleted_count, size_freed = await cls._cleanup_backup_files_in_directory(root_path)
(
deleted_count,
size_freed,
) = await cls._cleanup_backup_files_in_directory(root_path)
total_deleted += deleted_count
total_size_freed += size_freed
if deleted_count > 0:
logger.debug(f"Cleaned up {deleted_count} .bak files in {root_path} (freed {size_freed / (1024*1024):.2f} MB)")
logger.debug(
f"Cleaned up {deleted_count} .bak files in {root_path} (freed {size_freed / (1024 * 1024):.2f} MB)"
)
except Exception as e:
logger.error(f"Error cleaning up .bak files in {root_path}: {e}")
# Yield control periodically
await asyncio.sleep(0.01)
if total_deleted > 0:
logger.debug(f"Backup cleanup completed: removed {total_deleted} .bak files, freed {total_size_freed / (1024*1024):.2f} MB total")
logger.debug(
f"Backup cleanup completed: removed {total_deleted} .bak files, freed {total_size_freed / (1024 * 1024):.2f} MB total"
)
else:
logger.debug("Backup cleanup completed: no .bak files found")
except Exception as e:
logger.error(f"Error during backup file cleanup: {e}", exc_info=True)
@classmethod
async def _cleanup_backup_files_in_directory(cls, directory_path: str):
"""Clean up .bak files in a specific directory recursively
Args:
directory_path: Path to the directory to clean
Returns:
Tuple[int, int]: (number of files deleted, total size freed in bytes)
"""
deleted_count = 0
size_freed = 0
visited_paths = set()
def cleanup_recursive(path):
nonlocal deleted_count, size_freed
try:
real_path = os.path.realpath(path)
if real_path in visited_paths:
return
visited_paths.add(real_path)
with os.scandir(path) as it:
for entry in it:
try:
if entry.is_file(follow_symlinks=True) and entry.name.endswith('.bak'):
if entry.is_file(
follow_symlinks=True
) and entry.name.endswith(".bak"):
file_size = entry.stat().st_size
os.remove(entry.path)
deleted_count += 1
size_freed += file_size
logger.debug(f"Deleted .bak file: {entry.path}")
elif entry.is_dir(follow_symlinks=True):
cleanup_recursive(entry.path)
except Exception as e:
logger.warning(f"Could not delete .bak file {entry.path}: {e}")
logger.warning(
f"Could not delete .bak file {entry.path}: {e}"
)
except Exception as e:
logger.error(f"Error scanning directory {path} for .bak files: {e}")
# Run the recursive cleanup in a thread pool to avoid blocking
loop = asyncio.get_event_loop()
await loop.run_in_executor(None, cleanup_recursive, directory_path)
return deleted_count, size_freed
@classmethod
async def _cleanup_example_images_folders(cls):
"""Invoke the example images cleanup service for manual execution."""
@@ -370,21 +425,21 @@ class LoraManager:
service = ExampleImagesCleanupService()
result = await service.cleanup_example_image_folders()
if result.get('success'):
if result.get("success"):
logger.debug(
"Manual example images cleanup completed: moved=%s",
result.get('moved_total'),
result.get("moved_total"),
)
elif result.get('partial_success'):
elif result.get("partial_success"):
logger.warning(
"Manual example images cleanup partially succeeded: moved=%s failures=%s",
result.get('moved_total'),
result.get('move_failures'),
result.get("moved_total"),
result.get("move_failures"),
)
else:
logger.debug(
"Manual example images cleanup skipped or failed: %s",
result.get('error', 'no changes'),
result.get("error", "no changes"),
)
return result
@@ -392,9 +447,9 @@ class LoraManager:
except Exception as e: # pragma: no cover - defensive guard
logger.error(f"Error during example images cleanup: {e}", exc_info=True)
return {
'success': False,
'error': str(e),
'error_code': 'unexpected_error',
"success": False,
"error": str(e),
"error_code": "unexpected_error",
}
@classmethod
@@ -402,6 +457,6 @@ class LoraManager:
"""Cleanup resources using ServiceRegistry"""
try:
logger.info("LoRA Manager: Cleaning up services")
except Exception as e:
logger.error(f"Error during cleanup: {e}", exc_info=True)

View File

@@ -4,7 +4,10 @@ import logging
logger = logging.getLogger(__name__)
# Check if running in standalone mode
standalone_mode = os.environ.get("LORA_MANAGER_STANDALONE", "0") == "1" or os.environ.get("HF_HUB_DISABLE_TELEMETRY", "0") == "0"
standalone_mode = (
os.environ.get("LORA_MANAGER_STANDALONE", "0") == "1"
or os.environ.get("HF_HUB_DISABLE_TELEMETRY", "0") == "0"
)
if not standalone_mode:
from .metadata_hook import MetadataHook
@@ -13,13 +16,13 @@ if not standalone_mode:
def init():
# Install hooks to collect metadata during execution
MetadataHook.install()
# Initialize registry
registry = MetadataRegistry()
logger.info("ComfyUI Metadata Collector initialized")
def get_metadata(prompt_id=None):
def get_metadata(prompt_id=None): # type: ignore[no-redef]
"""Helper function to get metadata from the registry"""
registry = MetadataRegistry()
return registry.get_metadata(prompt_id)
@@ -27,7 +30,7 @@ else:
# Standalone mode - provide dummy implementations
def init():
logger.info("ComfyUI Metadata Collector disabled in standalone mode")
def get_metadata(prompt_id=None):
def get_metadata(prompt_id=None): # type: ignore[no-redef]
"""Dummy implementation for standalone mode"""
return {}

View File

@@ -1,50 +1,54 @@
import time
from nodes import NODE_CLASS_MAPPINGS
from nodes import NODE_CLASS_MAPPINGS # type: ignore
from .node_extractors import NODE_EXTRACTORS, GenericNodeExtractor
from .constants import METADATA_CATEGORIES, IMAGES
class MetadataRegistry:
"""A singleton registry to store and retrieve workflow metadata"""
_instance = None
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._reset()
return cls._instance
def _reset(self):
self.current_prompt_id = None
self.current_prompt = None
self.metadata = {}
self.prompt_metadata = {}
self.executed_nodes = set()
# Node-level cache for metadata
self.node_cache = {}
# Limit the number of stored prompts
self.max_prompt_history = 3
# Categories we want to track and retrieve from cache
self.metadata_categories = METADATA_CATEGORIES
def _clean_old_prompts(self):
"""Clean up old prompt metadata, keeping only recent ones"""
if len(self.prompt_metadata) <= self.max_prompt_history:
return
# Sort all prompt_ids by timestamp
sorted_prompts = sorted(
self.prompt_metadata.keys(),
key=lambda pid: self.prompt_metadata[pid].get("timestamp", 0)
key=lambda pid: self.prompt_metadata[pid].get("timestamp", 0),
)
# Remove oldest records
prompts_to_remove = sorted_prompts[:len(sorted_prompts) - self.max_prompt_history]
prompts_to_remove = sorted_prompts[
: len(sorted_prompts) - self.max_prompt_history
]
for pid in prompts_to_remove:
del self.prompt_metadata[pid]
def start_collection(self, prompt_id):
"""Begin metadata collection for a new prompt"""
self.current_prompt_id = prompt_id
@@ -53,90 +57,96 @@ class MetadataRegistry:
category: {} for category in METADATA_CATEGORIES
}
# Add additional metadata fields
self.prompt_metadata[prompt_id].update({
"execution_order": [],
"current_prompt": None, # Will store the prompt object
"timestamp": time.time()
})
self.prompt_metadata[prompt_id].update(
{
"execution_order": [],
"current_prompt": None, # Will store the prompt object
"timestamp": time.time(),
}
)
# Clean up old prompt data
self._clean_old_prompts()
def set_current_prompt(self, prompt):
"""Set the current prompt object reference"""
self.current_prompt = prompt
if self.current_prompt_id and self.current_prompt_id in self.prompt_metadata:
# Store the prompt in the metadata for later relationship tracing
self.prompt_metadata[self.current_prompt_id]["current_prompt"] = prompt
def get_metadata(self, prompt_id=None):
"""Get collected metadata for a prompt"""
key = prompt_id if prompt_id is not None else self.current_prompt_id
if key not in self.prompt_metadata:
return {}
metadata = self.prompt_metadata[key]
# If we have a current prompt object, check for non-executed nodes
prompt_obj = metadata.get("current_prompt")
if prompt_obj and hasattr(prompt_obj, "original_prompt"):
original_prompt = prompt_obj.original_prompt
# Fill in missing metadata from cache for nodes that weren't executed
self._fill_missing_metadata(key, original_prompt)
return self.prompt_metadata.get(key, {})
def _fill_missing_metadata(self, prompt_id, original_prompt):
"""Fill missing metadata from cache for non-executed nodes"""
if not original_prompt:
return
executed_nodes = self.executed_nodes
metadata = self.prompt_metadata[prompt_id]
# Iterate through nodes in the original prompt
for node_id, node_data in original_prompt.items():
# Skip if already executed in this run
if node_id in executed_nodes:
continue
# Get the node type from the prompt (this is the key in NODE_CLASS_MAPPINGS)
prompt_class_type = node_data.get("class_type")
if not prompt_class_type:
continue
# Convert to actual class name (which is what we use in our cache)
class_type = prompt_class_type
if prompt_class_type in NODE_CLASS_MAPPINGS:
class_obj = NODE_CLASS_MAPPINGS[prompt_class_type]
class_type = class_obj.__name__
# Create cache key using the actual class name
cache_key = f"{node_id}:{class_type}"
# Check if this node type is relevant for metadata collection
if class_type in NODE_EXTRACTORS:
# Check if we have cached metadata for this node
if cache_key in self.node_cache:
cached_data = self.node_cache[cache_key]
# Apply cached metadata to the current metadata
for category in self.metadata_categories:
if category in cached_data and node_id in cached_data[category]:
if node_id not in metadata[category]:
metadata[category][node_id] = cached_data[category][node_id]
metadata[category][node_id] = cached_data[category][
node_id
]
def record_node_execution(self, node_id, class_type, inputs, outputs):
"""Record information about a node's execution"""
if not self.current_prompt_id:
return
# Add to execution order and mark as executed
if node_id not in self.executed_nodes:
self.executed_nodes.add(node_id)
self.prompt_metadata[self.current_prompt_id]["execution_order"].append(node_id)
self.prompt_metadata[self.current_prompt_id]["execution_order"].append(
node_id
)
# Process inputs to simplify working with them
processed_inputs = {}
for input_name, input_values in inputs.items():
@@ -145,63 +155,61 @@ class MetadataRegistry:
processed_inputs[input_name] = input_values[0]
else:
processed_inputs[input_name] = input_values
# Extract node-specific metadata
extractor = NODE_EXTRACTORS.get(class_type, GenericNodeExtractor)
extractor.extract(
node_id,
processed_inputs,
outputs,
self.prompt_metadata[self.current_prompt_id]
node_id,
processed_inputs,
outputs,
self.prompt_metadata[self.current_prompt_id],
)
# Cache this node's metadata
self._cache_node_metadata(node_id, class_type)
def update_node_execution(self, node_id, class_type, outputs):
"""Update node metadata with output information"""
if not self.current_prompt_id:
return
# Process outputs to make them more usable
processed_outputs = outputs
# Use the same extractor to update with outputs
extractor = NODE_EXTRACTORS.get(class_type, GenericNodeExtractor)
if hasattr(extractor, 'update'):
if hasattr(extractor, "update"):
extractor.update(
node_id,
processed_outputs,
self.prompt_metadata[self.current_prompt_id]
node_id, processed_outputs, self.prompt_metadata[self.current_prompt_id]
)
# Update the cached metadata for this node
self._cache_node_metadata(node_id, class_type)
def _cache_node_metadata(self, node_id, class_type):
"""Cache the metadata for a specific node"""
if not self.current_prompt_id or not node_id or not class_type:
return
# Create a cache key combining node_id and class_type
cache_key = f"{node_id}:{class_type}"
# Create a shallow copy of the node's metadata
node_metadata = {}
current_metadata = self.prompt_metadata[self.current_prompt_id]
for category in self.metadata_categories:
if category in current_metadata and node_id in current_metadata[category]:
if category not in node_metadata:
node_metadata[category] = {}
node_metadata[category][node_id] = current_metadata[category][node_id]
# Save new metadata or clear stale cache entries when metadata is empty
if any(node_metadata.values()):
self.node_cache[cache_key] = node_metadata
else:
self.node_cache.pop(cache_key, None)
def clear_unused_cache(self):
"""Clean up node_cache entries that are no longer in use"""
# Collect all node_ids currently in prompt_metadata
@@ -210,18 +218,18 @@ class MetadataRegistry:
for category in self.metadata_categories:
if category in prompt_data:
active_node_ids.update(prompt_data[category].keys())
# Find cache keys that are no longer needed
keys_to_remove = []
for cache_key in self.node_cache:
node_id = cache_key.split(':')[0]
node_id = cache_key.split(":")[0]
if node_id not in active_node_ids:
keys_to_remove.append(cache_key)
# Remove cache entries that are no longer needed
for key in keys_to_remove:
del self.node_cache[key]
def clear_metadata(self, prompt_id=None):
"""Clear metadata for a specific prompt or reset all data"""
if prompt_id is not None:
@@ -232,25 +240,25 @@ class MetadataRegistry:
else:
# Reset all data
self._reset()
def get_first_decoded_image(self, prompt_id=None):
"""Get the first decoded image result"""
key = prompt_id if prompt_id is not None else self.current_prompt_id
if key not in self.prompt_metadata:
return None
metadata = self.prompt_metadata[key]
if IMAGES in metadata and "first_decode" in metadata[IMAGES]:
image_data = metadata[IMAGES]["first_decode"]["image"]
# If it's an image batch or tuple, handle various formats
if isinstance(image_data, (list, tuple)) and len(image_data) > 0:
# Return first element of list/tuple
return image_data[0]
# If it's a tensor, return as is for processing in the route handler
return image_data
# If no image is found in the current metadata, try to find it in the cache
# This handles the case where VAEDecode was cached by ComfyUI and not executed
prompt_obj = metadata.get("current_prompt")
@@ -270,8 +278,11 @@ class MetadataRegistry:
if IMAGES in cached_data and node_id in cached_data[IMAGES]:
image_data = cached_data[IMAGES][node_id]["image"]
# Handle different image formats
if isinstance(image_data, (list, tuple)) and len(image_data) > 0:
if (
isinstance(image_data, (list, tuple))
and len(image_data) > 0
):
return image_data[0]
return image_data
return None

View File

@@ -1,8 +1,9 @@
import json
import os
import re
from typing import Any, Dict, Optional
import numpy as np
import folder_paths # type: ignore
import folder_paths # type: ignore
from ..services.service_registry import ServiceRegistry
from ..metadata_collector.metadata_processor import MetadataProcessor
from ..metadata_collector import get_metadata
@@ -12,6 +13,7 @@ import logging
logger = logging.getLogger(__name__)
class SaveImageLM:
NAME = "Save Image (LoraManager)"
CATEGORY = "Lora Manager/utils"
@@ -23,42 +25,60 @@ class SaveImageLM:
self.prefix_append = ""
self.compress_level = 4
self.counter = 0
# Add pattern format regex for filename substitution
pattern_format = re.compile(r"(%[^%]+%)")
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"images": ("IMAGE",),
"filename_prefix": ("STRING", {
"default": "ComfyUI",
"tooltip": "Base filename for saved images. Supports format patterns like %seed%, %width%, %height%, %model%, etc."
}),
"file_format": (["png", "jpeg", "webp"], {
"tooltip": "Image format to save as. PNG preserves quality, JPEG is smaller, WebP balances size and quality."
}),
"filename_prefix": (
"STRING",
{
"default": "ComfyUI",
"tooltip": "Base filename for saved images. Supports format patterns like %seed%, %width%, %height%, %model%, etc.",
},
),
"file_format": (
["png", "jpeg", "webp"],
{
"tooltip": "Image format to save as. PNG preserves quality, JPEG is smaller, WebP balances size and quality."
},
),
},
"optional": {
"lossless_webp": ("BOOLEAN", {
"default": False,
"tooltip": "When enabled, saves WebP images with lossless compression. Results in larger files but no quality loss."
}),
"quality": ("INT", {
"default": 100,
"min": 1,
"max": 100,
"tooltip": "Compression quality for JPEG and lossy WebP formats (1-100). Higher values mean better quality but larger files."
}),
"embed_workflow": ("BOOLEAN", {
"default": False,
"tooltip": "Embeds the complete workflow data into the image metadata. Only works with PNG and WebP formats."
}),
"add_counter_to_filename": ("BOOLEAN", {
"default": True,
"tooltip": "Adds an incremental counter to filenames to prevent overwriting previous images."
}),
"lossless_webp": (
"BOOLEAN",
{
"default": False,
"tooltip": "When enabled, saves WebP images with lossless compression. Results in larger files but no quality loss.",
},
),
"quality": (
"INT",
{
"default": 100,
"min": 1,
"max": 100,
"tooltip": "Compression quality for JPEG and lossy WebP formats (1-100). Higher values mean better quality but larger files.",
},
),
"embed_workflow": (
"BOOLEAN",
{
"default": False,
"tooltip": "Embeds the complete workflow data into the image metadata. Only works with PNG and WebP formats.",
},
),
"add_counter_to_filename": (
"BOOLEAN",
{
"default": True,
"tooltip": "Adds an incremental counter to filenames to prevent overwriting previous images.",
},
),
},
"hidden": {
"id": "UNIQUE_ID",
@@ -75,57 +95,59 @@ class SaveImageLM:
def get_lora_hash(self, lora_name):
"""Get the lora hash from cache"""
scanner = ServiceRegistry.get_service_sync("lora_scanner")
# Use the new direct filename lookup method
hash_value = scanner.get_hash_by_filename(lora_name)
if hash_value:
return hash_value
if scanner is not None:
hash_value = scanner.get_hash_by_filename(lora_name)
if hash_value:
return hash_value
return None
def get_checkpoint_hash(self, checkpoint_path):
"""Get the checkpoint hash from cache"""
scanner = ServiceRegistry.get_service_sync("checkpoint_scanner")
if not checkpoint_path:
return None
# Extract basename without extension
checkpoint_name = os.path.basename(checkpoint_path)
checkpoint_name = os.path.splitext(checkpoint_name)[0]
# Try direct filename lookup first
hash_value = scanner.get_hash_by_filename(checkpoint_name)
if hash_value:
return hash_value
if scanner is not None:
hash_value = scanner.get_hash_by_filename(checkpoint_name)
if hash_value:
return hash_value
return None
def format_metadata(self, metadata_dict):
"""Format metadata in the requested format similar to userComment example"""
if not metadata_dict:
return ""
# Helper function to only add parameter if value is not None
def add_param_if_not_none(param_list, label, value):
if value is not None:
param_list.append(f"{label}: {value}")
# Extract the prompt and negative prompt
prompt = metadata_dict.get('prompt', '')
negative_prompt = metadata_dict.get('negative_prompt', '')
prompt = metadata_dict.get("prompt", "")
negative_prompt = metadata_dict.get("negative_prompt", "")
# Extract loras from the prompt if present
loras_text = metadata_dict.get('loras', '')
loras_text = metadata_dict.get("loras", "")
lora_hashes = {}
# If loras are found, add them on a new line after the prompt
if loras_text:
prompt_with_loras = f"{prompt}\n{loras_text}"
# Extract lora names from the format <lora:name:strength>
lora_matches = re.findall(r'<lora:([^:]+):([^>]+)>', loras_text)
lora_matches = re.findall(r"<lora:([^:]+):([^>]+)>", loras_text)
# Get hash for each lora
for lora_name, strength in lora_matches:
hash_value = self.get_lora_hash(lora_name)
@@ -133,112 +155,114 @@ class SaveImageLM:
lora_hashes[lora_name] = hash_value
else:
prompt_with_loras = prompt
# Format the first part (prompt and loras)
metadata_parts = [prompt_with_loras]
# Add negative prompt
if negative_prompt:
metadata_parts.append(f"Negative prompt: {negative_prompt}")
# Format the second part (generation parameters)
params = []
# Add standard parameters in the correct order
if 'steps' in metadata_dict:
add_param_if_not_none(params, "Steps", metadata_dict.get('steps'))
if "steps" in metadata_dict:
add_param_if_not_none(params, "Steps", metadata_dict.get("steps"))
# Combine sampler and scheduler information
sampler_name = None
scheduler_name = None
if 'sampler' in metadata_dict:
sampler = metadata_dict.get('sampler')
if "sampler" in metadata_dict:
sampler = metadata_dict.get("sampler")
# Convert ComfyUI sampler names to user-friendly names
sampler_mapping = {
'euler': 'Euler',
'euler_ancestral': 'Euler a',
'dpm_2': 'DPM2',
'dpm_2_ancestral': 'DPM2 a',
'heun': 'Heun',
'dpm_fast': 'DPM fast',
'dpm_adaptive': 'DPM adaptive',
'lms': 'LMS',
'dpmpp_2s_ancestral': 'DPM++ 2S a',
'dpmpp_sde': 'DPM++ SDE',
'dpmpp_sde_gpu': 'DPM++ SDE',
'dpmpp_2m': 'DPM++ 2M',
'dpmpp_2m_sde': 'DPM++ 2M SDE',
'dpmpp_2m_sde_gpu': 'DPM++ 2M SDE',
'ddim': 'DDIM'
"euler": "Euler",
"euler_ancestral": "Euler a",
"dpm_2": "DPM2",
"dpm_2_ancestral": "DPM2 a",
"heun": "Heun",
"dpm_fast": "DPM fast",
"dpm_adaptive": "DPM adaptive",
"lms": "LMS",
"dpmpp_2s_ancestral": "DPM++ 2S a",
"dpmpp_sde": "DPM++ SDE",
"dpmpp_sde_gpu": "DPM++ SDE",
"dpmpp_2m": "DPM++ 2M",
"dpmpp_2m_sde": "DPM++ 2M SDE",
"dpmpp_2m_sde_gpu": "DPM++ 2M SDE",
"ddim": "DDIM",
}
sampler_name = sampler_mapping.get(sampler, sampler)
if 'scheduler' in metadata_dict:
scheduler = metadata_dict.get('scheduler')
if "scheduler" in metadata_dict:
scheduler = metadata_dict.get("scheduler")
scheduler_mapping = {
'normal': 'Simple',
'karras': 'Karras',
'exponential': 'Exponential',
'sgm_uniform': 'SGM Uniform',
'sgm_quadratic': 'SGM Quadratic'
"normal": "Simple",
"karras": "Karras",
"exponential": "Exponential",
"sgm_uniform": "SGM Uniform",
"sgm_quadratic": "SGM Quadratic",
}
scheduler_name = scheduler_mapping.get(scheduler, scheduler)
# Add combined sampler and scheduler information
if sampler_name:
if scheduler_name:
params.append(f"Sampler: {sampler_name} {scheduler_name}")
else:
params.append(f"Sampler: {sampler_name}")
# CFG scale (Use guidance if available, otherwise fall back to cfg_scale or cfg)
if 'guidance' in metadata_dict:
add_param_if_not_none(params, "CFG scale", metadata_dict.get('guidance'))
elif 'cfg_scale' in metadata_dict:
add_param_if_not_none(params, "CFG scale", metadata_dict.get('cfg_scale'))
elif 'cfg' in metadata_dict:
add_param_if_not_none(params, "CFG scale", metadata_dict.get('cfg'))
if "guidance" in metadata_dict:
add_param_if_not_none(params, "CFG scale", metadata_dict.get("guidance"))
elif "cfg_scale" in metadata_dict:
add_param_if_not_none(params, "CFG scale", metadata_dict.get("cfg_scale"))
elif "cfg" in metadata_dict:
add_param_if_not_none(params, "CFG scale", metadata_dict.get("cfg"))
# Seed
if 'seed' in metadata_dict:
add_param_if_not_none(params, "Seed", metadata_dict.get('seed'))
if "seed" in metadata_dict:
add_param_if_not_none(params, "Seed", metadata_dict.get("seed"))
# Size
if 'size' in metadata_dict:
add_param_if_not_none(params, "Size", metadata_dict.get('size'))
if "size" in metadata_dict:
add_param_if_not_none(params, "Size", metadata_dict.get("size"))
# Model info
if 'checkpoint' in metadata_dict:
if "checkpoint" in metadata_dict:
# Ensure checkpoint is a string before processing
checkpoint = metadata_dict.get('checkpoint')
checkpoint = metadata_dict.get("checkpoint")
if checkpoint is not None:
# Get model hash
model_hash = self.get_checkpoint_hash(checkpoint)
# Extract basename without path
checkpoint_name = os.path.basename(checkpoint)
# Remove extension if present
checkpoint_name = os.path.splitext(checkpoint_name)[0]
# Add model hash if available
if model_hash:
params.append(f"Model hash: {model_hash[:10]}, Model: {checkpoint_name}")
params.append(
f"Model hash: {model_hash[:10]}, Model: {checkpoint_name}"
)
else:
params.append(f"Model: {checkpoint_name}")
# Add LoRA hashes if available
if lora_hashes:
lora_hash_parts = []
for lora_name, hash_value in lora_hashes.items():
lora_hash_parts.append(f"{lora_name}: {hash_value[:10]}")
if lora_hash_parts:
params.append(f"Lora hashes: \"{', '.join(lora_hash_parts)}\"")
params.append(f'Lora hashes: "{", ".join(lora_hash_parts)}"')
# Combine all parameters with commas
metadata_parts.append(", ".join(params))
# Join all parts with a new line
return "\n".join(metadata_parts)
@@ -248,36 +272,36 @@ class SaveImageLM:
"""Format filename with metadata values"""
if not metadata_dict:
return filename
result = re.findall(self.pattern_format, filename)
for segment in result:
parts = segment.replace("%", "").split(":")
key = parts[0]
if key == "seed" and 'seed' in metadata_dict:
filename = filename.replace(segment, str(metadata_dict.get('seed', '')))
elif key == "width" and 'size' in metadata_dict:
size = metadata_dict.get('size', 'x')
w = size.split('x')[0] if isinstance(size, str) else size[0]
if key == "seed" and "seed" in metadata_dict:
filename = filename.replace(segment, str(metadata_dict.get("seed", "")))
elif key == "width" and "size" in metadata_dict:
size = metadata_dict.get("size", "x")
w = size.split("x")[0] if isinstance(size, str) else size[0]
filename = filename.replace(segment, str(w))
elif key == "height" and 'size' in metadata_dict:
size = metadata_dict.get('size', 'x')
h = size.split('x')[1] if isinstance(size, str) else size[1]
elif key == "height" and "size" in metadata_dict:
size = metadata_dict.get("size", "x")
h = size.split("x")[1] if isinstance(size, str) else size[1]
filename = filename.replace(segment, str(h))
elif key == "pprompt" and 'prompt' in metadata_dict:
prompt = metadata_dict.get('prompt', '').replace("\n", " ")
elif key == "pprompt" and "prompt" in metadata_dict:
prompt = metadata_dict.get("prompt", "").replace("\n", " ")
if len(parts) >= 2:
length = int(parts[1])
prompt = prompt[:length]
filename = filename.replace(segment, prompt.strip())
elif key == "nprompt" and 'negative_prompt' in metadata_dict:
prompt = metadata_dict.get('negative_prompt', '').replace("\n", " ")
elif key == "nprompt" and "negative_prompt" in metadata_dict:
prompt = metadata_dict.get("negative_prompt", "").replace("\n", " ")
if len(parts) >= 2:
length = int(parts[1])
prompt = prompt[:length]
filename = filename.replace(segment, prompt.strip())
elif key == "model":
model_value = metadata_dict.get('checkpoint')
model_value = metadata_dict.get("checkpoint")
if isinstance(model_value, (bytes, os.PathLike)):
model_value = str(model_value)
@@ -291,6 +315,7 @@ class SaveImageLM:
filename = filename.replace(segment, model)
elif key == "date":
from datetime import datetime
now = datetime.now()
date_table = {
"yyyy": f"{now.year:04d}",
@@ -311,46 +336,62 @@ class SaveImageLM:
for k, v in date_table.items():
date_format = date_format.replace(k, v)
filename = filename.replace(segment, date_format)
return filename
def save_images(self, images, filename_prefix, file_format, id, prompt=None, extra_pnginfo=None,
lossless_webp=True, quality=100, embed_workflow=False, add_counter_to_filename=True):
def save_images(
self,
images,
filename_prefix,
file_format,
id,
prompt=None,
extra_pnginfo=None,
lossless_webp=True,
quality=100,
embed_workflow=False,
add_counter_to_filename=True,
):
"""Save images with metadata"""
results = []
# Get metadata using the metadata collector
raw_metadata = get_metadata()
metadata_dict = MetadataProcessor.to_dict(raw_metadata, id)
metadata = self.format_metadata(metadata_dict)
# Process filename_prefix with pattern substitution
filename_prefix = self.format_filename(filename_prefix, metadata_dict)
# Get initial save path info once for the batch
full_output_folder, filename, counter, subfolder, processed_prefix = folder_paths.get_save_image_path(
filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0]
full_output_folder, filename, counter, subfolder, processed_prefix = (
folder_paths.get_save_image_path(
filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0]
)
)
# Create directory if it doesn't exist
if not os.path.exists(full_output_folder):
os.makedirs(full_output_folder, exist_ok=True)
# Process each image with incrementing counter
for i, image in enumerate(images):
# Convert the tensor image to numpy array
img = 255. * image.cpu().numpy()
img = 255.0 * image.cpu().numpy()
img = Image.fromarray(np.clip(img, 0, 255).astype(np.uint8))
# Generate filename with counter if needed
base_filename = filename
if add_counter_to_filename:
# Use counter + i to ensure unique filenames for all images in batch
current_counter = counter + i
base_filename += f"_{current_counter:05}_"
# Set file extension and prepare saving parameters
file: str
save_kwargs: Dict[str, Any]
pnginfo: Optional[PngImagePlugin.PngInfo] = None
if file_format == "png":
file = base_filename + ".png"
file_extension = ".png"
@@ -362,17 +403,24 @@ class SaveImageLM:
file_extension = ".jpg"
save_kwargs = {"quality": quality, "optimize": True}
elif file_format == "webp":
file = base_filename + ".webp"
file = base_filename + ".webp"
file_extension = ".webp"
# Add optimization param to control performance
save_kwargs = {"quality": quality, "lossless": lossless_webp, "method": 0}
save_kwargs = {
"quality": quality,
"lossless": lossless_webp,
"method": 0,
}
else:
raise ValueError(f"Unsupported file format: {file_format}")
# Full save path
file_path = os.path.join(full_output_folder, file)
# Save the image with metadata
try:
if file_format == "png":
assert pnginfo is not None
if metadata:
pnginfo.add_text("parameters", metadata)
if embed_workflow and extra_pnginfo is not None:
@@ -384,7 +432,12 @@ class SaveImageLM:
# For JPEG, use piexif
if metadata:
try:
exif_dict = {'Exif': {piexif.ExifIFD.UserComment: b'UNICODE\0' + metadata.encode('utf-16be')}}
exif_dict = {
"Exif": {
piexif.ExifIFD.UserComment: b"UNICODE\0"
+ metadata.encode("utf-16be")
}
}
exif_bytes = piexif.dump(exif_dict)
save_kwargs["exif"] = exif_bytes
except Exception as e:
@@ -396,37 +449,52 @@ class SaveImageLM:
exif_dict = {}
if metadata:
exif_dict['Exif'] = {piexif.ExifIFD.UserComment: b'UNICODE\0' + metadata.encode('utf-16be')}
exif_dict["Exif"] = {
piexif.ExifIFD.UserComment: b"UNICODE\0"
+ metadata.encode("utf-16be")
}
# Add workflow if needed
if embed_workflow and extra_pnginfo is not None:
workflow_json = json.dumps(extra_pnginfo["workflow"])
exif_dict['0th'] = {piexif.ImageIFD.ImageDescription: "Workflow:" + workflow_json}
workflow_json = json.dumps(extra_pnginfo["workflow"])
exif_dict["0th"] = {
piexif.ImageIFD.ImageDescription: "Workflow:"
+ workflow_json
}
exif_bytes = piexif.dump(exif_dict)
save_kwargs["exif"] = exif_bytes
except Exception as e:
logger.error(f"Error adding EXIF data: {e}")
img.save(file_path, format="WEBP", **save_kwargs)
results.append({
"filename": file,
"subfolder": subfolder,
"type": self.type
})
results.append(
{"filename": file, "subfolder": subfolder, "type": self.type}
)
except Exception as e:
logger.error(f"Error saving image: {e}")
return results
def process_image(self, images, id, filename_prefix="ComfyUI", file_format="png", prompt=None, extra_pnginfo=None,
lossless_webp=True, quality=100, embed_workflow=False, add_counter_to_filename=True):
def process_image(
self,
images,
id,
filename_prefix="ComfyUI",
file_format="png",
prompt=None,
extra_pnginfo=None,
lossless_webp=True,
quality=100,
embed_workflow=False,
add_counter_to_filename=True,
):
"""Process and save image with metadata"""
# Make sure the output directory exists
os.makedirs(self.output_dir, exist_ok=True)
# If images is already a list or array of images, do nothing; otherwise, convert to list
if isinstance(images, (list, np.ndarray)):
pass
@@ -436,19 +504,19 @@ class SaveImageLM:
images = [images]
else: # Multiple images (batch, height, width, channels)
images = [img for img in images]
# Save all images
results = self.save_images(
images,
filename_prefix,
file_format,
images,
filename_prefix,
file_format,
id,
prompt,
prompt,
extra_pnginfo,
lossless_webp,
quality,
embed_workflow,
add_counter_to_filename
add_counter_to_filename,
)
return (images,)

View File

@@ -1,33 +1,35 @@
class AnyType(str):
"""A special class that is always equal in not equal comparisons. Credit to pythongosssss"""
"""A special class that is always equal in not equal comparisons. Credit to pythongosssss"""
def __ne__(self, __value: object) -> bool:
return False
def __ne__(self, __value: object) -> bool:
return False
# Credit to Regis Gaughan, III (rgthree)
class FlexibleOptionalInputType(dict):
"""A special class to make flexible nodes that pass data to our python handlers.
"""A special class to make flexible nodes that pass data to our python handlers.
Enables both flexible/dynamic input types (like for Any Switch) or a dynamic number of inputs
(like for Any Switch, Context Switch, Context Merge, Power Lora Loader, etc).
Enables both flexible/dynamic input types (like for Any Switch) or a dynamic number of inputs
(like for Any Switch, Context Switch, Context Merge, Power Lora Loader, etc).
Note, for ComfyUI, all that's needed is the `__contains__` override below, which tells ComfyUI
that our node will handle the input, regardless of what it is.
Note, for ComfyUI, all that's needed is the `__contains__` override below, which tells ComfyUI
that our node will handle the input, regardless of what it is.
However, with https://github.com/comfyanonymous/ComfyUI/pull/2666 a large change would occur
requiring more details on the input itself. There, we need to return a list/tuple where the first
item is the type. This can be a real type, or use the AnyType for additional flexibility.
However, with https://github.com/comfyanonymous/ComfyUI/pull/2666 a large change would occur
requiring more details on the input itself. There, we need to return a list/tuple where the first
item is the type. This can be a real type, or use the AnyType for additional flexibility.
This should be forwards compatible unless more changes occur in the PR.
"""
def __init__(self, type):
self.type = type
This should be forwards compatible unless more changes occur in the PR.
"""
def __getitem__(self, key):
return (self.type, )
def __init__(self, type):
self.type = type
def __contains__(self, key):
return True
def __getitem__(self, key):
return (self.type,)
def __contains__(self, key):
return True
any_type = AnyType("*")
@@ -37,25 +39,27 @@ import os
import logging
import copy
import sys
import folder_paths
import folder_paths # type: ignore
logger = logging.getLogger(__name__)
def extract_lora_name(lora_path):
"""Extract the lora name from a lora path (e.g., 'IL\\aorunIllstrious.safetensors' -> 'aorunIllstrious')"""
# Get the basename without extension
basename = os.path.basename(lora_path)
return os.path.splitext(basename)[0]
def get_loras_list(kwargs):
"""Helper to extract loras list from either old or new kwargs format"""
if 'loras' not in kwargs:
if "loras" not in kwargs:
return []
loras_data = kwargs['loras']
loras_data = kwargs["loras"]
# Handle new format: {'loras': {'__value__': [...]}}
if isinstance(loras_data, dict) and '__value__' in loras_data:
return loras_data['__value__']
if isinstance(loras_data, dict) and "__value__" in loras_data:
return loras_data["__value__"]
# Handle old format: {'loras': [...]}
elif isinstance(loras_data, list):
return loras_data
@@ -64,24 +68,26 @@ def get_loras_list(kwargs):
logger.warning(f"Unexpected loras format: {type(loras_data)}")
return []
def load_state_dict_in_safetensors(path, device="cpu", filter_prefix=""):
"""Simplified version of load_state_dict_in_safetensors that just loads from a local path"""
"""Simplified version of load_state_dict_in_safetensors that just loads from a local path"""
import safetensors.torch
state_dict = {}
with safetensors.torch.safe_open(path, framework="pt", device=device) as f:
with safetensors.torch.safe_open(path, framework="pt", device=device) as f: # type: ignore[attr-defined]
for k in f.keys():
if filter_prefix and not k.startswith(filter_prefix):
continue
state_dict[k.removeprefix(filter_prefix)] = f.get_tensor(k)
return state_dict
def to_diffusers(input_lora):
"""Simplified version of to_diffusers for Flux LoRA conversion"""
import torch
from diffusers.utils.state_dict_utils import convert_unet_state_dict_to_peft
from diffusers.loaders import FluxLoraLoaderMixin
from diffusers.loaders import FluxLoraLoaderMixin # type: ignore[attr-defined]
if isinstance(input_lora, str):
tensors = load_state_dict_in_safetensors(input_lora, device="cpu")
else:
@@ -91,22 +97,27 @@ def to_diffusers(input_lora):
for k, v in tensors.items():
if v.dtype not in [torch.float64, torch.float32, torch.bfloat16, torch.float16]:
tensors[k] = v.to(torch.bfloat16)
new_tensors = FluxLoraLoaderMixin.lora_state_dict(tensors)
new_tensors = convert_unet_state_dict_to_peft(new_tensors)
return new_tensors
def nunchaku_load_lora(model, lora_name, lora_strength):
"""Load a Flux LoRA for Nunchaku model"""
"""Load a Flux LoRA for Nunchaku model"""
# Get full path to the LoRA file. Allow both direct paths and registered LoRA names.
lora_path = lora_name if os.path.isfile(lora_name) else folder_paths.get_full_path("loras", lora_name)
lora_path = (
lora_name
if os.path.isfile(lora_name)
else folder_paths.get_full_path("loras", lora_name)
)
if not lora_path or not os.path.isfile(lora_path):
logger.warning("Skipping LoRA '%s' because it could not be found", lora_name)
return model
model_wrapper = model.model.diffusion_model
# Try to find copy_with_ctx in the same module as ComfyFluxWrapper
module_name = model_wrapper.__class__.__module__
module = sys.modules.get(module_name)
@@ -118,14 +129,16 @@ def nunchaku_load_lora(model, lora_name, lora_strength):
ret_model_wrapper.loras = [*model_wrapper.loras, (lora_path, lora_strength)]
else:
# Fallback to legacy logic
logger.warning("Please upgrade ComfyUI-nunchaku to 1.1.0 or above for better LoRA support. Falling back to legacy loading logic.")
logger.warning(
"Please upgrade ComfyUI-nunchaku to 1.1.0 or above for better LoRA support. Falling back to legacy loading logic."
)
transformer = model_wrapper.model
# Save the transformer temporarily
model_wrapper.model = None
ret_model = copy.deepcopy(model) # copy everything except the model
ret_model_wrapper = ret_model.model.diffusion_model
# Restore the model and set it for the copy
model_wrapper.model = transformer
ret_model_wrapper.model = transformer
@@ -133,15 +146,15 @@ def nunchaku_load_lora(model, lora_name, lora_strength):
# Convert the LoRA to diffusers format
sd = to_diffusers(lora_path)
# Handle embedding adjustment if needed
if "transformer.x_embedder.lora_A.weight" in sd:
new_in_channels = sd["transformer.x_embedder.lora_A.weight"].shape[1]
assert new_in_channels % 4 == 0
new_in_channels = new_in_channels // 4
old_in_channels = ret_model.model.model_config.unet_config["in_channels"]
if old_in_channels < new_in_channels:
ret_model.model.model_config.unet_config["in_channels"] = new_in_channels
return ret_model
return ret_model

View File

@@ -6,23 +6,24 @@ from .parsers import (
ComfyMetadataParser,
MetaFormatParser,
AutomaticMetadataParser,
CivitaiApiMetadataParser
CivitaiApiMetadataParser,
)
from .base import RecipeMetadataParser
logger = logging.getLogger(__name__)
class RecipeParserFactory:
"""Factory for creating recipe metadata parsers"""
@staticmethod
def create_parser(metadata) -> RecipeMetadataParser:
def create_parser(metadata) -> RecipeMetadataParser | None:
"""
Create appropriate parser based on the metadata content
Args:
metadata: The metadata from the image (dict or str)
Returns:
Appropriate RecipeMetadataParser implementation
"""
@@ -34,17 +35,18 @@ class RecipeParserFactory:
except Exception as e:
logger.debug(f"CivitaiApiMetadataParser check failed: {e}")
pass
# Convert dict to string for other parsers that expect string input
try:
import json
metadata_str = json.dumps(metadata)
except Exception as e:
logger.debug(f"Failed to convert dict to JSON string: {e}")
return None
else:
metadata_str = metadata
# Try ComfyMetadataParser which requires valid JSON
try:
if ComfyMetadataParser().is_metadata_matching(metadata_str):
@@ -52,7 +54,7 @@ class RecipeParserFactory:
except Exception:
# If JSON parsing fails, move on to other parsers
pass
# Check other parsers that expect string input
if RecipeFormatParser().is_metadata_matching(metadata_str):
return RecipeFormatParser()

View File

@@ -9,15 +9,16 @@ from ...services.metadata_service import get_default_metadata_provider
logger = logging.getLogger(__name__)
class CivitaiApiMetadataParser(RecipeMetadataParser):
"""Parser for Civitai image metadata format"""
def is_metadata_matching(self, metadata) -> bool:
"""Check if the metadata matches the Civitai image metadata format
Args:
metadata: The metadata from the image (dict)
Returns:
bool: True if this parser can handle the metadata
"""
@@ -28,7 +29,7 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
# Check for common CivitAI image metadata fields
civitai_image_fields = (
"resources",
"civitaiResources",
"civitaiResources",
"additionalResources",
"hashes",
"prompt",
@@ -40,7 +41,7 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
"width",
"height",
"Model",
"Model hash"
"Model hash",
)
return any(key in payload for key in civitai_image_fields)
@@ -50,7 +51,9 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
# Check for LoRA hash patterns
hashes = metadata.get("hashes")
if isinstance(hashes, dict) and any(str(key).lower().startswith("lora:") for key in hashes):
if isinstance(hashes, dict) and any(
str(key).lower().startswith("lora:") for key in hashes
):
return True
# Check nested meta object (common in CivitAI image responses)
@@ -61,22 +64,28 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
# Also check for LoRA hash patterns in nested meta
hashes = nested_meta.get("hashes")
if isinstance(hashes, dict) and any(str(key).lower().startswith("lora:") for key in hashes):
if isinstance(hashes, dict) and any(
str(key).lower().startswith("lora:") for key in hashes
):
return True
return False
async def parse_metadata(self, metadata, recipe_scanner=None, civitai_client=None) -> Dict[str, Any]:
async def parse_metadata( # type: ignore[override]
self, user_comment, recipe_scanner=None, civitai_client=None
) -> Dict[str, Any]:
"""Parse metadata from Civitai image format
Args:
metadata: The metadata from the image (dict)
user_comment: The metadata from the image (dict)
recipe_scanner: Optional recipe scanner service
civitai_client: Optional Civitai API client (deprecated, use metadata_provider instead)
Returns:
Dict containing parsed recipe data
"""
metadata: Dict[str, Any] = user_comment # type: ignore[assignment]
metadata = user_comment
try:
# Get metadata provider instead of using civitai_client directly
metadata_provider = await get_default_metadata_provider()
@@ -100,19 +109,19 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
)
):
metadata = inner_meta
# Initialize result structure
result = {
'base_model': None,
'loras': [],
'model': None,
'gen_params': {},
'from_civitai_image': True
"base_model": None,
"loras": [],
"model": None,
"gen_params": {},
"from_civitai_image": True,
}
# Track already added LoRAs to prevent duplicates
added_loras = {} # key: model_version_id or hash, value: index in result["loras"]
# Extract hash information from hashes field for LoRA matching
lora_hashes = {}
if "hashes" in metadata and isinstance(metadata["hashes"], dict):
@@ -121,14 +130,14 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
if key_str.lower().startswith("lora:"):
lora_name = key_str.split(":", 1)[1]
lora_hashes[lora_name] = hash_value
# Extract prompt and negative prompt
if "prompt" in metadata:
result["gen_params"]["prompt"] = metadata["prompt"]
if "negativePrompt" in metadata:
result["gen_params"]["negative_prompt"] = metadata["negativePrompt"]
# Extract other generation parameters
param_mapping = {
"steps": "steps",
@@ -138,98 +147,117 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
"Size": "size",
"clipSkip": "clip_skip",
}
for civitai_key, our_key in param_mapping.items():
if civitai_key in metadata and our_key in GEN_PARAM_KEYS:
result["gen_params"][our_key] = metadata[civitai_key]
# Extract base model information - directly if available
if "baseModel" in metadata:
result["base_model"] = metadata["baseModel"]
elif "Model hash" in metadata and metadata_provider:
model_hash = metadata["Model hash"]
model_info, error = await metadata_provider.get_model_by_hash(model_hash)
model_info, error = await metadata_provider.get_model_by_hash(
model_hash
)
if model_info:
result["base_model"] = model_info.get("baseModel", "")
elif "Model" in metadata and isinstance(metadata.get("resources"), list):
# Try to find base model in resources
for resource in metadata.get("resources", []):
if resource.get("type") == "model" and resource.get("name") == metadata.get("Model"):
if resource.get("type") == "model" and resource.get(
"name"
) == metadata.get("Model"):
# This is likely the checkpoint model
if metadata_provider and resource.get("hash"):
model_info, error = await metadata_provider.get_model_by_hash(resource.get("hash"))
(
model_info,
error,
) = await metadata_provider.get_model_by_hash(
resource.get("hash")
)
if model_info:
result["base_model"] = model_info.get("baseModel", "")
base_model_counts = {}
# Process standard resources array
if "resources" in metadata and isinstance(metadata["resources"], list):
for resource in metadata["resources"]:
# Modified to process resources without a type field as potential LoRAs
if resource.get("type", "lora") == "lora":
lora_hash = resource.get("hash", "")
# Try to get hash from the hashes field if not present in resource
if not lora_hash and resource.get("name"):
lora_hash = lora_hashes.get(resource["name"], "")
# Skip LoRAs without proper identification (hash or modelVersionId)
if not lora_hash and not resource.get("modelVersionId"):
logger.debug(f"Skipping LoRA resource '{resource.get('name', 'Unknown')}' - no hash or modelVersionId")
logger.debug(
f"Skipping LoRA resource '{resource.get('name', 'Unknown')}' - no hash or modelVersionId"
)
continue
# Skip if we've already added this LoRA by hash
if lora_hash and lora_hash in added_loras:
continue
lora_entry = {
'name': resource.get("name", "Unknown LoRA"),
'type': "lora",
'weight': float(resource.get("weight", 1.0)),
'hash': lora_hash,
'existsLocally': False,
'localPath': None,
'file_name': resource.get("name", "Unknown"),
'thumbnailUrl': '/loras_static/images/no-preview.png',
'baseModel': '',
'size': 0,
'downloadUrl': '',
'isDeleted': False
"name": resource.get("name", "Unknown LoRA"),
"type": "lora",
"weight": float(resource.get("weight", 1.0)),
"hash": lora_hash,
"existsLocally": False,
"localPath": None,
"file_name": resource.get("name", "Unknown"),
"thumbnailUrl": "/loras_static/images/no-preview.png",
"baseModel": "",
"size": 0,
"downloadUrl": "",
"isDeleted": False,
}
# Try to get info from Civitai if hash is available
if lora_entry['hash'] and metadata_provider:
if lora_entry["hash"] and metadata_provider:
try:
civitai_info = await metadata_provider.get_model_by_hash(lora_hash)
civitai_info = (
await metadata_provider.get_model_by_hash(lora_hash)
)
populated_entry = await self.populate_lora_from_civitai(
lora_entry,
civitai_info,
recipe_scanner,
base_model_counts,
lora_hash
lora_hash,
)
if populated_entry is None:
continue # Skip invalid LoRA types
lora_entry = populated_entry
# If we have a version ID from Civitai, track it for deduplication
if 'id' in lora_entry and lora_entry['id']:
added_loras[str(lora_entry['id'])] = len(result["loras"])
if "id" in lora_entry and lora_entry["id"]:
added_loras[str(lora_entry["id"])] = len(
result["loras"]
)
except Exception as e:
logger.error(f"Error fetching Civitai info for LoRA hash {lora_entry['hash']}: {e}")
logger.error(
f"Error fetching Civitai info for LoRA hash {lora_entry['hash']}: {e}"
)
# Track by hash if we have it
if lora_hash:
added_loras[lora_hash] = len(result["loras"])
result["loras"].append(lora_entry)
# Process civitaiResources array
if "civitaiResources" in metadata and isinstance(metadata["civitaiResources"], list):
if "civitaiResources" in metadata and isinstance(
metadata["civitaiResources"], list
):
for resource in metadata["civitaiResources"]:
# Get resource type and identifier
resource_type = str(resource.get("type") or "").lower()
@@ -237,32 +265,39 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
if resource_type == "checkpoint":
checkpoint_entry = {
'id': resource.get("modelVersionId", 0),
'modelId': resource.get("modelId", 0),
'name': resource.get("modelName", "Unknown Checkpoint"),
'version': resource.get("modelVersionName", ""),
'type': resource.get("type", "checkpoint"),
'existsLocally': False,
'localPath': None,
'file_name': resource.get("modelName", ""),
'hash': resource.get("hash", "") or "",
'thumbnailUrl': '/loras_static/images/no-preview.png',
'baseModel': '',
'size': 0,
'downloadUrl': '',
'isDeleted': False
"id": resource.get("modelVersionId", 0),
"modelId": resource.get("modelId", 0),
"name": resource.get("modelName", "Unknown Checkpoint"),
"version": resource.get("modelVersionName", ""),
"type": resource.get("type", "checkpoint"),
"existsLocally": False,
"localPath": None,
"file_name": resource.get("modelName", ""),
"hash": resource.get("hash", "") or "",
"thumbnailUrl": "/loras_static/images/no-preview.png",
"baseModel": "",
"size": 0,
"downloadUrl": "",
"isDeleted": False,
}
if version_id and metadata_provider:
try:
civitai_info = await metadata_provider.get_model_version_info(version_id)
civitai_info = (
await metadata_provider.get_model_version_info(
version_id
)
)
checkpoint_entry = await self.populate_checkpoint_from_civitai(
checkpoint_entry,
civitai_info
checkpoint_entry = (
await self.populate_checkpoint_from_civitai(
checkpoint_entry, civitai_info
)
)
except Exception as e:
logger.error(f"Error fetching Civitai info for checkpoint version {version_id}: {e}")
logger.error(
f"Error fetching Civitai info for checkpoint version {version_id}: {e}"
)
if result["model"] is None:
result["model"] = checkpoint_entry
@@ -275,31 +310,35 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
# Initialize lora entry
lora_entry = {
'id': resource.get("modelVersionId", 0),
'modelId': resource.get("modelId", 0),
'name': resource.get("modelName", "Unknown LoRA"),
'version': resource.get("modelVersionName", ""),
'type': resource.get("type", "lora"),
'weight': round(float(resource.get("weight", 1.0)), 2),
'existsLocally': False,
'thumbnailUrl': '/loras_static/images/no-preview.png',
'baseModel': '',
'size': 0,
'downloadUrl': '',
'isDeleted': False
"id": resource.get("modelVersionId", 0),
"modelId": resource.get("modelId", 0),
"name": resource.get("modelName", "Unknown LoRA"),
"version": resource.get("modelVersionName", ""),
"type": resource.get("type", "lora"),
"weight": round(float(resource.get("weight", 1.0)), 2),
"existsLocally": False,
"thumbnailUrl": "/loras_static/images/no-preview.png",
"baseModel": "",
"size": 0,
"downloadUrl": "",
"isDeleted": False,
}
# Try to get info from Civitai if modelVersionId is available
if version_id and metadata_provider:
try:
# Use get_model_version_info instead of get_model_version
civitai_info = await metadata_provider.get_model_version_info(version_id)
civitai_info = (
await metadata_provider.get_model_version_info(
version_id
)
)
populated_entry = await self.populate_lora_from_civitai(
lora_entry,
civitai_info,
recipe_scanner,
base_model_counts
base_model_counts,
)
if populated_entry is None:
@@ -307,74 +346,87 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
lora_entry = populated_entry
except Exception as e:
logger.error(f"Error fetching Civitai info for model version {version_id}: {e}")
logger.error(
f"Error fetching Civitai info for model version {version_id}: {e}"
)
# Track this LoRA in our deduplication dict
if version_id:
added_loras[version_id] = len(result["loras"])
result["loras"].append(lora_entry)
# Process additionalResources array
if "additionalResources" in metadata and isinstance(metadata["additionalResources"], list):
if "additionalResources" in metadata and isinstance(
metadata["additionalResources"], list
):
for resource in metadata["additionalResources"]:
# Skip resources that aren't LoRAs or LyCORIS
if resource.get("type") not in ["lora", "lycoris"] and "type" not in resource:
if (
resource.get("type") not in ["lora", "lycoris"]
and "type" not in resource
):
continue
lora_type = resource.get("type", "lora")
name = resource.get("name", "")
# Extract ID from URN format if available
version_id = None
if name and "civitai:" in name:
parts = name.split("@")
if len(parts) > 1:
version_id = parts[1]
# Skip if we've already added this LoRA
if version_id in added_loras:
continue
lora_entry = {
'name': name,
'type': lora_type,
'weight': float(resource.get("strength", 1.0)),
'hash': "",
'existsLocally': False,
'localPath': None,
'file_name': name,
'thumbnailUrl': '/loras_static/images/no-preview.png',
'baseModel': '',
'size': 0,
'downloadUrl': '',
'isDeleted': False
"name": name,
"type": lora_type,
"weight": float(resource.get("strength", 1.0)),
"hash": "",
"existsLocally": False,
"localPath": None,
"file_name": name,
"thumbnailUrl": "/loras_static/images/no-preview.png",
"baseModel": "",
"size": 0,
"downloadUrl": "",
"isDeleted": False,
}
# If we have a version ID and metadata provider, try to get more info
if version_id and metadata_provider:
try:
# Use get_model_version_info with the version ID
civitai_info = await metadata_provider.get_model_version_info(version_id)
civitai_info = (
await metadata_provider.get_model_version_info(
version_id
)
)
populated_entry = await self.populate_lora_from_civitai(
lora_entry,
civitai_info,
recipe_scanner,
base_model_counts
base_model_counts,
)
if populated_entry is None:
continue # Skip invalid LoRA types
lora_entry = populated_entry
# Track this LoRA for deduplication
if version_id:
added_loras[version_id] = len(result["loras"])
except Exception as e:
logger.error(f"Error fetching Civitai info for model ID {version_id}: {e}")
logger.error(
f"Error fetching Civitai info for model ID {version_id}: {e}"
)
result["loras"].append(lora_entry)
# If we found LoRA hashes in the metadata but haven't already
@@ -390,30 +442,32 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
continue
lora_entry = {
'name': lora_name,
'type': "lora",
'weight': 1.0,
'hash': lora_hash,
'existsLocally': False,
'localPath': None,
'file_name': lora_name,
'thumbnailUrl': '/loras_static/images/no-preview.png',
'baseModel': '',
'size': 0,
'downloadUrl': '',
'isDeleted': False
"name": lora_name,
"type": "lora",
"weight": 1.0,
"hash": lora_hash,
"existsLocally": False,
"localPath": None,
"file_name": lora_name,
"thumbnailUrl": "/loras_static/images/no-preview.png",
"baseModel": "",
"size": 0,
"downloadUrl": "",
"isDeleted": False,
}
if metadata_provider:
try:
civitai_info = await metadata_provider.get_model_by_hash(lora_hash)
civitai_info = await metadata_provider.get_model_by_hash(
lora_hash
)
populated_entry = await self.populate_lora_from_civitai(
lora_entry,
civitai_info,
recipe_scanner,
base_model_counts,
lora_hash
lora_hash,
)
if populated_entry is None:
@@ -421,80 +475,93 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
lora_entry = populated_entry
if 'id' in lora_entry and lora_entry['id']:
added_loras[str(lora_entry['id'])] = len(result["loras"])
if "id" in lora_entry and lora_entry["id"]:
added_loras[str(lora_entry["id"])] = len(result["loras"])
except Exception as e:
logger.error(f"Error fetching Civitai info for LoRA hash {lora_hash}: {e}")
logger.error(
f"Error fetching Civitai info for LoRA hash {lora_hash}: {e}"
)
added_loras[lora_hash] = len(result["loras"])
result["loras"].append(lora_entry)
# Check for LoRA info in the format "Lora_0 Model hash", "Lora_0 Model name", etc.
lora_index = 0
while f"Lora_{lora_index} Model hash" in metadata and f"Lora_{lora_index} Model name" in metadata:
while (
f"Lora_{lora_index} Model hash" in metadata
and f"Lora_{lora_index} Model name" in metadata
):
lora_hash = metadata[f"Lora_{lora_index} Model hash"]
lora_name = metadata[f"Lora_{lora_index} Model name"]
lora_strength_model = float(metadata.get(f"Lora_{lora_index} Strength model", 1.0))
lora_strength_model = float(
metadata.get(f"Lora_{lora_index} Strength model", 1.0)
)
# Skip if we've already added this LoRA by hash
if lora_hash and lora_hash in added_loras:
lora_index += 1
continue
lora_entry = {
'name': lora_name,
'type': "lora",
'weight': lora_strength_model,
'hash': lora_hash,
'existsLocally': False,
'localPath': None,
'file_name': lora_name,
'thumbnailUrl': '/loras_static/images/no-preview.png',
'baseModel': '',
'size': 0,
'downloadUrl': '',
'isDeleted': False
"name": lora_name,
"type": "lora",
"weight": lora_strength_model,
"hash": lora_hash,
"existsLocally": False,
"localPath": None,
"file_name": lora_name,
"thumbnailUrl": "/loras_static/images/no-preview.png",
"baseModel": "",
"size": 0,
"downloadUrl": "",
"isDeleted": False,
}
# Try to get info from Civitai if hash is available
if lora_entry['hash'] and metadata_provider:
if lora_entry["hash"] and metadata_provider:
try:
civitai_info = await metadata_provider.get_model_by_hash(lora_hash)
civitai_info = await metadata_provider.get_model_by_hash(
lora_hash
)
populated_entry = await self.populate_lora_from_civitai(
lora_entry,
civitai_info,
recipe_scanner,
base_model_counts,
lora_hash
lora_hash,
)
if populated_entry is None:
lora_index += 1
continue # Skip invalid LoRA types
lora_entry = populated_entry
# If we have a version ID from Civitai, track it for deduplication
if 'id' in lora_entry and lora_entry['id']:
added_loras[str(lora_entry['id'])] = len(result["loras"])
if "id" in lora_entry and lora_entry["id"]:
added_loras[str(lora_entry["id"])] = len(result["loras"])
except Exception as e:
logger.error(f"Error fetching Civitai info for LoRA hash {lora_entry['hash']}: {e}")
logger.error(
f"Error fetching Civitai info for LoRA hash {lora_entry['hash']}: {e}"
)
# Track by hash if we have it
if lora_hash:
added_loras[lora_hash] = len(result["loras"])
result["loras"].append(lora_entry)
lora_index += 1
# If base model wasn't found earlier, use the most common one from LoRAs
if not result["base_model"] and base_model_counts:
result["base_model"] = max(base_model_counts.items(), key=lambda x: x[1])[0]
result["base_model"] = max(
base_model_counts.items(), key=lambda x: x[1]
)[0]
return result
except Exception as e:
logger.error(f"Error parsing Civitai image metadata: {e}", exc_info=True)
return {"error": str(e), "loras": []}

View File

@@ -3,36 +3,42 @@ import copy
import logging
import os
from typing import Any, Optional, Dict, Tuple, List, Sequence
from .model_metadata_provider import CivitaiModelMetadataProvider, ModelMetadataProviderManager
from .model_metadata_provider import (
CivitaiModelMetadataProvider,
ModelMetadataProviderManager,
)
from .downloader import get_downloader
from .errors import RateLimitError, ResourceNotFoundError
from ..utils.civitai_utils import resolve_license_payload
logger = logging.getLogger(__name__)
class CivitaiClient:
_instance = None
_lock = asyncio.Lock()
@classmethod
async def get_instance(cls):
"""Get singleton instance of CivitaiClient"""
async with cls._lock:
if cls._instance is None:
cls._instance = cls()
# Register this client as a metadata provider
provider_manager = await ModelMetadataProviderManager.get_instance()
provider_manager.register_provider('civitai', CivitaiModelMetadataProvider(cls._instance), True)
provider_manager.register_provider(
"civitai", CivitaiModelMetadataProvider(cls._instance), True
)
return cls._instance
def __init__(self):
# Check if already initialized for singleton pattern
if hasattr(self, '_initialized'):
if hasattr(self, "_initialized"):
return
self._initialized = True
self.base_url = "https://civitai.com/api/v1"
async def _make_request(
@@ -75,8 +81,10 @@ class CivitaiClient:
meta = image.get("meta")
if isinstance(meta, dict) and "comfy" in meta:
meta.pop("comfy", None)
async def download_file(self, url: str, save_dir: str, default_filename: str, progress_callback=None) -> Tuple[bool, str]:
async def download_file(
self, url: str, save_dir: str, default_filename: str, progress_callback=None
) -> Tuple[bool, str]:
"""Download file with resumable downloads and retry mechanism
Args:
@@ -90,41 +98,48 @@ class CivitaiClient:
"""
downloader = await get_downloader()
save_path = os.path.join(save_dir, default_filename)
# Use unified downloader with CivitAI authentication
success, result = await downloader.download_file(
url=url,
save_path=save_path,
progress_callback=progress_callback,
use_auth=True, # Enable CivitAI authentication
allow_resume=True
allow_resume=True,
)
return success, result
async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]:
async def get_model_by_hash(
self, model_hash: str
) -> Tuple[Optional[Dict], Optional[str]]:
try:
success, version = await self._make_request(
'GET',
"GET",
f"{self.base_url}/model-versions/by-hash/{model_hash}",
use_auth=True
use_auth=True,
)
if not success:
message = str(version)
if "not found" in message.lower():
return None, "Model not found"
logger.error("Failed to fetch model info for %s: %s", model_hash[:10], message)
logger.error(
"Failed to fetch model info for %s: %s", model_hash[:10], message
)
return None, message
model_id = version.get('modelId')
if model_id:
model_data = await self._fetch_model_data(model_id)
if model_data:
self._enrich_version_with_model_data(version, model_data)
if isinstance(version, dict):
model_id = version.get("modelId")
if model_id:
model_data = await self._fetch_model_data(model_id)
if model_data:
self._enrich_version_with_model_data(version, model_data)
self._remove_comfy_metadata(version)
return version, None
self._remove_comfy_metadata(version)
return version, None
else:
return None, "Invalid response format"
except RateLimitError:
raise
except Exception as exc:
@@ -136,19 +151,19 @@ class CivitaiClient:
downloader = await get_downloader()
success, content, headers = await downloader.download_to_memory(
image_url,
use_auth=False # Preview images don't need auth
use_auth=False, # Preview images don't need auth
)
if success:
# Ensure directory exists
os.makedirs(os.path.dirname(save_path), exist_ok=True)
with open(save_path, 'wb') as f:
with open(save_path, "wb") as f:
f.write(content)
return True
return False
except Exception as e:
logger.error(f"Download Error: {str(e)}")
return False
@staticmethod
def _extract_error_message(payload: Any) -> str:
"""Return a human-readable error message from an API payload."""
@@ -175,19 +190,17 @@ class CivitaiClient:
"""Get all versions of a model with local availability info"""
try:
success, result = await self._make_request(
'GET',
f"{self.base_url}/models/{model_id}",
use_auth=True
"GET", f"{self.base_url}/models/{model_id}", use_auth=True
)
if success:
# Also return model type along with versions
return {
'modelVersions': result.get('modelVersions', []),
'type': result.get('type', ''),
'name': result.get('name', '')
"modelVersions": result.get("modelVersions", []),
"type": result.get("type", ""),
"name": result.get("name", ""),
}
message = self._extract_error_message(result)
if message and 'not found' in message.lower():
if message and "not found" in message.lower():
raise ResourceNotFoundError(f"Resource not found for model {model_id}")
if message:
raise RuntimeError(message)
@@ -221,15 +234,15 @@ class CivitaiClient:
try:
query = ",".join(normalized_ids)
success, result = await self._make_request(
'GET',
"GET",
f"{self.base_url}/models",
use_auth=True,
params={'ids': query},
params={"ids": query},
)
if not success:
return None
items = result.get('items') if isinstance(result, dict) else None
items = result.get("items") if isinstance(result, dict) else None
if not isinstance(items, list):
return {}
@@ -237,19 +250,19 @@ class CivitaiClient:
for item in items:
if not isinstance(item, dict):
continue
model_id = item.get('id')
model_id = item.get("id")
try:
normalized_id = int(model_id)
except (TypeError, ValueError):
continue
payload[normalized_id] = {
'modelVersions': item.get('modelVersions', []),
'type': item.get('type', ''),
'name': item.get('name', ''),
'allowNoCredit': item.get('allowNoCredit'),
'allowCommercialUse': item.get('allowCommercialUse'),
'allowDerivatives': item.get('allowDerivatives'),
'allowDifferentLicense': item.get('allowDifferentLicense'),
"modelVersions": item.get("modelVersions", []),
"type": item.get("type", ""),
"name": item.get("name", ""),
"allowNoCredit": item.get("allowNoCredit"),
"allowCommercialUse": item.get("allowCommercialUse"),
"allowDerivatives": item.get("allowDerivatives"),
"allowDifferentLicense": item.get("allowDifferentLicense"),
}
return payload
except RateLimitError:
@@ -257,8 +270,10 @@ class CivitaiClient:
except Exception as exc:
logger.error(f"Error fetching model versions in bulk: {exc}")
return None
async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]:
async def get_model_version(
self, model_id: int = None, version_id: int = None
) -> Optional[Dict]:
"""Get specific model version with additional metadata."""
try:
if model_id is None and version_id is not None:
@@ -281,7 +296,7 @@ class CivitaiClient:
if version is None:
return None
model_id = version.get('modelId')
model_id = version.get("modelId")
if not model_id:
logger.error(f"No modelId found in version {version_id}")
return None
@@ -293,7 +308,9 @@ class CivitaiClient:
self._remove_comfy_metadata(version)
return version
async def _get_version_with_model_id(self, model_id: int, version_id: Optional[int]) -> Optional[Dict]:
async def _get_version_with_model_id(
self, model_id: int, version_id: Optional[int]
) -> Optional[Dict]:
model_data = await self._fetch_model_data(model_id)
if not model_data:
return None
@@ -302,8 +319,12 @@ class CivitaiClient:
if target_version is None:
return None
target_version_id = target_version.get('id')
version = await self._fetch_version_by_id(target_version_id) if target_version_id else None
target_version_id = target_version.get("id")
version = (
await self._fetch_version_by_id(target_version_id)
if target_version_id
else None
)
if version is None:
model_hash = self._extract_primary_model_hash(target_version)
@@ -315,7 +336,9 @@ class CivitaiClient:
)
if version is None:
version = self._build_version_from_model_data(target_version, model_id, model_data)
version = self._build_version_from_model_data(
target_version, model_id, model_data
)
self._enrich_version_with_model_data(version, model_data)
self._remove_comfy_metadata(version)
@@ -323,9 +346,7 @@ class CivitaiClient:
async def _fetch_model_data(self, model_id: int) -> Optional[Dict]:
success, data = await self._make_request(
'GET',
f"{self.base_url}/models/{model_id}",
use_auth=True
"GET", f"{self.base_url}/models/{model_id}", use_auth=True
)
if success:
return data
@@ -337,9 +358,7 @@ class CivitaiClient:
return None
success, version = await self._make_request(
'GET',
f"{self.base_url}/model-versions/{version_id}",
use_auth=True
"GET", f"{self.base_url}/model-versions/{version_id}", use_auth=True
)
if success:
return version
@@ -352,9 +371,7 @@ class CivitaiClient:
return None
success, version = await self._make_request(
'GET',
f"{self.base_url}/model-versions/by-hash/{model_hash}",
use_auth=True
"GET", f"{self.base_url}/model-versions/by-hash/{model_hash}", use_auth=True
)
if success:
return version
@@ -362,16 +379,17 @@ class CivitaiClient:
logger.warning(f"Failed to fetch version by hash {model_hash}")
return None
def _select_target_version(self, model_data: Dict, model_id: int, version_id: Optional[int]) -> Optional[Dict]:
model_versions = model_data.get('modelVersions', [])
def _select_target_version(
self, model_data: Dict, model_id: int, version_id: Optional[int]
) -> Optional[Dict]:
model_versions = model_data.get("modelVersions", [])
if not model_versions:
logger.warning(f"No model versions found for model {model_id}")
return None
if version_id is not None:
target_version = next(
(item for item in model_versions if item.get('id') == version_id),
None
(item for item in model_versions if item.get("id") == version_id), None
)
if target_version is None:
logger.warning(
@@ -383,46 +401,50 @@ class CivitaiClient:
return model_versions[0]
def _extract_primary_model_hash(self, version_entry: Dict) -> Optional[str]:
for file_info in version_entry.get('files', []):
if file_info.get('type') == 'Model' and file_info.get('primary'):
hashes = file_info.get('hashes', {})
model_hash = hashes.get('SHA256')
for file_info in version_entry.get("files", []):
if file_info.get("type") == "Model" and file_info.get("primary"):
hashes = file_info.get("hashes", {})
model_hash = hashes.get("SHA256")
if model_hash:
return model_hash
return None
def _build_version_from_model_data(self, version_entry: Dict, model_id: int, model_data: Dict) -> Dict:
def _build_version_from_model_data(
self, version_entry: Dict, model_id: int, model_data: Dict
) -> Dict:
version = copy.deepcopy(version_entry)
version.pop('index', None)
version['modelId'] = model_id
version['model'] = {
'name': model_data.get('name'),
'type': model_data.get('type'),
'nsfw': model_data.get('nsfw'),
'poi': model_data.get('poi')
version.pop("index", None)
version["modelId"] = model_id
version["model"] = {
"name": model_data.get("name"),
"type": model_data.get("type"),
"nsfw": model_data.get("nsfw"),
"poi": model_data.get("poi"),
}
return version
def _enrich_version_with_model_data(self, version: Dict, model_data: Dict) -> None:
model_info = version.get('model')
model_info = version.get("model")
if not isinstance(model_info, dict):
model_info = {}
version['model'] = model_info
version["model"] = model_info
model_info['description'] = model_data.get("description")
model_info['tags'] = model_data.get("tags", [])
version['creator'] = model_data.get("creator")
model_info["description"] = model_data.get("description")
model_info["tags"] = model_data.get("tags", [])
version["creator"] = model_data.get("creator")
license_payload = resolve_license_payload(model_data)
for field, value in license_payload.items():
model_info[field] = value
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
async def get_model_version_info(
self, version_id: str
) -> Tuple[Optional[Dict], Optional[str]]:
"""Fetch model version metadata from Civitai
Args:
version_id: The Civitai model version ID
Returns:
Tuple[Optional[Dict], Optional[str]]: A tuple containing:
- The model version data or None if not found
@@ -430,25 +452,23 @@ class CivitaiClient:
"""
try:
url = f"{self.base_url}/model-versions/{version_id}"
logger.debug(f"Resolving DNS for model version info: {url}")
success, result = await self._make_request(
'GET',
url,
use_auth=True
)
success, result = await self._make_request("GET", url, use_auth=True)
if success:
logger.debug(f"Successfully fetched model version info for: {version_id}")
logger.debug(
f"Successfully fetched model version info for: {version_id}"
)
self._remove_comfy_metadata(result)
return result, None
# Handle specific error cases
if "not found" in str(result):
error_msg = f"Model not found"
logger.warning(f"Model version not found: {version_id} - {error_msg}")
return None, error_msg
# Other error cases
logger.error(f"Failed to fetch model info for {version_id}: {result}")
return None, str(result)
@@ -464,27 +484,23 @@ class CivitaiClient:
Args:
image_id: The Civitai image ID
Returns:
Optional[Dict]: The image data or None if not found
"""
try:
url = f"{self.base_url}/images?imageId={image_id}&nsfw=X"
logger.debug(f"Fetching image info for ID: {image_id}")
success, result = await self._make_request(
'GET',
url,
use_auth=True
)
success, result = await self._make_request("GET", url, use_auth=True)
if success:
if result and "items" in result and len(result["items"]) > 0:
logger.debug(f"Successfully fetched image info for ID: {image_id}")
return result["items"][0]
logger.warning(f"No image found with ID: {image_id}")
return None
logger.error(f"Failed to fetch image info for ID: {image_id}: {result}")
return None
except RateLimitError:
@@ -501,11 +517,7 @@ class CivitaiClient:
try:
url = f"{self.base_url}/models?username={username}"
success, result = await self._make_request(
'GET',
url,
use_auth=True
)
success, result = await self._make_request("GET", url, use_auth=True)
if not success:
logger.error("Failed to fetch models for %s: %s", username, result)