From a07720a3bf1eb1be41ff10e3b85a1eb3428093dd Mon Sep 17 00:00:00 2001 From: Will Miao Date: Wed, 17 Dec 2025 12:52:52 +0800 Subject: [PATCH] feat: Add model path tracing to accurately identify the primary checkpoint in workflows and include new tests. --- py/metadata_collector/metadata_processor.py | 82 +++++++++- tests/metadata_collector/test_tracer.py | 172 ++++++++++++++++++++ 2 files changed, 249 insertions(+), 5 deletions(-) create mode 100644 tests/metadata_collector/test_tracer.py diff --git a/py/metadata_collector/metadata_processor.py b/py/metadata_collector/metadata_processor.py index 587bcf12..c74cd23a 100644 --- a/py/metadata_collector/metadata_processor.py +++ b/py/metadata_collector/metadata_processor.py @@ -202,12 +202,84 @@ class MetadataProcessor: return last_valid_node if not target_class else None @staticmethod - def find_primary_checkpoint(metadata): - """Find the primary checkpoint model in the workflow""" - if not metadata.get(MODELS): + def trace_model_path(metadata, prompt, start_node_id): + """ + Trace the model connection path upstream to find the checkpoint + """ + if not prompt or not prompt.original_prompt: return None - # In most workflows, there's only one checkpoint, so we can just take the first one + current_node_id = start_node_id + depth = 0 + max_depth = 50 + + while depth < max_depth: + # Check if current node is a registered checkpoint in our metadata + # This handles cached nodes correctly because metadata contains info for all nodes in the graph + if current_node_id in metadata.get(MODELS, {}): + if metadata[MODELS][current_node_id].get("type") == "checkpoint": + return current_node_id + + if current_node_id not in prompt.original_prompt: + return None + + node = prompt.original_prompt[current_node_id] + inputs = node.get("inputs", {}) + class_type = node.get("class_type", "") + + # Determine which input to follow next + next_input_name = "model" + + # Special handling for initial node + if depth == 0: + if class_type == "SamplerCustomAdvanced": + next_input_name = "guider" + + # If the specific input doesn't exist, try generic 'model' + if next_input_name not in inputs: + if "model" in inputs: + next_input_name = "model" + else: + # Dead end - no model input to follow + return None + + # Get connected node + input_val = inputs[next_input_name] + if isinstance(input_val, list) and len(input_val) > 0: + current_node_id = input_val[0] + else: + return None + + depth += 1 + + return None + + @staticmethod + def find_primary_checkpoint(metadata, downstream_id=None): + """ + Find the primary checkpoint model in the workflow + + Parameters: + - metadata: The workflow metadata + - downstream_id: Optional ID of a downstream node to help identify the specific primary sampler + """ + if not metadata.get(MODELS): + return None + + # Method 1: Topology-based tracing (More accurate for complex workflows) + # First, find the primary sampler + primary_sampler_id, _ = MetadataProcessor.find_primary_sampler(metadata, downstream_id) + + if primary_sampler_id: + prompt = metadata.get("current_prompt") + if prompt: + # Trace back from the sampler to find the checkpoint + checkpoint_id = MetadataProcessor.trace_model_path(metadata, prompt, primary_sampler_id) + if checkpoint_id and checkpoint_id in metadata.get(MODELS, {}): + return metadata[MODELS][checkpoint_id].get("name") + + # Method 2: Fallback to the first available checkpoint (Original behavior) + # In most simple workflows, there's only one checkpoint, so we can just take the first one for node_id, model_info in metadata.get(MODELS, {}).items(): if model_info.get("type") == "checkpoint": return model_info.get("name") @@ -311,7 +383,7 @@ class MetadataProcessor: primary_sampler_id, primary_sampler = MetadataProcessor.find_primary_sampler(metadata, id) # Directly get checkpoint from metadata instead of tracing - checkpoint = MetadataProcessor.find_primary_checkpoint(metadata) + checkpoint = MetadataProcessor.find_primary_checkpoint(metadata, id) if checkpoint: params["checkpoint"] = checkpoint diff --git a/tests/metadata_collector/test_tracer.py b/tests/metadata_collector/test_tracer.py new file mode 100644 index 00000000..5fca8c5c --- /dev/null +++ b/tests/metadata_collector/test_tracer.py @@ -0,0 +1,172 @@ + +import pytest +from types import SimpleNamespace +from py.metadata_collector.metadata_processor import MetadataProcessor +from py.metadata_collector.constants import MODELS, SAMPLING, IS_SAMPLER + +class TestMetadataTracer: + + @pytest.fixture + def mock_workflow_metadata(self): + """ + Creates a mock metadata structure with a complex workflow graph. + Structure: + Sampler(246) -> Guider(241) -> LoraLoader(264) -> CheckpointLoader(238) + + Also includes a "Decoy" checkpoint (ID 999) that is NOT connected, + to verify we found the *connected* one, not just *any* one. + """ + + # 1. Define the Graph (Original Prompt) + # Using IDs as strings to match typical ComfyUI behavior in metadata + original_prompt = { + "246": { + "class_type": "SamplerCustomAdvanced", + "inputs": { + "guider": ["241", 0], + "noise": ["255", 0], + "sampler": ["247", 0], + "sigmas": ["248", 0], + "latent_image": ["153", 0] + } + }, + "241": { + "class_type": "CFGGuider", + "inputs": { + "model": ["264", 0], + "positive": ["239", 0], + "negative": ["240", 0] + } + }, + "264": { + "class_type": "LoraLoader", # Simplified name + "inputs": { + "model": ["238", 0], + "lora_name": "some_style_lora.safetensors" + } + }, + "238": { + "class_type": "CheckpointLoaderSimple", + "inputs": { + "ckpt_name": "Correct_Model.safetensors" + } + }, + + # unconnected / decoy nodes + "999": { + "class_type": "CheckpointLoaderSimple", + "inputs": { + "ckpt_name": "Decoy_Model.safetensors" + } + }, + "154": { # Downstream VAE Decode + "class_type": "VAEDecode", + "inputs": { + "samples": ["246", 0] + } + } + } + + # 2. Define the Metadata (Collected execution data) + metadata = { + "current_prompt": SimpleNamespace(original_prompt=original_prompt), + "execution_order": ["238", "264", "241", "246", "154", "999"], # 999 execs last or separately + + # Models Registry + MODELS: { + "238": { + "type": "checkpoint", + "name": "Correct_Model.safetensors" + }, + "999": { + "type": "checkpoint", + "name": "Decoy_Model.safetensors" + } + }, + + # Sampling Registry + SAMPLING: { + "246": { + IS_SAMPLER: True, + "parameters": { + "sampler_name": "euler", + "scheduler": "normal" + } + } + }, + "images": { + "first_decode": { + "node_id": "154" + } + } + } + + return metadata + + def test_find_primary_sampler_identifies_correct_node(self, mock_workflow_metadata): + """Verify find_primary_sampler correctly identifies the sampler connected to the downstream decode.""" + sampler_id, sampler_info = MetadataProcessor.find_primary_sampler(mock_workflow_metadata, downstream_id="154") + + assert sampler_id == "246" + assert sampler_info is not None + assert sampler_info["parameters"]["sampler_name"] == "euler" + + def test_trace_model_path_follows_topology(self, mock_workflow_metadata): + """Verify trace_model_path follows: Sampler -> Guider -> Lora -> Checkpoint.""" + prompt = mock_workflow_metadata["current_prompt"] + + # Start trace from Sampler (246) + # Should find Checkpoint (238) + ckpt_id = MetadataProcessor.trace_model_path(mock_workflow_metadata, prompt, "246") + + assert ckpt_id == "238" # Should be the ID of the connected checkpoint + + def test_find_primary_checkpoint_prioritizes_connected_model(self, mock_workflow_metadata): + """Verify find_primary_checkpoint returns the NAME of the topologically connected checkpoint, honoring the graph.""" + name = MetadataProcessor.find_primary_checkpoint(mock_workflow_metadata, downstream_id="154") + + assert name == "Correct_Model.safetensors" + assert name != "Decoy_Model.safetensors" + + def test_trace_model_path_simple_direct_connection(self): + """Verify it works for a simple Sampler -> Checkpoint connection.""" + original_prompt = { + "100": { # Sampler + "class_type": "KSampler", + "inputs": { + "model": ["101", 0] + } + }, + "101": { # Checkpoint + "class_type": "CheckpointLoaderSimple", + "inputs": {} + } + } + + metadata = { + "current_prompt": SimpleNamespace(original_prompt=original_prompt), + MODELS: { + "101": {"type": "checkpoint", "name": "Simple_Model.safetensors"} + } + } + + ckpt_id = MetadataProcessor.trace_model_path(metadata, metadata["current_prompt"], "100") + assert ckpt_id == "101" + + def test_trace_stops_at_max_depth(self): + """Verify logic halts if graph is infinitely cyclic or too deep.""" + # Create a cycle: Node 1 -> Node 2 -> Node 1 + original_prompt = { + "1": {"inputs": {"model": ["2", 0]}}, + "2": {"inputs": {"model": ["1", 0]}} + } + + metadata = { + "current_prompt": SimpleNamespace(original_prompt=original_prompt), + MODELS: {} # No checkpoints registered + } + + # Should return None, not hang forever + ckpt_id = MetadataProcessor.trace_model_path(metadata, metadata["current_prompt"], "1") + assert ckpt_id is None +