fix(registry): store nodes per-client to prevent multi-tab race condition

Move NodeRegistry from a single global _nodes dict to a per-client
(_tab_nodes) structure so that multiple ComfyUI browser tabs no
longer overwrite each other's workflow node data during a
lora_registry_refresh cycle.  The merged result is a union of all
known tabs' target nodes, eliminating the non-deterministic failure
where send-to-workflow could randomly target a tab lacking valid
targets.

- NodeRegistry.register_nodes(sid, nodes) replaces per-tab data
  without affecting other tabs.
- NodeRegistry.get_merged_registry() returns the union across all
  connected clients, together with tab_count / per-tab metadata.
- prepare_for_refresh() snapshots the current active sockets; caller
  re-reads before merging so that newly-connected tabs are not pruned.
- workflow_registry.js sends api.clientId in the POST body so the
  backend can identify which tab is registering.
This commit is contained in:
Will Miao
2026-06-28 17:57:58 +08:00
parent 703a6a4ea0
commit e9e8c31ad1
4 changed files with 208 additions and 94 deletions

View File

@@ -60,7 +60,9 @@ class FakePromptServer:
sent = []
class Instance:
def send_sync(self, event, payload):
sockets: dict = {}
def send_sync(self, event, payload, sid=None):
FakePromptServer.sent.append((event, payload))
instance = Instance()
@@ -148,7 +150,8 @@ class TestNodeRegistryHandlerSnapshots:
"type": "Lora Loader (LoraManager)",
"title": "Test Loader",
}
]
],
"client_id": "test-client-1",
}
)
@@ -167,7 +170,7 @@ class TestNodeRegistryHandlerSnapshots:
standalone_mode=False,
)
request = FakeRequest(json_data={"nodes": []})
request = FakeRequest(json_data={"nodes": [], "client_id": "test-client-1"})
response = await handler.register_nodes(request)
payload = json.loads(response.text)

View File

@@ -586,7 +586,9 @@ class FakePromptServer:
sent = []
class Instance:
def send_sync(self, event, payload):
sockets: dict = {}
def send_sync(self, event, payload, sid=None):
FakePromptServer.sent.append((event, payload))
instance = Instance()
@@ -601,7 +603,12 @@ async def test_register_nodes_requires_graph_id():
standalone_mode=False,
)
request = FakeRequest(json_data={"nodes": [{"node_id": 1}]})
request = FakeRequest(
json_data={
"nodes": [{"node_id": 1}],
"client_id": "test-client-1",
}
)
response = await handler.register_nodes(request)
payload = json.loads(response.text)
@@ -629,7 +636,8 @@ async def test_register_nodes_stores_graph_identifier():
"type": "Lora Loader (LoraManager)",
"title": "Loader",
}
]
],
"client_id": "test-client-1",
}
)
@@ -638,7 +646,7 @@ async def test_register_nodes_stores_graph_identifier():
assert payload["success"] is True
registry = await node_registry.get_registry()
registry = await node_registry.get_merged_registry()
assert registry["node_count"] == 1
stored_node = next(iter(registry["nodes"].values()))
assert stored_node["graph_id"] == "graph-123"
@@ -664,7 +672,8 @@ async def test_register_nodes_defaults_graph_name_to_none():
"type": "Lora Loader (LoraManager)",
"title": "Root Loader",
}
]
],
"client_id": "test-client-1",
}
)
@@ -673,7 +682,7 @@ async def test_register_nodes_defaults_graph_name_to_none():
assert payload["success"] is True
registry = await node_registry.get_registry()
registry = await node_registry.get_merged_registry()
stored_node = next(iter(registry["nodes"].values()))
assert stored_node["graph_name"] is None
@@ -700,7 +709,8 @@ async def test_register_nodes_includes_capabilities():
"widget_names": ["ckpt_name", "", 42],
},
}
]
],
"client_id": "test-client-1",
}
)
@@ -709,7 +719,7 @@ async def test_register_nodes_includes_capabilities():
assert payload["success"] is True
registry = await node_registry.get_registry()
registry = await node_registry.get_merged_registry()
stored_node = next(iter(registry["nodes"].values()))
assert stored_node["capabilities"] == {
"supports_lora": False,
@@ -724,7 +734,9 @@ async def test_update_node_widget_sends_payload():
class RecordingPromptServer:
class Instance:
def send_sync(self, event, payload):
sockets: dict = {}
def send_sync(self, event, payload, sid=None):
send_calls.append((event, payload))
instance = Instance()
@@ -768,7 +780,9 @@ async def test_update_lora_code_includes_graph_identifier():
class RecordingPromptServer:
class Instance:
def send_sync(self, event, payload):
sockets: dict = {}
def send_sync(self, event, payload, sid=None):
send_calls.append((event, payload))
instance = Instance()