mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-25 15:15:44 -03:00
feat: Enhance metadata collection by adding support for async execution hooks and improving error handling. See #291 #298
This commit is contained in:
@@ -26,98 +26,191 @@ class MetadataHook:
|
|||||||
print("Could not locate ComfyUI execution module, metadata collection disabled")
|
print("Could not locate ComfyUI execution module, metadata collection disabled")
|
||||||
return
|
return
|
||||||
|
|
||||||
# Store the original _map_node_over_list function
|
# Detect whether we're using the new async version of ComfyUI
|
||||||
original_map_node_over_list = execution._map_node_over_list
|
is_async = False
|
||||||
|
map_node_func_name = '_map_node_over_list'
|
||||||
|
|
||||||
# Define the wrapped _map_node_over_list function
|
if hasattr(execution, '_async_map_node_over_list'):
|
||||||
def map_node_over_list_with_metadata(obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None):
|
is_async = inspect.iscoroutinefunction(execution._async_map_node_over_list)
|
||||||
# Only collect metadata when calling the main function of nodes
|
map_node_func_name = '_async_map_node_over_list'
|
||||||
if func == obj.FUNCTION and hasattr(obj, '__class__'):
|
elif hasattr(execution, '_map_node_over_list'):
|
||||||
try:
|
is_async = inspect.iscoroutinefunction(execution._map_node_over_list)
|
||||||
# Get the current prompt_id from the registry
|
|
||||||
registry = MetadataRegistry()
|
|
||||||
prompt_id = registry.current_prompt_id
|
|
||||||
|
|
||||||
if prompt_id is not None:
|
if is_async:
|
||||||
# Get node class type
|
print("Detected async ComfyUI execution, installing async metadata hooks")
|
||||||
class_type = obj.__class__.__name__
|
MetadataHook._install_async_hooks(execution, map_node_func_name)
|
||||||
|
else:
|
||||||
# Unique ID might be available through the obj if it has a unique_id field
|
print("Detected sync ComfyUI execution, installing sync metadata hooks")
|
||||||
node_id = getattr(obj, 'unique_id', None)
|
MetadataHook._install_sync_hooks(execution)
|
||||||
if node_id is None and pre_execute_cb:
|
|
||||||
# Try to extract node_id through reflection on GraphBuilder.set_default_prefix
|
|
||||||
frame = inspect.currentframe()
|
|
||||||
while frame:
|
|
||||||
if 'unique_id' in frame.f_locals:
|
|
||||||
node_id = frame.f_locals['unique_id']
|
|
||||||
break
|
|
||||||
frame = frame.f_back
|
|
||||||
|
|
||||||
# Record inputs before execution
|
|
||||||
if node_id is not None:
|
|
||||||
registry.record_node_execution(node_id, class_type, input_data_all, None)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error collecting metadata (pre-execution): {str(e)}")
|
|
||||||
|
|
||||||
# Execute the original function
|
|
||||||
results = original_map_node_over_list(obj, input_data_all, func, allow_interrupt, execution_block_cb, pre_execute_cb)
|
|
||||||
|
|
||||||
# After execution, collect outputs for relevant nodes
|
|
||||||
if func == obj.FUNCTION and hasattr(obj, '__class__'):
|
|
||||||
try:
|
|
||||||
# Get the current prompt_id from the registry
|
|
||||||
registry = MetadataRegistry()
|
|
||||||
prompt_id = registry.current_prompt_id
|
|
||||||
|
|
||||||
if prompt_id is not None:
|
|
||||||
# Get node class type
|
|
||||||
class_type = obj.__class__.__name__
|
|
||||||
|
|
||||||
# Unique ID might be available through the obj if it has a unique_id field
|
|
||||||
node_id = getattr(obj, 'unique_id', None)
|
|
||||||
if node_id is None and pre_execute_cb:
|
|
||||||
# Try to extract node_id through reflection
|
|
||||||
frame = inspect.currentframe()
|
|
||||||
while frame:
|
|
||||||
if 'unique_id' in frame.f_locals:
|
|
||||||
node_id = frame.f_locals['unique_id']
|
|
||||||
break
|
|
||||||
frame = frame.f_back
|
|
||||||
|
|
||||||
# Record outputs after execution
|
|
||||||
if node_id is not None:
|
|
||||||
registry.update_node_execution(node_id, class_type, results)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error collecting metadata (post-execution): {str(e)}")
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
# Also hook the execute function to track the current prompt_id
|
|
||||||
original_execute = execution.execute
|
|
||||||
|
|
||||||
def execute_with_prompt_tracking(*args, **kwargs):
|
|
||||||
if len(args) >= 7: # Check if we have enough arguments
|
|
||||||
server, prompt, caches, node_id, extra_data, executed, prompt_id = args[:7]
|
|
||||||
registry = MetadataRegistry()
|
|
||||||
|
|
||||||
# Start collection if this is a new prompt
|
|
||||||
if not registry.current_prompt_id or registry.current_prompt_id != prompt_id:
|
|
||||||
registry.start_collection(prompt_id)
|
|
||||||
|
|
||||||
# Store the dynprompt reference for node lookups
|
|
||||||
if hasattr(prompt, 'original_prompt'):
|
|
||||||
registry.set_current_prompt(prompt)
|
|
||||||
|
|
||||||
# Execute the original function
|
|
||||||
return original_execute(*args, **kwargs)
|
|
||||||
|
|
||||||
# Replace the functions
|
|
||||||
execution._map_node_over_list = map_node_over_list_with_metadata
|
|
||||||
execution.execute = execute_with_prompt_tracking
|
|
||||||
# Make map_node_over_list public to avoid it being hidden by hooks
|
|
||||||
execution.map_node_over_list = original_map_node_over_list
|
|
||||||
|
|
||||||
print("Metadata collection hooks installed for runtime values")
|
print("Metadata collection hooks installed for runtime values")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error installing metadata hooks: {str(e)}")
|
print(f"Error installing metadata hooks: {str(e)}")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _install_sync_hooks(execution):
|
||||||
|
"""Install hooks for synchronous execution model"""
|
||||||
|
# Store the original _map_node_over_list function
|
||||||
|
original_map_node_over_list = execution._map_node_over_list
|
||||||
|
|
||||||
|
# Define the wrapped _map_node_over_list function
|
||||||
|
def map_node_over_list_with_metadata(obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None):
|
||||||
|
# Only collect metadata when calling the main function of nodes
|
||||||
|
if func == obj.FUNCTION and hasattr(obj, '__class__'):
|
||||||
|
try:
|
||||||
|
# Get the current prompt_id from the registry
|
||||||
|
registry = MetadataRegistry()
|
||||||
|
prompt_id = registry.current_prompt_id
|
||||||
|
|
||||||
|
if prompt_id is not None:
|
||||||
|
# Get node class type
|
||||||
|
class_type = obj.__class__.__name__
|
||||||
|
|
||||||
|
# Unique ID might be available through the obj if it has a unique_id field
|
||||||
|
node_id = getattr(obj, 'unique_id', None)
|
||||||
|
if node_id is None and pre_execute_cb:
|
||||||
|
# Try to extract node_id through reflection on GraphBuilder.set_default_prefix
|
||||||
|
frame = inspect.currentframe()
|
||||||
|
while frame:
|
||||||
|
if 'unique_id' in frame.f_locals:
|
||||||
|
node_id = frame.f_locals['unique_id']
|
||||||
|
break
|
||||||
|
frame = frame.f_back
|
||||||
|
|
||||||
|
# Record inputs before execution
|
||||||
|
if node_id is not None:
|
||||||
|
registry.record_node_execution(node_id, class_type, input_data_all, None)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error collecting metadata (pre-execution): {str(e)}")
|
||||||
|
|
||||||
|
# Execute the original function
|
||||||
|
results = original_map_node_over_list(obj, input_data_all, func, allow_interrupt, execution_block_cb, pre_execute_cb)
|
||||||
|
|
||||||
|
# After execution, collect outputs for relevant nodes
|
||||||
|
if func == obj.FUNCTION and hasattr(obj, '__class__'):
|
||||||
|
try:
|
||||||
|
# Get the current prompt_id from the registry
|
||||||
|
registry = MetadataRegistry()
|
||||||
|
prompt_id = registry.current_prompt_id
|
||||||
|
|
||||||
|
if prompt_id is not None:
|
||||||
|
# Get node class type
|
||||||
|
class_type = obj.__class__.__name__
|
||||||
|
|
||||||
|
# Unique ID might be available through the obj if it has a unique_id field
|
||||||
|
node_id = getattr(obj, 'unique_id', None)
|
||||||
|
if node_id is None and pre_execute_cb:
|
||||||
|
# Try to extract node_id through reflection
|
||||||
|
frame = inspect.currentframe()
|
||||||
|
while frame:
|
||||||
|
if 'unique_id' in frame.f_locals:
|
||||||
|
node_id = frame.f_locals['unique_id']
|
||||||
|
break
|
||||||
|
frame = frame.f_back
|
||||||
|
|
||||||
|
# Record outputs after execution
|
||||||
|
if node_id is not None:
|
||||||
|
registry.update_node_execution(node_id, class_type, results)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error collecting metadata (post-execution): {str(e)}")
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
# Also hook the execute function to track the current prompt_id
|
||||||
|
original_execute = execution.execute
|
||||||
|
|
||||||
|
def execute_with_prompt_tracking(*args, **kwargs):
|
||||||
|
if len(args) >= 7: # Check if we have enough arguments
|
||||||
|
server, prompt, caches, node_id, extra_data, executed, prompt_id = args[:7]
|
||||||
|
registry = MetadataRegistry()
|
||||||
|
|
||||||
|
# Start collection if this is a new prompt
|
||||||
|
if not registry.current_prompt_id or registry.current_prompt_id != prompt_id:
|
||||||
|
registry.start_collection(prompt_id)
|
||||||
|
|
||||||
|
# Store the dynprompt reference for node lookups
|
||||||
|
if hasattr(prompt, 'original_prompt'):
|
||||||
|
registry.set_current_prompt(prompt)
|
||||||
|
|
||||||
|
# Execute the original function
|
||||||
|
return original_execute(*args, **kwargs)
|
||||||
|
|
||||||
|
# Replace the functions
|
||||||
|
execution._map_node_over_list = map_node_over_list_with_metadata
|
||||||
|
execution.execute = execute_with_prompt_tracking
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _install_async_hooks(execution, map_node_func_name='_async_map_node_over_list'):
|
||||||
|
"""Install hooks for asynchronous execution model"""
|
||||||
|
# Store the original _async_map_node_over_list function
|
||||||
|
original_map_node_over_list = getattr(execution, map_node_func_name)
|
||||||
|
|
||||||
|
# Define the wrapped async function - NOTE: Updated signature with prompt_id and unique_id!
|
||||||
|
async def async_map_node_over_list_with_metadata(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None):
|
||||||
|
# Only collect metadata when calling the main function of nodes
|
||||||
|
if func == obj.FUNCTION and hasattr(obj, '__class__'):
|
||||||
|
try:
|
||||||
|
# Get the current prompt_id from the registry
|
||||||
|
registry = MetadataRegistry()
|
||||||
|
# We now have prompt_id directly from the function parameters
|
||||||
|
|
||||||
|
if prompt_id is not None:
|
||||||
|
# Get node class type
|
||||||
|
class_type = obj.__class__.__name__
|
||||||
|
|
||||||
|
# Use the passed unique_id parameter instead of trying to extract it
|
||||||
|
node_id = unique_id
|
||||||
|
|
||||||
|
# Record inputs before execution
|
||||||
|
if node_id is not None:
|
||||||
|
registry.record_node_execution(node_id, class_type, input_data_all, None)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error collecting metadata (pre-execution): {str(e)}")
|
||||||
|
|
||||||
|
# Execute the original async function with ALL parameters in the correct order
|
||||||
|
results = await original_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt, execution_block_cb, pre_execute_cb)
|
||||||
|
|
||||||
|
# After execution, collect outputs for relevant nodes
|
||||||
|
if func == obj.FUNCTION and hasattr(obj, '__class__'):
|
||||||
|
try:
|
||||||
|
# Get the current prompt_id from the registry
|
||||||
|
registry = MetadataRegistry()
|
||||||
|
|
||||||
|
if prompt_id is not None:
|
||||||
|
# Get node class type
|
||||||
|
class_type = obj.__class__.__name__
|
||||||
|
|
||||||
|
# Use the passed unique_id parameter
|
||||||
|
node_id = unique_id
|
||||||
|
|
||||||
|
# Record outputs after execution
|
||||||
|
if node_id is not None:
|
||||||
|
registry.update_node_execution(node_id, class_type, results)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error collecting metadata (post-execution): {str(e)}")
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
# Also hook the execute function to track the current prompt_id
|
||||||
|
original_execute = execution.execute
|
||||||
|
|
||||||
|
async def async_execute_with_prompt_tracking(*args, **kwargs):
|
||||||
|
if len(args) >= 7: # Check if we have enough arguments
|
||||||
|
server, prompt, caches, node_id, extra_data, executed, prompt_id = args[:7]
|
||||||
|
registry = MetadataRegistry()
|
||||||
|
|
||||||
|
# Start collection if this is a new prompt
|
||||||
|
if not registry.current_prompt_id or registry.current_prompt_id != prompt_id:
|
||||||
|
registry.start_collection(prompt_id)
|
||||||
|
|
||||||
|
# Store the dynprompt reference for node lookups
|
||||||
|
if hasattr(prompt, 'original_prompt'):
|
||||||
|
registry.set_current_prompt(prompt)
|
||||||
|
|
||||||
|
# Execute the original function
|
||||||
|
return await original_execute(*args, **kwargs)
|
||||||
|
|
||||||
|
# Replace the functions with async versions
|
||||||
|
setattr(execution, map_node_func_name, async_map_node_over_list_with_metadata)
|
||||||
|
execution.execute = async_execute_with_prompt_tracking
|
||||||
|
|||||||
Reference in New Issue
Block a user