mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 13:12:12 -03:00
fix: isolate extra unet paths from checkpoints to prevent type misclassification
Refactor _prepare_checkpoint_paths() to return a tuple instead of having side effects on instance variables. This prevents extra unet paths from being incorrectly classified as checkpoints when processing extra paths. - Changed return type from List[str] to Tuple[List[str], List[str], List[str]] (all_paths, checkpoint_roots, unet_roots) - Updated _init_checkpoint_paths() and _apply_library_paths() callers - Fixed extra paths processing to properly isolate main and extra roots - Updated test_checkpoint_path_overlap.py tests for new API This ensures models in extra unet paths are correctly identified as diffusion_model type and don't appear in checkpoints list.
This commit is contained in:
20
__init__.py
20
__init__.py
@@ -1,6 +1,8 @@
|
|||||||
try: # pragma: no cover - import fallback for pytest collection
|
try: # pragma: no cover - import fallback for pytest collection
|
||||||
from .py.lora_manager import LoraManager
|
from .py.lora_manager import LoraManager
|
||||||
from .py.nodes.lora_loader import LoraLoaderLM, LoraTextLoaderLM
|
from .py.nodes.lora_loader import LoraLoaderLM, LoraTextLoaderLM
|
||||||
|
from .py.nodes.checkpoint_loader import CheckpointLoaderLM
|
||||||
|
from .py.nodes.unet_loader import UNETLoaderLM
|
||||||
from .py.nodes.trigger_word_toggle import TriggerWordToggleLM
|
from .py.nodes.trigger_word_toggle import TriggerWordToggleLM
|
||||||
from .py.nodes.prompt import PromptLM
|
from .py.nodes.prompt import PromptLM
|
||||||
from .py.nodes.text import TextLM
|
from .py.nodes.text import TextLM
|
||||||
@@ -27,12 +29,12 @@ except (
|
|||||||
PromptLM = importlib.import_module("py.nodes.prompt").PromptLM
|
PromptLM = importlib.import_module("py.nodes.prompt").PromptLM
|
||||||
TextLM = importlib.import_module("py.nodes.text").TextLM
|
TextLM = importlib.import_module("py.nodes.text").TextLM
|
||||||
LoraManager = importlib.import_module("py.lora_manager").LoraManager
|
LoraManager = importlib.import_module("py.lora_manager").LoraManager
|
||||||
LoraLoaderLM = importlib.import_module(
|
LoraLoaderLM = importlib.import_module("py.nodes.lora_loader").LoraLoaderLM
|
||||||
"py.nodes.lora_loader"
|
LoraTextLoaderLM = importlib.import_module("py.nodes.lora_loader").LoraTextLoaderLM
|
||||||
).LoraLoaderLM
|
CheckpointLoaderLM = importlib.import_module(
|
||||||
LoraTextLoaderLM = importlib.import_module(
|
"py.nodes.checkpoint_loader"
|
||||||
"py.nodes.lora_loader"
|
).CheckpointLoaderLM
|
||||||
).LoraTextLoaderLM
|
UNETLoaderLM = importlib.import_module("py.nodes.unet_loader").UNETLoaderLM
|
||||||
TriggerWordToggleLM = importlib.import_module(
|
TriggerWordToggleLM = importlib.import_module(
|
||||||
"py.nodes.trigger_word_toggle"
|
"py.nodes.trigger_word_toggle"
|
||||||
).TriggerWordToggleLM
|
).TriggerWordToggleLM
|
||||||
@@ -49,9 +51,7 @@ except (
|
|||||||
LoraRandomizerLM = importlib.import_module(
|
LoraRandomizerLM = importlib.import_module(
|
||||||
"py.nodes.lora_randomizer"
|
"py.nodes.lora_randomizer"
|
||||||
).LoraRandomizerLM
|
).LoraRandomizerLM
|
||||||
LoraCyclerLM = importlib.import_module(
|
LoraCyclerLM = importlib.import_module("py.nodes.lora_cycler").LoraCyclerLM
|
||||||
"py.nodes.lora_cycler"
|
|
||||||
).LoraCyclerLM
|
|
||||||
init_metadata_collector = importlib.import_module("py.metadata_collector").init
|
init_metadata_collector = importlib.import_module("py.metadata_collector").init
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
@@ -59,6 +59,8 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
TextLM.NAME: TextLM,
|
TextLM.NAME: TextLM,
|
||||||
LoraLoaderLM.NAME: LoraLoaderLM,
|
LoraLoaderLM.NAME: LoraLoaderLM,
|
||||||
LoraTextLoaderLM.NAME: LoraTextLoaderLM,
|
LoraTextLoaderLM.NAME: LoraTextLoaderLM,
|
||||||
|
CheckpointLoaderLM.NAME: CheckpointLoaderLM,
|
||||||
|
UNETLoaderLM.NAME: UNETLoaderLM,
|
||||||
TriggerWordToggleLM.NAME: TriggerWordToggleLM,
|
TriggerWordToggleLM.NAME: TriggerWordToggleLM,
|
||||||
LoraStackerLM.NAME: LoraStackerLM,
|
LoraStackerLM.NAME: LoraStackerLM,
|
||||||
SaveImageLM.NAME: SaveImageLM,
|
SaveImageLM.NAME: SaveImageLM,
|
||||||
|
|||||||
47
py/config.py
47
py/config.py
@@ -707,7 +707,13 @@ class Config:
|
|||||||
|
|
||||||
def _prepare_checkpoint_paths(
|
def _prepare_checkpoint_paths(
|
||||||
self, checkpoint_paths: Iterable[str], unet_paths: Iterable[str]
|
self, checkpoint_paths: Iterable[str], unet_paths: Iterable[str]
|
||||||
) -> List[str]:
|
) -> Tuple[List[str], List[str], List[str]]:
|
||||||
|
"""Prepare checkpoint paths and return (all_roots, checkpoint_roots, unet_roots).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (all_unique_paths, checkpoint_only_paths, unet_only_paths)
|
||||||
|
This method does NOT modify instance variables - callers must set them.
|
||||||
|
"""
|
||||||
checkpoint_map = self._dedupe_existing_paths(checkpoint_paths)
|
checkpoint_map = self._dedupe_existing_paths(checkpoint_paths)
|
||||||
unet_map = self._dedupe_existing_paths(unet_paths)
|
unet_map = self._dedupe_existing_paths(unet_paths)
|
||||||
|
|
||||||
@@ -737,8 +743,8 @@ class Config:
|
|||||||
|
|
||||||
checkpoint_values = set(checkpoint_map.values())
|
checkpoint_values = set(checkpoint_map.values())
|
||||||
unet_values = set(unet_map.values())
|
unet_values = set(unet_map.values())
|
||||||
self.checkpoints_roots = [p for p in unique_paths if p in checkpoint_values]
|
checkpoint_roots = [p for p in unique_paths if p in checkpoint_values]
|
||||||
self.unet_roots = [p for p in unique_paths if p in unet_values]
|
unet_roots = [p for p in unique_paths if p in unet_values]
|
||||||
|
|
||||||
for original_path in unique_paths:
|
for original_path in unique_paths:
|
||||||
real_path = os.path.normpath(os.path.realpath(original_path)).replace(
|
real_path = os.path.normpath(os.path.realpath(original_path)).replace(
|
||||||
@@ -747,7 +753,7 @@ class Config:
|
|||||||
if real_path != original_path:
|
if real_path != original_path:
|
||||||
self.add_path_mapping(original_path, real_path)
|
self.add_path_mapping(original_path, real_path)
|
||||||
|
|
||||||
return unique_paths
|
return unique_paths, checkpoint_roots, unet_roots
|
||||||
|
|
||||||
def _prepare_embedding_paths(self, raw_paths: Iterable[str]) -> List[str]:
|
def _prepare_embedding_paths(self, raw_paths: Iterable[str]) -> List[str]:
|
||||||
path_map = self._dedupe_existing_paths(raw_paths)
|
path_map = self._dedupe_existing_paths(raw_paths)
|
||||||
@@ -776,9 +782,11 @@ class Config:
|
|||||||
embedding_paths = folder_paths.get("embeddings", []) or []
|
embedding_paths = folder_paths.get("embeddings", []) or []
|
||||||
|
|
||||||
self.loras_roots = self._prepare_lora_paths(lora_paths)
|
self.loras_roots = self._prepare_lora_paths(lora_paths)
|
||||||
self.base_models_roots = self._prepare_checkpoint_paths(
|
(
|
||||||
checkpoint_paths, unet_paths
|
self.base_models_roots,
|
||||||
)
|
self.checkpoints_roots,
|
||||||
|
self.unet_roots,
|
||||||
|
) = self._prepare_checkpoint_paths(checkpoint_paths, unet_paths)
|
||||||
self.embeddings_roots = self._prepare_embedding_paths(embedding_paths)
|
self.embeddings_roots = self._prepare_embedding_paths(embedding_paths)
|
||||||
|
|
||||||
# Process extra paths (only for LoRA Manager, not shared with ComfyUI)
|
# Process extra paths (only for LoRA Manager, not shared with ComfyUI)
|
||||||
@@ -789,18 +797,11 @@ class Config:
|
|||||||
extra_embedding_paths = extra_paths.get("embeddings", []) or []
|
extra_embedding_paths = extra_paths.get("embeddings", []) or []
|
||||||
|
|
||||||
self.extra_loras_roots = self._prepare_lora_paths(extra_lora_paths)
|
self.extra_loras_roots = self._prepare_lora_paths(extra_lora_paths)
|
||||||
# Save main paths before processing extra paths ( _prepare_checkpoint_paths overwrites them)
|
(
|
||||||
saved_checkpoints_roots = self.checkpoints_roots
|
_,
|
||||||
saved_unet_roots = self.unet_roots
|
self.extra_checkpoints_roots,
|
||||||
self.extra_checkpoints_roots = self._prepare_checkpoint_paths(
|
self.extra_unet_roots,
|
||||||
extra_checkpoint_paths, extra_unet_paths
|
) = self._prepare_checkpoint_paths(extra_checkpoint_paths, extra_unet_paths)
|
||||||
)
|
|
||||||
self.extra_unet_roots = (
|
|
||||||
self.unet_roots if self.unet_roots is not None else []
|
|
||||||
) # unet_roots was set by _prepare_checkpoint_paths
|
|
||||||
# Restore main paths
|
|
||||||
self.checkpoints_roots = saved_checkpoints_roots
|
|
||||||
self.unet_roots = saved_unet_roots
|
|
||||||
self.extra_embeddings_roots = self._prepare_embedding_paths(
|
self.extra_embeddings_roots = self._prepare_embedding_paths(
|
||||||
extra_embedding_paths
|
extra_embedding_paths
|
||||||
)
|
)
|
||||||
@@ -857,9 +858,11 @@ class Config:
|
|||||||
try:
|
try:
|
||||||
raw_checkpoint_paths = folder_paths.get_folder_paths("checkpoints")
|
raw_checkpoint_paths = folder_paths.get_folder_paths("checkpoints")
|
||||||
raw_unet_paths = folder_paths.get_folder_paths("unet")
|
raw_unet_paths = folder_paths.get_folder_paths("unet")
|
||||||
unique_paths = self._prepare_checkpoint_paths(
|
(
|
||||||
raw_checkpoint_paths, raw_unet_paths
|
unique_paths,
|
||||||
)
|
self.checkpoints_roots,
|
||||||
|
self.unet_roots,
|
||||||
|
) = self._prepare_checkpoint_paths(raw_checkpoint_paths, raw_unet_paths)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"Found checkpoint roots:"
|
"Found checkpoint roots:"
|
||||||
|
|||||||
184
py/nodes/checkpoint_loader.py
Normal file
184
py/nodes/checkpoint_loader.py
Normal file
@@ -0,0 +1,184 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import List, Tuple
|
||||||
|
import comfy.sd
|
||||||
|
import folder_paths
|
||||||
|
from ..utils.utils import get_checkpoint_info_absolute, _format_model_name_for_comfyui
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class CheckpointLoaderLM:
|
||||||
|
"""Checkpoint Loader with support for extra folder paths
|
||||||
|
|
||||||
|
Loads checkpoints from both standard ComfyUI folders and LoRA Manager's
|
||||||
|
extra folder paths, providing a unified interface for checkpoint loading.
|
||||||
|
"""
|
||||||
|
|
||||||
|
NAME = "CheckpointLoaderLM"
|
||||||
|
CATEGORY = "Lora Manager/loaders"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
# Get list of checkpoint names from scanner (includes extra folder paths)
|
||||||
|
checkpoint_names = s._get_checkpoint_names()
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"ckpt_name": (
|
||||||
|
checkpoint_names,
|
||||||
|
{"tooltip": "The name of the checkpoint (model) to load."},
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
|
||||||
|
RETURN_NAMES = ("MODEL", "CLIP", "VAE")
|
||||||
|
OUTPUT_TOOLTIPS = (
|
||||||
|
"The model used for denoising latents.",
|
||||||
|
"The CLIP model used for encoding text prompts.",
|
||||||
|
"The VAE model used for encoding and decoding images to and from latent space.",
|
||||||
|
)
|
||||||
|
FUNCTION = "load_checkpoint"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_checkpoint_names(cls) -> List[str]:
|
||||||
|
"""Get list of checkpoint names from scanner cache in ComfyUI format (relative path with extension)"""
|
||||||
|
try:
|
||||||
|
from ..services.service_registry import ServiceRegistry
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
async def _get_names():
|
||||||
|
scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||||
|
cache = await scanner.get_cached_data()
|
||||||
|
|
||||||
|
# Get all model roots for calculating relative paths
|
||||||
|
model_roots = scanner.get_model_roots()
|
||||||
|
|
||||||
|
# Filter only checkpoint type (not diffusion_model) and format names
|
||||||
|
names = []
|
||||||
|
for item in cache.raw_data:
|
||||||
|
if item.get("sub_type") == "checkpoint":
|
||||||
|
file_path = item.get("file_path", "")
|
||||||
|
if file_path:
|
||||||
|
# Format as ComfyUI-style: "folder/model_name.ext"
|
||||||
|
formatted_name = _format_model_name_for_comfyui(
|
||||||
|
file_path, model_roots
|
||||||
|
)
|
||||||
|
if formatted_name:
|
||||||
|
names.append(formatted_name)
|
||||||
|
|
||||||
|
return sorted(names)
|
||||||
|
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
import concurrent.futures
|
||||||
|
|
||||||
|
def run_in_thread():
|
||||||
|
new_loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(new_loop)
|
||||||
|
try:
|
||||||
|
return new_loop.run_until_complete(_get_names())
|
||||||
|
finally:
|
||||||
|
new_loop.close()
|
||||||
|
|
||||||
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||||
|
future = executor.submit(run_in_thread)
|
||||||
|
return future.result()
|
||||||
|
except RuntimeError:
|
||||||
|
return asyncio.run(_get_names())
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting checkpoint names: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def load_checkpoint(self, ckpt_name: str) -> Tuple:
|
||||||
|
"""Load a checkpoint by name, supporting extra folder paths
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ckpt_name: The name of the checkpoint to load (format: "folder/model_name.ext")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (MODEL, CLIP, VAE)
|
||||||
|
"""
|
||||||
|
# Get absolute path from cache using ComfyUI-style name
|
||||||
|
ckpt_path, metadata = get_checkpoint_info_absolute(ckpt_name)
|
||||||
|
|
||||||
|
if metadata is None:
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"Checkpoint '{ckpt_name}' not found in LoRA Manager cache. "
|
||||||
|
"Make sure the checkpoint is indexed and try again."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if it's a GGUF model
|
||||||
|
if ckpt_path.endswith(".gguf"):
|
||||||
|
return self._load_gguf_checkpoint(ckpt_path, ckpt_name)
|
||||||
|
|
||||||
|
# Load regular checkpoint using ComfyUI's API
|
||||||
|
logger.info(f"Loading checkpoint from: {ckpt_path}")
|
||||||
|
out = comfy.sd.load_checkpoint_guess_config(
|
||||||
|
ckpt_path,
|
||||||
|
output_vae=True,
|
||||||
|
output_clip=True,
|
||||||
|
embedding_directory=folder_paths.get_folder_paths("embeddings"),
|
||||||
|
)
|
||||||
|
return out[:3]
|
||||||
|
|
||||||
|
def _load_gguf_checkpoint(self, ckpt_path: str, ckpt_name: str) -> Tuple:
|
||||||
|
"""Load a GGUF format checkpoint
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ckpt_path: Absolute path to the GGUF file
|
||||||
|
ckpt_name: Name of the checkpoint for error messages
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (MODEL, CLIP, VAE) - CLIP and VAE may be None for GGUF
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Try to import ComfyUI-GGUF modules
|
||||||
|
from custom_nodes.ComfyUI_GGUF.loader import gguf_sd_loader
|
||||||
|
from custom_nodes.ComfyUI_GGUF.ops import GGMLOps
|
||||||
|
from custom_nodes.ComfyUI_GGUF.nodes import GGUFModelPatcher
|
||||||
|
except ImportError:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Cannot load GGUF model '{ckpt_name}'. "
|
||||||
|
"ComfyUI-GGUF is not installed. "
|
||||||
|
"Please install ComfyUI-GGUF from https://github.com/city96/ComfyUI-GGUF "
|
||||||
|
"to load GGUF format models."
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Loading GGUF checkpoint from: {ckpt_path}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Load GGUF state dict
|
||||||
|
sd, extra = gguf_sd_loader(ckpt_path)
|
||||||
|
|
||||||
|
# Prepare kwargs for metadata if supported
|
||||||
|
kwargs = {}
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
valid_params = inspect.signature(
|
||||||
|
comfy.sd.load_diffusion_model_state_dict
|
||||||
|
).parameters
|
||||||
|
if "metadata" in valid_params:
|
||||||
|
kwargs["metadata"] = extra.get("metadata", {})
|
||||||
|
|
||||||
|
# Load the model
|
||||||
|
model = comfy.sd.load_diffusion_model_state_dict(
|
||||||
|
sd, model_options={"custom_operations": GGMLOps()}, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
if model is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Could not detect model type for GGUF checkpoint: {ckpt_path}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Wrap with GGUFModelPatcher
|
||||||
|
model = GGUFModelPatcher.clone(model)
|
||||||
|
|
||||||
|
# GGUF checkpoints typically don't include CLIP/VAE
|
||||||
|
return (model, None, None)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error loading GGUF checkpoint '{ckpt_name}': {e}")
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Failed to load GGUF checkpoint '{ckpt_name}': {str(e)}"
|
||||||
|
)
|
||||||
205
py/nodes/unet_loader.py
Normal file
205
py/nodes/unet_loader.py
Normal file
@@ -0,0 +1,205 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import List, Tuple
|
||||||
|
import torch
|
||||||
|
import comfy.sd
|
||||||
|
from ..utils.utils import get_checkpoint_info_absolute, _format_model_name_for_comfyui
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class UNETLoaderLM:
|
||||||
|
"""UNET Loader with support for extra folder paths
|
||||||
|
|
||||||
|
Loads diffusion models/UNets from both standard ComfyUI folders and LoRA Manager's
|
||||||
|
extra folder paths, providing a unified interface for UNET loading.
|
||||||
|
Supports both regular diffusion models and GGUF format models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
NAME = "UNETLoaderLM"
|
||||||
|
CATEGORY = "Lora Manager/loaders"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
# Get list of unet names from scanner (includes extra folder paths)
|
||||||
|
unet_names = s._get_unet_names()
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"unet_name": (
|
||||||
|
unet_names,
|
||||||
|
{"tooltip": "The name of the diffusion model to load."},
|
||||||
|
),
|
||||||
|
"weight_dtype": (
|
||||||
|
["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2"],
|
||||||
|
{"tooltip": "The dtype to use for the model weights."},
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("MODEL",)
|
||||||
|
RETURN_NAMES = ("MODEL",)
|
||||||
|
OUTPUT_TOOLTIPS = ("The model used for denoising latents.",)
|
||||||
|
FUNCTION = "load_unet"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_unet_names(cls) -> List[str]:
|
||||||
|
"""Get list of diffusion model names from scanner cache in ComfyUI format (relative path with extension)"""
|
||||||
|
try:
|
||||||
|
from ..services.service_registry import ServiceRegistry
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
async def _get_names():
|
||||||
|
scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||||
|
cache = await scanner.get_cached_data()
|
||||||
|
|
||||||
|
# Get all model roots for calculating relative paths
|
||||||
|
model_roots = scanner.get_model_roots()
|
||||||
|
|
||||||
|
# Filter only diffusion_model type and format names
|
||||||
|
names = []
|
||||||
|
for item in cache.raw_data:
|
||||||
|
if item.get("sub_type") == "diffusion_model":
|
||||||
|
file_path = item.get("file_path", "")
|
||||||
|
if file_path:
|
||||||
|
# Format as ComfyUI-style: "folder/model_name.ext"
|
||||||
|
formatted_name = _format_model_name_for_comfyui(
|
||||||
|
file_path, model_roots
|
||||||
|
)
|
||||||
|
if formatted_name:
|
||||||
|
names.append(formatted_name)
|
||||||
|
|
||||||
|
return sorted(names)
|
||||||
|
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
import concurrent.futures
|
||||||
|
|
||||||
|
def run_in_thread():
|
||||||
|
new_loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(new_loop)
|
||||||
|
try:
|
||||||
|
return new_loop.run_until_complete(_get_names())
|
||||||
|
finally:
|
||||||
|
new_loop.close()
|
||||||
|
|
||||||
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||||
|
future = executor.submit(run_in_thread)
|
||||||
|
return future.result()
|
||||||
|
except RuntimeError:
|
||||||
|
return asyncio.run(_get_names())
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting unet names: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def load_unet(self, unet_name: str, weight_dtype: str) -> Tuple:
|
||||||
|
"""Load a diffusion model by name, supporting extra folder paths
|
||||||
|
|
||||||
|
Args:
|
||||||
|
unet_name: The name of the diffusion model to load (format: "folder/model_name.ext")
|
||||||
|
weight_dtype: The dtype to use for model weights
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (MODEL,)
|
||||||
|
"""
|
||||||
|
# Get absolute path from cache using ComfyUI-style name
|
||||||
|
unet_path, metadata = get_checkpoint_info_absolute(unet_name)
|
||||||
|
|
||||||
|
if metadata is None:
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"Diffusion model '{unet_name}' not found in LoRA Manager cache. "
|
||||||
|
"Make sure the model is indexed and try again."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if it's a GGUF model
|
||||||
|
if unet_path.endswith(".gguf"):
|
||||||
|
return self._load_gguf_unet(unet_path, unet_name, weight_dtype)
|
||||||
|
|
||||||
|
# Load regular diffusion model using ComfyUI's API
|
||||||
|
logger.info(f"Loading diffusion model from: {unet_path}")
|
||||||
|
|
||||||
|
# Build model options based on weight_dtype
|
||||||
|
model_options = {}
|
||||||
|
if weight_dtype == "fp8_e4m3fn":
|
||||||
|
model_options["dtype"] = torch.float8_e4m3fn
|
||||||
|
elif weight_dtype == "fp8_e4m3fn_fast":
|
||||||
|
model_options["dtype"] = torch.float8_e4m3fn
|
||||||
|
model_options["fp8_optimizations"] = True
|
||||||
|
elif weight_dtype == "fp8_e5m2":
|
||||||
|
model_options["dtype"] = torch.float8_e5m2
|
||||||
|
|
||||||
|
model = comfy.sd.load_diffusion_model(unet_path, model_options=model_options)
|
||||||
|
return (model,)
|
||||||
|
|
||||||
|
def _load_gguf_unet(
|
||||||
|
self, unet_path: str, unet_name: str, weight_dtype: str
|
||||||
|
) -> Tuple:
|
||||||
|
"""Load a GGUF format diffusion model
|
||||||
|
|
||||||
|
Args:
|
||||||
|
unet_path: Absolute path to the GGUF file
|
||||||
|
unet_name: Name of the model for error messages
|
||||||
|
weight_dtype: The dtype to use for model weights
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (MODEL,)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Try to import ComfyUI-GGUF modules
|
||||||
|
from custom_nodes.ComfyUI_GGUF.loader import gguf_sd_loader
|
||||||
|
from custom_nodes.ComfyUI_GGUF.ops import GGMLOps
|
||||||
|
from custom_nodes.ComfyUI_GGUF.nodes import GGUFModelPatcher
|
||||||
|
except ImportError:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Cannot load GGUF model '{unet_name}'. "
|
||||||
|
"ComfyUI-GGUF is not installed. "
|
||||||
|
"Please install ComfyUI-GGUF from https://github.com/city96/ComfyUI-GGUF "
|
||||||
|
"to load GGUF format models."
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Loading GGUF diffusion model from: {unet_path}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Load GGUF state dict
|
||||||
|
sd, extra = gguf_sd_loader(unet_path)
|
||||||
|
|
||||||
|
# Prepare kwargs for metadata if supported
|
||||||
|
kwargs = {}
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
valid_params = inspect.signature(
|
||||||
|
comfy.sd.load_diffusion_model_state_dict
|
||||||
|
).parameters
|
||||||
|
if "metadata" in valid_params:
|
||||||
|
kwargs["metadata"] = extra.get("metadata", {})
|
||||||
|
|
||||||
|
# Setup custom operations with GGUF support
|
||||||
|
ops = GGMLOps()
|
||||||
|
|
||||||
|
# Handle weight_dtype for GGUF models
|
||||||
|
if weight_dtype in ("default", None):
|
||||||
|
ops.Linear.dequant_dtype = None
|
||||||
|
elif weight_dtype in ["target"]:
|
||||||
|
ops.Linear.dequant_dtype = weight_dtype
|
||||||
|
else:
|
||||||
|
ops.Linear.dequant_dtype = getattr(torch, weight_dtype, None)
|
||||||
|
|
||||||
|
# Load the model
|
||||||
|
model = comfy.sd.load_diffusion_model_state_dict(
|
||||||
|
sd, model_options={"custom_operations": ops}, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
if model is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Could not detect model type for GGUF diffusion model: {unet_path}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Wrap with GGUFModelPatcher
|
||||||
|
model = GGUFModelPatcher.clone(model)
|
||||||
|
|
||||||
|
return (model,)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error loading GGUF diffusion model '{unet_name}': {e}")
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Failed to load GGUF diffusion model '{unet_name}': {str(e)}"
|
||||||
|
)
|
||||||
@@ -13,22 +13,35 @@ from .model_hash_index import ModelHashIndex
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class CheckpointScanner(ModelScanner):
|
class CheckpointScanner(ModelScanner):
|
||||||
"""Service for scanning and managing checkpoint files"""
|
"""Service for scanning and managing checkpoint files"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
# Define supported file extensions
|
# Define supported file extensions
|
||||||
file_extensions = {'.ckpt', '.pt', '.pt2', '.bin', '.pth', '.safetensors', '.pkl', '.sft', '.gguf'}
|
file_extensions = {
|
||||||
|
".ckpt",
|
||||||
|
".pt",
|
||||||
|
".pt2",
|
||||||
|
".bin",
|
||||||
|
".pth",
|
||||||
|
".safetensors",
|
||||||
|
".pkl",
|
||||||
|
".sft",
|
||||||
|
".gguf",
|
||||||
|
}
|
||||||
super().__init__(
|
super().__init__(
|
||||||
model_type="checkpoint",
|
model_type="checkpoint",
|
||||||
model_class=CheckpointMetadata,
|
model_class=CheckpointMetadata,
|
||||||
file_extensions=file_extensions,
|
file_extensions=file_extensions,
|
||||||
hash_index=ModelHashIndex()
|
hash_index=ModelHashIndex(),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _create_default_metadata(self, file_path: str) -> Optional[CheckpointMetadata]:
|
async def _create_default_metadata(
|
||||||
|
self, file_path: str
|
||||||
|
) -> Optional[CheckpointMetadata]:
|
||||||
"""Create default metadata for checkpoint without calculating hash (lazy hash).
|
"""Create default metadata for checkpoint without calculating hash (lazy hash).
|
||||||
|
|
||||||
Checkpoints are typically large (10GB+), so we skip hash calculation during initial
|
Checkpoints are typically large (10GB+), so we skip hash calculation during initial
|
||||||
scanning to improve startup performance. Hash will be calculated on-demand when
|
scanning to improve startup performance. Hash will be calculated on-demand when
|
||||||
fetching metadata from Civitai.
|
fetching metadata from Civitai.
|
||||||
@@ -38,13 +51,13 @@ class CheckpointScanner(ModelScanner):
|
|||||||
if not os.path.exists(real_path):
|
if not os.path.exists(real_path):
|
||||||
logger.error(f"File not found: {file_path}")
|
logger.error(f"File not found: {file_path}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
base_name = os.path.splitext(os.path.basename(file_path))[0]
|
base_name = os.path.splitext(os.path.basename(file_path))[0]
|
||||||
dir_path = os.path.dirname(file_path)
|
dir_path = os.path.dirname(file_path)
|
||||||
|
|
||||||
# Find preview image
|
# Find preview image
|
||||||
preview_url = find_preview_file(base_name, dir_path)
|
preview_url = find_preview_file(base_name, dir_path)
|
||||||
|
|
||||||
# Create metadata WITHOUT calculating hash
|
# Create metadata WITHOUT calculating hash
|
||||||
metadata = CheckpointMetadata(
|
metadata = CheckpointMetadata(
|
||||||
file_name=base_name,
|
file_name=base_name,
|
||||||
@@ -59,70 +72,76 @@ class CheckpointScanner(ModelScanner):
|
|||||||
modelDescription="",
|
modelDescription="",
|
||||||
sub_type="checkpoint",
|
sub_type="checkpoint",
|
||||||
from_civitai=False, # Mark as local model since no hash yet
|
from_civitai=False, # Mark as local model since no hash yet
|
||||||
hash_status="pending" # Mark hash as pending
|
hash_status="pending", # Mark hash as pending
|
||||||
)
|
)
|
||||||
|
|
||||||
# Save the created metadata
|
# Save the created metadata
|
||||||
logger.info(f"Creating checkpoint metadata (hash pending) for {file_path}")
|
logger.info(f"Creating checkpoint metadata (hash pending) for {file_path}")
|
||||||
await MetadataManager.save_metadata(file_path, metadata)
|
await MetadataManager.save_metadata(file_path, metadata)
|
||||||
|
|
||||||
return metadata
|
return metadata
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error creating default checkpoint metadata for {file_path}: {e}")
|
logger.error(
|
||||||
|
f"Error creating default checkpoint metadata for {file_path}: {e}"
|
||||||
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def calculate_hash_for_model(self, file_path: str) -> Optional[str]:
|
async def calculate_hash_for_model(self, file_path: str) -> Optional[str]:
|
||||||
"""Calculate hash for a checkpoint on-demand.
|
"""Calculate hash for a checkpoint on-demand.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
file_path: Path to the model file
|
file_path: Path to the model file
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
SHA256 hash string, or None if calculation failed
|
SHA256 hash string, or None if calculation failed
|
||||||
"""
|
"""
|
||||||
from ..utils.file_utils import calculate_sha256
|
from ..utils.file_utils import calculate_sha256
|
||||||
|
|
||||||
try:
|
try:
|
||||||
real_path = os.path.realpath(file_path)
|
real_path = os.path.realpath(file_path)
|
||||||
if not os.path.exists(real_path):
|
if not os.path.exists(real_path):
|
||||||
logger.error(f"File not found for hash calculation: {file_path}")
|
logger.error(f"File not found for hash calculation: {file_path}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Load current metadata
|
# Load current metadata
|
||||||
metadata, _ = await MetadataManager.load_metadata(file_path, self.model_class)
|
metadata, _ = await MetadataManager.load_metadata(
|
||||||
|
file_path, self.model_class
|
||||||
|
)
|
||||||
if metadata is None:
|
if metadata is None:
|
||||||
logger.error(f"No metadata found for {file_path}")
|
logger.error(f"No metadata found for {file_path}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Check if hash is already calculated
|
# Check if hash is already calculated
|
||||||
if metadata.hash_status == "completed" and metadata.sha256:
|
if metadata.hash_status == "completed" and metadata.sha256:
|
||||||
return metadata.sha256
|
return metadata.sha256
|
||||||
|
|
||||||
# Update status to calculating
|
# Update status to calculating
|
||||||
metadata.hash_status = "calculating"
|
metadata.hash_status = "calculating"
|
||||||
await MetadataManager.save_metadata(file_path, metadata)
|
await MetadataManager.save_metadata(file_path, metadata)
|
||||||
|
|
||||||
# Calculate hash
|
# Calculate hash
|
||||||
logger.info(f"Calculating hash for checkpoint: {file_path}")
|
logger.info(f"Calculating hash for checkpoint: {file_path}")
|
||||||
sha256 = await calculate_sha256(real_path)
|
sha256 = await calculate_sha256(real_path)
|
||||||
|
|
||||||
# Update metadata with hash
|
# Update metadata with hash
|
||||||
metadata.sha256 = sha256
|
metadata.sha256 = sha256
|
||||||
metadata.hash_status = "completed"
|
metadata.hash_status = "completed"
|
||||||
await MetadataManager.save_metadata(file_path, metadata)
|
await MetadataManager.save_metadata(file_path, metadata)
|
||||||
|
|
||||||
# Update hash index
|
# Update hash index
|
||||||
self._hash_index.add_entry(sha256.lower(), file_path)
|
self._hash_index.add_entry(sha256.lower(), file_path)
|
||||||
|
|
||||||
logger.info(f"Hash calculated for checkpoint: {file_path}")
|
logger.info(f"Hash calculated for checkpoint: {file_path}")
|
||||||
return sha256
|
return sha256
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error calculating hash for {file_path}: {e}")
|
logger.error(f"Error calculating hash for {file_path}: {e}")
|
||||||
# Update status to failed
|
# Update status to failed
|
||||||
try:
|
try:
|
||||||
metadata, _ = await MetadataManager.load_metadata(file_path, self.model_class)
|
metadata, _ = await MetadataManager.load_metadata(
|
||||||
|
file_path, self.model_class
|
||||||
|
)
|
||||||
if metadata:
|
if metadata:
|
||||||
metadata.hash_status = "failed"
|
metadata.hash_status = "failed"
|
||||||
await MetadataManager.save_metadata(file_path, metadata)
|
await MetadataManager.save_metadata(file_path, metadata)
|
||||||
@@ -130,43 +149,46 @@ class CheckpointScanner(ModelScanner):
|
|||||||
pass
|
pass
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def calculate_all_pending_hashes(self, progress_callback=None) -> Dict[str, int]:
|
async def calculate_all_pending_hashes(
|
||||||
|
self, progress_callback=None
|
||||||
|
) -> Dict[str, int]:
|
||||||
"""Calculate hashes for all checkpoints with pending hash status.
|
"""Calculate hashes for all checkpoints with pending hash status.
|
||||||
|
|
||||||
If cache is not initialized, scans filesystem directly for metadata files
|
If cache is not initialized, scans filesystem directly for metadata files
|
||||||
with hash_status != 'completed'.
|
with hash_status != 'completed'.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
progress_callback: Optional callback(progress, total, current_file)
|
progress_callback: Optional callback(progress, total, current_file)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict with 'completed', 'failed', 'total' counts
|
Dict with 'completed', 'failed', 'total' counts
|
||||||
"""
|
"""
|
||||||
# Try to get from cache first
|
# Try to get from cache first
|
||||||
cache = await self.get_cached_data()
|
cache = await self.get_cached_data()
|
||||||
|
|
||||||
if cache and cache.raw_data:
|
if cache and cache.raw_data:
|
||||||
# Use cache if available
|
# Use cache if available
|
||||||
pending_models = [
|
pending_models = [
|
||||||
item for item in cache.raw_data
|
item
|
||||||
if item.get('hash_status') != 'completed' or not item.get('sha256')
|
for item in cache.raw_data
|
||||||
|
if item.get("hash_status") != "completed" or not item.get("sha256")
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
# Cache not initialized, scan filesystem directly
|
# Cache not initialized, scan filesystem directly
|
||||||
pending_models = await self._find_pending_models_from_filesystem()
|
pending_models = await self._find_pending_models_from_filesystem()
|
||||||
|
|
||||||
if not pending_models:
|
if not pending_models:
|
||||||
return {'completed': 0, 'failed': 0, 'total': 0}
|
return {"completed": 0, "failed": 0, "total": 0}
|
||||||
|
|
||||||
total = len(pending_models)
|
total = len(pending_models)
|
||||||
completed = 0
|
completed = 0
|
||||||
failed = 0
|
failed = 0
|
||||||
|
|
||||||
for i, model_data in enumerate(pending_models):
|
for i, model_data in enumerate(pending_models):
|
||||||
file_path = model_data.get('file_path')
|
file_path = model_data.get("file_path")
|
||||||
if not file_path:
|
if not file_path:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
sha256 = await self.calculate_hash_for_model(file_path)
|
sha256 = await self.calculate_hash_for_model(file_path)
|
||||||
if sha256:
|
if sha256:
|
||||||
@@ -176,77 +198,102 @@ class CheckpointScanner(ModelScanner):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error calculating hash for {file_path}: {e}")
|
logger.error(f"Error calculating hash for {file_path}: {e}")
|
||||||
failed += 1
|
failed += 1
|
||||||
|
|
||||||
if progress_callback:
|
if progress_callback:
|
||||||
try:
|
try:
|
||||||
await progress_callback(i + 1, total, file_path)
|
await progress_callback(i + 1, total, file_path)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return {
|
return {"completed": completed, "failed": failed, "total": total}
|
||||||
'completed': completed,
|
|
||||||
'failed': failed,
|
|
||||||
'total': total
|
|
||||||
}
|
|
||||||
|
|
||||||
async def _find_pending_models_from_filesystem(self) -> List[Dict[str, Any]]:
|
async def _find_pending_models_from_filesystem(self) -> List[Dict[str, Any]]:
|
||||||
"""Scan filesystem for checkpoint metadata files with pending hash status."""
|
"""Scan filesystem for checkpoint metadata files with pending hash status."""
|
||||||
pending_models = []
|
pending_models = []
|
||||||
|
|
||||||
for root_path in self.get_model_roots():
|
for root_path in self.get_model_roots():
|
||||||
if not os.path.exists(root_path):
|
if not os.path.exists(root_path):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for dirpath, _dirnames, filenames in os.walk(root_path):
|
for dirpath, _dirnames, filenames in os.walk(root_path):
|
||||||
for filename in filenames:
|
for filename in filenames:
|
||||||
if not filename.endswith('.metadata.json'):
|
if not filename.endswith(".metadata.json"):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
metadata_path = os.path.join(dirpath, filename)
|
metadata_path = os.path.join(dirpath, filename)
|
||||||
try:
|
try:
|
||||||
with open(metadata_path, 'r', encoding='utf-8') as f:
|
with open(metadata_path, "r", encoding="utf-8") as f:
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
|
|
||||||
# Check if hash is pending
|
# Check if hash is pending
|
||||||
hash_status = data.get('hash_status', 'completed')
|
hash_status = data.get("hash_status", "completed")
|
||||||
sha256 = data.get('sha256', '')
|
sha256 = data.get("sha256", "")
|
||||||
|
|
||||||
if hash_status != 'completed' or not sha256:
|
if hash_status != "completed" or not sha256:
|
||||||
# Find corresponding model file
|
# Find corresponding model file
|
||||||
model_name = filename.replace('.metadata.json', '')
|
model_name = filename.replace(".metadata.json", "")
|
||||||
model_path = None
|
model_path = None
|
||||||
|
|
||||||
# Look for model file with matching name
|
# Look for model file with matching name
|
||||||
for ext in self.file_extensions:
|
for ext in self.file_extensions:
|
||||||
potential_path = os.path.join(dirpath, model_name + ext)
|
potential_path = os.path.join(dirpath, model_name + ext)
|
||||||
if os.path.exists(potential_path):
|
if os.path.exists(potential_path):
|
||||||
model_path = potential_path
|
model_path = potential_path
|
||||||
break
|
break
|
||||||
|
|
||||||
if model_path:
|
if model_path:
|
||||||
pending_models.append({
|
pending_models.append(
|
||||||
'file_path': model_path.replace(os.sep, '/'),
|
{
|
||||||
'hash_status': hash_status,
|
"file_path": model_path.replace(os.sep, "/"),
|
||||||
'sha256': sha256,
|
"hash_status": hash_status,
|
||||||
**{k: v for k, v in data.items() if k not in ['file_path', 'hash_status', 'sha256']}
|
"sha256": sha256,
|
||||||
})
|
**{
|
||||||
|
k: v
|
||||||
|
for k, v in data.items()
|
||||||
|
if k
|
||||||
|
not in [
|
||||||
|
"file_path",
|
||||||
|
"hash_status",
|
||||||
|
"sha256",
|
||||||
|
]
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
except (json.JSONDecodeError, Exception) as e:
|
except (json.JSONDecodeError, Exception) as e:
|
||||||
logger.debug(f"Error reading metadata file {metadata_path}: {e}")
|
logger.debug(
|
||||||
|
f"Error reading metadata file {metadata_path}: {e}"
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
return pending_models
|
return pending_models
|
||||||
|
|
||||||
def _resolve_sub_type(self, root_path: Optional[str]) -> Optional[str]:
|
def _resolve_sub_type(self, root_path: Optional[str]) -> Optional[str]:
|
||||||
"""Resolve the sub-type based on the root path."""
|
"""Resolve the sub-type based on the root path.
|
||||||
|
|
||||||
|
Checks both standard ComfyUI paths and LoRA Manager's extra folder paths.
|
||||||
|
"""
|
||||||
if not root_path:
|
if not root_path:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
# Check standard ComfyUI checkpoint paths
|
||||||
if config.checkpoints_roots and root_path in config.checkpoints_roots:
|
if config.checkpoints_roots and root_path in config.checkpoints_roots:
|
||||||
return "checkpoint"
|
return "checkpoint"
|
||||||
|
|
||||||
|
# Check extra checkpoint paths
|
||||||
|
if (
|
||||||
|
config.extra_checkpoints_roots
|
||||||
|
and root_path in config.extra_checkpoints_roots
|
||||||
|
):
|
||||||
|
return "checkpoint"
|
||||||
|
|
||||||
|
# Check standard ComfyUI unet paths
|
||||||
if config.unet_roots and root_path in config.unet_roots:
|
if config.unet_roots and root_path in config.unet_roots:
|
||||||
return "diffusion_model"
|
return "diffusion_model"
|
||||||
|
|
||||||
|
# Check extra unet paths
|
||||||
|
if config.extra_unet_roots and root_path in config.extra_unet_roots:
|
||||||
|
return "diffusion_model"
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def adjust_metadata(self, metadata, file_path, root_path):
|
def adjust_metadata(self, metadata, file_path, root_path):
|
||||||
|
|||||||
@@ -112,6 +112,112 @@ def get_lora_info_absolute(lora_name):
|
|||||||
return asyncio.run(_get_lora_info_absolute_async())
|
return asyncio.run(_get_lora_info_absolute_async())
|
||||||
|
|
||||||
|
|
||||||
|
def get_checkpoint_info_absolute(checkpoint_name):
|
||||||
|
"""Get the absolute checkpoint path and metadata from cache
|
||||||
|
|
||||||
|
Supports ComfyUI-style model names (e.g., "folder/model_name.ext")
|
||||||
|
|
||||||
|
Args:
|
||||||
|
checkpoint_name: The model name, can be:
|
||||||
|
- ComfyUI format: "folder/model_name.safetensors"
|
||||||
|
- Simple name: "model_name"
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (absolute_path, metadata) where absolute_path is the full
|
||||||
|
file system path to the checkpoint file, or original checkpoint_name if not found,
|
||||||
|
metadata is the full model metadata dict or None
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def _get_checkpoint_info_absolute_async():
|
||||||
|
from ..services.service_registry import ServiceRegistry
|
||||||
|
|
||||||
|
scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||||
|
cache = await scanner.get_cached_data()
|
||||||
|
|
||||||
|
# Get model roots for matching
|
||||||
|
model_roots = scanner.get_model_roots()
|
||||||
|
|
||||||
|
# Normalize the checkpoint name
|
||||||
|
normalized_name = checkpoint_name.replace(os.sep, "/")
|
||||||
|
|
||||||
|
for item in cache.raw_data:
|
||||||
|
file_path = item.get("file_path", "")
|
||||||
|
if not file_path:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Format the stored path as ComfyUI-style name
|
||||||
|
formatted_name = _format_model_name_for_comfyui(file_path, model_roots)
|
||||||
|
|
||||||
|
# Match by formatted name
|
||||||
|
if formatted_name == normalized_name or formatted_name == checkpoint_name:
|
||||||
|
return file_path, item
|
||||||
|
|
||||||
|
# Also try matching by basename only (for backward compatibility)
|
||||||
|
file_name = item.get("file_name", "")
|
||||||
|
if (
|
||||||
|
file_name == checkpoint_name
|
||||||
|
or file_name == os.path.splitext(normalized_name)[0]
|
||||||
|
):
|
||||||
|
return file_path, item
|
||||||
|
|
||||||
|
return checkpoint_name, None
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Check if we're already in an event loop
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
# If we're in a running loop, we need to use a different approach
|
||||||
|
# Create a new thread to run the async code
|
||||||
|
import concurrent.futures
|
||||||
|
|
||||||
|
def run_in_thread():
|
||||||
|
new_loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(new_loop)
|
||||||
|
try:
|
||||||
|
return new_loop.run_until_complete(
|
||||||
|
_get_checkpoint_info_absolute_async()
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
new_loop.close()
|
||||||
|
|
||||||
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||||
|
future = executor.submit(run_in_thread)
|
||||||
|
return future.result()
|
||||||
|
|
||||||
|
except RuntimeError:
|
||||||
|
# No event loop is running, we can use asyncio.run()
|
||||||
|
return asyncio.run(_get_checkpoint_info_absolute_async())
|
||||||
|
|
||||||
|
|
||||||
|
def _format_model_name_for_comfyui(file_path: str, model_roots: list) -> str:
|
||||||
|
"""Format file path to ComfyUI-style model name (relative path with extension)
|
||||||
|
|
||||||
|
Example: /path/to/checkpoints/Illustrious/model.safetensors -> Illustrious/model.safetensors
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Absolute path to the model file
|
||||||
|
model_roots: List of model root directories
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ComfyUI-style model name with relative path and extension
|
||||||
|
"""
|
||||||
|
# Normalize path separators
|
||||||
|
normalized_path = file_path.replace(os.sep, "/")
|
||||||
|
|
||||||
|
# Find the matching root and get relative path
|
||||||
|
for root in model_roots:
|
||||||
|
normalized_root = root.replace(os.sep, "/")
|
||||||
|
# Ensure root ends with / for proper matching
|
||||||
|
if not normalized_root.endswith("/"):
|
||||||
|
normalized_root += "/"
|
||||||
|
|
||||||
|
if normalized_path.startswith(normalized_root):
|
||||||
|
rel_path = normalized_path[len(normalized_root) :]
|
||||||
|
return rel_path
|
||||||
|
|
||||||
|
# If no root matches, just return the basename with extension
|
||||||
|
return os.path.basename(file_path)
|
||||||
|
|
||||||
|
|
||||||
def fuzzy_match(text: str, pattern: str, threshold: float = 0.85) -> bool:
|
def fuzzy_match(text: str, pattern: str, threshold: float = 0.85) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if text matches pattern using fuzzy matching.
|
Check if text matches pattern using fuzzy matching.
|
||||||
|
|||||||
@@ -36,8 +36,8 @@ class TestCheckpointPathOverlap:
|
|||||||
config._preview_root_paths = set()
|
config._preview_root_paths = set()
|
||||||
config._cached_fingerprint = None
|
config._cached_fingerprint = None
|
||||||
|
|
||||||
# Call the method under test
|
# Call the method under test - now returns a tuple
|
||||||
result = config._prepare_checkpoint_paths(
|
all_paths, checkpoint_roots, unet_roots = config._prepare_checkpoint_paths(
|
||||||
[str(checkpoints_link)], [str(unet_link)]
|
[str(checkpoints_link)], [str(unet_link)]
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -50,21 +50,27 @@ class TestCheckpointPathOverlap:
|
|||||||
]
|
]
|
||||||
assert len(warning_messages) == 1
|
assert len(warning_messages) == 1
|
||||||
assert "checkpoints" in warning_messages[0].lower()
|
assert "checkpoints" in warning_messages[0].lower()
|
||||||
assert "diffusion_models" in warning_messages[0].lower() or "unet" in warning_messages[0].lower()
|
assert (
|
||||||
|
"diffusion_models" in warning_messages[0].lower()
|
||||||
|
or "unet" in warning_messages[0].lower()
|
||||||
|
)
|
||||||
# Verify warning mentions backward compatibility fallback
|
# Verify warning mentions backward compatibility fallback
|
||||||
assert "falling back" in warning_messages[0].lower() or "backward compatibility" in warning_messages[0].lower()
|
assert (
|
||||||
|
"falling back" in warning_messages[0].lower()
|
||||||
|
or "backward compatibility" in warning_messages[0].lower()
|
||||||
|
)
|
||||||
|
|
||||||
# Verify only one path is returned (deduplication still works)
|
# Verify only one path is returned (deduplication still works)
|
||||||
assert len(result) == 1
|
assert len(all_paths) == 1
|
||||||
# Prioritizes checkpoints path for backward compatibility
|
# Prioritizes checkpoints path for backward compatibility
|
||||||
assert _normalize(result[0]) == _normalize(str(checkpoints_link))
|
assert _normalize(all_paths[0]) == _normalize(str(checkpoints_link))
|
||||||
|
|
||||||
# Verify checkpoints_roots has the path (prioritized)
|
# Verify checkpoint_roots has the path (prioritized)
|
||||||
assert len(config.checkpoints_roots) == 1
|
assert len(checkpoint_roots) == 1
|
||||||
assert _normalize(config.checkpoints_roots[0]) == _normalize(str(checkpoints_link))
|
assert _normalize(checkpoint_roots[0]) == _normalize(str(checkpoints_link))
|
||||||
|
|
||||||
# Verify unet_roots is empty (overlapping paths removed)
|
# Verify unet_roots is empty (overlapping paths removed)
|
||||||
assert config.unet_roots == []
|
assert unet_roots == []
|
||||||
|
|
||||||
def test_non_overlapping_paths_no_warning(
|
def test_non_overlapping_paths_no_warning(
|
||||||
self, monkeypatch: pytest.MonkeyPatch, tmp_path, caplog
|
self, monkeypatch: pytest.MonkeyPatch, tmp_path, caplog
|
||||||
@@ -83,7 +89,7 @@ class TestCheckpointPathOverlap:
|
|||||||
config._preview_root_paths = set()
|
config._preview_root_paths = set()
|
||||||
config._cached_fingerprint = None
|
config._cached_fingerprint = None
|
||||||
|
|
||||||
result = config._prepare_checkpoint_paths(
|
all_paths, checkpoint_roots, unet_roots = config._prepare_checkpoint_paths(
|
||||||
[str(checkpoints_dir)], [str(unet_dir)]
|
[str(checkpoints_dir)], [str(unet_dir)]
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -97,14 +103,14 @@ class TestCheckpointPathOverlap:
|
|||||||
assert len(warning_messages) == 0
|
assert len(warning_messages) == 0
|
||||||
|
|
||||||
# Verify both paths are returned
|
# Verify both paths are returned
|
||||||
assert len(result) == 2
|
assert len(all_paths) == 2
|
||||||
normalized_result = [_normalize(p) for p in result]
|
normalized_result = [_normalize(p) for p in all_paths]
|
||||||
assert _normalize(str(checkpoints_dir)) in normalized_result
|
assert _normalize(str(checkpoints_dir)) in normalized_result
|
||||||
assert _normalize(str(unet_dir)) in normalized_result
|
assert _normalize(str(unet_dir)) in normalized_result
|
||||||
|
|
||||||
# Verify both roots are properly set
|
# Verify both roots are properly set
|
||||||
assert len(config.checkpoints_roots) == 1
|
assert len(checkpoint_roots) == 1
|
||||||
assert len(config.unet_roots) == 1
|
assert len(unet_roots) == 1
|
||||||
|
|
||||||
def test_partial_overlap_prioritizes_checkpoints(
|
def test_partial_overlap_prioritizes_checkpoints(
|
||||||
self, monkeypatch: pytest.MonkeyPatch, tmp_path, caplog
|
self, monkeypatch: pytest.MonkeyPatch, tmp_path, caplog
|
||||||
@@ -129,9 +135,9 @@ class TestCheckpointPathOverlap:
|
|||||||
config._cached_fingerprint = None
|
config._cached_fingerprint = None
|
||||||
|
|
||||||
# One checkpoint path overlaps with one unet path
|
# One checkpoint path overlaps with one unet path
|
||||||
result = config._prepare_checkpoint_paths(
|
all_paths, checkpoint_roots, unet_roots = config._prepare_checkpoint_paths(
|
||||||
[str(shared_link), str(separate_checkpoint)],
|
[str(shared_link), str(separate_checkpoint)],
|
||||||
[str(shared_link), str(separate_unet)]
|
[str(shared_link), str(separate_unet)],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify warning was logged for the overlapping path
|
# Verify warning was logged for the overlapping path
|
||||||
@@ -144,17 +150,20 @@ class TestCheckpointPathOverlap:
|
|||||||
assert len(warning_messages) == 1
|
assert len(warning_messages) == 1
|
||||||
|
|
||||||
# Verify 3 unique paths (shared counted once as checkpoint, plus separate ones)
|
# Verify 3 unique paths (shared counted once as checkpoint, plus separate ones)
|
||||||
assert len(result) == 3
|
assert len(all_paths) == 3
|
||||||
|
|
||||||
# Verify the overlapping path appears in warning message
|
# Verify the overlapping path appears in warning message
|
||||||
assert str(shared_link.name) in warning_messages[0] or str(shared_dir.name) in warning_messages[0]
|
assert (
|
||||||
|
str(shared_link.name) in warning_messages[0]
|
||||||
|
or str(shared_dir.name) in warning_messages[0]
|
||||||
|
)
|
||||||
|
|
||||||
# Verify checkpoints_roots includes both checkpoint paths (including the shared one)
|
# Verify checkpoint_roots includes both checkpoint paths (including the shared one)
|
||||||
assert len(config.checkpoints_roots) == 2
|
assert len(checkpoint_roots) == 2
|
||||||
checkpoint_normalized = [_normalize(p) for p in config.checkpoints_roots]
|
checkpoint_normalized = [_normalize(p) for p in checkpoint_roots]
|
||||||
assert _normalize(str(shared_link)) in checkpoint_normalized
|
assert _normalize(str(shared_link)) in checkpoint_normalized
|
||||||
assert _normalize(str(separate_checkpoint)) in checkpoint_normalized
|
assert _normalize(str(separate_checkpoint)) in checkpoint_normalized
|
||||||
|
|
||||||
# Verify unet_roots only includes the non-overlapping unet path
|
# Verify unet_roots only includes the non-overlapping unet path
|
||||||
assert len(config.unet_roots) == 1
|
assert len(unet_roots) == 1
|
||||||
assert _normalize(config.unet_roots[0]) == _normalize(str(separate_unet))
|
assert _normalize(unet_roots[0]) == _normalize(str(separate_unet))
|
||||||
|
|||||||
158
tests/test_checkpoint_loaders.py
Normal file
158
tests/test_checkpoint_loaders.py
Normal file
@@ -0,0 +1,158 @@
|
|||||||
|
"""Tests for checkpoint and unet loaders with extra folder paths support"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
# Get project root directory (ComfyUI-Lora-Manager folder)
|
||||||
|
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
|
||||||
|
|
||||||
|
class TestCheckpointLoaderLM:
|
||||||
|
"""Test CheckpointLoaderLM node"""
|
||||||
|
|
||||||
|
def test_class_attributes(self):
|
||||||
|
"""Test that CheckpointLoaderLM has required class attributes"""
|
||||||
|
# Import in a way that doesn't require ComfyUI
|
||||||
|
import ast
|
||||||
|
|
||||||
|
filepath = os.path.join(PROJECT_ROOT, "py", "nodes", "checkpoint_loader.py")
|
||||||
|
|
||||||
|
with open(filepath, "r") as f:
|
||||||
|
tree = ast.parse(f.read())
|
||||||
|
|
||||||
|
# Find CheckpointLoaderLM class
|
||||||
|
classes = {
|
||||||
|
node.name: node for node in ast.walk(tree) if isinstance(node, ast.ClassDef)
|
||||||
|
}
|
||||||
|
assert "CheckpointLoaderLM" in classes
|
||||||
|
|
||||||
|
cls = classes["CheckpointLoaderLM"]
|
||||||
|
|
||||||
|
# Check for NAME attribute
|
||||||
|
name_attr = [
|
||||||
|
n
|
||||||
|
for n in cls.body
|
||||||
|
if isinstance(n, ast.Assign)
|
||||||
|
and any(t.id == "NAME" for t in n.targets if isinstance(t, ast.Name))
|
||||||
|
]
|
||||||
|
assert len(name_attr) > 0, "CheckpointLoaderLM should have NAME attribute"
|
||||||
|
|
||||||
|
# Check for CATEGORY attribute
|
||||||
|
cat_attr = [
|
||||||
|
n
|
||||||
|
for n in cls.body
|
||||||
|
if isinstance(n, ast.Assign)
|
||||||
|
and any(t.id == "CATEGORY" for t in n.targets if isinstance(t, ast.Name))
|
||||||
|
]
|
||||||
|
assert len(cat_attr) > 0, "CheckpointLoaderLM should have CATEGORY attribute"
|
||||||
|
|
||||||
|
# Check for INPUT_TYPES method
|
||||||
|
input_types = [
|
||||||
|
n
|
||||||
|
for n in cls.body
|
||||||
|
if isinstance(n, ast.FunctionDef) and n.name == "INPUT_TYPES"
|
||||||
|
]
|
||||||
|
assert len(input_types) > 0, "CheckpointLoaderLM should have INPUT_TYPES method"
|
||||||
|
|
||||||
|
# Check for load_checkpoint method
|
||||||
|
load_method = [
|
||||||
|
n
|
||||||
|
for n in cls.body
|
||||||
|
if isinstance(n, ast.FunctionDef) and n.name == "load_checkpoint"
|
||||||
|
]
|
||||||
|
assert len(load_method) > 0, (
|
||||||
|
"CheckpointLoaderLM should have load_checkpoint method"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestUNETLoaderLM:
|
||||||
|
"""Test UNETLoaderLM node"""
|
||||||
|
|
||||||
|
def test_class_attributes(self):
|
||||||
|
"""Test that UNETLoaderLM has required class attributes"""
|
||||||
|
# Import in a way that doesn't require ComfyUI
|
||||||
|
import ast
|
||||||
|
|
||||||
|
filepath = os.path.join(PROJECT_ROOT, "py", "nodes", "unet_loader.py")
|
||||||
|
|
||||||
|
with open(filepath, "r") as f:
|
||||||
|
tree = ast.parse(f.read())
|
||||||
|
|
||||||
|
# Find UNETLoaderLM class
|
||||||
|
classes = {
|
||||||
|
node.name: node for node in ast.walk(tree) if isinstance(node, ast.ClassDef)
|
||||||
|
}
|
||||||
|
assert "UNETLoaderLM" in classes
|
||||||
|
|
||||||
|
cls = classes["UNETLoaderLM"]
|
||||||
|
|
||||||
|
# Check for NAME attribute
|
||||||
|
name_attr = [
|
||||||
|
n
|
||||||
|
for n in cls.body
|
||||||
|
if isinstance(n, ast.Assign)
|
||||||
|
and any(t.id == "NAME" for t in n.targets if isinstance(t, ast.Name))
|
||||||
|
]
|
||||||
|
assert len(name_attr) > 0, "UNETLoaderLM should have NAME attribute"
|
||||||
|
|
||||||
|
# Check for CATEGORY attribute
|
||||||
|
cat_attr = [
|
||||||
|
n
|
||||||
|
for n in cls.body
|
||||||
|
if isinstance(n, ast.Assign)
|
||||||
|
and any(t.id == "CATEGORY" for t in n.targets if isinstance(t, ast.Name))
|
||||||
|
]
|
||||||
|
assert len(cat_attr) > 0, "UNETLoaderLM should have CATEGORY attribute"
|
||||||
|
|
||||||
|
# Check for INPUT_TYPES method
|
||||||
|
input_types = [
|
||||||
|
n
|
||||||
|
for n in cls.body
|
||||||
|
if isinstance(n, ast.FunctionDef) and n.name == "INPUT_TYPES"
|
||||||
|
]
|
||||||
|
assert len(input_types) > 0, "UNETLoaderLM should have INPUT_TYPES method"
|
||||||
|
|
||||||
|
# Check for load_unet method
|
||||||
|
load_method = [
|
||||||
|
n
|
||||||
|
for n in cls.body
|
||||||
|
if isinstance(n, ast.FunctionDef) and n.name == "load_unet"
|
||||||
|
]
|
||||||
|
assert len(load_method) > 0, "UNETLoaderLM should have load_unet method"
|
||||||
|
|
||||||
|
|
||||||
|
class TestUtils:
|
||||||
|
"""Test utility functions"""
|
||||||
|
|
||||||
|
def test_get_checkpoint_info_absolute_exists(self):
|
||||||
|
"""Test that get_checkpoint_info_absolute function exists in utils"""
|
||||||
|
import ast
|
||||||
|
|
||||||
|
filepath = os.path.join(PROJECT_ROOT, "py", "utils", "utils.py")
|
||||||
|
|
||||||
|
with open(filepath, "r") as f:
|
||||||
|
tree = ast.parse(f.read())
|
||||||
|
|
||||||
|
functions = [
|
||||||
|
node.name for node in ast.walk(tree) if isinstance(node, ast.FunctionDef)
|
||||||
|
]
|
||||||
|
assert "get_checkpoint_info_absolute" in functions, (
|
||||||
|
"get_checkpoint_info_absolute should exist"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_format_model_name_for_comfyui_exists(self):
|
||||||
|
"""Test that _format_model_name_for_comfyui function exists in utils"""
|
||||||
|
import ast
|
||||||
|
|
||||||
|
filepath = os.path.join(PROJECT_ROOT, "py", "utils", "utils.py")
|
||||||
|
|
||||||
|
with open(filepath, "r") as f:
|
||||||
|
tree = ast.parse(f.read())
|
||||||
|
|
||||||
|
functions = [
|
||||||
|
node.name for node in ast.walk(tree) if isinstance(node, ast.FunctionDef)
|
||||||
|
]
|
||||||
|
assert "_format_model_name_for_comfyui" in functions, (
|
||||||
|
"_format_model_name_for_comfyui should exist"
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user