feat(node-registry): add support to send checkpoint/diffusion model to workflow

- Add capabilities parsing and validation for node registration
- Implement widget_names extraction from capabilities with type safety
- Add supports_lora boolean conversion in capabilities
- Include comfy_class fallback to node_type when missing
- Add new update_node_widget API endpoint for bulk widget updates
- Improve error handling and input validation for widget updates
- Remove unused parameters from node selector event setup function

These changes improve node metadata handling and enable dynamic widget management capabilities.
This commit is contained in:
Will Miao
2025-10-23 10:44:25 +08:00
parent 13433f8cd2
commit d0aa916683
7 changed files with 706 additions and 136 deletions

View File

@@ -219,6 +219,78 @@ async def test_register_nodes_defaults_graph_name_to_none():
assert stored_node["graph_name"] is None
@pytest.mark.asyncio
async def test_register_nodes_includes_capabilities():
node_registry = NodeRegistry()
handler = NodeRegistryHandler(
node_registry=node_registry,
prompt_server=FakePromptServer,
standalone_mode=False,
)
request = FakeRequest(
json_data={
"nodes": [
{
"node_id": 9,
"graph_id": "root",
"type": "CheckpointLoaderSimple",
"title": "Checkpoint Loader",
"capabilities": {"supports_lora": False, "widget_names": ["ckpt_name", "", 42]},
}
]
}
)
response = await handler.register_nodes(request)
payload = json.loads(response.text)
assert payload["success"] is True
registry = await node_registry.get_registry()
stored_node = next(iter(registry["nodes"].values()))
assert stored_node["capabilities"] == {"supports_lora": False, "widget_names": ["ckpt_name"]}
assert stored_node["widget_names"] == ["ckpt_name"]
@pytest.mark.asyncio
async def test_update_node_widget_sends_payload():
send_calls: list[tuple[str, dict]] = []
class RecordingPromptServer:
class Instance:
def send_sync(self, event, payload):
send_calls.append((event, payload))
instance = Instance()
handler = NodeRegistryHandler(
node_registry=NodeRegistry(),
prompt_server=RecordingPromptServer,
standalone_mode=False,
)
request = FakeRequest(
json_data={
"widget_name": "ckpt_name",
"value": "models/checkpoints/model.ckpt",
"node_ids": [{"node_id": 12, "graph_id": "root"}],
}
)
response = await handler.update_node_widget(request)
payload = json.loads(response.text)
assert response.status == 200
assert payload["success"] is True
assert send_calls == [
(
"lm_widget_update",
{"id": 12, "widget_name": "ckpt_name", "value": "models/checkpoints/model.ckpt", "graph_id": "root"},
)
]
@pytest.mark.asyncio
async def test_update_lora_code_includes_graph_identifier():
send_calls: list[tuple[str, dict]] = []