From c8a179488aa2f8723bb7bbd040ee2c8475aa4724 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Thu, 18 Dec 2025 22:30:41 +0800 Subject: [PATCH] feat(metadata): enhance primary sampler detection and workflow tracing - Add support for `basic_pipe` nodes in metadata processor to handle pipeline nodes like FromBasicPipe - Optimize `find_primary_checkpoint` by accepting optional `primary_sampler_id` to avoid redundant calculations - Update `get_workflow_trace` to pass known primary sampler ID for improved efficiency --- py/metadata_collector/metadata_processor.py | 14 ++- tests/metadata_collector/test_pipe_tracer.py | 98 ++++++++++++++++++++ 2 files changed, 108 insertions(+), 4 deletions(-) create mode 100644 tests/metadata_collector/test_pipe_tracer.py diff --git a/py/metadata_collector/metadata_processor.py b/py/metadata_collector/metadata_processor.py index c74cd23a..9dd85542 100644 --- a/py/metadata_collector/metadata_processor.py +++ b/py/metadata_collector/metadata_processor.py @@ -239,6 +239,9 @@ class MetadataProcessor: if next_input_name not in inputs: if "model" in inputs: next_input_name = "model" + elif "basic_pipe" in inputs: + # Handle pipe nodes like FromBasicPipe by following the pipeline + next_input_name = "basic_pipe" else: # Dead end - no model input to follow return None @@ -255,20 +258,22 @@ class MetadataProcessor: return None @staticmethod - def find_primary_checkpoint(metadata, downstream_id=None): + def find_primary_checkpoint(metadata, downstream_id=None, primary_sampler_id=None): """ Find the primary checkpoint model in the workflow Parameters: - metadata: The workflow metadata - downstream_id: Optional ID of a downstream node to help identify the specific primary sampler + - primary_sampler_id: Optional ID of the primary sampler if already known """ if not metadata.get(MODELS): return None # Method 1: Topology-based tracing (More accurate for complex workflows) - # First, find the primary sampler - primary_sampler_id, _ = MetadataProcessor.find_primary_sampler(metadata, downstream_id) + # First, find the primary sampler if not provided + if not primary_sampler_id: + primary_sampler_id, _ = MetadataProcessor.find_primary_sampler(metadata, downstream_id) if primary_sampler_id: prompt = metadata.get("current_prompt") @@ -383,7 +388,8 @@ class MetadataProcessor: primary_sampler_id, primary_sampler = MetadataProcessor.find_primary_sampler(metadata, id) # Directly get checkpoint from metadata instead of tracing - checkpoint = MetadataProcessor.find_primary_checkpoint(metadata, id) + # Pass primary_sampler_id to avoid redundant calculation + checkpoint = MetadataProcessor.find_primary_checkpoint(metadata, id, primary_sampler_id) if checkpoint: params["checkpoint"] = checkpoint diff --git a/tests/metadata_collector/test_pipe_tracer.py b/tests/metadata_collector/test_pipe_tracer.py new file mode 100644 index 00000000..ddad5b98 --- /dev/null +++ b/tests/metadata_collector/test_pipe_tracer.py @@ -0,0 +1,98 @@ + +import pytest +from types import SimpleNamespace +from py.metadata_collector.metadata_processor import MetadataProcessor +from py.metadata_collector.constants import MODELS, SAMPLING, IS_SAMPLER + +class TestPipeTracer: + + @pytest.fixture + def pipe_workflow_metadata(self): + """ + Creates a mock metadata structure matching the one provided in refs/tmp. + Structure: + Load Checkpoint(28) -> Lora Loader(52) -> ToBasicPipe(69) -> FromBasicPipe(71) -> KSampler(32) + """ + + original_prompt = { + '28': { + 'inputs': {'ckpt_name': 'Illustrious\\bananaSplitzXL_vee5PointOh.safetensors'}, + 'class_type': 'CheckpointLoaderSimple' + }, + '52': { + 'inputs': { + 'model': ['28', 0], + 'clip': ['28', 1] + }, + 'class_type': 'Lora Loader (LoraManager)' + }, + '69': { + 'inputs': { + 'model': ['52', 0], + 'clip': ['52', 1], + 'vae': ['28', 2], + 'positive': ['75', 0], + 'negative': ['30', 0] + }, + 'class_type': 'ToBasicPipe' + }, + '71': { + 'inputs': {'basic_pipe': ['69', 0]}, + 'class_type': 'FromBasicPipe' + }, + '32': { + 'inputs': { + 'seed': 131755205602911, + 'steps': 5, + 'cfg': 8.0, + 'sampler_name': 'euler_ancestral', + 'scheduler': 'karras', + 'denoise': 1.0, + 'model': ['71', 0], + 'positive': ['71', 3], + 'negative': ['71', 4], + 'latent_image': ['76', 0] + }, + 'class_type': 'KSampler' + }, + '75': {'inputs': {'text': 'positive', 'clip': ['52', 1]}, 'class_type': 'CLIPTextEncode'}, + '30': {'inputs': {'text': 'negative', 'clip': ['52', 1]}, 'class_type': 'CLIPTextEncode'}, + '76': {'inputs': {'width': 832, 'height': 1216, 'batch_size': 1}, 'class_type': 'EmptyLatentImage'} + } + + metadata = { + "current_prompt": SimpleNamespace(original_prompt=original_prompt), + MODELS: { + "28": { + "type": "checkpoint", + "name": "bananaSplitzXL_vee5PointOh.safetensors" + } + }, + SAMPLING: { + "32": { + IS_SAMPLER: True, + "parameters": { + "sampler_name": "euler_ancestral", + "scheduler": "karras" + } + } + } + } + + return metadata + + def test_trace_model_path_through_pipe(self, pipe_workflow_metadata): + """Verify trace_model_path can follow: KSampler -> FromBasicPipe -> ToBasicPipe -> Lora -> Checkpoint.""" + prompt = pipe_workflow_metadata["current_prompt"] + + # Start trace from KSampler (32) + ckpt_id = MetadataProcessor.trace_model_path(pipe_workflow_metadata, prompt, "32") + + assert ckpt_id == "28" + + def test_find_primary_checkpoint_with_pipe(self, pipe_workflow_metadata): + """Verify find_primary_checkpoint returns the correct name even with pipe nodes.""" + # Providing sampler_id to test the optimization as well + name = MetadataProcessor.find_primary_checkpoint(pipe_workflow_metadata, primary_sampler_id="32") + + assert name == "bananaSplitzXL_vee5PointOh.safetensors"