fix(metadata): trace conditioning provenance for prompts

This commit is contained in:
Will Miao
2026-04-23 14:41:54 +08:00
parent 2eef629821
commit ebdbb36271
3 changed files with 494 additions and 38 deletions

View File

@@ -353,49 +353,100 @@ class MetadataProcessor:
# Check if we have stored conditioning objects for this sampler # Check if we have stored conditioning objects for this sampler
if sampler_id in metadata.get(PROMPTS, {}) and ( if sampler_id in metadata.get(PROMPTS, {}) and (
"pos_conditioning" in metadata[PROMPTS][sampler_id] or "pos_conditioning" in metadata[PROMPTS][sampler_id] or
"neg_conditioning" in metadata[PROMPTS][sampler_id]): "neg_conditioning" in metadata[PROMPTS][sampler_id]
):
pos_conditioning = metadata[PROMPTS][sampler_id].get("pos_conditioning") pos_conditioning = metadata[PROMPTS][sampler_id].get("pos_conditioning")
neg_conditioning = metadata[PROMPTS][sampler_id].get("neg_conditioning") neg_conditioning = metadata[PROMPTS][sampler_id].get("neg_conditioning")
# Helper function to recursively find prompt text for a conditioning object def extend_unique(target, values):
def find_prompt_text_for_conditioning(conditioning_obj, is_positive=True): for value in values:
if value and value not in target:
target.append(value)
# Helper function to recursively find prompt texts for a conditioning object.
# Transform nodes can map one output conditioning to multiple source conditionings.
def find_prompt_texts_for_conditioning(
conditioning_obj, is_positive=True, visited=None
):
if conditioning_obj is None: if conditioning_obj is None:
return "" return []
if visited is None:
visited = set()
conditioning_id = id(conditioning_obj)
if conditioning_id in visited:
return []
visited.add(conditioning_id)
prompt_texts = []
# Try to match conditioning objects with those stored by extractors # Try to match conditioning objects with those stored by extractors
for prompt_node_id, prompt_data in metadata[PROMPTS].items(): for prompt_node_id, prompt_data in metadata[PROMPTS].items():
# For nodes with single conditioning output if not isinstance(prompt_data, dict):
if "conditioning" in prompt_data: continue
if id(prompt_data["conditioning"]) == id(conditioning_obj):
return prompt_data.get("text", "")
# For nodes with separate pos_conditioning and neg_conditioning outputs (like TSC_EfficientLoader) # For CLIP text nodes with a single conditioning output.
if is_positive and "positive_encoded" in prompt_data: if id(prompt_data.get("conditioning")) == conditioning_id:
if id(prompt_data["positive_encoded"]) == id(conditioning_obj): text = prompt_data.get("text", "")
if "positive_text" in prompt_data: if text:
return prompt_data["positive_text"] extend_unique(prompt_texts, [text])
else:
orig_conditioning = prompt_data.get("orig_pos_cond", None)
if orig_conditioning is not None:
# Recursively find the prompt text for the original conditioning
return find_prompt_text_for_conditioning(orig_conditioning, is_positive=True)
if not is_positive and "negative_encoded" in prompt_data: # Generic provenance for passthrough/transform/combine nodes.
if id(prompt_data["negative_encoded"]) == id(conditioning_obj): for source in prompt_data.get("conditioning_sources", []):
if "negative_text" in prompt_data: if id(source.get("output")) != conditioning_id:
return prompt_data["negative_text"] continue
else: for input_conditioning in source.get("inputs", []):
orig_conditioning = prompt_data.get("orig_neg_cond", None) extend_unique(
if orig_conditioning is not None: prompt_texts,
# Recursively find the prompt text for the original conditioning find_prompt_texts_for_conditioning(
return find_prompt_text_for_conditioning(orig_conditioning, is_positive=False) input_conditioning, is_positive, visited
),
)
return "" # For nodes with separate pos_conditioning and neg_conditioning outputs
# like TSC_EfficientLoader and existing ControlNet-style metadata.
if (
is_positive
and id(prompt_data.get("positive_encoded")) == conditioning_id
):
if prompt_data.get("positive_text"):
extend_unique(prompt_texts, [prompt_data["positive_text"]])
else:
extend_unique(
prompt_texts,
find_prompt_texts_for_conditioning(
prompt_data.get("orig_pos_cond"),
is_positive=True,
visited=visited,
),
)
if (
not is_positive
and id(prompt_data.get("negative_encoded")) == conditioning_id
):
if prompt_data.get("negative_text"):
extend_unique(prompt_texts, [prompt_data["negative_text"]])
else:
extend_unique(
prompt_texts,
find_prompt_texts_for_conditioning(
prompt_data.get("orig_neg_cond"),
is_positive=False,
visited=visited,
),
)
return prompt_texts
# Find prompt texts using the helper function # Find prompt texts using the helper function
result["prompt"] = find_prompt_text_for_conditioning(pos_conditioning, is_positive=True) result["prompt"] = ", ".join(
result["negative_prompt"] = find_prompt_text_for_conditioning(neg_conditioning, is_positive=False) find_prompt_texts_for_conditioning(pos_conditioning, is_positive=True)
)
result["negative_prompt"] = ", ".join(
find_prompt_texts_for_conditioning(neg_conditioning, is_positive=False)
)
return result return result

View File

@@ -163,6 +163,193 @@ class CLIPTextEncodeExtractor(NodeMetadataExtractor):
conditioning = outputs[0][0] conditioning = outputs[0][0]
metadata[PROMPTS][node_id]["conditioning"] = conditioning metadata[PROMPTS][node_id]["conditioning"] = conditioning
def _ensure_prompt_metadata(metadata, node_id):
if node_id not in metadata[PROMPTS]:
metadata[PROMPTS][node_id] = {"node_id": node_id}
return metadata[PROMPTS][node_id]
def _first_output_tuple(outputs):
if not outputs or not isinstance(outputs, list) or len(outputs) == 0:
return None
first_output = outputs[0]
if isinstance(first_output, tuple):
return first_output
return None
def _record_conditioning_source(
metadata, node_id, output_conditioning, input_conditionings
):
if output_conditioning is None:
return
sources = [
conditioning for conditioning in input_conditionings if conditioning is not None
]
if not sources:
return
prompt_metadata = _ensure_prompt_metadata(metadata, node_id)
prompt_metadata.setdefault("conditioning_sources", []).append(
{
"output": output_conditioning,
"inputs": sources,
}
)
def _get_variable_name(inputs):
for key in ("key", "name", "variable_name", "tag", "text"):
value = inputs.get(key)
if isinstance(value, str) and value:
return value
return None
def _get_node_variable_name(metadata, node_id, inputs):
variable_name = _get_variable_name(inputs)
if variable_name:
return variable_name
prompt = metadata.get("current_prompt")
original_prompt = getattr(prompt, "original_prompt", None)
if not original_prompt or node_id not in original_prompt:
return None
node_data = original_prompt[node_id]
variable_name = _get_variable_name(node_data.get("inputs", {}))
if variable_name:
return variable_name
widgets_values = node_data.get("widgets_values", [])
if widgets_values and isinstance(widgets_values[0], str):
return widgets_values[0]
return None
class ControlNetApplyAdvancedExtractor(NodeMetadataExtractor):
@staticmethod
def extract(node_id, inputs, outputs, metadata):
if not inputs:
return
prompt_metadata = _ensure_prompt_metadata(metadata, node_id)
if inputs.get("positive") is not None:
prompt_metadata["orig_pos_cond"] = inputs["positive"]
if inputs.get("negative") is not None:
prompt_metadata["orig_neg_cond"] = inputs["negative"]
@staticmethod
def update(node_id, outputs, metadata):
output_tuple = _first_output_tuple(outputs)
if not output_tuple:
return
prompt_metadata = _ensure_prompt_metadata(metadata, node_id)
positive_input = prompt_metadata.get("orig_pos_cond")
negative_input = prompt_metadata.get("orig_neg_cond")
if len(output_tuple) >= 1:
prompt_metadata["positive_encoded"] = output_tuple[0]
_record_conditioning_source(
metadata, node_id, output_tuple[0], [positive_input]
)
if len(output_tuple) >= 2:
prompt_metadata["negative_encoded"] = output_tuple[1]
_record_conditioning_source(
metadata, node_id, output_tuple[1], [negative_input]
)
class ConditioningCombineExtractor(NodeMetadataExtractor):
@staticmethod
def extract(node_id, inputs, outputs, metadata):
if not inputs:
return
input_conditionings = []
for input_name in inputs:
if (
input_name.startswith("conditioning")
and inputs[input_name] is not None
):
input_conditionings.append(inputs[input_name])
if input_conditionings:
prompt_metadata = _ensure_prompt_metadata(metadata, node_id)
prompt_metadata["orig_conditionings"] = input_conditionings
@staticmethod
def update(node_id, outputs, metadata):
output_tuple = _first_output_tuple(outputs)
if not output_tuple or len(output_tuple) < 1:
return
prompt_metadata = _ensure_prompt_metadata(metadata, node_id)
output_conditioning = output_tuple[0]
prompt_metadata["conditioning"] = output_conditioning
_record_conditioning_source(
metadata,
node_id,
output_conditioning,
prompt_metadata.get("orig_conditionings", []),
)
class SetNodeExtractor(NodeMetadataExtractor):
@staticmethod
def extract(node_id, inputs, outputs, metadata):
if not inputs:
return
variable_name = _get_node_variable_name(metadata, node_id, inputs)
conditioning = inputs.get("CONDITIONING")
if conditioning is None:
conditioning = inputs.get("conditioning")
if conditioning is None:
return
prompt_metadata = _ensure_prompt_metadata(metadata, node_id)
prompt_metadata["conditioning"] = conditioning
if variable_name:
prompt_metadata["variable_name"] = variable_name
metadata[PROMPTS].setdefault("__conditioning_variables__", {})[
variable_name
] = conditioning
class GetNodeExtractor(NodeMetadataExtractor):
@staticmethod
def extract(node_id, inputs, outputs, metadata):
variable_name = _get_node_variable_name(metadata, node_id, inputs or {})
if variable_name:
prompt_metadata = _ensure_prompt_metadata(metadata, node_id)
prompt_metadata["variable_name"] = variable_name
@staticmethod
def update(node_id, outputs, metadata):
output_tuple = _first_output_tuple(outputs)
if not output_tuple or len(output_tuple) < 1:
return
prompt_metadata = _ensure_prompt_metadata(metadata, node_id)
output_conditioning = output_tuple[0]
prompt_metadata["conditioning"] = output_conditioning
variable_name = prompt_metadata.get("variable_name")
if not variable_name:
return
input_conditioning = metadata[PROMPTS].get("__conditioning_variables__", {}).get(
variable_name
)
_record_conditioning_source(
metadata, node_id, output_conditioning, [input_conditioning]
)
# Base Sampler Extractor to reduce code redundancy # Base Sampler Extractor to reduce code redundancy
class BaseSamplerExtractor(NodeMetadataExtractor): class BaseSamplerExtractor(NodeMetadataExtractor):
"""Base extractor for sampler nodes with common functionality""" """Base extractor for sampler nodes with common functionality"""
@@ -798,6 +985,10 @@ NODE_EXTRACTORS = {
"smZ_CLIPTextEncode": CLIPTextEncodeExtractor, # From https://github.com/shiimizu/ComfyUI_smZNodes "smZ_CLIPTextEncode": CLIPTextEncodeExtractor, # From https://github.com/shiimizu/ComfyUI_smZNodes
"CR_ApplyControlNetStack": CR_ApplyControlNetStackExtractor, # Add CR_ApplyControlNetStack "CR_ApplyControlNetStack": CR_ApplyControlNetStackExtractor, # Add CR_ApplyControlNetStack
"PCTextEncode": CLIPTextEncodeExtractor, # From https://github.com/asagi4/comfyui-prompt-control "PCTextEncode": CLIPTextEncodeExtractor, # From https://github.com/asagi4/comfyui-prompt-control
"ControlNetApplyAdvanced": ControlNetApplyAdvancedExtractor,
"ConditioningCombine": ConditioningCombineExtractor,
"SetNode": SetNodeExtractor,
"GetNode": GetNodeExtractor,
# Latent # Latent
"EmptyLatentImage": ImageSizeExtractor, "EmptyLatentImage": ImageSizeExtractor,
# Flux # Flux

View File

@@ -177,6 +177,220 @@ def test_attention_bias_clip_text_encode_prompts_are_collected(metadata_registry
assert prompt_results["negative_prompt"] == "low quality" assert prompt_results["negative_prompt"] == "low quality"
def test_conditioning_provenance_recovers_combined_controlnet_prompts(
metadata_registry, monkeypatch
):
import types
prompt_graph = {
"encode_wd": {
"class_type": "CLIPTextEncode",
"inputs": {"text": "wd14 tags", "clip": ["clip", 0]},
},
"encode_manual": {
"class_type": "CLIPTextEncode",
"inputs": {"text": "manual tags", "clip": ["clip", 0]},
},
"encode_neg": {
"class_type": "CLIPTextEncode",
"inputs": {"text": "low quality", "clip": ["clip", 0]},
},
"combine": {
"class_type": "ConditioningCombine",
"inputs": {
"conditioning_1": ["encode_wd", 0],
"conditioning_2": ["encode_manual", 0],
},
},
"controlnet": {
"class_type": "ControlNetApplyAdvanced",
"inputs": {
"positive": ["combine", 0],
"negative": ["encode_neg", 0],
},
},
"sampler": {
"class_type": "KSampler",
"inputs": {
"seed": 123,
"steps": 20,
"cfg": 7.0,
"sampler_name": "Euler",
"scheduler": "karras",
"denoise": 1.0,
"positive": ["controlnet", 0],
"negative": ["controlnet", 1],
"latent_image": {"samples": types.SimpleNamespace(shape=(1, 4, 16, 16))},
},
},
}
prompt = SimpleNamespace(original_prompt=prompt_graph)
wd_conditioning = object()
manual_conditioning = object()
negative_conditioning = object()
combined_conditioning = object()
controlnet_positive = object()
controlnet_negative = object()
monkeypatch.setattr(metadata_processor, "standalone_mode", False)
metadata_registry.start_collection("prompt-provenance")
metadata_registry.set_current_prompt(prompt)
metadata_registry.record_node_execution(
"encode_wd", "CLIPTextEncode", {"text": "wd14 tags"}, None
)
metadata_registry.update_node_execution(
"encode_wd", "CLIPTextEncode", [(wd_conditioning,)]
)
metadata_registry.record_node_execution(
"encode_manual", "CLIPTextEncode", {"text": "manual tags"}, None
)
metadata_registry.update_node_execution(
"encode_manual", "CLIPTextEncode", [(manual_conditioning,)]
)
metadata_registry.record_node_execution(
"encode_neg", "CLIPTextEncode", {"text": "low quality"}, None
)
metadata_registry.update_node_execution(
"encode_neg", "CLIPTextEncode", [(negative_conditioning,)]
)
metadata_registry.record_node_execution(
"combine",
"ConditioningCombine",
{
"conditioning_1": wd_conditioning,
"conditioning_2": manual_conditioning,
},
None,
)
metadata_registry.update_node_execution(
"combine", "ConditioningCombine", [(combined_conditioning,)]
)
metadata_registry.record_node_execution(
"controlnet",
"ControlNetApplyAdvanced",
{
"positive": combined_conditioning,
"negative": negative_conditioning,
},
None,
)
metadata_registry.update_node_execution(
"controlnet",
"ControlNetApplyAdvanced",
[(controlnet_positive, controlnet_negative)],
)
metadata_registry.record_node_execution(
"sampler",
"KSampler",
{
"seed": 123,
"steps": 20,
"cfg": 7.0,
"sampler_name": "Euler",
"scheduler": "karras",
"denoise": 1.0,
"positive": controlnet_positive,
"negative": controlnet_negative,
"latent_image": {"samples": types.SimpleNamespace(shape=(1, 4, 16, 16))},
},
None,
)
metadata = metadata_registry.get_metadata("prompt-provenance")
params = MetadataProcessor.extract_generation_params(metadata)
assert params["prompt"] == "wd14 tags, manual tags"
assert params["negative_prompt"] == "low quality"
def test_conditioning_provenance_recovers_kj_set_get_prompts(
metadata_registry, monkeypatch
):
import types
prompt_graph = {
"encode_pos": {
"class_type": "CLIPTextEncode",
"inputs": {"text": "from set node", "clip": ["clip", 0]},
},
"set_positive": {
"class_type": "SetNode",
"inputs": {"CONDITIONING": ["encode_pos", 0], "name": "positive"},
},
"get_positive": {
"class_type": "GetNode",
"inputs": {"name": "positive"},
},
"sampler": {
"class_type": "KSampler",
"inputs": {
"seed": 123,
"steps": 20,
"cfg": 7.0,
"sampler_name": "Euler",
"scheduler": "karras",
"denoise": 1.0,
"positive": ["get_positive", 0],
"negative": ["encode_pos", 0],
"latent_image": {"samples": types.SimpleNamespace(shape=(1, 4, 16, 16))},
},
},
}
prompt = SimpleNamespace(original_prompt=prompt_graph)
original_conditioning = object()
get_conditioning = object()
monkeypatch.setattr(metadata_processor, "standalone_mode", False)
metadata_registry.start_collection("prompt-kj-get")
metadata_registry.set_current_prompt(prompt)
metadata_registry.record_node_execution(
"encode_pos", "CLIPTextEncode", {"text": "from set node"}, None
)
metadata_registry.update_node_execution(
"encode_pos", "CLIPTextEncode", [(original_conditioning,)]
)
metadata_registry.record_node_execution(
"set_positive",
"SetNode",
{"CONDITIONING": original_conditioning, "name": "positive"},
None,
)
metadata_registry.record_node_execution(
"get_positive", "GetNode", {"name": "positive"}, None
)
metadata_registry.update_node_execution(
"get_positive", "GetNode", [(get_conditioning,)]
)
metadata_registry.record_node_execution(
"sampler",
"KSampler",
{
"seed": 123,
"steps": 20,
"cfg": 7.0,
"sampler_name": "Euler",
"scheduler": "karras",
"denoise": 1.0,
"positive": get_conditioning,
"negative": original_conditioning,
"latent_image": {"samples": types.SimpleNamespace(shape=(1, 4, 16, 16))},
},
None,
)
metadata = metadata_registry.get_metadata("prompt-kj-get")
params = MetadataProcessor.extract_generation_params(metadata)
assert params["prompt"] == "from set node"
assert params["negative_prompt"] == "from set node"
def test_sampler_custom_advanced_recovers_prompt_text_through_guidance_nodes(metadata_registry, monkeypatch): def test_sampler_custom_advanced_recovers_prompt_text_through_guidance_nodes(metadata_registry, monkeypatch):
import types import types