fix(metadata): trace conditioning provenance for prompts

This commit is contained in:
Will Miao
2026-04-23 14:41:54 +08:00
parent 2eef629821
commit ebdbb36271
3 changed files with 494 additions and 38 deletions

View File

@@ -163,6 +163,193 @@ class CLIPTextEncodeExtractor(NodeMetadataExtractor):
conditioning = outputs[0][0]
metadata[PROMPTS][node_id]["conditioning"] = conditioning
def _ensure_prompt_metadata(metadata, node_id):
if node_id not in metadata[PROMPTS]:
metadata[PROMPTS][node_id] = {"node_id": node_id}
return metadata[PROMPTS][node_id]
def _first_output_tuple(outputs):
if not outputs or not isinstance(outputs, list) or len(outputs) == 0:
return None
first_output = outputs[0]
if isinstance(first_output, tuple):
return first_output
return None
def _record_conditioning_source(
metadata, node_id, output_conditioning, input_conditionings
):
if output_conditioning is None:
return
sources = [
conditioning for conditioning in input_conditionings if conditioning is not None
]
if not sources:
return
prompt_metadata = _ensure_prompt_metadata(metadata, node_id)
prompt_metadata.setdefault("conditioning_sources", []).append(
{
"output": output_conditioning,
"inputs": sources,
}
)
def _get_variable_name(inputs):
for key in ("key", "name", "variable_name", "tag", "text"):
value = inputs.get(key)
if isinstance(value, str) and value:
return value
return None
def _get_node_variable_name(metadata, node_id, inputs):
variable_name = _get_variable_name(inputs)
if variable_name:
return variable_name
prompt = metadata.get("current_prompt")
original_prompt = getattr(prompt, "original_prompt", None)
if not original_prompt or node_id not in original_prompt:
return None
node_data = original_prompt[node_id]
variable_name = _get_variable_name(node_data.get("inputs", {}))
if variable_name:
return variable_name
widgets_values = node_data.get("widgets_values", [])
if widgets_values and isinstance(widgets_values[0], str):
return widgets_values[0]
return None
class ControlNetApplyAdvancedExtractor(NodeMetadataExtractor):
@staticmethod
def extract(node_id, inputs, outputs, metadata):
if not inputs:
return
prompt_metadata = _ensure_prompt_metadata(metadata, node_id)
if inputs.get("positive") is not None:
prompt_metadata["orig_pos_cond"] = inputs["positive"]
if inputs.get("negative") is not None:
prompt_metadata["orig_neg_cond"] = inputs["negative"]
@staticmethod
def update(node_id, outputs, metadata):
output_tuple = _first_output_tuple(outputs)
if not output_tuple:
return
prompt_metadata = _ensure_prompt_metadata(metadata, node_id)
positive_input = prompt_metadata.get("orig_pos_cond")
negative_input = prompt_metadata.get("orig_neg_cond")
if len(output_tuple) >= 1:
prompt_metadata["positive_encoded"] = output_tuple[0]
_record_conditioning_source(
metadata, node_id, output_tuple[0], [positive_input]
)
if len(output_tuple) >= 2:
prompt_metadata["negative_encoded"] = output_tuple[1]
_record_conditioning_source(
metadata, node_id, output_tuple[1], [negative_input]
)
class ConditioningCombineExtractor(NodeMetadataExtractor):
@staticmethod
def extract(node_id, inputs, outputs, metadata):
if not inputs:
return
input_conditionings = []
for input_name in inputs:
if (
input_name.startswith("conditioning")
and inputs[input_name] is not None
):
input_conditionings.append(inputs[input_name])
if input_conditionings:
prompt_metadata = _ensure_prompt_metadata(metadata, node_id)
prompt_metadata["orig_conditionings"] = input_conditionings
@staticmethod
def update(node_id, outputs, metadata):
output_tuple = _first_output_tuple(outputs)
if not output_tuple or len(output_tuple) < 1:
return
prompt_metadata = _ensure_prompt_metadata(metadata, node_id)
output_conditioning = output_tuple[0]
prompt_metadata["conditioning"] = output_conditioning
_record_conditioning_source(
metadata,
node_id,
output_conditioning,
prompt_metadata.get("orig_conditionings", []),
)
class SetNodeExtractor(NodeMetadataExtractor):
@staticmethod
def extract(node_id, inputs, outputs, metadata):
if not inputs:
return
variable_name = _get_node_variable_name(metadata, node_id, inputs)
conditioning = inputs.get("CONDITIONING")
if conditioning is None:
conditioning = inputs.get("conditioning")
if conditioning is None:
return
prompt_metadata = _ensure_prompt_metadata(metadata, node_id)
prompt_metadata["conditioning"] = conditioning
if variable_name:
prompt_metadata["variable_name"] = variable_name
metadata[PROMPTS].setdefault("__conditioning_variables__", {})[
variable_name
] = conditioning
class GetNodeExtractor(NodeMetadataExtractor):
@staticmethod
def extract(node_id, inputs, outputs, metadata):
variable_name = _get_node_variable_name(metadata, node_id, inputs or {})
if variable_name:
prompt_metadata = _ensure_prompt_metadata(metadata, node_id)
prompt_metadata["variable_name"] = variable_name
@staticmethod
def update(node_id, outputs, metadata):
output_tuple = _first_output_tuple(outputs)
if not output_tuple or len(output_tuple) < 1:
return
prompt_metadata = _ensure_prompt_metadata(metadata, node_id)
output_conditioning = output_tuple[0]
prompt_metadata["conditioning"] = output_conditioning
variable_name = prompt_metadata.get("variable_name")
if not variable_name:
return
input_conditioning = metadata[PROMPTS].get("__conditioning_variables__", {}).get(
variable_name
)
_record_conditioning_source(
metadata, node_id, output_conditioning, [input_conditioning]
)
# Base Sampler Extractor to reduce code redundancy
class BaseSamplerExtractor(NodeMetadataExtractor):
"""Base extractor for sampler nodes with common functionality"""
@@ -798,6 +985,10 @@ NODE_EXTRACTORS = {
"smZ_CLIPTextEncode": CLIPTextEncodeExtractor, # From https://github.com/shiimizu/ComfyUI_smZNodes
"CR_ApplyControlNetStack": CR_ApplyControlNetStackExtractor, # Add CR_ApplyControlNetStack
"PCTextEncode": CLIPTextEncodeExtractor, # From https://github.com/asagi4/comfyui-prompt-control
"ControlNetApplyAdvanced": ControlNetApplyAdvancedExtractor,
"ConditioningCombine": ConditioningCombineExtractor,
"SetNode": SetNodeExtractor,
"GetNode": GetNodeExtractor,
# Latent
"EmptyLatentImage": ImageSizeExtractor,
# Flux