From cfec5447d306eaad3857f7a1d89bcc2cb58b82ba Mon Sep 17 00:00:00 2001 From: pixelpaws Date: Sun, 5 Oct 2025 14:44:17 +0800 Subject: [PATCH] test(metadata): add collector coverage --- tests/metadata_collector/conftest.py | 165 ++++++++++++++++++ .../test_metadata_collector.py | 126 +++++++++++++ 2 files changed, 291 insertions(+) create mode 100644 tests/metadata_collector/conftest.py create mode 100644 tests/metadata_collector/test_metadata_collector.py diff --git a/tests/metadata_collector/conftest.py b/tests/metadata_collector/conftest.py new file mode 100644 index 00000000..64b86bb8 --- /dev/null +++ b/tests/metadata_collector/conftest.py @@ -0,0 +1,165 @@ +import types +from types import SimpleNamespace + +import pytest + +from py.metadata_collector.metadata_registry import MetadataRegistry + + +@pytest.fixture +def metadata_registry(): + """Provide a clean MetadataRegistry singleton for each test.""" + registry = MetadataRegistry() + registry.clear_metadata() + yield registry + registry.clear_metadata() + + +@pytest.fixture +def populated_registry(metadata_registry): + """Populate the registry with a simulated ComfyUI node graph.""" + import nodes + + # Ensure node mappings exist for extractor lookups + class TSC_EfficientLoader: # type: ignore[too-many-ancestors] + __name__ = "TSC_EfficientLoader" + + class SamplerCustomAdvanced: # type: ignore[too-many-ancestors] + __name__ = "SamplerCustomAdvanced" + + class BasicScheduler: # type: ignore[too-many-ancestors] + __name__ = "BasicScheduler" + + class KSamplerSelect: # type: ignore[too-many-ancestors] + __name__ = "KSamplerSelect" + + class CFGGuider: # type: ignore[too-many-ancestors] + __name__ = "CFGGuider" + + class CLIPTextEncode: # type: ignore[too-many-ancestors] + __name__ = "CLIPTextEncode" + + class VAEDecode: # type: ignore[too-many-ancestors] + __name__ = "VAEDecode" + + nodes.NODE_CLASS_MAPPINGS.update( + { + "TSC_EfficientLoader": TSC_EfficientLoader, + "SamplerCustomAdvanced": SamplerCustomAdvanced, + "BasicScheduler": BasicScheduler, + "KSamplerSelect": KSamplerSelect, + "CFGGuider": CFGGuider, + "CLIPTextEncode": CLIPTextEncode, + "VAEDecode": VAEDecode, + } + ) + + prompt_graph = { + "loader": {"class_type": "TSC_EfficientLoader", "inputs": {}}, + "encode_pos": {"class_type": "CLIPTextEncode", "inputs": {"text": "A castle on a hill"}}, + "encode_neg": {"class_type": "CLIPTextEncode", "inputs": {"text": "low quality"}}, + "cfg_guider": { + "class_type": "CFGGuider", + "inputs": { + "cfg": 7.5, + "positive": ["encode_pos", 0], + "negative": ["encode_neg", 0], + }, + }, + "scheduler": { + "class_type": "BasicScheduler", + "inputs": { + "steps": 20, + "scheduler": "karras", + }, + }, + "sampler_select": { + "class_type": "KSamplerSelect", + "inputs": {"sampler_name": "Euler"}, + }, + "sampler": { + "class_type": "SamplerCustomAdvanced", + "inputs": { + "sigmas": ["scheduler", 0], + "sampler": ["sampler_select", 0], + "guider": ["cfg_guider", 0], + "positive": ["cfg_guider", 0], + "negative": ["cfg_guider", 0], + }, + }, + "vae": { + "class_type": "VAEDecode", + "inputs": {"samples": ["sampler", 0]}, + }, + } + + prompt = SimpleNamespace(original_prompt=prompt_graph) + + pos_conditioning = object() + neg_conditioning = object() + latent_samples = types.SimpleNamespace(shape=(1, 4, 16, 16)) + + metadata_registry.start_collection("promptA") + metadata_registry.set_current_prompt(prompt) + + # Loader node populates checkpoint, loras, and prompt text metadata + loader_inputs = { + "ckpt_name": "model.safetensors", + "lora_stack": (("/loras/my-lora.safetensors", 0.6, 0.5),), + "positive": "A castle on a hill", + "negative": "low quality", + } + metadata_registry.record_node_execution("loader", "TSC_EfficientLoader", loader_inputs, None) + loader_outputs = [ + ( + None, + pos_conditioning, + neg_conditioning, + {"samples": latent_samples}, + None, + None, + {}, + ) + ] + metadata_registry.update_node_execution("loader", "TSC_EfficientLoader", loader_outputs) + + # Positive and negative prompt encoders + metadata_registry.record_node_execution("encode_pos", "CLIPTextEncode", {"text": "A castle on a hill"}, None) + metadata_registry.update_node_execution("encode_pos", "CLIPTextEncode", [(pos_conditioning,)]) + metadata_registry.record_node_execution("encode_neg", "CLIPTextEncode", {"text": "low quality"}, None) + metadata_registry.update_node_execution("encode_neg", "CLIPTextEncode", [(neg_conditioning,)]) + + # CFG guider and scheduler nodes + metadata_registry.record_node_execution("cfg_guider", "CFGGuider", {"cfg": 7.5}, None) + metadata_registry.record_node_execution( + "scheduler", + "BasicScheduler", + {"steps": 20, "scheduler": "karras"}, + None, + ) + metadata_registry.record_node_execution( + "sampler_select", "KSamplerSelect", {"sampler_name": "Euler"}, None + ) + + # Sampler execution populates sampling metadata and links conditioning + sampler_inputs = { + "noise": types.SimpleNamespace(seed=999), + "positive": pos_conditioning, + "negative": neg_conditioning, + "latent_image": {"samples": latent_samples}, + } + metadata_registry.record_node_execution("sampler", "SamplerCustomAdvanced", sampler_inputs, None) + + # VAEDecode outputs image data + metadata_registry.record_node_execution("vae", "VAEDecode", {}, None) + metadata_registry.update_node_execution("vae", "VAEDecode", ["image-data"]) + + metadata = metadata_registry.get_metadata("promptA") + + return { + "registry": metadata_registry, + "prompt": prompt, + "metadata": metadata, + "pos_conditioning": pos_conditioning, + "neg_conditioning": neg_conditioning, + } diff --git a/tests/metadata_collector/test_metadata_collector.py b/tests/metadata_collector/test_metadata_collector.py new file mode 100644 index 00000000..3663f2a9 --- /dev/null +++ b/tests/metadata_collector/test_metadata_collector.py @@ -0,0 +1,126 @@ +import sys +import types +from types import SimpleNamespace + +from py.metadata_collector import metadata_processor +from py.metadata_collector.metadata_hook import MetadataHook +from py.metadata_collector.metadata_processor import MetadataProcessor +from py.metadata_collector.metadata_registry import MetadataRegistry +from py.metadata_collector.constants import LORAS, MODELS, PROMPTS, SAMPLING, SIZE + + +def test_metadata_hook_installs_and_traces_execution(monkeypatch, metadata_registry): + """Ensure MetadataHook installs wrappers and records node execution.""" + fake_execution = types.SimpleNamespace() + def original_map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None): + return {"outputs": "result"} + + def original_execute(*args, **kwargs): + return "executed" + + fake_execution._map_node_over_list = original_map_node_over_list + fake_execution.execute = original_execute + + monkeypatch.setitem(sys.modules, "execution", fake_execution) + + MetadataHook.install() + + assert fake_execution._map_node_over_list is not original_map_node_over_list + assert fake_execution.execute is not original_execute + + calls = [] + + def record_stub(self, node_id, class_type, inputs, outputs): + calls.append(("record", node_id, class_type, inputs)) + + def update_stub(self, node_id, class_type, outputs): + calls.append(("update", node_id, class_type, outputs)) + + monkeypatch.setattr(MetadataRegistry, "record_node_execution", record_stub) + monkeypatch.setattr(MetadataRegistry, "update_node_execution", update_stub) + + metadata_registry.start_collection("prompt-1") + metadata_registry.set_current_prompt(SimpleNamespace(original_prompt={})) + + class FakeNode: + FUNCTION = "run" + + node = FakeNode() + node.unique_id = "node-1" + + wrapped_map = fake_execution._map_node_over_list + result = wrapped_map(node, {"input": ["value"]}, node.FUNCTION) + + assert result == {"outputs": "result"} + assert ("record", "node-1", "FakeNode", {"input": ["value"]}) in calls + assert any(call[0] == "update" for call in calls) + + metadata_registry.clear_metadata() + + prompt = SimpleNamespace(original_prompt={}) + execute_wrapper = fake_execution.execute + execute_wrapper("server", prompt, {}, None, None, None, "prompt-2") + + registry = MetadataRegistry() + assert registry.current_prompt_id == "prompt-2" + assert registry.get_metadata("prompt-2")["current_prompt"] is prompt + + +def test_metadata_processor_extracts_generation_params(populated_registry, monkeypatch): + metadata = populated_registry["metadata"] + prompt = populated_registry["prompt"] + + monkeypatch.setattr(metadata_processor, "standalone_mode", False) + + sampler_id, sampler_data = MetadataProcessor.find_primary_sampler(metadata, downstream_id="vae") + assert sampler_id == "sampler" + assert sampler_data["parameters"]["seed"] == 999 + + positive_node = MetadataProcessor.trace_node_input(prompt, "cfg_guider", "positive", target_class="CLIPTextEncode") + assert positive_node == "encode_pos" + + params = MetadataProcessor.extract_generation_params(metadata) + assert params["prompt"] == "A castle on a hill" + assert params["negative_prompt"] == "low quality" + assert params["seed"] == 999 + assert params["steps"] == 20 + assert params["cfg_scale"] == 7.5 + assert params["sampler"] == "Euler" + assert params["scheduler"] == "karras" + assert params["checkpoint"] == "model.safetensors" + assert params["loras"] == "" + assert params["size"] == "128x128" + + params_dict = MetadataProcessor.to_dict(metadata) + assert params_dict["prompt"] == "A castle on a hill" + for value in params_dict.values(): + if value is not None: + assert isinstance(value, str) + + +def test_metadata_registry_caches_and_rehydrates(populated_registry): + registry = populated_registry["registry"] + prompt = populated_registry["prompt"] + + assert registry.node_cache # Cache should contain entries from the first prompt + + new_prompt = SimpleNamespace(original_prompt=prompt.original_prompt) + registry.start_collection("promptB") + registry.set_current_prompt(new_prompt) + + cache_entry = registry.node_cache.get("sampler:SamplerCustomAdvanced") + assert cache_entry is not None + + metadata = registry.get_metadata("promptB") + + assert metadata[MODELS]["loader"]["name"] == "model.safetensors" + assert metadata[PROMPTS]["loader"]["positive_text"] == "A castle on a hill" + assert metadata[SAMPLING]["sampler"]["parameters"]["seed"] == 999 + assert metadata[LORAS]["loader"]["lora_list"][0]["name"] == "my-lora" + assert metadata[SIZE]["sampler"]["width"] == 128 + + image = registry.get_first_decoded_image("promptB") + assert image == "image-data" + + registry.clear_metadata("promptA") + assert "promptA" not in registry.prompt_metadata