mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-25 23:25:43 -03:00
feat(graph): enhance node handling with graph identifiers and improve metadata updates, see #408, #538
This commit is contained in:
@@ -80,7 +80,7 @@ class NodeRegistry:
|
|||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self._lock = asyncio.Lock()
|
self._lock = asyncio.Lock()
|
||||||
self._nodes: Dict[int, dict] = {}
|
self._nodes: Dict[str, dict] = {}
|
||||||
self._registry_updated = asyncio.Event()
|
self._registry_updated = asyncio.Event()
|
||||||
|
|
||||||
async def register_nodes(self, nodes: list[dict]) -> None:
|
async def register_nodes(self, nodes: list[dict]) -> None:
|
||||||
@@ -88,11 +88,16 @@ class NodeRegistry:
|
|||||||
self._nodes.clear()
|
self._nodes.clear()
|
||||||
for node in nodes:
|
for node in nodes:
|
||||||
node_id = node["node_id"]
|
node_id = node["node_id"]
|
||||||
|
graph_id = str(node["graph_id"])
|
||||||
|
unique_id = f"{graph_id}:{node_id}"
|
||||||
node_type = node.get("type", "")
|
node_type = node.get("type", "")
|
||||||
type_id = NODE_TYPES.get(node_type, 0)
|
type_id = NODE_TYPES.get(node_type, 0)
|
||||||
bgcolor = node.get("bgcolor") or DEFAULT_NODE_COLOR
|
bgcolor = node.get("bgcolor") or DEFAULT_NODE_COLOR
|
||||||
self._nodes[node_id] = {
|
self._nodes[unique_id] = {
|
||||||
"id": node_id,
|
"id": node_id,
|
||||||
|
"graph_id": graph_id,
|
||||||
|
"graph_name": node.get("graph_name"),
|
||||||
|
"unique_id": unique_id,
|
||||||
"bgcolor": bgcolor,
|
"bgcolor": bgcolor,
|
||||||
"title": node.get("title"),
|
"title": node.get("title"),
|
||||||
"type": type_id,
|
"type": type_id,
|
||||||
@@ -330,16 +335,65 @@ class LoraCodeHandler:
|
|||||||
logger.error("Error broadcasting lora code: %s", exc)
|
logger.error("Error broadcasting lora code: %s", exc)
|
||||||
results.append({"node_id": "broadcast", "success": False, "error": str(exc)})
|
results.append({"node_id": "broadcast", "success": False, "error": str(exc)})
|
||||||
else:
|
else:
|
||||||
for node_id in node_ids:
|
for entry in node_ids:
|
||||||
|
node_identifier = entry
|
||||||
|
graph_identifier = None
|
||||||
|
if isinstance(entry, dict):
|
||||||
|
node_identifier = entry.get("node_id")
|
||||||
|
graph_identifier = entry.get("graph_id")
|
||||||
|
|
||||||
|
if node_identifier is None:
|
||||||
|
results.append(
|
||||||
|
{
|
||||||
|
"node_id": node_identifier,
|
||||||
|
"graph_id": graph_identifier,
|
||||||
|
"success": False,
|
||||||
|
"error": "Missing node_id parameter",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
parsed_node_id = int(node_identifier)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
parsed_node_id = node_identifier
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"id": parsed_node_id,
|
||||||
|
"lora_code": lora_code,
|
||||||
|
"mode": mode,
|
||||||
|
}
|
||||||
|
|
||||||
|
if graph_identifier is not None:
|
||||||
|
payload["graph_id"] = str(graph_identifier)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self._prompt_server.instance.send_sync(
|
self._prompt_server.instance.send_sync(
|
||||||
"lora_code_update",
|
"lora_code_update",
|
||||||
{"id": node_id, "lora_code": lora_code, "mode": mode},
|
payload,
|
||||||
|
)
|
||||||
|
results.append(
|
||||||
|
{
|
||||||
|
"node_id": parsed_node_id,
|
||||||
|
"graph_id": payload.get("graph_id"),
|
||||||
|
"success": True,
|
||||||
|
}
|
||||||
)
|
)
|
||||||
results.append({"node_id": node_id, "success": True})
|
|
||||||
except Exception as exc: # pragma: no cover - defensive logging
|
except Exception as exc: # pragma: no cover - defensive logging
|
||||||
logger.error("Error sending lora code to node %s: %s", node_id, exc)
|
logger.error(
|
||||||
results.append({"node_id": node_id, "success": False, "error": str(exc)})
|
"Error sending lora code to node %s (graph %s): %s",
|
||||||
|
parsed_node_id,
|
||||||
|
graph_identifier,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
results.append(
|
||||||
|
{
|
||||||
|
"node_id": parsed_node_id,
|
||||||
|
"graph_id": payload.get("graph_id"),
|
||||||
|
"success": False,
|
||||||
|
"error": str(exc),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
return web.json_response({"success": True, "results": results})
|
return web.json_response({"success": True, "results": results})
|
||||||
except Exception as exc: # pragma: no cover - defensive logging
|
except Exception as exc: # pragma: no cover - defensive logging
|
||||||
@@ -679,10 +733,21 @@ class NodeRegistryHandler:
|
|||||||
node_id = node.get("node_id")
|
node_id = node.get("node_id")
|
||||||
if node_id is None:
|
if node_id is None:
|
||||||
return web.json_response({"success": False, "error": f"Node {index} missing node_id parameter"}, status=400)
|
return web.json_response({"success": False, "error": f"Node {index} missing node_id parameter"}, status=400)
|
||||||
|
graph_id = node.get("graph_id")
|
||||||
|
if graph_id is None:
|
||||||
|
return web.json_response({"success": False, "error": f"Node {index} missing graph_id parameter"}, status=400)
|
||||||
|
graph_name = node.get("graph_name")
|
||||||
try:
|
try:
|
||||||
node["node_id"] = int(node_id)
|
node["node_id"] = int(node_id)
|
||||||
except (TypeError, ValueError):
|
except (TypeError, ValueError):
|
||||||
return web.json_response({"success": False, "error": f"Node {index} node_id must be an integer"}, status=400)
|
return web.json_response({"success": False, "error": f"Node {index} node_id must be an integer"}, status=400)
|
||||||
|
node["graph_id"] = str(graph_id)
|
||||||
|
if graph_name is None:
|
||||||
|
node["graph_name"] = None
|
||||||
|
elif isinstance(graph_name, str):
|
||||||
|
node["graph_name"] = graph_name
|
||||||
|
else:
|
||||||
|
node["graph_name"] = str(graph_name)
|
||||||
|
|
||||||
await self._node_registry.register_nodes(nodes)
|
await self._node_registry.register_nodes(nodes)
|
||||||
return web.json_response({"success": True, "message": f"{len(nodes)} nodes registered successfully"})
|
return web.json_response({"success": True, "message": f"{len(nodes)} nodes registered successfully"})
|
||||||
|
|||||||
@@ -229,11 +229,27 @@ class LoraRoutes(BaseModelRoutes):
|
|||||||
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
|
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
|
||||||
|
|
||||||
# Send update to all connected trigger word toggle nodes
|
# Send update to all connected trigger word toggle nodes
|
||||||
for node_id in node_ids:
|
for entry in node_ids:
|
||||||
PromptServer.instance.send_sync("trigger_word_update", {
|
node_identifier = entry
|
||||||
"id": node_id,
|
graph_identifier = None
|
||||||
|
if isinstance(entry, dict):
|
||||||
|
node_identifier = entry.get("node_id")
|
||||||
|
graph_identifier = entry.get("graph_id")
|
||||||
|
|
||||||
|
try:
|
||||||
|
parsed_node_id = int(node_identifier)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
parsed_node_id = node_identifier
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"id": parsed_node_id,
|
||||||
"message": trigger_words_text
|
"message": trigger_words_text
|
||||||
})
|
}
|
||||||
|
|
||||||
|
if graph_identifier is not None:
|
||||||
|
payload["graph_id"] = str(graph_identifier)
|
||||||
|
|
||||||
|
PromptServer.instance.send_sync("trigger_word_update", payload)
|
||||||
|
|
||||||
return web.json_response({"success": True})
|
return web.json_response({"success": True})
|
||||||
|
|
||||||
|
|||||||
@@ -435,8 +435,9 @@ export async function sendLoraToWorkflow(loraSyntax, replaceMode = false, syntax
|
|||||||
return true;
|
return true;
|
||||||
} else {
|
} else {
|
||||||
// Single node - send directly
|
// Single node - send directly
|
||||||
const nodeId = Object.keys(registryData.data.nodes)[0];
|
const nodes = registryData.data.nodes;
|
||||||
return await sendToSpecificNode([nodeId], loraSyntax, replaceMode, syntaxType);
|
const nodeId = Object.keys(nodes)[0];
|
||||||
|
return await sendToSpecificNode([nodeId], nodes, loraSyntax, replaceMode, syntaxType);
|
||||||
}
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error('Failed to get registry:', error);
|
console.error('Failed to get registry:', error);
|
||||||
@@ -452,19 +453,65 @@ export async function sendLoraToWorkflow(loraSyntax, replaceMode = false, syntax
|
|||||||
* @param {boolean} replaceMode - Whether to replace existing LoRAs
|
* @param {boolean} replaceMode - Whether to replace existing LoRAs
|
||||||
* @param {string} syntaxType - The type of syntax ('lora' or 'recipe')
|
* @param {string} syntaxType - The type of syntax ('lora' or 'recipe')
|
||||||
*/
|
*/
|
||||||
async function sendToSpecificNode(nodeIds, loraSyntax, replaceMode, syntaxType) {
|
function resolveNodeReference(nodeKey, nodesMap) {
|
||||||
|
if (!nodeKey) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
const directMatch = nodesMap?.[nodeKey];
|
||||||
|
if (directMatch) {
|
||||||
|
return {
|
||||||
|
node_id: directMatch.id,
|
||||||
|
graph_id: directMatch.graph_id ?? null,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
if (typeof nodeKey === 'string' && nodeKey.includes(':')) {
|
||||||
|
const [graphId, ...rest] = nodeKey.split(':');
|
||||||
|
const nodeIdPart = rest.join(':');
|
||||||
|
const numericNodeId = Number(nodeIdPart);
|
||||||
|
return {
|
||||||
|
node_id: Number.isNaN(numericNodeId) ? nodeIdPart : numericNodeId,
|
||||||
|
graph_id: graphId || null,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
const numericId = Number(nodeKey);
|
||||||
|
return {
|
||||||
|
node_id: Number.isNaN(numericId) ? nodeKey : numericId,
|
||||||
|
graph_id: null,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
async function sendToSpecificNode(nodeIds, nodesMap, loraSyntax, replaceMode, syntaxType) {
|
||||||
try {
|
try {
|
||||||
// Call the backend API to update the lora code
|
// Call the backend API to update the lora code
|
||||||
|
const requestBody = {
|
||||||
|
lora_code: loraSyntax,
|
||||||
|
mode: replaceMode ? 'replace' : 'append'
|
||||||
|
};
|
||||||
|
|
||||||
|
if (Array.isArray(nodeIds)) {
|
||||||
|
const references = nodeIds
|
||||||
|
.map((nodeKey) => resolveNodeReference(nodeKey, nodesMap))
|
||||||
|
.filter((reference) => reference && reference.node_id !== undefined);
|
||||||
|
|
||||||
|
if (references.length > 0) {
|
||||||
|
requestBody.node_ids = references;
|
||||||
|
}
|
||||||
|
} else if (nodeIds) {
|
||||||
|
const reference = resolveNodeReference(nodeIds, nodesMap);
|
||||||
|
if (reference) {
|
||||||
|
requestBody.node_ids = [reference];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const response = await fetch('/api/lm/update-lora-code', {
|
const response = await fetch('/api/lm/update-lora-code', {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: {
|
headers: {
|
||||||
'Content-Type': 'application/json'
|
'Content-Type': 'application/json'
|
||||||
},
|
},
|
||||||
body: JSON.stringify({
|
body: JSON.stringify(requestBody)
|
||||||
node_ids: nodeIds,
|
|
||||||
lora_code: loraSyntax,
|
|
||||||
mode: replaceMode ? 'replace' : 'append'
|
|
||||||
})
|
|
||||||
});
|
});
|
||||||
|
|
||||||
const result = await response.json();
|
const result = await response.json();
|
||||||
@@ -522,16 +569,17 @@ function showNodeSelector(nodes, loraSyntax, replaceMode, syntaxType) {
|
|||||||
hideNodeSelector();
|
hideNodeSelector();
|
||||||
|
|
||||||
// Generate node list HTML with icons and proper colors
|
// Generate node list HTML with icons and proper colors
|
||||||
const nodeItems = Object.values(nodes).map(node => {
|
const nodeItems = Object.entries(nodes).map(([nodeKey, node]) => {
|
||||||
const iconClass = NODE_TYPE_ICONS[node.type] || 'fas fa-question-circle';
|
const iconClass = NODE_TYPE_ICONS[node.type] || 'fas fa-question-circle';
|
||||||
const bgColor = node.bgcolor || DEFAULT_NODE_COLOR;
|
const bgColor = node.bgcolor || DEFAULT_NODE_COLOR;
|
||||||
|
const graphLabel = node.graph_name ? ` (${node.graph_name})` : '';
|
||||||
|
|
||||||
return `
|
return `
|
||||||
<div class="node-item" data-node-id="${node.id}">
|
<div class="node-item" data-node-id="${nodeKey}">
|
||||||
<div class="node-icon-indicator" style="background-color: ${bgColor}">
|
<div class="node-icon-indicator" style="background-color: ${bgColor}">
|
||||||
<i class="${iconClass}"></i>
|
<i class="${iconClass}"></i>
|
||||||
</div>
|
</div>
|
||||||
<span>#${node.id} ${node.title}</span>
|
<span>#${node.id}${graphLabel} ${node.title}</span>
|
||||||
</div>
|
</div>
|
||||||
`;
|
`;
|
||||||
}).join('');
|
}).join('');
|
||||||
@@ -610,10 +658,10 @@ function setupNodeSelectorEvents(selector, nodes, loraSyntax, replaceMode, synta
|
|||||||
if (action === 'send-all') {
|
if (action === 'send-all') {
|
||||||
// Send to all nodes
|
// Send to all nodes
|
||||||
const allNodeIds = Object.keys(nodes);
|
const allNodeIds = Object.keys(nodes);
|
||||||
await sendToSpecificNode(allNodeIds, loraSyntax, replaceMode, syntaxType);
|
await sendToSpecificNode(allNodeIds, nodes, loraSyntax, replaceMode, syntaxType);
|
||||||
} else if (nodeId) {
|
} else if (nodeId) {
|
||||||
// Send to specific node
|
// Send to specific node
|
||||||
await sendToSpecificNode([nodeId], loraSyntax, replaceMode, syntaxType);
|
await sendToSpecificNode([nodeId], nodes, loraSyntax, replaceMode, syntaxType);
|
||||||
}
|
}
|
||||||
|
|
||||||
hideNodeSelector();
|
hideNodeSelector();
|
||||||
|
|||||||
@@ -46,6 +46,7 @@ vi.mock(EVENT_MANAGER_MODULE, () => ({
|
|||||||
off: vi.fn(),
|
off: vi.fn(),
|
||||||
addHandler: vi.fn(),
|
addHandler: vi.fn(),
|
||||||
removeHandler: vi.fn(),
|
removeHandler: vi.fn(),
|
||||||
|
setState: vi.fn(),
|
||||||
},
|
},
|
||||||
}));
|
}));
|
||||||
|
|
||||||
@@ -62,6 +63,7 @@ describe('UI helper DOM utilities', () => {
|
|||||||
|
|
||||||
afterEach(() => {
|
afterEach(() => {
|
||||||
vi.useRealTimers();
|
vi.useRealTimers();
|
||||||
|
delete global.fetch;
|
||||||
});
|
});
|
||||||
|
|
||||||
it('creates toast elements and cleans them up after timeout', async () => {
|
it('creates toast elements and cleans them up after timeout', async () => {
|
||||||
@@ -105,4 +107,53 @@ describe('UI helper DOM utilities', () => {
|
|||||||
expect(document.body.dataset.theme).toBe('dark');
|
expect(document.body.dataset.theme).toBe('dark');
|
||||||
expect(document.querySelector('.theme-toggle').classList.contains('theme-dark')).toBe(true);
|
expect(document.querySelector('.theme-toggle').classList.contains('theme-dark')).toBe(true);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('renders subgraph names in the node selector list', async () => {
|
||||||
|
const registryResponse = {
|
||||||
|
success: true,
|
||||||
|
data: {
|
||||||
|
node_count: 2,
|
||||||
|
nodes: {
|
||||||
|
'root:1': {
|
||||||
|
id: 1,
|
||||||
|
graph_id: 'root',
|
||||||
|
graph_name: null,
|
||||||
|
title: 'Root Loader',
|
||||||
|
type: 1,
|
||||||
|
bgcolor: '#123456',
|
||||||
|
},
|
||||||
|
'subgraph-uuid:2': {
|
||||||
|
id: 2,
|
||||||
|
graph_id: 'subgraph-uuid',
|
||||||
|
graph_name: 'Character Subgraph',
|
||||||
|
title: 'Nested Loader',
|
||||||
|
type: 1,
|
||||||
|
bgcolor: '#654321',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
global.fetch = vi.fn().mockResolvedValue({
|
||||||
|
json: async () => registryResponse,
|
||||||
|
});
|
||||||
|
|
||||||
|
document.body.innerHTML = '<div id="nodeSelector"></div>';
|
||||||
|
|
||||||
|
const { sendLoraToWorkflow } = await import(UI_HELPERS_MODULE);
|
||||||
|
|
||||||
|
const result = await sendLoraToWorkflow('<lora:test:1>');
|
||||||
|
|
||||||
|
expect(result).toBe(true);
|
||||||
|
expect(global.fetch).toHaveBeenCalledWith('/api/lm/get-registry');
|
||||||
|
|
||||||
|
const nodeLabels = Array.from(
|
||||||
|
document.querySelectorAll('#nodeSelector .node-item[data-node-id] span')
|
||||||
|
).map((span) => span.textContent.trim());
|
||||||
|
|
||||||
|
expect(nodeLabels).toEqual([
|
||||||
|
'#1 Root Loader',
|
||||||
|
'#2 (Character Subgraph) Nested Loader',
|
||||||
|
]);
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -188,7 +188,7 @@ async def test_get_trigger_words_broadcasts(monkeypatch, routes):
|
|||||||
|
|
||||||
monkeypatch.setattr("py.routes.lora_routes.get_lora_info", lambda name: (f"path/{name}", [f"trigger-{name}"]))
|
monkeypatch.setattr("py.routes.lora_routes.get_lora_info", lambda name: (f"path/{name}", [f"trigger-{name}"]))
|
||||||
|
|
||||||
request = DummyRequest(json_data={"lora_names": ["one"], "node_ids": ["node"]})
|
request = DummyRequest(json_data={"lora_names": ["one"], "node_ids": [{"node_id": "node", "graph_id": "graph-1"}]})
|
||||||
|
|
||||||
response = await routes.get_trigger_words(request)
|
response = await routes.get_trigger_words(request)
|
||||||
payload = json.loads(response.text)
|
payload = json.loads(response.text)
|
||||||
@@ -196,7 +196,7 @@ async def test_get_trigger_words_broadcasts(monkeypatch, routes):
|
|||||||
assert payload == {"success": True}
|
assert payload == {"success": True}
|
||||||
send_mock.assert_called_once_with(
|
send_mock.assert_called_once_with(
|
||||||
"trigger_word_update",
|
"trigger_word_update",
|
||||||
{"id": "node", "message": "trigger-one"},
|
{"id": "node", "graph_id": "graph-1", "message": "trigger-one"},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,13 @@ from types import SimpleNamespace
|
|||||||
import pytest
|
import pytest
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
|
|
||||||
from py.routes.handlers.misc_handlers import SettingsHandler, ServiceRegistryAdapter
|
from py.routes.handlers.misc_handlers import (
|
||||||
|
LoraCodeHandler,
|
||||||
|
NodeRegistry,
|
||||||
|
NodeRegistryHandler,
|
||||||
|
ServiceRegistryAdapter,
|
||||||
|
SettingsHandler,
|
||||||
|
)
|
||||||
from py.routes.misc_route_registrar import MISC_ROUTE_DEFINITIONS, MiscRouteRegistrar
|
from py.routes.misc_route_registrar import MISC_ROUTE_DEFINITIONS, MiscRouteRegistrar
|
||||||
from py.routes.misc_routes import MiscRoutes
|
from py.routes.misc_routes import MiscRoutes
|
||||||
|
|
||||||
@@ -126,6 +132,128 @@ class FakePromptServer:
|
|||||||
instance = Instance()
|
instance = Instance()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_register_nodes_requires_graph_id():
|
||||||
|
node_registry = NodeRegistry()
|
||||||
|
handler = NodeRegistryHandler(
|
||||||
|
node_registry=node_registry,
|
||||||
|
prompt_server=FakePromptServer,
|
||||||
|
standalone_mode=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
request = FakeRequest(json_data={"nodes": [{"node_id": 1}]})
|
||||||
|
response = await handler.register_nodes(request)
|
||||||
|
payload = json.loads(response.text)
|
||||||
|
|
||||||
|
assert response.status == 400
|
||||||
|
assert payload["success"] is False
|
||||||
|
assert "graph_id" in payload["error"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_register_nodes_stores_graph_identifier():
|
||||||
|
node_registry = NodeRegistry()
|
||||||
|
handler = NodeRegistryHandler(
|
||||||
|
node_registry=node_registry,
|
||||||
|
prompt_server=FakePromptServer,
|
||||||
|
standalone_mode=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
request = FakeRequest(
|
||||||
|
json_data={
|
||||||
|
"nodes": [
|
||||||
|
{
|
||||||
|
"node_id": 7,
|
||||||
|
"graph_id": "graph-123",
|
||||||
|
"graph_name": "Character Subgraph",
|
||||||
|
"type": "Lora Loader (LoraManager)",
|
||||||
|
"title": "Loader",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await handler.register_nodes(request)
|
||||||
|
payload = json.loads(response.text)
|
||||||
|
|
||||||
|
assert payload["success"] is True
|
||||||
|
|
||||||
|
registry = await node_registry.get_registry()
|
||||||
|
assert registry["node_count"] == 1
|
||||||
|
stored_node = next(iter(registry["nodes"].values()))
|
||||||
|
assert stored_node["graph_id"] == "graph-123"
|
||||||
|
assert stored_node["unique_id"] == "graph-123:7"
|
||||||
|
assert stored_node["graph_name"] == "Character Subgraph"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_register_nodes_defaults_graph_name_to_none():
|
||||||
|
node_registry = NodeRegistry()
|
||||||
|
handler = NodeRegistryHandler(
|
||||||
|
node_registry=node_registry,
|
||||||
|
prompt_server=FakePromptServer,
|
||||||
|
standalone_mode=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
request = FakeRequest(
|
||||||
|
json_data={
|
||||||
|
"nodes": [
|
||||||
|
{
|
||||||
|
"node_id": 8,
|
||||||
|
"graph_id": "root",
|
||||||
|
"type": "Lora Loader (LoraManager)",
|
||||||
|
"title": "Root Loader",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await handler.register_nodes(request)
|
||||||
|
payload = json.loads(response.text)
|
||||||
|
|
||||||
|
assert payload["success"] is True
|
||||||
|
|
||||||
|
registry = await node_registry.get_registry()
|
||||||
|
stored_node = next(iter(registry["nodes"].values()))
|
||||||
|
assert stored_node["graph_name"] is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_lora_code_includes_graph_identifier():
|
||||||
|
send_calls: list[tuple[str, dict]] = []
|
||||||
|
|
||||||
|
class RecordingPromptServer:
|
||||||
|
class Instance:
|
||||||
|
def send_sync(self, event, payload):
|
||||||
|
send_calls.append((event, payload))
|
||||||
|
|
||||||
|
instance = Instance()
|
||||||
|
|
||||||
|
handler = LoraCodeHandler(RecordingPromptServer)
|
||||||
|
|
||||||
|
request = FakeRequest(
|
||||||
|
json_data={
|
||||||
|
"node_ids": [{"node_id": 3, "graph_id": "graph-A"}],
|
||||||
|
"lora_code": "<lora>",
|
||||||
|
"mode": "replace",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await handler.update_lora_code(request)
|
||||||
|
payload = json.loads(response.text)
|
||||||
|
|
||||||
|
assert payload["success"] is True
|
||||||
|
assert payload["results"] == [
|
||||||
|
{"node_id": 3, "graph_id": "graph-A", "success": True}
|
||||||
|
]
|
||||||
|
assert send_calls == [
|
||||||
|
(
|
||||||
|
"lora_code_update",
|
||||||
|
{"id": 3, "graph_id": "graph-A", "lora_code": "<lora>", "mode": "replace"},
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class FakeScanner:
|
class FakeScanner:
|
||||||
async def check_model_version_exists(self, _version_id):
|
async def check_model_version_exists(self, _version_id):
|
||||||
return False
|
return False
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import { app } from "../../scripts/app.js";
|
import { app } from "../../scripts/app.js";
|
||||||
import { api } from "../../scripts/api.js";
|
import { api } from "../../scripts/api.js";
|
||||||
import { addJsonDisplayWidget } from "./json_display_widget.js";
|
import { addJsonDisplayWidget } from "./json_display_widget.js";
|
||||||
|
import { getNodeFromGraph } from "./utils.js";
|
||||||
|
|
||||||
app.registerExtension({
|
app.registerExtension({
|
||||||
name: "LoraManager.DebugMetadata",
|
name: "LoraManager.DebugMetadata",
|
||||||
@@ -8,8 +9,8 @@ app.registerExtension({
|
|||||||
setup() {
|
setup() {
|
||||||
// Add message handler to listen for metadata updates from Python
|
// Add message handler to listen for metadata updates from Python
|
||||||
api.addEventListener("metadata_update", (event) => {
|
api.addEventListener("metadata_update", (event) => {
|
||||||
const { id, metadata } = event.detail;
|
const { id, graph_id: graphId, metadata } = event.detail;
|
||||||
this.handleMetadataUpdate(id, metadata);
|
this.handleMetadataUpdate(id, graphId, metadata);
|
||||||
});
|
});
|
||||||
},
|
},
|
||||||
|
|
||||||
@@ -37,8 +38,8 @@ app.registerExtension({
|
|||||||
},
|
},
|
||||||
|
|
||||||
// Handle metadata updates from Python
|
// Handle metadata updates from Python
|
||||||
handleMetadataUpdate(id, metadata) {
|
handleMetadataUpdate(id, graphId, metadata) {
|
||||||
const node = app.graph.getNodeById(+id);
|
const node = getNodeFromGraph(graphId, id);
|
||||||
if (!node || node.comfyClass !== "Debug Metadata (LoraManager)") {
|
if (!node || node.comfyClass !== "Debug Metadata (LoraManager)") {
|
||||||
console.warn("Node not found or not a DebugMetadata node:", id);
|
console.warn("Node not found or not a DebugMetadata node:", id);
|
||||||
return;
|
return;
|
||||||
|
|||||||
@@ -7,6 +7,8 @@ import {
|
|||||||
chainCallback,
|
chainCallback,
|
||||||
mergeLoras,
|
mergeLoras,
|
||||||
setupInputWidgetWithAutocomplete,
|
setupInputWidgetWithAutocomplete,
|
||||||
|
getAllGraphNodes,
|
||||||
|
getNodeFromGraph,
|
||||||
} from "./utils.js";
|
} from "./utils.js";
|
||||||
import { addLorasWidget } from "./loras_widget.js";
|
import { addLorasWidget } from "./loras_widget.js";
|
||||||
|
|
||||||
@@ -16,23 +18,26 @@ app.registerExtension({
|
|||||||
setup() {
|
setup() {
|
||||||
// Add message handler to listen for messages from Python
|
// Add message handler to listen for messages from Python
|
||||||
api.addEventListener("lora_code_update", (event) => {
|
api.addEventListener("lora_code_update", (event) => {
|
||||||
const { id, lora_code, mode } = event.detail;
|
this.handleLoraCodeUpdate(event.detail || {});
|
||||||
this.handleLoraCodeUpdate(id, lora_code, mode);
|
|
||||||
});
|
});
|
||||||
},
|
},
|
||||||
|
|
||||||
// Handle lora code updates from Python
|
// Handle lora code updates from Python
|
||||||
handleLoraCodeUpdate(id, loraCode, mode) {
|
handleLoraCodeUpdate(message) {
|
||||||
|
const nodeId = message?.node_id ?? message?.id;
|
||||||
|
const graphId = message?.graph_id;
|
||||||
|
const loraCode = message?.lora_code ?? "";
|
||||||
|
const mode = message?.mode ?? "append";
|
||||||
|
|
||||||
|
const numericNodeId =
|
||||||
|
typeof nodeId === "string" ? Number(nodeId) : nodeId;
|
||||||
|
|
||||||
// Handle broadcast mode (for Desktop/non-browser support)
|
// Handle broadcast mode (for Desktop/non-browser support)
|
||||||
if (id === -1) {
|
if (numericNodeId === -1) {
|
||||||
// Find all Lora Loader nodes in the current graph
|
// Find all Lora Loader nodes in the current graph
|
||||||
const loraLoaderNodes = [];
|
const loraLoaderNodes = getAllGraphNodes(app.graph)
|
||||||
for (const nodeId in app.graph._nodes_by_id) {
|
.map(({ node }) => node)
|
||||||
const node = app.graph._nodes_by_id[nodeId];
|
.filter((node) => node?.comfyClass === "Lora Loader (LoraManager)");
|
||||||
if (node.comfyClass === "Lora Loader (LoraManager)") {
|
|
||||||
loraLoaderNodes.push(node);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update each Lora Loader node found
|
// Update each Lora Loader node found
|
||||||
if (loraLoaderNodes.length > 0) {
|
if (loraLoaderNodes.length > 0) {
|
||||||
@@ -52,14 +57,18 @@ app.registerExtension({
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Standard mode - update a specific node
|
// Standard mode - update a specific node
|
||||||
const node = app.graph.getNodeById(+id);
|
const node = getNodeFromGraph(graphId, numericNodeId);
|
||||||
if (
|
if (
|
||||||
!node ||
|
!node ||
|
||||||
(node.comfyClass !== "Lora Loader (LoraManager)" &&
|
(node.comfyClass !== "Lora Loader (LoraManager)" &&
|
||||||
node.comfyClass !== "Lora Stacker (LoraManager)" &&
|
node.comfyClass !== "Lora Stacker (LoraManager)" &&
|
||||||
node.comfyClass !== "WanVideo Lora Select (LoraManager)")
|
node.comfyClass !== "WanVideo Lora Select (LoraManager)")
|
||||||
) {
|
) {
|
||||||
console.warn("Node not found or not a LoraLoader:", id);
|
console.warn(
|
||||||
|
"Node not found or not a LoraLoader:",
|
||||||
|
graphId ?? "root",
|
||||||
|
nodeId
|
||||||
|
);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -7,6 +7,8 @@ import {
|
|||||||
chainCallback,
|
chainCallback,
|
||||||
mergeLoras,
|
mergeLoras,
|
||||||
setupInputWidgetWithAutocomplete,
|
setupInputWidgetWithAutocomplete,
|
||||||
|
getLinkFromGraph,
|
||||||
|
getNodeKey,
|
||||||
} from "./utils.js";
|
} from "./utils.js";
|
||||||
import { addLorasWidget } from "./loras_widget.js";
|
import { addLorasWidget } from "./loras_widget.js";
|
||||||
|
|
||||||
@@ -124,17 +126,18 @@ app.registerExtension({
|
|||||||
|
|
||||||
// Helper function to find and update downstream Lora Loader nodes
|
// Helper function to find and update downstream Lora Loader nodes
|
||||||
function updateDownstreamLoaders(startNode, visited = new Set()) {
|
function updateDownstreamLoaders(startNode, visited = new Set()) {
|
||||||
if (visited.has(startNode.id)) return;
|
const nodeKey = getNodeKey(startNode);
|
||||||
visited.add(startNode.id);
|
if (!nodeKey || visited.has(nodeKey)) return;
|
||||||
|
visited.add(nodeKey);
|
||||||
|
|
||||||
// Check each output link
|
// Check each output link
|
||||||
if (startNode.outputs) {
|
if (startNode.outputs) {
|
||||||
for (const output of startNode.outputs) {
|
for (const output of startNode.outputs) {
|
||||||
if (output.links) {
|
if (output.links) {
|
||||||
for (const linkId of output.links) {
|
for (const linkId of output.links) {
|
||||||
const link = app.graph.links[linkId];
|
const link = getLinkFromGraph(startNode.graph, linkId);
|
||||||
if (link) {
|
if (link) {
|
||||||
const targetNode = app.graph.getNodeById(link.target_id);
|
const targetNode = startNode.graph?.getNodeById?.(link.target_id);
|
||||||
|
|
||||||
// If target is a Lora Loader, collect all active loras in the chain and update
|
// If target is a Lora Loader, collect all active loras in the chain and update
|
||||||
if (
|
if (
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import { app } from "../../scripts/app.js";
|
import { app } from "../../scripts/app.js";
|
||||||
import { api } from "../../scripts/api.js";
|
import { api } from "../../scripts/api.js";
|
||||||
import { CONVERTED_TYPE } from "./utils.js";
|
import { CONVERTED_TYPE, getNodeFromGraph } from "./utils.js";
|
||||||
import { addTagsWidget } from "./tags_widget.js";
|
import { addTagsWidget } from "./tags_widget.js";
|
||||||
|
|
||||||
// TriggerWordToggle extension for ComfyUI
|
// TriggerWordToggle extension for ComfyUI
|
||||||
@@ -10,8 +10,8 @@ app.registerExtension({
|
|||||||
setup() {
|
setup() {
|
||||||
// Add message handler to listen for messages from Python
|
// Add message handler to listen for messages from Python
|
||||||
api.addEventListener("trigger_word_update", (event) => {
|
api.addEventListener("trigger_word_update", (event) => {
|
||||||
const { id, message } = event.detail;
|
const { id, graph_id: graphId, message } = event.detail;
|
||||||
this.handleTriggerWordUpdate(id, message);
|
this.handleTriggerWordUpdate(id, graphId, message);
|
||||||
});
|
});
|
||||||
},
|
},
|
||||||
|
|
||||||
@@ -76,8 +76,8 @@ app.registerExtension({
|
|||||||
},
|
},
|
||||||
|
|
||||||
// Handle trigger word updates from Python
|
// Handle trigger word updates from Python
|
||||||
handleTriggerWordUpdate(id, message) {
|
handleTriggerWordUpdate(id, graphId, message) {
|
||||||
const node = app.graph.getNodeById(+id);
|
const node = getNodeFromGraph(graphId, id);
|
||||||
if (!node || node.comfyClass !== "TriggerWord Toggle (LoraManager)") {
|
if (!node || node.comfyClass !== "TriggerWord Toggle (LoraManager)") {
|
||||||
console.warn("Node not found or not a TriggerWordToggle:", id);
|
console.warn("Node not found or not a TriggerWordToggle:", id);
|
||||||
return;
|
return;
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
// ComfyUI extension to track model usage statistics
|
// ComfyUI extension to track model usage statistics
|
||||||
import { app } from "../../scripts/app.js";
|
import { app } from "../../scripts/app.js";
|
||||||
import { api } from "../../scripts/api.js";
|
import { api } from "../../scripts/api.js";
|
||||||
import { showToast } from "./utils.js";
|
import { getAllGraphNodes, getNodeReference, showToast } from "./utils.js";
|
||||||
|
|
||||||
// Define target nodes and their widget configurations
|
// Define target nodes and their widget configurations
|
||||||
const PATH_CORRECTION_TARGETS = [
|
const PATH_CORRECTION_TARGETS = [
|
||||||
@@ -56,25 +56,35 @@ app.registerExtension({
|
|||||||
|
|
||||||
async refreshRegistry() {
|
async refreshRegistry() {
|
||||||
try {
|
try {
|
||||||
// Get current workflow nodes
|
const loraNodes = [];
|
||||||
const prompt = await app.graphToPrompt();
|
const nodeEntries = getAllGraphNodes(app.graph);
|
||||||
const workflow = prompt.workflow;
|
|
||||||
if (!workflow || !workflow.nodes) {
|
for (const { graph, node } of nodeEntries) {
|
||||||
console.warn("No workflow nodes found for registry refresh");
|
if (!node || !node.comfyClass) {
|
||||||
return;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Find all Lora nodes
|
if (
|
||||||
const loraNodes = [];
|
node.comfyClass === "Lora Loader (LoraManager)" ||
|
||||||
for (const node of workflow.nodes.values()) {
|
node.comfyClass === "Lora Stacker (LoraManager)" ||
|
||||||
if (node.type === "Lora Loader (LoraManager)" ||
|
node.comfyClass === "WanVideo Lora Select (LoraManager)"
|
||||||
node.type === "Lora Stacker (LoraManager)" ||
|
) {
|
||||||
node.type === "WanVideo Lora Select (LoraManager)") {
|
const reference = getNodeReference(node);
|
||||||
|
if (!reference) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
const graphName = typeof graph?.name === "string" && graph.name.trim()
|
||||||
|
? graph.name
|
||||||
|
: null;
|
||||||
|
|
||||||
loraNodes.push({
|
loraNodes.push({
|
||||||
node_id: node.id,
|
node_id: reference.node_id,
|
||||||
bgcolor: node.bgcolor || null,
|
graph_id: reference.graph_id,
|
||||||
title: node.title || node.type,
|
graph_name: graphName,
|
||||||
type: node.type
|
bgcolor: node.bgcolor ?? node.color ?? null,
|
||||||
|
title: node.title || node.comfyClass,
|
||||||
|
type: node.comfyClass,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,120 @@ export const CONVERTED_TYPE = 'converted-widget';
|
|||||||
import { app } from "../../scripts/app.js";
|
import { app } from "../../scripts/app.js";
|
||||||
import { AutoComplete } from "./autocomplete.js";
|
import { AutoComplete } from "./autocomplete.js";
|
||||||
|
|
||||||
|
const ROOT_GRAPH_ID = "root";
|
||||||
|
|
||||||
|
function isMapLike(collection) {
|
||||||
|
return collection && typeof collection.entries === "function" && typeof collection.values === "function";
|
||||||
|
}
|
||||||
|
|
||||||
|
function getChildGraphs(graph) {
|
||||||
|
if (!graph || !graph._subgraphs) {
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
|
||||||
|
const rawSubgraphs = isMapLike(graph._subgraphs)
|
||||||
|
? Array.from(graph._subgraphs.values())
|
||||||
|
: Object.values(graph._subgraphs);
|
||||||
|
|
||||||
|
return rawSubgraphs
|
||||||
|
.map((subgraph) => subgraph?.graph || subgraph?._graph || subgraph)
|
||||||
|
.filter((subgraph) => subgraph && subgraph !== graph);
|
||||||
|
}
|
||||||
|
|
||||||
|
function traverseGraphs(rootGraph, visitor, visited = new Set()) {
|
||||||
|
const graph = rootGraph || app.graph;
|
||||||
|
if (!graph) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const graphId = getGraphId(graph);
|
||||||
|
if (visited.has(graphId)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
visited.add(graphId);
|
||||||
|
visitor(graph);
|
||||||
|
|
||||||
|
for (const subgraph of getChildGraphs(graph)) {
|
||||||
|
traverseGraphs(subgraph, visitor, visited);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export function getGraphId(graph) {
|
||||||
|
return graph?.id ?? ROOT_GRAPH_ID;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function getNodeGraphId(node) {
|
||||||
|
if (!node) {
|
||||||
|
return ROOT_GRAPH_ID;
|
||||||
|
}
|
||||||
|
return getGraphId(node.graph || app.graph);
|
||||||
|
}
|
||||||
|
|
||||||
|
export function getGraphById(graphId, rootGraph = app.graph) {
|
||||||
|
if (!graphId) {
|
||||||
|
return rootGraph;
|
||||||
|
}
|
||||||
|
|
||||||
|
let foundGraph = null;
|
||||||
|
traverseGraphs(rootGraph, (graph) => {
|
||||||
|
if (!foundGraph && getGraphId(graph) === graphId) {
|
||||||
|
foundGraph = graph;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
return foundGraph;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function getNodeFromGraph(graphId, nodeId) {
|
||||||
|
const graph = getGraphById(graphId) || app.graph;
|
||||||
|
if (!graph || typeof graph.getNodeById !== "function") {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
const numericId = typeof nodeId === "string" ? Number(nodeId) : nodeId;
|
||||||
|
return graph.getNodeById(Number.isNaN(numericId) ? nodeId : numericId) || null;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function getAllGraphNodes(rootGraph = app.graph) {
|
||||||
|
const nodes = [];
|
||||||
|
traverseGraphs(rootGraph, (graph) => {
|
||||||
|
if (Array.isArray(graph._nodes)) {
|
||||||
|
for (const node of graph._nodes) {
|
||||||
|
nodes.push({ graph, node });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
return nodes;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function getNodeReference(node) {
|
||||||
|
if (!node) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
node_id: node.id,
|
||||||
|
graph_id: getNodeGraphId(node),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
export function getNodeKey(node) {
|
||||||
|
if (!node) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
return `${getNodeGraphId(node)}:${node.id}`;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function getLinkFromGraph(graph, linkId) {
|
||||||
|
if (!graph || graph.links == null) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isMapLike(graph.links)) {
|
||||||
|
return graph.links.get(linkId) || null;
|
||||||
|
}
|
||||||
|
|
||||||
|
return graph.links[linkId] || null;
|
||||||
|
}
|
||||||
|
|
||||||
export function chainCallback(object, property, callback) {
|
export function chainCallback(object, property, callback) {
|
||||||
if (object == undefined) {
|
if (object == undefined) {
|
||||||
//This should not happen.
|
//This should not happen.
|
||||||
@@ -104,19 +218,26 @@ export const LORA_PATTERN = /<lora:([^:]+):([-\d\.]+)(?::([-\d\.]+))?>/g;
|
|||||||
export function getConnectedInputStackers(node) {
|
export function getConnectedInputStackers(node) {
|
||||||
const connectedStackers = [];
|
const connectedStackers = [];
|
||||||
|
|
||||||
if (node.inputs) {
|
if (!node?.inputs) {
|
||||||
|
return connectedStackers;
|
||||||
|
}
|
||||||
|
|
||||||
for (const input of node.inputs) {
|
for (const input of node.inputs) {
|
||||||
if (input.name === "lora_stack" && input.link) {
|
if (input.name !== "lora_stack" || !input.link) {
|
||||||
const link = app.graph.links[input.link];
|
continue;
|
||||||
if (link) {
|
}
|
||||||
const sourceNode = app.graph.getNodeById(link.origin_id);
|
|
||||||
|
const link = getLinkFromGraph(node.graph, input.link);
|
||||||
|
if (!link) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
const sourceNode = node.graph?.getNodeById?.(link.origin_id);
|
||||||
if (sourceNode && sourceNode.comfyClass === "Lora Stacker (LoraManager)") {
|
if (sourceNode && sourceNode.comfyClass === "Lora Stacker (LoraManager)") {
|
||||||
connectedStackers.push(sourceNode);
|
connectedStackers.push(sourceNode);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return connectedStackers;
|
return connectedStackers;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -124,21 +245,28 @@ export function getConnectedInputStackers(node) {
|
|||||||
export function getConnectedTriggerToggleNodes(node) {
|
export function getConnectedTriggerToggleNodes(node) {
|
||||||
const connectedNodes = [];
|
const connectedNodes = [];
|
||||||
|
|
||||||
if (node.outputs && node.outputs.length > 0) {
|
if (!node?.outputs) {
|
||||||
|
return connectedNodes;
|
||||||
|
}
|
||||||
|
|
||||||
for (const output of node.outputs) {
|
for (const output of node.outputs) {
|
||||||
if (output.links && output.links.length > 0) {
|
if (!output?.links?.length) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
for (const linkId of output.links) {
|
for (const linkId of output.links) {
|
||||||
const link = app.graph.links[linkId];
|
const link = getLinkFromGraph(node.graph, linkId);
|
||||||
if (link) {
|
if (!link) {
|
||||||
const targetNode = app.graph.getNodeById(link.target_id);
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
const targetNode = node.graph?.getNodeById?.(link.target_id);
|
||||||
if (targetNode && targetNode.comfyClass === "TriggerWord Toggle (LoraManager)") {
|
if (targetNode && targetNode.comfyClass === "TriggerWord Toggle (LoraManager)") {
|
||||||
connectedNodes.push(targetNode.id);
|
connectedNodes.push(targetNode);
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return connectedNodes;
|
return connectedNodes;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -161,10 +289,14 @@ export function getActiveLorasFromNode(node) {
|
|||||||
// Recursively collect all active loras from a node and its input chain
|
// Recursively collect all active loras from a node and its input chain
|
||||||
export function collectActiveLorasFromChain(node, visited = new Set()) {
|
export function collectActiveLorasFromChain(node, visited = new Set()) {
|
||||||
// Prevent infinite loops from circular references
|
// Prevent infinite loops from circular references
|
||||||
if (visited.has(node.id)) {
|
const nodeKey = getNodeKey(node);
|
||||||
|
if (!nodeKey) {
|
||||||
return new Set();
|
return new Set();
|
||||||
}
|
}
|
||||||
visited.add(node.id);
|
if (visited.has(nodeKey)) {
|
||||||
|
return new Set();
|
||||||
|
}
|
||||||
|
visited.add(nodeKey);
|
||||||
|
|
||||||
// Get active loras from current node
|
// Get active loras from current node
|
||||||
const allActiveLoraNames = getActiveLorasFromNode(node);
|
const allActiveLoraNames = getActiveLorasFromNode(node);
|
||||||
@@ -181,14 +313,22 @@ export function collectActiveLorasFromChain(node, visited = new Set()) {
|
|||||||
|
|
||||||
// Update trigger words for connected toggle nodes
|
// Update trigger words for connected toggle nodes
|
||||||
export function updateConnectedTriggerWords(node, loraNames) {
|
export function updateConnectedTriggerWords(node, loraNames) {
|
||||||
const connectedNodeIds = getConnectedTriggerToggleNodes(node);
|
const connectedNodes = getConnectedTriggerToggleNodes(node);
|
||||||
if (connectedNodeIds.length > 0) {
|
if (connectedNodes.length > 0) {
|
||||||
|
const nodeIds = connectedNodes
|
||||||
|
.map((connectedNode) => getNodeReference(connectedNode))
|
||||||
|
.filter((reference) => reference !== null);
|
||||||
|
|
||||||
|
if (nodeIds.length === 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
fetch("/api/lm/loras/get_trigger_words", {
|
fetch("/api/lm/loras/get_trigger_words", {
|
||||||
method: "POST",
|
method: "POST",
|
||||||
headers: { "Content-Type": "application/json" },
|
headers: { "Content-Type": "application/json" },
|
||||||
body: JSON.stringify({
|
body: JSON.stringify({
|
||||||
lora_names: Array.from(loraNames),
|
lora_names: Array.from(loraNames),
|
||||||
node_ids: connectedNodeIds
|
node_ids: nodeIds
|
||||||
})
|
})
|
||||||
}).catch(err => console.error("Error fetching trigger words:", err));
|
}).catch(err => console.error("Error fetching trigger words:", err));
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user