From 154ae825193a696d348a722ca5b98884c9e28073 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Fri, 19 Dec 2025 01:30:08 +0800 Subject: [PATCH] feat(metadata_processor): enhance primary sampler selection logic - Add pre-processing step to populate missing parameters for candidate samplers, especially for SamplerCustomAdvanced requiring tracing - Change sampler selection from most recent (closest to downstream) to first in execution order to prioritize base samplers over refine samplers - Improve parameter handling by updating sampler parameters with traced values before ranking - Maintain backward compatibility with fallback to first sampler if no criteria match --- py/metadata_collector/metadata_processor.py | 67 +++++++++++++++++---- 1 file changed, 55 insertions(+), 12 deletions(-) diff --git a/py/metadata_collector/metadata_processor.py b/py/metadata_collector/metadata_processor.py index 9dd85542..2d39f2ba 100644 --- a/py/metadata_collector/metadata_processor.py +++ b/py/metadata_collector/metadata_processor.py @@ -39,8 +39,39 @@ class MetadataProcessor: if node_id in metadata.get(SAMPLING, {}) and metadata[SAMPLING][node_id].get(IS_SAMPLER, False): candidate_samplers[node_id] = metadata[SAMPLING][node_id] - # If we found candidate samplers, apply primary sampler logic to these candidates only - if candidate_samplers: + # If we found candidate samplers, apply primary sampler logic to these candidates only + + # PRE-PROCESS: Ensure all candidate samplers have their parameters populated + # This is especially important for SamplerCustomAdvanced which needs tracing + prompt = metadata.get("current_prompt") + for node_id in candidate_samplers: + # If a sampler is missing common parameters like steps or denoise, + # try to populate them using tracing before ranking + sampler_info = candidate_samplers[node_id] + params = sampler_info.get("parameters", {}) + + if prompt and (params.get("steps") is None or params.get("denoise") is None): + # Create a temporary params dict to use the handler + temp_params = { + "steps": params.get("steps"), + "denoise": params.get("denoise"), + "sampler": params.get("sampler_name"), + "scheduler": params.get("scheduler") + } + + # Check if it's SamplerCustomAdvanced + if prompt.original_prompt and node_id in prompt.original_prompt: + if prompt.original_prompt[node_id].get("class_type") == "SamplerCustomAdvanced": + MetadataProcessor.handle_custom_advanced_sampler(metadata, prompt, node_id, temp_params) + + # Update the actual parameters with found values + params["steps"] = temp_params.get("steps") + params["denoise"] = temp_params.get("denoise") + if temp_params.get("sampler"): + params["sampler_name"] = temp_params.get("sampler") + if temp_params.get("scheduler"): + params["scheduler"] = temp_params.get("scheduler") + # Collect potential primary samplers based on different criteria custom_advanced_samplers = [] advanced_add_noise_samplers = [] @@ -49,7 +80,6 @@ class MetadataProcessor: high_denoise_id = None # First, check for SamplerCustomAdvanced among candidates - prompt = metadata.get("current_prompt") if prompt and prompt.original_prompt: for node_id in candidate_samplers: node_info = prompt.original_prompt.get(node_id, {}) @@ -77,15 +107,16 @@ class MetadataProcessor: # Combine all potential primary samplers potential_samplers = custom_advanced_samplers + advanced_add_noise_samplers + high_denoise_samplers - # Find the most recent potential primary sampler (closest to downstream node) - for i in range(downstream_index - 1, -1, -1): + # Find the first potential primary sampler (prefer base sampler over refine) + # Use forward search to prioritize the first one in execution order + for i in range(downstream_index): node_id = execution_order[i] if node_id in potential_samplers: return node_id, candidate_samplers[node_id] - # If no potential sampler found from our criteria, return the most recent sampler + # If no potential sampler found from our criteria, return the first sampler if candidate_samplers: - for i in range(downstream_index - 1, -1, -1): + for i in range(downstream_index): node_id = execution_order[i] if node_id in candidate_samplers: return node_id, candidate_samplers[node_id] @@ -176,8 +207,11 @@ class MetadataProcessor: found_node_id = input_value[0] # Connected node_id # If we're looking for a specific node class - if target_class and prompt.original_prompt[found_node_id].get("class_type") == target_class: - return found_node_id + if target_class: + if found_node_id not in prompt.original_prompt: + return None + if prompt.original_prompt[found_node_id].get("class_type") == target_class: + return found_node_id # If we're not looking for a specific class, update the last valid node if not target_class: @@ -185,11 +219,19 @@ class MetadataProcessor: # Continue tracing through intermediate nodes current_node_id = found_node_id - # For most conditioning nodes, the input we want to follow is named "conditioning" - if "conditioning" in prompt.original_prompt[current_node_id].get("inputs", {}): + + # Check if current source node exists + if current_node_id not in prompt.original_prompt: + return found_node_id if not target_class else None + + # Determine which input to follow next on the source node + source_node_inputs = prompt.original_prompt[current_node_id].get("inputs", {}) + if input_name in source_node_inputs: + current_input = input_name + elif "conditioning" in source_node_inputs: current_input = "conditioning" else: - # If there's no "conditioning" input, return the current node + # If there's no suitable input to follow, return the current node # if we're not looking for a specific target_class return found_node_id if not target_class else None else: @@ -523,6 +565,7 @@ class MetadataProcessor: scheduler_params = metadata[SAMPLING][scheduler_node_id].get("parameters", {}) params["steps"] = scheduler_params.get("steps") params["scheduler"] = scheduler_params.get("scheduler") + params["denoise"] = scheduler_params.get("denoise") # 2. Trace sampler input to find KSamplerSelect (only if sampler input exists) if "sampler" in sampler_inputs: