diff --git a/py/metadata_collector/metadata_processor.py b/py/metadata_collector/metadata_processor.py index c2ba585c..d7dcab65 100644 --- a/py/metadata_collector/metadata_processor.py +++ b/py/metadata_collector/metadata_processor.py @@ -352,50 +352,101 @@ class MetadataProcessor: # Check if we have stored conditioning objects for this sampler if sampler_id in metadata.get(PROMPTS, {}) and ( - "pos_conditioning" in metadata[PROMPTS][sampler_id] or - "neg_conditioning" in metadata[PROMPTS][sampler_id]): - + "pos_conditioning" in metadata[PROMPTS][sampler_id] or + "neg_conditioning" in metadata[PROMPTS][sampler_id] + ): pos_conditioning = metadata[PROMPTS][sampler_id].get("pos_conditioning") neg_conditioning = metadata[PROMPTS][sampler_id].get("neg_conditioning") - - # Helper function to recursively find prompt text for a conditioning object - def find_prompt_text_for_conditioning(conditioning_obj, is_positive=True): + + def extend_unique(target, values): + for value in values: + if value and value not in target: + target.append(value) + + # Helper function to recursively find prompt texts for a conditioning object. + # Transform nodes can map one output conditioning to multiple source conditionings. + def find_prompt_texts_for_conditioning( + conditioning_obj, is_positive=True, visited=None + ): if conditioning_obj is None: - return "" - + return [] + + if visited is None: + visited = set() + + conditioning_id = id(conditioning_obj) + if conditioning_id in visited: + return [] + visited.add(conditioning_id) + + prompt_texts = [] + # Try to match conditioning objects with those stored by extractors for prompt_node_id, prompt_data in metadata[PROMPTS].items(): - # For nodes with single conditioning output - if "conditioning" in prompt_data: - if id(prompt_data["conditioning"]) == id(conditioning_obj): - return prompt_data.get("text", "") - - # For nodes with separate pos_conditioning and neg_conditioning outputs (like TSC_EfficientLoader) - if is_positive and "positive_encoded" in prompt_data: - if id(prompt_data["positive_encoded"]) == id(conditioning_obj): - if "positive_text" in prompt_data: - return prompt_data["positive_text"] - else: - orig_conditioning = prompt_data.get("orig_pos_cond", None) - if orig_conditioning is not None: - # Recursively find the prompt text for the original conditioning - return find_prompt_text_for_conditioning(orig_conditioning, is_positive=True) - - if not is_positive and "negative_encoded" in prompt_data: - if id(prompt_data["negative_encoded"]) == id(conditioning_obj): - if "negative_text" in prompt_data: - return prompt_data["negative_text"] - else: - orig_conditioning = prompt_data.get("orig_neg_cond", None) - if orig_conditioning is not None: - # Recursively find the prompt text for the original conditioning - return find_prompt_text_for_conditioning(orig_conditioning, is_positive=False) - - return "" - + if not isinstance(prompt_data, dict): + continue + + # For CLIP text nodes with a single conditioning output. + if id(prompt_data.get("conditioning")) == conditioning_id: + text = prompt_data.get("text", "") + if text: + extend_unique(prompt_texts, [text]) + + # Generic provenance for passthrough/transform/combine nodes. + for source in prompt_data.get("conditioning_sources", []): + if id(source.get("output")) != conditioning_id: + continue + for input_conditioning in source.get("inputs", []): + extend_unique( + prompt_texts, + find_prompt_texts_for_conditioning( + input_conditioning, is_positive, visited + ), + ) + + # For nodes with separate pos_conditioning and neg_conditioning outputs + # like TSC_EfficientLoader and existing ControlNet-style metadata. + if ( + is_positive + and id(prompt_data.get("positive_encoded")) == conditioning_id + ): + if prompt_data.get("positive_text"): + extend_unique(prompt_texts, [prompt_data["positive_text"]]) + else: + extend_unique( + prompt_texts, + find_prompt_texts_for_conditioning( + prompt_data.get("orig_pos_cond"), + is_positive=True, + visited=visited, + ), + ) + + if ( + not is_positive + and id(prompt_data.get("negative_encoded")) == conditioning_id + ): + if prompt_data.get("negative_text"): + extend_unique(prompt_texts, [prompt_data["negative_text"]]) + else: + extend_unique( + prompt_texts, + find_prompt_texts_for_conditioning( + prompt_data.get("orig_neg_cond"), + is_positive=False, + visited=visited, + ), + ) + + return prompt_texts + # Find prompt texts using the helper function - result["prompt"] = find_prompt_text_for_conditioning(pos_conditioning, is_positive=True) - result["negative_prompt"] = find_prompt_text_for_conditioning(neg_conditioning, is_positive=False) + result["prompt"] = ", ".join( + find_prompt_texts_for_conditioning(pos_conditioning, is_positive=True) + ) + result["negative_prompt"] = ", ".join( + find_prompt_texts_for_conditioning(neg_conditioning, is_positive=False) + ) return result diff --git a/py/metadata_collector/node_extractors.py b/py/metadata_collector/node_extractors.py index 7d6096a7..494306e0 100644 --- a/py/metadata_collector/node_extractors.py +++ b/py/metadata_collector/node_extractors.py @@ -163,6 +163,193 @@ class CLIPTextEncodeExtractor(NodeMetadataExtractor): conditioning = outputs[0][0] metadata[PROMPTS][node_id]["conditioning"] = conditioning + +def _ensure_prompt_metadata(metadata, node_id): + if node_id not in metadata[PROMPTS]: + metadata[PROMPTS][node_id] = {"node_id": node_id} + return metadata[PROMPTS][node_id] + + +def _first_output_tuple(outputs): + if not outputs or not isinstance(outputs, list) or len(outputs) == 0: + return None + first_output = outputs[0] + if isinstance(first_output, tuple): + return first_output + return None + + +def _record_conditioning_source( + metadata, node_id, output_conditioning, input_conditionings +): + if output_conditioning is None: + return + + sources = [ + conditioning for conditioning in input_conditionings if conditioning is not None + ] + if not sources: + return + + prompt_metadata = _ensure_prompt_metadata(metadata, node_id) + prompt_metadata.setdefault("conditioning_sources", []).append( + { + "output": output_conditioning, + "inputs": sources, + } + ) + + +def _get_variable_name(inputs): + for key in ("key", "name", "variable_name", "tag", "text"): + value = inputs.get(key) + if isinstance(value, str) and value: + return value + return None + + +def _get_node_variable_name(metadata, node_id, inputs): + variable_name = _get_variable_name(inputs) + if variable_name: + return variable_name + + prompt = metadata.get("current_prompt") + original_prompt = getattr(prompt, "original_prompt", None) + if not original_prompt or node_id not in original_prompt: + return None + + node_data = original_prompt[node_id] + variable_name = _get_variable_name(node_data.get("inputs", {})) + if variable_name: + return variable_name + + widgets_values = node_data.get("widgets_values", []) + if widgets_values and isinstance(widgets_values[0], str): + return widgets_values[0] + + return None + + +class ControlNetApplyAdvancedExtractor(NodeMetadataExtractor): + @staticmethod + def extract(node_id, inputs, outputs, metadata): + if not inputs: + return + + prompt_metadata = _ensure_prompt_metadata(metadata, node_id) + if inputs.get("positive") is not None: + prompt_metadata["orig_pos_cond"] = inputs["positive"] + if inputs.get("negative") is not None: + prompt_metadata["orig_neg_cond"] = inputs["negative"] + + @staticmethod + def update(node_id, outputs, metadata): + output_tuple = _first_output_tuple(outputs) + if not output_tuple: + return + + prompt_metadata = _ensure_prompt_metadata(metadata, node_id) + positive_input = prompt_metadata.get("orig_pos_cond") + negative_input = prompt_metadata.get("orig_neg_cond") + + if len(output_tuple) >= 1: + prompt_metadata["positive_encoded"] = output_tuple[0] + _record_conditioning_source( + metadata, node_id, output_tuple[0], [positive_input] + ) + if len(output_tuple) >= 2: + prompt_metadata["negative_encoded"] = output_tuple[1] + _record_conditioning_source( + metadata, node_id, output_tuple[1], [negative_input] + ) + + +class ConditioningCombineExtractor(NodeMetadataExtractor): + @staticmethod + def extract(node_id, inputs, outputs, metadata): + if not inputs: + return + + input_conditionings = [] + for input_name in inputs: + if ( + input_name.startswith("conditioning") + and inputs[input_name] is not None + ): + input_conditionings.append(inputs[input_name]) + + if input_conditionings: + prompt_metadata = _ensure_prompt_metadata(metadata, node_id) + prompt_metadata["orig_conditionings"] = input_conditionings + + @staticmethod + def update(node_id, outputs, metadata): + output_tuple = _first_output_tuple(outputs) + if not output_tuple or len(output_tuple) < 1: + return + + prompt_metadata = _ensure_prompt_metadata(metadata, node_id) + output_conditioning = output_tuple[0] + prompt_metadata["conditioning"] = output_conditioning + _record_conditioning_source( + metadata, + node_id, + output_conditioning, + prompt_metadata.get("orig_conditionings", []), + ) + + +class SetNodeExtractor(NodeMetadataExtractor): + @staticmethod + def extract(node_id, inputs, outputs, metadata): + if not inputs: + return + + variable_name = _get_node_variable_name(metadata, node_id, inputs) + conditioning = inputs.get("CONDITIONING") + if conditioning is None: + conditioning = inputs.get("conditioning") + if conditioning is None: + return + + prompt_metadata = _ensure_prompt_metadata(metadata, node_id) + prompt_metadata["conditioning"] = conditioning + if variable_name: + prompt_metadata["variable_name"] = variable_name + metadata[PROMPTS].setdefault("__conditioning_variables__", {})[ + variable_name + ] = conditioning + + +class GetNodeExtractor(NodeMetadataExtractor): + @staticmethod + def extract(node_id, inputs, outputs, metadata): + variable_name = _get_node_variable_name(metadata, node_id, inputs or {}) + if variable_name: + prompt_metadata = _ensure_prompt_metadata(metadata, node_id) + prompt_metadata["variable_name"] = variable_name + + @staticmethod + def update(node_id, outputs, metadata): + output_tuple = _first_output_tuple(outputs) + if not output_tuple or len(output_tuple) < 1: + return + + prompt_metadata = _ensure_prompt_metadata(metadata, node_id) + output_conditioning = output_tuple[0] + prompt_metadata["conditioning"] = output_conditioning + + variable_name = prompt_metadata.get("variable_name") + if not variable_name: + return + + input_conditioning = metadata[PROMPTS].get("__conditioning_variables__", {}).get( + variable_name + ) + _record_conditioning_source( + metadata, node_id, output_conditioning, [input_conditioning] + ) + # Base Sampler Extractor to reduce code redundancy class BaseSamplerExtractor(NodeMetadataExtractor): """Base extractor for sampler nodes with common functionality""" @@ -798,6 +985,10 @@ NODE_EXTRACTORS = { "smZ_CLIPTextEncode": CLIPTextEncodeExtractor, # From https://github.com/shiimizu/ComfyUI_smZNodes "CR_ApplyControlNetStack": CR_ApplyControlNetStackExtractor, # Add CR_ApplyControlNetStack "PCTextEncode": CLIPTextEncodeExtractor, # From https://github.com/asagi4/comfyui-prompt-control + "ControlNetApplyAdvanced": ControlNetApplyAdvancedExtractor, + "ConditioningCombine": ConditioningCombineExtractor, + "SetNode": SetNodeExtractor, + "GetNode": GetNodeExtractor, # Latent "EmptyLatentImage": ImageSizeExtractor, # Flux diff --git a/tests/metadata_collector/test_metadata_collector.py b/tests/metadata_collector/test_metadata_collector.py index dcb6b770..20033119 100644 --- a/tests/metadata_collector/test_metadata_collector.py +++ b/tests/metadata_collector/test_metadata_collector.py @@ -177,6 +177,220 @@ def test_attention_bias_clip_text_encode_prompts_are_collected(metadata_registry assert prompt_results["negative_prompt"] == "low quality" +def test_conditioning_provenance_recovers_combined_controlnet_prompts( + metadata_registry, monkeypatch +): + import types + + prompt_graph = { + "encode_wd": { + "class_type": "CLIPTextEncode", + "inputs": {"text": "wd14 tags", "clip": ["clip", 0]}, + }, + "encode_manual": { + "class_type": "CLIPTextEncode", + "inputs": {"text": "manual tags", "clip": ["clip", 0]}, + }, + "encode_neg": { + "class_type": "CLIPTextEncode", + "inputs": {"text": "low quality", "clip": ["clip", 0]}, + }, + "combine": { + "class_type": "ConditioningCombine", + "inputs": { + "conditioning_1": ["encode_wd", 0], + "conditioning_2": ["encode_manual", 0], + }, + }, + "controlnet": { + "class_type": "ControlNetApplyAdvanced", + "inputs": { + "positive": ["combine", 0], + "negative": ["encode_neg", 0], + }, + }, + "sampler": { + "class_type": "KSampler", + "inputs": { + "seed": 123, + "steps": 20, + "cfg": 7.0, + "sampler_name": "Euler", + "scheduler": "karras", + "denoise": 1.0, + "positive": ["controlnet", 0], + "negative": ["controlnet", 1], + "latent_image": {"samples": types.SimpleNamespace(shape=(1, 4, 16, 16))}, + }, + }, + } + prompt = SimpleNamespace(original_prompt=prompt_graph) + + wd_conditioning = object() + manual_conditioning = object() + negative_conditioning = object() + combined_conditioning = object() + controlnet_positive = object() + controlnet_negative = object() + + monkeypatch.setattr(metadata_processor, "standalone_mode", False) + + metadata_registry.start_collection("prompt-provenance") + metadata_registry.set_current_prompt(prompt) + + metadata_registry.record_node_execution( + "encode_wd", "CLIPTextEncode", {"text": "wd14 tags"}, None + ) + metadata_registry.update_node_execution( + "encode_wd", "CLIPTextEncode", [(wd_conditioning,)] + ) + metadata_registry.record_node_execution( + "encode_manual", "CLIPTextEncode", {"text": "manual tags"}, None + ) + metadata_registry.update_node_execution( + "encode_manual", "CLIPTextEncode", [(manual_conditioning,)] + ) + metadata_registry.record_node_execution( + "encode_neg", "CLIPTextEncode", {"text": "low quality"}, None + ) + metadata_registry.update_node_execution( + "encode_neg", "CLIPTextEncode", [(negative_conditioning,)] + ) + metadata_registry.record_node_execution( + "combine", + "ConditioningCombine", + { + "conditioning_1": wd_conditioning, + "conditioning_2": manual_conditioning, + }, + None, + ) + metadata_registry.update_node_execution( + "combine", "ConditioningCombine", [(combined_conditioning,)] + ) + metadata_registry.record_node_execution( + "controlnet", + "ControlNetApplyAdvanced", + { + "positive": combined_conditioning, + "negative": negative_conditioning, + }, + None, + ) + metadata_registry.update_node_execution( + "controlnet", + "ControlNetApplyAdvanced", + [(controlnet_positive, controlnet_negative)], + ) + metadata_registry.record_node_execution( + "sampler", + "KSampler", + { + "seed": 123, + "steps": 20, + "cfg": 7.0, + "sampler_name": "Euler", + "scheduler": "karras", + "denoise": 1.0, + "positive": controlnet_positive, + "negative": controlnet_negative, + "latent_image": {"samples": types.SimpleNamespace(shape=(1, 4, 16, 16))}, + }, + None, + ) + + metadata = metadata_registry.get_metadata("prompt-provenance") + params = MetadataProcessor.extract_generation_params(metadata) + + assert params["prompt"] == "wd14 tags, manual tags" + assert params["negative_prompt"] == "low quality" + + +def test_conditioning_provenance_recovers_kj_set_get_prompts( + metadata_registry, monkeypatch +): + import types + + prompt_graph = { + "encode_pos": { + "class_type": "CLIPTextEncode", + "inputs": {"text": "from set node", "clip": ["clip", 0]}, + }, + "set_positive": { + "class_type": "SetNode", + "inputs": {"CONDITIONING": ["encode_pos", 0], "name": "positive"}, + }, + "get_positive": { + "class_type": "GetNode", + "inputs": {"name": "positive"}, + }, + "sampler": { + "class_type": "KSampler", + "inputs": { + "seed": 123, + "steps": 20, + "cfg": 7.0, + "sampler_name": "Euler", + "scheduler": "karras", + "denoise": 1.0, + "positive": ["get_positive", 0], + "negative": ["encode_pos", 0], + "latent_image": {"samples": types.SimpleNamespace(shape=(1, 4, 16, 16))}, + }, + }, + } + prompt = SimpleNamespace(original_prompt=prompt_graph) + + original_conditioning = object() + get_conditioning = object() + + monkeypatch.setattr(metadata_processor, "standalone_mode", False) + + metadata_registry.start_collection("prompt-kj-get") + metadata_registry.set_current_prompt(prompt) + + metadata_registry.record_node_execution( + "encode_pos", "CLIPTextEncode", {"text": "from set node"}, None + ) + metadata_registry.update_node_execution( + "encode_pos", "CLIPTextEncode", [(original_conditioning,)] + ) + metadata_registry.record_node_execution( + "set_positive", + "SetNode", + {"CONDITIONING": original_conditioning, "name": "positive"}, + None, + ) + metadata_registry.record_node_execution( + "get_positive", "GetNode", {"name": "positive"}, None + ) + metadata_registry.update_node_execution( + "get_positive", "GetNode", [(get_conditioning,)] + ) + metadata_registry.record_node_execution( + "sampler", + "KSampler", + { + "seed": 123, + "steps": 20, + "cfg": 7.0, + "sampler_name": "Euler", + "scheduler": "karras", + "denoise": 1.0, + "positive": get_conditioning, + "negative": original_conditioning, + "latent_image": {"samples": types.SimpleNamespace(shape=(1, 4, 16, 16))}, + }, + None, + ) + + metadata = metadata_registry.get_metadata("prompt-kj-get") + params = MetadataProcessor.extract_generation_params(metadata) + + assert params["prompt"] == "from set node" + assert params["negative_prompt"] == "from set node" + + def test_sampler_custom_advanced_recovers_prompt_text_through_guidance_nodes(metadata_registry, monkeypatch): import types