mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
Add Lora Loader node support for Nunchaku SVDQuant FLUX model architecture with template workflow. Fixes #255
This commit is contained in:
BIN
example_workflows/nunchaku-flux.1-dev.jpg
Normal file
BIN
example_workflows/nunchaku-flux.1-dev.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 68 KiB |
1
example_workflows/nunchaku-flux.1-dev.json
Normal file
1
example_workflows/nunchaku-flux.1-dev.json
Normal file
File diff suppressed because one or more lines are too long
@@ -2,14 +2,14 @@ import logging
|
|||||||
from nodes import LoraLoader
|
from nodes import LoraLoader
|
||||||
from comfy.comfy_types import IO # type: ignore
|
from comfy.comfy_types import IO # type: ignore
|
||||||
import asyncio
|
import asyncio
|
||||||
from .utils import FlexibleOptionalInputType, any_type, get_lora_info, extract_lora_name, get_loras_list
|
from .utils import FlexibleOptionalInputType, any_type, get_lora_info, extract_lora_name, get_loras_list, nunchaku_load_lora
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class LoraManagerLoader:
|
class LoraManagerLoader:
|
||||||
NAME = "Lora Loader (LoraManager)"
|
NAME = "Lora Loader (LoraManager)"
|
||||||
CATEGORY = "Lora Manager/loaders"
|
CATEGORY = "Lora Manager/loaders"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(cls):
|
def INPUT_TYPES(cls):
|
||||||
return {
|
return {
|
||||||
@@ -37,19 +37,39 @@ class LoraManagerLoader:
|
|||||||
|
|
||||||
clip = kwargs.get('clip', None)
|
clip = kwargs.get('clip', None)
|
||||||
lora_stack = kwargs.get('lora_stack', None)
|
lora_stack = kwargs.get('lora_stack', None)
|
||||||
|
|
||||||
|
# Check if model is a Nunchaku Flux model - simplified approach
|
||||||
|
is_nunchaku_model = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
model_wrapper = model.model.diffusion_model
|
||||||
|
# Check if model is a Nunchaku Flux model using only class name
|
||||||
|
if model_wrapper.__class__.__name__ == "ComfyFluxWrapper":
|
||||||
|
is_nunchaku_model = True
|
||||||
|
logger.info("Detected Nunchaku Flux model")
|
||||||
|
except (AttributeError, TypeError):
|
||||||
|
# Not a model with the expected structure
|
||||||
|
pass
|
||||||
|
|
||||||
# First process lora_stack if available
|
# First process lora_stack if available
|
||||||
if lora_stack:
|
if lora_stack:
|
||||||
for lora_path, model_strength, clip_strength in lora_stack:
|
for lora_path, model_strength, clip_strength in lora_stack:
|
||||||
# Apply the LoRA using the provided path and strengths
|
# Apply the LoRA using the appropriate loader
|
||||||
model, clip = LoraLoader().load_lora(model, clip, lora_path, model_strength, clip_strength)
|
if is_nunchaku_model:
|
||||||
|
# Use our custom function for Flux models
|
||||||
|
model = nunchaku_load_lora(model, lora_path, model_strength)
|
||||||
|
# clip remains unchanged for Nunchaku models
|
||||||
|
else:
|
||||||
|
# Use default loader for standard models
|
||||||
|
model, clip = LoraLoader().load_lora(model, clip, lora_path, model_strength, clip_strength)
|
||||||
|
|
||||||
# Extract lora name for trigger words lookup
|
# Extract lora name for trigger words lookup
|
||||||
lora_name = extract_lora_name(lora_path)
|
lora_name = extract_lora_name(lora_path)
|
||||||
_, trigger_words = asyncio.run(get_lora_info(lora_name))
|
_, trigger_words = asyncio.run(get_lora_info(lora_name))
|
||||||
|
|
||||||
all_trigger_words.extend(trigger_words)
|
all_trigger_words.extend(trigger_words)
|
||||||
# Add clip strength to output if different from model strength
|
# Add clip strength to output if different from model strength (except for Nunchaku models)
|
||||||
if abs(model_strength - clip_strength) > 0.001:
|
if not is_nunchaku_model and abs(model_strength - clip_strength) > 0.001:
|
||||||
loaded_loras.append(f"{lora_name}: {model_strength},{clip_strength}")
|
loaded_loras.append(f"{lora_name}: {model_strength},{clip_strength}")
|
||||||
else:
|
else:
|
||||||
loaded_loras.append(f"{lora_name}: {model_strength}")
|
loaded_loras.append(f"{lora_name}: {model_strength}")
|
||||||
@@ -68,11 +88,17 @@ class LoraManagerLoader:
|
|||||||
# Get lora path and trigger words
|
# Get lora path and trigger words
|
||||||
lora_path, trigger_words = asyncio.run(get_lora_info(lora_name))
|
lora_path, trigger_words = asyncio.run(get_lora_info(lora_name))
|
||||||
|
|
||||||
# Apply the LoRA using the resolved path with separate strengths
|
# Apply the LoRA using the appropriate loader
|
||||||
model, clip = LoraLoader().load_lora(model, clip, lora_path, model_strength, clip_strength)
|
if is_nunchaku_model:
|
||||||
|
# For Nunchaku models, use our custom function
|
||||||
|
model = nunchaku_load_lora(model, lora_path, model_strength)
|
||||||
|
# clip remains unchanged
|
||||||
|
else:
|
||||||
|
# Use default loader for standard models
|
||||||
|
model, clip = LoraLoader().load_lora(model, clip, lora_path, model_strength, clip_strength)
|
||||||
|
|
||||||
# Include clip strength in output if different from model strength
|
# Include clip strength in output if different from model strength and not a Nunchaku model
|
||||||
if abs(model_strength - clip_strength) > 0.001:
|
if not is_nunchaku_model and abs(model_strength - clip_strength) > 0.001:
|
||||||
loaded_loras.append(f"{lora_name}: {model_strength},{clip_strength}")
|
loaded_loras.append(f"{lora_name}: {model_strength},{clip_strength}")
|
||||||
else:
|
else:
|
||||||
loaded_loras.append(f"{lora_name}: {model_strength}")
|
loaded_loras.append(f"{lora_name}: {model_strength}")
|
||||||
|
|||||||
@@ -35,7 +35,12 @@ any_type = AnyType("*")
|
|||||||
# Common methods extracted from lora_loader.py and lora_stacker.py
|
# Common methods extracted from lora_loader.py and lora_stacker.py
|
||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
import asyncio
|
import copy
|
||||||
|
import folder_paths
|
||||||
|
import torch
|
||||||
|
import safetensors.torch
|
||||||
|
from diffusers.utils.state_dict_utils import convert_unet_state_dict_to_peft
|
||||||
|
from diffusers.loaders import FluxLoraLoaderMixin
|
||||||
from ..services.lora_scanner import LoraScanner
|
from ..services.lora_scanner import LoraScanner
|
||||||
from ..config import config
|
from ..config import config
|
||||||
|
|
||||||
@@ -81,4 +86,64 @@ def get_loras_list(kwargs):
|
|||||||
# Unexpected format
|
# Unexpected format
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Unexpected loras format: {type(loras_data)}")
|
logger.warning(f"Unexpected loras format: {type(loras_data)}")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
def load_state_dict_in_safetensors(path, device="cpu", filter_prefix=""):
|
||||||
|
"""Simplified version of load_state_dict_in_safetensors that just loads from a local path"""
|
||||||
|
state_dict = {}
|
||||||
|
with safetensors.torch.safe_open(path, framework="pt", device=device) as f:
|
||||||
|
for k in f.keys():
|
||||||
|
if filter_prefix and not k.startswith(filter_prefix):
|
||||||
|
continue
|
||||||
|
state_dict[k.removeprefix(filter_prefix)] = f.get_tensor(k)
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
def to_diffusers(input_lora):
|
||||||
|
"""Simplified version of to_diffusers for Flux LoRA conversion"""
|
||||||
|
if isinstance(input_lora, str):
|
||||||
|
tensors = load_state_dict_in_safetensors(input_lora, device="cpu")
|
||||||
|
else:
|
||||||
|
tensors = {k: v for k, v in input_lora.items()}
|
||||||
|
|
||||||
|
# Convert FP8 tensors to BF16
|
||||||
|
for k, v in tensors.items():
|
||||||
|
if v.dtype not in [torch.float64, torch.float32, torch.bfloat16, torch.float16]:
|
||||||
|
tensors[k] = v.to(torch.bfloat16)
|
||||||
|
|
||||||
|
new_tensors = FluxLoraLoaderMixin.lora_state_dict(tensors)
|
||||||
|
new_tensors = convert_unet_state_dict_to_peft(new_tensors)
|
||||||
|
|
||||||
|
return new_tensors
|
||||||
|
|
||||||
|
def nunchaku_load_lora(model, lora_name, lora_strength):
|
||||||
|
"""Load a Flux LoRA for Nunchaku model"""
|
||||||
|
model_wrapper = model.model.diffusion_model
|
||||||
|
transformer = model_wrapper.model
|
||||||
|
|
||||||
|
# Save the transformer temporarily
|
||||||
|
model_wrapper.model = None
|
||||||
|
ret_model = copy.deepcopy(model) # copy everything except the model
|
||||||
|
ret_model_wrapper = ret_model.model.diffusion_model
|
||||||
|
|
||||||
|
# Restore the model and set it for the copy
|
||||||
|
model_wrapper.model = transformer
|
||||||
|
ret_model_wrapper.model = transformer
|
||||||
|
|
||||||
|
# Get full path to the LoRA file
|
||||||
|
lora_path = folder_paths.get_full_path("loras", lora_name)
|
||||||
|
ret_model_wrapper.loras.append((lora_path, lora_strength))
|
||||||
|
|
||||||
|
# Convert the LoRA to diffusers format
|
||||||
|
sd = to_diffusers(lora_path)
|
||||||
|
|
||||||
|
# Handle embedding adjustment if needed
|
||||||
|
if "transformer.x_embedder.lora_A.weight" in sd:
|
||||||
|
new_in_channels = sd["transformer.x_embedder.lora_A.weight"].shape[1]
|
||||||
|
assert new_in_channels % 4 == 0
|
||||||
|
new_in_channels = new_in_channels // 4
|
||||||
|
|
||||||
|
old_in_channels = ret_model.model.model_config.unet_config["in_channels"]
|
||||||
|
if old_in_channels < new_in_channels:
|
||||||
|
ret_model.model.model_config.unet_config["in_channels"] = new_in_channels
|
||||||
|
|
||||||
|
return ret_model
|
||||||
@@ -7,6 +7,7 @@ dependencies = [
|
|||||||
"aiohttp",
|
"aiohttp",
|
||||||
"jinja2",
|
"jinja2",
|
||||||
"safetensors",
|
"safetensors",
|
||||||
|
"diffusers",
|
||||||
"watchdog",
|
"watchdog",
|
||||||
"beautifulsoup4",
|
"beautifulsoup4",
|
||||||
"piexif",
|
"piexif",
|
||||||
|
|||||||
Reference in New Issue
Block a user