mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-28 00:18:52 -03:00
fix(nodes): lazy load qwen lora helper
This commit is contained in:
@@ -1,3 +1,4 @@
|
|||||||
|
import importlib
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
|
|
||||||
@@ -5,7 +6,6 @@ import comfy.sd # type: ignore
|
|||||||
import comfy.utils # type: ignore
|
import comfy.utils # type: ignore
|
||||||
|
|
||||||
from ..utils.utils import get_lora_info_absolute
|
from ..utils.utils import get_lora_info_absolute
|
||||||
from .nunchaku_qwen import nunchaku_load_qwen_loras
|
|
||||||
from .utils import (
|
from .utils import (
|
||||||
FlexibleOptionalInputType,
|
FlexibleOptionalInputType,
|
||||||
any_type,
|
any_type,
|
||||||
@@ -18,6 +18,16 @@ from .utils import (
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_nunchaku_load_qwen_loras():
|
||||||
|
try:
|
||||||
|
module = importlib.import_module(".nunchaku_qwen", __package__)
|
||||||
|
except ImportError as exc:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Qwen-Image LoRA loading requires the ComfyUI runtime with its torch dependency available."
|
||||||
|
) from exc
|
||||||
|
return module.nunchaku_load_qwen_loras
|
||||||
|
|
||||||
|
|
||||||
def _collect_stack_entries(lora_stack):
|
def _collect_stack_entries(lora_stack):
|
||||||
entries = []
|
entries = []
|
||||||
if not lora_stack:
|
if not lora_stack:
|
||||||
@@ -74,6 +84,7 @@ def _apply_entries(model, clip, lora_entries, nunchaku_model_kind):
|
|||||||
all_trigger_words = []
|
all_trigger_words = []
|
||||||
|
|
||||||
if nunchaku_model_kind == "qwen_image":
|
if nunchaku_model_kind == "qwen_image":
|
||||||
|
nunchaku_load_qwen_loras = _get_nunchaku_load_qwen_loras()
|
||||||
qwen_lora_configs = []
|
qwen_lora_configs = []
|
||||||
for entry in lora_entries:
|
for entry in lora_entries:
|
||||||
qwen_lora_configs.append((entry["absolute_path"], entry["model_strength"]))
|
qwen_lora_configs.append((entry["absolute_path"], entry["model_strength"]))
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
import types
|
import types
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from py.nodes.lora_loader import LoraLoaderLM, LoraTextLoaderLM
|
from py.nodes.lora_loader import LoraLoaderLM, LoraTextLoaderLM
|
||||||
|
|
||||||
|
|
||||||
@@ -128,7 +130,7 @@ def test_lora_loader_qwen_model_batches_loras(monkeypatch):
|
|||||||
batched_calls.append((model_arg, lora_configs))
|
batched_calls.append((model_arg, lora_configs))
|
||||||
return model_arg
|
return model_arg
|
||||||
|
|
||||||
monkeypatch.setattr("py.nodes.lora_loader.nunchaku_load_qwen_loras", mock_nunchaku_load_qwen_loras)
|
monkeypatch.setattr("py.nodes.lora_loader._get_nunchaku_load_qwen_loras", lambda: mock_nunchaku_load_qwen_loras)
|
||||||
|
|
||||||
_, result_clip, trigger_words, loaded_loras = loader.load_loras(
|
_, result_clip, trigger_words, loaded_loras = loader.load_loras(
|
||||||
qwen_model,
|
qwen_model,
|
||||||
@@ -160,8 +162,8 @@ def test_lora_text_loader_qwen_batches_text_and_stack(monkeypatch):
|
|||||||
|
|
||||||
batched_calls = []
|
batched_calls = []
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
"py.nodes.lora_loader.nunchaku_load_qwen_loras",
|
"py.nodes.lora_loader._get_nunchaku_load_qwen_loras",
|
||||||
lambda model_arg, lora_configs: batched_calls.append(lora_configs) or model_arg,
|
lambda: (lambda model_arg, lora_configs: batched_calls.append(lora_configs) or model_arg),
|
||||||
)
|
)
|
||||||
|
|
||||||
_, _, trigger_words, loaded_loras = loader.load_loras_from_text(
|
_, _, trigger_words, loaded_loras = loader.load_loras_from_text(
|
||||||
@@ -174,3 +176,26 @@ def test_lora_text_loader_qwen_batches_text_and_stack(monkeypatch):
|
|||||||
assert batched_calls == [[("/abs/stack_qwen.safetensors", 0.6), ("/abs/text_qwen.safetensors", 1.2)]]
|
assert batched_calls == [[("/abs/stack_qwen.safetensors", 0.6), ("/abs/text_qwen.safetensors", 1.2)]]
|
||||||
assert trigger_words == "stack_qwen_trigger,, text_qwen_trigger"
|
assert trigger_words == "stack_qwen_trigger,, text_qwen_trigger"
|
||||||
assert loaded_loras == "<lora:stack_qwen:0.6> <lora:text_qwen:1.2>"
|
assert loaded_loras == "<lora:stack_qwen:0.6> <lora:text_qwen:1.2>"
|
||||||
|
|
||||||
|
|
||||||
|
def test_lora_loader_qwen_model_raises_clear_error_when_helper_import_fails(monkeypatch):
|
||||||
|
qwen_model = _Model(type("NunchakuQwenImageTransformer2DModel", (), {})())
|
||||||
|
loader = LoraLoaderLM()
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"py.nodes.lora_loader.get_lora_info_absolute",
|
||||||
|
lambda name: (f"/abs/{name}.safetensors", [f"{name}_trigger"]),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"py.nodes.lora_loader._get_nunchaku_load_qwen_loras",
|
||||||
|
lambda: (_ for _ in ()).throw( # pragma: no branch
|
||||||
|
RuntimeError("Qwen-Image LoRA loading requires the ComfyUI runtime with its torch dependency available.")
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError, match="Qwen-Image LoRA loading requires the ComfyUI runtime"):
|
||||||
|
loader.load_loras(
|
||||||
|
qwen_model,
|
||||||
|
"",
|
||||||
|
lora_stack=[("stack_qwen.safetensors", 0.6, 0.1)],
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user