mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
feat: Improve Nunchaku LoRA loading with copy_with_ctx support and add unit tests. see #733
This commit is contained in:
@@ -36,6 +36,7 @@ any_type = AnyType("*")
|
|||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
import copy
|
import copy
|
||||||
|
import sys
|
||||||
import folder_paths
|
import folder_paths
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -98,25 +99,37 @@ def to_diffusers(input_lora):
|
|||||||
|
|
||||||
def nunchaku_load_lora(model, lora_name, lora_strength):
|
def nunchaku_load_lora(model, lora_name, lora_strength):
|
||||||
"""Load a Flux LoRA for Nunchaku model"""
|
"""Load a Flux LoRA for Nunchaku model"""
|
||||||
model_wrapper = model.model.diffusion_model
|
|
||||||
transformer = model_wrapper.model
|
|
||||||
|
|
||||||
# Save the transformer temporarily
|
|
||||||
model_wrapper.model = None
|
|
||||||
ret_model = model.clone()
|
|
||||||
ret_model_wrapper = ret_model.model.diffusion_model
|
|
||||||
|
|
||||||
# Restore the model and set it for the copy
|
|
||||||
model_wrapper.model = transformer
|
|
||||||
ret_model_wrapper.model = transformer
|
|
||||||
|
|
||||||
# Get full path to the LoRA file. Allow both direct paths and registered LoRA names.
|
# Get full path to the LoRA file. Allow both direct paths and registered LoRA names.
|
||||||
lora_path = lora_name if os.path.isfile(lora_name) else folder_paths.get_full_path("loras", lora_name)
|
lora_path = lora_name if os.path.isfile(lora_name) else folder_paths.get_full_path("loras", lora_name)
|
||||||
if not lora_path or not os.path.isfile(lora_path):
|
if not lora_path or not os.path.isfile(lora_path):
|
||||||
logger.warning("Skipping LoRA '%s' because it could not be found", lora_name)
|
logger.warning("Skipping LoRA '%s' because it could not be found", lora_name)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
ret_model_wrapper.loras.append((lora_path, lora_strength))
|
model_wrapper = model.model.diffusion_model
|
||||||
|
|
||||||
|
# Try to find copy_with_ctx in the same module as ComfyFluxWrapper
|
||||||
|
module_name = model_wrapper.__class__.__module__
|
||||||
|
module = sys.modules.get(module_name)
|
||||||
|
copy_with_ctx = getattr(module, "copy_with_ctx", None)
|
||||||
|
|
||||||
|
if copy_with_ctx is not None:
|
||||||
|
# New logic using copy_with_ctx from ComfyUI-nunchaku 1.1.0+
|
||||||
|
ret_model_wrapper, ret_model = copy_with_ctx(model_wrapper)
|
||||||
|
ret_model_wrapper.loras = [*model_wrapper.loras, (lora_path, lora_strength)]
|
||||||
|
else:
|
||||||
|
# Fallback to legacy logic
|
||||||
|
logger.warning("Please upgrade ComfyUI-nunchaku to 1.1.0 or above for better LoRA support. Falling back to legacy loading logic.")
|
||||||
|
transformer = model_wrapper.model
|
||||||
|
|
||||||
|
# Save the transformer temporarily
|
||||||
|
model_wrapper.model = None
|
||||||
|
ret_model = copy.deepcopy(model) # copy everything except the model
|
||||||
|
ret_model_wrapper = ret_model.model.diffusion_model
|
||||||
|
|
||||||
|
# Restore the model and set it for the copy
|
||||||
|
model_wrapper.model = transformer
|
||||||
|
ret_model_wrapper.model = transformer
|
||||||
|
ret_model_wrapper.loras.append((lora_path, lora_strength))
|
||||||
|
|
||||||
# Convert the LoRA to diffusers format
|
# Convert the LoRA to diffusers format
|
||||||
sd = to_diffusers(lora_path)
|
sd = to_diffusers(lora_path)
|
||||||
|
|||||||
92
tests/nodes/test_nunchaku_lora.py
Normal file
92
tests/nodes/test_nunchaku_lora.py
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import unittest.mock as mock
|
||||||
|
from py.nodes.utils import nunchaku_load_lora
|
||||||
|
|
||||||
|
class _DummyTransformer:
|
||||||
|
pass
|
||||||
|
|
||||||
|
class _DummyModelConfig:
|
||||||
|
def __init__(self):
|
||||||
|
self.unet_config = {"in_channels": 4}
|
||||||
|
|
||||||
|
class _DummyDiffusionModel:
|
||||||
|
def __init__(self):
|
||||||
|
self.model = _DummyTransformer()
|
||||||
|
self.loras = []
|
||||||
|
|
||||||
|
class _DummyModelWrapper:
|
||||||
|
def __init__(self):
|
||||||
|
self.diffusion_model = _DummyDiffusionModel()
|
||||||
|
self.model_config = _DummyModelConfig()
|
||||||
|
|
||||||
|
class _DummyModel:
|
||||||
|
def __init__(self):
|
||||||
|
self.model = _DummyModelWrapper()
|
||||||
|
|
||||||
|
def clone(self):
|
||||||
|
# This is what our legacy logic used via copy.deepcopy(model)
|
||||||
|
# But in the new logic, copy_with_ctx returns the cloned model
|
||||||
|
return self
|
||||||
|
|
||||||
|
def test_nunchaku_load_lora_legacy_fallback(monkeypatch, caplog):
|
||||||
|
import folder_paths
|
||||||
|
import copy
|
||||||
|
|
||||||
|
dummy_model = _DummyModel()
|
||||||
|
|
||||||
|
# Mock folder_paths and os.path.isfile to "find" the LoRA
|
||||||
|
monkeypatch.setattr(folder_paths, "get_full_path", lambda folder, name: f"/fake/path/{name}", raising=False)
|
||||||
|
monkeypatch.setattr(os.path, "isfile", lambda path: True if "/fake/path/" in path else False)
|
||||||
|
|
||||||
|
# Mock to_diffusers to return a dummy state dict
|
||||||
|
monkeypatch.setattr("py.nodes.utils.to_diffusers", lambda path: {})
|
||||||
|
|
||||||
|
# Ensure copy_with_ctx is NOT found
|
||||||
|
# model_wrapper.__class__.__module__ will be this module
|
||||||
|
module_name = _DummyDiffusionModel.__module__
|
||||||
|
if module_name in sys.modules:
|
||||||
|
module = sys.modules[module_name]
|
||||||
|
if hasattr(module, "copy_with_ctx"):
|
||||||
|
monkeypatch.delattr(module, "copy_with_ctx")
|
||||||
|
|
||||||
|
with caplog.at_level(logging.WARNING):
|
||||||
|
result_model = nunchaku_load_lora(dummy_model, "some_lora", 0.8)
|
||||||
|
|
||||||
|
assert "better LoRA support" in caplog.text
|
||||||
|
assert len(result_model.model.diffusion_model.loras) == 1
|
||||||
|
assert result_model.model.diffusion_model.loras[0][1] == 0.8
|
||||||
|
|
||||||
|
def test_nunchaku_load_lora_new_logic(monkeypatch):
|
||||||
|
import folder_paths
|
||||||
|
import os
|
||||||
|
|
||||||
|
dummy_model = _DummyModel()
|
||||||
|
model_wrapper = dummy_model.model.diffusion_model
|
||||||
|
|
||||||
|
# Mock folder_paths and os.path.isfile
|
||||||
|
monkeypatch.setattr(folder_paths, "get_full_path", lambda folder, name: f"/fake/path/{name}", raising=False)
|
||||||
|
monkeypatch.setattr(os.path, "isfile", lambda path: True if "/fake/path/" in path else False)
|
||||||
|
|
||||||
|
# Mock to_diffusers
|
||||||
|
monkeypatch.setattr("py.nodes.utils.to_diffusers", lambda path: {})
|
||||||
|
|
||||||
|
# Create the cloned objects that copy_with_ctx would return
|
||||||
|
cloned_wrapper = _DummyDiffusionModel()
|
||||||
|
cloned_model = _DummyModel()
|
||||||
|
cloned_model.model.diffusion_model = cloned_wrapper
|
||||||
|
|
||||||
|
# Define copy_with_ctx
|
||||||
|
def mock_copy_with_ctx(wrapper):
|
||||||
|
return cloned_wrapper, cloned_model
|
||||||
|
|
||||||
|
# Inject copy_with_ctx into the module
|
||||||
|
module_name = _DummyDiffusionModel.__module__
|
||||||
|
module = sys.modules[module_name]
|
||||||
|
monkeypatch.setattr(module, "copy_with_ctx", mock_copy_with_ctx, raising=False)
|
||||||
|
|
||||||
|
result_model = nunchaku_load_lora(dummy_model, "new_lora", 0.7)
|
||||||
|
|
||||||
|
assert result_model is cloned_model
|
||||||
|
assert cloned_wrapper.loras == [("/fake/path/new_lora", 0.7)]
|
||||||
Reference in New Issue
Block a user