mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
feat(graph): enhance node handling with graph identifiers and improve metadata updates, see #408, #538
This commit is contained in:
@@ -188,7 +188,7 @@ async def test_get_trigger_words_broadcasts(monkeypatch, routes):
|
||||
|
||||
monkeypatch.setattr("py.routes.lora_routes.get_lora_info", lambda name: (f"path/{name}", [f"trigger-{name}"]))
|
||||
|
||||
request = DummyRequest(json_data={"lora_names": ["one"], "node_ids": ["node"]})
|
||||
request = DummyRequest(json_data={"lora_names": ["one"], "node_ids": [{"node_id": "node", "graph_id": "graph-1"}]})
|
||||
|
||||
response = await routes.get_trigger_words(request)
|
||||
payload = json.loads(response.text)
|
||||
@@ -196,7 +196,7 @@ async def test_get_trigger_words_broadcasts(monkeypatch, routes):
|
||||
assert payload == {"success": True}
|
||||
send_mock.assert_called_once_with(
|
||||
"trigger_word_update",
|
||||
{"id": "node", "message": "trigger-one"},
|
||||
{"id": "node", "graph_id": "graph-1", "message": "trigger-one"},
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -4,7 +4,13 @@ from types import SimpleNamespace
|
||||
import pytest
|
||||
from aiohttp import web
|
||||
|
||||
from py.routes.handlers.misc_handlers import SettingsHandler, ServiceRegistryAdapter
|
||||
from py.routes.handlers.misc_handlers import (
|
||||
LoraCodeHandler,
|
||||
NodeRegistry,
|
||||
NodeRegistryHandler,
|
||||
ServiceRegistryAdapter,
|
||||
SettingsHandler,
|
||||
)
|
||||
from py.routes.misc_route_registrar import MISC_ROUTE_DEFINITIONS, MiscRouteRegistrar
|
||||
from py.routes.misc_routes import MiscRoutes
|
||||
|
||||
@@ -126,6 +132,128 @@ class FakePromptServer:
|
||||
instance = Instance()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_nodes_requires_graph_id():
|
||||
node_registry = NodeRegistry()
|
||||
handler = NodeRegistryHandler(
|
||||
node_registry=node_registry,
|
||||
prompt_server=FakePromptServer,
|
||||
standalone_mode=False,
|
||||
)
|
||||
|
||||
request = FakeRequest(json_data={"nodes": [{"node_id": 1}]})
|
||||
response = await handler.register_nodes(request)
|
||||
payload = json.loads(response.text)
|
||||
|
||||
assert response.status == 400
|
||||
assert payload["success"] is False
|
||||
assert "graph_id" in payload["error"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_nodes_stores_graph_identifier():
|
||||
node_registry = NodeRegistry()
|
||||
handler = NodeRegistryHandler(
|
||||
node_registry=node_registry,
|
||||
prompt_server=FakePromptServer,
|
||||
standalone_mode=False,
|
||||
)
|
||||
|
||||
request = FakeRequest(
|
||||
json_data={
|
||||
"nodes": [
|
||||
{
|
||||
"node_id": 7,
|
||||
"graph_id": "graph-123",
|
||||
"graph_name": "Character Subgraph",
|
||||
"type": "Lora Loader (LoraManager)",
|
||||
"title": "Loader",
|
||||
}
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
response = await handler.register_nodes(request)
|
||||
payload = json.loads(response.text)
|
||||
|
||||
assert payload["success"] is True
|
||||
|
||||
registry = await node_registry.get_registry()
|
||||
assert registry["node_count"] == 1
|
||||
stored_node = next(iter(registry["nodes"].values()))
|
||||
assert stored_node["graph_id"] == "graph-123"
|
||||
assert stored_node["unique_id"] == "graph-123:7"
|
||||
assert stored_node["graph_name"] == "Character Subgraph"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_nodes_defaults_graph_name_to_none():
|
||||
node_registry = NodeRegistry()
|
||||
handler = NodeRegistryHandler(
|
||||
node_registry=node_registry,
|
||||
prompt_server=FakePromptServer,
|
||||
standalone_mode=False,
|
||||
)
|
||||
|
||||
request = FakeRequest(
|
||||
json_data={
|
||||
"nodes": [
|
||||
{
|
||||
"node_id": 8,
|
||||
"graph_id": "root",
|
||||
"type": "Lora Loader (LoraManager)",
|
||||
"title": "Root Loader",
|
||||
}
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
response = await handler.register_nodes(request)
|
||||
payload = json.loads(response.text)
|
||||
|
||||
assert payload["success"] is True
|
||||
|
||||
registry = await node_registry.get_registry()
|
||||
stored_node = next(iter(registry["nodes"].values()))
|
||||
assert stored_node["graph_name"] is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_lora_code_includes_graph_identifier():
|
||||
send_calls: list[tuple[str, dict]] = []
|
||||
|
||||
class RecordingPromptServer:
|
||||
class Instance:
|
||||
def send_sync(self, event, payload):
|
||||
send_calls.append((event, payload))
|
||||
|
||||
instance = Instance()
|
||||
|
||||
handler = LoraCodeHandler(RecordingPromptServer)
|
||||
|
||||
request = FakeRequest(
|
||||
json_data={
|
||||
"node_ids": [{"node_id": 3, "graph_id": "graph-A"}],
|
||||
"lora_code": "<lora>",
|
||||
"mode": "replace",
|
||||
}
|
||||
)
|
||||
|
||||
response = await handler.update_lora_code(request)
|
||||
payload = json.loads(response.text)
|
||||
|
||||
assert payload["success"] is True
|
||||
assert payload["results"] == [
|
||||
{"node_id": 3, "graph_id": "graph-A", "success": True}
|
||||
]
|
||||
assert send_calls == [
|
||||
(
|
||||
"lora_code_update",
|
||||
{"id": 3, "graph_id": "graph-A", "lora_code": "<lora>", "mode": "replace"},
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
class FakeScanner:
|
||||
async def check_model_version_exists(self, _version_id):
|
||||
return False
|
||||
|
||||
Reference in New Issue
Block a user