feat(graph): enhance node handling with graph identifiers and improve metadata updates, see #408, #538

This commit is contained in:
Will Miao
2025-10-07 23:22:38 +08:00
parent 9199950b74
commit 3118f3b43c
12 changed files with 574 additions and 103 deletions

View File

@@ -188,7 +188,7 @@ async def test_get_trigger_words_broadcasts(monkeypatch, routes):
monkeypatch.setattr("py.routes.lora_routes.get_lora_info", lambda name: (f"path/{name}", [f"trigger-{name}"]))
request = DummyRequest(json_data={"lora_names": ["one"], "node_ids": ["node"]})
request = DummyRequest(json_data={"lora_names": ["one"], "node_ids": [{"node_id": "node", "graph_id": "graph-1"}]})
response = await routes.get_trigger_words(request)
payload = json.loads(response.text)
@@ -196,7 +196,7 @@ async def test_get_trigger_words_broadcasts(monkeypatch, routes):
assert payload == {"success": True}
send_mock.assert_called_once_with(
"trigger_word_update",
{"id": "node", "message": "trigger-one"},
{"id": "node", "graph_id": "graph-1", "message": "trigger-one"},
)

View File

@@ -4,7 +4,13 @@ from types import SimpleNamespace
import pytest
from aiohttp import web
from py.routes.handlers.misc_handlers import SettingsHandler, ServiceRegistryAdapter
from py.routes.handlers.misc_handlers import (
LoraCodeHandler,
NodeRegistry,
NodeRegistryHandler,
ServiceRegistryAdapter,
SettingsHandler,
)
from py.routes.misc_route_registrar import MISC_ROUTE_DEFINITIONS, MiscRouteRegistrar
from py.routes.misc_routes import MiscRoutes
@@ -126,6 +132,128 @@ class FakePromptServer:
instance = Instance()
@pytest.mark.asyncio
async def test_register_nodes_requires_graph_id():
node_registry = NodeRegistry()
handler = NodeRegistryHandler(
node_registry=node_registry,
prompt_server=FakePromptServer,
standalone_mode=False,
)
request = FakeRequest(json_data={"nodes": [{"node_id": 1}]})
response = await handler.register_nodes(request)
payload = json.loads(response.text)
assert response.status == 400
assert payload["success"] is False
assert "graph_id" in payload["error"]
@pytest.mark.asyncio
async def test_register_nodes_stores_graph_identifier():
node_registry = NodeRegistry()
handler = NodeRegistryHandler(
node_registry=node_registry,
prompt_server=FakePromptServer,
standalone_mode=False,
)
request = FakeRequest(
json_data={
"nodes": [
{
"node_id": 7,
"graph_id": "graph-123",
"graph_name": "Character Subgraph",
"type": "Lora Loader (LoraManager)",
"title": "Loader",
}
]
}
)
response = await handler.register_nodes(request)
payload = json.loads(response.text)
assert payload["success"] is True
registry = await node_registry.get_registry()
assert registry["node_count"] == 1
stored_node = next(iter(registry["nodes"].values()))
assert stored_node["graph_id"] == "graph-123"
assert stored_node["unique_id"] == "graph-123:7"
assert stored_node["graph_name"] == "Character Subgraph"
@pytest.mark.asyncio
async def test_register_nodes_defaults_graph_name_to_none():
node_registry = NodeRegistry()
handler = NodeRegistryHandler(
node_registry=node_registry,
prompt_server=FakePromptServer,
standalone_mode=False,
)
request = FakeRequest(
json_data={
"nodes": [
{
"node_id": 8,
"graph_id": "root",
"type": "Lora Loader (LoraManager)",
"title": "Root Loader",
}
]
}
)
response = await handler.register_nodes(request)
payload = json.loads(response.text)
assert payload["success"] is True
registry = await node_registry.get_registry()
stored_node = next(iter(registry["nodes"].values()))
assert stored_node["graph_name"] is None
@pytest.mark.asyncio
async def test_update_lora_code_includes_graph_identifier():
send_calls: list[tuple[str, dict]] = []
class RecordingPromptServer:
class Instance:
def send_sync(self, event, payload):
send_calls.append((event, payload))
instance = Instance()
handler = LoraCodeHandler(RecordingPromptServer)
request = FakeRequest(
json_data={
"node_ids": [{"node_id": 3, "graph_id": "graph-A"}],
"lora_code": "<lora>",
"mode": "replace",
}
)
response = await handler.update_lora_code(request)
payload = json.loads(response.text)
assert payload["success"] is True
assert payload["results"] == [
{"node_id": 3, "graph_id": "graph-A", "success": True}
]
assert send_calls == [
(
"lora_code_update",
{"id": 3, "graph_id": "graph-A", "lora_code": "<lora>", "mode": "replace"},
)
]
class FakeScanner:
async def check_model_version_exists(self, _version_id):
return False