fix(trigger-words): propagate LORA_STACK updates through combiners (#881)

This commit is contained in:
Will Miao
2026-04-03 15:01:02 +08:00
parent 30db8c3d1d
commit 4f599aeced
6 changed files with 278 additions and 36 deletions

View File

@@ -0,0 +1,151 @@
import { beforeEach, describe, expect, it, vi } from "vitest";
const { APP_MODULE, UTILS_MODULE } = vi.hoisted(() => ({
APP_MODULE: new URL("../../../scripts/app.js", import.meta.url).pathname,
UTILS_MODULE: new URL("../../../web/comfyui/utils.js", import.meta.url).pathname,
}));
vi.mock(APP_MODULE, () => ({
app: {
graph: null,
registerExtension: vi.fn(),
ui: {
settings: {
getSettingValue: vi.fn(),
},
},
},
}));
describe("LoRA chain traversal", () => {
let collectActiveLorasFromChain;
beforeEach(async () => {
vi.resetModules();
({ collectActiveLorasFromChain } = await import(UTILS_MODULE));
});
function createGraph(nodes, links) {
const graph = {
_nodes: nodes,
links,
getNodeById(id) {
return nodes.find((node) => node.id === id) ?? null;
},
};
nodes.forEach((node) => {
node.graph = graph;
});
return graph;
}
it("aggregates active LoRAs through a combiner with multiple LORA_STACK inputs", () => {
const randomizerA = {
id: 1,
comfyClass: "Lora Randomizer (LoraManager)",
mode: 0,
widgets: [
{
name: "loras",
value: [
{ name: "Alpha", active: true },
{ name: "Ignored", active: false },
],
},
],
inputs: [],
outputs: [],
};
const randomizerB = {
id: 2,
comfyClass: "Lora Randomizer (LoraManager)",
mode: 0,
widgets: [
{
name: "loras",
value: [{ name: "Beta", active: true }],
},
],
inputs: [],
outputs: [],
};
const combiner = {
id: 3,
comfyClass: "Lora Stack Combiner (LoraManager)",
mode: 0,
widgets: [],
inputs: [
{ name: "lora_stack_a", type: "LORA_STACK", link: 11 },
{ name: "lora_stack_b", type: "LORA_STACK", link: 12 },
],
outputs: [],
};
const loader = {
id: 4,
comfyClass: "Lora Loader (LoraManager)",
mode: 0,
widgets: [],
inputs: [{ name: "lora_stack", type: "LORA_STACK", link: 13 }],
outputs: [],
};
createGraph(
[randomizerA, randomizerB, combiner, loader],
{
11: { origin_id: 1, target_id: 3 },
12: { origin_id: 2, target_id: 3 },
13: { origin_id: 3, target_id: 4 },
}
);
const result = collectActiveLorasFromChain(loader);
expect([...result]).toEqual(["Alpha", "Beta"]);
});
it("stops propagation when the combiner is inactive", () => {
const randomizer = {
id: 1,
comfyClass: "Lora Randomizer (LoraManager)",
mode: 0,
widgets: [
{
name: "loras",
value: [{ name: "Alpha", active: true }],
},
],
inputs: [],
outputs: [],
};
const combiner = {
id: 2,
comfyClass: "Lora Stack Combiner (LoraManager)",
mode: 2,
widgets: [],
inputs: [{ name: "lora_stack_a", type: "LORA_STACK", link: 21 }],
outputs: [],
};
const loader = {
id: 3,
comfyClass: "Lora Loader (LoraManager)",
mode: 0,
widgets: [],
inputs: [{ name: "lora_stack", type: "LORA_STACK", link: 22 }],
outputs: [],
};
createGraph(
[randomizer, combiner, loader],
{
21: { origin_id: 1, target_id: 2 },
22: { origin_id: 2, target_id: 3 },
}
);
const result = collectActiveLorasFromChain(loader);
expect(result.size).toBe(0);
});
});

View File

@@ -9,7 +9,7 @@ import type { LoraPoolConfig, RandomizerConfig, CyclerConfig } from './composabl
import {
setupModeChangeHandler,
createModeChangeCallback,
LORA_PROVIDER_NODE_TYPES
LORA_CHAIN_NODE_TYPES
} from './mode-change-handler'
const LORA_POOL_WIDGET_MIN_WIDTH = 500
@@ -755,8 +755,8 @@ app.registerExtension({
}
}
// Register mode change handlers for LoRA provider nodes
if (LORA_PROVIDER_NODE_TYPES.includes(comfyClass)) {
// Register mode change handlers for LORA_STACK chain nodes
if (LORA_CHAIN_NODE_TYPES.includes(comfyClass)) {
const originalOnNodeCreated = nodeType.prototype.onNodeCreated
nodeType.prototype.onNodeCreated = function () {

View File

@@ -18,7 +18,22 @@ export const LORA_PROVIDER_NODE_TYPES = [
"Lora Cycler (LoraManager)",
] as const;
/**
* Nodes that do not own LoRA state themselves, but merge or forward LORA_STACK
* inputs so downstream trigger-word updates must traverse through them.
*/
export const LORA_STACK_AGGREGATOR_NODE_TYPES = [
"Lora Stack Combiner (LoraManager)",
] as const;
export const LORA_CHAIN_NODE_TYPES = [
...LORA_PROVIDER_NODE_TYPES,
...LORA_STACK_AGGREGATOR_NODE_TYPES,
] as const;
export type LoraProviderNodeType = typeof LORA_PROVIDER_NODE_TYPES[number];
export type LoraStackAggregatorNodeType = typeof LORA_STACK_AGGREGATOR_NODE_TYPES[number];
export type LoraChainNodeType = typeof LORA_CHAIN_NODE_TYPES[number];
/**
* Check if a node class is a LoRA provider node.
@@ -27,6 +42,16 @@ export function isLoraProviderNode(comfyClass: string): comfyClass is LoraProvid
return LORA_PROVIDER_NODE_TYPES.includes(comfyClass as LoraProviderNodeType);
}
export function isLoraStackAggregatorNode(
comfyClass: string
): comfyClass is LoraStackAggregatorNodeType {
return LORA_STACK_AGGREGATOR_NODE_TYPES.includes(comfyClass as LoraStackAggregatorNodeType);
}
export function isLoraChainNode(comfyClass: string): comfyClass is LoraChainNodeType {
return LORA_CHAIN_NODE_TYPES.includes(comfyClass as LoraChainNodeType);
}
/**
* Extract active LoRA filenames from a node based on its type.
*
@@ -40,6 +65,10 @@ export function getActiveLorasFromNodeByType(node: any): Set<string> {
return extractFromCyclerConfig(node);
}
if (isLoraStackAggregatorNode(comfyClass)) {
return new Set<string>();
}
// Default: use lorasWidget (works for Stacker and Randomizer)
return extractFromLorasWidget(node);
}

View File

@@ -10,10 +10,27 @@ export const LORA_PROVIDER_NODE_TYPES = [
"Lora Cycler (LoraManager)",
];
export const LORA_STACK_AGGREGATOR_NODE_TYPES = [
"Lora Stack Combiner (LoraManager)",
];
export const LORA_CHAIN_NODE_TYPES = [
...LORA_PROVIDER_NODE_TYPES,
...LORA_STACK_AGGREGATOR_NODE_TYPES,
];
export function isLoraProviderNode(comfyClass) {
return LORA_PROVIDER_NODE_TYPES.includes(comfyClass);
}
export function isLoraStackAggregatorNode(comfyClass) {
return LORA_STACK_AGGREGATOR_NODE_TYPES.includes(comfyClass);
}
export function isLoraChainNode(comfyClass) {
return LORA_CHAIN_NODE_TYPES.includes(comfyClass);
}
function isMapLike(collection) {
return collection && typeof collection.entries === "function" && typeof collection.values === "function";
}
@@ -245,16 +262,20 @@ export function hideWidgetForGood(node, widget, suffix = "") {
// Update pattern to match both formats: <lora:name:model_strength> or <lora:name:model_strength:clip_strength>
export const LORA_PATTERN = /<lora:([^:]+):([-\d\.]+)(?::([-\d\.]+))?>/g;
// Get connected Lora Stacker nodes that feed into the current node
export function getConnectedInputStackers(node) {
const connectedStackers = [];
function isLoraStackInput(input) {
return input?.type === "LORA_STACK";
}
// Get connected LORA_STACK chain nodes that feed into the current node
export function getConnectedInputLoraChainNodes(node) {
const connectedNodes = [];
if (!node?.inputs) {
return connectedStackers;
return connectedNodes;
}
for (const input of node.inputs) {
if (input.name !== "lora_stack" || !input.link) {
if (!isLoraStackInput(input) || !input.link) {
continue;
}
@@ -264,12 +285,12 @@ export function getConnectedInputStackers(node) {
}
const sourceNode = node.graph?.getNodeById?.(link.origin_id);
if (sourceNode && isLoraProviderNode(sourceNode.comfyClass)) {
connectedStackers.push(sourceNode);
if (sourceNode && isLoraChainNode(sourceNode.comfyClass)) {
connectedNodes.push(sourceNode);
}
}
return connectedStackers;
return connectedNodes;
}
// Get connected TriggerWord Toggle nodes that receive output from the current node
@@ -314,6 +335,11 @@ export function getActiveLorasFromNode(node) {
return activeLoraNames;
}
// Aggregator nodes do not own LoRA state directly; they only forward upstream stacks.
if (isLoraStackAggregatorNode(node.comfyClass)) {
return activeLoraNames;
}
// Handle Lora Stacker and Lora Randomizer (lorasWidget)
let lorasWidget = node.lorasWidget;
if (!lorasWidget && node.widgets) {
@@ -348,14 +374,18 @@ export function collectActiveLorasFromChain(node, visited = new Set()) {
// Mode 2 is Never, Mode 4 is Bypass
const isNodeActive = node.mode === undefined || node.mode === 0 || node.mode === 3;
// Get active loras from current node only if node is active
const allActiveLoraNames = isNodeActive ? getActiveLorasFromNode(node) : new Set();
if (!isNodeActive) {
return new Set();
}
// Get connected input stackers and collect their active loras
const inputStackers = getConnectedInputStackers(node);
for (const stacker of inputStackers) {
const stackerLoras = collectActiveLorasFromChain(stacker, visited);
stackerLoras.forEach(name => allActiveLoraNames.add(name));
// Get active loras from current node only if node is active
const allActiveLoraNames = getActiveLorasFromNode(node);
// Get connected input LORA_STACK chain nodes and collect their active loras
const inputChainNodes = getConnectedInputLoraChainNodes(node);
for (const chainNode of inputChainNodes) {
const upstreamLoras = collectActiveLorasFromChain(chainNode, visited);
upstreamLoras.forEach(name => allActiveLoraNames.add(name));
}
return allActiveLoraNames;
@@ -819,8 +849,8 @@ export function updateDownstreamLoaders(startNode, visited = new Set()) {
collectActiveLorasFromChain(targetNode);
updateConnectedTriggerWords(targetNode, allActiveLoraNames);
}
// If target is another LoRA provider node, recursively check its outputs
else if (targetNode && isLoraProviderNode(targetNode.comfyClass)) {
// If target is another LORA_STACK chain node, recursively check its outputs
else if (targetNode && isLoraChainNode(targetNode.comfyClass)) {
updateDownstreamLoaders(targetNode, visited);
}
}

View File

@@ -14938,11 +14938,24 @@ const LORA_PROVIDER_NODE_TYPES$1 = [
"Lora Randomizer (LoraManager)",
"Lora Cycler (LoraManager)"
];
const LORA_STACK_AGGREGATOR_NODE_TYPES$1 = [
"Lora Stack Combiner (LoraManager)"
];
const LORA_CHAIN_NODE_TYPES$1 = [
...LORA_PROVIDER_NODE_TYPES$1,
...LORA_STACK_AGGREGATOR_NODE_TYPES$1
];
function isLoraStackAggregatorNode$1(comfyClass) {
return LORA_STACK_AGGREGATOR_NODE_TYPES$1.includes(comfyClass);
}
function getActiveLorasFromNodeByType(node) {
const comfyClass = node == null ? void 0 : node.comfyClass;
if (comfyClass === "Lora Cycler (LoraManager)") {
return extractFromCyclerConfig(node);
}
if (isLoraStackAggregatorNode$1(comfyClass)) {
return /* @__PURE__ */ new Set();
}
return extractFromLorasWidget(node);
}
function extractFromLorasWidget(node) {
@@ -15002,8 +15015,18 @@ const LORA_PROVIDER_NODE_TYPES = [
"Lora Randomizer (LoraManager)",
"Lora Cycler (LoraManager)"
];
function isLoraProviderNode(comfyClass) {
return LORA_PROVIDER_NODE_TYPES.includes(comfyClass);
const LORA_STACK_AGGREGATOR_NODE_TYPES = [
"Lora Stack Combiner (LoraManager)"
];
const LORA_CHAIN_NODE_TYPES = [
...LORA_PROVIDER_NODE_TYPES,
...LORA_STACK_AGGREGATOR_NODE_TYPES
];
function isLoraStackAggregatorNode(comfyClass) {
return LORA_STACK_AGGREGATOR_NODE_TYPES.includes(comfyClass);
}
function isLoraChainNode(comfyClass) {
return LORA_CHAIN_NODE_TYPES.includes(comfyClass);
}
function isMapLike(collection) {
return collection && typeof collection.entries === "function" && typeof collection.values === "function";
@@ -15041,14 +15064,17 @@ function getLinkFromGraph(graph, linkId) {
}
return graph.links[linkId] || null;
}
function getConnectedInputStackers(node) {
function isLoraStackInput(input) {
return (input == null ? void 0 : input.type) === "LORA_STACK";
}
function getConnectedInputLoraChainNodes(node) {
var _a2, _b;
const connectedStackers = [];
const connectedNodes = [];
if (!(node == null ? void 0 : node.inputs)) {
return connectedStackers;
return connectedNodes;
}
for (const input of node.inputs) {
if (input.name !== "lora_stack" || !input.link) {
if (!isLoraStackInput(input) || !input.link) {
continue;
}
const link = getLinkFromGraph(node.graph, input.link);
@@ -15056,11 +15082,11 @@ function getConnectedInputStackers(node) {
continue;
}
const sourceNode = (_b = (_a2 = node.graph) == null ? void 0 : _a2.getNodeById) == null ? void 0 : _b.call(_a2, link.origin_id);
if (sourceNode && isLoraProviderNode(sourceNode.comfyClass)) {
connectedStackers.push(sourceNode);
if (sourceNode && isLoraChainNode(sourceNode.comfyClass)) {
connectedNodes.push(sourceNode);
}
}
return connectedStackers;
return connectedNodes;
}
function getConnectedTriggerToggleNodes(node) {
var _a2, _b, _c;
@@ -15095,6 +15121,9 @@ function getActiveLorasFromNode(node) {
}
return activeLoraNames;
}
if (isLoraStackAggregatorNode(node.comfyClass)) {
return activeLoraNames;
}
let lorasWidget = node.lorasWidget;
if (!lorasWidget && node.widgets) {
lorasWidget = node.widgets.find((w2) => w2.name === "loras");
@@ -15118,11 +15147,14 @@ function collectActiveLorasFromChain(node, visited = /* @__PURE__ */ new Set())
}
visited.add(nodeKey);
const isNodeActive2 = node.mode === void 0 || node.mode === 0 || node.mode === 3;
const allActiveLoraNames = isNodeActive2 ? getActiveLorasFromNode(node) : /* @__PURE__ */ new Set();
const inputStackers = getConnectedInputStackers(node);
for (const stacker of inputStackers) {
const stackerLoras = collectActiveLorasFromChain(stacker, visited);
stackerLoras.forEach((name) => allActiveLoraNames.add(name));
if (!isNodeActive2) {
return /* @__PURE__ */ new Set();
}
const allActiveLoraNames = getActiveLorasFromNode(node);
const inputChainNodes = getConnectedInputLoraChainNodes(node);
for (const chainNode of inputChainNodes) {
const upstreamLoras = collectActiveLorasFromChain(chainNode, visited);
upstreamLoras.forEach((name) => allActiveLoraNames.add(name));
}
return allActiveLoraNames;
}
@@ -15191,7 +15223,7 @@ function updateDownstreamLoaders(startNode, visited = /* @__PURE__ */ new Set())
if (targetNode && targetNode.comfyClass === "Lora Loader (LoraManager)") {
const allActiveLoraNames = collectActiveLorasFromChain(targetNode);
updateConnectedTriggerWords(targetNode, allActiveLoraNames);
} else if (targetNode && isLoraProviderNode(targetNode.comfyClass)) {
} else if (targetNode && isLoraChainNode(targetNode.comfyClass)) {
updateDownstreamLoaders(targetNode, visited);
}
}
@@ -15784,7 +15816,7 @@ app$1.registerExtension({
return originalConfigure == null ? void 0 : originalConfigure.apply(this, arguments);
};
}
if (LORA_PROVIDER_NODE_TYPES$1.includes(comfyClass)) {
if (LORA_CHAIN_NODE_TYPES$1.includes(comfyClass)) {
const originalOnNodeCreated = nodeType.prototype.onNodeCreated;
nodeType.prototype.onNodeCreated = function() {
originalOnNodeCreated == null ? void 0 : originalOnNodeCreated.apply(this, arguments);

File diff suppressed because one or more lines are too long