feat: enhance async metadata collection by updating function signature and preserving all parameters. Fixes #328 #327

This commit is contained in:
Will Miao
2025-08-01 21:47:52 +08:00
parent 848c1741fe
commit c0eff2bb5e

View File

@@ -146,52 +146,40 @@ class MetadataHook:
# Store the original _async_map_node_over_list function
original_map_node_over_list = getattr(execution, map_node_func_name)
# Define the wrapped async function - NOTE: Updated signature with prompt_id and unique_id!
async def async_map_node_over_list_with_metadata(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None):
# Wrapped async function, compatible with both stable and nightly
async def async_map_node_over_list_with_metadata(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None, *args, **kwargs):
hidden_inputs = kwargs.get('hidden_inputs', None)
# Only collect metadata when calling the main function of nodes
if func == obj.FUNCTION and hasattr(obj, '__class__'):
try:
# Get the current prompt_id from the registry
registry = MetadataRegistry()
# We now have prompt_id directly from the function parameters
if prompt_id is not None:
# Get node class type
class_type = obj.__class__.__name__
# Use the passed unique_id parameter instead of trying to extract it
node_id = unique_id
# Record inputs before execution
if node_id is not None:
registry.record_node_execution(node_id, class_type, input_data_all, None)
except Exception as e:
print(f"Error collecting metadata (pre-execution): {str(e)}")
# Execute the original async function with ALL parameters in the correct order
results = await original_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt, execution_block_cb, pre_execute_cb)
# Call original function with all args/kwargs
results = await original_map_node_over_list(
prompt_id, unique_id, obj, input_data_all, func,
allow_interrupt, execution_block_cb, pre_execute_cb, *args, **kwargs
)
# After execution, collect outputs for relevant nodes
if func == obj.FUNCTION and hasattr(obj, '__class__'):
try:
# Get the current prompt_id from the registry
registry = MetadataRegistry()
if prompt_id is not None:
# Get node class type
class_type = obj.__class__.__name__
# Use the passed unique_id parameter
node_id = unique_id
# Record outputs after execution
if node_id is not None:
registry.update_node_execution(node_id, class_type, results)
except Exception as e:
print(f"Error collecting metadata (post-execution): {str(e)}")
return results
# Also hook the execute function to track the current prompt_id
original_execute = execution.execute