diff --git a/tests/frontend/components/loraLoader.triggerWords.test.js b/tests/frontend/components/loraLoader.triggerWords.test.js new file mode 100644 index 00000000..03ba38ba --- /dev/null +++ b/tests/frontend/components/loraLoader.triggerWords.test.js @@ -0,0 +1,131 @@ +import { describe, it, expect, beforeEach, vi } from "vitest"; + +const { + APP_MODULE, + API_MODULE, + UTILS_MODULE, + LORAS_WIDGET_MODULE, + LORA_LOADER_MODULE, +} = vi.hoisted(() => ({ + APP_MODULE: new URL("../../../scripts/app.js", import.meta.url).pathname, + API_MODULE: new URL("../../../scripts/api.js", import.meta.url).pathname, + UTILS_MODULE: new URL("../../../web/comfyui/utils.js", import.meta.url).pathname, + LORAS_WIDGET_MODULE: new URL("../../../web/comfyui/loras_widget.js", import.meta.url).pathname, + LORA_LOADER_MODULE: new URL("../../../web/comfyui/lora_loader.js", import.meta.url).pathname, +})); + +const extensionState = { current: null }; +const registerExtensionMock = vi.fn((extension) => { + extensionState.current = extension; +}); + +vi.mock(APP_MODULE, () => ({ + app: { + registerExtension: registerExtensionMock, + graph: {}, + }, +})); + +vi.mock(API_MODULE, () => ({ + api: { + addEventListener: vi.fn(), + }, +})); + +const collectActiveLorasFromChain = vi.fn(); +const updateConnectedTriggerWords = vi.fn(); +const mergeLoras = vi.fn(); +const setupInputWidgetWithAutocomplete = vi.fn(); +const getAllGraphNodes = vi.fn(); +const getNodeFromGraph = vi.fn(); + +vi.mock(UTILS_MODULE, () => ({ + collectActiveLorasFromChain, + updateConnectedTriggerWords, + mergeLoras, + setupInputWidgetWithAutocomplete, + chainCallback: (proto, property, callback) => { + proto[property] = callback; + }, + getAllGraphNodes, + getNodeFromGraph, + LORA_PATTERN: //g, +})); + +const addLorasWidget = vi.fn(); + +vi.mock(LORAS_WIDGET_MODULE, () => ({ + addLorasWidget, +})); + +describe("Lora Loader trigger word updates", () => { + beforeEach(() => { + vi.resetModules(); + + extensionState.current = null; + registerExtensionMock.mockClear(); + + collectActiveLorasFromChain.mockClear(); + collectActiveLorasFromChain.mockImplementation(() => new Set(["Alpha"])); + + updateConnectedTriggerWords.mockClear(); + + mergeLoras.mockClear(); + mergeLoras.mockImplementation(() => [{ name: "Alpha", active: true }]); + + setupInputWidgetWithAutocomplete.mockClear(); + setupInputWidgetWithAutocomplete.mockImplementation( + (_node, _widget, originalCallback) => originalCallback + ); + + addLorasWidget.mockClear(); + addLorasWidget.mockImplementation((_node, _name, _opts, callback) => ({ + widget: { value: [], callback }, + })); + }); + + it("refreshes trigger word toggles after LoRA syntax edits in the input widget", async () => { + await import(LORA_LOADER_MODULE); + + expect(registerExtensionMock).toHaveBeenCalled(); + const extension = extensionState.current; + expect(extension).toBeDefined(); + + const nodeType = { comfyClass: "Lora Loader (LoraManager)", prototype: {} }; + await extension.beforeRegisterNodeDef(nodeType, {}, {}); + + const node = { + comfyClass: "Lora Loader (LoraManager)", + widgets: [ + { + value: "", + options: {}, + inputEl: {}, + }, + ], + addInput: vi.fn(), + graph: {}, + }; + + nodeType.prototype.onNodeCreated.call(node); + + expect(setupInputWidgetWithAutocomplete).toHaveBeenCalled(); + expect(node.lorasWidget).toBeDefined(); + + const inputCallback = node.widgets[0].callback; + expect(typeof inputCallback).toBe("function"); + + inputCallback(""); + + expect(mergeLoras).toHaveBeenCalledWith("", []); + expect(node.lorasWidget.value).toEqual([{ name: "Alpha", active: true }]); + expect(collectActiveLorasFromChain).toHaveBeenCalledWith(node); + + const activeSet = collectActiveLorasFromChain.mock.results.at(-1)?.value; + const [[targetNode, triggerWordSet]] = updateConnectedTriggerWords.mock.calls; + expect(targetNode).toBe(node); + expect(triggerWordSet).toBe(activeSet); + expect([...triggerWordSet]).toEqual(["Alpha"]); + }); +}); + diff --git a/web/comfyui/lora_loader.js b/web/comfyui/lora_loader.js index 382c1970..613177ba 100644 --- a/web/comfyui/lora_loader.js +++ b/web/comfyui/lora_loader.js @@ -178,6 +178,9 @@ app.registerExtension({ const mergedLoras = mergeLoras(value, currentLoras); this.lorasWidget.value = mergedLoras; + + const allActiveLoraNames = collectActiveLorasFromChain(this); + updateConnectedTriggerWords(this, allActiveLoraNames); } finally { isUpdating = false; }