feat(graph): enhance node handling with graph identifiers and improve metadata updates, see #408, #538

This commit is contained in:
Will Miao
2025-10-07 23:22:38 +08:00
parent 9199950b74
commit 3118f3b43c
12 changed files with 574 additions and 103 deletions

View File

@@ -2,6 +2,120 @@ export const CONVERTED_TYPE = 'converted-widget';
import { app } from "../../scripts/app.js";
import { AutoComplete } from "./autocomplete.js";
const ROOT_GRAPH_ID = "root";
function isMapLike(collection) {
return collection && typeof collection.entries === "function" && typeof collection.values === "function";
}
function getChildGraphs(graph) {
if (!graph || !graph._subgraphs) {
return [];
}
const rawSubgraphs = isMapLike(graph._subgraphs)
? Array.from(graph._subgraphs.values())
: Object.values(graph._subgraphs);
return rawSubgraphs
.map((subgraph) => subgraph?.graph || subgraph?._graph || subgraph)
.filter((subgraph) => subgraph && subgraph !== graph);
}
function traverseGraphs(rootGraph, visitor, visited = new Set()) {
const graph = rootGraph || app.graph;
if (!graph) {
return;
}
const graphId = getGraphId(graph);
if (visited.has(graphId)) {
return;
}
visited.add(graphId);
visitor(graph);
for (const subgraph of getChildGraphs(graph)) {
traverseGraphs(subgraph, visitor, visited);
}
}
export function getGraphId(graph) {
return graph?.id ?? ROOT_GRAPH_ID;
}
export function getNodeGraphId(node) {
if (!node) {
return ROOT_GRAPH_ID;
}
return getGraphId(node.graph || app.graph);
}
export function getGraphById(graphId, rootGraph = app.graph) {
if (!graphId) {
return rootGraph;
}
let foundGraph = null;
traverseGraphs(rootGraph, (graph) => {
if (!foundGraph && getGraphId(graph) === graphId) {
foundGraph = graph;
}
});
return foundGraph;
}
export function getNodeFromGraph(graphId, nodeId) {
const graph = getGraphById(graphId) || app.graph;
if (!graph || typeof graph.getNodeById !== "function") {
return null;
}
const numericId = typeof nodeId === "string" ? Number(nodeId) : nodeId;
return graph.getNodeById(Number.isNaN(numericId) ? nodeId : numericId) || null;
}
export function getAllGraphNodes(rootGraph = app.graph) {
const nodes = [];
traverseGraphs(rootGraph, (graph) => {
if (Array.isArray(graph._nodes)) {
for (const node of graph._nodes) {
nodes.push({ graph, node });
}
}
});
return nodes;
}
export function getNodeReference(node) {
if (!node) {
return null;
}
return {
node_id: node.id,
graph_id: getNodeGraphId(node),
};
}
export function getNodeKey(node) {
if (!node) {
return null;
}
return `${getNodeGraphId(node)}:${node.id}`;
}
export function getLinkFromGraph(graph, linkId) {
if (!graph || graph.links == null) {
return null;
}
if (isMapLike(graph.links)) {
return graph.links.get(linkId) || null;
}
return graph.links[linkId] || null;
}
export function chainCallback(object, property, callback) {
if (object == undefined) {
//This should not happen.
@@ -103,42 +217,56 @@ export const LORA_PATTERN = /<lora:([^:]+):([-\d\.]+)(?::([-\d\.]+))?>/g;
// Get connected Lora Stacker nodes that feed into the current node
export function getConnectedInputStackers(node) {
const connectedStackers = [];
if (node.inputs) {
for (const input of node.inputs) {
if (input.name === "lora_stack" && input.link) {
const link = app.graph.links[input.link];
if (link) {
const sourceNode = app.graph.getNodeById(link.origin_id);
if (sourceNode && sourceNode.comfyClass === "Lora Stacker (LoraManager)") {
connectedStackers.push(sourceNode);
}
}
}
if (!node?.inputs) {
return connectedStackers;
}
for (const input of node.inputs) {
if (input.name !== "lora_stack" || !input.link) {
continue;
}
const link = getLinkFromGraph(node.graph, input.link);
if (!link) {
continue;
}
const sourceNode = node.graph?.getNodeById?.(link.origin_id);
if (sourceNode && sourceNode.comfyClass === "Lora Stacker (LoraManager)") {
connectedStackers.push(sourceNode);
}
}
return connectedStackers;
}
// Get connected TriggerWord Toggle nodes that receive output from the current node
export function getConnectedTriggerToggleNodes(node) {
const connectedNodes = [];
if (node.outputs && node.outputs.length > 0) {
for (const output of node.outputs) {
if (output.links && output.links.length > 0) {
for (const linkId of output.links) {
const link = app.graph.links[linkId];
if (link) {
const targetNode = app.graph.getNodeById(link.target_id);
if (targetNode && targetNode.comfyClass === "TriggerWord Toggle (LoraManager)") {
connectedNodes.push(targetNode.id);
}
}
}
if (!node?.outputs) {
return connectedNodes;
}
for (const output of node.outputs) {
if (!output?.links?.length) {
continue;
}
for (const linkId of output.links) {
const link = getLinkFromGraph(node.graph, linkId);
if (!link) {
continue;
}
const targetNode = node.graph?.getNodeById?.(link.target_id);
if (targetNode && targetNode.comfyClass === "TriggerWord Toggle (LoraManager)") {
connectedNodes.push(targetNode);
}
}
}
return connectedNodes;
}
@@ -161,11 +289,15 @@ export function getActiveLorasFromNode(node) {
// Recursively collect all active loras from a node and its input chain
export function collectActiveLorasFromChain(node, visited = new Set()) {
// Prevent infinite loops from circular references
if (visited.has(node.id)) {
const nodeKey = getNodeKey(node);
if (!nodeKey) {
return new Set();
}
visited.add(node.id);
if (visited.has(nodeKey)) {
return new Set();
}
visited.add(nodeKey);
// Get active loras from current node
const allActiveLoraNames = getActiveLorasFromNode(node);
@@ -181,14 +313,22 @@ export function collectActiveLorasFromChain(node, visited = new Set()) {
// Update trigger words for connected toggle nodes
export function updateConnectedTriggerWords(node, loraNames) {
const connectedNodeIds = getConnectedTriggerToggleNodes(node);
if (connectedNodeIds.length > 0) {
const connectedNodes = getConnectedTriggerToggleNodes(node);
if (connectedNodes.length > 0) {
const nodeIds = connectedNodes
.map((connectedNode) => getNodeReference(connectedNode))
.filter((reference) => reference !== null);
if (nodeIds.length === 0) {
return;
}
fetch("/api/lm/loras/get_trigger_words", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
lora_names: Array.from(loraNames),
node_ids: connectedNodeIds
node_ids: nodeIds
})
}).catch(err => console.error("Error fetching trigger words:", err));
}