mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-24 14:42:11 -03:00
fix: isolate extra unet paths from checkpoints to prevent type misclassification
Refactor _prepare_checkpoint_paths() to return a tuple instead of having side effects on instance variables. This prevents extra unet paths from being incorrectly classified as checkpoints when processing extra paths. - Changed return type from List[str] to Tuple[List[str], List[str], List[str]] (all_paths, checkpoint_roots, unet_roots) - Updated _init_checkpoint_paths() and _apply_library_paths() callers - Fixed extra paths processing to properly isolate main and extra roots - Updated test_checkpoint_path_overlap.py tests for new API This ensures models in extra unet paths are correctly identified as diffusion_model type and don't appear in checkpoints list.
This commit is contained in:
47
py/config.py
47
py/config.py
@@ -707,7 +707,13 @@ class Config:
|
||||
|
||||
def _prepare_checkpoint_paths(
|
||||
self, checkpoint_paths: Iterable[str], unet_paths: Iterable[str]
|
||||
) -> List[str]:
|
||||
) -> Tuple[List[str], List[str], List[str]]:
|
||||
"""Prepare checkpoint paths and return (all_roots, checkpoint_roots, unet_roots).
|
||||
|
||||
Returns:
|
||||
Tuple of (all_unique_paths, checkpoint_only_paths, unet_only_paths)
|
||||
This method does NOT modify instance variables - callers must set them.
|
||||
"""
|
||||
checkpoint_map = self._dedupe_existing_paths(checkpoint_paths)
|
||||
unet_map = self._dedupe_existing_paths(unet_paths)
|
||||
|
||||
@@ -737,8 +743,8 @@ class Config:
|
||||
|
||||
checkpoint_values = set(checkpoint_map.values())
|
||||
unet_values = set(unet_map.values())
|
||||
self.checkpoints_roots = [p for p in unique_paths if p in checkpoint_values]
|
||||
self.unet_roots = [p for p in unique_paths if p in unet_values]
|
||||
checkpoint_roots = [p for p in unique_paths if p in checkpoint_values]
|
||||
unet_roots = [p for p in unique_paths if p in unet_values]
|
||||
|
||||
for original_path in unique_paths:
|
||||
real_path = os.path.normpath(os.path.realpath(original_path)).replace(
|
||||
@@ -747,7 +753,7 @@ class Config:
|
||||
if real_path != original_path:
|
||||
self.add_path_mapping(original_path, real_path)
|
||||
|
||||
return unique_paths
|
||||
return unique_paths, checkpoint_roots, unet_roots
|
||||
|
||||
def _prepare_embedding_paths(self, raw_paths: Iterable[str]) -> List[str]:
|
||||
path_map = self._dedupe_existing_paths(raw_paths)
|
||||
@@ -776,9 +782,11 @@ class Config:
|
||||
embedding_paths = folder_paths.get("embeddings", []) or []
|
||||
|
||||
self.loras_roots = self._prepare_lora_paths(lora_paths)
|
||||
self.base_models_roots = self._prepare_checkpoint_paths(
|
||||
checkpoint_paths, unet_paths
|
||||
)
|
||||
(
|
||||
self.base_models_roots,
|
||||
self.checkpoints_roots,
|
||||
self.unet_roots,
|
||||
) = self._prepare_checkpoint_paths(checkpoint_paths, unet_paths)
|
||||
self.embeddings_roots = self._prepare_embedding_paths(embedding_paths)
|
||||
|
||||
# Process extra paths (only for LoRA Manager, not shared with ComfyUI)
|
||||
@@ -789,18 +797,11 @@ class Config:
|
||||
extra_embedding_paths = extra_paths.get("embeddings", []) or []
|
||||
|
||||
self.extra_loras_roots = self._prepare_lora_paths(extra_lora_paths)
|
||||
# Save main paths before processing extra paths ( _prepare_checkpoint_paths overwrites them)
|
||||
saved_checkpoints_roots = self.checkpoints_roots
|
||||
saved_unet_roots = self.unet_roots
|
||||
self.extra_checkpoints_roots = self._prepare_checkpoint_paths(
|
||||
extra_checkpoint_paths, extra_unet_paths
|
||||
)
|
||||
self.extra_unet_roots = (
|
||||
self.unet_roots if self.unet_roots is not None else []
|
||||
) # unet_roots was set by _prepare_checkpoint_paths
|
||||
# Restore main paths
|
||||
self.checkpoints_roots = saved_checkpoints_roots
|
||||
self.unet_roots = saved_unet_roots
|
||||
(
|
||||
_,
|
||||
self.extra_checkpoints_roots,
|
||||
self.extra_unet_roots,
|
||||
) = self._prepare_checkpoint_paths(extra_checkpoint_paths, extra_unet_paths)
|
||||
self.extra_embeddings_roots = self._prepare_embedding_paths(
|
||||
extra_embedding_paths
|
||||
)
|
||||
@@ -857,9 +858,11 @@ class Config:
|
||||
try:
|
||||
raw_checkpoint_paths = folder_paths.get_folder_paths("checkpoints")
|
||||
raw_unet_paths = folder_paths.get_folder_paths("unet")
|
||||
unique_paths = self._prepare_checkpoint_paths(
|
||||
raw_checkpoint_paths, raw_unet_paths
|
||||
)
|
||||
(
|
||||
unique_paths,
|
||||
self.checkpoints_roots,
|
||||
self.unet_roots,
|
||||
) = self._prepare_checkpoint_paths(raw_checkpoint_paths, raw_unet_paths)
|
||||
|
||||
logger.info(
|
||||
"Found checkpoint roots:"
|
||||
|
||||
184
py/nodes/checkpoint_loader.py
Normal file
184
py/nodes/checkpoint_loader.py
Normal file
@@ -0,0 +1,184 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import List, Tuple
|
||||
import comfy.sd
|
||||
import folder_paths
|
||||
from ..utils.utils import get_checkpoint_info_absolute, _format_model_name_for_comfyui
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CheckpointLoaderLM:
|
||||
"""Checkpoint Loader with support for extra folder paths
|
||||
|
||||
Loads checkpoints from both standard ComfyUI folders and LoRA Manager's
|
||||
extra folder paths, providing a unified interface for checkpoint loading.
|
||||
"""
|
||||
|
||||
NAME = "CheckpointLoaderLM"
|
||||
CATEGORY = "Lora Manager/loaders"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
# Get list of checkpoint names from scanner (includes extra folder paths)
|
||||
checkpoint_names = s._get_checkpoint_names()
|
||||
return {
|
||||
"required": {
|
||||
"ckpt_name": (
|
||||
checkpoint_names,
|
||||
{"tooltip": "The name of the checkpoint (model) to load."},
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
|
||||
RETURN_NAMES = ("MODEL", "CLIP", "VAE")
|
||||
OUTPUT_TOOLTIPS = (
|
||||
"The model used for denoising latents.",
|
||||
"The CLIP model used for encoding text prompts.",
|
||||
"The VAE model used for encoding and decoding images to and from latent space.",
|
||||
)
|
||||
FUNCTION = "load_checkpoint"
|
||||
|
||||
@classmethod
|
||||
def _get_checkpoint_names(cls) -> List[str]:
|
||||
"""Get list of checkpoint names from scanner cache in ComfyUI format (relative path with extension)"""
|
||||
try:
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
import asyncio
|
||||
|
||||
async def _get_names():
|
||||
scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
cache = await scanner.get_cached_data()
|
||||
|
||||
# Get all model roots for calculating relative paths
|
||||
model_roots = scanner.get_model_roots()
|
||||
|
||||
# Filter only checkpoint type (not diffusion_model) and format names
|
||||
names = []
|
||||
for item in cache.raw_data:
|
||||
if item.get("sub_type") == "checkpoint":
|
||||
file_path = item.get("file_path", "")
|
||||
if file_path:
|
||||
# Format as ComfyUI-style: "folder/model_name.ext"
|
||||
formatted_name = _format_model_name_for_comfyui(
|
||||
file_path, model_roots
|
||||
)
|
||||
if formatted_name:
|
||||
names.append(formatted_name)
|
||||
|
||||
return sorted(names)
|
||||
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
import concurrent.futures
|
||||
|
||||
def run_in_thread():
|
||||
new_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(new_loop)
|
||||
try:
|
||||
return new_loop.run_until_complete(_get_names())
|
||||
finally:
|
||||
new_loop.close()
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(run_in_thread)
|
||||
return future.result()
|
||||
except RuntimeError:
|
||||
return asyncio.run(_get_names())
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting checkpoint names: {e}")
|
||||
return []
|
||||
|
||||
def load_checkpoint(self, ckpt_name: str) -> Tuple:
|
||||
"""Load a checkpoint by name, supporting extra folder paths
|
||||
|
||||
Args:
|
||||
ckpt_name: The name of the checkpoint to load (format: "folder/model_name.ext")
|
||||
|
||||
Returns:
|
||||
Tuple of (MODEL, CLIP, VAE)
|
||||
"""
|
||||
# Get absolute path from cache using ComfyUI-style name
|
||||
ckpt_path, metadata = get_checkpoint_info_absolute(ckpt_name)
|
||||
|
||||
if metadata is None:
|
||||
raise FileNotFoundError(
|
||||
f"Checkpoint '{ckpt_name}' not found in LoRA Manager cache. "
|
||||
"Make sure the checkpoint is indexed and try again."
|
||||
)
|
||||
|
||||
# Check if it's a GGUF model
|
||||
if ckpt_path.endswith(".gguf"):
|
||||
return self._load_gguf_checkpoint(ckpt_path, ckpt_name)
|
||||
|
||||
# Load regular checkpoint using ComfyUI's API
|
||||
logger.info(f"Loading checkpoint from: {ckpt_path}")
|
||||
out = comfy.sd.load_checkpoint_guess_config(
|
||||
ckpt_path,
|
||||
output_vae=True,
|
||||
output_clip=True,
|
||||
embedding_directory=folder_paths.get_folder_paths("embeddings"),
|
||||
)
|
||||
return out[:3]
|
||||
|
||||
def _load_gguf_checkpoint(self, ckpt_path: str, ckpt_name: str) -> Tuple:
|
||||
"""Load a GGUF format checkpoint
|
||||
|
||||
Args:
|
||||
ckpt_path: Absolute path to the GGUF file
|
||||
ckpt_name: Name of the checkpoint for error messages
|
||||
|
||||
Returns:
|
||||
Tuple of (MODEL, CLIP, VAE) - CLIP and VAE may be None for GGUF
|
||||
"""
|
||||
try:
|
||||
# Try to import ComfyUI-GGUF modules
|
||||
from custom_nodes.ComfyUI_GGUF.loader import gguf_sd_loader
|
||||
from custom_nodes.ComfyUI_GGUF.ops import GGMLOps
|
||||
from custom_nodes.ComfyUI_GGUF.nodes import GGUFModelPatcher
|
||||
except ImportError:
|
||||
raise RuntimeError(
|
||||
f"Cannot load GGUF model '{ckpt_name}'. "
|
||||
"ComfyUI-GGUF is not installed. "
|
||||
"Please install ComfyUI-GGUF from https://github.com/city96/ComfyUI-GGUF "
|
||||
"to load GGUF format models."
|
||||
)
|
||||
|
||||
logger.info(f"Loading GGUF checkpoint from: {ckpt_path}")
|
||||
|
||||
try:
|
||||
# Load GGUF state dict
|
||||
sd, extra = gguf_sd_loader(ckpt_path)
|
||||
|
||||
# Prepare kwargs for metadata if supported
|
||||
kwargs = {}
|
||||
import inspect
|
||||
|
||||
valid_params = inspect.signature(
|
||||
comfy.sd.load_diffusion_model_state_dict
|
||||
).parameters
|
||||
if "metadata" in valid_params:
|
||||
kwargs["metadata"] = extra.get("metadata", {})
|
||||
|
||||
# Load the model
|
||||
model = comfy.sd.load_diffusion_model_state_dict(
|
||||
sd, model_options={"custom_operations": GGMLOps()}, **kwargs
|
||||
)
|
||||
|
||||
if model is None:
|
||||
raise RuntimeError(
|
||||
f"Could not detect model type for GGUF checkpoint: {ckpt_path}"
|
||||
)
|
||||
|
||||
# Wrap with GGUFModelPatcher
|
||||
model = GGUFModelPatcher.clone(model)
|
||||
|
||||
# GGUF checkpoints typically don't include CLIP/VAE
|
||||
return (model, None, None)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading GGUF checkpoint '{ckpt_name}': {e}")
|
||||
raise RuntimeError(
|
||||
f"Failed to load GGUF checkpoint '{ckpt_name}': {str(e)}"
|
||||
)
|
||||
205
py/nodes/unet_loader.py
Normal file
205
py/nodes/unet_loader.py
Normal file
@@ -0,0 +1,205 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import List, Tuple
|
||||
import torch
|
||||
import comfy.sd
|
||||
from ..utils.utils import get_checkpoint_info_absolute, _format_model_name_for_comfyui
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UNETLoaderLM:
|
||||
"""UNET Loader with support for extra folder paths
|
||||
|
||||
Loads diffusion models/UNets from both standard ComfyUI folders and LoRA Manager's
|
||||
extra folder paths, providing a unified interface for UNET loading.
|
||||
Supports both regular diffusion models and GGUF format models.
|
||||
"""
|
||||
|
||||
NAME = "UNETLoaderLM"
|
||||
CATEGORY = "Lora Manager/loaders"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
# Get list of unet names from scanner (includes extra folder paths)
|
||||
unet_names = s._get_unet_names()
|
||||
return {
|
||||
"required": {
|
||||
"unet_name": (
|
||||
unet_names,
|
||||
{"tooltip": "The name of the diffusion model to load."},
|
||||
),
|
||||
"weight_dtype": (
|
||||
["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2"],
|
||||
{"tooltip": "The dtype to use for the model weights."},
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
RETURN_NAMES = ("MODEL",)
|
||||
OUTPUT_TOOLTIPS = ("The model used for denoising latents.",)
|
||||
FUNCTION = "load_unet"
|
||||
|
||||
@classmethod
|
||||
def _get_unet_names(cls) -> List[str]:
|
||||
"""Get list of diffusion model names from scanner cache in ComfyUI format (relative path with extension)"""
|
||||
try:
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
import asyncio
|
||||
|
||||
async def _get_names():
|
||||
scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
cache = await scanner.get_cached_data()
|
||||
|
||||
# Get all model roots for calculating relative paths
|
||||
model_roots = scanner.get_model_roots()
|
||||
|
||||
# Filter only diffusion_model type and format names
|
||||
names = []
|
||||
for item in cache.raw_data:
|
||||
if item.get("sub_type") == "diffusion_model":
|
||||
file_path = item.get("file_path", "")
|
||||
if file_path:
|
||||
# Format as ComfyUI-style: "folder/model_name.ext"
|
||||
formatted_name = _format_model_name_for_comfyui(
|
||||
file_path, model_roots
|
||||
)
|
||||
if formatted_name:
|
||||
names.append(formatted_name)
|
||||
|
||||
return sorted(names)
|
||||
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
import concurrent.futures
|
||||
|
||||
def run_in_thread():
|
||||
new_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(new_loop)
|
||||
try:
|
||||
return new_loop.run_until_complete(_get_names())
|
||||
finally:
|
||||
new_loop.close()
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(run_in_thread)
|
||||
return future.result()
|
||||
except RuntimeError:
|
||||
return asyncio.run(_get_names())
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting unet names: {e}")
|
||||
return []
|
||||
|
||||
def load_unet(self, unet_name: str, weight_dtype: str) -> Tuple:
|
||||
"""Load a diffusion model by name, supporting extra folder paths
|
||||
|
||||
Args:
|
||||
unet_name: The name of the diffusion model to load (format: "folder/model_name.ext")
|
||||
weight_dtype: The dtype to use for model weights
|
||||
|
||||
Returns:
|
||||
Tuple of (MODEL,)
|
||||
"""
|
||||
# Get absolute path from cache using ComfyUI-style name
|
||||
unet_path, metadata = get_checkpoint_info_absolute(unet_name)
|
||||
|
||||
if metadata is None:
|
||||
raise FileNotFoundError(
|
||||
f"Diffusion model '{unet_name}' not found in LoRA Manager cache. "
|
||||
"Make sure the model is indexed and try again."
|
||||
)
|
||||
|
||||
# Check if it's a GGUF model
|
||||
if unet_path.endswith(".gguf"):
|
||||
return self._load_gguf_unet(unet_path, unet_name, weight_dtype)
|
||||
|
||||
# Load regular diffusion model using ComfyUI's API
|
||||
logger.info(f"Loading diffusion model from: {unet_path}")
|
||||
|
||||
# Build model options based on weight_dtype
|
||||
model_options = {}
|
||||
if weight_dtype == "fp8_e4m3fn":
|
||||
model_options["dtype"] = torch.float8_e4m3fn
|
||||
elif weight_dtype == "fp8_e4m3fn_fast":
|
||||
model_options["dtype"] = torch.float8_e4m3fn
|
||||
model_options["fp8_optimizations"] = True
|
||||
elif weight_dtype == "fp8_e5m2":
|
||||
model_options["dtype"] = torch.float8_e5m2
|
||||
|
||||
model = comfy.sd.load_diffusion_model(unet_path, model_options=model_options)
|
||||
return (model,)
|
||||
|
||||
def _load_gguf_unet(
|
||||
self, unet_path: str, unet_name: str, weight_dtype: str
|
||||
) -> Tuple:
|
||||
"""Load a GGUF format diffusion model
|
||||
|
||||
Args:
|
||||
unet_path: Absolute path to the GGUF file
|
||||
unet_name: Name of the model for error messages
|
||||
weight_dtype: The dtype to use for model weights
|
||||
|
||||
Returns:
|
||||
Tuple of (MODEL,)
|
||||
"""
|
||||
try:
|
||||
# Try to import ComfyUI-GGUF modules
|
||||
from custom_nodes.ComfyUI_GGUF.loader import gguf_sd_loader
|
||||
from custom_nodes.ComfyUI_GGUF.ops import GGMLOps
|
||||
from custom_nodes.ComfyUI_GGUF.nodes import GGUFModelPatcher
|
||||
except ImportError:
|
||||
raise RuntimeError(
|
||||
f"Cannot load GGUF model '{unet_name}'. "
|
||||
"ComfyUI-GGUF is not installed. "
|
||||
"Please install ComfyUI-GGUF from https://github.com/city96/ComfyUI-GGUF "
|
||||
"to load GGUF format models."
|
||||
)
|
||||
|
||||
logger.info(f"Loading GGUF diffusion model from: {unet_path}")
|
||||
|
||||
try:
|
||||
# Load GGUF state dict
|
||||
sd, extra = gguf_sd_loader(unet_path)
|
||||
|
||||
# Prepare kwargs for metadata if supported
|
||||
kwargs = {}
|
||||
import inspect
|
||||
|
||||
valid_params = inspect.signature(
|
||||
comfy.sd.load_diffusion_model_state_dict
|
||||
).parameters
|
||||
if "metadata" in valid_params:
|
||||
kwargs["metadata"] = extra.get("metadata", {})
|
||||
|
||||
# Setup custom operations with GGUF support
|
||||
ops = GGMLOps()
|
||||
|
||||
# Handle weight_dtype for GGUF models
|
||||
if weight_dtype in ("default", None):
|
||||
ops.Linear.dequant_dtype = None
|
||||
elif weight_dtype in ["target"]:
|
||||
ops.Linear.dequant_dtype = weight_dtype
|
||||
else:
|
||||
ops.Linear.dequant_dtype = getattr(torch, weight_dtype, None)
|
||||
|
||||
# Load the model
|
||||
model = comfy.sd.load_diffusion_model_state_dict(
|
||||
sd, model_options={"custom_operations": ops}, **kwargs
|
||||
)
|
||||
|
||||
if model is None:
|
||||
raise RuntimeError(
|
||||
f"Could not detect model type for GGUF diffusion model: {unet_path}"
|
||||
)
|
||||
|
||||
# Wrap with GGUFModelPatcher
|
||||
model = GGUFModelPatcher.clone(model)
|
||||
|
||||
return (model,)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading GGUF diffusion model '{unet_name}': {e}")
|
||||
raise RuntimeError(
|
||||
f"Failed to load GGUF diffusion model '{unet_name}': {str(e)}"
|
||||
)
|
||||
@@ -13,22 +13,35 @@ from .model_hash_index import ModelHashIndex
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CheckpointScanner(ModelScanner):
|
||||
"""Service for scanning and managing checkpoint files"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
# Define supported file extensions
|
||||
file_extensions = {'.ckpt', '.pt', '.pt2', '.bin', '.pth', '.safetensors', '.pkl', '.sft', '.gguf'}
|
||||
file_extensions = {
|
||||
".ckpt",
|
||||
".pt",
|
||||
".pt2",
|
||||
".bin",
|
||||
".pth",
|
||||
".safetensors",
|
||||
".pkl",
|
||||
".sft",
|
||||
".gguf",
|
||||
}
|
||||
super().__init__(
|
||||
model_type="checkpoint",
|
||||
model_class=CheckpointMetadata,
|
||||
file_extensions=file_extensions,
|
||||
hash_index=ModelHashIndex()
|
||||
hash_index=ModelHashIndex(),
|
||||
)
|
||||
|
||||
async def _create_default_metadata(self, file_path: str) -> Optional[CheckpointMetadata]:
|
||||
async def _create_default_metadata(
|
||||
self, file_path: str
|
||||
) -> Optional[CheckpointMetadata]:
|
||||
"""Create default metadata for checkpoint without calculating hash (lazy hash).
|
||||
|
||||
|
||||
Checkpoints are typically large (10GB+), so we skip hash calculation during initial
|
||||
scanning to improve startup performance. Hash will be calculated on-demand when
|
||||
fetching metadata from Civitai.
|
||||
@@ -38,13 +51,13 @@ class CheckpointScanner(ModelScanner):
|
||||
if not os.path.exists(real_path):
|
||||
logger.error(f"File not found: {file_path}")
|
||||
return None
|
||||
|
||||
|
||||
base_name = os.path.splitext(os.path.basename(file_path))[0]
|
||||
dir_path = os.path.dirname(file_path)
|
||||
|
||||
|
||||
# Find preview image
|
||||
preview_url = find_preview_file(base_name, dir_path)
|
||||
|
||||
|
||||
# Create metadata WITHOUT calculating hash
|
||||
metadata = CheckpointMetadata(
|
||||
file_name=base_name,
|
||||
@@ -59,70 +72,76 @@ class CheckpointScanner(ModelScanner):
|
||||
modelDescription="",
|
||||
sub_type="checkpoint",
|
||||
from_civitai=False, # Mark as local model since no hash yet
|
||||
hash_status="pending" # Mark hash as pending
|
||||
hash_status="pending", # Mark hash as pending
|
||||
)
|
||||
|
||||
|
||||
# Save the created metadata
|
||||
logger.info(f"Creating checkpoint metadata (hash pending) for {file_path}")
|
||||
await MetadataManager.save_metadata(file_path, metadata)
|
||||
|
||||
|
||||
return metadata
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating default checkpoint metadata for {file_path}: {e}")
|
||||
logger.error(
|
||||
f"Error creating default checkpoint metadata for {file_path}: {e}"
|
||||
)
|
||||
return None
|
||||
|
||||
async def calculate_hash_for_model(self, file_path: str) -> Optional[str]:
|
||||
"""Calculate hash for a checkpoint on-demand.
|
||||
|
||||
|
||||
Args:
|
||||
file_path: Path to the model file
|
||||
|
||||
|
||||
Returns:
|
||||
SHA256 hash string, or None if calculation failed
|
||||
"""
|
||||
from ..utils.file_utils import calculate_sha256
|
||||
|
||||
|
||||
try:
|
||||
real_path = os.path.realpath(file_path)
|
||||
if not os.path.exists(real_path):
|
||||
logger.error(f"File not found for hash calculation: {file_path}")
|
||||
return None
|
||||
|
||||
|
||||
# Load current metadata
|
||||
metadata, _ = await MetadataManager.load_metadata(file_path, self.model_class)
|
||||
metadata, _ = await MetadataManager.load_metadata(
|
||||
file_path, self.model_class
|
||||
)
|
||||
if metadata is None:
|
||||
logger.error(f"No metadata found for {file_path}")
|
||||
return None
|
||||
|
||||
|
||||
# Check if hash is already calculated
|
||||
if metadata.hash_status == "completed" and metadata.sha256:
|
||||
return metadata.sha256
|
||||
|
||||
|
||||
# Update status to calculating
|
||||
metadata.hash_status = "calculating"
|
||||
await MetadataManager.save_metadata(file_path, metadata)
|
||||
|
||||
|
||||
# Calculate hash
|
||||
logger.info(f"Calculating hash for checkpoint: {file_path}")
|
||||
sha256 = await calculate_sha256(real_path)
|
||||
|
||||
|
||||
# Update metadata with hash
|
||||
metadata.sha256 = sha256
|
||||
metadata.hash_status = "completed"
|
||||
await MetadataManager.save_metadata(file_path, metadata)
|
||||
|
||||
|
||||
# Update hash index
|
||||
self._hash_index.add_entry(sha256.lower(), file_path)
|
||||
|
||||
|
||||
logger.info(f"Hash calculated for checkpoint: {file_path}")
|
||||
return sha256
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating hash for {file_path}: {e}")
|
||||
# Update status to failed
|
||||
try:
|
||||
metadata, _ = await MetadataManager.load_metadata(file_path, self.model_class)
|
||||
metadata, _ = await MetadataManager.load_metadata(
|
||||
file_path, self.model_class
|
||||
)
|
||||
if metadata:
|
||||
metadata.hash_status = "failed"
|
||||
await MetadataManager.save_metadata(file_path, metadata)
|
||||
@@ -130,43 +149,46 @@ class CheckpointScanner(ModelScanner):
|
||||
pass
|
||||
return None
|
||||
|
||||
async def calculate_all_pending_hashes(self, progress_callback=None) -> Dict[str, int]:
|
||||
async def calculate_all_pending_hashes(
|
||||
self, progress_callback=None
|
||||
) -> Dict[str, int]:
|
||||
"""Calculate hashes for all checkpoints with pending hash status.
|
||||
|
||||
|
||||
If cache is not initialized, scans filesystem directly for metadata files
|
||||
with hash_status != 'completed'.
|
||||
|
||||
|
||||
Args:
|
||||
progress_callback: Optional callback(progress, total, current_file)
|
||||
|
||||
|
||||
Returns:
|
||||
Dict with 'completed', 'failed', 'total' counts
|
||||
"""
|
||||
# Try to get from cache first
|
||||
cache = await self.get_cached_data()
|
||||
|
||||
|
||||
if cache and cache.raw_data:
|
||||
# Use cache if available
|
||||
pending_models = [
|
||||
item for item in cache.raw_data
|
||||
if item.get('hash_status') != 'completed' or not item.get('sha256')
|
||||
item
|
||||
for item in cache.raw_data
|
||||
if item.get("hash_status") != "completed" or not item.get("sha256")
|
||||
]
|
||||
else:
|
||||
# Cache not initialized, scan filesystem directly
|
||||
pending_models = await self._find_pending_models_from_filesystem()
|
||||
|
||||
|
||||
if not pending_models:
|
||||
return {'completed': 0, 'failed': 0, 'total': 0}
|
||||
|
||||
return {"completed": 0, "failed": 0, "total": 0}
|
||||
|
||||
total = len(pending_models)
|
||||
completed = 0
|
||||
failed = 0
|
||||
|
||||
|
||||
for i, model_data in enumerate(pending_models):
|
||||
file_path = model_data.get('file_path')
|
||||
file_path = model_data.get("file_path")
|
||||
if not file_path:
|
||||
continue
|
||||
|
||||
|
||||
try:
|
||||
sha256 = await self.calculate_hash_for_model(file_path)
|
||||
if sha256:
|
||||
@@ -176,77 +198,102 @@ class CheckpointScanner(ModelScanner):
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating hash for {file_path}: {e}")
|
||||
failed += 1
|
||||
|
||||
|
||||
if progress_callback:
|
||||
try:
|
||||
await progress_callback(i + 1, total, file_path)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return {
|
||||
'completed': completed,
|
||||
'failed': failed,
|
||||
'total': total
|
||||
}
|
||||
|
||||
|
||||
return {"completed": completed, "failed": failed, "total": total}
|
||||
|
||||
async def _find_pending_models_from_filesystem(self) -> List[Dict[str, Any]]:
|
||||
"""Scan filesystem for checkpoint metadata files with pending hash status."""
|
||||
pending_models = []
|
||||
|
||||
|
||||
for root_path in self.get_model_roots():
|
||||
if not os.path.exists(root_path):
|
||||
continue
|
||||
|
||||
|
||||
for dirpath, _dirnames, filenames in os.walk(root_path):
|
||||
for filename in filenames:
|
||||
if not filename.endswith('.metadata.json'):
|
||||
if not filename.endswith(".metadata.json"):
|
||||
continue
|
||||
|
||||
|
||||
metadata_path = os.path.join(dirpath, filename)
|
||||
try:
|
||||
with open(metadata_path, 'r', encoding='utf-8') as f:
|
||||
with open(metadata_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
|
||||
# Check if hash is pending
|
||||
hash_status = data.get('hash_status', 'completed')
|
||||
sha256 = data.get('sha256', '')
|
||||
|
||||
if hash_status != 'completed' or not sha256:
|
||||
hash_status = data.get("hash_status", "completed")
|
||||
sha256 = data.get("sha256", "")
|
||||
|
||||
if hash_status != "completed" or not sha256:
|
||||
# Find corresponding model file
|
||||
model_name = filename.replace('.metadata.json', '')
|
||||
model_name = filename.replace(".metadata.json", "")
|
||||
model_path = None
|
||||
|
||||
|
||||
# Look for model file with matching name
|
||||
for ext in self.file_extensions:
|
||||
potential_path = os.path.join(dirpath, model_name + ext)
|
||||
if os.path.exists(potential_path):
|
||||
model_path = potential_path
|
||||
break
|
||||
|
||||
|
||||
if model_path:
|
||||
pending_models.append({
|
||||
'file_path': model_path.replace(os.sep, '/'),
|
||||
'hash_status': hash_status,
|
||||
'sha256': sha256,
|
||||
**{k: v for k, v in data.items() if k not in ['file_path', 'hash_status', 'sha256']}
|
||||
})
|
||||
pending_models.append(
|
||||
{
|
||||
"file_path": model_path.replace(os.sep, "/"),
|
||||
"hash_status": hash_status,
|
||||
"sha256": sha256,
|
||||
**{
|
||||
k: v
|
||||
for k, v in data.items()
|
||||
if k
|
||||
not in [
|
||||
"file_path",
|
||||
"hash_status",
|
||||
"sha256",
|
||||
]
|
||||
},
|
||||
}
|
||||
)
|
||||
except (json.JSONDecodeError, Exception) as e:
|
||||
logger.debug(f"Error reading metadata file {metadata_path}: {e}")
|
||||
logger.debug(
|
||||
f"Error reading metadata file {metadata_path}: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
|
||||
return pending_models
|
||||
|
||||
def _resolve_sub_type(self, root_path: Optional[str]) -> Optional[str]:
|
||||
"""Resolve the sub-type based on the root path."""
|
||||
"""Resolve the sub-type based on the root path.
|
||||
|
||||
Checks both standard ComfyUI paths and LoRA Manager's extra folder paths.
|
||||
"""
|
||||
if not root_path:
|
||||
return None
|
||||
|
||||
# Check standard ComfyUI checkpoint paths
|
||||
if config.checkpoints_roots and root_path in config.checkpoints_roots:
|
||||
return "checkpoint"
|
||||
|
||||
# Check extra checkpoint paths
|
||||
if (
|
||||
config.extra_checkpoints_roots
|
||||
and root_path in config.extra_checkpoints_roots
|
||||
):
|
||||
return "checkpoint"
|
||||
|
||||
# Check standard ComfyUI unet paths
|
||||
if config.unet_roots and root_path in config.unet_roots:
|
||||
return "diffusion_model"
|
||||
|
||||
# Check extra unet paths
|
||||
if config.extra_unet_roots and root_path in config.extra_unet_roots:
|
||||
return "diffusion_model"
|
||||
|
||||
return None
|
||||
|
||||
def adjust_metadata(self, metadata, file_path, root_path):
|
||||
|
||||
@@ -112,6 +112,112 @@ def get_lora_info_absolute(lora_name):
|
||||
return asyncio.run(_get_lora_info_absolute_async())
|
||||
|
||||
|
||||
def get_checkpoint_info_absolute(checkpoint_name):
|
||||
"""Get the absolute checkpoint path and metadata from cache
|
||||
|
||||
Supports ComfyUI-style model names (e.g., "folder/model_name.ext")
|
||||
|
||||
Args:
|
||||
checkpoint_name: The model name, can be:
|
||||
- ComfyUI format: "folder/model_name.safetensors"
|
||||
- Simple name: "model_name"
|
||||
|
||||
Returns:
|
||||
tuple: (absolute_path, metadata) where absolute_path is the full
|
||||
file system path to the checkpoint file, or original checkpoint_name if not found,
|
||||
metadata is the full model metadata dict or None
|
||||
"""
|
||||
|
||||
async def _get_checkpoint_info_absolute_async():
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
|
||||
scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
cache = await scanner.get_cached_data()
|
||||
|
||||
# Get model roots for matching
|
||||
model_roots = scanner.get_model_roots()
|
||||
|
||||
# Normalize the checkpoint name
|
||||
normalized_name = checkpoint_name.replace(os.sep, "/")
|
||||
|
||||
for item in cache.raw_data:
|
||||
file_path = item.get("file_path", "")
|
||||
if not file_path:
|
||||
continue
|
||||
|
||||
# Format the stored path as ComfyUI-style name
|
||||
formatted_name = _format_model_name_for_comfyui(file_path, model_roots)
|
||||
|
||||
# Match by formatted name
|
||||
if formatted_name == normalized_name or formatted_name == checkpoint_name:
|
||||
return file_path, item
|
||||
|
||||
# Also try matching by basename only (for backward compatibility)
|
||||
file_name = item.get("file_name", "")
|
||||
if (
|
||||
file_name == checkpoint_name
|
||||
or file_name == os.path.splitext(normalized_name)[0]
|
||||
):
|
||||
return file_path, item
|
||||
|
||||
return checkpoint_name, None
|
||||
|
||||
try:
|
||||
# Check if we're already in an event loop
|
||||
loop = asyncio.get_running_loop()
|
||||
# If we're in a running loop, we need to use a different approach
|
||||
# Create a new thread to run the async code
|
||||
import concurrent.futures
|
||||
|
||||
def run_in_thread():
|
||||
new_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(new_loop)
|
||||
try:
|
||||
return new_loop.run_until_complete(
|
||||
_get_checkpoint_info_absolute_async()
|
||||
)
|
||||
finally:
|
||||
new_loop.close()
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(run_in_thread)
|
||||
return future.result()
|
||||
|
||||
except RuntimeError:
|
||||
# No event loop is running, we can use asyncio.run()
|
||||
return asyncio.run(_get_checkpoint_info_absolute_async())
|
||||
|
||||
|
||||
def _format_model_name_for_comfyui(file_path: str, model_roots: list) -> str:
|
||||
"""Format file path to ComfyUI-style model name (relative path with extension)
|
||||
|
||||
Example: /path/to/checkpoints/Illustrious/model.safetensors -> Illustrious/model.safetensors
|
||||
|
||||
Args:
|
||||
file_path: Absolute path to the model file
|
||||
model_roots: List of model root directories
|
||||
|
||||
Returns:
|
||||
ComfyUI-style model name with relative path and extension
|
||||
"""
|
||||
# Normalize path separators
|
||||
normalized_path = file_path.replace(os.sep, "/")
|
||||
|
||||
# Find the matching root and get relative path
|
||||
for root in model_roots:
|
||||
normalized_root = root.replace(os.sep, "/")
|
||||
# Ensure root ends with / for proper matching
|
||||
if not normalized_root.endswith("/"):
|
||||
normalized_root += "/"
|
||||
|
||||
if normalized_path.startswith(normalized_root):
|
||||
rel_path = normalized_path[len(normalized_root) :]
|
||||
return rel_path
|
||||
|
||||
# If no root matches, just return the basename with extension
|
||||
return os.path.basename(file_path)
|
||||
|
||||
|
||||
def fuzzy_match(text: str, pattern: str, threshold: float = 0.85) -> bool:
|
||||
"""
|
||||
Check if text matches pattern using fuzzy matching.
|
||||
|
||||
Reference in New Issue
Block a user