From 1e8aca4787c3cac7601c765fe632b7a1b188d9b6 Mon Sep 17 00:00:00 2001 From: Will Miao Date: Fri, 27 Mar 2026 19:24:43 +0800 Subject: [PATCH] Add experimental Nunchaku Qwen LoRA support (#873) --- README.md | 1 + py/nodes/lora_loader.py | 367 +++++++++----------- py/nodes/nunchaku_qwen.py | 570 ++++++++++++++++++++++++++++++++ py/nodes/utils.py | 21 ++ tests/nodes/test_lora_loader.py | 176 ++++++++++ 5 files changed, 924 insertions(+), 211 deletions(-) create mode 100644 py/nodes/nunchaku_qwen.py create mode 100644 tests/nodes/test_lora_loader.py diff --git a/README.md b/README.md index d35cd93a..8199772f 100644 --- a/README.md +++ b/README.md @@ -346,6 +346,7 @@ We appreciate your understanding and look forward to potentially accepting code This project has been inspired by and benefited from other excellent ComfyUI extensions: +- [ComfyUI-QwenImageLoraLoader](https://github.com/ussoewwin/ComfyUI-QwenImageLoraLoader) - For the experimental Nunchaku Qwen-Image LoRA support - [ComfyUI-SaveImageWithMetaData](https://github.com/nkchocoai/ComfyUI-SaveImageWithMetaData) - For the image metadata functionality - [rgthree-comfy](https://github.com/rgthree/rgthree-comfy) - For the lora loader functionality diff --git a/py/nodes/lora_loader.py b/py/nodes/lora_loader.py index e6a5830f..3e90bdfc 100644 --- a/py/nodes/lora_loader.py +++ b/py/nodes/lora_loader.py @@ -1,22 +1,127 @@ import logging import re -import comfy.utils # type: ignore -import comfy.sd # type: ignore + +import comfy.sd # type: ignore +import comfy.utils # type: ignore + from ..utils.utils import get_lora_info_absolute -from .utils import FlexibleOptionalInputType, any_type, extract_lora_name, get_loras_list, nunchaku_load_lora +from .nunchaku_qwen import nunchaku_load_qwen_loras +from .utils import ( + FlexibleOptionalInputType, + any_type, + detect_nunchaku_model_kind, + extract_lora_name, + get_loras_list, + nunchaku_load_lora, +) logger = logging.getLogger(__name__) + +def _collect_stack_entries(lora_stack): + entries = [] + if not lora_stack: + return entries + + for lora_path, model_strength, clip_strength in lora_stack: + lora_name = extract_lora_name(lora_path) + absolute_lora_path, trigger_words = get_lora_info_absolute(lora_name) + entries.append({ + "name": lora_name, + "absolute_path": absolute_lora_path, + "input_path": lora_path, + "model_strength": float(model_strength), + "clip_strength": float(clip_strength), + "trigger_words": trigger_words, + }) + return entries + + +def _collect_widget_entries(kwargs): + entries = [] + for lora in get_loras_list(kwargs): + if not lora.get("active", False): + continue + lora_name = lora["name"] + model_strength = float(lora["strength"]) + clip_strength = float(lora.get("clipStrength", model_strength)) + lora_path, trigger_words = get_lora_info_absolute(lora_name) + entries.append({ + "name": lora_name, + "absolute_path": lora_path, + "input_path": lora_path, + "model_strength": model_strength, + "clip_strength": clip_strength, + "trigger_words": trigger_words, + }) + return entries + + +def _format_loaded_loras(loaded_loras): + formatted_loras = [] + for item in loaded_loras: + if item["include_clip_strength"]: + formatted_loras.append( + f"" + ) + else: + formatted_loras.append(f"") + return " ".join(formatted_loras) + + +def _apply_entries(model, clip, lora_entries, nunchaku_model_kind): + loaded_loras = [] + all_trigger_words = [] + + if nunchaku_model_kind == "qwen_image": + qwen_lora_configs = [] + for entry in lora_entries: + qwen_lora_configs.append((entry["absolute_path"], entry["model_strength"])) + loaded_loras.append({ + "name": entry["name"], + "model_strength": entry["model_strength"], + "clip_strength": entry["model_strength"], + "include_clip_strength": False, + }) + all_trigger_words.extend(entry["trigger_words"]) + if qwen_lora_configs: + model = nunchaku_load_qwen_loras(model, qwen_lora_configs) + return model, clip, loaded_loras, all_trigger_words + + for entry in lora_entries: + if nunchaku_model_kind == "flux": + model = nunchaku_load_lora(model, entry["input_path"], entry["model_strength"]) + else: + lora = comfy.utils.load_torch_file(entry["absolute_path"], safe_load=True) + model, clip = comfy.sd.load_lora_for_models( + model, + clip, + lora, + entry["model_strength"], + entry["clip_strength"], + ) + + include_clip_strength = nunchaku_model_kind is None and abs(entry["model_strength"] - entry["clip_strength"]) > 0.001 + loaded_loras.append({ + "name": entry["name"], + "model_strength": entry["model_strength"], + "clip_strength": entry["clip_strength"], + "include_clip_strength": include_clip_strength, + }) + all_trigger_words.extend(entry["trigger_words"]) + + return model, clip, loaded_loras, all_trigger_words + + class LoraLoaderLM: NAME = "Lora Loader (LoraManager)" CATEGORY = "Lora Manager/loaders" - + @classmethod def INPUT_TYPES(cls): return { "required": { "model": ("MODEL",), - # "clip": ("CLIP",), "text": ("AUTOCOMPLETE_TEXT_LORAS", { "placeholder": "Search LoRAs to add...", "tooltip": "Format: separated by spaces or punctuation", @@ -28,114 +133,30 @@ class LoraLoaderLM: RETURN_TYPES = ("MODEL", "CLIP", "STRING", "STRING") RETURN_NAMES = ("MODEL", "CLIP", "trigger_words", "loaded_loras") FUNCTION = "load_loras" - + def load_loras(self, model, text, **kwargs): """Loads multiple LoRAs based on the kwargs input and lora_stack.""" - loaded_loras = [] - all_trigger_words = [] - - clip = kwargs.get('clip', None) - lora_stack = kwargs.get('lora_stack', None) - - # Check if model is a Nunchaku Flux model - simplified approach - is_nunchaku_model = False - - try: - model_wrapper = model.model.diffusion_model - # Check if model is a Nunchaku Flux model using only class name - if model_wrapper.__class__.__name__ == "ComfyFluxWrapper": - is_nunchaku_model = True - logger.info("Detected Nunchaku Flux model") - except (AttributeError, TypeError): - # Not a model with the expected structure - pass - - # First process lora_stack if available - if lora_stack: - for lora_path, model_strength, clip_strength in lora_stack: - # Extract lora name and convert to absolute path - # lora_stack stores relative paths, but load_torch_file needs absolute paths - lora_name = extract_lora_name(lora_path) - absolute_lora_path, trigger_words = get_lora_info_absolute(lora_name) - - # Apply the LoRA using the appropriate loader - if is_nunchaku_model: - # Use our custom function for Flux models - model = nunchaku_load_lora(model, lora_path, model_strength) - # clip remains unchanged for Nunchaku models - else: - # Use lower-level API to load LoRA directly without folder_paths validation - lora = comfy.utils.load_torch_file(absolute_lora_path, safe_load=True) - model, clip = comfy.sd.load_lora_for_models(model, clip, lora, model_strength, clip_strength) - - all_trigger_words.extend(trigger_words) - # Add clip strength to output if different from model strength (except for Nunchaku models) - if not is_nunchaku_model and abs(model_strength - clip_strength) > 0.001: - loaded_loras.append(f"{lora_name}: {model_strength},{clip_strength}") - else: - loaded_loras.append(f"{lora_name}: {model_strength}") - - # Then process loras from kwargs with support for both old and new formats - loras_list = get_loras_list(kwargs) - for lora in loras_list: - if not lora.get('active', False): - continue - - lora_name = lora['name'] - model_strength = float(lora['strength']) - # Get clip strength - use model strength as default if not specified - clip_strength = float(lora.get('clipStrength', model_strength)) - - # Get lora path and trigger words - lora_path, trigger_words = get_lora_info_absolute(lora_name) - - # Apply the LoRA using the appropriate loader - if is_nunchaku_model: - # For Nunchaku models, use our custom function - model = nunchaku_load_lora(model, lora_path, model_strength) - # clip remains unchanged - else: - # Use lower-level API to load LoRA directly without folder_paths validation - lora = comfy.utils.load_torch_file(lora_path, safe_load=True) - model, clip = comfy.sd.load_lora_for_models(model, clip, lora, model_strength, clip_strength) - - # Include clip strength in output if different from model strength and not a Nunchaku model - if not is_nunchaku_model and abs(model_strength - clip_strength) > 0.001: - loaded_loras.append(f"{lora_name}: {model_strength},{clip_strength}") - else: - loaded_loras.append(f"{lora_name}: {model_strength}") - - # Add trigger words to collection - all_trigger_words.extend(trigger_words) - - # use ',, ' to separate trigger words for group mode - trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else "" - - # Format loaded_loras with support for both formats - formatted_loras = [] - for item in loaded_loras: - parts = item.split(":") - lora_name = parts[0] - strength_parts = parts[1].strip().split(",") - - if len(strength_parts) > 1: - # Different model and clip strengths - model_str = strength_parts[0].strip() - clip_str = strength_parts[1].strip() - formatted_loras.append(f"") - else: - # Same strength for both - model_str = strength_parts[0].strip() - formatted_loras.append(f"") - - formatted_loras_text = " ".join(formatted_loras) + del text + clip = kwargs.get("clip", None) + lora_entries = _collect_stack_entries(kwargs.get("lora_stack", None)) + lora_entries.extend(_collect_widget_entries(kwargs)) + nunchaku_model_kind = detect_nunchaku_model_kind(model) + if nunchaku_model_kind == "flux": + logger.info("Detected Nunchaku Flux model") + elif nunchaku_model_kind == "qwen_image": + logger.info("Detected Nunchaku Qwen-Image model") + + model, clip, loaded_loras, all_trigger_words = _apply_entries(model, clip, lora_entries, nunchaku_model_kind) + trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else "" + formatted_loras_text = _format_loaded_loras(loaded_loras) return (model, clip, trigger_words_text, formatted_loras_text) + class LoraTextLoaderLM: NAME = "LoRA Text Loader (LoraManager)" CATEGORY = "Lora Manager/loaders" - + @classmethod def INPUT_TYPES(cls): return { @@ -143,131 +164,55 @@ class LoraTextLoaderLM: "model": ("MODEL",), "lora_syntax": ("STRING", { "forceInput": True, - "tooltip": "Format: separated by spaces or punctuation" + "tooltip": "Format: separated by spaces or punctuation", }), }, "optional": { "clip": ("CLIP",), "lora_stack": ("LORA_STACK",), - } + }, } RETURN_TYPES = ("MODEL", "CLIP", "STRING", "STRING") RETURN_NAMES = ("MODEL", "CLIP", "trigger_words", "loaded_loras") FUNCTION = "load_loras_from_text" - + def parse_lora_syntax(self, text): """Parse LoRA syntax from text input.""" - # Pattern to match or - pattern = r']+):([^:>]+)(?::([^:>]+))?>' + pattern = r"]+):([^:>]+)(?::([^:>]+))?>" matches = re.findall(pattern, text, re.IGNORECASE) - + loras = [] for match in matches: - lora_name = match[0] model_strength = float(match[1]) - clip_strength = float(match[2]) if match[2] else model_strength - loras.append({ - 'name': lora_name, - 'model_strength': model_strength, - 'clip_strength': clip_strength + "name": match[0], + "model_strength": model_strength, + "clip_strength": float(match[2]) if match[2] else model_strength, }) - return loras - + def load_loras_from_text(self, model, lora_syntax, clip=None, lora_stack=None): """Load LoRAs based on text syntax input.""" - loaded_loras = [] - all_trigger_words = [] - - # Check if model is a Nunchaku Flux model - simplified approach - is_nunchaku_model = False - - try: - model_wrapper = model.model.diffusion_model - # Check if model is a Nunchaku Flux model using only class name - if model_wrapper.__class__.__name__ == "ComfyFluxWrapper": - is_nunchaku_model = True - logger.info("Detected Nunchaku Flux model") - except (AttributeError, TypeError): - # Not a model with the expected structure - pass - - # First process lora_stack if available - if lora_stack: - for lora_path, model_strength, clip_strength in lora_stack: - # Extract lora name and convert to absolute path - # lora_stack stores relative paths, but load_torch_file needs absolute paths - lora_name = extract_lora_name(lora_path) - absolute_lora_path, trigger_words = get_lora_info_absolute(lora_name) - - # Apply the LoRA using the appropriate loader - if is_nunchaku_model: - # Use our custom function for Flux models - model = nunchaku_load_lora(model, lora_path, model_strength) - # clip remains unchanged for Nunchaku models - else: - # Use lower-level API to load LoRA directly without folder_paths validation - lora = comfy.utils.load_torch_file(absolute_lora_path, safe_load=True) - model, clip = comfy.sd.load_lora_for_models(model, clip, lora, model_strength, clip_strength) - - all_trigger_words.extend(trigger_words) - # Add clip strength to output if different from model strength (except for Nunchaku models) - if not is_nunchaku_model and abs(model_strength - clip_strength) > 0.001: - loaded_loras.append(f"{lora_name}: {model_strength},{clip_strength}") - else: - loaded_loras.append(f"{lora_name}: {model_strength}") - - # Parse and process LoRAs from text syntax - parsed_loras = self.parse_lora_syntax(lora_syntax) - for lora in parsed_loras: - lora_name = lora['name'] - model_strength = lora['model_strength'] - clip_strength = lora['clip_strength'] - - # Get lora path and trigger words - lora_path, trigger_words = get_lora_info_absolute(lora_name) - - # Apply the LoRA using the appropriate loader - if is_nunchaku_model: - # For Nunchaku models, use our custom function - model = nunchaku_load_lora(model, lora_path, model_strength) - # clip remains unchanged - else: - # Use lower-level API to load LoRA directly without folder_paths validation - lora = comfy.utils.load_torch_file(lora_path, safe_load=True) - model, clip = comfy.sd.load_lora_for_models(model, clip, lora, model_strength, clip_strength) - - # Include clip strength in output if different from model strength and not a Nunchaku model - if not is_nunchaku_model and abs(model_strength - clip_strength) > 0.001: - loaded_loras.append(f"{lora_name}: {model_strength},{clip_strength}") - else: - loaded_loras.append(f"{lora_name}: {model_strength}") - - # Add trigger words to collection - all_trigger_words.extend(trigger_words) - - # use ',, ' to separate trigger words for group mode + lora_entries = _collect_stack_entries(lora_stack) + for lora in self.parse_lora_syntax(lora_syntax): + lora_path, trigger_words = get_lora_info_absolute(lora["name"]) + lora_entries.append({ + "name": lora["name"], + "absolute_path": lora_path, + "input_path": lora_path, + "model_strength": lora["model_strength"], + "clip_strength": lora["clip_strength"], + "trigger_words": trigger_words, + }) + + nunchaku_model_kind = detect_nunchaku_model_kind(model) + if nunchaku_model_kind == "flux": + logger.info("Detected Nunchaku Flux model") + elif nunchaku_model_kind == "qwen_image": + logger.info("Detected Nunchaku Qwen-Image model") + + model, clip, loaded_loras, all_trigger_words = _apply_entries(model, clip, lora_entries, nunchaku_model_kind) trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else "" - - # Format loaded_loras with support for both formats - formatted_loras = [] - for item in loaded_loras: - parts = item.split(":") - lora_name = parts[0].strip() - strength_parts = parts[1].strip().split(",") - - if len(strength_parts) > 1: - # Different model and clip strengths - model_str = strength_parts[0].strip() - clip_str = strength_parts[1].strip() - formatted_loras.append(f"") - else: - # Same strength for both - model_str = strength_parts[0].strip() - formatted_loras.append(f"") - - formatted_loras_text = " ".join(formatted_loras) - - return (model, clip, trigger_words_text, formatted_loras_text) \ No newline at end of file + formatted_loras_text = _format_loaded_loras(loaded_loras) + return (model, clip, trigger_words_text, formatted_loras_text) diff --git a/py/nodes/nunchaku_qwen.py b/py/nodes/nunchaku_qwen.py new file mode 100644 index 00000000..70c2e7d5 --- /dev/null +++ b/py/nodes/nunchaku_qwen.py @@ -0,0 +1,570 @@ +from __future__ import annotations + +"""Qwen-Image LoRA support for Nunchaku models. + +Portions of the LoRA mapping/application logic in this file are adapted from +ComfyUI-QwenImageLoraLoader by GitHub user ussoewwin: +https://github.com/ussoewwin/ComfyUI-QwenImageLoraLoader + +The upstream project is licensed under Apache License 2.0. +""" + +import copy +import logging +import os +import re +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple, Union + +import comfy.utils # type: ignore +import folder_paths # type: ignore +import torch +import torch.nn as nn +from safetensors import safe_open + +from nunchaku.lora.flux.nunchaku_converter import ( + pack_lowrank_weight, + unpack_lowrank_weight, +) + +logger = logging.getLogger(__name__) + +KEY_MAPPING = [ + (re.compile(r"^(layers)[._](\d+)[._]attention[._]to[._]([qkv])$"), r"\1.\2.attention.to_qkv", "qkv", lambda m: m.group(3).upper()), + (re.compile(r"^(layers)[._](\d+)[._]feed_forward[._](w1|w3)$"), r"\1.\2.feed_forward.net.0.proj", "glu", lambda m: m.group(3)), + (re.compile(r"^(layers)[._](\d+)[._]feed_forward[._]w2$"), r"\1.\2.feed_forward.net.2", "regular", None), + (re.compile(r"^(layers)[._](\d+)[._](.*)$"), r"\1.\2.\3", "regular", None), + (re.compile(r"^(transformer_blocks)[._](\d+)[._]attn[._]to[._]([qkv])$"), r"\1.\2.attn.to_qkv", "qkv", lambda m: m.group(3).upper()), + (re.compile(r"^(transformer_blocks)[._](\d+)[._]attn[._](q|k|v)[._]proj$"), r"\1.\2.attn.to_qkv", "qkv", lambda m: m.group(3).upper()), + (re.compile(r"^(transformer_blocks)[._](\d+)[._]attn[._]add[._](q|k|v)[._]proj$"), r"\1.\2.attn.add_qkv_proj", "add_qkv", lambda m: m.group(3).upper()), + (re.compile(r"^(transformer_blocks)[._](\d+)[._]out[._]proj[._]context$"), r"\1.\2.attn.to_add_out", "regular", None), + (re.compile(r"^(transformer_blocks)[._](\d+)[._]out[._]proj$"), r"\1.\2.attn.to_out.0", "regular", None), + (re.compile(r"^(transformer_blocks)[._](\d+)[._]attn[._]to[._]out$"), r"\1.\2.attn.to_out.0", "regular", None), + (re.compile(r"^(single_transformer_blocks)[._](\d+)[._]attn[._]to[._]([qkv])$"), r"\1.\2.attn.to_qkv", "qkv", lambda m: m.group(3).upper()), + (re.compile(r"^(single_transformer_blocks)[._](\d+)[._]attn[._]to[._]out$"), r"\1.\2.attn.to_out", "regular", None), + (re.compile(r"^(transformer_blocks)[._](\d+)[._]ff[._]net[._]0(?:[._]proj)?$"), r"\1.\2.mlp_fc1", "regular", None), + (re.compile(r"^(transformer_blocks)[._](\d+)[._]ff[._]net[._]2$"), r"\1.\2.mlp_fc2", "regular", None), + (re.compile(r"^(transformer_blocks)[._](\d+)[._]ff_context[._]net[._]0(?:[._]proj)?$"), r"\1.\2.mlp_context_fc1", "regular", None), + (re.compile(r"^(transformer_blocks)[._](\d+)[._]ff_context[._]net[._]2$"), r"\1.\2.mlp_context_fc2", "regular", None), + (re.compile(r"^(transformer_blocks)[._](\d+)[._](img_mlp)[._](net)[._](0)[._](proj)$"), r"\1.\2.\3.\4.\5.\6", "regular", None), + (re.compile(r"^(transformer_blocks)[._](\d+)[._](img_mlp)[._](net)[._](2)$"), r"\1.\2.\3.\4.\5", "regular", None), + (re.compile(r"^(transformer_blocks)[._](\d+)[._](txt_mlp)[._](net)[._](0)[._](proj)$"), r"\1.\2.\3.\4.\5.\6", "regular", None), + (re.compile(r"^(transformer_blocks)[._](\d+)[._](txt_mlp)[._](net)[._](2)$"), r"\1.\2.\3.\4.\5", "regular", None), + (re.compile(r"^(transformer_blocks)[._](\d+)[._](img_mod)[._](1)$"), r"\1.\2.\3.\4", "regular", None), + (re.compile(r"^(transformer_blocks)[._](\d+)[._](txt_mod)[._](1)$"), r"\1.\2.\3.\4", "regular", None), + (re.compile(r"^(single_transformer_blocks)[._](\d+)[._]proj[._]out$"), r"\1.\2.proj_out", "single_proj_out", None), + (re.compile(r"^(single_transformer_blocks)[._](\d+)[._]proj[._]mlp$"), r"\1.\2.mlp_fc1", "regular", None), + (re.compile(r"^(single_transformer_blocks)[._](\d+)[._]norm[._]linear$"), r"\1.\2.norm.linear", "regular", None), + (re.compile(r"^(transformer_blocks)[._](\d+)[._]norm1[._]linear$"), r"\1.\2.norm1.linear", "regular", None), + (re.compile(r"^(transformer_blocks)[._](\d+)[._]norm1_context[._]linear$"), r"\1.\2.norm1_context.linear", "regular", None), + (re.compile(r"^(img_in)$"), r"\1", "regular", None), + (re.compile(r"^(txt_in)$"), r"\1", "regular", None), + (re.compile(r"^(proj_out)$"), r"\1", "regular", None), + (re.compile(r"^(norm_out)[._](linear)$"), r"\1.\2", "regular", None), + (re.compile(r"^(time_text_embed)[._](timestep_embedder)[._](linear_1)$"), r"\1.\2.\3", "regular", None), + (re.compile(r"^(time_text_embed)[._](timestep_embedder)[._](linear_2)$"), r"\1.\2.\3", "regular", None), +] + +_RE_LORA_SUFFIX = re.compile(r"\.(?Plora(?:[._](?:A|B|down|up)))(?:\.[^.]+)*\.weight$") +_RE_ALPHA_SUFFIX = re.compile(r"\.(?:alpha|lora_alpha)(?:\.[^.]+)*$") + + +def _rename_layer_underscore_layer_name(old_name: str) -> str: + rules = [ + (r"_(\d+)_attn_to_out_(\d+)", r".\1.attn.to_out.\2"), + (r"_(\d+)_img_mlp_net_(\d+)_proj", r".\1.img_mlp.net.\2.proj"), + (r"_(\d+)_txt_mlp_net_(\d+)_proj", r".\1.txt_mlp.net.\2.proj"), + (r"_(\d+)_img_mlp_net_(\d+)", r".\1.img_mlp.net.\2"), + (r"_(\d+)_txt_mlp_net_(\d+)", r".\1.txt_mlp.net.\2"), + (r"_(\d+)_img_mod_(\d+)", r".\1.img_mod.\2"), + (r"_(\d+)_txt_mod_(\d+)", r".\1.txt_mod.\2"), + (r"_(\d+)_attn_", r".\1.attn."), + ] + new_name = old_name + for pattern, replacement in rules: + new_name = re.sub(pattern, replacement, new_name) + return new_name + + +def _is_indexable_module(module): + return isinstance(module, (nn.ModuleList, nn.Sequential, list, tuple)) + + +def _get_module_by_name(model: nn.Module, name: str) -> Optional[nn.Module]: + if not name: + return model + module = model + for part in name.split("."): + if not part: + continue + if hasattr(module, part): + module = getattr(module, part) + elif part.isdigit() and _is_indexable_module(module): + try: + module = module[int(part)] + except (IndexError, TypeError): + return None + else: + return None + return module + + +def _resolve_module_name(model: nn.Module, name: str) -> Tuple[str, Optional[nn.Module]]: + module = _get_module_by_name(model, name) + if module is not None: + return name, module + + replacements = [ + (".attn.to_out.0", ".attn.to_out"), + (".attention.to_qkv", ".attention.qkv"), + (".attention.to_out.0", ".attention.out"), + (".feed_forward.net.0.proj", ".feed_forward.w13"), + (".feed_forward.net.2", ".feed_forward.w2"), + (".ff.net.0.proj", ".mlp_fc1"), + (".ff.net.2", ".mlp_fc2"), + (".ff_context.net.0.proj", ".mlp_context_fc1"), + (".ff_context.net.2", ".mlp_context_fc2"), + ] + for src, dst in replacements: + if src in name: + alt = name.replace(src, dst) + module = _get_module_by_name(model, alt) + if module is not None: + return alt, module + return name, None + + +def _classify_and_map_key(key: str) -> Optional[Tuple[str, str, Optional[str], str]]: + normalized = key + if normalized.startswith("transformer."): + normalized = normalized[len("transformer."):] + if normalized.startswith("diffusion_model."): + normalized = normalized[len("diffusion_model."):] + if normalized.startswith("lora_unet_"): + normalized = _rename_layer_underscore_layer_name(normalized[len("lora_unet_"):]) + + match = _RE_LORA_SUFFIX.search(normalized) + if match: + tag = match.group("tag") + base = normalized[:match.start()] + ab = "A" if ("lora_A" in tag or tag.endswith(".A") or "down" in tag) else "B" + else: + match = _RE_ALPHA_SUFFIX.search(normalized) + if not match: + return None + base = normalized[:match.start()] + ab = "alpha" + + for pattern, template, group, comp_fn in KEY_MAPPING: + key_match = pattern.match(base) + if key_match: + return group, key_match.expand(template), comp_fn(key_match) if comp_fn else None, ab + return None + + +def _detect_lora_format(lora_state_dict: Dict[str, torch.Tensor]) -> bool: + standard_patterns = ( + ".lora_up.", + ".lora_down.", + ".lora_A.", + ".lora_B.", + ".lora.up.", + ".lora.down.", + ".lora.A.", + ".lora.B.", + ) + return any(pattern in key for key in lora_state_dict for pattern in standard_patterns) + + +def _load_lora_state_dict(path_or_dict: Union[str, Path, Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]: + if isinstance(path_or_dict, dict): + return path_or_dict + path = Path(path_or_dict) + if path.suffix == ".safetensors": + state_dict: Dict[str, torch.Tensor] = {} + with safe_open(path, framework="pt", device="cpu") as handle: + for key in handle.keys(): + state_dict[key] = handle.get_tensor(key) + return state_dict + return comfy.utils.load_torch_file(str(path), safe_load=True) + + +def _fuse_glu_lora(glu_weights: Dict[str, torch.Tensor]) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: + if "w1_A" not in glu_weights or "w3_A" not in glu_weights: + return None, None, None + a_w1, b_w1 = glu_weights["w1_A"], glu_weights["w1_B"] + a_w3, b_w3 = glu_weights["w3_A"], glu_weights["w3_B"] + if a_w1.shape[1] != a_w3.shape[1]: + return None, None, None + a_fused = torch.cat([a_w1, a_w3], dim=0) + out1, out3 = b_w1.shape[0], b_w3.shape[0] + rank1, rank3 = b_w1.shape[1], b_w3.shape[1] + b_fused = torch.zeros(out1 + out3, rank1 + rank3, dtype=b_w1.dtype, device=b_w1.device) + b_fused[:out1, :rank1] = b_w1 + b_fused[out1:, rank1:] = b_w3 + return a_fused, b_fused, glu_weights.get("w1_alpha") + + +def _fuse_qkv_lora(qkv_weights: Dict[str, torch.Tensor], model: Optional[nn.Module] = None, base_key: Optional[str] = None) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: + required_keys = ["Q_A", "Q_B", "K_A", "K_B", "V_A", "V_B"] + if not all(key in qkv_weights for key in required_keys): + return None, None, None + a_q, a_k, a_v = qkv_weights["Q_A"], qkv_weights["K_A"], qkv_weights["V_A"] + b_q, b_k, b_v = qkv_weights["Q_B"], qkv_weights["K_B"], qkv_weights["V_B"] + if not (a_q.shape == a_k.shape == a_v.shape): + return None, None, None + if not (b_q.shape[1] == b_k.shape[1] == b_v.shape[1]): + return None, None, None + + out_features = None + if model is not None and base_key is not None: + _, module = _resolve_module_name(model, base_key) + out_features = getattr(module, "out_features", None) if module is not None else None + + alpha_fused = None + alpha_q = qkv_weights.get("Q_alpha") + alpha_k = qkv_weights.get("K_alpha") + alpha_v = qkv_weights.get("V_alpha") + if alpha_q is not None and alpha_k is not None and alpha_v is not None and alpha_q.item() == alpha_k.item() == alpha_v.item(): + alpha_fused = alpha_q + + a_fused = torch.cat([a_q, a_k, a_v], dim=0) + rank = b_q.shape[1] + out_q, out_k, out_v = b_q.shape[0], b_k.shape[0], b_v.shape[0] + total_out = out_features if out_features is not None else out_q + out_k + out_v + b_fused = torch.zeros(total_out, 3 * rank, dtype=b_q.dtype, device=b_q.device) + b_fused[:out_q, :rank] = b_q + b_fused[out_q:out_q + out_k, rank:2 * rank] = b_k + b_fused[out_q + out_k:out_q + out_k + out_v, 2 * rank:] = b_v + return a_fused, b_fused, alpha_fused + + +def _handle_proj_out_split(lora_dict: Dict[str, Dict[str, torch.Tensor]], base_key: str, model: nn.Module) -> Tuple[Dict[str, Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]], List[str]]: + result: Dict[str, Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]] = {} + consumed: List[str] = [] + match = re.search(r"single_transformer_blocks\.(\d+)", base_key) + if not match or base_key not in lora_dict: + return result, consumed + block_idx = match.group(1) + block = _get_module_by_name(model, f"single_transformer_blocks.{block_idx}") + if block is None: + return result, consumed + a_full = lora_dict[base_key].get("A") + b_full = lora_dict[base_key].get("B") + alpha = lora_dict[base_key].get("alpha") + attn_to_out = getattr(getattr(block, "attn", None), "to_out", None) + mlp_fc2 = getattr(block, "mlp_fc2", None) + if a_full is None or b_full is None or attn_to_out is None or mlp_fc2 is None: + return result, consumed + attn_in = getattr(attn_to_out, "in_features", None) + mlp_in = getattr(mlp_fc2, "in_features", None) + if attn_in is None or mlp_in is None or a_full.shape[1] != attn_in + mlp_in: + return result, consumed + result[f"single_transformer_blocks.{block_idx}.attn.to_out"] = (a_full[:, :attn_in], b_full.clone(), alpha) + result[f"single_transformer_blocks.{block_idx}.mlp_fc2"] = (a_full[:, attn_in:], b_full.clone(), alpha) + consumed.append(base_key) + return result, consumed + + +def _apply_lora_to_module(module: nn.Module, a_tensor: torch.Tensor, b_tensor: torch.Tensor, module_name: str, model: nn.Module) -> None: + if not hasattr(module, "in_features") or not hasattr(module, "out_features"): + raise ValueError(f"{module_name}: unsupported module without in/out features") + if a_tensor.shape[1] != module.in_features or b_tensor.shape[0] != module.out_features: + raise ValueError(f"{module_name}: LoRA shape mismatch") + + if module.__class__.__name__ == "AWQW4A16Linear" and hasattr(module, "qweight"): + if not hasattr(module, "_lora_original_forward"): + module._lora_original_forward = module.forward + if not hasattr(module, "_nunchaku_lora_bundle"): + module._nunchaku_lora_bundle = [] + module._nunchaku_lora_bundle.append((a_tensor, b_tensor)) + + def _awq_lora_forward(x, *args, **kwargs): + out = module._lora_original_forward(x, *args, **kwargs) + x_flat = x.reshape(-1, module.in_features) + for local_a, local_b in module._nunchaku_lora_bundle: + local_a = local_a.to(device=out.device, dtype=out.dtype) + local_b = local_b.to(device=out.device, dtype=out.dtype) + lora_term = (x_flat @ local_a.transpose(0, 1)) @ local_b.transpose(0, 1) + try: + out = out + lora_term.reshape(out.shape) + except Exception: + pass + return out + + module.forward = _awq_lora_forward + if not hasattr(model, "_lora_slots"): + model._lora_slots = {} + model._lora_slots[module_name] = {"type": "awq_w4a16"} + return + + if hasattr(module, "proj_down") and hasattr(module, "proj_up"): + proj_down = unpack_lowrank_weight(module.proj_down.data, down=True) + proj_up = unpack_lowrank_weight(module.proj_up.data, down=False) + base_rank = proj_down.shape[0] if proj_down.shape[1] == module.in_features else proj_down.shape[1] + if proj_down.shape[1] == module.in_features: + updated_down = torch.cat([proj_down, a_tensor], dim=0) + axis_down = 0 + else: + updated_down = torch.cat([proj_down, a_tensor.T], dim=1) + axis_down = 1 + updated_up = torch.cat([proj_up, b_tensor], dim=1) + module.proj_down.data = pack_lowrank_weight(updated_down, down=True) + module.proj_up.data = pack_lowrank_weight(updated_up, down=False) + module.rank = base_rank + a_tensor.shape[0] + if not hasattr(model, "_lora_slots"): + model._lora_slots = {} + model._lora_slots[module_name] = { + "type": "nunchaku", + "base_rank": base_rank, + "axis_down": axis_down, + } + return + + if isinstance(module, nn.Linear): + if not hasattr(model, "_lora_slots"): + model._lora_slots = {} + if module_name not in model._lora_slots: + model._lora_slots[module_name] = { + "type": "linear", + "original_weight": module.weight.detach().cpu().clone(), + } + module.weight.data.add_((b_tensor @ a_tensor).to(dtype=module.weight.dtype, device=module.weight.device)) + return + + raise ValueError(f"{module_name}: unsupported module type {type(module)}") + + +def reset_lora_v2(model: nn.Module) -> None: + slots = getattr(model, "_lora_slots", None) + if not slots: + return + for name, info in list(slots.items()): + module = _get_module_by_name(model, name) + if module is None: + continue + module_type = info.get("type", "nunchaku") + if module_type == "nunchaku": + base_rank = info["base_rank"] + proj_down = unpack_lowrank_weight(module.proj_down.data, down=True) + proj_up = unpack_lowrank_weight(module.proj_up.data, down=False) + if info.get("axis_down", 0) == 0: + proj_down = proj_down[:base_rank, :].clone() + else: + proj_down = proj_down[:, :base_rank].clone() + proj_up = proj_up[:, :base_rank].clone() + module.proj_down.data = pack_lowrank_weight(proj_down, down=True) + module.proj_up.data = pack_lowrank_weight(proj_up, down=False) + module.rank = base_rank + elif module_type == "linear" and "original_weight" in info: + module.weight.data.copy_(info["original_weight"].to(device=module.weight.device, dtype=module.weight.dtype)) + elif module_type == "awq_w4a16": + if hasattr(module, "_lora_original_forward"): + module.forward = module._lora_original_forward + for attr in ("_lora_original_forward", "_nunchaku_lora_bundle"): + if hasattr(module, attr): + delattr(module, attr) + model._lora_slots = {} + + +def compose_loras_v2(model: nn.Module, lora_configs: List[Tuple[Union[str, Path, Dict[str, torch.Tensor]], float]], apply_awq_mod: bool = True) -> bool: + del apply_awq_mod # retained for interface compatibility + reset_lora_v2(model) + aggregated_weights: Dict[str, List[Dict[str, object]]] = defaultdict(list) + saw_supported_format = False + unresolved_targets = 0 + + for index, (path_or_dict, strength) in enumerate(lora_configs): + if abs(strength) < 1e-5: + continue + lora_name = str(path_or_dict) if not isinstance(path_or_dict, dict) else f"lora_{index}" + lora_state_dict = _load_lora_state_dict(path_or_dict) + if not lora_state_dict or not _detect_lora_format(lora_state_dict): + logger.warning("Skipping unsupported Qwen LoRA: %s", lora_name) + continue + saw_supported_format = True + + grouped_weights: Dict[str, Dict[str, torch.Tensor]] = defaultdict(dict) + for key, value in lora_state_dict.items(): + parsed = _classify_and_map_key(key) + if parsed is None: + continue + group, base_key, component, ab = parsed + if component and ab: + grouped_weights[base_key][f"{component}_{ab}"] = value + else: + grouped_weights[base_key][ab] = value + + processed_groups: Dict[str, Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]] = {} + handled: set[str] = set() + for base_key, weights in grouped_weights.items(): + if base_key in handled: + continue + a_tensor = b_tensor = alpha = None + if "qkv" in base_key or "add_qkv_proj" in base_key: + a_tensor, b_tensor, alpha = _fuse_qkv_lora(weights, model=model, base_key=base_key) + elif "w1_A" in weights or "w3_A" in weights: + a_tensor, b_tensor, alpha = _fuse_glu_lora(weights) + elif ".proj_out" in base_key and "single_transformer_blocks" in base_key: + split_map, consumed = _handle_proj_out_split(grouped_weights, base_key, model) + processed_groups.update(split_map) + handled.update(consumed) + continue + else: + a_tensor, b_tensor, alpha = weights.get("A"), weights.get("B"), weights.get("alpha") + if a_tensor is not None and b_tensor is not None: + processed_groups[base_key] = (a_tensor, b_tensor, alpha) + + for module_name, (a_tensor, b_tensor, alpha) in processed_groups.items(): + aggregated_weights[module_name].append({ + "A": a_tensor, + "B": b_tensor, + "alpha": alpha, + "strength": strength, + }) + + for module_name, weight_list in aggregated_weights.items(): + resolved_name, module = _resolve_module_name(model, module_name) + if module is None: + logger.warning("Skipping unresolved Qwen LoRA target: %s", module_name) + unresolved_targets += 1 + continue + all_a = [] + all_b_scaled = [] + for item in weight_list: + a_tensor = item["A"] + b_tensor = item["B"] + alpha = item["alpha"] + strength = float(item["strength"]) + rank = a_tensor.shape[0] + scale = strength * ((alpha / rank) if alpha is not None else 1.0) + if module.__class__.__name__ == "AWQW4A16Linear" and hasattr(module, "qweight"): + target_dtype = torch.float16 + target_device = module.qweight.device + elif hasattr(module, "proj_down"): + target_dtype = module.proj_down.dtype + target_device = module.proj_down.device + elif hasattr(module, "weight"): + target_dtype = module.weight.dtype + target_device = module.weight.device + else: + target_dtype = torch.float16 + target_device = "cuda" if torch.cuda.is_available() else "cpu" + all_a.append(a_tensor.to(dtype=target_dtype, device=target_device)) + all_b_scaled.append((b_tensor * scale).to(dtype=target_dtype, device=target_device)) + if not all_a: + continue + _apply_lora_to_module(module, torch.cat(all_a, dim=0), torch.cat(all_b_scaled, dim=1), resolved_name, model) + + slot_count = len(getattr(model, "_lora_slots", {}) or {}) + logger.info( + "Qwen LoRA composition finished: requested=%d supported=%s applied_targets=%d unresolved=%d", + len(lora_configs), + saw_supported_format, + slot_count, + unresolved_targets, + ) + return saw_supported_format + + +class ComfyQwenImageWrapperLM(nn.Module): + def __init__(self, model: nn.Module, config=None, apply_awq_mod: bool = True): + super().__init__() + self.model = model + self.config = {} if config is None else config + self.dtype = next(model.parameters()).dtype + self.loras: List[Tuple[Union[str, Path, Dict[str, torch.Tensor]], float]] = [] + self._applied_loras: Optional[List[Tuple[Union[str, Path, Dict[str, torch.Tensor]], float]]] = None + self.apply_awq_mod = apply_awq_mod + + def __getattr__(self, name): + try: + inner = object.__getattribute__(self, "_modules").get("model") + except (AttributeError, KeyError): + inner = None + if inner is None: + raise AttributeError(f"{type(self).__name__!s} has no attribute {name}") + if name == "model": + return inner + return getattr(inner, name) + + def process_img(self, *args, **kwargs): + return self.model.process_img(*args, **kwargs) + + def _ensure_composed(self): + if self._applied_loras != self.loras or (not self.loras and getattr(self.model, "_lora_slots", None)): + is_supported_format = compose_loras_v2(self.model, self.loras, apply_awq_mod=self.apply_awq_mod) + self._applied_loras = self.loras.copy() + has_slots = bool(getattr(self.model, "_lora_slots", None)) + if self.loras and is_supported_format and not has_slots: + logger.warning("Qwen LoRA compose produced 0 target modules. Resetting and retrying once.") + reset_lora_v2(self.model) + compose_loras_v2(self.model, self.loras, apply_awq_mod=self.apply_awq_mod) + has_slots = bool(getattr(self.model, "_lora_slots", None)) + logger.info("Qwen LoRA retry result: applied_targets=%d", len(getattr(self.model, "_lora_slots", {}) or {})) + + offload_manager = getattr(self.model, "offload_manager", None) + if offload_manager is not None: + offload_settings = { + "num_blocks_on_gpu": getattr(offload_manager, "num_blocks_on_gpu", 1), + "use_pin_memory": getattr(offload_manager, "use_pin_memory", False), + } + logger.info( + "Rebuilding Qwen offload manager after LoRA compose: num_blocks_on_gpu=%s use_pin_memory=%s", + offload_settings["num_blocks_on_gpu"], + offload_settings["use_pin_memory"], + ) + self.model.set_offload(False) + self.model.set_offload(True, **offload_settings) + + def forward(self, *args, **kwargs): + self._ensure_composed() + return self.model(*args, **kwargs) + + +def _get_qwen_wrapper_and_transformer(model): + model_wrapper = model.model.diffusion_model + if hasattr(model_wrapper, "model") and hasattr(model_wrapper, "loras"): + transformer = model_wrapper.model + if transformer.__class__.__name__.endswith("NunchakuQwenImageTransformer2DModel"): + return model_wrapper, transformer + if model_wrapper.__class__.__name__.endswith("NunchakuQwenImageTransformer2DModel"): + wrapped_model = ComfyQwenImageWrapperLM(model_wrapper, getattr(model_wrapper, "config", {})) + model.model.diffusion_model = wrapped_model + return wrapped_model, wrapped_model.model + raise TypeError(f"This LoRA loader only works with Nunchaku Qwen Image models, but got {type(model_wrapper).__name__}.") + + +def nunchaku_load_qwen_loras(model, lora_configs: List[Tuple[str, float]], apply_awq_mod: bool = True): + model_wrapper, transformer = _get_qwen_wrapper_and_transformer(model) + model_wrapper.apply_awq_mod = apply_awq_mod + + saved_config = None + if hasattr(model, "model") and hasattr(model.model, "model_config"): + saved_config = model.model.model_config + model.model.model_config = None + + model_wrapper.model = None + try: + ret_model = copy.deepcopy(model) + finally: + if saved_config is not None: + model.model.model_config = saved_config + model_wrapper.model = transformer + + ret_model_wrapper = ret_model.model.diffusion_model + if saved_config is not None: + ret_model.model.model_config = saved_config + ret_model_wrapper.model = transformer + ret_model_wrapper.apply_awq_mod = apply_awq_mod + ret_model_wrapper.loras = list(getattr(model_wrapper, "loras", [])) + + for lora_name, lora_strength in lora_configs: + lora_path = lora_name if os.path.isfile(lora_name) else folder_paths.get_full_path("loras", lora_name) + if not lora_path or not os.path.isfile(lora_path): + logger.warning("Skipping Qwen LoRA '%s' because it could not be found", lora_name) + continue + ret_model_wrapper.loras.append((lora_path, lora_strength)) + + return ret_model diff --git a/py/nodes/utils.py b/py/nodes/utils.py index 89183508..905126c7 100644 --- a/py/nodes/utils.py +++ b/py/nodes/utils.py @@ -158,3 +158,24 @@ def nunchaku_load_lora(model, lora_name, lora_strength): ret_model.model.model_config.unet_config["in_channels"] = new_in_channels return ret_model + + +def detect_nunchaku_model_kind(model): + """Return the supported Nunchaku model kind for a Comfy model, if any.""" + try: + model_wrapper = model.model.diffusion_model + except (AttributeError, TypeError): + return None + + wrapper_name = model_wrapper.__class__.__name__ + if wrapper_name == "ComfyFluxWrapper": + return "flux" + + inner_model = getattr(model_wrapper, "model", None) + inner_name = inner_model.__class__.__name__ if inner_model is not None else "" + if wrapper_name.endswith("NunchakuQwenImageTransformer2DModel"): + return "qwen_image" + if inner_name.endswith("NunchakuQwenImageTransformer2DModel"): + return "qwen_image" + + return None diff --git a/tests/nodes/test_lora_loader.py b/tests/nodes/test_lora_loader.py new file mode 100644 index 00000000..4a43af15 --- /dev/null +++ b/tests/nodes/test_lora_loader.py @@ -0,0 +1,176 @@ +import types + +from py.nodes.lora_loader import LoraLoaderLM, LoraTextLoaderLM + + +class _ModelContainer: + def __init__(self, diffusion_model): + self.diffusion_model = diffusion_model + + +class _Model: + def __init__(self, diffusion_model): + self.model = _ModelContainer(diffusion_model) + + +def test_lora_loader_standard_model_uses_comfy_loader(monkeypatch): + loader = LoraLoaderLM() + model = _Model(object()) + clip = object() + + monkeypatch.setattr( + "py.nodes.lora_loader.get_lora_info_absolute", + lambda name: (f"/abs/{name}.safetensors", [f"{name}_trigger"]), + ) + + load_calls = [] + + def mock_load_torch_file(path, safe_load=True): + load_calls.append((path, safe_load)) + return {"path": path} + + def mock_load_lora_for_models(model_arg, clip_arg, lora_arg, model_strength, clip_strength): + return model_arg, clip_arg + + monkeypatch.setattr("comfy.utils.load_torch_file", mock_load_torch_file) + monkeypatch.setattr("comfy.sd.load_lora_for_models", mock_load_lora_for_models) + + result_model, result_clip, trigger_words, loaded_loras = loader.load_loras( + model, + "", + clip=clip, + loras={ + "__value__": [ + {"active": True, "name": "demo", "strength": 0.75, "clipStrength": 0.5}, + ] + }, + ) + + assert result_model is model + assert result_clip is clip + assert load_calls == [("/abs/demo.safetensors", True)] + assert trigger_words == "demo_trigger" + assert loaded_loras == "" + + +def test_lora_loader_formats_widget_lora_names_with_colons(monkeypatch): + loader = LoraLoaderLM() + model = _Model(object()) + clip = object() + + monkeypatch.setattr( + "py.nodes.lora_loader.get_lora_info_absolute", + lambda name: (f"/abs/{name}.safetensors", [f"{name}_trigger"]), + ) + monkeypatch.setattr("comfy.utils.load_torch_file", lambda path, safe_load=True: {"path": path}) + monkeypatch.setattr( + "comfy.sd.load_lora_for_models", + lambda model_arg, clip_arg, lora_arg, model_strength, clip_strength: (model_arg, clip_arg), + ) + + _, _, trigger_words, loaded_loras = loader.load_loras( + model, + "", + clip=clip, + loras={ + "__value__": [ + {"active": True, "name": "demo:variant", "strength": 0.75, "clipStrength": 0.5}, + {"active": True, "name": "demo:single", "strength": 0.3}, + ] + }, + ) + + assert trigger_words == "demo:variant_trigger,, demo:single_trigger" + assert loaded_loras == " " + + +def test_lora_loader_flux_model_uses_flux_helper(monkeypatch): + flux_model = _Model(type("ComfyFluxWrapper", (), {})()) + loader = LoraLoaderLM() + + monkeypatch.setattr( + "py.nodes.lora_loader.get_lora_info_absolute", + lambda name: (f"/abs/{name}.safetensors", [f"{name}_trigger"]), + ) + + calls = [] + + def mock_nunchaku_load_lora(model_arg, lora_name, strength): + calls.append((lora_name, strength)) + return model_arg + + monkeypatch.setattr("py.nodes.lora_loader.nunchaku_load_lora", mock_nunchaku_load_lora) + + _, _, trigger_words, loaded_loras = loader.load_loras( + flux_model, + "", + lora_stack=[("stack_lora.safetensors", 0.4, 0.2)], + loras={"__value__": [{"active": True, "name": "widget_lora", "strength": 0.8}]}, + ) + + assert calls == [("stack_lora.safetensors", 0.4), ("/abs/widget_lora.safetensors", 0.8)] + assert trigger_words == "stack_lora_trigger,, widget_lora_trigger" + assert loaded_loras == " " + + +def test_lora_loader_qwen_model_batches_loras(monkeypatch): + qwen_model = _Model(type("NunchakuQwenImageTransformer2DModel", (), {})()) + loader = LoraLoaderLM() + + monkeypatch.setattr( + "py.nodes.lora_loader.get_lora_info_absolute", + lambda name: (f"/abs/{name}.safetensors", [f"{name}_trigger"]), + ) + + batched_calls = [] + + def mock_nunchaku_load_qwen_loras(model_arg, lora_configs): + batched_calls.append((model_arg, lora_configs)) + return model_arg + + monkeypatch.setattr("py.nodes.lora_loader.nunchaku_load_qwen_loras", mock_nunchaku_load_qwen_loras) + + _, result_clip, trigger_words, loaded_loras = loader.load_loras( + qwen_model, + "", + clip="clip", + lora_stack=[("stack_qwen.safetensors", 0.6, 0.1)], + loras={"__value__": [{"active": True, "name": "widget_qwen", "strength": 0.9, "clipStrength": 0.3}]}, + ) + + assert result_clip == "clip" + assert len(batched_calls) == 1 + assert batched_calls[0][0] is qwen_model + assert batched_calls[0][1] == [ + ("/abs/stack_qwen.safetensors", 0.6), + ("/abs/widget_qwen.safetensors", 0.9), + ] + assert trigger_words == "stack_qwen_trigger,, widget_qwen_trigger" + assert loaded_loras == " " + + +def test_lora_text_loader_qwen_batches_text_and_stack(monkeypatch): + qwen_model = _Model(type("NunchakuQwenImageTransformer2DModel", (), {})()) + loader = LoraTextLoaderLM() + + monkeypatch.setattr( + "py.nodes.lora_loader.get_lora_info_absolute", + lambda name: (f"/abs/{name}.safetensors", [f"{name}_trigger"]), + ) + + batched_calls = [] + monkeypatch.setattr( + "py.nodes.lora_loader.nunchaku_load_qwen_loras", + lambda model_arg, lora_configs: batched_calls.append(lora_configs) or model_arg, + ) + + _, _, trigger_words, loaded_loras = loader.load_loras_from_text( + qwen_model, + "", + clip="clip", + lora_stack=[("stack_qwen.safetensors", 0.6, 0.1)], + ) + + assert batched_calls == [[("/abs/stack_qwen.safetensors", 0.6), ("/abs/text_qwen.safetensors", 1.2)]] + assert trigger_words == "stack_qwen_trigger,, text_qwen_trigger" + assert loaded_loras == " "