diff --git a/py/metadata_collector/node_extractors.py b/py/metadata_collector/node_extractors.py index 44de9054..1dda702f 100644 --- a/py/metadata_collector/node_extractors.py +++ b/py/metadata_collector/node_extractors.py @@ -788,6 +788,7 @@ NODE_EXTRACTORS = { "TensorRTLoader": TensorRTLoaderExtractor, # Conditioning "CLIPTextEncode": CLIPTextEncodeExtractor, + "CLIPTextEncodeAttentionBias": CLIPTextEncodeExtractor, # From https://github.com/silveroxides/ComfyUI_PromptAttention "PromptLM": CLIPTextEncodeExtractor, "CLIPTextEncodeFlux": CLIPTextEncodeFluxExtractor, # Add CLIPTextEncodeFlux "WAS_Text_to_Conditioning": CLIPTextEncodeExtractor, diff --git a/tests/metadata_collector/test_metadata_collector.py b/tests/metadata_collector/test_metadata_collector.py index edb2d65a..c62d56f7 100644 --- a/tests/metadata_collector/test_metadata_collector.py +++ b/tests/metadata_collector/test_metadata_collector.py @@ -98,6 +98,85 @@ def test_metadata_processor_extracts_generation_params(populated_registry, monke assert isinstance(value, str) +def test_attention_bias_clip_text_encode_prompts_are_collected(metadata_registry, monkeypatch): + import types + + prompt_graph = { + "encode_pos": { + "class_type": "CLIPTextEncodeAttentionBias", + "inputs": {"text": "A on a hill", "clip": ["clip", 0]}, + }, + "encode_neg": { + "class_type": "CLIPTextEncodeAttentionBias", + "inputs": {"text": "low quality", "clip": ["clip", 0]}, + }, + "sampler": { + "class_type": "KSampler", + "inputs": { + "seed": types.SimpleNamespace(seed=123), + "steps": 20, + "cfg": 7.0, + "sampler_name": "Euler", + "scheduler": "karras", + "denoise": 1.0, + "positive": ["encode_pos", 0], + "negative": ["encode_neg", 0], + "latent_image": {"samples": types.SimpleNamespace(shape=(1, 4, 16, 16))}, + }, + }, + } + prompt = SimpleNamespace(original_prompt=prompt_graph) + + pos_conditioning = object() + neg_conditioning = object() + + monkeypatch.setattr(metadata_processor, "standalone_mode", False) + + metadata_registry.start_collection("prompt-attention") + metadata_registry.set_current_prompt(prompt) + + metadata_registry.record_node_execution( + "encode_pos", + "CLIPTextEncodeAttentionBias", + {"text": "A on a hill"}, + None, + ) + metadata_registry.update_node_execution( + "encode_pos", "CLIPTextEncodeAttentionBias", [(pos_conditioning,)] + ) + metadata_registry.record_node_execution( + "encode_neg", + "CLIPTextEncodeAttentionBias", + {"text": "low quality"}, + None, + ) + metadata_registry.update_node_execution( + "encode_neg", "CLIPTextEncodeAttentionBias", [(neg_conditioning,)] + ) + metadata_registry.record_node_execution( + "sampler", + "KSampler", + { + "seed": types.SimpleNamespace(seed=123), + "positive": pos_conditioning, + "negative": neg_conditioning, + "latent_image": {"samples": types.SimpleNamespace(shape=(1, 4, 16, 16))}, + }, + None, + ) + + metadata = metadata_registry.get_metadata("prompt-attention") + sampler_data = metadata[SAMPLING]["sampler"] + prompt_results = MetadataProcessor.match_conditioning_to_prompts(metadata, "sampler") + + assert metadata[PROMPTS]["encode_pos"]["text"] == "A on a hill" + assert metadata[PROMPTS]["encode_neg"]["text"] == "low quality" + assert sampler_data["node_id"] == "sampler" + assert sampler_data["is_sampler"] is True + assert prompt_results["prompt"] == "A on a hill" + assert prompt_results["negative_prompt"] == "low quality" + + def test_metadata_registry_caches_and_rehydrates(populated_registry): registry = populated_registry["registry"] prompt = populated_registry["prompt"]