mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-05-06 08:26:45 -03:00
fix(metadata): trace conditioning provenance for prompts
This commit is contained in:
@@ -353,49 +353,100 @@ class MetadataProcessor:
|
||||
# Check if we have stored conditioning objects for this sampler
|
||||
if sampler_id in metadata.get(PROMPTS, {}) and (
|
||||
"pos_conditioning" in metadata[PROMPTS][sampler_id] or
|
||||
"neg_conditioning" in metadata[PROMPTS][sampler_id]):
|
||||
|
||||
"neg_conditioning" in metadata[PROMPTS][sampler_id]
|
||||
):
|
||||
pos_conditioning = metadata[PROMPTS][sampler_id].get("pos_conditioning")
|
||||
neg_conditioning = metadata[PROMPTS][sampler_id].get("neg_conditioning")
|
||||
|
||||
# Helper function to recursively find prompt text for a conditioning object
|
||||
def find_prompt_text_for_conditioning(conditioning_obj, is_positive=True):
|
||||
def extend_unique(target, values):
|
||||
for value in values:
|
||||
if value and value not in target:
|
||||
target.append(value)
|
||||
|
||||
# Helper function to recursively find prompt texts for a conditioning object.
|
||||
# Transform nodes can map one output conditioning to multiple source conditionings.
|
||||
def find_prompt_texts_for_conditioning(
|
||||
conditioning_obj, is_positive=True, visited=None
|
||||
):
|
||||
if conditioning_obj is None:
|
||||
return ""
|
||||
return []
|
||||
|
||||
if visited is None:
|
||||
visited = set()
|
||||
|
||||
conditioning_id = id(conditioning_obj)
|
||||
if conditioning_id in visited:
|
||||
return []
|
||||
visited.add(conditioning_id)
|
||||
|
||||
prompt_texts = []
|
||||
|
||||
# Try to match conditioning objects with those stored by extractors
|
||||
for prompt_node_id, prompt_data in metadata[PROMPTS].items():
|
||||
# For nodes with single conditioning output
|
||||
if "conditioning" in prompt_data:
|
||||
if id(prompt_data["conditioning"]) == id(conditioning_obj):
|
||||
return prompt_data.get("text", "")
|
||||
if not isinstance(prompt_data, dict):
|
||||
continue
|
||||
|
||||
# For nodes with separate pos_conditioning and neg_conditioning outputs (like TSC_EfficientLoader)
|
||||
if is_positive and "positive_encoded" in prompt_data:
|
||||
if id(prompt_data["positive_encoded"]) == id(conditioning_obj):
|
||||
if "positive_text" in prompt_data:
|
||||
return prompt_data["positive_text"]
|
||||
# For CLIP text nodes with a single conditioning output.
|
||||
if id(prompt_data.get("conditioning")) == conditioning_id:
|
||||
text = prompt_data.get("text", "")
|
||||
if text:
|
||||
extend_unique(prompt_texts, [text])
|
||||
|
||||
# Generic provenance for passthrough/transform/combine nodes.
|
||||
for source in prompt_data.get("conditioning_sources", []):
|
||||
if id(source.get("output")) != conditioning_id:
|
||||
continue
|
||||
for input_conditioning in source.get("inputs", []):
|
||||
extend_unique(
|
||||
prompt_texts,
|
||||
find_prompt_texts_for_conditioning(
|
||||
input_conditioning, is_positive, visited
|
||||
),
|
||||
)
|
||||
|
||||
# For nodes with separate pos_conditioning and neg_conditioning outputs
|
||||
# like TSC_EfficientLoader and existing ControlNet-style metadata.
|
||||
if (
|
||||
is_positive
|
||||
and id(prompt_data.get("positive_encoded")) == conditioning_id
|
||||
):
|
||||
if prompt_data.get("positive_text"):
|
||||
extend_unique(prompt_texts, [prompt_data["positive_text"]])
|
||||
else:
|
||||
orig_conditioning = prompt_data.get("orig_pos_cond", None)
|
||||
if orig_conditioning is not None:
|
||||
# Recursively find the prompt text for the original conditioning
|
||||
return find_prompt_text_for_conditioning(orig_conditioning, is_positive=True)
|
||||
extend_unique(
|
||||
prompt_texts,
|
||||
find_prompt_texts_for_conditioning(
|
||||
prompt_data.get("orig_pos_cond"),
|
||||
is_positive=True,
|
||||
visited=visited,
|
||||
),
|
||||
)
|
||||
|
||||
if not is_positive and "negative_encoded" in prompt_data:
|
||||
if id(prompt_data["negative_encoded"]) == id(conditioning_obj):
|
||||
if "negative_text" in prompt_data:
|
||||
return prompt_data["negative_text"]
|
||||
if (
|
||||
not is_positive
|
||||
and id(prompt_data.get("negative_encoded")) == conditioning_id
|
||||
):
|
||||
if prompt_data.get("negative_text"):
|
||||
extend_unique(prompt_texts, [prompt_data["negative_text"]])
|
||||
else:
|
||||
orig_conditioning = prompt_data.get("orig_neg_cond", None)
|
||||
if orig_conditioning is not None:
|
||||
# Recursively find the prompt text for the original conditioning
|
||||
return find_prompt_text_for_conditioning(orig_conditioning, is_positive=False)
|
||||
extend_unique(
|
||||
prompt_texts,
|
||||
find_prompt_texts_for_conditioning(
|
||||
prompt_data.get("orig_neg_cond"),
|
||||
is_positive=False,
|
||||
visited=visited,
|
||||
),
|
||||
)
|
||||
|
||||
return ""
|
||||
return prompt_texts
|
||||
|
||||
# Find prompt texts using the helper function
|
||||
result["prompt"] = find_prompt_text_for_conditioning(pos_conditioning, is_positive=True)
|
||||
result["negative_prompt"] = find_prompt_text_for_conditioning(neg_conditioning, is_positive=False)
|
||||
result["prompt"] = ", ".join(
|
||||
find_prompt_texts_for_conditioning(pos_conditioning, is_positive=True)
|
||||
)
|
||||
result["negative_prompt"] = ", ".join(
|
||||
find_prompt_texts_for_conditioning(neg_conditioning, is_positive=False)
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@@ -163,6 +163,193 @@ class CLIPTextEncodeExtractor(NodeMetadataExtractor):
|
||||
conditioning = outputs[0][0]
|
||||
metadata[PROMPTS][node_id]["conditioning"] = conditioning
|
||||
|
||||
|
||||
def _ensure_prompt_metadata(metadata, node_id):
|
||||
if node_id not in metadata[PROMPTS]:
|
||||
metadata[PROMPTS][node_id] = {"node_id": node_id}
|
||||
return metadata[PROMPTS][node_id]
|
||||
|
||||
|
||||
def _first_output_tuple(outputs):
|
||||
if not outputs or not isinstance(outputs, list) or len(outputs) == 0:
|
||||
return None
|
||||
first_output = outputs[0]
|
||||
if isinstance(first_output, tuple):
|
||||
return first_output
|
||||
return None
|
||||
|
||||
|
||||
def _record_conditioning_source(
|
||||
metadata, node_id, output_conditioning, input_conditionings
|
||||
):
|
||||
if output_conditioning is None:
|
||||
return
|
||||
|
||||
sources = [
|
||||
conditioning for conditioning in input_conditionings if conditioning is not None
|
||||
]
|
||||
if not sources:
|
||||
return
|
||||
|
||||
prompt_metadata = _ensure_prompt_metadata(metadata, node_id)
|
||||
prompt_metadata.setdefault("conditioning_sources", []).append(
|
||||
{
|
||||
"output": output_conditioning,
|
||||
"inputs": sources,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _get_variable_name(inputs):
|
||||
for key in ("key", "name", "variable_name", "tag", "text"):
|
||||
value = inputs.get(key)
|
||||
if isinstance(value, str) and value:
|
||||
return value
|
||||
return None
|
||||
|
||||
|
||||
def _get_node_variable_name(metadata, node_id, inputs):
|
||||
variable_name = _get_variable_name(inputs)
|
||||
if variable_name:
|
||||
return variable_name
|
||||
|
||||
prompt = metadata.get("current_prompt")
|
||||
original_prompt = getattr(prompt, "original_prompt", None)
|
||||
if not original_prompt or node_id not in original_prompt:
|
||||
return None
|
||||
|
||||
node_data = original_prompt[node_id]
|
||||
variable_name = _get_variable_name(node_data.get("inputs", {}))
|
||||
if variable_name:
|
||||
return variable_name
|
||||
|
||||
widgets_values = node_data.get("widgets_values", [])
|
||||
if widgets_values and isinstance(widgets_values[0], str):
|
||||
return widgets_values[0]
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class ControlNetApplyAdvancedExtractor(NodeMetadataExtractor):
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
if not inputs:
|
||||
return
|
||||
|
||||
prompt_metadata = _ensure_prompt_metadata(metadata, node_id)
|
||||
if inputs.get("positive") is not None:
|
||||
prompt_metadata["orig_pos_cond"] = inputs["positive"]
|
||||
if inputs.get("negative") is not None:
|
||||
prompt_metadata["orig_neg_cond"] = inputs["negative"]
|
||||
|
||||
@staticmethod
|
||||
def update(node_id, outputs, metadata):
|
||||
output_tuple = _first_output_tuple(outputs)
|
||||
if not output_tuple:
|
||||
return
|
||||
|
||||
prompt_metadata = _ensure_prompt_metadata(metadata, node_id)
|
||||
positive_input = prompt_metadata.get("orig_pos_cond")
|
||||
negative_input = prompt_metadata.get("orig_neg_cond")
|
||||
|
||||
if len(output_tuple) >= 1:
|
||||
prompt_metadata["positive_encoded"] = output_tuple[0]
|
||||
_record_conditioning_source(
|
||||
metadata, node_id, output_tuple[0], [positive_input]
|
||||
)
|
||||
if len(output_tuple) >= 2:
|
||||
prompt_metadata["negative_encoded"] = output_tuple[1]
|
||||
_record_conditioning_source(
|
||||
metadata, node_id, output_tuple[1], [negative_input]
|
||||
)
|
||||
|
||||
|
||||
class ConditioningCombineExtractor(NodeMetadataExtractor):
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
if not inputs:
|
||||
return
|
||||
|
||||
input_conditionings = []
|
||||
for input_name in inputs:
|
||||
if (
|
||||
input_name.startswith("conditioning")
|
||||
and inputs[input_name] is not None
|
||||
):
|
||||
input_conditionings.append(inputs[input_name])
|
||||
|
||||
if input_conditionings:
|
||||
prompt_metadata = _ensure_prompt_metadata(metadata, node_id)
|
||||
prompt_metadata["orig_conditionings"] = input_conditionings
|
||||
|
||||
@staticmethod
|
||||
def update(node_id, outputs, metadata):
|
||||
output_tuple = _first_output_tuple(outputs)
|
||||
if not output_tuple or len(output_tuple) < 1:
|
||||
return
|
||||
|
||||
prompt_metadata = _ensure_prompt_metadata(metadata, node_id)
|
||||
output_conditioning = output_tuple[0]
|
||||
prompt_metadata["conditioning"] = output_conditioning
|
||||
_record_conditioning_source(
|
||||
metadata,
|
||||
node_id,
|
||||
output_conditioning,
|
||||
prompt_metadata.get("orig_conditionings", []),
|
||||
)
|
||||
|
||||
|
||||
class SetNodeExtractor(NodeMetadataExtractor):
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
if not inputs:
|
||||
return
|
||||
|
||||
variable_name = _get_node_variable_name(metadata, node_id, inputs)
|
||||
conditioning = inputs.get("CONDITIONING")
|
||||
if conditioning is None:
|
||||
conditioning = inputs.get("conditioning")
|
||||
if conditioning is None:
|
||||
return
|
||||
|
||||
prompt_metadata = _ensure_prompt_metadata(metadata, node_id)
|
||||
prompt_metadata["conditioning"] = conditioning
|
||||
if variable_name:
|
||||
prompt_metadata["variable_name"] = variable_name
|
||||
metadata[PROMPTS].setdefault("__conditioning_variables__", {})[
|
||||
variable_name
|
||||
] = conditioning
|
||||
|
||||
|
||||
class GetNodeExtractor(NodeMetadataExtractor):
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
variable_name = _get_node_variable_name(metadata, node_id, inputs or {})
|
||||
if variable_name:
|
||||
prompt_metadata = _ensure_prompt_metadata(metadata, node_id)
|
||||
prompt_metadata["variable_name"] = variable_name
|
||||
|
||||
@staticmethod
|
||||
def update(node_id, outputs, metadata):
|
||||
output_tuple = _first_output_tuple(outputs)
|
||||
if not output_tuple or len(output_tuple) < 1:
|
||||
return
|
||||
|
||||
prompt_metadata = _ensure_prompt_metadata(metadata, node_id)
|
||||
output_conditioning = output_tuple[0]
|
||||
prompt_metadata["conditioning"] = output_conditioning
|
||||
|
||||
variable_name = prompt_metadata.get("variable_name")
|
||||
if not variable_name:
|
||||
return
|
||||
|
||||
input_conditioning = metadata[PROMPTS].get("__conditioning_variables__", {}).get(
|
||||
variable_name
|
||||
)
|
||||
_record_conditioning_source(
|
||||
metadata, node_id, output_conditioning, [input_conditioning]
|
||||
)
|
||||
|
||||
# Base Sampler Extractor to reduce code redundancy
|
||||
class BaseSamplerExtractor(NodeMetadataExtractor):
|
||||
"""Base extractor for sampler nodes with common functionality"""
|
||||
@@ -798,6 +985,10 @@ NODE_EXTRACTORS = {
|
||||
"smZ_CLIPTextEncode": CLIPTextEncodeExtractor, # From https://github.com/shiimizu/ComfyUI_smZNodes
|
||||
"CR_ApplyControlNetStack": CR_ApplyControlNetStackExtractor, # Add CR_ApplyControlNetStack
|
||||
"PCTextEncode": CLIPTextEncodeExtractor, # From https://github.com/asagi4/comfyui-prompt-control
|
||||
"ControlNetApplyAdvanced": ControlNetApplyAdvancedExtractor,
|
||||
"ConditioningCombine": ConditioningCombineExtractor,
|
||||
"SetNode": SetNodeExtractor,
|
||||
"GetNode": GetNodeExtractor,
|
||||
# Latent
|
||||
"EmptyLatentImage": ImageSizeExtractor,
|
||||
# Flux
|
||||
|
||||
@@ -177,6 +177,220 @@ def test_attention_bias_clip_text_encode_prompts_are_collected(metadata_registry
|
||||
assert prompt_results["negative_prompt"] == "low quality"
|
||||
|
||||
|
||||
def test_conditioning_provenance_recovers_combined_controlnet_prompts(
|
||||
metadata_registry, monkeypatch
|
||||
):
|
||||
import types
|
||||
|
||||
prompt_graph = {
|
||||
"encode_wd": {
|
||||
"class_type": "CLIPTextEncode",
|
||||
"inputs": {"text": "wd14 tags", "clip": ["clip", 0]},
|
||||
},
|
||||
"encode_manual": {
|
||||
"class_type": "CLIPTextEncode",
|
||||
"inputs": {"text": "manual tags", "clip": ["clip", 0]},
|
||||
},
|
||||
"encode_neg": {
|
||||
"class_type": "CLIPTextEncode",
|
||||
"inputs": {"text": "low quality", "clip": ["clip", 0]},
|
||||
},
|
||||
"combine": {
|
||||
"class_type": "ConditioningCombine",
|
||||
"inputs": {
|
||||
"conditioning_1": ["encode_wd", 0],
|
||||
"conditioning_2": ["encode_manual", 0],
|
||||
},
|
||||
},
|
||||
"controlnet": {
|
||||
"class_type": "ControlNetApplyAdvanced",
|
||||
"inputs": {
|
||||
"positive": ["combine", 0],
|
||||
"negative": ["encode_neg", 0],
|
||||
},
|
||||
},
|
||||
"sampler": {
|
||||
"class_type": "KSampler",
|
||||
"inputs": {
|
||||
"seed": 123,
|
||||
"steps": 20,
|
||||
"cfg": 7.0,
|
||||
"sampler_name": "Euler",
|
||||
"scheduler": "karras",
|
||||
"denoise": 1.0,
|
||||
"positive": ["controlnet", 0],
|
||||
"negative": ["controlnet", 1],
|
||||
"latent_image": {"samples": types.SimpleNamespace(shape=(1, 4, 16, 16))},
|
||||
},
|
||||
},
|
||||
}
|
||||
prompt = SimpleNamespace(original_prompt=prompt_graph)
|
||||
|
||||
wd_conditioning = object()
|
||||
manual_conditioning = object()
|
||||
negative_conditioning = object()
|
||||
combined_conditioning = object()
|
||||
controlnet_positive = object()
|
||||
controlnet_negative = object()
|
||||
|
||||
monkeypatch.setattr(metadata_processor, "standalone_mode", False)
|
||||
|
||||
metadata_registry.start_collection("prompt-provenance")
|
||||
metadata_registry.set_current_prompt(prompt)
|
||||
|
||||
metadata_registry.record_node_execution(
|
||||
"encode_wd", "CLIPTextEncode", {"text": "wd14 tags"}, None
|
||||
)
|
||||
metadata_registry.update_node_execution(
|
||||
"encode_wd", "CLIPTextEncode", [(wd_conditioning,)]
|
||||
)
|
||||
metadata_registry.record_node_execution(
|
||||
"encode_manual", "CLIPTextEncode", {"text": "manual tags"}, None
|
||||
)
|
||||
metadata_registry.update_node_execution(
|
||||
"encode_manual", "CLIPTextEncode", [(manual_conditioning,)]
|
||||
)
|
||||
metadata_registry.record_node_execution(
|
||||
"encode_neg", "CLIPTextEncode", {"text": "low quality"}, None
|
||||
)
|
||||
metadata_registry.update_node_execution(
|
||||
"encode_neg", "CLIPTextEncode", [(negative_conditioning,)]
|
||||
)
|
||||
metadata_registry.record_node_execution(
|
||||
"combine",
|
||||
"ConditioningCombine",
|
||||
{
|
||||
"conditioning_1": wd_conditioning,
|
||||
"conditioning_2": manual_conditioning,
|
||||
},
|
||||
None,
|
||||
)
|
||||
metadata_registry.update_node_execution(
|
||||
"combine", "ConditioningCombine", [(combined_conditioning,)]
|
||||
)
|
||||
metadata_registry.record_node_execution(
|
||||
"controlnet",
|
||||
"ControlNetApplyAdvanced",
|
||||
{
|
||||
"positive": combined_conditioning,
|
||||
"negative": negative_conditioning,
|
||||
},
|
||||
None,
|
||||
)
|
||||
metadata_registry.update_node_execution(
|
||||
"controlnet",
|
||||
"ControlNetApplyAdvanced",
|
||||
[(controlnet_positive, controlnet_negative)],
|
||||
)
|
||||
metadata_registry.record_node_execution(
|
||||
"sampler",
|
||||
"KSampler",
|
||||
{
|
||||
"seed": 123,
|
||||
"steps": 20,
|
||||
"cfg": 7.0,
|
||||
"sampler_name": "Euler",
|
||||
"scheduler": "karras",
|
||||
"denoise": 1.0,
|
||||
"positive": controlnet_positive,
|
||||
"negative": controlnet_negative,
|
||||
"latent_image": {"samples": types.SimpleNamespace(shape=(1, 4, 16, 16))},
|
||||
},
|
||||
None,
|
||||
)
|
||||
|
||||
metadata = metadata_registry.get_metadata("prompt-provenance")
|
||||
params = MetadataProcessor.extract_generation_params(metadata)
|
||||
|
||||
assert params["prompt"] == "wd14 tags, manual tags"
|
||||
assert params["negative_prompt"] == "low quality"
|
||||
|
||||
|
||||
def test_conditioning_provenance_recovers_kj_set_get_prompts(
|
||||
metadata_registry, monkeypatch
|
||||
):
|
||||
import types
|
||||
|
||||
prompt_graph = {
|
||||
"encode_pos": {
|
||||
"class_type": "CLIPTextEncode",
|
||||
"inputs": {"text": "from set node", "clip": ["clip", 0]},
|
||||
},
|
||||
"set_positive": {
|
||||
"class_type": "SetNode",
|
||||
"inputs": {"CONDITIONING": ["encode_pos", 0], "name": "positive"},
|
||||
},
|
||||
"get_positive": {
|
||||
"class_type": "GetNode",
|
||||
"inputs": {"name": "positive"},
|
||||
},
|
||||
"sampler": {
|
||||
"class_type": "KSampler",
|
||||
"inputs": {
|
||||
"seed": 123,
|
||||
"steps": 20,
|
||||
"cfg": 7.0,
|
||||
"sampler_name": "Euler",
|
||||
"scheduler": "karras",
|
||||
"denoise": 1.0,
|
||||
"positive": ["get_positive", 0],
|
||||
"negative": ["encode_pos", 0],
|
||||
"latent_image": {"samples": types.SimpleNamespace(shape=(1, 4, 16, 16))},
|
||||
},
|
||||
},
|
||||
}
|
||||
prompt = SimpleNamespace(original_prompt=prompt_graph)
|
||||
|
||||
original_conditioning = object()
|
||||
get_conditioning = object()
|
||||
|
||||
monkeypatch.setattr(metadata_processor, "standalone_mode", False)
|
||||
|
||||
metadata_registry.start_collection("prompt-kj-get")
|
||||
metadata_registry.set_current_prompt(prompt)
|
||||
|
||||
metadata_registry.record_node_execution(
|
||||
"encode_pos", "CLIPTextEncode", {"text": "from set node"}, None
|
||||
)
|
||||
metadata_registry.update_node_execution(
|
||||
"encode_pos", "CLIPTextEncode", [(original_conditioning,)]
|
||||
)
|
||||
metadata_registry.record_node_execution(
|
||||
"set_positive",
|
||||
"SetNode",
|
||||
{"CONDITIONING": original_conditioning, "name": "positive"},
|
||||
None,
|
||||
)
|
||||
metadata_registry.record_node_execution(
|
||||
"get_positive", "GetNode", {"name": "positive"}, None
|
||||
)
|
||||
metadata_registry.update_node_execution(
|
||||
"get_positive", "GetNode", [(get_conditioning,)]
|
||||
)
|
||||
metadata_registry.record_node_execution(
|
||||
"sampler",
|
||||
"KSampler",
|
||||
{
|
||||
"seed": 123,
|
||||
"steps": 20,
|
||||
"cfg": 7.0,
|
||||
"sampler_name": "Euler",
|
||||
"scheduler": "karras",
|
||||
"denoise": 1.0,
|
||||
"positive": get_conditioning,
|
||||
"negative": original_conditioning,
|
||||
"latent_image": {"samples": types.SimpleNamespace(shape=(1, 4, 16, 16))},
|
||||
},
|
||||
None,
|
||||
)
|
||||
|
||||
metadata = metadata_registry.get_metadata("prompt-kj-get")
|
||||
params = MetadataProcessor.extract_generation_params(metadata)
|
||||
|
||||
assert params["prompt"] == "from set node"
|
||||
assert params["negative_prompt"] == "from set node"
|
||||
|
||||
|
||||
def test_sampler_custom_advanced_recovers_prompt_text_through_guidance_nodes(metadata_registry, monkeypatch):
|
||||
import types
|
||||
|
||||
|
||||
Reference in New Issue
Block a user