refactor: streamline prompt matching logic in MetadataProcessor

This commit is contained in:
Will Miao
2025-06-20 17:00:23 +08:00
parent 32d12bb334
commit aa34c4c84c

View File

@@ -299,14 +299,7 @@ class MetadataProcessor:
params["sampler"] = sampling_params.get("sampler_name") params["sampler"] = sampling_params.get("sampler_name")
params["scheduler"] = sampling_params.get("scheduler") params["scheduler"] = sampling_params.get("scheduler")
# First try to match conditioning objects to prompts (new method) if prompt and primary_sampler_id:
if primary_sampler_id:
prompt_results = MetadataProcessor.match_conditioning_to_prompts(metadata, primary_sampler_id)
params["prompt"] = prompt_results["prompt"]
params["negative_prompt"] = prompt_results["negative_prompt"]
# If prompts were not found by object matching, fall back to tracing connections
if not params["prompt"] and prompt and primary_sampler_id:
# Check if this is a SamplerCustomAdvanced node # Check if this is a SamplerCustomAdvanced node
is_custom_advanced = False is_custom_advanced = False
if prompt.original_prompt and primary_sampler_id in prompt.original_prompt: if prompt.original_prompt and primary_sampler_id in prompt.original_prompt:
@@ -353,29 +346,36 @@ class MetadataProcessor:
params["prompt"] = metadata[PROMPTS][positive_node_id].get("text", "") params["prompt"] = metadata[PROMPTS][positive_node_id].get("text", "")
else: else:
# Original tracing for standard samplers # For standard samplers, match conditioning objects to prompts
# Trace positive prompt - look specifically for CLIPTextEncode prompt_results = MetadataProcessor.match_conditioning_to_prompts(metadata, primary_sampler_id)
positive_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "positive", max_depth=10) params["prompt"] = prompt_results["prompt"]
if positive_node_id and positive_node_id in metadata.get(PROMPTS, {}): params["negative_prompt"] = prompt_results["negative_prompt"]
params["prompt"] = metadata[PROMPTS][positive_node_id].get("text", "")
else: # If prompts were still not found, fall back to tracing connections
# If CLIPTextEncode is not found, try to find CLIPTextEncodeFlux if not params["prompt"]:
positive_flux_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "positive", "CLIPTextEncodeFlux", max_depth=10) # Original tracing for standard samplers
if positive_flux_node_id and positive_flux_node_id in metadata.get(PROMPTS, {}): # Trace positive prompt - look specifically for CLIPTextEncode
params["prompt"] = metadata[PROMPTS][positive_flux_node_id].get("text", "") positive_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "positive", max_depth=10)
if positive_node_id and positive_node_id in metadata.get(PROMPTS, {}):
# Trace negative prompt - look specifically for CLIPTextEncode params["prompt"] = metadata[PROMPTS][positive_node_id].get("text", "")
negative_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "negative", max_depth=10) else:
if negative_node_id and negative_node_id in metadata.get(PROMPTS, {}): # If CLIPTextEncode is not found, try to find CLIPTextEncodeFlux
params["negative_prompt"] = metadata[PROMPTS][negative_node_id].get("text", "") positive_flux_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "positive", "CLIPTextEncodeFlux", max_depth=10)
if positive_flux_node_id and positive_flux_node_id in metadata.get(PROMPTS, {}):
params["prompt"] = metadata[PROMPTS][positive_flux_node_id].get("text", "")
# Trace negative prompt - look specifically for CLIPTextEncode
negative_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "negative", max_depth=10)
if negative_node_id and negative_node_id in metadata.get(PROMPTS, {}):
params["negative_prompt"] = metadata[PROMPTS][negative_node_id].get("text", "")
# Size extraction is same for all sampler types # Size extraction is same for all sampler types
# Check if the sampler itself has size information (from latent_image) # Check if the sampler itself has size information (from latent_image)
if primary_sampler_id in metadata.get(SIZE, {}): if primary_sampler_id in metadata.get(SIZE, {}):
width = metadata[SIZE][primary_sampler_id].get("width") width = metadata[SIZE][primary_sampler_id].get("width")
height = metadata[SIZE][primary_sampler_id].get("height") height = metadata[SIZE][primary_sampler_id].get("height")
if width and height: if width and height:
params["size"] = f"{width}x{height}" params["size"] = f"{width}x{height}"
# Extract LoRAs using the standardized format # Extract LoRAs using the standardized format
lora_parts = [] lora_parts = []