feat(lora): add mode change listeners to update trigger words

Add property descriptor to listen for mode changes in Lora Loader and Lora Stacker nodes. When node mode changes, automatically update connected trigger word toggle nodes and downstream loader nodes to maintain synchronization between node modes and trigger word states.

- Lora Loader: Updates connected trigger words when mode changes
- Lora Stacker: Updates connected trigger words and downstream loaders when mode changes
- Both nodes log mode changes for debugging purposes
This commit is contained in:
Will Miao
2025-11-07 15:11:59 +08:00
parent ce5a1ae3d0
commit f76343f389
3 changed files with 321 additions and 0 deletions

View File

@@ -0,0 +1,264 @@
import { describe, it, expect, beforeEach, vi } from "vitest";
const {
APP_MODULE,
API_MODULE,
UTILS_MODULE,
LORAS_WIDGET_MODULE,
LORA_LOADER_MODULE,
LORA_STACKER_MODULE,
} = vi.hoisted(() => ({
APP_MODULE: new URL("../../../scripts/app.js", import.meta.url).pathname,
API_MODULE: new URL("../../../scripts/api.js", import.meta.url).pathname,
UTILS_MODULE: new URL("../../../web/comfyui/utils.js", import.meta.url).pathname,
LORAS_WIDGET_MODULE: new URL("../../../web/comfyui/loras_widget.js", import.meta.url).pathname,
LORA_LOADER_MODULE: new URL("../../../web/comfyui/lora_loader.js", import.meta.url).pathname,
LORA_STACKER_MODULE: new URL("../../../web/comfyui/lora_stacker.js", import.meta.url).pathname,
}));
const extensionState = {
loraLoader: null,
loraStacker: null
};
const registerExtensionMock = vi.fn((extension) => {
if (extension.name === "LoraManager.LoraLoader") {
extensionState.loraLoader = extension;
} else if (extension.name === "LoraManager.LoraStacker") {
extensionState.loraStacker = extension;
}
});
vi.mock(APP_MODULE, () => ({
app: {
registerExtension: registerExtensionMock,
graph: {},
},
}));
vi.mock(API_MODULE, () => ({
api: {
addEventListener: vi.fn(),
},
}));
const collectActiveLorasFromChain = vi.fn();
const updateConnectedTriggerWords = vi.fn();
const getActiveLorasFromNode = vi.fn();
const mergeLoras = vi.fn();
const setupInputWidgetWithAutocomplete = vi.fn();
const getAllGraphNodes = vi.fn();
const getNodeFromGraph = vi.fn();
const getNodeKey = vi.fn();
const getLinkFromGraph = vi.fn();
const chainCallback = vi.fn((proto, property, callback) => {
proto[property] = callback;
});
vi.mock(UTILS_MODULE, async (importOriginal) => {
const actual = await importOriginal();
return {
...actual,
collectActiveLorasFromChain,
updateConnectedTriggerWords,
getActiveLorasFromNode,
mergeLoras,
setupInputWidgetWithAutocomplete,
chainCallback,
getAllGraphNodes,
getNodeFromGraph,
getNodeKey,
getLinkFromGraph,
};
});
const addLorasWidget = vi.fn();
vi.mock(LORAS_WIDGET_MODULE, () => ({
addLorasWidget,
}));
describe("Node mode change handling", () => {
beforeEach(() => {
vi.resetModules();
extensionState.loraLoader = null;
extensionState.loraStacker = null;
registerExtensionMock.mockClear();
collectActiveLorasFromChain.mockClear();
collectActiveLorasFromChain.mockImplementation(() => new Set(["Alpha"]));
updateConnectedTriggerWords.mockClear();
getActiveLorasFromNode.mockClear();
getActiveLorasFromNode.mockImplementation(() => new Set(["Alpha"]));
mergeLoras.mockClear();
mergeLoras.mockImplementation(() => [{ name: "Alpha", active: true }]);
setupInputWidgetWithAutocomplete.mockClear();
setupInputWidgetWithAutocomplete.mockImplementation(
(_node, _widget, originalCallback) => originalCallback
);
addLorasWidget.mockClear();
addLorasWidget.mockImplementation((_node, _name, _opts, callback) => ({
widget: { value: [], callback },
}));
});
describe("Lora Stacker mode change handling", () => {
let node, extension;
beforeEach(async () => {
await import(LORA_STACKER_MODULE);
expect(registerExtensionMock).toHaveBeenCalled();
extension = extensionState.loraStacker;
expect(extension).toBeDefined();
const nodeType = { comfyClass: "Lora Stacker (LoraManager)", prototype: {} };
await extension.beforeRegisterNodeDef(nodeType, {}, {});
node = {
comfyClass: "Lora Stacker (LoraManager)",
widgets: [
{
value: "",
options: {},
callback: () => {},
},
],
addInput: vi.fn(),
mode: 0, // Initial mode
graph: {},
outputs: [], // Add outputs property for updateDownstreamLoaders
};
nodeType.prototype.onNodeCreated.call(node);
});
it("should handle mode property changes", () => {
const initialMode = node.mode;
expect(initialMode).toBe(0);
// Change mode from 0 to 3
node.mode = 3;
// Verify that the property was updated
expect(node.mode).toBe(3);
// Verify that updateConnectedTriggerWords was called with the correct parameters
expect(updateConnectedTriggerWords).toHaveBeenCalledWith(
node,
expect.anything() // This would be the active Lora names set
);
});
it("should update trigger words based on node activity when mode changes", () => {
// Set up the mock to return active loras when mode is 0 or 3
getActiveLorasFromNode.mockImplementation(() => new Set(["Alpha", "Beta"]));
// Change to active mode (0)
node.mode = 0;
expect(updateConnectedTriggerWords).toHaveBeenCalledWith(
node,
new Set(["Alpha", "Beta"]) // Should call with active loras
);
// Change to inactive mode (1) - should call with empty set
updateConnectedTriggerWords.mockClear();
getActiveLorasFromNode.mockImplementation(() => new Set()); // Return empty set for inactive mode
node.mode = 1;
expect(updateConnectedTriggerWords).toHaveBeenCalledWith(
node,
new Set() // Should call with empty set for inactive mode
);
});
});
describe("Lora Loader mode change handling", () => {
let node, extension;
beforeEach(async () => {
await import(LORA_LOADER_MODULE);
expect(registerExtensionMock).toHaveBeenCalled();
extension = extensionState.loraLoader;
expect(extension).toBeDefined();
const nodeType = { comfyClass: "Lora Loader (LoraManager)", prototype: {} };
await extension.beforeRegisterNodeDef(nodeType, {}, {});
node = {
comfyClass: "Lora Loader (LoraManager)",
widgets: [
{
value: "",
options: {},
callback: () => {},
},
],
addInput: vi.fn(),
mode: 0, // Initial mode
graph: {},
};
nodeType.prototype.onNodeCreated.call(node);
});
it("should handle mode property changes", () => {
const initialMode = node.mode;
expect(initialMode).toBe(0);
// Change mode from 0 to 3
node.mode = 3;
// Verify that the property was updated
expect(node.mode).toBe(3);
// Verify that updateConnectedTriggerWords was called
expect(updateConnectedTriggerWords).toHaveBeenCalledWith(
node,
expect.anything() // This would be the active Lora names set
);
});
it("should call onModeChange when mode property is changed", () => {
const initialMode = node.mode;
expect(initialMode).toBe(0);
// Mock console.log to verify it was called
const consoleSpy = vi.spyOn(console, 'log').mockImplementation(() => {});
// Change mode from 0 to 1
node.mode = 1;
// Verify console log was called
expect(consoleSpy).toHaveBeenCalledWith(
'[Lora Loader] Node mode changed from 0 to 1'
);
expect(consoleSpy).toHaveBeenCalledWith(
'Lora Loader node mode changed: from 0 to 1'
);
consoleSpy.mockRestore();
});
it("should update connected trigger words when mode changes", () => {
// Mock the collectActiveLorasFromChain to return a specific set
collectActiveLorasFromChain.mockImplementation(() => new Set(["LoaderLora1", "LoaderLora2"]));
// Change mode
node.mode = 2;
// Verify that collectActiveLorasFromChain and updateConnectedTriggerWords were called
expect(collectActiveLorasFromChain).toHaveBeenCalledWith(node);
expect(updateConnectedTriggerWords).toHaveBeenCalledWith(
node,
new Set(["LoaderLora1", "LoaderLora2"])
);
});
});
});

View File

@@ -118,6 +118,35 @@ app.registerExtension({
let isUpdating = false; let isUpdating = false;
let isSyncingInput = false; let isSyncingInput = false;
// Mechanism: Property descriptor to listen for mode changes
const self = this;
let _mode = this.mode;
Object.defineProperty(this, 'mode', {
get() {
return _mode;
},
set(value) {
const oldValue = _mode;
_mode = value;
// Trigger mode change handler
if (self.onModeChange) {
self.onModeChange(value, oldValue);
}
console.log(`[Lora Loader] Node mode changed from ${oldValue} to ${value}`);
}
});
// Define the mode change handler
this.onModeChange = function(newMode, oldMode) {
console.log(`Lora Loader node mode changed: from ${oldMode} to ${newMode}`);
// Update connected trigger word toggle nodes when mode changes
const allActiveLoraNames = collectActiveLorasFromChain(self);
updateConnectedTriggerWords(self, allActiveLoraNames);
};
const inputWidget = this.widgets[0]; const inputWidget = this.widgets[0];
inputWidget.options.getMaxHeight = () => 100; inputWidget.options.getMaxHeight = () => 100;
this.inputWidget = inputWidget; this.inputWidget = inputWidget;

View File

@@ -29,6 +29,34 @@ app.registerExtension({
let isUpdating = false; let isUpdating = false;
let isSyncingInput = false; let isSyncingInput = false;
// Mechanism 3: Property descriptor to listen for mode changes
const self = this;
let _mode = this.mode;
Object.defineProperty(this, 'mode', {
get() {
return _mode;
},
set(value) {
const oldValue = _mode;
_mode = value;
// Trigger mode change handler
if (self.onModeChange) {
self.onModeChange(value, oldValue);
}
}
});
// Define the mode change handler
this.onModeChange = function(newMode, oldMode) {
// Update connected trigger word toggle nodes and downstream loader trigger word toggle nodes
// when mode changes, similar to when loras change
const isNodeActive = newMode === 0 || newMode === 3; // Active when mode is Always (0) or On Trigger (3)
const activeLoraNames = isNodeActive ? getActiveLorasFromNode(self) : new Set();
updateConnectedTriggerWords(self, activeLoraNames);
updateDownstreamLoaders(self);
};
const inputWidget = this.widgets[0]; const inputWidget = this.widgets[0];
inputWidget.options.getMaxHeight = () => 100; inputWidget.options.getMaxHeight = () => 100;
this.inputWidget = inputWidget; this.inputWidget = inputWidget;