checkpoint

This commit is contained in:
Will Miao
2025-03-11 19:29:31 +08:00
parent 5a6c412845
commit ad56cafd62
7 changed files with 255 additions and 16 deletions

View File

@@ -1,10 +1,12 @@
from .py.lora_manager import LoraManager from .py.lora_manager import LoraManager
from .py.nodes.lora_loader import LoraManagerLoader from .py.nodes.lora_loader import LoraManagerLoader
from .py.nodes.trigger_word_toggle import TriggerWordToggle from .py.nodes.trigger_word_toggle import TriggerWordToggle
from .py.nodes.lora_stacker import LoraStacker
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
LoraManagerLoader.NAME: LoraManagerLoader, LoraManagerLoader.NAME: LoraManagerLoader,
TriggerWordToggle.NAME: TriggerWordToggle TriggerWordToggle.NAME: TriggerWordToggle
# LoraStacker.NAME: LoraStacker
} }
WEB_DIRECTORY = "./web/comfyui" WEB_DIRECTORY = "./web/comfyui"

View File

@@ -8,7 +8,7 @@ from .utils import FlexibleOptionalInputType, any_type
class LoraManagerLoader: class LoraManagerLoader:
NAME = "Lora Loader (LoraManager)" NAME = "Lora Loader (LoraManager)"
CATEGORY = "loaders" CATEGORY = "Lora Manager/loaders"
@classmethod @classmethod
def INPUT_TYPES(cls): def INPUT_TYPES(cls):
@@ -23,7 +23,10 @@ class LoraManagerLoader:
"placeholder": "LoRA syntax input: <lora:name:strength>" "placeholder": "LoRA syntax input: <lora:name:strength>"
}), }),
}, },
"optional": FlexibleOptionalInputType(any_type), "optional": {
**FlexibleOptionalInputType(any_type),
"lora_stack": ("LORA_STACK", {"default": None}),
}
} }
RETURN_TYPES = ("MODEL", "CLIP", IO.STRING) RETURN_TYPES = ("MODEL", "CLIP", IO.STRING)
@@ -49,11 +52,32 @@ class LoraManagerLoader:
return relative_path, trigger_words return relative_path, trigger_words
return lora_name, [] # Fallback if not found return lora_name, [] # Fallback if not found
def load_loras(self, model, clip, text, **kwargs): def extract_lora_name(self, lora_path):
"""Loads multiple LoRAs based on the kwargs input.""" """Extract the lora name from a lora path (e.g., 'IL\\aorunIllstrious.safetensors' -> 'aorunIllstrious')"""
# Get the basename without extension
basename = os.path.basename(lora_path)
return os.path.splitext(basename)[0]
def load_loras(self, model, clip, text, lora_stack=None, **kwargs):
print("load_loras kwargs: ", kwargs)
"""Loads multiple LoRAs based on the kwargs input and lora_stack."""
loaded_loras = [] loaded_loras = []
all_trigger_words = [] all_trigger_words = []
# First process lora_stack if available
if lora_stack:
for lora_path, model_strength, clip_strength in lora_stack:
# Apply the LoRA using the provided path and strengths
model, clip = LoraLoader().load_lora(model, clip, lora_path, model_strength, clip_strength)
# Extract lora name for trigger words lookup
lora_name = self.extract_lora_name(lora_path)
_, trigger_words = asyncio.run(self.get_lora_info(lora_name))
all_trigger_words.extend(trigger_words)
loaded_loras.append(f"{lora_name}: {model_strength}")
# Then process loras from kwargs
if 'loras' in kwargs: if 'loras' in kwargs:
for lora in kwargs['loras']: for lora in kwargs['loras']:
if not lora.get('active', False): if not lora.get('active', False):

94
py/nodes/lora_stacker.py Normal file
View File

@@ -0,0 +1,94 @@
from comfy.comfy_types import IO # type: ignore
from ..services.lora_scanner import LoraScanner
from ..config import config
import asyncio
import os
from .utils import FlexibleOptionalInputType, any_type
class LoraStacker:
NAME = "Lora Stacker (LoraManager)"
CATEGORY = "Lora Manager/stackers"
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"text": (IO.STRING, {
"multiline": True,
"dynamicPrompts": True,
"tooltip": "Format: <lora:lora_name:strength> separated by spaces or punctuation",
"placeholder": "LoRA syntax input: <lora:name:strength>"
}),
},
"optional": {
**FlexibleOptionalInputType(any_type),
"lora_stack": ("LORA_STACK", {"default": None}),
}
}
RETURN_TYPES = ("LORA_STACK", IO.STRING)
RETURN_NAMES = ("LORA_STACK", "trigger_words")
FUNCTION = "stack_loras"
async def get_lora_info(self, lora_name):
"""Get the lora path and trigger words from cache"""
scanner = await LoraScanner.get_instance()
cache = await scanner.get_cached_data()
for item in cache.raw_data:
if item.get('file_name') == lora_name:
file_path = item.get('file_path')
if file_path:
for root in config.loras_roots:
root = root.replace(os.sep, '/')
if file_path.startswith(root):
relative_path = os.path.relpath(file_path, root).replace(os.sep, '/')
# Get trigger words from civitai metadata
civitai = item.get('civitai', {})
trigger_words = civitai.get('trainedWords', []) if civitai else []
return relative_path, trigger_words
return lora_name, [] # Fallback if not found
def extract_lora_name(self, lora_path):
"""Extract the lora name from a lora path (e.g., 'IL\\aorunIllstrious.safetensors' -> 'aorunIllstrious')"""
# Get the basename without extension
basename = os.path.basename(lora_path)
return os.path.splitext(basename)[0]
def stack_loras(self, text, lora_stack=None, **kwargs):
print("stack_loras kwargs: ", kwargs)
"""Stacks multiple LoRAs based on the kwargs input without loading them."""
stack = []
all_trigger_words = []
# Process existing lora_stack if available
if lora_stack:
stack.extend(lora_stack)
# Get trigger words from existing stack entries
for lora_path, _, _ in lora_stack:
lora_name = self.extract_lora_name(lora_path)
_, trigger_words = asyncio.run(self.get_lora_info(lora_name))
all_trigger_words.extend(trigger_words)
if 'loras' in kwargs:
for lora in kwargs['loras']:
if not lora.get('active', False):
continue
lora_name = lora['name']
model_strength = float(lora['strength'])
clip_strength = model_strength # Using same strength for both as in the original loader
# Get lora path and trigger words
lora_path, trigger_words = asyncio.run(self.get_lora_info(lora_name))
# Add to stack without loading
stack.append((lora_path, model_strength, clip_strength))
# Add trigger words to collection
all_trigger_words.extend(trigger_words)
# use ',, ' to separate trigger words for group mode
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
return (stack, trigger_words_text)

View File

@@ -4,7 +4,7 @@ from .utils import FlexibleOptionalInputType, any_type
class TriggerWordToggle: class TriggerWordToggle:
NAME = "TriggerWord Toggle (LoraManager)" NAME = "TriggerWord Toggle (LoraManager)"
CATEGORY = "lora manager" CATEGORY = "Lora Manager/utils"
DESCRIPTION = "Toggle trigger words on/off" DESCRIPTION = "Toggle trigger words on/off"
@classmethod @classmethod
@@ -13,10 +13,7 @@ class TriggerWordToggle:
"required": { "required": {
"group_mode": ("BOOLEAN", {"default": True}), "group_mode": ("BOOLEAN", {"default": True}),
}, },
"optional": { "optional": FlexibleOptionalInputType(any_type),
**FlexibleOptionalInputType(any_type),
"trigger_words": ("STRING", {"default": "", "defaultInput": True}),
},
"hidden": { "hidden": {
"id": "UNIQUE_ID", # 会被 ComfyUI 自动替换为唯一ID "id": "UNIQUE_ID", # 会被 ComfyUI 自动替换为唯一ID
}, },
@@ -26,7 +23,9 @@ class TriggerWordToggle:
RETURN_NAMES = ("filtered_trigger_words",) RETURN_NAMES = ("filtered_trigger_words",)
FUNCTION = "process_trigger_words" FUNCTION = "process_trigger_words"
def process_trigger_words(self, id, trigger_words="", **kwargs): def process_trigger_words(self, id, **kwargs):
print("trigger_words ", kwargs)
trigger_words = kwargs.get("trigger_words", "")
# Send trigger words to frontend # Send trigger words to frontend
PromptServer.instance.send_sync("trigger_word_update", { PromptServer.instance.send_sync("trigger_word_update", {
"id": id, "id": id,

112
web/comfyui/lora_stacker.js Normal file
View File

@@ -0,0 +1,112 @@
import { app } from "../../scripts/app.js";
import { addLorasWidget } from "./loras_widget.js";
// Extract pattern into a constant for consistent use
const LORA_PATTERN = /<lora:([^:]+):([-\d\.]+)>/g;
function mergeLoras(lorasText, lorasArr) {
const result = [];
let match;
// Parse text input and create initial entries
while ((match = LORA_PATTERN.exec(lorasText)) !== null) {
const name = match[1];
const inputStrength = Number(match[2]);
// Find if this lora exists in the array data
const existingLora = lorasArr.find(l => l.name === name);
result.push({
name: name,
// Use existing strength if available, otherwise use input strength
strength: existingLora ? existingLora.strength : inputStrength,
active: existingLora ? existingLora.active : true
});
}
return result;
}
app.registerExtension({
name: "LoraManager.LoraStacker",
async nodeCreated(node) {
if (node.comfyClass === "Lora Stacker (LoraManager)") {
// Enable widget serialization
node.serialize_widgets = true;
// Wait for node to be properly initialized
requestAnimationFrame(() => {
// Restore saved value if exists
let existingLoras = [];
if (node.widgets_values && node.widgets_values.length > 0) {
const savedValue = node.widgets_values[1];
// TODO: clean up this code
try {
// Check if the value is already an array/object
if (typeof savedValue === 'object' && savedValue !== null) {
existingLoras = savedValue;
} else if (typeof savedValue === 'string') {
existingLoras = JSON.parse(savedValue);
}
} catch (e) {
console.warn("Failed to parse loras data:", e);
existingLoras = [];
}
}
// Merge the loras data
const mergedLoras = mergeLoras(node.widgets[0].value, existingLoras);
// Add flag to prevent callback loops
let isUpdating = false;
// Get the widget object directly from the returned object
const result = addLorasWidget(node, "loras", {
defaultVal: mergedLoras // Pass object directly
}, (value) => {
// Prevent recursive calls
if (isUpdating) return;
isUpdating = true;
try {
// Remove loras that are not in the value array
const inputWidget = node.widgets[0];
const currentLoras = value.map(l => l.name);
// Use the constant pattern here as well
let newText = inputWidget.value.replace(LORA_PATTERN, (match, name, strength) => {
return currentLoras.includes(name) ? match : '';
});
// Clean up multiple spaces and trim
newText = newText.replace(/\s+/g, ' ').trim();
inputWidget.value = newText;
} finally {
isUpdating = false;
}
});
node.lorasWidget = result.widget;
// Update input widget callback
const inputWidget = node.widgets[0];
inputWidget.callback = (value) => {
if (isUpdating) return;
isUpdating = true;
try {
const currentLoras = node.lorasWidget.value || [];
const mergedLoras = mergeLoras(value, currentLoras);
node.lorasWidget.value = mergedLoras;
} finally {
isUpdating = false;
}
};
});
console.log("Lora Stacker node created:", node);
}
},
});

View File

@@ -750,6 +750,7 @@ export function addLorasWidget(node, name, opts, callback) {
widget.callback = callback; widget.callback = callback;
widget.serializeValue = () => { widget.serializeValue = () => {
console.log("Serializing loras data: ", widgetValue);
// Add dummy items to avoid the 2-element serialization issue, a bug in comfyui // Add dummy items to avoid the 2-element serialization issue, a bug in comfyui
return [...widgetValue, return [...widgetValue,
{ name: "__dummy_item1__", strength: 0, active: false, _isDummy: true }, { name: "__dummy_item1__", strength: 0, active: false, _isDummy: true },

View File

@@ -22,7 +22,12 @@ app.registerExtension({
node.serialize_widgets = true; node.serialize_widgets = true;
// Wait for node to be properly initialized // Wait for node to be properly initialized
requestAnimationFrame(() => { requestAnimationFrame(() => {
node.addInput("trigger_words", 'string', {
"default": "",
"defaultInput": false, // Changed to make it optional
"optional": true // Marking the input as optional
});
// Get the widget object directly from the returned object // Get the widget object directly from the returned object
const result = addTagsWidget(node, "toggle_trigger_words", { const result = addTagsWidget(node, "toggle_trigger_words", {
defaultVal: [] defaultVal: []
@@ -39,11 +44,11 @@ app.registerExtension({
// Restore saved value if exists // Restore saved value if exists
if (node.widgets_values && node.widgets_values.length > 0) { if (node.widgets_values && node.widgets_values.length > 0) {
// 0 is group mode, 1 is input, 2 is tag widget, 3 is original message // 0 is group mode, 1 is input, 2 is tag widget, 3 is original message
const savedValue = node.widgets_values[2]; const savedValue = node.widgets_values[1];
if (savedValue) { if (savedValue) {
result.widget.value = savedValue; result.widget.value = savedValue;
} }
const originalMessage = node.widgets_values[3]; const originalMessage = node.widgets_values[2];
if (originalMessage) { if (originalMessage) {
hiddenWidget.value = originalMessage; hiddenWidget.value = originalMessage;
} }
@@ -51,10 +56,12 @@ app.registerExtension({
const groupModeWidget = node.widgets[0]; const groupModeWidget = node.widgets[0];
groupModeWidget.callback = (value) => { groupModeWidget.callback = (value) => {
if (node.widgets[3].value) { if (node.widgets[2].value) {
this.updateTagsBasedOnMode(node, node.widgets[3].value, value); this.updateTagsBasedOnMode(node, node.widgets[2].value, value);
} }
} }
console.log("node ", node);
}); });
} }
}, },
@@ -68,7 +75,7 @@ app.registerExtension({
} }
// Store the original message for mode switching // Store the original message for mode switching
node.widgets[3].value = message; node.widgets[2].value = message;
if (node.tagWidget) { if (node.tagWidget) {
// Parse tags based on current group mode // Parse tags based on current group mode