feat: Enhance trace_node_input method with depth tracking and target class filtering; add FluxGuidanceExtractor for guidance parameter extraction

This commit is contained in:
Will Miao
2025-04-17 08:06:21 +08:00
parent 18eb605605
commit 32d34d1748
2 changed files with 80 additions and 13 deletions

View File

@@ -24,20 +24,65 @@ class MetadataProcessor:
return primary_sampler_id, primary_sampler
@staticmethod
def trace_node_input(prompt, node_id, input_name):
"""Trace an input connection from a node to find the source node"""
def trace_node_input(prompt, node_id, input_name, target_class=None, max_depth=10):
"""
Trace an input connection from a node to find the source node
Parameters:
- prompt: The prompt object containing node connections
- node_id: ID of the starting node
- input_name: Name of the input to trace
- target_class: Optional class name to search for (e.g., "CLIPTextEncode")
- max_depth: Maximum depth to follow the node chain to prevent infinite loops
Returns:
- node_id of the found node, or None if not found
"""
if not prompt or not prompt.original_prompt or node_id not in prompt.original_prompt:
return None
node_inputs = prompt.original_prompt[node_id].get("inputs", {})
if input_name not in node_inputs:
return None
# For depth tracking
current_depth = 0
current_node_id = node_id
current_input = input_name
while current_depth < max_depth:
if current_node_id not in prompt.original_prompt:
return None
node_inputs = prompt.original_prompt[current_node_id].get("inputs", {})
if current_input not in node_inputs:
return None
input_value = node_inputs[current_input]
# Input connections are formatted as [node_id, output_index]
if isinstance(input_value, list) and len(input_value) >= 2:
found_node_id = input_value[0] # Connected node_id
# If we're looking for a specific node class
if target_class and prompt.original_prompt[found_node_id].get("class_type") == target_class:
return found_node_id
# If we're not looking for a specific class or haven't found it yet
if not target_class:
return found_node_id
# Continue tracing through intermediate nodes
current_node_id = found_node_id
# For most conditioning nodes, the input we want to follow is named "conditioning"
if "conditioning" in prompt.original_prompt[current_node_id].get("inputs", {}):
current_input = "conditioning"
else:
# If there's no "conditioning" input, we can't trace further
return found_node_id if not target_class else None
else:
# We've reached a node with no further connections
return None
input_value = node_inputs[input_name]
# Input connections are formatted as [node_id, output_index]
if isinstance(input_value, list) and len(input_value) >= 2:
return input_value[0] # Return connected node_id
current_depth += 1
# If we've reached max depth without finding target_class
return None
@staticmethod
@@ -62,6 +107,7 @@ class MetadataProcessor:
"seed": None,
"steps": None,
"cfg_scale": None,
"guidance": None, # Add guidance parameter
"sampler": None,
"checkpoint": None,
"loras": "",
@@ -90,13 +136,19 @@ class MetadataProcessor:
# Trace connections from the primary sampler
if prompt and primary_sampler_id:
# Trace positive prompt
positive_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "positive")
# Trace positive prompt - look specifically for CLIPTextEncode
positive_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "positive", "CLIPTextEncode", max_depth=10)
if positive_node_id and positive_node_id in metadata.get(PROMPTS, {}):
params["prompt"] = metadata[PROMPTS][positive_node_id].get("text", "")
# Trace negative prompt
negative_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "negative")
# Find any FluxGuidance nodes in the positive conditioning path
flux_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "positive", "FluxGuidance", max_depth=5)
if flux_node_id and flux_node_id in metadata.get(SAMPLING, {}):
flux_params = metadata[SAMPLING][flux_node_id].get("parameters", {})
params["guidance"] = flux_params.get("guidance")
# Trace negative prompt - look specifically for CLIPTextEncode
negative_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "negative", "CLIPTextEncode", max_depth=10)
if negative_node_id and negative_node_id in metadata.get(PROMPTS, {}):
params["negative_prompt"] = metadata[PROMPTS][negative_node_id].get("text", "")

View File

@@ -170,6 +170,20 @@ class LoraLoaderManagerExtractor(NodeMetadataExtractor):
"lora_list": active_loras,
"node_id": node_id
}
class FluxGuidanceExtractor(NodeMetadataExtractor):
@staticmethod
def extract(node_id, inputs, outputs, metadata):
if not inputs or "guidance" not in inputs:
return
guidance_value = inputs.get("guidance")
# Store the guidance value in SAMPLING category
if node_id not in metadata[SAMPLING]:
metadata[SAMPLING][node_id] = {"parameters": {}, "node_id": node_id}
metadata[SAMPLING][node_id]["parameters"]["guidance"] = guidance_value
# Registry of node-specific extractors
NODE_EXTRACTORS = {
@@ -181,5 +195,6 @@ NODE_EXTRACTORS = {
"LoraManagerLoader": LoraLoaderManagerExtractor,
"SamplerCustomAdvanced": SamplerExtractor, # Add SamplerCustomAdvanced
"UNETLoader": CheckpointLoaderExtractor, # Add UNETLoader
"FluxGuidance": FluxGuidanceExtractor, # Add FluxGuidance
# Add other nodes as needed
}