diff --git a/py/routes/handlers/misc_handlers.py b/py/routes/handlers/misc_handlers.py index 1c9564be..9015293c 100644 --- a/py/routes/handlers/misc_handlers.py +++ b/py/routes/handlers/misc_handlers.py @@ -80,7 +80,7 @@ class NodeRegistry: def __init__(self) -> None: self._lock = asyncio.Lock() - self._nodes: Dict[int, dict] = {} + self._nodes: Dict[str, dict] = {} self._registry_updated = asyncio.Event() async def register_nodes(self, nodes: list[dict]) -> None: @@ -88,11 +88,16 @@ class NodeRegistry: self._nodes.clear() for node in nodes: node_id = node["node_id"] + graph_id = str(node["graph_id"]) + unique_id = f"{graph_id}:{node_id}" node_type = node.get("type", "") type_id = NODE_TYPES.get(node_type, 0) bgcolor = node.get("bgcolor") or DEFAULT_NODE_COLOR - self._nodes[node_id] = { + self._nodes[unique_id] = { "id": node_id, + "graph_id": graph_id, + "graph_name": node.get("graph_name"), + "unique_id": unique_id, "bgcolor": bgcolor, "title": node.get("title"), "type": type_id, @@ -330,16 +335,65 @@ class LoraCodeHandler: logger.error("Error broadcasting lora code: %s", exc) results.append({"node_id": "broadcast", "success": False, "error": str(exc)}) else: - for node_id in node_ids: + for entry in node_ids: + node_identifier = entry + graph_identifier = None + if isinstance(entry, dict): + node_identifier = entry.get("node_id") + graph_identifier = entry.get("graph_id") + + if node_identifier is None: + results.append( + { + "node_id": node_identifier, + "graph_id": graph_identifier, + "success": False, + "error": "Missing node_id parameter", + } + ) + continue + + try: + parsed_node_id = int(node_identifier) + except (TypeError, ValueError): + parsed_node_id = node_identifier + + payload = { + "id": parsed_node_id, + "lora_code": lora_code, + "mode": mode, + } + + if graph_identifier is not None: + payload["graph_id"] = str(graph_identifier) + try: self._prompt_server.instance.send_sync( "lora_code_update", - {"id": node_id, "lora_code": lora_code, "mode": mode}, + payload, + ) + results.append( + { + "node_id": parsed_node_id, + "graph_id": payload.get("graph_id"), + "success": True, + } ) - results.append({"node_id": node_id, "success": True}) except Exception as exc: # pragma: no cover - defensive logging - logger.error("Error sending lora code to node %s: %s", node_id, exc) - results.append({"node_id": node_id, "success": False, "error": str(exc)}) + logger.error( + "Error sending lora code to node %s (graph %s): %s", + parsed_node_id, + graph_identifier, + exc, + ) + results.append( + { + "node_id": parsed_node_id, + "graph_id": payload.get("graph_id"), + "success": False, + "error": str(exc), + } + ) return web.json_response({"success": True, "results": results}) except Exception as exc: # pragma: no cover - defensive logging @@ -679,10 +733,21 @@ class NodeRegistryHandler: node_id = node.get("node_id") if node_id is None: return web.json_response({"success": False, "error": f"Node {index} missing node_id parameter"}, status=400) + graph_id = node.get("graph_id") + if graph_id is None: + return web.json_response({"success": False, "error": f"Node {index} missing graph_id parameter"}, status=400) + graph_name = node.get("graph_name") try: node["node_id"] = int(node_id) except (TypeError, ValueError): return web.json_response({"success": False, "error": f"Node {index} node_id must be an integer"}, status=400) + node["graph_id"] = str(graph_id) + if graph_name is None: + node["graph_name"] = None + elif isinstance(graph_name, str): + node["graph_name"] = graph_name + else: + node["graph_name"] = str(graph_name) await self._node_registry.register_nodes(nodes) return web.json_response({"success": True, "message": f"{len(nodes)} nodes registered successfully"}) diff --git a/py/routes/lora_routes.py b/py/routes/lora_routes.py index ee6bc151..15b6a0b7 100644 --- a/py/routes/lora_routes.py +++ b/py/routes/lora_routes.py @@ -229,11 +229,27 @@ class LoraRoutes(BaseModelRoutes): trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else "" # Send update to all connected trigger word toggle nodes - for node_id in node_ids: - PromptServer.instance.send_sync("trigger_word_update", { - "id": node_id, + for entry in node_ids: + node_identifier = entry + graph_identifier = None + if isinstance(entry, dict): + node_identifier = entry.get("node_id") + graph_identifier = entry.get("graph_id") + + try: + parsed_node_id = int(node_identifier) + except (TypeError, ValueError): + parsed_node_id = node_identifier + + payload = { + "id": parsed_node_id, "message": trigger_words_text - }) + } + + if graph_identifier is not None: + payload["graph_id"] = str(graph_identifier) + + PromptServer.instance.send_sync("trigger_word_update", payload) return web.json_response({"success": True}) diff --git a/static/js/utils/uiHelpers.js b/static/js/utils/uiHelpers.js index 3e02d086..2a43a2b6 100644 --- a/static/js/utils/uiHelpers.js +++ b/static/js/utils/uiHelpers.js @@ -435,8 +435,9 @@ export async function sendLoraToWorkflow(loraSyntax, replaceMode = false, syntax return true; } else { // Single node - send directly - const nodeId = Object.keys(registryData.data.nodes)[0]; - return await sendToSpecificNode([nodeId], loraSyntax, replaceMode, syntaxType); + const nodes = registryData.data.nodes; + const nodeId = Object.keys(nodes)[0]; + return await sendToSpecificNode([nodeId], nodes, loraSyntax, replaceMode, syntaxType); } } catch (error) { console.error('Failed to get registry:', error); @@ -452,19 +453,65 @@ export async function sendLoraToWorkflow(loraSyntax, replaceMode = false, syntax * @param {boolean} replaceMode - Whether to replace existing LoRAs * @param {string} syntaxType - The type of syntax ('lora' or 'recipe') */ -async function sendToSpecificNode(nodeIds, loraSyntax, replaceMode, syntaxType) { +function resolveNodeReference(nodeKey, nodesMap) { + if (!nodeKey) { + return null; + } + + const directMatch = nodesMap?.[nodeKey]; + if (directMatch) { + return { + node_id: directMatch.id, + graph_id: directMatch.graph_id ?? null, + }; + } + + if (typeof nodeKey === 'string' && nodeKey.includes(':')) { + const [graphId, ...rest] = nodeKey.split(':'); + const nodeIdPart = rest.join(':'); + const numericNodeId = Number(nodeIdPart); + return { + node_id: Number.isNaN(numericNodeId) ? nodeIdPart : numericNodeId, + graph_id: graphId || null, + }; + } + + const numericId = Number(nodeKey); + return { + node_id: Number.isNaN(numericId) ? nodeKey : numericId, + graph_id: null, + }; +} + +async function sendToSpecificNode(nodeIds, nodesMap, loraSyntax, replaceMode, syntaxType) { try { // Call the backend API to update the lora code + const requestBody = { + lora_code: loraSyntax, + mode: replaceMode ? 'replace' : 'append' + }; + + if (Array.isArray(nodeIds)) { + const references = nodeIds + .map((nodeKey) => resolveNodeReference(nodeKey, nodesMap)) + .filter((reference) => reference && reference.node_id !== undefined); + + if (references.length > 0) { + requestBody.node_ids = references; + } + } else if (nodeIds) { + const reference = resolveNodeReference(nodeIds, nodesMap); + if (reference) { + requestBody.node_ids = [reference]; + } + } + const response = await fetch('/api/lm/update-lora-code', { method: 'POST', headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify({ - node_ids: nodeIds, - lora_code: loraSyntax, - mode: replaceMode ? 'replace' : 'append' - }) + body: JSON.stringify(requestBody) }); const result = await response.json(); @@ -522,16 +569,17 @@ function showNodeSelector(nodes, loraSyntax, replaceMode, syntaxType) { hideNodeSelector(); // Generate node list HTML with icons and proper colors - const nodeItems = Object.values(nodes).map(node => { + const nodeItems = Object.entries(nodes).map(([nodeKey, node]) => { const iconClass = NODE_TYPE_ICONS[node.type] || 'fas fa-question-circle'; const bgColor = node.bgcolor || DEFAULT_NODE_COLOR; - + const graphLabel = node.graph_name ? ` (${node.graph_name})` : ''; + return ` -
+
- #${node.id} ${node.title} + #${node.id}${graphLabel} ${node.title}
`; }).join(''); @@ -610,10 +658,10 @@ function setupNodeSelectorEvents(selector, nodes, loraSyntax, replaceMode, synta if (action === 'send-all') { // Send to all nodes const allNodeIds = Object.keys(nodes); - await sendToSpecificNode(allNodeIds, loraSyntax, replaceMode, syntaxType); + await sendToSpecificNode(allNodeIds, nodes, loraSyntax, replaceMode, syntaxType); } else if (nodeId) { // Send to specific node - await sendToSpecificNode([nodeId], loraSyntax, replaceMode, syntaxType); + await sendToSpecificNode([nodeId], nodes, loraSyntax, replaceMode, syntaxType); } hideNodeSelector(); diff --git a/tests/frontend/utils/uiHelpers.dom.test.js b/tests/frontend/utils/uiHelpers.dom.test.js index 94739553..2c74dcd4 100644 --- a/tests/frontend/utils/uiHelpers.dom.test.js +++ b/tests/frontend/utils/uiHelpers.dom.test.js @@ -46,6 +46,7 @@ vi.mock(EVENT_MANAGER_MODULE, () => ({ off: vi.fn(), addHandler: vi.fn(), removeHandler: vi.fn(), + setState: vi.fn(), }, })); @@ -62,6 +63,7 @@ describe('UI helper DOM utilities', () => { afterEach(() => { vi.useRealTimers(); + delete global.fetch; }); it('creates toast elements and cleans them up after timeout', async () => { @@ -105,4 +107,53 @@ describe('UI helper DOM utilities', () => { expect(document.body.dataset.theme).toBe('dark'); expect(document.querySelector('.theme-toggle').classList.contains('theme-dark')).toBe(true); }); + + it('renders subgraph names in the node selector list', async () => { + const registryResponse = { + success: true, + data: { + node_count: 2, + nodes: { + 'root:1': { + id: 1, + graph_id: 'root', + graph_name: null, + title: 'Root Loader', + type: 1, + bgcolor: '#123456', + }, + 'subgraph-uuid:2': { + id: 2, + graph_id: 'subgraph-uuid', + graph_name: 'Character Subgraph', + title: 'Nested Loader', + type: 1, + bgcolor: '#654321', + }, + }, + }, + }; + + global.fetch = vi.fn().mockResolvedValue({ + json: async () => registryResponse, + }); + + document.body.innerHTML = '
'; + + const { sendLoraToWorkflow } = await import(UI_HELPERS_MODULE); + + const result = await sendLoraToWorkflow(''); + + expect(result).toBe(true); + expect(global.fetch).toHaveBeenCalledWith('/api/lm/get-registry'); + + const nodeLabels = Array.from( + document.querySelectorAll('#nodeSelector .node-item[data-node-id] span') + ).map((span) => span.textContent.trim()); + + expect(nodeLabels).toEqual([ + '#1 Root Loader', + '#2 (Character Subgraph) Nested Loader', + ]); + }); }); diff --git a/tests/routes/test_lora_routes.py b/tests/routes/test_lora_routes.py index 2b447987..b2e15f01 100644 --- a/tests/routes/test_lora_routes.py +++ b/tests/routes/test_lora_routes.py @@ -188,7 +188,7 @@ async def test_get_trigger_words_broadcasts(monkeypatch, routes): monkeypatch.setattr("py.routes.lora_routes.get_lora_info", lambda name: (f"path/{name}", [f"trigger-{name}"])) - request = DummyRequest(json_data={"lora_names": ["one"], "node_ids": ["node"]}) + request = DummyRequest(json_data={"lora_names": ["one"], "node_ids": [{"node_id": "node", "graph_id": "graph-1"}]}) response = await routes.get_trigger_words(request) payload = json.loads(response.text) @@ -196,7 +196,7 @@ async def test_get_trigger_words_broadcasts(monkeypatch, routes): assert payload == {"success": True} send_mock.assert_called_once_with( "trigger_word_update", - {"id": "node", "message": "trigger-one"}, + {"id": "node", "graph_id": "graph-1", "message": "trigger-one"}, ) diff --git a/tests/routes/test_misc_routes.py b/tests/routes/test_misc_routes.py index 28ac816f..5ce9b822 100644 --- a/tests/routes/test_misc_routes.py +++ b/tests/routes/test_misc_routes.py @@ -4,7 +4,13 @@ from types import SimpleNamespace import pytest from aiohttp import web -from py.routes.handlers.misc_handlers import SettingsHandler, ServiceRegistryAdapter +from py.routes.handlers.misc_handlers import ( + LoraCodeHandler, + NodeRegistry, + NodeRegistryHandler, + ServiceRegistryAdapter, + SettingsHandler, +) from py.routes.misc_route_registrar import MISC_ROUTE_DEFINITIONS, MiscRouteRegistrar from py.routes.misc_routes import MiscRoutes @@ -126,6 +132,128 @@ class FakePromptServer: instance = Instance() +@pytest.mark.asyncio +async def test_register_nodes_requires_graph_id(): + node_registry = NodeRegistry() + handler = NodeRegistryHandler( + node_registry=node_registry, + prompt_server=FakePromptServer, + standalone_mode=False, + ) + + request = FakeRequest(json_data={"nodes": [{"node_id": 1}]}) + response = await handler.register_nodes(request) + payload = json.loads(response.text) + + assert response.status == 400 + assert payload["success"] is False + assert "graph_id" in payload["error"] + + +@pytest.mark.asyncio +async def test_register_nodes_stores_graph_identifier(): + node_registry = NodeRegistry() + handler = NodeRegistryHandler( + node_registry=node_registry, + prompt_server=FakePromptServer, + standalone_mode=False, + ) + + request = FakeRequest( + json_data={ + "nodes": [ + { + "node_id": 7, + "graph_id": "graph-123", + "graph_name": "Character Subgraph", + "type": "Lora Loader (LoraManager)", + "title": "Loader", + } + ] + } + ) + + response = await handler.register_nodes(request) + payload = json.loads(response.text) + + assert payload["success"] is True + + registry = await node_registry.get_registry() + assert registry["node_count"] == 1 + stored_node = next(iter(registry["nodes"].values())) + assert stored_node["graph_id"] == "graph-123" + assert stored_node["unique_id"] == "graph-123:7" + assert stored_node["graph_name"] == "Character Subgraph" + + +@pytest.mark.asyncio +async def test_register_nodes_defaults_graph_name_to_none(): + node_registry = NodeRegistry() + handler = NodeRegistryHandler( + node_registry=node_registry, + prompt_server=FakePromptServer, + standalone_mode=False, + ) + + request = FakeRequest( + json_data={ + "nodes": [ + { + "node_id": 8, + "graph_id": "root", + "type": "Lora Loader (LoraManager)", + "title": "Root Loader", + } + ] + } + ) + + response = await handler.register_nodes(request) + payload = json.loads(response.text) + + assert payload["success"] is True + + registry = await node_registry.get_registry() + stored_node = next(iter(registry["nodes"].values())) + assert stored_node["graph_name"] is None + + +@pytest.mark.asyncio +async def test_update_lora_code_includes_graph_identifier(): + send_calls: list[tuple[str, dict]] = [] + + class RecordingPromptServer: + class Instance: + def send_sync(self, event, payload): + send_calls.append((event, payload)) + + instance = Instance() + + handler = LoraCodeHandler(RecordingPromptServer) + + request = FakeRequest( + json_data={ + "node_ids": [{"node_id": 3, "graph_id": "graph-A"}], + "lora_code": "", + "mode": "replace", + } + ) + + response = await handler.update_lora_code(request) + payload = json.loads(response.text) + + assert payload["success"] is True + assert payload["results"] == [ + {"node_id": 3, "graph_id": "graph-A", "success": True} + ] + assert send_calls == [ + ( + "lora_code_update", + {"id": 3, "graph_id": "graph-A", "lora_code": "", "mode": "replace"}, + ) + ] + + class FakeScanner: async def check_model_version_exists(self, _version_id): return False diff --git a/web/comfyui/debug_metadata.js b/web/comfyui/debug_metadata.js index 1c86a177..bcbb893f 100644 --- a/web/comfyui/debug_metadata.js +++ b/web/comfyui/debug_metadata.js @@ -1,6 +1,7 @@ import { app } from "../../scripts/app.js"; import { api } from "../../scripts/api.js"; import { addJsonDisplayWidget } from "./json_display_widget.js"; +import { getNodeFromGraph } from "./utils.js"; app.registerExtension({ name: "LoraManager.DebugMetadata", @@ -8,8 +9,8 @@ app.registerExtension({ setup() { // Add message handler to listen for metadata updates from Python api.addEventListener("metadata_update", (event) => { - const { id, metadata } = event.detail; - this.handleMetadataUpdate(id, metadata); + const { id, graph_id: graphId, metadata } = event.detail; + this.handleMetadataUpdate(id, graphId, metadata); }); }, @@ -37,8 +38,8 @@ app.registerExtension({ }, // Handle metadata updates from Python - handleMetadataUpdate(id, metadata) { - const node = app.graph.getNodeById(+id); + handleMetadataUpdate(id, graphId, metadata) { + const node = getNodeFromGraph(graphId, id); if (!node || node.comfyClass !== "Debug Metadata (LoraManager)") { console.warn("Node not found or not a DebugMetadata node:", id); return; diff --git a/web/comfyui/lora_loader.js b/web/comfyui/lora_loader.js index f13f48a6..4e672be6 100644 --- a/web/comfyui/lora_loader.js +++ b/web/comfyui/lora_loader.js @@ -7,6 +7,8 @@ import { chainCallback, mergeLoras, setupInputWidgetWithAutocomplete, + getAllGraphNodes, + getNodeFromGraph, } from "./utils.js"; import { addLorasWidget } from "./loras_widget.js"; @@ -16,23 +18,26 @@ app.registerExtension({ setup() { // Add message handler to listen for messages from Python api.addEventListener("lora_code_update", (event) => { - const { id, lora_code, mode } = event.detail; - this.handleLoraCodeUpdate(id, lora_code, mode); + this.handleLoraCodeUpdate(event.detail || {}); }); }, // Handle lora code updates from Python - handleLoraCodeUpdate(id, loraCode, mode) { + handleLoraCodeUpdate(message) { + const nodeId = message?.node_id ?? message?.id; + const graphId = message?.graph_id; + const loraCode = message?.lora_code ?? ""; + const mode = message?.mode ?? "append"; + + const numericNodeId = + typeof nodeId === "string" ? Number(nodeId) : nodeId; + // Handle broadcast mode (for Desktop/non-browser support) - if (id === -1) { + if (numericNodeId === -1) { // Find all Lora Loader nodes in the current graph - const loraLoaderNodes = []; - for (const nodeId in app.graph._nodes_by_id) { - const node = app.graph._nodes_by_id[nodeId]; - if (node.comfyClass === "Lora Loader (LoraManager)") { - loraLoaderNodes.push(node); - } - } + const loraLoaderNodes = getAllGraphNodes(app.graph) + .map(({ node }) => node) + .filter((node) => node?.comfyClass === "Lora Loader (LoraManager)"); // Update each Lora Loader node found if (loraLoaderNodes.length > 0) { @@ -52,14 +57,18 @@ app.registerExtension({ } // Standard mode - update a specific node - const node = app.graph.getNodeById(+id); + const node = getNodeFromGraph(graphId, numericNodeId); if ( !node || (node.comfyClass !== "Lora Loader (LoraManager)" && node.comfyClass !== "Lora Stacker (LoraManager)" && node.comfyClass !== "WanVideo Lora Select (LoraManager)") ) { - console.warn("Node not found or not a LoraLoader:", id); + console.warn( + "Node not found or not a LoraLoader:", + graphId ?? "root", + nodeId + ); return; } diff --git a/web/comfyui/lora_stacker.js b/web/comfyui/lora_stacker.js index 5648891d..af07ee7c 100644 --- a/web/comfyui/lora_stacker.js +++ b/web/comfyui/lora_stacker.js @@ -7,6 +7,8 @@ import { chainCallback, mergeLoras, setupInputWidgetWithAutocomplete, + getLinkFromGraph, + getNodeKey, } from "./utils.js"; import { addLorasWidget } from "./loras_widget.js"; @@ -124,17 +126,18 @@ app.registerExtension({ // Helper function to find and update downstream Lora Loader nodes function updateDownstreamLoaders(startNode, visited = new Set()) { - if (visited.has(startNode.id)) return; - visited.add(startNode.id); + const nodeKey = getNodeKey(startNode); + if (!nodeKey || visited.has(nodeKey)) return; + visited.add(nodeKey); // Check each output link if (startNode.outputs) { for (const output of startNode.outputs) { if (output.links) { for (const linkId of output.links) { - const link = app.graph.links[linkId]; + const link = getLinkFromGraph(startNode.graph, linkId); if (link) { - const targetNode = app.graph.getNodeById(link.target_id); + const targetNode = startNode.graph?.getNodeById?.(link.target_id); // If target is a Lora Loader, collect all active loras in the chain and update if ( diff --git a/web/comfyui/trigger_word_toggle.js b/web/comfyui/trigger_word_toggle.js index 57541bd4..47b7a0a0 100644 --- a/web/comfyui/trigger_word_toggle.js +++ b/web/comfyui/trigger_word_toggle.js @@ -1,6 +1,6 @@ import { app } from "../../scripts/app.js"; import { api } from "../../scripts/api.js"; -import { CONVERTED_TYPE } from "./utils.js"; +import { CONVERTED_TYPE, getNodeFromGraph } from "./utils.js"; import { addTagsWidget } from "./tags_widget.js"; // TriggerWordToggle extension for ComfyUI @@ -10,8 +10,8 @@ app.registerExtension({ setup() { // Add message handler to listen for messages from Python api.addEventListener("trigger_word_update", (event) => { - const { id, message } = event.detail; - this.handleTriggerWordUpdate(id, message); + const { id, graph_id: graphId, message } = event.detail; + this.handleTriggerWordUpdate(id, graphId, message); }); }, @@ -76,8 +76,8 @@ app.registerExtension({ }, // Handle trigger word updates from Python - handleTriggerWordUpdate(id, message) { - const node = app.graph.getNodeById(+id); + handleTriggerWordUpdate(id, graphId, message) { + const node = getNodeFromGraph(graphId, id); if (!node || node.comfyClass !== "TriggerWord Toggle (LoraManager)") { console.warn("Node not found or not a TriggerWordToggle:", id); return; diff --git a/web/comfyui/usage_stats.js b/web/comfyui/usage_stats.js index b89eaf46..d55624c2 100644 --- a/web/comfyui/usage_stats.js +++ b/web/comfyui/usage_stats.js @@ -1,7 +1,7 @@ // ComfyUI extension to track model usage statistics import { app } from "../../scripts/app.js"; import { api } from "../../scripts/api.js"; -import { showToast } from "./utils.js"; +import { getAllGraphNodes, getNodeReference, showToast } from "./utils.js"; // Define target nodes and their widget configurations const PATH_CORRECTION_TARGETS = [ @@ -56,25 +56,35 @@ app.registerExtension({ async refreshRegistry() { try { - // Get current workflow nodes - const prompt = await app.graphToPrompt(); - const workflow = prompt.workflow; - if (!workflow || !workflow.nodes) { - console.warn("No workflow nodes found for registry refresh"); - return; - } - - // Find all Lora nodes const loraNodes = []; - for (const node of workflow.nodes.values()) { - if (node.type === "Lora Loader (LoraManager)" || - node.type === "Lora Stacker (LoraManager)" || - node.type === "WanVideo Lora Select (LoraManager)") { + const nodeEntries = getAllGraphNodes(app.graph); + + for (const { graph, node } of nodeEntries) { + if (!node || !node.comfyClass) { + continue; + } + + if ( + node.comfyClass === "Lora Loader (LoraManager)" || + node.comfyClass === "Lora Stacker (LoraManager)" || + node.comfyClass === "WanVideo Lora Select (LoraManager)" + ) { + const reference = getNodeReference(node); + if (!reference) { + continue; + } + + const graphName = typeof graph?.name === "string" && graph.name.trim() + ? graph.name + : null; + loraNodes.push({ - node_id: node.id, - bgcolor: node.bgcolor || null, - title: node.title || node.type, - type: node.type + node_id: reference.node_id, + graph_id: reference.graph_id, + graph_name: graphName, + bgcolor: node.bgcolor ?? node.color ?? null, + title: node.title || node.comfyClass, + type: node.comfyClass, }); } } diff --git a/web/comfyui/utils.js b/web/comfyui/utils.js index 8060414d..54f41f24 100644 --- a/web/comfyui/utils.js +++ b/web/comfyui/utils.js @@ -2,6 +2,120 @@ export const CONVERTED_TYPE = 'converted-widget'; import { app } from "../../scripts/app.js"; import { AutoComplete } from "./autocomplete.js"; +const ROOT_GRAPH_ID = "root"; + +function isMapLike(collection) { + return collection && typeof collection.entries === "function" && typeof collection.values === "function"; +} + +function getChildGraphs(graph) { + if (!graph || !graph._subgraphs) { + return []; + } + + const rawSubgraphs = isMapLike(graph._subgraphs) + ? Array.from(graph._subgraphs.values()) + : Object.values(graph._subgraphs); + + return rawSubgraphs + .map((subgraph) => subgraph?.graph || subgraph?._graph || subgraph) + .filter((subgraph) => subgraph && subgraph !== graph); +} + +function traverseGraphs(rootGraph, visitor, visited = new Set()) { + const graph = rootGraph || app.graph; + if (!graph) { + return; + } + + const graphId = getGraphId(graph); + if (visited.has(graphId)) { + return; + } + visited.add(graphId); + visitor(graph); + + for (const subgraph of getChildGraphs(graph)) { + traverseGraphs(subgraph, visitor, visited); + } +} + +export function getGraphId(graph) { + return graph?.id ?? ROOT_GRAPH_ID; +} + +export function getNodeGraphId(node) { + if (!node) { + return ROOT_GRAPH_ID; + } + return getGraphId(node.graph || app.graph); +} + +export function getGraphById(graphId, rootGraph = app.graph) { + if (!graphId) { + return rootGraph; + } + + let foundGraph = null; + traverseGraphs(rootGraph, (graph) => { + if (!foundGraph && getGraphId(graph) === graphId) { + foundGraph = graph; + } + }); + return foundGraph; +} + +export function getNodeFromGraph(graphId, nodeId) { + const graph = getGraphById(graphId) || app.graph; + if (!graph || typeof graph.getNodeById !== "function") { + return null; + } + + const numericId = typeof nodeId === "string" ? Number(nodeId) : nodeId; + return graph.getNodeById(Number.isNaN(numericId) ? nodeId : numericId) || null; +} + +export function getAllGraphNodes(rootGraph = app.graph) { + const nodes = []; + traverseGraphs(rootGraph, (graph) => { + if (Array.isArray(graph._nodes)) { + for (const node of graph._nodes) { + nodes.push({ graph, node }); + } + } + }); + return nodes; +} + +export function getNodeReference(node) { + if (!node) { + return null; + } + return { + node_id: node.id, + graph_id: getNodeGraphId(node), + }; +} + +export function getNodeKey(node) { + if (!node) { + return null; + } + return `${getNodeGraphId(node)}:${node.id}`; +} + +export function getLinkFromGraph(graph, linkId) { + if (!graph || graph.links == null) { + return null; + } + + if (isMapLike(graph.links)) { + return graph.links.get(linkId) || null; + } + + return graph.links[linkId] || null; +} + export function chainCallback(object, property, callback) { if (object == undefined) { //This should not happen. @@ -103,42 +217,56 @@ export const LORA_PATTERN = //g; // Get connected Lora Stacker nodes that feed into the current node export function getConnectedInputStackers(node) { const connectedStackers = []; - - if (node.inputs) { - for (const input of node.inputs) { - if (input.name === "lora_stack" && input.link) { - const link = app.graph.links[input.link]; - if (link) { - const sourceNode = app.graph.getNodeById(link.origin_id); - if (sourceNode && sourceNode.comfyClass === "Lora Stacker (LoraManager)") { - connectedStackers.push(sourceNode); - } - } - } + + if (!node?.inputs) { + return connectedStackers; + } + + for (const input of node.inputs) { + if (input.name !== "lora_stack" || !input.link) { + continue; + } + + const link = getLinkFromGraph(node.graph, input.link); + if (!link) { + continue; + } + + const sourceNode = node.graph?.getNodeById?.(link.origin_id); + if (sourceNode && sourceNode.comfyClass === "Lora Stacker (LoraManager)") { + connectedStackers.push(sourceNode); } } + return connectedStackers; } // Get connected TriggerWord Toggle nodes that receive output from the current node export function getConnectedTriggerToggleNodes(node) { const connectedNodes = []; - - if (node.outputs && node.outputs.length > 0) { - for (const output of node.outputs) { - if (output.links && output.links.length > 0) { - for (const linkId of output.links) { - const link = app.graph.links[linkId]; - if (link) { - const targetNode = app.graph.getNodeById(link.target_id); - if (targetNode && targetNode.comfyClass === "TriggerWord Toggle (LoraManager)") { - connectedNodes.push(targetNode.id); - } - } - } + + if (!node?.outputs) { + return connectedNodes; + } + + for (const output of node.outputs) { + if (!output?.links?.length) { + continue; + } + + for (const linkId of output.links) { + const link = getLinkFromGraph(node.graph, linkId); + if (!link) { + continue; + } + + const targetNode = node.graph?.getNodeById?.(link.target_id); + if (targetNode && targetNode.comfyClass === "TriggerWord Toggle (LoraManager)") { + connectedNodes.push(targetNode); } } } + return connectedNodes; } @@ -161,11 +289,15 @@ export function getActiveLorasFromNode(node) { // Recursively collect all active loras from a node and its input chain export function collectActiveLorasFromChain(node, visited = new Set()) { // Prevent infinite loops from circular references - if (visited.has(node.id)) { + const nodeKey = getNodeKey(node); + if (!nodeKey) { return new Set(); } - visited.add(node.id); - + if (visited.has(nodeKey)) { + return new Set(); + } + visited.add(nodeKey); + // Get active loras from current node const allActiveLoraNames = getActiveLorasFromNode(node); @@ -181,14 +313,22 @@ export function collectActiveLorasFromChain(node, visited = new Set()) { // Update trigger words for connected toggle nodes export function updateConnectedTriggerWords(node, loraNames) { - const connectedNodeIds = getConnectedTriggerToggleNodes(node); - if (connectedNodeIds.length > 0) { + const connectedNodes = getConnectedTriggerToggleNodes(node); + if (connectedNodes.length > 0) { + const nodeIds = connectedNodes + .map((connectedNode) => getNodeReference(connectedNode)) + .filter((reference) => reference !== null); + + if (nodeIds.length === 0) { + return; + } + fetch("/api/lm/loras/get_trigger_words", { method: "POST", headers: { "Content-Type": "application/json" }, body: JSON.stringify({ lora_names: Array.from(loraNames), - node_ids: connectedNodeIds + node_ids: nodeIds }) }).catch(err => console.error("Error fetching trigger words:", err)); }