mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-06-29 05:51:16 -03:00
Move NodeRegistry from a single global _nodes dict to a per-client (_tab_nodes) structure so that multiple ComfyUI browser tabs no longer overwrite each other's workflow node data during a lora_registry_refresh cycle. The merged result is a union of all known tabs' target nodes, eliminating the non-deterministic failure where send-to-workflow could randomly target a tab lacking valid targets. - NodeRegistry.register_nodes(sid, nodes) replaces per-tab data without affecting other tabs. - NodeRegistry.get_merged_registry() returns the union across all connected clients, together with tab_count / per-tab metadata. - prepare_for_refresh() snapshots the current active sockets; caller re-reads before merging so that newly-connected tabs are not pruned. - workflow_registry.js sends api.clientId in the POST body so the backend can identify which tab is registering.
259 lines
7.9 KiB
Python
259 lines
7.9 KiB
Python
"""Snapshot tests for API response formats using Syrupy.
|
|
|
|
These tests verify that API responses maintain consistent structure and format
|
|
by comparing against stored snapshots. This catches unexpected changes to
|
|
response schemas.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import pytest
|
|
from types import SimpleNamespace
|
|
from syrupy import SnapshotAssertion
|
|
|
|
from py.routes.handlers.misc_handlers import (
|
|
ModelLibraryHandler,
|
|
NodeRegistry,
|
|
NodeRegistryHandler,
|
|
ServiceRegistryAdapter,
|
|
SettingsHandler,
|
|
)
|
|
from py.utils.utils import calculate_recipe_fingerprint, sanitize_folder_name
|
|
|
|
|
|
class FakeRequest:
|
|
"""Fake HTTP request for testing."""
|
|
|
|
def __init__(self, *, json_data=None, query=None):
|
|
self._json_data = json_data or {}
|
|
self.query = query or {}
|
|
|
|
async def json(self):
|
|
return self._json_data
|
|
|
|
|
|
class DummySettings:
|
|
"""Dummy settings service for testing."""
|
|
|
|
def __init__(self, data=None):
|
|
self.data = data or {}
|
|
|
|
def get(self, key, default=None):
|
|
return self.data.get(key, default)
|
|
|
|
def set(self, key, value):
|
|
self.data[key] = value
|
|
|
|
def keys(self):
|
|
return self.data.keys()
|
|
|
|
|
|
async def noop_async(*_args, **_kwargs):
|
|
"""No-op async function."""
|
|
return None
|
|
|
|
|
|
class FakePromptServer:
|
|
"""Fake prompt server for testing."""
|
|
|
|
sent = []
|
|
|
|
class Instance:
|
|
sockets: dict = {}
|
|
|
|
def send_sync(self, event, payload, sid=None):
|
|
FakePromptServer.sent.append((event, payload))
|
|
|
|
instance = Instance()
|
|
|
|
|
|
class FakeDownloadHistoryService:
|
|
async def has_been_downloaded(self, _model_type, _version_id):
|
|
return False
|
|
|
|
async def get_downloaded_version_ids(self, _model_type, _model_id):
|
|
return []
|
|
|
|
async def get_downloaded_version_ids_bulk(self, _model_type, _model_ids):
|
|
return {}
|
|
|
|
async def mark_downloaded(self, *_args, **_kwargs):
|
|
return None
|
|
|
|
async def mark_as_deleted(self, *_args, **_kwargs):
|
|
return None
|
|
|
|
|
|
async def fake_download_history_service_factory():
|
|
return FakeDownloadHistoryService()
|
|
|
|
|
|
class TestSettingsHandlerSnapshots:
|
|
"""Snapshot tests for SettingsHandler responses."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_settings_response_format(self, snapshot: SnapshotAssertion):
|
|
"""Verify get_settings response format matches snapshot."""
|
|
settings_service = DummySettings({
|
|
"civitai_api_key": "test-key",
|
|
"language": "en",
|
|
"theme": "dark"
|
|
})
|
|
handler = SettingsHandler(
|
|
settings_service=settings_service,
|
|
metadata_provider_updater=noop_async,
|
|
downloader_factory=lambda: None,
|
|
)
|
|
|
|
response = await handler.get_settings(FakeRequest())
|
|
payload = json.loads(response.text)
|
|
|
|
assert payload == snapshot
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_update_settings_success_response(self, snapshot: SnapshotAssertion):
|
|
"""Verify successful update_settings response format."""
|
|
settings_service = DummySettings()
|
|
handler = SettingsHandler(
|
|
settings_service=settings_service,
|
|
metadata_provider_updater=noop_async,
|
|
downloader_factory=lambda: None,
|
|
)
|
|
|
|
request = FakeRequest(json_data={"language": "zh"})
|
|
response = await handler.update_settings(request)
|
|
payload = json.loads(response.text)
|
|
|
|
assert payload == snapshot
|
|
|
|
|
|
class TestNodeRegistryHandlerSnapshots:
|
|
"""Snapshot tests for NodeRegistryHandler responses."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_register_nodes_success_response(self, snapshot: SnapshotAssertion):
|
|
"""Verify successful register_nodes response format."""
|
|
node_registry = NodeRegistry()
|
|
handler = NodeRegistryHandler(
|
|
node_registry=node_registry,
|
|
prompt_server=FakePromptServer,
|
|
standalone_mode=False,
|
|
)
|
|
|
|
request = FakeRequest(
|
|
json_data={
|
|
"nodes": [
|
|
{
|
|
"node_id": 1,
|
|
"graph_id": "root",
|
|
"type": "Lora Loader (LoraManager)",
|
|
"title": "Test Loader",
|
|
}
|
|
],
|
|
"client_id": "test-client-1",
|
|
}
|
|
)
|
|
|
|
response = await handler.register_nodes(request)
|
|
payload = json.loads(response.text)
|
|
|
|
assert payload == snapshot
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_register_nodes_error_response(self, snapshot: SnapshotAssertion):
|
|
"""Verify error register_nodes response format."""
|
|
node_registry = NodeRegistry()
|
|
handler = NodeRegistryHandler(
|
|
node_registry=node_registry,
|
|
prompt_server=FakePromptServer,
|
|
standalone_mode=False,
|
|
)
|
|
|
|
request = FakeRequest(json_data={"nodes": [], "client_id": "test-client-1"})
|
|
response = await handler.register_nodes(request)
|
|
payload = json.loads(response.text)
|
|
|
|
assert payload == snapshot
|
|
|
|
|
|
class TestUtilityFunctionSnapshots:
|
|
"""Snapshot tests for utility function outputs."""
|
|
|
|
def test_sanitize_folder_name_various_inputs(self, snapshot: SnapshotAssertion):
|
|
"""Verify sanitize_folder_name produces expected outputs."""
|
|
test_inputs = [
|
|
"normal_folder",
|
|
"folder with spaces",
|
|
"folder/with/slashes",
|
|
'folder\\with\\backslashes',
|
|
'folder<with>brackets',
|
|
'folder"with"quotes',
|
|
'folder|with|pipes',
|
|
'folder?with?questions',
|
|
'folder*with*asterisks',
|
|
'',
|
|
' spaces ',
|
|
'folder.with.dots',
|
|
'___underscores___',
|
|
]
|
|
|
|
results = {input_name: sanitize_folder_name(input_name) for input_name in test_inputs}
|
|
assert results == snapshot
|
|
|
|
def test_calculate_recipe_fingerprint_various_inputs(self, snapshot: SnapshotAssertion):
|
|
"""Verify calculate_recipe_fingerprint produces expected outputs."""
|
|
test_cases = [
|
|
[],
|
|
[{"hash": "abc123", "strength": 1.0}],
|
|
[
|
|
{"hash": "abc123", "strength": 1.0},
|
|
{"hash": "def456", "strength": 0.75},
|
|
],
|
|
[
|
|
{"hash": "DEF456", "strength": 1.0},
|
|
{"hash": "ABC123", "strength": 0.5},
|
|
],
|
|
[{"hash": "abc123", "weight": 0.8}],
|
|
[{"modelVersionId": 12345, "strength": 1.0}],
|
|
[{"hash": "abc123", "exclude": True, "strength": 1.0}],
|
|
[{"hash": "", "strength": 1.0}],
|
|
[{"strength": 1.0}],
|
|
]
|
|
|
|
results = [calculate_recipe_fingerprint(loras) for loras in test_cases]
|
|
assert results == snapshot
|
|
|
|
|
|
class TestModelLibraryHandlerSnapshots:
|
|
"""Snapshot tests for ModelLibraryHandler responses."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_check_model_exists_empty_response(self, snapshot: SnapshotAssertion):
|
|
"""Verify check_model_exists with no versions response format."""
|
|
|
|
class EmptyVersionScanner:
|
|
async def check_model_version_exists(self, _version_id):
|
|
return False
|
|
|
|
async def get_model_versions_by_id(self, _model_id):
|
|
return []
|
|
|
|
async def scanner_factory():
|
|
return EmptyVersionScanner()
|
|
|
|
handler = ModelLibraryHandler(
|
|
ServiceRegistryAdapter(
|
|
get_lora_scanner=scanner_factory,
|
|
get_checkpoint_scanner=scanner_factory,
|
|
get_embedding_scanner=scanner_factory,
|
|
get_downloaded_version_history_service=fake_download_history_service_factory,
|
|
),
|
|
metadata_provider_factory=lambda: None,
|
|
)
|
|
|
|
response = await handler.check_model_exists(FakeRequest(query={"modelId": "1"}))
|
|
payload = json.loads(response.text)
|
|
|
|
assert payload == snapshot
|