From d0aa9166832b5b2cf4688b983af3846eeb8500ee Mon Sep 17 00:00:00 2001 From: Will Miao Date: Thu, 23 Oct 2025 10:44:25 +0800 Subject: [PATCH] feat(node-registry): add support to send checkpoint/diffusion model to workflow - Add capabilities parsing and validation for node registration - Implement widget_names extraction from capabilities with type safety - Add supports_lora boolean conversion in capabilities - Include comfy_class fallback to node_type when missing - Add new update_node_widget API endpoint for bulk widget updates - Improve error handling and input validation for widget updates - Remove unused parameters from node selector event setup function These changes improve node metadata handling and enable dynamic widget management capabilities. --- py/routes/handlers/misc_handlers.py | 116 +++++++ py/routes/misc_route_registrar.py | 1 + static/js/components/shared/ModelCard.js | 68 ++++- static/js/utils/uiHelpers.js | 373 ++++++++++++++++++----- tests/routes/test_misc_routes.py | 72 +++++ web/comfyui/usage_stats.js | 61 +--- web/comfyui/workflow_registry.js | 151 +++++++++ 7 files changed, 706 insertions(+), 136 deletions(-) create mode 100644 web/comfyui/workflow_registry.js diff --git a/py/routes/handlers/misc_handlers.py b/py/routes/handlers/misc_handlers.py index a902c2ae..638689dc 100644 --- a/py/routes/handlers/misc_handlers.py +++ b/py/routes/handlers/misc_handlers.py @@ -100,6 +100,36 @@ class NodeRegistry: node_type = node.get("type", "") type_id = NODE_TYPES.get(node_type, 0) bgcolor = node.get("bgcolor") or DEFAULT_NODE_COLOR + raw_capabilities = node.get("capabilities") + capabilities: dict = {} + if isinstance(raw_capabilities, dict): + capabilities = dict(raw_capabilities) + + raw_widget_names: list | None = node.get("widget_names") + if not isinstance(raw_widget_names, list): + capability_widget_names = capabilities.get("widget_names") + raw_widget_names = capability_widget_names if isinstance(capability_widget_names, list) else None + + widget_names: list[str] = [] + if isinstance(raw_widget_names, list): + widget_names = [ + str(widget_name) + for widget_name in raw_widget_names + if isinstance(widget_name, str) and widget_name + ] + + if widget_names: + capabilities["widget_names"] = widget_names + else: + capabilities.pop("widget_names", None) + + if "supports_lora" in capabilities: + capabilities["supports_lora"] = bool(capabilities["supports_lora"]) + + comfy_class = node.get("comfy_class") + if not isinstance(comfy_class, str) or not comfy_class: + comfy_class = node_type if isinstance(node_type, str) else None + self._nodes[unique_id] = { "id": node_id, "graph_id": graph_id, @@ -109,6 +139,9 @@ class NodeRegistry: "title": node.get("title"), "type": type_id, "type_name": node_type, + "comfy_class": comfy_class, + "capabilities": capabilities, + "widget_names": widget_names, } logger.debug("Registered %s nodes in registry", len(nodes)) self._registry_updated.set() @@ -919,6 +952,88 @@ class NodeRegistryHandler: logger.error("Failed to get registry: %s", exc, exc_info=True) return web.json_response({"success": False, "error": "Internal Error", "message": str(exc)}, status=500) + async def update_node_widget(self, request: web.Request) -> web.Response: + try: + data = await request.json() + widget_name = data.get("widget_name") + value = data.get("value") + node_ids = data.get("node_ids") + + if not isinstance(widget_name, str) or not widget_name: + return web.json_response({"success": False, "error": "Missing widget_name parameter"}, status=400) + + if not isinstance(value, str) or not value: + return web.json_response({"success": False, "error": "Missing value parameter"}, status=400) + + if not isinstance(node_ids, list) or not node_ids: + return web.json_response( + {"success": False, "error": "node_ids must be a non-empty list"}, + status=400, + ) + + results = [] + for entry in node_ids: + node_identifier = entry + graph_identifier = None + if isinstance(entry, dict): + node_identifier = entry.get("node_id") + graph_identifier = entry.get("graph_id") + + if node_identifier is None: + results.append( + { + "node_id": node_identifier, + "graph_id": graph_identifier, + "success": False, + "error": "Missing node_id parameter", + } + ) + continue + + try: + parsed_node_id = int(node_identifier) + except (TypeError, ValueError): + parsed_node_id = node_identifier + + payload = { + "id": parsed_node_id, + "widget_name": widget_name, + "value": value, + } + + if graph_identifier is not None: + payload["graph_id"] = str(graph_identifier) + + try: + self._prompt_server.instance.send_sync("lm_widget_update", payload) + results.append( + { + "node_id": parsed_node_id, + "graph_id": payload.get("graph_id"), + "success": True, + } + ) + except Exception as exc: # pragma: no cover - defensive logging + logger.error( + "Error sending widget update to node %s (graph %s): %s", + parsed_node_id, + graph_identifier, + exc, + ) + results.append( + { + "node_id": parsed_node_id, + "graph_id": payload.get("graph_id"), + "success": False, + "error": str(exc), + } + ) + + return web.json_response({"success": True, "results": results}) + except Exception as exc: # pragma: no cover - defensive logging + logger.error("Failed to update node widget: %s", exc, exc_info=True) + return web.json_response({"success": False, "error": str(exc)}, status=500) + class MiscHandlerSet: """Aggregate handlers into a lookup compatible with the registrar.""" @@ -962,6 +1077,7 @@ class MiscHandlerSet: "get_trained_words": self.trained_words.get_trained_words, "get_model_example_files": self.model_examples.get_model_example_files, "register_nodes": self.node_registry.register_nodes, + "update_node_widget": self.node_registry.update_node_widget, "get_registry": self.node_registry.get_registry, "check_model_exists": self.model_library.check_model_exists, "get_civitai_user_models": self.model_library.get_civitai_user_models, diff --git a/py/routes/misc_route_registrar.py b/py/routes/misc_route_registrar.py index 53068566..a68aa8eb 100644 --- a/py/routes/misc_route_registrar.py +++ b/py/routes/misc_route_registrar.py @@ -33,6 +33,7 @@ MISC_ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = ( RouteDefinition("GET", "/api/lm/trained-words", "get_trained_words"), RouteDefinition("GET", "/api/lm/model-example-files", "get_model_example_files"), RouteDefinition("POST", "/api/lm/register-nodes", "register_nodes"), + RouteDefinition("POST", "/api/lm/update-node-widget", "update_node_widget"), RouteDefinition("GET", "/api/lm/get-registry", "get_registry"), RouteDefinition("GET", "/api/lm/check-model-exists", "check_model_exists"), RouteDefinition("GET", "/api/lm/civitai/user-models", "get_civitai_user_models"), diff --git a/static/js/components/shared/ModelCard.js b/static/js/components/shared/ModelCard.js index a60b1642..aa777292 100644 --- a/static/js/components/shared/ModelCard.js +++ b/static/js/components/shared/ModelCard.js @@ -1,4 +1,4 @@ -import { showToast, openCivitai, copyToClipboard, copyLoraSyntax, sendLoraToWorkflow, openExampleImagesFolder, buildLoraSyntax } from '../../utils/uiHelpers.js'; +import { showToast, openCivitai, copyToClipboard, copyLoraSyntax, sendLoraToWorkflow, openExampleImagesFolder, buildLoraSyntax, sendModelPathToWorkflow } from '../../utils/uiHelpers.js'; import { state, getCurrentPageState } from '../../state/index.js'; import { showModelModal } from './ModelModal.js'; import { toggleShowcase } from './showcase/ShowcaseView.js'; @@ -168,8 +168,53 @@ function handleSendToWorkflow(card, replaceMode, modelType) { const usageTips = JSON.parse(card.dataset.usage_tips || '{}'); const loraSyntax = buildLoraSyntax(card.dataset.file_name, usageTips); sendLoraToWorkflow(loraSyntax, replaceMode, 'lora'); + } else if (modelType === MODEL_TYPES.CHECKPOINT) { + const modelPath = card.dataset.filepath; + if (!modelPath) { + const message = translate('modelCard.sendToWorkflow.missingPath', {}, 'Unable to determine model path for this card'); + showToast(message, {}, 'error'); + return; + } + + const subtype = (card.dataset.model_type || 'checkpoint').toLowerCase(); + const isDiffusionModel = subtype === 'diffusion_model'; + const widgetName = isDiffusionModel ? 'unet_name' : 'ckpt_name'; + const actionTypeText = translate( + isDiffusionModel ? 'uiHelpers.nodeSelector.diffusionModel' : 'uiHelpers.nodeSelector.checkpoint', + {}, + isDiffusionModel ? 'Diffusion Model' : 'Checkpoint' + ); + const successMessage = translate( + isDiffusionModel ? 'uiHelpers.workflow.diffusionModelUpdated' : 'uiHelpers.workflow.checkpointUpdated', + {}, + isDiffusionModel ? 'Diffusion model updated in workflow' : 'Checkpoint updated in workflow' + ); + const failureMessage = translate( + isDiffusionModel ? 'uiHelpers.workflow.diffusionModelFailed' : 'uiHelpers.workflow.checkpointFailed', + {}, + isDiffusionModel ? 'Failed to update diffusion model node' : 'Failed to update checkpoint node' + ); + const missingNodesMessage = translate( + 'uiHelpers.workflow.noMatchingNodes', + {}, + 'No compatible nodes available in the current workflow' + ); + const missingTargetMessage = translate( + 'uiHelpers.workflow.noTargetNodeSelected', + {}, + 'No target node selected' + ); + + sendModelPathToWorkflow(modelPath, { + widgetName, + collectionType: MODEL_TYPES.CHECKPOINT, + actionTypeText, + successMessage, + failureMessage, + missingNodesMessage, + missingTargetMessage, + }); } else { - // Checkpoint send functionality - to be implemented showToast('modelCard.sendToWorkflow.checkpointNotImplemented', {}, 'info'); } } @@ -470,8 +515,21 @@ export function createModelCard(model, modelType) { const globeTitle = model.from_civitai ? translate('modelCard.actions.viewOnCivitai', {}, 'View on Civitai') : translate('modelCard.actions.notAvailableFromCivitai', {}, 'Not available from Civitai'); - const sendTitle = translate('modelCard.actions.sendToWorkflow', {}, 'Send to ComfyUI (Click: Append, Shift+Click: Replace)'); - const copyTitle = translate('modelCard.actions.copyLoRASyntax', {}, 'Copy LoRA Syntax'); + let sendTitle; + let copyTitle; + if (modelType === MODEL_TYPES.LORA) { + sendTitle = translate('modelCard.actions.sendToWorkflow', {}, 'Send to ComfyUI (Click: Append, Shift+Click: Replace)'); + copyTitle = translate('modelCard.actions.copyLoRASyntax', {}, 'Copy LoRA Syntax'); + } else if (modelType === MODEL_TYPES.CHECKPOINT) { + sendTitle = translate('modelCard.actions.sendCheckpointToWorkflow', {}, 'Send to ComfyUI'); + copyTitle = translate('modelCard.actions.copyCheckpointName', {}, 'Copy checkpoint name'); + } else if (modelType === MODEL_TYPES.EMBEDDING) { + sendTitle = translate('modelCard.actions.sendEmbeddingToWorkflow', {}, 'Send to ComfyUI'); + copyTitle = translate('modelCard.actions.copyEmbeddingName', {}, 'Copy embedding name'); + } else { + sendTitle = translate('modelCard.actions.sendToWorkflow', {}, 'Send to ComfyUI'); + copyTitle = translate('modelCard.actions.copyLoRASyntax', {}, 'Copy value'); + } const actionIcons = ` { + try { + return predicate(node); + } catch (error) { + console.warn('Failed to evaluate registry node predicate', error); + return false; + } + }), + ); +} + +function getWidgetNames(node) { + if (!node) { + return []; + } + + if (Array.isArray(node.widget_names)) { + return node.widget_names; + } + + if (node.capabilities && Array.isArray(node.capabilities.widget_names)) { + return node.capabilities.widget_names; + } + + return []; +} + +function isAbsolutePath(path) { + if (typeof path !== 'string') { + return false; + } + + return path.startsWith('/') || path.startsWith('\\') || /^[a-zA-Z]:[\\/]/.test(path); +} + +async function ensureRelativeModelPath(modelPath, collectionType) { + if (!modelPath || !isAbsolutePath(modelPath)) { + return modelPath; + } + + const fileName = modelPath.split(/[/\\]/).pop(); + if (!fileName) { + return modelPath; + } + + try { + const response = await fetch(`/api/lm/${collectionType}/relative-paths?search=${encodeURIComponent(fileName)}&limit=10`); + if (!response.ok) { + return modelPath; + } + const data = await response.json(); + const relativePaths = Array.isArray(data?.relative_paths) ? data.relative_paths : []; + if (relativePaths.length === 0) { + return modelPath; + } + const exactMatch = relativePaths.find((path) => path.endsWith(fileName)); + return exactMatch || relativePaths[0] || modelPath; + } catch (error) { + console.warn('LoRA Manager: failed to resolve relative path for model', error); + return modelPath; + } +} + /** * Sends LoRA syntax to the active ComfyUI workflow * @param {string} loraSyntax - The LoRA syntax to send @@ -406,44 +497,106 @@ export function copyLoraSyntax(card) { * @returns {Promise} - Whether the operation was successful */ export async function sendLoraToWorkflow(loraSyntax, replaceMode = false, syntaxType = 'lora') { - try { - // Get registry information from the new endpoint - const registryResponse = await fetch('/api/lm/get-registry'); - const registryData = await registryResponse.json(); - - if (!registryData.success) { - // Handle specific error cases - if (registryData.error === 'Standalone Mode Active') { - // Standalone mode - show warning with specific message - showToast('toast.general.cannotInteractStandalone', {}, 'warning'); - return false; - } else { - // Other errors - show error toast - showToast('toast.general.failedWorkflowInfo', {}, 'error'); - return false; - } - } - - // Success case - check node count - if (registryData.data.node_count === 0) { - // No nodes found - show warning - showToast('uiHelpers.workflow.noSupportedNodes', {}, 'warning'); - return false; - } else if (registryData.data.node_count > 1) { - // Multiple nodes - show selector - showNodeSelector(registryData.data.nodes, loraSyntax, replaceMode, syntaxType); - return true; - } else { - // Single node - send directly - const nodes = registryData.data.nodes; - const nodeId = Object.keys(nodes)[0]; - return await sendToSpecificNode([nodeId], nodes, loraSyntax, replaceMode, syntaxType); - } - } catch (error) { - console.error('Failed to get registry:', error); - showToast('uiHelpers.workflow.communicationFailed', {}, 'error'); + const registry = await fetchWorkflowRegistry(); + if (!registry) { return false; } + + const loraNodes = filterRegistryNodes(registry.nodes, (node) => { + if (!node) { + return false; + } + if (node.capabilities && typeof node.capabilities === 'object') { + if (node.capabilities.supports_lora === true) { + return true; + } + } + return typeof node.type === 'number' && node.type > 0; + }); + + const nodeKeys = Object.keys(loraNodes); + if (nodeKeys.length === 0) { + showToast('uiHelpers.workflow.noSupportedNodes', {}, 'warning'); + return false; + } + + if (nodeKeys.length === 1) { + return await sendLoraToNodes([nodeKeys[0]], loraNodes, loraSyntax, replaceMode, syntaxType); + } + + const actionType = + syntaxType === 'recipe' + ? translate('uiHelpers.nodeSelector.recipe', {}, 'Recipe') + : translate('uiHelpers.nodeSelector.lora', {}, 'LoRA'); + const actionMode = replaceMode + ? translate('uiHelpers.nodeSelector.replace', {}, 'Replace') + : translate('uiHelpers.nodeSelector.append', {}, 'Append'); + + showNodeSelector(loraNodes, { + actionType, + actionMode, + onSend: (selectedNodeIds) => + sendLoraToNodes(selectedNodeIds, loraNodes, loraSyntax, replaceMode, syntaxType), + }); + return true; +} + +export async function sendModelPathToWorkflow(modelPath, options) { + const { + widgetName, + collectionType = 'checkpoints', + actionTypeText = 'Checkpoint', + successMessage = 'Updated workflow node', + failureMessage = 'Failed to update workflow node', + missingNodesMessage = 'No compatible nodes available in the current workflow', + missingTargetMessage = 'No target node selected', + } = options; + + if (!widgetName) { + console.warn('LoRA Manager: widget name is required to send model to workflow'); + return false; + } + + const relativePath = await ensureRelativeModelPath(modelPath, collectionType); + + const registry = await fetchWorkflowRegistry(); + if (!registry) { + return false; + } + + const targetNodes = filterRegistryNodes(registry.nodes, (node) => { + const widgetNames = getWidgetNames(node); + return widgetNames.includes(widgetName); + }); + + const nodeKeys = Object.keys(targetNodes); + if (nodeKeys.length === 0) { + showToast(missingNodesMessage, {}, 'warning'); + return false; + } + + const actionType = actionTypeText; + const actionMode = translate('uiHelpers.nodeSelector.replace', {}, 'Replace'); + + const messages = { + successMessage, + failureMessage, + missingTargetMessage, + }; + + const handleSend = (selectedNodeIds) => + sendWidgetValueToNodes(selectedNodeIds, targetNodes, widgetName, relativePath, messages); + + if (nodeKeys.length === 1) { + return await handleSend([nodeKeys[0]]); + } + + showNodeSelector(targetNodes, { + actionType, + actionMode, + onSend: handleSend, + }); + return true; } /** @@ -483,7 +636,7 @@ function resolveNodeReference(nodeKey, nodesMap) { }; } -async function sendToSpecificNode(nodeIds, nodesMap, loraSyntax, replaceMode, syntaxType) { +async function sendLoraToNodes(nodeIds, nodesMap, loraSyntax, replaceMode, syntaxType) { try { // Call the backend API to update the lora code const requestBody = { @@ -547,29 +700,96 @@ async function sendToSpecificNode(nodeIds, nodesMap, loraSyntax, replaceMode, sy } } +async function sendWidgetValueToNodes(nodeIds, nodesMap, widgetName, value, messages = {}) { + const { + successMessage = 'Updated workflow node', + failureMessage = 'Failed to update workflow node', + missingTargetMessage = 'No target node selected', + } = messages; + + const targetIds = Array.isArray(nodeIds) ? nodeIds : []; + if (targetIds.length === 0) { + showToast(missingTargetMessage, {}, 'warning'); + return false; + } + + const references = targetIds + .map((nodeKey) => resolveNodeReference(nodeKey, nodesMap)) + .filter((reference) => reference && reference.node_id !== undefined); + + if (references.length === 0) { + showToast(missingTargetMessage, {}, 'warning'); + return false; + } + + try { + const response = await fetch('/api/lm/update-node-widget', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ + widget_name: widgetName, + value, + node_ids: references, + }), + }); + + const result = await response.json(); + if (result.success) { + showToast(successMessage, {}, 'success'); + return true; + } + + const errorMessage = result?.error || failureMessage; + showToast(errorMessage, {}, 'error'); + return false; + } catch (error) { + console.error('Failed to send widget value to workflow:', error); + showToast(failureMessage, {}, 'error'); + return false; + } +} + // Global variable to track active node selector state let nodeSelectorState = { isActive: false, clickHandler: null, - selectorClickHandler: null + selectorClickHandler: null, + currentNodes: {}, + onSend: null, + enableSendAll: true, }; /** * Show node selector popup near mouse position * @param {Object} nodes - Registry nodes data - * @param {string} loraSyntax - The LoRA syntax to send - * @param {boolean} replaceMode - Whether to replace existing LoRAs - * @param {string} syntaxType - The type of syntax ('lora' or 'recipe') + * @param {Object} options - Configuration for display and actions + * @param {string} options.actionType - Display label for the action type (e.g. LoRA) + * @param {string} options.actionMode - Display label for the action mode (e.g. Replace) + * @param {Function} options.onSend - Callback invoked with selected node ids + * @param {boolean} [options.enableSendAll=true] - Whether to show the "send to all" option */ -function showNodeSelector(nodes, loraSyntax, replaceMode, syntaxType) { +function showNodeSelector(nodes, options = {}) { const selector = document.getElementById('nodeSelector'); if (!selector) return; // Clean up any existing state hideNodeSelector(); + + const safeNodes = nodes || {}; + const onSend = typeof options.onSend === 'function' ? options.onSend : null; + if (!onSend) { + console.warn('LoRA Manager: node selector invoked without send handler'); + return; + } + + nodeSelectorState.currentNodes = safeNodes; + nodeSelectorState.onSend = onSend; + nodeSelectorState.enableSendAll = options.enableSendAll !== false; // Generate node list HTML with icons and proper colors - const nodeItems = Object.entries(nodes).map(([nodeKey, node]) => { + const nodeItems = Object.entries(safeNodes).map(([nodeKey, node]) => { const iconClass = NODE_TYPE_ICONS[node.type] || 'fas fa-question-circle'; const bgColor = node.bgcolor || DEFAULT_NODE_COLOR; const graphLabel = node.graph_name ? ` (${node.graph_name})` : ''; @@ -585,14 +805,20 @@ function showNodeSelector(nodes, loraSyntax, replaceMode, syntaxType) { }).join(''); // Add header with action mode indicator - const actionType = syntaxType === 'recipe' ? - translate('uiHelpers.nodeSelector.recipe', {}, 'Recipe') : - translate('uiHelpers.nodeSelector.lora', {}, 'LoRA'); - const actionMode = replaceMode ? - translate('uiHelpers.nodeSelector.replace', {}, 'Replace') : - translate('uiHelpers.nodeSelector.append', {}, 'Append'); + const actionType = options.actionType ?? translate('uiHelpers.nodeSelector.lora', {}, 'LoRA'); + const actionMode = options.actionMode ?? translate('uiHelpers.nodeSelector.replace', {}, 'Replace'); const selectTargetNodeText = translate('uiHelpers.nodeSelector.selectTargetNode', {}, 'Select target node'); const sendToAllText = translate('uiHelpers.nodeSelector.sendToAll', {}, 'Send to All'); + + const sendAllMarkup = nodeSelectorState.enableSendAll + ? ` +
+
+ +
+ ${sendToAllText} +
` + : ''; selector.innerHTML = `
@@ -600,12 +826,7 @@ function showNodeSelector(nodes, loraSyntax, replaceMode, syntaxType) { ${selectTargetNodeText}
${nodeItems} -
-
- -
- ${sendToAllText} -
+ ${sendAllMarkup} `; // Position near mouse @@ -619,18 +840,14 @@ function showNodeSelector(nodes, loraSyntax, replaceMode, syntaxType) { eventManager.setState('nodeSelectorActive', true); // Setup event listeners with proper cleanup through event manager - setupNodeSelectorEvents(selector, nodes, loraSyntax, replaceMode, syntaxType); + setupNodeSelectorEvents(selector); } /** * Setup event listeners for node selector using event manager * @param {HTMLElement} selector - The selector element - * @param {Object} nodes - Registry nodes data - * @param {string} loraSyntax - The LoRA syntax to send - * @param {boolean} replaceMode - Whether to replace existing LoRAs - * @param {string} syntaxType - The type of syntax ('lora' or 'recipe') */ -function setupNodeSelectorEvents(selector, nodes, loraSyntax, replaceMode, syntaxType) { +function setupNodeSelectorEvents(selector) { // Clean up any existing event listeners cleanupNodeSelectorEvents(); @@ -650,21 +867,32 @@ function setupNodeSelectorEvents(selector, nodes, loraSyntax, replaceMode, synta const nodeItem = e.target.closest('.node-item'); if (!nodeItem) return false; // Continue with other handlers + const onSend = nodeSelectorState.onSend; + if (typeof onSend !== 'function') { + hideNodeSelector(); + return true; + } + e.stopPropagation(); const action = nodeItem.dataset.action; const nodeId = nodeItem.dataset.nodeId; + const nodes = nodeSelectorState.currentNodes || {}; - if (action === 'send-all') { - // Send to all nodes - const allNodeIds = Object.keys(nodes); - await sendToSpecificNode(allNodeIds, nodes, loraSyntax, replaceMode, syntaxType); - } else if (nodeId) { - // Send to specific node - await sendToSpecificNode([nodeId], nodes, loraSyntax, replaceMode, syntaxType); + try { + if (action === 'send-all') { + if (!nodeSelectorState.enableSendAll) { + return true; + } + const allNodeIds = Object.keys(nodes); + await onSend(allNodeIds); + } else if (nodeId) { + await onSend([nodeId]); + } + } finally { + hideNodeSelector(); } - - hideNodeSelector(); + return true; // Stop propagation }, { priority: 150, // High priority but lower than outside click @@ -699,6 +927,9 @@ function hideNodeSelector() { // Clean up event listeners cleanupNodeSelectorEvents(); nodeSelectorState.isActive = false; + nodeSelectorState.currentNodes = {}; + nodeSelectorState.onSend = null; + nodeSelectorState.enableSendAll = true; // Update event manager state eventManager.setState('nodeSelectorActive', false); @@ -787,4 +1018,4 @@ export async function openExampleImagesFolder(modelHash) { showToast('uiHelpers.exampleImages.failedToOpen', {}, 'error'); return false; } -} \ No newline at end of file +} diff --git a/tests/routes/test_misc_routes.py b/tests/routes/test_misc_routes.py index 79dab8e4..77dda37e 100644 --- a/tests/routes/test_misc_routes.py +++ b/tests/routes/test_misc_routes.py @@ -219,6 +219,78 @@ async def test_register_nodes_defaults_graph_name_to_none(): assert stored_node["graph_name"] is None +@pytest.mark.asyncio +async def test_register_nodes_includes_capabilities(): + node_registry = NodeRegistry() + handler = NodeRegistryHandler( + node_registry=node_registry, + prompt_server=FakePromptServer, + standalone_mode=False, + ) + + request = FakeRequest( + json_data={ + "nodes": [ + { + "node_id": 9, + "graph_id": "root", + "type": "CheckpointLoaderSimple", + "title": "Checkpoint Loader", + "capabilities": {"supports_lora": False, "widget_names": ["ckpt_name", "", 42]}, + } + ] + } + ) + + response = await handler.register_nodes(request) + payload = json.loads(response.text) + + assert payload["success"] is True + + registry = await node_registry.get_registry() + stored_node = next(iter(registry["nodes"].values())) + assert stored_node["capabilities"] == {"supports_lora": False, "widget_names": ["ckpt_name"]} + assert stored_node["widget_names"] == ["ckpt_name"] + + +@pytest.mark.asyncio +async def test_update_node_widget_sends_payload(): + send_calls: list[tuple[str, dict]] = [] + + class RecordingPromptServer: + class Instance: + def send_sync(self, event, payload): + send_calls.append((event, payload)) + + instance = Instance() + + handler = NodeRegistryHandler( + node_registry=NodeRegistry(), + prompt_server=RecordingPromptServer, + standalone_mode=False, + ) + + request = FakeRequest( + json_data={ + "widget_name": "ckpt_name", + "value": "models/checkpoints/model.ckpt", + "node_ids": [{"node_id": 12, "graph_id": "root"}], + } + ) + + response = await handler.update_node_widget(request) + payload = json.loads(response.text) + + assert response.status == 200 + assert payload["success"] is True + assert send_calls == [ + ( + "lm_widget_update", + {"id": 12, "widget_name": "ckpt_name", "value": "models/checkpoints/model.ckpt", "graph_id": "root"}, + ) + ] + + @pytest.mark.asyncio async def test_update_lora_code_includes_graph_identifier(): send_calls: list[tuple[str, dict]] = [] diff --git a/web/comfyui/usage_stats.js b/web/comfyui/usage_stats.js index da1e2b13..4a6ace12 100644 --- a/web/comfyui/usage_stats.js +++ b/web/comfyui/usage_stats.js @@ -1,7 +1,7 @@ // ComfyUI extension to track model usage statistics import { app } from "../../scripts/app.js"; import { api } from "../../scripts/api.js"; -import { getAllGraphNodes, getNodeReference, showToast } from "./utils.js"; +import { showToast } from "./utils.js"; // Define target nodes and their widget configurations const PATH_CORRECTION_TARGETS = [ @@ -68,12 +68,8 @@ app.registerExtension({ } }); - // Listen for registry refresh requests - api.addEventListener("lora_registry_refresh", () => { - this.refreshRegistry(); - }); }, - + async updateUsageStats(promptId) { try { // Call backend endpoint with the prompt_id @@ -93,59 +89,6 @@ app.registerExtension({ } }, - async refreshRegistry() { - try { - const loraNodes = []; - const nodeEntries = getAllGraphNodes(app.graph); - - for (const { graph, node } of nodeEntries) { - if (!node || !node.comfyClass) { - continue; - } - - if ( - node.comfyClass === "Lora Loader (LoraManager)" || - node.comfyClass === "Lora Stacker (LoraManager)" || - node.comfyClass === "WanVideo Lora Select (LoraManager)" - ) { - const reference = getNodeReference(node); - if (!reference) { - continue; - } - - const graphName = typeof graph?.name === "string" && graph.name.trim() - ? graph.name - : null; - - loraNodes.push({ - node_id: reference.node_id, - graph_id: reference.graph_id, - graph_name: graphName, - bgcolor: node.bgcolor ?? node.color ?? null, - title: node.title || node.comfyClass, - type: node.comfyClass, - }); - } - } - - const response = await fetch('/api/lm/register-nodes', { - method: 'POST', - headers: { - 'Content-Type': 'application/json', - }, - body: JSON.stringify({ nodes: loraNodes }), - }); - - if (!response.ok) { - console.warn("Failed to register Lora nodes:", response.statusText); - } else { - console.log(`Successfully registered ${loraNodes.length} Lora nodes`); - } - } catch (error) { - console.error("Error refreshing registry:", error); - } - }, - async loadedGraphNode(node) { if (!getAutoPathCorrectionPreference()) { return; diff --git a/web/comfyui/workflow_registry.js b/web/comfyui/workflow_registry.js new file mode 100644 index 00000000..9f200c5e --- /dev/null +++ b/web/comfyui/workflow_registry.js @@ -0,0 +1,151 @@ +import { app } from "../../scripts/app.js"; +import { api } from "../../scripts/api.js"; +import { getAllGraphNodes, getNodeReference, getNodeFromGraph } from "./utils.js"; + +const LORA_NODE_CLASSES = new Set([ + "Lora Loader (LoraManager)", + "Lora Stacker (LoraManager)", + "WanVideo Lora Select (LoraManager)", +]); + +const TARGET_WIDGET_NAMES = new Set(["ckpt_name", "unet_name"]); + +app.registerExtension({ + name: "LoraManager.WorkflowRegistry", + + setup() { + api.addEventListener("lora_registry_refresh", () => { + this.refreshRegistry(); + }); + + api.addEventListener("lm_widget_update", (event) => { + this.applyWidgetUpdate(event?.detail ?? {}); + }); + }, + + async refreshRegistry() { + try { + const workflowNodes = []; + const nodeEntries = getAllGraphNodes(app.graph); + + for (const { graph, node } of nodeEntries) { + if (!node) { + continue; + } + + const widgetNames = Array.isArray(node.widgets) + ? node.widgets + .map((widget) => widget?.name) + .filter((name) => typeof name === "string" && name.length > 0) + : []; + + const supportsLora = LORA_NODE_CLASSES.has(node.comfyClass); + const hasTargetWidget = widgetNames.some((name) => TARGET_WIDGET_NAMES.has(name)); + + if (!supportsLora && !hasTargetWidget) { + continue; + } + + const reference = getNodeReference(node); + if (!reference) { + continue; + } + + const graphName = + typeof graph?.name === "string" && graph.name.trim() ? graph.name : null; + + workflowNodes.push({ + node_id: reference.node_id, + graph_id: reference.graph_id, + graph_name: graphName, + bgcolor: node.bgcolor ?? node.color ?? null, + title: node.title || node.comfyClass, + type: node.comfyClass, + comfy_class: node.comfyClass, + capabilities: { + supports_lora: supportsLora, + widget_names: widgetNames, + }, + }); + } + + const response = await fetch("/api/lm/register-nodes", { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ nodes: workflowNodes }), + }); + + if (!response.ok) { + console.warn("LoRA Manager: failed to register workflow nodes", response.statusText); + } else { + console.debug( + `LoRA Manager: registered ${workflowNodes.length} workflow nodes` + ); + } + } catch (error) { + console.error("LoRA Manager: error refreshing workflow registry", error); + } + }, + + applyWidgetUpdate(message) { + const nodeId = message?.node_id ?? message?.id; + const graphId = message?.graph_id; + const widgetName = message?.widget_name; + const value = message?.value; + + if (nodeId == null || !widgetName) { + console.warn("LoRA Manager: invalid widget update payload", message); + return; + } + + const node = getNodeFromGraph(graphId, nodeId); + if (!node) { + console.warn( + "LoRA Manager: target node not found for widget update", + graphId ?? "root", + nodeId + ); + return; + } + + if (!Array.isArray(node.widgets)) { + console.warn("LoRA Manager: node does not expose widgets", node); + return; + } + + const widgetIndex = node.widgets.findIndex((widget) => widget?.name === widgetName); + if (widgetIndex === -1) { + console.warn( + "LoRA Manager: target widget not found on node", + widgetName, + node + ); + return; + } + + const widget = node.widgets[widgetIndex]; + widget.value = value; + + if (Array.isArray(node.widgets_values) && node.widgets_values.length > widgetIndex) { + node.widgets_values[widgetIndex] = value; + } + + if (typeof widget.callback === "function") { + try { + widget.callback(value); + } catch (callbackError) { + console.error("LoRA Manager: widget callback failed", callbackError); + } + } + + if (typeof node.setDirtyCanvas === "function") { + node.setDirtyCanvas(true); + } + + if (typeof app.graph?.setDirtyCanvas === "function") { + app.graph.setDirtyCanvas(true, true); + } + }, +});