feat: Improve Nunchaku LoRA loading with copy_with_ctx support and add unit tests. see #733

This commit is contained in:
Will Miao
2025-12-27 21:46:14 +08:00
parent 7d6b717385
commit f5d5bffa61
2 changed files with 118 additions and 13 deletions

View File

@@ -36,6 +36,7 @@ any_type = AnyType("*")
import os import os
import logging import logging
import copy import copy
import sys
import folder_paths import folder_paths
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -98,25 +99,37 @@ def to_diffusers(input_lora):
def nunchaku_load_lora(model, lora_name, lora_strength): def nunchaku_load_lora(model, lora_name, lora_strength):
"""Load a Flux LoRA for Nunchaku model""" """Load a Flux LoRA for Nunchaku model"""
model_wrapper = model.model.diffusion_model
transformer = model_wrapper.model
# Save the transformer temporarily
model_wrapper.model = None
ret_model = model.clone()
ret_model_wrapper = ret_model.model.diffusion_model
# Restore the model and set it for the copy
model_wrapper.model = transformer
ret_model_wrapper.model = transformer
# Get full path to the LoRA file. Allow both direct paths and registered LoRA names. # Get full path to the LoRA file. Allow both direct paths and registered LoRA names.
lora_path = lora_name if os.path.isfile(lora_name) else folder_paths.get_full_path("loras", lora_name) lora_path = lora_name if os.path.isfile(lora_name) else folder_paths.get_full_path("loras", lora_name)
if not lora_path or not os.path.isfile(lora_path): if not lora_path or not os.path.isfile(lora_path):
logger.warning("Skipping LoRA '%s' because it could not be found", lora_name) logger.warning("Skipping LoRA '%s' because it could not be found", lora_name)
return model return model
ret_model_wrapper.loras.append((lora_path, lora_strength)) model_wrapper = model.model.diffusion_model
# Try to find copy_with_ctx in the same module as ComfyFluxWrapper
module_name = model_wrapper.__class__.__module__
module = sys.modules.get(module_name)
copy_with_ctx = getattr(module, "copy_with_ctx", None)
if copy_with_ctx is not None:
# New logic using copy_with_ctx from ComfyUI-nunchaku 1.1.0+
ret_model_wrapper, ret_model = copy_with_ctx(model_wrapper)
ret_model_wrapper.loras = [*model_wrapper.loras, (lora_path, lora_strength)]
else:
# Fallback to legacy logic
logger.warning("Please upgrade ComfyUI-nunchaku to 1.1.0 or above for better LoRA support. Falling back to legacy loading logic.")
transformer = model_wrapper.model
# Save the transformer temporarily
model_wrapper.model = None
ret_model = copy.deepcopy(model) # copy everything except the model
ret_model_wrapper = ret_model.model.diffusion_model
# Restore the model and set it for the copy
model_wrapper.model = transformer
ret_model_wrapper.model = transformer
ret_model_wrapper.loras.append((lora_path, lora_strength))
# Convert the LoRA to diffusers format # Convert the LoRA to diffusers format
sd = to_diffusers(lora_path) sd = to_diffusers(lora_path)

View File

@@ -0,0 +1,92 @@
import logging
import sys
import os
import unittest.mock as mock
from py.nodes.utils import nunchaku_load_lora
class _DummyTransformer:
pass
class _DummyModelConfig:
def __init__(self):
self.unet_config = {"in_channels": 4}
class _DummyDiffusionModel:
def __init__(self):
self.model = _DummyTransformer()
self.loras = []
class _DummyModelWrapper:
def __init__(self):
self.diffusion_model = _DummyDiffusionModel()
self.model_config = _DummyModelConfig()
class _DummyModel:
def __init__(self):
self.model = _DummyModelWrapper()
def clone(self):
# This is what our legacy logic used via copy.deepcopy(model)
# But in the new logic, copy_with_ctx returns the cloned model
return self
def test_nunchaku_load_lora_legacy_fallback(monkeypatch, caplog):
import folder_paths
import copy
dummy_model = _DummyModel()
# Mock folder_paths and os.path.isfile to "find" the LoRA
monkeypatch.setattr(folder_paths, "get_full_path", lambda folder, name: f"/fake/path/{name}", raising=False)
monkeypatch.setattr(os.path, "isfile", lambda path: True if "/fake/path/" in path else False)
# Mock to_diffusers to return a dummy state dict
monkeypatch.setattr("py.nodes.utils.to_diffusers", lambda path: {})
# Ensure copy_with_ctx is NOT found
# model_wrapper.__class__.__module__ will be this module
module_name = _DummyDiffusionModel.__module__
if module_name in sys.modules:
module = sys.modules[module_name]
if hasattr(module, "copy_with_ctx"):
monkeypatch.delattr(module, "copy_with_ctx")
with caplog.at_level(logging.WARNING):
result_model = nunchaku_load_lora(dummy_model, "some_lora", 0.8)
assert "better LoRA support" in caplog.text
assert len(result_model.model.diffusion_model.loras) == 1
assert result_model.model.diffusion_model.loras[0][1] == 0.8
def test_nunchaku_load_lora_new_logic(monkeypatch):
import folder_paths
import os
dummy_model = _DummyModel()
model_wrapper = dummy_model.model.diffusion_model
# Mock folder_paths and os.path.isfile
monkeypatch.setattr(folder_paths, "get_full_path", lambda folder, name: f"/fake/path/{name}", raising=False)
monkeypatch.setattr(os.path, "isfile", lambda path: True if "/fake/path/" in path else False)
# Mock to_diffusers
monkeypatch.setattr("py.nodes.utils.to_diffusers", lambda path: {})
# Create the cloned objects that copy_with_ctx would return
cloned_wrapper = _DummyDiffusionModel()
cloned_model = _DummyModel()
cloned_model.model.diffusion_model = cloned_wrapper
# Define copy_with_ctx
def mock_copy_with_ctx(wrapper):
return cloned_wrapper, cloned_model
# Inject copy_with_ctx into the module
module_name = _DummyDiffusionModel.__module__
module = sys.modules[module_name]
monkeypatch.setattr(module, "copy_with_ctx", mock_copy_with_ctx, raising=False)
result_model = nunchaku_load_lora(dummy_model, "new_lora", 0.7)
assert result_model is cloned_model
assert cloned_wrapper.loras == [("/fake/path/new_lora", 0.7)]