diff --git a/__init__.py b/__init__.py index 008219c8..375dfc13 100644 --- a/__init__.py +++ b/__init__.py @@ -3,16 +3,23 @@ from .py.nodes.lora_loader import LoraManagerLoader from .py.nodes.trigger_word_toggle import TriggerWordToggle from .py.nodes.lora_stacker import LoraStacker from .py.nodes.save_image import SaveImage +from .py.nodes.debug_metadata import DebugMetadata +# Import metadata collector to install hooks on startup +from .py.metadata_collector import init as init_metadata_collector NODE_CLASS_MAPPINGS = { LoraManagerLoader.NAME: LoraManagerLoader, TriggerWordToggle.NAME: TriggerWordToggle, LoraStacker.NAME: LoraStacker, - SaveImage.NAME: SaveImage + SaveImage.NAME: SaveImage, + DebugMetadata.NAME: DebugMetadata } WEB_DIRECTORY = "./web/comfyui" +# Initialize metadata collector +init_metadata_collector() + # Register routes on import LoraManager.add_routes() __all__ = ['NODE_CLASS_MAPPINGS', 'WEB_DIRECTORY'] diff --git a/py/metadata_collector/__init__.py b/py/metadata_collector/__init__.py new file mode 100644 index 00000000..3fea3a6b --- /dev/null +++ b/py/metadata_collector/__init__.py @@ -0,0 +1,18 @@ +import os +import importlib +from .metadata_hook import MetadataHook +from .metadata_registry import MetadataRegistry + +def init(): + # Install hooks to collect metadata during execution + MetadataHook.install() + + # Initialize registry + registry = MetadataRegistry() + + print("ComfyUI Metadata Collector initialized") + +def get_metadata(prompt_id=None): + """Helper function to get metadata from the registry""" + registry = MetadataRegistry() + return registry.get_metadata(prompt_id) diff --git a/py/metadata_collector/constants.py b/py/metadata_collector/constants.py new file mode 100644 index 00000000..c1109580 --- /dev/null +++ b/py/metadata_collector/constants.py @@ -0,0 +1,12 @@ +"""Constants used by the metadata collector""" + +# Individual category constants +MODELS = "models" +PROMPTS = "prompts" +SAMPLING = "sampling" +LORAS = "loras" +SIZE = "size" +IMAGES = "images" # Added new category for image results + +# Collection of categories for iteration +METADATA_CATEGORIES = [MODELS, PROMPTS, SAMPLING, LORAS, SIZE, IMAGES] # Added IMAGES to categories diff --git a/py/metadata_collector/metadata_hook.py b/py/metadata_collector/metadata_hook.py new file mode 100644 index 00000000..2b6a7b6d --- /dev/null +++ b/py/metadata_collector/metadata_hook.py @@ -0,0 +1,123 @@ +import sys +import inspect +from .metadata_registry import MetadataRegistry + +class MetadataHook: + """Install hooks for metadata collection""" + + @staticmethod + def install(): + """Install hooks to collect metadata during execution""" + try: + # Import ComfyUI's execution module + execution = None + try: + # Try direct import first + import execution # type: ignore + except ImportError: + # Try to locate from system modules + for module_name in sys.modules: + if module_name.endswith('.execution'): + execution = sys.modules[module_name] + break + + # If we can't find the execution module, we can't install hooks + if execution is None: + print("Could not locate ComfyUI execution module, metadata collection disabled") + return + + # Store the original _map_node_over_list function + original_map_node_over_list = execution._map_node_over_list + + # Define the wrapped _map_node_over_list function + def map_node_over_list_with_metadata(obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None): + # Only collect metadata when calling the main function of nodes + if func == obj.FUNCTION and hasattr(obj, '__class__'): + try: + # Get the current prompt_id from the registry + registry = MetadataRegistry() + prompt_id = registry.current_prompt_id + + if prompt_id is not None: + # Get node class type + class_type = obj.__class__.__name__ + + # Unique ID might be available through the obj if it has a unique_id field + node_id = getattr(obj, 'unique_id', None) + if node_id is None and pre_execute_cb: + # Try to extract node_id through reflection on GraphBuilder.set_default_prefix + frame = inspect.currentframe() + while frame: + if 'unique_id' in frame.f_locals: + node_id = frame.f_locals['unique_id'] + break + frame = frame.f_back + + # Record inputs before execution + if node_id is not None: + registry.record_node_execution(node_id, class_type, input_data_all, None) + except Exception as e: + print(f"Error collecting metadata (pre-execution): {str(e)}") + + # Execute the original function + results = original_map_node_over_list(obj, input_data_all, func, allow_interrupt, execution_block_cb, pre_execute_cb) + + # After execution, collect outputs for relevant nodes + if func == obj.FUNCTION and hasattr(obj, '__class__'): + try: + # Get the current prompt_id from the registry + registry = MetadataRegistry() + prompt_id = registry.current_prompt_id + + if prompt_id is not None: + # Get node class type + class_type = obj.__class__.__name__ + + # Unique ID might be available through the obj if it has a unique_id field + node_id = getattr(obj, 'unique_id', None) + if node_id is None and pre_execute_cb: + # Try to extract node_id through reflection + frame = inspect.currentframe() + while frame: + if 'unique_id' in frame.f_locals: + node_id = frame.f_locals['unique_id'] + break + frame = frame.f_back + + # Record outputs after execution + if node_id is not None: + registry.update_node_execution(node_id, class_type, results) + except Exception as e: + print(f"Error collecting metadata (post-execution): {str(e)}") + + return results + + # Also hook the execute function to track the current prompt_id + original_execute = execution.execute + + def execute_with_prompt_tracking(*args, **kwargs): + if len(args) >= 7: # Check if we have enough arguments + server, prompt, caches, node_id, extra_data, executed, prompt_id = args[:7] + registry = MetadataRegistry() + + # Start collection if this is a new prompt + if not registry.current_prompt_id or registry.current_prompt_id != prompt_id: + registry.start_collection(prompt_id) + + # Store the dynprompt reference for node lookups + if hasattr(prompt, 'original_prompt'): + registry.set_current_prompt(prompt) + + # Execute the original function + return original_execute(*args, **kwargs) + + # Replace the functions + execution._map_node_over_list = map_node_over_list_with_metadata + execution.execute = execute_with_prompt_tracking + # Make map_node_over_list public to avoid it being hidden by hooks + execution.map_node_over_list = original_map_node_over_list + + print("Metadata collection hooks installed for runtime values") + + except Exception as e: + print(f"Error installing metadata hooks: {str(e)}") diff --git a/py/metadata_collector/metadata_processor.py b/py/metadata_collector/metadata_processor.py new file mode 100644 index 00000000..4cf72b73 --- /dev/null +++ b/py/metadata_collector/metadata_processor.py @@ -0,0 +1,245 @@ +import json + +from .constants import MODELS, PROMPTS, SAMPLING, LORAS, SIZE + +class MetadataProcessor: + """Process and format collected metadata""" + + @staticmethod + def find_primary_sampler(metadata): + """Find the primary KSampler node (with denoise=1)""" + primary_sampler = None + primary_sampler_id = None + + # First, check for KSamplerAdvanced with add_noise="enable" + for node_id, sampler_info in metadata.get(SAMPLING, {}).items(): + parameters = sampler_info.get("parameters", {}) + add_noise = parameters.get("add_noise") + + # If add_noise is "enable", this is likely the primary sampler for KSamplerAdvanced + if add_noise == "enable": + primary_sampler = sampler_info + primary_sampler_id = node_id + break + + # If no KSamplerAdvanced found, fall back to traditional KSampler with denoise=1 + if primary_sampler is None: + for node_id, sampler_info in metadata.get(SAMPLING, {}).items(): + parameters = sampler_info.get("parameters", {}) + denoise = parameters.get("denoise") + + # If denoise is 1.0, this is likely the primary sampler + if denoise == 1.0 or denoise == 1: + primary_sampler = sampler_info + primary_sampler_id = node_id + break + + return primary_sampler_id, primary_sampler + + @staticmethod + def trace_node_input(prompt, node_id, input_name, target_class=None, max_depth=10): + """ + Trace an input connection from a node to find the source node + + Parameters: + - prompt: The prompt object containing node connections + - node_id: ID of the starting node + - input_name: Name of the input to trace + - target_class: Optional class name to search for (e.g., "CLIPTextEncode") + - max_depth: Maximum depth to follow the node chain to prevent infinite loops + + Returns: + - node_id of the found node, or None if not found + """ + if not prompt or not prompt.original_prompt or node_id not in prompt.original_prompt: + return None + + # For depth tracking + current_depth = 0 + + current_node_id = node_id + current_input = input_name + + while current_depth < max_depth: + if current_node_id not in prompt.original_prompt: + return None + + node_inputs = prompt.original_prompt[current_node_id].get("inputs", {}) + if current_input not in node_inputs: + return None + + input_value = node_inputs[current_input] + # Input connections are formatted as [node_id, output_index] + if isinstance(input_value, list) and len(input_value) >= 2: + found_node_id = input_value[0] # Connected node_id + + # If we're looking for a specific node class + if target_class and prompt.original_prompt[found_node_id].get("class_type") == target_class: + return found_node_id + + # If we're not looking for a specific class or haven't found it yet + if not target_class: + return found_node_id + + # Continue tracing through intermediate nodes + current_node_id = found_node_id + # For most conditioning nodes, the input we want to follow is named "conditioning" + if "conditioning" in prompt.original_prompt[current_node_id].get("inputs", {}): + current_input = "conditioning" + else: + # If there's no "conditioning" input, we can't trace further + return found_node_id if not target_class else None + else: + # We've reached a node with no further connections + return None + + current_depth += 1 + + # If we've reached max depth without finding target_class + return None + + @staticmethod + def find_primary_checkpoint(metadata): + """Find the primary checkpoint model in the workflow""" + if not metadata.get(MODELS): + return None + + # In most workflows, there's only one checkpoint, so we can just take the first one + for node_id, model_info in metadata.get(MODELS, {}).items(): + if model_info.get("type") == "checkpoint": + return model_info.get("name") + + return None + + @staticmethod + def extract_generation_params(metadata): + """Extract generation parameters from metadata using node relationships""" + params = { + "prompt": "", + "negative_prompt": "", + "seed": None, + "steps": None, + "cfg_scale": None, + "guidance": None, # Add guidance parameter + "sampler": None, + "scheduler": None, + "checkpoint": None, + "loras": "", + "size": None, + "clip_skip": None + } + + # Get the prompt object for node relationship tracing + prompt = metadata.get("current_prompt") + + # Find the primary KSampler node + primary_sampler_id, primary_sampler = MetadataProcessor.find_primary_sampler(metadata) + + # Directly get checkpoint from metadata instead of tracing + checkpoint = MetadataProcessor.find_primary_checkpoint(metadata) + if checkpoint: + params["checkpoint"] = checkpoint + + if primary_sampler: + # Extract sampling parameters + sampling_params = primary_sampler.get("parameters", {}) + # Handle both seed and noise_seed + params["seed"] = sampling_params.get("seed") if sampling_params.get("seed") is not None else sampling_params.get("noise_seed") + params["steps"] = sampling_params.get("steps") + params["cfg_scale"] = sampling_params.get("cfg") + params["sampler"] = sampling_params.get("sampler_name") + params["scheduler"] = sampling_params.get("scheduler") + + # Trace connections from the primary sampler + if prompt and primary_sampler_id: + # Trace positive prompt - look specifically for CLIPTextEncode + positive_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "positive", "CLIPTextEncode", max_depth=10) + if positive_node_id and positive_node_id in metadata.get(PROMPTS, {}): + params["prompt"] = metadata[PROMPTS][positive_node_id].get("text", "") + + # Find any FluxGuidance nodes in the positive conditioning path + flux_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "positive", "FluxGuidance", max_depth=5) + if flux_node_id and flux_node_id in metadata.get(SAMPLING, {}): + flux_params = metadata[SAMPLING][flux_node_id].get("parameters", {}) + params["guidance"] = flux_params.get("guidance") + + # Trace negative prompt - look specifically for CLIPTextEncode + negative_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "negative", "CLIPTextEncode", max_depth=10) + if negative_node_id and negative_node_id in metadata.get(PROMPTS, {}): + params["negative_prompt"] = metadata[PROMPTS][negative_node_id].get("text", "") + + # Check if the sampler itself has size information (from latent_image) + if primary_sampler_id in metadata.get(SIZE, {}): + width = metadata[SIZE][primary_sampler_id].get("width") + height = metadata[SIZE][primary_sampler_id].get("height") + if width and height: + params["size"] = f"{width}x{height}" + else: + # Fallback to the previous trace method if needed + latent_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "latent_image") + if latent_node_id: + # Follow chain to find EmptyLatentImage node + size_found = False + current_node_id = latent_node_id + + # Limit depth to avoid infinite loops in complex workflows + max_depth = 10 + for _ in range(max_depth): + if current_node_id in metadata.get(SIZE, {}): + width = metadata[SIZE][current_node_id].get("width") + height = metadata[SIZE][current_node_id].get("height") + if width and height: + params["size"] = f"{width}x{height}" + size_found = True + break + + # Try to follow the chain + if prompt and prompt.original_prompt and current_node_id in prompt.original_prompt: + node_info = prompt.original_prompt[current_node_id] + if "inputs" in node_info: + # Look for a connection that might lead to size information + for input_name, input_value in node_info["inputs"].items(): + if isinstance(input_value, list) and len(input_value) >= 2: + current_node_id = input_value[0] + break + else: + break # No connections to follow + else: + break # No inputs to follow + else: + break # Can't follow further + + # Extract LoRAs using the standardized format + lora_parts = [] + for node_id, lora_info in metadata.get(LORAS, {}).items(): + # Access the lora_list from the standardized format + lora_list = lora_info.get("lora_list", []) + for lora in lora_list: + name = lora.get("name", "unknown") + strength = lora.get("strength", 1.0) + lora_parts.append(f"") + + params["loras"] = " ".join(lora_parts) + + # Set default clip_skip value + params["clip_skip"] = "1" # Common default + + return params + + @staticmethod + def to_dict(metadata): + """Convert extracted metadata to the ComfyUI output.json format""" + params = MetadataProcessor.extract_generation_params(metadata) + + # Convert all values to strings to match output.json format + for key in params: + if params[key] is not None: + params[key] = str(params[key]) + + return params + + @staticmethod + def to_json(metadata): + """Convert metadata to JSON string""" + params = MetadataProcessor.to_dict(metadata) + return json.dumps(params, indent=4) diff --git a/py/metadata_collector/metadata_registry.py b/py/metadata_collector/metadata_registry.py new file mode 100644 index 00000000..6e33e806 --- /dev/null +++ b/py/metadata_collector/metadata_registry.py @@ -0,0 +1,275 @@ +import time +from nodes import NODE_CLASS_MAPPINGS +from .node_extractors import NODE_EXTRACTORS, GenericNodeExtractor +from .constants import METADATA_CATEGORIES, IMAGES + +class MetadataRegistry: + """A singleton registry to store and retrieve workflow metadata""" + _instance = None + + def __new__(cls): + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance._reset() + return cls._instance + + def _reset(self): + self.current_prompt_id = None + self.current_prompt = None + self.metadata = {} + self.prompt_metadata = {} + self.executed_nodes = set() + + # Node-level cache for metadata + self.node_cache = {} + + # Limit the number of stored prompts + self.max_prompt_history = 3 + + # Categories we want to track and retrieve from cache + self.metadata_categories = METADATA_CATEGORIES + + def _clean_old_prompts(self): + """Clean up old prompt metadata, keeping only recent ones""" + if len(self.prompt_metadata) <= self.max_prompt_history: + return + + # Sort all prompt_ids by timestamp + sorted_prompts = sorted( + self.prompt_metadata.keys(), + key=lambda pid: self.prompt_metadata[pid].get("timestamp", 0) + ) + + # Remove oldest records + prompts_to_remove = sorted_prompts[:len(sorted_prompts) - self.max_prompt_history] + for pid in prompts_to_remove: + del self.prompt_metadata[pid] + + def start_collection(self, prompt_id): + """Begin metadata collection for a new prompt""" + self.current_prompt_id = prompt_id + self.executed_nodes = set() + self.prompt_metadata[prompt_id] = { + category: {} for category in METADATA_CATEGORIES + } + # Add additional metadata fields + self.prompt_metadata[prompt_id].update({ + "execution_order": [], + "current_prompt": None, # Will store the prompt object + "timestamp": time.time() + }) + + # Clean up old prompt data + self._clean_old_prompts() + + def set_current_prompt(self, prompt): + """Set the current prompt object reference""" + self.current_prompt = prompt + if self.current_prompt_id and self.current_prompt_id in self.prompt_metadata: + # Store the prompt in the metadata for later relationship tracing + self.prompt_metadata[self.current_prompt_id]["current_prompt"] = prompt + + def get_metadata(self, prompt_id=None): + """Get collected metadata for a prompt""" + key = prompt_id if prompt_id is not None else self.current_prompt_id + if key not in self.prompt_metadata: + return {} + + metadata = self.prompt_metadata[key] + + # If we have a current prompt object, check for non-executed nodes + prompt_obj = metadata.get("current_prompt") + if prompt_obj and hasattr(prompt_obj, "original_prompt"): + original_prompt = prompt_obj.original_prompt + + # Fill in missing metadata from cache for nodes that weren't executed + self._fill_missing_metadata(key, original_prompt) + + return self.prompt_metadata.get(key, {}) + + def _fill_missing_metadata(self, prompt_id, original_prompt): + """Fill missing metadata from cache for non-executed nodes""" + if not original_prompt: + return + + executed_nodes = self.executed_nodes + metadata = self.prompt_metadata[prompt_id] + + # Iterate through nodes in the original prompt + for node_id, node_data in original_prompt.items(): + # Skip if already executed in this run + if node_id in executed_nodes: + continue + + # Get the node type from the prompt (this is the key in NODE_CLASS_MAPPINGS) + prompt_class_type = node_data.get("class_type") + if not prompt_class_type: + continue + + # Convert to actual class name (which is what we use in our cache) + class_type = prompt_class_type + if prompt_class_type in NODE_CLASS_MAPPINGS: + class_obj = NODE_CLASS_MAPPINGS[prompt_class_type] + class_type = class_obj.__name__ + + # Create cache key using the actual class name + cache_key = f"{node_id}:{class_type}" + + # Check if this node type is relevant for metadata collection + if class_type in NODE_EXTRACTORS: + # Check if we have cached metadata for this node + if cache_key in self.node_cache: + cached_data = self.node_cache[cache_key] + + # Apply cached metadata to the current metadata + for category in self.metadata_categories: + if category in cached_data and node_id in cached_data[category]: + if node_id not in metadata[category]: + metadata[category][node_id] = cached_data[category][node_id] + + def record_node_execution(self, node_id, class_type, inputs, outputs): + """Record information about a node's execution""" + if not self.current_prompt_id: + return + + # Add to execution order and mark as executed + if node_id not in self.executed_nodes: + self.executed_nodes.add(node_id) + self.prompt_metadata[self.current_prompt_id]["execution_order"].append(node_id) + + # Process inputs to simplify working with them + processed_inputs = {} + for input_name, input_values in inputs.items(): + if isinstance(input_values, list) and len(input_values) > 0: + # For single values, just use the first one (most common case) + processed_inputs[input_name] = input_values[0] + else: + processed_inputs[input_name] = input_values + + # Extract node-specific metadata + extractor = NODE_EXTRACTORS.get(class_type, GenericNodeExtractor) + extractor.extract( + node_id, + processed_inputs, + outputs, + self.prompt_metadata[self.current_prompt_id] + ) + + # Cache this node's metadata + self._cache_node_metadata(node_id, class_type) + + def update_node_execution(self, node_id, class_type, outputs): + """Update node metadata with output information""" + if not self.current_prompt_id: + return + + # Process outputs to make them more usable + processed_outputs = outputs + + # Use the same extractor to update with outputs + extractor = NODE_EXTRACTORS.get(class_type, GenericNodeExtractor) + if hasattr(extractor, 'update'): + extractor.update( + node_id, + processed_outputs, + self.prompt_metadata[self.current_prompt_id] + ) + + # Update the cached metadata for this node + self._cache_node_metadata(node_id, class_type) + + def _cache_node_metadata(self, node_id, class_type): + """Cache the metadata for a specific node""" + if not self.current_prompt_id or not node_id or not class_type: + return + + # Create a cache key combining node_id and class_type + cache_key = f"{node_id}:{class_type}" + + # Create a shallow copy of the node's metadata + node_metadata = {} + current_metadata = self.prompt_metadata[self.current_prompt_id] + + for category in self.metadata_categories: + if category in current_metadata and node_id in current_metadata[category]: + if category not in node_metadata: + node_metadata[category] = {} + node_metadata[category][node_id] = current_metadata[category][node_id] + + # Save to cache if we have any metadata for this node + if any(node_metadata.values()): + self.node_cache[cache_key] = node_metadata + + def clear_unused_cache(self): + """Clean up node_cache entries that are no longer in use""" + # Collect all node_ids currently in prompt_metadata + active_node_ids = set() + for prompt_data in self.prompt_metadata.values(): + for category in self.metadata_categories: + if category in prompt_data: + active_node_ids.update(prompt_data[category].keys()) + + # Find cache keys that are no longer needed + keys_to_remove = [] + for cache_key in self.node_cache: + node_id = cache_key.split(':')[0] + if node_id not in active_node_ids: + keys_to_remove.append(cache_key) + + # Remove cache entries that are no longer needed + for key in keys_to_remove: + del self.node_cache[key] + + def clear_metadata(self, prompt_id=None): + """Clear metadata for a specific prompt or reset all data""" + if prompt_id is not None: + if prompt_id in self.prompt_metadata: + del self.prompt_metadata[prompt_id] + # Clean up cache after removing prompt + self.clear_unused_cache() + else: + # Reset all data + self._reset() + + def get_first_decoded_image(self, prompt_id=None): + """Get the first decoded image result""" + key = prompt_id if prompt_id is not None else self.current_prompt_id + if key not in self.prompt_metadata: + return None + + metadata = self.prompt_metadata[key] + if IMAGES in metadata and "first_decode" in metadata[IMAGES]: + image_data = metadata[IMAGES]["first_decode"]["image"] + + # If it's an image batch or tuple, handle various formats + if isinstance(image_data, (list, tuple)) and len(image_data) > 0: + # Return first element of list/tuple + return image_data[0] + + # If it's a tensor, return as is for processing in the route handler + return image_data + + # If no image is found in the current metadata, try to find it in the cache + # This handles the case where VAEDecode was cached by ComfyUI and not executed + prompt_obj = metadata.get("current_prompt") + if prompt_obj and hasattr(prompt_obj, "original_prompt"): + original_prompt = prompt_obj.original_prompt + for node_id, node_data in original_prompt.items(): + class_type = node_data.get("class_type") + if class_type and class_type in NODE_CLASS_MAPPINGS: + class_obj = NODE_CLASS_MAPPINGS[class_type] + class_name = class_obj.__name__ + # Check if this is a VAEDecode node + if class_name == "VAEDecode": + # Try to find this node in the cache + cache_key = f"{node_id}:{class_name}" + if cache_key in self.node_cache: + cached_data = self.node_cache[cache_key] + if IMAGES in cached_data and node_id in cached_data[IMAGES]: + image_data = cached_data[IMAGES][node_id]["image"] + # Handle different image formats + if isinstance(image_data, (list, tuple)) and len(image_data) > 0: + return image_data[0] + return image_data + + return None diff --git a/py/metadata_collector/node_extractors.py b/py/metadata_collector/node_extractors.py new file mode 100644 index 00000000..64dda557 --- /dev/null +++ b/py/metadata_collector/node_extractors.py @@ -0,0 +1,280 @@ +import os + +from .constants import MODELS, PROMPTS, SAMPLING, LORAS, SIZE, IMAGES + + +class NodeMetadataExtractor: + """Base class for node-specific metadata extraction""" + + @staticmethod + def extract(node_id, inputs, outputs, metadata): + """Extract metadata from node inputs/outputs""" + pass + + @staticmethod + def update(node_id, outputs, metadata): + """Update metadata with node outputs after execution""" + pass + +class GenericNodeExtractor(NodeMetadataExtractor): + """Default extractor for nodes without specific handling""" + @staticmethod + def extract(node_id, inputs, outputs, metadata): + pass + +class CheckpointLoaderExtractor(NodeMetadataExtractor): + @staticmethod + def extract(node_id, inputs, outputs, metadata): + if not inputs or "ckpt_name" not in inputs: + return + + model_name = inputs.get("ckpt_name") + if model_name: + metadata[MODELS][node_id] = { + "name": model_name, + "type": "checkpoint", + "node_id": node_id + } + +class CLIPTextEncodeExtractor(NodeMetadataExtractor): + @staticmethod + def extract(node_id, inputs, outputs, metadata): + if not inputs or "text" not in inputs: + return + + text = inputs.get("text", "") + metadata[PROMPTS][node_id] = { + "text": text, + "node_id": node_id + } + +class SamplerExtractor(NodeMetadataExtractor): + @staticmethod + def extract(node_id, inputs, outputs, metadata): + if not inputs: + return + + sampling_params = {} + for key in ["seed", "steps", "cfg", "sampler_name", "scheduler", "denoise"]: + if key in inputs: + sampling_params[key] = inputs[key] + + metadata[SAMPLING][node_id] = { + "parameters": sampling_params, + "node_id": node_id + } + + # Extract latent image dimensions if available + if "latent_image" in inputs and inputs["latent_image"] is not None: + latent = inputs["latent_image"] + if isinstance(latent, dict) and "samples" in latent: + # Extract dimensions from latent tensor + samples = latent["samples"] + if hasattr(samples, "shape") and len(samples.shape) >= 3: + # Correct shape interpretation: [batch_size, channels, height/8, width/8] + # Multiply by 8 to get actual pixel dimensions + height = int(samples.shape[2] * 8) + width = int(samples.shape[3] * 8) + + if SIZE not in metadata: + metadata[SIZE] = {} + + metadata[SIZE][node_id] = { + "width": width, + "height": height, + "node_id": node_id + } + +class KSamplerAdvancedExtractor(NodeMetadataExtractor): + @staticmethod + def extract(node_id, inputs, outputs, metadata): + if not inputs: + return + + sampling_params = {} + for key in ["noise_seed", "steps", "cfg", "sampler_name", "scheduler", "add_noise"]: + if key in inputs: + sampling_params[key] = inputs[key] + + metadata[SAMPLING][node_id] = { + "parameters": sampling_params, + "node_id": node_id + } + + # Extract latent image dimensions if available + if "latent_image" in inputs and inputs["latent_image"] is not None: + latent = inputs["latent_image"] + if isinstance(latent, dict) and "samples" in latent: + # Extract dimensions from latent tensor + samples = latent["samples"] + if hasattr(samples, "shape") and len(samples.shape) >= 3: + # Correct shape interpretation: [batch_size, channels, height/8, width/8] + # Multiply by 8 to get actual pixel dimensions + height = int(samples.shape[2] * 8) + width = int(samples.shape[3] * 8) + + if SIZE not in metadata: + metadata[SIZE] = {} + + metadata[SIZE][node_id] = { + "width": width, + "height": height, + "node_id": node_id + } + +class LoraLoaderExtractor(NodeMetadataExtractor): + @staticmethod + def extract(node_id, inputs, outputs, metadata): + if not inputs or "lora_name" not in inputs: + return + + lora_name = inputs.get("lora_name") + # Extract base filename without extension from path + lora_name = os.path.splitext(os.path.basename(lora_name))[0] + strength_model = round(float(inputs.get("strength_model", 1.0)), 2) + + # Use the standardized format with lora_list + metadata[LORAS][node_id] = { + "lora_list": [ + { + "name": lora_name, + "strength": strength_model + } + ], + "node_id": node_id + } + +class ImageSizeExtractor(NodeMetadataExtractor): + @staticmethod + def extract(node_id, inputs, outputs, metadata): + if not inputs: + return + + width = inputs.get("width", 512) + height = inputs.get("height", 512) + + if SIZE not in metadata: + metadata[SIZE] = {} + + metadata[SIZE][node_id] = { + "width": width, + "height": height, + "node_id": node_id + } + +class LoraLoaderManagerExtractor(NodeMetadataExtractor): + @staticmethod + def extract(node_id, inputs, outputs, metadata): + if not inputs: + return + + active_loras = [] + + # Process lora_stack if available + if "lora_stack" in inputs: + lora_stack = inputs.get("lora_stack", []) + for lora_path, model_strength, clip_strength in lora_stack: + # Extract lora name from path (following the format in lora_loader.py) + lora_name = os.path.splitext(os.path.basename(lora_path))[0] + active_loras.append({ + "name": lora_name, + "strength": model_strength + }) + + # Process loras from inputs + if "loras" in inputs: + loras_data = inputs.get("loras", []) + + # Handle new format: {'loras': {'__value__': [...]}} + if isinstance(loras_data, dict) and '__value__' in loras_data: + loras_list = loras_data['__value__'] + # Handle old format: {'loras': [...]} + elif isinstance(loras_data, list): + loras_list = loras_data + else: + loras_list = [] + + # Filter for active loras + for lora in loras_list: + if isinstance(lora, dict) and lora.get("active", True) and not lora.get("_isDummy", False): + active_loras.append({ + "name": lora.get("name", ""), + "strength": float(lora.get("strength", 1.0)) + }) + + if active_loras: + metadata[LORAS][node_id] = { + "lora_list": active_loras, + "node_id": node_id + } + +class FluxGuidanceExtractor(NodeMetadataExtractor): + @staticmethod + def extract(node_id, inputs, outputs, metadata): + if not inputs or "guidance" not in inputs: + return + + guidance_value = inputs.get("guidance") + + # Store the guidance value in SAMPLING category + if node_id not in metadata[SAMPLING]: + metadata[SAMPLING][node_id] = {"parameters": {}, "node_id": node_id} + + metadata[SAMPLING][node_id]["parameters"]["guidance"] = guidance_value + +class UNETLoaderExtractor(NodeMetadataExtractor): + @staticmethod + def extract(node_id, inputs, outputs, metadata): + if not inputs or "unet_name" not in inputs: + return + + model_name = inputs.get("unet_name") + if model_name: + metadata[MODELS][node_id] = { + "name": model_name, + "type": "checkpoint", + "node_id": node_id + } + +class VAEDecodeExtractor(NodeMetadataExtractor): + @staticmethod + def extract(node_id, inputs, outputs, metadata): + pass + + @staticmethod + def update(node_id, outputs, metadata): + # Ensure IMAGES category exists + if IMAGES not in metadata: + metadata[IMAGES] = {} + + # Save image data under node ID index to be captured by caching mechanism + metadata[IMAGES][node_id] = { + "node_id": node_id, + "image": outputs + } + + # Only set first_decode if it hasn't been recorded yet + if "first_decode" not in metadata[IMAGES]: + metadata[IMAGES]["first_decode"] = metadata[IMAGES][node_id] + +# Registry of node-specific extractors +NODE_EXTRACTORS = { + # Sampling + "KSampler": SamplerExtractor, + "KSamplerAdvanced": KSamplerAdvancedExtractor, # Add KSamplerAdvanced + "SamplerCustomAdvanced": SamplerExtractor, # Add SamplerCustomAdvanced + # Loaders + "CheckpointLoaderSimple": CheckpointLoaderExtractor, + "UNETLoader": UNETLoaderExtractor, # Updated to use dedicated extractor + "LoraLoader": LoraLoaderExtractor, + "LoraManagerLoader": LoraLoaderManagerExtractor, + # Conditioning + "CLIPTextEncode": CLIPTextEncodeExtractor, + # Latent + "EmptyLatentImage": ImageSizeExtractor, + # Flux + "FluxGuidance": FluxGuidanceExtractor, # Add FluxGuidance + # Image + "VAEDecode": VAEDecodeExtractor, # Added VAEDecode extractor + # Add other nodes as needed +} diff --git a/py/nodes/debug_metadata.py b/py/nodes/debug_metadata.py new file mode 100644 index 00000000..ee13e3d8 --- /dev/null +++ b/py/nodes/debug_metadata.py @@ -0,0 +1,35 @@ +import logging +from ..metadata_collector.metadata_processor import MetadataProcessor + +logger = logging.getLogger(__name__) + +class DebugMetadata: + NAME = "Debug Metadata (LoraManager)" + CATEGORY = "Lora Manager/utils" + DESCRIPTION = "Debug node to verify metadata_processor functionality" + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "images": ("IMAGE",), + }, + } + + RETURN_TYPES = ("STRING",) + RETURN_NAMES = ("metadata_json",) + FUNCTION = "process_metadata" + + def process_metadata(self, images): + try: + # Get the current execution context's metadata + from ..metadata_collector import get_metadata + metadata = get_metadata() + + # Use the MetadataProcessor to convert it to JSON string + metadata_json = MetadataProcessor.to_json(metadata) + + return (metadata_json,) + except Exception as e: + logger.error(f"Error processing metadata: {e}") + return ("{}",) # Return empty JSON object in case of error diff --git a/py/nodes/save_image.py b/py/nodes/save_image.py index 31c6695e..003092db 100644 --- a/py/nodes/save_image.py +++ b/py/nodes/save_image.py @@ -5,10 +5,10 @@ import re import numpy as np import folder_paths # type: ignore from ..services.lora_scanner import LoraScanner -from ..workflow.parser import WorkflowParser +from ..metadata_collector.metadata_processor import MetadataProcessor +from ..metadata_collector import get_metadata from PIL import Image, PngImagePlugin import piexif -from io import BytesIO class SaveImage: NAME = "Save Image (LoraManager)" @@ -34,8 +34,7 @@ class SaveImage: "file_format": (["png", "jpeg", "webp"],), }, "optional": { - "custom_prompt": ("STRING", {"default": "", "forceInput": True}), - "lossless_webp": ("BOOLEAN", {"default": True}), + "lossless_webp": ("BOOLEAN", {"default": False}), "quality": ("INT", {"default": 100, "min": 1, "max": 100}), "embed_workflow": ("BOOLEAN", {"default": False}), "add_counter_to_filename": ("BOOLEAN", {"default": True}), @@ -61,21 +60,17 @@ class SaveImage: return item.get('sha256') return None - async def format_metadata(self, parsed_workflow, custom_prompt=None): + async def format_metadata(self, metadata_dict): """Format metadata in the requested format similar to userComment example""" - if not parsed_workflow: + if not metadata_dict: return "" # Extract the prompt and negative prompt - prompt = parsed_workflow.get('prompt', '') - negative_prompt = parsed_workflow.get('negative_prompt', '') - - # Override prompt with custom_prompt if provided - if custom_prompt: - prompt = custom_prompt + prompt = metadata_dict.get('prompt', '') + negative_prompt = metadata_dict.get('negative_prompt', '') # Extract loras from the prompt if present - loras_text = parsed_workflow.get('loras', '') + loras_text = metadata_dict.get('loras', '') lora_hashes = {} # If loras are found, add them on a new line after the prompt @@ -104,11 +99,11 @@ class SaveImage: params = [] # Add standard parameters in the correct order - if 'steps' in parsed_workflow: - params.append(f"Steps: {parsed_workflow.get('steps')}") + if 'steps' in metadata_dict: + params.append(f"Steps: {metadata_dict.get('steps')}") - if 'sampler' in parsed_workflow: - sampler = parsed_workflow.get('sampler') + if 'sampler' in metadata_dict: + sampler = metadata_dict.get('sampler') # Convert ComfyUI sampler names to user-friendly names sampler_mapping = { 'euler': 'Euler', @@ -130,8 +125,8 @@ class SaveImage: sampler_name = sampler_mapping.get(sampler, sampler) params.append(f"Sampler: {sampler_name}") - if 'scheduler' in parsed_workflow: - scheduler = parsed_workflow.get('scheduler') + if 'scheduler' in metadata_dict: + scheduler = metadata_dict.get('scheduler') scheduler_mapping = { 'normal': 'Simple', 'karras': 'Karras', @@ -142,27 +137,36 @@ class SaveImage: scheduler_name = scheduler_mapping.get(scheduler, scheduler) params.append(f"Schedule type: {scheduler_name}") - # CFG scale (cfg in parsed_workflow) - if 'cfg_scale' in parsed_workflow: - params.append(f"CFG scale: {parsed_workflow.get('cfg_scale')}") - elif 'cfg' in parsed_workflow: - params.append(f"CFG scale: {parsed_workflow.get('cfg')}") + # CFG scale (cfg_scale in metadata_dict) + if 'cfg_scale' in metadata_dict: + params.append(f"CFG scale: {metadata_dict.get('cfg_scale')}") + elif 'cfg' in metadata_dict: + params.append(f"CFG scale: {metadata_dict.get('cfg')}") # Seed - if 'seed' in parsed_workflow: - params.append(f"Seed: {parsed_workflow.get('seed')}") + if 'seed' in metadata_dict: + params.append(f"Seed: {metadata_dict.get('seed')}") # Size - if 'size' in parsed_workflow: - params.append(f"Size: {parsed_workflow.get('size')}") + if 'size' in metadata_dict: + params.append(f"Size: {metadata_dict.get('size')}") # Model info - if 'checkpoint' in parsed_workflow: - # Extract basename without path - checkpoint = os.path.basename(parsed_workflow.get('checkpoint', '')) - # Remove extension if present - checkpoint = os.path.splitext(checkpoint)[0] - params.append(f"Model: {checkpoint}") + if 'checkpoint' in metadata_dict: + # Ensure checkpoint is a string before processing + checkpoint = metadata_dict.get('checkpoint') + if checkpoint is not None: + # Handle both string and other types safely + if isinstance(checkpoint, str): + # Extract basename without path + checkpoint = os.path.basename(checkpoint) + # Remove extension if present + checkpoint = os.path.splitext(checkpoint)[0] + else: + # Convert non-string to string + checkpoint = str(checkpoint) + + params.append(f"Model: {checkpoint}") # Add LoRA hashes if available if lora_hashes: @@ -181,9 +185,9 @@ class SaveImage: # credit to nkchocoai # Add format_filename method to handle pattern substitution - def format_filename(self, filename, parsed_workflow): + def format_filename(self, filename, metadata_dict): """Format filename with metadata values""" - if not parsed_workflow: + if not metadata_dict: return filename result = re.findall(self.pattern_format, filename) @@ -191,30 +195,30 @@ class SaveImage: parts = segment.replace("%", "").split(":") key = parts[0] - if key == "seed" and 'seed' in parsed_workflow: - filename = filename.replace(segment, str(parsed_workflow.get('seed', ''))) - elif key == "width" and 'size' in parsed_workflow: - size = parsed_workflow.get('size', 'x') + if key == "seed" and 'seed' in metadata_dict: + filename = filename.replace(segment, str(metadata_dict.get('seed', ''))) + elif key == "width" and 'size' in metadata_dict: + size = metadata_dict.get('size', 'x') w = size.split('x')[0] if isinstance(size, str) else size[0] filename = filename.replace(segment, str(w)) - elif key == "height" and 'size' in parsed_workflow: - size = parsed_workflow.get('size', 'x') + elif key == "height" and 'size' in metadata_dict: + size = metadata_dict.get('size', 'x') h = size.split('x')[1] if isinstance(size, str) else size[1] filename = filename.replace(segment, str(h)) - elif key == "pprompt" and 'prompt' in parsed_workflow: - prompt = parsed_workflow.get('prompt', '').replace("\n", " ") + elif key == "pprompt" and 'prompt' in metadata_dict: + prompt = metadata_dict.get('prompt', '').replace("\n", " ") if len(parts) >= 2: length = int(parts[1]) prompt = prompt[:length] filename = filename.replace(segment, prompt.strip()) - elif key == "nprompt" and 'negative_prompt' in parsed_workflow: - prompt = parsed_workflow.get('negative_prompt', '').replace("\n", " ") + elif key == "nprompt" and 'negative_prompt' in metadata_dict: + prompt = metadata_dict.get('negative_prompt', '').replace("\n", " ") if len(parts) >= 2: length = int(parts[1]) prompt = prompt[:length] filename = filename.replace(segment, prompt.strip()) - elif key == "model" and 'checkpoint' in parsed_workflow: - model = parsed_workflow.get('checkpoint', '') + elif key == "model" and 'checkpoint' in metadata_dict: + model = metadata_dict.get('checkpoint', '') model = os.path.splitext(os.path.basename(model))[0] if len(parts) >= 2: length = int(parts[1]) @@ -246,23 +250,19 @@ class SaveImage: return filename def save_images(self, images, filename_prefix, file_format, prompt=None, extra_pnginfo=None, - lossless_webp=True, quality=100, embed_workflow=False, add_counter_to_filename=True, - custom_prompt=None): + lossless_webp=True, quality=100, embed_workflow=False, add_counter_to_filename=True): """Save images with metadata""" results = [] - # Parse the workflow using the WorkflowParser - parser = WorkflowParser() - if prompt: - parsed_workflow = parser.parse_workflow(prompt) - else: - parsed_workflow = {} + # Get metadata using the metadata collector + raw_metadata = get_metadata() + metadata_dict = MetadataProcessor.to_dict(raw_metadata) # Get or create metadata asynchronously - metadata = asyncio.run(self.format_metadata(parsed_workflow, custom_prompt)) + metadata = asyncio.run(self.format_metadata(metadata_dict)) # Process filename_prefix with pattern substitution - filename_prefix = self.format_filename(filename_prefix, parsed_workflow) + filename_prefix = self.format_filename(filename_prefix, metadata_dict) # Get initial save path info once for the batch full_output_folder, filename, counter, subfolder, processed_prefix = folder_paths.get_save_image_path( @@ -290,7 +290,8 @@ class SaveImage: if file_format == "png": file = base_filename + ".png" file_extension = ".png" - save_kwargs = {"optimize": True, "compress_level": self.compress_level} + # Remove "optimize": True to match built-in node behavior + save_kwargs = {"compress_level": self.compress_level} pnginfo = PngImagePlugin.PngInfo() elif file_format == "jpeg": file = base_filename + ".jpg" @@ -299,7 +300,8 @@ class SaveImage: elif file_format == "webp": file = base_filename + ".webp" file_extension = ".webp" - save_kwargs = {"quality": quality, "lossless": lossless_webp} + # Add optimization param to control performance + save_kwargs = {"quality": quality, "lossless": lossless_webp, "method": 0} # Full save path file_path = os.path.join(full_output_folder, file) @@ -347,8 +349,7 @@ class SaveImage: return results def process_image(self, images, filename_prefix="ComfyUI", file_format="png", prompt=None, extra_pnginfo=None, - lossless_webp=True, quality=100, embed_workflow=False, add_counter_to_filename=True, - custom_prompt=""): + lossless_webp=True, quality=100, embed_workflow=False, add_counter_to_filename=True): """Process and save image with metadata""" # Make sure the output directory exists os.makedirs(self.output_dir, exist_ok=True) @@ -369,8 +370,7 @@ class SaveImage: lossless_webp, quality, embed_workflow, - add_counter_to_filename, - custom_prompt if custom_prompt.strip() else None + add_counter_to_filename ) return (images,) \ No newline at end of file diff --git a/py/routes/recipe_routes.py b/py/routes/recipe_routes.py index 48328537..2b97832b 100644 --- a/py/routes/recipe_routes.py +++ b/py/routes/recipe_routes.py @@ -1,5 +1,9 @@ import os import time +import numpy as np +from PIL import Image +import torch +import io import logging from aiohttp import web from typing import Dict @@ -11,9 +15,11 @@ from ..utils.recipe_parsers import RecipeParserFactory from ..utils.constants import CARD_PREVIEW_WIDTH from ..config import config -from ..workflow.parser import WorkflowParser +from ..metadata_collector import get_metadata # Add MetadataCollector import +from ..metadata_collector.metadata_processor import MetadataProcessor # Add MetadataProcessor import from ..utils.utils import download_civitai_image from ..services.service_registry import ServiceRegistry # Add ServiceRegistry import +from ..metadata_collector.metadata_registry import MetadataRegistry logger = logging.getLogger(__name__) @@ -24,7 +30,7 @@ class RecipeRoutes: # Initialize service references as None, will be set during async init self.recipe_scanner = None self.civitai_client = None - self.parser = WorkflowParser() + # Remove WorkflowParser instance # Pre-warm the cache self._init_cache_task = None @@ -656,8 +662,8 @@ class RecipeRoutes: logger.error(f"Error retrieving base models: {e}", exc_info=True) return web.json_response({ 'success': False, - 'error': str(e) - }, status=500) + 'error': str(e)} + , status=500) async def share_recipe(self, request: web.Request) -> web.Response: """Process a recipe image for sharing by adding metadata to EXIF""" @@ -786,50 +792,72 @@ class RecipeRoutes: # Ensure services are initialized await self.init_services() - reader = await request.multipart() + # Get metadata using the metadata collector instead of workflow parsing + raw_metadata = get_metadata() + metadata_dict = MetadataProcessor.to_dict(raw_metadata) - # Process form data - workflow_json = None + # Check if we have valid metadata + if not metadata_dict: + return web.json_response({"error": "No generation metadata found"}, status=400) - while True: - field = await reader.next() - if field is None: - break + # Get the most recent image from metadata registry instead of temp directory + metadata_registry = MetadataRegistry() + latest_image = metadata_registry.get_first_decoded_image() + + if not latest_image: + return web.json_response({"error": "No recent images found to use for recipe. Try generating an image first."}, status=400) + + # Convert the image data to bytes - handle tuple and tensor cases + logger.debug(f"Image type: {type(latest_image)}") + + try: + # Handle the tuple case first + if isinstance(latest_image, tuple): + # Extract the tensor from the tuple + if len(latest_image) > 0: + tensor_image = latest_image[0] + else: + return web.json_response({"error": "Empty image tuple received"}, status=400) + else: + tensor_image = latest_image - if field.name == 'workflow_json': - workflow_text = await field.text() - try: - workflow_json = json.loads(workflow_text) - except: - return web.json_response({"error": "Invalid workflow JSON"}, status=400) + # Get the shape info for debugging + if hasattr(tensor_image, 'shape'): + shape_info = tensor_image.shape + logger.debug(f"Tensor shape: {shape_info}, dtype: {tensor_image.dtype}") + + # Convert tensor to numpy array + if isinstance(tensor_image, torch.Tensor): + image_np = tensor_image.cpu().numpy() + else: + image_np = np.array(tensor_image) + + # Handle different tensor shapes + # Case: (1, 1, H, W, 3) or (1, H, W, 3) - batch or multi-batch + if len(image_np.shape) > 3: + # Remove batch dimensions until we get to (H, W, 3) + while len(image_np.shape) > 3: + image_np = image_np[0] + + # If values are in [0, 1] range, convert to [0, 255] + if image_np.dtype == np.float32 or image_np.dtype == np.float64: + if image_np.max() <= 1.0: + image_np = (image_np * 255).astype(np.uint8) + + # Ensure image is in the right format (HWC with RGB channels) + if len(image_np.shape) == 3 and image_np.shape[2] == 3: + pil_image = Image.fromarray(image_np) + img_byte_arr = io.BytesIO() + pil_image.save(img_byte_arr, format='PNG') + image = img_byte_arr.getvalue() + else: + return web.json_response({"error": f"Cannot handle this data shape: {image_np.shape}, {image_np.dtype}"}, status=400) + except Exception as e: + logger.error(f"Error processing image data: {str(e)}", exc_info=True) + return web.json_response({"error": f"Error processing image: {str(e)}"}, status=400) - if not workflow_json: - return web.json_response({"error": "Missing workflow JSON"}, status=400) - - # Find the latest image in the temp directory - temp_dir = config.temp_directory - image_files = [] - - for file in os.listdir(temp_dir): - if file.lower().endswith(('.png', '.jpg', '.jpeg', '.webp')): - file_path = os.path.join(temp_dir, file) - image_files.append((file_path, os.path.getmtime(file_path))) - - if not image_files: - return web.json_response({"error": "No recent images found to use for recipe"}, status=400) - - # Sort by modification time (newest first) - image_files.sort(key=lambda x: x[1], reverse=True) - latest_image_path = image_files[0][0] - - # Parse the workflow to extract generation parameters and loras - parsed_workflow = self.parser.parse_workflow(workflow_json) - - if not parsed_workflow: - return web.json_response({"error": "Could not extract parameters from workflow"}, status=400) - - # Get the lora stack from the parsed workflow - lora_stack = parsed_workflow.get("loras", "") + # Get the lora stack from the metadata + lora_stack = metadata_dict.get("loras", "") # Parse the lora stack format: " ..." import re @@ -837,7 +865,7 @@ class RecipeRoutes: # Check if any loras were found if not lora_matches: - return web.json_response({"error": "No LoRAs found in the workflow"}, status=400) + return web.json_response({"error": "No LoRAs found in the generation metadata"}, status=400) # Generate recipe name from the first 3 loras (or less if fewer are available) loras_for_name = lora_matches[:3] # Take at most 3 loras for the name @@ -851,10 +879,6 @@ class RecipeRoutes: recipe_name = " ".join(recipe_name_parts) - # Read the image - with open(latest_image_path, 'rb') as f: - image = f.read() - # Create recipes directory if it doesn't exist recipes_dir = self.recipe_scanner.recipes_dir os.makedirs(recipes_dir, exist_ok=True) @@ -922,8 +946,8 @@ class RecipeRoutes: "created_date": time.time(), "base_model": most_common_base_model, "loras": loras_data, - "checkpoint": parsed_workflow.get("checkpoint", ""), - "gen_params": {key: value for key, value in parsed_workflow.items() + "checkpoint": metadata_dict.get("checkpoint", ""), + "gen_params": {key: value for key, value in metadata_dict.items() if key not in ['checkpoint', 'loras']}, "loras_stack": lora_stack # Include the original lora stack string } diff --git a/web/comfyui/loras_widget.js b/web/comfyui/loras_widget.js index 0166d014..a5210234 100644 --- a/web/comfyui/loras_widget.js +++ b/web/comfyui/loras_widget.js @@ -966,9 +966,6 @@ export function addLorasWidget(node, name, opts, callback) { // Function to directly save the recipe without dialog async function saveRecipeDirectly(widget) { try { - // Get the workflow data from the ComfyUI app - const prompt = await app.graphToPrompt(); - // Show loading toast if (app && app.extensionManager && app.extensionManager.toast) { app.extensionManager.toast.add({ @@ -979,14 +976,9 @@ async function saveRecipeDirectly(widget) { }); } - // Prepare the data - only send workflow JSON - const formData = new FormData(); - formData.append('workflow_json', JSON.stringify(prompt.output)); - - // Send the request + // Send the request to the backend API without workflow data const response = await fetch('/api/recipes/save-from-widget', { - method: 'POST', - body: formData + method: 'POST' }); const result = await response.json();