fix(trigger-words): propagate LORA_STACK updates through combiners (#881)

This commit is contained in:
Will Miao
2026-04-03 15:01:02 +08:00
parent 30db8c3d1d
commit 4f599aeced
6 changed files with 278 additions and 36 deletions

View File

@@ -0,0 +1,151 @@
import { beforeEach, describe, expect, it, vi } from "vitest";
const { APP_MODULE, UTILS_MODULE } = vi.hoisted(() => ({
APP_MODULE: new URL("../../../scripts/app.js", import.meta.url).pathname,
UTILS_MODULE: new URL("../../../web/comfyui/utils.js", import.meta.url).pathname,
}));
vi.mock(APP_MODULE, () => ({
app: {
graph: null,
registerExtension: vi.fn(),
ui: {
settings: {
getSettingValue: vi.fn(),
},
},
},
}));
describe("LoRA chain traversal", () => {
let collectActiveLorasFromChain;
beforeEach(async () => {
vi.resetModules();
({ collectActiveLorasFromChain } = await import(UTILS_MODULE));
});
function createGraph(nodes, links) {
const graph = {
_nodes: nodes,
links,
getNodeById(id) {
return nodes.find((node) => node.id === id) ?? null;
},
};
nodes.forEach((node) => {
node.graph = graph;
});
return graph;
}
it("aggregates active LoRAs through a combiner with multiple LORA_STACK inputs", () => {
const randomizerA = {
id: 1,
comfyClass: "Lora Randomizer (LoraManager)",
mode: 0,
widgets: [
{
name: "loras",
value: [
{ name: "Alpha", active: true },
{ name: "Ignored", active: false },
],
},
],
inputs: [],
outputs: [],
};
const randomizerB = {
id: 2,
comfyClass: "Lora Randomizer (LoraManager)",
mode: 0,
widgets: [
{
name: "loras",
value: [{ name: "Beta", active: true }],
},
],
inputs: [],
outputs: [],
};
const combiner = {
id: 3,
comfyClass: "Lora Stack Combiner (LoraManager)",
mode: 0,
widgets: [],
inputs: [
{ name: "lora_stack_a", type: "LORA_STACK", link: 11 },
{ name: "lora_stack_b", type: "LORA_STACK", link: 12 },
],
outputs: [],
};
const loader = {
id: 4,
comfyClass: "Lora Loader (LoraManager)",
mode: 0,
widgets: [],
inputs: [{ name: "lora_stack", type: "LORA_STACK", link: 13 }],
outputs: [],
};
createGraph(
[randomizerA, randomizerB, combiner, loader],
{
11: { origin_id: 1, target_id: 3 },
12: { origin_id: 2, target_id: 3 },
13: { origin_id: 3, target_id: 4 },
}
);
const result = collectActiveLorasFromChain(loader);
expect([...result]).toEqual(["Alpha", "Beta"]);
});
it("stops propagation when the combiner is inactive", () => {
const randomizer = {
id: 1,
comfyClass: "Lora Randomizer (LoraManager)",
mode: 0,
widgets: [
{
name: "loras",
value: [{ name: "Alpha", active: true }],
},
],
inputs: [],
outputs: [],
};
const combiner = {
id: 2,
comfyClass: "Lora Stack Combiner (LoraManager)",
mode: 2,
widgets: [],
inputs: [{ name: "lora_stack_a", type: "LORA_STACK", link: 21 }],
outputs: [],
};
const loader = {
id: 3,
comfyClass: "Lora Loader (LoraManager)",
mode: 0,
widgets: [],
inputs: [{ name: "lora_stack", type: "LORA_STACK", link: 22 }],
outputs: [],
};
createGraph(
[randomizer, combiner, loader],
{
21: { origin_id: 1, target_id: 2 },
22: { origin_id: 2, target_id: 3 },
}
);
const result = collectActiveLorasFromChain(loader);
expect(result.size).toBe(0);
});
});

View File

@@ -9,7 +9,7 @@ import type { LoraPoolConfig, RandomizerConfig, CyclerConfig } from './composabl
import { import {
setupModeChangeHandler, setupModeChangeHandler,
createModeChangeCallback, createModeChangeCallback,
LORA_PROVIDER_NODE_TYPES LORA_CHAIN_NODE_TYPES
} from './mode-change-handler' } from './mode-change-handler'
const LORA_POOL_WIDGET_MIN_WIDTH = 500 const LORA_POOL_WIDGET_MIN_WIDTH = 500
@@ -755,8 +755,8 @@ app.registerExtension({
} }
} }
// Register mode change handlers for LoRA provider nodes // Register mode change handlers for LORA_STACK chain nodes
if (LORA_PROVIDER_NODE_TYPES.includes(comfyClass)) { if (LORA_CHAIN_NODE_TYPES.includes(comfyClass)) {
const originalOnNodeCreated = nodeType.prototype.onNodeCreated const originalOnNodeCreated = nodeType.prototype.onNodeCreated
nodeType.prototype.onNodeCreated = function () { nodeType.prototype.onNodeCreated = function () {

View File

@@ -18,7 +18,22 @@ export const LORA_PROVIDER_NODE_TYPES = [
"Lora Cycler (LoraManager)", "Lora Cycler (LoraManager)",
] as const; ] as const;
/**
* Nodes that do not own LoRA state themselves, but merge or forward LORA_STACK
* inputs so downstream trigger-word updates must traverse through them.
*/
export const LORA_STACK_AGGREGATOR_NODE_TYPES = [
"Lora Stack Combiner (LoraManager)",
] as const;
export const LORA_CHAIN_NODE_TYPES = [
...LORA_PROVIDER_NODE_TYPES,
...LORA_STACK_AGGREGATOR_NODE_TYPES,
] as const;
export type LoraProviderNodeType = typeof LORA_PROVIDER_NODE_TYPES[number]; export type LoraProviderNodeType = typeof LORA_PROVIDER_NODE_TYPES[number];
export type LoraStackAggregatorNodeType = typeof LORA_STACK_AGGREGATOR_NODE_TYPES[number];
export type LoraChainNodeType = typeof LORA_CHAIN_NODE_TYPES[number];
/** /**
* Check if a node class is a LoRA provider node. * Check if a node class is a LoRA provider node.
@@ -27,6 +42,16 @@ export function isLoraProviderNode(comfyClass: string): comfyClass is LoraProvid
return LORA_PROVIDER_NODE_TYPES.includes(comfyClass as LoraProviderNodeType); return LORA_PROVIDER_NODE_TYPES.includes(comfyClass as LoraProviderNodeType);
} }
export function isLoraStackAggregatorNode(
comfyClass: string
): comfyClass is LoraStackAggregatorNodeType {
return LORA_STACK_AGGREGATOR_NODE_TYPES.includes(comfyClass as LoraStackAggregatorNodeType);
}
export function isLoraChainNode(comfyClass: string): comfyClass is LoraChainNodeType {
return LORA_CHAIN_NODE_TYPES.includes(comfyClass as LoraChainNodeType);
}
/** /**
* Extract active LoRA filenames from a node based on its type. * Extract active LoRA filenames from a node based on its type.
* *
@@ -40,6 +65,10 @@ export function getActiveLorasFromNodeByType(node: any): Set<string> {
return extractFromCyclerConfig(node); return extractFromCyclerConfig(node);
} }
if (isLoraStackAggregatorNode(comfyClass)) {
return new Set<string>();
}
// Default: use lorasWidget (works for Stacker and Randomizer) // Default: use lorasWidget (works for Stacker and Randomizer)
return extractFromLorasWidget(node); return extractFromLorasWidget(node);
} }

View File

@@ -10,10 +10,27 @@ export const LORA_PROVIDER_NODE_TYPES = [
"Lora Cycler (LoraManager)", "Lora Cycler (LoraManager)",
]; ];
export const LORA_STACK_AGGREGATOR_NODE_TYPES = [
"Lora Stack Combiner (LoraManager)",
];
export const LORA_CHAIN_NODE_TYPES = [
...LORA_PROVIDER_NODE_TYPES,
...LORA_STACK_AGGREGATOR_NODE_TYPES,
];
export function isLoraProviderNode(comfyClass) { export function isLoraProviderNode(comfyClass) {
return LORA_PROVIDER_NODE_TYPES.includes(comfyClass); return LORA_PROVIDER_NODE_TYPES.includes(comfyClass);
} }
export function isLoraStackAggregatorNode(comfyClass) {
return LORA_STACK_AGGREGATOR_NODE_TYPES.includes(comfyClass);
}
export function isLoraChainNode(comfyClass) {
return LORA_CHAIN_NODE_TYPES.includes(comfyClass);
}
function isMapLike(collection) { function isMapLike(collection) {
return collection && typeof collection.entries === "function" && typeof collection.values === "function"; return collection && typeof collection.entries === "function" && typeof collection.values === "function";
} }
@@ -245,16 +262,20 @@ export function hideWidgetForGood(node, widget, suffix = "") {
// Update pattern to match both formats: <lora:name:model_strength> or <lora:name:model_strength:clip_strength> // Update pattern to match both formats: <lora:name:model_strength> or <lora:name:model_strength:clip_strength>
export const LORA_PATTERN = /<lora:([^:]+):([-\d\.]+)(?::([-\d\.]+))?>/g; export const LORA_PATTERN = /<lora:([^:]+):([-\d\.]+)(?::([-\d\.]+))?>/g;
// Get connected Lora Stacker nodes that feed into the current node function isLoraStackInput(input) {
export function getConnectedInputStackers(node) { return input?.type === "LORA_STACK";
const connectedStackers = []; }
// Get connected LORA_STACK chain nodes that feed into the current node
export function getConnectedInputLoraChainNodes(node) {
const connectedNodes = [];
if (!node?.inputs) { if (!node?.inputs) {
return connectedStackers; return connectedNodes;
} }
for (const input of node.inputs) { for (const input of node.inputs) {
if (input.name !== "lora_stack" || !input.link) { if (!isLoraStackInput(input) || !input.link) {
continue; continue;
} }
@@ -264,12 +285,12 @@ export function getConnectedInputStackers(node) {
} }
const sourceNode = node.graph?.getNodeById?.(link.origin_id); const sourceNode = node.graph?.getNodeById?.(link.origin_id);
if (sourceNode && isLoraProviderNode(sourceNode.comfyClass)) { if (sourceNode && isLoraChainNode(sourceNode.comfyClass)) {
connectedStackers.push(sourceNode); connectedNodes.push(sourceNode);
} }
} }
return connectedStackers; return connectedNodes;
} }
// Get connected TriggerWord Toggle nodes that receive output from the current node // Get connected TriggerWord Toggle nodes that receive output from the current node
@@ -314,6 +335,11 @@ export function getActiveLorasFromNode(node) {
return activeLoraNames; return activeLoraNames;
} }
// Aggregator nodes do not own LoRA state directly; they only forward upstream stacks.
if (isLoraStackAggregatorNode(node.comfyClass)) {
return activeLoraNames;
}
// Handle Lora Stacker and Lora Randomizer (lorasWidget) // Handle Lora Stacker and Lora Randomizer (lorasWidget)
let lorasWidget = node.lorasWidget; let lorasWidget = node.lorasWidget;
if (!lorasWidget && node.widgets) { if (!lorasWidget && node.widgets) {
@@ -348,14 +374,18 @@ export function collectActiveLorasFromChain(node, visited = new Set()) {
// Mode 2 is Never, Mode 4 is Bypass // Mode 2 is Never, Mode 4 is Bypass
const isNodeActive = node.mode === undefined || node.mode === 0 || node.mode === 3; const isNodeActive = node.mode === undefined || node.mode === 0 || node.mode === 3;
if (!isNodeActive) {
return new Set();
}
// Get active loras from current node only if node is active // Get active loras from current node only if node is active
const allActiveLoraNames = isNodeActive ? getActiveLorasFromNode(node) : new Set(); const allActiveLoraNames = getActiveLorasFromNode(node);
// Get connected input stackers and collect their active loras // Get connected input LORA_STACK chain nodes and collect their active loras
const inputStackers = getConnectedInputStackers(node); const inputChainNodes = getConnectedInputLoraChainNodes(node);
for (const stacker of inputStackers) { for (const chainNode of inputChainNodes) {
const stackerLoras = collectActiveLorasFromChain(stacker, visited); const upstreamLoras = collectActiveLorasFromChain(chainNode, visited);
stackerLoras.forEach(name => allActiveLoraNames.add(name)); upstreamLoras.forEach(name => allActiveLoraNames.add(name));
} }
return allActiveLoraNames; return allActiveLoraNames;
@@ -819,8 +849,8 @@ export function updateDownstreamLoaders(startNode, visited = new Set()) {
collectActiveLorasFromChain(targetNode); collectActiveLorasFromChain(targetNode);
updateConnectedTriggerWords(targetNode, allActiveLoraNames); updateConnectedTriggerWords(targetNode, allActiveLoraNames);
} }
// If target is another LoRA provider node, recursively check its outputs // If target is another LORA_STACK chain node, recursively check its outputs
else if (targetNode && isLoraProviderNode(targetNode.comfyClass)) { else if (targetNode && isLoraChainNode(targetNode.comfyClass)) {
updateDownstreamLoaders(targetNode, visited); updateDownstreamLoaders(targetNode, visited);
} }
} }

View File

@@ -14938,11 +14938,24 @@ const LORA_PROVIDER_NODE_TYPES$1 = [
"Lora Randomizer (LoraManager)", "Lora Randomizer (LoraManager)",
"Lora Cycler (LoraManager)" "Lora Cycler (LoraManager)"
]; ];
const LORA_STACK_AGGREGATOR_NODE_TYPES$1 = [
"Lora Stack Combiner (LoraManager)"
];
const LORA_CHAIN_NODE_TYPES$1 = [
...LORA_PROVIDER_NODE_TYPES$1,
...LORA_STACK_AGGREGATOR_NODE_TYPES$1
];
function isLoraStackAggregatorNode$1(comfyClass) {
return LORA_STACK_AGGREGATOR_NODE_TYPES$1.includes(comfyClass);
}
function getActiveLorasFromNodeByType(node) { function getActiveLorasFromNodeByType(node) {
const comfyClass = node == null ? void 0 : node.comfyClass; const comfyClass = node == null ? void 0 : node.comfyClass;
if (comfyClass === "Lora Cycler (LoraManager)") { if (comfyClass === "Lora Cycler (LoraManager)") {
return extractFromCyclerConfig(node); return extractFromCyclerConfig(node);
} }
if (isLoraStackAggregatorNode$1(comfyClass)) {
return /* @__PURE__ */ new Set();
}
return extractFromLorasWidget(node); return extractFromLorasWidget(node);
} }
function extractFromLorasWidget(node) { function extractFromLorasWidget(node) {
@@ -15002,8 +15015,18 @@ const LORA_PROVIDER_NODE_TYPES = [
"Lora Randomizer (LoraManager)", "Lora Randomizer (LoraManager)",
"Lora Cycler (LoraManager)" "Lora Cycler (LoraManager)"
]; ];
function isLoraProviderNode(comfyClass) { const LORA_STACK_AGGREGATOR_NODE_TYPES = [
return LORA_PROVIDER_NODE_TYPES.includes(comfyClass); "Lora Stack Combiner (LoraManager)"
];
const LORA_CHAIN_NODE_TYPES = [
...LORA_PROVIDER_NODE_TYPES,
...LORA_STACK_AGGREGATOR_NODE_TYPES
];
function isLoraStackAggregatorNode(comfyClass) {
return LORA_STACK_AGGREGATOR_NODE_TYPES.includes(comfyClass);
}
function isLoraChainNode(comfyClass) {
return LORA_CHAIN_NODE_TYPES.includes(comfyClass);
} }
function isMapLike(collection) { function isMapLike(collection) {
return collection && typeof collection.entries === "function" && typeof collection.values === "function"; return collection && typeof collection.entries === "function" && typeof collection.values === "function";
@@ -15041,14 +15064,17 @@ function getLinkFromGraph(graph, linkId) {
} }
return graph.links[linkId] || null; return graph.links[linkId] || null;
} }
function getConnectedInputStackers(node) { function isLoraStackInput(input) {
return (input == null ? void 0 : input.type) === "LORA_STACK";
}
function getConnectedInputLoraChainNodes(node) {
var _a2, _b; var _a2, _b;
const connectedStackers = []; const connectedNodes = [];
if (!(node == null ? void 0 : node.inputs)) { if (!(node == null ? void 0 : node.inputs)) {
return connectedStackers; return connectedNodes;
} }
for (const input of node.inputs) { for (const input of node.inputs) {
if (input.name !== "lora_stack" || !input.link) { if (!isLoraStackInput(input) || !input.link) {
continue; continue;
} }
const link = getLinkFromGraph(node.graph, input.link); const link = getLinkFromGraph(node.graph, input.link);
@@ -15056,11 +15082,11 @@ function getConnectedInputStackers(node) {
continue; continue;
} }
const sourceNode = (_b = (_a2 = node.graph) == null ? void 0 : _a2.getNodeById) == null ? void 0 : _b.call(_a2, link.origin_id); const sourceNode = (_b = (_a2 = node.graph) == null ? void 0 : _a2.getNodeById) == null ? void 0 : _b.call(_a2, link.origin_id);
if (sourceNode && isLoraProviderNode(sourceNode.comfyClass)) { if (sourceNode && isLoraChainNode(sourceNode.comfyClass)) {
connectedStackers.push(sourceNode); connectedNodes.push(sourceNode);
} }
} }
return connectedStackers; return connectedNodes;
} }
function getConnectedTriggerToggleNodes(node) { function getConnectedTriggerToggleNodes(node) {
var _a2, _b, _c; var _a2, _b, _c;
@@ -15095,6 +15121,9 @@ function getActiveLorasFromNode(node) {
} }
return activeLoraNames; return activeLoraNames;
} }
if (isLoraStackAggregatorNode(node.comfyClass)) {
return activeLoraNames;
}
let lorasWidget = node.lorasWidget; let lorasWidget = node.lorasWidget;
if (!lorasWidget && node.widgets) { if (!lorasWidget && node.widgets) {
lorasWidget = node.widgets.find((w2) => w2.name === "loras"); lorasWidget = node.widgets.find((w2) => w2.name === "loras");
@@ -15118,11 +15147,14 @@ function collectActiveLorasFromChain(node, visited = /* @__PURE__ */ new Set())
} }
visited.add(nodeKey); visited.add(nodeKey);
const isNodeActive2 = node.mode === void 0 || node.mode === 0 || node.mode === 3; const isNodeActive2 = node.mode === void 0 || node.mode === 0 || node.mode === 3;
const allActiveLoraNames = isNodeActive2 ? getActiveLorasFromNode(node) : /* @__PURE__ */ new Set(); if (!isNodeActive2) {
const inputStackers = getConnectedInputStackers(node); return /* @__PURE__ */ new Set();
for (const stacker of inputStackers) { }
const stackerLoras = collectActiveLorasFromChain(stacker, visited); const allActiveLoraNames = getActiveLorasFromNode(node);
stackerLoras.forEach((name) => allActiveLoraNames.add(name)); const inputChainNodes = getConnectedInputLoraChainNodes(node);
for (const chainNode of inputChainNodes) {
const upstreamLoras = collectActiveLorasFromChain(chainNode, visited);
upstreamLoras.forEach((name) => allActiveLoraNames.add(name));
} }
return allActiveLoraNames; return allActiveLoraNames;
} }
@@ -15191,7 +15223,7 @@ function updateDownstreamLoaders(startNode, visited = /* @__PURE__ */ new Set())
if (targetNode && targetNode.comfyClass === "Lora Loader (LoraManager)") { if (targetNode && targetNode.comfyClass === "Lora Loader (LoraManager)") {
const allActiveLoraNames = collectActiveLorasFromChain(targetNode); const allActiveLoraNames = collectActiveLorasFromChain(targetNode);
updateConnectedTriggerWords(targetNode, allActiveLoraNames); updateConnectedTriggerWords(targetNode, allActiveLoraNames);
} else if (targetNode && isLoraProviderNode(targetNode.comfyClass)) { } else if (targetNode && isLoraChainNode(targetNode.comfyClass)) {
updateDownstreamLoaders(targetNode, visited); updateDownstreamLoaders(targetNode, visited);
} }
} }
@@ -15784,7 +15816,7 @@ app$1.registerExtension({
return originalConfigure == null ? void 0 : originalConfigure.apply(this, arguments); return originalConfigure == null ? void 0 : originalConfigure.apply(this, arguments);
}; };
} }
if (LORA_PROVIDER_NODE_TYPES$1.includes(comfyClass)) { if (LORA_CHAIN_NODE_TYPES$1.includes(comfyClass)) {
const originalOnNodeCreated = nodeType.prototype.onNodeCreated; const originalOnNodeCreated = nodeType.prototype.onNodeCreated;
nodeType.prototype.onNodeCreated = function() { nodeType.prototype.onNodeCreated = function() {
originalOnNodeCreated == null ? void 0 : originalOnNodeCreated.apply(this, arguments); originalOnNodeCreated == null ? void 0 : originalOnNodeCreated.apply(this, arguments);

File diff suppressed because one or more lines are too long