refactor: streamline prompt matching logic in MetadataProcessor

This commit is contained in:
Will Miao
2025-06-20 17:00:23 +08:00
parent 32d12bb334
commit aa34c4c84c

View File

@@ -299,14 +299,7 @@ class MetadataProcessor:
params["sampler"] = sampling_params.get("sampler_name") params["sampler"] = sampling_params.get("sampler_name")
params["scheduler"] = sampling_params.get("scheduler") params["scheduler"] = sampling_params.get("scheduler")
# First try to match conditioning objects to prompts (new method) if prompt and primary_sampler_id:
if primary_sampler_id:
prompt_results = MetadataProcessor.match_conditioning_to_prompts(metadata, primary_sampler_id)
params["prompt"] = prompt_results["prompt"]
params["negative_prompt"] = prompt_results["negative_prompt"]
# If prompts were not found by object matching, fall back to tracing connections
if not params["prompt"] and prompt and primary_sampler_id:
# Check if this is a SamplerCustomAdvanced node # Check if this is a SamplerCustomAdvanced node
is_custom_advanced = False is_custom_advanced = False
if prompt.original_prompt and primary_sampler_id in prompt.original_prompt: if prompt.original_prompt and primary_sampler_id in prompt.original_prompt:
@@ -353,6 +346,13 @@ class MetadataProcessor:
params["prompt"] = metadata[PROMPTS][positive_node_id].get("text", "") params["prompt"] = metadata[PROMPTS][positive_node_id].get("text", "")
else: else:
# For standard samplers, match conditioning objects to prompts
prompt_results = MetadataProcessor.match_conditioning_to_prompts(metadata, primary_sampler_id)
params["prompt"] = prompt_results["prompt"]
params["negative_prompt"] = prompt_results["negative_prompt"]
# If prompts were still not found, fall back to tracing connections
if not params["prompt"]:
# Original tracing for standard samplers # Original tracing for standard samplers
# Trace positive prompt - look specifically for CLIPTextEncode # Trace positive prompt - look specifically for CLIPTextEncode
positive_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "positive", max_depth=10) positive_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "positive", max_depth=10)