From db7f57a5a4201f6b7cb21fb23004e0b4b5f3e900 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Wed, 16 Jul 2025 08:08:11 +0800 Subject: [PATCH] feat: Refactor sampler extractors to reduce redundancy and improve maintainability. Add support for KSampler [pipe] from comfyui-impact-pack and comfyui-inspire-pack --- py/metadata_collector/node_extractors.py | 187 ++++++++++++++--------- 1 file changed, 112 insertions(+), 75 deletions(-) diff --git a/py/metadata_collector/node_extractors.py b/py/metadata_collector/node_extractors.py index 4e3a79d2..1a3759d5 100644 --- a/py/metadata_collector/node_extractors.py +++ b/py/metadata_collector/node_extractors.py @@ -117,15 +117,15 @@ class CLIPTextEncodeExtractor(NodeMetadataExtractor): if isinstance(outputs[0], tuple) and len(outputs[0]) > 0: conditioning = outputs[0][0] metadata[PROMPTS][node_id]["conditioning"] = conditioning - -class SamplerExtractor(NodeMetadataExtractor): + +# Base Sampler Extractor to reduce code redundancy +class BaseSamplerExtractor(NodeMetadataExtractor): + """Base extractor for sampler nodes with common functionality""" @staticmethod - def extract(node_id, inputs, outputs, metadata): - if not inputs: - return - + def extract_sampling_params(node_id, inputs, metadata, param_keys): + """Extract sampling parameters from inputs""" sampling_params = {} - for key in ["seed", "steps", "cfg", "sampler_name", "scheduler", "denoise"]: + for key in param_keys: if key in inputs: sampling_params[key] = inputs[key] @@ -134,7 +134,10 @@ class SamplerExtractor(NodeMetadataExtractor): "node_id": node_id, IS_SAMPLER: True # Add sampler flag } - + + @staticmethod + def extract_conditioning(node_id, inputs, metadata): + """Extract conditioning objects from inputs""" # Store the conditioning objects directly in metadata for later matching pos_conditioning = inputs.get("positive", None) neg_conditioning = inputs.get("negative", None) @@ -146,7 +149,10 @@ class SamplerExtractor(NodeMetadataExtractor): metadata[PROMPTS][node_id]["pos_conditioning"] = pos_conditioning metadata[PROMPTS][node_id]["neg_conditioning"] = neg_conditioning - + + @staticmethod + def extract_latent_dimensions(node_id, inputs, metadata): + """Extract dimensions from latent image""" # Extract latent image dimensions if available if "latent_image" in inputs and inputs["latent_image"] is not None: latent = inputs["latent_image"] @@ -167,59 +173,106 @@ class SamplerExtractor(NodeMetadataExtractor): "height": height, "node_id": node_id } - -class KSamplerAdvancedExtractor(NodeMetadataExtractor): + +class SamplerExtractor(BaseSamplerExtractor): @staticmethod def extract(node_id, inputs, outputs, metadata): if not inputs: return - sampling_params = {} - for key in ["noise_seed", "steps", "cfg", "sampler_name", "scheduler", "add_noise"]: - if key in inputs: - sampling_params[key] = inputs[key] - - metadata[SAMPLING][node_id] = { - "parameters": sampling_params, - "node_id": node_id, - IS_SAMPLER: True # Add sampler flag - } + # Extract common sampling parameters + BaseSamplerExtractor.extract_sampling_params( + node_id, inputs, metadata, + ["seed", "steps", "cfg", "sampler_name", "scheduler", "denoise"] + ) - # Store the conditioning objects directly in metadata for later matching - pos_conditioning = inputs.get("positive", None) - neg_conditioning = inputs.get("negative", None) + # Extract conditioning objects + BaseSamplerExtractor.extract_conditioning(node_id, inputs, metadata) + + # Extract latent dimensions + BaseSamplerExtractor.extract_latent_dimensions(node_id, inputs, metadata) - # Save conditioning objects in metadata for later matching - if pos_conditioning is not None or neg_conditioning is not None: - if node_id not in metadata[PROMPTS]: - metadata[PROMPTS][node_id] = {"node_id": node_id} +class KSamplerAdvancedExtractor(BaseSamplerExtractor): + @staticmethod + def extract(node_id, inputs, outputs, metadata): + if not inputs: + return - metadata[PROMPTS][node_id]["pos_conditioning"] = pos_conditioning - metadata[PROMPTS][node_id]["neg_conditioning"] = neg_conditioning + # Extract common sampling parameters + BaseSamplerExtractor.extract_sampling_params( + node_id, inputs, metadata, + ["noise_seed", "steps", "cfg", "sampler_name", "scheduler", "add_noise"] + ) - # Extract latent image dimensions if available - if "latent_image" in inputs and inputs["latent_image"] is not None: - latent = inputs["latent_image"] - if isinstance(latent, dict) and "samples" in latent: - # Extract dimensions from latent tensor - samples = latent["samples"] - if hasattr(samples, "shape") and len(samples.shape) >= 3: - # Correct shape interpretation: [batch_size, channels, height/8, width/8] - # Multiply by 8 to get actual pixel dimensions - height = int(samples.shape[2] * 8) - width = int(samples.shape[3] * 8) - - if SIZE not in metadata: - metadata[SIZE] = {} - - metadata[SIZE][node_id] = { - "width": width, - "height": height, - "node_id": node_id - } + # Extract conditioning objects + BaseSamplerExtractor.extract_conditioning(node_id, inputs, metadata) + + # Extract latent dimensions + BaseSamplerExtractor.extract_latent_dimensions(node_id, inputs, metadata) + +class KSamplerBasicPipeExtractor(BaseSamplerExtractor): + """Extractor for KSamplerBasicPipe and KSampler_inspire_pipe nodes""" + @staticmethod + def extract(node_id, inputs, outputs, metadata): + if not inputs: + return + + # Extract common sampling parameters + BaseSamplerExtractor.extract_sampling_params( + node_id, inputs, metadata, + ["seed", "steps", "cfg", "sampler_name", "scheduler", "denoise"] + ) + + # Extract conditioning objects from basic_pipe + if "basic_pipe" in inputs and inputs["basic_pipe"] is not None: + basic_pipe = inputs["basic_pipe"] + # Typically, basic_pipe structure is (model, clip, vae, positive, negative) + if isinstance(basic_pipe, tuple) and len(basic_pipe) >= 5: + pos_conditioning = basic_pipe[3] # positive is at index 3 + neg_conditioning = basic_pipe[4] # negative is at index 4 + + # Save conditioning objects in metadata + if node_id not in metadata[PROMPTS]: + metadata[PROMPTS][node_id] = {"node_id": node_id} + + metadata[PROMPTS][node_id]["pos_conditioning"] = pos_conditioning + metadata[PROMPTS][node_id]["neg_conditioning"] = neg_conditioning + + # Extract latent dimensions + BaseSamplerExtractor.extract_latent_dimensions(node_id, inputs, metadata) + +class KSamplerAdvancedBasicPipeExtractor(BaseSamplerExtractor): + """Extractor for KSamplerAdvancedBasicPipe nodes""" + @staticmethod + def extract(node_id, inputs, outputs, metadata): + if not inputs: + return + + # Extract common sampling parameters + BaseSamplerExtractor.extract_sampling_params( + node_id, inputs, metadata, + ["noise_seed", "steps", "cfg", "sampler_name", "scheduler", "add_noise"] + ) + + # Extract conditioning objects from basic_pipe + if "basic_pipe" in inputs and inputs["basic_pipe"] is not None: + basic_pipe = inputs["basic_pipe"] + # Typically, basic_pipe structure is (model, clip, vae, positive, negative) + if isinstance(basic_pipe, tuple) and len(basic_pipe) >= 5: + pos_conditioning = basic_pipe[3] # positive is at index 3 + neg_conditioning = basic_pipe[4] # negative is at index 4 + + # Save conditioning objects in metadata + if node_id not in metadata[PROMPTS]: + metadata[PROMPTS][node_id] = {"node_id": node_id} + + metadata[PROMPTS][node_id]["pos_conditioning"] = pos_conditioning + metadata[PROMPTS][node_id]["neg_conditioning"] = neg_conditioning + + # Extract latent dimensions + BaseSamplerExtractor.extract_latent_dimensions(node_id, inputs, metadata) class TSCSamplerBaseExtractor(NodeMetadataExtractor): - """Base extractor for handling TSC sampler node outputs""" @staticmethod def extract(node_id, inputs, outputs, metadata): # Store vae_decode setting for later use in update @@ -273,7 +326,6 @@ class TSCSamplerBaseExtractor(NodeMetadataExtractor): metadata[IMAGES]["first_decode"] = metadata[IMAGES][node_id] class TSCKSamplerExtractor(SamplerExtractor, TSCSamplerBaseExtractor): - """Extractor for TSC_KSampler nodes""" @staticmethod def extract(node_id, inputs, outputs, metadata): # Call parent extract methods @@ -284,11 +336,10 @@ class TSCKSamplerExtractor(SamplerExtractor, TSCSamplerBaseExtractor): class TSCKSamplerAdvancedExtractor(KSamplerAdvancedExtractor, TSCSamplerBaseExtractor): - """Extractor for TSC_KSamplerAdvanced nodes""" @staticmethod def extract(node_id, inputs, outputs, metadata): # Call parent extract methods - SamplerExtractor.extract(node_id, inputs, outputs, metadata) + KSamplerAdvancedExtractor.extract(node_id, inputs, outputs, metadata) TSCSamplerBaseExtractor.extract(node_id, inputs, outputs, metadata) # Update method is inherited from TSCSamplerBaseExtractor @@ -461,7 +512,7 @@ class BasicSchedulerExtractor(NodeMetadataExtractor): IS_SAMPLER: False # Mark as non-primary sampler } -class SamplerCustomAdvancedExtractor(NodeMetadataExtractor): +class SamplerCustomAdvancedExtractor(BaseSamplerExtractor): @staticmethod def extract(node_id, inputs, outputs, metadata): if not inputs: @@ -480,26 +531,8 @@ class SamplerCustomAdvancedExtractor(NodeMetadataExtractor): IS_SAMPLER: True # Add sampler flag } - # Extract latent image dimensions if available - if "latent_image" in inputs and inputs["latent_image"] is not None: - latent = inputs["latent_image"] - if isinstance(latent, dict) and "samples" in latent: - # Extract dimensions from latent tensor - samples = latent["samples"] - if hasattr(samples, "shape") and len(samples.shape) >= 3: - # Correct shape interpretation: [batch_size, channels, height/8, width/8] - # Multiply by 8 to get actual pixel dimensions - height = int(samples.shape[2] * 8) - width = int(samples.shape[3] * 8) - - if SIZE not in metadata: - metadata[SIZE] = {} - - metadata[SIZE][node_id] = { - "width": width, - "height": height, - "node_id": node_id - } + # Extract latent dimensions + BaseSamplerExtractor.extract_latent_dimensions(node_id, inputs, metadata) import json @@ -612,6 +645,10 @@ NODE_EXTRACTORS = { "SamplerCustomAdvanced": SamplerCustomAdvancedExtractor, "TSC_KSampler": TSCKSamplerExtractor, # Efficient Nodes "TSC_KSamplerAdvanced": TSCKSamplerAdvancedExtractor, # Efficient Nodes + "KSamplerBasicPipe": KSamplerBasicPipeExtractor, # comfyui-impact-pack + "KSamplerAdvancedBasicPipe": KSamplerAdvancedBasicPipeExtractor, # comfyui-impact-pack + "KSampler_inspire_pipe": KSamplerBasicPipeExtractor, # comfyui-inspire-pack + "KSamplerAdvanced_inspire_pipe": KSamplerAdvancedBasicPipeExtractor, # comfyui-inspire-pack # Sampling Selectors "KSamplerSelect": KSamplerSelectExtractor, # Add KSamplerSelect "BasicScheduler": BasicSchedulerExtractor, # Add BasicScheduler