From d766dac3416d0c7de759628772ce0510135fc012 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Thu, 17 Jul 2025 14:44:30 +0800 Subject: [PATCH] feat: Enhance metadata collection by adding support for async execution hooks and improving error handling. See #291 #298 --- py/metadata_collector/metadata_hook.py | 269 +++++++++++++++++-------- 1 file changed, 181 insertions(+), 88 deletions(-) diff --git a/py/metadata_collector/metadata_hook.py b/py/metadata_collector/metadata_hook.py index 2b6a7b6d..f0beda7c 100644 --- a/py/metadata_collector/metadata_hook.py +++ b/py/metadata_collector/metadata_hook.py @@ -26,98 +26,191 @@ class MetadataHook: print("Could not locate ComfyUI execution module, metadata collection disabled") return - # Store the original _map_node_over_list function - original_map_node_over_list = execution._map_node_over_list + # Detect whether we're using the new async version of ComfyUI + is_async = False + map_node_func_name = '_map_node_over_list' - # Define the wrapped _map_node_over_list function - def map_node_over_list_with_metadata(obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None): - # Only collect metadata when calling the main function of nodes - if func == obj.FUNCTION and hasattr(obj, '__class__'): - try: - # Get the current prompt_id from the registry - registry = MetadataRegistry() - prompt_id = registry.current_prompt_id - - if prompt_id is not None: - # Get node class type - class_type = obj.__class__.__name__ - - # Unique ID might be available through the obj if it has a unique_id field - node_id = getattr(obj, 'unique_id', None) - if node_id is None and pre_execute_cb: - # Try to extract node_id through reflection on GraphBuilder.set_default_prefix - frame = inspect.currentframe() - while frame: - if 'unique_id' in frame.f_locals: - node_id = frame.f_locals['unique_id'] - break - frame = frame.f_back - - # Record inputs before execution - if node_id is not None: - registry.record_node_execution(node_id, class_type, input_data_all, None) - except Exception as e: - print(f"Error collecting metadata (pre-execution): {str(e)}") - - # Execute the original function - results = original_map_node_over_list(obj, input_data_all, func, allow_interrupt, execution_block_cb, pre_execute_cb) - - # After execution, collect outputs for relevant nodes - if func == obj.FUNCTION and hasattr(obj, '__class__'): - try: - # Get the current prompt_id from the registry - registry = MetadataRegistry() - prompt_id = registry.current_prompt_id - - if prompt_id is not None: - # Get node class type - class_type = obj.__class__.__name__ - - # Unique ID might be available through the obj if it has a unique_id field - node_id = getattr(obj, 'unique_id', None) - if node_id is None and pre_execute_cb: - # Try to extract node_id through reflection - frame = inspect.currentframe() - while frame: - if 'unique_id' in frame.f_locals: - node_id = frame.f_locals['unique_id'] - break - frame = frame.f_back - - # Record outputs after execution - if node_id is not None: - registry.update_node_execution(node_id, class_type, results) - except Exception as e: - print(f"Error collecting metadata (post-execution): {str(e)}") - - return results - - # Also hook the execute function to track the current prompt_id - original_execute = execution.execute + if hasattr(execution, '_async_map_node_over_list'): + is_async = inspect.iscoroutinefunction(execution._async_map_node_over_list) + map_node_func_name = '_async_map_node_over_list' + elif hasattr(execution, '_map_node_over_list'): + is_async = inspect.iscoroutinefunction(execution._map_node_over_list) - def execute_with_prompt_tracking(*args, **kwargs): - if len(args) >= 7: # Check if we have enough arguments - server, prompt, caches, node_id, extra_data, executed, prompt_id = args[:7] - registry = MetadataRegistry() - - # Start collection if this is a new prompt - if not registry.current_prompt_id or registry.current_prompt_id != prompt_id: - registry.start_collection(prompt_id) - - # Store the dynprompt reference for node lookups - if hasattr(prompt, 'original_prompt'): - registry.set_current_prompt(prompt) - - # Execute the original function - return original_execute(*args, **kwargs) - - # Replace the functions - execution._map_node_over_list = map_node_over_list_with_metadata - execution.execute = execute_with_prompt_tracking - # Make map_node_over_list public to avoid it being hidden by hooks - execution.map_node_over_list = original_map_node_over_list + if is_async: + print("Detected async ComfyUI execution, installing async metadata hooks") + MetadataHook._install_async_hooks(execution, map_node_func_name) + else: + print("Detected sync ComfyUI execution, installing sync metadata hooks") + MetadataHook._install_sync_hooks(execution) print("Metadata collection hooks installed for runtime values") except Exception as e: print(f"Error installing metadata hooks: {str(e)}") + + @staticmethod + def _install_sync_hooks(execution): + """Install hooks for synchronous execution model""" + # Store the original _map_node_over_list function + original_map_node_over_list = execution._map_node_over_list + + # Define the wrapped _map_node_over_list function + def map_node_over_list_with_metadata(obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None): + # Only collect metadata when calling the main function of nodes + if func == obj.FUNCTION and hasattr(obj, '__class__'): + try: + # Get the current prompt_id from the registry + registry = MetadataRegistry() + prompt_id = registry.current_prompt_id + + if prompt_id is not None: + # Get node class type + class_type = obj.__class__.__name__ + + # Unique ID might be available through the obj if it has a unique_id field + node_id = getattr(obj, 'unique_id', None) + if node_id is None and pre_execute_cb: + # Try to extract node_id through reflection on GraphBuilder.set_default_prefix + frame = inspect.currentframe() + while frame: + if 'unique_id' in frame.f_locals: + node_id = frame.f_locals['unique_id'] + break + frame = frame.f_back + + # Record inputs before execution + if node_id is not None: + registry.record_node_execution(node_id, class_type, input_data_all, None) + except Exception as e: + print(f"Error collecting metadata (pre-execution): {str(e)}") + + # Execute the original function + results = original_map_node_over_list(obj, input_data_all, func, allow_interrupt, execution_block_cb, pre_execute_cb) + + # After execution, collect outputs for relevant nodes + if func == obj.FUNCTION and hasattr(obj, '__class__'): + try: + # Get the current prompt_id from the registry + registry = MetadataRegistry() + prompt_id = registry.current_prompt_id + + if prompt_id is not None: + # Get node class type + class_type = obj.__class__.__name__ + + # Unique ID might be available through the obj if it has a unique_id field + node_id = getattr(obj, 'unique_id', None) + if node_id is None and pre_execute_cb: + # Try to extract node_id through reflection + frame = inspect.currentframe() + while frame: + if 'unique_id' in frame.f_locals: + node_id = frame.f_locals['unique_id'] + break + frame = frame.f_back + + # Record outputs after execution + if node_id is not None: + registry.update_node_execution(node_id, class_type, results) + except Exception as e: + print(f"Error collecting metadata (post-execution): {str(e)}") + + return results + + # Also hook the execute function to track the current prompt_id + original_execute = execution.execute + + def execute_with_prompt_tracking(*args, **kwargs): + if len(args) >= 7: # Check if we have enough arguments + server, prompt, caches, node_id, extra_data, executed, prompt_id = args[:7] + registry = MetadataRegistry() + + # Start collection if this is a new prompt + if not registry.current_prompt_id or registry.current_prompt_id != prompt_id: + registry.start_collection(prompt_id) + + # Store the dynprompt reference for node lookups + if hasattr(prompt, 'original_prompt'): + registry.set_current_prompt(prompt) + + # Execute the original function + return original_execute(*args, **kwargs) + + # Replace the functions + execution._map_node_over_list = map_node_over_list_with_metadata + execution.execute = execute_with_prompt_tracking + + @staticmethod + def _install_async_hooks(execution, map_node_func_name='_async_map_node_over_list'): + """Install hooks for asynchronous execution model""" + # Store the original _async_map_node_over_list function + original_map_node_over_list = getattr(execution, map_node_func_name) + + # Define the wrapped async function - NOTE: Updated signature with prompt_id and unique_id! + async def async_map_node_over_list_with_metadata(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None): + # Only collect metadata when calling the main function of nodes + if func == obj.FUNCTION and hasattr(obj, '__class__'): + try: + # Get the current prompt_id from the registry + registry = MetadataRegistry() + # We now have prompt_id directly from the function parameters + + if prompt_id is not None: + # Get node class type + class_type = obj.__class__.__name__ + + # Use the passed unique_id parameter instead of trying to extract it + node_id = unique_id + + # Record inputs before execution + if node_id is not None: + registry.record_node_execution(node_id, class_type, input_data_all, None) + except Exception as e: + print(f"Error collecting metadata (pre-execution): {str(e)}") + + # Execute the original async function with ALL parameters in the correct order + results = await original_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt, execution_block_cb, pre_execute_cb) + + # After execution, collect outputs for relevant nodes + if func == obj.FUNCTION and hasattr(obj, '__class__'): + try: + # Get the current prompt_id from the registry + registry = MetadataRegistry() + + if prompt_id is not None: + # Get node class type + class_type = obj.__class__.__name__ + + # Use the passed unique_id parameter + node_id = unique_id + + # Record outputs after execution + if node_id is not None: + registry.update_node_execution(node_id, class_type, results) + except Exception as e: + print(f"Error collecting metadata (post-execution): {str(e)}") + + return results + + # Also hook the execute function to track the current prompt_id + original_execute = execution.execute + + async def async_execute_with_prompt_tracking(*args, **kwargs): + if len(args) >= 7: # Check if we have enough arguments + server, prompt, caches, node_id, extra_data, executed, prompt_id = args[:7] + registry = MetadataRegistry() + + # Start collection if this is a new prompt + if not registry.current_prompt_id or registry.current_prompt_id != prompt_id: + registry.start_collection(prompt_id) + + # Store the dynprompt reference for node lookups + if hasattr(prompt, 'original_prompt'): + registry.set_current_prompt(prompt) + + # Execute the original function + return await original_execute(*args, **kwargs) + + # Replace the functions with async versions + setattr(execution, map_node_func_name, async_map_node_over_list_with_metadata) + execution.execute = async_execute_with_prompt_tracking