From 4889955ecf33dba26d0de692e05098a2c2c2d6c8 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Fri, 20 Jun 2025 00:04:02 +0800 Subject: [PATCH] feat: Add conditioning matching to prompts and update metadata handling in node extractors. See #235 --- py/metadata_collector/metadata_processor.py | 48 ++++++++++++++++++++- py/metadata_collector/node_extractors.py | 39 +++++++++++++++++ 2 files changed, 85 insertions(+), 2 deletions(-) diff --git a/py/metadata_collector/metadata_processor.py b/py/metadata_collector/metadata_processor.py index 86df086a..e4abcc26 100644 --- a/py/metadata_collector/metadata_processor.py +++ b/py/metadata_collector/metadata_processor.py @@ -209,6 +209,44 @@ class MetadataProcessor: return None + @staticmethod + def match_conditioning_to_prompts(metadata, sampler_id): + """ + Match conditioning objects from a sampler to prompts in metadata + + Parameters: + - metadata: The workflow metadata + - sampler_id: ID of the sampler node to match + + Returns: + - Dictionary with 'prompt' and 'negative_prompt' if found + """ + result = { + "prompt": "", + "negative_prompt": "" + } + + # Check if we have stored conditioning objects for this sampler + if sampler_id in metadata.get(PROMPTS, {}) and ( + "pos_conditioning" in metadata[PROMPTS][sampler_id] or + "neg_conditioning" in metadata[PROMPTS][sampler_id]): + + pos_conditioning = metadata[PROMPTS][sampler_id].get("pos_conditioning") + neg_conditioning = metadata[PROMPTS][sampler_id].get("neg_conditioning") + + # Try to match conditioning objects with those stored by CLIPTextEncodeExtractor + for prompt_node_id, prompt_data in metadata[PROMPTS].items(): + if "conditioning" not in prompt_data: + continue + + if pos_conditioning is not None and id(prompt_data["conditioning"]) == id(pos_conditioning): + result["prompt"] = prompt_data.get("text", "") + + if neg_conditioning is not None and id(prompt_data["conditioning"]) == id(neg_conditioning): + result["negative_prompt"] = prompt_data.get("text", "") + + return result + @staticmethod def extract_generation_params(metadata, id=None): """ @@ -261,8 +299,14 @@ class MetadataProcessor: params["sampler"] = sampling_params.get("sampler_name") params["scheduler"] = sampling_params.get("scheduler") - # Trace connections from the primary sampler - if prompt and primary_sampler_id: + # First try to match conditioning objects to prompts (new method) + if primary_sampler_id: + prompt_results = MetadataProcessor.match_conditioning_to_prompts(metadata, primary_sampler_id) + params["prompt"] = prompt_results["prompt"] + params["negative_prompt"] = prompt_results["negative_prompt"] + + # If prompts were not found by object matching, fall back to tracing connections + if not params["prompt"] and prompt and primary_sampler_id: # Check if this is a SamplerCustomAdvanced node is_custom_advanced = False if prompt.original_prompt and primary_sampler_id in prompt.original_prompt: diff --git a/py/metadata_collector/node_extractors.py b/py/metadata_collector/node_extractors.py index 8e37cd00..a9608865 100644 --- a/py/metadata_collector/node_extractors.py +++ b/py/metadata_collector/node_extractors.py @@ -47,6 +47,13 @@ class CLIPTextEncodeExtractor(NodeMetadataExtractor): "text": text, "node_id": node_id } + + @staticmethod + def update(node_id, outputs, metadata): + if outputs and isinstance(outputs, list) and len(outputs) > 0: + if isinstance(outputs[0], tuple) and len(outputs[0]) > 0: + conditioning = outputs[0][0] + metadata[PROMPTS][node_id]["conditioning"] = conditioning class SamplerExtractor(NodeMetadataExtractor): @staticmethod @@ -64,6 +71,18 @@ class SamplerExtractor(NodeMetadataExtractor): "node_id": node_id, IS_SAMPLER: True # Add sampler flag } + + # Store the conditioning objects directly in metadata for later matching + pos_conditioning = inputs.get("positive", None) + neg_conditioning = inputs.get("negative", None) + + # Save conditioning objects in metadata for later matching + if pos_conditioning is not None or neg_conditioning is not None: + if node_id not in metadata[PROMPTS]: + metadata[PROMPTS][node_id] = {"node_id": node_id} + + metadata[PROMPTS][node_id]["pos_conditioning"] = pos_conditioning + metadata[PROMPTS][node_id]["neg_conditioning"] = neg_conditioning # Extract latent image dimensions if available if "latent_image" in inputs and inputs["latent_image"] is not None: @@ -103,6 +122,18 @@ class KSamplerAdvancedExtractor(NodeMetadataExtractor): IS_SAMPLER: True # Add sampler flag } + # Store the conditioning objects directly in metadata for later matching + pos_conditioning = inputs.get("positive", None) + neg_conditioning = inputs.get("negative", None) + + # Save conditioning objects in metadata for later matching + if pos_conditioning is not None or neg_conditioning is not None: + if node_id not in metadata[PROMPTS]: + metadata[PROMPTS][node_id] = {"node_id": node_id} + + metadata[PROMPTS][node_id]["pos_conditioning"] = pos_conditioning + metadata[PROMPTS][node_id]["neg_conditioning"] = neg_conditioning + # Extract latent image dimensions if available if "latent_image" in inputs and inputs["latent_image"] is not None: latent = inputs["latent_image"] @@ -376,6 +407,13 @@ class CLIPTextEncodeFluxExtractor(NodeMetadataExtractor): metadata[SAMPLING][node_id]["parameters"]["guidance"] = guidance_value + @staticmethod + def update(node_id, outputs, metadata): + if outputs and isinstance(outputs, list) and len(outputs) > 0: + if isinstance(outputs[0], tuple) and len(outputs[0]) > 0: + conditioning = outputs[0][0] + metadata[PROMPTS][node_id]["conditioning"] = conditioning + class CFGGuiderExtractor(NodeMetadataExtractor): @staticmethod def extract(node_id, inputs, outputs, metadata): @@ -404,6 +442,7 @@ NODE_EXTRACTORS = { "BasicScheduler": BasicSchedulerExtractor, # Add BasicScheduler # Loaders "CheckpointLoaderSimple": CheckpointLoaderExtractor, + "comfyLoader": CheckpointLoaderExtractor, # eeasy comfyLoader "UNETLoader": UNETLoaderExtractor, # Updated to use dedicated extractor "UnetLoaderGGUF": UNETLoaderExtractor, # Updated to use dedicated extractor "LoraLoader": LoraLoaderExtractor,