feat: Add conditioning matching to prompts and update metadata handling in node extractors. See #235

This commit is contained in:
Will Miao
2025-06-20 00:04:02 +08:00
parent d840fd53da
commit 4889955ecf
2 changed files with 85 additions and 2 deletions

View File

@@ -209,6 +209,44 @@ class MetadataProcessor:
return None
@staticmethod
def match_conditioning_to_prompts(metadata, sampler_id):
"""
Match conditioning objects from a sampler to prompts in metadata
Parameters:
- metadata: The workflow metadata
- sampler_id: ID of the sampler node to match
Returns:
- Dictionary with 'prompt' and 'negative_prompt' if found
"""
result = {
"prompt": "",
"negative_prompt": ""
}
# Check if we have stored conditioning objects for this sampler
if sampler_id in metadata.get(PROMPTS, {}) and (
"pos_conditioning" in metadata[PROMPTS][sampler_id] or
"neg_conditioning" in metadata[PROMPTS][sampler_id]):
pos_conditioning = metadata[PROMPTS][sampler_id].get("pos_conditioning")
neg_conditioning = metadata[PROMPTS][sampler_id].get("neg_conditioning")
# Try to match conditioning objects with those stored by CLIPTextEncodeExtractor
for prompt_node_id, prompt_data in metadata[PROMPTS].items():
if "conditioning" not in prompt_data:
continue
if pos_conditioning is not None and id(prompt_data["conditioning"]) == id(pos_conditioning):
result["prompt"] = prompt_data.get("text", "")
if neg_conditioning is not None and id(prompt_data["conditioning"]) == id(neg_conditioning):
result["negative_prompt"] = prompt_data.get("text", "")
return result
@staticmethod
def extract_generation_params(metadata, id=None):
"""
@@ -261,8 +299,14 @@ class MetadataProcessor:
params["sampler"] = sampling_params.get("sampler_name")
params["scheduler"] = sampling_params.get("scheduler")
# Trace connections from the primary sampler
if prompt and primary_sampler_id:
# First try to match conditioning objects to prompts (new method)
if primary_sampler_id:
prompt_results = MetadataProcessor.match_conditioning_to_prompts(metadata, primary_sampler_id)
params["prompt"] = prompt_results["prompt"]
params["negative_prompt"] = prompt_results["negative_prompt"]
# If prompts were not found by object matching, fall back to tracing connections
if not params["prompt"] and prompt and primary_sampler_id:
# Check if this is a SamplerCustomAdvanced node
is_custom_advanced = False
if prompt.original_prompt and primary_sampler_id in prompt.original_prompt: