feat(node-registry): add support to send checkpoint/diffusion model to workflow

- Add capabilities parsing and validation for node registration
- Implement widget_names extraction from capabilities with type safety
- Add supports_lora boolean conversion in capabilities
- Include comfy_class fallback to node_type when missing
- Add new update_node_widget API endpoint for bulk widget updates
- Improve error handling and input validation for widget updates
- Remove unused parameters from node selector event setup function

These changes improve node metadata handling and enable dynamic widget management capabilities.
This commit is contained in:
Will Miao
2025-10-23 10:44:25 +08:00
parent 13433f8cd2
commit d0aa916683
7 changed files with 706 additions and 136 deletions

View File

@@ -100,6 +100,36 @@ class NodeRegistry:
node_type = node.get("type", "")
type_id = NODE_TYPES.get(node_type, 0)
bgcolor = node.get("bgcolor") or DEFAULT_NODE_COLOR
raw_capabilities = node.get("capabilities")
capabilities: dict = {}
if isinstance(raw_capabilities, dict):
capabilities = dict(raw_capabilities)
raw_widget_names: list | None = node.get("widget_names")
if not isinstance(raw_widget_names, list):
capability_widget_names = capabilities.get("widget_names")
raw_widget_names = capability_widget_names if isinstance(capability_widget_names, list) else None
widget_names: list[str] = []
if isinstance(raw_widget_names, list):
widget_names = [
str(widget_name)
for widget_name in raw_widget_names
if isinstance(widget_name, str) and widget_name
]
if widget_names:
capabilities["widget_names"] = widget_names
else:
capabilities.pop("widget_names", None)
if "supports_lora" in capabilities:
capabilities["supports_lora"] = bool(capabilities["supports_lora"])
comfy_class = node.get("comfy_class")
if not isinstance(comfy_class, str) or not comfy_class:
comfy_class = node_type if isinstance(node_type, str) else None
self._nodes[unique_id] = {
"id": node_id,
"graph_id": graph_id,
@@ -109,6 +139,9 @@ class NodeRegistry:
"title": node.get("title"),
"type": type_id,
"type_name": node_type,
"comfy_class": comfy_class,
"capabilities": capabilities,
"widget_names": widget_names,
}
logger.debug("Registered %s nodes in registry", len(nodes))
self._registry_updated.set()
@@ -919,6 +952,88 @@ class NodeRegistryHandler:
logger.error("Failed to get registry: %s", exc, exc_info=True)
return web.json_response({"success": False, "error": "Internal Error", "message": str(exc)}, status=500)
async def update_node_widget(self, request: web.Request) -> web.Response:
try:
data = await request.json()
widget_name = data.get("widget_name")
value = data.get("value")
node_ids = data.get("node_ids")
if not isinstance(widget_name, str) or not widget_name:
return web.json_response({"success": False, "error": "Missing widget_name parameter"}, status=400)
if not isinstance(value, str) or not value:
return web.json_response({"success": False, "error": "Missing value parameter"}, status=400)
if not isinstance(node_ids, list) or not node_ids:
return web.json_response(
{"success": False, "error": "node_ids must be a non-empty list"},
status=400,
)
results = []
for entry in node_ids:
node_identifier = entry
graph_identifier = None
if isinstance(entry, dict):
node_identifier = entry.get("node_id")
graph_identifier = entry.get("graph_id")
if node_identifier is None:
results.append(
{
"node_id": node_identifier,
"graph_id": graph_identifier,
"success": False,
"error": "Missing node_id parameter",
}
)
continue
try:
parsed_node_id = int(node_identifier)
except (TypeError, ValueError):
parsed_node_id = node_identifier
payload = {
"id": parsed_node_id,
"widget_name": widget_name,
"value": value,
}
if graph_identifier is not None:
payload["graph_id"] = str(graph_identifier)
try:
self._prompt_server.instance.send_sync("lm_widget_update", payload)
results.append(
{
"node_id": parsed_node_id,
"graph_id": payload.get("graph_id"),
"success": True,
}
)
except Exception as exc: # pragma: no cover - defensive logging
logger.error(
"Error sending widget update to node %s (graph %s): %s",
parsed_node_id,
graph_identifier,
exc,
)
results.append(
{
"node_id": parsed_node_id,
"graph_id": payload.get("graph_id"),
"success": False,
"error": str(exc),
}
)
return web.json_response({"success": True, "results": results})
except Exception as exc: # pragma: no cover - defensive logging
logger.error("Failed to update node widget: %s", exc, exc_info=True)
return web.json_response({"success": False, "error": str(exc)}, status=500)
class MiscHandlerSet:
"""Aggregate handlers into a lookup compatible with the registrar."""
@@ -962,6 +1077,7 @@ class MiscHandlerSet:
"get_trained_words": self.trained_words.get_trained_words,
"get_model_example_files": self.model_examples.get_model_example_files,
"register_nodes": self.node_registry.register_nodes,
"update_node_widget": self.node_registry.update_node_widget,
"get_registry": self.node_registry.get_registry,
"check_model_exists": self.model_library.check_model_exists,
"get_civitai_user_models": self.model_library.get_civitai_user_models,

View File

@@ -33,6 +33,7 @@ MISC_ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
RouteDefinition("GET", "/api/lm/trained-words", "get_trained_words"),
RouteDefinition("GET", "/api/lm/model-example-files", "get_model_example_files"),
RouteDefinition("POST", "/api/lm/register-nodes", "register_nodes"),
RouteDefinition("POST", "/api/lm/update-node-widget", "update_node_widget"),
RouteDefinition("GET", "/api/lm/get-registry", "get_registry"),
RouteDefinition("GET", "/api/lm/check-model-exists", "check_model_exists"),
RouteDefinition("GET", "/api/lm/civitai/user-models", "get_civitai_user_models"),