mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-26 07:35:44 -03:00
feat(metadata_processor): enhance primary sampler selection logic
- Add pre-processing step to populate missing parameters for candidate samplers, especially for SamplerCustomAdvanced requiring tracing - Change sampler selection from most recent (closest to downstream) to first in execution order to prioritize base samplers over refine samplers - Improve parameter handling by updating sampler parameters with traced values before ranking - Maintain backward compatibility with fallback to first sampler if no criteria match
This commit is contained in:
@@ -40,7 +40,38 @@ class MetadataProcessor:
|
|||||||
candidate_samplers[node_id] = metadata[SAMPLING][node_id]
|
candidate_samplers[node_id] = metadata[SAMPLING][node_id]
|
||||||
|
|
||||||
# If we found candidate samplers, apply primary sampler logic to these candidates only
|
# If we found candidate samplers, apply primary sampler logic to these candidates only
|
||||||
if candidate_samplers:
|
|
||||||
|
# PRE-PROCESS: Ensure all candidate samplers have their parameters populated
|
||||||
|
# This is especially important for SamplerCustomAdvanced which needs tracing
|
||||||
|
prompt = metadata.get("current_prompt")
|
||||||
|
for node_id in candidate_samplers:
|
||||||
|
# If a sampler is missing common parameters like steps or denoise,
|
||||||
|
# try to populate them using tracing before ranking
|
||||||
|
sampler_info = candidate_samplers[node_id]
|
||||||
|
params = sampler_info.get("parameters", {})
|
||||||
|
|
||||||
|
if prompt and (params.get("steps") is None or params.get("denoise") is None):
|
||||||
|
# Create a temporary params dict to use the handler
|
||||||
|
temp_params = {
|
||||||
|
"steps": params.get("steps"),
|
||||||
|
"denoise": params.get("denoise"),
|
||||||
|
"sampler": params.get("sampler_name"),
|
||||||
|
"scheduler": params.get("scheduler")
|
||||||
|
}
|
||||||
|
|
||||||
|
# Check if it's SamplerCustomAdvanced
|
||||||
|
if prompt.original_prompt and node_id in prompt.original_prompt:
|
||||||
|
if prompt.original_prompt[node_id].get("class_type") == "SamplerCustomAdvanced":
|
||||||
|
MetadataProcessor.handle_custom_advanced_sampler(metadata, prompt, node_id, temp_params)
|
||||||
|
|
||||||
|
# Update the actual parameters with found values
|
||||||
|
params["steps"] = temp_params.get("steps")
|
||||||
|
params["denoise"] = temp_params.get("denoise")
|
||||||
|
if temp_params.get("sampler"):
|
||||||
|
params["sampler_name"] = temp_params.get("sampler")
|
||||||
|
if temp_params.get("scheduler"):
|
||||||
|
params["scheduler"] = temp_params.get("scheduler")
|
||||||
|
|
||||||
# Collect potential primary samplers based on different criteria
|
# Collect potential primary samplers based on different criteria
|
||||||
custom_advanced_samplers = []
|
custom_advanced_samplers = []
|
||||||
advanced_add_noise_samplers = []
|
advanced_add_noise_samplers = []
|
||||||
@@ -49,7 +80,6 @@ class MetadataProcessor:
|
|||||||
high_denoise_id = None
|
high_denoise_id = None
|
||||||
|
|
||||||
# First, check for SamplerCustomAdvanced among candidates
|
# First, check for SamplerCustomAdvanced among candidates
|
||||||
prompt = metadata.get("current_prompt")
|
|
||||||
if prompt and prompt.original_prompt:
|
if prompt and prompt.original_prompt:
|
||||||
for node_id in candidate_samplers:
|
for node_id in candidate_samplers:
|
||||||
node_info = prompt.original_prompt.get(node_id, {})
|
node_info = prompt.original_prompt.get(node_id, {})
|
||||||
@@ -77,15 +107,16 @@ class MetadataProcessor:
|
|||||||
# Combine all potential primary samplers
|
# Combine all potential primary samplers
|
||||||
potential_samplers = custom_advanced_samplers + advanced_add_noise_samplers + high_denoise_samplers
|
potential_samplers = custom_advanced_samplers + advanced_add_noise_samplers + high_denoise_samplers
|
||||||
|
|
||||||
# Find the most recent potential primary sampler (closest to downstream node)
|
# Find the first potential primary sampler (prefer base sampler over refine)
|
||||||
for i in range(downstream_index - 1, -1, -1):
|
# Use forward search to prioritize the first one in execution order
|
||||||
|
for i in range(downstream_index):
|
||||||
node_id = execution_order[i]
|
node_id = execution_order[i]
|
||||||
if node_id in potential_samplers:
|
if node_id in potential_samplers:
|
||||||
return node_id, candidate_samplers[node_id]
|
return node_id, candidate_samplers[node_id]
|
||||||
|
|
||||||
# If no potential sampler found from our criteria, return the most recent sampler
|
# If no potential sampler found from our criteria, return the first sampler
|
||||||
if candidate_samplers:
|
if candidate_samplers:
|
||||||
for i in range(downstream_index - 1, -1, -1):
|
for i in range(downstream_index):
|
||||||
node_id = execution_order[i]
|
node_id = execution_order[i]
|
||||||
if node_id in candidate_samplers:
|
if node_id in candidate_samplers:
|
||||||
return node_id, candidate_samplers[node_id]
|
return node_id, candidate_samplers[node_id]
|
||||||
@@ -176,7 +207,10 @@ class MetadataProcessor:
|
|||||||
found_node_id = input_value[0] # Connected node_id
|
found_node_id = input_value[0] # Connected node_id
|
||||||
|
|
||||||
# If we're looking for a specific node class
|
# If we're looking for a specific node class
|
||||||
if target_class and prompt.original_prompt[found_node_id].get("class_type") == target_class:
|
if target_class:
|
||||||
|
if found_node_id not in prompt.original_prompt:
|
||||||
|
return None
|
||||||
|
if prompt.original_prompt[found_node_id].get("class_type") == target_class:
|
||||||
return found_node_id
|
return found_node_id
|
||||||
|
|
||||||
# If we're not looking for a specific class, update the last valid node
|
# If we're not looking for a specific class, update the last valid node
|
||||||
@@ -185,11 +219,19 @@ class MetadataProcessor:
|
|||||||
|
|
||||||
# Continue tracing through intermediate nodes
|
# Continue tracing through intermediate nodes
|
||||||
current_node_id = found_node_id
|
current_node_id = found_node_id
|
||||||
# For most conditioning nodes, the input we want to follow is named "conditioning"
|
|
||||||
if "conditioning" in prompt.original_prompt[current_node_id].get("inputs", {}):
|
# Check if current source node exists
|
||||||
|
if current_node_id not in prompt.original_prompt:
|
||||||
|
return found_node_id if not target_class else None
|
||||||
|
|
||||||
|
# Determine which input to follow next on the source node
|
||||||
|
source_node_inputs = prompt.original_prompt[current_node_id].get("inputs", {})
|
||||||
|
if input_name in source_node_inputs:
|
||||||
|
current_input = input_name
|
||||||
|
elif "conditioning" in source_node_inputs:
|
||||||
current_input = "conditioning"
|
current_input = "conditioning"
|
||||||
else:
|
else:
|
||||||
# If there's no "conditioning" input, return the current node
|
# If there's no suitable input to follow, return the current node
|
||||||
# if we're not looking for a specific target_class
|
# if we're not looking for a specific target_class
|
||||||
return found_node_id if not target_class else None
|
return found_node_id if not target_class else None
|
||||||
else:
|
else:
|
||||||
@@ -523,6 +565,7 @@ class MetadataProcessor:
|
|||||||
scheduler_params = metadata[SAMPLING][scheduler_node_id].get("parameters", {})
|
scheduler_params = metadata[SAMPLING][scheduler_node_id].get("parameters", {})
|
||||||
params["steps"] = scheduler_params.get("steps")
|
params["steps"] = scheduler_params.get("steps")
|
||||||
params["scheduler"] = scheduler_params.get("scheduler")
|
params["scheduler"] = scheduler_params.get("scheduler")
|
||||||
|
params["denoise"] = scheduler_params.get("denoise")
|
||||||
|
|
||||||
# 2. Trace sampler input to find KSamplerSelect (only if sampler input exists)
|
# 2. Trace sampler input to find KSamplerSelect (only if sampler input exists)
|
||||||
if "sampler" in sampler_inputs:
|
if "sampler" in sampler_inputs:
|
||||||
|
|||||||
Reference in New Issue
Block a user