Merge pull request #111 from willmiao/dev

Dev
This commit is contained in:
pixelpaws
2025-04-18 10:55:07 +08:00
committed by GitHub
11 changed files with 1137 additions and 126 deletions

View File

@@ -3,16 +3,23 @@ from .py.nodes.lora_loader import LoraManagerLoader
from .py.nodes.trigger_word_toggle import TriggerWordToggle
from .py.nodes.lora_stacker import LoraStacker
from .py.nodes.save_image import SaveImage
from .py.nodes.debug_metadata import DebugMetadata
# Import metadata collector to install hooks on startup
from .py.metadata_collector import init as init_metadata_collector
NODE_CLASS_MAPPINGS = {
LoraManagerLoader.NAME: LoraManagerLoader,
TriggerWordToggle.NAME: TriggerWordToggle,
LoraStacker.NAME: LoraStacker,
SaveImage.NAME: SaveImage
SaveImage.NAME: SaveImage,
DebugMetadata.NAME: DebugMetadata
}
WEB_DIRECTORY = "./web/comfyui"
# Initialize metadata collector
init_metadata_collector()
# Register routes on import
LoraManager.add_routes()
__all__ = ['NODE_CLASS_MAPPINGS', 'WEB_DIRECTORY']

View File

@@ -0,0 +1,18 @@
import os
import importlib
from .metadata_hook import MetadataHook
from .metadata_registry import MetadataRegistry
def init():
# Install hooks to collect metadata during execution
MetadataHook.install()
# Initialize registry
registry = MetadataRegistry()
print("ComfyUI Metadata Collector initialized")
def get_metadata(prompt_id=None):
"""Helper function to get metadata from the registry"""
registry = MetadataRegistry()
return registry.get_metadata(prompt_id)

View File

@@ -0,0 +1,12 @@
"""Constants used by the metadata collector"""
# Individual category constants
MODELS = "models"
PROMPTS = "prompts"
SAMPLING = "sampling"
LORAS = "loras"
SIZE = "size"
IMAGES = "images" # Added new category for image results
# Collection of categories for iteration
METADATA_CATEGORIES = [MODELS, PROMPTS, SAMPLING, LORAS, SIZE, IMAGES] # Added IMAGES to categories

View File

@@ -0,0 +1,123 @@
import sys
import inspect
from .metadata_registry import MetadataRegistry
class MetadataHook:
"""Install hooks for metadata collection"""
@staticmethod
def install():
"""Install hooks to collect metadata during execution"""
try:
# Import ComfyUI's execution module
execution = None
try:
# Try direct import first
import execution # type: ignore
except ImportError:
# Try to locate from system modules
for module_name in sys.modules:
if module_name.endswith('.execution'):
execution = sys.modules[module_name]
break
# If we can't find the execution module, we can't install hooks
if execution is None:
print("Could not locate ComfyUI execution module, metadata collection disabled")
return
# Store the original _map_node_over_list function
original_map_node_over_list = execution._map_node_over_list
# Define the wrapped _map_node_over_list function
def map_node_over_list_with_metadata(obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None):
# Only collect metadata when calling the main function of nodes
if func == obj.FUNCTION and hasattr(obj, '__class__'):
try:
# Get the current prompt_id from the registry
registry = MetadataRegistry()
prompt_id = registry.current_prompt_id
if prompt_id is not None:
# Get node class type
class_type = obj.__class__.__name__
# Unique ID might be available through the obj if it has a unique_id field
node_id = getattr(obj, 'unique_id', None)
if node_id is None and pre_execute_cb:
# Try to extract node_id through reflection on GraphBuilder.set_default_prefix
frame = inspect.currentframe()
while frame:
if 'unique_id' in frame.f_locals:
node_id = frame.f_locals['unique_id']
break
frame = frame.f_back
# Record inputs before execution
if node_id is not None:
registry.record_node_execution(node_id, class_type, input_data_all, None)
except Exception as e:
print(f"Error collecting metadata (pre-execution): {str(e)}")
# Execute the original function
results = original_map_node_over_list(obj, input_data_all, func, allow_interrupt, execution_block_cb, pre_execute_cb)
# After execution, collect outputs for relevant nodes
if func == obj.FUNCTION and hasattr(obj, '__class__'):
try:
# Get the current prompt_id from the registry
registry = MetadataRegistry()
prompt_id = registry.current_prompt_id
if prompt_id is not None:
# Get node class type
class_type = obj.__class__.__name__
# Unique ID might be available through the obj if it has a unique_id field
node_id = getattr(obj, 'unique_id', None)
if node_id is None and pre_execute_cb:
# Try to extract node_id through reflection
frame = inspect.currentframe()
while frame:
if 'unique_id' in frame.f_locals:
node_id = frame.f_locals['unique_id']
break
frame = frame.f_back
# Record outputs after execution
if node_id is not None:
registry.update_node_execution(node_id, class_type, results)
except Exception as e:
print(f"Error collecting metadata (post-execution): {str(e)}")
return results
# Also hook the execute function to track the current prompt_id
original_execute = execution.execute
def execute_with_prompt_tracking(*args, **kwargs):
if len(args) >= 7: # Check if we have enough arguments
server, prompt, caches, node_id, extra_data, executed, prompt_id = args[:7]
registry = MetadataRegistry()
# Start collection if this is a new prompt
if not registry.current_prompt_id or registry.current_prompt_id != prompt_id:
registry.start_collection(prompt_id)
# Store the dynprompt reference for node lookups
if hasattr(prompt, 'original_prompt'):
registry.set_current_prompt(prompt)
# Execute the original function
return original_execute(*args, **kwargs)
# Replace the functions
execution._map_node_over_list = map_node_over_list_with_metadata
execution.execute = execute_with_prompt_tracking
# Make map_node_over_list public to avoid it being hidden by hooks
execution.map_node_over_list = original_map_node_over_list
print("Metadata collection hooks installed for runtime values")
except Exception as e:
print(f"Error installing metadata hooks: {str(e)}")

View File

@@ -0,0 +1,245 @@
import json
from .constants import MODELS, PROMPTS, SAMPLING, LORAS, SIZE
class MetadataProcessor:
"""Process and format collected metadata"""
@staticmethod
def find_primary_sampler(metadata):
"""Find the primary KSampler node (with denoise=1)"""
primary_sampler = None
primary_sampler_id = None
# First, check for KSamplerAdvanced with add_noise="enable"
for node_id, sampler_info in metadata.get(SAMPLING, {}).items():
parameters = sampler_info.get("parameters", {})
add_noise = parameters.get("add_noise")
# If add_noise is "enable", this is likely the primary sampler for KSamplerAdvanced
if add_noise == "enable":
primary_sampler = sampler_info
primary_sampler_id = node_id
break
# If no KSamplerAdvanced found, fall back to traditional KSampler with denoise=1
if primary_sampler is None:
for node_id, sampler_info in metadata.get(SAMPLING, {}).items():
parameters = sampler_info.get("parameters", {})
denoise = parameters.get("denoise")
# If denoise is 1.0, this is likely the primary sampler
if denoise == 1.0 or denoise == 1:
primary_sampler = sampler_info
primary_sampler_id = node_id
break
return primary_sampler_id, primary_sampler
@staticmethod
def trace_node_input(prompt, node_id, input_name, target_class=None, max_depth=10):
"""
Trace an input connection from a node to find the source node
Parameters:
- prompt: The prompt object containing node connections
- node_id: ID of the starting node
- input_name: Name of the input to trace
- target_class: Optional class name to search for (e.g., "CLIPTextEncode")
- max_depth: Maximum depth to follow the node chain to prevent infinite loops
Returns:
- node_id of the found node, or None if not found
"""
if not prompt or not prompt.original_prompt or node_id not in prompt.original_prompt:
return None
# For depth tracking
current_depth = 0
current_node_id = node_id
current_input = input_name
while current_depth < max_depth:
if current_node_id not in prompt.original_prompt:
return None
node_inputs = prompt.original_prompt[current_node_id].get("inputs", {})
if current_input not in node_inputs:
return None
input_value = node_inputs[current_input]
# Input connections are formatted as [node_id, output_index]
if isinstance(input_value, list) and len(input_value) >= 2:
found_node_id = input_value[0] # Connected node_id
# If we're looking for a specific node class
if target_class and prompt.original_prompt[found_node_id].get("class_type") == target_class:
return found_node_id
# If we're not looking for a specific class or haven't found it yet
if not target_class:
return found_node_id
# Continue tracing through intermediate nodes
current_node_id = found_node_id
# For most conditioning nodes, the input we want to follow is named "conditioning"
if "conditioning" in prompt.original_prompt[current_node_id].get("inputs", {}):
current_input = "conditioning"
else:
# If there's no "conditioning" input, we can't trace further
return found_node_id if not target_class else None
else:
# We've reached a node with no further connections
return None
current_depth += 1
# If we've reached max depth without finding target_class
return None
@staticmethod
def find_primary_checkpoint(metadata):
"""Find the primary checkpoint model in the workflow"""
if not metadata.get(MODELS):
return None
# In most workflows, there's only one checkpoint, so we can just take the first one
for node_id, model_info in metadata.get(MODELS, {}).items():
if model_info.get("type") == "checkpoint":
return model_info.get("name")
return None
@staticmethod
def extract_generation_params(metadata):
"""Extract generation parameters from metadata using node relationships"""
params = {
"prompt": "",
"negative_prompt": "",
"seed": None,
"steps": None,
"cfg_scale": None,
"guidance": None, # Add guidance parameter
"sampler": None,
"scheduler": None,
"checkpoint": None,
"loras": "",
"size": None,
"clip_skip": None
}
# Get the prompt object for node relationship tracing
prompt = metadata.get("current_prompt")
# Find the primary KSampler node
primary_sampler_id, primary_sampler = MetadataProcessor.find_primary_sampler(metadata)
# Directly get checkpoint from metadata instead of tracing
checkpoint = MetadataProcessor.find_primary_checkpoint(metadata)
if checkpoint:
params["checkpoint"] = checkpoint
if primary_sampler:
# Extract sampling parameters
sampling_params = primary_sampler.get("parameters", {})
# Handle both seed and noise_seed
params["seed"] = sampling_params.get("seed") if sampling_params.get("seed") is not None else sampling_params.get("noise_seed")
params["steps"] = sampling_params.get("steps")
params["cfg_scale"] = sampling_params.get("cfg")
params["sampler"] = sampling_params.get("sampler_name")
params["scheduler"] = sampling_params.get("scheduler")
# Trace connections from the primary sampler
if prompt and primary_sampler_id:
# Trace positive prompt - look specifically for CLIPTextEncode
positive_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "positive", "CLIPTextEncode", max_depth=10)
if positive_node_id and positive_node_id in metadata.get(PROMPTS, {}):
params["prompt"] = metadata[PROMPTS][positive_node_id].get("text", "")
# Find any FluxGuidance nodes in the positive conditioning path
flux_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "positive", "FluxGuidance", max_depth=5)
if flux_node_id and flux_node_id in metadata.get(SAMPLING, {}):
flux_params = metadata[SAMPLING][flux_node_id].get("parameters", {})
params["guidance"] = flux_params.get("guidance")
# Trace negative prompt - look specifically for CLIPTextEncode
negative_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "negative", "CLIPTextEncode", max_depth=10)
if negative_node_id and negative_node_id in metadata.get(PROMPTS, {}):
params["negative_prompt"] = metadata[PROMPTS][negative_node_id].get("text", "")
# Check if the sampler itself has size information (from latent_image)
if primary_sampler_id in metadata.get(SIZE, {}):
width = metadata[SIZE][primary_sampler_id].get("width")
height = metadata[SIZE][primary_sampler_id].get("height")
if width and height:
params["size"] = f"{width}x{height}"
else:
# Fallback to the previous trace method if needed
latent_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "latent_image")
if latent_node_id:
# Follow chain to find EmptyLatentImage node
size_found = False
current_node_id = latent_node_id
# Limit depth to avoid infinite loops in complex workflows
max_depth = 10
for _ in range(max_depth):
if current_node_id in metadata.get(SIZE, {}):
width = metadata[SIZE][current_node_id].get("width")
height = metadata[SIZE][current_node_id].get("height")
if width and height:
params["size"] = f"{width}x{height}"
size_found = True
break
# Try to follow the chain
if prompt and prompt.original_prompt and current_node_id in prompt.original_prompt:
node_info = prompt.original_prompt[current_node_id]
if "inputs" in node_info:
# Look for a connection that might lead to size information
for input_name, input_value in node_info["inputs"].items():
if isinstance(input_value, list) and len(input_value) >= 2:
current_node_id = input_value[0]
break
else:
break # No connections to follow
else:
break # No inputs to follow
else:
break # Can't follow further
# Extract LoRAs using the standardized format
lora_parts = []
for node_id, lora_info in metadata.get(LORAS, {}).items():
# Access the lora_list from the standardized format
lora_list = lora_info.get("lora_list", [])
for lora in lora_list:
name = lora.get("name", "unknown")
strength = lora.get("strength", 1.0)
lora_parts.append(f"<lora:{name}:{strength}>")
params["loras"] = " ".join(lora_parts)
# Set default clip_skip value
params["clip_skip"] = "1" # Common default
return params
@staticmethod
def to_dict(metadata):
"""Convert extracted metadata to the ComfyUI output.json format"""
params = MetadataProcessor.extract_generation_params(metadata)
# Convert all values to strings to match output.json format
for key in params:
if params[key] is not None:
params[key] = str(params[key])
return params
@staticmethod
def to_json(metadata):
"""Convert metadata to JSON string"""
params = MetadataProcessor.to_dict(metadata)
return json.dumps(params, indent=4)

View File

@@ -0,0 +1,275 @@
import time
from nodes import NODE_CLASS_MAPPINGS
from .node_extractors import NODE_EXTRACTORS, GenericNodeExtractor
from .constants import METADATA_CATEGORIES, IMAGES
class MetadataRegistry:
"""A singleton registry to store and retrieve workflow metadata"""
_instance = None
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._reset()
return cls._instance
def _reset(self):
self.current_prompt_id = None
self.current_prompt = None
self.metadata = {}
self.prompt_metadata = {}
self.executed_nodes = set()
# Node-level cache for metadata
self.node_cache = {}
# Limit the number of stored prompts
self.max_prompt_history = 3
# Categories we want to track and retrieve from cache
self.metadata_categories = METADATA_CATEGORIES
def _clean_old_prompts(self):
"""Clean up old prompt metadata, keeping only recent ones"""
if len(self.prompt_metadata) <= self.max_prompt_history:
return
# Sort all prompt_ids by timestamp
sorted_prompts = sorted(
self.prompt_metadata.keys(),
key=lambda pid: self.prompt_metadata[pid].get("timestamp", 0)
)
# Remove oldest records
prompts_to_remove = sorted_prompts[:len(sorted_prompts) - self.max_prompt_history]
for pid in prompts_to_remove:
del self.prompt_metadata[pid]
def start_collection(self, prompt_id):
"""Begin metadata collection for a new prompt"""
self.current_prompt_id = prompt_id
self.executed_nodes = set()
self.prompt_metadata[prompt_id] = {
category: {} for category in METADATA_CATEGORIES
}
# Add additional metadata fields
self.prompt_metadata[prompt_id].update({
"execution_order": [],
"current_prompt": None, # Will store the prompt object
"timestamp": time.time()
})
# Clean up old prompt data
self._clean_old_prompts()
def set_current_prompt(self, prompt):
"""Set the current prompt object reference"""
self.current_prompt = prompt
if self.current_prompt_id and self.current_prompt_id in self.prompt_metadata:
# Store the prompt in the metadata for later relationship tracing
self.prompt_metadata[self.current_prompt_id]["current_prompt"] = prompt
def get_metadata(self, prompt_id=None):
"""Get collected metadata for a prompt"""
key = prompt_id if prompt_id is not None else self.current_prompt_id
if key not in self.prompt_metadata:
return {}
metadata = self.prompt_metadata[key]
# If we have a current prompt object, check for non-executed nodes
prompt_obj = metadata.get("current_prompt")
if prompt_obj and hasattr(prompt_obj, "original_prompt"):
original_prompt = prompt_obj.original_prompt
# Fill in missing metadata from cache for nodes that weren't executed
self._fill_missing_metadata(key, original_prompt)
return self.prompt_metadata.get(key, {})
def _fill_missing_metadata(self, prompt_id, original_prompt):
"""Fill missing metadata from cache for non-executed nodes"""
if not original_prompt:
return
executed_nodes = self.executed_nodes
metadata = self.prompt_metadata[prompt_id]
# Iterate through nodes in the original prompt
for node_id, node_data in original_prompt.items():
# Skip if already executed in this run
if node_id in executed_nodes:
continue
# Get the node type from the prompt (this is the key in NODE_CLASS_MAPPINGS)
prompt_class_type = node_data.get("class_type")
if not prompt_class_type:
continue
# Convert to actual class name (which is what we use in our cache)
class_type = prompt_class_type
if prompt_class_type in NODE_CLASS_MAPPINGS:
class_obj = NODE_CLASS_MAPPINGS[prompt_class_type]
class_type = class_obj.__name__
# Create cache key using the actual class name
cache_key = f"{node_id}:{class_type}"
# Check if this node type is relevant for metadata collection
if class_type in NODE_EXTRACTORS:
# Check if we have cached metadata for this node
if cache_key in self.node_cache:
cached_data = self.node_cache[cache_key]
# Apply cached metadata to the current metadata
for category in self.metadata_categories:
if category in cached_data and node_id in cached_data[category]:
if node_id not in metadata[category]:
metadata[category][node_id] = cached_data[category][node_id]
def record_node_execution(self, node_id, class_type, inputs, outputs):
"""Record information about a node's execution"""
if not self.current_prompt_id:
return
# Add to execution order and mark as executed
if node_id not in self.executed_nodes:
self.executed_nodes.add(node_id)
self.prompt_metadata[self.current_prompt_id]["execution_order"].append(node_id)
# Process inputs to simplify working with them
processed_inputs = {}
for input_name, input_values in inputs.items():
if isinstance(input_values, list) and len(input_values) > 0:
# For single values, just use the first one (most common case)
processed_inputs[input_name] = input_values[0]
else:
processed_inputs[input_name] = input_values
# Extract node-specific metadata
extractor = NODE_EXTRACTORS.get(class_type, GenericNodeExtractor)
extractor.extract(
node_id,
processed_inputs,
outputs,
self.prompt_metadata[self.current_prompt_id]
)
# Cache this node's metadata
self._cache_node_metadata(node_id, class_type)
def update_node_execution(self, node_id, class_type, outputs):
"""Update node metadata with output information"""
if not self.current_prompt_id:
return
# Process outputs to make them more usable
processed_outputs = outputs
# Use the same extractor to update with outputs
extractor = NODE_EXTRACTORS.get(class_type, GenericNodeExtractor)
if hasattr(extractor, 'update'):
extractor.update(
node_id,
processed_outputs,
self.prompt_metadata[self.current_prompt_id]
)
# Update the cached metadata for this node
self._cache_node_metadata(node_id, class_type)
def _cache_node_metadata(self, node_id, class_type):
"""Cache the metadata for a specific node"""
if not self.current_prompt_id or not node_id or not class_type:
return
# Create a cache key combining node_id and class_type
cache_key = f"{node_id}:{class_type}"
# Create a shallow copy of the node's metadata
node_metadata = {}
current_metadata = self.prompt_metadata[self.current_prompt_id]
for category in self.metadata_categories:
if category in current_metadata and node_id in current_metadata[category]:
if category not in node_metadata:
node_metadata[category] = {}
node_metadata[category][node_id] = current_metadata[category][node_id]
# Save to cache if we have any metadata for this node
if any(node_metadata.values()):
self.node_cache[cache_key] = node_metadata
def clear_unused_cache(self):
"""Clean up node_cache entries that are no longer in use"""
# Collect all node_ids currently in prompt_metadata
active_node_ids = set()
for prompt_data in self.prompt_metadata.values():
for category in self.metadata_categories:
if category in prompt_data:
active_node_ids.update(prompt_data[category].keys())
# Find cache keys that are no longer needed
keys_to_remove = []
for cache_key in self.node_cache:
node_id = cache_key.split(':')[0]
if node_id not in active_node_ids:
keys_to_remove.append(cache_key)
# Remove cache entries that are no longer needed
for key in keys_to_remove:
del self.node_cache[key]
def clear_metadata(self, prompt_id=None):
"""Clear metadata for a specific prompt or reset all data"""
if prompt_id is not None:
if prompt_id in self.prompt_metadata:
del self.prompt_metadata[prompt_id]
# Clean up cache after removing prompt
self.clear_unused_cache()
else:
# Reset all data
self._reset()
def get_first_decoded_image(self, prompt_id=None):
"""Get the first decoded image result"""
key = prompt_id if prompt_id is not None else self.current_prompt_id
if key not in self.prompt_metadata:
return None
metadata = self.prompt_metadata[key]
if IMAGES in metadata and "first_decode" in metadata[IMAGES]:
image_data = metadata[IMAGES]["first_decode"]["image"]
# If it's an image batch or tuple, handle various formats
if isinstance(image_data, (list, tuple)) and len(image_data) > 0:
# Return first element of list/tuple
return image_data[0]
# If it's a tensor, return as is for processing in the route handler
return image_data
# If no image is found in the current metadata, try to find it in the cache
# This handles the case where VAEDecode was cached by ComfyUI and not executed
prompt_obj = metadata.get("current_prompt")
if prompt_obj and hasattr(prompt_obj, "original_prompt"):
original_prompt = prompt_obj.original_prompt
for node_id, node_data in original_prompt.items():
class_type = node_data.get("class_type")
if class_type and class_type in NODE_CLASS_MAPPINGS:
class_obj = NODE_CLASS_MAPPINGS[class_type]
class_name = class_obj.__name__
# Check if this is a VAEDecode node
if class_name == "VAEDecode":
# Try to find this node in the cache
cache_key = f"{node_id}:{class_name}"
if cache_key in self.node_cache:
cached_data = self.node_cache[cache_key]
if IMAGES in cached_data and node_id in cached_data[IMAGES]:
image_data = cached_data[IMAGES][node_id]["image"]
# Handle different image formats
if isinstance(image_data, (list, tuple)) and len(image_data) > 0:
return image_data[0]
return image_data
return None

View File

@@ -0,0 +1,280 @@
import os
from .constants import MODELS, PROMPTS, SAMPLING, LORAS, SIZE, IMAGES
class NodeMetadataExtractor:
"""Base class for node-specific metadata extraction"""
@staticmethod
def extract(node_id, inputs, outputs, metadata):
"""Extract metadata from node inputs/outputs"""
pass
@staticmethod
def update(node_id, outputs, metadata):
"""Update metadata with node outputs after execution"""
pass
class GenericNodeExtractor(NodeMetadataExtractor):
"""Default extractor for nodes without specific handling"""
@staticmethod
def extract(node_id, inputs, outputs, metadata):
pass
class CheckpointLoaderExtractor(NodeMetadataExtractor):
@staticmethod
def extract(node_id, inputs, outputs, metadata):
if not inputs or "ckpt_name" not in inputs:
return
model_name = inputs.get("ckpt_name")
if model_name:
metadata[MODELS][node_id] = {
"name": model_name,
"type": "checkpoint",
"node_id": node_id
}
class CLIPTextEncodeExtractor(NodeMetadataExtractor):
@staticmethod
def extract(node_id, inputs, outputs, metadata):
if not inputs or "text" not in inputs:
return
text = inputs.get("text", "")
metadata[PROMPTS][node_id] = {
"text": text,
"node_id": node_id
}
class SamplerExtractor(NodeMetadataExtractor):
@staticmethod
def extract(node_id, inputs, outputs, metadata):
if not inputs:
return
sampling_params = {}
for key in ["seed", "steps", "cfg", "sampler_name", "scheduler", "denoise"]:
if key in inputs:
sampling_params[key] = inputs[key]
metadata[SAMPLING][node_id] = {
"parameters": sampling_params,
"node_id": node_id
}
# Extract latent image dimensions if available
if "latent_image" in inputs and inputs["latent_image"] is not None:
latent = inputs["latent_image"]
if isinstance(latent, dict) and "samples" in latent:
# Extract dimensions from latent tensor
samples = latent["samples"]
if hasattr(samples, "shape") and len(samples.shape) >= 3:
# Correct shape interpretation: [batch_size, channels, height/8, width/8]
# Multiply by 8 to get actual pixel dimensions
height = int(samples.shape[2] * 8)
width = int(samples.shape[3] * 8)
if SIZE not in metadata:
metadata[SIZE] = {}
metadata[SIZE][node_id] = {
"width": width,
"height": height,
"node_id": node_id
}
class KSamplerAdvancedExtractor(NodeMetadataExtractor):
@staticmethod
def extract(node_id, inputs, outputs, metadata):
if not inputs:
return
sampling_params = {}
for key in ["noise_seed", "steps", "cfg", "sampler_name", "scheduler", "add_noise"]:
if key in inputs:
sampling_params[key] = inputs[key]
metadata[SAMPLING][node_id] = {
"parameters": sampling_params,
"node_id": node_id
}
# Extract latent image dimensions if available
if "latent_image" in inputs and inputs["latent_image"] is not None:
latent = inputs["latent_image"]
if isinstance(latent, dict) and "samples" in latent:
# Extract dimensions from latent tensor
samples = latent["samples"]
if hasattr(samples, "shape") and len(samples.shape) >= 3:
# Correct shape interpretation: [batch_size, channels, height/8, width/8]
# Multiply by 8 to get actual pixel dimensions
height = int(samples.shape[2] * 8)
width = int(samples.shape[3] * 8)
if SIZE not in metadata:
metadata[SIZE] = {}
metadata[SIZE][node_id] = {
"width": width,
"height": height,
"node_id": node_id
}
class LoraLoaderExtractor(NodeMetadataExtractor):
@staticmethod
def extract(node_id, inputs, outputs, metadata):
if not inputs or "lora_name" not in inputs:
return
lora_name = inputs.get("lora_name")
# Extract base filename without extension from path
lora_name = os.path.splitext(os.path.basename(lora_name))[0]
strength_model = round(float(inputs.get("strength_model", 1.0)), 2)
# Use the standardized format with lora_list
metadata[LORAS][node_id] = {
"lora_list": [
{
"name": lora_name,
"strength": strength_model
}
],
"node_id": node_id
}
class ImageSizeExtractor(NodeMetadataExtractor):
@staticmethod
def extract(node_id, inputs, outputs, metadata):
if not inputs:
return
width = inputs.get("width", 512)
height = inputs.get("height", 512)
if SIZE not in metadata:
metadata[SIZE] = {}
metadata[SIZE][node_id] = {
"width": width,
"height": height,
"node_id": node_id
}
class LoraLoaderManagerExtractor(NodeMetadataExtractor):
@staticmethod
def extract(node_id, inputs, outputs, metadata):
if not inputs:
return
active_loras = []
# Process lora_stack if available
if "lora_stack" in inputs:
lora_stack = inputs.get("lora_stack", [])
for lora_path, model_strength, clip_strength in lora_stack:
# Extract lora name from path (following the format in lora_loader.py)
lora_name = os.path.splitext(os.path.basename(lora_path))[0]
active_loras.append({
"name": lora_name,
"strength": model_strength
})
# Process loras from inputs
if "loras" in inputs:
loras_data = inputs.get("loras", [])
# Handle new format: {'loras': {'__value__': [...]}}
if isinstance(loras_data, dict) and '__value__' in loras_data:
loras_list = loras_data['__value__']
# Handle old format: {'loras': [...]}
elif isinstance(loras_data, list):
loras_list = loras_data
else:
loras_list = []
# Filter for active loras
for lora in loras_list:
if isinstance(lora, dict) and lora.get("active", True) and not lora.get("_isDummy", False):
active_loras.append({
"name": lora.get("name", ""),
"strength": float(lora.get("strength", 1.0))
})
if active_loras:
metadata[LORAS][node_id] = {
"lora_list": active_loras,
"node_id": node_id
}
class FluxGuidanceExtractor(NodeMetadataExtractor):
@staticmethod
def extract(node_id, inputs, outputs, metadata):
if not inputs or "guidance" not in inputs:
return
guidance_value = inputs.get("guidance")
# Store the guidance value in SAMPLING category
if node_id not in metadata[SAMPLING]:
metadata[SAMPLING][node_id] = {"parameters": {}, "node_id": node_id}
metadata[SAMPLING][node_id]["parameters"]["guidance"] = guidance_value
class UNETLoaderExtractor(NodeMetadataExtractor):
@staticmethod
def extract(node_id, inputs, outputs, metadata):
if not inputs or "unet_name" not in inputs:
return
model_name = inputs.get("unet_name")
if model_name:
metadata[MODELS][node_id] = {
"name": model_name,
"type": "checkpoint",
"node_id": node_id
}
class VAEDecodeExtractor(NodeMetadataExtractor):
@staticmethod
def extract(node_id, inputs, outputs, metadata):
pass
@staticmethod
def update(node_id, outputs, metadata):
# Ensure IMAGES category exists
if IMAGES not in metadata:
metadata[IMAGES] = {}
# Save image data under node ID index to be captured by caching mechanism
metadata[IMAGES][node_id] = {
"node_id": node_id,
"image": outputs
}
# Only set first_decode if it hasn't been recorded yet
if "first_decode" not in metadata[IMAGES]:
metadata[IMAGES]["first_decode"] = metadata[IMAGES][node_id]
# Registry of node-specific extractors
NODE_EXTRACTORS = {
# Sampling
"KSampler": SamplerExtractor,
"KSamplerAdvanced": KSamplerAdvancedExtractor, # Add KSamplerAdvanced
"SamplerCustomAdvanced": SamplerExtractor, # Add SamplerCustomAdvanced
# Loaders
"CheckpointLoaderSimple": CheckpointLoaderExtractor,
"UNETLoader": UNETLoaderExtractor, # Updated to use dedicated extractor
"LoraLoader": LoraLoaderExtractor,
"LoraManagerLoader": LoraLoaderManagerExtractor,
# Conditioning
"CLIPTextEncode": CLIPTextEncodeExtractor,
# Latent
"EmptyLatentImage": ImageSizeExtractor,
# Flux
"FluxGuidance": FluxGuidanceExtractor, # Add FluxGuidance
# Image
"VAEDecode": VAEDecodeExtractor, # Added VAEDecode extractor
# Add other nodes as needed
}

View File

@@ -0,0 +1,35 @@
import logging
from ..metadata_collector.metadata_processor import MetadataProcessor
logger = logging.getLogger(__name__)
class DebugMetadata:
NAME = "Debug Metadata (LoraManager)"
CATEGORY = "Lora Manager/utils"
DESCRIPTION = "Debug node to verify metadata_processor functionality"
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"images": ("IMAGE",),
},
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("metadata_json",)
FUNCTION = "process_metadata"
def process_metadata(self, images):
try:
# Get the current execution context's metadata
from ..metadata_collector import get_metadata
metadata = get_metadata()
# Use the MetadataProcessor to convert it to JSON string
metadata_json = MetadataProcessor.to_json(metadata)
return (metadata_json,)
except Exception as e:
logger.error(f"Error processing metadata: {e}")
return ("{}",) # Return empty JSON object in case of error

View File

@@ -5,10 +5,10 @@ import re
import numpy as np
import folder_paths # type: ignore
from ..services.lora_scanner import LoraScanner
from ..workflow.parser import WorkflowParser
from ..metadata_collector.metadata_processor import MetadataProcessor
from ..metadata_collector import get_metadata
from PIL import Image, PngImagePlugin
import piexif
from io import BytesIO
class SaveImage:
NAME = "Save Image (LoraManager)"
@@ -34,8 +34,7 @@ class SaveImage:
"file_format": (["png", "jpeg", "webp"],),
},
"optional": {
"custom_prompt": ("STRING", {"default": "", "forceInput": True}),
"lossless_webp": ("BOOLEAN", {"default": True}),
"lossless_webp": ("BOOLEAN", {"default": False}),
"quality": ("INT", {"default": 100, "min": 1, "max": 100}),
"embed_workflow": ("BOOLEAN", {"default": False}),
"add_counter_to_filename": ("BOOLEAN", {"default": True}),
@@ -61,21 +60,17 @@ class SaveImage:
return item.get('sha256')
return None
async def format_metadata(self, parsed_workflow, custom_prompt=None):
async def format_metadata(self, metadata_dict):
"""Format metadata in the requested format similar to userComment example"""
if not parsed_workflow:
if not metadata_dict:
return ""
# Extract the prompt and negative prompt
prompt = parsed_workflow.get('prompt', '')
negative_prompt = parsed_workflow.get('negative_prompt', '')
# Override prompt with custom_prompt if provided
if custom_prompt:
prompt = custom_prompt
prompt = metadata_dict.get('prompt', '')
negative_prompt = metadata_dict.get('negative_prompt', '')
# Extract loras from the prompt if present
loras_text = parsed_workflow.get('loras', '')
loras_text = metadata_dict.get('loras', '')
lora_hashes = {}
# If loras are found, add them on a new line after the prompt
@@ -104,11 +99,11 @@ class SaveImage:
params = []
# Add standard parameters in the correct order
if 'steps' in parsed_workflow:
params.append(f"Steps: {parsed_workflow.get('steps')}")
if 'steps' in metadata_dict:
params.append(f"Steps: {metadata_dict.get('steps')}")
if 'sampler' in parsed_workflow:
sampler = parsed_workflow.get('sampler')
if 'sampler' in metadata_dict:
sampler = metadata_dict.get('sampler')
# Convert ComfyUI sampler names to user-friendly names
sampler_mapping = {
'euler': 'Euler',
@@ -130,8 +125,8 @@ class SaveImage:
sampler_name = sampler_mapping.get(sampler, sampler)
params.append(f"Sampler: {sampler_name}")
if 'scheduler' in parsed_workflow:
scheduler = parsed_workflow.get('scheduler')
if 'scheduler' in metadata_dict:
scheduler = metadata_dict.get('scheduler')
scheduler_mapping = {
'normal': 'Simple',
'karras': 'Karras',
@@ -142,27 +137,36 @@ class SaveImage:
scheduler_name = scheduler_mapping.get(scheduler, scheduler)
params.append(f"Schedule type: {scheduler_name}")
# CFG scale (cfg in parsed_workflow)
if 'cfg_scale' in parsed_workflow:
params.append(f"CFG scale: {parsed_workflow.get('cfg_scale')}")
elif 'cfg' in parsed_workflow:
params.append(f"CFG scale: {parsed_workflow.get('cfg')}")
# CFG scale (cfg_scale in metadata_dict)
if 'cfg_scale' in metadata_dict:
params.append(f"CFG scale: {metadata_dict.get('cfg_scale')}")
elif 'cfg' in metadata_dict:
params.append(f"CFG scale: {metadata_dict.get('cfg')}")
# Seed
if 'seed' in parsed_workflow:
params.append(f"Seed: {parsed_workflow.get('seed')}")
if 'seed' in metadata_dict:
params.append(f"Seed: {metadata_dict.get('seed')}")
# Size
if 'size' in parsed_workflow:
params.append(f"Size: {parsed_workflow.get('size')}")
if 'size' in metadata_dict:
params.append(f"Size: {metadata_dict.get('size')}")
# Model info
if 'checkpoint' in parsed_workflow:
# Extract basename without path
checkpoint = os.path.basename(parsed_workflow.get('checkpoint', ''))
# Remove extension if present
checkpoint = os.path.splitext(checkpoint)[0]
params.append(f"Model: {checkpoint}")
if 'checkpoint' in metadata_dict:
# Ensure checkpoint is a string before processing
checkpoint = metadata_dict.get('checkpoint')
if checkpoint is not None:
# Handle both string and other types safely
if isinstance(checkpoint, str):
# Extract basename without path
checkpoint = os.path.basename(checkpoint)
# Remove extension if present
checkpoint = os.path.splitext(checkpoint)[0]
else:
# Convert non-string to string
checkpoint = str(checkpoint)
params.append(f"Model: {checkpoint}")
# Add LoRA hashes if available
if lora_hashes:
@@ -181,9 +185,9 @@ class SaveImage:
# credit to nkchocoai
# Add format_filename method to handle pattern substitution
def format_filename(self, filename, parsed_workflow):
def format_filename(self, filename, metadata_dict):
"""Format filename with metadata values"""
if not parsed_workflow:
if not metadata_dict:
return filename
result = re.findall(self.pattern_format, filename)
@@ -191,30 +195,30 @@ class SaveImage:
parts = segment.replace("%", "").split(":")
key = parts[0]
if key == "seed" and 'seed' in parsed_workflow:
filename = filename.replace(segment, str(parsed_workflow.get('seed', '')))
elif key == "width" and 'size' in parsed_workflow:
size = parsed_workflow.get('size', 'x')
if key == "seed" and 'seed' in metadata_dict:
filename = filename.replace(segment, str(metadata_dict.get('seed', '')))
elif key == "width" and 'size' in metadata_dict:
size = metadata_dict.get('size', 'x')
w = size.split('x')[0] if isinstance(size, str) else size[0]
filename = filename.replace(segment, str(w))
elif key == "height" and 'size' in parsed_workflow:
size = parsed_workflow.get('size', 'x')
elif key == "height" and 'size' in metadata_dict:
size = metadata_dict.get('size', 'x')
h = size.split('x')[1] if isinstance(size, str) else size[1]
filename = filename.replace(segment, str(h))
elif key == "pprompt" and 'prompt' in parsed_workflow:
prompt = parsed_workflow.get('prompt', '').replace("\n", " ")
elif key == "pprompt" and 'prompt' in metadata_dict:
prompt = metadata_dict.get('prompt', '').replace("\n", " ")
if len(parts) >= 2:
length = int(parts[1])
prompt = prompt[:length]
filename = filename.replace(segment, prompt.strip())
elif key == "nprompt" and 'negative_prompt' in parsed_workflow:
prompt = parsed_workflow.get('negative_prompt', '').replace("\n", " ")
elif key == "nprompt" and 'negative_prompt' in metadata_dict:
prompt = metadata_dict.get('negative_prompt', '').replace("\n", " ")
if len(parts) >= 2:
length = int(parts[1])
prompt = prompt[:length]
filename = filename.replace(segment, prompt.strip())
elif key == "model" and 'checkpoint' in parsed_workflow:
model = parsed_workflow.get('checkpoint', '')
elif key == "model" and 'checkpoint' in metadata_dict:
model = metadata_dict.get('checkpoint', '')
model = os.path.splitext(os.path.basename(model))[0]
if len(parts) >= 2:
length = int(parts[1])
@@ -246,23 +250,19 @@ class SaveImage:
return filename
def save_images(self, images, filename_prefix, file_format, prompt=None, extra_pnginfo=None,
lossless_webp=True, quality=100, embed_workflow=False, add_counter_to_filename=True,
custom_prompt=None):
lossless_webp=True, quality=100, embed_workflow=False, add_counter_to_filename=True):
"""Save images with metadata"""
results = []
# Parse the workflow using the WorkflowParser
parser = WorkflowParser()
if prompt:
parsed_workflow = parser.parse_workflow(prompt)
else:
parsed_workflow = {}
# Get metadata using the metadata collector
raw_metadata = get_metadata()
metadata_dict = MetadataProcessor.to_dict(raw_metadata)
# Get or create metadata asynchronously
metadata = asyncio.run(self.format_metadata(parsed_workflow, custom_prompt))
metadata = asyncio.run(self.format_metadata(metadata_dict))
# Process filename_prefix with pattern substitution
filename_prefix = self.format_filename(filename_prefix, parsed_workflow)
filename_prefix = self.format_filename(filename_prefix, metadata_dict)
# Get initial save path info once for the batch
full_output_folder, filename, counter, subfolder, processed_prefix = folder_paths.get_save_image_path(
@@ -290,7 +290,8 @@ class SaveImage:
if file_format == "png":
file = base_filename + ".png"
file_extension = ".png"
save_kwargs = {"optimize": True, "compress_level": self.compress_level}
# Remove "optimize": True to match built-in node behavior
save_kwargs = {"compress_level": self.compress_level}
pnginfo = PngImagePlugin.PngInfo()
elif file_format == "jpeg":
file = base_filename + ".jpg"
@@ -299,7 +300,8 @@ class SaveImage:
elif file_format == "webp":
file = base_filename + ".webp"
file_extension = ".webp"
save_kwargs = {"quality": quality, "lossless": lossless_webp}
# Add optimization param to control performance
save_kwargs = {"quality": quality, "lossless": lossless_webp, "method": 0}
# Full save path
file_path = os.path.join(full_output_folder, file)
@@ -347,8 +349,7 @@ class SaveImage:
return results
def process_image(self, images, filename_prefix="ComfyUI", file_format="png", prompt=None, extra_pnginfo=None,
lossless_webp=True, quality=100, embed_workflow=False, add_counter_to_filename=True,
custom_prompt=""):
lossless_webp=True, quality=100, embed_workflow=False, add_counter_to_filename=True):
"""Process and save image with metadata"""
# Make sure the output directory exists
os.makedirs(self.output_dir, exist_ok=True)
@@ -369,8 +370,7 @@ class SaveImage:
lossless_webp,
quality,
embed_workflow,
add_counter_to_filename,
custom_prompt if custom_prompt.strip() else None
add_counter_to_filename
)
return (images,)

View File

@@ -1,5 +1,9 @@
import os
import time
import numpy as np
from PIL import Image
import torch
import io
import logging
from aiohttp import web
from typing import Dict
@@ -11,9 +15,11 @@ from ..utils.recipe_parsers import RecipeParserFactory
from ..utils.constants import CARD_PREVIEW_WIDTH
from ..config import config
from ..workflow.parser import WorkflowParser
from ..metadata_collector import get_metadata # Add MetadataCollector import
from ..metadata_collector.metadata_processor import MetadataProcessor # Add MetadataProcessor import
from ..utils.utils import download_civitai_image
from ..services.service_registry import ServiceRegistry # Add ServiceRegistry import
from ..metadata_collector.metadata_registry import MetadataRegistry
logger = logging.getLogger(__name__)
@@ -24,7 +30,7 @@ class RecipeRoutes:
# Initialize service references as None, will be set during async init
self.recipe_scanner = None
self.civitai_client = None
self.parser = WorkflowParser()
# Remove WorkflowParser instance
# Pre-warm the cache
self._init_cache_task = None
@@ -656,8 +662,8 @@ class RecipeRoutes:
logger.error(f"Error retrieving base models: {e}", exc_info=True)
return web.json_response({
'success': False,
'error': str(e)
}, status=500)
'error': str(e)}
, status=500)
async def share_recipe(self, request: web.Request) -> web.Response:
"""Process a recipe image for sharing by adding metadata to EXIF"""
@@ -786,50 +792,72 @@ class RecipeRoutes:
# Ensure services are initialized
await self.init_services()
reader = await request.multipart()
# Get metadata using the metadata collector instead of workflow parsing
raw_metadata = get_metadata()
metadata_dict = MetadataProcessor.to_dict(raw_metadata)
# Process form data
workflow_json = None
# Check if we have valid metadata
if not metadata_dict:
return web.json_response({"error": "No generation metadata found"}, status=400)
while True:
field = await reader.next()
if field is None:
break
# Get the most recent image from metadata registry instead of temp directory
metadata_registry = MetadataRegistry()
latest_image = metadata_registry.get_first_decoded_image()
if not latest_image:
return web.json_response({"error": "No recent images found to use for recipe. Try generating an image first."}, status=400)
# Convert the image data to bytes - handle tuple and tensor cases
logger.debug(f"Image type: {type(latest_image)}")
try:
# Handle the tuple case first
if isinstance(latest_image, tuple):
# Extract the tensor from the tuple
if len(latest_image) > 0:
tensor_image = latest_image[0]
else:
return web.json_response({"error": "Empty image tuple received"}, status=400)
else:
tensor_image = latest_image
if field.name == 'workflow_json':
workflow_text = await field.text()
try:
workflow_json = json.loads(workflow_text)
except:
return web.json_response({"error": "Invalid workflow JSON"}, status=400)
# Get the shape info for debugging
if hasattr(tensor_image, 'shape'):
shape_info = tensor_image.shape
logger.debug(f"Tensor shape: {shape_info}, dtype: {tensor_image.dtype}")
# Convert tensor to numpy array
if isinstance(tensor_image, torch.Tensor):
image_np = tensor_image.cpu().numpy()
else:
image_np = np.array(tensor_image)
# Handle different tensor shapes
# Case: (1, 1, H, W, 3) or (1, H, W, 3) - batch or multi-batch
if len(image_np.shape) > 3:
# Remove batch dimensions until we get to (H, W, 3)
while len(image_np.shape) > 3:
image_np = image_np[0]
# If values are in [0, 1] range, convert to [0, 255]
if image_np.dtype == np.float32 or image_np.dtype == np.float64:
if image_np.max() <= 1.0:
image_np = (image_np * 255).astype(np.uint8)
# Ensure image is in the right format (HWC with RGB channels)
if len(image_np.shape) == 3 and image_np.shape[2] == 3:
pil_image = Image.fromarray(image_np)
img_byte_arr = io.BytesIO()
pil_image.save(img_byte_arr, format='PNG')
image = img_byte_arr.getvalue()
else:
return web.json_response({"error": f"Cannot handle this data shape: {image_np.shape}, {image_np.dtype}"}, status=400)
except Exception as e:
logger.error(f"Error processing image data: {str(e)}", exc_info=True)
return web.json_response({"error": f"Error processing image: {str(e)}"}, status=400)
if not workflow_json:
return web.json_response({"error": "Missing workflow JSON"}, status=400)
# Find the latest image in the temp directory
temp_dir = config.temp_directory
image_files = []
for file in os.listdir(temp_dir):
if file.lower().endswith(('.png', '.jpg', '.jpeg', '.webp')):
file_path = os.path.join(temp_dir, file)
image_files.append((file_path, os.path.getmtime(file_path)))
if not image_files:
return web.json_response({"error": "No recent images found to use for recipe"}, status=400)
# Sort by modification time (newest first)
image_files.sort(key=lambda x: x[1], reverse=True)
latest_image_path = image_files[0][0]
# Parse the workflow to extract generation parameters and loras
parsed_workflow = self.parser.parse_workflow(workflow_json)
if not parsed_workflow:
return web.json_response({"error": "Could not extract parameters from workflow"}, status=400)
# Get the lora stack from the parsed workflow
lora_stack = parsed_workflow.get("loras", "")
# Get the lora stack from the metadata
lora_stack = metadata_dict.get("loras", "")
# Parse the lora stack format: "<lora:name:strength> <lora:name2:strength2> ..."
import re
@@ -837,7 +865,7 @@ class RecipeRoutes:
# Check if any loras were found
if not lora_matches:
return web.json_response({"error": "No LoRAs found in the workflow"}, status=400)
return web.json_response({"error": "No LoRAs found in the generation metadata"}, status=400)
# Generate recipe name from the first 3 loras (or less if fewer are available)
loras_for_name = lora_matches[:3] # Take at most 3 loras for the name
@@ -851,10 +879,6 @@ class RecipeRoutes:
recipe_name = " ".join(recipe_name_parts)
# Read the image
with open(latest_image_path, 'rb') as f:
image = f.read()
# Create recipes directory if it doesn't exist
recipes_dir = self.recipe_scanner.recipes_dir
os.makedirs(recipes_dir, exist_ok=True)
@@ -922,8 +946,8 @@ class RecipeRoutes:
"created_date": time.time(),
"base_model": most_common_base_model,
"loras": loras_data,
"checkpoint": parsed_workflow.get("checkpoint", ""),
"gen_params": {key: value for key, value in parsed_workflow.items()
"checkpoint": metadata_dict.get("checkpoint", ""),
"gen_params": {key: value for key, value in metadata_dict.items()
if key not in ['checkpoint', 'loras']},
"loras_stack": lora_stack # Include the original lora stack string
}

View File

@@ -966,9 +966,6 @@ export function addLorasWidget(node, name, opts, callback) {
// Function to directly save the recipe without dialog
async function saveRecipeDirectly(widget) {
try {
// Get the workflow data from the ComfyUI app
const prompt = await app.graphToPrompt();
// Show loading toast
if (app && app.extensionManager && app.extensionManager.toast) {
app.extensionManager.toast.add({
@@ -979,14 +976,9 @@ async function saveRecipeDirectly(widget) {
});
}
// Prepare the data - only send workflow JSON
const formData = new FormData();
formData.append('workflow_json', JSON.stringify(prompt.output));
// Send the request
// Send the request to the backend API without workflow data
const response = await fetch('/api/recipes/save-from-widget', {
method: 'POST',
body: formData
method: 'POST'
});
const result = await response.json();