diff --git a/__init__.py b/__init__.py index 98034007..1966e838 100644 --- a/__init__.py +++ b/__init__.py @@ -1,6 +1,8 @@ try: # pragma: no cover - import fallback for pytest collection from .py.lora_manager import LoraManager from .py.nodes.lora_loader import LoraLoaderLM, LoraTextLoaderLM + from .py.nodes.checkpoint_loader import CheckpointLoaderLM + from .py.nodes.unet_loader import UNETLoaderLM from .py.nodes.trigger_word_toggle import TriggerWordToggleLM from .py.nodes.prompt import PromptLM from .py.nodes.text import TextLM @@ -27,12 +29,12 @@ except ( PromptLM = importlib.import_module("py.nodes.prompt").PromptLM TextLM = importlib.import_module("py.nodes.text").TextLM LoraManager = importlib.import_module("py.lora_manager").LoraManager - LoraLoaderLM = importlib.import_module( - "py.nodes.lora_loader" - ).LoraLoaderLM - LoraTextLoaderLM = importlib.import_module( - "py.nodes.lora_loader" - ).LoraTextLoaderLM + LoraLoaderLM = importlib.import_module("py.nodes.lora_loader").LoraLoaderLM + LoraTextLoaderLM = importlib.import_module("py.nodes.lora_loader").LoraTextLoaderLM + CheckpointLoaderLM = importlib.import_module( + "py.nodes.checkpoint_loader" + ).CheckpointLoaderLM + UNETLoaderLM = importlib.import_module("py.nodes.unet_loader").UNETLoaderLM TriggerWordToggleLM = importlib.import_module( "py.nodes.trigger_word_toggle" ).TriggerWordToggleLM @@ -49,9 +51,7 @@ except ( LoraRandomizerLM = importlib.import_module( "py.nodes.lora_randomizer" ).LoraRandomizerLM - LoraCyclerLM = importlib.import_module( - "py.nodes.lora_cycler" - ).LoraCyclerLM + LoraCyclerLM = importlib.import_module("py.nodes.lora_cycler").LoraCyclerLM init_metadata_collector = importlib.import_module("py.metadata_collector").init NODE_CLASS_MAPPINGS = { @@ -59,6 +59,8 @@ NODE_CLASS_MAPPINGS = { TextLM.NAME: TextLM, LoraLoaderLM.NAME: LoraLoaderLM, LoraTextLoaderLM.NAME: LoraTextLoaderLM, + CheckpointLoaderLM.NAME: CheckpointLoaderLM, + UNETLoaderLM.NAME: UNETLoaderLM, TriggerWordToggleLM.NAME: TriggerWordToggleLM, LoraStackerLM.NAME: LoraStackerLM, SaveImageLM.NAME: SaveImageLM, diff --git a/py/config.py b/py/config.py index bd2d3453..b5ea7b3b 100644 --- a/py/config.py +++ b/py/config.py @@ -707,7 +707,13 @@ class Config: def _prepare_checkpoint_paths( self, checkpoint_paths: Iterable[str], unet_paths: Iterable[str] - ) -> List[str]: + ) -> Tuple[List[str], List[str], List[str]]: + """Prepare checkpoint paths and return (all_roots, checkpoint_roots, unet_roots). + + Returns: + Tuple of (all_unique_paths, checkpoint_only_paths, unet_only_paths) + This method does NOT modify instance variables - callers must set them. + """ checkpoint_map = self._dedupe_existing_paths(checkpoint_paths) unet_map = self._dedupe_existing_paths(unet_paths) @@ -737,8 +743,8 @@ class Config: checkpoint_values = set(checkpoint_map.values()) unet_values = set(unet_map.values()) - self.checkpoints_roots = [p for p in unique_paths if p in checkpoint_values] - self.unet_roots = [p for p in unique_paths if p in unet_values] + checkpoint_roots = [p for p in unique_paths if p in checkpoint_values] + unet_roots = [p for p in unique_paths if p in unet_values] for original_path in unique_paths: real_path = os.path.normpath(os.path.realpath(original_path)).replace( @@ -747,7 +753,7 @@ class Config: if real_path != original_path: self.add_path_mapping(original_path, real_path) - return unique_paths + return unique_paths, checkpoint_roots, unet_roots def _prepare_embedding_paths(self, raw_paths: Iterable[str]) -> List[str]: path_map = self._dedupe_existing_paths(raw_paths) @@ -776,9 +782,11 @@ class Config: embedding_paths = folder_paths.get("embeddings", []) or [] self.loras_roots = self._prepare_lora_paths(lora_paths) - self.base_models_roots = self._prepare_checkpoint_paths( - checkpoint_paths, unet_paths - ) + ( + self.base_models_roots, + self.checkpoints_roots, + self.unet_roots, + ) = self._prepare_checkpoint_paths(checkpoint_paths, unet_paths) self.embeddings_roots = self._prepare_embedding_paths(embedding_paths) # Process extra paths (only for LoRA Manager, not shared with ComfyUI) @@ -789,18 +797,11 @@ class Config: extra_embedding_paths = extra_paths.get("embeddings", []) or [] self.extra_loras_roots = self._prepare_lora_paths(extra_lora_paths) - # Save main paths before processing extra paths ( _prepare_checkpoint_paths overwrites them) - saved_checkpoints_roots = self.checkpoints_roots - saved_unet_roots = self.unet_roots - self.extra_checkpoints_roots = self._prepare_checkpoint_paths( - extra_checkpoint_paths, extra_unet_paths - ) - self.extra_unet_roots = ( - self.unet_roots if self.unet_roots is not None else [] - ) # unet_roots was set by _prepare_checkpoint_paths - # Restore main paths - self.checkpoints_roots = saved_checkpoints_roots - self.unet_roots = saved_unet_roots + ( + _, + self.extra_checkpoints_roots, + self.extra_unet_roots, + ) = self._prepare_checkpoint_paths(extra_checkpoint_paths, extra_unet_paths) self.extra_embeddings_roots = self._prepare_embedding_paths( extra_embedding_paths ) @@ -857,9 +858,11 @@ class Config: try: raw_checkpoint_paths = folder_paths.get_folder_paths("checkpoints") raw_unet_paths = folder_paths.get_folder_paths("unet") - unique_paths = self._prepare_checkpoint_paths( - raw_checkpoint_paths, raw_unet_paths - ) + ( + unique_paths, + self.checkpoints_roots, + self.unet_roots, + ) = self._prepare_checkpoint_paths(raw_checkpoint_paths, raw_unet_paths) logger.info( "Found checkpoint roots:" diff --git a/py/nodes/checkpoint_loader.py b/py/nodes/checkpoint_loader.py new file mode 100644 index 00000000..0f5b57fa --- /dev/null +++ b/py/nodes/checkpoint_loader.py @@ -0,0 +1,184 @@ +import logging +import os +from typing import List, Tuple +import comfy.sd +import folder_paths +from ..utils.utils import get_checkpoint_info_absolute, _format_model_name_for_comfyui + +logger = logging.getLogger(__name__) + + +class CheckpointLoaderLM: + """Checkpoint Loader with support for extra folder paths + + Loads checkpoints from both standard ComfyUI folders and LoRA Manager's + extra folder paths, providing a unified interface for checkpoint loading. + """ + + NAME = "CheckpointLoaderLM" + CATEGORY = "Lora Manager/loaders" + + @classmethod + def INPUT_TYPES(s): + # Get list of checkpoint names from scanner (includes extra folder paths) + checkpoint_names = s._get_checkpoint_names() + return { + "required": { + "ckpt_name": ( + checkpoint_names, + {"tooltip": "The name of the checkpoint (model) to load."}, + ), + } + } + + RETURN_TYPES = ("MODEL", "CLIP", "VAE") + RETURN_NAMES = ("MODEL", "CLIP", "VAE") + OUTPUT_TOOLTIPS = ( + "The model used for denoising latents.", + "The CLIP model used for encoding text prompts.", + "The VAE model used for encoding and decoding images to and from latent space.", + ) + FUNCTION = "load_checkpoint" + + @classmethod + def _get_checkpoint_names(cls) -> List[str]: + """Get list of checkpoint names from scanner cache in ComfyUI format (relative path with extension)""" + try: + from ..services.service_registry import ServiceRegistry + import asyncio + + async def _get_names(): + scanner = await ServiceRegistry.get_checkpoint_scanner() + cache = await scanner.get_cached_data() + + # Get all model roots for calculating relative paths + model_roots = scanner.get_model_roots() + + # Filter only checkpoint type (not diffusion_model) and format names + names = [] + for item in cache.raw_data: + if item.get("sub_type") == "checkpoint": + file_path = item.get("file_path", "") + if file_path: + # Format as ComfyUI-style: "folder/model_name.ext" + formatted_name = _format_model_name_for_comfyui( + file_path, model_roots + ) + if formatted_name: + names.append(formatted_name) + + return sorted(names) + + try: + loop = asyncio.get_running_loop() + import concurrent.futures + + def run_in_thread(): + new_loop = asyncio.new_event_loop() + asyncio.set_event_loop(new_loop) + try: + return new_loop.run_until_complete(_get_names()) + finally: + new_loop.close() + + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(run_in_thread) + return future.result() + except RuntimeError: + return asyncio.run(_get_names()) + except Exception as e: + logger.error(f"Error getting checkpoint names: {e}") + return [] + + def load_checkpoint(self, ckpt_name: str) -> Tuple: + """Load a checkpoint by name, supporting extra folder paths + + Args: + ckpt_name: The name of the checkpoint to load (format: "folder/model_name.ext") + + Returns: + Tuple of (MODEL, CLIP, VAE) + """ + # Get absolute path from cache using ComfyUI-style name + ckpt_path, metadata = get_checkpoint_info_absolute(ckpt_name) + + if metadata is None: + raise FileNotFoundError( + f"Checkpoint '{ckpt_name}' not found in LoRA Manager cache. " + "Make sure the checkpoint is indexed and try again." + ) + + # Check if it's a GGUF model + if ckpt_path.endswith(".gguf"): + return self._load_gguf_checkpoint(ckpt_path, ckpt_name) + + # Load regular checkpoint using ComfyUI's API + logger.info(f"Loading checkpoint from: {ckpt_path}") + out = comfy.sd.load_checkpoint_guess_config( + ckpt_path, + output_vae=True, + output_clip=True, + embedding_directory=folder_paths.get_folder_paths("embeddings"), + ) + return out[:3] + + def _load_gguf_checkpoint(self, ckpt_path: str, ckpt_name: str) -> Tuple: + """Load a GGUF format checkpoint + + Args: + ckpt_path: Absolute path to the GGUF file + ckpt_name: Name of the checkpoint for error messages + + Returns: + Tuple of (MODEL, CLIP, VAE) - CLIP and VAE may be None for GGUF + """ + try: + # Try to import ComfyUI-GGUF modules + from custom_nodes.ComfyUI_GGUF.loader import gguf_sd_loader + from custom_nodes.ComfyUI_GGUF.ops import GGMLOps + from custom_nodes.ComfyUI_GGUF.nodes import GGUFModelPatcher + except ImportError: + raise RuntimeError( + f"Cannot load GGUF model '{ckpt_name}'. " + "ComfyUI-GGUF is not installed. " + "Please install ComfyUI-GGUF from https://github.com/city96/ComfyUI-GGUF " + "to load GGUF format models." + ) + + logger.info(f"Loading GGUF checkpoint from: {ckpt_path}") + + try: + # Load GGUF state dict + sd, extra = gguf_sd_loader(ckpt_path) + + # Prepare kwargs for metadata if supported + kwargs = {} + import inspect + + valid_params = inspect.signature( + comfy.sd.load_diffusion_model_state_dict + ).parameters + if "metadata" in valid_params: + kwargs["metadata"] = extra.get("metadata", {}) + + # Load the model + model = comfy.sd.load_diffusion_model_state_dict( + sd, model_options={"custom_operations": GGMLOps()}, **kwargs + ) + + if model is None: + raise RuntimeError( + f"Could not detect model type for GGUF checkpoint: {ckpt_path}" + ) + + # Wrap with GGUFModelPatcher + model = GGUFModelPatcher.clone(model) + + # GGUF checkpoints typically don't include CLIP/VAE + return (model, None, None) + + except Exception as e: + logger.error(f"Error loading GGUF checkpoint '{ckpt_name}': {e}") + raise RuntimeError( + f"Failed to load GGUF checkpoint '{ckpt_name}': {str(e)}" + ) diff --git a/py/nodes/unet_loader.py b/py/nodes/unet_loader.py new file mode 100644 index 00000000..4478ee6d --- /dev/null +++ b/py/nodes/unet_loader.py @@ -0,0 +1,205 @@ +import logging +import os +from typing import List, Tuple +import torch +import comfy.sd +from ..utils.utils import get_checkpoint_info_absolute, _format_model_name_for_comfyui + +logger = logging.getLogger(__name__) + + +class UNETLoaderLM: + """UNET Loader with support for extra folder paths + + Loads diffusion models/UNets from both standard ComfyUI folders and LoRA Manager's + extra folder paths, providing a unified interface for UNET loading. + Supports both regular diffusion models and GGUF format models. + """ + + NAME = "UNETLoaderLM" + CATEGORY = "Lora Manager/loaders" + + @classmethod + def INPUT_TYPES(s): + # Get list of unet names from scanner (includes extra folder paths) + unet_names = s._get_unet_names() + return { + "required": { + "unet_name": ( + unet_names, + {"tooltip": "The name of the diffusion model to load."}, + ), + "weight_dtype": ( + ["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2"], + {"tooltip": "The dtype to use for the model weights."}, + ), + } + } + + RETURN_TYPES = ("MODEL",) + RETURN_NAMES = ("MODEL",) + OUTPUT_TOOLTIPS = ("The model used for denoising latents.",) + FUNCTION = "load_unet" + + @classmethod + def _get_unet_names(cls) -> List[str]: + """Get list of diffusion model names from scanner cache in ComfyUI format (relative path with extension)""" + try: + from ..services.service_registry import ServiceRegistry + import asyncio + + async def _get_names(): + scanner = await ServiceRegistry.get_checkpoint_scanner() + cache = await scanner.get_cached_data() + + # Get all model roots for calculating relative paths + model_roots = scanner.get_model_roots() + + # Filter only diffusion_model type and format names + names = [] + for item in cache.raw_data: + if item.get("sub_type") == "diffusion_model": + file_path = item.get("file_path", "") + if file_path: + # Format as ComfyUI-style: "folder/model_name.ext" + formatted_name = _format_model_name_for_comfyui( + file_path, model_roots + ) + if formatted_name: + names.append(formatted_name) + + return sorted(names) + + try: + loop = asyncio.get_running_loop() + import concurrent.futures + + def run_in_thread(): + new_loop = asyncio.new_event_loop() + asyncio.set_event_loop(new_loop) + try: + return new_loop.run_until_complete(_get_names()) + finally: + new_loop.close() + + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(run_in_thread) + return future.result() + except RuntimeError: + return asyncio.run(_get_names()) + except Exception as e: + logger.error(f"Error getting unet names: {e}") + return [] + + def load_unet(self, unet_name: str, weight_dtype: str) -> Tuple: + """Load a diffusion model by name, supporting extra folder paths + + Args: + unet_name: The name of the diffusion model to load (format: "folder/model_name.ext") + weight_dtype: The dtype to use for model weights + + Returns: + Tuple of (MODEL,) + """ + # Get absolute path from cache using ComfyUI-style name + unet_path, metadata = get_checkpoint_info_absolute(unet_name) + + if metadata is None: + raise FileNotFoundError( + f"Diffusion model '{unet_name}' not found in LoRA Manager cache. " + "Make sure the model is indexed and try again." + ) + + # Check if it's a GGUF model + if unet_path.endswith(".gguf"): + return self._load_gguf_unet(unet_path, unet_name, weight_dtype) + + # Load regular diffusion model using ComfyUI's API + logger.info(f"Loading diffusion model from: {unet_path}") + + # Build model options based on weight_dtype + model_options = {} + if weight_dtype == "fp8_e4m3fn": + model_options["dtype"] = torch.float8_e4m3fn + elif weight_dtype == "fp8_e4m3fn_fast": + model_options["dtype"] = torch.float8_e4m3fn + model_options["fp8_optimizations"] = True + elif weight_dtype == "fp8_e5m2": + model_options["dtype"] = torch.float8_e5m2 + + model = comfy.sd.load_diffusion_model(unet_path, model_options=model_options) + return (model,) + + def _load_gguf_unet( + self, unet_path: str, unet_name: str, weight_dtype: str + ) -> Tuple: + """Load a GGUF format diffusion model + + Args: + unet_path: Absolute path to the GGUF file + unet_name: Name of the model for error messages + weight_dtype: The dtype to use for model weights + + Returns: + Tuple of (MODEL,) + """ + try: + # Try to import ComfyUI-GGUF modules + from custom_nodes.ComfyUI_GGUF.loader import gguf_sd_loader + from custom_nodes.ComfyUI_GGUF.ops import GGMLOps + from custom_nodes.ComfyUI_GGUF.nodes import GGUFModelPatcher + except ImportError: + raise RuntimeError( + f"Cannot load GGUF model '{unet_name}'. " + "ComfyUI-GGUF is not installed. " + "Please install ComfyUI-GGUF from https://github.com/city96/ComfyUI-GGUF " + "to load GGUF format models." + ) + + logger.info(f"Loading GGUF diffusion model from: {unet_path}") + + try: + # Load GGUF state dict + sd, extra = gguf_sd_loader(unet_path) + + # Prepare kwargs for metadata if supported + kwargs = {} + import inspect + + valid_params = inspect.signature( + comfy.sd.load_diffusion_model_state_dict + ).parameters + if "metadata" in valid_params: + kwargs["metadata"] = extra.get("metadata", {}) + + # Setup custom operations with GGUF support + ops = GGMLOps() + + # Handle weight_dtype for GGUF models + if weight_dtype in ("default", None): + ops.Linear.dequant_dtype = None + elif weight_dtype in ["target"]: + ops.Linear.dequant_dtype = weight_dtype + else: + ops.Linear.dequant_dtype = getattr(torch, weight_dtype, None) + + # Load the model + model = comfy.sd.load_diffusion_model_state_dict( + sd, model_options={"custom_operations": ops}, **kwargs + ) + + if model is None: + raise RuntimeError( + f"Could not detect model type for GGUF diffusion model: {unet_path}" + ) + + # Wrap with GGUFModelPatcher + model = GGUFModelPatcher.clone(model) + + return (model,) + + except Exception as e: + logger.error(f"Error loading GGUF diffusion model '{unet_name}': {e}") + raise RuntimeError( + f"Failed to load GGUF diffusion model '{unet_name}': {str(e)}" + ) diff --git a/py/services/checkpoint_scanner.py b/py/services/checkpoint_scanner.py index f28fa439..42dc0580 100644 --- a/py/services/checkpoint_scanner.py +++ b/py/services/checkpoint_scanner.py @@ -13,22 +13,35 @@ from .model_hash_index import ModelHashIndex logger = logging.getLogger(__name__) + class CheckpointScanner(ModelScanner): """Service for scanning and managing checkpoint files""" - + def __init__(self): # Define supported file extensions - file_extensions = {'.ckpt', '.pt', '.pt2', '.bin', '.pth', '.safetensors', '.pkl', '.sft', '.gguf'} + file_extensions = { + ".ckpt", + ".pt", + ".pt2", + ".bin", + ".pth", + ".safetensors", + ".pkl", + ".sft", + ".gguf", + } super().__init__( model_type="checkpoint", model_class=CheckpointMetadata, file_extensions=file_extensions, - hash_index=ModelHashIndex() + hash_index=ModelHashIndex(), ) - async def _create_default_metadata(self, file_path: str) -> Optional[CheckpointMetadata]: + async def _create_default_metadata( + self, file_path: str + ) -> Optional[CheckpointMetadata]: """Create default metadata for checkpoint without calculating hash (lazy hash). - + Checkpoints are typically large (10GB+), so we skip hash calculation during initial scanning to improve startup performance. Hash will be calculated on-demand when fetching metadata from Civitai. @@ -38,13 +51,13 @@ class CheckpointScanner(ModelScanner): if not os.path.exists(real_path): logger.error(f"File not found: {file_path}") return None - + base_name = os.path.splitext(os.path.basename(file_path))[0] dir_path = os.path.dirname(file_path) - + # Find preview image preview_url = find_preview_file(base_name, dir_path) - + # Create metadata WITHOUT calculating hash metadata = CheckpointMetadata( file_name=base_name, @@ -59,70 +72,76 @@ class CheckpointScanner(ModelScanner): modelDescription="", sub_type="checkpoint", from_civitai=False, # Mark as local model since no hash yet - hash_status="pending" # Mark hash as pending + hash_status="pending", # Mark hash as pending ) - + # Save the created metadata logger.info(f"Creating checkpoint metadata (hash pending) for {file_path}") await MetadataManager.save_metadata(file_path, metadata) - + return metadata - + except Exception as e: - logger.error(f"Error creating default checkpoint metadata for {file_path}: {e}") + logger.error( + f"Error creating default checkpoint metadata for {file_path}: {e}" + ) return None async def calculate_hash_for_model(self, file_path: str) -> Optional[str]: """Calculate hash for a checkpoint on-demand. - + Args: file_path: Path to the model file - + Returns: SHA256 hash string, or None if calculation failed """ from ..utils.file_utils import calculate_sha256 - + try: real_path = os.path.realpath(file_path) if not os.path.exists(real_path): logger.error(f"File not found for hash calculation: {file_path}") return None - + # Load current metadata - metadata, _ = await MetadataManager.load_metadata(file_path, self.model_class) + metadata, _ = await MetadataManager.load_metadata( + file_path, self.model_class + ) if metadata is None: logger.error(f"No metadata found for {file_path}") return None - + # Check if hash is already calculated if metadata.hash_status == "completed" and metadata.sha256: return metadata.sha256 - + # Update status to calculating metadata.hash_status = "calculating" await MetadataManager.save_metadata(file_path, metadata) - + # Calculate hash logger.info(f"Calculating hash for checkpoint: {file_path}") sha256 = await calculate_sha256(real_path) - + # Update metadata with hash metadata.sha256 = sha256 metadata.hash_status = "completed" await MetadataManager.save_metadata(file_path, metadata) - + # Update hash index self._hash_index.add_entry(sha256.lower(), file_path) - + logger.info(f"Hash calculated for checkpoint: {file_path}") return sha256 - + except Exception as e: logger.error(f"Error calculating hash for {file_path}: {e}") # Update status to failed try: - metadata, _ = await MetadataManager.load_metadata(file_path, self.model_class) + metadata, _ = await MetadataManager.load_metadata( + file_path, self.model_class + ) if metadata: metadata.hash_status = "failed" await MetadataManager.save_metadata(file_path, metadata) @@ -130,43 +149,46 @@ class CheckpointScanner(ModelScanner): pass return None - async def calculate_all_pending_hashes(self, progress_callback=None) -> Dict[str, int]: + async def calculate_all_pending_hashes( + self, progress_callback=None + ) -> Dict[str, int]: """Calculate hashes for all checkpoints with pending hash status. - + If cache is not initialized, scans filesystem directly for metadata files with hash_status != 'completed'. - + Args: progress_callback: Optional callback(progress, total, current_file) - + Returns: Dict with 'completed', 'failed', 'total' counts """ # Try to get from cache first cache = await self.get_cached_data() - + if cache and cache.raw_data: # Use cache if available pending_models = [ - item for item in cache.raw_data - if item.get('hash_status') != 'completed' or not item.get('sha256') + item + for item in cache.raw_data + if item.get("hash_status") != "completed" or not item.get("sha256") ] else: # Cache not initialized, scan filesystem directly pending_models = await self._find_pending_models_from_filesystem() - + if not pending_models: - return {'completed': 0, 'failed': 0, 'total': 0} - + return {"completed": 0, "failed": 0, "total": 0} + total = len(pending_models) completed = 0 failed = 0 - + for i, model_data in enumerate(pending_models): - file_path = model_data.get('file_path') + file_path = model_data.get("file_path") if not file_path: continue - + try: sha256 = await self.calculate_hash_for_model(file_path) if sha256: @@ -176,77 +198,102 @@ class CheckpointScanner(ModelScanner): except Exception as e: logger.error(f"Error calculating hash for {file_path}: {e}") failed += 1 - + if progress_callback: try: await progress_callback(i + 1, total, file_path) except Exception: pass - - return { - 'completed': completed, - 'failed': failed, - 'total': total - } - + + return {"completed": completed, "failed": failed, "total": total} + async def _find_pending_models_from_filesystem(self) -> List[Dict[str, Any]]: """Scan filesystem for checkpoint metadata files with pending hash status.""" pending_models = [] - + for root_path in self.get_model_roots(): if not os.path.exists(root_path): continue - + for dirpath, _dirnames, filenames in os.walk(root_path): for filename in filenames: - if not filename.endswith('.metadata.json'): + if not filename.endswith(".metadata.json"): continue - + metadata_path = os.path.join(dirpath, filename) try: - with open(metadata_path, 'r', encoding='utf-8') as f: + with open(metadata_path, "r", encoding="utf-8") as f: data = json.load(f) - + # Check if hash is pending - hash_status = data.get('hash_status', 'completed') - sha256 = data.get('sha256', '') - - if hash_status != 'completed' or not sha256: + hash_status = data.get("hash_status", "completed") + sha256 = data.get("sha256", "") + + if hash_status != "completed" or not sha256: # Find corresponding model file - model_name = filename.replace('.metadata.json', '') + model_name = filename.replace(".metadata.json", "") model_path = None - + # Look for model file with matching name for ext in self.file_extensions: potential_path = os.path.join(dirpath, model_name + ext) if os.path.exists(potential_path): model_path = potential_path break - + if model_path: - pending_models.append({ - 'file_path': model_path.replace(os.sep, '/'), - 'hash_status': hash_status, - 'sha256': sha256, - **{k: v for k, v in data.items() if k not in ['file_path', 'hash_status', 'sha256']} - }) + pending_models.append( + { + "file_path": model_path.replace(os.sep, "/"), + "hash_status": hash_status, + "sha256": sha256, + **{ + k: v + for k, v in data.items() + if k + not in [ + "file_path", + "hash_status", + "sha256", + ] + }, + } + ) except (json.JSONDecodeError, Exception) as e: - logger.debug(f"Error reading metadata file {metadata_path}: {e}") + logger.debug( + f"Error reading metadata file {metadata_path}: {e}" + ) continue - + return pending_models def _resolve_sub_type(self, root_path: Optional[str]) -> Optional[str]: - """Resolve the sub-type based on the root path.""" + """Resolve the sub-type based on the root path. + + Checks both standard ComfyUI paths and LoRA Manager's extra folder paths. + """ if not root_path: return None + # Check standard ComfyUI checkpoint paths if config.checkpoints_roots and root_path in config.checkpoints_roots: return "checkpoint" + # Check extra checkpoint paths + if ( + config.extra_checkpoints_roots + and root_path in config.extra_checkpoints_roots + ): + return "checkpoint" + + # Check standard ComfyUI unet paths if config.unet_roots and root_path in config.unet_roots: return "diffusion_model" + # Check extra unet paths + if config.extra_unet_roots and root_path in config.extra_unet_roots: + return "diffusion_model" + return None def adjust_metadata(self, metadata, file_path, root_path): diff --git a/py/utils/utils.py b/py/utils/utils.py index 75b19738..d7bba5d9 100644 --- a/py/utils/utils.py +++ b/py/utils/utils.py @@ -112,6 +112,112 @@ def get_lora_info_absolute(lora_name): return asyncio.run(_get_lora_info_absolute_async()) +def get_checkpoint_info_absolute(checkpoint_name): + """Get the absolute checkpoint path and metadata from cache + + Supports ComfyUI-style model names (e.g., "folder/model_name.ext") + + Args: + checkpoint_name: The model name, can be: + - ComfyUI format: "folder/model_name.safetensors" + - Simple name: "model_name" + + Returns: + tuple: (absolute_path, metadata) where absolute_path is the full + file system path to the checkpoint file, or original checkpoint_name if not found, + metadata is the full model metadata dict or None + """ + + async def _get_checkpoint_info_absolute_async(): + from ..services.service_registry import ServiceRegistry + + scanner = await ServiceRegistry.get_checkpoint_scanner() + cache = await scanner.get_cached_data() + + # Get model roots for matching + model_roots = scanner.get_model_roots() + + # Normalize the checkpoint name + normalized_name = checkpoint_name.replace(os.sep, "/") + + for item in cache.raw_data: + file_path = item.get("file_path", "") + if not file_path: + continue + + # Format the stored path as ComfyUI-style name + formatted_name = _format_model_name_for_comfyui(file_path, model_roots) + + # Match by formatted name + if formatted_name == normalized_name or formatted_name == checkpoint_name: + return file_path, item + + # Also try matching by basename only (for backward compatibility) + file_name = item.get("file_name", "") + if ( + file_name == checkpoint_name + or file_name == os.path.splitext(normalized_name)[0] + ): + return file_path, item + + return checkpoint_name, None + + try: + # Check if we're already in an event loop + loop = asyncio.get_running_loop() + # If we're in a running loop, we need to use a different approach + # Create a new thread to run the async code + import concurrent.futures + + def run_in_thread(): + new_loop = asyncio.new_event_loop() + asyncio.set_event_loop(new_loop) + try: + return new_loop.run_until_complete( + _get_checkpoint_info_absolute_async() + ) + finally: + new_loop.close() + + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(run_in_thread) + return future.result() + + except RuntimeError: + # No event loop is running, we can use asyncio.run() + return asyncio.run(_get_checkpoint_info_absolute_async()) + + +def _format_model_name_for_comfyui(file_path: str, model_roots: list) -> str: + """Format file path to ComfyUI-style model name (relative path with extension) + + Example: /path/to/checkpoints/Illustrious/model.safetensors -> Illustrious/model.safetensors + + Args: + file_path: Absolute path to the model file + model_roots: List of model root directories + + Returns: + ComfyUI-style model name with relative path and extension + """ + # Normalize path separators + normalized_path = file_path.replace(os.sep, "/") + + # Find the matching root and get relative path + for root in model_roots: + normalized_root = root.replace(os.sep, "/") + # Ensure root ends with / for proper matching + if not normalized_root.endswith("/"): + normalized_root += "/" + + if normalized_path.startswith(normalized_root): + rel_path = normalized_path[len(normalized_root) :] + return rel_path + + # If no root matches, just return the basename with extension + return os.path.basename(file_path) + + def fuzzy_match(text: str, pattern: str, threshold: float = 0.85) -> bool: """ Check if text matches pattern using fuzzy matching. diff --git a/tests/config/test_checkpoint_path_overlap.py b/tests/config/test_checkpoint_path_overlap.py index 2019e785..3624e0a0 100644 --- a/tests/config/test_checkpoint_path_overlap.py +++ b/tests/config/test_checkpoint_path_overlap.py @@ -36,8 +36,8 @@ class TestCheckpointPathOverlap: config._preview_root_paths = set() config._cached_fingerprint = None - # Call the method under test - result = config._prepare_checkpoint_paths( + # Call the method under test - now returns a tuple + all_paths, checkpoint_roots, unet_roots = config._prepare_checkpoint_paths( [str(checkpoints_link)], [str(unet_link)] ) @@ -50,21 +50,27 @@ class TestCheckpointPathOverlap: ] assert len(warning_messages) == 1 assert "checkpoints" in warning_messages[0].lower() - assert "diffusion_models" in warning_messages[0].lower() or "unet" in warning_messages[0].lower() + assert ( + "diffusion_models" in warning_messages[0].lower() + or "unet" in warning_messages[0].lower() + ) # Verify warning mentions backward compatibility fallback - assert "falling back" in warning_messages[0].lower() or "backward compatibility" in warning_messages[0].lower() + assert ( + "falling back" in warning_messages[0].lower() + or "backward compatibility" in warning_messages[0].lower() + ) # Verify only one path is returned (deduplication still works) - assert len(result) == 1 + assert len(all_paths) == 1 # Prioritizes checkpoints path for backward compatibility - assert _normalize(result[0]) == _normalize(str(checkpoints_link)) + assert _normalize(all_paths[0]) == _normalize(str(checkpoints_link)) - # Verify checkpoints_roots has the path (prioritized) - assert len(config.checkpoints_roots) == 1 - assert _normalize(config.checkpoints_roots[0]) == _normalize(str(checkpoints_link)) + # Verify checkpoint_roots has the path (prioritized) + assert len(checkpoint_roots) == 1 + assert _normalize(checkpoint_roots[0]) == _normalize(str(checkpoints_link)) # Verify unet_roots is empty (overlapping paths removed) - assert config.unet_roots == [] + assert unet_roots == [] def test_non_overlapping_paths_no_warning( self, monkeypatch: pytest.MonkeyPatch, tmp_path, caplog @@ -83,7 +89,7 @@ class TestCheckpointPathOverlap: config._preview_root_paths = set() config._cached_fingerprint = None - result = config._prepare_checkpoint_paths( + all_paths, checkpoint_roots, unet_roots = config._prepare_checkpoint_paths( [str(checkpoints_dir)], [str(unet_dir)] ) @@ -97,14 +103,14 @@ class TestCheckpointPathOverlap: assert len(warning_messages) == 0 # Verify both paths are returned - assert len(result) == 2 - normalized_result = [_normalize(p) for p in result] + assert len(all_paths) == 2 + normalized_result = [_normalize(p) for p in all_paths] assert _normalize(str(checkpoints_dir)) in normalized_result assert _normalize(str(unet_dir)) in normalized_result # Verify both roots are properly set - assert len(config.checkpoints_roots) == 1 - assert len(config.unet_roots) == 1 + assert len(checkpoint_roots) == 1 + assert len(unet_roots) == 1 def test_partial_overlap_prioritizes_checkpoints( self, monkeypatch: pytest.MonkeyPatch, tmp_path, caplog @@ -129,9 +135,9 @@ class TestCheckpointPathOverlap: config._cached_fingerprint = None # One checkpoint path overlaps with one unet path - result = config._prepare_checkpoint_paths( + all_paths, checkpoint_roots, unet_roots = config._prepare_checkpoint_paths( [str(shared_link), str(separate_checkpoint)], - [str(shared_link), str(separate_unet)] + [str(shared_link), str(separate_unet)], ) # Verify warning was logged for the overlapping path @@ -144,17 +150,20 @@ class TestCheckpointPathOverlap: assert len(warning_messages) == 1 # Verify 3 unique paths (shared counted once as checkpoint, plus separate ones) - assert len(result) == 3 + assert len(all_paths) == 3 # Verify the overlapping path appears in warning message - assert str(shared_link.name) in warning_messages[0] or str(shared_dir.name) in warning_messages[0] + assert ( + str(shared_link.name) in warning_messages[0] + or str(shared_dir.name) in warning_messages[0] + ) - # Verify checkpoints_roots includes both checkpoint paths (including the shared one) - assert len(config.checkpoints_roots) == 2 - checkpoint_normalized = [_normalize(p) for p in config.checkpoints_roots] + # Verify checkpoint_roots includes both checkpoint paths (including the shared one) + assert len(checkpoint_roots) == 2 + checkpoint_normalized = [_normalize(p) for p in checkpoint_roots] assert _normalize(str(shared_link)) in checkpoint_normalized assert _normalize(str(separate_checkpoint)) in checkpoint_normalized # Verify unet_roots only includes the non-overlapping unet path - assert len(config.unet_roots) == 1 - assert _normalize(config.unet_roots[0]) == _normalize(str(separate_unet)) + assert len(unet_roots) == 1 + assert _normalize(unet_roots[0]) == _normalize(str(separate_unet)) diff --git a/tests/test_checkpoint_loaders.py b/tests/test_checkpoint_loaders.py new file mode 100644 index 00000000..9e544fb3 --- /dev/null +++ b/tests/test_checkpoint_loaders.py @@ -0,0 +1,158 @@ +"""Tests for checkpoint and unet loaders with extra folder paths support""" + +import pytest +import os + + +# Get project root directory (ComfyUI-Lora-Manager folder) +PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + + +class TestCheckpointLoaderLM: + """Test CheckpointLoaderLM node""" + + def test_class_attributes(self): + """Test that CheckpointLoaderLM has required class attributes""" + # Import in a way that doesn't require ComfyUI + import ast + + filepath = os.path.join(PROJECT_ROOT, "py", "nodes", "checkpoint_loader.py") + + with open(filepath, "r") as f: + tree = ast.parse(f.read()) + + # Find CheckpointLoaderLM class + classes = { + node.name: node for node in ast.walk(tree) if isinstance(node, ast.ClassDef) + } + assert "CheckpointLoaderLM" in classes + + cls = classes["CheckpointLoaderLM"] + + # Check for NAME attribute + name_attr = [ + n + for n in cls.body + if isinstance(n, ast.Assign) + and any(t.id == "NAME" for t in n.targets if isinstance(t, ast.Name)) + ] + assert len(name_attr) > 0, "CheckpointLoaderLM should have NAME attribute" + + # Check for CATEGORY attribute + cat_attr = [ + n + for n in cls.body + if isinstance(n, ast.Assign) + and any(t.id == "CATEGORY" for t in n.targets if isinstance(t, ast.Name)) + ] + assert len(cat_attr) > 0, "CheckpointLoaderLM should have CATEGORY attribute" + + # Check for INPUT_TYPES method + input_types = [ + n + for n in cls.body + if isinstance(n, ast.FunctionDef) and n.name == "INPUT_TYPES" + ] + assert len(input_types) > 0, "CheckpointLoaderLM should have INPUT_TYPES method" + + # Check for load_checkpoint method + load_method = [ + n + for n in cls.body + if isinstance(n, ast.FunctionDef) and n.name == "load_checkpoint" + ] + assert len(load_method) > 0, ( + "CheckpointLoaderLM should have load_checkpoint method" + ) + + +class TestUNETLoaderLM: + """Test UNETLoaderLM node""" + + def test_class_attributes(self): + """Test that UNETLoaderLM has required class attributes""" + # Import in a way that doesn't require ComfyUI + import ast + + filepath = os.path.join(PROJECT_ROOT, "py", "nodes", "unet_loader.py") + + with open(filepath, "r") as f: + tree = ast.parse(f.read()) + + # Find UNETLoaderLM class + classes = { + node.name: node for node in ast.walk(tree) if isinstance(node, ast.ClassDef) + } + assert "UNETLoaderLM" in classes + + cls = classes["UNETLoaderLM"] + + # Check for NAME attribute + name_attr = [ + n + for n in cls.body + if isinstance(n, ast.Assign) + and any(t.id == "NAME" for t in n.targets if isinstance(t, ast.Name)) + ] + assert len(name_attr) > 0, "UNETLoaderLM should have NAME attribute" + + # Check for CATEGORY attribute + cat_attr = [ + n + for n in cls.body + if isinstance(n, ast.Assign) + and any(t.id == "CATEGORY" for t in n.targets if isinstance(t, ast.Name)) + ] + assert len(cat_attr) > 0, "UNETLoaderLM should have CATEGORY attribute" + + # Check for INPUT_TYPES method + input_types = [ + n + for n in cls.body + if isinstance(n, ast.FunctionDef) and n.name == "INPUT_TYPES" + ] + assert len(input_types) > 0, "UNETLoaderLM should have INPUT_TYPES method" + + # Check for load_unet method + load_method = [ + n + for n in cls.body + if isinstance(n, ast.FunctionDef) and n.name == "load_unet" + ] + assert len(load_method) > 0, "UNETLoaderLM should have load_unet method" + + +class TestUtils: + """Test utility functions""" + + def test_get_checkpoint_info_absolute_exists(self): + """Test that get_checkpoint_info_absolute function exists in utils""" + import ast + + filepath = os.path.join(PROJECT_ROOT, "py", "utils", "utils.py") + + with open(filepath, "r") as f: + tree = ast.parse(f.read()) + + functions = [ + node.name for node in ast.walk(tree) if isinstance(node, ast.FunctionDef) + ] + assert "get_checkpoint_info_absolute" in functions, ( + "get_checkpoint_info_absolute should exist" + ) + + def test_format_model_name_for_comfyui_exists(self): + """Test that _format_model_name_for_comfyui function exists in utils""" + import ast + + filepath = os.path.join(PROJECT_ROOT, "py", "utils", "utils.py") + + with open(filepath, "r") as f: + tree = ast.parse(f.read()) + + functions = [ + node.name for node in ast.walk(tree) if isinstance(node, ast.FunctionDef) + ] + assert "_format_model_name_for_comfyui" in functions, ( + "_format_model_name_for_comfyui should exist" + )