mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-06-29 05:51:16 -03:00
fix(registry): store nodes per-client to prevent multi-tab race condition
Move NodeRegistry from a single global _nodes dict to a per-client (_tab_nodes) structure so that multiple ComfyUI browser tabs no longer overwrite each other's workflow node data during a lora_registry_refresh cycle. The merged result is a union of all known tabs' target nodes, eliminating the non-deterministic failure where send-to-workflow could randomly target a tab lacking valid targets. - NodeRegistry.register_nodes(sid, nodes) replaces per-tab data without affecting other tabs. - NodeRegistry.get_merged_registry() returns the union across all connected clients, together with tab_count / per-tab metadata. - prepare_for_refresh() snapshots the current active sockets; caller re-reads before merging so that newly-connected tabs are not pruned. - workflow_registry.js sends api.clientId in the POST body so the backend can identify which tab is registering.
This commit is contained in:
@@ -414,9 +414,10 @@ class PromptServerProtocol(Protocol):
|
||||
"""Subset of PromptServer used by the handlers."""
|
||||
|
||||
instance: "PromptServerProtocol"
|
||||
sockets: dict # maps clientId (sid) → WebSocketResponse
|
||||
|
||||
def send_sync(
|
||||
self, event: str, payload: dict
|
||||
self, event: str, payload: dict | None = None, sid: str | None = None
|
||||
) -> None: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
@@ -471,90 +472,154 @@ class BackupServiceProtocol(Protocol):
|
||||
|
||||
|
||||
class NodeRegistry:
|
||||
"""Thread-safe registry for tracking LoRA nodes in active workflows."""
|
||||
"""Thread-safe registry for tracking LoRA nodes across ComfyUI tabs.
|
||||
|
||||
Each connected ComfyUI browser tab (identified by its ``sid`` / ``clientId``)
|
||||
registers its own set of workflow nodes. Queries merge all known tabs into
|
||||
a single result so that the calling LM panel always sees *every* available
|
||||
target node, regardless of which tab responded fastest.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._lock = asyncio.Lock()
|
||||
self._nodes: Dict[str, dict] = {}
|
||||
self._registry_updated = asyncio.Event()
|
||||
# sid → {unique_id → node_info}
|
||||
self._tab_nodes: Dict[str, Dict[str, dict]] = {}
|
||||
self._ready = asyncio.Event()
|
||||
self._waiting_clients: set[str] = set()
|
||||
|
||||
@property
|
||||
def pending_client_count(self) -> int:
|
||||
"""Number of clients that have not yet responded in the current refresh cycle."""
|
||||
return len(self._waiting_clients)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Helpers to build one node dict (extracted so it's reused for each tab)
|
||||
# ------------------------------------------------------------------
|
||||
@staticmethod
|
||||
def _build_node_dict(node: dict) -> dict:
|
||||
node_id = node["node_id"]
|
||||
graph_id = str(node["graph_id"])
|
||||
unique_id = f"{graph_id}:{node_id}"
|
||||
node_type = node.get("type", "")
|
||||
type_id = NODE_TYPES.get(node_type, 0)
|
||||
bgcolor = node.get("bgcolor") or DEFAULT_NODE_COLOR
|
||||
|
||||
raw_capabilities = node.get("capabilities")
|
||||
capabilities: dict = {}
|
||||
if isinstance(raw_capabilities, dict):
|
||||
capabilities = dict(raw_capabilities)
|
||||
|
||||
raw_widget_names: list | None = node.get("widget_names")
|
||||
if not isinstance(raw_widget_names, list):
|
||||
capability_widget_names = capabilities.get("widget_names")
|
||||
raw_widget_names = (
|
||||
capability_widget_names
|
||||
if isinstance(capability_widget_names, list)
|
||||
else None
|
||||
)
|
||||
|
||||
widget_names: list[str] = []
|
||||
if isinstance(raw_widget_names, list):
|
||||
widget_names = [
|
||||
str(widget_name)
|
||||
for widget_name in raw_widget_names
|
||||
if isinstance(widget_name, str) and widget_name
|
||||
]
|
||||
|
||||
if widget_names:
|
||||
capabilities["widget_names"] = widget_names
|
||||
else:
|
||||
capabilities.pop("widget_names", None)
|
||||
|
||||
if "supports_lora" in capabilities:
|
||||
capabilities["supports_lora"] = bool(capabilities["supports_lora"])
|
||||
|
||||
comfy_class = node.get("comfy_class")
|
||||
if not isinstance(comfy_class, str) or not comfy_class:
|
||||
comfy_class = node_type if isinstance(node_type, str) else None
|
||||
|
||||
return {
|
||||
"id": node_id,
|
||||
"graph_id": graph_id,
|
||||
"graph_name": node.get("graph_name"),
|
||||
"unique_id": unique_id,
|
||||
"bgcolor": bgcolor,
|
||||
"title": node.get("title"),
|
||||
"type": type_id,
|
||||
"type_name": node_type,
|
||||
"comfy_class": comfy_class,
|
||||
"capabilities": capabilities,
|
||||
"widget_names": widget_names,
|
||||
"mode": node.get("mode"),
|
||||
"marker_role": node.get("marker_role"),
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API
|
||||
# ------------------------------------------------------------------
|
||||
async def register_nodes(self, sid: str, nodes: list[dict]) -> None:
|
||||
"""Register/replace the node list for a single ComfyUI tab (identified by *sid*)."""
|
||||
tab_nodes: dict[str, dict] = {}
|
||||
for node in nodes:
|
||||
nd = self._build_node_dict(node)
|
||||
tab_nodes[nd["unique_id"]] = nd
|
||||
|
||||
async def register_nodes(self, nodes: list[dict]) -> None:
|
||||
async with self._lock:
|
||||
self._nodes.clear()
|
||||
for node in nodes:
|
||||
node_id = node["node_id"]
|
||||
graph_id = str(node["graph_id"])
|
||||
unique_id = f"{graph_id}:{node_id}"
|
||||
node_type = node.get("type", "")
|
||||
type_id = NODE_TYPES.get(node_type, 0)
|
||||
bgcolor = node.get("bgcolor") or DEFAULT_NODE_COLOR
|
||||
raw_capabilities = node.get("capabilities")
|
||||
capabilities: dict = {}
|
||||
if isinstance(raw_capabilities, dict):
|
||||
capabilities = dict(raw_capabilities)
|
||||
self._tab_nodes[sid] = tab_nodes
|
||||
self._waiting_clients.discard(sid)
|
||||
if not self._waiting_clients:
|
||||
self._ready.set()
|
||||
|
||||
raw_widget_names: list | None = node.get("widget_names")
|
||||
if not isinstance(raw_widget_names, list):
|
||||
capability_widget_names = capabilities.get("widget_names")
|
||||
raw_widget_names = (
|
||||
capability_widget_names
|
||||
if isinstance(capability_widget_names, list)
|
||||
else None
|
||||
)
|
||||
logger.debug("Registered %s nodes from client %s", len(nodes), sid)
|
||||
|
||||
widget_names: list[str] = []
|
||||
if isinstance(raw_widget_names, list):
|
||||
widget_names = [
|
||||
str(widget_name)
|
||||
for widget_name in raw_widget_names
|
||||
if isinstance(widget_name, str) and widget_name
|
||||
]
|
||||
def prepare_for_refresh(self, active_sids: list[str]) -> None:
|
||||
"""Set the list of client IDs we expect to hear from during the next refresh cycle."""
|
||||
self._ready.clear()
|
||||
self._waiting_clients = set(active_sids)
|
||||
|
||||
if widget_names:
|
||||
capabilities["widget_names"] = widget_names
|
||||
else:
|
||||
capabilities.pop("widget_names", None)
|
||||
|
||||
if "supports_lora" in capabilities:
|
||||
capabilities["supports_lora"] = bool(capabilities["supports_lora"])
|
||||
|
||||
comfy_class = node.get("comfy_class")
|
||||
if not isinstance(comfy_class, str) or not comfy_class:
|
||||
comfy_class = node_type if isinstance(node_type, str) else None
|
||||
|
||||
self._nodes[unique_id] = {
|
||||
"id": node_id,
|
||||
"graph_id": graph_id,
|
||||
"graph_name": node.get("graph_name"),
|
||||
"unique_id": unique_id,
|
||||
"bgcolor": bgcolor,
|
||||
"title": node.get("title"),
|
||||
"type": type_id,
|
||||
"type_name": node_type,
|
||||
"comfy_class": comfy_class,
|
||||
"capabilities": capabilities,
|
||||
"widget_names": widget_names,
|
||||
"mode": node.get("mode"),
|
||||
"marker_role": node.get("marker_role"),
|
||||
}
|
||||
logger.debug("Registered %s nodes in registry", len(nodes))
|
||||
self._registry_updated.set()
|
||||
|
||||
async def get_registry(self) -> dict:
|
||||
async with self._lock:
|
||||
return {
|
||||
"nodes": dict(self._nodes),
|
||||
"node_count": len(self._nodes),
|
||||
}
|
||||
|
||||
async def wait_for_update(self, timeout: float = 1.0) -> bool:
|
||||
self._registry_updated.clear()
|
||||
async def wait_for_all(self, timeout: float = 2.0) -> bool:
|
||||
"""Block until every client in the current waiting set has responded
|
||||
(or *timeout* seconds elapse). Returns ``True`` if all responded."""
|
||||
if not self._waiting_clients:
|
||||
return True
|
||||
try:
|
||||
await asyncio.wait_for(self._registry_updated.wait(), timeout=timeout)
|
||||
await asyncio.wait_for(self._ready.wait(), timeout=timeout)
|
||||
return True
|
||||
except asyncio.TimeoutError:
|
||||
return False
|
||||
|
||||
async def get_merged_registry(self, active_sids: set[str] | None = None) -> dict:
|
||||
"""Return the union of all known tab nodes, pruning any tab that is no
|
||||
longer connected."""
|
||||
async with self._lock:
|
||||
# Garbage-collect stale entries (disconnected tabs)
|
||||
if active_sids is not None:
|
||||
for sid in list(self._tab_nodes):
|
||||
if sid not in active_sids:
|
||||
del self._tab_nodes[sid]
|
||||
|
||||
merged: dict[str, dict] = {}
|
||||
tab_info: dict[str, dict] = {}
|
||||
for sid, nodes in self._tab_nodes.items():
|
||||
tab_info[sid] = {
|
||||
"node_count": len(nodes),
|
||||
"graph_names": list(
|
||||
{
|
||||
n.get("graph_name")
|
||||
for n in nodes.values()
|
||||
if n.get("graph_name")
|
||||
}
|
||||
),
|
||||
}
|
||||
merged.update(nodes)
|
||||
|
||||
return {
|
||||
"nodes": merged,
|
||||
"node_count": len(merged),
|
||||
"tab_count": len(self._tab_nodes),
|
||||
"tabs": tab_info,
|
||||
}
|
||||
|
||||
|
||||
class HealthCheckHandler:
|
||||
async def health_check(self, request: web.Request) -> web.Response:
|
||||
@@ -2995,10 +3060,21 @@ class NodeRegistryHandler:
|
||||
try:
|
||||
data = await request.json()
|
||||
nodes = data.get("nodes", [])
|
||||
client_id = data.get("client_id")
|
||||
if not isinstance(nodes, list):
|
||||
return web.json_response(
|
||||
{"success": False, "error": "nodes must be a list"}, status=400
|
||||
)
|
||||
|
||||
if not isinstance(client_id, str) or not client_id:
|
||||
return web.json_response(
|
||||
{
|
||||
"success": False,
|
||||
"error": "Missing client_id parameter",
|
||||
},
|
||||
status=400,
|
||||
)
|
||||
|
||||
for index, node in enumerate(nodes):
|
||||
if not isinstance(node, dict):
|
||||
return web.json_response(
|
||||
@@ -3042,7 +3118,7 @@ class NodeRegistryHandler:
|
||||
else:
|
||||
node["graph_name"] = str(graph_name)
|
||||
|
||||
await self._node_registry.register_nodes(nodes)
|
||||
await self._node_registry.register_nodes(client_id, nodes)
|
||||
return web.json_response(
|
||||
{
|
||||
"success": True,
|
||||
@@ -3066,9 +3142,15 @@ class NodeRegistryHandler:
|
||||
status=503,
|
||||
)
|
||||
|
||||
# Snapshot of currently-connected ComfyUI tabs
|
||||
active_sids = list(self._prompt_server.instance.sockets.keys())
|
||||
self._node_registry.prepare_for_refresh(active_sids)
|
||||
|
||||
try:
|
||||
self._prompt_server.instance.send_sync("lora_registry_refresh", {})
|
||||
logger.debug("Sent registry refresh request to frontend")
|
||||
logger.debug(
|
||||
"Sent registry refresh request (expecting %s clients)", len(active_sids)
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error("Failed to send registry refresh message: %s", exc)
|
||||
return web.json_response(
|
||||
@@ -3080,19 +3162,31 @@ class NodeRegistryHandler:
|
||||
status=500,
|
||||
)
|
||||
|
||||
registry_updated = await self._node_registry.wait_for_update(timeout=1.0)
|
||||
if not registry_updated:
|
||||
logger.warning("Registry refresh timeout after 1 second")
|
||||
if not await self._node_registry.wait_for_all(timeout=2.0):
|
||||
logger.warning(
|
||||
"Registry refresh timeout after 2s (%s/%s clients responded)",
|
||||
len(active_sids) - self._node_registry.pending_client_count,
|
||||
len(active_sids),
|
||||
)
|
||||
|
||||
# Re-read current sockets after the wait: a tab may have connected
|
||||
# while we were waiting, and we don't want to garbage-collect it.
|
||||
current_sids = set(self._prompt_server.instance.sockets.keys())
|
||||
registry_info = await self._node_registry.get_merged_registry(
|
||||
active_sids=current_sids
|
||||
)
|
||||
|
||||
if registry_info["node_count"] == 0:
|
||||
logger.warning("No nodes registered after refresh")
|
||||
return web.json_response(
|
||||
{
|
||||
"success": False,
|
||||
"error": "Timeout Error",
|
||||
"message": "Registry refresh timeout - ComfyUI frontend may not be responsive",
|
||||
"error": "Empty Registry",
|
||||
"message": "No workflow nodes found — ensure ComfyUI is open and the extension is loaded.",
|
||||
},
|
||||
status=408,
|
||||
)
|
||||
|
||||
registry_info = await self._node_registry.get_registry()
|
||||
return web.json_response({"success": True, "data": registry_info})
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
logger.error("Failed to get registry: %s", exc, exc_info=True)
|
||||
|
||||
Reference in New Issue
Block a user