feat(metadata): enhance primary sampler detection and workflow tracing

- Add support for `basic_pipe` nodes in metadata processor to handle pipeline nodes like FromBasicPipe
- Optimize `find_primary_checkpoint` by accepting optional `primary_sampler_id` to avoid redundant calculations
- Update `get_workflow_trace` to pass known primary sampler ID for improved efficiency
This commit is contained in:
Will Miao
2025-12-18 22:30:41 +08:00
parent ca6bb43406
commit c8a179488a
2 changed files with 108 additions and 4 deletions

View File

@@ -239,6 +239,9 @@ class MetadataProcessor:
if next_input_name not in inputs:
if "model" in inputs:
next_input_name = "model"
elif "basic_pipe" in inputs:
# Handle pipe nodes like FromBasicPipe by following the pipeline
next_input_name = "basic_pipe"
else:
# Dead end - no model input to follow
return None
@@ -255,20 +258,22 @@ class MetadataProcessor:
return None
@staticmethod
def find_primary_checkpoint(metadata, downstream_id=None):
def find_primary_checkpoint(metadata, downstream_id=None, primary_sampler_id=None):
"""
Find the primary checkpoint model in the workflow
Parameters:
- metadata: The workflow metadata
- downstream_id: Optional ID of a downstream node to help identify the specific primary sampler
- primary_sampler_id: Optional ID of the primary sampler if already known
"""
if not metadata.get(MODELS):
return None
# Method 1: Topology-based tracing (More accurate for complex workflows)
# First, find the primary sampler
primary_sampler_id, _ = MetadataProcessor.find_primary_sampler(metadata, downstream_id)
# First, find the primary sampler if not provided
if not primary_sampler_id:
primary_sampler_id, _ = MetadataProcessor.find_primary_sampler(metadata, downstream_id)
if primary_sampler_id:
prompt = metadata.get("current_prompt")
@@ -383,7 +388,8 @@ class MetadataProcessor:
primary_sampler_id, primary_sampler = MetadataProcessor.find_primary_sampler(metadata, id)
# Directly get checkpoint from metadata instead of tracing
checkpoint = MetadataProcessor.find_primary_checkpoint(metadata, id)
# Pass primary_sampler_id to avoid redundant calculation
checkpoint = MetadataProcessor.find_primary_checkpoint(metadata, id, primary_sampler_id)
if checkpoint:
params["checkpoint"] = checkpoint