From e9e8c31ad1cd2a6bb418651ac21efd225215ce9c Mon Sep 17 00:00:00 2001 From: Will Miao Date: Sun, 28 Jun 2026 17:57:58 +0800 Subject: [PATCH] fix(registry): store nodes per-client to prevent multi-tab race condition Move NodeRegistry from a single global _nodes dict to a per-client (_tab_nodes) structure so that multiple ComfyUI browser tabs no longer overwrite each other's workflow node data during a lora_registry_refresh cycle. The merged result is a union of all known tabs' target nodes, eliminating the non-deterministic failure where send-to-workflow could randomly target a tab lacking valid targets. - NodeRegistry.register_nodes(sid, nodes) replaces per-tab data without affecting other tabs. - NodeRegistry.get_merged_registry() returns the union across all connected clients, together with tab_count / per-tab metadata. - prepare_for_refresh() snapshots the current active sockets; caller re-reads before merging so that newly-connected tabs are not pruned. - workflow_registry.js sends api.clientId in the POST body so the backend can identify which tab is registering. --- py/routes/handlers/misc_handlers.py | 254 +++++++++++++++++++--------- tests/routes/test_api_snapshots.py | 9 +- tests/routes/test_misc_routes.py | 34 ++-- web/comfyui/workflow_registry.js | 5 +- 4 files changed, 208 insertions(+), 94 deletions(-) diff --git a/py/routes/handlers/misc_handlers.py b/py/routes/handlers/misc_handlers.py index c3b86981..1450046f 100644 --- a/py/routes/handlers/misc_handlers.py +++ b/py/routes/handlers/misc_handlers.py @@ -414,9 +414,10 @@ class PromptServerProtocol(Protocol): """Subset of PromptServer used by the handlers.""" instance: "PromptServerProtocol" + sockets: dict # maps clientId (sid) → WebSocketResponse def send_sync( - self, event: str, payload: dict + self, event: str, payload: dict | None = None, sid: str | None = None ) -> None: # pragma: no cover - protocol ... @@ -471,90 +472,154 @@ class BackupServiceProtocol(Protocol): class NodeRegistry: - """Thread-safe registry for tracking LoRA nodes in active workflows.""" + """Thread-safe registry for tracking LoRA nodes across ComfyUI tabs. + + Each connected ComfyUI browser tab (identified by its ``sid`` / ``clientId``) + registers its own set of workflow nodes. Queries merge all known tabs into + a single result so that the calling LM panel always sees *every* available + target node, regardless of which tab responded fastest. + """ def __init__(self) -> None: self._lock = asyncio.Lock() - self._nodes: Dict[str, dict] = {} - self._registry_updated = asyncio.Event() + # sid → {unique_id → node_info} + self._tab_nodes: Dict[str, Dict[str, dict]] = {} + self._ready = asyncio.Event() + self._waiting_clients: set[str] = set() + + @property + def pending_client_count(self) -> int: + """Number of clients that have not yet responded in the current refresh cycle.""" + return len(self._waiting_clients) + + # ------------------------------------------------------------------ + # Helpers to build one node dict (extracted so it's reused for each tab) + # ------------------------------------------------------------------ + @staticmethod + def _build_node_dict(node: dict) -> dict: + node_id = node["node_id"] + graph_id = str(node["graph_id"]) + unique_id = f"{graph_id}:{node_id}" + node_type = node.get("type", "") + type_id = NODE_TYPES.get(node_type, 0) + bgcolor = node.get("bgcolor") or DEFAULT_NODE_COLOR + + raw_capabilities = node.get("capabilities") + capabilities: dict = {} + if isinstance(raw_capabilities, dict): + capabilities = dict(raw_capabilities) + + raw_widget_names: list | None = node.get("widget_names") + if not isinstance(raw_widget_names, list): + capability_widget_names = capabilities.get("widget_names") + raw_widget_names = ( + capability_widget_names + if isinstance(capability_widget_names, list) + else None + ) + + widget_names: list[str] = [] + if isinstance(raw_widget_names, list): + widget_names = [ + str(widget_name) + for widget_name in raw_widget_names + if isinstance(widget_name, str) and widget_name + ] + + if widget_names: + capabilities["widget_names"] = widget_names + else: + capabilities.pop("widget_names", None) + + if "supports_lora" in capabilities: + capabilities["supports_lora"] = bool(capabilities["supports_lora"]) + + comfy_class = node.get("comfy_class") + if not isinstance(comfy_class, str) or not comfy_class: + comfy_class = node_type if isinstance(node_type, str) else None + + return { + "id": node_id, + "graph_id": graph_id, + "graph_name": node.get("graph_name"), + "unique_id": unique_id, + "bgcolor": bgcolor, + "title": node.get("title"), + "type": type_id, + "type_name": node_type, + "comfy_class": comfy_class, + "capabilities": capabilities, + "widget_names": widget_names, + "mode": node.get("mode"), + "marker_role": node.get("marker_role"), + } + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + async def register_nodes(self, sid: str, nodes: list[dict]) -> None: + """Register/replace the node list for a single ComfyUI tab (identified by *sid*).""" + tab_nodes: dict[str, dict] = {} + for node in nodes: + nd = self._build_node_dict(node) + tab_nodes[nd["unique_id"]] = nd - async def register_nodes(self, nodes: list[dict]) -> None: async with self._lock: - self._nodes.clear() - for node in nodes: - node_id = node["node_id"] - graph_id = str(node["graph_id"]) - unique_id = f"{graph_id}:{node_id}" - node_type = node.get("type", "") - type_id = NODE_TYPES.get(node_type, 0) - bgcolor = node.get("bgcolor") or DEFAULT_NODE_COLOR - raw_capabilities = node.get("capabilities") - capabilities: dict = {} - if isinstance(raw_capabilities, dict): - capabilities = dict(raw_capabilities) + self._tab_nodes[sid] = tab_nodes + self._waiting_clients.discard(sid) + if not self._waiting_clients: + self._ready.set() - raw_widget_names: list | None = node.get("widget_names") - if not isinstance(raw_widget_names, list): - capability_widget_names = capabilities.get("widget_names") - raw_widget_names = ( - capability_widget_names - if isinstance(capability_widget_names, list) - else None - ) + logger.debug("Registered %s nodes from client %s", len(nodes), sid) - widget_names: list[str] = [] - if isinstance(raw_widget_names, list): - widget_names = [ - str(widget_name) - for widget_name in raw_widget_names - if isinstance(widget_name, str) and widget_name - ] + def prepare_for_refresh(self, active_sids: list[str]) -> None: + """Set the list of client IDs we expect to hear from during the next refresh cycle.""" + self._ready.clear() + self._waiting_clients = set(active_sids) - if widget_names: - capabilities["widget_names"] = widget_names - else: - capabilities.pop("widget_names", None) - - if "supports_lora" in capabilities: - capabilities["supports_lora"] = bool(capabilities["supports_lora"]) - - comfy_class = node.get("comfy_class") - if not isinstance(comfy_class, str) or not comfy_class: - comfy_class = node_type if isinstance(node_type, str) else None - - self._nodes[unique_id] = { - "id": node_id, - "graph_id": graph_id, - "graph_name": node.get("graph_name"), - "unique_id": unique_id, - "bgcolor": bgcolor, - "title": node.get("title"), - "type": type_id, - "type_name": node_type, - "comfy_class": comfy_class, - "capabilities": capabilities, - "widget_names": widget_names, - "mode": node.get("mode"), - "marker_role": node.get("marker_role"), - } - logger.debug("Registered %s nodes in registry", len(nodes)) - self._registry_updated.set() - - async def get_registry(self) -> dict: - async with self._lock: - return { - "nodes": dict(self._nodes), - "node_count": len(self._nodes), - } - - async def wait_for_update(self, timeout: float = 1.0) -> bool: - self._registry_updated.clear() + async def wait_for_all(self, timeout: float = 2.0) -> bool: + """Block until every client in the current waiting set has responded + (or *timeout* seconds elapse). Returns ``True`` if all responded.""" + if not self._waiting_clients: + return True try: - await asyncio.wait_for(self._registry_updated.wait(), timeout=timeout) + await asyncio.wait_for(self._ready.wait(), timeout=timeout) return True except asyncio.TimeoutError: return False + async def get_merged_registry(self, active_sids: set[str] | None = None) -> dict: + """Return the union of all known tab nodes, pruning any tab that is no + longer connected.""" + async with self._lock: + # Garbage-collect stale entries (disconnected tabs) + if active_sids is not None: + for sid in list(self._tab_nodes): + if sid not in active_sids: + del self._tab_nodes[sid] + + merged: dict[str, dict] = {} + tab_info: dict[str, dict] = {} + for sid, nodes in self._tab_nodes.items(): + tab_info[sid] = { + "node_count": len(nodes), + "graph_names": list( + { + n.get("graph_name") + for n in nodes.values() + if n.get("graph_name") + } + ), + } + merged.update(nodes) + + return { + "nodes": merged, + "node_count": len(merged), + "tab_count": len(self._tab_nodes), + "tabs": tab_info, + } + class HealthCheckHandler: async def health_check(self, request: web.Request) -> web.Response: @@ -2995,10 +3060,21 @@ class NodeRegistryHandler: try: data = await request.json() nodes = data.get("nodes", []) + client_id = data.get("client_id") if not isinstance(nodes, list): return web.json_response( {"success": False, "error": "nodes must be a list"}, status=400 ) + + if not isinstance(client_id, str) or not client_id: + return web.json_response( + { + "success": False, + "error": "Missing client_id parameter", + }, + status=400, + ) + for index, node in enumerate(nodes): if not isinstance(node, dict): return web.json_response( @@ -3042,7 +3118,7 @@ class NodeRegistryHandler: else: node["graph_name"] = str(graph_name) - await self._node_registry.register_nodes(nodes) + await self._node_registry.register_nodes(client_id, nodes) return web.json_response( { "success": True, @@ -3066,9 +3142,15 @@ class NodeRegistryHandler: status=503, ) + # Snapshot of currently-connected ComfyUI tabs + active_sids = list(self._prompt_server.instance.sockets.keys()) + self._node_registry.prepare_for_refresh(active_sids) + try: self._prompt_server.instance.send_sync("lora_registry_refresh", {}) - logger.debug("Sent registry refresh request to frontend") + logger.debug( + "Sent registry refresh request (expecting %s clients)", len(active_sids) + ) except Exception as exc: logger.error("Failed to send registry refresh message: %s", exc) return web.json_response( @@ -3080,19 +3162,31 @@ class NodeRegistryHandler: status=500, ) - registry_updated = await self._node_registry.wait_for_update(timeout=1.0) - if not registry_updated: - logger.warning("Registry refresh timeout after 1 second") + if not await self._node_registry.wait_for_all(timeout=2.0): + logger.warning( + "Registry refresh timeout after 2s (%s/%s clients responded)", + len(active_sids) - self._node_registry.pending_client_count, + len(active_sids), + ) + + # Re-read current sockets after the wait: a tab may have connected + # while we were waiting, and we don't want to garbage-collect it. + current_sids = set(self._prompt_server.instance.sockets.keys()) + registry_info = await self._node_registry.get_merged_registry( + active_sids=current_sids + ) + + if registry_info["node_count"] == 0: + logger.warning("No nodes registered after refresh") return web.json_response( { "success": False, - "error": "Timeout Error", - "message": "Registry refresh timeout - ComfyUI frontend may not be responsive", + "error": "Empty Registry", + "message": "No workflow nodes found — ensure ComfyUI is open and the extension is loaded.", }, status=408, ) - registry_info = await self._node_registry.get_registry() return web.json_response({"success": True, "data": registry_info}) except Exception as exc: # pragma: no cover - defensive logging logger.error("Failed to get registry: %s", exc, exc_info=True) diff --git a/tests/routes/test_api_snapshots.py b/tests/routes/test_api_snapshots.py index 7b83bd65..d3884df0 100644 --- a/tests/routes/test_api_snapshots.py +++ b/tests/routes/test_api_snapshots.py @@ -60,7 +60,9 @@ class FakePromptServer: sent = [] class Instance: - def send_sync(self, event, payload): + sockets: dict = {} + + def send_sync(self, event, payload, sid=None): FakePromptServer.sent.append((event, payload)) instance = Instance() @@ -148,7 +150,8 @@ class TestNodeRegistryHandlerSnapshots: "type": "Lora Loader (LoraManager)", "title": "Test Loader", } - ] + ], + "client_id": "test-client-1", } ) @@ -167,7 +170,7 @@ class TestNodeRegistryHandlerSnapshots: standalone_mode=False, ) - request = FakeRequest(json_data={"nodes": []}) + request = FakeRequest(json_data={"nodes": [], "client_id": "test-client-1"}) response = await handler.register_nodes(request) payload = json.loads(response.text) diff --git a/tests/routes/test_misc_routes.py b/tests/routes/test_misc_routes.py index 7fc1c184..3e3a8031 100644 --- a/tests/routes/test_misc_routes.py +++ b/tests/routes/test_misc_routes.py @@ -586,7 +586,9 @@ class FakePromptServer: sent = [] class Instance: - def send_sync(self, event, payload): + sockets: dict = {} + + def send_sync(self, event, payload, sid=None): FakePromptServer.sent.append((event, payload)) instance = Instance() @@ -601,7 +603,12 @@ async def test_register_nodes_requires_graph_id(): standalone_mode=False, ) - request = FakeRequest(json_data={"nodes": [{"node_id": 1}]}) + request = FakeRequest( + json_data={ + "nodes": [{"node_id": 1}], + "client_id": "test-client-1", + } + ) response = await handler.register_nodes(request) payload = json.loads(response.text) @@ -629,7 +636,8 @@ async def test_register_nodes_stores_graph_identifier(): "type": "Lora Loader (LoraManager)", "title": "Loader", } - ] + ], + "client_id": "test-client-1", } ) @@ -638,7 +646,7 @@ async def test_register_nodes_stores_graph_identifier(): assert payload["success"] is True - registry = await node_registry.get_registry() + registry = await node_registry.get_merged_registry() assert registry["node_count"] == 1 stored_node = next(iter(registry["nodes"].values())) assert stored_node["graph_id"] == "graph-123" @@ -664,7 +672,8 @@ async def test_register_nodes_defaults_graph_name_to_none(): "type": "Lora Loader (LoraManager)", "title": "Root Loader", } - ] + ], + "client_id": "test-client-1", } ) @@ -673,7 +682,7 @@ async def test_register_nodes_defaults_graph_name_to_none(): assert payload["success"] is True - registry = await node_registry.get_registry() + registry = await node_registry.get_merged_registry() stored_node = next(iter(registry["nodes"].values())) assert stored_node["graph_name"] is None @@ -700,7 +709,8 @@ async def test_register_nodes_includes_capabilities(): "widget_names": ["ckpt_name", "", 42], }, } - ] + ], + "client_id": "test-client-1", } ) @@ -709,7 +719,7 @@ async def test_register_nodes_includes_capabilities(): assert payload["success"] is True - registry = await node_registry.get_registry() + registry = await node_registry.get_merged_registry() stored_node = next(iter(registry["nodes"].values())) assert stored_node["capabilities"] == { "supports_lora": False, @@ -724,7 +734,9 @@ async def test_update_node_widget_sends_payload(): class RecordingPromptServer: class Instance: - def send_sync(self, event, payload): + sockets: dict = {} + + def send_sync(self, event, payload, sid=None): send_calls.append((event, payload)) instance = Instance() @@ -768,7 +780,9 @@ async def test_update_lora_code_includes_graph_identifier(): class RecordingPromptServer: class Instance: - def send_sync(self, event, payload): + sockets: dict = {} + + def send_sync(self, event, payload, sid=None): send_calls.append((event, payload)) instance = Instance() diff --git a/web/comfyui/workflow_registry.js b/web/comfyui/workflow_registry.js index 5984dfc4..29e968a7 100644 --- a/web/comfyui/workflow_registry.js +++ b/web/comfyui/workflow_registry.js @@ -151,7 +151,10 @@ app.registerExtension({ headers: { "Content-Type": "application/json", }, - body: JSON.stringify({ nodes: workflowNodes }), + body: JSON.stringify({ + nodes: workflowNodes, + client_id: api.clientId ?? api.initialClientId ?? "", + }), }); if (!response.ok) {