mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-28 08:28:53 -03:00
Add experimental Nunchaku Qwen LoRA support (#873)
This commit is contained in:
@@ -346,6 +346,7 @@ We appreciate your understanding and look forward to potentially accepting code
|
|||||||
|
|
||||||
This project has been inspired by and benefited from other excellent ComfyUI extensions:
|
This project has been inspired by and benefited from other excellent ComfyUI extensions:
|
||||||
|
|
||||||
|
- [ComfyUI-QwenImageLoraLoader](https://github.com/ussoewwin/ComfyUI-QwenImageLoraLoader) - For the experimental Nunchaku Qwen-Image LoRA support
|
||||||
- [ComfyUI-SaveImageWithMetaData](https://github.com/nkchocoai/ComfyUI-SaveImageWithMetaData) - For the image metadata functionality
|
- [ComfyUI-SaveImageWithMetaData](https://github.com/nkchocoai/ComfyUI-SaveImageWithMetaData) - For the image metadata functionality
|
||||||
- [rgthree-comfy](https://github.com/rgthree/rgthree-comfy) - For the lora loader functionality
|
- [rgthree-comfy](https://github.com/rgthree/rgthree-comfy) - For the lora loader functionality
|
||||||
|
|
||||||
|
|||||||
@@ -1,12 +1,118 @@
|
|||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
import comfy.utils # type: ignore
|
|
||||||
import comfy.sd # type: ignore
|
import comfy.sd # type: ignore
|
||||||
|
import comfy.utils # type: ignore
|
||||||
|
|
||||||
from ..utils.utils import get_lora_info_absolute
|
from ..utils.utils import get_lora_info_absolute
|
||||||
from .utils import FlexibleOptionalInputType, any_type, extract_lora_name, get_loras_list, nunchaku_load_lora
|
from .nunchaku_qwen import nunchaku_load_qwen_loras
|
||||||
|
from .utils import (
|
||||||
|
FlexibleOptionalInputType,
|
||||||
|
any_type,
|
||||||
|
detect_nunchaku_model_kind,
|
||||||
|
extract_lora_name,
|
||||||
|
get_loras_list,
|
||||||
|
nunchaku_load_lora,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _collect_stack_entries(lora_stack):
|
||||||
|
entries = []
|
||||||
|
if not lora_stack:
|
||||||
|
return entries
|
||||||
|
|
||||||
|
for lora_path, model_strength, clip_strength in lora_stack:
|
||||||
|
lora_name = extract_lora_name(lora_path)
|
||||||
|
absolute_lora_path, trigger_words = get_lora_info_absolute(lora_name)
|
||||||
|
entries.append({
|
||||||
|
"name": lora_name,
|
||||||
|
"absolute_path": absolute_lora_path,
|
||||||
|
"input_path": lora_path,
|
||||||
|
"model_strength": float(model_strength),
|
||||||
|
"clip_strength": float(clip_strength),
|
||||||
|
"trigger_words": trigger_words,
|
||||||
|
})
|
||||||
|
return entries
|
||||||
|
|
||||||
|
|
||||||
|
def _collect_widget_entries(kwargs):
|
||||||
|
entries = []
|
||||||
|
for lora in get_loras_list(kwargs):
|
||||||
|
if not lora.get("active", False):
|
||||||
|
continue
|
||||||
|
lora_name = lora["name"]
|
||||||
|
model_strength = float(lora["strength"])
|
||||||
|
clip_strength = float(lora.get("clipStrength", model_strength))
|
||||||
|
lora_path, trigger_words = get_lora_info_absolute(lora_name)
|
||||||
|
entries.append({
|
||||||
|
"name": lora_name,
|
||||||
|
"absolute_path": lora_path,
|
||||||
|
"input_path": lora_path,
|
||||||
|
"model_strength": model_strength,
|
||||||
|
"clip_strength": clip_strength,
|
||||||
|
"trigger_words": trigger_words,
|
||||||
|
})
|
||||||
|
return entries
|
||||||
|
|
||||||
|
|
||||||
|
def _format_loaded_loras(loaded_loras):
|
||||||
|
formatted_loras = []
|
||||||
|
for item in loaded_loras:
|
||||||
|
if item["include_clip_strength"]:
|
||||||
|
formatted_loras.append(
|
||||||
|
f"<lora:{item['name']}:{item['model_strength']}:{item['clip_strength']}>"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
formatted_loras.append(f"<lora:{item['name']}:{item['model_strength']}>")
|
||||||
|
return " ".join(formatted_loras)
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_entries(model, clip, lora_entries, nunchaku_model_kind):
|
||||||
|
loaded_loras = []
|
||||||
|
all_trigger_words = []
|
||||||
|
|
||||||
|
if nunchaku_model_kind == "qwen_image":
|
||||||
|
qwen_lora_configs = []
|
||||||
|
for entry in lora_entries:
|
||||||
|
qwen_lora_configs.append((entry["absolute_path"], entry["model_strength"]))
|
||||||
|
loaded_loras.append({
|
||||||
|
"name": entry["name"],
|
||||||
|
"model_strength": entry["model_strength"],
|
||||||
|
"clip_strength": entry["model_strength"],
|
||||||
|
"include_clip_strength": False,
|
||||||
|
})
|
||||||
|
all_trigger_words.extend(entry["trigger_words"])
|
||||||
|
if qwen_lora_configs:
|
||||||
|
model = nunchaku_load_qwen_loras(model, qwen_lora_configs)
|
||||||
|
return model, clip, loaded_loras, all_trigger_words
|
||||||
|
|
||||||
|
for entry in lora_entries:
|
||||||
|
if nunchaku_model_kind == "flux":
|
||||||
|
model = nunchaku_load_lora(model, entry["input_path"], entry["model_strength"])
|
||||||
|
else:
|
||||||
|
lora = comfy.utils.load_torch_file(entry["absolute_path"], safe_load=True)
|
||||||
|
model, clip = comfy.sd.load_lora_for_models(
|
||||||
|
model,
|
||||||
|
clip,
|
||||||
|
lora,
|
||||||
|
entry["model_strength"],
|
||||||
|
entry["clip_strength"],
|
||||||
|
)
|
||||||
|
|
||||||
|
include_clip_strength = nunchaku_model_kind is None and abs(entry["model_strength"] - entry["clip_strength"]) > 0.001
|
||||||
|
loaded_loras.append({
|
||||||
|
"name": entry["name"],
|
||||||
|
"model_strength": entry["model_strength"],
|
||||||
|
"clip_strength": entry["clip_strength"],
|
||||||
|
"include_clip_strength": include_clip_strength,
|
||||||
|
})
|
||||||
|
all_trigger_words.extend(entry["trigger_words"])
|
||||||
|
|
||||||
|
return model, clip, loaded_loras, all_trigger_words
|
||||||
|
|
||||||
|
|
||||||
class LoraLoaderLM:
|
class LoraLoaderLM:
|
||||||
NAME = "Lora Loader (LoraManager)"
|
NAME = "Lora Loader (LoraManager)"
|
||||||
CATEGORY = "Lora Manager/loaders"
|
CATEGORY = "Lora Manager/loaders"
|
||||||
@@ -16,7 +122,6 @@ class LoraLoaderLM:
|
|||||||
return {
|
return {
|
||||||
"required": {
|
"required": {
|
||||||
"model": ("MODEL",),
|
"model": ("MODEL",),
|
||||||
# "clip": ("CLIP",),
|
|
||||||
"text": ("AUTOCOMPLETE_TEXT_LORAS", {
|
"text": ("AUTOCOMPLETE_TEXT_LORAS", {
|
||||||
"placeholder": "Search LoRAs to add...",
|
"placeholder": "Search LoRAs to add...",
|
||||||
"tooltip": "Format: <lora:lora_name:strength> separated by spaces or punctuation",
|
"tooltip": "Format: <lora:lora_name:strength> separated by spaces or punctuation",
|
||||||
@@ -31,107 +136,23 @@ class LoraLoaderLM:
|
|||||||
|
|
||||||
def load_loras(self, model, text, **kwargs):
|
def load_loras(self, model, text, **kwargs):
|
||||||
"""Loads multiple LoRAs based on the kwargs input and lora_stack."""
|
"""Loads multiple LoRAs based on the kwargs input and lora_stack."""
|
||||||
loaded_loras = []
|
del text
|
||||||
all_trigger_words = []
|
clip = kwargs.get("clip", None)
|
||||||
|
lora_entries = _collect_stack_entries(kwargs.get("lora_stack", None))
|
||||||
|
lora_entries.extend(_collect_widget_entries(kwargs))
|
||||||
|
|
||||||
clip = kwargs.get('clip', None)
|
nunchaku_model_kind = detect_nunchaku_model_kind(model)
|
||||||
lora_stack = kwargs.get('lora_stack', None)
|
if nunchaku_model_kind == "flux":
|
||||||
|
|
||||||
# Check if model is a Nunchaku Flux model - simplified approach
|
|
||||||
is_nunchaku_model = False
|
|
||||||
|
|
||||||
try:
|
|
||||||
model_wrapper = model.model.diffusion_model
|
|
||||||
# Check if model is a Nunchaku Flux model using only class name
|
|
||||||
if model_wrapper.__class__.__name__ == "ComfyFluxWrapper":
|
|
||||||
is_nunchaku_model = True
|
|
||||||
logger.info("Detected Nunchaku Flux model")
|
logger.info("Detected Nunchaku Flux model")
|
||||||
except (AttributeError, TypeError):
|
elif nunchaku_model_kind == "qwen_image":
|
||||||
# Not a model with the expected structure
|
logger.info("Detected Nunchaku Qwen-Image model")
|
||||||
pass
|
|
||||||
|
|
||||||
# First process lora_stack if available
|
model, clip, loaded_loras, all_trigger_words = _apply_entries(model, clip, lora_entries, nunchaku_model_kind)
|
||||||
if lora_stack:
|
|
||||||
for lora_path, model_strength, clip_strength in lora_stack:
|
|
||||||
# Extract lora name and convert to absolute path
|
|
||||||
# lora_stack stores relative paths, but load_torch_file needs absolute paths
|
|
||||||
lora_name = extract_lora_name(lora_path)
|
|
||||||
absolute_lora_path, trigger_words = get_lora_info_absolute(lora_name)
|
|
||||||
|
|
||||||
# Apply the LoRA using the appropriate loader
|
|
||||||
if is_nunchaku_model:
|
|
||||||
# Use our custom function for Flux models
|
|
||||||
model = nunchaku_load_lora(model, lora_path, model_strength)
|
|
||||||
# clip remains unchanged for Nunchaku models
|
|
||||||
else:
|
|
||||||
# Use lower-level API to load LoRA directly without folder_paths validation
|
|
||||||
lora = comfy.utils.load_torch_file(absolute_lora_path, safe_load=True)
|
|
||||||
model, clip = comfy.sd.load_lora_for_models(model, clip, lora, model_strength, clip_strength)
|
|
||||||
|
|
||||||
all_trigger_words.extend(trigger_words)
|
|
||||||
# Add clip strength to output if different from model strength (except for Nunchaku models)
|
|
||||||
if not is_nunchaku_model and abs(model_strength - clip_strength) > 0.001:
|
|
||||||
loaded_loras.append(f"{lora_name}: {model_strength},{clip_strength}")
|
|
||||||
else:
|
|
||||||
loaded_loras.append(f"{lora_name}: {model_strength}")
|
|
||||||
|
|
||||||
# Then process loras from kwargs with support for both old and new formats
|
|
||||||
loras_list = get_loras_list(kwargs)
|
|
||||||
for lora in loras_list:
|
|
||||||
if not lora.get('active', False):
|
|
||||||
continue
|
|
||||||
|
|
||||||
lora_name = lora['name']
|
|
||||||
model_strength = float(lora['strength'])
|
|
||||||
# Get clip strength - use model strength as default if not specified
|
|
||||||
clip_strength = float(lora.get('clipStrength', model_strength))
|
|
||||||
|
|
||||||
# Get lora path and trigger words
|
|
||||||
lora_path, trigger_words = get_lora_info_absolute(lora_name)
|
|
||||||
|
|
||||||
# Apply the LoRA using the appropriate loader
|
|
||||||
if is_nunchaku_model:
|
|
||||||
# For Nunchaku models, use our custom function
|
|
||||||
model = nunchaku_load_lora(model, lora_path, model_strength)
|
|
||||||
# clip remains unchanged
|
|
||||||
else:
|
|
||||||
# Use lower-level API to load LoRA directly without folder_paths validation
|
|
||||||
lora = comfy.utils.load_torch_file(lora_path, safe_load=True)
|
|
||||||
model, clip = comfy.sd.load_lora_for_models(model, clip, lora, model_strength, clip_strength)
|
|
||||||
|
|
||||||
# Include clip strength in output if different from model strength and not a Nunchaku model
|
|
||||||
if not is_nunchaku_model and abs(model_strength - clip_strength) > 0.001:
|
|
||||||
loaded_loras.append(f"{lora_name}: {model_strength},{clip_strength}")
|
|
||||||
else:
|
|
||||||
loaded_loras.append(f"{lora_name}: {model_strength}")
|
|
||||||
|
|
||||||
# Add trigger words to collection
|
|
||||||
all_trigger_words.extend(trigger_words)
|
|
||||||
|
|
||||||
# use ',, ' to separate trigger words for group mode
|
|
||||||
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
|
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
|
||||||
|
formatted_loras_text = _format_loaded_loras(loaded_loras)
|
||||||
# Format loaded_loras with support for both formats
|
|
||||||
formatted_loras = []
|
|
||||||
for item in loaded_loras:
|
|
||||||
parts = item.split(":")
|
|
||||||
lora_name = parts[0]
|
|
||||||
strength_parts = parts[1].strip().split(",")
|
|
||||||
|
|
||||||
if len(strength_parts) > 1:
|
|
||||||
# Different model and clip strengths
|
|
||||||
model_str = strength_parts[0].strip()
|
|
||||||
clip_str = strength_parts[1].strip()
|
|
||||||
formatted_loras.append(f"<lora:{lora_name}:{model_str}:{clip_str}>")
|
|
||||||
else:
|
|
||||||
# Same strength for both
|
|
||||||
model_str = strength_parts[0].strip()
|
|
||||||
formatted_loras.append(f"<lora:{lora_name}:{model_str}>")
|
|
||||||
|
|
||||||
formatted_loras_text = " ".join(formatted_loras)
|
|
||||||
|
|
||||||
return (model, clip, trigger_words_text, formatted_loras_text)
|
return (model, clip, trigger_words_text, formatted_loras_text)
|
||||||
|
|
||||||
|
|
||||||
class LoraTextLoaderLM:
|
class LoraTextLoaderLM:
|
||||||
NAME = "LoRA Text Loader (LoraManager)"
|
NAME = "LoRA Text Loader (LoraManager)"
|
||||||
CATEGORY = "Lora Manager/loaders"
|
CATEGORY = "Lora Manager/loaders"
|
||||||
@@ -143,13 +164,13 @@ class LoraTextLoaderLM:
|
|||||||
"model": ("MODEL",),
|
"model": ("MODEL",),
|
||||||
"lora_syntax": ("STRING", {
|
"lora_syntax": ("STRING", {
|
||||||
"forceInput": True,
|
"forceInput": True,
|
||||||
"tooltip": "Format: <lora:lora_name:strength> separated by spaces or punctuation"
|
"tooltip": "Format: <lora:lora_name:strength> separated by spaces or punctuation",
|
||||||
}),
|
}),
|
||||||
},
|
},
|
||||||
"optional": {
|
"optional": {
|
||||||
"clip": ("CLIP",),
|
"clip": ("CLIP",),
|
||||||
"lora_stack": ("LORA_STACK",),
|
"lora_stack": ("LORA_STACK",),
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
RETURN_TYPES = ("MODEL", "CLIP", "STRING", "STRING")
|
RETURN_TYPES = ("MODEL", "CLIP", "STRING", "STRING")
|
||||||
@@ -158,116 +179,40 @@ class LoraTextLoaderLM:
|
|||||||
|
|
||||||
def parse_lora_syntax(self, text):
|
def parse_lora_syntax(self, text):
|
||||||
"""Parse LoRA syntax from text input."""
|
"""Parse LoRA syntax from text input."""
|
||||||
# Pattern to match <lora:name:strength> or <lora:name:model_strength:clip_strength>
|
pattern = r"<lora:([^:>]+):([^:>]+)(?::([^:>]+))?>"
|
||||||
pattern = r'<lora:([^:>]+):([^:>]+)(?::([^:>]+))?>'
|
|
||||||
matches = re.findall(pattern, text, re.IGNORECASE)
|
matches = re.findall(pattern, text, re.IGNORECASE)
|
||||||
|
|
||||||
loras = []
|
loras = []
|
||||||
for match in matches:
|
for match in matches:
|
||||||
lora_name = match[0]
|
|
||||||
model_strength = float(match[1])
|
model_strength = float(match[1])
|
||||||
clip_strength = float(match[2]) if match[2] else model_strength
|
|
||||||
|
|
||||||
loras.append({
|
loras.append({
|
||||||
'name': lora_name,
|
"name": match[0],
|
||||||
'model_strength': model_strength,
|
"model_strength": model_strength,
|
||||||
'clip_strength': clip_strength
|
"clip_strength": float(match[2]) if match[2] else model_strength,
|
||||||
})
|
})
|
||||||
|
|
||||||
return loras
|
return loras
|
||||||
|
|
||||||
def load_loras_from_text(self, model, lora_syntax, clip=None, lora_stack=None):
|
def load_loras_from_text(self, model, lora_syntax, clip=None, lora_stack=None):
|
||||||
"""Load LoRAs based on text syntax input."""
|
"""Load LoRAs based on text syntax input."""
|
||||||
loaded_loras = []
|
lora_entries = _collect_stack_entries(lora_stack)
|
||||||
all_trigger_words = []
|
for lora in self.parse_lora_syntax(lora_syntax):
|
||||||
|
lora_path, trigger_words = get_lora_info_absolute(lora["name"])
|
||||||
|
lora_entries.append({
|
||||||
|
"name": lora["name"],
|
||||||
|
"absolute_path": lora_path,
|
||||||
|
"input_path": lora_path,
|
||||||
|
"model_strength": lora["model_strength"],
|
||||||
|
"clip_strength": lora["clip_strength"],
|
||||||
|
"trigger_words": trigger_words,
|
||||||
|
})
|
||||||
|
|
||||||
# Check if model is a Nunchaku Flux model - simplified approach
|
nunchaku_model_kind = detect_nunchaku_model_kind(model)
|
||||||
is_nunchaku_model = False
|
if nunchaku_model_kind == "flux":
|
||||||
|
|
||||||
try:
|
|
||||||
model_wrapper = model.model.diffusion_model
|
|
||||||
# Check if model is a Nunchaku Flux model using only class name
|
|
||||||
if model_wrapper.__class__.__name__ == "ComfyFluxWrapper":
|
|
||||||
is_nunchaku_model = True
|
|
||||||
logger.info("Detected Nunchaku Flux model")
|
logger.info("Detected Nunchaku Flux model")
|
||||||
except (AttributeError, TypeError):
|
elif nunchaku_model_kind == "qwen_image":
|
||||||
# Not a model with the expected structure
|
logger.info("Detected Nunchaku Qwen-Image model")
|
||||||
pass
|
|
||||||
|
|
||||||
# First process lora_stack if available
|
model, clip, loaded_loras, all_trigger_words = _apply_entries(model, clip, lora_entries, nunchaku_model_kind)
|
||||||
if lora_stack:
|
|
||||||
for lora_path, model_strength, clip_strength in lora_stack:
|
|
||||||
# Extract lora name and convert to absolute path
|
|
||||||
# lora_stack stores relative paths, but load_torch_file needs absolute paths
|
|
||||||
lora_name = extract_lora_name(lora_path)
|
|
||||||
absolute_lora_path, trigger_words = get_lora_info_absolute(lora_name)
|
|
||||||
|
|
||||||
# Apply the LoRA using the appropriate loader
|
|
||||||
if is_nunchaku_model:
|
|
||||||
# Use our custom function for Flux models
|
|
||||||
model = nunchaku_load_lora(model, lora_path, model_strength)
|
|
||||||
# clip remains unchanged for Nunchaku models
|
|
||||||
else:
|
|
||||||
# Use lower-level API to load LoRA directly without folder_paths validation
|
|
||||||
lora = comfy.utils.load_torch_file(absolute_lora_path, safe_load=True)
|
|
||||||
model, clip = comfy.sd.load_lora_for_models(model, clip, lora, model_strength, clip_strength)
|
|
||||||
|
|
||||||
all_trigger_words.extend(trigger_words)
|
|
||||||
# Add clip strength to output if different from model strength (except for Nunchaku models)
|
|
||||||
if not is_nunchaku_model and abs(model_strength - clip_strength) > 0.001:
|
|
||||||
loaded_loras.append(f"{lora_name}: {model_strength},{clip_strength}")
|
|
||||||
else:
|
|
||||||
loaded_loras.append(f"{lora_name}: {model_strength}")
|
|
||||||
|
|
||||||
# Parse and process LoRAs from text syntax
|
|
||||||
parsed_loras = self.parse_lora_syntax(lora_syntax)
|
|
||||||
for lora in parsed_loras:
|
|
||||||
lora_name = lora['name']
|
|
||||||
model_strength = lora['model_strength']
|
|
||||||
clip_strength = lora['clip_strength']
|
|
||||||
|
|
||||||
# Get lora path and trigger words
|
|
||||||
lora_path, trigger_words = get_lora_info_absolute(lora_name)
|
|
||||||
|
|
||||||
# Apply the LoRA using the appropriate loader
|
|
||||||
if is_nunchaku_model:
|
|
||||||
# For Nunchaku models, use our custom function
|
|
||||||
model = nunchaku_load_lora(model, lora_path, model_strength)
|
|
||||||
# clip remains unchanged
|
|
||||||
else:
|
|
||||||
# Use lower-level API to load LoRA directly without folder_paths validation
|
|
||||||
lora = comfy.utils.load_torch_file(lora_path, safe_load=True)
|
|
||||||
model, clip = comfy.sd.load_lora_for_models(model, clip, lora, model_strength, clip_strength)
|
|
||||||
|
|
||||||
# Include clip strength in output if different from model strength and not a Nunchaku model
|
|
||||||
if not is_nunchaku_model and abs(model_strength - clip_strength) > 0.001:
|
|
||||||
loaded_loras.append(f"{lora_name}: {model_strength},{clip_strength}")
|
|
||||||
else:
|
|
||||||
loaded_loras.append(f"{lora_name}: {model_strength}")
|
|
||||||
|
|
||||||
# Add trigger words to collection
|
|
||||||
all_trigger_words.extend(trigger_words)
|
|
||||||
|
|
||||||
# use ',, ' to separate trigger words for group mode
|
|
||||||
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
|
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
|
||||||
|
formatted_loras_text = _format_loaded_loras(loaded_loras)
|
||||||
# Format loaded_loras with support for both formats
|
|
||||||
formatted_loras = []
|
|
||||||
for item in loaded_loras:
|
|
||||||
parts = item.split(":")
|
|
||||||
lora_name = parts[0].strip()
|
|
||||||
strength_parts = parts[1].strip().split(",")
|
|
||||||
|
|
||||||
if len(strength_parts) > 1:
|
|
||||||
# Different model and clip strengths
|
|
||||||
model_str = strength_parts[0].strip()
|
|
||||||
clip_str = strength_parts[1].strip()
|
|
||||||
formatted_loras.append(f"<lora:{lora_name}:{model_str}:{clip_str}>")
|
|
||||||
else:
|
|
||||||
# Same strength for both
|
|
||||||
model_str = strength_parts[0].strip()
|
|
||||||
formatted_loras.append(f"<lora:{lora_name}:{model_str}>")
|
|
||||||
|
|
||||||
formatted_loras_text = " ".join(formatted_loras)
|
|
||||||
|
|
||||||
return (model, clip, trigger_words_text, formatted_loras_text)
|
return (model, clip, trigger_words_text, formatted_loras_text)
|
||||||
570
py/nodes/nunchaku_qwen.py
Normal file
570
py/nodes/nunchaku_qwen.py
Normal file
@@ -0,0 +1,570 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
"""Qwen-Image LoRA support for Nunchaku models.
|
||||||
|
|
||||||
|
Portions of the LoRA mapping/application logic in this file are adapted from
|
||||||
|
ComfyUI-QwenImageLoraLoader by GitHub user ussoewwin:
|
||||||
|
https://github.com/ussoewwin/ComfyUI-QwenImageLoraLoader
|
||||||
|
|
||||||
|
The upstream project is licensed under Apache License 2.0.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import copy
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
from collections import defaultdict
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import comfy.utils # type: ignore
|
||||||
|
import folder_paths # type: ignore
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from safetensors import safe_open
|
||||||
|
|
||||||
|
from nunchaku.lora.flux.nunchaku_converter import (
|
||||||
|
pack_lowrank_weight,
|
||||||
|
unpack_lowrank_weight,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
KEY_MAPPING = [
|
||||||
|
(re.compile(r"^(layers)[._](\d+)[._]attention[._]to[._]([qkv])$"), r"\1.\2.attention.to_qkv", "qkv", lambda m: m.group(3).upper()),
|
||||||
|
(re.compile(r"^(layers)[._](\d+)[._]feed_forward[._](w1|w3)$"), r"\1.\2.feed_forward.net.0.proj", "glu", lambda m: m.group(3)),
|
||||||
|
(re.compile(r"^(layers)[._](\d+)[._]feed_forward[._]w2$"), r"\1.\2.feed_forward.net.2", "regular", None),
|
||||||
|
(re.compile(r"^(layers)[._](\d+)[._](.*)$"), r"\1.\2.\3", "regular", None),
|
||||||
|
(re.compile(r"^(transformer_blocks)[._](\d+)[._]attn[._]to[._]([qkv])$"), r"\1.\2.attn.to_qkv", "qkv", lambda m: m.group(3).upper()),
|
||||||
|
(re.compile(r"^(transformer_blocks)[._](\d+)[._]attn[._](q|k|v)[._]proj$"), r"\1.\2.attn.to_qkv", "qkv", lambda m: m.group(3).upper()),
|
||||||
|
(re.compile(r"^(transformer_blocks)[._](\d+)[._]attn[._]add[._](q|k|v)[._]proj$"), r"\1.\2.attn.add_qkv_proj", "add_qkv", lambda m: m.group(3).upper()),
|
||||||
|
(re.compile(r"^(transformer_blocks)[._](\d+)[._]out[._]proj[._]context$"), r"\1.\2.attn.to_add_out", "regular", None),
|
||||||
|
(re.compile(r"^(transformer_blocks)[._](\d+)[._]out[._]proj$"), r"\1.\2.attn.to_out.0", "regular", None),
|
||||||
|
(re.compile(r"^(transformer_blocks)[._](\d+)[._]attn[._]to[._]out$"), r"\1.\2.attn.to_out.0", "regular", None),
|
||||||
|
(re.compile(r"^(single_transformer_blocks)[._](\d+)[._]attn[._]to[._]([qkv])$"), r"\1.\2.attn.to_qkv", "qkv", lambda m: m.group(3).upper()),
|
||||||
|
(re.compile(r"^(single_transformer_blocks)[._](\d+)[._]attn[._]to[._]out$"), r"\1.\2.attn.to_out", "regular", None),
|
||||||
|
(re.compile(r"^(transformer_blocks)[._](\d+)[._]ff[._]net[._]0(?:[._]proj)?$"), r"\1.\2.mlp_fc1", "regular", None),
|
||||||
|
(re.compile(r"^(transformer_blocks)[._](\d+)[._]ff[._]net[._]2$"), r"\1.\2.mlp_fc2", "regular", None),
|
||||||
|
(re.compile(r"^(transformer_blocks)[._](\d+)[._]ff_context[._]net[._]0(?:[._]proj)?$"), r"\1.\2.mlp_context_fc1", "regular", None),
|
||||||
|
(re.compile(r"^(transformer_blocks)[._](\d+)[._]ff_context[._]net[._]2$"), r"\1.\2.mlp_context_fc2", "regular", None),
|
||||||
|
(re.compile(r"^(transformer_blocks)[._](\d+)[._](img_mlp)[._](net)[._](0)[._](proj)$"), r"\1.\2.\3.\4.\5.\6", "regular", None),
|
||||||
|
(re.compile(r"^(transformer_blocks)[._](\d+)[._](img_mlp)[._](net)[._](2)$"), r"\1.\2.\3.\4.\5", "regular", None),
|
||||||
|
(re.compile(r"^(transformer_blocks)[._](\d+)[._](txt_mlp)[._](net)[._](0)[._](proj)$"), r"\1.\2.\3.\4.\5.\6", "regular", None),
|
||||||
|
(re.compile(r"^(transformer_blocks)[._](\d+)[._](txt_mlp)[._](net)[._](2)$"), r"\1.\2.\3.\4.\5", "regular", None),
|
||||||
|
(re.compile(r"^(transformer_blocks)[._](\d+)[._](img_mod)[._](1)$"), r"\1.\2.\3.\4", "regular", None),
|
||||||
|
(re.compile(r"^(transformer_blocks)[._](\d+)[._](txt_mod)[._](1)$"), r"\1.\2.\3.\4", "regular", None),
|
||||||
|
(re.compile(r"^(single_transformer_blocks)[._](\d+)[._]proj[._]out$"), r"\1.\2.proj_out", "single_proj_out", None),
|
||||||
|
(re.compile(r"^(single_transformer_blocks)[._](\d+)[._]proj[._]mlp$"), r"\1.\2.mlp_fc1", "regular", None),
|
||||||
|
(re.compile(r"^(single_transformer_blocks)[._](\d+)[._]norm[._]linear$"), r"\1.\2.norm.linear", "regular", None),
|
||||||
|
(re.compile(r"^(transformer_blocks)[._](\d+)[._]norm1[._]linear$"), r"\1.\2.norm1.linear", "regular", None),
|
||||||
|
(re.compile(r"^(transformer_blocks)[._](\d+)[._]norm1_context[._]linear$"), r"\1.\2.norm1_context.linear", "regular", None),
|
||||||
|
(re.compile(r"^(img_in)$"), r"\1", "regular", None),
|
||||||
|
(re.compile(r"^(txt_in)$"), r"\1", "regular", None),
|
||||||
|
(re.compile(r"^(proj_out)$"), r"\1", "regular", None),
|
||||||
|
(re.compile(r"^(norm_out)[._](linear)$"), r"\1.\2", "regular", None),
|
||||||
|
(re.compile(r"^(time_text_embed)[._](timestep_embedder)[._](linear_1)$"), r"\1.\2.\3", "regular", None),
|
||||||
|
(re.compile(r"^(time_text_embed)[._](timestep_embedder)[._](linear_2)$"), r"\1.\2.\3", "regular", None),
|
||||||
|
]
|
||||||
|
|
||||||
|
_RE_LORA_SUFFIX = re.compile(r"\.(?P<tag>lora(?:[._](?:A|B|down|up)))(?:\.[^.]+)*\.weight$")
|
||||||
|
_RE_ALPHA_SUFFIX = re.compile(r"\.(?:alpha|lora_alpha)(?:\.[^.]+)*$")
|
||||||
|
|
||||||
|
|
||||||
|
def _rename_layer_underscore_layer_name(old_name: str) -> str:
|
||||||
|
rules = [
|
||||||
|
(r"_(\d+)_attn_to_out_(\d+)", r".\1.attn.to_out.\2"),
|
||||||
|
(r"_(\d+)_img_mlp_net_(\d+)_proj", r".\1.img_mlp.net.\2.proj"),
|
||||||
|
(r"_(\d+)_txt_mlp_net_(\d+)_proj", r".\1.txt_mlp.net.\2.proj"),
|
||||||
|
(r"_(\d+)_img_mlp_net_(\d+)", r".\1.img_mlp.net.\2"),
|
||||||
|
(r"_(\d+)_txt_mlp_net_(\d+)", r".\1.txt_mlp.net.\2"),
|
||||||
|
(r"_(\d+)_img_mod_(\d+)", r".\1.img_mod.\2"),
|
||||||
|
(r"_(\d+)_txt_mod_(\d+)", r".\1.txt_mod.\2"),
|
||||||
|
(r"_(\d+)_attn_", r".\1.attn."),
|
||||||
|
]
|
||||||
|
new_name = old_name
|
||||||
|
for pattern, replacement in rules:
|
||||||
|
new_name = re.sub(pattern, replacement, new_name)
|
||||||
|
return new_name
|
||||||
|
|
||||||
|
|
||||||
|
def _is_indexable_module(module):
|
||||||
|
return isinstance(module, (nn.ModuleList, nn.Sequential, list, tuple))
|
||||||
|
|
||||||
|
|
||||||
|
def _get_module_by_name(model: nn.Module, name: str) -> Optional[nn.Module]:
|
||||||
|
if not name:
|
||||||
|
return model
|
||||||
|
module = model
|
||||||
|
for part in name.split("."):
|
||||||
|
if not part:
|
||||||
|
continue
|
||||||
|
if hasattr(module, part):
|
||||||
|
module = getattr(module, part)
|
||||||
|
elif part.isdigit() and _is_indexable_module(module):
|
||||||
|
try:
|
||||||
|
module = module[int(part)]
|
||||||
|
except (IndexError, TypeError):
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
return module
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_module_name(model: nn.Module, name: str) -> Tuple[str, Optional[nn.Module]]:
|
||||||
|
module = _get_module_by_name(model, name)
|
||||||
|
if module is not None:
|
||||||
|
return name, module
|
||||||
|
|
||||||
|
replacements = [
|
||||||
|
(".attn.to_out.0", ".attn.to_out"),
|
||||||
|
(".attention.to_qkv", ".attention.qkv"),
|
||||||
|
(".attention.to_out.0", ".attention.out"),
|
||||||
|
(".feed_forward.net.0.proj", ".feed_forward.w13"),
|
||||||
|
(".feed_forward.net.2", ".feed_forward.w2"),
|
||||||
|
(".ff.net.0.proj", ".mlp_fc1"),
|
||||||
|
(".ff.net.2", ".mlp_fc2"),
|
||||||
|
(".ff_context.net.0.proj", ".mlp_context_fc1"),
|
||||||
|
(".ff_context.net.2", ".mlp_context_fc2"),
|
||||||
|
]
|
||||||
|
for src, dst in replacements:
|
||||||
|
if src in name:
|
||||||
|
alt = name.replace(src, dst)
|
||||||
|
module = _get_module_by_name(model, alt)
|
||||||
|
if module is not None:
|
||||||
|
return alt, module
|
||||||
|
return name, None
|
||||||
|
|
||||||
|
|
||||||
|
def _classify_and_map_key(key: str) -> Optional[Tuple[str, str, Optional[str], str]]:
|
||||||
|
normalized = key
|
||||||
|
if normalized.startswith("transformer."):
|
||||||
|
normalized = normalized[len("transformer."):]
|
||||||
|
if normalized.startswith("diffusion_model."):
|
||||||
|
normalized = normalized[len("diffusion_model."):]
|
||||||
|
if normalized.startswith("lora_unet_"):
|
||||||
|
normalized = _rename_layer_underscore_layer_name(normalized[len("lora_unet_"):])
|
||||||
|
|
||||||
|
match = _RE_LORA_SUFFIX.search(normalized)
|
||||||
|
if match:
|
||||||
|
tag = match.group("tag")
|
||||||
|
base = normalized[:match.start()]
|
||||||
|
ab = "A" if ("lora_A" in tag or tag.endswith(".A") or "down" in tag) else "B"
|
||||||
|
else:
|
||||||
|
match = _RE_ALPHA_SUFFIX.search(normalized)
|
||||||
|
if not match:
|
||||||
|
return None
|
||||||
|
base = normalized[:match.start()]
|
||||||
|
ab = "alpha"
|
||||||
|
|
||||||
|
for pattern, template, group, comp_fn in KEY_MAPPING:
|
||||||
|
key_match = pattern.match(base)
|
||||||
|
if key_match:
|
||||||
|
return group, key_match.expand(template), comp_fn(key_match) if comp_fn else None, ab
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _detect_lora_format(lora_state_dict: Dict[str, torch.Tensor]) -> bool:
|
||||||
|
standard_patterns = (
|
||||||
|
".lora_up.",
|
||||||
|
".lora_down.",
|
||||||
|
".lora_A.",
|
||||||
|
".lora_B.",
|
||||||
|
".lora.up.",
|
||||||
|
".lora.down.",
|
||||||
|
".lora.A.",
|
||||||
|
".lora.B.",
|
||||||
|
)
|
||||||
|
return any(pattern in key for key in lora_state_dict for pattern in standard_patterns)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_lora_state_dict(path_or_dict: Union[str, Path, Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
|
||||||
|
if isinstance(path_or_dict, dict):
|
||||||
|
return path_or_dict
|
||||||
|
path = Path(path_or_dict)
|
||||||
|
if path.suffix == ".safetensors":
|
||||||
|
state_dict: Dict[str, torch.Tensor] = {}
|
||||||
|
with safe_open(path, framework="pt", device="cpu") as handle:
|
||||||
|
for key in handle.keys():
|
||||||
|
state_dict[key] = handle.get_tensor(key)
|
||||||
|
return state_dict
|
||||||
|
return comfy.utils.load_torch_file(str(path), safe_load=True)
|
||||||
|
|
||||||
|
|
||||||
|
def _fuse_glu_lora(glu_weights: Dict[str, torch.Tensor]) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||||
|
if "w1_A" not in glu_weights or "w3_A" not in glu_weights:
|
||||||
|
return None, None, None
|
||||||
|
a_w1, b_w1 = glu_weights["w1_A"], glu_weights["w1_B"]
|
||||||
|
a_w3, b_w3 = glu_weights["w3_A"], glu_weights["w3_B"]
|
||||||
|
if a_w1.shape[1] != a_w3.shape[1]:
|
||||||
|
return None, None, None
|
||||||
|
a_fused = torch.cat([a_w1, a_w3], dim=0)
|
||||||
|
out1, out3 = b_w1.shape[0], b_w3.shape[0]
|
||||||
|
rank1, rank3 = b_w1.shape[1], b_w3.shape[1]
|
||||||
|
b_fused = torch.zeros(out1 + out3, rank1 + rank3, dtype=b_w1.dtype, device=b_w1.device)
|
||||||
|
b_fused[:out1, :rank1] = b_w1
|
||||||
|
b_fused[out1:, rank1:] = b_w3
|
||||||
|
return a_fused, b_fused, glu_weights.get("w1_alpha")
|
||||||
|
|
||||||
|
|
||||||
|
def _fuse_qkv_lora(qkv_weights: Dict[str, torch.Tensor], model: Optional[nn.Module] = None, base_key: Optional[str] = None) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||||
|
required_keys = ["Q_A", "Q_B", "K_A", "K_B", "V_A", "V_B"]
|
||||||
|
if not all(key in qkv_weights for key in required_keys):
|
||||||
|
return None, None, None
|
||||||
|
a_q, a_k, a_v = qkv_weights["Q_A"], qkv_weights["K_A"], qkv_weights["V_A"]
|
||||||
|
b_q, b_k, b_v = qkv_weights["Q_B"], qkv_weights["K_B"], qkv_weights["V_B"]
|
||||||
|
if not (a_q.shape == a_k.shape == a_v.shape):
|
||||||
|
return None, None, None
|
||||||
|
if not (b_q.shape[1] == b_k.shape[1] == b_v.shape[1]):
|
||||||
|
return None, None, None
|
||||||
|
|
||||||
|
out_features = None
|
||||||
|
if model is not None and base_key is not None:
|
||||||
|
_, module = _resolve_module_name(model, base_key)
|
||||||
|
out_features = getattr(module, "out_features", None) if module is not None else None
|
||||||
|
|
||||||
|
alpha_fused = None
|
||||||
|
alpha_q = qkv_weights.get("Q_alpha")
|
||||||
|
alpha_k = qkv_weights.get("K_alpha")
|
||||||
|
alpha_v = qkv_weights.get("V_alpha")
|
||||||
|
if alpha_q is not None and alpha_k is not None and alpha_v is not None and alpha_q.item() == alpha_k.item() == alpha_v.item():
|
||||||
|
alpha_fused = alpha_q
|
||||||
|
|
||||||
|
a_fused = torch.cat([a_q, a_k, a_v], dim=0)
|
||||||
|
rank = b_q.shape[1]
|
||||||
|
out_q, out_k, out_v = b_q.shape[0], b_k.shape[0], b_v.shape[0]
|
||||||
|
total_out = out_features if out_features is not None else out_q + out_k + out_v
|
||||||
|
b_fused = torch.zeros(total_out, 3 * rank, dtype=b_q.dtype, device=b_q.device)
|
||||||
|
b_fused[:out_q, :rank] = b_q
|
||||||
|
b_fused[out_q:out_q + out_k, rank:2 * rank] = b_k
|
||||||
|
b_fused[out_q + out_k:out_q + out_k + out_v, 2 * rank:] = b_v
|
||||||
|
return a_fused, b_fused, alpha_fused
|
||||||
|
|
||||||
|
|
||||||
|
def _handle_proj_out_split(lora_dict: Dict[str, Dict[str, torch.Tensor]], base_key: str, model: nn.Module) -> Tuple[Dict[str, Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]], List[str]]:
|
||||||
|
result: Dict[str, Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]] = {}
|
||||||
|
consumed: List[str] = []
|
||||||
|
match = re.search(r"single_transformer_blocks\.(\d+)", base_key)
|
||||||
|
if not match or base_key not in lora_dict:
|
||||||
|
return result, consumed
|
||||||
|
block_idx = match.group(1)
|
||||||
|
block = _get_module_by_name(model, f"single_transformer_blocks.{block_idx}")
|
||||||
|
if block is None:
|
||||||
|
return result, consumed
|
||||||
|
a_full = lora_dict[base_key].get("A")
|
||||||
|
b_full = lora_dict[base_key].get("B")
|
||||||
|
alpha = lora_dict[base_key].get("alpha")
|
||||||
|
attn_to_out = getattr(getattr(block, "attn", None), "to_out", None)
|
||||||
|
mlp_fc2 = getattr(block, "mlp_fc2", None)
|
||||||
|
if a_full is None or b_full is None or attn_to_out is None or mlp_fc2 is None:
|
||||||
|
return result, consumed
|
||||||
|
attn_in = getattr(attn_to_out, "in_features", None)
|
||||||
|
mlp_in = getattr(mlp_fc2, "in_features", None)
|
||||||
|
if attn_in is None or mlp_in is None or a_full.shape[1] != attn_in + mlp_in:
|
||||||
|
return result, consumed
|
||||||
|
result[f"single_transformer_blocks.{block_idx}.attn.to_out"] = (a_full[:, :attn_in], b_full.clone(), alpha)
|
||||||
|
result[f"single_transformer_blocks.{block_idx}.mlp_fc2"] = (a_full[:, attn_in:], b_full.clone(), alpha)
|
||||||
|
consumed.append(base_key)
|
||||||
|
return result, consumed
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_lora_to_module(module: nn.Module, a_tensor: torch.Tensor, b_tensor: torch.Tensor, module_name: str, model: nn.Module) -> None:
|
||||||
|
if not hasattr(module, "in_features") or not hasattr(module, "out_features"):
|
||||||
|
raise ValueError(f"{module_name}: unsupported module without in/out features")
|
||||||
|
if a_tensor.shape[1] != module.in_features or b_tensor.shape[0] != module.out_features:
|
||||||
|
raise ValueError(f"{module_name}: LoRA shape mismatch")
|
||||||
|
|
||||||
|
if module.__class__.__name__ == "AWQW4A16Linear" and hasattr(module, "qweight"):
|
||||||
|
if not hasattr(module, "_lora_original_forward"):
|
||||||
|
module._lora_original_forward = module.forward
|
||||||
|
if not hasattr(module, "_nunchaku_lora_bundle"):
|
||||||
|
module._nunchaku_lora_bundle = []
|
||||||
|
module._nunchaku_lora_bundle.append((a_tensor, b_tensor))
|
||||||
|
|
||||||
|
def _awq_lora_forward(x, *args, **kwargs):
|
||||||
|
out = module._lora_original_forward(x, *args, **kwargs)
|
||||||
|
x_flat = x.reshape(-1, module.in_features)
|
||||||
|
for local_a, local_b in module._nunchaku_lora_bundle:
|
||||||
|
local_a = local_a.to(device=out.device, dtype=out.dtype)
|
||||||
|
local_b = local_b.to(device=out.device, dtype=out.dtype)
|
||||||
|
lora_term = (x_flat @ local_a.transpose(0, 1)) @ local_b.transpose(0, 1)
|
||||||
|
try:
|
||||||
|
out = out + lora_term.reshape(out.shape)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return out
|
||||||
|
|
||||||
|
module.forward = _awq_lora_forward
|
||||||
|
if not hasattr(model, "_lora_slots"):
|
||||||
|
model._lora_slots = {}
|
||||||
|
model._lora_slots[module_name] = {"type": "awq_w4a16"}
|
||||||
|
return
|
||||||
|
|
||||||
|
if hasattr(module, "proj_down") and hasattr(module, "proj_up"):
|
||||||
|
proj_down = unpack_lowrank_weight(module.proj_down.data, down=True)
|
||||||
|
proj_up = unpack_lowrank_weight(module.proj_up.data, down=False)
|
||||||
|
base_rank = proj_down.shape[0] if proj_down.shape[1] == module.in_features else proj_down.shape[1]
|
||||||
|
if proj_down.shape[1] == module.in_features:
|
||||||
|
updated_down = torch.cat([proj_down, a_tensor], dim=0)
|
||||||
|
axis_down = 0
|
||||||
|
else:
|
||||||
|
updated_down = torch.cat([proj_down, a_tensor.T], dim=1)
|
||||||
|
axis_down = 1
|
||||||
|
updated_up = torch.cat([proj_up, b_tensor], dim=1)
|
||||||
|
module.proj_down.data = pack_lowrank_weight(updated_down, down=True)
|
||||||
|
module.proj_up.data = pack_lowrank_weight(updated_up, down=False)
|
||||||
|
module.rank = base_rank + a_tensor.shape[0]
|
||||||
|
if not hasattr(model, "_lora_slots"):
|
||||||
|
model._lora_slots = {}
|
||||||
|
model._lora_slots[module_name] = {
|
||||||
|
"type": "nunchaku",
|
||||||
|
"base_rank": base_rank,
|
||||||
|
"axis_down": axis_down,
|
||||||
|
}
|
||||||
|
return
|
||||||
|
|
||||||
|
if isinstance(module, nn.Linear):
|
||||||
|
if not hasattr(model, "_lora_slots"):
|
||||||
|
model._lora_slots = {}
|
||||||
|
if module_name not in model._lora_slots:
|
||||||
|
model._lora_slots[module_name] = {
|
||||||
|
"type": "linear",
|
||||||
|
"original_weight": module.weight.detach().cpu().clone(),
|
||||||
|
}
|
||||||
|
module.weight.data.add_((b_tensor @ a_tensor).to(dtype=module.weight.dtype, device=module.weight.device))
|
||||||
|
return
|
||||||
|
|
||||||
|
raise ValueError(f"{module_name}: unsupported module type {type(module)}")
|
||||||
|
|
||||||
|
|
||||||
|
def reset_lora_v2(model: nn.Module) -> None:
|
||||||
|
slots = getattr(model, "_lora_slots", None)
|
||||||
|
if not slots:
|
||||||
|
return
|
||||||
|
for name, info in list(slots.items()):
|
||||||
|
module = _get_module_by_name(model, name)
|
||||||
|
if module is None:
|
||||||
|
continue
|
||||||
|
module_type = info.get("type", "nunchaku")
|
||||||
|
if module_type == "nunchaku":
|
||||||
|
base_rank = info["base_rank"]
|
||||||
|
proj_down = unpack_lowrank_weight(module.proj_down.data, down=True)
|
||||||
|
proj_up = unpack_lowrank_weight(module.proj_up.data, down=False)
|
||||||
|
if info.get("axis_down", 0) == 0:
|
||||||
|
proj_down = proj_down[:base_rank, :].clone()
|
||||||
|
else:
|
||||||
|
proj_down = proj_down[:, :base_rank].clone()
|
||||||
|
proj_up = proj_up[:, :base_rank].clone()
|
||||||
|
module.proj_down.data = pack_lowrank_weight(proj_down, down=True)
|
||||||
|
module.proj_up.data = pack_lowrank_weight(proj_up, down=False)
|
||||||
|
module.rank = base_rank
|
||||||
|
elif module_type == "linear" and "original_weight" in info:
|
||||||
|
module.weight.data.copy_(info["original_weight"].to(device=module.weight.device, dtype=module.weight.dtype))
|
||||||
|
elif module_type == "awq_w4a16":
|
||||||
|
if hasattr(module, "_lora_original_forward"):
|
||||||
|
module.forward = module._lora_original_forward
|
||||||
|
for attr in ("_lora_original_forward", "_nunchaku_lora_bundle"):
|
||||||
|
if hasattr(module, attr):
|
||||||
|
delattr(module, attr)
|
||||||
|
model._lora_slots = {}
|
||||||
|
|
||||||
|
|
||||||
|
def compose_loras_v2(model: nn.Module, lora_configs: List[Tuple[Union[str, Path, Dict[str, torch.Tensor]], float]], apply_awq_mod: bool = True) -> bool:
|
||||||
|
del apply_awq_mod # retained for interface compatibility
|
||||||
|
reset_lora_v2(model)
|
||||||
|
aggregated_weights: Dict[str, List[Dict[str, object]]] = defaultdict(list)
|
||||||
|
saw_supported_format = False
|
||||||
|
unresolved_targets = 0
|
||||||
|
|
||||||
|
for index, (path_or_dict, strength) in enumerate(lora_configs):
|
||||||
|
if abs(strength) < 1e-5:
|
||||||
|
continue
|
||||||
|
lora_name = str(path_or_dict) if not isinstance(path_or_dict, dict) else f"lora_{index}"
|
||||||
|
lora_state_dict = _load_lora_state_dict(path_or_dict)
|
||||||
|
if not lora_state_dict or not _detect_lora_format(lora_state_dict):
|
||||||
|
logger.warning("Skipping unsupported Qwen LoRA: %s", lora_name)
|
||||||
|
continue
|
||||||
|
saw_supported_format = True
|
||||||
|
|
||||||
|
grouped_weights: Dict[str, Dict[str, torch.Tensor]] = defaultdict(dict)
|
||||||
|
for key, value in lora_state_dict.items():
|
||||||
|
parsed = _classify_and_map_key(key)
|
||||||
|
if parsed is None:
|
||||||
|
continue
|
||||||
|
group, base_key, component, ab = parsed
|
||||||
|
if component and ab:
|
||||||
|
grouped_weights[base_key][f"{component}_{ab}"] = value
|
||||||
|
else:
|
||||||
|
grouped_weights[base_key][ab] = value
|
||||||
|
|
||||||
|
processed_groups: Dict[str, Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]] = {}
|
||||||
|
handled: set[str] = set()
|
||||||
|
for base_key, weights in grouped_weights.items():
|
||||||
|
if base_key in handled:
|
||||||
|
continue
|
||||||
|
a_tensor = b_tensor = alpha = None
|
||||||
|
if "qkv" in base_key or "add_qkv_proj" in base_key:
|
||||||
|
a_tensor, b_tensor, alpha = _fuse_qkv_lora(weights, model=model, base_key=base_key)
|
||||||
|
elif "w1_A" in weights or "w3_A" in weights:
|
||||||
|
a_tensor, b_tensor, alpha = _fuse_glu_lora(weights)
|
||||||
|
elif ".proj_out" in base_key and "single_transformer_blocks" in base_key:
|
||||||
|
split_map, consumed = _handle_proj_out_split(grouped_weights, base_key, model)
|
||||||
|
processed_groups.update(split_map)
|
||||||
|
handled.update(consumed)
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
a_tensor, b_tensor, alpha = weights.get("A"), weights.get("B"), weights.get("alpha")
|
||||||
|
if a_tensor is not None and b_tensor is not None:
|
||||||
|
processed_groups[base_key] = (a_tensor, b_tensor, alpha)
|
||||||
|
|
||||||
|
for module_name, (a_tensor, b_tensor, alpha) in processed_groups.items():
|
||||||
|
aggregated_weights[module_name].append({
|
||||||
|
"A": a_tensor,
|
||||||
|
"B": b_tensor,
|
||||||
|
"alpha": alpha,
|
||||||
|
"strength": strength,
|
||||||
|
})
|
||||||
|
|
||||||
|
for module_name, weight_list in aggregated_weights.items():
|
||||||
|
resolved_name, module = _resolve_module_name(model, module_name)
|
||||||
|
if module is None:
|
||||||
|
logger.warning("Skipping unresolved Qwen LoRA target: %s", module_name)
|
||||||
|
unresolved_targets += 1
|
||||||
|
continue
|
||||||
|
all_a = []
|
||||||
|
all_b_scaled = []
|
||||||
|
for item in weight_list:
|
||||||
|
a_tensor = item["A"]
|
||||||
|
b_tensor = item["B"]
|
||||||
|
alpha = item["alpha"]
|
||||||
|
strength = float(item["strength"])
|
||||||
|
rank = a_tensor.shape[0]
|
||||||
|
scale = strength * ((alpha / rank) if alpha is not None else 1.0)
|
||||||
|
if module.__class__.__name__ == "AWQW4A16Linear" and hasattr(module, "qweight"):
|
||||||
|
target_dtype = torch.float16
|
||||||
|
target_device = module.qweight.device
|
||||||
|
elif hasattr(module, "proj_down"):
|
||||||
|
target_dtype = module.proj_down.dtype
|
||||||
|
target_device = module.proj_down.device
|
||||||
|
elif hasattr(module, "weight"):
|
||||||
|
target_dtype = module.weight.dtype
|
||||||
|
target_device = module.weight.device
|
||||||
|
else:
|
||||||
|
target_dtype = torch.float16
|
||||||
|
target_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
all_a.append(a_tensor.to(dtype=target_dtype, device=target_device))
|
||||||
|
all_b_scaled.append((b_tensor * scale).to(dtype=target_dtype, device=target_device))
|
||||||
|
if not all_a:
|
||||||
|
continue
|
||||||
|
_apply_lora_to_module(module, torch.cat(all_a, dim=0), torch.cat(all_b_scaled, dim=1), resolved_name, model)
|
||||||
|
|
||||||
|
slot_count = len(getattr(model, "_lora_slots", {}) or {})
|
||||||
|
logger.info(
|
||||||
|
"Qwen LoRA composition finished: requested=%d supported=%s applied_targets=%d unresolved=%d",
|
||||||
|
len(lora_configs),
|
||||||
|
saw_supported_format,
|
||||||
|
slot_count,
|
||||||
|
unresolved_targets,
|
||||||
|
)
|
||||||
|
return saw_supported_format
|
||||||
|
|
||||||
|
|
||||||
|
class ComfyQwenImageWrapperLM(nn.Module):
|
||||||
|
def __init__(self, model: nn.Module, config=None, apply_awq_mod: bool = True):
|
||||||
|
super().__init__()
|
||||||
|
self.model = model
|
||||||
|
self.config = {} if config is None else config
|
||||||
|
self.dtype = next(model.parameters()).dtype
|
||||||
|
self.loras: List[Tuple[Union[str, Path, Dict[str, torch.Tensor]], float]] = []
|
||||||
|
self._applied_loras: Optional[List[Tuple[Union[str, Path, Dict[str, torch.Tensor]], float]]] = None
|
||||||
|
self.apply_awq_mod = apply_awq_mod
|
||||||
|
|
||||||
|
def __getattr__(self, name):
|
||||||
|
try:
|
||||||
|
inner = object.__getattribute__(self, "_modules").get("model")
|
||||||
|
except (AttributeError, KeyError):
|
||||||
|
inner = None
|
||||||
|
if inner is None:
|
||||||
|
raise AttributeError(f"{type(self).__name__!s} has no attribute {name}")
|
||||||
|
if name == "model":
|
||||||
|
return inner
|
||||||
|
return getattr(inner, name)
|
||||||
|
|
||||||
|
def process_img(self, *args, **kwargs):
|
||||||
|
return self.model.process_img(*args, **kwargs)
|
||||||
|
|
||||||
|
def _ensure_composed(self):
|
||||||
|
if self._applied_loras != self.loras or (not self.loras and getattr(self.model, "_lora_slots", None)):
|
||||||
|
is_supported_format = compose_loras_v2(self.model, self.loras, apply_awq_mod=self.apply_awq_mod)
|
||||||
|
self._applied_loras = self.loras.copy()
|
||||||
|
has_slots = bool(getattr(self.model, "_lora_slots", None))
|
||||||
|
if self.loras and is_supported_format and not has_slots:
|
||||||
|
logger.warning("Qwen LoRA compose produced 0 target modules. Resetting and retrying once.")
|
||||||
|
reset_lora_v2(self.model)
|
||||||
|
compose_loras_v2(self.model, self.loras, apply_awq_mod=self.apply_awq_mod)
|
||||||
|
has_slots = bool(getattr(self.model, "_lora_slots", None))
|
||||||
|
logger.info("Qwen LoRA retry result: applied_targets=%d", len(getattr(self.model, "_lora_slots", {}) or {}))
|
||||||
|
|
||||||
|
offload_manager = getattr(self.model, "offload_manager", None)
|
||||||
|
if offload_manager is not None:
|
||||||
|
offload_settings = {
|
||||||
|
"num_blocks_on_gpu": getattr(offload_manager, "num_blocks_on_gpu", 1),
|
||||||
|
"use_pin_memory": getattr(offload_manager, "use_pin_memory", False),
|
||||||
|
}
|
||||||
|
logger.info(
|
||||||
|
"Rebuilding Qwen offload manager after LoRA compose: num_blocks_on_gpu=%s use_pin_memory=%s",
|
||||||
|
offload_settings["num_blocks_on_gpu"],
|
||||||
|
offload_settings["use_pin_memory"],
|
||||||
|
)
|
||||||
|
self.model.set_offload(False)
|
||||||
|
self.model.set_offload(True, **offload_settings)
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
self._ensure_composed()
|
||||||
|
return self.model(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_qwen_wrapper_and_transformer(model):
|
||||||
|
model_wrapper = model.model.diffusion_model
|
||||||
|
if hasattr(model_wrapper, "model") and hasattr(model_wrapper, "loras"):
|
||||||
|
transformer = model_wrapper.model
|
||||||
|
if transformer.__class__.__name__.endswith("NunchakuQwenImageTransformer2DModel"):
|
||||||
|
return model_wrapper, transformer
|
||||||
|
if model_wrapper.__class__.__name__.endswith("NunchakuQwenImageTransformer2DModel"):
|
||||||
|
wrapped_model = ComfyQwenImageWrapperLM(model_wrapper, getattr(model_wrapper, "config", {}))
|
||||||
|
model.model.diffusion_model = wrapped_model
|
||||||
|
return wrapped_model, wrapped_model.model
|
||||||
|
raise TypeError(f"This LoRA loader only works with Nunchaku Qwen Image models, but got {type(model_wrapper).__name__}.")
|
||||||
|
|
||||||
|
|
||||||
|
def nunchaku_load_qwen_loras(model, lora_configs: List[Tuple[str, float]], apply_awq_mod: bool = True):
|
||||||
|
model_wrapper, transformer = _get_qwen_wrapper_and_transformer(model)
|
||||||
|
model_wrapper.apply_awq_mod = apply_awq_mod
|
||||||
|
|
||||||
|
saved_config = None
|
||||||
|
if hasattr(model, "model") and hasattr(model.model, "model_config"):
|
||||||
|
saved_config = model.model.model_config
|
||||||
|
model.model.model_config = None
|
||||||
|
|
||||||
|
model_wrapper.model = None
|
||||||
|
try:
|
||||||
|
ret_model = copy.deepcopy(model)
|
||||||
|
finally:
|
||||||
|
if saved_config is not None:
|
||||||
|
model.model.model_config = saved_config
|
||||||
|
model_wrapper.model = transformer
|
||||||
|
|
||||||
|
ret_model_wrapper = ret_model.model.diffusion_model
|
||||||
|
if saved_config is not None:
|
||||||
|
ret_model.model.model_config = saved_config
|
||||||
|
ret_model_wrapper.model = transformer
|
||||||
|
ret_model_wrapper.apply_awq_mod = apply_awq_mod
|
||||||
|
ret_model_wrapper.loras = list(getattr(model_wrapper, "loras", []))
|
||||||
|
|
||||||
|
for lora_name, lora_strength in lora_configs:
|
||||||
|
lora_path = lora_name if os.path.isfile(lora_name) else folder_paths.get_full_path("loras", lora_name)
|
||||||
|
if not lora_path or not os.path.isfile(lora_path):
|
||||||
|
logger.warning("Skipping Qwen LoRA '%s' because it could not be found", lora_name)
|
||||||
|
continue
|
||||||
|
ret_model_wrapper.loras.append((lora_path, lora_strength))
|
||||||
|
|
||||||
|
return ret_model
|
||||||
@@ -158,3 +158,24 @@ def nunchaku_load_lora(model, lora_name, lora_strength):
|
|||||||
ret_model.model.model_config.unet_config["in_channels"] = new_in_channels
|
ret_model.model.model_config.unet_config["in_channels"] = new_in_channels
|
||||||
|
|
||||||
return ret_model
|
return ret_model
|
||||||
|
|
||||||
|
|
||||||
|
def detect_nunchaku_model_kind(model):
|
||||||
|
"""Return the supported Nunchaku model kind for a Comfy model, if any."""
|
||||||
|
try:
|
||||||
|
model_wrapper = model.model.diffusion_model
|
||||||
|
except (AttributeError, TypeError):
|
||||||
|
return None
|
||||||
|
|
||||||
|
wrapper_name = model_wrapper.__class__.__name__
|
||||||
|
if wrapper_name == "ComfyFluxWrapper":
|
||||||
|
return "flux"
|
||||||
|
|
||||||
|
inner_model = getattr(model_wrapper, "model", None)
|
||||||
|
inner_name = inner_model.__class__.__name__ if inner_model is not None else ""
|
||||||
|
if wrapper_name.endswith("NunchakuQwenImageTransformer2DModel"):
|
||||||
|
return "qwen_image"
|
||||||
|
if inner_name.endswith("NunchakuQwenImageTransformer2DModel"):
|
||||||
|
return "qwen_image"
|
||||||
|
|
||||||
|
return None
|
||||||
|
|||||||
176
tests/nodes/test_lora_loader.py
Normal file
176
tests/nodes/test_lora_loader.py
Normal file
@@ -0,0 +1,176 @@
|
|||||||
|
import types
|
||||||
|
|
||||||
|
from py.nodes.lora_loader import LoraLoaderLM, LoraTextLoaderLM
|
||||||
|
|
||||||
|
|
||||||
|
class _ModelContainer:
|
||||||
|
def __init__(self, diffusion_model):
|
||||||
|
self.diffusion_model = diffusion_model
|
||||||
|
|
||||||
|
|
||||||
|
class _Model:
|
||||||
|
def __init__(self, diffusion_model):
|
||||||
|
self.model = _ModelContainer(diffusion_model)
|
||||||
|
|
||||||
|
|
||||||
|
def test_lora_loader_standard_model_uses_comfy_loader(monkeypatch):
|
||||||
|
loader = LoraLoaderLM()
|
||||||
|
model = _Model(object())
|
||||||
|
clip = object()
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"py.nodes.lora_loader.get_lora_info_absolute",
|
||||||
|
lambda name: (f"/abs/{name}.safetensors", [f"{name}_trigger"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
load_calls = []
|
||||||
|
|
||||||
|
def mock_load_torch_file(path, safe_load=True):
|
||||||
|
load_calls.append((path, safe_load))
|
||||||
|
return {"path": path}
|
||||||
|
|
||||||
|
def mock_load_lora_for_models(model_arg, clip_arg, lora_arg, model_strength, clip_strength):
|
||||||
|
return model_arg, clip_arg
|
||||||
|
|
||||||
|
monkeypatch.setattr("comfy.utils.load_torch_file", mock_load_torch_file)
|
||||||
|
monkeypatch.setattr("comfy.sd.load_lora_for_models", mock_load_lora_for_models)
|
||||||
|
|
||||||
|
result_model, result_clip, trigger_words, loaded_loras = loader.load_loras(
|
||||||
|
model,
|
||||||
|
"",
|
||||||
|
clip=clip,
|
||||||
|
loras={
|
||||||
|
"__value__": [
|
||||||
|
{"active": True, "name": "demo", "strength": 0.75, "clipStrength": 0.5},
|
||||||
|
]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result_model is model
|
||||||
|
assert result_clip is clip
|
||||||
|
assert load_calls == [("/abs/demo.safetensors", True)]
|
||||||
|
assert trigger_words == "demo_trigger"
|
||||||
|
assert loaded_loras == "<lora:demo:0.75:0.5>"
|
||||||
|
|
||||||
|
|
||||||
|
def test_lora_loader_formats_widget_lora_names_with_colons(monkeypatch):
|
||||||
|
loader = LoraLoaderLM()
|
||||||
|
model = _Model(object())
|
||||||
|
clip = object()
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"py.nodes.lora_loader.get_lora_info_absolute",
|
||||||
|
lambda name: (f"/abs/{name}.safetensors", [f"{name}_trigger"]),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr("comfy.utils.load_torch_file", lambda path, safe_load=True: {"path": path})
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"comfy.sd.load_lora_for_models",
|
||||||
|
lambda model_arg, clip_arg, lora_arg, model_strength, clip_strength: (model_arg, clip_arg),
|
||||||
|
)
|
||||||
|
|
||||||
|
_, _, trigger_words, loaded_loras = loader.load_loras(
|
||||||
|
model,
|
||||||
|
"",
|
||||||
|
clip=clip,
|
||||||
|
loras={
|
||||||
|
"__value__": [
|
||||||
|
{"active": True, "name": "demo:variant", "strength": 0.75, "clipStrength": 0.5},
|
||||||
|
{"active": True, "name": "demo:single", "strength": 0.3},
|
||||||
|
]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert trigger_words == "demo:variant_trigger,, demo:single_trigger"
|
||||||
|
assert loaded_loras == "<lora:demo:variant:0.75:0.5> <lora:demo:single:0.3>"
|
||||||
|
|
||||||
|
|
||||||
|
def test_lora_loader_flux_model_uses_flux_helper(monkeypatch):
|
||||||
|
flux_model = _Model(type("ComfyFluxWrapper", (), {})())
|
||||||
|
loader = LoraLoaderLM()
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"py.nodes.lora_loader.get_lora_info_absolute",
|
||||||
|
lambda name: (f"/abs/{name}.safetensors", [f"{name}_trigger"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
calls = []
|
||||||
|
|
||||||
|
def mock_nunchaku_load_lora(model_arg, lora_name, strength):
|
||||||
|
calls.append((lora_name, strength))
|
||||||
|
return model_arg
|
||||||
|
|
||||||
|
monkeypatch.setattr("py.nodes.lora_loader.nunchaku_load_lora", mock_nunchaku_load_lora)
|
||||||
|
|
||||||
|
_, _, trigger_words, loaded_loras = loader.load_loras(
|
||||||
|
flux_model,
|
||||||
|
"",
|
||||||
|
lora_stack=[("stack_lora.safetensors", 0.4, 0.2)],
|
||||||
|
loras={"__value__": [{"active": True, "name": "widget_lora", "strength": 0.8}]},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert calls == [("stack_lora.safetensors", 0.4), ("/abs/widget_lora.safetensors", 0.8)]
|
||||||
|
assert trigger_words == "stack_lora_trigger,, widget_lora_trigger"
|
||||||
|
assert loaded_loras == "<lora:stack_lora:0.4> <lora:widget_lora:0.8>"
|
||||||
|
|
||||||
|
|
||||||
|
def test_lora_loader_qwen_model_batches_loras(monkeypatch):
|
||||||
|
qwen_model = _Model(type("NunchakuQwenImageTransformer2DModel", (), {})())
|
||||||
|
loader = LoraLoaderLM()
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"py.nodes.lora_loader.get_lora_info_absolute",
|
||||||
|
lambda name: (f"/abs/{name}.safetensors", [f"{name}_trigger"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
batched_calls = []
|
||||||
|
|
||||||
|
def mock_nunchaku_load_qwen_loras(model_arg, lora_configs):
|
||||||
|
batched_calls.append((model_arg, lora_configs))
|
||||||
|
return model_arg
|
||||||
|
|
||||||
|
monkeypatch.setattr("py.nodes.lora_loader.nunchaku_load_qwen_loras", mock_nunchaku_load_qwen_loras)
|
||||||
|
|
||||||
|
_, result_clip, trigger_words, loaded_loras = loader.load_loras(
|
||||||
|
qwen_model,
|
||||||
|
"",
|
||||||
|
clip="clip",
|
||||||
|
lora_stack=[("stack_qwen.safetensors", 0.6, 0.1)],
|
||||||
|
loras={"__value__": [{"active": True, "name": "widget_qwen", "strength": 0.9, "clipStrength": 0.3}]},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result_clip == "clip"
|
||||||
|
assert len(batched_calls) == 1
|
||||||
|
assert batched_calls[0][0] is qwen_model
|
||||||
|
assert batched_calls[0][1] == [
|
||||||
|
("/abs/stack_qwen.safetensors", 0.6),
|
||||||
|
("/abs/widget_qwen.safetensors", 0.9),
|
||||||
|
]
|
||||||
|
assert trigger_words == "stack_qwen_trigger,, widget_qwen_trigger"
|
||||||
|
assert loaded_loras == "<lora:stack_qwen:0.6> <lora:widget_qwen:0.9>"
|
||||||
|
|
||||||
|
|
||||||
|
def test_lora_text_loader_qwen_batches_text_and_stack(monkeypatch):
|
||||||
|
qwen_model = _Model(type("NunchakuQwenImageTransformer2DModel", (), {})())
|
||||||
|
loader = LoraTextLoaderLM()
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"py.nodes.lora_loader.get_lora_info_absolute",
|
||||||
|
lambda name: (f"/abs/{name}.safetensors", [f"{name}_trigger"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
batched_calls = []
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"py.nodes.lora_loader.nunchaku_load_qwen_loras",
|
||||||
|
lambda model_arg, lora_configs: batched_calls.append(lora_configs) or model_arg,
|
||||||
|
)
|
||||||
|
|
||||||
|
_, _, trigger_words, loaded_loras = loader.load_loras_from_text(
|
||||||
|
qwen_model,
|
||||||
|
"<lora:text_qwen:1.2:0.4>",
|
||||||
|
clip="clip",
|
||||||
|
lora_stack=[("stack_qwen.safetensors", 0.6, 0.1)],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert batched_calls == [[("/abs/stack_qwen.safetensors", 0.6), ("/abs/text_qwen.safetensors", 1.2)]]
|
||||||
|
assert trigger_words == "stack_qwen_trigger,, text_qwen_trigger"
|
||||||
|
assert loaded_loras == "<lora:stack_qwen:0.6> <lora:text_qwen:1.2>"
|
||||||
Reference in New Issue
Block a user