From bccabe40c01b8922594472b4a2b211bc640dacfe Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Fri, 18 Apr 2025 05:29:36 +0800 Subject: [PATCH] feat: Enhance KSamplerAdvancedExtractor to include additional sampling parameters and update metadata processing --- py/metadata_collector/metadata_processor.py | 24 ++++++++-- py/metadata_collector/node_extractors.py | 53 +++++++++++++++++++-- 2 files changed, 68 insertions(+), 9 deletions(-) diff --git a/py/metadata_collector/metadata_processor.py b/py/metadata_collector/metadata_processor.py index 9cf2cb83..4cf72b73 100644 --- a/py/metadata_collector/metadata_processor.py +++ b/py/metadata_collector/metadata_processor.py @@ -11,15 +11,28 @@ class MetadataProcessor: primary_sampler = None primary_sampler_id = None + # First, check for KSamplerAdvanced with add_noise="enable" for node_id, sampler_info in metadata.get(SAMPLING, {}).items(): parameters = sampler_info.get("parameters", {}) - denoise = parameters.get("denoise") + add_noise = parameters.get("add_noise") - # If denoise is 1.0, this is likely the primary sampler - if denoise == 1.0 or denoise == 1: + # If add_noise is "enable", this is likely the primary sampler for KSamplerAdvanced + if add_noise == "enable": primary_sampler = sampler_info primary_sampler_id = node_id break + + # If no KSamplerAdvanced found, fall back to traditional KSampler with denoise=1 + if primary_sampler is None: + for node_id, sampler_info in metadata.get(SAMPLING, {}).items(): + parameters = sampler_info.get("parameters", {}) + denoise = parameters.get("denoise") + + # If denoise is 1.0, this is likely the primary sampler + if denoise == 1.0 or denoise == 1: + primary_sampler = sampler_info + primary_sampler_id = node_id + break return primary_sampler_id, primary_sampler @@ -109,6 +122,7 @@ class MetadataProcessor: "cfg_scale": None, "guidance": None, # Add guidance parameter "sampler": None, + "scheduler": None, "checkpoint": None, "loras": "", "size": None, @@ -129,10 +143,12 @@ class MetadataProcessor: if primary_sampler: # Extract sampling parameters sampling_params = primary_sampler.get("parameters", {}) - params["seed"] = sampling_params.get("seed") + # Handle both seed and noise_seed + params["seed"] = sampling_params.get("seed") if sampling_params.get("seed") is not None else sampling_params.get("noise_seed") params["steps"] = sampling_params.get("steps") params["cfg_scale"] = sampling_params.get("cfg") params["sampler"] = sampling_params.get("sampler_name") + params["scheduler"] = sampling_params.get("scheduler") # Trace connections from the primary sampler if prompt and primary_sampler_id: diff --git a/py/metadata_collector/node_extractors.py b/py/metadata_collector/node_extractors.py index bdda89ad..210ab29e 100644 --- a/py/metadata_collector/node_extractors.py +++ b/py/metadata_collector/node_extractors.py @@ -85,6 +85,43 @@ class SamplerExtractor(NodeMetadataExtractor): "node_id": node_id } +class KSamplerAdvancedExtractor(NodeMetadataExtractor): + @staticmethod + def extract(node_id, inputs, outputs, metadata): + if not inputs: + return + + sampling_params = {} + for key in ["noise_seed", "steps", "cfg", "sampler_name", "scheduler", "add_noise"]: + if key in inputs: + sampling_params[key] = inputs[key] + + metadata[SAMPLING][node_id] = { + "parameters": sampling_params, + "node_id": node_id + } + + # Extract latent image dimensions if available + if "latent_image" in inputs and inputs["latent_image"] is not None: + latent = inputs["latent_image"] + if isinstance(latent, dict) and "samples" in latent: + # Extract dimensions from latent tensor + samples = latent["samples"] + if hasattr(samples, "shape") and len(samples.shape) >= 3: + # Correct shape interpretation: [batch_size, channels, height/8, width/8] + # Multiply by 8 to get actual pixel dimensions + height = int(samples.shape[2] * 8) + width = int(samples.shape[3] * 8) + + if SIZE not in metadata: + metadata[SIZE] = {} + + metadata[SIZE][node_id] = { + "width": width, + "height": height, + "node_id": node_id + } + class LoraLoaderExtractor(NodeMetadataExtractor): @staticmethod def extract(node_id, inputs, outputs, metadata): @@ -201,14 +238,20 @@ class UNETLoaderExtractor(NodeMetadataExtractor): # Registry of node-specific extractors NODE_EXTRACTORS = { - "CheckpointLoaderSimple": CheckpointLoaderExtractor, - "CLIPTextEncode": CLIPTextEncodeExtractor, + # Sampling "KSampler": SamplerExtractor, - "LoraLoader": LoraLoaderExtractor, - "EmptyLatentImage": ImageSizeExtractor, - "LoraManagerLoader": LoraLoaderManagerExtractor, + "KSamplerAdvanced": KSamplerAdvancedExtractor, # Add KSamplerAdvanced "SamplerCustomAdvanced": SamplerExtractor, # Add SamplerCustomAdvanced + # Loaders + "CheckpointLoaderSimple": CheckpointLoaderExtractor, "UNETLoader": UNETLoaderExtractor, # Updated to use dedicated extractor + "LoraLoader": LoraLoaderExtractor, + "LoraManagerLoader": LoraLoaderManagerExtractor, + # Conditioning + "CLIPTextEncode": CLIPTextEncodeExtractor, + # Latent + "EmptyLatentImage": ImageSizeExtractor, + # Flux "FluxGuidance": FluxGuidanceExtractor, # Add FluxGuidance # Add other nodes as needed }