From f5d5bffa617ce7aa4e418ecd51d9d42ebb59c6eb Mon Sep 17 00:00:00 2001 From: Will Miao Date: Sat, 27 Dec 2025 21:46:14 +0800 Subject: [PATCH] feat: Improve Nunchaku LoRA loading with `copy_with_ctx` support and add unit tests. see #733 --- py/nodes/utils.py | 39 ++++++++----- tests/nodes/test_nunchaku_lora.py | 92 +++++++++++++++++++++++++++++++ 2 files changed, 118 insertions(+), 13 deletions(-) create mode 100644 tests/nodes/test_nunchaku_lora.py diff --git a/py/nodes/utils.py b/py/nodes/utils.py index 402025e1..41127e4f 100644 --- a/py/nodes/utils.py +++ b/py/nodes/utils.py @@ -36,6 +36,7 @@ any_type = AnyType("*") import os import logging import copy +import sys import folder_paths logger = logging.getLogger(__name__) @@ -98,25 +99,37 @@ def to_diffusers(input_lora): def nunchaku_load_lora(model, lora_name, lora_strength): """Load a Flux LoRA for Nunchaku model""" - model_wrapper = model.model.diffusion_model - transformer = model_wrapper.model - - # Save the transformer temporarily - model_wrapper.model = None - ret_model = model.clone() - ret_model_wrapper = ret_model.model.diffusion_model - - # Restore the model and set it for the copy - model_wrapper.model = transformer - ret_model_wrapper.model = transformer - # Get full path to the LoRA file. Allow both direct paths and registered LoRA names. lora_path = lora_name if os.path.isfile(lora_name) else folder_paths.get_full_path("loras", lora_name) if not lora_path or not os.path.isfile(lora_path): logger.warning("Skipping LoRA '%s' because it could not be found", lora_name) return model - ret_model_wrapper.loras.append((lora_path, lora_strength)) + model_wrapper = model.model.diffusion_model + + # Try to find copy_with_ctx in the same module as ComfyFluxWrapper + module_name = model_wrapper.__class__.__module__ + module = sys.modules.get(module_name) + copy_with_ctx = getattr(module, "copy_with_ctx", None) + + if copy_with_ctx is not None: + # New logic using copy_with_ctx from ComfyUI-nunchaku 1.1.0+ + ret_model_wrapper, ret_model = copy_with_ctx(model_wrapper) + ret_model_wrapper.loras = [*model_wrapper.loras, (lora_path, lora_strength)] + else: + # Fallback to legacy logic + logger.warning("Please upgrade ComfyUI-nunchaku to 1.1.0 or above for better LoRA support. Falling back to legacy loading logic.") + transformer = model_wrapper.model + + # Save the transformer temporarily + model_wrapper.model = None + ret_model = copy.deepcopy(model) # copy everything except the model + ret_model_wrapper = ret_model.model.diffusion_model + + # Restore the model and set it for the copy + model_wrapper.model = transformer + ret_model_wrapper.model = transformer + ret_model_wrapper.loras.append((lora_path, lora_strength)) # Convert the LoRA to diffusers format sd = to_diffusers(lora_path) diff --git a/tests/nodes/test_nunchaku_lora.py b/tests/nodes/test_nunchaku_lora.py new file mode 100644 index 00000000..6b1ac783 --- /dev/null +++ b/tests/nodes/test_nunchaku_lora.py @@ -0,0 +1,92 @@ +import logging +import sys +import os +import unittest.mock as mock +from py.nodes.utils import nunchaku_load_lora + +class _DummyTransformer: + pass + +class _DummyModelConfig: + def __init__(self): + self.unet_config = {"in_channels": 4} + +class _DummyDiffusionModel: + def __init__(self): + self.model = _DummyTransformer() + self.loras = [] + +class _DummyModelWrapper: + def __init__(self): + self.diffusion_model = _DummyDiffusionModel() + self.model_config = _DummyModelConfig() + +class _DummyModel: + def __init__(self): + self.model = _DummyModelWrapper() + + def clone(self): + # This is what our legacy logic used via copy.deepcopy(model) + # But in the new logic, copy_with_ctx returns the cloned model + return self + +def test_nunchaku_load_lora_legacy_fallback(monkeypatch, caplog): + import folder_paths + import copy + + dummy_model = _DummyModel() + + # Mock folder_paths and os.path.isfile to "find" the LoRA + monkeypatch.setattr(folder_paths, "get_full_path", lambda folder, name: f"/fake/path/{name}", raising=False) + monkeypatch.setattr(os.path, "isfile", lambda path: True if "/fake/path/" in path else False) + + # Mock to_diffusers to return a dummy state dict + monkeypatch.setattr("py.nodes.utils.to_diffusers", lambda path: {}) + + # Ensure copy_with_ctx is NOT found + # model_wrapper.__class__.__module__ will be this module + module_name = _DummyDiffusionModel.__module__ + if module_name in sys.modules: + module = sys.modules[module_name] + if hasattr(module, "copy_with_ctx"): + monkeypatch.delattr(module, "copy_with_ctx") + + with caplog.at_level(logging.WARNING): + result_model = nunchaku_load_lora(dummy_model, "some_lora", 0.8) + + assert "better LoRA support" in caplog.text + assert len(result_model.model.diffusion_model.loras) == 1 + assert result_model.model.diffusion_model.loras[0][1] == 0.8 + +def test_nunchaku_load_lora_new_logic(monkeypatch): + import folder_paths + import os + + dummy_model = _DummyModel() + model_wrapper = dummy_model.model.diffusion_model + + # Mock folder_paths and os.path.isfile + monkeypatch.setattr(folder_paths, "get_full_path", lambda folder, name: f"/fake/path/{name}", raising=False) + monkeypatch.setattr(os.path, "isfile", lambda path: True if "/fake/path/" in path else False) + + # Mock to_diffusers + monkeypatch.setattr("py.nodes.utils.to_diffusers", lambda path: {}) + + # Create the cloned objects that copy_with_ctx would return + cloned_wrapper = _DummyDiffusionModel() + cloned_model = _DummyModel() + cloned_model.model.diffusion_model = cloned_wrapper + + # Define copy_with_ctx + def mock_copy_with_ctx(wrapper): + return cloned_wrapper, cloned_model + + # Inject copy_with_ctx into the module + module_name = _DummyDiffusionModel.__module__ + module = sys.modules[module_name] + monkeypatch.setattr(module, "copy_with_ctx", mock_copy_with_ctx, raising=False) + + result_model = nunchaku_load_lora(dummy_model, "new_lora", 0.7) + + assert result_model is cloned_model + assert cloned_wrapper.loras == [("/fake/path/new_lora", 0.7)]