From 8ed38527d061ff3d97d470ffb1bec247934ab514 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Wed, 16 Apr 2025 11:44:14 +0800 Subject: [PATCH 01/15] feat: Implement metadata collection and processing framework with debug node for verification --- __init__.py | 9 +- py/metadata_collector/__init__.py | 18 +++ py/metadata_collector/metadata_hook.py | 123 ++++++++++++++ py/metadata_collector/metadata_processor.py | 171 ++++++++++++++++++++ py/metadata_collector/metadata_registry.py | 171 ++++++++++++++++++++ py/metadata_collector/node_extractors.py | 163 +++++++++++++++++++ py/nodes/debug_metadata.py | 35 ++++ 7 files changed, 689 insertions(+), 1 deletion(-) create mode 100644 py/metadata_collector/__init__.py create mode 100644 py/metadata_collector/metadata_hook.py create mode 100644 py/metadata_collector/metadata_processor.py create mode 100644 py/metadata_collector/metadata_registry.py create mode 100644 py/metadata_collector/node_extractors.py create mode 100644 py/nodes/debug_metadata.py diff --git a/__init__.py b/__init__.py index 008219c8..375dfc13 100644 --- a/__init__.py +++ b/__init__.py @@ -3,16 +3,23 @@ from .py.nodes.lora_loader import LoraManagerLoader from .py.nodes.trigger_word_toggle import TriggerWordToggle from .py.nodes.lora_stacker import LoraStacker from .py.nodes.save_image import SaveImage +from .py.nodes.debug_metadata import DebugMetadata +# Import metadata collector to install hooks on startup +from .py.metadata_collector import init as init_metadata_collector NODE_CLASS_MAPPINGS = { LoraManagerLoader.NAME: LoraManagerLoader, TriggerWordToggle.NAME: TriggerWordToggle, LoraStacker.NAME: LoraStacker, - SaveImage.NAME: SaveImage + SaveImage.NAME: SaveImage, + DebugMetadata.NAME: DebugMetadata } WEB_DIRECTORY = "./web/comfyui" +# Initialize metadata collector +init_metadata_collector() + # Register routes on import LoraManager.add_routes() __all__ = ['NODE_CLASS_MAPPINGS', 'WEB_DIRECTORY'] diff --git a/py/metadata_collector/__init__.py b/py/metadata_collector/__init__.py new file mode 100644 index 00000000..3fea3a6b --- /dev/null +++ b/py/metadata_collector/__init__.py @@ -0,0 +1,18 @@ +import os +import importlib +from .metadata_hook import MetadataHook +from .metadata_registry import MetadataRegistry + +def init(): + # Install hooks to collect metadata during execution + MetadataHook.install() + + # Initialize registry + registry = MetadataRegistry() + + print("ComfyUI Metadata Collector initialized") + +def get_metadata(prompt_id=None): + """Helper function to get metadata from the registry""" + registry = MetadataRegistry() + return registry.get_metadata(prompt_id) diff --git a/py/metadata_collector/metadata_hook.py b/py/metadata_collector/metadata_hook.py new file mode 100644 index 00000000..2b6a7b6d --- /dev/null +++ b/py/metadata_collector/metadata_hook.py @@ -0,0 +1,123 @@ +import sys +import inspect +from .metadata_registry import MetadataRegistry + +class MetadataHook: + """Install hooks for metadata collection""" + + @staticmethod + def install(): + """Install hooks to collect metadata during execution""" + try: + # Import ComfyUI's execution module + execution = None + try: + # Try direct import first + import execution # type: ignore + except ImportError: + # Try to locate from system modules + for module_name in sys.modules: + if module_name.endswith('.execution'): + execution = sys.modules[module_name] + break + + # If we can't find the execution module, we can't install hooks + if execution is None: + print("Could not locate ComfyUI execution module, metadata collection disabled") + return + + # Store the original _map_node_over_list function + original_map_node_over_list = execution._map_node_over_list + + # Define the wrapped _map_node_over_list function + def map_node_over_list_with_metadata(obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None): + # Only collect metadata when calling the main function of nodes + if func == obj.FUNCTION and hasattr(obj, '__class__'): + try: + # Get the current prompt_id from the registry + registry = MetadataRegistry() + prompt_id = registry.current_prompt_id + + if prompt_id is not None: + # Get node class type + class_type = obj.__class__.__name__ + + # Unique ID might be available through the obj if it has a unique_id field + node_id = getattr(obj, 'unique_id', None) + if node_id is None and pre_execute_cb: + # Try to extract node_id through reflection on GraphBuilder.set_default_prefix + frame = inspect.currentframe() + while frame: + if 'unique_id' in frame.f_locals: + node_id = frame.f_locals['unique_id'] + break + frame = frame.f_back + + # Record inputs before execution + if node_id is not None: + registry.record_node_execution(node_id, class_type, input_data_all, None) + except Exception as e: + print(f"Error collecting metadata (pre-execution): {str(e)}") + + # Execute the original function + results = original_map_node_over_list(obj, input_data_all, func, allow_interrupt, execution_block_cb, pre_execute_cb) + + # After execution, collect outputs for relevant nodes + if func == obj.FUNCTION and hasattr(obj, '__class__'): + try: + # Get the current prompt_id from the registry + registry = MetadataRegistry() + prompt_id = registry.current_prompt_id + + if prompt_id is not None: + # Get node class type + class_type = obj.__class__.__name__ + + # Unique ID might be available through the obj if it has a unique_id field + node_id = getattr(obj, 'unique_id', None) + if node_id is None and pre_execute_cb: + # Try to extract node_id through reflection + frame = inspect.currentframe() + while frame: + if 'unique_id' in frame.f_locals: + node_id = frame.f_locals['unique_id'] + break + frame = frame.f_back + + # Record outputs after execution + if node_id is not None: + registry.update_node_execution(node_id, class_type, results) + except Exception as e: + print(f"Error collecting metadata (post-execution): {str(e)}") + + return results + + # Also hook the execute function to track the current prompt_id + original_execute = execution.execute + + def execute_with_prompt_tracking(*args, **kwargs): + if len(args) >= 7: # Check if we have enough arguments + server, prompt, caches, node_id, extra_data, executed, prompt_id = args[:7] + registry = MetadataRegistry() + + # Start collection if this is a new prompt + if not registry.current_prompt_id or registry.current_prompt_id != prompt_id: + registry.start_collection(prompt_id) + + # Store the dynprompt reference for node lookups + if hasattr(prompt, 'original_prompt'): + registry.set_current_prompt(prompt) + + # Execute the original function + return original_execute(*args, **kwargs) + + # Replace the functions + execution._map_node_over_list = map_node_over_list_with_metadata + execution.execute = execute_with_prompt_tracking + # Make map_node_over_list public to avoid it being hidden by hooks + execution.map_node_over_list = original_map_node_over_list + + print("Metadata collection hooks installed for runtime values") + + except Exception as e: + print(f"Error installing metadata hooks: {str(e)}") diff --git a/py/metadata_collector/metadata_processor.py b/py/metadata_collector/metadata_processor.py new file mode 100644 index 00000000..4721f889 --- /dev/null +++ b/py/metadata_collector/metadata_processor.py @@ -0,0 +1,171 @@ +import json + +class MetadataProcessor: + """Process and format collected metadata""" + + @staticmethod + def find_primary_sampler(metadata): + """Find the primary KSampler node (with denoise=1)""" + primary_sampler = None + primary_sampler_id = None + + for node_id, sampler_info in metadata.get("sampling", {}).items(): + parameters = sampler_info.get("parameters", {}) + denoise = parameters.get("denoise") + + # If denoise is 1.0, this is likely the primary sampler + if denoise == 1.0 or denoise == 1: + primary_sampler = sampler_info + primary_sampler_id = node_id + break + + return primary_sampler_id, primary_sampler + + @staticmethod + def trace_node_input(prompt, node_id, input_name): + """Trace an input connection from a node to find the source node""" + if not prompt or not prompt.original_prompt or node_id not in prompt.original_prompt: + return None + + node_inputs = prompt.original_prompt[node_id].get("inputs", {}) + if input_name not in node_inputs: + return None + + input_value = node_inputs[input_name] + # Input connections are formatted as [node_id, output_index] + if isinstance(input_value, list) and len(input_value) >= 2: + return input_value[0] # Return connected node_id + + return None + + @staticmethod + def find_primary_checkpoint(metadata): + """Find the primary checkpoint model in the workflow""" + if not metadata.get("models"): + return None + + # In most workflows, there's only one checkpoint, so we can just take the first one + for node_id, model_info in metadata.get("models", {}).items(): + if model_info.get("type") == "checkpoint": + return model_info.get("name") + + return None + + @staticmethod + def extract_generation_params(metadata): + """Extract generation parameters from metadata using node relationships""" + params = { + "prompt": "", + "negative_prompt": "", + "seed": None, + "steps": None, + "cfg_scale": None, + "sampler": None, + "checkpoint": None, + "loras": "", + "size": None, + "clip_skip": None + } + + # Get the prompt object for node relationship tracing + prompt = metadata.get("current_prompt") + + # Find the primary KSampler node + primary_sampler_id, primary_sampler = MetadataProcessor.find_primary_sampler(metadata) + + # Directly get checkpoint from metadata instead of tracing + checkpoint = MetadataProcessor.find_primary_checkpoint(metadata) + if checkpoint: + params["checkpoint"] = checkpoint + + if primary_sampler: + # Extract sampling parameters + sampling_params = primary_sampler.get("parameters", {}) + params["seed"] = sampling_params.get("seed") + params["steps"] = sampling_params.get("steps") + params["cfg_scale"] = sampling_params.get("cfg") + params["sampler"] = sampling_params.get("sampler_name") + + # Trace connections from the primary sampler + if prompt and primary_sampler_id: + # Trace positive prompt + positive_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "positive") + if positive_node_id and positive_node_id in metadata.get("prompts", {}): + params["prompt"] = metadata["prompts"][positive_node_id].get("text", "") + + # Trace negative prompt + negative_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "negative") + if negative_node_id and negative_node_id in metadata.get("prompts", {}): + params["negative_prompt"] = metadata["prompts"][negative_node_id].get("text", "") + + # Check if the sampler itself has size information (from latent_image) + if primary_sampler_id in metadata.get("size", {}): + width = metadata["size"][primary_sampler_id].get("width") + height = metadata["size"][primary_sampler_id].get("height") + if width and height: + params["size"] = f"{width}x{height}" + else: + # Fallback to the previous trace method if needed + latent_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "latent_image") + if latent_node_id: + # Follow chain to find EmptyLatentImage node + size_found = False + current_node_id = latent_node_id + + # Limit depth to avoid infinite loops in complex workflows + max_depth = 10 + for _ in range(max_depth): + if current_node_id in metadata.get("size", {}): + width = metadata["size"][current_node_id].get("width") + height = metadata["size"][current_node_id].get("height") + if width and height: + params["size"] = f"{width}x{height}" + size_found = True + break + + # Try to follow the chain + if prompt and prompt.original_prompt and current_node_id in prompt.original_prompt: + node_info = prompt.original_prompt[current_node_id] + if "inputs" in node_info: + # Look for a connection that might lead to size information + for input_name, input_value in node_info["inputs"].items(): + if isinstance(input_value, list) and len(input_value) >= 2: + current_node_id = input_value[0] + break + else: + break # No connections to follow + else: + break # No inputs to follow + else: + break # Can't follow further + + # Extract LoRAs + lora_parts = [] + for node_id, lora_info in metadata.get("loras", {}).items(): + name = lora_info.get("name", "unknown") + strength = lora_info.get("strength_model", 1.0) + lora_parts.append(f"") + params["loras"] = " ".join(lora_parts) + + # Set default clip_skip value + params["clip_skip"] = "1" # Common default + + return params + + @staticmethod + def to_comfyui_format(metadata): + """Convert extracted metadata to the ComfyUI output.json format""" + params = MetadataProcessor.extract_generation_params(metadata) + + # Convert all values to strings to match output.json format + for key in params: + if params[key] is not None: + params[key] = str(params[key]) + + return params + + @staticmethod + def to_json(metadata): + """Convert metadata to JSON string""" + params = MetadataProcessor.to_comfyui_format(metadata) + return json.dumps(params, indent=4) diff --git a/py/metadata_collector/metadata_registry.py b/py/metadata_collector/metadata_registry.py new file mode 100644 index 00000000..cae9a1a9 --- /dev/null +++ b/py/metadata_collector/metadata_registry.py @@ -0,0 +1,171 @@ +import time +from .node_extractors import NODE_EXTRACTORS, GenericNodeExtractor + +class MetadataRegistry: + """A singleton registry to store and retrieve workflow metadata""" + _instance = None + + def __new__(cls): + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance._reset() + return cls._instance + + def _reset(self): + self.current_prompt_id = None + self.current_prompt = None + self.metadata = {} + self.prompt_metadata = {} + self.executed_nodes = set() + + # Node-level cache for metadata + self.node_cache = {} + + # Categories we want to track and retrieve from cache + self.metadata_categories = ["models", "prompts", "sampling", "loras", "size"] + + def start_collection(self, prompt_id): + """Begin metadata collection for a new prompt""" + self.current_prompt_id = prompt_id + self.executed_nodes = set() + self.prompt_metadata[prompt_id] = { + "models": {}, + "prompts": {}, + "sampling": {}, + "loras": {}, + "size": {}, + "execution_order": [], + "current_prompt": None, # Will store the prompt object + "timestamp": time.time() + } + + def set_current_prompt(self, prompt): + """Set the current prompt object reference""" + self.current_prompt = prompt + if self.current_prompt_id and self.current_prompt_id in self.prompt_metadata: + # Store the prompt in the metadata for later relationship tracing + self.prompt_metadata[self.current_prompt_id]["current_prompt"] = prompt + + def get_metadata(self, prompt_id=None): + """Get collected metadata for a prompt""" + key = prompt_id if prompt_id is not None else self.current_prompt_id + if key not in self.prompt_metadata: + return {} + + metadata = self.prompt_metadata[key] + + # If we have a current prompt object, check for non-executed nodes + prompt_obj = metadata.get("current_prompt") + if prompt_obj and hasattr(prompt_obj, "original_prompt"): + original_prompt = prompt_obj.original_prompt + + # Fill in missing metadata from cache for nodes that weren't executed + self._fill_missing_metadata(key, original_prompt) + + return self.prompt_metadata.get(key, {}) + + def _fill_missing_metadata(self, prompt_id, original_prompt): + """Fill missing metadata from cache for non-executed nodes""" + if not original_prompt: + return + + executed_nodes = self.executed_nodes + metadata = self.prompt_metadata[prompt_id] + + # Iterate through nodes in the original prompt + for node_id, node_data in original_prompt.items(): + # Skip if already executed in this run + if node_id in executed_nodes: + continue + + class_type = node_data.get("class_type") + if not class_type: + continue + + # Create cache key + cache_key = f"{node_id}:{class_type}" + + # Check if this node type is relevant for metadata collection + if class_type in NODE_EXTRACTORS: + # Check if we have cached metadata for this node + if cache_key in self.node_cache: + cached_data = self.node_cache[cache_key] + + # Apply cached metadata to the current metadata + for category in self.metadata_categories: + if category in cached_data and node_id in cached_data[category]: + if node_id not in metadata[category]: + metadata[category][node_id] = cached_data[category][node_id] + + def record_node_execution(self, node_id, class_type, inputs, outputs): + """Record information about a node's execution""" + if not self.current_prompt_id: + return + + # Add to execution order and mark as executed + if node_id not in self.executed_nodes: + self.executed_nodes.add(node_id) + self.prompt_metadata[self.current_prompt_id]["execution_order"].append(node_id) + + # Process inputs to simplify working with them + processed_inputs = {} + for input_name, input_values in inputs.items(): + if isinstance(input_values, list) and len(input_values) > 0: + # For single values, just use the first one (most common case) + processed_inputs[input_name] = input_values[0] + else: + processed_inputs[input_name] = input_values + + # Extract node-specific metadata + extractor = NODE_EXTRACTORS.get(class_type, GenericNodeExtractor) + extractor.extract( + node_id, + processed_inputs, + outputs, + self.prompt_metadata[self.current_prompt_id] + ) + + # Cache this node's metadata + self._cache_node_metadata(node_id, class_type) + + def update_node_execution(self, node_id, class_type, outputs): + """Update node metadata with output information""" + if not self.current_prompt_id: + return + + # Process outputs to make them more usable + processed_outputs = outputs + + # Use the same extractor to update with outputs + extractor = NODE_EXTRACTORS.get(class_type, GenericNodeExtractor) + if hasattr(extractor, 'update'): + extractor.update( + node_id, + processed_outputs, + self.prompt_metadata[self.current_prompt_id] + ) + + # Update the cached metadata for this node + self._cache_node_metadata(node_id, class_type) + + def _cache_node_metadata(self, node_id, class_type): + """Cache the metadata for a specific node""" + if not self.current_prompt_id or not node_id or not class_type: + return + + # Create a cache key combining node_id and class_type + cache_key = f"{node_id}:{class_type}" + + # Create a shallow copy of the node's metadata + node_metadata = {} + current_metadata = self.prompt_metadata[self.current_prompt_id] + + for category in self.metadata_categories: + if category in current_metadata and node_id in current_metadata[category]: + if category not in node_metadata: + node_metadata[category] = {} + node_metadata[category][node_id] = current_metadata[category][node_id] + + # Save to cache if we have any metadata for this node + if any(node_metadata.values()): + self.node_cache[cache_key] = node_metadata diff --git a/py/metadata_collector/node_extractors.py b/py/metadata_collector/node_extractors.py new file mode 100644 index 00000000..78d0a81b --- /dev/null +++ b/py/metadata_collector/node_extractors.py @@ -0,0 +1,163 @@ +class NodeMetadataExtractor: + """Base class for node-specific metadata extraction""" + + @staticmethod + def extract(node_id, inputs, outputs, metadata): + """Extract metadata from node inputs/outputs""" + pass + + @staticmethod + def update(node_id, outputs, metadata): + """Update metadata with node outputs after execution""" + pass + +class GenericNodeExtractor(NodeMetadataExtractor): + """Default extractor for nodes without specific handling""" + @staticmethod + def extract(node_id, inputs, outputs, metadata): + pass + +class CheckpointLoaderExtractor(NodeMetadataExtractor): + @staticmethod + def extract(node_id, inputs, outputs, metadata): + if not inputs or "ckpt_name" not in inputs: + return + + model_name = inputs.get("ckpt_name") + if model_name: + metadata["models"][node_id] = { + "name": model_name, + "type": "checkpoint", + "node_id": node_id + } + +class CLIPTextEncodeExtractor(NodeMetadataExtractor): + @staticmethod + def extract(node_id, inputs, outputs, metadata): + if not inputs or "text" not in inputs: + return + + text = inputs.get("text", "") + metadata["prompts"][node_id] = { + "text": text, + "node_id": node_id + } + +class SamplerExtractor(NodeMetadataExtractor): + @staticmethod + def extract(node_id, inputs, outputs, metadata): + if not inputs: + return + + sampling_params = {} + for key in ["seed", "steps", "cfg", "sampler_name", "scheduler", "denoise"]: + if key in inputs: + sampling_params[key] = inputs[key] + + metadata["sampling"][node_id] = { + "parameters": sampling_params, + "node_id": node_id + } + + # Extract latent image dimensions if available + if "latent_image" in inputs and inputs["latent_image"] is not None: + latent = inputs["latent_image"] + if isinstance(latent, dict) and "samples" in latent: + # Extract dimensions from latent tensor + samples = latent["samples"] + if hasattr(samples, "shape") and len(samples.shape) >= 3: + # Correct shape interpretation: [batch_size, channels, height/8, width/8] + # Multiply by 8 to get actual pixel dimensions + height = int(samples.shape[2] * 8) + width = int(samples.shape[3] * 8) + + if "size" not in metadata: + metadata["size"] = {} + + metadata["size"][node_id] = { + "width": width, + "height": height, + "node_id": node_id + } + +class LoraLoaderExtractor(NodeMetadataExtractor): + @staticmethod + def extract(node_id, inputs, outputs, metadata): + if not inputs or "lora_name" not in inputs: + return + + lora_name = inputs.get("lora_name") + strength_model = inputs.get("strength_model", 1.0) + strength_clip = inputs.get("strength_clip", 1.0) + + metadata["loras"][node_id] = { + "name": lora_name, + "strength_model": strength_model, + "strength_clip": strength_clip, + "node_id": node_id + } + +class ImageSizeExtractor(NodeMetadataExtractor): + @staticmethod + def extract(node_id, inputs, outputs, metadata): + if not inputs: + return + + width = inputs.get("width", 512) + height = inputs.get("height", 512) + + if "size" not in metadata: + metadata["size"] = {} + + metadata["size"][node_id] = { + "width": width, + "height": height, + "node_id": node_id + } + +class LoraLoaderManagerExtractor(NodeMetadataExtractor): + @staticmethod + def extract(node_id, inputs, outputs, metadata): + if not inputs: + return + + # Handle LoraManager nodes which might store loras differently + if "loras" in inputs: + loras = inputs.get("loras", []) + if isinstance(loras, list): + active_loras = [] + # Filter for active loras (may be a list of dicts with 'active' flag) + for lora in loras: + if isinstance(lora, dict) and lora.get("active", True) and not lora.get("_isDummy", False): + active_loras.append({ + "name": lora.get("name", ""), + "strength": lora.get("strength", 1.0) + }) + + if active_loras: + metadata["loras"][node_id] = { + "lora_list": active_loras, + "node_id": node_id + } + + # If there's a direct text field with lora definitions + if "text" in inputs: + text = inputs.get("text", "") + if text and " Date: Wed, 16 Apr 2025 21:20:56 +0800 Subject: [PATCH 02/15] feat: Standardize LoRA extraction format and enhance input handling in node extractors --- py/metadata_collector/metadata_processor.py | 12 ++- py/metadata_collector/node_extractors.py | 86 +++++++++++++-------- 2 files changed, 61 insertions(+), 37 deletions(-) diff --git a/py/metadata_collector/metadata_processor.py b/py/metadata_collector/metadata_processor.py index 4721f889..bbd499a1 100644 --- a/py/metadata_collector/metadata_processor.py +++ b/py/metadata_collector/metadata_processor.py @@ -139,12 +139,16 @@ class MetadataProcessor: else: break # Can't follow further - # Extract LoRAs + # Extract LoRAs using the standardized format lora_parts = [] for node_id, lora_info in metadata.get("loras", {}).items(): - name = lora_info.get("name", "unknown") - strength = lora_info.get("strength_model", 1.0) - lora_parts.append(f"") + # Access the lora_list from the standardized format + lora_list = lora_info.get("lora_list", []) + for lora in lora_list: + name = lora.get("name", "unknown") + strength = lora.get("strength", 1.0) + lora_parts.append(f"") + params["loras"] = " ".join(lora_parts) # Set default clip_skip value diff --git a/py/metadata_collector/node_extractors.py b/py/metadata_collector/node_extractors.py index 78d0a81b..a1677b31 100644 --- a/py/metadata_collector/node_extractors.py +++ b/py/metadata_collector/node_extractors.py @@ -1,3 +1,6 @@ +import os + + class NodeMetadataExtractor: """Base class for node-specific metadata extraction""" @@ -87,13 +90,16 @@ class LoraLoaderExtractor(NodeMetadataExtractor): return lora_name = inputs.get("lora_name") - strength_model = inputs.get("strength_model", 1.0) - strength_clip = inputs.get("strength_clip", 1.0) + strength_model = round(float(inputs.get("strength_model", 1.0)), 2) + # Use the standardized format with lora_list metadata["loras"][node_id] = { - "name": lora_name, - "strength_model": strength_model, - "strength_clip": strength_clip, + "lora_list": [ + { + "name": lora_name, + "strength": strength_model + } + ], "node_id": node_id } @@ -120,34 +126,48 @@ class LoraLoaderManagerExtractor(NodeMetadataExtractor): def extract(node_id, inputs, outputs, metadata): if not inputs: return - - # Handle LoraManager nodes which might store loras differently - if "loras" in inputs: - loras = inputs.get("loras", []) - if isinstance(loras, list): - active_loras = [] - # Filter for active loras (may be a list of dicts with 'active' flag) - for lora in loras: - if isinstance(lora, dict) and lora.get("active", True) and not lora.get("_isDummy", False): - active_loras.append({ - "name": lora.get("name", ""), - "strength": lora.get("strength", 1.0) - }) - - if active_loras: - metadata["loras"][node_id] = { - "lora_list": active_loras, - "node_id": node_id - } - # If there's a direct text field with lora definitions - if "text" in inputs: - text = inputs.get("text", "") - if text and " Date: Wed, 16 Apr 2025 21:42:54 +0800 Subject: [PATCH 03/15] refactor: Rename to_comfyui_format method to to_dict and update references in save_image.py --- py/metadata_collector/metadata_processor.py | 4 +- py/metadata_collector/node_extractors.py | 2 - py/nodes/save_image.py | 88 ++++++++++----------- 3 files changed, 45 insertions(+), 49 deletions(-) diff --git a/py/metadata_collector/metadata_processor.py b/py/metadata_collector/metadata_processor.py index bbd499a1..da55c0c7 100644 --- a/py/metadata_collector/metadata_processor.py +++ b/py/metadata_collector/metadata_processor.py @@ -157,7 +157,7 @@ class MetadataProcessor: return params @staticmethod - def to_comfyui_format(metadata): + def to_dict(metadata): """Convert extracted metadata to the ComfyUI output.json format""" params = MetadataProcessor.extract_generation_params(metadata) @@ -171,5 +171,5 @@ class MetadataProcessor: @staticmethod def to_json(metadata): """Convert metadata to JSON string""" - params = MetadataProcessor.to_comfyui_format(metadata) + params = MetadataProcessor.to_dict(metadata) return json.dumps(params, indent=4) diff --git a/py/metadata_collector/node_extractors.py b/py/metadata_collector/node_extractors.py index a1677b31..081dd8c1 100644 --- a/py/metadata_collector/node_extractors.py +++ b/py/metadata_collector/node_extractors.py @@ -167,8 +167,6 @@ class LoraLoaderManagerExtractor(NodeMetadataExtractor): "node_id": node_id } - print(f"Active LoRAs for node {node_id}: {active_loras}") - # Registry of node-specific extractors NODE_EXTRACTORS = { "CheckpointLoaderSimple": CheckpointLoaderExtractor, diff --git a/py/nodes/save_image.py b/py/nodes/save_image.py index 31c6695e..dea33a45 100644 --- a/py/nodes/save_image.py +++ b/py/nodes/save_image.py @@ -5,7 +5,8 @@ import re import numpy as np import folder_paths # type: ignore from ..services.lora_scanner import LoraScanner -from ..workflow.parser import WorkflowParser +from ..metadata_collector.metadata_processor import MetadataProcessor +from ..metadata_collector import get_metadata from PIL import Image, PngImagePlugin import piexif from io import BytesIO @@ -61,21 +62,21 @@ class SaveImage: return item.get('sha256') return None - async def format_metadata(self, parsed_workflow, custom_prompt=None): + async def format_metadata(self, metadata_dict, custom_prompt=None): """Format metadata in the requested format similar to userComment example""" - if not parsed_workflow: + if not metadata_dict: return "" # Extract the prompt and negative prompt - prompt = parsed_workflow.get('prompt', '') - negative_prompt = parsed_workflow.get('negative_prompt', '') + prompt = metadata_dict.get('prompt', '') + negative_prompt = metadata_dict.get('negative_prompt', '') # Override prompt with custom_prompt if provided if custom_prompt: prompt = custom_prompt # Extract loras from the prompt if present - loras_text = parsed_workflow.get('loras', '') + loras_text = metadata_dict.get('loras', '') lora_hashes = {} # If loras are found, add them on a new line after the prompt @@ -104,11 +105,11 @@ class SaveImage: params = [] # Add standard parameters in the correct order - if 'steps' in parsed_workflow: - params.append(f"Steps: {parsed_workflow.get('steps')}") + if 'steps' in metadata_dict: + params.append(f"Steps: {metadata_dict.get('steps')}") - if 'sampler' in parsed_workflow: - sampler = parsed_workflow.get('sampler') + if 'sampler' in metadata_dict: + sampler = metadata_dict.get('sampler') # Convert ComfyUI sampler names to user-friendly names sampler_mapping = { 'euler': 'Euler', @@ -130,8 +131,8 @@ class SaveImage: sampler_name = sampler_mapping.get(sampler, sampler) params.append(f"Sampler: {sampler_name}") - if 'scheduler' in parsed_workflow: - scheduler = parsed_workflow.get('scheduler') + if 'scheduler' in metadata_dict: + scheduler = metadata_dict.get('scheduler') scheduler_mapping = { 'normal': 'Simple', 'karras': 'Karras', @@ -142,24 +143,24 @@ class SaveImage: scheduler_name = scheduler_mapping.get(scheduler, scheduler) params.append(f"Schedule type: {scheduler_name}") - # CFG scale (cfg in parsed_workflow) - if 'cfg_scale' in parsed_workflow: - params.append(f"CFG scale: {parsed_workflow.get('cfg_scale')}") - elif 'cfg' in parsed_workflow: - params.append(f"CFG scale: {parsed_workflow.get('cfg')}") + # CFG scale (cfg_scale in metadata_dict) + if 'cfg_scale' in metadata_dict: + params.append(f"CFG scale: {metadata_dict.get('cfg_scale')}") + elif 'cfg' in metadata_dict: + params.append(f"CFG scale: {metadata_dict.get('cfg')}") # Seed - if 'seed' in parsed_workflow: - params.append(f"Seed: {parsed_workflow.get('seed')}") + if 'seed' in metadata_dict: + params.append(f"Seed: {metadata_dict.get('seed')}") # Size - if 'size' in parsed_workflow: - params.append(f"Size: {parsed_workflow.get('size')}") + if 'size' in metadata_dict: + params.append(f"Size: {metadata_dict.get('size')}") # Model info - if 'checkpoint' in parsed_workflow: + if 'checkpoint' in metadata_dict: # Extract basename without path - checkpoint = os.path.basename(parsed_workflow.get('checkpoint', '')) + checkpoint = os.path.basename(metadata_dict.get('checkpoint', '')) # Remove extension if present checkpoint = os.path.splitext(checkpoint)[0] params.append(f"Model: {checkpoint}") @@ -181,9 +182,9 @@ class SaveImage: # credit to nkchocoai # Add format_filename method to handle pattern substitution - def format_filename(self, filename, parsed_workflow): + def format_filename(self, filename, metadata_dict): """Format filename with metadata values""" - if not parsed_workflow: + if not metadata_dict: return filename result = re.findall(self.pattern_format, filename) @@ -191,30 +192,30 @@ class SaveImage: parts = segment.replace("%", "").split(":") key = parts[0] - if key == "seed" and 'seed' in parsed_workflow: - filename = filename.replace(segment, str(parsed_workflow.get('seed', ''))) - elif key == "width" and 'size' in parsed_workflow: - size = parsed_workflow.get('size', 'x') + if key == "seed" and 'seed' in metadata_dict: + filename = filename.replace(segment, str(metadata_dict.get('seed', ''))) + elif key == "width" and 'size' in metadata_dict: + size = metadata_dict.get('size', 'x') w = size.split('x')[0] if isinstance(size, str) else size[0] filename = filename.replace(segment, str(w)) - elif key == "height" and 'size' in parsed_workflow: - size = parsed_workflow.get('size', 'x') + elif key == "height" and 'size' in metadata_dict: + size = metadata_dict.get('size', 'x') h = size.split('x')[1] if isinstance(size, str) else size[1] filename = filename.replace(segment, str(h)) - elif key == "pprompt" and 'prompt' in parsed_workflow: - prompt = parsed_workflow.get('prompt', '').replace("\n", " ") + elif key == "pprompt" and 'prompt' in metadata_dict: + prompt = metadata_dict.get('prompt', '').replace("\n", " ") if len(parts) >= 2: length = int(parts[1]) prompt = prompt[:length] filename = filename.replace(segment, prompt.strip()) - elif key == "nprompt" and 'negative_prompt' in parsed_workflow: - prompt = parsed_workflow.get('negative_prompt', '').replace("\n", " ") + elif key == "nprompt" and 'negative_prompt' in metadata_dict: + prompt = metadata_dict.get('negative_prompt', '').replace("\n", " ") if len(parts) >= 2: length = int(parts[1]) prompt = prompt[:length] filename = filename.replace(segment, prompt.strip()) - elif key == "model" and 'checkpoint' in parsed_workflow: - model = parsed_workflow.get('checkpoint', '') + elif key == "model" and 'checkpoint' in metadata_dict: + model = metadata_dict.get('checkpoint', '') model = os.path.splitext(os.path.basename(model))[0] if len(parts) >= 2: length = int(parts[1]) @@ -251,18 +252,15 @@ class SaveImage: """Save images with metadata""" results = [] - # Parse the workflow using the WorkflowParser - parser = WorkflowParser() - if prompt: - parsed_workflow = parser.parse_workflow(prompt) - else: - parsed_workflow = {} + # Get metadata using the metadata collector + raw_metadata = get_metadata() + metadata_dict = MetadataProcessor.to_dict(raw_metadata) # Get or create metadata asynchronously - metadata = asyncio.run(self.format_metadata(parsed_workflow, custom_prompt)) + metadata = asyncio.run(self.format_metadata(metadata_dict, custom_prompt)) # Process filename_prefix with pattern substitution - filename_prefix = self.format_filename(filename_prefix, parsed_workflow) + filename_prefix = self.format_filename(filename_prefix, metadata_dict) # Get initial save path info once for the batch full_output_folder, filename, counter, subfolder, processed_prefix = folder_paths.get_save_image_path( From 4c69d8d3a8c55ae52b1e870d618096678d5d9c06 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Wed, 16 Apr 2025 22:15:46 +0800 Subject: [PATCH 04/15] feat: Integrate metadata collection in RecipeRoutes and simplify saveRecipeDirectly function --- py/routes/recipe_routes.py | 45 ++++++++++++------------------------- web/comfyui/loras_widget.js | 12 ++-------- 2 files changed, 16 insertions(+), 41 deletions(-) diff --git a/py/routes/recipe_routes.py b/py/routes/recipe_routes.py index 48328537..2e13fdf8 100644 --- a/py/routes/recipe_routes.py +++ b/py/routes/recipe_routes.py @@ -11,7 +11,8 @@ from ..utils.recipe_parsers import RecipeParserFactory from ..utils.constants import CARD_PREVIEW_WIDTH from ..config import config -from ..workflow.parser import WorkflowParser +from ..metadata_collector import get_metadata # Add MetadataCollector import +from ..metadata_collector.metadata_processor import MetadataProcessor # Add MetadataProcessor import from ..utils.utils import download_civitai_image from ..services.service_registry import ServiceRegistry # Add ServiceRegistry import @@ -24,7 +25,7 @@ class RecipeRoutes: # Initialize service references as None, will be set during async init self.recipe_scanner = None self.civitai_client = None - self.parser = WorkflowParser() + # Remove WorkflowParser instance # Pre-warm the cache self._init_cache_task = None @@ -786,25 +787,13 @@ class RecipeRoutes: # Ensure services are initialized await self.init_services() - reader = await request.multipart() + # Get metadata using the metadata collector instead of workflow parsing + raw_metadata = get_metadata() + metadata_dict = MetadataProcessor.to_dict(raw_metadata) - # Process form data - workflow_json = None - - while True: - field = await reader.next() - if field is None: - break - - if field.name == 'workflow_json': - workflow_text = await field.text() - try: - workflow_json = json.loads(workflow_text) - except: - return web.json_response({"error": "Invalid workflow JSON"}, status=400) - - if not workflow_json: - return web.json_response({"error": "Missing workflow JSON"}, status=400) + # Check if we have valid metadata + if not metadata_dict: + return web.json_response({"error": "No generation metadata found"}, status=400) # Find the latest image in the temp directory temp_dir = config.temp_directory @@ -822,14 +811,8 @@ class RecipeRoutes: image_files.sort(key=lambda x: x[1], reverse=True) latest_image_path = image_files[0][0] - # Parse the workflow to extract generation parameters and loras - parsed_workflow = self.parser.parse_workflow(workflow_json) - - if not parsed_workflow: - return web.json_response({"error": "Could not extract parameters from workflow"}, status=400) - - # Get the lora stack from the parsed workflow - lora_stack = parsed_workflow.get("loras", "") + # Get the lora stack from the metadata + lora_stack = metadata_dict.get("loras", "") # Parse the lora stack format: " ..." import re @@ -837,7 +820,7 @@ class RecipeRoutes: # Check if any loras were found if not lora_matches: - return web.json_response({"error": "No LoRAs found in the workflow"}, status=400) + return web.json_response({"error": "No LoRAs found in the generation metadata"}, status=400) # Generate recipe name from the first 3 loras (or less if fewer are available) loras_for_name = lora_matches[:3] # Take at most 3 loras for the name @@ -922,8 +905,8 @@ class RecipeRoutes: "created_date": time.time(), "base_model": most_common_base_model, "loras": loras_data, - "checkpoint": parsed_workflow.get("checkpoint", ""), - "gen_params": {key: value for key, value in parsed_workflow.items() + "checkpoint": metadata_dict.get("checkpoint", ""), + "gen_params": {key: value for key, value in metadata_dict.items() if key not in ['checkpoint', 'loras']}, "loras_stack": lora_stack # Include the original lora stack string } diff --git a/web/comfyui/loras_widget.js b/web/comfyui/loras_widget.js index 0166d014..a5210234 100644 --- a/web/comfyui/loras_widget.js +++ b/web/comfyui/loras_widget.js @@ -966,9 +966,6 @@ export function addLorasWidget(node, name, opts, callback) { // Function to directly save the recipe without dialog async function saveRecipeDirectly(widget) { try { - // Get the workflow data from the ComfyUI app - const prompt = await app.graphToPrompt(); - // Show loading toast if (app && app.extensionManager && app.extensionManager.toast) { app.extensionManager.toast.add({ @@ -979,14 +976,9 @@ async function saveRecipeDirectly(widget) { }); } - // Prepare the data - only send workflow JSON - const formData = new FormData(); - formData.append('workflow_json', JSON.stringify(prompt.output)); - - // Send the request + // Send the request to the backend API without workflow data const response = await fetch('/api/recipes/save-from-widget', { - method: 'POST', - body: formData + method: 'POST' }); const result = await response.json(); From 4fdc88e9e1698012ba03d8c761a7057045d6cc39 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Wed, 16 Apr 2025 22:19:38 +0800 Subject: [PATCH 05/15] feat: Enhance LoraLoaderExtractor to extract base filename from lora_name input --- py/metadata_collector/node_extractors.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/py/metadata_collector/node_extractors.py b/py/metadata_collector/node_extractors.py index 081dd8c1..612059d6 100644 --- a/py/metadata_collector/node_extractors.py +++ b/py/metadata_collector/node_extractors.py @@ -90,6 +90,8 @@ class LoraLoaderExtractor(NodeMetadataExtractor): return lora_name = inputs.get("lora_name") + # Extract base filename without extension from path + lora_name = os.path.splitext(os.path.basename(lora_name))[0] strength_model = round(float(inputs.get("strength_model", 1.0)), 2) # Use the standardized format with lora_list From 18eb605605b2c373f23764df7549aa54334f5a2e Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Thu, 17 Apr 2025 06:23:31 +0800 Subject: [PATCH 06/15] feat: Refactor metadata processing to use constants for category keys and improve structure --- py/metadata_collector/constants.py | 11 ++++++++ py/metadata_collector/metadata_processor.py | 30 +++++++++++---------- py/metadata_collector/metadata_registry.py | 14 +++++----- py/metadata_collector/node_extractors.py | 24 +++++++++-------- 4 files changed, 47 insertions(+), 32 deletions(-) create mode 100644 py/metadata_collector/constants.py diff --git a/py/metadata_collector/constants.py b/py/metadata_collector/constants.py new file mode 100644 index 00000000..0ad55b47 --- /dev/null +++ b/py/metadata_collector/constants.py @@ -0,0 +1,11 @@ +"""Constants used by the metadata collector""" + +# Individual category constants +MODELS = "models" +PROMPTS = "prompts" +SAMPLING = "sampling" +LORAS = "loras" +SIZE = "size" + +# Collection of categories for iteration +METADATA_CATEGORIES = [MODELS, PROMPTS, SAMPLING, LORAS, SIZE] diff --git a/py/metadata_collector/metadata_processor.py b/py/metadata_collector/metadata_processor.py index da55c0c7..9c80e68c 100644 --- a/py/metadata_collector/metadata_processor.py +++ b/py/metadata_collector/metadata_processor.py @@ -1,5 +1,7 @@ import json +from .constants import MODELS, PROMPTS, SAMPLING, LORAS, SIZE + class MetadataProcessor: """Process and format collected metadata""" @@ -9,7 +11,7 @@ class MetadataProcessor: primary_sampler = None primary_sampler_id = None - for node_id, sampler_info in metadata.get("sampling", {}).items(): + for node_id, sampler_info in metadata.get(SAMPLING, {}).items(): parameters = sampler_info.get("parameters", {}) denoise = parameters.get("denoise") @@ -41,11 +43,11 @@ class MetadataProcessor: @staticmethod def find_primary_checkpoint(metadata): """Find the primary checkpoint model in the workflow""" - if not metadata.get("models"): + if not metadata.get(MODELS): return None # In most workflows, there's only one checkpoint, so we can just take the first one - for node_id, model_info in metadata.get("models", {}).items(): + for node_id, model_info in metadata.get(MODELS, {}).items(): if model_info.get("type") == "checkpoint": return model_info.get("name") @@ -90,18 +92,18 @@ class MetadataProcessor: if prompt and primary_sampler_id: # Trace positive prompt positive_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "positive") - if positive_node_id and positive_node_id in metadata.get("prompts", {}): - params["prompt"] = metadata["prompts"][positive_node_id].get("text", "") + if positive_node_id and positive_node_id in metadata.get(PROMPTS, {}): + params["prompt"] = metadata[PROMPTS][positive_node_id].get("text", "") # Trace negative prompt negative_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "negative") - if negative_node_id and negative_node_id in metadata.get("prompts", {}): - params["negative_prompt"] = metadata["prompts"][negative_node_id].get("text", "") + if negative_node_id and negative_node_id in metadata.get(PROMPTS, {}): + params["negative_prompt"] = metadata[PROMPTS][negative_node_id].get("text", "") # Check if the sampler itself has size information (from latent_image) - if primary_sampler_id in metadata.get("size", {}): - width = metadata["size"][primary_sampler_id].get("width") - height = metadata["size"][primary_sampler_id].get("height") + if primary_sampler_id in metadata.get(SIZE, {}): + width = metadata[SIZE][primary_sampler_id].get("width") + height = metadata[SIZE][primary_sampler_id].get("height") if width and height: params["size"] = f"{width}x{height}" else: @@ -115,9 +117,9 @@ class MetadataProcessor: # Limit depth to avoid infinite loops in complex workflows max_depth = 10 for _ in range(max_depth): - if current_node_id in metadata.get("size", {}): - width = metadata["size"][current_node_id].get("width") - height = metadata["size"][current_node_id].get("height") + if current_node_id in metadata.get(SIZE, {}): + width = metadata[SIZE][current_node_id].get("width") + height = metadata[SIZE][current_node_id].get("height") if width and height: params["size"] = f"{width}x{height}" size_found = True @@ -141,7 +143,7 @@ class MetadataProcessor: # Extract LoRAs using the standardized format lora_parts = [] - for node_id, lora_info in metadata.get("loras", {}).items(): + for node_id, lora_info in metadata.get(LORAS, {}).items(): # Access the lora_list from the standardized format lora_list = lora_info.get("lora_list", []) for lora in lora_list: diff --git a/py/metadata_collector/metadata_registry.py b/py/metadata_collector/metadata_registry.py index cae9a1a9..9b75351f 100644 --- a/py/metadata_collector/metadata_registry.py +++ b/py/metadata_collector/metadata_registry.py @@ -1,5 +1,6 @@ import time from .node_extractors import NODE_EXTRACTORS, GenericNodeExtractor +from .constants import METADATA_CATEGORIES class MetadataRegistry: """A singleton registry to store and retrieve workflow metadata""" @@ -22,22 +23,21 @@ class MetadataRegistry: self.node_cache = {} # Categories we want to track and retrieve from cache - self.metadata_categories = ["models", "prompts", "sampling", "loras", "size"] + self.metadata_categories = METADATA_CATEGORIES def start_collection(self, prompt_id): """Begin metadata collection for a new prompt""" self.current_prompt_id = prompt_id self.executed_nodes = set() self.prompt_metadata[prompt_id] = { - "models": {}, - "prompts": {}, - "sampling": {}, - "loras": {}, - "size": {}, + category: {} for category in METADATA_CATEGORIES + } + # Add additional metadata fields + self.prompt_metadata[prompt_id].update({ "execution_order": [], "current_prompt": None, # Will store the prompt object "timestamp": time.time() - } + }) def set_current_prompt(self, prompt): """Set the current prompt object reference""" diff --git a/py/metadata_collector/node_extractors.py b/py/metadata_collector/node_extractors.py index 612059d6..6fb018b5 100644 --- a/py/metadata_collector/node_extractors.py +++ b/py/metadata_collector/node_extractors.py @@ -1,5 +1,7 @@ import os +from .constants import MODELS, PROMPTS, SAMPLING, LORAS, SIZE + class NodeMetadataExtractor: """Base class for node-specific metadata extraction""" @@ -28,7 +30,7 @@ class CheckpointLoaderExtractor(NodeMetadataExtractor): model_name = inputs.get("ckpt_name") if model_name: - metadata["models"][node_id] = { + metadata[MODELS][node_id] = { "name": model_name, "type": "checkpoint", "node_id": node_id @@ -41,7 +43,7 @@ class CLIPTextEncodeExtractor(NodeMetadataExtractor): return text = inputs.get("text", "") - metadata["prompts"][node_id] = { + metadata[PROMPTS][node_id] = { "text": text, "node_id": node_id } @@ -57,7 +59,7 @@ class SamplerExtractor(NodeMetadataExtractor): if key in inputs: sampling_params[key] = inputs[key] - metadata["sampling"][node_id] = { + metadata[SAMPLING][node_id] = { "parameters": sampling_params, "node_id": node_id } @@ -74,10 +76,10 @@ class SamplerExtractor(NodeMetadataExtractor): height = int(samples.shape[2] * 8) width = int(samples.shape[3] * 8) - if "size" not in metadata: - metadata["size"] = {} + if SIZE not in metadata: + metadata[SIZE] = {} - metadata["size"][node_id] = { + metadata[SIZE][node_id] = { "width": width, "height": height, "node_id": node_id @@ -95,7 +97,7 @@ class LoraLoaderExtractor(NodeMetadataExtractor): strength_model = round(float(inputs.get("strength_model", 1.0)), 2) # Use the standardized format with lora_list - metadata["loras"][node_id] = { + metadata[LORAS][node_id] = { "lora_list": [ { "name": lora_name, @@ -114,10 +116,10 @@ class ImageSizeExtractor(NodeMetadataExtractor): width = inputs.get("width", 512) height = inputs.get("height", 512) - if "size" not in metadata: - metadata["size"] = {} + if SIZE not in metadata: + metadata[SIZE] = {} - metadata["size"][node_id] = { + metadata[SIZE][node_id] = { "width": width, "height": height, "node_id": node_id @@ -164,7 +166,7 @@ class LoraLoaderManagerExtractor(NodeMetadataExtractor): }) if active_loras: - metadata["loras"][node_id] = { + metadata[LORAS][node_id] = { "lora_list": active_loras, "node_id": node_id } From 32d34d17481b779ab2f140d25eced3093a159f7a Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Thu, 17 Apr 2025 08:06:21 +0800 Subject: [PATCH 07/15] feat: Enhance trace_node_input method with depth tracking and target class filtering; add FluxGuidanceExtractor for guidance parameter extraction --- py/metadata_collector/metadata_processor.py | 78 +++++++++++++++++---- py/metadata_collector/node_extractors.py | 15 ++++ 2 files changed, 80 insertions(+), 13 deletions(-) diff --git a/py/metadata_collector/metadata_processor.py b/py/metadata_collector/metadata_processor.py index 9c80e68c..9cf2cb83 100644 --- a/py/metadata_collector/metadata_processor.py +++ b/py/metadata_collector/metadata_processor.py @@ -24,20 +24,65 @@ class MetadataProcessor: return primary_sampler_id, primary_sampler @staticmethod - def trace_node_input(prompt, node_id, input_name): - """Trace an input connection from a node to find the source node""" + def trace_node_input(prompt, node_id, input_name, target_class=None, max_depth=10): + """ + Trace an input connection from a node to find the source node + + Parameters: + - prompt: The prompt object containing node connections + - node_id: ID of the starting node + - input_name: Name of the input to trace + - target_class: Optional class name to search for (e.g., "CLIPTextEncode") + - max_depth: Maximum depth to follow the node chain to prevent infinite loops + + Returns: + - node_id of the found node, or None if not found + """ if not prompt or not prompt.original_prompt or node_id not in prompt.original_prompt: return None - node_inputs = prompt.original_prompt[node_id].get("inputs", {}) - if input_name not in node_inputs: - return None + # For depth tracking + current_depth = 0 + + current_node_id = node_id + current_input = input_name + + while current_depth < max_depth: + if current_node_id not in prompt.original_prompt: + return None + + node_inputs = prompt.original_prompt[current_node_id].get("inputs", {}) + if current_input not in node_inputs: + return None + + input_value = node_inputs[current_input] + # Input connections are formatted as [node_id, output_index] + if isinstance(input_value, list) and len(input_value) >= 2: + found_node_id = input_value[0] # Connected node_id + + # If we're looking for a specific node class + if target_class and prompt.original_prompt[found_node_id].get("class_type") == target_class: + return found_node_id + + # If we're not looking for a specific class or haven't found it yet + if not target_class: + return found_node_id + + # Continue tracing through intermediate nodes + current_node_id = found_node_id + # For most conditioning nodes, the input we want to follow is named "conditioning" + if "conditioning" in prompt.original_prompt[current_node_id].get("inputs", {}): + current_input = "conditioning" + else: + # If there's no "conditioning" input, we can't trace further + return found_node_id if not target_class else None + else: + # We've reached a node with no further connections + return None - input_value = node_inputs[input_name] - # Input connections are formatted as [node_id, output_index] - if isinstance(input_value, list) and len(input_value) >= 2: - return input_value[0] # Return connected node_id + current_depth += 1 + # If we've reached max depth without finding target_class return None @staticmethod @@ -62,6 +107,7 @@ class MetadataProcessor: "seed": None, "steps": None, "cfg_scale": None, + "guidance": None, # Add guidance parameter "sampler": None, "checkpoint": None, "loras": "", @@ -90,13 +136,19 @@ class MetadataProcessor: # Trace connections from the primary sampler if prompt and primary_sampler_id: - # Trace positive prompt - positive_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "positive") + # Trace positive prompt - look specifically for CLIPTextEncode + positive_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "positive", "CLIPTextEncode", max_depth=10) if positive_node_id and positive_node_id in metadata.get(PROMPTS, {}): params["prompt"] = metadata[PROMPTS][positive_node_id].get("text", "") - # Trace negative prompt - negative_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "negative") + # Find any FluxGuidance nodes in the positive conditioning path + flux_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "positive", "FluxGuidance", max_depth=5) + if flux_node_id and flux_node_id in metadata.get(SAMPLING, {}): + flux_params = metadata[SAMPLING][flux_node_id].get("parameters", {}) + params["guidance"] = flux_params.get("guidance") + + # Trace negative prompt - look specifically for CLIPTextEncode + negative_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "negative", "CLIPTextEncode", max_depth=10) if negative_node_id and negative_node_id in metadata.get(PROMPTS, {}): params["negative_prompt"] = metadata[PROMPTS][negative_node_id].get("text", "") diff --git a/py/metadata_collector/node_extractors.py b/py/metadata_collector/node_extractors.py index 6fb018b5..0d599c93 100644 --- a/py/metadata_collector/node_extractors.py +++ b/py/metadata_collector/node_extractors.py @@ -170,6 +170,20 @@ class LoraLoaderManagerExtractor(NodeMetadataExtractor): "lora_list": active_loras, "node_id": node_id } + +class FluxGuidanceExtractor(NodeMetadataExtractor): + @staticmethod + def extract(node_id, inputs, outputs, metadata): + if not inputs or "guidance" not in inputs: + return + + guidance_value = inputs.get("guidance") + + # Store the guidance value in SAMPLING category + if node_id not in metadata[SAMPLING]: + metadata[SAMPLING][node_id] = {"parameters": {}, "node_id": node_id} + + metadata[SAMPLING][node_id]["parameters"]["guidance"] = guidance_value # Registry of node-specific extractors NODE_EXTRACTORS = { @@ -181,5 +195,6 @@ NODE_EXTRACTORS = { "LoraManagerLoader": LoraLoaderManagerExtractor, "SamplerCustomAdvanced": SamplerExtractor, # Add SamplerCustomAdvanced "UNETLoader": CheckpointLoaderExtractor, # Add UNETLoader + "FluxGuidance": FluxGuidanceExtractor, # Add FluxGuidance # Add other nodes as needed } From 5fd069d70d58cc934c7d446d40759222b72fb404 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Thu, 17 Apr 2025 09:38:20 +0800 Subject: [PATCH 08/15] feat: Enhance checkpoint processing in format_metadata to handle non-string types safely --- py/nodes/save_image.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/py/nodes/save_image.py b/py/nodes/save_image.py index dea33a45..110e2072 100644 --- a/py/nodes/save_image.py +++ b/py/nodes/save_image.py @@ -159,11 +159,20 @@ class SaveImage: # Model info if 'checkpoint' in metadata_dict: - # Extract basename without path - checkpoint = os.path.basename(metadata_dict.get('checkpoint', '')) - # Remove extension if present - checkpoint = os.path.splitext(checkpoint)[0] - params.append(f"Model: {checkpoint}") + # Ensure checkpoint is a string before processing + checkpoint = metadata_dict.get('checkpoint') + if checkpoint is not None: + # Handle both string and other types safely + if isinstance(checkpoint, str): + # Extract basename without path + checkpoint = os.path.basename(checkpoint) + # Remove extension if present + checkpoint = os.path.splitext(checkpoint)[0] + else: + # Convert non-string to string + checkpoint = str(checkpoint) + + params.append(f"Model: {checkpoint}") # Add LoRA hashes if available if lora_hashes: From c2f599b4ff2aa634f628ed9e7d53e2cfbc4fa1d9 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Thu, 17 Apr 2025 22:05:40 +0800 Subject: [PATCH 09/15] feat: Update node extractors to include UNETLoaderExtractor and enhance metadata handling for guidance parameters --- py/metadata_collector/metadata_registry.py | 14 +++++++++++--- py/metadata_collector/node_extractors.py | 16 +++++++++++++++- 2 files changed, 26 insertions(+), 4 deletions(-) diff --git a/py/metadata_collector/metadata_registry.py b/py/metadata_collector/metadata_registry.py index 9b75351f..434f1eb1 100644 --- a/py/metadata_collector/metadata_registry.py +++ b/py/metadata_collector/metadata_registry.py @@ -1,4 +1,5 @@ import time +from nodes import NODE_CLASS_MAPPINGS from .node_extractors import NODE_EXTRACTORS, GenericNodeExtractor from .constants import METADATA_CATEGORIES @@ -78,11 +79,18 @@ class MetadataRegistry: if node_id in executed_nodes: continue - class_type = node_data.get("class_type") - if not class_type: + # Get the node type from the prompt (this is the key in NODE_CLASS_MAPPINGS) + prompt_class_type = node_data.get("class_type") + if not prompt_class_type: continue - # Create cache key + # Convert to actual class name (which is what we use in our cache) + class_type = prompt_class_type + if prompt_class_type in NODE_CLASS_MAPPINGS: + class_obj = NODE_CLASS_MAPPINGS[prompt_class_type] + class_type = class_obj.__name__ + + # Create cache key using the actual class name cache_key = f"{node_id}:{class_type}" # Check if this node type is relevant for metadata collection diff --git a/py/metadata_collector/node_extractors.py b/py/metadata_collector/node_extractors.py index 0d599c93..bdda89ad 100644 --- a/py/metadata_collector/node_extractors.py +++ b/py/metadata_collector/node_extractors.py @@ -184,6 +184,20 @@ class FluxGuidanceExtractor(NodeMetadataExtractor): metadata[SAMPLING][node_id] = {"parameters": {}, "node_id": node_id} metadata[SAMPLING][node_id]["parameters"]["guidance"] = guidance_value + +class UNETLoaderExtractor(NodeMetadataExtractor): + @staticmethod + def extract(node_id, inputs, outputs, metadata): + if not inputs or "unet_name" not in inputs: + return + + model_name = inputs.get("unet_name") + if model_name: + metadata[MODELS][node_id] = { + "name": model_name, + "type": "checkpoint", + "node_id": node_id + } # Registry of node-specific extractors NODE_EXTRACTORS = { @@ -194,7 +208,7 @@ NODE_EXTRACTORS = { "EmptyLatentImage": ImageSizeExtractor, "LoraManagerLoader": LoraLoaderManagerExtractor, "SamplerCustomAdvanced": SamplerExtractor, # Add SamplerCustomAdvanced - "UNETLoader": CheckpointLoaderExtractor, # Add UNETLoader + "UNETLoader": UNETLoaderExtractor, # Updated to use dedicated extractor "FluxGuidance": FluxGuidanceExtractor, # Add FluxGuidance # Add other nodes as needed } From bccabe40c01b8922594472b4a2b211bc640dacfe Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Fri, 18 Apr 2025 05:29:36 +0800 Subject: [PATCH 10/15] feat: Enhance KSamplerAdvancedExtractor to include additional sampling parameters and update metadata processing --- py/metadata_collector/metadata_processor.py | 24 ++++++++-- py/metadata_collector/node_extractors.py | 53 +++++++++++++++++++-- 2 files changed, 68 insertions(+), 9 deletions(-) diff --git a/py/metadata_collector/metadata_processor.py b/py/metadata_collector/metadata_processor.py index 9cf2cb83..4cf72b73 100644 --- a/py/metadata_collector/metadata_processor.py +++ b/py/metadata_collector/metadata_processor.py @@ -11,15 +11,28 @@ class MetadataProcessor: primary_sampler = None primary_sampler_id = None + # First, check for KSamplerAdvanced with add_noise="enable" for node_id, sampler_info in metadata.get(SAMPLING, {}).items(): parameters = sampler_info.get("parameters", {}) - denoise = parameters.get("denoise") + add_noise = parameters.get("add_noise") - # If denoise is 1.0, this is likely the primary sampler - if denoise == 1.0 or denoise == 1: + # If add_noise is "enable", this is likely the primary sampler for KSamplerAdvanced + if add_noise == "enable": primary_sampler = sampler_info primary_sampler_id = node_id break + + # If no KSamplerAdvanced found, fall back to traditional KSampler with denoise=1 + if primary_sampler is None: + for node_id, sampler_info in metadata.get(SAMPLING, {}).items(): + parameters = sampler_info.get("parameters", {}) + denoise = parameters.get("denoise") + + # If denoise is 1.0, this is likely the primary sampler + if denoise == 1.0 or denoise == 1: + primary_sampler = sampler_info + primary_sampler_id = node_id + break return primary_sampler_id, primary_sampler @@ -109,6 +122,7 @@ class MetadataProcessor: "cfg_scale": None, "guidance": None, # Add guidance parameter "sampler": None, + "scheduler": None, "checkpoint": None, "loras": "", "size": None, @@ -129,10 +143,12 @@ class MetadataProcessor: if primary_sampler: # Extract sampling parameters sampling_params = primary_sampler.get("parameters", {}) - params["seed"] = sampling_params.get("seed") + # Handle both seed and noise_seed + params["seed"] = sampling_params.get("seed") if sampling_params.get("seed") is not None else sampling_params.get("noise_seed") params["steps"] = sampling_params.get("steps") params["cfg_scale"] = sampling_params.get("cfg") params["sampler"] = sampling_params.get("sampler_name") + params["scheduler"] = sampling_params.get("scheduler") # Trace connections from the primary sampler if prompt and primary_sampler_id: diff --git a/py/metadata_collector/node_extractors.py b/py/metadata_collector/node_extractors.py index bdda89ad..210ab29e 100644 --- a/py/metadata_collector/node_extractors.py +++ b/py/metadata_collector/node_extractors.py @@ -85,6 +85,43 @@ class SamplerExtractor(NodeMetadataExtractor): "node_id": node_id } +class KSamplerAdvancedExtractor(NodeMetadataExtractor): + @staticmethod + def extract(node_id, inputs, outputs, metadata): + if not inputs: + return + + sampling_params = {} + for key in ["noise_seed", "steps", "cfg", "sampler_name", "scheduler", "add_noise"]: + if key in inputs: + sampling_params[key] = inputs[key] + + metadata[SAMPLING][node_id] = { + "parameters": sampling_params, + "node_id": node_id + } + + # Extract latent image dimensions if available + if "latent_image" in inputs and inputs["latent_image"] is not None: + latent = inputs["latent_image"] + if isinstance(latent, dict) and "samples" in latent: + # Extract dimensions from latent tensor + samples = latent["samples"] + if hasattr(samples, "shape") and len(samples.shape) >= 3: + # Correct shape interpretation: [batch_size, channels, height/8, width/8] + # Multiply by 8 to get actual pixel dimensions + height = int(samples.shape[2] * 8) + width = int(samples.shape[3] * 8) + + if SIZE not in metadata: + metadata[SIZE] = {} + + metadata[SIZE][node_id] = { + "width": width, + "height": height, + "node_id": node_id + } + class LoraLoaderExtractor(NodeMetadataExtractor): @staticmethod def extract(node_id, inputs, outputs, metadata): @@ -201,14 +238,20 @@ class UNETLoaderExtractor(NodeMetadataExtractor): # Registry of node-specific extractors NODE_EXTRACTORS = { - "CheckpointLoaderSimple": CheckpointLoaderExtractor, - "CLIPTextEncode": CLIPTextEncodeExtractor, + # Sampling "KSampler": SamplerExtractor, - "LoraLoader": LoraLoaderExtractor, - "EmptyLatentImage": ImageSizeExtractor, - "LoraManagerLoader": LoraLoaderManagerExtractor, + "KSamplerAdvanced": KSamplerAdvancedExtractor, # Add KSamplerAdvanced "SamplerCustomAdvanced": SamplerExtractor, # Add SamplerCustomAdvanced + # Loaders + "CheckpointLoaderSimple": CheckpointLoaderExtractor, "UNETLoader": UNETLoaderExtractor, # Updated to use dedicated extractor + "LoraLoader": LoraLoaderExtractor, + "LoraManagerLoader": LoraLoaderManagerExtractor, + # Conditioning + "CLIPTextEncode": CLIPTextEncodeExtractor, + # Latent + "EmptyLatentImage": ImageSizeExtractor, + # Flux "FluxGuidance": FluxGuidanceExtractor, # Add FluxGuidance # Add other nodes as needed } From f0203c96ab854c33c197a193d3cb8d72cd48e7d4 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Fri, 18 Apr 2025 05:34:42 +0800 Subject: [PATCH 11/15] feat: Simplify format_metadata method by removing custom_prompt parameter and update related function calls --- py/nodes/save_image.py | 19 +++++-------------- 1 file changed, 5 insertions(+), 14 deletions(-) diff --git a/py/nodes/save_image.py b/py/nodes/save_image.py index 110e2072..e88546d7 100644 --- a/py/nodes/save_image.py +++ b/py/nodes/save_image.py @@ -9,7 +9,6 @@ from ..metadata_collector.metadata_processor import MetadataProcessor from ..metadata_collector import get_metadata from PIL import Image, PngImagePlugin import piexif -from io import BytesIO class SaveImage: NAME = "Save Image (LoraManager)" @@ -35,7 +34,6 @@ class SaveImage: "file_format": (["png", "jpeg", "webp"],), }, "optional": { - "custom_prompt": ("STRING", {"default": "", "forceInput": True}), "lossless_webp": ("BOOLEAN", {"default": True}), "quality": ("INT", {"default": 100, "min": 1, "max": 100}), "embed_workflow": ("BOOLEAN", {"default": False}), @@ -62,7 +60,7 @@ class SaveImage: return item.get('sha256') return None - async def format_metadata(self, metadata_dict, custom_prompt=None): + async def format_metadata(self, metadata_dict): """Format metadata in the requested format similar to userComment example""" if not metadata_dict: return "" @@ -71,10 +69,6 @@ class SaveImage: prompt = metadata_dict.get('prompt', '') negative_prompt = metadata_dict.get('negative_prompt', '') - # Override prompt with custom_prompt if provided - if custom_prompt: - prompt = custom_prompt - # Extract loras from the prompt if present loras_text = metadata_dict.get('loras', '') lora_hashes = {} @@ -256,8 +250,7 @@ class SaveImage: return filename def save_images(self, images, filename_prefix, file_format, prompt=None, extra_pnginfo=None, - lossless_webp=True, quality=100, embed_workflow=False, add_counter_to_filename=True, - custom_prompt=None): + lossless_webp=True, quality=100, embed_workflow=False, add_counter_to_filename=True): """Save images with metadata""" results = [] @@ -266,7 +259,7 @@ class SaveImage: metadata_dict = MetadataProcessor.to_dict(raw_metadata) # Get or create metadata asynchronously - metadata = asyncio.run(self.format_metadata(metadata_dict, custom_prompt)) + metadata = asyncio.run(self.format_metadata(metadata_dict)) # Process filename_prefix with pattern substitution filename_prefix = self.format_filename(filename_prefix, metadata_dict) @@ -354,8 +347,7 @@ class SaveImage: return results def process_image(self, images, filename_prefix="ComfyUI", file_format="png", prompt=None, extra_pnginfo=None, - lossless_webp=True, quality=100, embed_workflow=False, add_counter_to_filename=True, - custom_prompt=""): + lossless_webp=True, quality=100, embed_workflow=False, add_counter_to_filename=True): """Process and save image with metadata""" # Make sure the output directory exists os.makedirs(self.output_dir, exist_ok=True) @@ -376,8 +368,7 @@ class SaveImage: lossless_webp, quality, embed_workflow, - add_counter_to_filename, - custom_prompt if custom_prompt.strip() else None + add_counter_to_filename ) return (images,) \ No newline at end of file From df6d56ce667be5d734810f075deb5a5ccc253e38 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Fri, 18 Apr 2025 07:12:43 +0800 Subject: [PATCH 12/15] feat: Add IMAGES category to constants and enhance metadata handling in node extractors --- py/metadata_collector/constants.py | 3 +- py/metadata_collector/metadata_registry.py | 67 +++++++++++++++++++++- py/metadata_collector/node_extractors.py | 25 +++++++- 3 files changed, 92 insertions(+), 3 deletions(-) diff --git a/py/metadata_collector/constants.py b/py/metadata_collector/constants.py index 0ad55b47..c1109580 100644 --- a/py/metadata_collector/constants.py +++ b/py/metadata_collector/constants.py @@ -6,6 +6,7 @@ PROMPTS = "prompts" SAMPLING = "sampling" LORAS = "loras" SIZE = "size" +IMAGES = "images" # Added new category for image results # Collection of categories for iteration -METADATA_CATEGORIES = [MODELS, PROMPTS, SAMPLING, LORAS, SIZE] +METADATA_CATEGORIES = [MODELS, PROMPTS, SAMPLING, LORAS, SIZE, IMAGES] # Added IMAGES to categories diff --git a/py/metadata_collector/metadata_registry.py b/py/metadata_collector/metadata_registry.py index 434f1eb1..bcf2284a 100644 --- a/py/metadata_collector/metadata_registry.py +++ b/py/metadata_collector/metadata_registry.py @@ -1,7 +1,7 @@ import time from nodes import NODE_CLASS_MAPPINGS from .node_extractors import NODE_EXTRACTORS, GenericNodeExtractor -from .constants import METADATA_CATEGORIES +from .constants import METADATA_CATEGORIES, IMAGES class MetadataRegistry: """A singleton registry to store and retrieve workflow metadata""" @@ -23,9 +23,28 @@ class MetadataRegistry: # Node-level cache for metadata self.node_cache = {} + # Limit the number of stored prompts + self.max_prompt_history = 3 + # Categories we want to track and retrieve from cache self.metadata_categories = METADATA_CATEGORIES + def _clean_old_prompts(self): + """Clean up old prompt metadata, keeping only recent ones""" + if len(self.prompt_metadata) <= self.max_prompt_history: + return + + # Sort all prompt_ids by timestamp + sorted_prompts = sorted( + self.prompt_metadata.keys(), + key=lambda pid: self.prompt_metadata[pid].get("timestamp", 0) + ) + + # Remove oldest records + prompts_to_remove = sorted_prompts[:len(sorted_prompts) - self.max_prompt_history] + for pid in prompts_to_remove: + del self.prompt_metadata[pid] + def start_collection(self, prompt_id): """Begin metadata collection for a new prompt""" self.current_prompt_id = prompt_id @@ -39,6 +58,9 @@ class MetadataRegistry: "current_prompt": None, # Will store the prompt object "timestamp": time.time() }) + + # Clean up old prompt data + self._clean_old_prompts() def set_current_prompt(self, prompt): """Set the current prompt object reference""" @@ -177,3 +199,46 @@ class MetadataRegistry: # Save to cache if we have any metadata for this node if any(node_metadata.values()): self.node_cache[cache_key] = node_metadata + + def clear_unused_cache(self): + """Clean up node_cache entries that are no longer in use""" + # Collect all node_ids currently in prompt_metadata + active_node_ids = set() + for prompt_data in self.prompt_metadata.values(): + for category in self.metadata_categories: + if category in prompt_data: + active_node_ids.update(prompt_data[category].keys()) + + # Find cache keys that are no longer needed + keys_to_remove = [] + for cache_key in self.node_cache: + node_id = cache_key.split(':')[0] + if node_id not in active_node_ids: + keys_to_remove.append(cache_key) + + # Remove cache entries that are no longer needed + for key in keys_to_remove: + del self.node_cache[key] + + def clear_metadata(self, prompt_id=None): + """Clear metadata for a specific prompt or reset all data""" + if prompt_id is not None: + if prompt_id in self.prompt_metadata: + del self.prompt_metadata[prompt_id] + # Clean up cache after removing prompt + self.clear_unused_cache() + else: + # Reset all data + self._reset() + + def get_first_decoded_image(self, prompt_id=None): + """Get the first decoded image result""" + key = prompt_id if prompt_id is not None else self.current_prompt_id + if key not in self.prompt_metadata: + return None + + metadata = self.prompt_metadata[key] + if IMAGES in metadata and "first_decode" in metadata[IMAGES]: + return metadata[IMAGES]["first_decode"]["image"] + + return None diff --git a/py/metadata_collector/node_extractors.py b/py/metadata_collector/node_extractors.py index 210ab29e..cca73095 100644 --- a/py/metadata_collector/node_extractors.py +++ b/py/metadata_collector/node_extractors.py @@ -1,6 +1,6 @@ import os -from .constants import MODELS, PROMPTS, SAMPLING, LORAS, SIZE +from .constants import MODELS, PROMPTS, SAMPLING, LORAS, SIZE, IMAGES class NodeMetadataExtractor: @@ -235,7 +235,28 @@ class UNETLoaderExtractor(NodeMetadataExtractor): "type": "checkpoint", "node_id": node_id } + +class VAEDecodeExtractor(NodeMetadataExtractor): + @staticmethod + def extract(node_id, inputs, outputs, metadata): + pass + @staticmethod + def update(node_id, outputs, metadata): + # Check if we already have a first VAEDecode result + if IMAGES in metadata and "first_decode" in metadata[IMAGES]: + return + + # Ensure IMAGES category exists + if IMAGES not in metadata: + metadata[IMAGES] = {} + + # Save reference to the first VAEDecode result + metadata[IMAGES]["first_decode"] = { + "node_id": node_id, + "image": outputs + } + # Registry of node-specific extractors NODE_EXTRACTORS = { # Sampling @@ -253,5 +274,7 @@ NODE_EXTRACTORS = { "EmptyLatentImage": ImageSizeExtractor, # Flux "FluxGuidance": FluxGuidanceExtractor, # Add FluxGuidance + # Image + "VAEDecode": VAEDecodeExtractor, # Added VAEDecode extractor # Add other nodes as needed } From 91b4827c1d7e77728ebcbaaba512c9b63a2018a7 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Fri, 18 Apr 2025 09:24:48 +0800 Subject: [PATCH 13/15] feat: Enhance image retrieval in MetadataRegistry and update recipe routes to process images from metadata --- py/metadata_collector/metadata_registry.py | 10 ++- py/routes/recipe_routes.py | 77 +++++++++++++++++----- 2 files changed, 68 insertions(+), 19 deletions(-) diff --git a/py/metadata_collector/metadata_registry.py b/py/metadata_collector/metadata_registry.py index bcf2284a..e287c5b1 100644 --- a/py/metadata_collector/metadata_registry.py +++ b/py/metadata_collector/metadata_registry.py @@ -239,6 +239,14 @@ class MetadataRegistry: metadata = self.prompt_metadata[key] if IMAGES in metadata and "first_decode" in metadata[IMAGES]: - return metadata[IMAGES]["first_decode"]["image"] + image_data = metadata[IMAGES]["first_decode"]["image"] + + # If it's an image batch or tuple, handle various formats + if isinstance(image_data, (list, tuple)) and len(image_data) > 0: + # Return first element of list/tuple + return image_data[0] + + # If it's a tensor, return as is for processing in the route handler + return image_data return None diff --git a/py/routes/recipe_routes.py b/py/routes/recipe_routes.py index 2e13fdf8..2b97832b 100644 --- a/py/routes/recipe_routes.py +++ b/py/routes/recipe_routes.py @@ -1,5 +1,9 @@ import os import time +import numpy as np +from PIL import Image +import torch +import io import logging from aiohttp import web from typing import Dict @@ -15,6 +19,7 @@ from ..metadata_collector import get_metadata # Add MetadataCollector import from ..metadata_collector.metadata_processor import MetadataProcessor # Add MetadataProcessor import from ..utils.utils import download_civitai_image from ..services.service_registry import ServiceRegistry # Add ServiceRegistry import +from ..metadata_collector.metadata_registry import MetadataRegistry logger = logging.getLogger(__name__) @@ -657,8 +662,8 @@ class RecipeRoutes: logger.error(f"Error retrieving base models: {e}", exc_info=True) return web.json_response({ 'success': False, - 'error': str(e) - }, status=500) + 'error': str(e)} + , status=500) async def share_recipe(self, request: web.Request) -> web.Response: """Process a recipe image for sharing by adding metadata to EXIF""" @@ -795,21 +800,61 @@ class RecipeRoutes: if not metadata_dict: return web.json_response({"error": "No generation metadata found"}, status=400) - # Find the latest image in the temp directory - temp_dir = config.temp_directory - image_files = [] + # Get the most recent image from metadata registry instead of temp directory + metadata_registry = MetadataRegistry() + latest_image = metadata_registry.get_first_decoded_image() - for file in os.listdir(temp_dir): - if file.lower().endswith(('.png', '.jpg', '.jpeg', '.webp')): - file_path = os.path.join(temp_dir, file) - image_files.append((file_path, os.path.getmtime(file_path))) + if not latest_image: + return web.json_response({"error": "No recent images found to use for recipe. Try generating an image first."}, status=400) - if not image_files: - return web.json_response({"error": "No recent images found to use for recipe"}, status=400) + # Convert the image data to bytes - handle tuple and tensor cases + logger.debug(f"Image type: {type(latest_image)}") - # Sort by modification time (newest first) - image_files.sort(key=lambda x: x[1], reverse=True) - latest_image_path = image_files[0][0] + try: + # Handle the tuple case first + if isinstance(latest_image, tuple): + # Extract the tensor from the tuple + if len(latest_image) > 0: + tensor_image = latest_image[0] + else: + return web.json_response({"error": "Empty image tuple received"}, status=400) + else: + tensor_image = latest_image + + # Get the shape info for debugging + if hasattr(tensor_image, 'shape'): + shape_info = tensor_image.shape + logger.debug(f"Tensor shape: {shape_info}, dtype: {tensor_image.dtype}") + + # Convert tensor to numpy array + if isinstance(tensor_image, torch.Tensor): + image_np = tensor_image.cpu().numpy() + else: + image_np = np.array(tensor_image) + + # Handle different tensor shapes + # Case: (1, 1, H, W, 3) or (1, H, W, 3) - batch or multi-batch + if len(image_np.shape) > 3: + # Remove batch dimensions until we get to (H, W, 3) + while len(image_np.shape) > 3: + image_np = image_np[0] + + # If values are in [0, 1] range, convert to [0, 255] + if image_np.dtype == np.float32 or image_np.dtype == np.float64: + if image_np.max() <= 1.0: + image_np = (image_np * 255).astype(np.uint8) + + # Ensure image is in the right format (HWC with RGB channels) + if len(image_np.shape) == 3 and image_np.shape[2] == 3: + pil_image = Image.fromarray(image_np) + img_byte_arr = io.BytesIO() + pil_image.save(img_byte_arr, format='PNG') + image = img_byte_arr.getvalue() + else: + return web.json_response({"error": f"Cannot handle this data shape: {image_np.shape}, {image_np.dtype}"}, status=400) + except Exception as e: + logger.error(f"Error processing image data: {str(e)}", exc_info=True) + return web.json_response({"error": f"Error processing image: {str(e)}"}, status=400) # Get the lora stack from the metadata lora_stack = metadata_dict.get("loras", "") @@ -834,10 +879,6 @@ class RecipeRoutes: recipe_name = " ".join(recipe_name_parts) - # Read the image - with open(latest_image_path, 'rb') as f: - image = f.read() - # Create recipes directory if it doesn't exist recipes_dir = self.recipe_scanner.recipes_dir os.makedirs(recipes_dir, exist_ok=True) From 0734252e98b204c797d97615813ea7b1510add35 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Fri, 18 Apr 2025 10:03:26 +0800 Subject: [PATCH 14/15] feat: Enhance VAEDecodeExtractor to improve image caching and metadata handling --- py/metadata_collector/metadata_registry.py | 23 ++++++++++++++++++++++ py/metadata_collector/node_extractors.py | 12 +++++------ 2 files changed, 29 insertions(+), 6 deletions(-) diff --git a/py/metadata_collector/metadata_registry.py b/py/metadata_collector/metadata_registry.py index e287c5b1..6e33e806 100644 --- a/py/metadata_collector/metadata_registry.py +++ b/py/metadata_collector/metadata_registry.py @@ -249,4 +249,27 @@ class MetadataRegistry: # If it's a tensor, return as is for processing in the route handler return image_data + # If no image is found in the current metadata, try to find it in the cache + # This handles the case where VAEDecode was cached by ComfyUI and not executed + prompt_obj = metadata.get("current_prompt") + if prompt_obj and hasattr(prompt_obj, "original_prompt"): + original_prompt = prompt_obj.original_prompt + for node_id, node_data in original_prompt.items(): + class_type = node_data.get("class_type") + if class_type and class_type in NODE_CLASS_MAPPINGS: + class_obj = NODE_CLASS_MAPPINGS[class_type] + class_name = class_obj.__name__ + # Check if this is a VAEDecode node + if class_name == "VAEDecode": + # Try to find this node in the cache + cache_key = f"{node_id}:{class_name}" + if cache_key in self.node_cache: + cached_data = self.node_cache[cache_key] + if IMAGES in cached_data and node_id in cached_data[IMAGES]: + image_data = cached_data[IMAGES][node_id]["image"] + # Handle different image formats + if isinstance(image_data, (list, tuple)) and len(image_data) > 0: + return image_data[0] + return image_data + return None diff --git a/py/metadata_collector/node_extractors.py b/py/metadata_collector/node_extractors.py index cca73095..64dda557 100644 --- a/py/metadata_collector/node_extractors.py +++ b/py/metadata_collector/node_extractors.py @@ -243,19 +243,19 @@ class VAEDecodeExtractor(NodeMetadataExtractor): @staticmethod def update(node_id, outputs, metadata): - # Check if we already have a first VAEDecode result - if IMAGES in metadata and "first_decode" in metadata[IMAGES]: - return - # Ensure IMAGES category exists if IMAGES not in metadata: metadata[IMAGES] = {} - # Save reference to the first VAEDecode result - metadata[IMAGES]["first_decode"] = { + # Save image data under node ID index to be captured by caching mechanism + metadata[IMAGES][node_id] = { "node_id": node_id, "image": outputs } + + # Only set first_decode if it hasn't been recorded yet + if "first_decode" not in metadata[IMAGES]: + metadata[IMAGES]["first_decode"] = metadata[IMAGES][node_id] # Registry of node-specific extractors NODE_EXTRACTORS = { From 4766b45746b5fa5a284e3279b418b102d6a1b955 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Fri, 18 Apr 2025 10:52:39 +0800 Subject: [PATCH 15/15] feat: Update SaveImage node to modify default lossless_webp setting and adjust save_kwargs for image formats --- py/nodes/save_image.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/py/nodes/save_image.py b/py/nodes/save_image.py index e88546d7..003092db 100644 --- a/py/nodes/save_image.py +++ b/py/nodes/save_image.py @@ -34,7 +34,7 @@ class SaveImage: "file_format": (["png", "jpeg", "webp"],), }, "optional": { - "lossless_webp": ("BOOLEAN", {"default": True}), + "lossless_webp": ("BOOLEAN", {"default": False}), "quality": ("INT", {"default": 100, "min": 1, "max": 100}), "embed_workflow": ("BOOLEAN", {"default": False}), "add_counter_to_filename": ("BOOLEAN", {"default": True}), @@ -290,7 +290,8 @@ class SaveImage: if file_format == "png": file = base_filename + ".png" file_extension = ".png" - save_kwargs = {"optimize": True, "compress_level": self.compress_level} + # Remove "optimize": True to match built-in node behavior + save_kwargs = {"compress_level": self.compress_level} pnginfo = PngImagePlugin.PngInfo() elif file_format == "jpeg": file = base_filename + ".jpg" @@ -299,7 +300,8 @@ class SaveImage: elif file_format == "webp": file = base_filename + ".webp" file_extension = ".webp" - save_kwargs = {"quality": quality, "lossless": lossless_webp} + # Add optimization param to control performance + save_kwargs = {"quality": quality, "lossless": lossless_webp, "method": 0} # Full save path file_path = os.path.join(full_output_folder, file)