From 32d34d17481b779ab2f140d25eced3093a159f7a Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Thu, 17 Apr 2025 08:06:21 +0800 Subject: [PATCH] feat: Enhance trace_node_input method with depth tracking and target class filtering; add FluxGuidanceExtractor for guidance parameter extraction --- py/metadata_collector/metadata_processor.py | 78 +++++++++++++++++---- py/metadata_collector/node_extractors.py | 15 ++++ 2 files changed, 80 insertions(+), 13 deletions(-) diff --git a/py/metadata_collector/metadata_processor.py b/py/metadata_collector/metadata_processor.py index 9c80e68c..9cf2cb83 100644 --- a/py/metadata_collector/metadata_processor.py +++ b/py/metadata_collector/metadata_processor.py @@ -24,20 +24,65 @@ class MetadataProcessor: return primary_sampler_id, primary_sampler @staticmethod - def trace_node_input(prompt, node_id, input_name): - """Trace an input connection from a node to find the source node""" + def trace_node_input(prompt, node_id, input_name, target_class=None, max_depth=10): + """ + Trace an input connection from a node to find the source node + + Parameters: + - prompt: The prompt object containing node connections + - node_id: ID of the starting node + - input_name: Name of the input to trace + - target_class: Optional class name to search for (e.g., "CLIPTextEncode") + - max_depth: Maximum depth to follow the node chain to prevent infinite loops + + Returns: + - node_id of the found node, or None if not found + """ if not prompt or not prompt.original_prompt or node_id not in prompt.original_prompt: return None - node_inputs = prompt.original_prompt[node_id].get("inputs", {}) - if input_name not in node_inputs: - return None + # For depth tracking + current_depth = 0 + + current_node_id = node_id + current_input = input_name + + while current_depth < max_depth: + if current_node_id not in prompt.original_prompt: + return None + + node_inputs = prompt.original_prompt[current_node_id].get("inputs", {}) + if current_input not in node_inputs: + return None + + input_value = node_inputs[current_input] + # Input connections are formatted as [node_id, output_index] + if isinstance(input_value, list) and len(input_value) >= 2: + found_node_id = input_value[0] # Connected node_id + + # If we're looking for a specific node class + if target_class and prompt.original_prompt[found_node_id].get("class_type") == target_class: + return found_node_id + + # If we're not looking for a specific class or haven't found it yet + if not target_class: + return found_node_id + + # Continue tracing through intermediate nodes + current_node_id = found_node_id + # For most conditioning nodes, the input we want to follow is named "conditioning" + if "conditioning" in prompt.original_prompt[current_node_id].get("inputs", {}): + current_input = "conditioning" + else: + # If there's no "conditioning" input, we can't trace further + return found_node_id if not target_class else None + else: + # We've reached a node with no further connections + return None - input_value = node_inputs[input_name] - # Input connections are formatted as [node_id, output_index] - if isinstance(input_value, list) and len(input_value) >= 2: - return input_value[0] # Return connected node_id + current_depth += 1 + # If we've reached max depth without finding target_class return None @staticmethod @@ -62,6 +107,7 @@ class MetadataProcessor: "seed": None, "steps": None, "cfg_scale": None, + "guidance": None, # Add guidance parameter "sampler": None, "checkpoint": None, "loras": "", @@ -90,13 +136,19 @@ class MetadataProcessor: # Trace connections from the primary sampler if prompt and primary_sampler_id: - # Trace positive prompt - positive_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "positive") + # Trace positive prompt - look specifically for CLIPTextEncode + positive_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "positive", "CLIPTextEncode", max_depth=10) if positive_node_id and positive_node_id in metadata.get(PROMPTS, {}): params["prompt"] = metadata[PROMPTS][positive_node_id].get("text", "") - # Trace negative prompt - negative_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "negative") + # Find any FluxGuidance nodes in the positive conditioning path + flux_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "positive", "FluxGuidance", max_depth=5) + if flux_node_id and flux_node_id in metadata.get(SAMPLING, {}): + flux_params = metadata[SAMPLING][flux_node_id].get("parameters", {}) + params["guidance"] = flux_params.get("guidance") + + # Trace negative prompt - look specifically for CLIPTextEncode + negative_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "negative", "CLIPTextEncode", max_depth=10) if negative_node_id and negative_node_id in metadata.get(PROMPTS, {}): params["negative_prompt"] = metadata[PROMPTS][negative_node_id].get("text", "") diff --git a/py/metadata_collector/node_extractors.py b/py/metadata_collector/node_extractors.py index 6fb018b5..0d599c93 100644 --- a/py/metadata_collector/node_extractors.py +++ b/py/metadata_collector/node_extractors.py @@ -170,6 +170,20 @@ class LoraLoaderManagerExtractor(NodeMetadataExtractor): "lora_list": active_loras, "node_id": node_id } + +class FluxGuidanceExtractor(NodeMetadataExtractor): + @staticmethod + def extract(node_id, inputs, outputs, metadata): + if not inputs or "guidance" not in inputs: + return + + guidance_value = inputs.get("guidance") + + # Store the guidance value in SAMPLING category + if node_id not in metadata[SAMPLING]: + metadata[SAMPLING][node_id] = {"parameters": {}, "node_id": node_id} + + metadata[SAMPLING][node_id]["parameters"]["guidance"] = guidance_value # Registry of node-specific extractors NODE_EXTRACTORS = { @@ -181,5 +195,6 @@ NODE_EXTRACTORS = { "LoraManagerLoader": LoraLoaderManagerExtractor, "SamplerCustomAdvanced": SamplerExtractor, # Add SamplerCustomAdvanced "UNETLoader": CheckpointLoaderExtractor, # Add UNETLoader + "FluxGuidance": FluxGuidanceExtractor, # Add FluxGuidance # Add other nodes as needed }