feat: Improve Nunchaku LoRA loading with copy_with_ctx support and add unit tests. see #733

This commit is contained in:
Will Miao
2025-12-27 21:46:14 +08:00
parent 7d6b717385
commit f5d5bffa61
2 changed files with 118 additions and 13 deletions

View File

@@ -36,6 +36,7 @@ any_type = AnyType("*")
import os
import logging
import copy
import sys
import folder_paths
logger = logging.getLogger(__name__)
@@ -98,25 +99,37 @@ def to_diffusers(input_lora):
def nunchaku_load_lora(model, lora_name, lora_strength):
"""Load a Flux LoRA for Nunchaku model"""
model_wrapper = model.model.diffusion_model
transformer = model_wrapper.model
# Save the transformer temporarily
model_wrapper.model = None
ret_model = model.clone()
ret_model_wrapper = ret_model.model.diffusion_model
# Restore the model and set it for the copy
model_wrapper.model = transformer
ret_model_wrapper.model = transformer
# Get full path to the LoRA file. Allow both direct paths and registered LoRA names.
lora_path = lora_name if os.path.isfile(lora_name) else folder_paths.get_full_path("loras", lora_name)
if not lora_path or not os.path.isfile(lora_path):
logger.warning("Skipping LoRA '%s' because it could not be found", lora_name)
return model
ret_model_wrapper.loras.append((lora_path, lora_strength))
model_wrapper = model.model.diffusion_model
# Try to find copy_with_ctx in the same module as ComfyFluxWrapper
module_name = model_wrapper.__class__.__module__
module = sys.modules.get(module_name)
copy_with_ctx = getattr(module, "copy_with_ctx", None)
if copy_with_ctx is not None:
# New logic using copy_with_ctx from ComfyUI-nunchaku 1.1.0+
ret_model_wrapper, ret_model = copy_with_ctx(model_wrapper)
ret_model_wrapper.loras = [*model_wrapper.loras, (lora_path, lora_strength)]
else:
# Fallback to legacy logic
logger.warning("Please upgrade ComfyUI-nunchaku to 1.1.0 or above for better LoRA support. Falling back to legacy loading logic.")
transformer = model_wrapper.model
# Save the transformer temporarily
model_wrapper.model = None
ret_model = copy.deepcopy(model) # copy everything except the model
ret_model_wrapper = ret_model.model.diffusion_model
# Restore the model and set it for the copy
model_wrapper.model = transformer
ret_model_wrapper.model = transformer
ret_model_wrapper.loras.append((lora_path, lora_strength))
# Convert the LoRA to diffusers format
sd = to_diffusers(lora_path)