mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-26 15:38:52 -03:00
feat: Refactor SamplerCustom handling and enhance node extractor mappings for improved metadata processing
This commit is contained in:
@@ -339,44 +339,8 @@ class MetadataProcessor:
|
|||||||
is_custom_advanced = prompt.original_prompt[primary_sampler_id].get("class_type") == "SamplerCustomAdvanced"
|
is_custom_advanced = prompt.original_prompt[primary_sampler_id].get("class_type") == "SamplerCustomAdvanced"
|
||||||
|
|
||||||
if is_custom_advanced:
|
if is_custom_advanced:
|
||||||
# For SamplerCustomAdvanced, trace specific inputs
|
# For SamplerCustomAdvanced, use the new handler method
|
||||||
|
MetadataProcessor.handle_custom_advanced_sampler(metadata, prompt, primary_sampler_id, params)
|
||||||
# 1. Trace sigmas input to find BasicScheduler
|
|
||||||
scheduler_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "sigmas", "BasicScheduler", max_depth=5)
|
|
||||||
if scheduler_node_id and scheduler_node_id in metadata.get(SAMPLING, {}):
|
|
||||||
scheduler_params = metadata[SAMPLING][scheduler_node_id].get("parameters", {})
|
|
||||||
params["steps"] = scheduler_params.get("steps")
|
|
||||||
params["scheduler"] = scheduler_params.get("scheduler")
|
|
||||||
|
|
||||||
# 2. Trace sampler input to find KSamplerSelect
|
|
||||||
sampler_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "sampler", "KSamplerSelect", max_depth=5)
|
|
||||||
if sampler_node_id and sampler_node_id in metadata.get(SAMPLING, {}):
|
|
||||||
sampler_params = metadata[SAMPLING][sampler_node_id].get("parameters", {})
|
|
||||||
params["sampler"] = sampler_params.get("sampler_name")
|
|
||||||
|
|
||||||
# 3. Trace guider input for CFGGuider and CLIPTextEncode
|
|
||||||
guider_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "guider", max_depth=5)
|
|
||||||
if guider_node_id and guider_node_id in prompt.original_prompt:
|
|
||||||
# Check if the guider node is a CFGGuider
|
|
||||||
if prompt.original_prompt[guider_node_id].get("class_type") == "CFGGuider":
|
|
||||||
# Extract cfg value from the CFGGuider
|
|
||||||
if guider_node_id in metadata.get(SAMPLING, {}):
|
|
||||||
cfg_params = metadata[SAMPLING][guider_node_id].get("parameters", {})
|
|
||||||
params["cfg_scale"] = cfg_params.get("cfg")
|
|
||||||
|
|
||||||
# Find CLIPTextEncode for positive prompt
|
|
||||||
positive_node_id = MetadataProcessor.trace_node_input(prompt, guider_node_id, "positive", "CLIPTextEncode", max_depth=10)
|
|
||||||
if positive_node_id and positive_node_id in metadata.get(PROMPTS, {}):
|
|
||||||
params["prompt"] = metadata[PROMPTS][positive_node_id].get("text", "")
|
|
||||||
|
|
||||||
# Find CLIPTextEncode for negative prompt
|
|
||||||
negative_node_id = MetadataProcessor.trace_node_input(prompt, guider_node_id, "negative", "CLIPTextEncode", max_depth=10)
|
|
||||||
if negative_node_id and negative_node_id in metadata.get(PROMPTS, {}):
|
|
||||||
params["negative_prompt"] = metadata[PROMPTS][negative_node_id].get("text", "")
|
|
||||||
else:
|
|
||||||
positive_node_id = MetadataProcessor.trace_node_input(prompt, guider_node_id, "conditioning", max_depth=10)
|
|
||||||
if positive_node_id and positive_node_id in metadata.get(PROMPTS, {}):
|
|
||||||
params["prompt"] = metadata[PROMPTS][positive_node_id].get("text", "")
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# For standard samplers, match conditioning objects to prompts
|
# For standard samplers, match conditioning objects to prompts
|
||||||
@@ -401,6 +365,9 @@ class MetadataProcessor:
|
|||||||
negative_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "negative", max_depth=10)
|
negative_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "negative", max_depth=10)
|
||||||
if negative_node_id and negative_node_id in metadata.get(PROMPTS, {}):
|
if negative_node_id and negative_node_id in metadata.get(PROMPTS, {}):
|
||||||
params["negative_prompt"] = metadata[PROMPTS][negative_node_id].get("text", "")
|
params["negative_prompt"] = metadata[PROMPTS][negative_node_id].get("text", "")
|
||||||
|
|
||||||
|
# For SamplerCustom, handle any additional parameters
|
||||||
|
MetadataProcessor.handle_custom_advanced_sampler(metadata, prompt, primary_sampler_id, params)
|
||||||
|
|
||||||
# Size extraction is same for all sampler types
|
# Size extraction is same for all sampler types
|
||||||
# Check if the sampler itself has size information (from latent_image)
|
# Check if the sampler itself has size information (from latent_image)
|
||||||
@@ -454,3 +421,59 @@ class MetadataProcessor:
|
|||||||
"""Convert metadata to JSON string"""
|
"""Convert metadata to JSON string"""
|
||||||
params = MetadataProcessor.to_dict(metadata, id)
|
params = MetadataProcessor.to_dict(metadata, id)
|
||||||
return json.dumps(params, indent=4)
|
return json.dumps(params, indent=4)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def handle_custom_advanced_sampler(metadata, prompt, primary_sampler_id, params):
|
||||||
|
"""
|
||||||
|
Handle parameter extraction for SamplerCustomAdvanced nodes
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- metadata: The workflow metadata
|
||||||
|
- prompt: The prompt object containing node connections
|
||||||
|
- primary_sampler_id: ID of the SamplerCustomAdvanced node
|
||||||
|
- params: Parameters dictionary to update
|
||||||
|
"""
|
||||||
|
if not prompt.original_prompt or primary_sampler_id not in prompt.original_prompt:
|
||||||
|
return
|
||||||
|
|
||||||
|
sampler_inputs = prompt.original_prompt[primary_sampler_id].get("inputs", {})
|
||||||
|
|
||||||
|
# 1. Trace sigmas input to find BasicScheduler (only if sigmas input exists)
|
||||||
|
if "sigmas" in sampler_inputs:
|
||||||
|
scheduler_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "sigmas", None, max_depth=5)
|
||||||
|
if scheduler_node_id and scheduler_node_id in metadata.get(SAMPLING, {}):
|
||||||
|
scheduler_params = metadata[SAMPLING][scheduler_node_id].get("parameters", {})
|
||||||
|
params["steps"] = scheduler_params.get("steps")
|
||||||
|
params["scheduler"] = scheduler_params.get("scheduler")
|
||||||
|
|
||||||
|
# 2. Trace sampler input to find KSamplerSelect (only if sampler input exists)
|
||||||
|
if "sampler" in sampler_inputs:
|
||||||
|
sampler_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "sampler", "KSamplerSelect", max_depth=5)
|
||||||
|
if sampler_node_id and sampler_node_id in metadata.get(SAMPLING, {}):
|
||||||
|
sampler_params = metadata[SAMPLING][sampler_node_id].get("parameters", {})
|
||||||
|
params["sampler"] = sampler_params.get("sampler_name")
|
||||||
|
|
||||||
|
# 3. Trace guider input for CFGGuider and CLIPTextEncode
|
||||||
|
if "guider" in sampler_inputs:
|
||||||
|
guider_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "guider", max_depth=5)
|
||||||
|
if guider_node_id and guider_node_id in prompt.original_prompt:
|
||||||
|
# Check if the guider node is a CFGGuider
|
||||||
|
if prompt.original_prompt[guider_node_id].get("class_type") == "CFGGuider":
|
||||||
|
# Extract cfg value from the CFGGuider
|
||||||
|
if guider_node_id in metadata.get(SAMPLING, {}):
|
||||||
|
cfg_params = metadata[SAMPLING][guider_node_id].get("parameters", {})
|
||||||
|
params["cfg_scale"] = cfg_params.get("cfg")
|
||||||
|
|
||||||
|
# Find CLIPTextEncode for positive prompt
|
||||||
|
positive_node_id = MetadataProcessor.trace_node_input(prompt, guider_node_id, "positive", "CLIPTextEncode", max_depth=10)
|
||||||
|
if positive_node_id and positive_node_id in metadata.get(PROMPTS, {}):
|
||||||
|
params["prompt"] = metadata[PROMPTS][positive_node_id].get("text", "")
|
||||||
|
|
||||||
|
# Find CLIPTextEncode for negative prompt
|
||||||
|
negative_node_id = MetadataProcessor.trace_node_input(prompt, guider_node_id, "negative", "CLIPTextEncode", max_depth=10)
|
||||||
|
if negative_node_id and negative_node_id in metadata.get(PROMPTS, {}):
|
||||||
|
params["negative_prompt"] = metadata[PROMPTS][negative_node_id].get("text", "")
|
||||||
|
else:
|
||||||
|
positive_node_id = MetadataProcessor.trace_node_input(prompt, guider_node_id, "conditioning", max_depth=10)
|
||||||
|
if positive_node_id and positive_node_id in metadata.get(PROMPTS, {}):
|
||||||
|
params["prompt"] = metadata[PROMPTS][positive_node_id].get("text", "")
|
||||||
|
|||||||
@@ -642,6 +642,7 @@ NODE_EXTRACTORS = {
|
|||||||
# Sampling
|
# Sampling
|
||||||
"KSampler": SamplerExtractor,
|
"KSampler": SamplerExtractor,
|
||||||
"KSamplerAdvanced": KSamplerAdvancedExtractor,
|
"KSamplerAdvanced": KSamplerAdvancedExtractor,
|
||||||
|
"SamplerCustom": KSamplerAdvancedExtractor,
|
||||||
"SamplerCustomAdvanced": SamplerCustomAdvancedExtractor,
|
"SamplerCustomAdvanced": SamplerCustomAdvancedExtractor,
|
||||||
"TSC_KSampler": TSCKSamplerExtractor, # Efficient Nodes
|
"TSC_KSampler": TSCKSamplerExtractor, # Efficient Nodes
|
||||||
"TSC_KSamplerAdvanced": TSCKSamplerAdvancedExtractor, # Efficient Nodes
|
"TSC_KSamplerAdvanced": TSCKSamplerAdvancedExtractor, # Efficient Nodes
|
||||||
@@ -652,9 +653,11 @@ NODE_EXTRACTORS = {
|
|||||||
# Sampling Selectors
|
# Sampling Selectors
|
||||||
"KSamplerSelect": KSamplerSelectExtractor, # Add KSamplerSelect
|
"KSamplerSelect": KSamplerSelectExtractor, # Add KSamplerSelect
|
||||||
"BasicScheduler": BasicSchedulerExtractor, # Add BasicScheduler
|
"BasicScheduler": BasicSchedulerExtractor, # Add BasicScheduler
|
||||||
|
"AlignYourStepsScheduler": BasicSchedulerExtractor, # Add AlignYourStepsScheduler
|
||||||
# Loaders
|
# Loaders
|
||||||
"CheckpointLoaderSimple": CheckpointLoaderExtractor,
|
"CheckpointLoaderSimple": CheckpointLoaderExtractor,
|
||||||
"comfyLoader": CheckpointLoaderExtractor, # easy comfyLoader
|
"comfyLoader": CheckpointLoaderExtractor, # easy comfyLoader
|
||||||
|
"CheckpointLoaderSimpleWithImages": CheckpointLoaderExtractor, # CheckpointLoader|pysssss
|
||||||
"TSC_EfficientLoader": TSCCheckpointLoaderExtractor, # Efficient Nodes
|
"TSC_EfficientLoader": TSCCheckpointLoaderExtractor, # Efficient Nodes
|
||||||
"UNETLoader": UNETLoaderExtractor, # Updated to use dedicated extractor
|
"UNETLoader": UNETLoaderExtractor, # Updated to use dedicated extractor
|
||||||
"UnetLoaderGGUF": UNETLoaderExtractor, # Updated to use dedicated extractor
|
"UnetLoaderGGUF": UNETLoaderExtractor, # Updated to use dedicated extractor
|
||||||
|
|||||||
Reference in New Issue
Block a user