mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-28 00:18:52 -03:00
Add experimental Nunchaku Qwen LoRA support (#873)
This commit is contained in:
@@ -346,6 +346,7 @@ We appreciate your understanding and look forward to potentially accepting code
|
||||
|
||||
This project has been inspired by and benefited from other excellent ComfyUI extensions:
|
||||
|
||||
- [ComfyUI-QwenImageLoraLoader](https://github.com/ussoewwin/ComfyUI-QwenImageLoraLoader) - For the experimental Nunchaku Qwen-Image LoRA support
|
||||
- [ComfyUI-SaveImageWithMetaData](https://github.com/nkchocoai/ComfyUI-SaveImageWithMetaData) - For the image metadata functionality
|
||||
- [rgthree-comfy](https://github.com/rgthree/rgthree-comfy) - For the lora loader functionality
|
||||
|
||||
|
||||
@@ -1,22 +1,127 @@
|
||||
import logging
|
||||
import re
|
||||
import comfy.utils # type: ignore
|
||||
import comfy.sd # type: ignore
|
||||
|
||||
import comfy.sd # type: ignore
|
||||
import comfy.utils # type: ignore
|
||||
|
||||
from ..utils.utils import get_lora_info_absolute
|
||||
from .utils import FlexibleOptionalInputType, any_type, extract_lora_name, get_loras_list, nunchaku_load_lora
|
||||
from .nunchaku_qwen import nunchaku_load_qwen_loras
|
||||
from .utils import (
|
||||
FlexibleOptionalInputType,
|
||||
any_type,
|
||||
detect_nunchaku_model_kind,
|
||||
extract_lora_name,
|
||||
get_loras_list,
|
||||
nunchaku_load_lora,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _collect_stack_entries(lora_stack):
|
||||
entries = []
|
||||
if not lora_stack:
|
||||
return entries
|
||||
|
||||
for lora_path, model_strength, clip_strength in lora_stack:
|
||||
lora_name = extract_lora_name(lora_path)
|
||||
absolute_lora_path, trigger_words = get_lora_info_absolute(lora_name)
|
||||
entries.append({
|
||||
"name": lora_name,
|
||||
"absolute_path": absolute_lora_path,
|
||||
"input_path": lora_path,
|
||||
"model_strength": float(model_strength),
|
||||
"clip_strength": float(clip_strength),
|
||||
"trigger_words": trigger_words,
|
||||
})
|
||||
return entries
|
||||
|
||||
|
||||
def _collect_widget_entries(kwargs):
|
||||
entries = []
|
||||
for lora in get_loras_list(kwargs):
|
||||
if not lora.get("active", False):
|
||||
continue
|
||||
lora_name = lora["name"]
|
||||
model_strength = float(lora["strength"])
|
||||
clip_strength = float(lora.get("clipStrength", model_strength))
|
||||
lora_path, trigger_words = get_lora_info_absolute(lora_name)
|
||||
entries.append({
|
||||
"name": lora_name,
|
||||
"absolute_path": lora_path,
|
||||
"input_path": lora_path,
|
||||
"model_strength": model_strength,
|
||||
"clip_strength": clip_strength,
|
||||
"trigger_words": trigger_words,
|
||||
})
|
||||
return entries
|
||||
|
||||
|
||||
def _format_loaded_loras(loaded_loras):
|
||||
formatted_loras = []
|
||||
for item in loaded_loras:
|
||||
if item["include_clip_strength"]:
|
||||
formatted_loras.append(
|
||||
f"<lora:{item['name']}:{item['model_strength']}:{item['clip_strength']}>"
|
||||
)
|
||||
else:
|
||||
formatted_loras.append(f"<lora:{item['name']}:{item['model_strength']}>")
|
||||
return " ".join(formatted_loras)
|
||||
|
||||
|
||||
def _apply_entries(model, clip, lora_entries, nunchaku_model_kind):
|
||||
loaded_loras = []
|
||||
all_trigger_words = []
|
||||
|
||||
if nunchaku_model_kind == "qwen_image":
|
||||
qwen_lora_configs = []
|
||||
for entry in lora_entries:
|
||||
qwen_lora_configs.append((entry["absolute_path"], entry["model_strength"]))
|
||||
loaded_loras.append({
|
||||
"name": entry["name"],
|
||||
"model_strength": entry["model_strength"],
|
||||
"clip_strength": entry["model_strength"],
|
||||
"include_clip_strength": False,
|
||||
})
|
||||
all_trigger_words.extend(entry["trigger_words"])
|
||||
if qwen_lora_configs:
|
||||
model = nunchaku_load_qwen_loras(model, qwen_lora_configs)
|
||||
return model, clip, loaded_loras, all_trigger_words
|
||||
|
||||
for entry in lora_entries:
|
||||
if nunchaku_model_kind == "flux":
|
||||
model = nunchaku_load_lora(model, entry["input_path"], entry["model_strength"])
|
||||
else:
|
||||
lora = comfy.utils.load_torch_file(entry["absolute_path"], safe_load=True)
|
||||
model, clip = comfy.sd.load_lora_for_models(
|
||||
model,
|
||||
clip,
|
||||
lora,
|
||||
entry["model_strength"],
|
||||
entry["clip_strength"],
|
||||
)
|
||||
|
||||
include_clip_strength = nunchaku_model_kind is None and abs(entry["model_strength"] - entry["clip_strength"]) > 0.001
|
||||
loaded_loras.append({
|
||||
"name": entry["name"],
|
||||
"model_strength": entry["model_strength"],
|
||||
"clip_strength": entry["clip_strength"],
|
||||
"include_clip_strength": include_clip_strength,
|
||||
})
|
||||
all_trigger_words.extend(entry["trigger_words"])
|
||||
|
||||
return model, clip, loaded_loras, all_trigger_words
|
||||
|
||||
|
||||
class LoraLoaderLM:
|
||||
NAME = "Lora Loader (LoraManager)"
|
||||
CATEGORY = "Lora Manager/loaders"
|
||||
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"model": ("MODEL",),
|
||||
# "clip": ("CLIP",),
|
||||
"text": ("AUTOCOMPLETE_TEXT_LORAS", {
|
||||
"placeholder": "Search LoRAs to add...",
|
||||
"tooltip": "Format: <lora:lora_name:strength> separated by spaces or punctuation",
|
||||
@@ -28,114 +133,30 @@ class LoraLoaderLM:
|
||||
RETURN_TYPES = ("MODEL", "CLIP", "STRING", "STRING")
|
||||
RETURN_NAMES = ("MODEL", "CLIP", "trigger_words", "loaded_loras")
|
||||
FUNCTION = "load_loras"
|
||||
|
||||
|
||||
def load_loras(self, model, text, **kwargs):
|
||||
"""Loads multiple LoRAs based on the kwargs input and lora_stack."""
|
||||
loaded_loras = []
|
||||
all_trigger_words = []
|
||||
|
||||
clip = kwargs.get('clip', None)
|
||||
lora_stack = kwargs.get('lora_stack', None)
|
||||
|
||||
# Check if model is a Nunchaku Flux model - simplified approach
|
||||
is_nunchaku_model = False
|
||||
|
||||
try:
|
||||
model_wrapper = model.model.diffusion_model
|
||||
# Check if model is a Nunchaku Flux model using only class name
|
||||
if model_wrapper.__class__.__name__ == "ComfyFluxWrapper":
|
||||
is_nunchaku_model = True
|
||||
logger.info("Detected Nunchaku Flux model")
|
||||
except (AttributeError, TypeError):
|
||||
# Not a model with the expected structure
|
||||
pass
|
||||
|
||||
# First process lora_stack if available
|
||||
if lora_stack:
|
||||
for lora_path, model_strength, clip_strength in lora_stack:
|
||||
# Extract lora name and convert to absolute path
|
||||
# lora_stack stores relative paths, but load_torch_file needs absolute paths
|
||||
lora_name = extract_lora_name(lora_path)
|
||||
absolute_lora_path, trigger_words = get_lora_info_absolute(lora_name)
|
||||
|
||||
# Apply the LoRA using the appropriate loader
|
||||
if is_nunchaku_model:
|
||||
# Use our custom function for Flux models
|
||||
model = nunchaku_load_lora(model, lora_path, model_strength)
|
||||
# clip remains unchanged for Nunchaku models
|
||||
else:
|
||||
# Use lower-level API to load LoRA directly without folder_paths validation
|
||||
lora = comfy.utils.load_torch_file(absolute_lora_path, safe_load=True)
|
||||
model, clip = comfy.sd.load_lora_for_models(model, clip, lora, model_strength, clip_strength)
|
||||
|
||||
all_trigger_words.extend(trigger_words)
|
||||
# Add clip strength to output if different from model strength (except for Nunchaku models)
|
||||
if not is_nunchaku_model and abs(model_strength - clip_strength) > 0.001:
|
||||
loaded_loras.append(f"{lora_name}: {model_strength},{clip_strength}")
|
||||
else:
|
||||
loaded_loras.append(f"{lora_name}: {model_strength}")
|
||||
|
||||
# Then process loras from kwargs with support for both old and new formats
|
||||
loras_list = get_loras_list(kwargs)
|
||||
for lora in loras_list:
|
||||
if not lora.get('active', False):
|
||||
continue
|
||||
|
||||
lora_name = lora['name']
|
||||
model_strength = float(lora['strength'])
|
||||
# Get clip strength - use model strength as default if not specified
|
||||
clip_strength = float(lora.get('clipStrength', model_strength))
|
||||
|
||||
# Get lora path and trigger words
|
||||
lora_path, trigger_words = get_lora_info_absolute(lora_name)
|
||||
|
||||
# Apply the LoRA using the appropriate loader
|
||||
if is_nunchaku_model:
|
||||
# For Nunchaku models, use our custom function
|
||||
model = nunchaku_load_lora(model, lora_path, model_strength)
|
||||
# clip remains unchanged
|
||||
else:
|
||||
# Use lower-level API to load LoRA directly without folder_paths validation
|
||||
lora = comfy.utils.load_torch_file(lora_path, safe_load=True)
|
||||
model, clip = comfy.sd.load_lora_for_models(model, clip, lora, model_strength, clip_strength)
|
||||
|
||||
# Include clip strength in output if different from model strength and not a Nunchaku model
|
||||
if not is_nunchaku_model and abs(model_strength - clip_strength) > 0.001:
|
||||
loaded_loras.append(f"{lora_name}: {model_strength},{clip_strength}")
|
||||
else:
|
||||
loaded_loras.append(f"{lora_name}: {model_strength}")
|
||||
|
||||
# Add trigger words to collection
|
||||
all_trigger_words.extend(trigger_words)
|
||||
|
||||
# use ',, ' to separate trigger words for group mode
|
||||
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
|
||||
|
||||
# Format loaded_loras with support for both formats
|
||||
formatted_loras = []
|
||||
for item in loaded_loras:
|
||||
parts = item.split(":")
|
||||
lora_name = parts[0]
|
||||
strength_parts = parts[1].strip().split(",")
|
||||
|
||||
if len(strength_parts) > 1:
|
||||
# Different model and clip strengths
|
||||
model_str = strength_parts[0].strip()
|
||||
clip_str = strength_parts[1].strip()
|
||||
formatted_loras.append(f"<lora:{lora_name}:{model_str}:{clip_str}>")
|
||||
else:
|
||||
# Same strength for both
|
||||
model_str = strength_parts[0].strip()
|
||||
formatted_loras.append(f"<lora:{lora_name}:{model_str}>")
|
||||
|
||||
formatted_loras_text = " ".join(formatted_loras)
|
||||
del text
|
||||
clip = kwargs.get("clip", None)
|
||||
lora_entries = _collect_stack_entries(kwargs.get("lora_stack", None))
|
||||
lora_entries.extend(_collect_widget_entries(kwargs))
|
||||
|
||||
nunchaku_model_kind = detect_nunchaku_model_kind(model)
|
||||
if nunchaku_model_kind == "flux":
|
||||
logger.info("Detected Nunchaku Flux model")
|
||||
elif nunchaku_model_kind == "qwen_image":
|
||||
logger.info("Detected Nunchaku Qwen-Image model")
|
||||
|
||||
model, clip, loaded_loras, all_trigger_words = _apply_entries(model, clip, lora_entries, nunchaku_model_kind)
|
||||
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
|
||||
formatted_loras_text = _format_loaded_loras(loaded_loras)
|
||||
return (model, clip, trigger_words_text, formatted_loras_text)
|
||||
|
||||
|
||||
class LoraTextLoaderLM:
|
||||
NAME = "LoRA Text Loader (LoraManager)"
|
||||
CATEGORY = "Lora Manager/loaders"
|
||||
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
@@ -143,131 +164,55 @@ class LoraTextLoaderLM:
|
||||
"model": ("MODEL",),
|
||||
"lora_syntax": ("STRING", {
|
||||
"forceInput": True,
|
||||
"tooltip": "Format: <lora:lora_name:strength> separated by spaces or punctuation"
|
||||
"tooltip": "Format: <lora:lora_name:strength> separated by spaces or punctuation",
|
||||
}),
|
||||
},
|
||||
"optional": {
|
||||
"clip": ("CLIP",),
|
||||
"lora_stack": ("LORA_STACK",),
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("MODEL", "CLIP", "STRING", "STRING")
|
||||
RETURN_NAMES = ("MODEL", "CLIP", "trigger_words", "loaded_loras")
|
||||
FUNCTION = "load_loras_from_text"
|
||||
|
||||
|
||||
def parse_lora_syntax(self, text):
|
||||
"""Parse LoRA syntax from text input."""
|
||||
# Pattern to match <lora:name:strength> or <lora:name:model_strength:clip_strength>
|
||||
pattern = r'<lora:([^:>]+):([^:>]+)(?::([^:>]+))?>'
|
||||
pattern = r"<lora:([^:>]+):([^:>]+)(?::([^:>]+))?>"
|
||||
matches = re.findall(pattern, text, re.IGNORECASE)
|
||||
|
||||
|
||||
loras = []
|
||||
for match in matches:
|
||||
lora_name = match[0]
|
||||
model_strength = float(match[1])
|
||||
clip_strength = float(match[2]) if match[2] else model_strength
|
||||
|
||||
loras.append({
|
||||
'name': lora_name,
|
||||
'model_strength': model_strength,
|
||||
'clip_strength': clip_strength
|
||||
"name": match[0],
|
||||
"model_strength": model_strength,
|
||||
"clip_strength": float(match[2]) if match[2] else model_strength,
|
||||
})
|
||||
|
||||
return loras
|
||||
|
||||
|
||||
def load_loras_from_text(self, model, lora_syntax, clip=None, lora_stack=None):
|
||||
"""Load LoRAs based on text syntax input."""
|
||||
loaded_loras = []
|
||||
all_trigger_words = []
|
||||
|
||||
# Check if model is a Nunchaku Flux model - simplified approach
|
||||
is_nunchaku_model = False
|
||||
|
||||
try:
|
||||
model_wrapper = model.model.diffusion_model
|
||||
# Check if model is a Nunchaku Flux model using only class name
|
||||
if model_wrapper.__class__.__name__ == "ComfyFluxWrapper":
|
||||
is_nunchaku_model = True
|
||||
logger.info("Detected Nunchaku Flux model")
|
||||
except (AttributeError, TypeError):
|
||||
# Not a model with the expected structure
|
||||
pass
|
||||
|
||||
# First process lora_stack if available
|
||||
if lora_stack:
|
||||
for lora_path, model_strength, clip_strength in lora_stack:
|
||||
# Extract lora name and convert to absolute path
|
||||
# lora_stack stores relative paths, but load_torch_file needs absolute paths
|
||||
lora_name = extract_lora_name(lora_path)
|
||||
absolute_lora_path, trigger_words = get_lora_info_absolute(lora_name)
|
||||
|
||||
# Apply the LoRA using the appropriate loader
|
||||
if is_nunchaku_model:
|
||||
# Use our custom function for Flux models
|
||||
model = nunchaku_load_lora(model, lora_path, model_strength)
|
||||
# clip remains unchanged for Nunchaku models
|
||||
else:
|
||||
# Use lower-level API to load LoRA directly without folder_paths validation
|
||||
lora = comfy.utils.load_torch_file(absolute_lora_path, safe_load=True)
|
||||
model, clip = comfy.sd.load_lora_for_models(model, clip, lora, model_strength, clip_strength)
|
||||
|
||||
all_trigger_words.extend(trigger_words)
|
||||
# Add clip strength to output if different from model strength (except for Nunchaku models)
|
||||
if not is_nunchaku_model and abs(model_strength - clip_strength) > 0.001:
|
||||
loaded_loras.append(f"{lora_name}: {model_strength},{clip_strength}")
|
||||
else:
|
||||
loaded_loras.append(f"{lora_name}: {model_strength}")
|
||||
|
||||
# Parse and process LoRAs from text syntax
|
||||
parsed_loras = self.parse_lora_syntax(lora_syntax)
|
||||
for lora in parsed_loras:
|
||||
lora_name = lora['name']
|
||||
model_strength = lora['model_strength']
|
||||
clip_strength = lora['clip_strength']
|
||||
|
||||
# Get lora path and trigger words
|
||||
lora_path, trigger_words = get_lora_info_absolute(lora_name)
|
||||
|
||||
# Apply the LoRA using the appropriate loader
|
||||
if is_nunchaku_model:
|
||||
# For Nunchaku models, use our custom function
|
||||
model = nunchaku_load_lora(model, lora_path, model_strength)
|
||||
# clip remains unchanged
|
||||
else:
|
||||
# Use lower-level API to load LoRA directly without folder_paths validation
|
||||
lora = comfy.utils.load_torch_file(lora_path, safe_load=True)
|
||||
model, clip = comfy.sd.load_lora_for_models(model, clip, lora, model_strength, clip_strength)
|
||||
|
||||
# Include clip strength in output if different from model strength and not a Nunchaku model
|
||||
if not is_nunchaku_model and abs(model_strength - clip_strength) > 0.001:
|
||||
loaded_loras.append(f"{lora_name}: {model_strength},{clip_strength}")
|
||||
else:
|
||||
loaded_loras.append(f"{lora_name}: {model_strength}")
|
||||
|
||||
# Add trigger words to collection
|
||||
all_trigger_words.extend(trigger_words)
|
||||
|
||||
# use ',, ' to separate trigger words for group mode
|
||||
lora_entries = _collect_stack_entries(lora_stack)
|
||||
for lora in self.parse_lora_syntax(lora_syntax):
|
||||
lora_path, trigger_words = get_lora_info_absolute(lora["name"])
|
||||
lora_entries.append({
|
||||
"name": lora["name"],
|
||||
"absolute_path": lora_path,
|
||||
"input_path": lora_path,
|
||||
"model_strength": lora["model_strength"],
|
||||
"clip_strength": lora["clip_strength"],
|
||||
"trigger_words": trigger_words,
|
||||
})
|
||||
|
||||
nunchaku_model_kind = detect_nunchaku_model_kind(model)
|
||||
if nunchaku_model_kind == "flux":
|
||||
logger.info("Detected Nunchaku Flux model")
|
||||
elif nunchaku_model_kind == "qwen_image":
|
||||
logger.info("Detected Nunchaku Qwen-Image model")
|
||||
|
||||
model, clip, loaded_loras, all_trigger_words = _apply_entries(model, clip, lora_entries, nunchaku_model_kind)
|
||||
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
|
||||
|
||||
# Format loaded_loras with support for both formats
|
||||
formatted_loras = []
|
||||
for item in loaded_loras:
|
||||
parts = item.split(":")
|
||||
lora_name = parts[0].strip()
|
||||
strength_parts = parts[1].strip().split(",")
|
||||
|
||||
if len(strength_parts) > 1:
|
||||
# Different model and clip strengths
|
||||
model_str = strength_parts[0].strip()
|
||||
clip_str = strength_parts[1].strip()
|
||||
formatted_loras.append(f"<lora:{lora_name}:{model_str}:{clip_str}>")
|
||||
else:
|
||||
# Same strength for both
|
||||
model_str = strength_parts[0].strip()
|
||||
formatted_loras.append(f"<lora:{lora_name}:{model_str}>")
|
||||
|
||||
formatted_loras_text = " ".join(formatted_loras)
|
||||
|
||||
return (model, clip, trigger_words_text, formatted_loras_text)
|
||||
formatted_loras_text = _format_loaded_loras(loaded_loras)
|
||||
return (model, clip, trigger_words_text, formatted_loras_text)
|
||||
|
||||
570
py/nodes/nunchaku_qwen.py
Normal file
570
py/nodes/nunchaku_qwen.py
Normal file
@@ -0,0 +1,570 @@
|
||||
from __future__ import annotations
|
||||
|
||||
"""Qwen-Image LoRA support for Nunchaku models.
|
||||
|
||||
Portions of the LoRA mapping/application logic in this file are adapted from
|
||||
ComfyUI-QwenImageLoraLoader by GitHub user ussoewwin:
|
||||
https://github.com/ussoewwin/ComfyUI-QwenImageLoraLoader
|
||||
|
||||
The upstream project is licensed under Apache License 2.0.
|
||||
"""
|
||||
|
||||
import copy
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import comfy.utils # type: ignore
|
||||
import folder_paths # type: ignore
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from safetensors import safe_open
|
||||
|
||||
from nunchaku.lora.flux.nunchaku_converter import (
|
||||
pack_lowrank_weight,
|
||||
unpack_lowrank_weight,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
KEY_MAPPING = [
|
||||
(re.compile(r"^(layers)[._](\d+)[._]attention[._]to[._]([qkv])$"), r"\1.\2.attention.to_qkv", "qkv", lambda m: m.group(3).upper()),
|
||||
(re.compile(r"^(layers)[._](\d+)[._]feed_forward[._](w1|w3)$"), r"\1.\2.feed_forward.net.0.proj", "glu", lambda m: m.group(3)),
|
||||
(re.compile(r"^(layers)[._](\d+)[._]feed_forward[._]w2$"), r"\1.\2.feed_forward.net.2", "regular", None),
|
||||
(re.compile(r"^(layers)[._](\d+)[._](.*)$"), r"\1.\2.\3", "regular", None),
|
||||
(re.compile(r"^(transformer_blocks)[._](\d+)[._]attn[._]to[._]([qkv])$"), r"\1.\2.attn.to_qkv", "qkv", lambda m: m.group(3).upper()),
|
||||
(re.compile(r"^(transformer_blocks)[._](\d+)[._]attn[._](q|k|v)[._]proj$"), r"\1.\2.attn.to_qkv", "qkv", lambda m: m.group(3).upper()),
|
||||
(re.compile(r"^(transformer_blocks)[._](\d+)[._]attn[._]add[._](q|k|v)[._]proj$"), r"\1.\2.attn.add_qkv_proj", "add_qkv", lambda m: m.group(3).upper()),
|
||||
(re.compile(r"^(transformer_blocks)[._](\d+)[._]out[._]proj[._]context$"), r"\1.\2.attn.to_add_out", "regular", None),
|
||||
(re.compile(r"^(transformer_blocks)[._](\d+)[._]out[._]proj$"), r"\1.\2.attn.to_out.0", "regular", None),
|
||||
(re.compile(r"^(transformer_blocks)[._](\d+)[._]attn[._]to[._]out$"), r"\1.\2.attn.to_out.0", "regular", None),
|
||||
(re.compile(r"^(single_transformer_blocks)[._](\d+)[._]attn[._]to[._]([qkv])$"), r"\1.\2.attn.to_qkv", "qkv", lambda m: m.group(3).upper()),
|
||||
(re.compile(r"^(single_transformer_blocks)[._](\d+)[._]attn[._]to[._]out$"), r"\1.\2.attn.to_out", "regular", None),
|
||||
(re.compile(r"^(transformer_blocks)[._](\d+)[._]ff[._]net[._]0(?:[._]proj)?$"), r"\1.\2.mlp_fc1", "regular", None),
|
||||
(re.compile(r"^(transformer_blocks)[._](\d+)[._]ff[._]net[._]2$"), r"\1.\2.mlp_fc2", "regular", None),
|
||||
(re.compile(r"^(transformer_blocks)[._](\d+)[._]ff_context[._]net[._]0(?:[._]proj)?$"), r"\1.\2.mlp_context_fc1", "regular", None),
|
||||
(re.compile(r"^(transformer_blocks)[._](\d+)[._]ff_context[._]net[._]2$"), r"\1.\2.mlp_context_fc2", "regular", None),
|
||||
(re.compile(r"^(transformer_blocks)[._](\d+)[._](img_mlp)[._](net)[._](0)[._](proj)$"), r"\1.\2.\3.\4.\5.\6", "regular", None),
|
||||
(re.compile(r"^(transformer_blocks)[._](\d+)[._](img_mlp)[._](net)[._](2)$"), r"\1.\2.\3.\4.\5", "regular", None),
|
||||
(re.compile(r"^(transformer_blocks)[._](\d+)[._](txt_mlp)[._](net)[._](0)[._](proj)$"), r"\1.\2.\3.\4.\5.\6", "regular", None),
|
||||
(re.compile(r"^(transformer_blocks)[._](\d+)[._](txt_mlp)[._](net)[._](2)$"), r"\1.\2.\3.\4.\5", "regular", None),
|
||||
(re.compile(r"^(transformer_blocks)[._](\d+)[._](img_mod)[._](1)$"), r"\1.\2.\3.\4", "regular", None),
|
||||
(re.compile(r"^(transformer_blocks)[._](\d+)[._](txt_mod)[._](1)$"), r"\1.\2.\3.\4", "regular", None),
|
||||
(re.compile(r"^(single_transformer_blocks)[._](\d+)[._]proj[._]out$"), r"\1.\2.proj_out", "single_proj_out", None),
|
||||
(re.compile(r"^(single_transformer_blocks)[._](\d+)[._]proj[._]mlp$"), r"\1.\2.mlp_fc1", "regular", None),
|
||||
(re.compile(r"^(single_transformer_blocks)[._](\d+)[._]norm[._]linear$"), r"\1.\2.norm.linear", "regular", None),
|
||||
(re.compile(r"^(transformer_blocks)[._](\d+)[._]norm1[._]linear$"), r"\1.\2.norm1.linear", "regular", None),
|
||||
(re.compile(r"^(transformer_blocks)[._](\d+)[._]norm1_context[._]linear$"), r"\1.\2.norm1_context.linear", "regular", None),
|
||||
(re.compile(r"^(img_in)$"), r"\1", "regular", None),
|
||||
(re.compile(r"^(txt_in)$"), r"\1", "regular", None),
|
||||
(re.compile(r"^(proj_out)$"), r"\1", "regular", None),
|
||||
(re.compile(r"^(norm_out)[._](linear)$"), r"\1.\2", "regular", None),
|
||||
(re.compile(r"^(time_text_embed)[._](timestep_embedder)[._](linear_1)$"), r"\1.\2.\3", "regular", None),
|
||||
(re.compile(r"^(time_text_embed)[._](timestep_embedder)[._](linear_2)$"), r"\1.\2.\3", "regular", None),
|
||||
]
|
||||
|
||||
_RE_LORA_SUFFIX = re.compile(r"\.(?P<tag>lora(?:[._](?:A|B|down|up)))(?:\.[^.]+)*\.weight$")
|
||||
_RE_ALPHA_SUFFIX = re.compile(r"\.(?:alpha|lora_alpha)(?:\.[^.]+)*$")
|
||||
|
||||
|
||||
def _rename_layer_underscore_layer_name(old_name: str) -> str:
|
||||
rules = [
|
||||
(r"_(\d+)_attn_to_out_(\d+)", r".\1.attn.to_out.\2"),
|
||||
(r"_(\d+)_img_mlp_net_(\d+)_proj", r".\1.img_mlp.net.\2.proj"),
|
||||
(r"_(\d+)_txt_mlp_net_(\d+)_proj", r".\1.txt_mlp.net.\2.proj"),
|
||||
(r"_(\d+)_img_mlp_net_(\d+)", r".\1.img_mlp.net.\2"),
|
||||
(r"_(\d+)_txt_mlp_net_(\d+)", r".\1.txt_mlp.net.\2"),
|
||||
(r"_(\d+)_img_mod_(\d+)", r".\1.img_mod.\2"),
|
||||
(r"_(\d+)_txt_mod_(\d+)", r".\1.txt_mod.\2"),
|
||||
(r"_(\d+)_attn_", r".\1.attn."),
|
||||
]
|
||||
new_name = old_name
|
||||
for pattern, replacement in rules:
|
||||
new_name = re.sub(pattern, replacement, new_name)
|
||||
return new_name
|
||||
|
||||
|
||||
def _is_indexable_module(module):
|
||||
return isinstance(module, (nn.ModuleList, nn.Sequential, list, tuple))
|
||||
|
||||
|
||||
def _get_module_by_name(model: nn.Module, name: str) -> Optional[nn.Module]:
|
||||
if not name:
|
||||
return model
|
||||
module = model
|
||||
for part in name.split("."):
|
||||
if not part:
|
||||
continue
|
||||
if hasattr(module, part):
|
||||
module = getattr(module, part)
|
||||
elif part.isdigit() and _is_indexable_module(module):
|
||||
try:
|
||||
module = module[int(part)]
|
||||
except (IndexError, TypeError):
|
||||
return None
|
||||
else:
|
||||
return None
|
||||
return module
|
||||
|
||||
|
||||
def _resolve_module_name(model: nn.Module, name: str) -> Tuple[str, Optional[nn.Module]]:
|
||||
module = _get_module_by_name(model, name)
|
||||
if module is not None:
|
||||
return name, module
|
||||
|
||||
replacements = [
|
||||
(".attn.to_out.0", ".attn.to_out"),
|
||||
(".attention.to_qkv", ".attention.qkv"),
|
||||
(".attention.to_out.0", ".attention.out"),
|
||||
(".feed_forward.net.0.proj", ".feed_forward.w13"),
|
||||
(".feed_forward.net.2", ".feed_forward.w2"),
|
||||
(".ff.net.0.proj", ".mlp_fc1"),
|
||||
(".ff.net.2", ".mlp_fc2"),
|
||||
(".ff_context.net.0.proj", ".mlp_context_fc1"),
|
||||
(".ff_context.net.2", ".mlp_context_fc2"),
|
||||
]
|
||||
for src, dst in replacements:
|
||||
if src in name:
|
||||
alt = name.replace(src, dst)
|
||||
module = _get_module_by_name(model, alt)
|
||||
if module is not None:
|
||||
return alt, module
|
||||
return name, None
|
||||
|
||||
|
||||
def _classify_and_map_key(key: str) -> Optional[Tuple[str, str, Optional[str], str]]:
|
||||
normalized = key
|
||||
if normalized.startswith("transformer."):
|
||||
normalized = normalized[len("transformer."):]
|
||||
if normalized.startswith("diffusion_model."):
|
||||
normalized = normalized[len("diffusion_model."):]
|
||||
if normalized.startswith("lora_unet_"):
|
||||
normalized = _rename_layer_underscore_layer_name(normalized[len("lora_unet_"):])
|
||||
|
||||
match = _RE_LORA_SUFFIX.search(normalized)
|
||||
if match:
|
||||
tag = match.group("tag")
|
||||
base = normalized[:match.start()]
|
||||
ab = "A" if ("lora_A" in tag or tag.endswith(".A") or "down" in tag) else "B"
|
||||
else:
|
||||
match = _RE_ALPHA_SUFFIX.search(normalized)
|
||||
if not match:
|
||||
return None
|
||||
base = normalized[:match.start()]
|
||||
ab = "alpha"
|
||||
|
||||
for pattern, template, group, comp_fn in KEY_MAPPING:
|
||||
key_match = pattern.match(base)
|
||||
if key_match:
|
||||
return group, key_match.expand(template), comp_fn(key_match) if comp_fn else None, ab
|
||||
return None
|
||||
|
||||
|
||||
def _detect_lora_format(lora_state_dict: Dict[str, torch.Tensor]) -> bool:
|
||||
standard_patterns = (
|
||||
".lora_up.",
|
||||
".lora_down.",
|
||||
".lora_A.",
|
||||
".lora_B.",
|
||||
".lora.up.",
|
||||
".lora.down.",
|
||||
".lora.A.",
|
||||
".lora.B.",
|
||||
)
|
||||
return any(pattern in key for key in lora_state_dict for pattern in standard_patterns)
|
||||
|
||||
|
||||
def _load_lora_state_dict(path_or_dict: Union[str, Path, Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
|
||||
if isinstance(path_or_dict, dict):
|
||||
return path_or_dict
|
||||
path = Path(path_or_dict)
|
||||
if path.suffix == ".safetensors":
|
||||
state_dict: Dict[str, torch.Tensor] = {}
|
||||
with safe_open(path, framework="pt", device="cpu") as handle:
|
||||
for key in handle.keys():
|
||||
state_dict[key] = handle.get_tensor(key)
|
||||
return state_dict
|
||||
return comfy.utils.load_torch_file(str(path), safe_load=True)
|
||||
|
||||
|
||||
def _fuse_glu_lora(glu_weights: Dict[str, torch.Tensor]) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
if "w1_A" not in glu_weights or "w3_A" not in glu_weights:
|
||||
return None, None, None
|
||||
a_w1, b_w1 = glu_weights["w1_A"], glu_weights["w1_B"]
|
||||
a_w3, b_w3 = glu_weights["w3_A"], glu_weights["w3_B"]
|
||||
if a_w1.shape[1] != a_w3.shape[1]:
|
||||
return None, None, None
|
||||
a_fused = torch.cat([a_w1, a_w3], dim=0)
|
||||
out1, out3 = b_w1.shape[0], b_w3.shape[0]
|
||||
rank1, rank3 = b_w1.shape[1], b_w3.shape[1]
|
||||
b_fused = torch.zeros(out1 + out3, rank1 + rank3, dtype=b_w1.dtype, device=b_w1.device)
|
||||
b_fused[:out1, :rank1] = b_w1
|
||||
b_fused[out1:, rank1:] = b_w3
|
||||
return a_fused, b_fused, glu_weights.get("w1_alpha")
|
||||
|
||||
|
||||
def _fuse_qkv_lora(qkv_weights: Dict[str, torch.Tensor], model: Optional[nn.Module] = None, base_key: Optional[str] = None) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
required_keys = ["Q_A", "Q_B", "K_A", "K_B", "V_A", "V_B"]
|
||||
if not all(key in qkv_weights for key in required_keys):
|
||||
return None, None, None
|
||||
a_q, a_k, a_v = qkv_weights["Q_A"], qkv_weights["K_A"], qkv_weights["V_A"]
|
||||
b_q, b_k, b_v = qkv_weights["Q_B"], qkv_weights["K_B"], qkv_weights["V_B"]
|
||||
if not (a_q.shape == a_k.shape == a_v.shape):
|
||||
return None, None, None
|
||||
if not (b_q.shape[1] == b_k.shape[1] == b_v.shape[1]):
|
||||
return None, None, None
|
||||
|
||||
out_features = None
|
||||
if model is not None and base_key is not None:
|
||||
_, module = _resolve_module_name(model, base_key)
|
||||
out_features = getattr(module, "out_features", None) if module is not None else None
|
||||
|
||||
alpha_fused = None
|
||||
alpha_q = qkv_weights.get("Q_alpha")
|
||||
alpha_k = qkv_weights.get("K_alpha")
|
||||
alpha_v = qkv_weights.get("V_alpha")
|
||||
if alpha_q is not None and alpha_k is not None and alpha_v is not None and alpha_q.item() == alpha_k.item() == alpha_v.item():
|
||||
alpha_fused = alpha_q
|
||||
|
||||
a_fused = torch.cat([a_q, a_k, a_v], dim=0)
|
||||
rank = b_q.shape[1]
|
||||
out_q, out_k, out_v = b_q.shape[0], b_k.shape[0], b_v.shape[0]
|
||||
total_out = out_features if out_features is not None else out_q + out_k + out_v
|
||||
b_fused = torch.zeros(total_out, 3 * rank, dtype=b_q.dtype, device=b_q.device)
|
||||
b_fused[:out_q, :rank] = b_q
|
||||
b_fused[out_q:out_q + out_k, rank:2 * rank] = b_k
|
||||
b_fused[out_q + out_k:out_q + out_k + out_v, 2 * rank:] = b_v
|
||||
return a_fused, b_fused, alpha_fused
|
||||
|
||||
|
||||
def _handle_proj_out_split(lora_dict: Dict[str, Dict[str, torch.Tensor]], base_key: str, model: nn.Module) -> Tuple[Dict[str, Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]], List[str]]:
|
||||
result: Dict[str, Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]] = {}
|
||||
consumed: List[str] = []
|
||||
match = re.search(r"single_transformer_blocks\.(\d+)", base_key)
|
||||
if not match or base_key not in lora_dict:
|
||||
return result, consumed
|
||||
block_idx = match.group(1)
|
||||
block = _get_module_by_name(model, f"single_transformer_blocks.{block_idx}")
|
||||
if block is None:
|
||||
return result, consumed
|
||||
a_full = lora_dict[base_key].get("A")
|
||||
b_full = lora_dict[base_key].get("B")
|
||||
alpha = lora_dict[base_key].get("alpha")
|
||||
attn_to_out = getattr(getattr(block, "attn", None), "to_out", None)
|
||||
mlp_fc2 = getattr(block, "mlp_fc2", None)
|
||||
if a_full is None or b_full is None or attn_to_out is None or mlp_fc2 is None:
|
||||
return result, consumed
|
||||
attn_in = getattr(attn_to_out, "in_features", None)
|
||||
mlp_in = getattr(mlp_fc2, "in_features", None)
|
||||
if attn_in is None or mlp_in is None or a_full.shape[1] != attn_in + mlp_in:
|
||||
return result, consumed
|
||||
result[f"single_transformer_blocks.{block_idx}.attn.to_out"] = (a_full[:, :attn_in], b_full.clone(), alpha)
|
||||
result[f"single_transformer_blocks.{block_idx}.mlp_fc2"] = (a_full[:, attn_in:], b_full.clone(), alpha)
|
||||
consumed.append(base_key)
|
||||
return result, consumed
|
||||
|
||||
|
||||
def _apply_lora_to_module(module: nn.Module, a_tensor: torch.Tensor, b_tensor: torch.Tensor, module_name: str, model: nn.Module) -> None:
|
||||
if not hasattr(module, "in_features") or not hasattr(module, "out_features"):
|
||||
raise ValueError(f"{module_name}: unsupported module without in/out features")
|
||||
if a_tensor.shape[1] != module.in_features or b_tensor.shape[0] != module.out_features:
|
||||
raise ValueError(f"{module_name}: LoRA shape mismatch")
|
||||
|
||||
if module.__class__.__name__ == "AWQW4A16Linear" and hasattr(module, "qweight"):
|
||||
if not hasattr(module, "_lora_original_forward"):
|
||||
module._lora_original_forward = module.forward
|
||||
if not hasattr(module, "_nunchaku_lora_bundle"):
|
||||
module._nunchaku_lora_bundle = []
|
||||
module._nunchaku_lora_bundle.append((a_tensor, b_tensor))
|
||||
|
||||
def _awq_lora_forward(x, *args, **kwargs):
|
||||
out = module._lora_original_forward(x, *args, **kwargs)
|
||||
x_flat = x.reshape(-1, module.in_features)
|
||||
for local_a, local_b in module._nunchaku_lora_bundle:
|
||||
local_a = local_a.to(device=out.device, dtype=out.dtype)
|
||||
local_b = local_b.to(device=out.device, dtype=out.dtype)
|
||||
lora_term = (x_flat @ local_a.transpose(0, 1)) @ local_b.transpose(0, 1)
|
||||
try:
|
||||
out = out + lora_term.reshape(out.shape)
|
||||
except Exception:
|
||||
pass
|
||||
return out
|
||||
|
||||
module.forward = _awq_lora_forward
|
||||
if not hasattr(model, "_lora_slots"):
|
||||
model._lora_slots = {}
|
||||
model._lora_slots[module_name] = {"type": "awq_w4a16"}
|
||||
return
|
||||
|
||||
if hasattr(module, "proj_down") and hasattr(module, "proj_up"):
|
||||
proj_down = unpack_lowrank_weight(module.proj_down.data, down=True)
|
||||
proj_up = unpack_lowrank_weight(module.proj_up.data, down=False)
|
||||
base_rank = proj_down.shape[0] if proj_down.shape[1] == module.in_features else proj_down.shape[1]
|
||||
if proj_down.shape[1] == module.in_features:
|
||||
updated_down = torch.cat([proj_down, a_tensor], dim=0)
|
||||
axis_down = 0
|
||||
else:
|
||||
updated_down = torch.cat([proj_down, a_tensor.T], dim=1)
|
||||
axis_down = 1
|
||||
updated_up = torch.cat([proj_up, b_tensor], dim=1)
|
||||
module.proj_down.data = pack_lowrank_weight(updated_down, down=True)
|
||||
module.proj_up.data = pack_lowrank_weight(updated_up, down=False)
|
||||
module.rank = base_rank + a_tensor.shape[0]
|
||||
if not hasattr(model, "_lora_slots"):
|
||||
model._lora_slots = {}
|
||||
model._lora_slots[module_name] = {
|
||||
"type": "nunchaku",
|
||||
"base_rank": base_rank,
|
||||
"axis_down": axis_down,
|
||||
}
|
||||
return
|
||||
|
||||
if isinstance(module, nn.Linear):
|
||||
if not hasattr(model, "_lora_slots"):
|
||||
model._lora_slots = {}
|
||||
if module_name not in model._lora_slots:
|
||||
model._lora_slots[module_name] = {
|
||||
"type": "linear",
|
||||
"original_weight": module.weight.detach().cpu().clone(),
|
||||
}
|
||||
module.weight.data.add_((b_tensor @ a_tensor).to(dtype=module.weight.dtype, device=module.weight.device))
|
||||
return
|
||||
|
||||
raise ValueError(f"{module_name}: unsupported module type {type(module)}")
|
||||
|
||||
|
||||
def reset_lora_v2(model: nn.Module) -> None:
|
||||
slots = getattr(model, "_lora_slots", None)
|
||||
if not slots:
|
||||
return
|
||||
for name, info in list(slots.items()):
|
||||
module = _get_module_by_name(model, name)
|
||||
if module is None:
|
||||
continue
|
||||
module_type = info.get("type", "nunchaku")
|
||||
if module_type == "nunchaku":
|
||||
base_rank = info["base_rank"]
|
||||
proj_down = unpack_lowrank_weight(module.proj_down.data, down=True)
|
||||
proj_up = unpack_lowrank_weight(module.proj_up.data, down=False)
|
||||
if info.get("axis_down", 0) == 0:
|
||||
proj_down = proj_down[:base_rank, :].clone()
|
||||
else:
|
||||
proj_down = proj_down[:, :base_rank].clone()
|
||||
proj_up = proj_up[:, :base_rank].clone()
|
||||
module.proj_down.data = pack_lowrank_weight(proj_down, down=True)
|
||||
module.proj_up.data = pack_lowrank_weight(proj_up, down=False)
|
||||
module.rank = base_rank
|
||||
elif module_type == "linear" and "original_weight" in info:
|
||||
module.weight.data.copy_(info["original_weight"].to(device=module.weight.device, dtype=module.weight.dtype))
|
||||
elif module_type == "awq_w4a16":
|
||||
if hasattr(module, "_lora_original_forward"):
|
||||
module.forward = module._lora_original_forward
|
||||
for attr in ("_lora_original_forward", "_nunchaku_lora_bundle"):
|
||||
if hasattr(module, attr):
|
||||
delattr(module, attr)
|
||||
model._lora_slots = {}
|
||||
|
||||
|
||||
def compose_loras_v2(model: nn.Module, lora_configs: List[Tuple[Union[str, Path, Dict[str, torch.Tensor]], float]], apply_awq_mod: bool = True) -> bool:
|
||||
del apply_awq_mod # retained for interface compatibility
|
||||
reset_lora_v2(model)
|
||||
aggregated_weights: Dict[str, List[Dict[str, object]]] = defaultdict(list)
|
||||
saw_supported_format = False
|
||||
unresolved_targets = 0
|
||||
|
||||
for index, (path_or_dict, strength) in enumerate(lora_configs):
|
||||
if abs(strength) < 1e-5:
|
||||
continue
|
||||
lora_name = str(path_or_dict) if not isinstance(path_or_dict, dict) else f"lora_{index}"
|
||||
lora_state_dict = _load_lora_state_dict(path_or_dict)
|
||||
if not lora_state_dict or not _detect_lora_format(lora_state_dict):
|
||||
logger.warning("Skipping unsupported Qwen LoRA: %s", lora_name)
|
||||
continue
|
||||
saw_supported_format = True
|
||||
|
||||
grouped_weights: Dict[str, Dict[str, torch.Tensor]] = defaultdict(dict)
|
||||
for key, value in lora_state_dict.items():
|
||||
parsed = _classify_and_map_key(key)
|
||||
if parsed is None:
|
||||
continue
|
||||
group, base_key, component, ab = parsed
|
||||
if component and ab:
|
||||
grouped_weights[base_key][f"{component}_{ab}"] = value
|
||||
else:
|
||||
grouped_weights[base_key][ab] = value
|
||||
|
||||
processed_groups: Dict[str, Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]] = {}
|
||||
handled: set[str] = set()
|
||||
for base_key, weights in grouped_weights.items():
|
||||
if base_key in handled:
|
||||
continue
|
||||
a_tensor = b_tensor = alpha = None
|
||||
if "qkv" in base_key or "add_qkv_proj" in base_key:
|
||||
a_tensor, b_tensor, alpha = _fuse_qkv_lora(weights, model=model, base_key=base_key)
|
||||
elif "w1_A" in weights or "w3_A" in weights:
|
||||
a_tensor, b_tensor, alpha = _fuse_glu_lora(weights)
|
||||
elif ".proj_out" in base_key and "single_transformer_blocks" in base_key:
|
||||
split_map, consumed = _handle_proj_out_split(grouped_weights, base_key, model)
|
||||
processed_groups.update(split_map)
|
||||
handled.update(consumed)
|
||||
continue
|
||||
else:
|
||||
a_tensor, b_tensor, alpha = weights.get("A"), weights.get("B"), weights.get("alpha")
|
||||
if a_tensor is not None and b_tensor is not None:
|
||||
processed_groups[base_key] = (a_tensor, b_tensor, alpha)
|
||||
|
||||
for module_name, (a_tensor, b_tensor, alpha) in processed_groups.items():
|
||||
aggregated_weights[module_name].append({
|
||||
"A": a_tensor,
|
||||
"B": b_tensor,
|
||||
"alpha": alpha,
|
||||
"strength": strength,
|
||||
})
|
||||
|
||||
for module_name, weight_list in aggregated_weights.items():
|
||||
resolved_name, module = _resolve_module_name(model, module_name)
|
||||
if module is None:
|
||||
logger.warning("Skipping unresolved Qwen LoRA target: %s", module_name)
|
||||
unresolved_targets += 1
|
||||
continue
|
||||
all_a = []
|
||||
all_b_scaled = []
|
||||
for item in weight_list:
|
||||
a_tensor = item["A"]
|
||||
b_tensor = item["B"]
|
||||
alpha = item["alpha"]
|
||||
strength = float(item["strength"])
|
||||
rank = a_tensor.shape[0]
|
||||
scale = strength * ((alpha / rank) if alpha is not None else 1.0)
|
||||
if module.__class__.__name__ == "AWQW4A16Linear" and hasattr(module, "qweight"):
|
||||
target_dtype = torch.float16
|
||||
target_device = module.qweight.device
|
||||
elif hasattr(module, "proj_down"):
|
||||
target_dtype = module.proj_down.dtype
|
||||
target_device = module.proj_down.device
|
||||
elif hasattr(module, "weight"):
|
||||
target_dtype = module.weight.dtype
|
||||
target_device = module.weight.device
|
||||
else:
|
||||
target_dtype = torch.float16
|
||||
target_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
all_a.append(a_tensor.to(dtype=target_dtype, device=target_device))
|
||||
all_b_scaled.append((b_tensor * scale).to(dtype=target_dtype, device=target_device))
|
||||
if not all_a:
|
||||
continue
|
||||
_apply_lora_to_module(module, torch.cat(all_a, dim=0), torch.cat(all_b_scaled, dim=1), resolved_name, model)
|
||||
|
||||
slot_count = len(getattr(model, "_lora_slots", {}) or {})
|
||||
logger.info(
|
||||
"Qwen LoRA composition finished: requested=%d supported=%s applied_targets=%d unresolved=%d",
|
||||
len(lora_configs),
|
||||
saw_supported_format,
|
||||
slot_count,
|
||||
unresolved_targets,
|
||||
)
|
||||
return saw_supported_format
|
||||
|
||||
|
||||
class ComfyQwenImageWrapperLM(nn.Module):
|
||||
def __init__(self, model: nn.Module, config=None, apply_awq_mod: bool = True):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.config = {} if config is None else config
|
||||
self.dtype = next(model.parameters()).dtype
|
||||
self.loras: List[Tuple[Union[str, Path, Dict[str, torch.Tensor]], float]] = []
|
||||
self._applied_loras: Optional[List[Tuple[Union[str, Path, Dict[str, torch.Tensor]], float]]] = None
|
||||
self.apply_awq_mod = apply_awq_mod
|
||||
|
||||
def __getattr__(self, name):
|
||||
try:
|
||||
inner = object.__getattribute__(self, "_modules").get("model")
|
||||
except (AttributeError, KeyError):
|
||||
inner = None
|
||||
if inner is None:
|
||||
raise AttributeError(f"{type(self).__name__!s} has no attribute {name}")
|
||||
if name == "model":
|
||||
return inner
|
||||
return getattr(inner, name)
|
||||
|
||||
def process_img(self, *args, **kwargs):
|
||||
return self.model.process_img(*args, **kwargs)
|
||||
|
||||
def _ensure_composed(self):
|
||||
if self._applied_loras != self.loras or (not self.loras and getattr(self.model, "_lora_slots", None)):
|
||||
is_supported_format = compose_loras_v2(self.model, self.loras, apply_awq_mod=self.apply_awq_mod)
|
||||
self._applied_loras = self.loras.copy()
|
||||
has_slots = bool(getattr(self.model, "_lora_slots", None))
|
||||
if self.loras and is_supported_format and not has_slots:
|
||||
logger.warning("Qwen LoRA compose produced 0 target modules. Resetting and retrying once.")
|
||||
reset_lora_v2(self.model)
|
||||
compose_loras_v2(self.model, self.loras, apply_awq_mod=self.apply_awq_mod)
|
||||
has_slots = bool(getattr(self.model, "_lora_slots", None))
|
||||
logger.info("Qwen LoRA retry result: applied_targets=%d", len(getattr(self.model, "_lora_slots", {}) or {}))
|
||||
|
||||
offload_manager = getattr(self.model, "offload_manager", None)
|
||||
if offload_manager is not None:
|
||||
offload_settings = {
|
||||
"num_blocks_on_gpu": getattr(offload_manager, "num_blocks_on_gpu", 1),
|
||||
"use_pin_memory": getattr(offload_manager, "use_pin_memory", False),
|
||||
}
|
||||
logger.info(
|
||||
"Rebuilding Qwen offload manager after LoRA compose: num_blocks_on_gpu=%s use_pin_memory=%s",
|
||||
offload_settings["num_blocks_on_gpu"],
|
||||
offload_settings["use_pin_memory"],
|
||||
)
|
||||
self.model.set_offload(False)
|
||||
self.model.set_offload(True, **offload_settings)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
self._ensure_composed()
|
||||
return self.model(*args, **kwargs)
|
||||
|
||||
|
||||
def _get_qwen_wrapper_and_transformer(model):
|
||||
model_wrapper = model.model.diffusion_model
|
||||
if hasattr(model_wrapper, "model") and hasattr(model_wrapper, "loras"):
|
||||
transformer = model_wrapper.model
|
||||
if transformer.__class__.__name__.endswith("NunchakuQwenImageTransformer2DModel"):
|
||||
return model_wrapper, transformer
|
||||
if model_wrapper.__class__.__name__.endswith("NunchakuQwenImageTransformer2DModel"):
|
||||
wrapped_model = ComfyQwenImageWrapperLM(model_wrapper, getattr(model_wrapper, "config", {}))
|
||||
model.model.diffusion_model = wrapped_model
|
||||
return wrapped_model, wrapped_model.model
|
||||
raise TypeError(f"This LoRA loader only works with Nunchaku Qwen Image models, but got {type(model_wrapper).__name__}.")
|
||||
|
||||
|
||||
def nunchaku_load_qwen_loras(model, lora_configs: List[Tuple[str, float]], apply_awq_mod: bool = True):
|
||||
model_wrapper, transformer = _get_qwen_wrapper_and_transformer(model)
|
||||
model_wrapper.apply_awq_mod = apply_awq_mod
|
||||
|
||||
saved_config = None
|
||||
if hasattr(model, "model") and hasattr(model.model, "model_config"):
|
||||
saved_config = model.model.model_config
|
||||
model.model.model_config = None
|
||||
|
||||
model_wrapper.model = None
|
||||
try:
|
||||
ret_model = copy.deepcopy(model)
|
||||
finally:
|
||||
if saved_config is not None:
|
||||
model.model.model_config = saved_config
|
||||
model_wrapper.model = transformer
|
||||
|
||||
ret_model_wrapper = ret_model.model.diffusion_model
|
||||
if saved_config is not None:
|
||||
ret_model.model.model_config = saved_config
|
||||
ret_model_wrapper.model = transformer
|
||||
ret_model_wrapper.apply_awq_mod = apply_awq_mod
|
||||
ret_model_wrapper.loras = list(getattr(model_wrapper, "loras", []))
|
||||
|
||||
for lora_name, lora_strength in lora_configs:
|
||||
lora_path = lora_name if os.path.isfile(lora_name) else folder_paths.get_full_path("loras", lora_name)
|
||||
if not lora_path or not os.path.isfile(lora_path):
|
||||
logger.warning("Skipping Qwen LoRA '%s' because it could not be found", lora_name)
|
||||
continue
|
||||
ret_model_wrapper.loras.append((lora_path, lora_strength))
|
||||
|
||||
return ret_model
|
||||
@@ -158,3 +158,24 @@ def nunchaku_load_lora(model, lora_name, lora_strength):
|
||||
ret_model.model.model_config.unet_config["in_channels"] = new_in_channels
|
||||
|
||||
return ret_model
|
||||
|
||||
|
||||
def detect_nunchaku_model_kind(model):
|
||||
"""Return the supported Nunchaku model kind for a Comfy model, if any."""
|
||||
try:
|
||||
model_wrapper = model.model.diffusion_model
|
||||
except (AttributeError, TypeError):
|
||||
return None
|
||||
|
||||
wrapper_name = model_wrapper.__class__.__name__
|
||||
if wrapper_name == "ComfyFluxWrapper":
|
||||
return "flux"
|
||||
|
||||
inner_model = getattr(model_wrapper, "model", None)
|
||||
inner_name = inner_model.__class__.__name__ if inner_model is not None else ""
|
||||
if wrapper_name.endswith("NunchakuQwenImageTransformer2DModel"):
|
||||
return "qwen_image"
|
||||
if inner_name.endswith("NunchakuQwenImageTransformer2DModel"):
|
||||
return "qwen_image"
|
||||
|
||||
return None
|
||||
|
||||
176
tests/nodes/test_lora_loader.py
Normal file
176
tests/nodes/test_lora_loader.py
Normal file
@@ -0,0 +1,176 @@
|
||||
import types
|
||||
|
||||
from py.nodes.lora_loader import LoraLoaderLM, LoraTextLoaderLM
|
||||
|
||||
|
||||
class _ModelContainer:
|
||||
def __init__(self, diffusion_model):
|
||||
self.diffusion_model = diffusion_model
|
||||
|
||||
|
||||
class _Model:
|
||||
def __init__(self, diffusion_model):
|
||||
self.model = _ModelContainer(diffusion_model)
|
||||
|
||||
|
||||
def test_lora_loader_standard_model_uses_comfy_loader(monkeypatch):
|
||||
loader = LoraLoaderLM()
|
||||
model = _Model(object())
|
||||
clip = object()
|
||||
|
||||
monkeypatch.setattr(
|
||||
"py.nodes.lora_loader.get_lora_info_absolute",
|
||||
lambda name: (f"/abs/{name}.safetensors", [f"{name}_trigger"]),
|
||||
)
|
||||
|
||||
load_calls = []
|
||||
|
||||
def mock_load_torch_file(path, safe_load=True):
|
||||
load_calls.append((path, safe_load))
|
||||
return {"path": path}
|
||||
|
||||
def mock_load_lora_for_models(model_arg, clip_arg, lora_arg, model_strength, clip_strength):
|
||||
return model_arg, clip_arg
|
||||
|
||||
monkeypatch.setattr("comfy.utils.load_torch_file", mock_load_torch_file)
|
||||
monkeypatch.setattr("comfy.sd.load_lora_for_models", mock_load_lora_for_models)
|
||||
|
||||
result_model, result_clip, trigger_words, loaded_loras = loader.load_loras(
|
||||
model,
|
||||
"",
|
||||
clip=clip,
|
||||
loras={
|
||||
"__value__": [
|
||||
{"active": True, "name": "demo", "strength": 0.75, "clipStrength": 0.5},
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
assert result_model is model
|
||||
assert result_clip is clip
|
||||
assert load_calls == [("/abs/demo.safetensors", True)]
|
||||
assert trigger_words == "demo_trigger"
|
||||
assert loaded_loras == "<lora:demo:0.75:0.5>"
|
||||
|
||||
|
||||
def test_lora_loader_formats_widget_lora_names_with_colons(monkeypatch):
|
||||
loader = LoraLoaderLM()
|
||||
model = _Model(object())
|
||||
clip = object()
|
||||
|
||||
monkeypatch.setattr(
|
||||
"py.nodes.lora_loader.get_lora_info_absolute",
|
||||
lambda name: (f"/abs/{name}.safetensors", [f"{name}_trigger"]),
|
||||
)
|
||||
monkeypatch.setattr("comfy.utils.load_torch_file", lambda path, safe_load=True: {"path": path})
|
||||
monkeypatch.setattr(
|
||||
"comfy.sd.load_lora_for_models",
|
||||
lambda model_arg, clip_arg, lora_arg, model_strength, clip_strength: (model_arg, clip_arg),
|
||||
)
|
||||
|
||||
_, _, trigger_words, loaded_loras = loader.load_loras(
|
||||
model,
|
||||
"",
|
||||
clip=clip,
|
||||
loras={
|
||||
"__value__": [
|
||||
{"active": True, "name": "demo:variant", "strength": 0.75, "clipStrength": 0.5},
|
||||
{"active": True, "name": "demo:single", "strength": 0.3},
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
assert trigger_words == "demo:variant_trigger,, demo:single_trigger"
|
||||
assert loaded_loras == "<lora:demo:variant:0.75:0.5> <lora:demo:single:0.3>"
|
||||
|
||||
|
||||
def test_lora_loader_flux_model_uses_flux_helper(monkeypatch):
|
||||
flux_model = _Model(type("ComfyFluxWrapper", (), {})())
|
||||
loader = LoraLoaderLM()
|
||||
|
||||
monkeypatch.setattr(
|
||||
"py.nodes.lora_loader.get_lora_info_absolute",
|
||||
lambda name: (f"/abs/{name}.safetensors", [f"{name}_trigger"]),
|
||||
)
|
||||
|
||||
calls = []
|
||||
|
||||
def mock_nunchaku_load_lora(model_arg, lora_name, strength):
|
||||
calls.append((lora_name, strength))
|
||||
return model_arg
|
||||
|
||||
monkeypatch.setattr("py.nodes.lora_loader.nunchaku_load_lora", mock_nunchaku_load_lora)
|
||||
|
||||
_, _, trigger_words, loaded_loras = loader.load_loras(
|
||||
flux_model,
|
||||
"",
|
||||
lora_stack=[("stack_lora.safetensors", 0.4, 0.2)],
|
||||
loras={"__value__": [{"active": True, "name": "widget_lora", "strength": 0.8}]},
|
||||
)
|
||||
|
||||
assert calls == [("stack_lora.safetensors", 0.4), ("/abs/widget_lora.safetensors", 0.8)]
|
||||
assert trigger_words == "stack_lora_trigger,, widget_lora_trigger"
|
||||
assert loaded_loras == "<lora:stack_lora:0.4> <lora:widget_lora:0.8>"
|
||||
|
||||
|
||||
def test_lora_loader_qwen_model_batches_loras(monkeypatch):
|
||||
qwen_model = _Model(type("NunchakuQwenImageTransformer2DModel", (), {})())
|
||||
loader = LoraLoaderLM()
|
||||
|
||||
monkeypatch.setattr(
|
||||
"py.nodes.lora_loader.get_lora_info_absolute",
|
||||
lambda name: (f"/abs/{name}.safetensors", [f"{name}_trigger"]),
|
||||
)
|
||||
|
||||
batched_calls = []
|
||||
|
||||
def mock_nunchaku_load_qwen_loras(model_arg, lora_configs):
|
||||
batched_calls.append((model_arg, lora_configs))
|
||||
return model_arg
|
||||
|
||||
monkeypatch.setattr("py.nodes.lora_loader.nunchaku_load_qwen_loras", mock_nunchaku_load_qwen_loras)
|
||||
|
||||
_, result_clip, trigger_words, loaded_loras = loader.load_loras(
|
||||
qwen_model,
|
||||
"",
|
||||
clip="clip",
|
||||
lora_stack=[("stack_qwen.safetensors", 0.6, 0.1)],
|
||||
loras={"__value__": [{"active": True, "name": "widget_qwen", "strength": 0.9, "clipStrength": 0.3}]},
|
||||
)
|
||||
|
||||
assert result_clip == "clip"
|
||||
assert len(batched_calls) == 1
|
||||
assert batched_calls[0][0] is qwen_model
|
||||
assert batched_calls[0][1] == [
|
||||
("/abs/stack_qwen.safetensors", 0.6),
|
||||
("/abs/widget_qwen.safetensors", 0.9),
|
||||
]
|
||||
assert trigger_words == "stack_qwen_trigger,, widget_qwen_trigger"
|
||||
assert loaded_loras == "<lora:stack_qwen:0.6> <lora:widget_qwen:0.9>"
|
||||
|
||||
|
||||
def test_lora_text_loader_qwen_batches_text_and_stack(monkeypatch):
|
||||
qwen_model = _Model(type("NunchakuQwenImageTransformer2DModel", (), {})())
|
||||
loader = LoraTextLoaderLM()
|
||||
|
||||
monkeypatch.setattr(
|
||||
"py.nodes.lora_loader.get_lora_info_absolute",
|
||||
lambda name: (f"/abs/{name}.safetensors", [f"{name}_trigger"]),
|
||||
)
|
||||
|
||||
batched_calls = []
|
||||
monkeypatch.setattr(
|
||||
"py.nodes.lora_loader.nunchaku_load_qwen_loras",
|
||||
lambda model_arg, lora_configs: batched_calls.append(lora_configs) or model_arg,
|
||||
)
|
||||
|
||||
_, _, trigger_words, loaded_loras = loader.load_loras_from_text(
|
||||
qwen_model,
|
||||
"<lora:text_qwen:1.2:0.4>",
|
||||
clip="clip",
|
||||
lora_stack=[("stack_qwen.safetensors", 0.6, 0.1)],
|
||||
)
|
||||
|
||||
assert batched_calls == [[("/abs/stack_qwen.safetensors", 0.6), ("/abs/text_qwen.safetensors", 1.2)]]
|
||||
assert trigger_words == "stack_qwen_trigger,, text_qwen_trigger"
|
||||
assert loaded_loras == "<lora:stack_qwen:0.6> <lora:text_qwen:1.2>"
|
||||
Reference in New Issue
Block a user