Merge pull request #111 from willmiao/dev

Dev
This commit is contained in:
pixelpaws
2025-04-18 10:55:07 +08:00
committed by GitHub
11 changed files with 1137 additions and 126 deletions

View File

@@ -3,16 +3,23 @@ from .py.nodes.lora_loader import LoraManagerLoader
from .py.nodes.trigger_word_toggle import TriggerWordToggle from .py.nodes.trigger_word_toggle import TriggerWordToggle
from .py.nodes.lora_stacker import LoraStacker from .py.nodes.lora_stacker import LoraStacker
from .py.nodes.save_image import SaveImage from .py.nodes.save_image import SaveImage
from .py.nodes.debug_metadata import DebugMetadata
# Import metadata collector to install hooks on startup
from .py.metadata_collector import init as init_metadata_collector
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
LoraManagerLoader.NAME: LoraManagerLoader, LoraManagerLoader.NAME: LoraManagerLoader,
TriggerWordToggle.NAME: TriggerWordToggle, TriggerWordToggle.NAME: TriggerWordToggle,
LoraStacker.NAME: LoraStacker, LoraStacker.NAME: LoraStacker,
SaveImage.NAME: SaveImage SaveImage.NAME: SaveImage,
DebugMetadata.NAME: DebugMetadata
} }
WEB_DIRECTORY = "./web/comfyui" WEB_DIRECTORY = "./web/comfyui"
# Initialize metadata collector
init_metadata_collector()
# Register routes on import # Register routes on import
LoraManager.add_routes() LoraManager.add_routes()
__all__ = ['NODE_CLASS_MAPPINGS', 'WEB_DIRECTORY'] __all__ = ['NODE_CLASS_MAPPINGS', 'WEB_DIRECTORY']

View File

@@ -0,0 +1,18 @@
import os
import importlib
from .metadata_hook import MetadataHook
from .metadata_registry import MetadataRegistry
def init():
# Install hooks to collect metadata during execution
MetadataHook.install()
# Initialize registry
registry = MetadataRegistry()
print("ComfyUI Metadata Collector initialized")
def get_metadata(prompt_id=None):
"""Helper function to get metadata from the registry"""
registry = MetadataRegistry()
return registry.get_metadata(prompt_id)

View File

@@ -0,0 +1,12 @@
"""Constants used by the metadata collector"""
# Individual category constants
MODELS = "models"
PROMPTS = "prompts"
SAMPLING = "sampling"
LORAS = "loras"
SIZE = "size"
IMAGES = "images" # Added new category for image results
# Collection of categories for iteration
METADATA_CATEGORIES = [MODELS, PROMPTS, SAMPLING, LORAS, SIZE, IMAGES] # Added IMAGES to categories

View File

@@ -0,0 +1,123 @@
import sys
import inspect
from .metadata_registry import MetadataRegistry
class MetadataHook:
"""Install hooks for metadata collection"""
@staticmethod
def install():
"""Install hooks to collect metadata during execution"""
try:
# Import ComfyUI's execution module
execution = None
try:
# Try direct import first
import execution # type: ignore
except ImportError:
# Try to locate from system modules
for module_name in sys.modules:
if module_name.endswith('.execution'):
execution = sys.modules[module_name]
break
# If we can't find the execution module, we can't install hooks
if execution is None:
print("Could not locate ComfyUI execution module, metadata collection disabled")
return
# Store the original _map_node_over_list function
original_map_node_over_list = execution._map_node_over_list
# Define the wrapped _map_node_over_list function
def map_node_over_list_with_metadata(obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None):
# Only collect metadata when calling the main function of nodes
if func == obj.FUNCTION and hasattr(obj, '__class__'):
try:
# Get the current prompt_id from the registry
registry = MetadataRegistry()
prompt_id = registry.current_prompt_id
if prompt_id is not None:
# Get node class type
class_type = obj.__class__.__name__
# Unique ID might be available through the obj if it has a unique_id field
node_id = getattr(obj, 'unique_id', None)
if node_id is None and pre_execute_cb:
# Try to extract node_id through reflection on GraphBuilder.set_default_prefix
frame = inspect.currentframe()
while frame:
if 'unique_id' in frame.f_locals:
node_id = frame.f_locals['unique_id']
break
frame = frame.f_back
# Record inputs before execution
if node_id is not None:
registry.record_node_execution(node_id, class_type, input_data_all, None)
except Exception as e:
print(f"Error collecting metadata (pre-execution): {str(e)}")
# Execute the original function
results = original_map_node_over_list(obj, input_data_all, func, allow_interrupt, execution_block_cb, pre_execute_cb)
# After execution, collect outputs for relevant nodes
if func == obj.FUNCTION and hasattr(obj, '__class__'):
try:
# Get the current prompt_id from the registry
registry = MetadataRegistry()
prompt_id = registry.current_prompt_id
if prompt_id is not None:
# Get node class type
class_type = obj.__class__.__name__
# Unique ID might be available through the obj if it has a unique_id field
node_id = getattr(obj, 'unique_id', None)
if node_id is None and pre_execute_cb:
# Try to extract node_id through reflection
frame = inspect.currentframe()
while frame:
if 'unique_id' in frame.f_locals:
node_id = frame.f_locals['unique_id']
break
frame = frame.f_back
# Record outputs after execution
if node_id is not None:
registry.update_node_execution(node_id, class_type, results)
except Exception as e:
print(f"Error collecting metadata (post-execution): {str(e)}")
return results
# Also hook the execute function to track the current prompt_id
original_execute = execution.execute
def execute_with_prompt_tracking(*args, **kwargs):
if len(args) >= 7: # Check if we have enough arguments
server, prompt, caches, node_id, extra_data, executed, prompt_id = args[:7]
registry = MetadataRegistry()
# Start collection if this is a new prompt
if not registry.current_prompt_id or registry.current_prompt_id != prompt_id:
registry.start_collection(prompt_id)
# Store the dynprompt reference for node lookups
if hasattr(prompt, 'original_prompt'):
registry.set_current_prompt(prompt)
# Execute the original function
return original_execute(*args, **kwargs)
# Replace the functions
execution._map_node_over_list = map_node_over_list_with_metadata
execution.execute = execute_with_prompt_tracking
# Make map_node_over_list public to avoid it being hidden by hooks
execution.map_node_over_list = original_map_node_over_list
print("Metadata collection hooks installed for runtime values")
except Exception as e:
print(f"Error installing metadata hooks: {str(e)}")

View File

@@ -0,0 +1,245 @@
import json
from .constants import MODELS, PROMPTS, SAMPLING, LORAS, SIZE
class MetadataProcessor:
"""Process and format collected metadata"""
@staticmethod
def find_primary_sampler(metadata):
"""Find the primary KSampler node (with denoise=1)"""
primary_sampler = None
primary_sampler_id = None
# First, check for KSamplerAdvanced with add_noise="enable"
for node_id, sampler_info in metadata.get(SAMPLING, {}).items():
parameters = sampler_info.get("parameters", {})
add_noise = parameters.get("add_noise")
# If add_noise is "enable", this is likely the primary sampler for KSamplerAdvanced
if add_noise == "enable":
primary_sampler = sampler_info
primary_sampler_id = node_id
break
# If no KSamplerAdvanced found, fall back to traditional KSampler with denoise=1
if primary_sampler is None:
for node_id, sampler_info in metadata.get(SAMPLING, {}).items():
parameters = sampler_info.get("parameters", {})
denoise = parameters.get("denoise")
# If denoise is 1.0, this is likely the primary sampler
if denoise == 1.0 or denoise == 1:
primary_sampler = sampler_info
primary_sampler_id = node_id
break
return primary_sampler_id, primary_sampler
@staticmethod
def trace_node_input(prompt, node_id, input_name, target_class=None, max_depth=10):
"""
Trace an input connection from a node to find the source node
Parameters:
- prompt: The prompt object containing node connections
- node_id: ID of the starting node
- input_name: Name of the input to trace
- target_class: Optional class name to search for (e.g., "CLIPTextEncode")
- max_depth: Maximum depth to follow the node chain to prevent infinite loops
Returns:
- node_id of the found node, or None if not found
"""
if not prompt or not prompt.original_prompt or node_id not in prompt.original_prompt:
return None
# For depth tracking
current_depth = 0
current_node_id = node_id
current_input = input_name
while current_depth < max_depth:
if current_node_id not in prompt.original_prompt:
return None
node_inputs = prompt.original_prompt[current_node_id].get("inputs", {})
if current_input not in node_inputs:
return None
input_value = node_inputs[current_input]
# Input connections are formatted as [node_id, output_index]
if isinstance(input_value, list) and len(input_value) >= 2:
found_node_id = input_value[0] # Connected node_id
# If we're looking for a specific node class
if target_class and prompt.original_prompt[found_node_id].get("class_type") == target_class:
return found_node_id
# If we're not looking for a specific class or haven't found it yet
if not target_class:
return found_node_id
# Continue tracing through intermediate nodes
current_node_id = found_node_id
# For most conditioning nodes, the input we want to follow is named "conditioning"
if "conditioning" in prompt.original_prompt[current_node_id].get("inputs", {}):
current_input = "conditioning"
else:
# If there's no "conditioning" input, we can't trace further
return found_node_id if not target_class else None
else:
# We've reached a node with no further connections
return None
current_depth += 1
# If we've reached max depth without finding target_class
return None
@staticmethod
def find_primary_checkpoint(metadata):
"""Find the primary checkpoint model in the workflow"""
if not metadata.get(MODELS):
return None
# In most workflows, there's only one checkpoint, so we can just take the first one
for node_id, model_info in metadata.get(MODELS, {}).items():
if model_info.get("type") == "checkpoint":
return model_info.get("name")
return None
@staticmethod
def extract_generation_params(metadata):
"""Extract generation parameters from metadata using node relationships"""
params = {
"prompt": "",
"negative_prompt": "",
"seed": None,
"steps": None,
"cfg_scale": None,
"guidance": None, # Add guidance parameter
"sampler": None,
"scheduler": None,
"checkpoint": None,
"loras": "",
"size": None,
"clip_skip": None
}
# Get the prompt object for node relationship tracing
prompt = metadata.get("current_prompt")
# Find the primary KSampler node
primary_sampler_id, primary_sampler = MetadataProcessor.find_primary_sampler(metadata)
# Directly get checkpoint from metadata instead of tracing
checkpoint = MetadataProcessor.find_primary_checkpoint(metadata)
if checkpoint:
params["checkpoint"] = checkpoint
if primary_sampler:
# Extract sampling parameters
sampling_params = primary_sampler.get("parameters", {})
# Handle both seed and noise_seed
params["seed"] = sampling_params.get("seed") if sampling_params.get("seed") is not None else sampling_params.get("noise_seed")
params["steps"] = sampling_params.get("steps")
params["cfg_scale"] = sampling_params.get("cfg")
params["sampler"] = sampling_params.get("sampler_name")
params["scheduler"] = sampling_params.get("scheduler")
# Trace connections from the primary sampler
if prompt and primary_sampler_id:
# Trace positive prompt - look specifically for CLIPTextEncode
positive_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "positive", "CLIPTextEncode", max_depth=10)
if positive_node_id and positive_node_id in metadata.get(PROMPTS, {}):
params["prompt"] = metadata[PROMPTS][positive_node_id].get("text", "")
# Find any FluxGuidance nodes in the positive conditioning path
flux_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "positive", "FluxGuidance", max_depth=5)
if flux_node_id and flux_node_id in metadata.get(SAMPLING, {}):
flux_params = metadata[SAMPLING][flux_node_id].get("parameters", {})
params["guidance"] = flux_params.get("guidance")
# Trace negative prompt - look specifically for CLIPTextEncode
negative_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "negative", "CLIPTextEncode", max_depth=10)
if negative_node_id and negative_node_id in metadata.get(PROMPTS, {}):
params["negative_prompt"] = metadata[PROMPTS][negative_node_id].get("text", "")
# Check if the sampler itself has size information (from latent_image)
if primary_sampler_id in metadata.get(SIZE, {}):
width = metadata[SIZE][primary_sampler_id].get("width")
height = metadata[SIZE][primary_sampler_id].get("height")
if width and height:
params["size"] = f"{width}x{height}"
else:
# Fallback to the previous trace method if needed
latent_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "latent_image")
if latent_node_id:
# Follow chain to find EmptyLatentImage node
size_found = False
current_node_id = latent_node_id
# Limit depth to avoid infinite loops in complex workflows
max_depth = 10
for _ in range(max_depth):
if current_node_id in metadata.get(SIZE, {}):
width = metadata[SIZE][current_node_id].get("width")
height = metadata[SIZE][current_node_id].get("height")
if width and height:
params["size"] = f"{width}x{height}"
size_found = True
break
# Try to follow the chain
if prompt and prompt.original_prompt and current_node_id in prompt.original_prompt:
node_info = prompt.original_prompt[current_node_id]
if "inputs" in node_info:
# Look for a connection that might lead to size information
for input_name, input_value in node_info["inputs"].items():
if isinstance(input_value, list) and len(input_value) >= 2:
current_node_id = input_value[0]
break
else:
break # No connections to follow
else:
break # No inputs to follow
else:
break # Can't follow further
# Extract LoRAs using the standardized format
lora_parts = []
for node_id, lora_info in metadata.get(LORAS, {}).items():
# Access the lora_list from the standardized format
lora_list = lora_info.get("lora_list", [])
for lora in lora_list:
name = lora.get("name", "unknown")
strength = lora.get("strength", 1.0)
lora_parts.append(f"<lora:{name}:{strength}>")
params["loras"] = " ".join(lora_parts)
# Set default clip_skip value
params["clip_skip"] = "1" # Common default
return params
@staticmethod
def to_dict(metadata):
"""Convert extracted metadata to the ComfyUI output.json format"""
params = MetadataProcessor.extract_generation_params(metadata)
# Convert all values to strings to match output.json format
for key in params:
if params[key] is not None:
params[key] = str(params[key])
return params
@staticmethod
def to_json(metadata):
"""Convert metadata to JSON string"""
params = MetadataProcessor.to_dict(metadata)
return json.dumps(params, indent=4)

View File

@@ -0,0 +1,275 @@
import time
from nodes import NODE_CLASS_MAPPINGS
from .node_extractors import NODE_EXTRACTORS, GenericNodeExtractor
from .constants import METADATA_CATEGORIES, IMAGES
class MetadataRegistry:
"""A singleton registry to store and retrieve workflow metadata"""
_instance = None
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._reset()
return cls._instance
def _reset(self):
self.current_prompt_id = None
self.current_prompt = None
self.metadata = {}
self.prompt_metadata = {}
self.executed_nodes = set()
# Node-level cache for metadata
self.node_cache = {}
# Limit the number of stored prompts
self.max_prompt_history = 3
# Categories we want to track and retrieve from cache
self.metadata_categories = METADATA_CATEGORIES
def _clean_old_prompts(self):
"""Clean up old prompt metadata, keeping only recent ones"""
if len(self.prompt_metadata) <= self.max_prompt_history:
return
# Sort all prompt_ids by timestamp
sorted_prompts = sorted(
self.prompt_metadata.keys(),
key=lambda pid: self.prompt_metadata[pid].get("timestamp", 0)
)
# Remove oldest records
prompts_to_remove = sorted_prompts[:len(sorted_prompts) - self.max_prompt_history]
for pid in prompts_to_remove:
del self.prompt_metadata[pid]
def start_collection(self, prompt_id):
"""Begin metadata collection for a new prompt"""
self.current_prompt_id = prompt_id
self.executed_nodes = set()
self.prompt_metadata[prompt_id] = {
category: {} for category in METADATA_CATEGORIES
}
# Add additional metadata fields
self.prompt_metadata[prompt_id].update({
"execution_order": [],
"current_prompt": None, # Will store the prompt object
"timestamp": time.time()
})
# Clean up old prompt data
self._clean_old_prompts()
def set_current_prompt(self, prompt):
"""Set the current prompt object reference"""
self.current_prompt = prompt
if self.current_prompt_id and self.current_prompt_id in self.prompt_metadata:
# Store the prompt in the metadata for later relationship tracing
self.prompt_metadata[self.current_prompt_id]["current_prompt"] = prompt
def get_metadata(self, prompt_id=None):
"""Get collected metadata for a prompt"""
key = prompt_id if prompt_id is not None else self.current_prompt_id
if key not in self.prompt_metadata:
return {}
metadata = self.prompt_metadata[key]
# If we have a current prompt object, check for non-executed nodes
prompt_obj = metadata.get("current_prompt")
if prompt_obj and hasattr(prompt_obj, "original_prompt"):
original_prompt = prompt_obj.original_prompt
# Fill in missing metadata from cache for nodes that weren't executed
self._fill_missing_metadata(key, original_prompt)
return self.prompt_metadata.get(key, {})
def _fill_missing_metadata(self, prompt_id, original_prompt):
"""Fill missing metadata from cache for non-executed nodes"""
if not original_prompt:
return
executed_nodes = self.executed_nodes
metadata = self.prompt_metadata[prompt_id]
# Iterate through nodes in the original prompt
for node_id, node_data in original_prompt.items():
# Skip if already executed in this run
if node_id in executed_nodes:
continue
# Get the node type from the prompt (this is the key in NODE_CLASS_MAPPINGS)
prompt_class_type = node_data.get("class_type")
if not prompt_class_type:
continue
# Convert to actual class name (which is what we use in our cache)
class_type = prompt_class_type
if prompt_class_type in NODE_CLASS_MAPPINGS:
class_obj = NODE_CLASS_MAPPINGS[prompt_class_type]
class_type = class_obj.__name__
# Create cache key using the actual class name
cache_key = f"{node_id}:{class_type}"
# Check if this node type is relevant for metadata collection
if class_type in NODE_EXTRACTORS:
# Check if we have cached metadata for this node
if cache_key in self.node_cache:
cached_data = self.node_cache[cache_key]
# Apply cached metadata to the current metadata
for category in self.metadata_categories:
if category in cached_data and node_id in cached_data[category]:
if node_id not in metadata[category]:
metadata[category][node_id] = cached_data[category][node_id]
def record_node_execution(self, node_id, class_type, inputs, outputs):
"""Record information about a node's execution"""
if not self.current_prompt_id:
return
# Add to execution order and mark as executed
if node_id not in self.executed_nodes:
self.executed_nodes.add(node_id)
self.prompt_metadata[self.current_prompt_id]["execution_order"].append(node_id)
# Process inputs to simplify working with them
processed_inputs = {}
for input_name, input_values in inputs.items():
if isinstance(input_values, list) and len(input_values) > 0:
# For single values, just use the first one (most common case)
processed_inputs[input_name] = input_values[0]
else:
processed_inputs[input_name] = input_values
# Extract node-specific metadata
extractor = NODE_EXTRACTORS.get(class_type, GenericNodeExtractor)
extractor.extract(
node_id,
processed_inputs,
outputs,
self.prompt_metadata[self.current_prompt_id]
)
# Cache this node's metadata
self._cache_node_metadata(node_id, class_type)
def update_node_execution(self, node_id, class_type, outputs):
"""Update node metadata with output information"""
if not self.current_prompt_id:
return
# Process outputs to make them more usable
processed_outputs = outputs
# Use the same extractor to update with outputs
extractor = NODE_EXTRACTORS.get(class_type, GenericNodeExtractor)
if hasattr(extractor, 'update'):
extractor.update(
node_id,
processed_outputs,
self.prompt_metadata[self.current_prompt_id]
)
# Update the cached metadata for this node
self._cache_node_metadata(node_id, class_type)
def _cache_node_metadata(self, node_id, class_type):
"""Cache the metadata for a specific node"""
if not self.current_prompt_id or not node_id or not class_type:
return
# Create a cache key combining node_id and class_type
cache_key = f"{node_id}:{class_type}"
# Create a shallow copy of the node's metadata
node_metadata = {}
current_metadata = self.prompt_metadata[self.current_prompt_id]
for category in self.metadata_categories:
if category in current_metadata and node_id in current_metadata[category]:
if category not in node_metadata:
node_metadata[category] = {}
node_metadata[category][node_id] = current_metadata[category][node_id]
# Save to cache if we have any metadata for this node
if any(node_metadata.values()):
self.node_cache[cache_key] = node_metadata
def clear_unused_cache(self):
"""Clean up node_cache entries that are no longer in use"""
# Collect all node_ids currently in prompt_metadata
active_node_ids = set()
for prompt_data in self.prompt_metadata.values():
for category in self.metadata_categories:
if category in prompt_data:
active_node_ids.update(prompt_data[category].keys())
# Find cache keys that are no longer needed
keys_to_remove = []
for cache_key in self.node_cache:
node_id = cache_key.split(':')[0]
if node_id not in active_node_ids:
keys_to_remove.append(cache_key)
# Remove cache entries that are no longer needed
for key in keys_to_remove:
del self.node_cache[key]
def clear_metadata(self, prompt_id=None):
"""Clear metadata for a specific prompt or reset all data"""
if prompt_id is not None:
if prompt_id in self.prompt_metadata:
del self.prompt_metadata[prompt_id]
# Clean up cache after removing prompt
self.clear_unused_cache()
else:
# Reset all data
self._reset()
def get_first_decoded_image(self, prompt_id=None):
"""Get the first decoded image result"""
key = prompt_id if prompt_id is not None else self.current_prompt_id
if key not in self.prompt_metadata:
return None
metadata = self.prompt_metadata[key]
if IMAGES in metadata and "first_decode" in metadata[IMAGES]:
image_data = metadata[IMAGES]["first_decode"]["image"]
# If it's an image batch or tuple, handle various formats
if isinstance(image_data, (list, tuple)) and len(image_data) > 0:
# Return first element of list/tuple
return image_data[0]
# If it's a tensor, return as is for processing in the route handler
return image_data
# If no image is found in the current metadata, try to find it in the cache
# This handles the case where VAEDecode was cached by ComfyUI and not executed
prompt_obj = metadata.get("current_prompt")
if prompt_obj and hasattr(prompt_obj, "original_prompt"):
original_prompt = prompt_obj.original_prompt
for node_id, node_data in original_prompt.items():
class_type = node_data.get("class_type")
if class_type and class_type in NODE_CLASS_MAPPINGS:
class_obj = NODE_CLASS_MAPPINGS[class_type]
class_name = class_obj.__name__
# Check if this is a VAEDecode node
if class_name == "VAEDecode":
# Try to find this node in the cache
cache_key = f"{node_id}:{class_name}"
if cache_key in self.node_cache:
cached_data = self.node_cache[cache_key]
if IMAGES in cached_data and node_id in cached_data[IMAGES]:
image_data = cached_data[IMAGES][node_id]["image"]
# Handle different image formats
if isinstance(image_data, (list, tuple)) and len(image_data) > 0:
return image_data[0]
return image_data
return None

View File

@@ -0,0 +1,280 @@
import os
from .constants import MODELS, PROMPTS, SAMPLING, LORAS, SIZE, IMAGES
class NodeMetadataExtractor:
"""Base class for node-specific metadata extraction"""
@staticmethod
def extract(node_id, inputs, outputs, metadata):
"""Extract metadata from node inputs/outputs"""
pass
@staticmethod
def update(node_id, outputs, metadata):
"""Update metadata with node outputs after execution"""
pass
class GenericNodeExtractor(NodeMetadataExtractor):
"""Default extractor for nodes without specific handling"""
@staticmethod
def extract(node_id, inputs, outputs, metadata):
pass
class CheckpointLoaderExtractor(NodeMetadataExtractor):
@staticmethod
def extract(node_id, inputs, outputs, metadata):
if not inputs or "ckpt_name" not in inputs:
return
model_name = inputs.get("ckpt_name")
if model_name:
metadata[MODELS][node_id] = {
"name": model_name,
"type": "checkpoint",
"node_id": node_id
}
class CLIPTextEncodeExtractor(NodeMetadataExtractor):
@staticmethod
def extract(node_id, inputs, outputs, metadata):
if not inputs or "text" not in inputs:
return
text = inputs.get("text", "")
metadata[PROMPTS][node_id] = {
"text": text,
"node_id": node_id
}
class SamplerExtractor(NodeMetadataExtractor):
@staticmethod
def extract(node_id, inputs, outputs, metadata):
if not inputs:
return
sampling_params = {}
for key in ["seed", "steps", "cfg", "sampler_name", "scheduler", "denoise"]:
if key in inputs:
sampling_params[key] = inputs[key]
metadata[SAMPLING][node_id] = {
"parameters": sampling_params,
"node_id": node_id
}
# Extract latent image dimensions if available
if "latent_image" in inputs and inputs["latent_image"] is not None:
latent = inputs["latent_image"]
if isinstance(latent, dict) and "samples" in latent:
# Extract dimensions from latent tensor
samples = latent["samples"]
if hasattr(samples, "shape") and len(samples.shape) >= 3:
# Correct shape interpretation: [batch_size, channels, height/8, width/8]
# Multiply by 8 to get actual pixel dimensions
height = int(samples.shape[2] * 8)
width = int(samples.shape[3] * 8)
if SIZE not in metadata:
metadata[SIZE] = {}
metadata[SIZE][node_id] = {
"width": width,
"height": height,
"node_id": node_id
}
class KSamplerAdvancedExtractor(NodeMetadataExtractor):
@staticmethod
def extract(node_id, inputs, outputs, metadata):
if not inputs:
return
sampling_params = {}
for key in ["noise_seed", "steps", "cfg", "sampler_name", "scheduler", "add_noise"]:
if key in inputs:
sampling_params[key] = inputs[key]
metadata[SAMPLING][node_id] = {
"parameters": sampling_params,
"node_id": node_id
}
# Extract latent image dimensions if available
if "latent_image" in inputs and inputs["latent_image"] is not None:
latent = inputs["latent_image"]
if isinstance(latent, dict) and "samples" in latent:
# Extract dimensions from latent tensor
samples = latent["samples"]
if hasattr(samples, "shape") and len(samples.shape) >= 3:
# Correct shape interpretation: [batch_size, channels, height/8, width/8]
# Multiply by 8 to get actual pixel dimensions
height = int(samples.shape[2] * 8)
width = int(samples.shape[3] * 8)
if SIZE not in metadata:
metadata[SIZE] = {}
metadata[SIZE][node_id] = {
"width": width,
"height": height,
"node_id": node_id
}
class LoraLoaderExtractor(NodeMetadataExtractor):
@staticmethod
def extract(node_id, inputs, outputs, metadata):
if not inputs or "lora_name" not in inputs:
return
lora_name = inputs.get("lora_name")
# Extract base filename without extension from path
lora_name = os.path.splitext(os.path.basename(lora_name))[0]
strength_model = round(float(inputs.get("strength_model", 1.0)), 2)
# Use the standardized format with lora_list
metadata[LORAS][node_id] = {
"lora_list": [
{
"name": lora_name,
"strength": strength_model
}
],
"node_id": node_id
}
class ImageSizeExtractor(NodeMetadataExtractor):
@staticmethod
def extract(node_id, inputs, outputs, metadata):
if not inputs:
return
width = inputs.get("width", 512)
height = inputs.get("height", 512)
if SIZE not in metadata:
metadata[SIZE] = {}
metadata[SIZE][node_id] = {
"width": width,
"height": height,
"node_id": node_id
}
class LoraLoaderManagerExtractor(NodeMetadataExtractor):
@staticmethod
def extract(node_id, inputs, outputs, metadata):
if not inputs:
return
active_loras = []
# Process lora_stack if available
if "lora_stack" in inputs:
lora_stack = inputs.get("lora_stack", [])
for lora_path, model_strength, clip_strength in lora_stack:
# Extract lora name from path (following the format in lora_loader.py)
lora_name = os.path.splitext(os.path.basename(lora_path))[0]
active_loras.append({
"name": lora_name,
"strength": model_strength
})
# Process loras from inputs
if "loras" in inputs:
loras_data = inputs.get("loras", [])
# Handle new format: {'loras': {'__value__': [...]}}
if isinstance(loras_data, dict) and '__value__' in loras_data:
loras_list = loras_data['__value__']
# Handle old format: {'loras': [...]}
elif isinstance(loras_data, list):
loras_list = loras_data
else:
loras_list = []
# Filter for active loras
for lora in loras_list:
if isinstance(lora, dict) and lora.get("active", True) and not lora.get("_isDummy", False):
active_loras.append({
"name": lora.get("name", ""),
"strength": float(lora.get("strength", 1.0))
})
if active_loras:
metadata[LORAS][node_id] = {
"lora_list": active_loras,
"node_id": node_id
}
class FluxGuidanceExtractor(NodeMetadataExtractor):
@staticmethod
def extract(node_id, inputs, outputs, metadata):
if not inputs or "guidance" not in inputs:
return
guidance_value = inputs.get("guidance")
# Store the guidance value in SAMPLING category
if node_id not in metadata[SAMPLING]:
metadata[SAMPLING][node_id] = {"parameters": {}, "node_id": node_id}
metadata[SAMPLING][node_id]["parameters"]["guidance"] = guidance_value
class UNETLoaderExtractor(NodeMetadataExtractor):
@staticmethod
def extract(node_id, inputs, outputs, metadata):
if not inputs or "unet_name" not in inputs:
return
model_name = inputs.get("unet_name")
if model_name:
metadata[MODELS][node_id] = {
"name": model_name,
"type": "checkpoint",
"node_id": node_id
}
class VAEDecodeExtractor(NodeMetadataExtractor):
@staticmethod
def extract(node_id, inputs, outputs, metadata):
pass
@staticmethod
def update(node_id, outputs, metadata):
# Ensure IMAGES category exists
if IMAGES not in metadata:
metadata[IMAGES] = {}
# Save image data under node ID index to be captured by caching mechanism
metadata[IMAGES][node_id] = {
"node_id": node_id,
"image": outputs
}
# Only set first_decode if it hasn't been recorded yet
if "first_decode" not in metadata[IMAGES]:
metadata[IMAGES]["first_decode"] = metadata[IMAGES][node_id]
# Registry of node-specific extractors
NODE_EXTRACTORS = {
# Sampling
"KSampler": SamplerExtractor,
"KSamplerAdvanced": KSamplerAdvancedExtractor, # Add KSamplerAdvanced
"SamplerCustomAdvanced": SamplerExtractor, # Add SamplerCustomAdvanced
# Loaders
"CheckpointLoaderSimple": CheckpointLoaderExtractor,
"UNETLoader": UNETLoaderExtractor, # Updated to use dedicated extractor
"LoraLoader": LoraLoaderExtractor,
"LoraManagerLoader": LoraLoaderManagerExtractor,
# Conditioning
"CLIPTextEncode": CLIPTextEncodeExtractor,
# Latent
"EmptyLatentImage": ImageSizeExtractor,
# Flux
"FluxGuidance": FluxGuidanceExtractor, # Add FluxGuidance
# Image
"VAEDecode": VAEDecodeExtractor, # Added VAEDecode extractor
# Add other nodes as needed
}

View File

@@ -0,0 +1,35 @@
import logging
from ..metadata_collector.metadata_processor import MetadataProcessor
logger = logging.getLogger(__name__)
class DebugMetadata:
NAME = "Debug Metadata (LoraManager)"
CATEGORY = "Lora Manager/utils"
DESCRIPTION = "Debug node to verify metadata_processor functionality"
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"images": ("IMAGE",),
},
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("metadata_json",)
FUNCTION = "process_metadata"
def process_metadata(self, images):
try:
# Get the current execution context's metadata
from ..metadata_collector import get_metadata
metadata = get_metadata()
# Use the MetadataProcessor to convert it to JSON string
metadata_json = MetadataProcessor.to_json(metadata)
return (metadata_json,)
except Exception as e:
logger.error(f"Error processing metadata: {e}")
return ("{}",) # Return empty JSON object in case of error

View File

@@ -5,10 +5,10 @@ import re
import numpy as np import numpy as np
import folder_paths # type: ignore import folder_paths # type: ignore
from ..services.lora_scanner import LoraScanner from ..services.lora_scanner import LoraScanner
from ..workflow.parser import WorkflowParser from ..metadata_collector.metadata_processor import MetadataProcessor
from ..metadata_collector import get_metadata
from PIL import Image, PngImagePlugin from PIL import Image, PngImagePlugin
import piexif import piexif
from io import BytesIO
class SaveImage: class SaveImage:
NAME = "Save Image (LoraManager)" NAME = "Save Image (LoraManager)"
@@ -34,8 +34,7 @@ class SaveImage:
"file_format": (["png", "jpeg", "webp"],), "file_format": (["png", "jpeg", "webp"],),
}, },
"optional": { "optional": {
"custom_prompt": ("STRING", {"default": "", "forceInput": True}), "lossless_webp": ("BOOLEAN", {"default": False}),
"lossless_webp": ("BOOLEAN", {"default": True}),
"quality": ("INT", {"default": 100, "min": 1, "max": 100}), "quality": ("INT", {"default": 100, "min": 1, "max": 100}),
"embed_workflow": ("BOOLEAN", {"default": False}), "embed_workflow": ("BOOLEAN", {"default": False}),
"add_counter_to_filename": ("BOOLEAN", {"default": True}), "add_counter_to_filename": ("BOOLEAN", {"default": True}),
@@ -61,21 +60,17 @@ class SaveImage:
return item.get('sha256') return item.get('sha256')
return None return None
async def format_metadata(self, parsed_workflow, custom_prompt=None): async def format_metadata(self, metadata_dict):
"""Format metadata in the requested format similar to userComment example""" """Format metadata in the requested format similar to userComment example"""
if not parsed_workflow: if not metadata_dict:
return "" return ""
# Extract the prompt and negative prompt # Extract the prompt and negative prompt
prompt = parsed_workflow.get('prompt', '') prompt = metadata_dict.get('prompt', '')
negative_prompt = parsed_workflow.get('negative_prompt', '') negative_prompt = metadata_dict.get('negative_prompt', '')
# Override prompt with custom_prompt if provided
if custom_prompt:
prompt = custom_prompt
# Extract loras from the prompt if present # Extract loras from the prompt if present
loras_text = parsed_workflow.get('loras', '') loras_text = metadata_dict.get('loras', '')
lora_hashes = {} lora_hashes = {}
# If loras are found, add them on a new line after the prompt # If loras are found, add them on a new line after the prompt
@@ -104,11 +99,11 @@ class SaveImage:
params = [] params = []
# Add standard parameters in the correct order # Add standard parameters in the correct order
if 'steps' in parsed_workflow: if 'steps' in metadata_dict:
params.append(f"Steps: {parsed_workflow.get('steps')}") params.append(f"Steps: {metadata_dict.get('steps')}")
if 'sampler' in parsed_workflow: if 'sampler' in metadata_dict:
sampler = parsed_workflow.get('sampler') sampler = metadata_dict.get('sampler')
# Convert ComfyUI sampler names to user-friendly names # Convert ComfyUI sampler names to user-friendly names
sampler_mapping = { sampler_mapping = {
'euler': 'Euler', 'euler': 'Euler',
@@ -130,8 +125,8 @@ class SaveImage:
sampler_name = sampler_mapping.get(sampler, sampler) sampler_name = sampler_mapping.get(sampler, sampler)
params.append(f"Sampler: {sampler_name}") params.append(f"Sampler: {sampler_name}")
if 'scheduler' in parsed_workflow: if 'scheduler' in metadata_dict:
scheduler = parsed_workflow.get('scheduler') scheduler = metadata_dict.get('scheduler')
scheduler_mapping = { scheduler_mapping = {
'normal': 'Simple', 'normal': 'Simple',
'karras': 'Karras', 'karras': 'Karras',
@@ -142,27 +137,36 @@ class SaveImage:
scheduler_name = scheduler_mapping.get(scheduler, scheduler) scheduler_name = scheduler_mapping.get(scheduler, scheduler)
params.append(f"Schedule type: {scheduler_name}") params.append(f"Schedule type: {scheduler_name}")
# CFG scale (cfg in parsed_workflow) # CFG scale (cfg_scale in metadata_dict)
if 'cfg_scale' in parsed_workflow: if 'cfg_scale' in metadata_dict:
params.append(f"CFG scale: {parsed_workflow.get('cfg_scale')}") params.append(f"CFG scale: {metadata_dict.get('cfg_scale')}")
elif 'cfg' in parsed_workflow: elif 'cfg' in metadata_dict:
params.append(f"CFG scale: {parsed_workflow.get('cfg')}") params.append(f"CFG scale: {metadata_dict.get('cfg')}")
# Seed # Seed
if 'seed' in parsed_workflow: if 'seed' in metadata_dict:
params.append(f"Seed: {parsed_workflow.get('seed')}") params.append(f"Seed: {metadata_dict.get('seed')}")
# Size # Size
if 'size' in parsed_workflow: if 'size' in metadata_dict:
params.append(f"Size: {parsed_workflow.get('size')}") params.append(f"Size: {metadata_dict.get('size')}")
# Model info # Model info
if 'checkpoint' in parsed_workflow: if 'checkpoint' in metadata_dict:
# Extract basename without path # Ensure checkpoint is a string before processing
checkpoint = os.path.basename(parsed_workflow.get('checkpoint', '')) checkpoint = metadata_dict.get('checkpoint')
# Remove extension if present if checkpoint is not None:
checkpoint = os.path.splitext(checkpoint)[0] # Handle both string and other types safely
params.append(f"Model: {checkpoint}") if isinstance(checkpoint, str):
# Extract basename without path
checkpoint = os.path.basename(checkpoint)
# Remove extension if present
checkpoint = os.path.splitext(checkpoint)[0]
else:
# Convert non-string to string
checkpoint = str(checkpoint)
params.append(f"Model: {checkpoint}")
# Add LoRA hashes if available # Add LoRA hashes if available
if lora_hashes: if lora_hashes:
@@ -181,9 +185,9 @@ class SaveImage:
# credit to nkchocoai # credit to nkchocoai
# Add format_filename method to handle pattern substitution # Add format_filename method to handle pattern substitution
def format_filename(self, filename, parsed_workflow): def format_filename(self, filename, metadata_dict):
"""Format filename with metadata values""" """Format filename with metadata values"""
if not parsed_workflow: if not metadata_dict:
return filename return filename
result = re.findall(self.pattern_format, filename) result = re.findall(self.pattern_format, filename)
@@ -191,30 +195,30 @@ class SaveImage:
parts = segment.replace("%", "").split(":") parts = segment.replace("%", "").split(":")
key = parts[0] key = parts[0]
if key == "seed" and 'seed' in parsed_workflow: if key == "seed" and 'seed' in metadata_dict:
filename = filename.replace(segment, str(parsed_workflow.get('seed', ''))) filename = filename.replace(segment, str(metadata_dict.get('seed', '')))
elif key == "width" and 'size' in parsed_workflow: elif key == "width" and 'size' in metadata_dict:
size = parsed_workflow.get('size', 'x') size = metadata_dict.get('size', 'x')
w = size.split('x')[0] if isinstance(size, str) else size[0] w = size.split('x')[0] if isinstance(size, str) else size[0]
filename = filename.replace(segment, str(w)) filename = filename.replace(segment, str(w))
elif key == "height" and 'size' in parsed_workflow: elif key == "height" and 'size' in metadata_dict:
size = parsed_workflow.get('size', 'x') size = metadata_dict.get('size', 'x')
h = size.split('x')[1] if isinstance(size, str) else size[1] h = size.split('x')[1] if isinstance(size, str) else size[1]
filename = filename.replace(segment, str(h)) filename = filename.replace(segment, str(h))
elif key == "pprompt" and 'prompt' in parsed_workflow: elif key == "pprompt" and 'prompt' in metadata_dict:
prompt = parsed_workflow.get('prompt', '').replace("\n", " ") prompt = metadata_dict.get('prompt', '').replace("\n", " ")
if len(parts) >= 2: if len(parts) >= 2:
length = int(parts[1]) length = int(parts[1])
prompt = prompt[:length] prompt = prompt[:length]
filename = filename.replace(segment, prompt.strip()) filename = filename.replace(segment, prompt.strip())
elif key == "nprompt" and 'negative_prompt' in parsed_workflow: elif key == "nprompt" and 'negative_prompt' in metadata_dict:
prompt = parsed_workflow.get('negative_prompt', '').replace("\n", " ") prompt = metadata_dict.get('negative_prompt', '').replace("\n", " ")
if len(parts) >= 2: if len(parts) >= 2:
length = int(parts[1]) length = int(parts[1])
prompt = prompt[:length] prompt = prompt[:length]
filename = filename.replace(segment, prompt.strip()) filename = filename.replace(segment, prompt.strip())
elif key == "model" and 'checkpoint' in parsed_workflow: elif key == "model" and 'checkpoint' in metadata_dict:
model = parsed_workflow.get('checkpoint', '') model = metadata_dict.get('checkpoint', '')
model = os.path.splitext(os.path.basename(model))[0] model = os.path.splitext(os.path.basename(model))[0]
if len(parts) >= 2: if len(parts) >= 2:
length = int(parts[1]) length = int(parts[1])
@@ -246,23 +250,19 @@ class SaveImage:
return filename return filename
def save_images(self, images, filename_prefix, file_format, prompt=None, extra_pnginfo=None, def save_images(self, images, filename_prefix, file_format, prompt=None, extra_pnginfo=None,
lossless_webp=True, quality=100, embed_workflow=False, add_counter_to_filename=True, lossless_webp=True, quality=100, embed_workflow=False, add_counter_to_filename=True):
custom_prompt=None):
"""Save images with metadata""" """Save images with metadata"""
results = [] results = []
# Parse the workflow using the WorkflowParser # Get metadata using the metadata collector
parser = WorkflowParser() raw_metadata = get_metadata()
if prompt: metadata_dict = MetadataProcessor.to_dict(raw_metadata)
parsed_workflow = parser.parse_workflow(prompt)
else:
parsed_workflow = {}
# Get or create metadata asynchronously # Get or create metadata asynchronously
metadata = asyncio.run(self.format_metadata(parsed_workflow, custom_prompt)) metadata = asyncio.run(self.format_metadata(metadata_dict))
# Process filename_prefix with pattern substitution # Process filename_prefix with pattern substitution
filename_prefix = self.format_filename(filename_prefix, parsed_workflow) filename_prefix = self.format_filename(filename_prefix, metadata_dict)
# Get initial save path info once for the batch # Get initial save path info once for the batch
full_output_folder, filename, counter, subfolder, processed_prefix = folder_paths.get_save_image_path( full_output_folder, filename, counter, subfolder, processed_prefix = folder_paths.get_save_image_path(
@@ -290,7 +290,8 @@ class SaveImage:
if file_format == "png": if file_format == "png":
file = base_filename + ".png" file = base_filename + ".png"
file_extension = ".png" file_extension = ".png"
save_kwargs = {"optimize": True, "compress_level": self.compress_level} # Remove "optimize": True to match built-in node behavior
save_kwargs = {"compress_level": self.compress_level}
pnginfo = PngImagePlugin.PngInfo() pnginfo = PngImagePlugin.PngInfo()
elif file_format == "jpeg": elif file_format == "jpeg":
file = base_filename + ".jpg" file = base_filename + ".jpg"
@@ -299,7 +300,8 @@ class SaveImage:
elif file_format == "webp": elif file_format == "webp":
file = base_filename + ".webp" file = base_filename + ".webp"
file_extension = ".webp" file_extension = ".webp"
save_kwargs = {"quality": quality, "lossless": lossless_webp} # Add optimization param to control performance
save_kwargs = {"quality": quality, "lossless": lossless_webp, "method": 0}
# Full save path # Full save path
file_path = os.path.join(full_output_folder, file) file_path = os.path.join(full_output_folder, file)
@@ -347,8 +349,7 @@ class SaveImage:
return results return results
def process_image(self, images, filename_prefix="ComfyUI", file_format="png", prompt=None, extra_pnginfo=None, def process_image(self, images, filename_prefix="ComfyUI", file_format="png", prompt=None, extra_pnginfo=None,
lossless_webp=True, quality=100, embed_workflow=False, add_counter_to_filename=True, lossless_webp=True, quality=100, embed_workflow=False, add_counter_to_filename=True):
custom_prompt=""):
"""Process and save image with metadata""" """Process and save image with metadata"""
# Make sure the output directory exists # Make sure the output directory exists
os.makedirs(self.output_dir, exist_ok=True) os.makedirs(self.output_dir, exist_ok=True)
@@ -369,8 +370,7 @@ class SaveImage:
lossless_webp, lossless_webp,
quality, quality,
embed_workflow, embed_workflow,
add_counter_to_filename, add_counter_to_filename
custom_prompt if custom_prompt.strip() else None
) )
return (images,) return (images,)

View File

@@ -1,5 +1,9 @@
import os import os
import time import time
import numpy as np
from PIL import Image
import torch
import io
import logging import logging
from aiohttp import web from aiohttp import web
from typing import Dict from typing import Dict
@@ -11,9 +15,11 @@ from ..utils.recipe_parsers import RecipeParserFactory
from ..utils.constants import CARD_PREVIEW_WIDTH from ..utils.constants import CARD_PREVIEW_WIDTH
from ..config import config from ..config import config
from ..workflow.parser import WorkflowParser from ..metadata_collector import get_metadata # Add MetadataCollector import
from ..metadata_collector.metadata_processor import MetadataProcessor # Add MetadataProcessor import
from ..utils.utils import download_civitai_image from ..utils.utils import download_civitai_image
from ..services.service_registry import ServiceRegistry # Add ServiceRegistry import from ..services.service_registry import ServiceRegistry # Add ServiceRegistry import
from ..metadata_collector.metadata_registry import MetadataRegistry
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -24,7 +30,7 @@ class RecipeRoutes:
# Initialize service references as None, will be set during async init # Initialize service references as None, will be set during async init
self.recipe_scanner = None self.recipe_scanner = None
self.civitai_client = None self.civitai_client = None
self.parser = WorkflowParser() # Remove WorkflowParser instance
# Pre-warm the cache # Pre-warm the cache
self._init_cache_task = None self._init_cache_task = None
@@ -656,8 +662,8 @@ class RecipeRoutes:
logger.error(f"Error retrieving base models: {e}", exc_info=True) logger.error(f"Error retrieving base models: {e}", exc_info=True)
return web.json_response({ return web.json_response({
'success': False, 'success': False,
'error': str(e) 'error': str(e)}
}, status=500) , status=500)
async def share_recipe(self, request: web.Request) -> web.Response: async def share_recipe(self, request: web.Request) -> web.Response:
"""Process a recipe image for sharing by adding metadata to EXIF""" """Process a recipe image for sharing by adding metadata to EXIF"""
@@ -786,50 +792,72 @@ class RecipeRoutes:
# Ensure services are initialized # Ensure services are initialized
await self.init_services() await self.init_services()
reader = await request.multipart() # Get metadata using the metadata collector instead of workflow parsing
raw_metadata = get_metadata()
metadata_dict = MetadataProcessor.to_dict(raw_metadata)
# Process form data # Check if we have valid metadata
workflow_json = None if not metadata_dict:
return web.json_response({"error": "No generation metadata found"}, status=400)
while True: # Get the most recent image from metadata registry instead of temp directory
field = await reader.next() metadata_registry = MetadataRegistry()
if field is None: latest_image = metadata_registry.get_first_decoded_image()
break
if not latest_image:
return web.json_response({"error": "No recent images found to use for recipe. Try generating an image first."}, status=400)
# Convert the image data to bytes - handle tuple and tensor cases
logger.debug(f"Image type: {type(latest_image)}")
try:
# Handle the tuple case first
if isinstance(latest_image, tuple):
# Extract the tensor from the tuple
if len(latest_image) > 0:
tensor_image = latest_image[0]
else:
return web.json_response({"error": "Empty image tuple received"}, status=400)
else:
tensor_image = latest_image
if field.name == 'workflow_json': # Get the shape info for debugging
workflow_text = await field.text() if hasattr(tensor_image, 'shape'):
try: shape_info = tensor_image.shape
workflow_json = json.loads(workflow_text) logger.debug(f"Tensor shape: {shape_info}, dtype: {tensor_image.dtype}")
except:
return web.json_response({"error": "Invalid workflow JSON"}, status=400) # Convert tensor to numpy array
if isinstance(tensor_image, torch.Tensor):
image_np = tensor_image.cpu().numpy()
else:
image_np = np.array(tensor_image)
# Handle different tensor shapes
# Case: (1, 1, H, W, 3) or (1, H, W, 3) - batch or multi-batch
if len(image_np.shape) > 3:
# Remove batch dimensions until we get to (H, W, 3)
while len(image_np.shape) > 3:
image_np = image_np[0]
# If values are in [0, 1] range, convert to [0, 255]
if image_np.dtype == np.float32 or image_np.dtype == np.float64:
if image_np.max() <= 1.0:
image_np = (image_np * 255).astype(np.uint8)
# Ensure image is in the right format (HWC with RGB channels)
if len(image_np.shape) == 3 and image_np.shape[2] == 3:
pil_image = Image.fromarray(image_np)
img_byte_arr = io.BytesIO()
pil_image.save(img_byte_arr, format='PNG')
image = img_byte_arr.getvalue()
else:
return web.json_response({"error": f"Cannot handle this data shape: {image_np.shape}, {image_np.dtype}"}, status=400)
except Exception as e:
logger.error(f"Error processing image data: {str(e)}", exc_info=True)
return web.json_response({"error": f"Error processing image: {str(e)}"}, status=400)
if not workflow_json: # Get the lora stack from the metadata
return web.json_response({"error": "Missing workflow JSON"}, status=400) lora_stack = metadata_dict.get("loras", "")
# Find the latest image in the temp directory
temp_dir = config.temp_directory
image_files = []
for file in os.listdir(temp_dir):
if file.lower().endswith(('.png', '.jpg', '.jpeg', '.webp')):
file_path = os.path.join(temp_dir, file)
image_files.append((file_path, os.path.getmtime(file_path)))
if not image_files:
return web.json_response({"error": "No recent images found to use for recipe"}, status=400)
# Sort by modification time (newest first)
image_files.sort(key=lambda x: x[1], reverse=True)
latest_image_path = image_files[0][0]
# Parse the workflow to extract generation parameters and loras
parsed_workflow = self.parser.parse_workflow(workflow_json)
if not parsed_workflow:
return web.json_response({"error": "Could not extract parameters from workflow"}, status=400)
# Get the lora stack from the parsed workflow
lora_stack = parsed_workflow.get("loras", "")
# Parse the lora stack format: "<lora:name:strength> <lora:name2:strength2> ..." # Parse the lora stack format: "<lora:name:strength> <lora:name2:strength2> ..."
import re import re
@@ -837,7 +865,7 @@ class RecipeRoutes:
# Check if any loras were found # Check if any loras were found
if not lora_matches: if not lora_matches:
return web.json_response({"error": "No LoRAs found in the workflow"}, status=400) return web.json_response({"error": "No LoRAs found in the generation metadata"}, status=400)
# Generate recipe name from the first 3 loras (or less if fewer are available) # Generate recipe name from the first 3 loras (or less if fewer are available)
loras_for_name = lora_matches[:3] # Take at most 3 loras for the name loras_for_name = lora_matches[:3] # Take at most 3 loras for the name
@@ -851,10 +879,6 @@ class RecipeRoutes:
recipe_name = " ".join(recipe_name_parts) recipe_name = " ".join(recipe_name_parts)
# Read the image
with open(latest_image_path, 'rb') as f:
image = f.read()
# Create recipes directory if it doesn't exist # Create recipes directory if it doesn't exist
recipes_dir = self.recipe_scanner.recipes_dir recipes_dir = self.recipe_scanner.recipes_dir
os.makedirs(recipes_dir, exist_ok=True) os.makedirs(recipes_dir, exist_ok=True)
@@ -922,8 +946,8 @@ class RecipeRoutes:
"created_date": time.time(), "created_date": time.time(),
"base_model": most_common_base_model, "base_model": most_common_base_model,
"loras": loras_data, "loras": loras_data,
"checkpoint": parsed_workflow.get("checkpoint", ""), "checkpoint": metadata_dict.get("checkpoint", ""),
"gen_params": {key: value for key, value in parsed_workflow.items() "gen_params": {key: value for key, value in metadata_dict.items()
if key not in ['checkpoint', 'loras']}, if key not in ['checkpoint', 'loras']},
"loras_stack": lora_stack # Include the original lora stack string "loras_stack": lora_stack # Include the original lora stack string
} }

View File

@@ -966,9 +966,6 @@ export function addLorasWidget(node, name, opts, callback) {
// Function to directly save the recipe without dialog // Function to directly save the recipe without dialog
async function saveRecipeDirectly(widget) { async function saveRecipeDirectly(widget) {
try { try {
// Get the workflow data from the ComfyUI app
const prompt = await app.graphToPrompt();
// Show loading toast // Show loading toast
if (app && app.extensionManager && app.extensionManager.toast) { if (app && app.extensionManager && app.extensionManager.toast) {
app.extensionManager.toast.add({ app.extensionManager.toast.add({
@@ -979,14 +976,9 @@ async function saveRecipeDirectly(widget) {
}); });
} }
// Prepare the data - only send workflow JSON // Send the request to the backend API without workflow data
const formData = new FormData();
formData.append('workflow_json', JSON.stringify(prompt.output));
// Send the request
const response = await fetch('/api/recipes/save-from-widget', { const response = await fetch('/api/recipes/save-from-widget', {
method: 'POST', method: 'POST'
body: formData
}); });
const result = await response.json(); const result = await response.json();