diff --git a/__init__.py b/__init__.py index 008219c8..375dfc13 100644 --- a/__init__.py +++ b/__init__.py @@ -3,16 +3,23 @@ from .py.nodes.lora_loader import LoraManagerLoader from .py.nodes.trigger_word_toggle import TriggerWordToggle from .py.nodes.lora_stacker import LoraStacker from .py.nodes.save_image import SaveImage +from .py.nodes.debug_metadata import DebugMetadata +# Import metadata collector to install hooks on startup +from .py.metadata_collector import init as init_metadata_collector NODE_CLASS_MAPPINGS = { LoraManagerLoader.NAME: LoraManagerLoader, TriggerWordToggle.NAME: TriggerWordToggle, LoraStacker.NAME: LoraStacker, - SaveImage.NAME: SaveImage + SaveImage.NAME: SaveImage, + DebugMetadata.NAME: DebugMetadata } WEB_DIRECTORY = "./web/comfyui" +# Initialize metadata collector +init_metadata_collector() + # Register routes on import LoraManager.add_routes() __all__ = ['NODE_CLASS_MAPPINGS', 'WEB_DIRECTORY'] diff --git a/py/metadata_collector/__init__.py b/py/metadata_collector/__init__.py new file mode 100644 index 00000000..3fea3a6b --- /dev/null +++ b/py/metadata_collector/__init__.py @@ -0,0 +1,18 @@ +import os +import importlib +from .metadata_hook import MetadataHook +from .metadata_registry import MetadataRegistry + +def init(): + # Install hooks to collect metadata during execution + MetadataHook.install() + + # Initialize registry + registry = MetadataRegistry() + + print("ComfyUI Metadata Collector initialized") + +def get_metadata(prompt_id=None): + """Helper function to get metadata from the registry""" + registry = MetadataRegistry() + return registry.get_metadata(prompt_id) diff --git a/py/metadata_collector/metadata_hook.py b/py/metadata_collector/metadata_hook.py new file mode 100644 index 00000000..2b6a7b6d --- /dev/null +++ b/py/metadata_collector/metadata_hook.py @@ -0,0 +1,123 @@ +import sys +import inspect +from .metadata_registry import MetadataRegistry + +class MetadataHook: + """Install hooks for metadata collection""" + + @staticmethod + def install(): + """Install hooks to collect metadata during execution""" + try: + # Import ComfyUI's execution module + execution = None + try: + # Try direct import first + import execution # type: ignore + except ImportError: + # Try to locate from system modules + for module_name in sys.modules: + if module_name.endswith('.execution'): + execution = sys.modules[module_name] + break + + # If we can't find the execution module, we can't install hooks + if execution is None: + print("Could not locate ComfyUI execution module, metadata collection disabled") + return + + # Store the original _map_node_over_list function + original_map_node_over_list = execution._map_node_over_list + + # Define the wrapped _map_node_over_list function + def map_node_over_list_with_metadata(obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None): + # Only collect metadata when calling the main function of nodes + if func == obj.FUNCTION and hasattr(obj, '__class__'): + try: + # Get the current prompt_id from the registry + registry = MetadataRegistry() + prompt_id = registry.current_prompt_id + + if prompt_id is not None: + # Get node class type + class_type = obj.__class__.__name__ + + # Unique ID might be available through the obj if it has a unique_id field + node_id = getattr(obj, 'unique_id', None) + if node_id is None and pre_execute_cb: + # Try to extract node_id through reflection on GraphBuilder.set_default_prefix + frame = inspect.currentframe() + while frame: + if 'unique_id' in frame.f_locals: + node_id = frame.f_locals['unique_id'] + break + frame = frame.f_back + + # Record inputs before execution + if node_id is not None: + registry.record_node_execution(node_id, class_type, input_data_all, None) + except Exception as e: + print(f"Error collecting metadata (pre-execution): {str(e)}") + + # Execute the original function + results = original_map_node_over_list(obj, input_data_all, func, allow_interrupt, execution_block_cb, pre_execute_cb) + + # After execution, collect outputs for relevant nodes + if func == obj.FUNCTION and hasattr(obj, '__class__'): + try: + # Get the current prompt_id from the registry + registry = MetadataRegistry() + prompt_id = registry.current_prompt_id + + if prompt_id is not None: + # Get node class type + class_type = obj.__class__.__name__ + + # Unique ID might be available through the obj if it has a unique_id field + node_id = getattr(obj, 'unique_id', None) + if node_id is None and pre_execute_cb: + # Try to extract node_id through reflection + frame = inspect.currentframe() + while frame: + if 'unique_id' in frame.f_locals: + node_id = frame.f_locals['unique_id'] + break + frame = frame.f_back + + # Record outputs after execution + if node_id is not None: + registry.update_node_execution(node_id, class_type, results) + except Exception as e: + print(f"Error collecting metadata (post-execution): {str(e)}") + + return results + + # Also hook the execute function to track the current prompt_id + original_execute = execution.execute + + def execute_with_prompt_tracking(*args, **kwargs): + if len(args) >= 7: # Check if we have enough arguments + server, prompt, caches, node_id, extra_data, executed, prompt_id = args[:7] + registry = MetadataRegistry() + + # Start collection if this is a new prompt + if not registry.current_prompt_id or registry.current_prompt_id != prompt_id: + registry.start_collection(prompt_id) + + # Store the dynprompt reference for node lookups + if hasattr(prompt, 'original_prompt'): + registry.set_current_prompt(prompt) + + # Execute the original function + return original_execute(*args, **kwargs) + + # Replace the functions + execution._map_node_over_list = map_node_over_list_with_metadata + execution.execute = execute_with_prompt_tracking + # Make map_node_over_list public to avoid it being hidden by hooks + execution.map_node_over_list = original_map_node_over_list + + print("Metadata collection hooks installed for runtime values") + + except Exception as e: + print(f"Error installing metadata hooks: {str(e)}") diff --git a/py/metadata_collector/metadata_processor.py b/py/metadata_collector/metadata_processor.py new file mode 100644 index 00000000..4721f889 --- /dev/null +++ b/py/metadata_collector/metadata_processor.py @@ -0,0 +1,171 @@ +import json + +class MetadataProcessor: + """Process and format collected metadata""" + + @staticmethod + def find_primary_sampler(metadata): + """Find the primary KSampler node (with denoise=1)""" + primary_sampler = None + primary_sampler_id = None + + for node_id, sampler_info in metadata.get("sampling", {}).items(): + parameters = sampler_info.get("parameters", {}) + denoise = parameters.get("denoise") + + # If denoise is 1.0, this is likely the primary sampler + if denoise == 1.0 or denoise == 1: + primary_sampler = sampler_info + primary_sampler_id = node_id + break + + return primary_sampler_id, primary_sampler + + @staticmethod + def trace_node_input(prompt, node_id, input_name): + """Trace an input connection from a node to find the source node""" + if not prompt or not prompt.original_prompt or node_id not in prompt.original_prompt: + return None + + node_inputs = prompt.original_prompt[node_id].get("inputs", {}) + if input_name not in node_inputs: + return None + + input_value = node_inputs[input_name] + # Input connections are formatted as [node_id, output_index] + if isinstance(input_value, list) and len(input_value) >= 2: + return input_value[0] # Return connected node_id + + return None + + @staticmethod + def find_primary_checkpoint(metadata): + """Find the primary checkpoint model in the workflow""" + if not metadata.get("models"): + return None + + # In most workflows, there's only one checkpoint, so we can just take the first one + for node_id, model_info in metadata.get("models", {}).items(): + if model_info.get("type") == "checkpoint": + return model_info.get("name") + + return None + + @staticmethod + def extract_generation_params(metadata): + """Extract generation parameters from metadata using node relationships""" + params = { + "prompt": "", + "negative_prompt": "", + "seed": None, + "steps": None, + "cfg_scale": None, + "sampler": None, + "checkpoint": None, + "loras": "", + "size": None, + "clip_skip": None + } + + # Get the prompt object for node relationship tracing + prompt = metadata.get("current_prompt") + + # Find the primary KSampler node + primary_sampler_id, primary_sampler = MetadataProcessor.find_primary_sampler(metadata) + + # Directly get checkpoint from metadata instead of tracing + checkpoint = MetadataProcessor.find_primary_checkpoint(metadata) + if checkpoint: + params["checkpoint"] = checkpoint + + if primary_sampler: + # Extract sampling parameters + sampling_params = primary_sampler.get("parameters", {}) + params["seed"] = sampling_params.get("seed") + params["steps"] = sampling_params.get("steps") + params["cfg_scale"] = sampling_params.get("cfg") + params["sampler"] = sampling_params.get("sampler_name") + + # Trace connections from the primary sampler + if prompt and primary_sampler_id: + # Trace positive prompt + positive_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "positive") + if positive_node_id and positive_node_id in metadata.get("prompts", {}): + params["prompt"] = metadata["prompts"][positive_node_id].get("text", "") + + # Trace negative prompt + negative_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "negative") + if negative_node_id and negative_node_id in metadata.get("prompts", {}): + params["negative_prompt"] = metadata["prompts"][negative_node_id].get("text", "") + + # Check if the sampler itself has size information (from latent_image) + if primary_sampler_id in metadata.get("size", {}): + width = metadata["size"][primary_sampler_id].get("width") + height = metadata["size"][primary_sampler_id].get("height") + if width and height: + params["size"] = f"{width}x{height}" + else: + # Fallback to the previous trace method if needed + latent_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "latent_image") + if latent_node_id: + # Follow chain to find EmptyLatentImage node + size_found = False + current_node_id = latent_node_id + + # Limit depth to avoid infinite loops in complex workflows + max_depth = 10 + for _ in range(max_depth): + if current_node_id in metadata.get("size", {}): + width = metadata["size"][current_node_id].get("width") + height = metadata["size"][current_node_id].get("height") + if width and height: + params["size"] = f"{width}x{height}" + size_found = True + break + + # Try to follow the chain + if prompt and prompt.original_prompt and current_node_id in prompt.original_prompt: + node_info = prompt.original_prompt[current_node_id] + if "inputs" in node_info: + # Look for a connection that might lead to size information + for input_name, input_value in node_info["inputs"].items(): + if isinstance(input_value, list) and len(input_value) >= 2: + current_node_id = input_value[0] + break + else: + break # No connections to follow + else: + break # No inputs to follow + else: + break # Can't follow further + + # Extract LoRAs + lora_parts = [] + for node_id, lora_info in metadata.get("loras", {}).items(): + name = lora_info.get("name", "unknown") + strength = lora_info.get("strength_model", 1.0) + lora_parts.append(f"") + params["loras"] = " ".join(lora_parts) + + # Set default clip_skip value + params["clip_skip"] = "1" # Common default + + return params + + @staticmethod + def to_comfyui_format(metadata): + """Convert extracted metadata to the ComfyUI output.json format""" + params = MetadataProcessor.extract_generation_params(metadata) + + # Convert all values to strings to match output.json format + for key in params: + if params[key] is not None: + params[key] = str(params[key]) + + return params + + @staticmethod + def to_json(metadata): + """Convert metadata to JSON string""" + params = MetadataProcessor.to_comfyui_format(metadata) + return json.dumps(params, indent=4) diff --git a/py/metadata_collector/metadata_registry.py b/py/metadata_collector/metadata_registry.py new file mode 100644 index 00000000..cae9a1a9 --- /dev/null +++ b/py/metadata_collector/metadata_registry.py @@ -0,0 +1,171 @@ +import time +from .node_extractors import NODE_EXTRACTORS, GenericNodeExtractor + +class MetadataRegistry: + """A singleton registry to store and retrieve workflow metadata""" + _instance = None + + def __new__(cls): + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance._reset() + return cls._instance + + def _reset(self): + self.current_prompt_id = None + self.current_prompt = None + self.metadata = {} + self.prompt_metadata = {} + self.executed_nodes = set() + + # Node-level cache for metadata + self.node_cache = {} + + # Categories we want to track and retrieve from cache + self.metadata_categories = ["models", "prompts", "sampling", "loras", "size"] + + def start_collection(self, prompt_id): + """Begin metadata collection for a new prompt""" + self.current_prompt_id = prompt_id + self.executed_nodes = set() + self.prompt_metadata[prompt_id] = { + "models": {}, + "prompts": {}, + "sampling": {}, + "loras": {}, + "size": {}, + "execution_order": [], + "current_prompt": None, # Will store the prompt object + "timestamp": time.time() + } + + def set_current_prompt(self, prompt): + """Set the current prompt object reference""" + self.current_prompt = prompt + if self.current_prompt_id and self.current_prompt_id in self.prompt_metadata: + # Store the prompt in the metadata for later relationship tracing + self.prompt_metadata[self.current_prompt_id]["current_prompt"] = prompt + + def get_metadata(self, prompt_id=None): + """Get collected metadata for a prompt""" + key = prompt_id if prompt_id is not None else self.current_prompt_id + if key not in self.prompt_metadata: + return {} + + metadata = self.prompt_metadata[key] + + # If we have a current prompt object, check for non-executed nodes + prompt_obj = metadata.get("current_prompt") + if prompt_obj and hasattr(prompt_obj, "original_prompt"): + original_prompt = prompt_obj.original_prompt + + # Fill in missing metadata from cache for nodes that weren't executed + self._fill_missing_metadata(key, original_prompt) + + return self.prompt_metadata.get(key, {}) + + def _fill_missing_metadata(self, prompt_id, original_prompt): + """Fill missing metadata from cache for non-executed nodes""" + if not original_prompt: + return + + executed_nodes = self.executed_nodes + metadata = self.prompt_metadata[prompt_id] + + # Iterate through nodes in the original prompt + for node_id, node_data in original_prompt.items(): + # Skip if already executed in this run + if node_id in executed_nodes: + continue + + class_type = node_data.get("class_type") + if not class_type: + continue + + # Create cache key + cache_key = f"{node_id}:{class_type}" + + # Check if this node type is relevant for metadata collection + if class_type in NODE_EXTRACTORS: + # Check if we have cached metadata for this node + if cache_key in self.node_cache: + cached_data = self.node_cache[cache_key] + + # Apply cached metadata to the current metadata + for category in self.metadata_categories: + if category in cached_data and node_id in cached_data[category]: + if node_id not in metadata[category]: + metadata[category][node_id] = cached_data[category][node_id] + + def record_node_execution(self, node_id, class_type, inputs, outputs): + """Record information about a node's execution""" + if not self.current_prompt_id: + return + + # Add to execution order and mark as executed + if node_id not in self.executed_nodes: + self.executed_nodes.add(node_id) + self.prompt_metadata[self.current_prompt_id]["execution_order"].append(node_id) + + # Process inputs to simplify working with them + processed_inputs = {} + for input_name, input_values in inputs.items(): + if isinstance(input_values, list) and len(input_values) > 0: + # For single values, just use the first one (most common case) + processed_inputs[input_name] = input_values[0] + else: + processed_inputs[input_name] = input_values + + # Extract node-specific metadata + extractor = NODE_EXTRACTORS.get(class_type, GenericNodeExtractor) + extractor.extract( + node_id, + processed_inputs, + outputs, + self.prompt_metadata[self.current_prompt_id] + ) + + # Cache this node's metadata + self._cache_node_metadata(node_id, class_type) + + def update_node_execution(self, node_id, class_type, outputs): + """Update node metadata with output information""" + if not self.current_prompt_id: + return + + # Process outputs to make them more usable + processed_outputs = outputs + + # Use the same extractor to update with outputs + extractor = NODE_EXTRACTORS.get(class_type, GenericNodeExtractor) + if hasattr(extractor, 'update'): + extractor.update( + node_id, + processed_outputs, + self.prompt_metadata[self.current_prompt_id] + ) + + # Update the cached metadata for this node + self._cache_node_metadata(node_id, class_type) + + def _cache_node_metadata(self, node_id, class_type): + """Cache the metadata for a specific node""" + if not self.current_prompt_id or not node_id or not class_type: + return + + # Create a cache key combining node_id and class_type + cache_key = f"{node_id}:{class_type}" + + # Create a shallow copy of the node's metadata + node_metadata = {} + current_metadata = self.prompt_metadata[self.current_prompt_id] + + for category in self.metadata_categories: + if category in current_metadata and node_id in current_metadata[category]: + if category not in node_metadata: + node_metadata[category] = {} + node_metadata[category][node_id] = current_metadata[category][node_id] + + # Save to cache if we have any metadata for this node + if any(node_metadata.values()): + self.node_cache[cache_key] = node_metadata diff --git a/py/metadata_collector/node_extractors.py b/py/metadata_collector/node_extractors.py new file mode 100644 index 00000000..78d0a81b --- /dev/null +++ b/py/metadata_collector/node_extractors.py @@ -0,0 +1,163 @@ +class NodeMetadataExtractor: + """Base class for node-specific metadata extraction""" + + @staticmethod + def extract(node_id, inputs, outputs, metadata): + """Extract metadata from node inputs/outputs""" + pass + + @staticmethod + def update(node_id, outputs, metadata): + """Update metadata with node outputs after execution""" + pass + +class GenericNodeExtractor(NodeMetadataExtractor): + """Default extractor for nodes without specific handling""" + @staticmethod + def extract(node_id, inputs, outputs, metadata): + pass + +class CheckpointLoaderExtractor(NodeMetadataExtractor): + @staticmethod + def extract(node_id, inputs, outputs, metadata): + if not inputs or "ckpt_name" not in inputs: + return + + model_name = inputs.get("ckpt_name") + if model_name: + metadata["models"][node_id] = { + "name": model_name, + "type": "checkpoint", + "node_id": node_id + } + +class CLIPTextEncodeExtractor(NodeMetadataExtractor): + @staticmethod + def extract(node_id, inputs, outputs, metadata): + if not inputs or "text" not in inputs: + return + + text = inputs.get("text", "") + metadata["prompts"][node_id] = { + "text": text, + "node_id": node_id + } + +class SamplerExtractor(NodeMetadataExtractor): + @staticmethod + def extract(node_id, inputs, outputs, metadata): + if not inputs: + return + + sampling_params = {} + for key in ["seed", "steps", "cfg", "sampler_name", "scheduler", "denoise"]: + if key in inputs: + sampling_params[key] = inputs[key] + + metadata["sampling"][node_id] = { + "parameters": sampling_params, + "node_id": node_id + } + + # Extract latent image dimensions if available + if "latent_image" in inputs and inputs["latent_image"] is not None: + latent = inputs["latent_image"] + if isinstance(latent, dict) and "samples" in latent: + # Extract dimensions from latent tensor + samples = latent["samples"] + if hasattr(samples, "shape") and len(samples.shape) >= 3: + # Correct shape interpretation: [batch_size, channels, height/8, width/8] + # Multiply by 8 to get actual pixel dimensions + height = int(samples.shape[2] * 8) + width = int(samples.shape[3] * 8) + + if "size" not in metadata: + metadata["size"] = {} + + metadata["size"][node_id] = { + "width": width, + "height": height, + "node_id": node_id + } + +class LoraLoaderExtractor(NodeMetadataExtractor): + @staticmethod + def extract(node_id, inputs, outputs, metadata): + if not inputs or "lora_name" not in inputs: + return + + lora_name = inputs.get("lora_name") + strength_model = inputs.get("strength_model", 1.0) + strength_clip = inputs.get("strength_clip", 1.0) + + metadata["loras"][node_id] = { + "name": lora_name, + "strength_model": strength_model, + "strength_clip": strength_clip, + "node_id": node_id + } + +class ImageSizeExtractor(NodeMetadataExtractor): + @staticmethod + def extract(node_id, inputs, outputs, metadata): + if not inputs: + return + + width = inputs.get("width", 512) + height = inputs.get("height", 512) + + if "size" not in metadata: + metadata["size"] = {} + + metadata["size"][node_id] = { + "width": width, + "height": height, + "node_id": node_id + } + +class LoraLoaderManagerExtractor(NodeMetadataExtractor): + @staticmethod + def extract(node_id, inputs, outputs, metadata): + if not inputs: + return + + # Handle LoraManager nodes which might store loras differently + if "loras" in inputs: + loras = inputs.get("loras", []) + if isinstance(loras, list): + active_loras = [] + # Filter for active loras (may be a list of dicts with 'active' flag) + for lora in loras: + if isinstance(lora, dict) and lora.get("active", True) and not lora.get("_isDummy", False): + active_loras.append({ + "name": lora.get("name", ""), + "strength": lora.get("strength", 1.0) + }) + + if active_loras: + metadata["loras"][node_id] = { + "lora_list": active_loras, + "node_id": node_id + } + + # If there's a direct text field with lora definitions + if "text" in inputs: + text = inputs.get("text", "") + if text and "