mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-25 15:15:44 -03:00
Fix null-safety issues and apply code formatting
Bug fixes: - Add null guards for base_models_roots/embeddings_roots in backup cleanup - Fix null-safety initialization of extra_unet_roots Formatting: - Apply consistent code style across Python files - Fix line wrapping, quote consistency, and trailing commas - Add type ignore comments for dynamic/platform-specific code
This commit is contained in:
284
py/config.py
284
py/config.py
@@ -2,7 +2,7 @@ import os
|
|||||||
import platform
|
import platform
|
||||||
import threading
|
import threading
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import folder_paths # type: ignore
|
import folder_paths # type: ignore
|
||||||
from typing import Any, Dict, Iterable, List, Mapping, Optional, Set, Tuple
|
from typing import Any, Dict, Iterable, List, Mapping, Optional, Set, Tuple
|
||||||
import logging
|
import logging
|
||||||
import json
|
import json
|
||||||
@@ -10,16 +10,23 @@ import urllib.parse
|
|||||||
import time
|
import time
|
||||||
|
|
||||||
from .utils.cache_paths import CacheType, get_cache_file_path, get_legacy_cache_paths
|
from .utils.cache_paths import CacheType, get_cache_file_path, get_legacy_cache_paths
|
||||||
from .utils.settings_paths import ensure_settings_file, get_settings_dir, load_settings_template
|
from .utils.settings_paths import (
|
||||||
|
ensure_settings_file,
|
||||||
|
get_settings_dir,
|
||||||
|
load_settings_template,
|
||||||
|
)
|
||||||
|
|
||||||
# Use an environment variable to control standalone mode
|
# Use an environment variable to control standalone mode
|
||||||
standalone_mode = os.environ.get("LORA_MANAGER_STANDALONE", "0") == "1" or os.environ.get("HF_HUB_DISABLE_TELEMETRY", "0") == "0"
|
standalone_mode = (
|
||||||
|
os.environ.get("LORA_MANAGER_STANDALONE", "0") == "1"
|
||||||
|
or os.environ.get("HF_HUB_DISABLE_TELEMETRY", "0") == "0"
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _normalize_folder_paths_for_comparison(
|
def _normalize_folder_paths_for_comparison(
|
||||||
folder_paths: Mapping[str, Iterable[str]]
|
folder_paths: Mapping[str, Iterable[str]],
|
||||||
) -> Dict[str, Set[str]]:
|
) -> Dict[str, Set[str]]:
|
||||||
"""Normalize folder paths for comparison across libraries."""
|
"""Normalize folder paths for comparison across libraries."""
|
||||||
|
|
||||||
@@ -49,7 +56,7 @@ def _normalize_folder_paths_for_comparison(
|
|||||||
|
|
||||||
|
|
||||||
def _normalize_library_folder_paths(
|
def _normalize_library_folder_paths(
|
||||||
library_payload: Mapping[str, Any]
|
library_payload: Mapping[str, Any],
|
||||||
) -> Dict[str, Set[str]]:
|
) -> Dict[str, Set[str]]:
|
||||||
"""Return normalized folder paths extracted from a library payload."""
|
"""Return normalized folder paths extracted from a library payload."""
|
||||||
|
|
||||||
@@ -76,9 +83,15 @@ class Config:
|
|||||||
"""Global configuration for LoRA Manager"""
|
"""Global configuration for LoRA Manager"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.templates_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'templates')
|
self.templates_path = os.path.join(
|
||||||
self.static_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'static')
|
os.path.dirname(os.path.dirname(__file__)), "templates"
|
||||||
self.i18n_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'locales')
|
)
|
||||||
|
self.static_path = os.path.join(
|
||||||
|
os.path.dirname(os.path.dirname(__file__)), "static"
|
||||||
|
)
|
||||||
|
self.i18n_path = os.path.join(
|
||||||
|
os.path.dirname(os.path.dirname(__file__)), "locales"
|
||||||
|
)
|
||||||
# Path mapping dictionary, target to link mapping
|
# Path mapping dictionary, target to link mapping
|
||||||
self._path_mappings: Dict[str, str] = {}
|
self._path_mappings: Dict[str, str] = {}
|
||||||
# Normalized preview root directories used to validate preview access
|
# Normalized preview root directories used to validate preview access
|
||||||
@@ -152,17 +165,21 @@ class Config:
|
|||||||
default_library = libraries.get("default", {})
|
default_library = libraries.get("default", {})
|
||||||
|
|
||||||
target_folder_paths = {
|
target_folder_paths = {
|
||||||
'loras': list(self.loras_roots),
|
"loras": list(self.loras_roots),
|
||||||
'checkpoints': list(self.checkpoints_roots or []),
|
"checkpoints": list(self.checkpoints_roots or []),
|
||||||
'unet': list(self.unet_roots or []),
|
"unet": list(self.unet_roots or []),
|
||||||
'embeddings': list(self.embeddings_roots or []),
|
"embeddings": list(self.embeddings_roots or []),
|
||||||
}
|
}
|
||||||
|
|
||||||
normalized_target_paths = _normalize_folder_paths_for_comparison(target_folder_paths)
|
normalized_target_paths = _normalize_folder_paths_for_comparison(
|
||||||
|
target_folder_paths
|
||||||
|
)
|
||||||
|
|
||||||
normalized_default_paths: Optional[Dict[str, Set[str]]] = None
|
normalized_default_paths: Optional[Dict[str, Set[str]]] = None
|
||||||
if isinstance(default_library, Mapping):
|
if isinstance(default_library, Mapping):
|
||||||
normalized_default_paths = _normalize_library_folder_paths(default_library)
|
normalized_default_paths = _normalize_library_folder_paths(
|
||||||
|
default_library
|
||||||
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
not comfy_library
|
not comfy_library
|
||||||
@@ -185,13 +202,19 @@ class Config:
|
|||||||
default_lora_root = self.loras_roots[0]
|
default_lora_root = self.loras_roots[0]
|
||||||
|
|
||||||
default_checkpoint_root = comfy_library.get("default_checkpoint_root", "")
|
default_checkpoint_root = comfy_library.get("default_checkpoint_root", "")
|
||||||
if (not default_checkpoint_root and self.checkpoints_roots and
|
if (
|
||||||
len(self.checkpoints_roots) == 1):
|
not default_checkpoint_root
|
||||||
|
and self.checkpoints_roots
|
||||||
|
and len(self.checkpoints_roots) == 1
|
||||||
|
):
|
||||||
default_checkpoint_root = self.checkpoints_roots[0]
|
default_checkpoint_root = self.checkpoints_roots[0]
|
||||||
|
|
||||||
default_embedding_root = comfy_library.get("default_embedding_root", "")
|
default_embedding_root = comfy_library.get("default_embedding_root", "")
|
||||||
if (not default_embedding_root and self.embeddings_roots and
|
if (
|
||||||
len(self.embeddings_roots) == 1):
|
not default_embedding_root
|
||||||
|
and self.embeddings_roots
|
||||||
|
and len(self.embeddings_roots) == 1
|
||||||
|
):
|
||||||
default_embedding_root = self.embeddings_roots[0]
|
default_embedding_root = self.embeddings_roots[0]
|
||||||
|
|
||||||
metadata = dict(comfy_library.get("metadata", {}))
|
metadata = dict(comfy_library.get("metadata", {}))
|
||||||
@@ -216,11 +239,12 @@ class Config:
|
|||||||
try:
|
try:
|
||||||
if os.path.islink(path):
|
if os.path.islink(path):
|
||||||
return True
|
return True
|
||||||
if platform.system() == 'Windows':
|
if platform.system() == "Windows":
|
||||||
try:
|
try:
|
||||||
import ctypes
|
import ctypes
|
||||||
|
|
||||||
FILE_ATTRIBUTE_REPARSE_POINT = 0x400
|
FILE_ATTRIBUTE_REPARSE_POINT = 0x400
|
||||||
attrs = ctypes.windll.kernel32.GetFileAttributesW(str(path))
|
attrs = ctypes.windll.kernel32.GetFileAttributesW(str(path)) # type: ignore[attr-defined]
|
||||||
return attrs != -1 and (attrs & FILE_ATTRIBUTE_REPARSE_POINT)
|
return attrs != -1 and (attrs & FILE_ATTRIBUTE_REPARSE_POINT)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error checking Windows reparse point: {e}")
|
logger.error(f"Error checking Windows reparse point: {e}")
|
||||||
@@ -233,18 +257,19 @@ class Config:
|
|||||||
"""Check if a directory entry is a symlink, including Windows junctions."""
|
"""Check if a directory entry is a symlink, including Windows junctions."""
|
||||||
if entry.is_symlink():
|
if entry.is_symlink():
|
||||||
return True
|
return True
|
||||||
if platform.system() == 'Windows':
|
if platform.system() == "Windows":
|
||||||
try:
|
try:
|
||||||
import ctypes
|
import ctypes
|
||||||
|
|
||||||
FILE_ATTRIBUTE_REPARSE_POINT = 0x400
|
FILE_ATTRIBUTE_REPARSE_POINT = 0x400
|
||||||
attrs = ctypes.windll.kernel32.GetFileAttributesW(entry.path)
|
attrs = ctypes.windll.kernel32.GetFileAttributesW(entry.path) # type: ignore[attr-defined]
|
||||||
return attrs != -1 and (attrs & FILE_ATTRIBUTE_REPARSE_POINT)
|
return attrs != -1 and (attrs & FILE_ATTRIBUTE_REPARSE_POINT)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _normalize_path(self, path: str) -> str:
|
def _normalize_path(self, path: str) -> str:
|
||||||
return os.path.normpath(path).replace(os.sep, '/')
|
return os.path.normpath(path).replace(os.sep, "/")
|
||||||
|
|
||||||
def _get_symlink_cache_path(self) -> Path:
|
def _get_symlink_cache_path(self) -> Path:
|
||||||
canonical_path = get_cache_file_path(CacheType.SYMLINK, create_dir=True)
|
canonical_path = get_cache_file_path(CacheType.SYMLINK, create_dir=True)
|
||||||
@@ -278,19 +303,18 @@ class Config:
|
|||||||
if self._entry_is_symlink(entry):
|
if self._entry_is_symlink(entry):
|
||||||
try:
|
try:
|
||||||
target = os.path.realpath(entry.path)
|
target = os.path.realpath(entry.path)
|
||||||
direct_symlinks.append([
|
direct_symlinks.append(
|
||||||
self._normalize_path(entry.path),
|
[
|
||||||
self._normalize_path(target)
|
self._normalize_path(entry.path),
|
||||||
])
|
self._normalize_path(target),
|
||||||
|
]
|
||||||
|
)
|
||||||
except OSError:
|
except OSError:
|
||||||
pass
|
pass
|
||||||
except (OSError, PermissionError):
|
except (OSError, PermissionError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return {
|
return {"roots": unique_roots, "direct_symlinks": sorted(direct_symlinks)}
|
||||||
"roots": unique_roots,
|
|
||||||
"direct_symlinks": sorted(direct_symlinks)
|
|
||||||
}
|
|
||||||
|
|
||||||
def _initialize_symlink_mappings(self) -> None:
|
def _initialize_symlink_mappings(self) -> None:
|
||||||
start = time.perf_counter()
|
start = time.perf_counter()
|
||||||
@@ -307,10 +331,14 @@ class Config:
|
|||||||
cached_fingerprint = self._cached_fingerprint
|
cached_fingerprint = self._cached_fingerprint
|
||||||
|
|
||||||
# Check 1: First-level symlinks unchanged (catches new symlinks at root)
|
# Check 1: First-level symlinks unchanged (catches new symlinks at root)
|
||||||
fingerprint_valid = cached_fingerprint and current_fingerprint == cached_fingerprint
|
fingerprint_valid = (
|
||||||
|
cached_fingerprint and current_fingerprint == cached_fingerprint
|
||||||
|
)
|
||||||
|
|
||||||
# Check 2: All cached mappings still valid (catches changes at any depth)
|
# Check 2: All cached mappings still valid (catches changes at any depth)
|
||||||
mappings_valid = self._validate_cached_mappings() if fingerprint_valid else False
|
mappings_valid = (
|
||||||
|
self._validate_cached_mappings() if fingerprint_valid else False
|
||||||
|
)
|
||||||
|
|
||||||
if fingerprint_valid and mappings_valid:
|
if fingerprint_valid and mappings_valid:
|
||||||
return
|
return
|
||||||
@@ -370,7 +398,9 @@ class Config:
|
|||||||
for target, link in cached_mappings.items():
|
for target, link in cached_mappings.items():
|
||||||
if not isinstance(target, str) or not isinstance(link, str):
|
if not isinstance(target, str) or not isinstance(link, str):
|
||||||
continue
|
continue
|
||||||
normalized_mappings[self._normalize_path(target)] = self._normalize_path(link)
|
normalized_mappings[self._normalize_path(target)] = self._normalize_path(
|
||||||
|
link
|
||||||
|
)
|
||||||
|
|
||||||
self._path_mappings = normalized_mappings
|
self._path_mappings = normalized_mappings
|
||||||
|
|
||||||
@@ -391,7 +421,9 @@ class Config:
|
|||||||
parent_dir = loaded_path.parent
|
parent_dir = loaded_path.parent
|
||||||
if parent_dir.name == "cache" and not any(parent_dir.iterdir()):
|
if parent_dir.name == "cache" and not any(parent_dir.iterdir()):
|
||||||
parent_dir.rmdir()
|
parent_dir.rmdir()
|
||||||
logger.info("Removed empty legacy cache directory: %s", parent_dir)
|
logger.info(
|
||||||
|
"Removed empty legacy cache directory: %s", parent_dir
|
||||||
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -402,7 +434,9 @@ class Config:
|
|||||||
exc,
|
exc,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.info("Symlink cache loaded with %d mappings", len(self._path_mappings))
|
logger.info(
|
||||||
|
"Symlink cache loaded with %d mappings", len(self._path_mappings)
|
||||||
|
)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@@ -414,7 +448,7 @@ class Config:
|
|||||||
"""
|
"""
|
||||||
for target, link in self._path_mappings.items():
|
for target, link in self._path_mappings.items():
|
||||||
# Convert normalized paths back to OS paths
|
# Convert normalized paths back to OS paths
|
||||||
link_path = link.replace('/', os.sep)
|
link_path = link.replace("/", os.sep)
|
||||||
|
|
||||||
# Check if symlink still exists
|
# Check if symlink still exists
|
||||||
if not self._is_link(link_path):
|
if not self._is_link(link_path):
|
||||||
@@ -427,7 +461,9 @@ class Config:
|
|||||||
if actual_target != target:
|
if actual_target != target:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Symlink target changed: %s -> %s (cached: %s)",
|
"Symlink target changed: %s -> %s (cached: %s)",
|
||||||
link_path, actual_target, target
|
link_path,
|
||||||
|
actual_target,
|
||||||
|
target,
|
||||||
)
|
)
|
||||||
return False
|
return False
|
||||||
except OSError:
|
except OSError:
|
||||||
@@ -446,7 +482,11 @@ class Config:
|
|||||||
try:
|
try:
|
||||||
with cache_path.open("w", encoding="utf-8") as handle:
|
with cache_path.open("w", encoding="utf-8") as handle:
|
||||||
json.dump(payload, handle, ensure_ascii=False, indent=2)
|
json.dump(payload, handle, ensure_ascii=False, indent=2)
|
||||||
logger.debug("Symlink cache saved to %s with %d mappings", cache_path, len(self._path_mappings))
|
logger.debug(
|
||||||
|
"Symlink cache saved to %s with %d mappings",
|
||||||
|
cache_path,
|
||||||
|
len(self._path_mappings),
|
||||||
|
)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.info("Failed to write symlink cache %s: %s", cache_path, exc)
|
logger.info("Failed to write symlink cache %s: %s", cache_path, exc)
|
||||||
|
|
||||||
@@ -494,13 +534,13 @@ class Config:
|
|||||||
self.add_path_mapping(entry.path, target_path)
|
self.add_path_mapping(entry.path, target_path)
|
||||||
except Exception as inner_exc:
|
except Exception as inner_exc:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Error processing directory entry %s: %s", entry.path, inner_exc
|
"Error processing directory entry %s: %s",
|
||||||
|
entry.path,
|
||||||
|
inner_exc,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error scanning links in {root}: {e}")
|
logger.error(f"Error scanning links in {root}: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def add_path_mapping(self, link_path: str, target_path: str):
|
def add_path_mapping(self, link_path: str, target_path: str):
|
||||||
"""Add a symbolic link path mapping
|
"""Add a symbolic link path mapping
|
||||||
target_path: actual target path
|
target_path: actual target path
|
||||||
@@ -594,26 +634,31 @@ class Config:
|
|||||||
preview_roots.update(self._expand_preview_root(target))
|
preview_roots.update(self._expand_preview_root(target))
|
||||||
preview_roots.update(self._expand_preview_root(link))
|
preview_roots.update(self._expand_preview_root(link))
|
||||||
|
|
||||||
self._preview_root_paths = {path for path in preview_roots if path.is_absolute()}
|
self._preview_root_paths = {
|
||||||
|
path for path in preview_roots if path.is_absolute()
|
||||||
|
}
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Preview roots rebuilt: %d paths from %d lora roots (%d extra), %d checkpoint roots (%d extra), %d embedding roots (%d extra), %d symlink mappings",
|
"Preview roots rebuilt: %d paths from %d lora roots (%d extra), %d checkpoint roots (%d extra), %d embedding roots (%d extra), %d symlink mappings",
|
||||||
len(self._preview_root_paths),
|
len(self._preview_root_paths),
|
||||||
len(self.loras_roots or []), len(self.extra_loras_roots or []),
|
len(self.loras_roots or []),
|
||||||
len(self.base_models_roots or []), len(self.extra_checkpoints_roots or []),
|
len(self.extra_loras_roots or []),
|
||||||
len(self.embeddings_roots or []), len(self.extra_embeddings_roots or []),
|
len(self.base_models_roots or []),
|
||||||
|
len(self.extra_checkpoints_roots or []),
|
||||||
|
len(self.embeddings_roots or []),
|
||||||
|
len(self.extra_embeddings_roots or []),
|
||||||
len(self._path_mappings),
|
len(self._path_mappings),
|
||||||
)
|
)
|
||||||
|
|
||||||
def map_path_to_link(self, path: str) -> str:
|
def map_path_to_link(self, path: str) -> str:
|
||||||
"""Map a target path back to its symbolic link path"""
|
"""Map a target path back to its symbolic link path"""
|
||||||
normalized_path = os.path.normpath(path).replace(os.sep, '/')
|
normalized_path = os.path.normpath(path).replace(os.sep, "/")
|
||||||
# Check if the path is contained in any mapped target path
|
# Check if the path is contained in any mapped target path
|
||||||
for target_path, link_path in self._path_mappings.items():
|
for target_path, link_path in self._path_mappings.items():
|
||||||
# Match whole path components to avoid prefix collisions (e.g., /a/b vs /a/bc)
|
# Match whole path components to avoid prefix collisions (e.g., /a/b vs /a/bc)
|
||||||
if normalized_path == target_path:
|
if normalized_path == target_path:
|
||||||
return link_path
|
return link_path
|
||||||
|
|
||||||
if normalized_path.startswith(target_path + '/'):
|
if normalized_path.startswith(target_path + "/"):
|
||||||
# If the path starts with the target path, replace with link path
|
# If the path starts with the target path, replace with link path
|
||||||
mapped_path = normalized_path.replace(target_path, link_path, 1)
|
mapped_path = normalized_path.replace(target_path, link_path, 1)
|
||||||
return mapped_path
|
return mapped_path
|
||||||
@@ -621,14 +666,14 @@ class Config:
|
|||||||
|
|
||||||
def map_link_to_path(self, link_path: str) -> str:
|
def map_link_to_path(self, link_path: str) -> str:
|
||||||
"""Map a symbolic link path back to the actual path"""
|
"""Map a symbolic link path back to the actual path"""
|
||||||
normalized_link = os.path.normpath(link_path).replace(os.sep, '/')
|
normalized_link = os.path.normpath(link_path).replace(os.sep, "/")
|
||||||
# Check if the path is contained in any mapped target path
|
# Check if the path is contained in any mapped target path
|
||||||
for target_path, link_path_mapped in self._path_mappings.items():
|
for target_path, link_path_mapped in self._path_mappings.items():
|
||||||
# Match whole path components
|
# Match whole path components
|
||||||
if normalized_link == link_path_mapped:
|
if normalized_link == link_path_mapped:
|
||||||
return target_path
|
return target_path
|
||||||
|
|
||||||
if normalized_link.startswith(link_path_mapped + '/'):
|
if normalized_link.startswith(link_path_mapped + "/"):
|
||||||
# If the path starts with the link path, replace with actual path
|
# If the path starts with the link path, replace with actual path
|
||||||
mapped_path = normalized_link.replace(link_path_mapped, target_path, 1)
|
mapped_path = normalized_link.replace(link_path_mapped, target_path, 1)
|
||||||
return mapped_path
|
return mapped_path
|
||||||
@@ -641,8 +686,8 @@ class Config:
|
|||||||
continue
|
continue
|
||||||
if not os.path.exists(path):
|
if not os.path.exists(path):
|
||||||
continue
|
continue
|
||||||
real_path = os.path.normpath(os.path.realpath(path)).replace(os.sep, '/')
|
real_path = os.path.normpath(os.path.realpath(path)).replace(os.sep, "/")
|
||||||
normalized = os.path.normpath(path).replace(os.sep, '/')
|
normalized = os.path.normpath(path).replace(os.sep, "/")
|
||||||
if real_path not in dedup:
|
if real_path not in dedup:
|
||||||
dedup[real_path] = normalized
|
dedup[real_path] = normalized
|
||||||
return dedup
|
return dedup
|
||||||
@@ -652,7 +697,9 @@ class Config:
|
|||||||
unique_paths = sorted(path_map.values(), key=lambda p: p.lower())
|
unique_paths = sorted(path_map.values(), key=lambda p: p.lower())
|
||||||
|
|
||||||
for original_path in unique_paths:
|
for original_path in unique_paths:
|
||||||
real_path = os.path.normpath(os.path.realpath(original_path)).replace(os.sep, '/')
|
real_path = os.path.normpath(os.path.realpath(original_path)).replace(
|
||||||
|
os.sep, "/"
|
||||||
|
)
|
||||||
if real_path != original_path:
|
if real_path != original_path:
|
||||||
self.add_path_mapping(original_path, real_path)
|
self.add_path_mapping(original_path, real_path)
|
||||||
|
|
||||||
@@ -674,7 +721,7 @@ class Config:
|
|||||||
"Please fix your ComfyUI path configuration to separate these folders. "
|
"Please fix your ComfyUI path configuration to separate these folders. "
|
||||||
"Falling back to 'checkpoints' for backward compatibility. "
|
"Falling back to 'checkpoints' for backward compatibility. "
|
||||||
"Overlapping real paths: %s",
|
"Overlapping real paths: %s",
|
||||||
[checkpoint_map.get(rp, rp) for rp in overlapping_real_paths]
|
[checkpoint_map.get(rp, rp) for rp in overlapping_real_paths],
|
||||||
)
|
)
|
||||||
# Remove overlapping paths from unet_map to prioritize checkpoints
|
# Remove overlapping paths from unet_map to prioritize checkpoints
|
||||||
for rp in overlapping_real_paths:
|
for rp in overlapping_real_paths:
|
||||||
@@ -694,7 +741,9 @@ class Config:
|
|||||||
self.unet_roots = [p for p in unique_paths if p in unet_values]
|
self.unet_roots = [p for p in unique_paths if p in unet_values]
|
||||||
|
|
||||||
for original_path in unique_paths:
|
for original_path in unique_paths:
|
||||||
real_path = os.path.normpath(os.path.realpath(original_path)).replace(os.sep, '/')
|
real_path = os.path.normpath(os.path.realpath(original_path)).replace(
|
||||||
|
os.sep, "/"
|
||||||
|
)
|
||||||
if real_path != original_path:
|
if real_path != original_path:
|
||||||
self.add_path_mapping(original_path, real_path)
|
self.add_path_mapping(original_path, real_path)
|
||||||
|
|
||||||
@@ -705,7 +754,9 @@ class Config:
|
|||||||
unique_paths = sorted(path_map.values(), key=lambda p: p.lower())
|
unique_paths = sorted(path_map.values(), key=lambda p: p.lower())
|
||||||
|
|
||||||
for original_path in unique_paths:
|
for original_path in unique_paths:
|
||||||
real_path = os.path.normpath(os.path.realpath(original_path)).replace(os.sep, '/')
|
real_path = os.path.normpath(os.path.realpath(original_path)).replace(
|
||||||
|
os.sep, "/"
|
||||||
|
)
|
||||||
if real_path != original_path:
|
if real_path != original_path:
|
||||||
self.add_path_mapping(original_path, real_path)
|
self.add_path_mapping(original_path, real_path)
|
||||||
|
|
||||||
@@ -719,42 +770,66 @@ class Config:
|
|||||||
self._path_mappings.clear()
|
self._path_mappings.clear()
|
||||||
self._preview_root_paths = set()
|
self._preview_root_paths = set()
|
||||||
|
|
||||||
lora_paths = folder_paths.get('loras', []) or []
|
lora_paths = folder_paths.get("loras", []) or []
|
||||||
checkpoint_paths = folder_paths.get('checkpoints', []) or []
|
checkpoint_paths = folder_paths.get("checkpoints", []) or []
|
||||||
unet_paths = folder_paths.get('unet', []) or []
|
unet_paths = folder_paths.get("unet", []) or []
|
||||||
embedding_paths = folder_paths.get('embeddings', []) or []
|
embedding_paths = folder_paths.get("embeddings", []) or []
|
||||||
|
|
||||||
self.loras_roots = self._prepare_lora_paths(lora_paths)
|
self.loras_roots = self._prepare_lora_paths(lora_paths)
|
||||||
self.base_models_roots = self._prepare_checkpoint_paths(checkpoint_paths, unet_paths)
|
self.base_models_roots = self._prepare_checkpoint_paths(
|
||||||
|
checkpoint_paths, unet_paths
|
||||||
|
)
|
||||||
self.embeddings_roots = self._prepare_embedding_paths(embedding_paths)
|
self.embeddings_roots = self._prepare_embedding_paths(embedding_paths)
|
||||||
|
|
||||||
# Process extra paths (only for LoRA Manager, not shared with ComfyUI)
|
# Process extra paths (only for LoRA Manager, not shared with ComfyUI)
|
||||||
extra_paths = extra_folder_paths or {}
|
extra_paths = extra_folder_paths or {}
|
||||||
extra_lora_paths = extra_paths.get('loras', []) or []
|
extra_lora_paths = extra_paths.get("loras", []) or []
|
||||||
extra_checkpoint_paths = extra_paths.get('checkpoints', []) or []
|
extra_checkpoint_paths = extra_paths.get("checkpoints", []) or []
|
||||||
extra_unet_paths = extra_paths.get('unet', []) or []
|
extra_unet_paths = extra_paths.get("unet", []) or []
|
||||||
extra_embedding_paths = extra_paths.get('embeddings', []) or []
|
extra_embedding_paths = extra_paths.get("embeddings", []) or []
|
||||||
|
|
||||||
self.extra_loras_roots = self._prepare_lora_paths(extra_lora_paths)
|
self.extra_loras_roots = self._prepare_lora_paths(extra_lora_paths)
|
||||||
# Save main paths before processing extra paths ( _prepare_checkpoint_paths overwrites them)
|
# Save main paths before processing extra paths ( _prepare_checkpoint_paths overwrites them)
|
||||||
saved_checkpoints_roots = self.checkpoints_roots
|
saved_checkpoints_roots = self.checkpoints_roots
|
||||||
saved_unet_roots = self.unet_roots
|
saved_unet_roots = self.unet_roots
|
||||||
self.extra_checkpoints_roots = self._prepare_checkpoint_paths(extra_checkpoint_paths, extra_unet_paths)
|
self.extra_checkpoints_roots = self._prepare_checkpoint_paths(
|
||||||
self.extra_unet_roots = self.unet_roots # unet_roots was set by _prepare_checkpoint_paths
|
extra_checkpoint_paths, extra_unet_paths
|
||||||
|
)
|
||||||
|
self.extra_unet_roots = (
|
||||||
|
self.unet_roots if self.unet_roots is not None else []
|
||||||
|
) # unet_roots was set by _prepare_checkpoint_paths
|
||||||
# Restore main paths
|
# Restore main paths
|
||||||
self.checkpoints_roots = saved_checkpoints_roots
|
self.checkpoints_roots = saved_checkpoints_roots
|
||||||
self.unet_roots = saved_unet_roots
|
self.unet_roots = saved_unet_roots
|
||||||
self.extra_embeddings_roots = self._prepare_embedding_paths(extra_embedding_paths)
|
self.extra_embeddings_roots = self._prepare_embedding_paths(
|
||||||
|
extra_embedding_paths
|
||||||
|
)
|
||||||
|
|
||||||
# Log extra folder paths
|
# Log extra folder paths
|
||||||
if self.extra_loras_roots:
|
if self.extra_loras_roots:
|
||||||
logger.info("Found extra LoRA roots:" + "\n - " + "\n - ".join(self.extra_loras_roots))
|
logger.info(
|
||||||
|
"Found extra LoRA roots:"
|
||||||
|
+ "\n - "
|
||||||
|
+ "\n - ".join(self.extra_loras_roots)
|
||||||
|
)
|
||||||
if self.extra_checkpoints_roots:
|
if self.extra_checkpoints_roots:
|
||||||
logger.info("Found extra checkpoint roots:" + "\n - " + "\n - ".join(self.extra_checkpoints_roots))
|
logger.info(
|
||||||
|
"Found extra checkpoint roots:"
|
||||||
|
+ "\n - "
|
||||||
|
+ "\n - ".join(self.extra_checkpoints_roots)
|
||||||
|
)
|
||||||
if self.extra_unet_roots:
|
if self.extra_unet_roots:
|
||||||
logger.info("Found extra diffusion model roots:" + "\n - " + "\n - ".join(self.extra_unet_roots))
|
logger.info(
|
||||||
|
"Found extra diffusion model roots:"
|
||||||
|
+ "\n - "
|
||||||
|
+ "\n - ".join(self.extra_unet_roots)
|
||||||
|
)
|
||||||
if self.extra_embeddings_roots:
|
if self.extra_embeddings_roots:
|
||||||
logger.info("Found extra embedding roots:" + "\n - " + "\n - ".join(self.extra_embeddings_roots))
|
logger.info(
|
||||||
|
"Found extra embedding roots:"
|
||||||
|
+ "\n - "
|
||||||
|
+ "\n - ".join(self.extra_embeddings_roots)
|
||||||
|
)
|
||||||
|
|
||||||
self._initialize_symlink_mappings()
|
self._initialize_symlink_mappings()
|
||||||
|
|
||||||
@@ -763,7 +838,10 @@ class Config:
|
|||||||
try:
|
try:
|
||||||
raw_paths = folder_paths.get_folder_paths("loras")
|
raw_paths = folder_paths.get_folder_paths("loras")
|
||||||
unique_paths = self._prepare_lora_paths(raw_paths)
|
unique_paths = self._prepare_lora_paths(raw_paths)
|
||||||
logger.info("Found LoRA roots:" + ("\n - " + "\n - ".join(unique_paths) if unique_paths else "[]"))
|
logger.info(
|
||||||
|
"Found LoRA roots:"
|
||||||
|
+ ("\n - " + "\n - ".join(unique_paths) if unique_paths else "[]")
|
||||||
|
)
|
||||||
|
|
||||||
if not unique_paths:
|
if not unique_paths:
|
||||||
logger.warning("No valid loras folders found in ComfyUI configuration")
|
logger.warning("No valid loras folders found in ComfyUI configuration")
|
||||||
@@ -779,12 +857,19 @@ class Config:
|
|||||||
try:
|
try:
|
||||||
raw_checkpoint_paths = folder_paths.get_folder_paths("checkpoints")
|
raw_checkpoint_paths = folder_paths.get_folder_paths("checkpoints")
|
||||||
raw_unet_paths = folder_paths.get_folder_paths("unet")
|
raw_unet_paths = folder_paths.get_folder_paths("unet")
|
||||||
unique_paths = self._prepare_checkpoint_paths(raw_checkpoint_paths, raw_unet_paths)
|
unique_paths = self._prepare_checkpoint_paths(
|
||||||
|
raw_checkpoint_paths, raw_unet_paths
|
||||||
|
)
|
||||||
|
|
||||||
logger.info("Found checkpoint roots:" + ("\n - " + "\n - ".join(unique_paths) if unique_paths else "[]"))
|
logger.info(
|
||||||
|
"Found checkpoint roots:"
|
||||||
|
+ ("\n - " + "\n - ".join(unique_paths) if unique_paths else "[]")
|
||||||
|
)
|
||||||
|
|
||||||
if not unique_paths:
|
if not unique_paths:
|
||||||
logger.warning("No valid checkpoint folders found in ComfyUI configuration")
|
logger.warning(
|
||||||
|
"No valid checkpoint folders found in ComfyUI configuration"
|
||||||
|
)
|
||||||
return []
|
return []
|
||||||
|
|
||||||
return unique_paths
|
return unique_paths
|
||||||
@@ -797,10 +882,15 @@ class Config:
|
|||||||
try:
|
try:
|
||||||
raw_paths = folder_paths.get_folder_paths("embeddings")
|
raw_paths = folder_paths.get_folder_paths("embeddings")
|
||||||
unique_paths = self._prepare_embedding_paths(raw_paths)
|
unique_paths = self._prepare_embedding_paths(raw_paths)
|
||||||
logger.info("Found embedding roots:" + ("\n - " + "\n - ".join(unique_paths) if unique_paths else "[]"))
|
logger.info(
|
||||||
|
"Found embedding roots:"
|
||||||
|
+ ("\n - " + "\n - ".join(unique_paths) if unique_paths else "[]")
|
||||||
|
)
|
||||||
|
|
||||||
if not unique_paths:
|
if not unique_paths:
|
||||||
logger.warning("No valid embeddings folders found in ComfyUI configuration")
|
logger.warning(
|
||||||
|
"No valid embeddings folders found in ComfyUI configuration"
|
||||||
|
)
|
||||||
return []
|
return []
|
||||||
|
|
||||||
return unique_paths
|
return unique_paths
|
||||||
@@ -812,9 +902,9 @@ class Config:
|
|||||||
if not preview_path:
|
if not preview_path:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
normalized = os.path.normpath(preview_path).replace(os.sep, '/')
|
normalized = os.path.normpath(preview_path).replace(os.sep, "/")
|
||||||
encoded_path = urllib.parse.quote(normalized, safe='')
|
encoded_path = urllib.parse.quote(normalized, safe="")
|
||||||
return f'/api/lm/previews?path={encoded_path}'
|
return f"/api/lm/previews?path={encoded_path}"
|
||||||
|
|
||||||
def is_preview_path_allowed(self, preview_path: str) -> bool:
|
def is_preview_path_allowed(self, preview_path: str) -> bool:
|
||||||
"""Return ``True`` if ``preview_path`` is within an allowed directory.
|
"""Return ``True`` if ``preview_path`` is within an allowed directory.
|
||||||
@@ -889,14 +979,18 @@ class Config:
|
|||||||
normalized_link = self._normalize_path(str(current))
|
normalized_link = self._normalize_path(str(current))
|
||||||
|
|
||||||
self._path_mappings[normalized_target] = normalized_link
|
self._path_mappings[normalized_target] = normalized_link
|
||||||
self._preview_root_paths.update(self._expand_preview_root(normalized_target))
|
self._preview_root_paths.update(
|
||||||
self._preview_root_paths.update(self._expand_preview_root(normalized_link))
|
self._expand_preview_root(normalized_target)
|
||||||
|
)
|
||||||
|
self._preview_root_paths.update(
|
||||||
|
self._expand_preview_root(normalized_link)
|
||||||
|
)
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Discovered deep symlink: %s -> %s (preview path: %s)",
|
"Discovered deep symlink: %s -> %s (preview path: %s)",
|
||||||
normalized_link,
|
normalized_link,
|
||||||
normalized_target,
|
normalized_target,
|
||||||
preview_path
|
preview_path,
|
||||||
)
|
)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
@@ -914,8 +1008,16 @@ class Config:
|
|||||||
|
|
||||||
def apply_library_settings(self, library_config: Mapping[str, object]) -> None:
|
def apply_library_settings(self, library_config: Mapping[str, object]) -> None:
|
||||||
"""Update runtime paths to match the provided library configuration."""
|
"""Update runtime paths to match the provided library configuration."""
|
||||||
folder_paths = library_config.get('folder_paths') if isinstance(library_config, Mapping) else {}
|
folder_paths = (
|
||||||
extra_folder_paths = library_config.get('extra_folder_paths') if isinstance(library_config, Mapping) else None
|
library_config.get("folder_paths")
|
||||||
|
if isinstance(library_config, Mapping)
|
||||||
|
else {}
|
||||||
|
)
|
||||||
|
extra_folder_paths = (
|
||||||
|
library_config.get("extra_folder_paths")
|
||||||
|
if isinstance(library_config, Mapping)
|
||||||
|
else None
|
||||||
|
)
|
||||||
if not isinstance(folder_paths, Mapping):
|
if not isinstance(folder_paths, Mapping):
|
||||||
folder_paths = {}
|
folder_paths = {}
|
||||||
if not isinstance(extra_folder_paths, Mapping):
|
if not isinstance(extra_folder_paths, Mapping):
|
||||||
@@ -925,9 +1027,12 @@ class Config:
|
|||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"Applied library settings with %d lora roots (%d extra), %d checkpoint roots (%d extra), and %d embedding roots (%d extra)",
|
"Applied library settings with %d lora roots (%d extra), %d checkpoint roots (%d extra), and %d embedding roots (%d extra)",
|
||||||
len(self.loras_roots or []), len(self.extra_loras_roots or []),
|
len(self.loras_roots or []),
|
||||||
len(self.base_models_roots or []), len(self.extra_checkpoints_roots or []),
|
len(self.extra_loras_roots or []),
|
||||||
len(self.embeddings_roots or []), len(self.extra_embeddings_roots or []),
|
len(self.base_models_roots or []),
|
||||||
|
len(self.extra_checkpoints_roots or []),
|
||||||
|
len(self.embeddings_roots or []),
|
||||||
|
len(self.extra_embeddings_roots or []),
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_library_registry_snapshot(self) -> Dict[str, object]:
|
def get_library_registry_snapshot(self) -> Dict[str, object]:
|
||||||
@@ -947,5 +1052,6 @@ class Config:
|
|||||||
logger.debug("Failed to collect library registry snapshot: %s", exc)
|
logger.debug("Failed to collect library registry snapshot: %s", exc)
|
||||||
return {"active_library": "", "libraries": {}}
|
return {"active_library": "", "libraries": {}}
|
||||||
|
|
||||||
|
|
||||||
# Global config instance
|
# Global config instance
|
||||||
config = Config()
|
config = Config()
|
||||||
|
|||||||
@@ -5,16 +5,22 @@ import logging
|
|||||||
from .utils.logging_config import setup_logging
|
from .utils.logging_config import setup_logging
|
||||||
|
|
||||||
# Check if we're in standalone mode
|
# Check if we're in standalone mode
|
||||||
standalone_mode = os.environ.get("LORA_MANAGER_STANDALONE", "0") == "1" or os.environ.get("HF_HUB_DISABLE_TELEMETRY", "0") == "0"
|
standalone_mode = (
|
||||||
|
os.environ.get("LORA_MANAGER_STANDALONE", "0") == "1"
|
||||||
|
or os.environ.get("HF_HUB_DISABLE_TELEMETRY", "0") == "0"
|
||||||
|
)
|
||||||
|
|
||||||
# Only setup logging prefix if not in standalone mode
|
# Only setup logging prefix if not in standalone mode
|
||||||
if not standalone_mode:
|
if not standalone_mode:
|
||||||
setup_logging()
|
setup_logging()
|
||||||
|
|
||||||
from server import PromptServer # type: ignore
|
from server import PromptServer # type: ignore
|
||||||
|
|
||||||
from .config import config
|
from .config import config
|
||||||
from .services.model_service_factory import ModelServiceFactory, register_default_model_types
|
from .services.model_service_factory import (
|
||||||
|
ModelServiceFactory,
|
||||||
|
register_default_model_types,
|
||||||
|
)
|
||||||
from .routes.recipe_routes import RecipeRoutes
|
from .routes.recipe_routes import RecipeRoutes
|
||||||
from .routes.stats_routes import StatsRoutes
|
from .routes.stats_routes import StatsRoutes
|
||||||
from .routes.update_routes import UpdateRoutes
|
from .routes.update_routes import UpdateRoutes
|
||||||
@@ -61,6 +67,7 @@ class _SettingsProxy:
|
|||||||
|
|
||||||
settings = _SettingsProxy()
|
settings = _SettingsProxy()
|
||||||
|
|
||||||
|
|
||||||
class LoraManager:
|
class LoraManager:
|
||||||
"""Main entry point for LoRA Manager plugin"""
|
"""Main entry point for LoRA Manager plugin"""
|
||||||
|
|
||||||
@@ -76,7 +83,8 @@ class LoraManager:
|
|||||||
(
|
(
|
||||||
idx
|
idx
|
||||||
for idx, middleware in enumerate(app.middlewares)
|
for idx, middleware in enumerate(app.middlewares)
|
||||||
if getattr(middleware, "__name__", "") == "block_external_middleware"
|
if getattr(middleware, "__name__", "")
|
||||||
|
== "block_external_middleware"
|
||||||
),
|
),
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
@@ -84,7 +92,9 @@ class LoraManager:
|
|||||||
if block_middleware_index is None:
|
if block_middleware_index is None:
|
||||||
app.middlewares.append(relax_csp_for_remote_media)
|
app.middlewares.append(relax_csp_for_remote_media)
|
||||||
else:
|
else:
|
||||||
app.middlewares.insert(block_middleware_index, relax_csp_for_remote_media)
|
app.middlewares.insert(
|
||||||
|
block_middleware_index, relax_csp_for_remote_media
|
||||||
|
)
|
||||||
|
|
||||||
# Increase allowed header sizes so browsers with large localhost cookie
|
# Increase allowed header sizes so browsers with large localhost cookie
|
||||||
# jars (multiple UIs on 127.0.0.1) don't trip aiohttp's 8KB default
|
# jars (multiple UIs on 127.0.0.1) don't trip aiohttp's 8KB default
|
||||||
@@ -105,7 +115,7 @@ class LoraManager:
|
|||||||
app._handler_args = updated_handler_args
|
app._handler_args = updated_handler_args
|
||||||
|
|
||||||
# Configure aiohttp access logger to be less verbose
|
# Configure aiohttp access logger to be less verbose
|
||||||
logging.getLogger('aiohttp.access').setLevel(logging.WARNING)
|
logging.getLogger("aiohttp.access").setLevel(logging.WARNING)
|
||||||
|
|
||||||
# Add specific suppression for connection reset errors
|
# Add specific suppression for connection reset errors
|
||||||
class ConnectionResetFilter(logging.Filter):
|
class ConnectionResetFilter(logging.Filter):
|
||||||
@@ -124,19 +134,23 @@ class LoraManager:
|
|||||||
asyncio_logger.addFilter(ConnectionResetFilter())
|
asyncio_logger.addFilter(ConnectionResetFilter())
|
||||||
|
|
||||||
# Add static route for example images if the path exists in settings
|
# Add static route for example images if the path exists in settings
|
||||||
example_images_path = settings.get('example_images_path')
|
example_images_path = settings.get("example_images_path")
|
||||||
logger.info(f"Example images path: {example_images_path}")
|
logger.info(f"Example images path: {example_images_path}")
|
||||||
if example_images_path and os.path.exists(example_images_path):
|
if example_images_path and os.path.exists(example_images_path):
|
||||||
app.router.add_static('/example_images_static', example_images_path)
|
app.router.add_static("/example_images_static", example_images_path)
|
||||||
logger.info(f"Added static route for example images: /example_images_static -> {example_images_path}")
|
logger.info(
|
||||||
|
f"Added static route for example images: /example_images_static -> {example_images_path}"
|
||||||
|
)
|
||||||
|
|
||||||
# Add static route for locales JSON files
|
# Add static route for locales JSON files
|
||||||
if os.path.exists(config.i18n_path):
|
if os.path.exists(config.i18n_path):
|
||||||
app.router.add_static('/locales', config.i18n_path)
|
app.router.add_static("/locales", config.i18n_path)
|
||||||
logger.info(f"Added static route for locales: /locales -> {config.i18n_path}")
|
logger.info(
|
||||||
|
f"Added static route for locales: /locales -> {config.i18n_path}"
|
||||||
|
)
|
||||||
|
|
||||||
# Add static route for plugin assets
|
# Add static route for plugin assets
|
||||||
app.router.add_static('/loras_static', config.static_path)
|
app.router.add_static("/loras_static", config.static_path)
|
||||||
|
|
||||||
# Register default model types with the factory
|
# Register default model types with the factory
|
||||||
register_default_model_types()
|
register_default_model_types()
|
||||||
@@ -154,9 +168,11 @@ class LoraManager:
|
|||||||
PreviewRoutes.setup_routes(app)
|
PreviewRoutes.setup_routes(app)
|
||||||
|
|
||||||
# Setup WebSocket routes that are shared across all model types
|
# Setup WebSocket routes that are shared across all model types
|
||||||
app.router.add_get('/ws/fetch-progress', ws_manager.handle_connection)
|
app.router.add_get("/ws/fetch-progress", ws_manager.handle_connection)
|
||||||
app.router.add_get('/ws/download-progress', ws_manager.handle_download_connection)
|
app.router.add_get(
|
||||||
app.router.add_get('/ws/init-progress', ws_manager.handle_init_connection)
|
"/ws/download-progress", ws_manager.handle_download_connection
|
||||||
|
)
|
||||||
|
app.router.add_get("/ws/init-progress", ws_manager.handle_init_connection)
|
||||||
|
|
||||||
# Schedule service initialization
|
# Schedule service initialization
|
||||||
app.on_startup.append(lambda app: cls._initialize_services())
|
app.on_startup.append(lambda app: cls._initialize_services())
|
||||||
@@ -197,7 +213,9 @@ class LoraManager:
|
|||||||
extra_paths.get("embeddings", []),
|
extra_paths.get("embeddings", []),
|
||||||
)
|
)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.warning("Failed to apply library settings during initialization: %s", exc)
|
logger.warning(
|
||||||
|
"Failed to apply library settings during initialization: %s", exc
|
||||||
|
)
|
||||||
|
|
||||||
# Initialize CivitaiClient first to ensure it's ready for other services
|
# Initialize CivitaiClient first to ensure it's ready for other services
|
||||||
await ServiceRegistry.get_civitai_client()
|
await ServiceRegistry.get_civitai_client()
|
||||||
@@ -206,6 +224,7 @@ class LoraManager:
|
|||||||
await ServiceRegistry.get_download_manager()
|
await ServiceRegistry.get_download_manager()
|
||||||
|
|
||||||
from .services.metadata_service import initialize_metadata_providers
|
from .services.metadata_service import initialize_metadata_providers
|
||||||
|
|
||||||
await initialize_metadata_providers()
|
await initialize_metadata_providers()
|
||||||
|
|
||||||
# Initialize WebSocket manager
|
# Initialize WebSocket manager
|
||||||
@@ -221,39 +240,58 @@ class LoraManager:
|
|||||||
|
|
||||||
# Create low-priority initialization tasks
|
# Create low-priority initialization tasks
|
||||||
init_tasks = [
|
init_tasks = [
|
||||||
asyncio.create_task(lora_scanner.initialize_in_background(), name='lora_cache_init'),
|
asyncio.create_task(
|
||||||
asyncio.create_task(checkpoint_scanner.initialize_in_background(), name='checkpoint_cache_init'),
|
lora_scanner.initialize_in_background(), name="lora_cache_init"
|
||||||
asyncio.create_task(embedding_scanner.initialize_in_background(), name='embedding_cache_init'),
|
),
|
||||||
asyncio.create_task(recipe_scanner.initialize_in_background(), name='recipe_cache_init')
|
asyncio.create_task(
|
||||||
|
checkpoint_scanner.initialize_in_background(),
|
||||||
|
name="checkpoint_cache_init",
|
||||||
|
),
|
||||||
|
asyncio.create_task(
|
||||||
|
embedding_scanner.initialize_in_background(),
|
||||||
|
name="embedding_cache_init",
|
||||||
|
),
|
||||||
|
asyncio.create_task(
|
||||||
|
recipe_scanner.initialize_in_background(), name="recipe_cache_init"
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
await ExampleImagesMigration.check_and_run_migrations()
|
await ExampleImagesMigration.check_and_run_migrations()
|
||||||
|
|
||||||
# Schedule post-initialization tasks to run after scanners complete
|
# Schedule post-initialization tasks to run after scanners complete
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
cls._run_post_initialization_tasks(init_tasks),
|
cls._run_post_initialization_tasks(init_tasks), name="post_init_tasks"
|
||||||
name='post_init_tasks'
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug("LoRA Manager: All services initialized and background tasks scheduled")
|
logger.debug(
|
||||||
|
"LoRA Manager: All services initialized and background tasks scheduled"
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"LoRA Manager: Error initializing services: {e}", exc_info=True)
|
logger.error(
|
||||||
|
f"LoRA Manager: Error initializing services: {e}", exc_info=True
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def _run_post_initialization_tasks(cls, init_tasks):
|
async def _run_post_initialization_tasks(cls, init_tasks):
|
||||||
"""Run post-initialization tasks after all scanners complete"""
|
"""Run post-initialization tasks after all scanners complete"""
|
||||||
try:
|
try:
|
||||||
logger.debug("LoRA Manager: Waiting for scanner initialization to complete...")
|
logger.debug(
|
||||||
|
"LoRA Manager: Waiting for scanner initialization to complete..."
|
||||||
|
)
|
||||||
|
|
||||||
# Wait for all scanner initialization tasks to complete
|
# Wait for all scanner initialization tasks to complete
|
||||||
await asyncio.gather(*init_tasks, return_exceptions=True)
|
await asyncio.gather(*init_tasks, return_exceptions=True)
|
||||||
|
|
||||||
logger.debug("LoRA Manager: Scanner initialization completed, starting post-initialization tasks...")
|
logger.debug(
|
||||||
|
"LoRA Manager: Scanner initialization completed, starting post-initialization tasks..."
|
||||||
|
)
|
||||||
|
|
||||||
# Run post-initialization tasks
|
# Run post-initialization tasks
|
||||||
post_tasks = [
|
post_tasks = [
|
||||||
asyncio.create_task(cls._cleanup_backup_files(), name='cleanup_bak_files'),
|
asyncio.create_task(
|
||||||
|
cls._cleanup_backup_files(), name="cleanup_bak_files"
|
||||||
|
),
|
||||||
# Add more post-initialization tasks here as needed
|
# Add more post-initialization tasks here as needed
|
||||||
# asyncio.create_task(cls._another_post_task(), name='another_task'),
|
# asyncio.create_task(cls._another_post_task(), name='another_task'),
|
||||||
]
|
]
|
||||||
@@ -265,14 +303,20 @@ class LoraManager:
|
|||||||
for i, result in enumerate(results):
|
for i, result in enumerate(results):
|
||||||
task_name = post_tasks[i].get_name()
|
task_name = post_tasks[i].get_name()
|
||||||
if isinstance(result, Exception):
|
if isinstance(result, Exception):
|
||||||
logger.error(f"Post-initialization task '{task_name}' failed: {result}")
|
logger.error(
|
||||||
|
f"Post-initialization task '{task_name}' failed: {result}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logger.debug(f"Post-initialization task '{task_name}' completed successfully")
|
logger.debug(
|
||||||
|
f"Post-initialization task '{task_name}' completed successfully"
|
||||||
|
)
|
||||||
|
|
||||||
logger.debug("LoRA Manager: All post-initialization tasks completed")
|
logger.debug("LoRA Manager: All post-initialization tasks completed")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"LoRA Manager: Error in post-initialization tasks: {e}", exc_info=True)
|
logger.error(
|
||||||
|
f"LoRA Manager: Error in post-initialization tasks: {e}", exc_info=True
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def _cleanup_backup_files(cls):
|
async def _cleanup_backup_files(cls):
|
||||||
@@ -283,8 +327,8 @@ class LoraManager:
|
|||||||
# Collect all model roots
|
# Collect all model roots
|
||||||
all_roots = set()
|
all_roots = set()
|
||||||
all_roots.update(config.loras_roots)
|
all_roots.update(config.loras_roots)
|
||||||
all_roots.update(config.base_models_roots)
|
all_roots.update(config.base_models_roots or [])
|
||||||
all_roots.update(config.embeddings_roots)
|
all_roots.update(config.embeddings_roots or [])
|
||||||
|
|
||||||
total_deleted = 0
|
total_deleted = 0
|
||||||
total_size_freed = 0
|
total_size_freed = 0
|
||||||
@@ -294,12 +338,17 @@ class LoraManager:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
deleted_count, size_freed = await cls._cleanup_backup_files_in_directory(root_path)
|
(
|
||||||
|
deleted_count,
|
||||||
|
size_freed,
|
||||||
|
) = await cls._cleanup_backup_files_in_directory(root_path)
|
||||||
total_deleted += deleted_count
|
total_deleted += deleted_count
|
||||||
total_size_freed += size_freed
|
total_size_freed += size_freed
|
||||||
|
|
||||||
if deleted_count > 0:
|
if deleted_count > 0:
|
||||||
logger.debug(f"Cleaned up {deleted_count} .bak files in {root_path} (freed {size_freed / (1024*1024):.2f} MB)")
|
logger.debug(
|
||||||
|
f"Cleaned up {deleted_count} .bak files in {root_path} (freed {size_freed / (1024 * 1024):.2f} MB)"
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error cleaning up .bak files in {root_path}: {e}")
|
logger.error(f"Error cleaning up .bak files in {root_path}: {e}")
|
||||||
@@ -308,7 +357,9 @@ class LoraManager:
|
|||||||
await asyncio.sleep(0.01)
|
await asyncio.sleep(0.01)
|
||||||
|
|
||||||
if total_deleted > 0:
|
if total_deleted > 0:
|
||||||
logger.debug(f"Backup cleanup completed: removed {total_deleted} .bak files, freed {total_size_freed / (1024*1024):.2f} MB total")
|
logger.debug(
|
||||||
|
f"Backup cleanup completed: removed {total_deleted} .bak files, freed {total_size_freed / (1024 * 1024):.2f} MB total"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logger.debug("Backup cleanup completed: no .bak files found")
|
logger.debug("Backup cleanup completed: no .bak files found")
|
||||||
|
|
||||||
@@ -341,7 +392,9 @@ class LoraManager:
|
|||||||
with os.scandir(path) as it:
|
with os.scandir(path) as it:
|
||||||
for entry in it:
|
for entry in it:
|
||||||
try:
|
try:
|
||||||
if entry.is_file(follow_symlinks=True) and entry.name.endswith('.bak'):
|
if entry.is_file(
|
||||||
|
follow_symlinks=True
|
||||||
|
) and entry.name.endswith(".bak"):
|
||||||
file_size = entry.stat().st_size
|
file_size = entry.stat().st_size
|
||||||
os.remove(entry.path)
|
os.remove(entry.path)
|
||||||
deleted_count += 1
|
deleted_count += 1
|
||||||
@@ -352,7 +405,9 @@ class LoraManager:
|
|||||||
cleanup_recursive(entry.path)
|
cleanup_recursive(entry.path)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Could not delete .bak file {entry.path}: {e}")
|
logger.warning(
|
||||||
|
f"Could not delete .bak file {entry.path}: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error scanning directory {path} for .bak files: {e}")
|
logger.error(f"Error scanning directory {path} for .bak files: {e}")
|
||||||
@@ -370,21 +425,21 @@ class LoraManager:
|
|||||||
service = ExampleImagesCleanupService()
|
service = ExampleImagesCleanupService()
|
||||||
result = await service.cleanup_example_image_folders()
|
result = await service.cleanup_example_image_folders()
|
||||||
|
|
||||||
if result.get('success'):
|
if result.get("success"):
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Manual example images cleanup completed: moved=%s",
|
"Manual example images cleanup completed: moved=%s",
|
||||||
result.get('moved_total'),
|
result.get("moved_total"),
|
||||||
)
|
)
|
||||||
elif result.get('partial_success'):
|
elif result.get("partial_success"):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Manual example images cleanup partially succeeded: moved=%s failures=%s",
|
"Manual example images cleanup partially succeeded: moved=%s failures=%s",
|
||||||
result.get('moved_total'),
|
result.get("moved_total"),
|
||||||
result.get('move_failures'),
|
result.get("move_failures"),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Manual example images cleanup skipped or failed: %s",
|
"Manual example images cleanup skipped or failed: %s",
|
||||||
result.get('error', 'no changes'),
|
result.get("error", "no changes"),
|
||||||
)
|
)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
@@ -392,9 +447,9 @@ class LoraManager:
|
|||||||
except Exception as e: # pragma: no cover - defensive guard
|
except Exception as e: # pragma: no cover - defensive guard
|
||||||
logger.error(f"Error during example images cleanup: {e}", exc_info=True)
|
logger.error(f"Error during example images cleanup: {e}", exc_info=True)
|
||||||
return {
|
return {
|
||||||
'success': False,
|
"success": False,
|
||||||
'error': str(e),
|
"error": str(e),
|
||||||
'error_code': 'unexpected_error',
|
"error_code": "unexpected_error",
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@@ -4,7 +4,10 @@ import logging
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Check if running in standalone mode
|
# Check if running in standalone mode
|
||||||
standalone_mode = os.environ.get("LORA_MANAGER_STANDALONE", "0") == "1" or os.environ.get("HF_HUB_DISABLE_TELEMETRY", "0") == "0"
|
standalone_mode = (
|
||||||
|
os.environ.get("LORA_MANAGER_STANDALONE", "0") == "1"
|
||||||
|
or os.environ.get("HF_HUB_DISABLE_TELEMETRY", "0") == "0"
|
||||||
|
)
|
||||||
|
|
||||||
if not standalone_mode:
|
if not standalone_mode:
|
||||||
from .metadata_hook import MetadataHook
|
from .metadata_hook import MetadataHook
|
||||||
@@ -19,7 +22,7 @@ if not standalone_mode:
|
|||||||
|
|
||||||
logger.info("ComfyUI Metadata Collector initialized")
|
logger.info("ComfyUI Metadata Collector initialized")
|
||||||
|
|
||||||
def get_metadata(prompt_id=None):
|
def get_metadata(prompt_id=None): # type: ignore[no-redef]
|
||||||
"""Helper function to get metadata from the registry"""
|
"""Helper function to get metadata from the registry"""
|
||||||
registry = MetadataRegistry()
|
registry = MetadataRegistry()
|
||||||
return registry.get_metadata(prompt_id)
|
return registry.get_metadata(prompt_id)
|
||||||
@@ -28,6 +31,6 @@ else:
|
|||||||
def init():
|
def init():
|
||||||
logger.info("ComfyUI Metadata Collector disabled in standalone mode")
|
logger.info("ComfyUI Metadata Collector disabled in standalone mode")
|
||||||
|
|
||||||
def get_metadata(prompt_id=None):
|
def get_metadata(prompt_id=None): # type: ignore[no-redef]
|
||||||
"""Dummy implementation for standalone mode"""
|
"""Dummy implementation for standalone mode"""
|
||||||
return {}
|
return {}
|
||||||
|
|||||||
@@ -1,10 +1,12 @@
|
|||||||
import time
|
import time
|
||||||
from nodes import NODE_CLASS_MAPPINGS
|
from nodes import NODE_CLASS_MAPPINGS # type: ignore
|
||||||
from .node_extractors import NODE_EXTRACTORS, GenericNodeExtractor
|
from .node_extractors import NODE_EXTRACTORS, GenericNodeExtractor
|
||||||
from .constants import METADATA_CATEGORIES, IMAGES
|
from .constants import METADATA_CATEGORIES, IMAGES
|
||||||
|
|
||||||
|
|
||||||
class MetadataRegistry:
|
class MetadataRegistry:
|
||||||
"""A singleton registry to store and retrieve workflow metadata"""
|
"""A singleton registry to store and retrieve workflow metadata"""
|
||||||
|
|
||||||
_instance = None
|
_instance = None
|
||||||
|
|
||||||
def __new__(cls):
|
def __new__(cls):
|
||||||
@@ -37,11 +39,13 @@ class MetadataRegistry:
|
|||||||
# Sort all prompt_ids by timestamp
|
# Sort all prompt_ids by timestamp
|
||||||
sorted_prompts = sorted(
|
sorted_prompts = sorted(
|
||||||
self.prompt_metadata.keys(),
|
self.prompt_metadata.keys(),
|
||||||
key=lambda pid: self.prompt_metadata[pid].get("timestamp", 0)
|
key=lambda pid: self.prompt_metadata[pid].get("timestamp", 0),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Remove oldest records
|
# Remove oldest records
|
||||||
prompts_to_remove = sorted_prompts[:len(sorted_prompts) - self.max_prompt_history]
|
prompts_to_remove = sorted_prompts[
|
||||||
|
: len(sorted_prompts) - self.max_prompt_history
|
||||||
|
]
|
||||||
for pid in prompts_to_remove:
|
for pid in prompts_to_remove:
|
||||||
del self.prompt_metadata[pid]
|
del self.prompt_metadata[pid]
|
||||||
|
|
||||||
@@ -53,11 +57,13 @@ class MetadataRegistry:
|
|||||||
category: {} for category in METADATA_CATEGORIES
|
category: {} for category in METADATA_CATEGORIES
|
||||||
}
|
}
|
||||||
# Add additional metadata fields
|
# Add additional metadata fields
|
||||||
self.prompt_metadata[prompt_id].update({
|
self.prompt_metadata[prompt_id].update(
|
||||||
"execution_order": [],
|
{
|
||||||
"current_prompt": None, # Will store the prompt object
|
"execution_order": [],
|
||||||
"timestamp": time.time()
|
"current_prompt": None, # Will store the prompt object
|
||||||
})
|
"timestamp": time.time(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
# Clean up old prompt data
|
# Clean up old prompt data
|
||||||
self._clean_old_prompts()
|
self._clean_old_prompts()
|
||||||
@@ -125,7 +131,9 @@ class MetadataRegistry:
|
|||||||
for category in self.metadata_categories:
|
for category in self.metadata_categories:
|
||||||
if category in cached_data and node_id in cached_data[category]:
|
if category in cached_data and node_id in cached_data[category]:
|
||||||
if node_id not in metadata[category]:
|
if node_id not in metadata[category]:
|
||||||
metadata[category][node_id] = cached_data[category][node_id]
|
metadata[category][node_id] = cached_data[category][
|
||||||
|
node_id
|
||||||
|
]
|
||||||
|
|
||||||
def record_node_execution(self, node_id, class_type, inputs, outputs):
|
def record_node_execution(self, node_id, class_type, inputs, outputs):
|
||||||
"""Record information about a node's execution"""
|
"""Record information about a node's execution"""
|
||||||
@@ -135,7 +143,9 @@ class MetadataRegistry:
|
|||||||
# Add to execution order and mark as executed
|
# Add to execution order and mark as executed
|
||||||
if node_id not in self.executed_nodes:
|
if node_id not in self.executed_nodes:
|
||||||
self.executed_nodes.add(node_id)
|
self.executed_nodes.add(node_id)
|
||||||
self.prompt_metadata[self.current_prompt_id]["execution_order"].append(node_id)
|
self.prompt_metadata[self.current_prompt_id]["execution_order"].append(
|
||||||
|
node_id
|
||||||
|
)
|
||||||
|
|
||||||
# Process inputs to simplify working with them
|
# Process inputs to simplify working with them
|
||||||
processed_inputs = {}
|
processed_inputs = {}
|
||||||
@@ -152,7 +162,7 @@ class MetadataRegistry:
|
|||||||
node_id,
|
node_id,
|
||||||
processed_inputs,
|
processed_inputs,
|
||||||
outputs,
|
outputs,
|
||||||
self.prompt_metadata[self.current_prompt_id]
|
self.prompt_metadata[self.current_prompt_id],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Cache this node's metadata
|
# Cache this node's metadata
|
||||||
@@ -168,11 +178,9 @@ class MetadataRegistry:
|
|||||||
|
|
||||||
# Use the same extractor to update with outputs
|
# Use the same extractor to update with outputs
|
||||||
extractor = NODE_EXTRACTORS.get(class_type, GenericNodeExtractor)
|
extractor = NODE_EXTRACTORS.get(class_type, GenericNodeExtractor)
|
||||||
if hasattr(extractor, 'update'):
|
if hasattr(extractor, "update"):
|
||||||
extractor.update(
|
extractor.update(
|
||||||
node_id,
|
node_id, processed_outputs, self.prompt_metadata[self.current_prompt_id]
|
||||||
processed_outputs,
|
|
||||||
self.prompt_metadata[self.current_prompt_id]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update the cached metadata for this node
|
# Update the cached metadata for this node
|
||||||
@@ -214,7 +222,7 @@ class MetadataRegistry:
|
|||||||
# Find cache keys that are no longer needed
|
# Find cache keys that are no longer needed
|
||||||
keys_to_remove = []
|
keys_to_remove = []
|
||||||
for cache_key in self.node_cache:
|
for cache_key in self.node_cache:
|
||||||
node_id = cache_key.split(':')[0]
|
node_id = cache_key.split(":")[0]
|
||||||
if node_id not in active_node_ids:
|
if node_id not in active_node_ids:
|
||||||
keys_to_remove.append(cache_key)
|
keys_to_remove.append(cache_key)
|
||||||
|
|
||||||
@@ -270,7 +278,10 @@ class MetadataRegistry:
|
|||||||
if IMAGES in cached_data and node_id in cached_data[IMAGES]:
|
if IMAGES in cached_data and node_id in cached_data[IMAGES]:
|
||||||
image_data = cached_data[IMAGES][node_id]["image"]
|
image_data = cached_data[IMAGES][node_id]["image"]
|
||||||
# Handle different image formats
|
# Handle different image formats
|
||||||
if isinstance(image_data, (list, tuple)) and len(image_data) > 0:
|
if (
|
||||||
|
isinstance(image_data, (list, tuple))
|
||||||
|
and len(image_data) > 0
|
||||||
|
):
|
||||||
return image_data[0]
|
return image_data[0]
|
||||||
return image_data
|
return image_data
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import folder_paths # type: ignore
|
import folder_paths # type: ignore
|
||||||
from ..services.service_registry import ServiceRegistry
|
from ..services.service_registry import ServiceRegistry
|
||||||
from ..metadata_collector.metadata_processor import MetadataProcessor
|
from ..metadata_collector.metadata_processor import MetadataProcessor
|
||||||
from ..metadata_collector import get_metadata
|
from ..metadata_collector import get_metadata
|
||||||
@@ -12,6 +13,7 @@ import logging
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class SaveImageLM:
|
class SaveImageLM:
|
||||||
NAME = "Save Image (LoraManager)"
|
NAME = "Save Image (LoraManager)"
|
||||||
CATEGORY = "Lora Manager/utils"
|
CATEGORY = "Lora Manager/utils"
|
||||||
@@ -32,33 +34,51 @@ class SaveImageLM:
|
|||||||
return {
|
return {
|
||||||
"required": {
|
"required": {
|
||||||
"images": ("IMAGE",),
|
"images": ("IMAGE",),
|
||||||
"filename_prefix": ("STRING", {
|
"filename_prefix": (
|
||||||
"default": "ComfyUI",
|
"STRING",
|
||||||
"tooltip": "Base filename for saved images. Supports format patterns like %seed%, %width%, %height%, %model%, etc."
|
{
|
||||||
}),
|
"default": "ComfyUI",
|
||||||
"file_format": (["png", "jpeg", "webp"], {
|
"tooltip": "Base filename for saved images. Supports format patterns like %seed%, %width%, %height%, %model%, etc.",
|
||||||
"tooltip": "Image format to save as. PNG preserves quality, JPEG is smaller, WebP balances size and quality."
|
},
|
||||||
}),
|
),
|
||||||
|
"file_format": (
|
||||||
|
["png", "jpeg", "webp"],
|
||||||
|
{
|
||||||
|
"tooltip": "Image format to save as. PNG preserves quality, JPEG is smaller, WebP balances size and quality."
|
||||||
|
},
|
||||||
|
),
|
||||||
},
|
},
|
||||||
"optional": {
|
"optional": {
|
||||||
"lossless_webp": ("BOOLEAN", {
|
"lossless_webp": (
|
||||||
"default": False,
|
"BOOLEAN",
|
||||||
"tooltip": "When enabled, saves WebP images with lossless compression. Results in larger files but no quality loss."
|
{
|
||||||
}),
|
"default": False,
|
||||||
"quality": ("INT", {
|
"tooltip": "When enabled, saves WebP images with lossless compression. Results in larger files but no quality loss.",
|
||||||
"default": 100,
|
},
|
||||||
"min": 1,
|
),
|
||||||
"max": 100,
|
"quality": (
|
||||||
"tooltip": "Compression quality for JPEG and lossy WebP formats (1-100). Higher values mean better quality but larger files."
|
"INT",
|
||||||
}),
|
{
|
||||||
"embed_workflow": ("BOOLEAN", {
|
"default": 100,
|
||||||
"default": False,
|
"min": 1,
|
||||||
"tooltip": "Embeds the complete workflow data into the image metadata. Only works with PNG and WebP formats."
|
"max": 100,
|
||||||
}),
|
"tooltip": "Compression quality for JPEG and lossy WebP formats (1-100). Higher values mean better quality but larger files.",
|
||||||
"add_counter_to_filename": ("BOOLEAN", {
|
},
|
||||||
"default": True,
|
),
|
||||||
"tooltip": "Adds an incremental counter to filenames to prevent overwriting previous images."
|
"embed_workflow": (
|
||||||
}),
|
"BOOLEAN",
|
||||||
|
{
|
||||||
|
"default": False,
|
||||||
|
"tooltip": "Embeds the complete workflow data into the image metadata. Only works with PNG and WebP formats.",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"add_counter_to_filename": (
|
||||||
|
"BOOLEAN",
|
||||||
|
{
|
||||||
|
"default": True,
|
||||||
|
"tooltip": "Adds an incremental counter to filenames to prevent overwriting previous images.",
|
||||||
|
},
|
||||||
|
),
|
||||||
},
|
},
|
||||||
"hidden": {
|
"hidden": {
|
||||||
"id": "UNIQUE_ID",
|
"id": "UNIQUE_ID",
|
||||||
@@ -77,9 +97,10 @@ class SaveImageLM:
|
|||||||
scanner = ServiceRegistry.get_service_sync("lora_scanner")
|
scanner = ServiceRegistry.get_service_sync("lora_scanner")
|
||||||
|
|
||||||
# Use the new direct filename lookup method
|
# Use the new direct filename lookup method
|
||||||
hash_value = scanner.get_hash_by_filename(lora_name)
|
if scanner is not None:
|
||||||
if hash_value:
|
hash_value = scanner.get_hash_by_filename(lora_name)
|
||||||
return hash_value
|
if hash_value:
|
||||||
|
return hash_value
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -95,9 +116,10 @@ class SaveImageLM:
|
|||||||
checkpoint_name = os.path.splitext(checkpoint_name)[0]
|
checkpoint_name = os.path.splitext(checkpoint_name)[0]
|
||||||
|
|
||||||
# Try direct filename lookup first
|
# Try direct filename lookup first
|
||||||
hash_value = scanner.get_hash_by_filename(checkpoint_name)
|
if scanner is not None:
|
||||||
if hash_value:
|
hash_value = scanner.get_hash_by_filename(checkpoint_name)
|
||||||
return hash_value
|
if hash_value:
|
||||||
|
return hash_value
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -112,11 +134,11 @@ class SaveImageLM:
|
|||||||
param_list.append(f"{label}: {value}")
|
param_list.append(f"{label}: {value}")
|
||||||
|
|
||||||
# Extract the prompt and negative prompt
|
# Extract the prompt and negative prompt
|
||||||
prompt = metadata_dict.get('prompt', '')
|
prompt = metadata_dict.get("prompt", "")
|
||||||
negative_prompt = metadata_dict.get('negative_prompt', '')
|
negative_prompt = metadata_dict.get("negative_prompt", "")
|
||||||
|
|
||||||
# Extract loras from the prompt if present
|
# Extract loras from the prompt if present
|
||||||
loras_text = metadata_dict.get('loras', '')
|
loras_text = metadata_dict.get("loras", "")
|
||||||
lora_hashes = {}
|
lora_hashes = {}
|
||||||
|
|
||||||
# If loras are found, add them on a new line after the prompt
|
# If loras are found, add them on a new line after the prompt
|
||||||
@@ -124,7 +146,7 @@ class SaveImageLM:
|
|||||||
prompt_with_loras = f"{prompt}\n{loras_text}"
|
prompt_with_loras = f"{prompt}\n{loras_text}"
|
||||||
|
|
||||||
# Extract lora names from the format <lora:name:strength>
|
# Extract lora names from the format <lora:name:strength>
|
||||||
lora_matches = re.findall(r'<lora:([^:]+):([^>]+)>', loras_text)
|
lora_matches = re.findall(r"<lora:([^:]+):([^>]+)>", loras_text)
|
||||||
|
|
||||||
# Get hash for each lora
|
# Get hash for each lora
|
||||||
for lora_name, strength in lora_matches:
|
for lora_name, strength in lora_matches:
|
||||||
@@ -145,43 +167,43 @@ class SaveImageLM:
|
|||||||
params = []
|
params = []
|
||||||
|
|
||||||
# Add standard parameters in the correct order
|
# Add standard parameters in the correct order
|
||||||
if 'steps' in metadata_dict:
|
if "steps" in metadata_dict:
|
||||||
add_param_if_not_none(params, "Steps", metadata_dict.get('steps'))
|
add_param_if_not_none(params, "Steps", metadata_dict.get("steps"))
|
||||||
|
|
||||||
# Combine sampler and scheduler information
|
# Combine sampler and scheduler information
|
||||||
sampler_name = None
|
sampler_name = None
|
||||||
scheduler_name = None
|
scheduler_name = None
|
||||||
|
|
||||||
if 'sampler' in metadata_dict:
|
if "sampler" in metadata_dict:
|
||||||
sampler = metadata_dict.get('sampler')
|
sampler = metadata_dict.get("sampler")
|
||||||
# Convert ComfyUI sampler names to user-friendly names
|
# Convert ComfyUI sampler names to user-friendly names
|
||||||
sampler_mapping = {
|
sampler_mapping = {
|
||||||
'euler': 'Euler',
|
"euler": "Euler",
|
||||||
'euler_ancestral': 'Euler a',
|
"euler_ancestral": "Euler a",
|
||||||
'dpm_2': 'DPM2',
|
"dpm_2": "DPM2",
|
||||||
'dpm_2_ancestral': 'DPM2 a',
|
"dpm_2_ancestral": "DPM2 a",
|
||||||
'heun': 'Heun',
|
"heun": "Heun",
|
||||||
'dpm_fast': 'DPM fast',
|
"dpm_fast": "DPM fast",
|
||||||
'dpm_adaptive': 'DPM adaptive',
|
"dpm_adaptive": "DPM adaptive",
|
||||||
'lms': 'LMS',
|
"lms": "LMS",
|
||||||
'dpmpp_2s_ancestral': 'DPM++ 2S a',
|
"dpmpp_2s_ancestral": "DPM++ 2S a",
|
||||||
'dpmpp_sde': 'DPM++ SDE',
|
"dpmpp_sde": "DPM++ SDE",
|
||||||
'dpmpp_sde_gpu': 'DPM++ SDE',
|
"dpmpp_sde_gpu": "DPM++ SDE",
|
||||||
'dpmpp_2m': 'DPM++ 2M',
|
"dpmpp_2m": "DPM++ 2M",
|
||||||
'dpmpp_2m_sde': 'DPM++ 2M SDE',
|
"dpmpp_2m_sde": "DPM++ 2M SDE",
|
||||||
'dpmpp_2m_sde_gpu': 'DPM++ 2M SDE',
|
"dpmpp_2m_sde_gpu": "DPM++ 2M SDE",
|
||||||
'ddim': 'DDIM'
|
"ddim": "DDIM",
|
||||||
}
|
}
|
||||||
sampler_name = sampler_mapping.get(sampler, sampler)
|
sampler_name = sampler_mapping.get(sampler, sampler)
|
||||||
|
|
||||||
if 'scheduler' in metadata_dict:
|
if "scheduler" in metadata_dict:
|
||||||
scheduler = metadata_dict.get('scheduler')
|
scheduler = metadata_dict.get("scheduler")
|
||||||
scheduler_mapping = {
|
scheduler_mapping = {
|
||||||
'normal': 'Simple',
|
"normal": "Simple",
|
||||||
'karras': 'Karras',
|
"karras": "Karras",
|
||||||
'exponential': 'Exponential',
|
"exponential": "Exponential",
|
||||||
'sgm_uniform': 'SGM Uniform',
|
"sgm_uniform": "SGM Uniform",
|
||||||
'sgm_quadratic': 'SGM Quadratic'
|
"sgm_quadratic": "SGM Quadratic",
|
||||||
}
|
}
|
||||||
scheduler_name = scheduler_mapping.get(scheduler, scheduler)
|
scheduler_name = scheduler_mapping.get(scheduler, scheduler)
|
||||||
|
|
||||||
@@ -193,25 +215,25 @@ class SaveImageLM:
|
|||||||
params.append(f"Sampler: {sampler_name}")
|
params.append(f"Sampler: {sampler_name}")
|
||||||
|
|
||||||
# CFG scale (Use guidance if available, otherwise fall back to cfg_scale or cfg)
|
# CFG scale (Use guidance if available, otherwise fall back to cfg_scale or cfg)
|
||||||
if 'guidance' in metadata_dict:
|
if "guidance" in metadata_dict:
|
||||||
add_param_if_not_none(params, "CFG scale", metadata_dict.get('guidance'))
|
add_param_if_not_none(params, "CFG scale", metadata_dict.get("guidance"))
|
||||||
elif 'cfg_scale' in metadata_dict:
|
elif "cfg_scale" in metadata_dict:
|
||||||
add_param_if_not_none(params, "CFG scale", metadata_dict.get('cfg_scale'))
|
add_param_if_not_none(params, "CFG scale", metadata_dict.get("cfg_scale"))
|
||||||
elif 'cfg' in metadata_dict:
|
elif "cfg" in metadata_dict:
|
||||||
add_param_if_not_none(params, "CFG scale", metadata_dict.get('cfg'))
|
add_param_if_not_none(params, "CFG scale", metadata_dict.get("cfg"))
|
||||||
|
|
||||||
# Seed
|
# Seed
|
||||||
if 'seed' in metadata_dict:
|
if "seed" in metadata_dict:
|
||||||
add_param_if_not_none(params, "Seed", metadata_dict.get('seed'))
|
add_param_if_not_none(params, "Seed", metadata_dict.get("seed"))
|
||||||
|
|
||||||
# Size
|
# Size
|
||||||
if 'size' in metadata_dict:
|
if "size" in metadata_dict:
|
||||||
add_param_if_not_none(params, "Size", metadata_dict.get('size'))
|
add_param_if_not_none(params, "Size", metadata_dict.get("size"))
|
||||||
|
|
||||||
# Model info
|
# Model info
|
||||||
if 'checkpoint' in metadata_dict:
|
if "checkpoint" in metadata_dict:
|
||||||
# Ensure checkpoint is a string before processing
|
# Ensure checkpoint is a string before processing
|
||||||
checkpoint = metadata_dict.get('checkpoint')
|
checkpoint = metadata_dict.get("checkpoint")
|
||||||
if checkpoint is not None:
|
if checkpoint is not None:
|
||||||
# Get model hash
|
# Get model hash
|
||||||
model_hash = self.get_checkpoint_hash(checkpoint)
|
model_hash = self.get_checkpoint_hash(checkpoint)
|
||||||
@@ -223,7 +245,9 @@ class SaveImageLM:
|
|||||||
|
|
||||||
# Add model hash if available
|
# Add model hash if available
|
||||||
if model_hash:
|
if model_hash:
|
||||||
params.append(f"Model hash: {model_hash[:10]}, Model: {checkpoint_name}")
|
params.append(
|
||||||
|
f"Model hash: {model_hash[:10]}, Model: {checkpoint_name}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
params.append(f"Model: {checkpoint_name}")
|
params.append(f"Model: {checkpoint_name}")
|
||||||
|
|
||||||
@@ -234,7 +258,7 @@ class SaveImageLM:
|
|||||||
lora_hash_parts.append(f"{lora_name}: {hash_value[:10]}")
|
lora_hash_parts.append(f"{lora_name}: {hash_value[:10]}")
|
||||||
|
|
||||||
if lora_hash_parts:
|
if lora_hash_parts:
|
||||||
params.append(f"Lora hashes: \"{', '.join(lora_hash_parts)}\"")
|
params.append(f'Lora hashes: "{", ".join(lora_hash_parts)}"')
|
||||||
|
|
||||||
# Combine all parameters with commas
|
# Combine all parameters with commas
|
||||||
metadata_parts.append(", ".join(params))
|
metadata_parts.append(", ".join(params))
|
||||||
@@ -254,30 +278,30 @@ class SaveImageLM:
|
|||||||
parts = segment.replace("%", "").split(":")
|
parts = segment.replace("%", "").split(":")
|
||||||
key = parts[0]
|
key = parts[0]
|
||||||
|
|
||||||
if key == "seed" and 'seed' in metadata_dict:
|
if key == "seed" and "seed" in metadata_dict:
|
||||||
filename = filename.replace(segment, str(metadata_dict.get('seed', '')))
|
filename = filename.replace(segment, str(metadata_dict.get("seed", "")))
|
||||||
elif key == "width" and 'size' in metadata_dict:
|
elif key == "width" and "size" in metadata_dict:
|
||||||
size = metadata_dict.get('size', 'x')
|
size = metadata_dict.get("size", "x")
|
||||||
w = size.split('x')[0] if isinstance(size, str) else size[0]
|
w = size.split("x")[0] if isinstance(size, str) else size[0]
|
||||||
filename = filename.replace(segment, str(w))
|
filename = filename.replace(segment, str(w))
|
||||||
elif key == "height" and 'size' in metadata_dict:
|
elif key == "height" and "size" in metadata_dict:
|
||||||
size = metadata_dict.get('size', 'x')
|
size = metadata_dict.get("size", "x")
|
||||||
h = size.split('x')[1] if isinstance(size, str) else size[1]
|
h = size.split("x")[1] if isinstance(size, str) else size[1]
|
||||||
filename = filename.replace(segment, str(h))
|
filename = filename.replace(segment, str(h))
|
||||||
elif key == "pprompt" and 'prompt' in metadata_dict:
|
elif key == "pprompt" and "prompt" in metadata_dict:
|
||||||
prompt = metadata_dict.get('prompt', '').replace("\n", " ")
|
prompt = metadata_dict.get("prompt", "").replace("\n", " ")
|
||||||
if len(parts) >= 2:
|
if len(parts) >= 2:
|
||||||
length = int(parts[1])
|
length = int(parts[1])
|
||||||
prompt = prompt[:length]
|
prompt = prompt[:length]
|
||||||
filename = filename.replace(segment, prompt.strip())
|
filename = filename.replace(segment, prompt.strip())
|
||||||
elif key == "nprompt" and 'negative_prompt' in metadata_dict:
|
elif key == "nprompt" and "negative_prompt" in metadata_dict:
|
||||||
prompt = metadata_dict.get('negative_prompt', '').replace("\n", " ")
|
prompt = metadata_dict.get("negative_prompt", "").replace("\n", " ")
|
||||||
if len(parts) >= 2:
|
if len(parts) >= 2:
|
||||||
length = int(parts[1])
|
length = int(parts[1])
|
||||||
prompt = prompt[:length]
|
prompt = prompt[:length]
|
||||||
filename = filename.replace(segment, prompt.strip())
|
filename = filename.replace(segment, prompt.strip())
|
||||||
elif key == "model":
|
elif key == "model":
|
||||||
model_value = metadata_dict.get('checkpoint')
|
model_value = metadata_dict.get("checkpoint")
|
||||||
if isinstance(model_value, (bytes, os.PathLike)):
|
if isinstance(model_value, (bytes, os.PathLike)):
|
||||||
model_value = str(model_value)
|
model_value = str(model_value)
|
||||||
|
|
||||||
@@ -291,6 +315,7 @@ class SaveImageLM:
|
|||||||
filename = filename.replace(segment, model)
|
filename = filename.replace(segment, model)
|
||||||
elif key == "date":
|
elif key == "date":
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
now = datetime.now()
|
now = datetime.now()
|
||||||
date_table = {
|
date_table = {
|
||||||
"yyyy": f"{now.year:04d}",
|
"yyyy": f"{now.year:04d}",
|
||||||
@@ -314,8 +339,19 @@ class SaveImageLM:
|
|||||||
|
|
||||||
return filename
|
return filename
|
||||||
|
|
||||||
def save_images(self, images, filename_prefix, file_format, id, prompt=None, extra_pnginfo=None,
|
def save_images(
|
||||||
lossless_webp=True, quality=100, embed_workflow=False, add_counter_to_filename=True):
|
self,
|
||||||
|
images,
|
||||||
|
filename_prefix,
|
||||||
|
file_format,
|
||||||
|
id,
|
||||||
|
prompt=None,
|
||||||
|
extra_pnginfo=None,
|
||||||
|
lossless_webp=True,
|
||||||
|
quality=100,
|
||||||
|
embed_workflow=False,
|
||||||
|
add_counter_to_filename=True,
|
||||||
|
):
|
||||||
"""Save images with metadata"""
|
"""Save images with metadata"""
|
||||||
results = []
|
results = []
|
||||||
|
|
||||||
@@ -329,8 +365,10 @@ class SaveImageLM:
|
|||||||
filename_prefix = self.format_filename(filename_prefix, metadata_dict)
|
filename_prefix = self.format_filename(filename_prefix, metadata_dict)
|
||||||
|
|
||||||
# Get initial save path info once for the batch
|
# Get initial save path info once for the batch
|
||||||
full_output_folder, filename, counter, subfolder, processed_prefix = folder_paths.get_save_image_path(
|
full_output_folder, filename, counter, subfolder, processed_prefix = (
|
||||||
filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0]
|
folder_paths.get_save_image_path(
|
||||||
|
filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0]
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create directory if it doesn't exist
|
# Create directory if it doesn't exist
|
||||||
@@ -340,7 +378,7 @@ class SaveImageLM:
|
|||||||
# Process each image with incrementing counter
|
# Process each image with incrementing counter
|
||||||
for i, image in enumerate(images):
|
for i, image in enumerate(images):
|
||||||
# Convert the tensor image to numpy array
|
# Convert the tensor image to numpy array
|
||||||
img = 255. * image.cpu().numpy()
|
img = 255.0 * image.cpu().numpy()
|
||||||
img = Image.fromarray(np.clip(img, 0, 255).astype(np.uint8))
|
img = Image.fromarray(np.clip(img, 0, 255).astype(np.uint8))
|
||||||
|
|
||||||
# Generate filename with counter if needed
|
# Generate filename with counter if needed
|
||||||
@@ -351,6 +389,9 @@ class SaveImageLM:
|
|||||||
base_filename += f"_{current_counter:05}_"
|
base_filename += f"_{current_counter:05}_"
|
||||||
|
|
||||||
# Set file extension and prepare saving parameters
|
# Set file extension and prepare saving parameters
|
||||||
|
file: str
|
||||||
|
save_kwargs: Dict[str, Any]
|
||||||
|
pnginfo: Optional[PngImagePlugin.PngInfo] = None
|
||||||
if file_format == "png":
|
if file_format == "png":
|
||||||
file = base_filename + ".png"
|
file = base_filename + ".png"
|
||||||
file_extension = ".png"
|
file_extension = ".png"
|
||||||
@@ -365,7 +406,13 @@ class SaveImageLM:
|
|||||||
file = base_filename + ".webp"
|
file = base_filename + ".webp"
|
||||||
file_extension = ".webp"
|
file_extension = ".webp"
|
||||||
# Add optimization param to control performance
|
# Add optimization param to control performance
|
||||||
save_kwargs = {"quality": quality, "lossless": lossless_webp, "method": 0}
|
save_kwargs = {
|
||||||
|
"quality": quality,
|
||||||
|
"lossless": lossless_webp,
|
||||||
|
"method": 0,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported file format: {file_format}")
|
||||||
|
|
||||||
# Full save path
|
# Full save path
|
||||||
file_path = os.path.join(full_output_folder, file)
|
file_path = os.path.join(full_output_folder, file)
|
||||||
@@ -373,6 +420,7 @@ class SaveImageLM:
|
|||||||
# Save the image with metadata
|
# Save the image with metadata
|
||||||
try:
|
try:
|
||||||
if file_format == "png":
|
if file_format == "png":
|
||||||
|
assert pnginfo is not None
|
||||||
if metadata:
|
if metadata:
|
||||||
pnginfo.add_text("parameters", metadata)
|
pnginfo.add_text("parameters", metadata)
|
||||||
if embed_workflow and extra_pnginfo is not None:
|
if embed_workflow and extra_pnginfo is not None:
|
||||||
@@ -384,7 +432,12 @@ class SaveImageLM:
|
|||||||
# For JPEG, use piexif
|
# For JPEG, use piexif
|
||||||
if metadata:
|
if metadata:
|
||||||
try:
|
try:
|
||||||
exif_dict = {'Exif': {piexif.ExifIFD.UserComment: b'UNICODE\0' + metadata.encode('utf-16be')}}
|
exif_dict = {
|
||||||
|
"Exif": {
|
||||||
|
piexif.ExifIFD.UserComment: b"UNICODE\0"
|
||||||
|
+ metadata.encode("utf-16be")
|
||||||
|
}
|
||||||
|
}
|
||||||
exif_bytes = piexif.dump(exif_dict)
|
exif_bytes = piexif.dump(exif_dict)
|
||||||
save_kwargs["exif"] = exif_bytes
|
save_kwargs["exif"] = exif_bytes
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -396,12 +449,18 @@ class SaveImageLM:
|
|||||||
exif_dict = {}
|
exif_dict = {}
|
||||||
|
|
||||||
if metadata:
|
if metadata:
|
||||||
exif_dict['Exif'] = {piexif.ExifIFD.UserComment: b'UNICODE\0' + metadata.encode('utf-16be')}
|
exif_dict["Exif"] = {
|
||||||
|
piexif.ExifIFD.UserComment: b"UNICODE\0"
|
||||||
|
+ metadata.encode("utf-16be")
|
||||||
|
}
|
||||||
|
|
||||||
# Add workflow if needed
|
# Add workflow if needed
|
||||||
if embed_workflow and extra_pnginfo is not None:
|
if embed_workflow and extra_pnginfo is not None:
|
||||||
workflow_json = json.dumps(extra_pnginfo["workflow"])
|
workflow_json = json.dumps(extra_pnginfo["workflow"])
|
||||||
exif_dict['0th'] = {piexif.ImageIFD.ImageDescription: "Workflow:" + workflow_json}
|
exif_dict["0th"] = {
|
||||||
|
piexif.ImageIFD.ImageDescription: "Workflow:"
|
||||||
|
+ workflow_json
|
||||||
|
}
|
||||||
|
|
||||||
exif_bytes = piexif.dump(exif_dict)
|
exif_bytes = piexif.dump(exif_dict)
|
||||||
save_kwargs["exif"] = exif_bytes
|
save_kwargs["exif"] = exif_bytes
|
||||||
@@ -410,19 +469,28 @@ class SaveImageLM:
|
|||||||
|
|
||||||
img.save(file_path, format="WEBP", **save_kwargs)
|
img.save(file_path, format="WEBP", **save_kwargs)
|
||||||
|
|
||||||
results.append({
|
results.append(
|
||||||
"filename": file,
|
{"filename": file, "subfolder": subfolder, "type": self.type}
|
||||||
"subfolder": subfolder,
|
)
|
||||||
"type": self.type
|
|
||||||
})
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error saving image: {e}")
|
logger.error(f"Error saving image: {e}")
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def process_image(self, images, id, filename_prefix="ComfyUI", file_format="png", prompt=None, extra_pnginfo=None,
|
def process_image(
|
||||||
lossless_webp=True, quality=100, embed_workflow=False, add_counter_to_filename=True):
|
self,
|
||||||
|
images,
|
||||||
|
id,
|
||||||
|
filename_prefix="ComfyUI",
|
||||||
|
file_format="png",
|
||||||
|
prompt=None,
|
||||||
|
extra_pnginfo=None,
|
||||||
|
lossless_webp=True,
|
||||||
|
quality=100,
|
||||||
|
embed_workflow=False,
|
||||||
|
add_counter_to_filename=True,
|
||||||
|
):
|
||||||
"""Process and save image with metadata"""
|
"""Process and save image with metadata"""
|
||||||
# Make sure the output directory exists
|
# Make sure the output directory exists
|
||||||
os.makedirs(self.output_dir, exist_ok=True)
|
os.makedirs(self.output_dir, exist_ok=True)
|
||||||
@@ -448,7 +516,7 @@ class SaveImageLM:
|
|||||||
lossless_webp,
|
lossless_webp,
|
||||||
quality,
|
quality,
|
||||||
embed_workflow,
|
embed_workflow,
|
||||||
add_counter_to_filename
|
add_counter_to_filename,
|
||||||
)
|
)
|
||||||
|
|
||||||
return (images,)
|
return (images,)
|
||||||
|
|||||||
@@ -1,33 +1,35 @@
|
|||||||
class AnyType(str):
|
class AnyType(str):
|
||||||
"""A special class that is always equal in not equal comparisons. Credit to pythongosssss"""
|
"""A special class that is always equal in not equal comparisons. Credit to pythongosssss"""
|
||||||
|
|
||||||
|
def __ne__(self, __value: object) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
def __ne__(self, __value: object) -> bool:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Credit to Regis Gaughan, III (rgthree)
|
# Credit to Regis Gaughan, III (rgthree)
|
||||||
class FlexibleOptionalInputType(dict):
|
class FlexibleOptionalInputType(dict):
|
||||||
"""A special class to make flexible nodes that pass data to our python handlers.
|
"""A special class to make flexible nodes that pass data to our python handlers.
|
||||||
|
|
||||||
Enables both flexible/dynamic input types (like for Any Switch) or a dynamic number of inputs
|
Enables both flexible/dynamic input types (like for Any Switch) or a dynamic number of inputs
|
||||||
(like for Any Switch, Context Switch, Context Merge, Power Lora Loader, etc).
|
(like for Any Switch, Context Switch, Context Merge, Power Lora Loader, etc).
|
||||||
|
|
||||||
Note, for ComfyUI, all that's needed is the `__contains__` override below, which tells ComfyUI
|
Note, for ComfyUI, all that's needed is the `__contains__` override below, which tells ComfyUI
|
||||||
that our node will handle the input, regardless of what it is.
|
that our node will handle the input, regardless of what it is.
|
||||||
|
|
||||||
However, with https://github.com/comfyanonymous/ComfyUI/pull/2666 a large change would occur
|
However, with https://github.com/comfyanonymous/ComfyUI/pull/2666 a large change would occur
|
||||||
requiring more details on the input itself. There, we need to return a list/tuple where the first
|
requiring more details on the input itself. There, we need to return a list/tuple where the first
|
||||||
item is the type. This can be a real type, or use the AnyType for additional flexibility.
|
item is the type. This can be a real type, or use the AnyType for additional flexibility.
|
||||||
|
|
||||||
This should be forwards compatible unless more changes occur in the PR.
|
This should be forwards compatible unless more changes occur in the PR.
|
||||||
"""
|
"""
|
||||||
def __init__(self, type):
|
|
||||||
self.type = type
|
|
||||||
|
|
||||||
def __getitem__(self, key):
|
def __init__(self, type):
|
||||||
return (self.type, )
|
self.type = type
|
||||||
|
|
||||||
def __contains__(self, key):
|
def __getitem__(self, key):
|
||||||
return True
|
return (self.type,)
|
||||||
|
|
||||||
|
def __contains__(self, key):
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
any_type = AnyType("*")
|
any_type = AnyType("*")
|
||||||
@@ -37,25 +39,27 @@ import os
|
|||||||
import logging
|
import logging
|
||||||
import copy
|
import copy
|
||||||
import sys
|
import sys
|
||||||
import folder_paths
|
import folder_paths # type: ignore
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def extract_lora_name(lora_path):
|
def extract_lora_name(lora_path):
|
||||||
"""Extract the lora name from a lora path (e.g., 'IL\\aorunIllstrious.safetensors' -> 'aorunIllstrious')"""
|
"""Extract the lora name from a lora path (e.g., 'IL\\aorunIllstrious.safetensors' -> 'aorunIllstrious')"""
|
||||||
# Get the basename without extension
|
# Get the basename without extension
|
||||||
basename = os.path.basename(lora_path)
|
basename = os.path.basename(lora_path)
|
||||||
return os.path.splitext(basename)[0]
|
return os.path.splitext(basename)[0]
|
||||||
|
|
||||||
|
|
||||||
def get_loras_list(kwargs):
|
def get_loras_list(kwargs):
|
||||||
"""Helper to extract loras list from either old or new kwargs format"""
|
"""Helper to extract loras list from either old or new kwargs format"""
|
||||||
if 'loras' not in kwargs:
|
if "loras" not in kwargs:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
loras_data = kwargs['loras']
|
loras_data = kwargs["loras"]
|
||||||
# Handle new format: {'loras': {'__value__': [...]}}
|
# Handle new format: {'loras': {'__value__': [...]}}
|
||||||
if isinstance(loras_data, dict) and '__value__' in loras_data:
|
if isinstance(loras_data, dict) and "__value__" in loras_data:
|
||||||
return loras_data['__value__']
|
return loras_data["__value__"]
|
||||||
# Handle old format: {'loras': [...]}
|
# Handle old format: {'loras': [...]}
|
||||||
elif isinstance(loras_data, list):
|
elif isinstance(loras_data, list):
|
||||||
return loras_data
|
return loras_data
|
||||||
@@ -64,23 +68,25 @@ def get_loras_list(kwargs):
|
|||||||
logger.warning(f"Unexpected loras format: {type(loras_data)}")
|
logger.warning(f"Unexpected loras format: {type(loras_data)}")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
def load_state_dict_in_safetensors(path, device="cpu", filter_prefix=""):
|
def load_state_dict_in_safetensors(path, device="cpu", filter_prefix=""):
|
||||||
"""Simplified version of load_state_dict_in_safetensors that just loads from a local path"""
|
"""Simplified version of load_state_dict_in_safetensors that just loads from a local path"""
|
||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
|
|
||||||
state_dict = {}
|
state_dict = {}
|
||||||
with safetensors.torch.safe_open(path, framework="pt", device=device) as f:
|
with safetensors.torch.safe_open(path, framework="pt", device=device) as f: # type: ignore[attr-defined]
|
||||||
for k in f.keys():
|
for k in f.keys():
|
||||||
if filter_prefix and not k.startswith(filter_prefix):
|
if filter_prefix and not k.startswith(filter_prefix):
|
||||||
continue
|
continue
|
||||||
state_dict[k.removeprefix(filter_prefix)] = f.get_tensor(k)
|
state_dict[k.removeprefix(filter_prefix)] = f.get_tensor(k)
|
||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
def to_diffusers(input_lora):
|
def to_diffusers(input_lora):
|
||||||
"""Simplified version of to_diffusers for Flux LoRA conversion"""
|
"""Simplified version of to_diffusers for Flux LoRA conversion"""
|
||||||
import torch
|
import torch
|
||||||
from diffusers.utils.state_dict_utils import convert_unet_state_dict_to_peft
|
from diffusers.utils.state_dict_utils import convert_unet_state_dict_to_peft
|
||||||
from diffusers.loaders import FluxLoraLoaderMixin
|
from diffusers.loaders import FluxLoraLoaderMixin # type: ignore[attr-defined]
|
||||||
|
|
||||||
if isinstance(input_lora, str):
|
if isinstance(input_lora, str):
|
||||||
tensors = load_state_dict_in_safetensors(input_lora, device="cpu")
|
tensors = load_state_dict_in_safetensors(input_lora, device="cpu")
|
||||||
@@ -97,10 +103,15 @@ def to_diffusers(input_lora):
|
|||||||
|
|
||||||
return new_tensors
|
return new_tensors
|
||||||
|
|
||||||
|
|
||||||
def nunchaku_load_lora(model, lora_name, lora_strength):
|
def nunchaku_load_lora(model, lora_name, lora_strength):
|
||||||
"""Load a Flux LoRA for Nunchaku model"""
|
"""Load a Flux LoRA for Nunchaku model"""
|
||||||
# Get full path to the LoRA file. Allow both direct paths and registered LoRA names.
|
# Get full path to the LoRA file. Allow both direct paths and registered LoRA names.
|
||||||
lora_path = lora_name if os.path.isfile(lora_name) else folder_paths.get_full_path("loras", lora_name)
|
lora_path = (
|
||||||
|
lora_name
|
||||||
|
if os.path.isfile(lora_name)
|
||||||
|
else folder_paths.get_full_path("loras", lora_name)
|
||||||
|
)
|
||||||
if not lora_path or not os.path.isfile(lora_path):
|
if not lora_path or not os.path.isfile(lora_path):
|
||||||
logger.warning("Skipping LoRA '%s' because it could not be found", lora_name)
|
logger.warning("Skipping LoRA '%s' because it could not be found", lora_name)
|
||||||
return model
|
return model
|
||||||
@@ -118,7 +129,9 @@ def nunchaku_load_lora(model, lora_name, lora_strength):
|
|||||||
ret_model_wrapper.loras = [*model_wrapper.loras, (lora_path, lora_strength)]
|
ret_model_wrapper.loras = [*model_wrapper.loras, (lora_path, lora_strength)]
|
||||||
else:
|
else:
|
||||||
# Fallback to legacy logic
|
# Fallback to legacy logic
|
||||||
logger.warning("Please upgrade ComfyUI-nunchaku to 1.1.0 or above for better LoRA support. Falling back to legacy loading logic.")
|
logger.warning(
|
||||||
|
"Please upgrade ComfyUI-nunchaku to 1.1.0 or above for better LoRA support. Falling back to legacy loading logic."
|
||||||
|
)
|
||||||
transformer = model_wrapper.model
|
transformer = model_wrapper.model
|
||||||
|
|
||||||
# Save the transformer temporarily
|
# Save the transformer temporarily
|
||||||
|
|||||||
@@ -6,17 +6,18 @@ from .parsers import (
|
|||||||
ComfyMetadataParser,
|
ComfyMetadataParser,
|
||||||
MetaFormatParser,
|
MetaFormatParser,
|
||||||
AutomaticMetadataParser,
|
AutomaticMetadataParser,
|
||||||
CivitaiApiMetadataParser
|
CivitaiApiMetadataParser,
|
||||||
)
|
)
|
||||||
from .base import RecipeMetadataParser
|
from .base import RecipeMetadataParser
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class RecipeParserFactory:
|
class RecipeParserFactory:
|
||||||
"""Factory for creating recipe metadata parsers"""
|
"""Factory for creating recipe metadata parsers"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_parser(metadata) -> RecipeMetadataParser:
|
def create_parser(metadata) -> RecipeMetadataParser | None:
|
||||||
"""
|
"""
|
||||||
Create appropriate parser based on the metadata content
|
Create appropriate parser based on the metadata content
|
||||||
|
|
||||||
@@ -38,6 +39,7 @@ class RecipeParserFactory:
|
|||||||
# Convert dict to string for other parsers that expect string input
|
# Convert dict to string for other parsers that expect string input
|
||||||
try:
|
try:
|
||||||
import json
|
import json
|
||||||
|
|
||||||
metadata_str = json.dumps(metadata)
|
metadata_str = json.dumps(metadata)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"Failed to convert dict to JSON string: {e}")
|
logger.debug(f"Failed to convert dict to JSON string: {e}")
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from ...services.metadata_service import get_default_metadata_provider
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class CivitaiApiMetadataParser(RecipeMetadataParser):
|
class CivitaiApiMetadataParser(RecipeMetadataParser):
|
||||||
"""Parser for Civitai image metadata format"""
|
"""Parser for Civitai image metadata format"""
|
||||||
|
|
||||||
@@ -40,7 +41,7 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
|
|||||||
"width",
|
"width",
|
||||||
"height",
|
"height",
|
||||||
"Model",
|
"Model",
|
||||||
"Model hash"
|
"Model hash",
|
||||||
)
|
)
|
||||||
return any(key in payload for key in civitai_image_fields)
|
return any(key in payload for key in civitai_image_fields)
|
||||||
|
|
||||||
@@ -50,7 +51,9 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
|
|||||||
|
|
||||||
# Check for LoRA hash patterns
|
# Check for LoRA hash patterns
|
||||||
hashes = metadata.get("hashes")
|
hashes = metadata.get("hashes")
|
||||||
if isinstance(hashes, dict) and any(str(key).lower().startswith("lora:") for key in hashes):
|
if isinstance(hashes, dict) and any(
|
||||||
|
str(key).lower().startswith("lora:") for key in hashes
|
||||||
|
):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# Check nested meta object (common in CivitAI image responses)
|
# Check nested meta object (common in CivitAI image responses)
|
||||||
@@ -61,22 +64,28 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
|
|||||||
|
|
||||||
# Also check for LoRA hash patterns in nested meta
|
# Also check for LoRA hash patterns in nested meta
|
||||||
hashes = nested_meta.get("hashes")
|
hashes = nested_meta.get("hashes")
|
||||||
if isinstance(hashes, dict) and any(str(key).lower().startswith("lora:") for key in hashes):
|
if isinstance(hashes, dict) and any(
|
||||||
|
str(key).lower().startswith("lora:") for key in hashes
|
||||||
|
):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def parse_metadata(self, metadata, recipe_scanner=None, civitai_client=None) -> Dict[str, Any]:
|
async def parse_metadata( # type: ignore[override]
|
||||||
|
self, user_comment, recipe_scanner=None, civitai_client=None
|
||||||
|
) -> Dict[str, Any]:
|
||||||
"""Parse metadata from Civitai image format
|
"""Parse metadata from Civitai image format
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
metadata: The metadata from the image (dict)
|
user_comment: The metadata from the image (dict)
|
||||||
recipe_scanner: Optional recipe scanner service
|
recipe_scanner: Optional recipe scanner service
|
||||||
civitai_client: Optional Civitai API client (deprecated, use metadata_provider instead)
|
civitai_client: Optional Civitai API client (deprecated, use metadata_provider instead)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict containing parsed recipe data
|
Dict containing parsed recipe data
|
||||||
"""
|
"""
|
||||||
|
metadata: Dict[str, Any] = user_comment # type: ignore[assignment]
|
||||||
|
metadata = user_comment
|
||||||
try:
|
try:
|
||||||
# Get metadata provider instead of using civitai_client directly
|
# Get metadata provider instead of using civitai_client directly
|
||||||
metadata_provider = await get_default_metadata_provider()
|
metadata_provider = await get_default_metadata_provider()
|
||||||
@@ -103,11 +112,11 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
|
|||||||
|
|
||||||
# Initialize result structure
|
# Initialize result structure
|
||||||
result = {
|
result = {
|
||||||
'base_model': None,
|
"base_model": None,
|
||||||
'loras': [],
|
"loras": [],
|
||||||
'model': None,
|
"model": None,
|
||||||
'gen_params': {},
|
"gen_params": {},
|
||||||
'from_civitai_image': True
|
"from_civitai_image": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Track already added LoRAs to prevent duplicates
|
# Track already added LoRAs to prevent duplicates
|
||||||
@@ -148,16 +157,25 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
|
|||||||
result["base_model"] = metadata["baseModel"]
|
result["base_model"] = metadata["baseModel"]
|
||||||
elif "Model hash" in metadata and metadata_provider:
|
elif "Model hash" in metadata and metadata_provider:
|
||||||
model_hash = metadata["Model hash"]
|
model_hash = metadata["Model hash"]
|
||||||
model_info, error = await metadata_provider.get_model_by_hash(model_hash)
|
model_info, error = await metadata_provider.get_model_by_hash(
|
||||||
|
model_hash
|
||||||
|
)
|
||||||
if model_info:
|
if model_info:
|
||||||
result["base_model"] = model_info.get("baseModel", "")
|
result["base_model"] = model_info.get("baseModel", "")
|
||||||
elif "Model" in metadata and isinstance(metadata.get("resources"), list):
|
elif "Model" in metadata and isinstance(metadata.get("resources"), list):
|
||||||
# Try to find base model in resources
|
# Try to find base model in resources
|
||||||
for resource in metadata.get("resources", []):
|
for resource in metadata.get("resources", []):
|
||||||
if resource.get("type") == "model" and resource.get("name") == metadata.get("Model"):
|
if resource.get("type") == "model" and resource.get(
|
||||||
|
"name"
|
||||||
|
) == metadata.get("Model"):
|
||||||
# This is likely the checkpoint model
|
# This is likely the checkpoint model
|
||||||
if metadata_provider and resource.get("hash"):
|
if metadata_provider and resource.get("hash"):
|
||||||
model_info, error = await metadata_provider.get_model_by_hash(resource.get("hash"))
|
(
|
||||||
|
model_info,
|
||||||
|
error,
|
||||||
|
) = await metadata_provider.get_model_by_hash(
|
||||||
|
resource.get("hash")
|
||||||
|
)
|
||||||
if model_info:
|
if model_info:
|
||||||
result["base_model"] = model_info.get("baseModel", "")
|
result["base_model"] = model_info.get("baseModel", "")
|
||||||
|
|
||||||
@@ -176,7 +194,9 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
|
|||||||
|
|
||||||
# Skip LoRAs without proper identification (hash or modelVersionId)
|
# Skip LoRAs without proper identification (hash or modelVersionId)
|
||||||
if not lora_hash and not resource.get("modelVersionId"):
|
if not lora_hash and not resource.get("modelVersionId"):
|
||||||
logger.debug(f"Skipping LoRA resource '{resource.get('name', 'Unknown')}' - no hash or modelVersionId")
|
logger.debug(
|
||||||
|
f"Skipping LoRA resource '{resource.get('name', 'Unknown')}' - no hash or modelVersionId"
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Skip if we've already added this LoRA by hash
|
# Skip if we've already added this LoRA by hash
|
||||||
@@ -184,31 +204,33 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
lora_entry = {
|
lora_entry = {
|
||||||
'name': resource.get("name", "Unknown LoRA"),
|
"name": resource.get("name", "Unknown LoRA"),
|
||||||
'type': "lora",
|
"type": "lora",
|
||||||
'weight': float(resource.get("weight", 1.0)),
|
"weight": float(resource.get("weight", 1.0)),
|
||||||
'hash': lora_hash,
|
"hash": lora_hash,
|
||||||
'existsLocally': False,
|
"existsLocally": False,
|
||||||
'localPath': None,
|
"localPath": None,
|
||||||
'file_name': resource.get("name", "Unknown"),
|
"file_name": resource.get("name", "Unknown"),
|
||||||
'thumbnailUrl': '/loras_static/images/no-preview.png',
|
"thumbnailUrl": "/loras_static/images/no-preview.png",
|
||||||
'baseModel': '',
|
"baseModel": "",
|
||||||
'size': 0,
|
"size": 0,
|
||||||
'downloadUrl': '',
|
"downloadUrl": "",
|
||||||
'isDeleted': False
|
"isDeleted": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Try to get info from Civitai if hash is available
|
# Try to get info from Civitai if hash is available
|
||||||
if lora_entry['hash'] and metadata_provider:
|
if lora_entry["hash"] and metadata_provider:
|
||||||
try:
|
try:
|
||||||
civitai_info = await metadata_provider.get_model_by_hash(lora_hash)
|
civitai_info = (
|
||||||
|
await metadata_provider.get_model_by_hash(lora_hash)
|
||||||
|
)
|
||||||
|
|
||||||
populated_entry = await self.populate_lora_from_civitai(
|
populated_entry = await self.populate_lora_from_civitai(
|
||||||
lora_entry,
|
lora_entry,
|
||||||
civitai_info,
|
civitai_info,
|
||||||
recipe_scanner,
|
recipe_scanner,
|
||||||
base_model_counts,
|
base_model_counts,
|
||||||
lora_hash
|
lora_hash,
|
||||||
)
|
)
|
||||||
|
|
||||||
if populated_entry is None:
|
if populated_entry is None:
|
||||||
@@ -217,10 +239,14 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
|
|||||||
lora_entry = populated_entry
|
lora_entry = populated_entry
|
||||||
|
|
||||||
# If we have a version ID from Civitai, track it for deduplication
|
# If we have a version ID from Civitai, track it for deduplication
|
||||||
if 'id' in lora_entry and lora_entry['id']:
|
if "id" in lora_entry and lora_entry["id"]:
|
||||||
added_loras[str(lora_entry['id'])] = len(result["loras"])
|
added_loras[str(lora_entry["id"])] = len(
|
||||||
|
result["loras"]
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error fetching Civitai info for LoRA hash {lora_entry['hash']}: {e}")
|
logger.error(
|
||||||
|
f"Error fetching Civitai info for LoRA hash {lora_entry['hash']}: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
# Track by hash if we have it
|
# Track by hash if we have it
|
||||||
if lora_hash:
|
if lora_hash:
|
||||||
@@ -229,7 +255,9 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
|
|||||||
result["loras"].append(lora_entry)
|
result["loras"].append(lora_entry)
|
||||||
|
|
||||||
# Process civitaiResources array
|
# Process civitaiResources array
|
||||||
if "civitaiResources" in metadata and isinstance(metadata["civitaiResources"], list):
|
if "civitaiResources" in metadata and isinstance(
|
||||||
|
metadata["civitaiResources"], list
|
||||||
|
):
|
||||||
for resource in metadata["civitaiResources"]:
|
for resource in metadata["civitaiResources"]:
|
||||||
# Get resource type and identifier
|
# Get resource type and identifier
|
||||||
resource_type = str(resource.get("type") or "").lower()
|
resource_type = str(resource.get("type") or "").lower()
|
||||||
@@ -237,32 +265,39 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
|
|||||||
|
|
||||||
if resource_type == "checkpoint":
|
if resource_type == "checkpoint":
|
||||||
checkpoint_entry = {
|
checkpoint_entry = {
|
||||||
'id': resource.get("modelVersionId", 0),
|
"id": resource.get("modelVersionId", 0),
|
||||||
'modelId': resource.get("modelId", 0),
|
"modelId": resource.get("modelId", 0),
|
||||||
'name': resource.get("modelName", "Unknown Checkpoint"),
|
"name": resource.get("modelName", "Unknown Checkpoint"),
|
||||||
'version': resource.get("modelVersionName", ""),
|
"version": resource.get("modelVersionName", ""),
|
||||||
'type': resource.get("type", "checkpoint"),
|
"type": resource.get("type", "checkpoint"),
|
||||||
'existsLocally': False,
|
"existsLocally": False,
|
||||||
'localPath': None,
|
"localPath": None,
|
||||||
'file_name': resource.get("modelName", ""),
|
"file_name": resource.get("modelName", ""),
|
||||||
'hash': resource.get("hash", "") or "",
|
"hash": resource.get("hash", "") or "",
|
||||||
'thumbnailUrl': '/loras_static/images/no-preview.png',
|
"thumbnailUrl": "/loras_static/images/no-preview.png",
|
||||||
'baseModel': '',
|
"baseModel": "",
|
||||||
'size': 0,
|
"size": 0,
|
||||||
'downloadUrl': '',
|
"downloadUrl": "",
|
||||||
'isDeleted': False
|
"isDeleted": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
if version_id and metadata_provider:
|
if version_id and metadata_provider:
|
||||||
try:
|
try:
|
||||||
civitai_info = await metadata_provider.get_model_version_info(version_id)
|
civitai_info = (
|
||||||
|
await metadata_provider.get_model_version_info(
|
||||||
|
version_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
checkpoint_entry = await self.populate_checkpoint_from_civitai(
|
checkpoint_entry = (
|
||||||
checkpoint_entry,
|
await self.populate_checkpoint_from_civitai(
|
||||||
civitai_info
|
checkpoint_entry, civitai_info
|
||||||
|
)
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error fetching Civitai info for checkpoint version {version_id}: {e}")
|
logger.error(
|
||||||
|
f"Error fetching Civitai info for checkpoint version {version_id}: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
if result["model"] is None:
|
if result["model"] is None:
|
||||||
result["model"] = checkpoint_entry
|
result["model"] = checkpoint_entry
|
||||||
@@ -275,31 +310,35 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
|
|||||||
|
|
||||||
# Initialize lora entry
|
# Initialize lora entry
|
||||||
lora_entry = {
|
lora_entry = {
|
||||||
'id': resource.get("modelVersionId", 0),
|
"id": resource.get("modelVersionId", 0),
|
||||||
'modelId': resource.get("modelId", 0),
|
"modelId": resource.get("modelId", 0),
|
||||||
'name': resource.get("modelName", "Unknown LoRA"),
|
"name": resource.get("modelName", "Unknown LoRA"),
|
||||||
'version': resource.get("modelVersionName", ""),
|
"version": resource.get("modelVersionName", ""),
|
||||||
'type': resource.get("type", "lora"),
|
"type": resource.get("type", "lora"),
|
||||||
'weight': round(float(resource.get("weight", 1.0)), 2),
|
"weight": round(float(resource.get("weight", 1.0)), 2),
|
||||||
'existsLocally': False,
|
"existsLocally": False,
|
||||||
'thumbnailUrl': '/loras_static/images/no-preview.png',
|
"thumbnailUrl": "/loras_static/images/no-preview.png",
|
||||||
'baseModel': '',
|
"baseModel": "",
|
||||||
'size': 0,
|
"size": 0,
|
||||||
'downloadUrl': '',
|
"downloadUrl": "",
|
||||||
'isDeleted': False
|
"isDeleted": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Try to get info from Civitai if modelVersionId is available
|
# Try to get info from Civitai if modelVersionId is available
|
||||||
if version_id and metadata_provider:
|
if version_id and metadata_provider:
|
||||||
try:
|
try:
|
||||||
# Use get_model_version_info instead of get_model_version
|
# Use get_model_version_info instead of get_model_version
|
||||||
civitai_info = await metadata_provider.get_model_version_info(version_id)
|
civitai_info = (
|
||||||
|
await metadata_provider.get_model_version_info(
|
||||||
|
version_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
populated_entry = await self.populate_lora_from_civitai(
|
populated_entry = await self.populate_lora_from_civitai(
|
||||||
lora_entry,
|
lora_entry,
|
||||||
civitai_info,
|
civitai_info,
|
||||||
recipe_scanner,
|
recipe_scanner,
|
||||||
base_model_counts
|
base_model_counts,
|
||||||
)
|
)
|
||||||
|
|
||||||
if populated_entry is None:
|
if populated_entry is None:
|
||||||
@@ -307,7 +346,9 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
|
|||||||
|
|
||||||
lora_entry = populated_entry
|
lora_entry = populated_entry
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error fetching Civitai info for model version {version_id}: {e}")
|
logger.error(
|
||||||
|
f"Error fetching Civitai info for model version {version_id}: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
# Track this LoRA in our deduplication dict
|
# Track this LoRA in our deduplication dict
|
||||||
if version_id:
|
if version_id:
|
||||||
@@ -316,10 +357,15 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
|
|||||||
result["loras"].append(lora_entry)
|
result["loras"].append(lora_entry)
|
||||||
|
|
||||||
# Process additionalResources array
|
# Process additionalResources array
|
||||||
if "additionalResources" in metadata and isinstance(metadata["additionalResources"], list):
|
if "additionalResources" in metadata and isinstance(
|
||||||
|
metadata["additionalResources"], list
|
||||||
|
):
|
||||||
for resource in metadata["additionalResources"]:
|
for resource in metadata["additionalResources"]:
|
||||||
# Skip resources that aren't LoRAs or LyCORIS
|
# Skip resources that aren't LoRAs or LyCORIS
|
||||||
if resource.get("type") not in ["lora", "lycoris"] and "type" not in resource:
|
if (
|
||||||
|
resource.get("type") not in ["lora", "lycoris"]
|
||||||
|
and "type" not in resource
|
||||||
|
):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
lora_type = resource.get("type", "lora")
|
lora_type = resource.get("type", "lora")
|
||||||
@@ -337,31 +383,35 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
lora_entry = {
|
lora_entry = {
|
||||||
'name': name,
|
"name": name,
|
||||||
'type': lora_type,
|
"type": lora_type,
|
||||||
'weight': float(resource.get("strength", 1.0)),
|
"weight": float(resource.get("strength", 1.0)),
|
||||||
'hash': "",
|
"hash": "",
|
||||||
'existsLocally': False,
|
"existsLocally": False,
|
||||||
'localPath': None,
|
"localPath": None,
|
||||||
'file_name': name,
|
"file_name": name,
|
||||||
'thumbnailUrl': '/loras_static/images/no-preview.png',
|
"thumbnailUrl": "/loras_static/images/no-preview.png",
|
||||||
'baseModel': '',
|
"baseModel": "",
|
||||||
'size': 0,
|
"size": 0,
|
||||||
'downloadUrl': '',
|
"downloadUrl": "",
|
||||||
'isDeleted': False
|
"isDeleted": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
# If we have a version ID and metadata provider, try to get more info
|
# If we have a version ID and metadata provider, try to get more info
|
||||||
if version_id and metadata_provider:
|
if version_id and metadata_provider:
|
||||||
try:
|
try:
|
||||||
# Use get_model_version_info with the version ID
|
# Use get_model_version_info with the version ID
|
||||||
civitai_info = await metadata_provider.get_model_version_info(version_id)
|
civitai_info = (
|
||||||
|
await metadata_provider.get_model_version_info(
|
||||||
|
version_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
populated_entry = await self.populate_lora_from_civitai(
|
populated_entry = await self.populate_lora_from_civitai(
|
||||||
lora_entry,
|
lora_entry,
|
||||||
civitai_info,
|
civitai_info,
|
||||||
recipe_scanner,
|
recipe_scanner,
|
||||||
base_model_counts
|
base_model_counts,
|
||||||
)
|
)
|
||||||
|
|
||||||
if populated_entry is None:
|
if populated_entry is None:
|
||||||
@@ -373,7 +423,9 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
|
|||||||
if version_id:
|
if version_id:
|
||||||
added_loras[version_id] = len(result["loras"])
|
added_loras[version_id] = len(result["loras"])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error fetching Civitai info for model ID {version_id}: {e}")
|
logger.error(
|
||||||
|
f"Error fetching Civitai info for model ID {version_id}: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
result["loras"].append(lora_entry)
|
result["loras"].append(lora_entry)
|
||||||
|
|
||||||
@@ -390,30 +442,32 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
lora_entry = {
|
lora_entry = {
|
||||||
'name': lora_name,
|
"name": lora_name,
|
||||||
'type': "lora",
|
"type": "lora",
|
||||||
'weight': 1.0,
|
"weight": 1.0,
|
||||||
'hash': lora_hash,
|
"hash": lora_hash,
|
||||||
'existsLocally': False,
|
"existsLocally": False,
|
||||||
'localPath': None,
|
"localPath": None,
|
||||||
'file_name': lora_name,
|
"file_name": lora_name,
|
||||||
'thumbnailUrl': '/loras_static/images/no-preview.png',
|
"thumbnailUrl": "/loras_static/images/no-preview.png",
|
||||||
'baseModel': '',
|
"baseModel": "",
|
||||||
'size': 0,
|
"size": 0,
|
||||||
'downloadUrl': '',
|
"downloadUrl": "",
|
||||||
'isDeleted': False
|
"isDeleted": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
if metadata_provider:
|
if metadata_provider:
|
||||||
try:
|
try:
|
||||||
civitai_info = await metadata_provider.get_model_by_hash(lora_hash)
|
civitai_info = await metadata_provider.get_model_by_hash(
|
||||||
|
lora_hash
|
||||||
|
)
|
||||||
|
|
||||||
populated_entry = await self.populate_lora_from_civitai(
|
populated_entry = await self.populate_lora_from_civitai(
|
||||||
lora_entry,
|
lora_entry,
|
||||||
civitai_info,
|
civitai_info,
|
||||||
recipe_scanner,
|
recipe_scanner,
|
||||||
base_model_counts,
|
base_model_counts,
|
||||||
lora_hash
|
lora_hash,
|
||||||
)
|
)
|
||||||
|
|
||||||
if populated_entry is None:
|
if populated_entry is None:
|
||||||
@@ -421,20 +475,27 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
|
|||||||
|
|
||||||
lora_entry = populated_entry
|
lora_entry = populated_entry
|
||||||
|
|
||||||
if 'id' in lora_entry and lora_entry['id']:
|
if "id" in lora_entry and lora_entry["id"]:
|
||||||
added_loras[str(lora_entry['id'])] = len(result["loras"])
|
added_loras[str(lora_entry["id"])] = len(result["loras"])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error fetching Civitai info for LoRA hash {lora_hash}: {e}")
|
logger.error(
|
||||||
|
f"Error fetching Civitai info for LoRA hash {lora_hash}: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
added_loras[lora_hash] = len(result["loras"])
|
added_loras[lora_hash] = len(result["loras"])
|
||||||
result["loras"].append(lora_entry)
|
result["loras"].append(lora_entry)
|
||||||
|
|
||||||
# Check for LoRA info in the format "Lora_0 Model hash", "Lora_0 Model name", etc.
|
# Check for LoRA info in the format "Lora_0 Model hash", "Lora_0 Model name", etc.
|
||||||
lora_index = 0
|
lora_index = 0
|
||||||
while f"Lora_{lora_index} Model hash" in metadata and f"Lora_{lora_index} Model name" in metadata:
|
while (
|
||||||
|
f"Lora_{lora_index} Model hash" in metadata
|
||||||
|
and f"Lora_{lora_index} Model name" in metadata
|
||||||
|
):
|
||||||
lora_hash = metadata[f"Lora_{lora_index} Model hash"]
|
lora_hash = metadata[f"Lora_{lora_index} Model hash"]
|
||||||
lora_name = metadata[f"Lora_{lora_index} Model name"]
|
lora_name = metadata[f"Lora_{lora_index} Model name"]
|
||||||
lora_strength_model = float(metadata.get(f"Lora_{lora_index} Strength model", 1.0))
|
lora_strength_model = float(
|
||||||
|
metadata.get(f"Lora_{lora_index} Strength model", 1.0)
|
||||||
|
)
|
||||||
|
|
||||||
# Skip if we've already added this LoRA by hash
|
# Skip if we've already added this LoRA by hash
|
||||||
if lora_hash and lora_hash in added_loras:
|
if lora_hash and lora_hash in added_loras:
|
||||||
@@ -442,31 +503,33 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
lora_entry = {
|
lora_entry = {
|
||||||
'name': lora_name,
|
"name": lora_name,
|
||||||
'type': "lora",
|
"type": "lora",
|
||||||
'weight': lora_strength_model,
|
"weight": lora_strength_model,
|
||||||
'hash': lora_hash,
|
"hash": lora_hash,
|
||||||
'existsLocally': False,
|
"existsLocally": False,
|
||||||
'localPath': None,
|
"localPath": None,
|
||||||
'file_name': lora_name,
|
"file_name": lora_name,
|
||||||
'thumbnailUrl': '/loras_static/images/no-preview.png',
|
"thumbnailUrl": "/loras_static/images/no-preview.png",
|
||||||
'baseModel': '',
|
"baseModel": "",
|
||||||
'size': 0,
|
"size": 0,
|
||||||
'downloadUrl': '',
|
"downloadUrl": "",
|
||||||
'isDeleted': False
|
"isDeleted": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Try to get info from Civitai if hash is available
|
# Try to get info from Civitai if hash is available
|
||||||
if lora_entry['hash'] and metadata_provider:
|
if lora_entry["hash"] and metadata_provider:
|
||||||
try:
|
try:
|
||||||
civitai_info = await metadata_provider.get_model_by_hash(lora_hash)
|
civitai_info = await metadata_provider.get_model_by_hash(
|
||||||
|
lora_hash
|
||||||
|
)
|
||||||
|
|
||||||
populated_entry = await self.populate_lora_from_civitai(
|
populated_entry = await self.populate_lora_from_civitai(
|
||||||
lora_entry,
|
lora_entry,
|
||||||
civitai_info,
|
civitai_info,
|
||||||
recipe_scanner,
|
recipe_scanner,
|
||||||
base_model_counts,
|
base_model_counts,
|
||||||
lora_hash
|
lora_hash,
|
||||||
)
|
)
|
||||||
|
|
||||||
if populated_entry is None:
|
if populated_entry is None:
|
||||||
@@ -476,10 +539,12 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
|
|||||||
lora_entry = populated_entry
|
lora_entry = populated_entry
|
||||||
|
|
||||||
# If we have a version ID from Civitai, track it for deduplication
|
# If we have a version ID from Civitai, track it for deduplication
|
||||||
if 'id' in lora_entry and lora_entry['id']:
|
if "id" in lora_entry and lora_entry["id"]:
|
||||||
added_loras[str(lora_entry['id'])] = len(result["loras"])
|
added_loras[str(lora_entry["id"])] = len(result["loras"])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error fetching Civitai info for LoRA hash {lora_entry['hash']}: {e}")
|
logger.error(
|
||||||
|
f"Error fetching Civitai info for LoRA hash {lora_entry['hash']}: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
# Track by hash if we have it
|
# Track by hash if we have it
|
||||||
if lora_hash:
|
if lora_hash:
|
||||||
@@ -491,7 +556,9 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
|
|||||||
|
|
||||||
# If base model wasn't found earlier, use the most common one from LoRAs
|
# If base model wasn't found earlier, use the most common one from LoRAs
|
||||||
if not result["base_model"] and base_model_counts:
|
if not result["base_model"] and base_model_counts:
|
||||||
result["base_model"] = max(base_model_counts.items(), key=lambda x: x[1])[0]
|
result["base_model"] = max(
|
||||||
|
base_model_counts.items(), key=lambda x: x[1]
|
||||||
|
)[0]
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|||||||
@@ -3,13 +3,17 @@ import copy
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import Any, Optional, Dict, Tuple, List, Sequence
|
from typing import Any, Optional, Dict, Tuple, List, Sequence
|
||||||
from .model_metadata_provider import CivitaiModelMetadataProvider, ModelMetadataProviderManager
|
from .model_metadata_provider import (
|
||||||
|
CivitaiModelMetadataProvider,
|
||||||
|
ModelMetadataProviderManager,
|
||||||
|
)
|
||||||
from .downloader import get_downloader
|
from .downloader import get_downloader
|
||||||
from .errors import RateLimitError, ResourceNotFoundError
|
from .errors import RateLimitError, ResourceNotFoundError
|
||||||
from ..utils.civitai_utils import resolve_license_payload
|
from ..utils.civitai_utils import resolve_license_payload
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class CivitaiClient:
|
class CivitaiClient:
|
||||||
_instance = None
|
_instance = None
|
||||||
_lock = asyncio.Lock()
|
_lock = asyncio.Lock()
|
||||||
@@ -23,13 +27,15 @@ class CivitaiClient:
|
|||||||
|
|
||||||
# Register this client as a metadata provider
|
# Register this client as a metadata provider
|
||||||
provider_manager = await ModelMetadataProviderManager.get_instance()
|
provider_manager = await ModelMetadataProviderManager.get_instance()
|
||||||
provider_manager.register_provider('civitai', CivitaiModelMetadataProvider(cls._instance), True)
|
provider_manager.register_provider(
|
||||||
|
"civitai", CivitaiModelMetadataProvider(cls._instance), True
|
||||||
|
)
|
||||||
|
|
||||||
return cls._instance
|
return cls._instance
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
# Check if already initialized for singleton pattern
|
# Check if already initialized for singleton pattern
|
||||||
if hasattr(self, '_initialized'):
|
if hasattr(self, "_initialized"):
|
||||||
return
|
return
|
||||||
self._initialized = True
|
self._initialized = True
|
||||||
|
|
||||||
@@ -76,7 +82,9 @@ class CivitaiClient:
|
|||||||
if isinstance(meta, dict) and "comfy" in meta:
|
if isinstance(meta, dict) and "comfy" in meta:
|
||||||
meta.pop("comfy", None)
|
meta.pop("comfy", None)
|
||||||
|
|
||||||
async def download_file(self, url: str, save_dir: str, default_filename: str, progress_callback=None) -> Tuple[bool, str]:
|
async def download_file(
|
||||||
|
self, url: str, save_dir: str, default_filename: str, progress_callback=None
|
||||||
|
) -> Tuple[bool, str]:
|
||||||
"""Download file with resumable downloads and retry mechanism
|
"""Download file with resumable downloads and retry mechanism
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -97,34 +105,41 @@ class CivitaiClient:
|
|||||||
save_path=save_path,
|
save_path=save_path,
|
||||||
progress_callback=progress_callback,
|
progress_callback=progress_callback,
|
||||||
use_auth=True, # Enable CivitAI authentication
|
use_auth=True, # Enable CivitAI authentication
|
||||||
allow_resume=True
|
allow_resume=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
return success, result
|
return success, result
|
||||||
|
|
||||||
async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]:
|
async def get_model_by_hash(
|
||||||
|
self, model_hash: str
|
||||||
|
) -> Tuple[Optional[Dict], Optional[str]]:
|
||||||
try:
|
try:
|
||||||
success, version = await self._make_request(
|
success, version = await self._make_request(
|
||||||
'GET',
|
"GET",
|
||||||
f"{self.base_url}/model-versions/by-hash/{model_hash}",
|
f"{self.base_url}/model-versions/by-hash/{model_hash}",
|
||||||
use_auth=True
|
use_auth=True,
|
||||||
)
|
)
|
||||||
if not success:
|
if not success:
|
||||||
message = str(version)
|
message = str(version)
|
||||||
if "not found" in message.lower():
|
if "not found" in message.lower():
|
||||||
return None, "Model not found"
|
return None, "Model not found"
|
||||||
|
|
||||||
logger.error("Failed to fetch model info for %s: %s", model_hash[:10], message)
|
logger.error(
|
||||||
|
"Failed to fetch model info for %s: %s", model_hash[:10], message
|
||||||
|
)
|
||||||
return None, message
|
return None, message
|
||||||
|
|
||||||
model_id = version.get('modelId')
|
if isinstance(version, dict):
|
||||||
if model_id:
|
model_id = version.get("modelId")
|
||||||
model_data = await self._fetch_model_data(model_id)
|
if model_id:
|
||||||
if model_data:
|
model_data = await self._fetch_model_data(model_id)
|
||||||
self._enrich_version_with_model_data(version, model_data)
|
if model_data:
|
||||||
|
self._enrich_version_with_model_data(version, model_data)
|
||||||
|
|
||||||
self._remove_comfy_metadata(version)
|
self._remove_comfy_metadata(version)
|
||||||
return version, None
|
return version, None
|
||||||
|
else:
|
||||||
|
return None, "Invalid response format"
|
||||||
except RateLimitError:
|
except RateLimitError:
|
||||||
raise
|
raise
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
@@ -136,12 +151,12 @@ class CivitaiClient:
|
|||||||
downloader = await get_downloader()
|
downloader = await get_downloader()
|
||||||
success, content, headers = await downloader.download_to_memory(
|
success, content, headers = await downloader.download_to_memory(
|
||||||
image_url,
|
image_url,
|
||||||
use_auth=False # Preview images don't need auth
|
use_auth=False, # Preview images don't need auth
|
||||||
)
|
)
|
||||||
if success:
|
if success:
|
||||||
# Ensure directory exists
|
# Ensure directory exists
|
||||||
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
||||||
with open(save_path, 'wb') as f:
|
with open(save_path, "wb") as f:
|
||||||
f.write(content)
|
f.write(content)
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
@@ -175,19 +190,17 @@ class CivitaiClient:
|
|||||||
"""Get all versions of a model with local availability info"""
|
"""Get all versions of a model with local availability info"""
|
||||||
try:
|
try:
|
||||||
success, result = await self._make_request(
|
success, result = await self._make_request(
|
||||||
'GET',
|
"GET", f"{self.base_url}/models/{model_id}", use_auth=True
|
||||||
f"{self.base_url}/models/{model_id}",
|
|
||||||
use_auth=True
|
|
||||||
)
|
)
|
||||||
if success:
|
if success:
|
||||||
# Also return model type along with versions
|
# Also return model type along with versions
|
||||||
return {
|
return {
|
||||||
'modelVersions': result.get('modelVersions', []),
|
"modelVersions": result.get("modelVersions", []),
|
||||||
'type': result.get('type', ''),
|
"type": result.get("type", ""),
|
||||||
'name': result.get('name', '')
|
"name": result.get("name", ""),
|
||||||
}
|
}
|
||||||
message = self._extract_error_message(result)
|
message = self._extract_error_message(result)
|
||||||
if message and 'not found' in message.lower():
|
if message and "not found" in message.lower():
|
||||||
raise ResourceNotFoundError(f"Resource not found for model {model_id}")
|
raise ResourceNotFoundError(f"Resource not found for model {model_id}")
|
||||||
if message:
|
if message:
|
||||||
raise RuntimeError(message)
|
raise RuntimeError(message)
|
||||||
@@ -221,15 +234,15 @@ class CivitaiClient:
|
|||||||
try:
|
try:
|
||||||
query = ",".join(normalized_ids)
|
query = ",".join(normalized_ids)
|
||||||
success, result = await self._make_request(
|
success, result = await self._make_request(
|
||||||
'GET',
|
"GET",
|
||||||
f"{self.base_url}/models",
|
f"{self.base_url}/models",
|
||||||
use_auth=True,
|
use_auth=True,
|
||||||
params={'ids': query},
|
params={"ids": query},
|
||||||
)
|
)
|
||||||
if not success:
|
if not success:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
items = result.get('items') if isinstance(result, dict) else None
|
items = result.get("items") if isinstance(result, dict) else None
|
||||||
if not isinstance(items, list):
|
if not isinstance(items, list):
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
@@ -237,19 +250,19 @@ class CivitaiClient:
|
|||||||
for item in items:
|
for item in items:
|
||||||
if not isinstance(item, dict):
|
if not isinstance(item, dict):
|
||||||
continue
|
continue
|
||||||
model_id = item.get('id')
|
model_id = item.get("id")
|
||||||
try:
|
try:
|
||||||
normalized_id = int(model_id)
|
normalized_id = int(model_id)
|
||||||
except (TypeError, ValueError):
|
except (TypeError, ValueError):
|
||||||
continue
|
continue
|
||||||
payload[normalized_id] = {
|
payload[normalized_id] = {
|
||||||
'modelVersions': item.get('modelVersions', []),
|
"modelVersions": item.get("modelVersions", []),
|
||||||
'type': item.get('type', ''),
|
"type": item.get("type", ""),
|
||||||
'name': item.get('name', ''),
|
"name": item.get("name", ""),
|
||||||
'allowNoCredit': item.get('allowNoCredit'),
|
"allowNoCredit": item.get("allowNoCredit"),
|
||||||
'allowCommercialUse': item.get('allowCommercialUse'),
|
"allowCommercialUse": item.get("allowCommercialUse"),
|
||||||
'allowDerivatives': item.get('allowDerivatives'),
|
"allowDerivatives": item.get("allowDerivatives"),
|
||||||
'allowDifferentLicense': item.get('allowDifferentLicense'),
|
"allowDifferentLicense": item.get("allowDifferentLicense"),
|
||||||
}
|
}
|
||||||
return payload
|
return payload
|
||||||
except RateLimitError:
|
except RateLimitError:
|
||||||
@@ -258,7 +271,9 @@ class CivitaiClient:
|
|||||||
logger.error(f"Error fetching model versions in bulk: {exc}")
|
logger.error(f"Error fetching model versions in bulk: {exc}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]:
|
async def get_model_version(
|
||||||
|
self, model_id: int = None, version_id: int = None
|
||||||
|
) -> Optional[Dict]:
|
||||||
"""Get specific model version with additional metadata."""
|
"""Get specific model version with additional metadata."""
|
||||||
try:
|
try:
|
||||||
if model_id is None and version_id is not None:
|
if model_id is None and version_id is not None:
|
||||||
@@ -281,7 +296,7 @@ class CivitaiClient:
|
|||||||
if version is None:
|
if version is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
model_id = version.get('modelId')
|
model_id = version.get("modelId")
|
||||||
if not model_id:
|
if not model_id:
|
||||||
logger.error(f"No modelId found in version {version_id}")
|
logger.error(f"No modelId found in version {version_id}")
|
||||||
return None
|
return None
|
||||||
@@ -293,7 +308,9 @@ class CivitaiClient:
|
|||||||
self._remove_comfy_metadata(version)
|
self._remove_comfy_metadata(version)
|
||||||
return version
|
return version
|
||||||
|
|
||||||
async def _get_version_with_model_id(self, model_id: int, version_id: Optional[int]) -> Optional[Dict]:
|
async def _get_version_with_model_id(
|
||||||
|
self, model_id: int, version_id: Optional[int]
|
||||||
|
) -> Optional[Dict]:
|
||||||
model_data = await self._fetch_model_data(model_id)
|
model_data = await self._fetch_model_data(model_id)
|
||||||
if not model_data:
|
if not model_data:
|
||||||
return None
|
return None
|
||||||
@@ -302,8 +319,12 @@ class CivitaiClient:
|
|||||||
if target_version is None:
|
if target_version is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
target_version_id = target_version.get('id')
|
target_version_id = target_version.get("id")
|
||||||
version = await self._fetch_version_by_id(target_version_id) if target_version_id else None
|
version = (
|
||||||
|
await self._fetch_version_by_id(target_version_id)
|
||||||
|
if target_version_id
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
if version is None:
|
if version is None:
|
||||||
model_hash = self._extract_primary_model_hash(target_version)
|
model_hash = self._extract_primary_model_hash(target_version)
|
||||||
@@ -315,7 +336,9 @@ class CivitaiClient:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if version is None:
|
if version is None:
|
||||||
version = self._build_version_from_model_data(target_version, model_id, model_data)
|
version = self._build_version_from_model_data(
|
||||||
|
target_version, model_id, model_data
|
||||||
|
)
|
||||||
|
|
||||||
self._enrich_version_with_model_data(version, model_data)
|
self._enrich_version_with_model_data(version, model_data)
|
||||||
self._remove_comfy_metadata(version)
|
self._remove_comfy_metadata(version)
|
||||||
@@ -323,9 +346,7 @@ class CivitaiClient:
|
|||||||
|
|
||||||
async def _fetch_model_data(self, model_id: int) -> Optional[Dict]:
|
async def _fetch_model_data(self, model_id: int) -> Optional[Dict]:
|
||||||
success, data = await self._make_request(
|
success, data = await self._make_request(
|
||||||
'GET',
|
"GET", f"{self.base_url}/models/{model_id}", use_auth=True
|
||||||
f"{self.base_url}/models/{model_id}",
|
|
||||||
use_auth=True
|
|
||||||
)
|
)
|
||||||
if success:
|
if success:
|
||||||
return data
|
return data
|
||||||
@@ -337,9 +358,7 @@ class CivitaiClient:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
success, version = await self._make_request(
|
success, version = await self._make_request(
|
||||||
'GET',
|
"GET", f"{self.base_url}/model-versions/{version_id}", use_auth=True
|
||||||
f"{self.base_url}/model-versions/{version_id}",
|
|
||||||
use_auth=True
|
|
||||||
)
|
)
|
||||||
if success:
|
if success:
|
||||||
return version
|
return version
|
||||||
@@ -352,9 +371,7 @@ class CivitaiClient:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
success, version = await self._make_request(
|
success, version = await self._make_request(
|
||||||
'GET',
|
"GET", f"{self.base_url}/model-versions/by-hash/{model_hash}", use_auth=True
|
||||||
f"{self.base_url}/model-versions/by-hash/{model_hash}",
|
|
||||||
use_auth=True
|
|
||||||
)
|
)
|
||||||
if success:
|
if success:
|
||||||
return version
|
return version
|
||||||
@@ -362,16 +379,17 @@ class CivitaiClient:
|
|||||||
logger.warning(f"Failed to fetch version by hash {model_hash}")
|
logger.warning(f"Failed to fetch version by hash {model_hash}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _select_target_version(self, model_data: Dict, model_id: int, version_id: Optional[int]) -> Optional[Dict]:
|
def _select_target_version(
|
||||||
model_versions = model_data.get('modelVersions', [])
|
self, model_data: Dict, model_id: int, version_id: Optional[int]
|
||||||
|
) -> Optional[Dict]:
|
||||||
|
model_versions = model_data.get("modelVersions", [])
|
||||||
if not model_versions:
|
if not model_versions:
|
||||||
logger.warning(f"No model versions found for model {model_id}")
|
logger.warning(f"No model versions found for model {model_id}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if version_id is not None:
|
if version_id is not None:
|
||||||
target_version = next(
|
target_version = next(
|
||||||
(item for item in model_versions if item.get('id') == version_id),
|
(item for item in model_versions if item.get("id") == version_id), None
|
||||||
None
|
|
||||||
)
|
)
|
||||||
if target_version is None:
|
if target_version is None:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@@ -383,41 +401,45 @@ class CivitaiClient:
|
|||||||
return model_versions[0]
|
return model_versions[0]
|
||||||
|
|
||||||
def _extract_primary_model_hash(self, version_entry: Dict) -> Optional[str]:
|
def _extract_primary_model_hash(self, version_entry: Dict) -> Optional[str]:
|
||||||
for file_info in version_entry.get('files', []):
|
for file_info in version_entry.get("files", []):
|
||||||
if file_info.get('type') == 'Model' and file_info.get('primary'):
|
if file_info.get("type") == "Model" and file_info.get("primary"):
|
||||||
hashes = file_info.get('hashes', {})
|
hashes = file_info.get("hashes", {})
|
||||||
model_hash = hashes.get('SHA256')
|
model_hash = hashes.get("SHA256")
|
||||||
if model_hash:
|
if model_hash:
|
||||||
return model_hash
|
return model_hash
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _build_version_from_model_data(self, version_entry: Dict, model_id: int, model_data: Dict) -> Dict:
|
def _build_version_from_model_data(
|
||||||
|
self, version_entry: Dict, model_id: int, model_data: Dict
|
||||||
|
) -> Dict:
|
||||||
version = copy.deepcopy(version_entry)
|
version = copy.deepcopy(version_entry)
|
||||||
version.pop('index', None)
|
version.pop("index", None)
|
||||||
version['modelId'] = model_id
|
version["modelId"] = model_id
|
||||||
version['model'] = {
|
version["model"] = {
|
||||||
'name': model_data.get('name'),
|
"name": model_data.get("name"),
|
||||||
'type': model_data.get('type'),
|
"type": model_data.get("type"),
|
||||||
'nsfw': model_data.get('nsfw'),
|
"nsfw": model_data.get("nsfw"),
|
||||||
'poi': model_data.get('poi')
|
"poi": model_data.get("poi"),
|
||||||
}
|
}
|
||||||
return version
|
return version
|
||||||
|
|
||||||
def _enrich_version_with_model_data(self, version: Dict, model_data: Dict) -> None:
|
def _enrich_version_with_model_data(self, version: Dict, model_data: Dict) -> None:
|
||||||
model_info = version.get('model')
|
model_info = version.get("model")
|
||||||
if not isinstance(model_info, dict):
|
if not isinstance(model_info, dict):
|
||||||
model_info = {}
|
model_info = {}
|
||||||
version['model'] = model_info
|
version["model"] = model_info
|
||||||
|
|
||||||
model_info['description'] = model_data.get("description")
|
model_info["description"] = model_data.get("description")
|
||||||
model_info['tags'] = model_data.get("tags", [])
|
model_info["tags"] = model_data.get("tags", [])
|
||||||
version['creator'] = model_data.get("creator")
|
version["creator"] = model_data.get("creator")
|
||||||
|
|
||||||
license_payload = resolve_license_payload(model_data)
|
license_payload = resolve_license_payload(model_data)
|
||||||
for field, value in license_payload.items():
|
for field, value in license_payload.items():
|
||||||
model_info[field] = value
|
model_info[field] = value
|
||||||
|
|
||||||
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
|
async def get_model_version_info(
|
||||||
|
self, version_id: str
|
||||||
|
) -> Tuple[Optional[Dict], Optional[str]]:
|
||||||
"""Fetch model version metadata from Civitai
|
"""Fetch model version metadata from Civitai
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -432,14 +454,12 @@ class CivitaiClient:
|
|||||||
url = f"{self.base_url}/model-versions/{version_id}"
|
url = f"{self.base_url}/model-versions/{version_id}"
|
||||||
|
|
||||||
logger.debug(f"Resolving DNS for model version info: {url}")
|
logger.debug(f"Resolving DNS for model version info: {url}")
|
||||||
success, result = await self._make_request(
|
success, result = await self._make_request("GET", url, use_auth=True)
|
||||||
'GET',
|
|
||||||
url,
|
|
||||||
use_auth=True
|
|
||||||
)
|
|
||||||
|
|
||||||
if success:
|
if success:
|
||||||
logger.debug(f"Successfully fetched model version info for: {version_id}")
|
logger.debug(
|
||||||
|
f"Successfully fetched model version info for: {version_id}"
|
||||||
|
)
|
||||||
self._remove_comfy_metadata(result)
|
self._remove_comfy_metadata(result)
|
||||||
return result, None
|
return result, None
|
||||||
|
|
||||||
@@ -472,11 +492,7 @@ class CivitaiClient:
|
|||||||
url = f"{self.base_url}/images?imageId={image_id}&nsfw=X"
|
url = f"{self.base_url}/images?imageId={image_id}&nsfw=X"
|
||||||
|
|
||||||
logger.debug(f"Fetching image info for ID: {image_id}")
|
logger.debug(f"Fetching image info for ID: {image_id}")
|
||||||
success, result = await self._make_request(
|
success, result = await self._make_request("GET", url, use_auth=True)
|
||||||
'GET',
|
|
||||||
url,
|
|
||||||
use_auth=True
|
|
||||||
)
|
|
||||||
|
|
||||||
if success:
|
if success:
|
||||||
if result and "items" in result and len(result["items"]) > 0:
|
if result and "items" in result and len(result["items"]) > 0:
|
||||||
@@ -501,11 +517,7 @@ class CivitaiClient:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
url = f"{self.base_url}/models?username={username}"
|
url = f"{self.base_url}/models?username={username}"
|
||||||
success, result = await self._make_request(
|
success, result = await self._make_request("GET", url, use_auth=True)
|
||||||
'GET',
|
|
||||||
url,
|
|
||||||
use_auth=True
|
|
||||||
)
|
|
||||||
|
|
||||||
if not success:
|
if not success:
|
||||||
logger.error("Failed to fetch models for %s: %s", username, result)
|
logger.error("Failed to fetch models for %s: %s", username, result)
|
||||||
|
|||||||
Reference in New Issue
Block a user