Add Lora Loader node support for Nunchaku SVDQuant FLUX model architecture with template workflow. Fixes #255

This commit is contained in:
Will Miao
2025-06-29 23:57:50 +08:00
parent 6472e00fb0
commit 71762d788f
5 changed files with 105 additions and 12 deletions

Binary file not shown.

After

Width:  |  Height:  |  Size: 68 KiB

File diff suppressed because one or more lines are too long

View File

@@ -2,14 +2,14 @@ import logging
from nodes import LoraLoader
from comfy.comfy_types import IO # type: ignore
import asyncio
from .utils import FlexibleOptionalInputType, any_type, get_lora_info, extract_lora_name, get_loras_list
from .utils import FlexibleOptionalInputType, any_type, get_lora_info, extract_lora_name, get_loras_list, nunchaku_load_lora
logger = logging.getLogger(__name__)
class LoraManagerLoader:
NAME = "Lora Loader (LoraManager)"
CATEGORY = "Lora Manager/loaders"
@classmethod
def INPUT_TYPES(cls):
return {
@@ -37,19 +37,39 @@ class LoraManagerLoader:
clip = kwargs.get('clip', None)
lora_stack = kwargs.get('lora_stack', None)
# Check if model is a Nunchaku Flux model - simplified approach
is_nunchaku_model = False
try:
model_wrapper = model.model.diffusion_model
# Check if model is a Nunchaku Flux model using only class name
if model_wrapper.__class__.__name__ == "ComfyFluxWrapper":
is_nunchaku_model = True
logger.info("Detected Nunchaku Flux model")
except (AttributeError, TypeError):
# Not a model with the expected structure
pass
# First process lora_stack if available
if lora_stack:
for lora_path, model_strength, clip_strength in lora_stack:
# Apply the LoRA using the provided path and strengths
model, clip = LoraLoader().load_lora(model, clip, lora_path, model_strength, clip_strength)
# Apply the LoRA using the appropriate loader
if is_nunchaku_model:
# Use our custom function for Flux models
model = nunchaku_load_lora(model, lora_path, model_strength)
# clip remains unchanged for Nunchaku models
else:
# Use default loader for standard models
model, clip = LoraLoader().load_lora(model, clip, lora_path, model_strength, clip_strength)
# Extract lora name for trigger words lookup
lora_name = extract_lora_name(lora_path)
_, trigger_words = asyncio.run(get_lora_info(lora_name))
all_trigger_words.extend(trigger_words)
# Add clip strength to output if different from model strength
if abs(model_strength - clip_strength) > 0.001:
# Add clip strength to output if different from model strength (except for Nunchaku models)
if not is_nunchaku_model and abs(model_strength - clip_strength) > 0.001:
loaded_loras.append(f"{lora_name}: {model_strength},{clip_strength}")
else:
loaded_loras.append(f"{lora_name}: {model_strength}")
@@ -68,11 +88,17 @@ class LoraManagerLoader:
# Get lora path and trigger words
lora_path, trigger_words = asyncio.run(get_lora_info(lora_name))
# Apply the LoRA using the resolved path with separate strengths
model, clip = LoraLoader().load_lora(model, clip, lora_path, model_strength, clip_strength)
# Apply the LoRA using the appropriate loader
if is_nunchaku_model:
# For Nunchaku models, use our custom function
model = nunchaku_load_lora(model, lora_path, model_strength)
# clip remains unchanged
else:
# Use default loader for standard models
model, clip = LoraLoader().load_lora(model, clip, lora_path, model_strength, clip_strength)
# Include clip strength in output if different from model strength
if abs(model_strength - clip_strength) > 0.001:
# Include clip strength in output if different from model strength and not a Nunchaku model
if not is_nunchaku_model and abs(model_strength - clip_strength) > 0.001:
loaded_loras.append(f"{lora_name}: {model_strength},{clip_strength}")
else:
loaded_loras.append(f"{lora_name}: {model_strength}")

View File

@@ -35,7 +35,12 @@ any_type = AnyType("*")
# Common methods extracted from lora_loader.py and lora_stacker.py
import os
import logging
import asyncio
import copy
import folder_paths
import torch
import safetensors.torch
from diffusers.utils.state_dict_utils import convert_unet_state_dict_to_peft
from diffusers.loaders import FluxLoraLoaderMixin
from ..services.lora_scanner import LoraScanner
from ..config import config
@@ -81,4 +86,64 @@ def get_loras_list(kwargs):
# Unexpected format
else:
logger.warning(f"Unexpected loras format: {type(loras_data)}")
return []
return []
def load_state_dict_in_safetensors(path, device="cpu", filter_prefix=""):
"""Simplified version of load_state_dict_in_safetensors that just loads from a local path"""
state_dict = {}
with safetensors.torch.safe_open(path, framework="pt", device=device) as f:
for k in f.keys():
if filter_prefix and not k.startswith(filter_prefix):
continue
state_dict[k.removeprefix(filter_prefix)] = f.get_tensor(k)
return state_dict
def to_diffusers(input_lora):
"""Simplified version of to_diffusers for Flux LoRA conversion"""
if isinstance(input_lora, str):
tensors = load_state_dict_in_safetensors(input_lora, device="cpu")
else:
tensors = {k: v for k, v in input_lora.items()}
# Convert FP8 tensors to BF16
for k, v in tensors.items():
if v.dtype not in [torch.float64, torch.float32, torch.bfloat16, torch.float16]:
tensors[k] = v.to(torch.bfloat16)
new_tensors = FluxLoraLoaderMixin.lora_state_dict(tensors)
new_tensors = convert_unet_state_dict_to_peft(new_tensors)
return new_tensors
def nunchaku_load_lora(model, lora_name, lora_strength):
"""Load a Flux LoRA for Nunchaku model"""
model_wrapper = model.model.diffusion_model
transformer = model_wrapper.model
# Save the transformer temporarily
model_wrapper.model = None
ret_model = copy.deepcopy(model) # copy everything except the model
ret_model_wrapper = ret_model.model.diffusion_model
# Restore the model and set it for the copy
model_wrapper.model = transformer
ret_model_wrapper.model = transformer
# Get full path to the LoRA file
lora_path = folder_paths.get_full_path("loras", lora_name)
ret_model_wrapper.loras.append((lora_path, lora_strength))
# Convert the LoRA to diffusers format
sd = to_diffusers(lora_path)
# Handle embedding adjustment if needed
if "transformer.x_embedder.lora_A.weight" in sd:
new_in_channels = sd["transformer.x_embedder.lora_A.weight"].shape[1]
assert new_in_channels % 4 == 0
new_in_channels = new_in_channels // 4
old_in_channels = ret_model.model.model_config.unet_config["in_channels"]
if old_in_channels < new_in_channels:
ret_model.model.model_config.unet_config["in_channels"] = new_in_channels
return ret_model

View File

@@ -7,6 +7,7 @@ dependencies = [
"aiohttp",
"jinja2",
"safetensors",
"diffusers",
"watchdog",
"beautifulsoup4",
"piexif",