From 9f69822221052a37a0d378003b9d806a5523c935 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Sun, 17 Aug 2025 20:42:52 +0800 Subject: [PATCH] feat: Refactor SamplerCustom handling and enhance node extractor mappings for improved metadata processing --- py/metadata_collector/metadata_processor.py | 99 +++++++++++++-------- py/metadata_collector/node_extractors.py | 3 + 2 files changed, 64 insertions(+), 38 deletions(-) diff --git a/py/metadata_collector/metadata_processor.py b/py/metadata_collector/metadata_processor.py index add9dbfc..65ad16bd 100644 --- a/py/metadata_collector/metadata_processor.py +++ b/py/metadata_collector/metadata_processor.py @@ -339,44 +339,8 @@ class MetadataProcessor: is_custom_advanced = prompt.original_prompt[primary_sampler_id].get("class_type") == "SamplerCustomAdvanced" if is_custom_advanced: - # For SamplerCustomAdvanced, trace specific inputs - - # 1. Trace sigmas input to find BasicScheduler - scheduler_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "sigmas", "BasicScheduler", max_depth=5) - if scheduler_node_id and scheduler_node_id in metadata.get(SAMPLING, {}): - scheduler_params = metadata[SAMPLING][scheduler_node_id].get("parameters", {}) - params["steps"] = scheduler_params.get("steps") - params["scheduler"] = scheduler_params.get("scheduler") - - # 2. Trace sampler input to find KSamplerSelect - sampler_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "sampler", "KSamplerSelect", max_depth=5) - if sampler_node_id and sampler_node_id in metadata.get(SAMPLING, {}): - sampler_params = metadata[SAMPLING][sampler_node_id].get("parameters", {}) - params["sampler"] = sampler_params.get("sampler_name") - - # 3. Trace guider input for CFGGuider and CLIPTextEncode - guider_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "guider", max_depth=5) - if guider_node_id and guider_node_id in prompt.original_prompt: - # Check if the guider node is a CFGGuider - if prompt.original_prompt[guider_node_id].get("class_type") == "CFGGuider": - # Extract cfg value from the CFGGuider - if guider_node_id in metadata.get(SAMPLING, {}): - cfg_params = metadata[SAMPLING][guider_node_id].get("parameters", {}) - params["cfg_scale"] = cfg_params.get("cfg") - - # Find CLIPTextEncode for positive prompt - positive_node_id = MetadataProcessor.trace_node_input(prompt, guider_node_id, "positive", "CLIPTextEncode", max_depth=10) - if positive_node_id and positive_node_id in metadata.get(PROMPTS, {}): - params["prompt"] = metadata[PROMPTS][positive_node_id].get("text", "") - - # Find CLIPTextEncode for negative prompt - negative_node_id = MetadataProcessor.trace_node_input(prompt, guider_node_id, "negative", "CLIPTextEncode", max_depth=10) - if negative_node_id and negative_node_id in metadata.get(PROMPTS, {}): - params["negative_prompt"] = metadata[PROMPTS][negative_node_id].get("text", "") - else: - positive_node_id = MetadataProcessor.trace_node_input(prompt, guider_node_id, "conditioning", max_depth=10) - if positive_node_id and positive_node_id in metadata.get(PROMPTS, {}): - params["prompt"] = metadata[PROMPTS][positive_node_id].get("text", "") + # For SamplerCustomAdvanced, use the new handler method + MetadataProcessor.handle_custom_advanced_sampler(metadata, prompt, primary_sampler_id, params) else: # For standard samplers, match conditioning objects to prompts @@ -401,6 +365,9 @@ class MetadataProcessor: negative_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "negative", max_depth=10) if negative_node_id and negative_node_id in metadata.get(PROMPTS, {}): params["negative_prompt"] = metadata[PROMPTS][negative_node_id].get("text", "") + + # For SamplerCustom, handle any additional parameters + MetadataProcessor.handle_custom_advanced_sampler(metadata, prompt, primary_sampler_id, params) # Size extraction is same for all sampler types # Check if the sampler itself has size information (from latent_image) @@ -454,3 +421,59 @@ class MetadataProcessor: """Convert metadata to JSON string""" params = MetadataProcessor.to_dict(metadata, id) return json.dumps(params, indent=4) + + @staticmethod + def handle_custom_advanced_sampler(metadata, prompt, primary_sampler_id, params): + """ + Handle parameter extraction for SamplerCustomAdvanced nodes + + Parameters: + - metadata: The workflow metadata + - prompt: The prompt object containing node connections + - primary_sampler_id: ID of the SamplerCustomAdvanced node + - params: Parameters dictionary to update + """ + if not prompt.original_prompt or primary_sampler_id not in prompt.original_prompt: + return + + sampler_inputs = prompt.original_prompt[primary_sampler_id].get("inputs", {}) + + # 1. Trace sigmas input to find BasicScheduler (only if sigmas input exists) + if "sigmas" in sampler_inputs: + scheduler_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "sigmas", None, max_depth=5) + if scheduler_node_id and scheduler_node_id in metadata.get(SAMPLING, {}): + scheduler_params = metadata[SAMPLING][scheduler_node_id].get("parameters", {}) + params["steps"] = scheduler_params.get("steps") + params["scheduler"] = scheduler_params.get("scheduler") + + # 2. Trace sampler input to find KSamplerSelect (only if sampler input exists) + if "sampler" in sampler_inputs: + sampler_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "sampler", "KSamplerSelect", max_depth=5) + if sampler_node_id and sampler_node_id in metadata.get(SAMPLING, {}): + sampler_params = metadata[SAMPLING][sampler_node_id].get("parameters", {}) + params["sampler"] = sampler_params.get("sampler_name") + + # 3. Trace guider input for CFGGuider and CLIPTextEncode + if "guider" in sampler_inputs: + guider_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "guider", max_depth=5) + if guider_node_id and guider_node_id in prompt.original_prompt: + # Check if the guider node is a CFGGuider + if prompt.original_prompt[guider_node_id].get("class_type") == "CFGGuider": + # Extract cfg value from the CFGGuider + if guider_node_id in metadata.get(SAMPLING, {}): + cfg_params = metadata[SAMPLING][guider_node_id].get("parameters", {}) + params["cfg_scale"] = cfg_params.get("cfg") + + # Find CLIPTextEncode for positive prompt + positive_node_id = MetadataProcessor.trace_node_input(prompt, guider_node_id, "positive", "CLIPTextEncode", max_depth=10) + if positive_node_id and positive_node_id in metadata.get(PROMPTS, {}): + params["prompt"] = metadata[PROMPTS][positive_node_id].get("text", "") + + # Find CLIPTextEncode for negative prompt + negative_node_id = MetadataProcessor.trace_node_input(prompt, guider_node_id, "negative", "CLIPTextEncode", max_depth=10) + if negative_node_id and negative_node_id in metadata.get(PROMPTS, {}): + params["negative_prompt"] = metadata[PROMPTS][negative_node_id].get("text", "") + else: + positive_node_id = MetadataProcessor.trace_node_input(prompt, guider_node_id, "conditioning", max_depth=10) + if positive_node_id and positive_node_id in metadata.get(PROMPTS, {}): + params["prompt"] = metadata[PROMPTS][positive_node_id].get("text", "") diff --git a/py/metadata_collector/node_extractors.py b/py/metadata_collector/node_extractors.py index 1a3759d5..79f87ee2 100644 --- a/py/metadata_collector/node_extractors.py +++ b/py/metadata_collector/node_extractors.py @@ -642,6 +642,7 @@ NODE_EXTRACTORS = { # Sampling "KSampler": SamplerExtractor, "KSamplerAdvanced": KSamplerAdvancedExtractor, + "SamplerCustom": KSamplerAdvancedExtractor, "SamplerCustomAdvanced": SamplerCustomAdvancedExtractor, "TSC_KSampler": TSCKSamplerExtractor, # Efficient Nodes "TSC_KSamplerAdvanced": TSCKSamplerAdvancedExtractor, # Efficient Nodes @@ -652,9 +653,11 @@ NODE_EXTRACTORS = { # Sampling Selectors "KSamplerSelect": KSamplerSelectExtractor, # Add KSamplerSelect "BasicScheduler": BasicSchedulerExtractor, # Add BasicScheduler + "AlignYourStepsScheduler": BasicSchedulerExtractor, # Add AlignYourStepsScheduler # Loaders "CheckpointLoaderSimple": CheckpointLoaderExtractor, "comfyLoader": CheckpointLoaderExtractor, # easy comfyLoader + "CheckpointLoaderSimpleWithImages": CheckpointLoaderExtractor, # CheckpointLoader|pysssss "TSC_EfficientLoader": TSCCheckpointLoaderExtractor, # Efficient Nodes "UNETLoader": UNETLoaderExtractor, # Updated to use dedicated extractor "UnetLoaderGGUF": UNETLoaderExtractor, # Updated to use dedicated extractor