mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
feat: Implement metadata collection and processing framework with debug node for verification
This commit is contained in:
@@ -3,16 +3,23 @@ from .py.nodes.lora_loader import LoraManagerLoader
|
|||||||
from .py.nodes.trigger_word_toggle import TriggerWordToggle
|
from .py.nodes.trigger_word_toggle import TriggerWordToggle
|
||||||
from .py.nodes.lora_stacker import LoraStacker
|
from .py.nodes.lora_stacker import LoraStacker
|
||||||
from .py.nodes.save_image import SaveImage
|
from .py.nodes.save_image import SaveImage
|
||||||
|
from .py.nodes.debug_metadata import DebugMetadata
|
||||||
|
# Import metadata collector to install hooks on startup
|
||||||
|
from .py.metadata_collector import init as init_metadata_collector
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
LoraManagerLoader.NAME: LoraManagerLoader,
|
LoraManagerLoader.NAME: LoraManagerLoader,
|
||||||
TriggerWordToggle.NAME: TriggerWordToggle,
|
TriggerWordToggle.NAME: TriggerWordToggle,
|
||||||
LoraStacker.NAME: LoraStacker,
|
LoraStacker.NAME: LoraStacker,
|
||||||
SaveImage.NAME: SaveImage
|
SaveImage.NAME: SaveImage,
|
||||||
|
DebugMetadata.NAME: DebugMetadata
|
||||||
}
|
}
|
||||||
|
|
||||||
WEB_DIRECTORY = "./web/comfyui"
|
WEB_DIRECTORY = "./web/comfyui"
|
||||||
|
|
||||||
|
# Initialize metadata collector
|
||||||
|
init_metadata_collector()
|
||||||
|
|
||||||
# Register routes on import
|
# Register routes on import
|
||||||
LoraManager.add_routes()
|
LoraManager.add_routes()
|
||||||
__all__ = ['NODE_CLASS_MAPPINGS', 'WEB_DIRECTORY']
|
__all__ = ['NODE_CLASS_MAPPINGS', 'WEB_DIRECTORY']
|
||||||
|
|||||||
18
py/metadata_collector/__init__.py
Normal file
18
py/metadata_collector/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
import os
|
||||||
|
import importlib
|
||||||
|
from .metadata_hook import MetadataHook
|
||||||
|
from .metadata_registry import MetadataRegistry
|
||||||
|
|
||||||
|
def init():
|
||||||
|
# Install hooks to collect metadata during execution
|
||||||
|
MetadataHook.install()
|
||||||
|
|
||||||
|
# Initialize registry
|
||||||
|
registry = MetadataRegistry()
|
||||||
|
|
||||||
|
print("ComfyUI Metadata Collector initialized")
|
||||||
|
|
||||||
|
def get_metadata(prompt_id=None):
|
||||||
|
"""Helper function to get metadata from the registry"""
|
||||||
|
registry = MetadataRegistry()
|
||||||
|
return registry.get_metadata(prompt_id)
|
||||||
123
py/metadata_collector/metadata_hook.py
Normal file
123
py/metadata_collector/metadata_hook.py
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
import sys
|
||||||
|
import inspect
|
||||||
|
from .metadata_registry import MetadataRegistry
|
||||||
|
|
||||||
|
class MetadataHook:
|
||||||
|
"""Install hooks for metadata collection"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def install():
|
||||||
|
"""Install hooks to collect metadata during execution"""
|
||||||
|
try:
|
||||||
|
# Import ComfyUI's execution module
|
||||||
|
execution = None
|
||||||
|
try:
|
||||||
|
# Try direct import first
|
||||||
|
import execution # type: ignore
|
||||||
|
except ImportError:
|
||||||
|
# Try to locate from system modules
|
||||||
|
for module_name in sys.modules:
|
||||||
|
if module_name.endswith('.execution'):
|
||||||
|
execution = sys.modules[module_name]
|
||||||
|
break
|
||||||
|
|
||||||
|
# If we can't find the execution module, we can't install hooks
|
||||||
|
if execution is None:
|
||||||
|
print("Could not locate ComfyUI execution module, metadata collection disabled")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Store the original _map_node_over_list function
|
||||||
|
original_map_node_over_list = execution._map_node_over_list
|
||||||
|
|
||||||
|
# Define the wrapped _map_node_over_list function
|
||||||
|
def map_node_over_list_with_metadata(obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None):
|
||||||
|
# Only collect metadata when calling the main function of nodes
|
||||||
|
if func == obj.FUNCTION and hasattr(obj, '__class__'):
|
||||||
|
try:
|
||||||
|
# Get the current prompt_id from the registry
|
||||||
|
registry = MetadataRegistry()
|
||||||
|
prompt_id = registry.current_prompt_id
|
||||||
|
|
||||||
|
if prompt_id is not None:
|
||||||
|
# Get node class type
|
||||||
|
class_type = obj.__class__.__name__
|
||||||
|
|
||||||
|
# Unique ID might be available through the obj if it has a unique_id field
|
||||||
|
node_id = getattr(obj, 'unique_id', None)
|
||||||
|
if node_id is None and pre_execute_cb:
|
||||||
|
# Try to extract node_id through reflection on GraphBuilder.set_default_prefix
|
||||||
|
frame = inspect.currentframe()
|
||||||
|
while frame:
|
||||||
|
if 'unique_id' in frame.f_locals:
|
||||||
|
node_id = frame.f_locals['unique_id']
|
||||||
|
break
|
||||||
|
frame = frame.f_back
|
||||||
|
|
||||||
|
# Record inputs before execution
|
||||||
|
if node_id is not None:
|
||||||
|
registry.record_node_execution(node_id, class_type, input_data_all, None)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error collecting metadata (pre-execution): {str(e)}")
|
||||||
|
|
||||||
|
# Execute the original function
|
||||||
|
results = original_map_node_over_list(obj, input_data_all, func, allow_interrupt, execution_block_cb, pre_execute_cb)
|
||||||
|
|
||||||
|
# After execution, collect outputs for relevant nodes
|
||||||
|
if func == obj.FUNCTION and hasattr(obj, '__class__'):
|
||||||
|
try:
|
||||||
|
# Get the current prompt_id from the registry
|
||||||
|
registry = MetadataRegistry()
|
||||||
|
prompt_id = registry.current_prompt_id
|
||||||
|
|
||||||
|
if prompt_id is not None:
|
||||||
|
# Get node class type
|
||||||
|
class_type = obj.__class__.__name__
|
||||||
|
|
||||||
|
# Unique ID might be available through the obj if it has a unique_id field
|
||||||
|
node_id = getattr(obj, 'unique_id', None)
|
||||||
|
if node_id is None and pre_execute_cb:
|
||||||
|
# Try to extract node_id through reflection
|
||||||
|
frame = inspect.currentframe()
|
||||||
|
while frame:
|
||||||
|
if 'unique_id' in frame.f_locals:
|
||||||
|
node_id = frame.f_locals['unique_id']
|
||||||
|
break
|
||||||
|
frame = frame.f_back
|
||||||
|
|
||||||
|
# Record outputs after execution
|
||||||
|
if node_id is not None:
|
||||||
|
registry.update_node_execution(node_id, class_type, results)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error collecting metadata (post-execution): {str(e)}")
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
# Also hook the execute function to track the current prompt_id
|
||||||
|
original_execute = execution.execute
|
||||||
|
|
||||||
|
def execute_with_prompt_tracking(*args, **kwargs):
|
||||||
|
if len(args) >= 7: # Check if we have enough arguments
|
||||||
|
server, prompt, caches, node_id, extra_data, executed, prompt_id = args[:7]
|
||||||
|
registry = MetadataRegistry()
|
||||||
|
|
||||||
|
# Start collection if this is a new prompt
|
||||||
|
if not registry.current_prompt_id or registry.current_prompt_id != prompt_id:
|
||||||
|
registry.start_collection(prompt_id)
|
||||||
|
|
||||||
|
# Store the dynprompt reference for node lookups
|
||||||
|
if hasattr(prompt, 'original_prompt'):
|
||||||
|
registry.set_current_prompt(prompt)
|
||||||
|
|
||||||
|
# Execute the original function
|
||||||
|
return original_execute(*args, **kwargs)
|
||||||
|
|
||||||
|
# Replace the functions
|
||||||
|
execution._map_node_over_list = map_node_over_list_with_metadata
|
||||||
|
execution.execute = execute_with_prompt_tracking
|
||||||
|
# Make map_node_over_list public to avoid it being hidden by hooks
|
||||||
|
execution.map_node_over_list = original_map_node_over_list
|
||||||
|
|
||||||
|
print("Metadata collection hooks installed for runtime values")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error installing metadata hooks: {str(e)}")
|
||||||
171
py/metadata_collector/metadata_processor.py
Normal file
171
py/metadata_collector/metadata_processor.py
Normal file
@@ -0,0 +1,171 @@
|
|||||||
|
import json
|
||||||
|
|
||||||
|
class MetadataProcessor:
|
||||||
|
"""Process and format collected metadata"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def find_primary_sampler(metadata):
|
||||||
|
"""Find the primary KSampler node (with denoise=1)"""
|
||||||
|
primary_sampler = None
|
||||||
|
primary_sampler_id = None
|
||||||
|
|
||||||
|
for node_id, sampler_info in metadata.get("sampling", {}).items():
|
||||||
|
parameters = sampler_info.get("parameters", {})
|
||||||
|
denoise = parameters.get("denoise")
|
||||||
|
|
||||||
|
# If denoise is 1.0, this is likely the primary sampler
|
||||||
|
if denoise == 1.0 or denoise == 1:
|
||||||
|
primary_sampler = sampler_info
|
||||||
|
primary_sampler_id = node_id
|
||||||
|
break
|
||||||
|
|
||||||
|
return primary_sampler_id, primary_sampler
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def trace_node_input(prompt, node_id, input_name):
|
||||||
|
"""Trace an input connection from a node to find the source node"""
|
||||||
|
if not prompt or not prompt.original_prompt or node_id not in prompt.original_prompt:
|
||||||
|
return None
|
||||||
|
|
||||||
|
node_inputs = prompt.original_prompt[node_id].get("inputs", {})
|
||||||
|
if input_name not in node_inputs:
|
||||||
|
return None
|
||||||
|
|
||||||
|
input_value = node_inputs[input_name]
|
||||||
|
# Input connections are formatted as [node_id, output_index]
|
||||||
|
if isinstance(input_value, list) and len(input_value) >= 2:
|
||||||
|
return input_value[0] # Return connected node_id
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def find_primary_checkpoint(metadata):
|
||||||
|
"""Find the primary checkpoint model in the workflow"""
|
||||||
|
if not metadata.get("models"):
|
||||||
|
return None
|
||||||
|
|
||||||
|
# In most workflows, there's only one checkpoint, so we can just take the first one
|
||||||
|
for node_id, model_info in metadata.get("models", {}).items():
|
||||||
|
if model_info.get("type") == "checkpoint":
|
||||||
|
return model_info.get("name")
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def extract_generation_params(metadata):
|
||||||
|
"""Extract generation parameters from metadata using node relationships"""
|
||||||
|
params = {
|
||||||
|
"prompt": "",
|
||||||
|
"negative_prompt": "",
|
||||||
|
"seed": None,
|
||||||
|
"steps": None,
|
||||||
|
"cfg_scale": None,
|
||||||
|
"sampler": None,
|
||||||
|
"checkpoint": None,
|
||||||
|
"loras": "",
|
||||||
|
"size": None,
|
||||||
|
"clip_skip": None
|
||||||
|
}
|
||||||
|
|
||||||
|
# Get the prompt object for node relationship tracing
|
||||||
|
prompt = metadata.get("current_prompt")
|
||||||
|
|
||||||
|
# Find the primary KSampler node
|
||||||
|
primary_sampler_id, primary_sampler = MetadataProcessor.find_primary_sampler(metadata)
|
||||||
|
|
||||||
|
# Directly get checkpoint from metadata instead of tracing
|
||||||
|
checkpoint = MetadataProcessor.find_primary_checkpoint(metadata)
|
||||||
|
if checkpoint:
|
||||||
|
params["checkpoint"] = checkpoint
|
||||||
|
|
||||||
|
if primary_sampler:
|
||||||
|
# Extract sampling parameters
|
||||||
|
sampling_params = primary_sampler.get("parameters", {})
|
||||||
|
params["seed"] = sampling_params.get("seed")
|
||||||
|
params["steps"] = sampling_params.get("steps")
|
||||||
|
params["cfg_scale"] = sampling_params.get("cfg")
|
||||||
|
params["sampler"] = sampling_params.get("sampler_name")
|
||||||
|
|
||||||
|
# Trace connections from the primary sampler
|
||||||
|
if prompt and primary_sampler_id:
|
||||||
|
# Trace positive prompt
|
||||||
|
positive_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "positive")
|
||||||
|
if positive_node_id and positive_node_id in metadata.get("prompts", {}):
|
||||||
|
params["prompt"] = metadata["prompts"][positive_node_id].get("text", "")
|
||||||
|
|
||||||
|
# Trace negative prompt
|
||||||
|
negative_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "negative")
|
||||||
|
if negative_node_id and negative_node_id in metadata.get("prompts", {}):
|
||||||
|
params["negative_prompt"] = metadata["prompts"][negative_node_id].get("text", "")
|
||||||
|
|
||||||
|
# Check if the sampler itself has size information (from latent_image)
|
||||||
|
if primary_sampler_id in metadata.get("size", {}):
|
||||||
|
width = metadata["size"][primary_sampler_id].get("width")
|
||||||
|
height = metadata["size"][primary_sampler_id].get("height")
|
||||||
|
if width and height:
|
||||||
|
params["size"] = f"{width}x{height}"
|
||||||
|
else:
|
||||||
|
# Fallback to the previous trace method if needed
|
||||||
|
latent_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "latent_image")
|
||||||
|
if latent_node_id:
|
||||||
|
# Follow chain to find EmptyLatentImage node
|
||||||
|
size_found = False
|
||||||
|
current_node_id = latent_node_id
|
||||||
|
|
||||||
|
# Limit depth to avoid infinite loops in complex workflows
|
||||||
|
max_depth = 10
|
||||||
|
for _ in range(max_depth):
|
||||||
|
if current_node_id in metadata.get("size", {}):
|
||||||
|
width = metadata["size"][current_node_id].get("width")
|
||||||
|
height = metadata["size"][current_node_id].get("height")
|
||||||
|
if width and height:
|
||||||
|
params["size"] = f"{width}x{height}"
|
||||||
|
size_found = True
|
||||||
|
break
|
||||||
|
|
||||||
|
# Try to follow the chain
|
||||||
|
if prompt and prompt.original_prompt and current_node_id in prompt.original_prompt:
|
||||||
|
node_info = prompt.original_prompt[current_node_id]
|
||||||
|
if "inputs" in node_info:
|
||||||
|
# Look for a connection that might lead to size information
|
||||||
|
for input_name, input_value in node_info["inputs"].items():
|
||||||
|
if isinstance(input_value, list) and len(input_value) >= 2:
|
||||||
|
current_node_id = input_value[0]
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
break # No connections to follow
|
||||||
|
else:
|
||||||
|
break # No inputs to follow
|
||||||
|
else:
|
||||||
|
break # Can't follow further
|
||||||
|
|
||||||
|
# Extract LoRAs
|
||||||
|
lora_parts = []
|
||||||
|
for node_id, lora_info in metadata.get("loras", {}).items():
|
||||||
|
name = lora_info.get("name", "unknown")
|
||||||
|
strength = lora_info.get("strength_model", 1.0)
|
||||||
|
lora_parts.append(f"<lora:{name}:{strength}>")
|
||||||
|
params["loras"] = " ".join(lora_parts)
|
||||||
|
|
||||||
|
# Set default clip_skip value
|
||||||
|
params["clip_skip"] = "1" # Common default
|
||||||
|
|
||||||
|
return params
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def to_comfyui_format(metadata):
|
||||||
|
"""Convert extracted metadata to the ComfyUI output.json format"""
|
||||||
|
params = MetadataProcessor.extract_generation_params(metadata)
|
||||||
|
|
||||||
|
# Convert all values to strings to match output.json format
|
||||||
|
for key in params:
|
||||||
|
if params[key] is not None:
|
||||||
|
params[key] = str(params[key])
|
||||||
|
|
||||||
|
return params
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def to_json(metadata):
|
||||||
|
"""Convert metadata to JSON string"""
|
||||||
|
params = MetadataProcessor.to_comfyui_format(metadata)
|
||||||
|
return json.dumps(params, indent=4)
|
||||||
171
py/metadata_collector/metadata_registry.py
Normal file
171
py/metadata_collector/metadata_registry.py
Normal file
@@ -0,0 +1,171 @@
|
|||||||
|
import time
|
||||||
|
from .node_extractors import NODE_EXTRACTORS, GenericNodeExtractor
|
||||||
|
|
||||||
|
class MetadataRegistry:
|
||||||
|
"""A singleton registry to store and retrieve workflow metadata"""
|
||||||
|
_instance = None
|
||||||
|
|
||||||
|
def __new__(cls):
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = super().__new__(cls)
|
||||||
|
cls._instance._reset()
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
def _reset(self):
|
||||||
|
self.current_prompt_id = None
|
||||||
|
self.current_prompt = None
|
||||||
|
self.metadata = {}
|
||||||
|
self.prompt_metadata = {}
|
||||||
|
self.executed_nodes = set()
|
||||||
|
|
||||||
|
# Node-level cache for metadata
|
||||||
|
self.node_cache = {}
|
||||||
|
|
||||||
|
# Categories we want to track and retrieve from cache
|
||||||
|
self.metadata_categories = ["models", "prompts", "sampling", "loras", "size"]
|
||||||
|
|
||||||
|
def start_collection(self, prompt_id):
|
||||||
|
"""Begin metadata collection for a new prompt"""
|
||||||
|
self.current_prompt_id = prompt_id
|
||||||
|
self.executed_nodes = set()
|
||||||
|
self.prompt_metadata[prompt_id] = {
|
||||||
|
"models": {},
|
||||||
|
"prompts": {},
|
||||||
|
"sampling": {},
|
||||||
|
"loras": {},
|
||||||
|
"size": {},
|
||||||
|
"execution_order": [],
|
||||||
|
"current_prompt": None, # Will store the prompt object
|
||||||
|
"timestamp": time.time()
|
||||||
|
}
|
||||||
|
|
||||||
|
def set_current_prompt(self, prompt):
|
||||||
|
"""Set the current prompt object reference"""
|
||||||
|
self.current_prompt = prompt
|
||||||
|
if self.current_prompt_id and self.current_prompt_id in self.prompt_metadata:
|
||||||
|
# Store the prompt in the metadata for later relationship tracing
|
||||||
|
self.prompt_metadata[self.current_prompt_id]["current_prompt"] = prompt
|
||||||
|
|
||||||
|
def get_metadata(self, prompt_id=None):
|
||||||
|
"""Get collected metadata for a prompt"""
|
||||||
|
key = prompt_id if prompt_id is not None else self.current_prompt_id
|
||||||
|
if key not in self.prompt_metadata:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
metadata = self.prompt_metadata[key]
|
||||||
|
|
||||||
|
# If we have a current prompt object, check for non-executed nodes
|
||||||
|
prompt_obj = metadata.get("current_prompt")
|
||||||
|
if prompt_obj and hasattr(prompt_obj, "original_prompt"):
|
||||||
|
original_prompt = prompt_obj.original_prompt
|
||||||
|
|
||||||
|
# Fill in missing metadata from cache for nodes that weren't executed
|
||||||
|
self._fill_missing_metadata(key, original_prompt)
|
||||||
|
|
||||||
|
return self.prompt_metadata.get(key, {})
|
||||||
|
|
||||||
|
def _fill_missing_metadata(self, prompt_id, original_prompt):
|
||||||
|
"""Fill missing metadata from cache for non-executed nodes"""
|
||||||
|
if not original_prompt:
|
||||||
|
return
|
||||||
|
|
||||||
|
executed_nodes = self.executed_nodes
|
||||||
|
metadata = self.prompt_metadata[prompt_id]
|
||||||
|
|
||||||
|
# Iterate through nodes in the original prompt
|
||||||
|
for node_id, node_data in original_prompt.items():
|
||||||
|
# Skip if already executed in this run
|
||||||
|
if node_id in executed_nodes:
|
||||||
|
continue
|
||||||
|
|
||||||
|
class_type = node_data.get("class_type")
|
||||||
|
if not class_type:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Create cache key
|
||||||
|
cache_key = f"{node_id}:{class_type}"
|
||||||
|
|
||||||
|
# Check if this node type is relevant for metadata collection
|
||||||
|
if class_type in NODE_EXTRACTORS:
|
||||||
|
# Check if we have cached metadata for this node
|
||||||
|
if cache_key in self.node_cache:
|
||||||
|
cached_data = self.node_cache[cache_key]
|
||||||
|
|
||||||
|
# Apply cached metadata to the current metadata
|
||||||
|
for category in self.metadata_categories:
|
||||||
|
if category in cached_data and node_id in cached_data[category]:
|
||||||
|
if node_id not in metadata[category]:
|
||||||
|
metadata[category][node_id] = cached_data[category][node_id]
|
||||||
|
|
||||||
|
def record_node_execution(self, node_id, class_type, inputs, outputs):
|
||||||
|
"""Record information about a node's execution"""
|
||||||
|
if not self.current_prompt_id:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Add to execution order and mark as executed
|
||||||
|
if node_id not in self.executed_nodes:
|
||||||
|
self.executed_nodes.add(node_id)
|
||||||
|
self.prompt_metadata[self.current_prompt_id]["execution_order"].append(node_id)
|
||||||
|
|
||||||
|
# Process inputs to simplify working with them
|
||||||
|
processed_inputs = {}
|
||||||
|
for input_name, input_values in inputs.items():
|
||||||
|
if isinstance(input_values, list) and len(input_values) > 0:
|
||||||
|
# For single values, just use the first one (most common case)
|
||||||
|
processed_inputs[input_name] = input_values[0]
|
||||||
|
else:
|
||||||
|
processed_inputs[input_name] = input_values
|
||||||
|
|
||||||
|
# Extract node-specific metadata
|
||||||
|
extractor = NODE_EXTRACTORS.get(class_type, GenericNodeExtractor)
|
||||||
|
extractor.extract(
|
||||||
|
node_id,
|
||||||
|
processed_inputs,
|
||||||
|
outputs,
|
||||||
|
self.prompt_metadata[self.current_prompt_id]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Cache this node's metadata
|
||||||
|
self._cache_node_metadata(node_id, class_type)
|
||||||
|
|
||||||
|
def update_node_execution(self, node_id, class_type, outputs):
|
||||||
|
"""Update node metadata with output information"""
|
||||||
|
if not self.current_prompt_id:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Process outputs to make them more usable
|
||||||
|
processed_outputs = outputs
|
||||||
|
|
||||||
|
# Use the same extractor to update with outputs
|
||||||
|
extractor = NODE_EXTRACTORS.get(class_type, GenericNodeExtractor)
|
||||||
|
if hasattr(extractor, 'update'):
|
||||||
|
extractor.update(
|
||||||
|
node_id,
|
||||||
|
processed_outputs,
|
||||||
|
self.prompt_metadata[self.current_prompt_id]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update the cached metadata for this node
|
||||||
|
self._cache_node_metadata(node_id, class_type)
|
||||||
|
|
||||||
|
def _cache_node_metadata(self, node_id, class_type):
|
||||||
|
"""Cache the metadata for a specific node"""
|
||||||
|
if not self.current_prompt_id or not node_id or not class_type:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Create a cache key combining node_id and class_type
|
||||||
|
cache_key = f"{node_id}:{class_type}"
|
||||||
|
|
||||||
|
# Create a shallow copy of the node's metadata
|
||||||
|
node_metadata = {}
|
||||||
|
current_metadata = self.prompt_metadata[self.current_prompt_id]
|
||||||
|
|
||||||
|
for category in self.metadata_categories:
|
||||||
|
if category in current_metadata and node_id in current_metadata[category]:
|
||||||
|
if category not in node_metadata:
|
||||||
|
node_metadata[category] = {}
|
||||||
|
node_metadata[category][node_id] = current_metadata[category][node_id]
|
||||||
|
|
||||||
|
# Save to cache if we have any metadata for this node
|
||||||
|
if any(node_metadata.values()):
|
||||||
|
self.node_cache[cache_key] = node_metadata
|
||||||
163
py/metadata_collector/node_extractors.py
Normal file
163
py/metadata_collector/node_extractors.py
Normal file
@@ -0,0 +1,163 @@
|
|||||||
|
class NodeMetadataExtractor:
|
||||||
|
"""Base class for node-specific metadata extraction"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def extract(node_id, inputs, outputs, metadata):
|
||||||
|
"""Extract metadata from node inputs/outputs"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def update(node_id, outputs, metadata):
|
||||||
|
"""Update metadata with node outputs after execution"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
class GenericNodeExtractor(NodeMetadataExtractor):
|
||||||
|
"""Default extractor for nodes without specific handling"""
|
||||||
|
@staticmethod
|
||||||
|
def extract(node_id, inputs, outputs, metadata):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class CheckpointLoaderExtractor(NodeMetadataExtractor):
|
||||||
|
@staticmethod
|
||||||
|
def extract(node_id, inputs, outputs, metadata):
|
||||||
|
if not inputs or "ckpt_name" not in inputs:
|
||||||
|
return
|
||||||
|
|
||||||
|
model_name = inputs.get("ckpt_name")
|
||||||
|
if model_name:
|
||||||
|
metadata["models"][node_id] = {
|
||||||
|
"name": model_name,
|
||||||
|
"type": "checkpoint",
|
||||||
|
"node_id": node_id
|
||||||
|
}
|
||||||
|
|
||||||
|
class CLIPTextEncodeExtractor(NodeMetadataExtractor):
|
||||||
|
@staticmethod
|
||||||
|
def extract(node_id, inputs, outputs, metadata):
|
||||||
|
if not inputs or "text" not in inputs:
|
||||||
|
return
|
||||||
|
|
||||||
|
text = inputs.get("text", "")
|
||||||
|
metadata["prompts"][node_id] = {
|
||||||
|
"text": text,
|
||||||
|
"node_id": node_id
|
||||||
|
}
|
||||||
|
|
||||||
|
class SamplerExtractor(NodeMetadataExtractor):
|
||||||
|
@staticmethod
|
||||||
|
def extract(node_id, inputs, outputs, metadata):
|
||||||
|
if not inputs:
|
||||||
|
return
|
||||||
|
|
||||||
|
sampling_params = {}
|
||||||
|
for key in ["seed", "steps", "cfg", "sampler_name", "scheduler", "denoise"]:
|
||||||
|
if key in inputs:
|
||||||
|
sampling_params[key] = inputs[key]
|
||||||
|
|
||||||
|
metadata["sampling"][node_id] = {
|
||||||
|
"parameters": sampling_params,
|
||||||
|
"node_id": node_id
|
||||||
|
}
|
||||||
|
|
||||||
|
# Extract latent image dimensions if available
|
||||||
|
if "latent_image" in inputs and inputs["latent_image"] is not None:
|
||||||
|
latent = inputs["latent_image"]
|
||||||
|
if isinstance(latent, dict) and "samples" in latent:
|
||||||
|
# Extract dimensions from latent tensor
|
||||||
|
samples = latent["samples"]
|
||||||
|
if hasattr(samples, "shape") and len(samples.shape) >= 3:
|
||||||
|
# Correct shape interpretation: [batch_size, channels, height/8, width/8]
|
||||||
|
# Multiply by 8 to get actual pixel dimensions
|
||||||
|
height = int(samples.shape[2] * 8)
|
||||||
|
width = int(samples.shape[3] * 8)
|
||||||
|
|
||||||
|
if "size" not in metadata:
|
||||||
|
metadata["size"] = {}
|
||||||
|
|
||||||
|
metadata["size"][node_id] = {
|
||||||
|
"width": width,
|
||||||
|
"height": height,
|
||||||
|
"node_id": node_id
|
||||||
|
}
|
||||||
|
|
||||||
|
class LoraLoaderExtractor(NodeMetadataExtractor):
|
||||||
|
@staticmethod
|
||||||
|
def extract(node_id, inputs, outputs, metadata):
|
||||||
|
if not inputs or "lora_name" not in inputs:
|
||||||
|
return
|
||||||
|
|
||||||
|
lora_name = inputs.get("lora_name")
|
||||||
|
strength_model = inputs.get("strength_model", 1.0)
|
||||||
|
strength_clip = inputs.get("strength_clip", 1.0)
|
||||||
|
|
||||||
|
metadata["loras"][node_id] = {
|
||||||
|
"name": lora_name,
|
||||||
|
"strength_model": strength_model,
|
||||||
|
"strength_clip": strength_clip,
|
||||||
|
"node_id": node_id
|
||||||
|
}
|
||||||
|
|
||||||
|
class ImageSizeExtractor(NodeMetadataExtractor):
|
||||||
|
@staticmethod
|
||||||
|
def extract(node_id, inputs, outputs, metadata):
|
||||||
|
if not inputs:
|
||||||
|
return
|
||||||
|
|
||||||
|
width = inputs.get("width", 512)
|
||||||
|
height = inputs.get("height", 512)
|
||||||
|
|
||||||
|
if "size" not in metadata:
|
||||||
|
metadata["size"] = {}
|
||||||
|
|
||||||
|
metadata["size"][node_id] = {
|
||||||
|
"width": width,
|
||||||
|
"height": height,
|
||||||
|
"node_id": node_id
|
||||||
|
}
|
||||||
|
|
||||||
|
class LoraLoaderManagerExtractor(NodeMetadataExtractor):
|
||||||
|
@staticmethod
|
||||||
|
def extract(node_id, inputs, outputs, metadata):
|
||||||
|
if not inputs:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Handle LoraManager nodes which might store loras differently
|
||||||
|
if "loras" in inputs:
|
||||||
|
loras = inputs.get("loras", [])
|
||||||
|
if isinstance(loras, list):
|
||||||
|
active_loras = []
|
||||||
|
# Filter for active loras (may be a list of dicts with 'active' flag)
|
||||||
|
for lora in loras:
|
||||||
|
if isinstance(lora, dict) and lora.get("active", True) and not lora.get("_isDummy", False):
|
||||||
|
active_loras.append({
|
||||||
|
"name": lora.get("name", ""),
|
||||||
|
"strength": lora.get("strength", 1.0)
|
||||||
|
})
|
||||||
|
|
||||||
|
if active_loras:
|
||||||
|
metadata["loras"][node_id] = {
|
||||||
|
"lora_list": active_loras,
|
||||||
|
"node_id": node_id
|
||||||
|
}
|
||||||
|
|
||||||
|
# If there's a direct text field with lora definitions
|
||||||
|
if "text" in inputs:
|
||||||
|
text = inputs.get("text", "")
|
||||||
|
if text and "<lora:" in text:
|
||||||
|
metadata["loras"][node_id] = {
|
||||||
|
"raw_text": text,
|
||||||
|
"node_id": node_id
|
||||||
|
}
|
||||||
|
|
||||||
|
# Registry of node-specific extractors
|
||||||
|
NODE_EXTRACTORS = {
|
||||||
|
"CheckpointLoaderSimple": CheckpointLoaderExtractor,
|
||||||
|
"CLIPTextEncode": CLIPTextEncodeExtractor,
|
||||||
|
"KSampler": SamplerExtractor,
|
||||||
|
"LoraLoader": LoraLoaderExtractor,
|
||||||
|
"EmptyLatentImage": ImageSizeExtractor,
|
||||||
|
"Lora Loader (LoraManager)": LoraLoaderManagerExtractor,
|
||||||
|
"SamplerCustomAdvanced": SamplerExtractor, # Add SamplerCustomAdvanced
|
||||||
|
"UNETLoader": CheckpointLoaderExtractor, # Add UNETLoader
|
||||||
|
# Add other nodes as needed
|
||||||
|
}
|
||||||
35
py/nodes/debug_metadata.py
Normal file
35
py/nodes/debug_metadata.py
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
import logging
|
||||||
|
from ..metadata_collector.metadata_processor import MetadataProcessor
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class DebugMetadata:
|
||||||
|
NAME = "Debug Metadata (LoraManager)"
|
||||||
|
CATEGORY = "Lora Manager/utils"
|
||||||
|
DESCRIPTION = "Debug node to verify metadata_processor functionality"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"images": ("IMAGE",),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("STRING",)
|
||||||
|
RETURN_NAMES = ("metadata_json",)
|
||||||
|
FUNCTION = "process_metadata"
|
||||||
|
|
||||||
|
def process_metadata(self, images):
|
||||||
|
try:
|
||||||
|
# Get the current execution context's metadata
|
||||||
|
from ..metadata_collector import get_metadata
|
||||||
|
metadata = get_metadata()
|
||||||
|
|
||||||
|
# Use the MetadataProcessor to convert it to JSON string
|
||||||
|
metadata_json = MetadataProcessor.to_json(metadata)
|
||||||
|
|
||||||
|
return (metadata_json,)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing metadata: {e}")
|
||||||
|
return ("{}",) # Return empty JSON object in case of error
|
||||||
Reference in New Issue
Block a user