fix: isolate extra unet paths from checkpoints to prevent type misclassification

Refactor _prepare_checkpoint_paths() to return a tuple instead of having
side effects on instance variables. This prevents extra unet paths from
being incorrectly classified as checkpoints when processing extra paths.

- Changed return type from List[str] to Tuple[List[str], List[str], List[str]]
  (all_paths, checkpoint_roots, unet_roots)
- Updated _init_checkpoint_paths() and _apply_library_paths() callers
- Fixed extra paths processing to properly isolate main and extra roots
- Updated test_checkpoint_path_overlap.py tests for new API

This ensures models in extra unet paths are correctly identified as
diffusion_model type and don't appear in checkpoints list.
This commit is contained in:
Will Miao
2026-03-17 22:03:57 +08:00
parent 70c150bd80
commit 2dae4c1291
8 changed files with 838 additions and 124 deletions

View File

@@ -1,6 +1,8 @@
try: # pragma: no cover - import fallback for pytest collection try: # pragma: no cover - import fallback for pytest collection
from .py.lora_manager import LoraManager from .py.lora_manager import LoraManager
from .py.nodes.lora_loader import LoraLoaderLM, LoraTextLoaderLM from .py.nodes.lora_loader import LoraLoaderLM, LoraTextLoaderLM
from .py.nodes.checkpoint_loader import CheckpointLoaderLM
from .py.nodes.unet_loader import UNETLoaderLM
from .py.nodes.trigger_word_toggle import TriggerWordToggleLM from .py.nodes.trigger_word_toggle import TriggerWordToggleLM
from .py.nodes.prompt import PromptLM from .py.nodes.prompt import PromptLM
from .py.nodes.text import TextLM from .py.nodes.text import TextLM
@@ -27,12 +29,12 @@ except (
PromptLM = importlib.import_module("py.nodes.prompt").PromptLM PromptLM = importlib.import_module("py.nodes.prompt").PromptLM
TextLM = importlib.import_module("py.nodes.text").TextLM TextLM = importlib.import_module("py.nodes.text").TextLM
LoraManager = importlib.import_module("py.lora_manager").LoraManager LoraManager = importlib.import_module("py.lora_manager").LoraManager
LoraLoaderLM = importlib.import_module( LoraLoaderLM = importlib.import_module("py.nodes.lora_loader").LoraLoaderLM
"py.nodes.lora_loader" LoraTextLoaderLM = importlib.import_module("py.nodes.lora_loader").LoraTextLoaderLM
).LoraLoaderLM CheckpointLoaderLM = importlib.import_module(
LoraTextLoaderLM = importlib.import_module( "py.nodes.checkpoint_loader"
"py.nodes.lora_loader" ).CheckpointLoaderLM
).LoraTextLoaderLM UNETLoaderLM = importlib.import_module("py.nodes.unet_loader").UNETLoaderLM
TriggerWordToggleLM = importlib.import_module( TriggerWordToggleLM = importlib.import_module(
"py.nodes.trigger_word_toggle" "py.nodes.trigger_word_toggle"
).TriggerWordToggleLM ).TriggerWordToggleLM
@@ -49,9 +51,7 @@ except (
LoraRandomizerLM = importlib.import_module( LoraRandomizerLM = importlib.import_module(
"py.nodes.lora_randomizer" "py.nodes.lora_randomizer"
).LoraRandomizerLM ).LoraRandomizerLM
LoraCyclerLM = importlib.import_module( LoraCyclerLM = importlib.import_module("py.nodes.lora_cycler").LoraCyclerLM
"py.nodes.lora_cycler"
).LoraCyclerLM
init_metadata_collector = importlib.import_module("py.metadata_collector").init init_metadata_collector = importlib.import_module("py.metadata_collector").init
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
@@ -59,6 +59,8 @@ NODE_CLASS_MAPPINGS = {
TextLM.NAME: TextLM, TextLM.NAME: TextLM,
LoraLoaderLM.NAME: LoraLoaderLM, LoraLoaderLM.NAME: LoraLoaderLM,
LoraTextLoaderLM.NAME: LoraTextLoaderLM, LoraTextLoaderLM.NAME: LoraTextLoaderLM,
CheckpointLoaderLM.NAME: CheckpointLoaderLM,
UNETLoaderLM.NAME: UNETLoaderLM,
TriggerWordToggleLM.NAME: TriggerWordToggleLM, TriggerWordToggleLM.NAME: TriggerWordToggleLM,
LoraStackerLM.NAME: LoraStackerLM, LoraStackerLM.NAME: LoraStackerLM,
SaveImageLM.NAME: SaveImageLM, SaveImageLM.NAME: SaveImageLM,

View File

@@ -707,7 +707,13 @@ class Config:
def _prepare_checkpoint_paths( def _prepare_checkpoint_paths(
self, checkpoint_paths: Iterable[str], unet_paths: Iterable[str] self, checkpoint_paths: Iterable[str], unet_paths: Iterable[str]
) -> List[str]: ) -> Tuple[List[str], List[str], List[str]]:
"""Prepare checkpoint paths and return (all_roots, checkpoint_roots, unet_roots).
Returns:
Tuple of (all_unique_paths, checkpoint_only_paths, unet_only_paths)
This method does NOT modify instance variables - callers must set them.
"""
checkpoint_map = self._dedupe_existing_paths(checkpoint_paths) checkpoint_map = self._dedupe_existing_paths(checkpoint_paths)
unet_map = self._dedupe_existing_paths(unet_paths) unet_map = self._dedupe_existing_paths(unet_paths)
@@ -737,8 +743,8 @@ class Config:
checkpoint_values = set(checkpoint_map.values()) checkpoint_values = set(checkpoint_map.values())
unet_values = set(unet_map.values()) unet_values = set(unet_map.values())
self.checkpoints_roots = [p for p in unique_paths if p in checkpoint_values] checkpoint_roots = [p for p in unique_paths if p in checkpoint_values]
self.unet_roots = [p for p in unique_paths if p in unet_values] unet_roots = [p for p in unique_paths if p in unet_values]
for original_path in unique_paths: for original_path in unique_paths:
real_path = os.path.normpath(os.path.realpath(original_path)).replace( real_path = os.path.normpath(os.path.realpath(original_path)).replace(
@@ -747,7 +753,7 @@ class Config:
if real_path != original_path: if real_path != original_path:
self.add_path_mapping(original_path, real_path) self.add_path_mapping(original_path, real_path)
return unique_paths return unique_paths, checkpoint_roots, unet_roots
def _prepare_embedding_paths(self, raw_paths: Iterable[str]) -> List[str]: def _prepare_embedding_paths(self, raw_paths: Iterable[str]) -> List[str]:
path_map = self._dedupe_existing_paths(raw_paths) path_map = self._dedupe_existing_paths(raw_paths)
@@ -776,9 +782,11 @@ class Config:
embedding_paths = folder_paths.get("embeddings", []) or [] embedding_paths = folder_paths.get("embeddings", []) or []
self.loras_roots = self._prepare_lora_paths(lora_paths) self.loras_roots = self._prepare_lora_paths(lora_paths)
self.base_models_roots = self._prepare_checkpoint_paths( (
checkpoint_paths, unet_paths self.base_models_roots,
) self.checkpoints_roots,
self.unet_roots,
) = self._prepare_checkpoint_paths(checkpoint_paths, unet_paths)
self.embeddings_roots = self._prepare_embedding_paths(embedding_paths) self.embeddings_roots = self._prepare_embedding_paths(embedding_paths)
# Process extra paths (only for LoRA Manager, not shared with ComfyUI) # Process extra paths (only for LoRA Manager, not shared with ComfyUI)
@@ -789,18 +797,11 @@ class Config:
extra_embedding_paths = extra_paths.get("embeddings", []) or [] extra_embedding_paths = extra_paths.get("embeddings", []) or []
self.extra_loras_roots = self._prepare_lora_paths(extra_lora_paths) self.extra_loras_roots = self._prepare_lora_paths(extra_lora_paths)
# Save main paths before processing extra paths ( _prepare_checkpoint_paths overwrites them) (
saved_checkpoints_roots = self.checkpoints_roots _,
saved_unet_roots = self.unet_roots self.extra_checkpoints_roots,
self.extra_checkpoints_roots = self._prepare_checkpoint_paths( self.extra_unet_roots,
extra_checkpoint_paths, extra_unet_paths ) = self._prepare_checkpoint_paths(extra_checkpoint_paths, extra_unet_paths)
)
self.extra_unet_roots = (
self.unet_roots if self.unet_roots is not None else []
) # unet_roots was set by _prepare_checkpoint_paths
# Restore main paths
self.checkpoints_roots = saved_checkpoints_roots
self.unet_roots = saved_unet_roots
self.extra_embeddings_roots = self._prepare_embedding_paths( self.extra_embeddings_roots = self._prepare_embedding_paths(
extra_embedding_paths extra_embedding_paths
) )
@@ -857,9 +858,11 @@ class Config:
try: try:
raw_checkpoint_paths = folder_paths.get_folder_paths("checkpoints") raw_checkpoint_paths = folder_paths.get_folder_paths("checkpoints")
raw_unet_paths = folder_paths.get_folder_paths("unet") raw_unet_paths = folder_paths.get_folder_paths("unet")
unique_paths = self._prepare_checkpoint_paths( (
raw_checkpoint_paths, raw_unet_paths unique_paths,
) self.checkpoints_roots,
self.unet_roots,
) = self._prepare_checkpoint_paths(raw_checkpoint_paths, raw_unet_paths)
logger.info( logger.info(
"Found checkpoint roots:" "Found checkpoint roots:"

View File

@@ -0,0 +1,184 @@
import logging
import os
from typing import List, Tuple
import comfy.sd
import folder_paths
from ..utils.utils import get_checkpoint_info_absolute, _format_model_name_for_comfyui
logger = logging.getLogger(__name__)
class CheckpointLoaderLM:
"""Checkpoint Loader with support for extra folder paths
Loads checkpoints from both standard ComfyUI folders and LoRA Manager's
extra folder paths, providing a unified interface for checkpoint loading.
"""
NAME = "CheckpointLoaderLM"
CATEGORY = "Lora Manager/loaders"
@classmethod
def INPUT_TYPES(s):
# Get list of checkpoint names from scanner (includes extra folder paths)
checkpoint_names = s._get_checkpoint_names()
return {
"required": {
"ckpt_name": (
checkpoint_names,
{"tooltip": "The name of the checkpoint (model) to load."},
),
}
}
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
RETURN_NAMES = ("MODEL", "CLIP", "VAE")
OUTPUT_TOOLTIPS = (
"The model used for denoising latents.",
"The CLIP model used for encoding text prompts.",
"The VAE model used for encoding and decoding images to and from latent space.",
)
FUNCTION = "load_checkpoint"
@classmethod
def _get_checkpoint_names(cls) -> List[str]:
"""Get list of checkpoint names from scanner cache in ComfyUI format (relative path with extension)"""
try:
from ..services.service_registry import ServiceRegistry
import asyncio
async def _get_names():
scanner = await ServiceRegistry.get_checkpoint_scanner()
cache = await scanner.get_cached_data()
# Get all model roots for calculating relative paths
model_roots = scanner.get_model_roots()
# Filter only checkpoint type (not diffusion_model) and format names
names = []
for item in cache.raw_data:
if item.get("sub_type") == "checkpoint":
file_path = item.get("file_path", "")
if file_path:
# Format as ComfyUI-style: "folder/model_name.ext"
formatted_name = _format_model_name_for_comfyui(
file_path, model_roots
)
if formatted_name:
names.append(formatted_name)
return sorted(names)
try:
loop = asyncio.get_running_loop()
import concurrent.futures
def run_in_thread():
new_loop = asyncio.new_event_loop()
asyncio.set_event_loop(new_loop)
try:
return new_loop.run_until_complete(_get_names())
finally:
new_loop.close()
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(run_in_thread)
return future.result()
except RuntimeError:
return asyncio.run(_get_names())
except Exception as e:
logger.error(f"Error getting checkpoint names: {e}")
return []
def load_checkpoint(self, ckpt_name: str) -> Tuple:
"""Load a checkpoint by name, supporting extra folder paths
Args:
ckpt_name: The name of the checkpoint to load (format: "folder/model_name.ext")
Returns:
Tuple of (MODEL, CLIP, VAE)
"""
# Get absolute path from cache using ComfyUI-style name
ckpt_path, metadata = get_checkpoint_info_absolute(ckpt_name)
if metadata is None:
raise FileNotFoundError(
f"Checkpoint '{ckpt_name}' not found in LoRA Manager cache. "
"Make sure the checkpoint is indexed and try again."
)
# Check if it's a GGUF model
if ckpt_path.endswith(".gguf"):
return self._load_gguf_checkpoint(ckpt_path, ckpt_name)
# Load regular checkpoint using ComfyUI's API
logger.info(f"Loading checkpoint from: {ckpt_path}")
out = comfy.sd.load_checkpoint_guess_config(
ckpt_path,
output_vae=True,
output_clip=True,
embedding_directory=folder_paths.get_folder_paths("embeddings"),
)
return out[:3]
def _load_gguf_checkpoint(self, ckpt_path: str, ckpt_name: str) -> Tuple:
"""Load a GGUF format checkpoint
Args:
ckpt_path: Absolute path to the GGUF file
ckpt_name: Name of the checkpoint for error messages
Returns:
Tuple of (MODEL, CLIP, VAE) - CLIP and VAE may be None for GGUF
"""
try:
# Try to import ComfyUI-GGUF modules
from custom_nodes.ComfyUI_GGUF.loader import gguf_sd_loader
from custom_nodes.ComfyUI_GGUF.ops import GGMLOps
from custom_nodes.ComfyUI_GGUF.nodes import GGUFModelPatcher
except ImportError:
raise RuntimeError(
f"Cannot load GGUF model '{ckpt_name}'. "
"ComfyUI-GGUF is not installed. "
"Please install ComfyUI-GGUF from https://github.com/city96/ComfyUI-GGUF "
"to load GGUF format models."
)
logger.info(f"Loading GGUF checkpoint from: {ckpt_path}")
try:
# Load GGUF state dict
sd, extra = gguf_sd_loader(ckpt_path)
# Prepare kwargs for metadata if supported
kwargs = {}
import inspect
valid_params = inspect.signature(
comfy.sd.load_diffusion_model_state_dict
).parameters
if "metadata" in valid_params:
kwargs["metadata"] = extra.get("metadata", {})
# Load the model
model = comfy.sd.load_diffusion_model_state_dict(
sd, model_options={"custom_operations": GGMLOps()}, **kwargs
)
if model is None:
raise RuntimeError(
f"Could not detect model type for GGUF checkpoint: {ckpt_path}"
)
# Wrap with GGUFModelPatcher
model = GGUFModelPatcher.clone(model)
# GGUF checkpoints typically don't include CLIP/VAE
return (model, None, None)
except Exception as e:
logger.error(f"Error loading GGUF checkpoint '{ckpt_name}': {e}")
raise RuntimeError(
f"Failed to load GGUF checkpoint '{ckpt_name}': {str(e)}"
)

205
py/nodes/unet_loader.py Normal file
View File

@@ -0,0 +1,205 @@
import logging
import os
from typing import List, Tuple
import torch
import comfy.sd
from ..utils.utils import get_checkpoint_info_absolute, _format_model_name_for_comfyui
logger = logging.getLogger(__name__)
class UNETLoaderLM:
"""UNET Loader with support for extra folder paths
Loads diffusion models/UNets from both standard ComfyUI folders and LoRA Manager's
extra folder paths, providing a unified interface for UNET loading.
Supports both regular diffusion models and GGUF format models.
"""
NAME = "UNETLoaderLM"
CATEGORY = "Lora Manager/loaders"
@classmethod
def INPUT_TYPES(s):
# Get list of unet names from scanner (includes extra folder paths)
unet_names = s._get_unet_names()
return {
"required": {
"unet_name": (
unet_names,
{"tooltip": "The name of the diffusion model to load."},
),
"weight_dtype": (
["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2"],
{"tooltip": "The dtype to use for the model weights."},
),
}
}
RETURN_TYPES = ("MODEL",)
RETURN_NAMES = ("MODEL",)
OUTPUT_TOOLTIPS = ("The model used for denoising latents.",)
FUNCTION = "load_unet"
@classmethod
def _get_unet_names(cls) -> List[str]:
"""Get list of diffusion model names from scanner cache in ComfyUI format (relative path with extension)"""
try:
from ..services.service_registry import ServiceRegistry
import asyncio
async def _get_names():
scanner = await ServiceRegistry.get_checkpoint_scanner()
cache = await scanner.get_cached_data()
# Get all model roots for calculating relative paths
model_roots = scanner.get_model_roots()
# Filter only diffusion_model type and format names
names = []
for item in cache.raw_data:
if item.get("sub_type") == "diffusion_model":
file_path = item.get("file_path", "")
if file_path:
# Format as ComfyUI-style: "folder/model_name.ext"
formatted_name = _format_model_name_for_comfyui(
file_path, model_roots
)
if formatted_name:
names.append(formatted_name)
return sorted(names)
try:
loop = asyncio.get_running_loop()
import concurrent.futures
def run_in_thread():
new_loop = asyncio.new_event_loop()
asyncio.set_event_loop(new_loop)
try:
return new_loop.run_until_complete(_get_names())
finally:
new_loop.close()
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(run_in_thread)
return future.result()
except RuntimeError:
return asyncio.run(_get_names())
except Exception as e:
logger.error(f"Error getting unet names: {e}")
return []
def load_unet(self, unet_name: str, weight_dtype: str) -> Tuple:
"""Load a diffusion model by name, supporting extra folder paths
Args:
unet_name: The name of the diffusion model to load (format: "folder/model_name.ext")
weight_dtype: The dtype to use for model weights
Returns:
Tuple of (MODEL,)
"""
# Get absolute path from cache using ComfyUI-style name
unet_path, metadata = get_checkpoint_info_absolute(unet_name)
if metadata is None:
raise FileNotFoundError(
f"Diffusion model '{unet_name}' not found in LoRA Manager cache. "
"Make sure the model is indexed and try again."
)
# Check if it's a GGUF model
if unet_path.endswith(".gguf"):
return self._load_gguf_unet(unet_path, unet_name, weight_dtype)
# Load regular diffusion model using ComfyUI's API
logger.info(f"Loading diffusion model from: {unet_path}")
# Build model options based on weight_dtype
model_options = {}
if weight_dtype == "fp8_e4m3fn":
model_options["dtype"] = torch.float8_e4m3fn
elif weight_dtype == "fp8_e4m3fn_fast":
model_options["dtype"] = torch.float8_e4m3fn
model_options["fp8_optimizations"] = True
elif weight_dtype == "fp8_e5m2":
model_options["dtype"] = torch.float8_e5m2
model = comfy.sd.load_diffusion_model(unet_path, model_options=model_options)
return (model,)
def _load_gguf_unet(
self, unet_path: str, unet_name: str, weight_dtype: str
) -> Tuple:
"""Load a GGUF format diffusion model
Args:
unet_path: Absolute path to the GGUF file
unet_name: Name of the model for error messages
weight_dtype: The dtype to use for model weights
Returns:
Tuple of (MODEL,)
"""
try:
# Try to import ComfyUI-GGUF modules
from custom_nodes.ComfyUI_GGUF.loader import gguf_sd_loader
from custom_nodes.ComfyUI_GGUF.ops import GGMLOps
from custom_nodes.ComfyUI_GGUF.nodes import GGUFModelPatcher
except ImportError:
raise RuntimeError(
f"Cannot load GGUF model '{unet_name}'. "
"ComfyUI-GGUF is not installed. "
"Please install ComfyUI-GGUF from https://github.com/city96/ComfyUI-GGUF "
"to load GGUF format models."
)
logger.info(f"Loading GGUF diffusion model from: {unet_path}")
try:
# Load GGUF state dict
sd, extra = gguf_sd_loader(unet_path)
# Prepare kwargs for metadata if supported
kwargs = {}
import inspect
valid_params = inspect.signature(
comfy.sd.load_diffusion_model_state_dict
).parameters
if "metadata" in valid_params:
kwargs["metadata"] = extra.get("metadata", {})
# Setup custom operations with GGUF support
ops = GGMLOps()
# Handle weight_dtype for GGUF models
if weight_dtype in ("default", None):
ops.Linear.dequant_dtype = None
elif weight_dtype in ["target"]:
ops.Linear.dequant_dtype = weight_dtype
else:
ops.Linear.dequant_dtype = getattr(torch, weight_dtype, None)
# Load the model
model = comfy.sd.load_diffusion_model_state_dict(
sd, model_options={"custom_operations": ops}, **kwargs
)
if model is None:
raise RuntimeError(
f"Could not detect model type for GGUF diffusion model: {unet_path}"
)
# Wrap with GGUFModelPatcher
model = GGUFModelPatcher.clone(model)
return (model,)
except Exception as e:
logger.error(f"Error loading GGUF diffusion model '{unet_name}': {e}")
raise RuntimeError(
f"Failed to load GGUF diffusion model '{unet_name}': {str(e)}"
)

View File

@@ -13,22 +13,35 @@ from .model_hash_index import ModelHashIndex
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class CheckpointScanner(ModelScanner): class CheckpointScanner(ModelScanner):
"""Service for scanning and managing checkpoint files""" """Service for scanning and managing checkpoint files"""
def __init__(self): def __init__(self):
# Define supported file extensions # Define supported file extensions
file_extensions = {'.ckpt', '.pt', '.pt2', '.bin', '.pth', '.safetensors', '.pkl', '.sft', '.gguf'} file_extensions = {
".ckpt",
".pt",
".pt2",
".bin",
".pth",
".safetensors",
".pkl",
".sft",
".gguf",
}
super().__init__( super().__init__(
model_type="checkpoint", model_type="checkpoint",
model_class=CheckpointMetadata, model_class=CheckpointMetadata,
file_extensions=file_extensions, file_extensions=file_extensions,
hash_index=ModelHashIndex() hash_index=ModelHashIndex(),
) )
async def _create_default_metadata(self, file_path: str) -> Optional[CheckpointMetadata]: async def _create_default_metadata(
self, file_path: str
) -> Optional[CheckpointMetadata]:
"""Create default metadata for checkpoint without calculating hash (lazy hash). """Create default metadata for checkpoint without calculating hash (lazy hash).
Checkpoints are typically large (10GB+), so we skip hash calculation during initial Checkpoints are typically large (10GB+), so we skip hash calculation during initial
scanning to improve startup performance. Hash will be calculated on-demand when scanning to improve startup performance. Hash will be calculated on-demand when
fetching metadata from Civitai. fetching metadata from Civitai.
@@ -38,13 +51,13 @@ class CheckpointScanner(ModelScanner):
if not os.path.exists(real_path): if not os.path.exists(real_path):
logger.error(f"File not found: {file_path}") logger.error(f"File not found: {file_path}")
return None return None
base_name = os.path.splitext(os.path.basename(file_path))[0] base_name = os.path.splitext(os.path.basename(file_path))[0]
dir_path = os.path.dirname(file_path) dir_path = os.path.dirname(file_path)
# Find preview image # Find preview image
preview_url = find_preview_file(base_name, dir_path) preview_url = find_preview_file(base_name, dir_path)
# Create metadata WITHOUT calculating hash # Create metadata WITHOUT calculating hash
metadata = CheckpointMetadata( metadata = CheckpointMetadata(
file_name=base_name, file_name=base_name,
@@ -59,70 +72,76 @@ class CheckpointScanner(ModelScanner):
modelDescription="", modelDescription="",
sub_type="checkpoint", sub_type="checkpoint",
from_civitai=False, # Mark as local model since no hash yet from_civitai=False, # Mark as local model since no hash yet
hash_status="pending" # Mark hash as pending hash_status="pending", # Mark hash as pending
) )
# Save the created metadata # Save the created metadata
logger.info(f"Creating checkpoint metadata (hash pending) for {file_path}") logger.info(f"Creating checkpoint metadata (hash pending) for {file_path}")
await MetadataManager.save_metadata(file_path, metadata) await MetadataManager.save_metadata(file_path, metadata)
return metadata return metadata
except Exception as e: except Exception as e:
logger.error(f"Error creating default checkpoint metadata for {file_path}: {e}") logger.error(
f"Error creating default checkpoint metadata for {file_path}: {e}"
)
return None return None
async def calculate_hash_for_model(self, file_path: str) -> Optional[str]: async def calculate_hash_for_model(self, file_path: str) -> Optional[str]:
"""Calculate hash for a checkpoint on-demand. """Calculate hash for a checkpoint on-demand.
Args: Args:
file_path: Path to the model file file_path: Path to the model file
Returns: Returns:
SHA256 hash string, or None if calculation failed SHA256 hash string, or None if calculation failed
""" """
from ..utils.file_utils import calculate_sha256 from ..utils.file_utils import calculate_sha256
try: try:
real_path = os.path.realpath(file_path) real_path = os.path.realpath(file_path)
if not os.path.exists(real_path): if not os.path.exists(real_path):
logger.error(f"File not found for hash calculation: {file_path}") logger.error(f"File not found for hash calculation: {file_path}")
return None return None
# Load current metadata # Load current metadata
metadata, _ = await MetadataManager.load_metadata(file_path, self.model_class) metadata, _ = await MetadataManager.load_metadata(
file_path, self.model_class
)
if metadata is None: if metadata is None:
logger.error(f"No metadata found for {file_path}") logger.error(f"No metadata found for {file_path}")
return None return None
# Check if hash is already calculated # Check if hash is already calculated
if metadata.hash_status == "completed" and metadata.sha256: if metadata.hash_status == "completed" and metadata.sha256:
return metadata.sha256 return metadata.sha256
# Update status to calculating # Update status to calculating
metadata.hash_status = "calculating" metadata.hash_status = "calculating"
await MetadataManager.save_metadata(file_path, metadata) await MetadataManager.save_metadata(file_path, metadata)
# Calculate hash # Calculate hash
logger.info(f"Calculating hash for checkpoint: {file_path}") logger.info(f"Calculating hash for checkpoint: {file_path}")
sha256 = await calculate_sha256(real_path) sha256 = await calculate_sha256(real_path)
# Update metadata with hash # Update metadata with hash
metadata.sha256 = sha256 metadata.sha256 = sha256
metadata.hash_status = "completed" metadata.hash_status = "completed"
await MetadataManager.save_metadata(file_path, metadata) await MetadataManager.save_metadata(file_path, metadata)
# Update hash index # Update hash index
self._hash_index.add_entry(sha256.lower(), file_path) self._hash_index.add_entry(sha256.lower(), file_path)
logger.info(f"Hash calculated for checkpoint: {file_path}") logger.info(f"Hash calculated for checkpoint: {file_path}")
return sha256 return sha256
except Exception as e: except Exception as e:
logger.error(f"Error calculating hash for {file_path}: {e}") logger.error(f"Error calculating hash for {file_path}: {e}")
# Update status to failed # Update status to failed
try: try:
metadata, _ = await MetadataManager.load_metadata(file_path, self.model_class) metadata, _ = await MetadataManager.load_metadata(
file_path, self.model_class
)
if metadata: if metadata:
metadata.hash_status = "failed" metadata.hash_status = "failed"
await MetadataManager.save_metadata(file_path, metadata) await MetadataManager.save_metadata(file_path, metadata)
@@ -130,43 +149,46 @@ class CheckpointScanner(ModelScanner):
pass pass
return None return None
async def calculate_all_pending_hashes(self, progress_callback=None) -> Dict[str, int]: async def calculate_all_pending_hashes(
self, progress_callback=None
) -> Dict[str, int]:
"""Calculate hashes for all checkpoints with pending hash status. """Calculate hashes for all checkpoints with pending hash status.
If cache is not initialized, scans filesystem directly for metadata files If cache is not initialized, scans filesystem directly for metadata files
with hash_status != 'completed'. with hash_status != 'completed'.
Args: Args:
progress_callback: Optional callback(progress, total, current_file) progress_callback: Optional callback(progress, total, current_file)
Returns: Returns:
Dict with 'completed', 'failed', 'total' counts Dict with 'completed', 'failed', 'total' counts
""" """
# Try to get from cache first # Try to get from cache first
cache = await self.get_cached_data() cache = await self.get_cached_data()
if cache and cache.raw_data: if cache and cache.raw_data:
# Use cache if available # Use cache if available
pending_models = [ pending_models = [
item for item in cache.raw_data item
if item.get('hash_status') != 'completed' or not item.get('sha256') for item in cache.raw_data
if item.get("hash_status") != "completed" or not item.get("sha256")
] ]
else: else:
# Cache not initialized, scan filesystem directly # Cache not initialized, scan filesystem directly
pending_models = await self._find_pending_models_from_filesystem() pending_models = await self._find_pending_models_from_filesystem()
if not pending_models: if not pending_models:
return {'completed': 0, 'failed': 0, 'total': 0} return {"completed": 0, "failed": 0, "total": 0}
total = len(pending_models) total = len(pending_models)
completed = 0 completed = 0
failed = 0 failed = 0
for i, model_data in enumerate(pending_models): for i, model_data in enumerate(pending_models):
file_path = model_data.get('file_path') file_path = model_data.get("file_path")
if not file_path: if not file_path:
continue continue
try: try:
sha256 = await self.calculate_hash_for_model(file_path) sha256 = await self.calculate_hash_for_model(file_path)
if sha256: if sha256:
@@ -176,77 +198,102 @@ class CheckpointScanner(ModelScanner):
except Exception as e: except Exception as e:
logger.error(f"Error calculating hash for {file_path}: {e}") logger.error(f"Error calculating hash for {file_path}: {e}")
failed += 1 failed += 1
if progress_callback: if progress_callback:
try: try:
await progress_callback(i + 1, total, file_path) await progress_callback(i + 1, total, file_path)
except Exception: except Exception:
pass pass
return { return {"completed": completed, "failed": failed, "total": total}
'completed': completed,
'failed': failed,
'total': total
}
async def _find_pending_models_from_filesystem(self) -> List[Dict[str, Any]]: async def _find_pending_models_from_filesystem(self) -> List[Dict[str, Any]]:
"""Scan filesystem for checkpoint metadata files with pending hash status.""" """Scan filesystem for checkpoint metadata files with pending hash status."""
pending_models = [] pending_models = []
for root_path in self.get_model_roots(): for root_path in self.get_model_roots():
if not os.path.exists(root_path): if not os.path.exists(root_path):
continue continue
for dirpath, _dirnames, filenames in os.walk(root_path): for dirpath, _dirnames, filenames in os.walk(root_path):
for filename in filenames: for filename in filenames:
if not filename.endswith('.metadata.json'): if not filename.endswith(".metadata.json"):
continue continue
metadata_path = os.path.join(dirpath, filename) metadata_path = os.path.join(dirpath, filename)
try: try:
with open(metadata_path, 'r', encoding='utf-8') as f: with open(metadata_path, "r", encoding="utf-8") as f:
data = json.load(f) data = json.load(f)
# Check if hash is pending # Check if hash is pending
hash_status = data.get('hash_status', 'completed') hash_status = data.get("hash_status", "completed")
sha256 = data.get('sha256', '') sha256 = data.get("sha256", "")
if hash_status != 'completed' or not sha256: if hash_status != "completed" or not sha256:
# Find corresponding model file # Find corresponding model file
model_name = filename.replace('.metadata.json', '') model_name = filename.replace(".metadata.json", "")
model_path = None model_path = None
# Look for model file with matching name # Look for model file with matching name
for ext in self.file_extensions: for ext in self.file_extensions:
potential_path = os.path.join(dirpath, model_name + ext) potential_path = os.path.join(dirpath, model_name + ext)
if os.path.exists(potential_path): if os.path.exists(potential_path):
model_path = potential_path model_path = potential_path
break break
if model_path: if model_path:
pending_models.append({ pending_models.append(
'file_path': model_path.replace(os.sep, '/'), {
'hash_status': hash_status, "file_path": model_path.replace(os.sep, "/"),
'sha256': sha256, "hash_status": hash_status,
**{k: v for k, v in data.items() if k not in ['file_path', 'hash_status', 'sha256']} "sha256": sha256,
}) **{
k: v
for k, v in data.items()
if k
not in [
"file_path",
"hash_status",
"sha256",
]
},
}
)
except (json.JSONDecodeError, Exception) as e: except (json.JSONDecodeError, Exception) as e:
logger.debug(f"Error reading metadata file {metadata_path}: {e}") logger.debug(
f"Error reading metadata file {metadata_path}: {e}"
)
continue continue
return pending_models return pending_models
def _resolve_sub_type(self, root_path: Optional[str]) -> Optional[str]: def _resolve_sub_type(self, root_path: Optional[str]) -> Optional[str]:
"""Resolve the sub-type based on the root path.""" """Resolve the sub-type based on the root path.
Checks both standard ComfyUI paths and LoRA Manager's extra folder paths.
"""
if not root_path: if not root_path:
return None return None
# Check standard ComfyUI checkpoint paths
if config.checkpoints_roots and root_path in config.checkpoints_roots: if config.checkpoints_roots and root_path in config.checkpoints_roots:
return "checkpoint" return "checkpoint"
# Check extra checkpoint paths
if (
config.extra_checkpoints_roots
and root_path in config.extra_checkpoints_roots
):
return "checkpoint"
# Check standard ComfyUI unet paths
if config.unet_roots and root_path in config.unet_roots: if config.unet_roots and root_path in config.unet_roots:
return "diffusion_model" return "diffusion_model"
# Check extra unet paths
if config.extra_unet_roots and root_path in config.extra_unet_roots:
return "diffusion_model"
return None return None
def adjust_metadata(self, metadata, file_path, root_path): def adjust_metadata(self, metadata, file_path, root_path):

View File

@@ -112,6 +112,112 @@ def get_lora_info_absolute(lora_name):
return asyncio.run(_get_lora_info_absolute_async()) return asyncio.run(_get_lora_info_absolute_async())
def get_checkpoint_info_absolute(checkpoint_name):
"""Get the absolute checkpoint path and metadata from cache
Supports ComfyUI-style model names (e.g., "folder/model_name.ext")
Args:
checkpoint_name: The model name, can be:
- ComfyUI format: "folder/model_name.safetensors"
- Simple name: "model_name"
Returns:
tuple: (absolute_path, metadata) where absolute_path is the full
file system path to the checkpoint file, or original checkpoint_name if not found,
metadata is the full model metadata dict or None
"""
async def _get_checkpoint_info_absolute_async():
from ..services.service_registry import ServiceRegistry
scanner = await ServiceRegistry.get_checkpoint_scanner()
cache = await scanner.get_cached_data()
# Get model roots for matching
model_roots = scanner.get_model_roots()
# Normalize the checkpoint name
normalized_name = checkpoint_name.replace(os.sep, "/")
for item in cache.raw_data:
file_path = item.get("file_path", "")
if not file_path:
continue
# Format the stored path as ComfyUI-style name
formatted_name = _format_model_name_for_comfyui(file_path, model_roots)
# Match by formatted name
if formatted_name == normalized_name or formatted_name == checkpoint_name:
return file_path, item
# Also try matching by basename only (for backward compatibility)
file_name = item.get("file_name", "")
if (
file_name == checkpoint_name
or file_name == os.path.splitext(normalized_name)[0]
):
return file_path, item
return checkpoint_name, None
try:
# Check if we're already in an event loop
loop = asyncio.get_running_loop()
# If we're in a running loop, we need to use a different approach
# Create a new thread to run the async code
import concurrent.futures
def run_in_thread():
new_loop = asyncio.new_event_loop()
asyncio.set_event_loop(new_loop)
try:
return new_loop.run_until_complete(
_get_checkpoint_info_absolute_async()
)
finally:
new_loop.close()
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(run_in_thread)
return future.result()
except RuntimeError:
# No event loop is running, we can use asyncio.run()
return asyncio.run(_get_checkpoint_info_absolute_async())
def _format_model_name_for_comfyui(file_path: str, model_roots: list) -> str:
"""Format file path to ComfyUI-style model name (relative path with extension)
Example: /path/to/checkpoints/Illustrious/model.safetensors -> Illustrious/model.safetensors
Args:
file_path: Absolute path to the model file
model_roots: List of model root directories
Returns:
ComfyUI-style model name with relative path and extension
"""
# Normalize path separators
normalized_path = file_path.replace(os.sep, "/")
# Find the matching root and get relative path
for root in model_roots:
normalized_root = root.replace(os.sep, "/")
# Ensure root ends with / for proper matching
if not normalized_root.endswith("/"):
normalized_root += "/"
if normalized_path.startswith(normalized_root):
rel_path = normalized_path[len(normalized_root) :]
return rel_path
# If no root matches, just return the basename with extension
return os.path.basename(file_path)
def fuzzy_match(text: str, pattern: str, threshold: float = 0.85) -> bool: def fuzzy_match(text: str, pattern: str, threshold: float = 0.85) -> bool:
""" """
Check if text matches pattern using fuzzy matching. Check if text matches pattern using fuzzy matching.

View File

@@ -36,8 +36,8 @@ class TestCheckpointPathOverlap:
config._preview_root_paths = set() config._preview_root_paths = set()
config._cached_fingerprint = None config._cached_fingerprint = None
# Call the method under test # Call the method under test - now returns a tuple
result = config._prepare_checkpoint_paths( all_paths, checkpoint_roots, unet_roots = config._prepare_checkpoint_paths(
[str(checkpoints_link)], [str(unet_link)] [str(checkpoints_link)], [str(unet_link)]
) )
@@ -50,21 +50,27 @@ class TestCheckpointPathOverlap:
] ]
assert len(warning_messages) == 1 assert len(warning_messages) == 1
assert "checkpoints" in warning_messages[0].lower() assert "checkpoints" in warning_messages[0].lower()
assert "diffusion_models" in warning_messages[0].lower() or "unet" in warning_messages[0].lower() assert (
"diffusion_models" in warning_messages[0].lower()
or "unet" in warning_messages[0].lower()
)
# Verify warning mentions backward compatibility fallback # Verify warning mentions backward compatibility fallback
assert "falling back" in warning_messages[0].lower() or "backward compatibility" in warning_messages[0].lower() assert (
"falling back" in warning_messages[0].lower()
or "backward compatibility" in warning_messages[0].lower()
)
# Verify only one path is returned (deduplication still works) # Verify only one path is returned (deduplication still works)
assert len(result) == 1 assert len(all_paths) == 1
# Prioritizes checkpoints path for backward compatibility # Prioritizes checkpoints path for backward compatibility
assert _normalize(result[0]) == _normalize(str(checkpoints_link)) assert _normalize(all_paths[0]) == _normalize(str(checkpoints_link))
# Verify checkpoints_roots has the path (prioritized) # Verify checkpoint_roots has the path (prioritized)
assert len(config.checkpoints_roots) == 1 assert len(checkpoint_roots) == 1
assert _normalize(config.checkpoints_roots[0]) == _normalize(str(checkpoints_link)) assert _normalize(checkpoint_roots[0]) == _normalize(str(checkpoints_link))
# Verify unet_roots is empty (overlapping paths removed) # Verify unet_roots is empty (overlapping paths removed)
assert config.unet_roots == [] assert unet_roots == []
def test_non_overlapping_paths_no_warning( def test_non_overlapping_paths_no_warning(
self, monkeypatch: pytest.MonkeyPatch, tmp_path, caplog self, monkeypatch: pytest.MonkeyPatch, tmp_path, caplog
@@ -83,7 +89,7 @@ class TestCheckpointPathOverlap:
config._preview_root_paths = set() config._preview_root_paths = set()
config._cached_fingerprint = None config._cached_fingerprint = None
result = config._prepare_checkpoint_paths( all_paths, checkpoint_roots, unet_roots = config._prepare_checkpoint_paths(
[str(checkpoints_dir)], [str(unet_dir)] [str(checkpoints_dir)], [str(unet_dir)]
) )
@@ -97,14 +103,14 @@ class TestCheckpointPathOverlap:
assert len(warning_messages) == 0 assert len(warning_messages) == 0
# Verify both paths are returned # Verify both paths are returned
assert len(result) == 2 assert len(all_paths) == 2
normalized_result = [_normalize(p) for p in result] normalized_result = [_normalize(p) for p in all_paths]
assert _normalize(str(checkpoints_dir)) in normalized_result assert _normalize(str(checkpoints_dir)) in normalized_result
assert _normalize(str(unet_dir)) in normalized_result assert _normalize(str(unet_dir)) in normalized_result
# Verify both roots are properly set # Verify both roots are properly set
assert len(config.checkpoints_roots) == 1 assert len(checkpoint_roots) == 1
assert len(config.unet_roots) == 1 assert len(unet_roots) == 1
def test_partial_overlap_prioritizes_checkpoints( def test_partial_overlap_prioritizes_checkpoints(
self, monkeypatch: pytest.MonkeyPatch, tmp_path, caplog self, monkeypatch: pytest.MonkeyPatch, tmp_path, caplog
@@ -129,9 +135,9 @@ class TestCheckpointPathOverlap:
config._cached_fingerprint = None config._cached_fingerprint = None
# One checkpoint path overlaps with one unet path # One checkpoint path overlaps with one unet path
result = config._prepare_checkpoint_paths( all_paths, checkpoint_roots, unet_roots = config._prepare_checkpoint_paths(
[str(shared_link), str(separate_checkpoint)], [str(shared_link), str(separate_checkpoint)],
[str(shared_link), str(separate_unet)] [str(shared_link), str(separate_unet)],
) )
# Verify warning was logged for the overlapping path # Verify warning was logged for the overlapping path
@@ -144,17 +150,20 @@ class TestCheckpointPathOverlap:
assert len(warning_messages) == 1 assert len(warning_messages) == 1
# Verify 3 unique paths (shared counted once as checkpoint, plus separate ones) # Verify 3 unique paths (shared counted once as checkpoint, plus separate ones)
assert len(result) == 3 assert len(all_paths) == 3
# Verify the overlapping path appears in warning message # Verify the overlapping path appears in warning message
assert str(shared_link.name) in warning_messages[0] or str(shared_dir.name) in warning_messages[0] assert (
str(shared_link.name) in warning_messages[0]
or str(shared_dir.name) in warning_messages[0]
)
# Verify checkpoints_roots includes both checkpoint paths (including the shared one) # Verify checkpoint_roots includes both checkpoint paths (including the shared one)
assert len(config.checkpoints_roots) == 2 assert len(checkpoint_roots) == 2
checkpoint_normalized = [_normalize(p) for p in config.checkpoints_roots] checkpoint_normalized = [_normalize(p) for p in checkpoint_roots]
assert _normalize(str(shared_link)) in checkpoint_normalized assert _normalize(str(shared_link)) in checkpoint_normalized
assert _normalize(str(separate_checkpoint)) in checkpoint_normalized assert _normalize(str(separate_checkpoint)) in checkpoint_normalized
# Verify unet_roots only includes the non-overlapping unet path # Verify unet_roots only includes the non-overlapping unet path
assert len(config.unet_roots) == 1 assert len(unet_roots) == 1
assert _normalize(config.unet_roots[0]) == _normalize(str(separate_unet)) assert _normalize(unet_roots[0]) == _normalize(str(separate_unet))

View File

@@ -0,0 +1,158 @@
"""Tests for checkpoint and unet loaders with extra folder paths support"""
import pytest
import os
# Get project root directory (ComfyUI-Lora-Manager folder)
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
class TestCheckpointLoaderLM:
"""Test CheckpointLoaderLM node"""
def test_class_attributes(self):
"""Test that CheckpointLoaderLM has required class attributes"""
# Import in a way that doesn't require ComfyUI
import ast
filepath = os.path.join(PROJECT_ROOT, "py", "nodes", "checkpoint_loader.py")
with open(filepath, "r") as f:
tree = ast.parse(f.read())
# Find CheckpointLoaderLM class
classes = {
node.name: node for node in ast.walk(tree) if isinstance(node, ast.ClassDef)
}
assert "CheckpointLoaderLM" in classes
cls = classes["CheckpointLoaderLM"]
# Check for NAME attribute
name_attr = [
n
for n in cls.body
if isinstance(n, ast.Assign)
and any(t.id == "NAME" for t in n.targets if isinstance(t, ast.Name))
]
assert len(name_attr) > 0, "CheckpointLoaderLM should have NAME attribute"
# Check for CATEGORY attribute
cat_attr = [
n
for n in cls.body
if isinstance(n, ast.Assign)
and any(t.id == "CATEGORY" for t in n.targets if isinstance(t, ast.Name))
]
assert len(cat_attr) > 0, "CheckpointLoaderLM should have CATEGORY attribute"
# Check for INPUT_TYPES method
input_types = [
n
for n in cls.body
if isinstance(n, ast.FunctionDef) and n.name == "INPUT_TYPES"
]
assert len(input_types) > 0, "CheckpointLoaderLM should have INPUT_TYPES method"
# Check for load_checkpoint method
load_method = [
n
for n in cls.body
if isinstance(n, ast.FunctionDef) and n.name == "load_checkpoint"
]
assert len(load_method) > 0, (
"CheckpointLoaderLM should have load_checkpoint method"
)
class TestUNETLoaderLM:
"""Test UNETLoaderLM node"""
def test_class_attributes(self):
"""Test that UNETLoaderLM has required class attributes"""
# Import in a way that doesn't require ComfyUI
import ast
filepath = os.path.join(PROJECT_ROOT, "py", "nodes", "unet_loader.py")
with open(filepath, "r") as f:
tree = ast.parse(f.read())
# Find UNETLoaderLM class
classes = {
node.name: node for node in ast.walk(tree) if isinstance(node, ast.ClassDef)
}
assert "UNETLoaderLM" in classes
cls = classes["UNETLoaderLM"]
# Check for NAME attribute
name_attr = [
n
for n in cls.body
if isinstance(n, ast.Assign)
and any(t.id == "NAME" for t in n.targets if isinstance(t, ast.Name))
]
assert len(name_attr) > 0, "UNETLoaderLM should have NAME attribute"
# Check for CATEGORY attribute
cat_attr = [
n
for n in cls.body
if isinstance(n, ast.Assign)
and any(t.id == "CATEGORY" for t in n.targets if isinstance(t, ast.Name))
]
assert len(cat_attr) > 0, "UNETLoaderLM should have CATEGORY attribute"
# Check for INPUT_TYPES method
input_types = [
n
for n in cls.body
if isinstance(n, ast.FunctionDef) and n.name == "INPUT_TYPES"
]
assert len(input_types) > 0, "UNETLoaderLM should have INPUT_TYPES method"
# Check for load_unet method
load_method = [
n
for n in cls.body
if isinstance(n, ast.FunctionDef) and n.name == "load_unet"
]
assert len(load_method) > 0, "UNETLoaderLM should have load_unet method"
class TestUtils:
"""Test utility functions"""
def test_get_checkpoint_info_absolute_exists(self):
"""Test that get_checkpoint_info_absolute function exists in utils"""
import ast
filepath = os.path.join(PROJECT_ROOT, "py", "utils", "utils.py")
with open(filepath, "r") as f:
tree = ast.parse(f.read())
functions = [
node.name for node in ast.walk(tree) if isinstance(node, ast.FunctionDef)
]
assert "get_checkpoint_info_absolute" in functions, (
"get_checkpoint_info_absolute should exist"
)
def test_format_model_name_for_comfyui_exists(self):
"""Test that _format_model_name_for_comfyui function exists in utils"""
import ast
filepath = os.path.join(PROJECT_ROOT, "py", "utils", "utils.py")
with open(filepath, "r") as f:
tree = ast.parse(f.read())
functions = [
node.name for node in ast.walk(tree) if isinstance(node, ast.FunctionDef)
]
assert "_format_model_name_for_comfyui" in functions, (
"_format_model_name_for_comfyui should exist"
)