mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-24 22:52:12 -03:00
Merge pull request #527 from willmiao/codex/add-unit-tests-for-metadata-components
Add metadata collector unit tests and fixtures
This commit is contained in:
165
tests/metadata_collector/conftest.py
Normal file
165
tests/metadata_collector/conftest.py
Normal file
@@ -0,0 +1,165 @@
|
|||||||
|
import types
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from py.metadata_collector.metadata_registry import MetadataRegistry
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def metadata_registry():
|
||||||
|
"""Provide a clean MetadataRegistry singleton for each test."""
|
||||||
|
registry = MetadataRegistry()
|
||||||
|
registry.clear_metadata()
|
||||||
|
yield registry
|
||||||
|
registry.clear_metadata()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def populated_registry(metadata_registry):
|
||||||
|
"""Populate the registry with a simulated ComfyUI node graph."""
|
||||||
|
import nodes
|
||||||
|
|
||||||
|
# Ensure node mappings exist for extractor lookups
|
||||||
|
class TSC_EfficientLoader: # type: ignore[too-many-ancestors]
|
||||||
|
__name__ = "TSC_EfficientLoader"
|
||||||
|
|
||||||
|
class SamplerCustomAdvanced: # type: ignore[too-many-ancestors]
|
||||||
|
__name__ = "SamplerCustomAdvanced"
|
||||||
|
|
||||||
|
class BasicScheduler: # type: ignore[too-many-ancestors]
|
||||||
|
__name__ = "BasicScheduler"
|
||||||
|
|
||||||
|
class KSamplerSelect: # type: ignore[too-many-ancestors]
|
||||||
|
__name__ = "KSamplerSelect"
|
||||||
|
|
||||||
|
class CFGGuider: # type: ignore[too-many-ancestors]
|
||||||
|
__name__ = "CFGGuider"
|
||||||
|
|
||||||
|
class CLIPTextEncode: # type: ignore[too-many-ancestors]
|
||||||
|
__name__ = "CLIPTextEncode"
|
||||||
|
|
||||||
|
class VAEDecode: # type: ignore[too-many-ancestors]
|
||||||
|
__name__ = "VAEDecode"
|
||||||
|
|
||||||
|
nodes.NODE_CLASS_MAPPINGS.update(
|
||||||
|
{
|
||||||
|
"TSC_EfficientLoader": TSC_EfficientLoader,
|
||||||
|
"SamplerCustomAdvanced": SamplerCustomAdvanced,
|
||||||
|
"BasicScheduler": BasicScheduler,
|
||||||
|
"KSamplerSelect": KSamplerSelect,
|
||||||
|
"CFGGuider": CFGGuider,
|
||||||
|
"CLIPTextEncode": CLIPTextEncode,
|
||||||
|
"VAEDecode": VAEDecode,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt_graph = {
|
||||||
|
"loader": {"class_type": "TSC_EfficientLoader", "inputs": {}},
|
||||||
|
"encode_pos": {"class_type": "CLIPTextEncode", "inputs": {"text": "A castle on a hill"}},
|
||||||
|
"encode_neg": {"class_type": "CLIPTextEncode", "inputs": {"text": "low quality"}},
|
||||||
|
"cfg_guider": {
|
||||||
|
"class_type": "CFGGuider",
|
||||||
|
"inputs": {
|
||||||
|
"cfg": 7.5,
|
||||||
|
"positive": ["encode_pos", 0],
|
||||||
|
"negative": ["encode_neg", 0],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"scheduler": {
|
||||||
|
"class_type": "BasicScheduler",
|
||||||
|
"inputs": {
|
||||||
|
"steps": 20,
|
||||||
|
"scheduler": "karras",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"sampler_select": {
|
||||||
|
"class_type": "KSamplerSelect",
|
||||||
|
"inputs": {"sampler_name": "Euler"},
|
||||||
|
},
|
||||||
|
"sampler": {
|
||||||
|
"class_type": "SamplerCustomAdvanced",
|
||||||
|
"inputs": {
|
||||||
|
"sigmas": ["scheduler", 0],
|
||||||
|
"sampler": ["sampler_select", 0],
|
||||||
|
"guider": ["cfg_guider", 0],
|
||||||
|
"positive": ["cfg_guider", 0],
|
||||||
|
"negative": ["cfg_guider", 0],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"vae": {
|
||||||
|
"class_type": "VAEDecode",
|
||||||
|
"inputs": {"samples": ["sampler", 0]},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
prompt = SimpleNamespace(original_prompt=prompt_graph)
|
||||||
|
|
||||||
|
pos_conditioning = object()
|
||||||
|
neg_conditioning = object()
|
||||||
|
latent_samples = types.SimpleNamespace(shape=(1, 4, 16, 16))
|
||||||
|
|
||||||
|
metadata_registry.start_collection("promptA")
|
||||||
|
metadata_registry.set_current_prompt(prompt)
|
||||||
|
|
||||||
|
# Loader node populates checkpoint, loras, and prompt text metadata
|
||||||
|
loader_inputs = {
|
||||||
|
"ckpt_name": "model.safetensors",
|
||||||
|
"lora_stack": (("/loras/my-lora.safetensors", 0.6, 0.5),),
|
||||||
|
"positive": "A castle on a hill",
|
||||||
|
"negative": "low quality",
|
||||||
|
}
|
||||||
|
metadata_registry.record_node_execution("loader", "TSC_EfficientLoader", loader_inputs, None)
|
||||||
|
loader_outputs = [
|
||||||
|
(
|
||||||
|
None,
|
||||||
|
pos_conditioning,
|
||||||
|
neg_conditioning,
|
||||||
|
{"samples": latent_samples},
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
{},
|
||||||
|
)
|
||||||
|
]
|
||||||
|
metadata_registry.update_node_execution("loader", "TSC_EfficientLoader", loader_outputs)
|
||||||
|
|
||||||
|
# Positive and negative prompt encoders
|
||||||
|
metadata_registry.record_node_execution("encode_pos", "CLIPTextEncode", {"text": "A castle on a hill"}, None)
|
||||||
|
metadata_registry.update_node_execution("encode_pos", "CLIPTextEncode", [(pos_conditioning,)])
|
||||||
|
metadata_registry.record_node_execution("encode_neg", "CLIPTextEncode", {"text": "low quality"}, None)
|
||||||
|
metadata_registry.update_node_execution("encode_neg", "CLIPTextEncode", [(neg_conditioning,)])
|
||||||
|
|
||||||
|
# CFG guider and scheduler nodes
|
||||||
|
metadata_registry.record_node_execution("cfg_guider", "CFGGuider", {"cfg": 7.5}, None)
|
||||||
|
metadata_registry.record_node_execution(
|
||||||
|
"scheduler",
|
||||||
|
"BasicScheduler",
|
||||||
|
{"steps": 20, "scheduler": "karras"},
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
metadata_registry.record_node_execution(
|
||||||
|
"sampler_select", "KSamplerSelect", {"sampler_name": "Euler"}, None
|
||||||
|
)
|
||||||
|
|
||||||
|
# Sampler execution populates sampling metadata and links conditioning
|
||||||
|
sampler_inputs = {
|
||||||
|
"noise": types.SimpleNamespace(seed=999),
|
||||||
|
"positive": pos_conditioning,
|
||||||
|
"negative": neg_conditioning,
|
||||||
|
"latent_image": {"samples": latent_samples},
|
||||||
|
}
|
||||||
|
metadata_registry.record_node_execution("sampler", "SamplerCustomAdvanced", sampler_inputs, None)
|
||||||
|
|
||||||
|
# VAEDecode outputs image data
|
||||||
|
metadata_registry.record_node_execution("vae", "VAEDecode", {}, None)
|
||||||
|
metadata_registry.update_node_execution("vae", "VAEDecode", ["image-data"])
|
||||||
|
|
||||||
|
metadata = metadata_registry.get_metadata("promptA")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"registry": metadata_registry,
|
||||||
|
"prompt": prompt,
|
||||||
|
"metadata": metadata,
|
||||||
|
"pos_conditioning": pos_conditioning,
|
||||||
|
"neg_conditioning": neg_conditioning,
|
||||||
|
}
|
||||||
126
tests/metadata_collector/test_metadata_collector.py
Normal file
126
tests/metadata_collector/test_metadata_collector.py
Normal file
@@ -0,0 +1,126 @@
|
|||||||
|
import sys
|
||||||
|
import types
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
from py.metadata_collector import metadata_processor
|
||||||
|
from py.metadata_collector.metadata_hook import MetadataHook
|
||||||
|
from py.metadata_collector.metadata_processor import MetadataProcessor
|
||||||
|
from py.metadata_collector.metadata_registry import MetadataRegistry
|
||||||
|
from py.metadata_collector.constants import LORAS, MODELS, PROMPTS, SAMPLING, SIZE
|
||||||
|
|
||||||
|
|
||||||
|
def test_metadata_hook_installs_and_traces_execution(monkeypatch, metadata_registry):
|
||||||
|
"""Ensure MetadataHook installs wrappers and records node execution."""
|
||||||
|
fake_execution = types.SimpleNamespace()
|
||||||
|
def original_map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None):
|
||||||
|
return {"outputs": "result"}
|
||||||
|
|
||||||
|
def original_execute(*args, **kwargs):
|
||||||
|
return "executed"
|
||||||
|
|
||||||
|
fake_execution._map_node_over_list = original_map_node_over_list
|
||||||
|
fake_execution.execute = original_execute
|
||||||
|
|
||||||
|
monkeypatch.setitem(sys.modules, "execution", fake_execution)
|
||||||
|
|
||||||
|
MetadataHook.install()
|
||||||
|
|
||||||
|
assert fake_execution._map_node_over_list is not original_map_node_over_list
|
||||||
|
assert fake_execution.execute is not original_execute
|
||||||
|
|
||||||
|
calls = []
|
||||||
|
|
||||||
|
def record_stub(self, node_id, class_type, inputs, outputs):
|
||||||
|
calls.append(("record", node_id, class_type, inputs))
|
||||||
|
|
||||||
|
def update_stub(self, node_id, class_type, outputs):
|
||||||
|
calls.append(("update", node_id, class_type, outputs))
|
||||||
|
|
||||||
|
monkeypatch.setattr(MetadataRegistry, "record_node_execution", record_stub)
|
||||||
|
monkeypatch.setattr(MetadataRegistry, "update_node_execution", update_stub)
|
||||||
|
|
||||||
|
metadata_registry.start_collection("prompt-1")
|
||||||
|
metadata_registry.set_current_prompt(SimpleNamespace(original_prompt={}))
|
||||||
|
|
||||||
|
class FakeNode:
|
||||||
|
FUNCTION = "run"
|
||||||
|
|
||||||
|
node = FakeNode()
|
||||||
|
node.unique_id = "node-1"
|
||||||
|
|
||||||
|
wrapped_map = fake_execution._map_node_over_list
|
||||||
|
result = wrapped_map(node, {"input": ["value"]}, node.FUNCTION)
|
||||||
|
|
||||||
|
assert result == {"outputs": "result"}
|
||||||
|
assert ("record", "node-1", "FakeNode", {"input": ["value"]}) in calls
|
||||||
|
assert any(call[0] == "update" for call in calls)
|
||||||
|
|
||||||
|
metadata_registry.clear_metadata()
|
||||||
|
|
||||||
|
prompt = SimpleNamespace(original_prompt={})
|
||||||
|
execute_wrapper = fake_execution.execute
|
||||||
|
execute_wrapper("server", prompt, {}, None, None, None, "prompt-2")
|
||||||
|
|
||||||
|
registry = MetadataRegistry()
|
||||||
|
assert registry.current_prompt_id == "prompt-2"
|
||||||
|
assert registry.get_metadata("prompt-2")["current_prompt"] is prompt
|
||||||
|
|
||||||
|
|
||||||
|
def test_metadata_processor_extracts_generation_params(populated_registry, monkeypatch):
|
||||||
|
metadata = populated_registry["metadata"]
|
||||||
|
prompt = populated_registry["prompt"]
|
||||||
|
|
||||||
|
monkeypatch.setattr(metadata_processor, "standalone_mode", False)
|
||||||
|
|
||||||
|
sampler_id, sampler_data = MetadataProcessor.find_primary_sampler(metadata, downstream_id="vae")
|
||||||
|
assert sampler_id == "sampler"
|
||||||
|
assert sampler_data["parameters"]["seed"] == 999
|
||||||
|
|
||||||
|
positive_node = MetadataProcessor.trace_node_input(prompt, "cfg_guider", "positive", target_class="CLIPTextEncode")
|
||||||
|
assert positive_node == "encode_pos"
|
||||||
|
|
||||||
|
params = MetadataProcessor.extract_generation_params(metadata)
|
||||||
|
assert params["prompt"] == "A castle on a hill"
|
||||||
|
assert params["negative_prompt"] == "low quality"
|
||||||
|
assert params["seed"] == 999
|
||||||
|
assert params["steps"] == 20
|
||||||
|
assert params["cfg_scale"] == 7.5
|
||||||
|
assert params["sampler"] == "Euler"
|
||||||
|
assert params["scheduler"] == "karras"
|
||||||
|
assert params["checkpoint"] == "model.safetensors"
|
||||||
|
assert params["loras"] == "<lora:my-lora:0.6>"
|
||||||
|
assert params["size"] == "128x128"
|
||||||
|
|
||||||
|
params_dict = MetadataProcessor.to_dict(metadata)
|
||||||
|
assert params_dict["prompt"] == "A castle on a hill"
|
||||||
|
for value in params_dict.values():
|
||||||
|
if value is not None:
|
||||||
|
assert isinstance(value, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_metadata_registry_caches_and_rehydrates(populated_registry):
|
||||||
|
registry = populated_registry["registry"]
|
||||||
|
prompt = populated_registry["prompt"]
|
||||||
|
|
||||||
|
assert registry.node_cache # Cache should contain entries from the first prompt
|
||||||
|
|
||||||
|
new_prompt = SimpleNamespace(original_prompt=prompt.original_prompt)
|
||||||
|
registry.start_collection("promptB")
|
||||||
|
registry.set_current_prompt(new_prompt)
|
||||||
|
|
||||||
|
cache_entry = registry.node_cache.get("sampler:SamplerCustomAdvanced")
|
||||||
|
assert cache_entry is not None
|
||||||
|
|
||||||
|
metadata = registry.get_metadata("promptB")
|
||||||
|
|
||||||
|
assert metadata[MODELS]["loader"]["name"] == "model.safetensors"
|
||||||
|
assert metadata[PROMPTS]["loader"]["positive_text"] == "A castle on a hill"
|
||||||
|
assert metadata[SAMPLING]["sampler"]["parameters"]["seed"] == 999
|
||||||
|
assert metadata[LORAS]["loader"]["lora_list"][0]["name"] == "my-lora"
|
||||||
|
assert metadata[SIZE]["sampler"]["width"] == 128
|
||||||
|
|
||||||
|
image = registry.get_first_decoded_image("promptB")
|
||||||
|
assert image == "image-data"
|
||||||
|
|
||||||
|
registry.clear_metadata("promptA")
|
||||||
|
assert "promptA" not in registry.prompt_metadata
|
||||||
Reference in New Issue
Block a user