mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-26 15:38:52 -03:00
Compare commits
7 Commits
2dae4c1291
...
03e1fa75c5
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
03e1fa75c5 | ||
|
|
fefcaa4a45 | ||
|
|
701a6a6c44 | ||
|
|
0ef414d17e | ||
|
|
75dccaef87 | ||
|
|
7e87ec9521 | ||
|
|
46522edb1b |
@@ -15,7 +15,7 @@ class CheckpointLoaderLM:
|
|||||||
extra folder paths, providing a unified interface for checkpoint loading.
|
extra folder paths, providing a unified interface for checkpoint loading.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
NAME = "CheckpointLoaderLM"
|
NAME = "Checkpoint Loader (LoraManager)"
|
||||||
CATEGORY = "Lora Manager/loaders"
|
CATEGORY = "Lora Manager/loaders"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -60,7 +60,7 @@ class CheckpointLoaderLM:
|
|||||||
if item.get("sub_type") == "checkpoint":
|
if item.get("sub_type") == "checkpoint":
|
||||||
file_path = item.get("file_path", "")
|
file_path = item.get("file_path", "")
|
||||||
if file_path:
|
if file_path:
|
||||||
# Format as ComfyUI-style: "folder/model_name.ext"
|
# Format using relative path with OS-native separator
|
||||||
formatted_name = _format_model_name_for_comfyui(
|
formatted_name = _format_model_name_for_comfyui(
|
||||||
file_path, model_roots
|
file_path, model_roots
|
||||||
)
|
)
|
||||||
@@ -94,7 +94,7 @@ class CheckpointLoaderLM:
|
|||||||
"""Load a checkpoint by name, supporting extra folder paths
|
"""Load a checkpoint by name, supporting extra folder paths
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
ckpt_name: The name of the checkpoint to load (format: "folder/model_name.ext")
|
ckpt_name: The name of the checkpoint to load (relative path with extension)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of (MODEL, CLIP, VAE)
|
Tuple of (MODEL, CLIP, VAE)
|
||||||
@@ -108,10 +108,6 @@ class CheckpointLoaderLM:
|
|||||||
"Make sure the checkpoint is indexed and try again."
|
"Make sure the checkpoint is indexed and try again."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if it's a GGUF model
|
|
||||||
if ckpt_path.endswith(".gguf"):
|
|
||||||
return self._load_gguf_checkpoint(ckpt_path, ckpt_name)
|
|
||||||
|
|
||||||
# Load regular checkpoint using ComfyUI's API
|
# Load regular checkpoint using ComfyUI's API
|
||||||
logger.info(f"Loading checkpoint from: {ckpt_path}")
|
logger.info(f"Loading checkpoint from: {ckpt_path}")
|
||||||
out = comfy.sd.load_checkpoint_guess_config(
|
out = comfy.sd.load_checkpoint_guess_config(
|
||||||
@@ -121,64 +117,3 @@ class CheckpointLoaderLM:
|
|||||||
embedding_directory=folder_paths.get_folder_paths("embeddings"),
|
embedding_directory=folder_paths.get_folder_paths("embeddings"),
|
||||||
)
|
)
|
||||||
return out[:3]
|
return out[:3]
|
||||||
|
|
||||||
def _load_gguf_checkpoint(self, ckpt_path: str, ckpt_name: str) -> Tuple:
|
|
||||||
"""Load a GGUF format checkpoint
|
|
||||||
|
|
||||||
Args:
|
|
||||||
ckpt_path: Absolute path to the GGUF file
|
|
||||||
ckpt_name: Name of the checkpoint for error messages
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (MODEL, CLIP, VAE) - CLIP and VAE may be None for GGUF
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Try to import ComfyUI-GGUF modules
|
|
||||||
from custom_nodes.ComfyUI_GGUF.loader import gguf_sd_loader
|
|
||||||
from custom_nodes.ComfyUI_GGUF.ops import GGMLOps
|
|
||||||
from custom_nodes.ComfyUI_GGUF.nodes import GGUFModelPatcher
|
|
||||||
except ImportError:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Cannot load GGUF model '{ckpt_name}'. "
|
|
||||||
"ComfyUI-GGUF is not installed. "
|
|
||||||
"Please install ComfyUI-GGUF from https://github.com/city96/ComfyUI-GGUF "
|
|
||||||
"to load GGUF format models."
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"Loading GGUF checkpoint from: {ckpt_path}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Load GGUF state dict
|
|
||||||
sd, extra = gguf_sd_loader(ckpt_path)
|
|
||||||
|
|
||||||
# Prepare kwargs for metadata if supported
|
|
||||||
kwargs = {}
|
|
||||||
import inspect
|
|
||||||
|
|
||||||
valid_params = inspect.signature(
|
|
||||||
comfy.sd.load_diffusion_model_state_dict
|
|
||||||
).parameters
|
|
||||||
if "metadata" in valid_params:
|
|
||||||
kwargs["metadata"] = extra.get("metadata", {})
|
|
||||||
|
|
||||||
# Load the model
|
|
||||||
model = comfy.sd.load_diffusion_model_state_dict(
|
|
||||||
sd, model_options={"custom_operations": GGMLOps()}, **kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
if model is None:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Could not detect model type for GGUF checkpoint: {ckpt_path}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Wrap with GGUFModelPatcher
|
|
||||||
model = GGUFModelPatcher.clone(model)
|
|
||||||
|
|
||||||
# GGUF checkpoints typically don't include CLIP/VAE
|
|
||||||
return (model, None, None)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error loading GGUF checkpoint '{ckpt_name}': {e}")
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Failed to load GGUF checkpoint '{ckpt_name}': {str(e)}"
|
|
||||||
)
|
|
||||||
|
|||||||
161
py/nodes/gguf_import_helper.py
Normal file
161
py/nodes/gguf_import_helper.py
Normal file
@@ -0,0 +1,161 @@
|
|||||||
|
"""
|
||||||
|
Helper module to safely import ComfyUI-GGUF modules.
|
||||||
|
|
||||||
|
This module provides a robust way to import ComfyUI-GGUF functionality
|
||||||
|
regardless of how ComfyUI loaded it.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import importlib.util
|
||||||
|
import logging
|
||||||
|
from typing import Optional, Tuple, Any
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_gguf_path() -> str:
|
||||||
|
"""Get the path to ComfyUI-GGUF based on this file's location.
|
||||||
|
|
||||||
|
Since ComfyUI-Lora-Manager and ComfyUI-GGUF are both in custom_nodes/,
|
||||||
|
we can derive the GGUF path from our own location.
|
||||||
|
"""
|
||||||
|
# This file is at: custom_nodes/ComfyUI-Lora-Manager/py/nodes/gguf_import_helper.py
|
||||||
|
# ComfyUI-GGUF is at: custom_nodes/ComfyUI-GGUF
|
||||||
|
current_file = os.path.abspath(__file__)
|
||||||
|
# Go up 4 levels: nodes -> py -> ComfyUI-Lora-Manager -> custom_nodes
|
||||||
|
custom_nodes_dir = os.path.dirname(
|
||||||
|
os.path.dirname(os.path.dirname(os.path.dirname(current_file)))
|
||||||
|
)
|
||||||
|
return os.path.join(custom_nodes_dir, "ComfyUI-GGUF")
|
||||||
|
|
||||||
|
|
||||||
|
def _find_gguf_module() -> Optional[Any]:
|
||||||
|
"""Find ComfyUI-GGUF module in sys.modules.
|
||||||
|
|
||||||
|
ComfyUI registers modules using the full path with dots replaced by _x_.
|
||||||
|
"""
|
||||||
|
gguf_path = _get_gguf_path()
|
||||||
|
sys_module_name = gguf_path.replace(".", "_x_")
|
||||||
|
|
||||||
|
logger.debug(f"[GGUF Import] Looking for module '{sys_module_name}' in sys.modules")
|
||||||
|
if sys_module_name in sys.modules:
|
||||||
|
logger.info(f"[GGUF Import] Found module: '{sys_module_name}'")
|
||||||
|
return sys.modules[sys_module_name]
|
||||||
|
|
||||||
|
logger.debug(f"[GGUF Import] Module not found: '{sys_module_name}'")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _load_gguf_modules_directly() -> Optional[Any]:
|
||||||
|
"""Load ComfyUI-GGUF modules directly from file paths."""
|
||||||
|
gguf_path = _get_gguf_path()
|
||||||
|
|
||||||
|
logger.info(f"[GGUF Import] Direct Load: Attempting to load from '{gguf_path}'")
|
||||||
|
|
||||||
|
if not os.path.exists(gguf_path):
|
||||||
|
logger.warning(f"[GGUF Import] Path does not exist: {gguf_path}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
namespace = "ComfyUI_GGUF_Dynamic"
|
||||||
|
init_path = os.path.join(gguf_path, "__init__.py")
|
||||||
|
|
||||||
|
if not os.path.exists(init_path):
|
||||||
|
logger.warning(f"[GGUF Import] __init__.py not found at '{init_path}'")
|
||||||
|
return None
|
||||||
|
|
||||||
|
logger.debug(f"[GGUF Import] Loading from '{init_path}'")
|
||||||
|
spec = importlib.util.spec_from_file_location(namespace, init_path)
|
||||||
|
if not spec or not spec.loader:
|
||||||
|
logger.error(f"[GGUF Import] Failed to create spec for '{init_path}'")
|
||||||
|
return None
|
||||||
|
|
||||||
|
package = importlib.util.module_from_spec(spec)
|
||||||
|
package.__path__ = [gguf_path]
|
||||||
|
sys.modules[namespace] = package
|
||||||
|
spec.loader.exec_module(package)
|
||||||
|
logger.debug(f"[GGUF Import] Loaded main package '{namespace}'")
|
||||||
|
|
||||||
|
# Load submodules
|
||||||
|
loaded = []
|
||||||
|
for submod_name in ["loader", "ops", "nodes"]:
|
||||||
|
submod_path = os.path.join(gguf_path, f"{submod_name}.py")
|
||||||
|
if os.path.exists(submod_path):
|
||||||
|
submod_spec = importlib.util.spec_from_file_location(
|
||||||
|
f"{namespace}.{submod_name}", submod_path
|
||||||
|
)
|
||||||
|
if submod_spec and submod_spec.loader:
|
||||||
|
submod = importlib.util.module_from_spec(submod_spec)
|
||||||
|
submod.__package__ = namespace
|
||||||
|
sys.modules[f"{namespace}.{submod_name}"] = submod
|
||||||
|
submod_spec.loader.exec_module(submod)
|
||||||
|
setattr(package, submod_name, submod)
|
||||||
|
loaded.append(submod_name)
|
||||||
|
logger.debug(f"[GGUF Import] Loaded submodule '{submod_name}'")
|
||||||
|
|
||||||
|
logger.info(f"[GGUF Import] Direct Load success: {loaded}")
|
||||||
|
return package
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[GGUF Import] Direct Load failed: {e}", exc_info=True)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_gguf_modules() -> Tuple[Any, Any, Any]:
|
||||||
|
"""Get ComfyUI-GGUF modules (loader, ops, nodes).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (loader_module, ops_module, nodes_module)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If ComfyUI-GGUF cannot be found or loaded.
|
||||||
|
"""
|
||||||
|
logger.debug("[GGUF Import] Starting module search...")
|
||||||
|
|
||||||
|
# Try to find already loaded module first
|
||||||
|
gguf_module = _find_gguf_module()
|
||||||
|
|
||||||
|
if gguf_module is None:
|
||||||
|
logger.info("[GGUF Import] Not found in sys.modules, trying direct load...")
|
||||||
|
gguf_module = _load_gguf_modules_directly()
|
||||||
|
|
||||||
|
if gguf_module is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"ComfyUI-GGUF is not installed. "
|
||||||
|
"Please install from https://github.com/city96/ComfyUI-GGUF"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract submodules
|
||||||
|
loader = getattr(gguf_module, "loader", None)
|
||||||
|
ops = getattr(gguf_module, "ops", None)
|
||||||
|
nodes = getattr(gguf_module, "nodes", None)
|
||||||
|
|
||||||
|
if loader is None or ops is None or nodes is None:
|
||||||
|
missing = [
|
||||||
|
name
|
||||||
|
for name, mod in [("loader", loader), ("ops", ops), ("nodes", nodes)]
|
||||||
|
if mod is None
|
||||||
|
]
|
||||||
|
raise RuntimeError(f"ComfyUI-GGUF missing submodules: {missing}")
|
||||||
|
|
||||||
|
logger.debug("[GGUF Import] All modules loaded successfully")
|
||||||
|
return loader, ops, nodes
|
||||||
|
|
||||||
|
|
||||||
|
def get_gguf_sd_loader():
|
||||||
|
"""Get the gguf_sd_loader function from ComfyUI-GGUF."""
|
||||||
|
loader, _, _ = get_gguf_modules()
|
||||||
|
return getattr(loader, "gguf_sd_loader")
|
||||||
|
|
||||||
|
|
||||||
|
def get_ggml_ops():
|
||||||
|
"""Get the GGMLOps class from ComfyUI-GGUF."""
|
||||||
|
_, ops, _ = get_gguf_modules()
|
||||||
|
return getattr(ops, "GGMLOps")
|
||||||
|
|
||||||
|
|
||||||
|
def get_gguf_model_patcher():
|
||||||
|
"""Get the GGUFModelPatcher class from ComfyUI-GGUF."""
|
||||||
|
_, _, nodes = get_gguf_modules()
|
||||||
|
return getattr(nodes, "GGUFModelPatcher")
|
||||||
@@ -16,7 +16,7 @@ class UNETLoaderLM:
|
|||||||
Supports both regular diffusion models and GGUF format models.
|
Supports both regular diffusion models and GGUF format models.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
NAME = "UNETLoaderLM"
|
NAME = "Unet Loader (LoraManager)"
|
||||||
CATEGORY = "Lora Manager/loaders"
|
CATEGORY = "Lora Manager/loaders"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -61,7 +61,7 @@ class UNETLoaderLM:
|
|||||||
if item.get("sub_type") == "diffusion_model":
|
if item.get("sub_type") == "diffusion_model":
|
||||||
file_path = item.get("file_path", "")
|
file_path = item.get("file_path", "")
|
||||||
if file_path:
|
if file_path:
|
||||||
# Format as ComfyUI-style: "folder/model_name.ext"
|
# Format using relative path with OS-native separator
|
||||||
formatted_name = _format_model_name_for_comfyui(
|
formatted_name = _format_model_name_for_comfyui(
|
||||||
file_path, model_roots
|
file_path, model_roots
|
||||||
)
|
)
|
||||||
@@ -95,7 +95,7 @@ class UNETLoaderLM:
|
|||||||
"""Load a diffusion model by name, supporting extra folder paths
|
"""Load a diffusion model by name, supporting extra folder paths
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
unet_name: The name of the diffusion model to load (format: "folder/model_name.ext")
|
unet_name: The name of the diffusion model to load (relative path with extension)
|
||||||
weight_dtype: The dtype to use for model weights
|
weight_dtype: The dtype to use for model weights
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -143,18 +143,16 @@ class UNETLoaderLM:
|
|||||||
Returns:
|
Returns:
|
||||||
Tuple of (MODEL,)
|
Tuple of (MODEL,)
|
||||||
"""
|
"""
|
||||||
|
from .gguf_import_helper import get_gguf_modules
|
||||||
|
|
||||||
|
# Get ComfyUI-GGUF modules using helper (handles various import scenarios)
|
||||||
try:
|
try:
|
||||||
# Try to import ComfyUI-GGUF modules
|
loader_module, ops_module, nodes_module = get_gguf_modules()
|
||||||
from custom_nodes.ComfyUI_GGUF.loader import gguf_sd_loader
|
gguf_sd_loader = getattr(loader_module, "gguf_sd_loader")
|
||||||
from custom_nodes.ComfyUI_GGUF.ops import GGMLOps
|
GGMLOps = getattr(ops_module, "GGMLOps")
|
||||||
from custom_nodes.ComfyUI_GGUF.nodes import GGUFModelPatcher
|
GGUFModelPatcher = getattr(nodes_module, "GGUFModelPatcher")
|
||||||
except ImportError:
|
except RuntimeError as e:
|
||||||
raise RuntimeError(
|
raise RuntimeError(f"Cannot load GGUF model '{unet_name}'. {str(e)}")
|
||||||
f"Cannot load GGUF model '{unet_name}'. "
|
|
||||||
"ComfyUI-GGUF is not installed. "
|
|
||||||
"Please install ComfyUI-GGUF from https://github.com/city96/ComfyUI-GGUF "
|
|
||||||
"to load GGUF format models."
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"Loading GGUF diffusion model from: {unet_path}")
|
logger.info(f"Loading GGUF diffusion model from: {unet_path}")
|
||||||
|
|
||||||
|
|||||||
@@ -58,6 +58,7 @@ class CacheEntryValidator:
|
|||||||
'preview_nsfw_level': (0, False),
|
'preview_nsfw_level': (0, False),
|
||||||
'notes': ('', False),
|
'notes': ('', False),
|
||||||
'usage_tips': ('', False),
|
'usage_tips': ('', False),
|
||||||
|
'hash_status': ('completed', False),
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -90,13 +91,31 @@ class CacheEntryValidator:
|
|||||||
|
|
||||||
errors: List[str] = []
|
errors: List[str] = []
|
||||||
repaired = False
|
repaired = False
|
||||||
|
|
||||||
|
# If auto_repair is on, we work on a copy. If not, we still need a safe way to check fields.
|
||||||
working_entry = dict(entry) if auto_repair else entry
|
working_entry = dict(entry) if auto_repair else entry
|
||||||
|
|
||||||
|
# Determine effective hash_status for validation logic
|
||||||
|
hash_status = entry.get('hash_status')
|
||||||
|
if hash_status is None:
|
||||||
|
if auto_repair:
|
||||||
|
working_entry['hash_status'] = 'completed'
|
||||||
|
repaired = True
|
||||||
|
hash_status = 'completed'
|
||||||
|
|
||||||
for field_name, (default_value, is_required) in cls.CORE_FIELDS.items():
|
for field_name, (default_value, is_required) in cls.CORE_FIELDS.items():
|
||||||
value = working_entry.get(field_name)
|
# Get current value from the original entry to avoid side effects during validation
|
||||||
|
value = entry.get(field_name)
|
||||||
|
|
||||||
# Check if field is missing or None
|
# Check if field is missing or None
|
||||||
if value is None:
|
if value is None:
|
||||||
|
# Special case: sha256 can be None/empty if hash_status is pending
|
||||||
|
if field_name == 'sha256' and hash_status == 'pending':
|
||||||
|
if auto_repair:
|
||||||
|
working_entry[field_name] = ''
|
||||||
|
repaired = True
|
||||||
|
continue
|
||||||
|
|
||||||
if is_required:
|
if is_required:
|
||||||
errors.append(f"Required field '{field_name}' is missing or None")
|
errors.append(f"Required field '{field_name}' is missing or None")
|
||||||
if auto_repair:
|
if auto_repair:
|
||||||
@@ -107,6 +126,10 @@ class CacheEntryValidator:
|
|||||||
# Validate field type and value
|
# Validate field type and value
|
||||||
field_error = cls._validate_field(field_name, value, default_value)
|
field_error = cls._validate_field(field_name, value, default_value)
|
||||||
if field_error:
|
if field_error:
|
||||||
|
# Special case: allow empty string for sha256 if pending
|
||||||
|
if field_name == 'sha256' and hash_status == 'pending' and value == '':
|
||||||
|
continue
|
||||||
|
|
||||||
errors.append(field_error)
|
errors.append(field_error)
|
||||||
if auto_repair:
|
if auto_repair:
|
||||||
working_entry[field_name] = cls._get_default_copy(default_value)
|
working_entry[field_name] = cls._get_default_copy(default_value)
|
||||||
@@ -127,7 +150,7 @@ class CacheEntryValidator:
|
|||||||
# Special validation: sha256 must not be empty for required field
|
# Special validation: sha256 must not be empty for required field
|
||||||
# BUT allow empty sha256 when hash_status is pending (lazy hash calculation)
|
# BUT allow empty sha256 when hash_status is pending (lazy hash calculation)
|
||||||
sha256 = working_entry.get('sha256', '')
|
sha256 = working_entry.get('sha256', '')
|
||||||
hash_status = working_entry.get('hash_status', 'completed')
|
# Use the effective hash_status we determined earlier
|
||||||
if not sha256 or (isinstance(sha256, str) and not sha256.strip()):
|
if not sha256 or (isinstance(sha256, str) and not sha256.strip()):
|
||||||
# Allow empty sha256 for lazy hash calculation (checkpoints)
|
# Allow empty sha256 for lazy hash calculation (checkpoints)
|
||||||
if hash_status != 'pending':
|
if hash_status != 'pending':
|
||||||
@@ -144,8 +167,13 @@ class CacheEntryValidator:
|
|||||||
if isinstance(sha256, str):
|
if isinstance(sha256, str):
|
||||||
normalized_sha = sha256.lower().strip()
|
normalized_sha = sha256.lower().strip()
|
||||||
if normalized_sha != sha256:
|
if normalized_sha != sha256:
|
||||||
|
if auto_repair:
|
||||||
working_entry['sha256'] = normalized_sha
|
working_entry['sha256'] = normalized_sha
|
||||||
repaired = True
|
repaired = True
|
||||||
|
else:
|
||||||
|
# If not auto-repairing, we don't consider case difference as a "critical error"
|
||||||
|
# that invalidates the entry, but we also don't mark it repaired.
|
||||||
|
pass
|
||||||
|
|
||||||
# Determine if entry is valid
|
# Determine if entry is valid
|
||||||
# Entry is valid if no critical required field errors remain after repair
|
# Entry is valid if no critical required field errors remain after repair
|
||||||
|
|||||||
@@ -56,6 +56,7 @@ class PersistentModelCache:
|
|||||||
"exclude",
|
"exclude",
|
||||||
"db_checked",
|
"db_checked",
|
||||||
"last_checked_at",
|
"last_checked_at",
|
||||||
|
"hash_status",
|
||||||
)
|
)
|
||||||
_MODEL_UPDATE_COLUMNS: Tuple[str, ...] = _MODEL_COLUMNS[2:]
|
_MODEL_UPDATE_COLUMNS: Tuple[str, ...] = _MODEL_COLUMNS[2:]
|
||||||
_instances: Dict[str, "PersistentModelCache"] = {}
|
_instances: Dict[str, "PersistentModelCache"] = {}
|
||||||
@@ -186,6 +187,7 @@ class PersistentModelCache:
|
|||||||
"civitai_deleted": bool(row["civitai_deleted"]),
|
"civitai_deleted": bool(row["civitai_deleted"]),
|
||||||
"skip_metadata_refresh": bool(row["skip_metadata_refresh"]),
|
"skip_metadata_refresh": bool(row["skip_metadata_refresh"]),
|
||||||
"license_flags": int(license_value),
|
"license_flags": int(license_value),
|
||||||
|
"hash_status": row["hash_status"] or "completed",
|
||||||
}
|
}
|
||||||
raw_data.append(item)
|
raw_data.append(item)
|
||||||
|
|
||||||
@@ -449,6 +451,7 @@ class PersistentModelCache:
|
|||||||
exclude INTEGER,
|
exclude INTEGER,
|
||||||
db_checked INTEGER,
|
db_checked INTEGER,
|
||||||
last_checked_at REAL,
|
last_checked_at REAL,
|
||||||
|
hash_status TEXT,
|
||||||
PRIMARY KEY (model_type, file_path)
|
PRIMARY KEY (model_type, file_path)
|
||||||
);
|
);
|
||||||
|
|
||||||
@@ -496,6 +499,7 @@ class PersistentModelCache:
|
|||||||
"skip_metadata_refresh": "INTEGER DEFAULT 0",
|
"skip_metadata_refresh": "INTEGER DEFAULT 0",
|
||||||
# Persisting without explicit flags should assume CivitAI's documented defaults (0b111001 == 57).
|
# Persisting without explicit flags should assume CivitAI's documented defaults (0b111001 == 57).
|
||||||
"license_flags": f"INTEGER DEFAULT {DEFAULT_LICENSE_FLAGS}",
|
"license_flags": f"INTEGER DEFAULT {DEFAULT_LICENSE_FLAGS}",
|
||||||
|
"hash_status": "TEXT DEFAULT 'completed'",
|
||||||
}
|
}
|
||||||
|
|
||||||
for column, definition in required_columns.items():
|
for column, definition in required_columns.items():
|
||||||
@@ -570,6 +574,7 @@ class PersistentModelCache:
|
|||||||
1 if item.get("exclude") else 0,
|
1 if item.get("exclude") else 0,
|
||||||
1 if item.get("db_checked") else 0,
|
1 if item.get("db_checked") else 0,
|
||||||
float(item.get("last_checked_at") or 0.0),
|
float(item.get("last_checked_at") or 0.0),
|
||||||
|
item.get("hash_status", "completed"),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _insert_model_sql(self) -> str:
|
def _insert_model_sql(self) -> str:
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""Services responsible for recipe metadata analysis."""
|
"""Services responsible for recipe metadata analysis."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
@@ -69,7 +70,9 @@ class RecipeAnalysisService:
|
|||||||
try:
|
try:
|
||||||
metadata = self._exif_utils.extract_image_metadata(temp_path)
|
metadata = self._exif_utils.extract_image_metadata(temp_path)
|
||||||
if not metadata:
|
if not metadata:
|
||||||
return AnalysisResult({"error": "No metadata found in this image", "loras": []})
|
return AnalysisResult(
|
||||||
|
{"error": "No metadata found in this image", "loras": []}
|
||||||
|
)
|
||||||
|
|
||||||
return await self._parse_metadata(
|
return await self._parse_metadata(
|
||||||
metadata,
|
metadata,
|
||||||
@@ -105,7 +108,9 @@ class RecipeAnalysisService:
|
|||||||
if civitai_match:
|
if civitai_match:
|
||||||
image_info = await civitai_client.get_image_info(civitai_match.group(1))
|
image_info = await civitai_client.get_image_info(civitai_match.group(1))
|
||||||
if not image_info:
|
if not image_info:
|
||||||
raise RecipeDownloadError("Failed to fetch image information from Civitai")
|
raise RecipeDownloadError(
|
||||||
|
"Failed to fetch image information from Civitai"
|
||||||
|
)
|
||||||
|
|
||||||
image_url = image_info.get("url")
|
image_url = image_info.get("url")
|
||||||
if not image_url:
|
if not image_url:
|
||||||
@@ -114,13 +119,15 @@ class RecipeAnalysisService:
|
|||||||
is_video = image_info.get("type") == "video"
|
is_video = image_info.get("type") == "video"
|
||||||
|
|
||||||
# Use optimized preview URLs if possible
|
# Use optimized preview URLs if possible
|
||||||
rewritten_url, _ = rewrite_preview_url(image_url, media_type=image_info.get("type"))
|
rewritten_url, _ = rewrite_preview_url(
|
||||||
|
image_url, media_type=image_info.get("type")
|
||||||
|
)
|
||||||
if rewritten_url:
|
if rewritten_url:
|
||||||
image_url = rewritten_url
|
image_url = rewritten_url
|
||||||
|
|
||||||
if is_video:
|
if is_video:
|
||||||
# Extract extension from URL
|
# Extract extension from URL
|
||||||
url_path = image_url.split('?')[0].split('#')[0]
|
url_path = image_url.split("?")[0].split("#")[0]
|
||||||
extension = os.path.splitext(url_path)[1].lower() or ".mp4"
|
extension = os.path.splitext(url_path)[1].lower() or ".mp4"
|
||||||
else:
|
else:
|
||||||
extension = ".jpg"
|
extension = ".jpg"
|
||||||
@@ -135,9 +142,17 @@ class RecipeAnalysisService:
|
|||||||
and isinstance(metadata["meta"], dict)
|
and isinstance(metadata["meta"], dict)
|
||||||
):
|
):
|
||||||
metadata = metadata["meta"]
|
metadata = metadata["meta"]
|
||||||
|
|
||||||
|
# Validate that metadata contains meaningful recipe fields
|
||||||
|
# If not, treat as None to trigger EXIF extraction from downloaded image
|
||||||
|
if isinstance(metadata, dict) and not self._has_recipe_fields(metadata):
|
||||||
|
self._logger.debug(
|
||||||
|
"Civitai API metadata lacks recipe fields, will extract from EXIF"
|
||||||
|
)
|
||||||
|
metadata = None
|
||||||
else:
|
else:
|
||||||
# Basic extension detection for non-Civitai URLs
|
# Basic extension detection for non-Civitai URLs
|
||||||
url_path = url.split('?')[0].split('#')[0]
|
url_path = url.split("?")[0].split("#")[0]
|
||||||
extension = os.path.splitext(url_path)[1].lower()
|
extension = os.path.splitext(url_path)[1].lower()
|
||||||
if extension in [".mp4", ".webm"]:
|
if extension in [".mp4", ".webm"]:
|
||||||
is_video = True
|
is_video = True
|
||||||
@@ -211,7 +226,9 @@ class RecipeAnalysisService:
|
|||||||
|
|
||||||
image_bytes = self._convert_tensor_to_png_bytes(latest_image)
|
image_bytes = self._convert_tensor_to_png_bytes(latest_image)
|
||||||
if image_bytes is None:
|
if image_bytes is None:
|
||||||
raise RecipeValidationError("Cannot handle this data shape from metadata registry")
|
raise RecipeValidationError(
|
||||||
|
"Cannot handle this data shape from metadata registry"
|
||||||
|
)
|
||||||
|
|
||||||
return AnalysisResult(
|
return AnalysisResult(
|
||||||
{
|
{
|
||||||
@@ -222,6 +239,22 @@ class RecipeAnalysisService:
|
|||||||
|
|
||||||
# Internal helpers -------------------------------------------------
|
# Internal helpers -------------------------------------------------
|
||||||
|
|
||||||
|
def _has_recipe_fields(self, metadata: dict[str, Any]) -> bool:
|
||||||
|
"""Check if metadata contains meaningful recipe-related fields."""
|
||||||
|
recipe_fields = {
|
||||||
|
"prompt",
|
||||||
|
"negative_prompt",
|
||||||
|
"resources",
|
||||||
|
"hashes",
|
||||||
|
"params",
|
||||||
|
"generationData",
|
||||||
|
"Workflow",
|
||||||
|
"prompt_type",
|
||||||
|
"positive",
|
||||||
|
"negative",
|
||||||
|
}
|
||||||
|
return any(field in metadata for field in recipe_fields)
|
||||||
|
|
||||||
async def _parse_metadata(
|
async def _parse_metadata(
|
||||||
self,
|
self,
|
||||||
metadata: dict[str, Any],
|
metadata: dict[str, Any],
|
||||||
@@ -234,7 +267,12 @@ class RecipeAnalysisService:
|
|||||||
) -> AnalysisResult:
|
) -> AnalysisResult:
|
||||||
parser = self._recipe_parser_factory.create_parser(metadata)
|
parser = self._recipe_parser_factory.create_parser(metadata)
|
||||||
if parser is None:
|
if parser is None:
|
||||||
payload = {"error": "No parser found for this image", "loras": []}
|
# Provide more specific error message based on metadata source
|
||||||
|
if not metadata:
|
||||||
|
error_msg = "This image does not contain any generation metadata (prompt, models, or parameters)"
|
||||||
|
else:
|
||||||
|
error_msg = "No parser found for this image"
|
||||||
|
payload = {"error": error_msg, "loras": []}
|
||||||
if include_image_base64 and image_path:
|
if include_image_base64 and image_path:
|
||||||
payload["image_base64"] = self._encode_file(image_path)
|
payload["image_base64"] = self._encode_file(image_path)
|
||||||
payload["is_video"] = is_video
|
payload["is_video"] = is_video
|
||||||
@@ -257,7 +295,9 @@ class RecipeAnalysisService:
|
|||||||
|
|
||||||
matching_recipes: list[str] = []
|
matching_recipes: list[str] = []
|
||||||
if fingerprint:
|
if fingerprint:
|
||||||
matching_recipes = await recipe_scanner.find_recipes_by_fingerprint(fingerprint)
|
matching_recipes = await recipe_scanner.find_recipes_by_fingerprint(
|
||||||
|
fingerprint
|
||||||
|
)
|
||||||
result["matching_recipes"] = matching_recipes
|
result["matching_recipes"] = matching_recipes
|
||||||
|
|
||||||
return AnalysisResult(result)
|
return AnalysisResult(result)
|
||||||
@@ -269,7 +309,10 @@ class RecipeAnalysisService:
|
|||||||
raise RecipeDownloadError(f"Failed to download image from URL: {result}")
|
raise RecipeDownloadError(f"Failed to download image from URL: {result}")
|
||||||
|
|
||||||
def _metadata_not_found_response(self, path: str) -> AnalysisResult:
|
def _metadata_not_found_response(self, path: str) -> AnalysisResult:
|
||||||
payload: dict[str, Any] = {"error": "No metadata found in this image", "loras": []}
|
payload: dict[str, Any] = {
|
||||||
|
"error": "No metadata found in this image",
|
||||||
|
"loras": [],
|
||||||
|
}
|
||||||
if os.path.exists(path):
|
if os.path.exists(path):
|
||||||
payload["image_base64"] = self._encode_file(path)
|
payload["image_base64"] = self._encode_file(path)
|
||||||
return AnalysisResult(payload)
|
return AnalysisResult(payload)
|
||||||
@@ -305,7 +348,9 @@ class RecipeAnalysisService:
|
|||||||
|
|
||||||
if hasattr(tensor_image, "shape"):
|
if hasattr(tensor_image, "shape"):
|
||||||
self._logger.debug(
|
self._logger.debug(
|
||||||
"Tensor shape: %s, dtype: %s", tensor_image.shape, getattr(tensor_image, "dtype", None)
|
"Tensor shape: %s, dtype: %s",
|
||||||
|
tensor_image.shape,
|
||||||
|
getattr(tensor_image, "dtype", None),
|
||||||
)
|
)
|
||||||
|
|
||||||
import torch # type: ignore[import-not-found]
|
import torch # type: ignore[import-not-found]
|
||||||
|
|||||||
@@ -148,8 +148,8 @@ def get_checkpoint_info_absolute(checkpoint_name):
|
|||||||
# Format the stored path as ComfyUI-style name
|
# Format the stored path as ComfyUI-style name
|
||||||
formatted_name = _format_model_name_for_comfyui(file_path, model_roots)
|
formatted_name = _format_model_name_for_comfyui(file_path, model_roots)
|
||||||
|
|
||||||
# Match by formatted name
|
# Match by formatted name (normalize separators for robust comparison)
|
||||||
if formatted_name == normalized_name or formatted_name == checkpoint_name:
|
if formatted_name.replace(os.sep, "/") == normalized_name or formatted_name == checkpoint_name:
|
||||||
return file_path, item
|
return file_path, item
|
||||||
|
|
||||||
# Also try matching by basename only (for backward compatibility)
|
# Also try matching by basename only (for backward compatibility)
|
||||||
@@ -200,19 +200,22 @@ def _format_model_name_for_comfyui(file_path: str, model_roots: list) -> str:
|
|||||||
Returns:
|
Returns:
|
||||||
ComfyUI-style model name with relative path and extension
|
ComfyUI-style model name with relative path and extension
|
||||||
"""
|
"""
|
||||||
# Normalize path separators
|
|
||||||
normalized_path = file_path.replace(os.sep, "/")
|
|
||||||
|
|
||||||
# Find the matching root and get relative path
|
# Find the matching root and get relative path
|
||||||
for root in model_roots:
|
for root in model_roots:
|
||||||
normalized_root = root.replace(os.sep, "/")
|
try:
|
||||||
# Ensure root ends with / for proper matching
|
# Normalize paths for comparison
|
||||||
if not normalized_root.endswith("/"):
|
norm_file = os.path.normcase(os.path.abspath(file_path))
|
||||||
normalized_root += "/"
|
norm_root = os.path.normcase(os.path.abspath(root))
|
||||||
|
|
||||||
if normalized_path.startswith(normalized_root):
|
# Add trailing separator for prefix check
|
||||||
rel_path = normalized_path[len(normalized_root) :]
|
if not norm_root.endswith(os.sep):
|
||||||
return rel_path
|
norm_root += os.sep
|
||||||
|
|
||||||
|
if norm_file.startswith(norm_root):
|
||||||
|
# Use os.path.relpath to get relative path with OS-native separator
|
||||||
|
return os.path.relpath(file_path, root)
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
continue
|
||||||
|
|
||||||
# If no root matches, just return the basename with extension
|
# If no root matches, just return the basename with extension
|
||||||
return os.path.basename(file_path)
|
return os.path.basename(file_path)
|
||||||
|
|||||||
@@ -104,6 +104,14 @@ export class BatchImportManager {
|
|||||||
|
|
||||||
// Clean up any existing connections
|
// Clean up any existing connections
|
||||||
this.cleanupConnections();
|
this.cleanupConnections();
|
||||||
|
|
||||||
|
// Focus on the URL input field for better UX
|
||||||
|
setTimeout(() => {
|
||||||
|
const urlInput = document.getElementById('batchUrlInput');
|
||||||
|
if (urlInput) {
|
||||||
|
urlInput.focus();
|
||||||
|
}
|
||||||
|
}, 100);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@@ -194,6 +194,7 @@ class TestCacheHealthMonitor:
|
|||||||
'preview_nsfw_level': 0,
|
'preview_nsfw_level': 0,
|
||||||
'notes': '',
|
'notes': '',
|
||||||
'usage_tips': '',
|
'usage_tips': '',
|
||||||
|
'hash_status': 'completed',
|
||||||
}
|
}
|
||||||
incomplete_entry = {
|
incomplete_entry = {
|
||||||
'file_path': '/models/test2.safetensors',
|
'file_path': '/models/test2.safetensors',
|
||||||
|
|||||||
Reference in New Issue
Block a user