fix(registry): store nodes per-client to prevent multi-tab race condition

Move NodeRegistry from a single global _nodes dict to a per-client
(_tab_nodes) structure so that multiple ComfyUI browser tabs no
longer overwrite each other's workflow node data during a
lora_registry_refresh cycle.  The merged result is a union of all
known tabs' target nodes, eliminating the non-deterministic failure
where send-to-workflow could randomly target a tab lacking valid
targets.

- NodeRegistry.register_nodes(sid, nodes) replaces per-tab data
  without affecting other tabs.
- NodeRegistry.get_merged_registry() returns the union across all
  connected clients, together with tab_count / per-tab metadata.
- prepare_for_refresh() snapshots the current active sockets; caller
  re-reads before merging so that newly-connected tabs are not pruned.
- workflow_registry.js sends api.clientId in the POST body so the
  backend can identify which tab is registering.
This commit is contained in:
Will Miao
2026-06-28 17:57:58 +08:00
parent 703a6a4ea0
commit e9e8c31ad1
4 changed files with 208 additions and 94 deletions

View File

@@ -414,9 +414,10 @@ class PromptServerProtocol(Protocol):
"""Subset of PromptServer used by the handlers."""
instance: "PromptServerProtocol"
sockets: dict # maps clientId (sid) → WebSocketResponse
def send_sync(
self, event: str, payload: dict
self, event: str, payload: dict | None = None, sid: str | None = None
) -> None: # pragma: no cover - protocol
...
@@ -471,90 +472,154 @@ class BackupServiceProtocol(Protocol):
class NodeRegistry:
"""Thread-safe registry for tracking LoRA nodes in active workflows."""
"""Thread-safe registry for tracking LoRA nodes across ComfyUI tabs.
Each connected ComfyUI browser tab (identified by its ``sid`` / ``clientId``)
registers its own set of workflow nodes. Queries merge all known tabs into
a single result so that the calling LM panel always sees *every* available
target node, regardless of which tab responded fastest.
"""
def __init__(self) -> None:
self._lock = asyncio.Lock()
self._nodes: Dict[str, dict] = {}
self._registry_updated = asyncio.Event()
# sid → {unique_id → node_info}
self._tab_nodes: Dict[str, Dict[str, dict]] = {}
self._ready = asyncio.Event()
self._waiting_clients: set[str] = set()
@property
def pending_client_count(self) -> int:
"""Number of clients that have not yet responded in the current refresh cycle."""
return len(self._waiting_clients)
# ------------------------------------------------------------------
# Helpers to build one node dict (extracted so it's reused for each tab)
# ------------------------------------------------------------------
@staticmethod
def _build_node_dict(node: dict) -> dict:
node_id = node["node_id"]
graph_id = str(node["graph_id"])
unique_id = f"{graph_id}:{node_id}"
node_type = node.get("type", "")
type_id = NODE_TYPES.get(node_type, 0)
bgcolor = node.get("bgcolor") or DEFAULT_NODE_COLOR
raw_capabilities = node.get("capabilities")
capabilities: dict = {}
if isinstance(raw_capabilities, dict):
capabilities = dict(raw_capabilities)
raw_widget_names: list | None = node.get("widget_names")
if not isinstance(raw_widget_names, list):
capability_widget_names = capabilities.get("widget_names")
raw_widget_names = (
capability_widget_names
if isinstance(capability_widget_names, list)
else None
)
widget_names: list[str] = []
if isinstance(raw_widget_names, list):
widget_names = [
str(widget_name)
for widget_name in raw_widget_names
if isinstance(widget_name, str) and widget_name
]
if widget_names:
capabilities["widget_names"] = widget_names
else:
capabilities.pop("widget_names", None)
if "supports_lora" in capabilities:
capabilities["supports_lora"] = bool(capabilities["supports_lora"])
comfy_class = node.get("comfy_class")
if not isinstance(comfy_class, str) or not comfy_class:
comfy_class = node_type if isinstance(node_type, str) else None
return {
"id": node_id,
"graph_id": graph_id,
"graph_name": node.get("graph_name"),
"unique_id": unique_id,
"bgcolor": bgcolor,
"title": node.get("title"),
"type": type_id,
"type_name": node_type,
"comfy_class": comfy_class,
"capabilities": capabilities,
"widget_names": widget_names,
"mode": node.get("mode"),
"marker_role": node.get("marker_role"),
}
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
async def register_nodes(self, sid: str, nodes: list[dict]) -> None:
"""Register/replace the node list for a single ComfyUI tab (identified by *sid*)."""
tab_nodes: dict[str, dict] = {}
for node in nodes:
nd = self._build_node_dict(node)
tab_nodes[nd["unique_id"]] = nd
async def register_nodes(self, nodes: list[dict]) -> None:
async with self._lock:
self._nodes.clear()
for node in nodes:
node_id = node["node_id"]
graph_id = str(node["graph_id"])
unique_id = f"{graph_id}:{node_id}"
node_type = node.get("type", "")
type_id = NODE_TYPES.get(node_type, 0)
bgcolor = node.get("bgcolor") or DEFAULT_NODE_COLOR
raw_capabilities = node.get("capabilities")
capabilities: dict = {}
if isinstance(raw_capabilities, dict):
capabilities = dict(raw_capabilities)
self._tab_nodes[sid] = tab_nodes
self._waiting_clients.discard(sid)
if not self._waiting_clients:
self._ready.set()
raw_widget_names: list | None = node.get("widget_names")
if not isinstance(raw_widget_names, list):
capability_widget_names = capabilities.get("widget_names")
raw_widget_names = (
capability_widget_names
if isinstance(capability_widget_names, list)
else None
)
logger.debug("Registered %s nodes from client %s", len(nodes), sid)
widget_names: list[str] = []
if isinstance(raw_widget_names, list):
widget_names = [
str(widget_name)
for widget_name in raw_widget_names
if isinstance(widget_name, str) and widget_name
]
def prepare_for_refresh(self, active_sids: list[str]) -> None:
"""Set the list of client IDs we expect to hear from during the next refresh cycle."""
self._ready.clear()
self._waiting_clients = set(active_sids)
if widget_names:
capabilities["widget_names"] = widget_names
else:
capabilities.pop("widget_names", None)
if "supports_lora" in capabilities:
capabilities["supports_lora"] = bool(capabilities["supports_lora"])
comfy_class = node.get("comfy_class")
if not isinstance(comfy_class, str) or not comfy_class:
comfy_class = node_type if isinstance(node_type, str) else None
self._nodes[unique_id] = {
"id": node_id,
"graph_id": graph_id,
"graph_name": node.get("graph_name"),
"unique_id": unique_id,
"bgcolor": bgcolor,
"title": node.get("title"),
"type": type_id,
"type_name": node_type,
"comfy_class": comfy_class,
"capabilities": capabilities,
"widget_names": widget_names,
"mode": node.get("mode"),
"marker_role": node.get("marker_role"),
}
logger.debug("Registered %s nodes in registry", len(nodes))
self._registry_updated.set()
async def get_registry(self) -> dict:
async with self._lock:
return {
"nodes": dict(self._nodes),
"node_count": len(self._nodes),
}
async def wait_for_update(self, timeout: float = 1.0) -> bool:
self._registry_updated.clear()
async def wait_for_all(self, timeout: float = 2.0) -> bool:
"""Block until every client in the current waiting set has responded
(or *timeout* seconds elapse). Returns ``True`` if all responded."""
if not self._waiting_clients:
return True
try:
await asyncio.wait_for(self._registry_updated.wait(), timeout=timeout)
await asyncio.wait_for(self._ready.wait(), timeout=timeout)
return True
except asyncio.TimeoutError:
return False
async def get_merged_registry(self, active_sids: set[str] | None = None) -> dict:
"""Return the union of all known tab nodes, pruning any tab that is no
longer connected."""
async with self._lock:
# Garbage-collect stale entries (disconnected tabs)
if active_sids is not None:
for sid in list(self._tab_nodes):
if sid not in active_sids:
del self._tab_nodes[sid]
merged: dict[str, dict] = {}
tab_info: dict[str, dict] = {}
for sid, nodes in self._tab_nodes.items():
tab_info[sid] = {
"node_count": len(nodes),
"graph_names": list(
{
n.get("graph_name")
for n in nodes.values()
if n.get("graph_name")
}
),
}
merged.update(nodes)
return {
"nodes": merged,
"node_count": len(merged),
"tab_count": len(self._tab_nodes),
"tabs": tab_info,
}
class HealthCheckHandler:
async def health_check(self, request: web.Request) -> web.Response:
@@ -2995,10 +3060,21 @@ class NodeRegistryHandler:
try:
data = await request.json()
nodes = data.get("nodes", [])
client_id = data.get("client_id")
if not isinstance(nodes, list):
return web.json_response(
{"success": False, "error": "nodes must be a list"}, status=400
)
if not isinstance(client_id, str) or not client_id:
return web.json_response(
{
"success": False,
"error": "Missing client_id parameter",
},
status=400,
)
for index, node in enumerate(nodes):
if not isinstance(node, dict):
return web.json_response(
@@ -3042,7 +3118,7 @@ class NodeRegistryHandler:
else:
node["graph_name"] = str(graph_name)
await self._node_registry.register_nodes(nodes)
await self._node_registry.register_nodes(client_id, nodes)
return web.json_response(
{
"success": True,
@@ -3066,9 +3142,15 @@ class NodeRegistryHandler:
status=503,
)
# Snapshot of currently-connected ComfyUI tabs
active_sids = list(self._prompt_server.instance.sockets.keys())
self._node_registry.prepare_for_refresh(active_sids)
try:
self._prompt_server.instance.send_sync("lora_registry_refresh", {})
logger.debug("Sent registry refresh request to frontend")
logger.debug(
"Sent registry refresh request (expecting %s clients)", len(active_sids)
)
except Exception as exc:
logger.error("Failed to send registry refresh message: %s", exc)
return web.json_response(
@@ -3080,19 +3162,31 @@ class NodeRegistryHandler:
status=500,
)
registry_updated = await self._node_registry.wait_for_update(timeout=1.0)
if not registry_updated:
logger.warning("Registry refresh timeout after 1 second")
if not await self._node_registry.wait_for_all(timeout=2.0):
logger.warning(
"Registry refresh timeout after 2s (%s/%s clients responded)",
len(active_sids) - self._node_registry.pending_client_count,
len(active_sids),
)
# Re-read current sockets after the wait: a tab may have connected
# while we were waiting, and we don't want to garbage-collect it.
current_sids = set(self._prompt_server.instance.sockets.keys())
registry_info = await self._node_registry.get_merged_registry(
active_sids=current_sids
)
if registry_info["node_count"] == 0:
logger.warning("No nodes registered after refresh")
return web.json_response(
{
"success": False,
"error": "Timeout Error",
"message": "Registry refresh timeout - ComfyUI frontend may not be responsive",
"error": "Empty Registry",
"message": "No workflow nodes found — ensure ComfyUI is open and the extension is loaded.",
},
status=408,
)
registry_info = await self._node_registry.get_registry()
return web.json_response({"success": True, "data": registry_info})
except Exception as exc: # pragma: no cover - defensive logging
logger.error("Failed to get registry: %s", exc, exc_info=True)

View File

@@ -60,7 +60,9 @@ class FakePromptServer:
sent = []
class Instance:
def send_sync(self, event, payload):
sockets: dict = {}
def send_sync(self, event, payload, sid=None):
FakePromptServer.sent.append((event, payload))
instance = Instance()
@@ -148,7 +150,8 @@ class TestNodeRegistryHandlerSnapshots:
"type": "Lora Loader (LoraManager)",
"title": "Test Loader",
}
]
],
"client_id": "test-client-1",
}
)
@@ -167,7 +170,7 @@ class TestNodeRegistryHandlerSnapshots:
standalone_mode=False,
)
request = FakeRequest(json_data={"nodes": []})
request = FakeRequest(json_data={"nodes": [], "client_id": "test-client-1"})
response = await handler.register_nodes(request)
payload = json.loads(response.text)

View File

@@ -586,7 +586,9 @@ class FakePromptServer:
sent = []
class Instance:
def send_sync(self, event, payload):
sockets: dict = {}
def send_sync(self, event, payload, sid=None):
FakePromptServer.sent.append((event, payload))
instance = Instance()
@@ -601,7 +603,12 @@ async def test_register_nodes_requires_graph_id():
standalone_mode=False,
)
request = FakeRequest(json_data={"nodes": [{"node_id": 1}]})
request = FakeRequest(
json_data={
"nodes": [{"node_id": 1}],
"client_id": "test-client-1",
}
)
response = await handler.register_nodes(request)
payload = json.loads(response.text)
@@ -629,7 +636,8 @@ async def test_register_nodes_stores_graph_identifier():
"type": "Lora Loader (LoraManager)",
"title": "Loader",
}
]
],
"client_id": "test-client-1",
}
)
@@ -638,7 +646,7 @@ async def test_register_nodes_stores_graph_identifier():
assert payload["success"] is True
registry = await node_registry.get_registry()
registry = await node_registry.get_merged_registry()
assert registry["node_count"] == 1
stored_node = next(iter(registry["nodes"].values()))
assert stored_node["graph_id"] == "graph-123"
@@ -664,7 +672,8 @@ async def test_register_nodes_defaults_graph_name_to_none():
"type": "Lora Loader (LoraManager)",
"title": "Root Loader",
}
]
],
"client_id": "test-client-1",
}
)
@@ -673,7 +682,7 @@ async def test_register_nodes_defaults_graph_name_to_none():
assert payload["success"] is True
registry = await node_registry.get_registry()
registry = await node_registry.get_merged_registry()
stored_node = next(iter(registry["nodes"].values()))
assert stored_node["graph_name"] is None
@@ -700,7 +709,8 @@ async def test_register_nodes_includes_capabilities():
"widget_names": ["ckpt_name", "", 42],
},
}
]
],
"client_id": "test-client-1",
}
)
@@ -709,7 +719,7 @@ async def test_register_nodes_includes_capabilities():
assert payload["success"] is True
registry = await node_registry.get_registry()
registry = await node_registry.get_merged_registry()
stored_node = next(iter(registry["nodes"].values()))
assert stored_node["capabilities"] == {
"supports_lora": False,
@@ -724,7 +734,9 @@ async def test_update_node_widget_sends_payload():
class RecordingPromptServer:
class Instance:
def send_sync(self, event, payload):
sockets: dict = {}
def send_sync(self, event, payload, sid=None):
send_calls.append((event, payload))
instance = Instance()
@@ -768,7 +780,9 @@ async def test_update_lora_code_includes_graph_identifier():
class RecordingPromptServer:
class Instance:
def send_sync(self, event, payload):
sockets: dict = {}
def send_sync(self, event, payload, sid=None):
send_calls.append((event, payload))
instance = Instance()

View File

@@ -151,7 +151,10 @@ app.registerExtension({
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({ nodes: workflowNodes }),
body: JSON.stringify({
nodes: workflowNodes,
client_id: api.clientId ?? api.initialClientId ?? "",
}),
});
if (!response.ok) {