From 2626dbab8ed08093fbd64e17c23dbefdf3b7fe84 Mon Sep 17 00:00:00 2001 From: Will Miao Date: Sun, 29 Mar 2026 08:28:00 +0800 Subject: [PATCH] feat: add lora stack combiner node --- __init__.py | 5 +++ py/nodes/lora_stack_combiner.py | 26 +++++++++++++ tests/nodes/test_lora_stack_combiner.py | 51 +++++++++++++++++++++++++ 3 files changed, 82 insertions(+) create mode 100644 py/nodes/lora_stack_combiner.py create mode 100644 tests/nodes/test_lora_stack_combiner.py diff --git a/__init__.py b/__init__.py index 1966e838..1f58f2ab 100644 --- a/__init__.py +++ b/__init__.py @@ -7,6 +7,7 @@ try: # pragma: no cover - import fallback for pytest collection from .py.nodes.prompt import PromptLM from .py.nodes.text import TextLM from .py.nodes.lora_stacker import LoraStackerLM + from .py.nodes.lora_stack_combiner import LoraStackCombinerLM from .py.nodes.save_image import SaveImageLM from .py.nodes.debug_metadata import DebugMetadataLM from .py.nodes.wanvideo_lora_select import WanVideoLoraSelectLM @@ -39,6 +40,9 @@ except ( "py.nodes.trigger_word_toggle" ).TriggerWordToggleLM LoraStackerLM = importlib.import_module("py.nodes.lora_stacker").LoraStackerLM + LoraStackCombinerLM = importlib.import_module( + "py.nodes.lora_stack_combiner" + ).LoraStackCombinerLM SaveImageLM = importlib.import_module("py.nodes.save_image").SaveImageLM DebugMetadataLM = importlib.import_module("py.nodes.debug_metadata").DebugMetadataLM WanVideoLoraSelectLM = importlib.import_module( @@ -63,6 +67,7 @@ NODE_CLASS_MAPPINGS = { UNETLoaderLM.NAME: UNETLoaderLM, TriggerWordToggleLM.NAME: TriggerWordToggleLM, LoraStackerLM.NAME: LoraStackerLM, + LoraStackCombinerLM.NAME: LoraStackCombinerLM, SaveImageLM.NAME: SaveImageLM, DebugMetadataLM.NAME: DebugMetadataLM, WanVideoLoraSelectLM.NAME: WanVideoLoraSelectLM, diff --git a/py/nodes/lora_stack_combiner.py b/py/nodes/lora_stack_combiner.py new file mode 100644 index 00000000..9b5c4fbc --- /dev/null +++ b/py/nodes/lora_stack_combiner.py @@ -0,0 +1,26 @@ +class LoraStackCombinerLM: + NAME = "Lora Stack Combiner (LoraManager)" + CATEGORY = "Lora Manager/stackers" + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "lora_stack_a": ("LORA_STACK",), + "lora_stack_b": ("LORA_STACK",), + }, + } + + RETURN_TYPES = ("LORA_STACK",) + RETURN_NAMES = ("LORA_STACK",) + FUNCTION = "combine_stacks" + + def combine_stacks(self, lora_stack_a, lora_stack_b): + combined_stack = [] + + if lora_stack_a: + combined_stack.extend(lora_stack_a) + if lora_stack_b: + combined_stack.extend(lora_stack_b) + + return (combined_stack,) diff --git a/tests/nodes/test_lora_stack_combiner.py b/tests/nodes/test_lora_stack_combiner.py new file mode 100644 index 00000000..b5b83954 --- /dev/null +++ b/tests/nodes/test_lora_stack_combiner.py @@ -0,0 +1,51 @@ +from py.nodes.lora_stack_combiner import LoraStackCombinerLM + + +def test_combine_stacks_preserves_order(): + node = LoraStackCombinerLM() + stack_a = [ + ("folder/a.safetensors", 0.7, 0.6), + ("folder/b.safetensors", 0.8, 0.8), + ] + stack_b = [ + ("folder/c.safetensors", 1.0, 0.9), + ] + + (combined_stack,) = node.combine_stacks(stack_a, stack_b) + + assert combined_stack == stack_a + stack_b + + +def test_combine_stacks_returns_second_when_first_empty(): + node = LoraStackCombinerLM() + stack_b = [("folder/c.safetensors", 1.0, 0.9)] + + (combined_stack,) = node.combine_stacks([], stack_b) + + assert combined_stack == stack_b + + +def test_combine_stacks_returns_first_when_second_empty(): + node = LoraStackCombinerLM() + stack_a = [("folder/a.safetensors", 0.7, 0.6)] + + (combined_stack,) = node.combine_stacks(stack_a, []) + + assert combined_stack == stack_a + + +def test_combine_stacks_returns_empty_when_both_empty(): + node = LoraStackCombinerLM() + + (combined_stack,) = node.combine_stacks([], []) + + assert combined_stack == [] + + +def test_combine_stacks_allows_duplicate_entries(): + node = LoraStackCombinerLM() + duplicate_entry = ("folder/shared.safetensors", 0.9, 0.5) + + (combined_stack,) = node.combine_stacks([duplicate_entry], [duplicate_entry]) + + assert combined_stack == [duplicate_entry, duplicate_entry]