diff --git a/py/metadata_collector/metadata_processor.py b/py/metadata_collector/metadata_processor.py index 4cf72b73..60e5b220 100644 --- a/py/metadata_collector/metadata_processor.py +++ b/py/metadata_collector/metadata_processor.py @@ -11,7 +11,16 @@ class MetadataProcessor: primary_sampler = None primary_sampler_id = None - # First, check for KSamplerAdvanced with add_noise="enable" + # First, check for SamplerCustomAdvanced + prompt = metadata.get("current_prompt") + if prompt and prompt.original_prompt: + for node_id, node_info in prompt.original_prompt.items(): + if node_info.get("class_type") == "SamplerCustomAdvanced": + # Found a SamplerCustomAdvanced node + if node_id in metadata.get(SAMPLING, {}): + return node_id, metadata[SAMPLING][node_id] + + # Next, check for KSamplerAdvanced with add_noise="enable" for node_id, sampler_info in metadata.get(SAMPLING, {}).items(): parameters = sampler_info.get("parameters", {}) add_noise = parameters.get("add_noise") @@ -22,7 +31,7 @@ class MetadataProcessor: primary_sampler_id = node_id break - # If no KSamplerAdvanced found, fall back to traditional KSampler with denoise=1 + # If no specialized sampler found, fall back to traditional KSampler with denoise=1 if primary_sampler is None: for node_id, sampler_info in metadata.get(SAMPLING, {}).items(): parameters = sampler_info.get("parameters", {}) @@ -152,22 +161,60 @@ class MetadataProcessor: # Trace connections from the primary sampler if prompt and primary_sampler_id: - # Trace positive prompt - look specifically for CLIPTextEncode - positive_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "positive", "CLIPTextEncode", max_depth=10) - if positive_node_id and positive_node_id in metadata.get(PROMPTS, {}): - params["prompt"] = metadata[PROMPTS][positive_node_id].get("text", "") + # Check if this is a SamplerCustomAdvanced node + is_custom_advanced = False + if prompt.original_prompt and primary_sampler_id in prompt.original_prompt: + is_custom_advanced = prompt.original_prompt[primary_sampler_id].get("class_type") == "SamplerCustomAdvanced" - # Find any FluxGuidance nodes in the positive conditioning path - flux_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "positive", "FluxGuidance", max_depth=5) - if flux_node_id and flux_node_id in metadata.get(SAMPLING, {}): - flux_params = metadata[SAMPLING][flux_node_id].get("parameters", {}) - params["guidance"] = flux_params.get("guidance") + if is_custom_advanced: + # For SamplerCustomAdvanced, trace specific inputs + + # 1. Trace sigmas input to find BasicScheduler + scheduler_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "sigmas", "BasicScheduler", max_depth=5) + if scheduler_node_id and scheduler_node_id in metadata.get(SAMPLING, {}): + scheduler_params = metadata[SAMPLING][scheduler_node_id].get("parameters", {}) + params["steps"] = scheduler_params.get("steps") + params["scheduler"] = scheduler_params.get("scheduler") + + # 2. Trace sampler input to find KSamplerSelect + sampler_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "sampler", "KSamplerSelect", max_depth=5) + if sampler_node_id and sampler_node_id in metadata.get(SAMPLING, {}): + sampler_params = metadata[SAMPLING][sampler_node_id].get("parameters", {}) + params["sampler"] = sampler_params.get("sampler_name") + + # 3. Trace guider input for FluxGuidance and CLIPTextEncode + guider_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "guider", max_depth=5) + if guider_node_id: + # Look for FluxGuidance along the guider path + flux_node_id = MetadataProcessor.trace_node_input(prompt, guider_node_id, "conditioning", "FluxGuidance", max_depth=5) + if flux_node_id and flux_node_id in metadata.get(SAMPLING, {}): + flux_params = metadata[SAMPLING][flux_node_id].get("parameters", {}) + params["guidance"] = flux_params.get("guidance") + + # Find CLIPTextEncode for positive prompt (through conditioning) + positive_node_id = MetadataProcessor.trace_node_input(prompt, guider_node_id, "conditioning", "CLIPTextEncode", max_depth=10) + if positive_node_id and positive_node_id in metadata.get(PROMPTS, {}): + params["prompt"] = metadata[PROMPTS][positive_node_id].get("text", "") - # Trace negative prompt - look specifically for CLIPTextEncode - negative_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "negative", "CLIPTextEncode", max_depth=10) - if negative_node_id and negative_node_id in metadata.get(PROMPTS, {}): - params["negative_prompt"] = metadata[PROMPTS][negative_node_id].get("text", "") + else: + # Original tracing for standard samplers + # Trace positive prompt - look specifically for CLIPTextEncode + positive_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "positive", "CLIPTextEncode", max_depth=10) + if positive_node_id and positive_node_id in metadata.get(PROMPTS, {}): + params["prompt"] = metadata[PROMPTS][positive_node_id].get("text", "") + + # Find any FluxGuidance nodes in the positive conditioning path + flux_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "positive", "FluxGuidance", max_depth=5) + if flux_node_id and flux_node_id in metadata.get(SAMPLING, {}): + flux_params = metadata[SAMPLING][flux_node_id].get("parameters", {}) + params["guidance"] = flux_params.get("guidance") + + # Trace negative prompt - look specifically for CLIPTextEncode + negative_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "negative", "CLIPTextEncode", max_depth=10) + if negative_node_id and negative_node_id in metadata.get(PROMPTS, {}): + params["negative_prompt"] = metadata[PROMPTS][negative_node_id].get("text", "") + # Size extraction is same for all sampler types # Check if the sampler itself has size information (from latent_image) if primary_sampler_id in metadata.get(SIZE, {}): width = metadata[SIZE][primary_sampler_id].get("width") diff --git a/py/metadata_collector/node_extractors.py b/py/metadata_collector/node_extractors.py index 64dda557..adf679c5 100644 --- a/py/metadata_collector/node_extractors.py +++ b/py/metadata_collector/node_extractors.py @@ -257,12 +257,85 @@ class VAEDecodeExtractor(NodeMetadataExtractor): if "first_decode" not in metadata[IMAGES]: metadata[IMAGES]["first_decode"] = metadata[IMAGES][node_id] +class KSamplerSelectExtractor(NodeMetadataExtractor): + @staticmethod + def extract(node_id, inputs, outputs, metadata): + if not inputs or "sampler_name" not in inputs: + return + + sampling_params = {} + if "sampler_name" in inputs: + sampling_params["sampler_name"] = inputs["sampler_name"] + + metadata[SAMPLING][node_id] = { + "parameters": sampling_params, + "node_id": node_id + } + +class BasicSchedulerExtractor(NodeMetadataExtractor): + @staticmethod + def extract(node_id, inputs, outputs, metadata): + if not inputs: + return + + sampling_params = {} + for key in ["scheduler", "steps", "denoise"]: + if key in inputs: + sampling_params[key] = inputs[key] + + metadata[SAMPLING][node_id] = { + "parameters": sampling_params, + "node_id": node_id + } + +class SamplerCustomAdvancedExtractor(NodeMetadataExtractor): + @staticmethod + def extract(node_id, inputs, outputs, metadata): + if not inputs: + return + + sampling_params = {} + + # Handle noise.seed as seed + if "noise" in inputs and inputs["noise"] is not None and hasattr(inputs["noise"], "seed"): + noise = inputs["noise"] + sampling_params["seed"] = noise.seed + + metadata[SAMPLING][node_id] = { + "parameters": sampling_params, + "node_id": node_id + } + + # Extract latent image dimensions if available + if "latent_image" in inputs and inputs["latent_image"] is not None: + latent = inputs["latent_image"] + if isinstance(latent, dict) and "samples" in latent: + # Extract dimensions from latent tensor + samples = latent["samples"] + if hasattr(samples, "shape") and len(samples.shape) >= 3: + # Correct shape interpretation: [batch_size, channels, height/8, width/8] + # Multiply by 8 to get actual pixel dimensions + height = int(samples.shape[2] * 8) + width = int(samples.shape[3] * 8) + + if SIZE not in metadata: + metadata[SIZE] = {} + + metadata[SIZE][node_id] = { + "width": width, + "height": height, + "node_id": node_id + } + # Registry of node-specific extractors NODE_EXTRACTORS = { # Sampling "KSampler": SamplerExtractor, - "KSamplerAdvanced": KSamplerAdvancedExtractor, # Add KSamplerAdvanced - "SamplerCustomAdvanced": SamplerExtractor, # Add SamplerCustomAdvanced + "KSamplerAdvanced": KSamplerAdvancedExtractor, + "SamplerCustomAdvanced": SamplerCustomAdvancedExtractor, # Updated to use dedicated extractor + # Sampling Selectors + "KSamplerSelect": KSamplerSelectExtractor, # Add KSamplerSelect + "BasicScheduler": BasicSchedulerExtractor, # Add BasicScheduler # Loaders "CheckpointLoaderSimple": CheckpointLoaderExtractor, "UNETLoader": UNETLoaderExtractor, # Updated to use dedicated extractor