Files
ComfyUI-Lora-Manager/py/nodes/nunchaku_qwen.py
2026-03-27 19:24:43 +08:00

571 lines
27 KiB
Python

from __future__ import annotations
"""Qwen-Image LoRA support for Nunchaku models.
Portions of the LoRA mapping/application logic in this file are adapted from
ComfyUI-QwenImageLoraLoader by GitHub user ussoewwin:
https://github.com/ussoewwin/ComfyUI-QwenImageLoraLoader
The upstream project is licensed under Apache License 2.0.
"""
import copy
import logging
import os
import re
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
import comfy.utils # type: ignore
import folder_paths # type: ignore
import torch
import torch.nn as nn
from safetensors import safe_open
from nunchaku.lora.flux.nunchaku_converter import (
pack_lowrank_weight,
unpack_lowrank_weight,
)
logger = logging.getLogger(__name__)
KEY_MAPPING = [
(re.compile(r"^(layers)[._](\d+)[._]attention[._]to[._]([qkv])$"), r"\1.\2.attention.to_qkv", "qkv", lambda m: m.group(3).upper()),
(re.compile(r"^(layers)[._](\d+)[._]feed_forward[._](w1|w3)$"), r"\1.\2.feed_forward.net.0.proj", "glu", lambda m: m.group(3)),
(re.compile(r"^(layers)[._](\d+)[._]feed_forward[._]w2$"), r"\1.\2.feed_forward.net.2", "regular", None),
(re.compile(r"^(layers)[._](\d+)[._](.*)$"), r"\1.\2.\3", "regular", None),
(re.compile(r"^(transformer_blocks)[._](\d+)[._]attn[._]to[._]([qkv])$"), r"\1.\2.attn.to_qkv", "qkv", lambda m: m.group(3).upper()),
(re.compile(r"^(transformer_blocks)[._](\d+)[._]attn[._](q|k|v)[._]proj$"), r"\1.\2.attn.to_qkv", "qkv", lambda m: m.group(3).upper()),
(re.compile(r"^(transformer_blocks)[._](\d+)[._]attn[._]add[._](q|k|v)[._]proj$"), r"\1.\2.attn.add_qkv_proj", "add_qkv", lambda m: m.group(3).upper()),
(re.compile(r"^(transformer_blocks)[._](\d+)[._]out[._]proj[._]context$"), r"\1.\2.attn.to_add_out", "regular", None),
(re.compile(r"^(transformer_blocks)[._](\d+)[._]out[._]proj$"), r"\1.\2.attn.to_out.0", "regular", None),
(re.compile(r"^(transformer_blocks)[._](\d+)[._]attn[._]to[._]out$"), r"\1.\2.attn.to_out.0", "regular", None),
(re.compile(r"^(single_transformer_blocks)[._](\d+)[._]attn[._]to[._]([qkv])$"), r"\1.\2.attn.to_qkv", "qkv", lambda m: m.group(3).upper()),
(re.compile(r"^(single_transformer_blocks)[._](\d+)[._]attn[._]to[._]out$"), r"\1.\2.attn.to_out", "regular", None),
(re.compile(r"^(transformer_blocks)[._](\d+)[._]ff[._]net[._]0(?:[._]proj)?$"), r"\1.\2.mlp_fc1", "regular", None),
(re.compile(r"^(transformer_blocks)[._](\d+)[._]ff[._]net[._]2$"), r"\1.\2.mlp_fc2", "regular", None),
(re.compile(r"^(transformer_blocks)[._](\d+)[._]ff_context[._]net[._]0(?:[._]proj)?$"), r"\1.\2.mlp_context_fc1", "regular", None),
(re.compile(r"^(transformer_blocks)[._](\d+)[._]ff_context[._]net[._]2$"), r"\1.\2.mlp_context_fc2", "regular", None),
(re.compile(r"^(transformer_blocks)[._](\d+)[._](img_mlp)[._](net)[._](0)[._](proj)$"), r"\1.\2.\3.\4.\5.\6", "regular", None),
(re.compile(r"^(transformer_blocks)[._](\d+)[._](img_mlp)[._](net)[._](2)$"), r"\1.\2.\3.\4.\5", "regular", None),
(re.compile(r"^(transformer_blocks)[._](\d+)[._](txt_mlp)[._](net)[._](0)[._](proj)$"), r"\1.\2.\3.\4.\5.\6", "regular", None),
(re.compile(r"^(transformer_blocks)[._](\d+)[._](txt_mlp)[._](net)[._](2)$"), r"\1.\2.\3.\4.\5", "regular", None),
(re.compile(r"^(transformer_blocks)[._](\d+)[._](img_mod)[._](1)$"), r"\1.\2.\3.\4", "regular", None),
(re.compile(r"^(transformer_blocks)[._](\d+)[._](txt_mod)[._](1)$"), r"\1.\2.\3.\4", "regular", None),
(re.compile(r"^(single_transformer_blocks)[._](\d+)[._]proj[._]out$"), r"\1.\2.proj_out", "single_proj_out", None),
(re.compile(r"^(single_transformer_blocks)[._](\d+)[._]proj[._]mlp$"), r"\1.\2.mlp_fc1", "regular", None),
(re.compile(r"^(single_transformer_blocks)[._](\d+)[._]norm[._]linear$"), r"\1.\2.norm.linear", "regular", None),
(re.compile(r"^(transformer_blocks)[._](\d+)[._]norm1[._]linear$"), r"\1.\2.norm1.linear", "regular", None),
(re.compile(r"^(transformer_blocks)[._](\d+)[._]norm1_context[._]linear$"), r"\1.\2.norm1_context.linear", "regular", None),
(re.compile(r"^(img_in)$"), r"\1", "regular", None),
(re.compile(r"^(txt_in)$"), r"\1", "regular", None),
(re.compile(r"^(proj_out)$"), r"\1", "regular", None),
(re.compile(r"^(norm_out)[._](linear)$"), r"\1.\2", "regular", None),
(re.compile(r"^(time_text_embed)[._](timestep_embedder)[._](linear_1)$"), r"\1.\2.\3", "regular", None),
(re.compile(r"^(time_text_embed)[._](timestep_embedder)[._](linear_2)$"), r"\1.\2.\3", "regular", None),
]
_RE_LORA_SUFFIX = re.compile(r"\.(?P<tag>lora(?:[._](?:A|B|down|up)))(?:\.[^.]+)*\.weight$")
_RE_ALPHA_SUFFIX = re.compile(r"\.(?:alpha|lora_alpha)(?:\.[^.]+)*$")
def _rename_layer_underscore_layer_name(old_name: str) -> str:
rules = [
(r"_(\d+)_attn_to_out_(\d+)", r".\1.attn.to_out.\2"),
(r"_(\d+)_img_mlp_net_(\d+)_proj", r".\1.img_mlp.net.\2.proj"),
(r"_(\d+)_txt_mlp_net_(\d+)_proj", r".\1.txt_mlp.net.\2.proj"),
(r"_(\d+)_img_mlp_net_(\d+)", r".\1.img_mlp.net.\2"),
(r"_(\d+)_txt_mlp_net_(\d+)", r".\1.txt_mlp.net.\2"),
(r"_(\d+)_img_mod_(\d+)", r".\1.img_mod.\2"),
(r"_(\d+)_txt_mod_(\d+)", r".\1.txt_mod.\2"),
(r"_(\d+)_attn_", r".\1.attn."),
]
new_name = old_name
for pattern, replacement in rules:
new_name = re.sub(pattern, replacement, new_name)
return new_name
def _is_indexable_module(module):
return isinstance(module, (nn.ModuleList, nn.Sequential, list, tuple))
def _get_module_by_name(model: nn.Module, name: str) -> Optional[nn.Module]:
if not name:
return model
module = model
for part in name.split("."):
if not part:
continue
if hasattr(module, part):
module = getattr(module, part)
elif part.isdigit() and _is_indexable_module(module):
try:
module = module[int(part)]
except (IndexError, TypeError):
return None
else:
return None
return module
def _resolve_module_name(model: nn.Module, name: str) -> Tuple[str, Optional[nn.Module]]:
module = _get_module_by_name(model, name)
if module is not None:
return name, module
replacements = [
(".attn.to_out.0", ".attn.to_out"),
(".attention.to_qkv", ".attention.qkv"),
(".attention.to_out.0", ".attention.out"),
(".feed_forward.net.0.proj", ".feed_forward.w13"),
(".feed_forward.net.2", ".feed_forward.w2"),
(".ff.net.0.proj", ".mlp_fc1"),
(".ff.net.2", ".mlp_fc2"),
(".ff_context.net.0.proj", ".mlp_context_fc1"),
(".ff_context.net.2", ".mlp_context_fc2"),
]
for src, dst in replacements:
if src in name:
alt = name.replace(src, dst)
module = _get_module_by_name(model, alt)
if module is not None:
return alt, module
return name, None
def _classify_and_map_key(key: str) -> Optional[Tuple[str, str, Optional[str], str]]:
normalized = key
if normalized.startswith("transformer."):
normalized = normalized[len("transformer."):]
if normalized.startswith("diffusion_model."):
normalized = normalized[len("diffusion_model."):]
if normalized.startswith("lora_unet_"):
normalized = _rename_layer_underscore_layer_name(normalized[len("lora_unet_"):])
match = _RE_LORA_SUFFIX.search(normalized)
if match:
tag = match.group("tag")
base = normalized[:match.start()]
ab = "A" if ("lora_A" in tag or tag.endswith(".A") or "down" in tag) else "B"
else:
match = _RE_ALPHA_SUFFIX.search(normalized)
if not match:
return None
base = normalized[:match.start()]
ab = "alpha"
for pattern, template, group, comp_fn in KEY_MAPPING:
key_match = pattern.match(base)
if key_match:
return group, key_match.expand(template), comp_fn(key_match) if comp_fn else None, ab
return None
def _detect_lora_format(lora_state_dict: Dict[str, torch.Tensor]) -> bool:
standard_patterns = (
".lora_up.",
".lora_down.",
".lora_A.",
".lora_B.",
".lora.up.",
".lora.down.",
".lora.A.",
".lora.B.",
)
return any(pattern in key for key in lora_state_dict for pattern in standard_patterns)
def _load_lora_state_dict(path_or_dict: Union[str, Path, Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
if isinstance(path_or_dict, dict):
return path_or_dict
path = Path(path_or_dict)
if path.suffix == ".safetensors":
state_dict: Dict[str, torch.Tensor] = {}
with safe_open(path, framework="pt", device="cpu") as handle:
for key in handle.keys():
state_dict[key] = handle.get_tensor(key)
return state_dict
return comfy.utils.load_torch_file(str(path), safe_load=True)
def _fuse_glu_lora(glu_weights: Dict[str, torch.Tensor]) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
if "w1_A" not in glu_weights or "w3_A" not in glu_weights:
return None, None, None
a_w1, b_w1 = glu_weights["w1_A"], glu_weights["w1_B"]
a_w3, b_w3 = glu_weights["w3_A"], glu_weights["w3_B"]
if a_w1.shape[1] != a_w3.shape[1]:
return None, None, None
a_fused = torch.cat([a_w1, a_w3], dim=0)
out1, out3 = b_w1.shape[0], b_w3.shape[0]
rank1, rank3 = b_w1.shape[1], b_w3.shape[1]
b_fused = torch.zeros(out1 + out3, rank1 + rank3, dtype=b_w1.dtype, device=b_w1.device)
b_fused[:out1, :rank1] = b_w1
b_fused[out1:, rank1:] = b_w3
return a_fused, b_fused, glu_weights.get("w1_alpha")
def _fuse_qkv_lora(qkv_weights: Dict[str, torch.Tensor], model: Optional[nn.Module] = None, base_key: Optional[str] = None) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
required_keys = ["Q_A", "Q_B", "K_A", "K_B", "V_A", "V_B"]
if not all(key in qkv_weights for key in required_keys):
return None, None, None
a_q, a_k, a_v = qkv_weights["Q_A"], qkv_weights["K_A"], qkv_weights["V_A"]
b_q, b_k, b_v = qkv_weights["Q_B"], qkv_weights["K_B"], qkv_weights["V_B"]
if not (a_q.shape == a_k.shape == a_v.shape):
return None, None, None
if not (b_q.shape[1] == b_k.shape[1] == b_v.shape[1]):
return None, None, None
out_features = None
if model is not None and base_key is not None:
_, module = _resolve_module_name(model, base_key)
out_features = getattr(module, "out_features", None) if module is not None else None
alpha_fused = None
alpha_q = qkv_weights.get("Q_alpha")
alpha_k = qkv_weights.get("K_alpha")
alpha_v = qkv_weights.get("V_alpha")
if alpha_q is not None and alpha_k is not None and alpha_v is not None and alpha_q.item() == alpha_k.item() == alpha_v.item():
alpha_fused = alpha_q
a_fused = torch.cat([a_q, a_k, a_v], dim=0)
rank = b_q.shape[1]
out_q, out_k, out_v = b_q.shape[0], b_k.shape[0], b_v.shape[0]
total_out = out_features if out_features is not None else out_q + out_k + out_v
b_fused = torch.zeros(total_out, 3 * rank, dtype=b_q.dtype, device=b_q.device)
b_fused[:out_q, :rank] = b_q
b_fused[out_q:out_q + out_k, rank:2 * rank] = b_k
b_fused[out_q + out_k:out_q + out_k + out_v, 2 * rank:] = b_v
return a_fused, b_fused, alpha_fused
def _handle_proj_out_split(lora_dict: Dict[str, Dict[str, torch.Tensor]], base_key: str, model: nn.Module) -> Tuple[Dict[str, Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]], List[str]]:
result: Dict[str, Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]] = {}
consumed: List[str] = []
match = re.search(r"single_transformer_blocks\.(\d+)", base_key)
if not match or base_key not in lora_dict:
return result, consumed
block_idx = match.group(1)
block = _get_module_by_name(model, f"single_transformer_blocks.{block_idx}")
if block is None:
return result, consumed
a_full = lora_dict[base_key].get("A")
b_full = lora_dict[base_key].get("B")
alpha = lora_dict[base_key].get("alpha")
attn_to_out = getattr(getattr(block, "attn", None), "to_out", None)
mlp_fc2 = getattr(block, "mlp_fc2", None)
if a_full is None or b_full is None or attn_to_out is None or mlp_fc2 is None:
return result, consumed
attn_in = getattr(attn_to_out, "in_features", None)
mlp_in = getattr(mlp_fc2, "in_features", None)
if attn_in is None or mlp_in is None or a_full.shape[1] != attn_in + mlp_in:
return result, consumed
result[f"single_transformer_blocks.{block_idx}.attn.to_out"] = (a_full[:, :attn_in], b_full.clone(), alpha)
result[f"single_transformer_blocks.{block_idx}.mlp_fc2"] = (a_full[:, attn_in:], b_full.clone(), alpha)
consumed.append(base_key)
return result, consumed
def _apply_lora_to_module(module: nn.Module, a_tensor: torch.Tensor, b_tensor: torch.Tensor, module_name: str, model: nn.Module) -> None:
if not hasattr(module, "in_features") or not hasattr(module, "out_features"):
raise ValueError(f"{module_name}: unsupported module without in/out features")
if a_tensor.shape[1] != module.in_features or b_tensor.shape[0] != module.out_features:
raise ValueError(f"{module_name}: LoRA shape mismatch")
if module.__class__.__name__ == "AWQW4A16Linear" and hasattr(module, "qweight"):
if not hasattr(module, "_lora_original_forward"):
module._lora_original_forward = module.forward
if not hasattr(module, "_nunchaku_lora_bundle"):
module._nunchaku_lora_bundle = []
module._nunchaku_lora_bundle.append((a_tensor, b_tensor))
def _awq_lora_forward(x, *args, **kwargs):
out = module._lora_original_forward(x, *args, **kwargs)
x_flat = x.reshape(-1, module.in_features)
for local_a, local_b in module._nunchaku_lora_bundle:
local_a = local_a.to(device=out.device, dtype=out.dtype)
local_b = local_b.to(device=out.device, dtype=out.dtype)
lora_term = (x_flat @ local_a.transpose(0, 1)) @ local_b.transpose(0, 1)
try:
out = out + lora_term.reshape(out.shape)
except Exception:
pass
return out
module.forward = _awq_lora_forward
if not hasattr(model, "_lora_slots"):
model._lora_slots = {}
model._lora_slots[module_name] = {"type": "awq_w4a16"}
return
if hasattr(module, "proj_down") and hasattr(module, "proj_up"):
proj_down = unpack_lowrank_weight(module.proj_down.data, down=True)
proj_up = unpack_lowrank_weight(module.proj_up.data, down=False)
base_rank = proj_down.shape[0] if proj_down.shape[1] == module.in_features else proj_down.shape[1]
if proj_down.shape[1] == module.in_features:
updated_down = torch.cat([proj_down, a_tensor], dim=0)
axis_down = 0
else:
updated_down = torch.cat([proj_down, a_tensor.T], dim=1)
axis_down = 1
updated_up = torch.cat([proj_up, b_tensor], dim=1)
module.proj_down.data = pack_lowrank_weight(updated_down, down=True)
module.proj_up.data = pack_lowrank_weight(updated_up, down=False)
module.rank = base_rank + a_tensor.shape[0]
if not hasattr(model, "_lora_slots"):
model._lora_slots = {}
model._lora_slots[module_name] = {
"type": "nunchaku",
"base_rank": base_rank,
"axis_down": axis_down,
}
return
if isinstance(module, nn.Linear):
if not hasattr(model, "_lora_slots"):
model._lora_slots = {}
if module_name not in model._lora_slots:
model._lora_slots[module_name] = {
"type": "linear",
"original_weight": module.weight.detach().cpu().clone(),
}
module.weight.data.add_((b_tensor @ a_tensor).to(dtype=module.weight.dtype, device=module.weight.device))
return
raise ValueError(f"{module_name}: unsupported module type {type(module)}")
def reset_lora_v2(model: nn.Module) -> None:
slots = getattr(model, "_lora_slots", None)
if not slots:
return
for name, info in list(slots.items()):
module = _get_module_by_name(model, name)
if module is None:
continue
module_type = info.get("type", "nunchaku")
if module_type == "nunchaku":
base_rank = info["base_rank"]
proj_down = unpack_lowrank_weight(module.proj_down.data, down=True)
proj_up = unpack_lowrank_weight(module.proj_up.data, down=False)
if info.get("axis_down", 0) == 0:
proj_down = proj_down[:base_rank, :].clone()
else:
proj_down = proj_down[:, :base_rank].clone()
proj_up = proj_up[:, :base_rank].clone()
module.proj_down.data = pack_lowrank_weight(proj_down, down=True)
module.proj_up.data = pack_lowrank_weight(proj_up, down=False)
module.rank = base_rank
elif module_type == "linear" and "original_weight" in info:
module.weight.data.copy_(info["original_weight"].to(device=module.weight.device, dtype=module.weight.dtype))
elif module_type == "awq_w4a16":
if hasattr(module, "_lora_original_forward"):
module.forward = module._lora_original_forward
for attr in ("_lora_original_forward", "_nunchaku_lora_bundle"):
if hasattr(module, attr):
delattr(module, attr)
model._lora_slots = {}
def compose_loras_v2(model: nn.Module, lora_configs: List[Tuple[Union[str, Path, Dict[str, torch.Tensor]], float]], apply_awq_mod: bool = True) -> bool:
del apply_awq_mod # retained for interface compatibility
reset_lora_v2(model)
aggregated_weights: Dict[str, List[Dict[str, object]]] = defaultdict(list)
saw_supported_format = False
unresolved_targets = 0
for index, (path_or_dict, strength) in enumerate(lora_configs):
if abs(strength) < 1e-5:
continue
lora_name = str(path_or_dict) if not isinstance(path_or_dict, dict) else f"lora_{index}"
lora_state_dict = _load_lora_state_dict(path_or_dict)
if not lora_state_dict or not _detect_lora_format(lora_state_dict):
logger.warning("Skipping unsupported Qwen LoRA: %s", lora_name)
continue
saw_supported_format = True
grouped_weights: Dict[str, Dict[str, torch.Tensor]] = defaultdict(dict)
for key, value in lora_state_dict.items():
parsed = _classify_and_map_key(key)
if parsed is None:
continue
group, base_key, component, ab = parsed
if component and ab:
grouped_weights[base_key][f"{component}_{ab}"] = value
else:
grouped_weights[base_key][ab] = value
processed_groups: Dict[str, Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]] = {}
handled: set[str] = set()
for base_key, weights in grouped_weights.items():
if base_key in handled:
continue
a_tensor = b_tensor = alpha = None
if "qkv" in base_key or "add_qkv_proj" in base_key:
a_tensor, b_tensor, alpha = _fuse_qkv_lora(weights, model=model, base_key=base_key)
elif "w1_A" in weights or "w3_A" in weights:
a_tensor, b_tensor, alpha = _fuse_glu_lora(weights)
elif ".proj_out" in base_key and "single_transformer_blocks" in base_key:
split_map, consumed = _handle_proj_out_split(grouped_weights, base_key, model)
processed_groups.update(split_map)
handled.update(consumed)
continue
else:
a_tensor, b_tensor, alpha = weights.get("A"), weights.get("B"), weights.get("alpha")
if a_tensor is not None and b_tensor is not None:
processed_groups[base_key] = (a_tensor, b_tensor, alpha)
for module_name, (a_tensor, b_tensor, alpha) in processed_groups.items():
aggregated_weights[module_name].append({
"A": a_tensor,
"B": b_tensor,
"alpha": alpha,
"strength": strength,
})
for module_name, weight_list in aggregated_weights.items():
resolved_name, module = _resolve_module_name(model, module_name)
if module is None:
logger.warning("Skipping unresolved Qwen LoRA target: %s", module_name)
unresolved_targets += 1
continue
all_a = []
all_b_scaled = []
for item in weight_list:
a_tensor = item["A"]
b_tensor = item["B"]
alpha = item["alpha"]
strength = float(item["strength"])
rank = a_tensor.shape[0]
scale = strength * ((alpha / rank) if alpha is not None else 1.0)
if module.__class__.__name__ == "AWQW4A16Linear" and hasattr(module, "qweight"):
target_dtype = torch.float16
target_device = module.qweight.device
elif hasattr(module, "proj_down"):
target_dtype = module.proj_down.dtype
target_device = module.proj_down.device
elif hasattr(module, "weight"):
target_dtype = module.weight.dtype
target_device = module.weight.device
else:
target_dtype = torch.float16
target_device = "cuda" if torch.cuda.is_available() else "cpu"
all_a.append(a_tensor.to(dtype=target_dtype, device=target_device))
all_b_scaled.append((b_tensor * scale).to(dtype=target_dtype, device=target_device))
if not all_a:
continue
_apply_lora_to_module(module, torch.cat(all_a, dim=0), torch.cat(all_b_scaled, dim=1), resolved_name, model)
slot_count = len(getattr(model, "_lora_slots", {}) or {})
logger.info(
"Qwen LoRA composition finished: requested=%d supported=%s applied_targets=%d unresolved=%d",
len(lora_configs),
saw_supported_format,
slot_count,
unresolved_targets,
)
return saw_supported_format
class ComfyQwenImageWrapperLM(nn.Module):
def __init__(self, model: nn.Module, config=None, apply_awq_mod: bool = True):
super().__init__()
self.model = model
self.config = {} if config is None else config
self.dtype = next(model.parameters()).dtype
self.loras: List[Tuple[Union[str, Path, Dict[str, torch.Tensor]], float]] = []
self._applied_loras: Optional[List[Tuple[Union[str, Path, Dict[str, torch.Tensor]], float]]] = None
self.apply_awq_mod = apply_awq_mod
def __getattr__(self, name):
try:
inner = object.__getattribute__(self, "_modules").get("model")
except (AttributeError, KeyError):
inner = None
if inner is None:
raise AttributeError(f"{type(self).__name__!s} has no attribute {name}")
if name == "model":
return inner
return getattr(inner, name)
def process_img(self, *args, **kwargs):
return self.model.process_img(*args, **kwargs)
def _ensure_composed(self):
if self._applied_loras != self.loras or (not self.loras and getattr(self.model, "_lora_slots", None)):
is_supported_format = compose_loras_v2(self.model, self.loras, apply_awq_mod=self.apply_awq_mod)
self._applied_loras = self.loras.copy()
has_slots = bool(getattr(self.model, "_lora_slots", None))
if self.loras and is_supported_format and not has_slots:
logger.warning("Qwen LoRA compose produced 0 target modules. Resetting and retrying once.")
reset_lora_v2(self.model)
compose_loras_v2(self.model, self.loras, apply_awq_mod=self.apply_awq_mod)
has_slots = bool(getattr(self.model, "_lora_slots", None))
logger.info("Qwen LoRA retry result: applied_targets=%d", len(getattr(self.model, "_lora_slots", {}) or {}))
offload_manager = getattr(self.model, "offload_manager", None)
if offload_manager is not None:
offload_settings = {
"num_blocks_on_gpu": getattr(offload_manager, "num_blocks_on_gpu", 1),
"use_pin_memory": getattr(offload_manager, "use_pin_memory", False),
}
logger.info(
"Rebuilding Qwen offload manager after LoRA compose: num_blocks_on_gpu=%s use_pin_memory=%s",
offload_settings["num_blocks_on_gpu"],
offload_settings["use_pin_memory"],
)
self.model.set_offload(False)
self.model.set_offload(True, **offload_settings)
def forward(self, *args, **kwargs):
self._ensure_composed()
return self.model(*args, **kwargs)
def _get_qwen_wrapper_and_transformer(model):
model_wrapper = model.model.diffusion_model
if hasattr(model_wrapper, "model") and hasattr(model_wrapper, "loras"):
transformer = model_wrapper.model
if transformer.__class__.__name__.endswith("NunchakuQwenImageTransformer2DModel"):
return model_wrapper, transformer
if model_wrapper.__class__.__name__.endswith("NunchakuQwenImageTransformer2DModel"):
wrapped_model = ComfyQwenImageWrapperLM(model_wrapper, getattr(model_wrapper, "config", {}))
model.model.diffusion_model = wrapped_model
return wrapped_model, wrapped_model.model
raise TypeError(f"This LoRA loader only works with Nunchaku Qwen Image models, but got {type(model_wrapper).__name__}.")
def nunchaku_load_qwen_loras(model, lora_configs: List[Tuple[str, float]], apply_awq_mod: bool = True):
model_wrapper, transformer = _get_qwen_wrapper_and_transformer(model)
model_wrapper.apply_awq_mod = apply_awq_mod
saved_config = None
if hasattr(model, "model") and hasattr(model.model, "model_config"):
saved_config = model.model.model_config
model.model.model_config = None
model_wrapper.model = None
try:
ret_model = copy.deepcopy(model)
finally:
if saved_config is not None:
model.model.model_config = saved_config
model_wrapper.model = transformer
ret_model_wrapper = ret_model.model.diffusion_model
if saved_config is not None:
ret_model.model.model_config = saved_config
ret_model_wrapper.model = transformer
ret_model_wrapper.apply_awq_mod = apply_awq_mod
ret_model_wrapper.loras = list(getattr(model_wrapper, "loras", []))
for lora_name, lora_strength in lora_configs:
lora_path = lora_name if os.path.isfile(lora_name) else folder_paths.get_full_path("loras", lora_name)
if not lora_path or not os.path.isfile(lora_path):
logger.warning("Skipping Qwen LoRA '%s' because it could not be found", lora_name)
continue
ret_model_wrapper.loras.append((lora_path, lora_strength))
return ret_model