fix: isolate extra unet paths from checkpoints to prevent type misclassification

Refactor _prepare_checkpoint_paths() to return a tuple instead of having
side effects on instance variables. This prevents extra unet paths from
being incorrectly classified as checkpoints when processing extra paths.

- Changed return type from List[str] to Tuple[List[str], List[str], List[str]]
  (all_paths, checkpoint_roots, unet_roots)
- Updated _init_checkpoint_paths() and _apply_library_paths() callers
- Fixed extra paths processing to properly isolate main and extra roots
- Updated test_checkpoint_path_overlap.py tests for new API

This ensures models in extra unet paths are correctly identified as
diffusion_model type and don't appear in checkpoints list.
This commit is contained in:
Will Miao
2026-03-17 22:03:57 +08:00
parent 70c150bd80
commit 2dae4c1291
8 changed files with 838 additions and 124 deletions

View File

@@ -1,6 +1,8 @@
try: # pragma: no cover - import fallback for pytest collection
from .py.lora_manager import LoraManager
from .py.nodes.lora_loader import LoraLoaderLM, LoraTextLoaderLM
from .py.nodes.checkpoint_loader import CheckpointLoaderLM
from .py.nodes.unet_loader import UNETLoaderLM
from .py.nodes.trigger_word_toggle import TriggerWordToggleLM
from .py.nodes.prompt import PromptLM
from .py.nodes.text import TextLM
@@ -27,12 +29,12 @@ except (
PromptLM = importlib.import_module("py.nodes.prompt").PromptLM
TextLM = importlib.import_module("py.nodes.text").TextLM
LoraManager = importlib.import_module("py.lora_manager").LoraManager
LoraLoaderLM = importlib.import_module(
"py.nodes.lora_loader"
).LoraLoaderLM
LoraTextLoaderLM = importlib.import_module(
"py.nodes.lora_loader"
).LoraTextLoaderLM
LoraLoaderLM = importlib.import_module("py.nodes.lora_loader").LoraLoaderLM
LoraTextLoaderLM = importlib.import_module("py.nodes.lora_loader").LoraTextLoaderLM
CheckpointLoaderLM = importlib.import_module(
"py.nodes.checkpoint_loader"
).CheckpointLoaderLM
UNETLoaderLM = importlib.import_module("py.nodes.unet_loader").UNETLoaderLM
TriggerWordToggleLM = importlib.import_module(
"py.nodes.trigger_word_toggle"
).TriggerWordToggleLM
@@ -49,9 +51,7 @@ except (
LoraRandomizerLM = importlib.import_module(
"py.nodes.lora_randomizer"
).LoraRandomizerLM
LoraCyclerLM = importlib.import_module(
"py.nodes.lora_cycler"
).LoraCyclerLM
LoraCyclerLM = importlib.import_module("py.nodes.lora_cycler").LoraCyclerLM
init_metadata_collector = importlib.import_module("py.metadata_collector").init
NODE_CLASS_MAPPINGS = {
@@ -59,6 +59,8 @@ NODE_CLASS_MAPPINGS = {
TextLM.NAME: TextLM,
LoraLoaderLM.NAME: LoraLoaderLM,
LoraTextLoaderLM.NAME: LoraTextLoaderLM,
CheckpointLoaderLM.NAME: CheckpointLoaderLM,
UNETLoaderLM.NAME: UNETLoaderLM,
TriggerWordToggleLM.NAME: TriggerWordToggleLM,
LoraStackerLM.NAME: LoraStackerLM,
SaveImageLM.NAME: SaveImageLM,

View File

@@ -707,7 +707,13 @@ class Config:
def _prepare_checkpoint_paths(
self, checkpoint_paths: Iterable[str], unet_paths: Iterable[str]
) -> List[str]:
) -> Tuple[List[str], List[str], List[str]]:
"""Prepare checkpoint paths and return (all_roots, checkpoint_roots, unet_roots).
Returns:
Tuple of (all_unique_paths, checkpoint_only_paths, unet_only_paths)
This method does NOT modify instance variables - callers must set them.
"""
checkpoint_map = self._dedupe_existing_paths(checkpoint_paths)
unet_map = self._dedupe_existing_paths(unet_paths)
@@ -737,8 +743,8 @@ class Config:
checkpoint_values = set(checkpoint_map.values())
unet_values = set(unet_map.values())
self.checkpoints_roots = [p for p in unique_paths if p in checkpoint_values]
self.unet_roots = [p for p in unique_paths if p in unet_values]
checkpoint_roots = [p for p in unique_paths if p in checkpoint_values]
unet_roots = [p for p in unique_paths if p in unet_values]
for original_path in unique_paths:
real_path = os.path.normpath(os.path.realpath(original_path)).replace(
@@ -747,7 +753,7 @@ class Config:
if real_path != original_path:
self.add_path_mapping(original_path, real_path)
return unique_paths
return unique_paths, checkpoint_roots, unet_roots
def _prepare_embedding_paths(self, raw_paths: Iterable[str]) -> List[str]:
path_map = self._dedupe_existing_paths(raw_paths)
@@ -776,9 +782,11 @@ class Config:
embedding_paths = folder_paths.get("embeddings", []) or []
self.loras_roots = self._prepare_lora_paths(lora_paths)
self.base_models_roots = self._prepare_checkpoint_paths(
checkpoint_paths, unet_paths
)
(
self.base_models_roots,
self.checkpoints_roots,
self.unet_roots,
) = self._prepare_checkpoint_paths(checkpoint_paths, unet_paths)
self.embeddings_roots = self._prepare_embedding_paths(embedding_paths)
# Process extra paths (only for LoRA Manager, not shared with ComfyUI)
@@ -789,18 +797,11 @@ class Config:
extra_embedding_paths = extra_paths.get("embeddings", []) or []
self.extra_loras_roots = self._prepare_lora_paths(extra_lora_paths)
# Save main paths before processing extra paths ( _prepare_checkpoint_paths overwrites them)
saved_checkpoints_roots = self.checkpoints_roots
saved_unet_roots = self.unet_roots
self.extra_checkpoints_roots = self._prepare_checkpoint_paths(
extra_checkpoint_paths, extra_unet_paths
)
self.extra_unet_roots = (
self.unet_roots if self.unet_roots is not None else []
) # unet_roots was set by _prepare_checkpoint_paths
# Restore main paths
self.checkpoints_roots = saved_checkpoints_roots
self.unet_roots = saved_unet_roots
(
_,
self.extra_checkpoints_roots,
self.extra_unet_roots,
) = self._prepare_checkpoint_paths(extra_checkpoint_paths, extra_unet_paths)
self.extra_embeddings_roots = self._prepare_embedding_paths(
extra_embedding_paths
)
@@ -857,9 +858,11 @@ class Config:
try:
raw_checkpoint_paths = folder_paths.get_folder_paths("checkpoints")
raw_unet_paths = folder_paths.get_folder_paths("unet")
unique_paths = self._prepare_checkpoint_paths(
raw_checkpoint_paths, raw_unet_paths
)
(
unique_paths,
self.checkpoints_roots,
self.unet_roots,
) = self._prepare_checkpoint_paths(raw_checkpoint_paths, raw_unet_paths)
logger.info(
"Found checkpoint roots:"

View File

@@ -0,0 +1,184 @@
import logging
import os
from typing import List, Tuple
import comfy.sd
import folder_paths
from ..utils.utils import get_checkpoint_info_absolute, _format_model_name_for_comfyui
logger = logging.getLogger(__name__)
class CheckpointLoaderLM:
"""Checkpoint Loader with support for extra folder paths
Loads checkpoints from both standard ComfyUI folders and LoRA Manager's
extra folder paths, providing a unified interface for checkpoint loading.
"""
NAME = "CheckpointLoaderLM"
CATEGORY = "Lora Manager/loaders"
@classmethod
def INPUT_TYPES(s):
# Get list of checkpoint names from scanner (includes extra folder paths)
checkpoint_names = s._get_checkpoint_names()
return {
"required": {
"ckpt_name": (
checkpoint_names,
{"tooltip": "The name of the checkpoint (model) to load."},
),
}
}
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
RETURN_NAMES = ("MODEL", "CLIP", "VAE")
OUTPUT_TOOLTIPS = (
"The model used for denoising latents.",
"The CLIP model used for encoding text prompts.",
"The VAE model used for encoding and decoding images to and from latent space.",
)
FUNCTION = "load_checkpoint"
@classmethod
def _get_checkpoint_names(cls) -> List[str]:
"""Get list of checkpoint names from scanner cache in ComfyUI format (relative path with extension)"""
try:
from ..services.service_registry import ServiceRegistry
import asyncio
async def _get_names():
scanner = await ServiceRegistry.get_checkpoint_scanner()
cache = await scanner.get_cached_data()
# Get all model roots for calculating relative paths
model_roots = scanner.get_model_roots()
# Filter only checkpoint type (not diffusion_model) and format names
names = []
for item in cache.raw_data:
if item.get("sub_type") == "checkpoint":
file_path = item.get("file_path", "")
if file_path:
# Format as ComfyUI-style: "folder/model_name.ext"
formatted_name = _format_model_name_for_comfyui(
file_path, model_roots
)
if formatted_name:
names.append(formatted_name)
return sorted(names)
try:
loop = asyncio.get_running_loop()
import concurrent.futures
def run_in_thread():
new_loop = asyncio.new_event_loop()
asyncio.set_event_loop(new_loop)
try:
return new_loop.run_until_complete(_get_names())
finally:
new_loop.close()
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(run_in_thread)
return future.result()
except RuntimeError:
return asyncio.run(_get_names())
except Exception as e:
logger.error(f"Error getting checkpoint names: {e}")
return []
def load_checkpoint(self, ckpt_name: str) -> Tuple:
"""Load a checkpoint by name, supporting extra folder paths
Args:
ckpt_name: The name of the checkpoint to load (format: "folder/model_name.ext")
Returns:
Tuple of (MODEL, CLIP, VAE)
"""
# Get absolute path from cache using ComfyUI-style name
ckpt_path, metadata = get_checkpoint_info_absolute(ckpt_name)
if metadata is None:
raise FileNotFoundError(
f"Checkpoint '{ckpt_name}' not found in LoRA Manager cache. "
"Make sure the checkpoint is indexed and try again."
)
# Check if it's a GGUF model
if ckpt_path.endswith(".gguf"):
return self._load_gguf_checkpoint(ckpt_path, ckpt_name)
# Load regular checkpoint using ComfyUI's API
logger.info(f"Loading checkpoint from: {ckpt_path}")
out = comfy.sd.load_checkpoint_guess_config(
ckpt_path,
output_vae=True,
output_clip=True,
embedding_directory=folder_paths.get_folder_paths("embeddings"),
)
return out[:3]
def _load_gguf_checkpoint(self, ckpt_path: str, ckpt_name: str) -> Tuple:
"""Load a GGUF format checkpoint
Args:
ckpt_path: Absolute path to the GGUF file
ckpt_name: Name of the checkpoint for error messages
Returns:
Tuple of (MODEL, CLIP, VAE) - CLIP and VAE may be None for GGUF
"""
try:
# Try to import ComfyUI-GGUF modules
from custom_nodes.ComfyUI_GGUF.loader import gguf_sd_loader
from custom_nodes.ComfyUI_GGUF.ops import GGMLOps
from custom_nodes.ComfyUI_GGUF.nodes import GGUFModelPatcher
except ImportError:
raise RuntimeError(
f"Cannot load GGUF model '{ckpt_name}'. "
"ComfyUI-GGUF is not installed. "
"Please install ComfyUI-GGUF from https://github.com/city96/ComfyUI-GGUF "
"to load GGUF format models."
)
logger.info(f"Loading GGUF checkpoint from: {ckpt_path}")
try:
# Load GGUF state dict
sd, extra = gguf_sd_loader(ckpt_path)
# Prepare kwargs for metadata if supported
kwargs = {}
import inspect
valid_params = inspect.signature(
comfy.sd.load_diffusion_model_state_dict
).parameters
if "metadata" in valid_params:
kwargs["metadata"] = extra.get("metadata", {})
# Load the model
model = comfy.sd.load_diffusion_model_state_dict(
sd, model_options={"custom_operations": GGMLOps()}, **kwargs
)
if model is None:
raise RuntimeError(
f"Could not detect model type for GGUF checkpoint: {ckpt_path}"
)
# Wrap with GGUFModelPatcher
model = GGUFModelPatcher.clone(model)
# GGUF checkpoints typically don't include CLIP/VAE
return (model, None, None)
except Exception as e:
logger.error(f"Error loading GGUF checkpoint '{ckpt_name}': {e}")
raise RuntimeError(
f"Failed to load GGUF checkpoint '{ckpt_name}': {str(e)}"
)

205
py/nodes/unet_loader.py Normal file
View File

@@ -0,0 +1,205 @@
import logging
import os
from typing import List, Tuple
import torch
import comfy.sd
from ..utils.utils import get_checkpoint_info_absolute, _format_model_name_for_comfyui
logger = logging.getLogger(__name__)
class UNETLoaderLM:
"""UNET Loader with support for extra folder paths
Loads diffusion models/UNets from both standard ComfyUI folders and LoRA Manager's
extra folder paths, providing a unified interface for UNET loading.
Supports both regular diffusion models and GGUF format models.
"""
NAME = "UNETLoaderLM"
CATEGORY = "Lora Manager/loaders"
@classmethod
def INPUT_TYPES(s):
# Get list of unet names from scanner (includes extra folder paths)
unet_names = s._get_unet_names()
return {
"required": {
"unet_name": (
unet_names,
{"tooltip": "The name of the diffusion model to load."},
),
"weight_dtype": (
["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2"],
{"tooltip": "The dtype to use for the model weights."},
),
}
}
RETURN_TYPES = ("MODEL",)
RETURN_NAMES = ("MODEL",)
OUTPUT_TOOLTIPS = ("The model used for denoising latents.",)
FUNCTION = "load_unet"
@classmethod
def _get_unet_names(cls) -> List[str]:
"""Get list of diffusion model names from scanner cache in ComfyUI format (relative path with extension)"""
try:
from ..services.service_registry import ServiceRegistry
import asyncio
async def _get_names():
scanner = await ServiceRegistry.get_checkpoint_scanner()
cache = await scanner.get_cached_data()
# Get all model roots for calculating relative paths
model_roots = scanner.get_model_roots()
# Filter only diffusion_model type and format names
names = []
for item in cache.raw_data:
if item.get("sub_type") == "diffusion_model":
file_path = item.get("file_path", "")
if file_path:
# Format as ComfyUI-style: "folder/model_name.ext"
formatted_name = _format_model_name_for_comfyui(
file_path, model_roots
)
if formatted_name:
names.append(formatted_name)
return sorted(names)
try:
loop = asyncio.get_running_loop()
import concurrent.futures
def run_in_thread():
new_loop = asyncio.new_event_loop()
asyncio.set_event_loop(new_loop)
try:
return new_loop.run_until_complete(_get_names())
finally:
new_loop.close()
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(run_in_thread)
return future.result()
except RuntimeError:
return asyncio.run(_get_names())
except Exception as e:
logger.error(f"Error getting unet names: {e}")
return []
def load_unet(self, unet_name: str, weight_dtype: str) -> Tuple:
"""Load a diffusion model by name, supporting extra folder paths
Args:
unet_name: The name of the diffusion model to load (format: "folder/model_name.ext")
weight_dtype: The dtype to use for model weights
Returns:
Tuple of (MODEL,)
"""
# Get absolute path from cache using ComfyUI-style name
unet_path, metadata = get_checkpoint_info_absolute(unet_name)
if metadata is None:
raise FileNotFoundError(
f"Diffusion model '{unet_name}' not found in LoRA Manager cache. "
"Make sure the model is indexed and try again."
)
# Check if it's a GGUF model
if unet_path.endswith(".gguf"):
return self._load_gguf_unet(unet_path, unet_name, weight_dtype)
# Load regular diffusion model using ComfyUI's API
logger.info(f"Loading diffusion model from: {unet_path}")
# Build model options based on weight_dtype
model_options = {}
if weight_dtype == "fp8_e4m3fn":
model_options["dtype"] = torch.float8_e4m3fn
elif weight_dtype == "fp8_e4m3fn_fast":
model_options["dtype"] = torch.float8_e4m3fn
model_options["fp8_optimizations"] = True
elif weight_dtype == "fp8_e5m2":
model_options["dtype"] = torch.float8_e5m2
model = comfy.sd.load_diffusion_model(unet_path, model_options=model_options)
return (model,)
def _load_gguf_unet(
self, unet_path: str, unet_name: str, weight_dtype: str
) -> Tuple:
"""Load a GGUF format diffusion model
Args:
unet_path: Absolute path to the GGUF file
unet_name: Name of the model for error messages
weight_dtype: The dtype to use for model weights
Returns:
Tuple of (MODEL,)
"""
try:
# Try to import ComfyUI-GGUF modules
from custom_nodes.ComfyUI_GGUF.loader import gguf_sd_loader
from custom_nodes.ComfyUI_GGUF.ops import GGMLOps
from custom_nodes.ComfyUI_GGUF.nodes import GGUFModelPatcher
except ImportError:
raise RuntimeError(
f"Cannot load GGUF model '{unet_name}'. "
"ComfyUI-GGUF is not installed. "
"Please install ComfyUI-GGUF from https://github.com/city96/ComfyUI-GGUF "
"to load GGUF format models."
)
logger.info(f"Loading GGUF diffusion model from: {unet_path}")
try:
# Load GGUF state dict
sd, extra = gguf_sd_loader(unet_path)
# Prepare kwargs for metadata if supported
kwargs = {}
import inspect
valid_params = inspect.signature(
comfy.sd.load_diffusion_model_state_dict
).parameters
if "metadata" in valid_params:
kwargs["metadata"] = extra.get("metadata", {})
# Setup custom operations with GGUF support
ops = GGMLOps()
# Handle weight_dtype for GGUF models
if weight_dtype in ("default", None):
ops.Linear.dequant_dtype = None
elif weight_dtype in ["target"]:
ops.Linear.dequant_dtype = weight_dtype
else:
ops.Linear.dequant_dtype = getattr(torch, weight_dtype, None)
# Load the model
model = comfy.sd.load_diffusion_model_state_dict(
sd, model_options={"custom_operations": ops}, **kwargs
)
if model is None:
raise RuntimeError(
f"Could not detect model type for GGUF diffusion model: {unet_path}"
)
# Wrap with GGUFModelPatcher
model = GGUFModelPatcher.clone(model)
return (model,)
except Exception as e:
logger.error(f"Error loading GGUF diffusion model '{unet_name}': {e}")
raise RuntimeError(
f"Failed to load GGUF diffusion model '{unet_name}': {str(e)}"
)

View File

@@ -13,20 +13,33 @@ from .model_hash_index import ModelHashIndex
logger = logging.getLogger(__name__)
class CheckpointScanner(ModelScanner):
"""Service for scanning and managing checkpoint files"""
def __init__(self):
# Define supported file extensions
file_extensions = {'.ckpt', '.pt', '.pt2', '.bin', '.pth', '.safetensors', '.pkl', '.sft', '.gguf'}
file_extensions = {
".ckpt",
".pt",
".pt2",
".bin",
".pth",
".safetensors",
".pkl",
".sft",
".gguf",
}
super().__init__(
model_type="checkpoint",
model_class=CheckpointMetadata,
file_extensions=file_extensions,
hash_index=ModelHashIndex()
hash_index=ModelHashIndex(),
)
async def _create_default_metadata(self, file_path: str) -> Optional[CheckpointMetadata]:
async def _create_default_metadata(
self, file_path: str
) -> Optional[CheckpointMetadata]:
"""Create default metadata for checkpoint without calculating hash (lazy hash).
Checkpoints are typically large (10GB+), so we skip hash calculation during initial
@@ -59,7 +72,7 @@ class CheckpointScanner(ModelScanner):
modelDescription="",
sub_type="checkpoint",
from_civitai=False, # Mark as local model since no hash yet
hash_status="pending" # Mark hash as pending
hash_status="pending", # Mark hash as pending
)
# Save the created metadata
@@ -69,7 +82,9 @@ class CheckpointScanner(ModelScanner):
return metadata
except Exception as e:
logger.error(f"Error creating default checkpoint metadata for {file_path}: {e}")
logger.error(
f"Error creating default checkpoint metadata for {file_path}: {e}"
)
return None
async def calculate_hash_for_model(self, file_path: str) -> Optional[str]:
@@ -90,7 +105,9 @@ class CheckpointScanner(ModelScanner):
return None
# Load current metadata
metadata, _ = await MetadataManager.load_metadata(file_path, self.model_class)
metadata, _ = await MetadataManager.load_metadata(
file_path, self.model_class
)
if metadata is None:
logger.error(f"No metadata found for {file_path}")
return None
@@ -122,7 +139,9 @@ class CheckpointScanner(ModelScanner):
logger.error(f"Error calculating hash for {file_path}: {e}")
# Update status to failed
try:
metadata, _ = await MetadataManager.load_metadata(file_path, self.model_class)
metadata, _ = await MetadataManager.load_metadata(
file_path, self.model_class
)
if metadata:
metadata.hash_status = "failed"
await MetadataManager.save_metadata(file_path, metadata)
@@ -130,7 +149,9 @@ class CheckpointScanner(ModelScanner):
pass
return None
async def calculate_all_pending_hashes(self, progress_callback=None) -> Dict[str, int]:
async def calculate_all_pending_hashes(
self, progress_callback=None
) -> Dict[str, int]:
"""Calculate hashes for all checkpoints with pending hash status.
If cache is not initialized, scans filesystem directly for metadata files
@@ -148,22 +169,23 @@ class CheckpointScanner(ModelScanner):
if cache and cache.raw_data:
# Use cache if available
pending_models = [
item for item in cache.raw_data
if item.get('hash_status') != 'completed' or not item.get('sha256')
item
for item in cache.raw_data
if item.get("hash_status") != "completed" or not item.get("sha256")
]
else:
# Cache not initialized, scan filesystem directly
pending_models = await self._find_pending_models_from_filesystem()
if not pending_models:
return {'completed': 0, 'failed': 0, 'total': 0}
return {"completed": 0, "failed": 0, "total": 0}
total = len(pending_models)
completed = 0
failed = 0
for i, model_data in enumerate(pending_models):
file_path = model_data.get('file_path')
file_path = model_data.get("file_path")
if not file_path:
continue
@@ -183,11 +205,7 @@ class CheckpointScanner(ModelScanner):
except Exception:
pass
return {
'completed': completed,
'failed': failed,
'total': total
}
return {"completed": completed, "failed": failed, "total": total}
async def _find_pending_models_from_filesystem(self) -> List[Dict[str, Any]]:
"""Scan filesystem for checkpoint metadata files with pending hash status."""
@@ -199,21 +217,21 @@ class CheckpointScanner(ModelScanner):
for dirpath, _dirnames, filenames in os.walk(root_path):
for filename in filenames:
if not filename.endswith('.metadata.json'):
if not filename.endswith(".metadata.json"):
continue
metadata_path = os.path.join(dirpath, filename)
try:
with open(metadata_path, 'r', encoding='utf-8') as f:
with open(metadata_path, "r", encoding="utf-8") as f:
data = json.load(f)
# Check if hash is pending
hash_status = data.get('hash_status', 'completed')
sha256 = data.get('sha256', '')
hash_status = data.get("hash_status", "completed")
sha256 = data.get("sha256", "")
if hash_status != 'completed' or not sha256:
if hash_status != "completed" or not sha256:
# Find corresponding model file
model_name = filename.replace('.metadata.json', '')
model_name = filename.replace(".metadata.json", "")
model_path = None
# Look for model file with matching name
@@ -224,29 +242,58 @@ class CheckpointScanner(ModelScanner):
break
if model_path:
pending_models.append({
'file_path': model_path.replace(os.sep, '/'),
'hash_status': hash_status,
'sha256': sha256,
**{k: v for k, v in data.items() if k not in ['file_path', 'hash_status', 'sha256']}
})
pending_models.append(
{
"file_path": model_path.replace(os.sep, "/"),
"hash_status": hash_status,
"sha256": sha256,
**{
k: v
for k, v in data.items()
if k
not in [
"file_path",
"hash_status",
"sha256",
]
},
}
)
except (json.JSONDecodeError, Exception) as e:
logger.debug(f"Error reading metadata file {metadata_path}: {e}")
logger.debug(
f"Error reading metadata file {metadata_path}: {e}"
)
continue
return pending_models
def _resolve_sub_type(self, root_path: Optional[str]) -> Optional[str]:
"""Resolve the sub-type based on the root path."""
"""Resolve the sub-type based on the root path.
Checks both standard ComfyUI paths and LoRA Manager's extra folder paths.
"""
if not root_path:
return None
# Check standard ComfyUI checkpoint paths
if config.checkpoints_roots and root_path in config.checkpoints_roots:
return "checkpoint"
# Check extra checkpoint paths
if (
config.extra_checkpoints_roots
and root_path in config.extra_checkpoints_roots
):
return "checkpoint"
# Check standard ComfyUI unet paths
if config.unet_roots and root_path in config.unet_roots:
return "diffusion_model"
# Check extra unet paths
if config.extra_unet_roots and root_path in config.extra_unet_roots:
return "diffusion_model"
return None
def adjust_metadata(self, metadata, file_path, root_path):

View File

@@ -112,6 +112,112 @@ def get_lora_info_absolute(lora_name):
return asyncio.run(_get_lora_info_absolute_async())
def get_checkpoint_info_absolute(checkpoint_name):
"""Get the absolute checkpoint path and metadata from cache
Supports ComfyUI-style model names (e.g., "folder/model_name.ext")
Args:
checkpoint_name: The model name, can be:
- ComfyUI format: "folder/model_name.safetensors"
- Simple name: "model_name"
Returns:
tuple: (absolute_path, metadata) where absolute_path is the full
file system path to the checkpoint file, or original checkpoint_name if not found,
metadata is the full model metadata dict or None
"""
async def _get_checkpoint_info_absolute_async():
from ..services.service_registry import ServiceRegistry
scanner = await ServiceRegistry.get_checkpoint_scanner()
cache = await scanner.get_cached_data()
# Get model roots for matching
model_roots = scanner.get_model_roots()
# Normalize the checkpoint name
normalized_name = checkpoint_name.replace(os.sep, "/")
for item in cache.raw_data:
file_path = item.get("file_path", "")
if not file_path:
continue
# Format the stored path as ComfyUI-style name
formatted_name = _format_model_name_for_comfyui(file_path, model_roots)
# Match by formatted name
if formatted_name == normalized_name or formatted_name == checkpoint_name:
return file_path, item
# Also try matching by basename only (for backward compatibility)
file_name = item.get("file_name", "")
if (
file_name == checkpoint_name
or file_name == os.path.splitext(normalized_name)[0]
):
return file_path, item
return checkpoint_name, None
try:
# Check if we're already in an event loop
loop = asyncio.get_running_loop()
# If we're in a running loop, we need to use a different approach
# Create a new thread to run the async code
import concurrent.futures
def run_in_thread():
new_loop = asyncio.new_event_loop()
asyncio.set_event_loop(new_loop)
try:
return new_loop.run_until_complete(
_get_checkpoint_info_absolute_async()
)
finally:
new_loop.close()
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(run_in_thread)
return future.result()
except RuntimeError:
# No event loop is running, we can use asyncio.run()
return asyncio.run(_get_checkpoint_info_absolute_async())
def _format_model_name_for_comfyui(file_path: str, model_roots: list) -> str:
"""Format file path to ComfyUI-style model name (relative path with extension)
Example: /path/to/checkpoints/Illustrious/model.safetensors -> Illustrious/model.safetensors
Args:
file_path: Absolute path to the model file
model_roots: List of model root directories
Returns:
ComfyUI-style model name with relative path and extension
"""
# Normalize path separators
normalized_path = file_path.replace(os.sep, "/")
# Find the matching root and get relative path
for root in model_roots:
normalized_root = root.replace(os.sep, "/")
# Ensure root ends with / for proper matching
if not normalized_root.endswith("/"):
normalized_root += "/"
if normalized_path.startswith(normalized_root):
rel_path = normalized_path[len(normalized_root) :]
return rel_path
# If no root matches, just return the basename with extension
return os.path.basename(file_path)
def fuzzy_match(text: str, pattern: str, threshold: float = 0.85) -> bool:
"""
Check if text matches pattern using fuzzy matching.

View File

@@ -36,8 +36,8 @@ class TestCheckpointPathOverlap:
config._preview_root_paths = set()
config._cached_fingerprint = None
# Call the method under test
result = config._prepare_checkpoint_paths(
# Call the method under test - now returns a tuple
all_paths, checkpoint_roots, unet_roots = config._prepare_checkpoint_paths(
[str(checkpoints_link)], [str(unet_link)]
)
@@ -50,21 +50,27 @@ class TestCheckpointPathOverlap:
]
assert len(warning_messages) == 1
assert "checkpoints" in warning_messages[0].lower()
assert "diffusion_models" in warning_messages[0].lower() or "unet" in warning_messages[0].lower()
assert (
"diffusion_models" in warning_messages[0].lower()
or "unet" in warning_messages[0].lower()
)
# Verify warning mentions backward compatibility fallback
assert "falling back" in warning_messages[0].lower() or "backward compatibility" in warning_messages[0].lower()
assert (
"falling back" in warning_messages[0].lower()
or "backward compatibility" in warning_messages[0].lower()
)
# Verify only one path is returned (deduplication still works)
assert len(result) == 1
assert len(all_paths) == 1
# Prioritizes checkpoints path for backward compatibility
assert _normalize(result[0]) == _normalize(str(checkpoints_link))
assert _normalize(all_paths[0]) == _normalize(str(checkpoints_link))
# Verify checkpoints_roots has the path (prioritized)
assert len(config.checkpoints_roots) == 1
assert _normalize(config.checkpoints_roots[0]) == _normalize(str(checkpoints_link))
# Verify checkpoint_roots has the path (prioritized)
assert len(checkpoint_roots) == 1
assert _normalize(checkpoint_roots[0]) == _normalize(str(checkpoints_link))
# Verify unet_roots is empty (overlapping paths removed)
assert config.unet_roots == []
assert unet_roots == []
def test_non_overlapping_paths_no_warning(
self, monkeypatch: pytest.MonkeyPatch, tmp_path, caplog
@@ -83,7 +89,7 @@ class TestCheckpointPathOverlap:
config._preview_root_paths = set()
config._cached_fingerprint = None
result = config._prepare_checkpoint_paths(
all_paths, checkpoint_roots, unet_roots = config._prepare_checkpoint_paths(
[str(checkpoints_dir)], [str(unet_dir)]
)
@@ -97,14 +103,14 @@ class TestCheckpointPathOverlap:
assert len(warning_messages) == 0
# Verify both paths are returned
assert len(result) == 2
normalized_result = [_normalize(p) for p in result]
assert len(all_paths) == 2
normalized_result = [_normalize(p) for p in all_paths]
assert _normalize(str(checkpoints_dir)) in normalized_result
assert _normalize(str(unet_dir)) in normalized_result
# Verify both roots are properly set
assert len(config.checkpoints_roots) == 1
assert len(config.unet_roots) == 1
assert len(checkpoint_roots) == 1
assert len(unet_roots) == 1
def test_partial_overlap_prioritizes_checkpoints(
self, monkeypatch: pytest.MonkeyPatch, tmp_path, caplog
@@ -129,9 +135,9 @@ class TestCheckpointPathOverlap:
config._cached_fingerprint = None
# One checkpoint path overlaps with one unet path
result = config._prepare_checkpoint_paths(
all_paths, checkpoint_roots, unet_roots = config._prepare_checkpoint_paths(
[str(shared_link), str(separate_checkpoint)],
[str(shared_link), str(separate_unet)]
[str(shared_link), str(separate_unet)],
)
# Verify warning was logged for the overlapping path
@@ -144,17 +150,20 @@ class TestCheckpointPathOverlap:
assert len(warning_messages) == 1
# Verify 3 unique paths (shared counted once as checkpoint, plus separate ones)
assert len(result) == 3
assert len(all_paths) == 3
# Verify the overlapping path appears in warning message
assert str(shared_link.name) in warning_messages[0] or str(shared_dir.name) in warning_messages[0]
assert (
str(shared_link.name) in warning_messages[0]
or str(shared_dir.name) in warning_messages[0]
)
# Verify checkpoints_roots includes both checkpoint paths (including the shared one)
assert len(config.checkpoints_roots) == 2
checkpoint_normalized = [_normalize(p) for p in config.checkpoints_roots]
# Verify checkpoint_roots includes both checkpoint paths (including the shared one)
assert len(checkpoint_roots) == 2
checkpoint_normalized = [_normalize(p) for p in checkpoint_roots]
assert _normalize(str(shared_link)) in checkpoint_normalized
assert _normalize(str(separate_checkpoint)) in checkpoint_normalized
# Verify unet_roots only includes the non-overlapping unet path
assert len(config.unet_roots) == 1
assert _normalize(config.unet_roots[0]) == _normalize(str(separate_unet))
assert len(unet_roots) == 1
assert _normalize(unet_roots[0]) == _normalize(str(separate_unet))

View File

@@ -0,0 +1,158 @@
"""Tests for checkpoint and unet loaders with extra folder paths support"""
import pytest
import os
# Get project root directory (ComfyUI-Lora-Manager folder)
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
class TestCheckpointLoaderLM:
"""Test CheckpointLoaderLM node"""
def test_class_attributes(self):
"""Test that CheckpointLoaderLM has required class attributes"""
# Import in a way that doesn't require ComfyUI
import ast
filepath = os.path.join(PROJECT_ROOT, "py", "nodes", "checkpoint_loader.py")
with open(filepath, "r") as f:
tree = ast.parse(f.read())
# Find CheckpointLoaderLM class
classes = {
node.name: node for node in ast.walk(tree) if isinstance(node, ast.ClassDef)
}
assert "CheckpointLoaderLM" in classes
cls = classes["CheckpointLoaderLM"]
# Check for NAME attribute
name_attr = [
n
for n in cls.body
if isinstance(n, ast.Assign)
and any(t.id == "NAME" for t in n.targets if isinstance(t, ast.Name))
]
assert len(name_attr) > 0, "CheckpointLoaderLM should have NAME attribute"
# Check for CATEGORY attribute
cat_attr = [
n
for n in cls.body
if isinstance(n, ast.Assign)
and any(t.id == "CATEGORY" for t in n.targets if isinstance(t, ast.Name))
]
assert len(cat_attr) > 0, "CheckpointLoaderLM should have CATEGORY attribute"
# Check for INPUT_TYPES method
input_types = [
n
for n in cls.body
if isinstance(n, ast.FunctionDef) and n.name == "INPUT_TYPES"
]
assert len(input_types) > 0, "CheckpointLoaderLM should have INPUT_TYPES method"
# Check for load_checkpoint method
load_method = [
n
for n in cls.body
if isinstance(n, ast.FunctionDef) and n.name == "load_checkpoint"
]
assert len(load_method) > 0, (
"CheckpointLoaderLM should have load_checkpoint method"
)
class TestUNETLoaderLM:
"""Test UNETLoaderLM node"""
def test_class_attributes(self):
"""Test that UNETLoaderLM has required class attributes"""
# Import in a way that doesn't require ComfyUI
import ast
filepath = os.path.join(PROJECT_ROOT, "py", "nodes", "unet_loader.py")
with open(filepath, "r") as f:
tree = ast.parse(f.read())
# Find UNETLoaderLM class
classes = {
node.name: node for node in ast.walk(tree) if isinstance(node, ast.ClassDef)
}
assert "UNETLoaderLM" in classes
cls = classes["UNETLoaderLM"]
# Check for NAME attribute
name_attr = [
n
for n in cls.body
if isinstance(n, ast.Assign)
and any(t.id == "NAME" for t in n.targets if isinstance(t, ast.Name))
]
assert len(name_attr) > 0, "UNETLoaderLM should have NAME attribute"
# Check for CATEGORY attribute
cat_attr = [
n
for n in cls.body
if isinstance(n, ast.Assign)
and any(t.id == "CATEGORY" for t in n.targets if isinstance(t, ast.Name))
]
assert len(cat_attr) > 0, "UNETLoaderLM should have CATEGORY attribute"
# Check for INPUT_TYPES method
input_types = [
n
for n in cls.body
if isinstance(n, ast.FunctionDef) and n.name == "INPUT_TYPES"
]
assert len(input_types) > 0, "UNETLoaderLM should have INPUT_TYPES method"
# Check for load_unet method
load_method = [
n
for n in cls.body
if isinstance(n, ast.FunctionDef) and n.name == "load_unet"
]
assert len(load_method) > 0, "UNETLoaderLM should have load_unet method"
class TestUtils:
"""Test utility functions"""
def test_get_checkpoint_info_absolute_exists(self):
"""Test that get_checkpoint_info_absolute function exists in utils"""
import ast
filepath = os.path.join(PROJECT_ROOT, "py", "utils", "utils.py")
with open(filepath, "r") as f:
tree = ast.parse(f.read())
functions = [
node.name for node in ast.walk(tree) if isinstance(node, ast.FunctionDef)
]
assert "get_checkpoint_info_absolute" in functions, (
"get_checkpoint_info_absolute should exist"
)
def test_format_model_name_for_comfyui_exists(self):
"""Test that _format_model_name_for_comfyui function exists in utils"""
import ast
filepath = os.path.join(PROJECT_ROOT, "py", "utils", "utils.py")
with open(filepath, "r") as f:
tree = ast.parse(f.read())
functions = [
node.name for node in ast.walk(tree) if isinstance(node, ast.FunctionDef)
]
assert "_format_model_name_for_comfyui" in functions, (
"_format_model_name_for_comfyui should exist"
)