diff --git a/py/metadata_collector/constants.py b/py/metadata_collector/constants.py index 9a3ba95f..b38f010a 100644 --- a/py/metadata_collector/constants.py +++ b/py/metadata_collector/constants.py @@ -1,7 +1,5 @@ """Constants used by the metadata collector""" -# Metadata collection constants - # Metadata categories MODELS = "models" PROMPTS = "prompts" @@ -9,6 +7,7 @@ SAMPLING = "sampling" LORAS = "loras" SIZE = "size" IMAGES = "images" +IS_SAMPLER = "is_sampler" # New constant to mark sampler nodes # Complete list of categories to track METADATA_CATEGORIES = [MODELS, PROMPTS, SAMPLING, LORAS, SIZE, IMAGES] diff --git a/py/metadata_collector/metadata_processor.py b/py/metadata_collector/metadata_processor.py index 48b1d246..eef0fb0a 100644 --- a/py/metadata_collector/metadata_processor.py +++ b/py/metadata_collector/metadata_processor.py @@ -4,33 +4,109 @@ import sys # Check if running in standalone mode standalone_mode = 'nodes' not in sys.modules -from .constants import MODELS, PROMPTS, SAMPLING, LORAS, SIZE +from .constants import MODELS, PROMPTS, SAMPLING, LORAS, SIZE, IS_SAMPLER class MetadataProcessor: """Process and format collected metadata""" @staticmethod - def find_primary_sampler(metadata): - """Find the primary KSampler node (with highest denoise value)""" + def find_primary_sampler(metadata, downstream_id=None): + """ + Find the primary KSampler node that executed before the given downstream node + + Parameters: + - metadata: The workflow metadata + - downstream_id: Optional ID of a downstream node to help identify the specific primary sampler + """ + # If we have a downstream_id and execution_order, use it to narrow down potential samplers + if downstream_id and "execution_order" in metadata: + execution_order = metadata["execution_order"] + + # Find the index of the downstream node in the execution order + if downstream_id in execution_order: + downstream_index = execution_order.index(downstream_id) + + # Extract all sampler nodes that executed before the downstream node + candidate_samplers = {} + for i in range(downstream_index): + node_id = execution_order[i] + # Use IS_SAMPLER flag to identify true sampler nodes + if node_id in metadata.get(SAMPLING, {}) and metadata[SAMPLING][node_id].get(IS_SAMPLER, False): + candidate_samplers[node_id] = metadata[SAMPLING][node_id] + + # If we found candidate samplers, apply primary sampler logic to these candidates only + if candidate_samplers: + # Collect potential primary samplers based on different criteria + custom_advanced_samplers = [] + advanced_add_noise_samplers = [] + high_denoise_samplers = [] + max_denoise = -1 + high_denoise_id = None + + # First, check for SamplerCustomAdvanced among candidates + prompt = metadata.get("current_prompt") + if prompt and prompt.original_prompt: + for node_id in candidate_samplers: + node_info = prompt.original_prompt.get(node_id, {}) + if node_info.get("class_type") == "SamplerCustomAdvanced": + custom_advanced_samplers.append(node_id) + + # Next, check for KSamplerAdvanced with add_noise="enable" among candidates + for node_id, sampler_info in candidate_samplers.items(): + parameters = sampler_info.get("parameters", {}) + add_noise = parameters.get("add_noise") + if add_noise == "enable": + advanced_add_noise_samplers.append(node_id) + + # Find the sampler with highest denoise value among candidates + for node_id, sampler_info in candidate_samplers.items(): + parameters = sampler_info.get("parameters", {}) + denoise = parameters.get("denoise") + if denoise is not None and denoise > max_denoise: + max_denoise = denoise + high_denoise_id = node_id + + if high_denoise_id: + high_denoise_samplers.append(high_denoise_id) + + # Combine all potential primary samplers + potential_samplers = custom_advanced_samplers + advanced_add_noise_samplers + high_denoise_samplers + + # Find the most recent potential primary sampler (closest to downstream node) + for i in range(downstream_index - 1, -1, -1): + node_id = execution_order[i] + if node_id in potential_samplers: + return node_id, candidate_samplers[node_id] + + # If no potential sampler found from our criteria, return the most recent sampler + if candidate_samplers: + for i in range(downstream_index - 1, -1, -1): + node_id = execution_order[i] + if node_id in candidate_samplers: + return node_id, candidate_samplers[node_id] + + # If no downstream_id provided or no suitable sampler found, fall back to original logic primary_sampler = None primary_sampler_id = None - max_denoise = -1 # Track the highest denoise value + max_denoise = -1 # First, check for SamplerCustomAdvanced prompt = metadata.get("current_prompt") if prompt and prompt.original_prompt: for node_id, node_info in prompt.original_prompt.items(): if node_info.get("class_type") == "SamplerCustomAdvanced": - # Found a SamplerCustomAdvanced node - if node_id in metadata.get(SAMPLING, {}): + # Check if the node is in SAMPLING and has IS_SAMPLER flag + if node_id in metadata.get(SAMPLING, {}) and metadata[SAMPLING][node_id].get(IS_SAMPLER, False): return node_id, metadata[SAMPLING][node_id] - # Next, check for KSamplerAdvanced with add_noise="enable" + # Next, check for KSamplerAdvanced with add_noise="enable" using IS_SAMPLER flag for node_id, sampler_info in metadata.get(SAMPLING, {}).items(): + # Skip if not marked as a sampler + if not sampler_info.get(IS_SAMPLER, False): + continue + parameters = sampler_info.get("parameters", {}) add_noise = parameters.get("add_noise") - - # If add_noise is "enable", this is likely the primary sampler for KSamplerAdvanced if add_noise == "enable": primary_sampler = sampler_info primary_sampler_id = node_id @@ -39,10 +115,12 @@ class MetadataProcessor: # If no specialized sampler found, find the sampler with highest denoise value if primary_sampler is None: for node_id, sampler_info in metadata.get(SAMPLING, {}).items(): + # Skip if not marked as a sampler + if not sampler_info.get(IS_SAMPLER, False): + continue + parameters = sampler_info.get("parameters", {}) denoise = parameters.get("denoise") - - # If denoise exists and is higher than current max, use this sampler if denoise is not None and denoise > max_denoise: max_denoise = denoise primary_sampler = sampler_info @@ -74,13 +152,18 @@ class MetadataProcessor: current_node_id = node_id current_input = input_name + # If we're just tracing to origin (no target_class), keep track of the last valid node + last_valid_node = None + while current_depth < max_depth: if current_node_id not in prompt.original_prompt: - return None + return last_valid_node if not target_class else None node_inputs = prompt.original_prompt[current_node_id].get("inputs", {}) if current_input not in node_inputs: - return None + # We've reached a node without the specified input - this is our origin node + # if we're not looking for a specific target_class + return current_node_id if not target_class else None input_value = node_inputs[current_input] # Input connections are formatted as [node_id, output_index] @@ -91,9 +174,9 @@ class MetadataProcessor: if target_class and prompt.original_prompt[found_node_id].get("class_type") == target_class: return found_node_id - # If we're not looking for a specific class or haven't found it yet + # If we're not looking for a specific class, update the last valid node if not target_class: - return found_node_id + last_valid_node = found_node_id # Continue tracing through intermediate nodes current_node_id = found_node_id @@ -101,16 +184,17 @@ class MetadataProcessor: if "conditioning" in prompt.original_prompt[current_node_id].get("inputs", {}): current_input = "conditioning" else: - # If there's no "conditioning" input, we can't trace further + # If there's no "conditioning" input, return the current node + # if we're not looking for a specific target_class return found_node_id if not target_class else None else: # We've reached a node with no further connections - return None + return last_valid_node if not target_class else None current_depth += 1 # If we've reached max depth without finding target_class - return None + return last_valid_node if not target_class else None @staticmethod def find_primary_checkpoint(metadata): @@ -126,8 +210,14 @@ class MetadataProcessor: return None @staticmethod - def extract_generation_params(metadata): - """Extract generation parameters from metadata using node relationships""" + def extract_generation_params(metadata, id=None): + """ + Extract generation parameters from metadata using node relationships + + Parameters: + - metadata: The workflow metadata + - id: Optional ID of a downstream node to help identify the specific primary sampler + """ params = { "prompt": "", "negative_prompt": "", @@ -147,13 +237,21 @@ class MetadataProcessor: prompt = metadata.get("current_prompt") # Find the primary KSampler node - primary_sampler_id, primary_sampler = MetadataProcessor.find_primary_sampler(metadata) + primary_sampler_id, primary_sampler = MetadataProcessor.find_primary_sampler(metadata, id) + print(f"Primary sampler ID: {primary_sampler_id}, downstream ID: {id}") # Directly get checkpoint from metadata instead of tracing checkpoint = MetadataProcessor.find_primary_checkpoint(metadata) if checkpoint: params["checkpoint"] = checkpoint + # Check if guidance parameter exists in any sampling node + for node_id, sampler_info in metadata.get(SAMPLING, {}).items(): + parameters = sampler_info.get("parameters", {}) + if "guidance" in parameters and parameters["guidance"] is not None: + params["guidance"] = parameters["guidance"] + break + if primary_sampler: # Extract sampling parameters sampling_params = primary_sampler.get("parameters", {}) @@ -187,7 +285,7 @@ class MetadataProcessor: sampler_params = metadata[SAMPLING][sampler_node_id].get("parameters", {}) params["sampler"] = sampler_params.get("sampler_name") - # 3. Trace guider input for CFGGuider, FluxGuidance and CLIPTextEncode + # 3. Trace guider input for CFGGuider and CLIPTextEncode guider_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "guider", max_depth=5) if guider_node_id and guider_node_id in prompt.original_prompt: # Check if the guider node is a CFGGuider @@ -207,21 +305,14 @@ class MetadataProcessor: if negative_node_id and negative_node_id in metadata.get(PROMPTS, {}): params["negative_prompt"] = metadata[PROMPTS][negative_node_id].get("text", "") else: - # Look for FluxGuidance along the guider path - flux_node_id = MetadataProcessor.trace_node_input(prompt, guider_node_id, "conditioning", "FluxGuidance", max_depth=5) - if flux_node_id and flux_node_id in metadata.get(SAMPLING, {}): - flux_params = metadata[SAMPLING][flux_node_id].get("parameters", {}) - params["guidance"] = flux_params.get("guidance") - - # Find CLIPTextEncode for positive prompt (through conditioning) - positive_node_id = MetadataProcessor.trace_node_input(prompt, guider_node_id, "conditioning", "CLIPTextEncode", max_depth=10) + positive_node_id = MetadataProcessor.trace_node_input(prompt, guider_node_id, "conditioning", max_depth=10) if positive_node_id and positive_node_id in metadata.get(PROMPTS, {}): params["prompt"] = metadata[PROMPTS][positive_node_id].get("text", "") else: # Original tracing for standard samplers # Trace positive prompt - look specifically for CLIPTextEncode - positive_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "positive", "CLIPTextEncode", max_depth=10) + positive_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "positive", max_depth=10) if positive_node_id and positive_node_id in metadata.get(PROMPTS, {}): params["prompt"] = metadata[PROMPTS][positive_node_id].get("text", "") else: @@ -229,21 +320,9 @@ class MetadataProcessor: positive_flux_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "positive", "CLIPTextEncodeFlux", max_depth=10) if positive_flux_node_id and positive_flux_node_id in metadata.get(PROMPTS, {}): params["prompt"] = metadata[PROMPTS][positive_flux_node_id].get("text", "") - - # Also extract guidance value if present in the sampling data - if positive_flux_node_id in metadata.get(SAMPLING, {}): - flux_params = metadata[SAMPLING][positive_flux_node_id].get("parameters", {}) - if "guidance" in flux_params: - params["guidance"] = flux_params.get("guidance") - - # Find any FluxGuidance nodes in the positive conditioning path - flux_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "positive", "FluxGuidance", max_depth=5) - if flux_node_id and flux_node_id in metadata.get(SAMPLING, {}): - flux_params = metadata[SAMPLING][flux_node_id].get("parameters", {}) - params["guidance"] = flux_params.get("guidance") # Trace negative prompt - look specifically for CLIPTextEncode - negative_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "negative", "CLIPTextEncode", max_depth=10) + negative_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "negative", max_depth=10) if negative_node_id and negative_node_id in metadata.get(PROMPTS, {}): params["negative_prompt"] = metadata[PROMPTS][negative_node_id].get("text", "") @@ -273,13 +352,19 @@ class MetadataProcessor: return params @staticmethod - def to_dict(metadata): - """Convert extracted metadata to the ComfyUI output.json format""" + def to_dict(metadata, id=None): + """ + Convert extracted metadata to the ComfyUI output.json format + + Parameters: + - metadata: The workflow metadata + - id: Optional ID of a downstream node to help identify the specific primary sampler + """ if standalone_mode: # Return empty dictionary in standalone mode return {} - params = MetadataProcessor.extract_generation_params(metadata) + params = MetadataProcessor.extract_generation_params(metadata, id) # Convert all values to strings to match output.json format for key in params: @@ -289,7 +374,7 @@ class MetadataProcessor: return params @staticmethod - def to_json(metadata): + def to_json(metadata, id=None): """Convert metadata to JSON string""" - params = MetadataProcessor.to_dict(metadata) + params = MetadataProcessor.to_dict(metadata, id) return json.dumps(params, indent=4) diff --git a/py/metadata_collector/node_extractors.py b/py/metadata_collector/node_extractors.py index 7abc50e1..8e37cd00 100644 --- a/py/metadata_collector/node_extractors.py +++ b/py/metadata_collector/node_extractors.py @@ -1,6 +1,6 @@ import os -from .constants import MODELS, PROMPTS, SAMPLING, LORAS, SIZE, IMAGES +from .constants import MODELS, PROMPTS, SAMPLING, LORAS, SIZE, IMAGES, IS_SAMPLER class NodeMetadataExtractor: @@ -61,7 +61,8 @@ class SamplerExtractor(NodeMetadataExtractor): metadata[SAMPLING][node_id] = { "parameters": sampling_params, - "node_id": node_id + "node_id": node_id, + IS_SAMPLER: True # Add sampler flag } # Extract latent image dimensions if available @@ -98,7 +99,8 @@ class KSamplerAdvancedExtractor(NodeMetadataExtractor): metadata[SAMPLING][node_id] = { "parameters": sampling_params, - "node_id": node_id + "node_id": node_id, + IS_SAMPLER: True # Add sampler flag } # Extract latent image dimensions if available @@ -269,7 +271,8 @@ class KSamplerSelectExtractor(NodeMetadataExtractor): metadata[SAMPLING][node_id] = { "parameters": sampling_params, - "node_id": node_id + "node_id": node_id, + IS_SAMPLER: False # Mark as non-primary sampler } class BasicSchedulerExtractor(NodeMetadataExtractor): @@ -285,7 +288,8 @@ class BasicSchedulerExtractor(NodeMetadataExtractor): metadata[SAMPLING][node_id] = { "parameters": sampling_params, - "node_id": node_id + "node_id": node_id, + IS_SAMPLER: False # Mark as non-primary sampler } class SamplerCustomAdvancedExtractor(NodeMetadataExtractor): @@ -303,7 +307,8 @@ class SamplerCustomAdvancedExtractor(NodeMetadataExtractor): metadata[SAMPLING][node_id] = { "parameters": sampling_params, - "node_id": node_id + "node_id": node_id, + IS_SAMPLER: True # Add sampler flag } # Extract latent image dimensions if available @@ -338,11 +343,20 @@ class CLIPTextEncodeFluxExtractor(NodeMetadataExtractor): clip_l_text = inputs.get("clip_l", "") t5xxl_text = inputs.get("t5xxl", "") - # Create JSON string with T5 content first, then CLIP-L - combined_text = json.dumps({ - "T5": t5xxl_text, - "CLIP-L": clip_l_text - }) + # If both are empty, use empty string + if not clip_l_text and not t5xxl_text: + combined_text = "" + # If one is empty, use the non-empty one + elif not clip_l_text: + combined_text = t5xxl_text + elif not t5xxl_text: + combined_text = clip_l_text + # If both have content, use JSON format + else: + combined_text = json.dumps({ + "T5": t5xxl_text, + "CLIP-L": clip_l_text + }) metadata[PROMPTS][node_id] = { "text": combined_text, @@ -391,11 +405,13 @@ NODE_EXTRACTORS = { # Loaders "CheckpointLoaderSimple": CheckpointLoaderExtractor, "UNETLoader": UNETLoaderExtractor, # Updated to use dedicated extractor + "UnetLoaderGGUF": UNETLoaderExtractor, # Updated to use dedicated extractor "LoraLoader": LoraLoaderExtractor, "LoraManagerLoader": LoraLoaderManagerExtractor, # Conditioning "CLIPTextEncode": CLIPTextEncodeExtractor, "CLIPTextEncodeFlux": CLIPTextEncodeFluxExtractor, # Add CLIPTextEncodeFlux + "WAS_Text_to_Conditioning": CLIPTextEncodeExtractor, # Latent "EmptyLatentImage": ImageSizeExtractor, # Flux diff --git a/py/nodes/debug_metadata.py b/py/nodes/debug_metadata.py index ee13e3d8..839d1431 100644 --- a/py/nodes/debug_metadata.py +++ b/py/nodes/debug_metadata.py @@ -14,20 +14,23 @@ class DebugMetadata: "required": { "images": ("IMAGE",), }, + "hidden": { + "id": "UNIQUE_ID", + }, } RETURN_TYPES = ("STRING",) RETURN_NAMES = ("metadata_json",) FUNCTION = "process_metadata" - def process_metadata(self, images): + def process_metadata(self, images, id): try: # Get the current execution context's metadata from ..metadata_collector import get_metadata metadata = get_metadata() # Use the MetadataProcessor to convert it to JSON string - metadata_json = MetadataProcessor.to_json(metadata) + metadata_json = MetadataProcessor.to_json(metadata, id) return (metadata_json,) except Exception as e: diff --git a/py/nodes/save_image.py b/py/nodes/save_image.py index 50bdf226..69a68847 100644 --- a/py/nodes/save_image.py +++ b/py/nodes/save_image.py @@ -41,6 +41,7 @@ class SaveImage: "add_counter_to_filename": ("BOOLEAN", {"default": True}), }, "hidden": { + "id": "UNIQUE_ID", "prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO", }, @@ -300,14 +301,14 @@ class SaveImage: return filename - def save_images(self, images, filename_prefix, file_format, prompt=None, extra_pnginfo=None, + def save_images(self, images, filename_prefix, file_format, id, prompt=None, extra_pnginfo=None, lossless_webp=True, quality=100, embed_workflow=False, add_counter_to_filename=True): """Save images with metadata""" results = [] - + # Get metadata using the metadata collector raw_metadata = get_metadata() - metadata_dict = MetadataProcessor.to_dict(raw_metadata) + metadata_dict = MetadataProcessor.to_dict(raw_metadata, id) # Get or create metadata asynchronously metadata = asyncio.run(self.format_metadata(metadata_dict)) @@ -399,7 +400,7 @@ class SaveImage: return results - def process_image(self, images, filename_prefix="ComfyUI", file_format="png", prompt=None, extra_pnginfo=None, + def process_image(self, images, id, filename_prefix="ComfyUI", file_format="png", prompt=None, extra_pnginfo=None, lossless_webp=True, quality=100, embed_workflow=False, add_counter_to_filename=True): """Process and save image with metadata""" # Make sure the output directory exists @@ -416,6 +417,7 @@ class SaveImage: images, filename_prefix, file_format, + id, prompt, extra_pnginfo, lossless_webp,