diff --git a/__init__.py b/__init__.py index db97c0e1..88ec0edc 100644 --- a/__init__.py +++ b/__init__.py @@ -1,10 +1,12 @@ from .py.lora_manager import LoraManager from .py.nodes.lora_loader import LoraManagerLoader from .py.nodes.trigger_word_toggle import TriggerWordToggle +from .py.nodes.lora_stacker import LoraStacker NODE_CLASS_MAPPINGS = { LoraManagerLoader.NAME: LoraManagerLoader, TriggerWordToggle.NAME: TriggerWordToggle + # LoraStacker.NAME: LoraStacker } WEB_DIRECTORY = "./web/comfyui" diff --git a/py/nodes/lora_loader.py b/py/nodes/lora_loader.py index 295ca6d3..9e3398f2 100644 --- a/py/nodes/lora_loader.py +++ b/py/nodes/lora_loader.py @@ -8,7 +8,7 @@ from .utils import FlexibleOptionalInputType, any_type class LoraManagerLoader: NAME = "Lora Loader (LoraManager)" - CATEGORY = "loaders" + CATEGORY = "Lora Manager/loaders" @classmethod def INPUT_TYPES(cls): @@ -23,7 +23,10 @@ class LoraManagerLoader: "placeholder": "LoRA syntax input: " }), }, - "optional": FlexibleOptionalInputType(any_type), + "optional": { + **FlexibleOptionalInputType(any_type), + "lora_stack": ("LORA_STACK", {"default": None}), + } } RETURN_TYPES = ("MODEL", "CLIP", IO.STRING) @@ -49,11 +52,32 @@ class LoraManagerLoader: return relative_path, trigger_words return lora_name, [] # Fallback if not found - def load_loras(self, model, clip, text, **kwargs): - """Loads multiple LoRAs based on the kwargs input.""" + def extract_lora_name(self, lora_path): + """Extract the lora name from a lora path (e.g., 'IL\\aorunIllstrious.safetensors' -> 'aorunIllstrious')""" + # Get the basename without extension + basename = os.path.basename(lora_path) + return os.path.splitext(basename)[0] + + def load_loras(self, model, clip, text, lora_stack=None, **kwargs): + print("load_loras kwargs: ", kwargs) + """Loads multiple LoRAs based on the kwargs input and lora_stack.""" loaded_loras = [] all_trigger_words = [] + # First process lora_stack if available + if lora_stack: + for lora_path, model_strength, clip_strength in lora_stack: + # Apply the LoRA using the provided path and strengths + model, clip = LoraLoader().load_lora(model, clip, lora_path, model_strength, clip_strength) + + # Extract lora name for trigger words lookup + lora_name = self.extract_lora_name(lora_path) + _, trigger_words = asyncio.run(self.get_lora_info(lora_name)) + + all_trigger_words.extend(trigger_words) + loaded_loras.append(f"{lora_name}: {model_strength}") + + # Then process loras from kwargs if 'loras' in kwargs: for lora in kwargs['loras']: if not lora.get('active', False): diff --git a/py/nodes/lora_stacker.py b/py/nodes/lora_stacker.py new file mode 100644 index 00000000..20c69dbc --- /dev/null +++ b/py/nodes/lora_stacker.py @@ -0,0 +1,94 @@ +from comfy.comfy_types import IO # type: ignore +from ..services.lora_scanner import LoraScanner +from ..config import config +import asyncio +import os +from .utils import FlexibleOptionalInputType, any_type + +class LoraStacker: + NAME = "Lora Stacker (LoraManager)" + CATEGORY = "Lora Manager/stackers" + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "text": (IO.STRING, { + "multiline": True, + "dynamicPrompts": True, + "tooltip": "Format: separated by spaces or punctuation", + "placeholder": "LoRA syntax input: " + }), + }, + "optional": { + **FlexibleOptionalInputType(any_type), + "lora_stack": ("LORA_STACK", {"default": None}), + } + } + + RETURN_TYPES = ("LORA_STACK", IO.STRING) + RETURN_NAMES = ("LORA_STACK", "trigger_words") + FUNCTION = "stack_loras" + + async def get_lora_info(self, lora_name): + """Get the lora path and trigger words from cache""" + scanner = await LoraScanner.get_instance() + cache = await scanner.get_cached_data() + + for item in cache.raw_data: + if item.get('file_name') == lora_name: + file_path = item.get('file_path') + if file_path: + for root in config.loras_roots: + root = root.replace(os.sep, '/') + if file_path.startswith(root): + relative_path = os.path.relpath(file_path, root).replace(os.sep, '/') + # Get trigger words from civitai metadata + civitai = item.get('civitai', {}) + trigger_words = civitai.get('trainedWords', []) if civitai else [] + return relative_path, trigger_words + return lora_name, [] # Fallback if not found + + def extract_lora_name(self, lora_path): + """Extract the lora name from a lora path (e.g., 'IL\\aorunIllstrious.safetensors' -> 'aorunIllstrious')""" + # Get the basename without extension + basename = os.path.basename(lora_path) + return os.path.splitext(basename)[0] + + def stack_loras(self, text, lora_stack=None, **kwargs): + print("stack_loras kwargs: ", kwargs) + """Stacks multiple LoRAs based on the kwargs input without loading them.""" + stack = [] + all_trigger_words = [] + + # Process existing lora_stack if available + if lora_stack: + stack.extend(lora_stack) + # Get trigger words from existing stack entries + for lora_path, _, _ in lora_stack: + lora_name = self.extract_lora_name(lora_path) + _, trigger_words = asyncio.run(self.get_lora_info(lora_name)) + all_trigger_words.extend(trigger_words) + + if 'loras' in kwargs: + for lora in kwargs['loras']: + if not lora.get('active', False): + continue + + lora_name = lora['name'] + model_strength = float(lora['strength']) + clip_strength = model_strength # Using same strength for both as in the original loader + + # Get lora path and trigger words + lora_path, trigger_words = asyncio.run(self.get_lora_info(lora_name)) + + # Add to stack without loading + stack.append((lora_path, model_strength, clip_strength)) + + # Add trigger words to collection + all_trigger_words.extend(trigger_words) + + # use ',, ' to separate trigger words for group mode + trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else "" + + return (stack, trigger_words_text) diff --git a/py/nodes/trigger_word_toggle.py b/py/nodes/trigger_word_toggle.py index 3e934e08..9e353043 100644 --- a/py/nodes/trigger_word_toggle.py +++ b/py/nodes/trigger_word_toggle.py @@ -4,7 +4,7 @@ from .utils import FlexibleOptionalInputType, any_type class TriggerWordToggle: NAME = "TriggerWord Toggle (LoraManager)" - CATEGORY = "lora manager" + CATEGORY = "Lora Manager/utils" DESCRIPTION = "Toggle trigger words on/off" @classmethod @@ -13,10 +13,7 @@ class TriggerWordToggle: "required": { "group_mode": ("BOOLEAN", {"default": True}), }, - "optional": { - **FlexibleOptionalInputType(any_type), - "trigger_words": ("STRING", {"default": "", "defaultInput": True}), - }, + "optional": FlexibleOptionalInputType(any_type), "hidden": { "id": "UNIQUE_ID", # 会被 ComfyUI 自动替换为唯一ID }, @@ -26,7 +23,9 @@ class TriggerWordToggle: RETURN_NAMES = ("filtered_trigger_words",) FUNCTION = "process_trigger_words" - def process_trigger_words(self, id, trigger_words="", **kwargs): + def process_trigger_words(self, id, **kwargs): + print("trigger_words ", kwargs) + trigger_words = kwargs.get("trigger_words", "") # Send trigger words to frontend PromptServer.instance.send_sync("trigger_word_update", { "id": id, diff --git a/web/comfyui/lora_stacker.js b/web/comfyui/lora_stacker.js new file mode 100644 index 00000000..39a199ea --- /dev/null +++ b/web/comfyui/lora_stacker.js @@ -0,0 +1,112 @@ +import { app } from "../../scripts/app.js"; +import { addLorasWidget } from "./loras_widget.js"; + +// Extract pattern into a constant for consistent use +const LORA_PATTERN = //g; + +function mergeLoras(lorasText, lorasArr) { + const result = []; + let match; + + // Parse text input and create initial entries + while ((match = LORA_PATTERN.exec(lorasText)) !== null) { + const name = match[1]; + const inputStrength = Number(match[2]); + + // Find if this lora exists in the array data + const existingLora = lorasArr.find(l => l.name === name); + + result.push({ + name: name, + // Use existing strength if available, otherwise use input strength + strength: existingLora ? existingLora.strength : inputStrength, + active: existingLora ? existingLora.active : true + }); + } + + return result; +} + +app.registerExtension({ + name: "LoraManager.LoraStacker", + + async nodeCreated(node) { + if (node.comfyClass === "Lora Stacker (LoraManager)") { + // Enable widget serialization + node.serialize_widgets = true; + + // Wait for node to be properly initialized + requestAnimationFrame(() => { + // Restore saved value if exists + let existingLoras = []; + if (node.widgets_values && node.widgets_values.length > 0) { + const savedValue = node.widgets_values[1]; + // TODO: clean up this code + try { + // Check if the value is already an array/object + if (typeof savedValue === 'object' && savedValue !== null) { + existingLoras = savedValue; + } else if (typeof savedValue === 'string') { + existingLoras = JSON.parse(savedValue); + } + } catch (e) { + console.warn("Failed to parse loras data:", e); + existingLoras = []; + } + } + // Merge the loras data + const mergedLoras = mergeLoras(node.widgets[0].value, existingLoras); + + // Add flag to prevent callback loops + let isUpdating = false; + + // Get the widget object directly from the returned object + const result = addLorasWidget(node, "loras", { + defaultVal: mergedLoras // Pass object directly + }, (value) => { + // Prevent recursive calls + if (isUpdating) return; + isUpdating = true; + + try { + // Remove loras that are not in the value array + const inputWidget = node.widgets[0]; + const currentLoras = value.map(l => l.name); + + // Use the constant pattern here as well + let newText = inputWidget.value.replace(LORA_PATTERN, (match, name, strength) => { + return currentLoras.includes(name) ? match : ''; + }); + + // Clean up multiple spaces and trim + newText = newText.replace(/\s+/g, ' ').trim(); + + inputWidget.value = newText; + } finally { + isUpdating = false; + } + }); + + node.lorasWidget = result.widget; + + // Update input widget callback + const inputWidget = node.widgets[0]; + inputWidget.callback = (value) => { + if (isUpdating) return; + isUpdating = true; + + try { + const currentLoras = node.lorasWidget.value || []; + const mergedLoras = mergeLoras(value, currentLoras); + + node.lorasWidget.value = mergedLoras; + } finally { + isUpdating = false; + } + }; + }); + + console.log("Lora Stacker node created:", node); + } + }, +}); \ No newline at end of file diff --git a/web/comfyui/loras_widget.js b/web/comfyui/loras_widget.js index 7d544c0b..63fb609b 100644 --- a/web/comfyui/loras_widget.js +++ b/web/comfyui/loras_widget.js @@ -750,6 +750,7 @@ export function addLorasWidget(node, name, opts, callback) { widget.callback = callback; widget.serializeValue = () => { + console.log("Serializing loras data: ", widgetValue); // Add dummy items to avoid the 2-element serialization issue, a bug in comfyui return [...widgetValue, { name: "__dummy_item1__", strength: 0, active: false, _isDummy: true }, diff --git a/web/comfyui/trigger_word_toggle.js b/web/comfyui/trigger_word_toggle.js index 5687c238..3e24c313 100644 --- a/web/comfyui/trigger_word_toggle.js +++ b/web/comfyui/trigger_word_toggle.js @@ -22,7 +22,12 @@ app.registerExtension({ node.serialize_widgets = true; // Wait for node to be properly initialized - requestAnimationFrame(() => { + requestAnimationFrame(() => { + node.addInput("trigger_words", 'string', { + "default": "", + "defaultInput": false, // Changed to make it optional + "optional": true // Marking the input as optional + }); // Get the widget object directly from the returned object const result = addTagsWidget(node, "toggle_trigger_words", { defaultVal: [] @@ -39,11 +44,11 @@ app.registerExtension({ // Restore saved value if exists if (node.widgets_values && node.widgets_values.length > 0) { // 0 is group mode, 1 is input, 2 is tag widget, 3 is original message - const savedValue = node.widgets_values[2]; + const savedValue = node.widgets_values[1]; if (savedValue) { result.widget.value = savedValue; } - const originalMessage = node.widgets_values[3]; + const originalMessage = node.widgets_values[2]; if (originalMessage) { hiddenWidget.value = originalMessage; } @@ -51,10 +56,12 @@ app.registerExtension({ const groupModeWidget = node.widgets[0]; groupModeWidget.callback = (value) => { - if (node.widgets[3].value) { - this.updateTagsBasedOnMode(node, node.widgets[3].value, value); + if (node.widgets[2].value) { + this.updateTagsBasedOnMode(node, node.widgets[2].value, value); } } + + console.log("node ", node); }); } }, @@ -68,7 +75,7 @@ app.registerExtension({ } // Store the original message for mode switching - node.widgets[3].value = message; + node.widgets[2].value = message; if (node.tagWidget) { // Parse tags based on current group mode