mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-25 15:15:44 -03:00
feat: Add model path tracing to accurately identify the primary checkpoint in workflows and include new tests.
This commit is contained in:
@@ -202,12 +202,84 @@ class MetadataProcessor:
|
|||||||
return last_valid_node if not target_class else None
|
return last_valid_node if not target_class else None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def find_primary_checkpoint(metadata):
|
def trace_model_path(metadata, prompt, start_node_id):
|
||||||
"""Find the primary checkpoint model in the workflow"""
|
"""
|
||||||
if not metadata.get(MODELS):
|
Trace the model connection path upstream to find the checkpoint
|
||||||
|
"""
|
||||||
|
if not prompt or not prompt.original_prompt:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# In most workflows, there's only one checkpoint, so we can just take the first one
|
current_node_id = start_node_id
|
||||||
|
depth = 0
|
||||||
|
max_depth = 50
|
||||||
|
|
||||||
|
while depth < max_depth:
|
||||||
|
# Check if current node is a registered checkpoint in our metadata
|
||||||
|
# This handles cached nodes correctly because metadata contains info for all nodes in the graph
|
||||||
|
if current_node_id in metadata.get(MODELS, {}):
|
||||||
|
if metadata[MODELS][current_node_id].get("type") == "checkpoint":
|
||||||
|
return current_node_id
|
||||||
|
|
||||||
|
if current_node_id not in prompt.original_prompt:
|
||||||
|
return None
|
||||||
|
|
||||||
|
node = prompt.original_prompt[current_node_id]
|
||||||
|
inputs = node.get("inputs", {})
|
||||||
|
class_type = node.get("class_type", "")
|
||||||
|
|
||||||
|
# Determine which input to follow next
|
||||||
|
next_input_name = "model"
|
||||||
|
|
||||||
|
# Special handling for initial node
|
||||||
|
if depth == 0:
|
||||||
|
if class_type == "SamplerCustomAdvanced":
|
||||||
|
next_input_name = "guider"
|
||||||
|
|
||||||
|
# If the specific input doesn't exist, try generic 'model'
|
||||||
|
if next_input_name not in inputs:
|
||||||
|
if "model" in inputs:
|
||||||
|
next_input_name = "model"
|
||||||
|
else:
|
||||||
|
# Dead end - no model input to follow
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Get connected node
|
||||||
|
input_val = inputs[next_input_name]
|
||||||
|
if isinstance(input_val, list) and len(input_val) > 0:
|
||||||
|
current_node_id = input_val[0]
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
depth += 1
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def find_primary_checkpoint(metadata, downstream_id=None):
|
||||||
|
"""
|
||||||
|
Find the primary checkpoint model in the workflow
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- metadata: The workflow metadata
|
||||||
|
- downstream_id: Optional ID of a downstream node to help identify the specific primary sampler
|
||||||
|
"""
|
||||||
|
if not metadata.get(MODELS):
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Method 1: Topology-based tracing (More accurate for complex workflows)
|
||||||
|
# First, find the primary sampler
|
||||||
|
primary_sampler_id, _ = MetadataProcessor.find_primary_sampler(metadata, downstream_id)
|
||||||
|
|
||||||
|
if primary_sampler_id:
|
||||||
|
prompt = metadata.get("current_prompt")
|
||||||
|
if prompt:
|
||||||
|
# Trace back from the sampler to find the checkpoint
|
||||||
|
checkpoint_id = MetadataProcessor.trace_model_path(metadata, prompt, primary_sampler_id)
|
||||||
|
if checkpoint_id and checkpoint_id in metadata.get(MODELS, {}):
|
||||||
|
return metadata[MODELS][checkpoint_id].get("name")
|
||||||
|
|
||||||
|
# Method 2: Fallback to the first available checkpoint (Original behavior)
|
||||||
|
# In most simple workflows, there's only one checkpoint, so we can just take the first one
|
||||||
for node_id, model_info in metadata.get(MODELS, {}).items():
|
for node_id, model_info in metadata.get(MODELS, {}).items():
|
||||||
if model_info.get("type") == "checkpoint":
|
if model_info.get("type") == "checkpoint":
|
||||||
return model_info.get("name")
|
return model_info.get("name")
|
||||||
@@ -311,7 +383,7 @@ class MetadataProcessor:
|
|||||||
primary_sampler_id, primary_sampler = MetadataProcessor.find_primary_sampler(metadata, id)
|
primary_sampler_id, primary_sampler = MetadataProcessor.find_primary_sampler(metadata, id)
|
||||||
|
|
||||||
# Directly get checkpoint from metadata instead of tracing
|
# Directly get checkpoint from metadata instead of tracing
|
||||||
checkpoint = MetadataProcessor.find_primary_checkpoint(metadata)
|
checkpoint = MetadataProcessor.find_primary_checkpoint(metadata, id)
|
||||||
if checkpoint:
|
if checkpoint:
|
||||||
params["checkpoint"] = checkpoint
|
params["checkpoint"] = checkpoint
|
||||||
|
|
||||||
|
|||||||
172
tests/metadata_collector/test_tracer.py
Normal file
172
tests/metadata_collector/test_tracer.py
Normal file
@@ -0,0 +1,172 @@
|
|||||||
|
|
||||||
|
import pytest
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from py.metadata_collector.metadata_processor import MetadataProcessor
|
||||||
|
from py.metadata_collector.constants import MODELS, SAMPLING, IS_SAMPLER
|
||||||
|
|
||||||
|
class TestMetadataTracer:
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_workflow_metadata(self):
|
||||||
|
"""
|
||||||
|
Creates a mock metadata structure with a complex workflow graph.
|
||||||
|
Structure:
|
||||||
|
Sampler(246) -> Guider(241) -> LoraLoader(264) -> CheckpointLoader(238)
|
||||||
|
|
||||||
|
Also includes a "Decoy" checkpoint (ID 999) that is NOT connected,
|
||||||
|
to verify we found the *connected* one, not just *any* one.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 1. Define the Graph (Original Prompt)
|
||||||
|
# Using IDs as strings to match typical ComfyUI behavior in metadata
|
||||||
|
original_prompt = {
|
||||||
|
"246": {
|
||||||
|
"class_type": "SamplerCustomAdvanced",
|
||||||
|
"inputs": {
|
||||||
|
"guider": ["241", 0],
|
||||||
|
"noise": ["255", 0],
|
||||||
|
"sampler": ["247", 0],
|
||||||
|
"sigmas": ["248", 0],
|
||||||
|
"latent_image": ["153", 0]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"241": {
|
||||||
|
"class_type": "CFGGuider",
|
||||||
|
"inputs": {
|
||||||
|
"model": ["264", 0],
|
||||||
|
"positive": ["239", 0],
|
||||||
|
"negative": ["240", 0]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"264": {
|
||||||
|
"class_type": "LoraLoader", # Simplified name
|
||||||
|
"inputs": {
|
||||||
|
"model": ["238", 0],
|
||||||
|
"lora_name": "some_style_lora.safetensors"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"238": {
|
||||||
|
"class_type": "CheckpointLoaderSimple",
|
||||||
|
"inputs": {
|
||||||
|
"ckpt_name": "Correct_Model.safetensors"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
# unconnected / decoy nodes
|
||||||
|
"999": {
|
||||||
|
"class_type": "CheckpointLoaderSimple",
|
||||||
|
"inputs": {
|
||||||
|
"ckpt_name": "Decoy_Model.safetensors"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"154": { # Downstream VAE Decode
|
||||||
|
"class_type": "VAEDecode",
|
||||||
|
"inputs": {
|
||||||
|
"samples": ["246", 0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# 2. Define the Metadata (Collected execution data)
|
||||||
|
metadata = {
|
||||||
|
"current_prompt": SimpleNamespace(original_prompt=original_prompt),
|
||||||
|
"execution_order": ["238", "264", "241", "246", "154", "999"], # 999 execs last or separately
|
||||||
|
|
||||||
|
# Models Registry
|
||||||
|
MODELS: {
|
||||||
|
"238": {
|
||||||
|
"type": "checkpoint",
|
||||||
|
"name": "Correct_Model.safetensors"
|
||||||
|
},
|
||||||
|
"999": {
|
||||||
|
"type": "checkpoint",
|
||||||
|
"name": "Decoy_Model.safetensors"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
# Sampling Registry
|
||||||
|
SAMPLING: {
|
||||||
|
"246": {
|
||||||
|
IS_SAMPLER: True,
|
||||||
|
"parameters": {
|
||||||
|
"sampler_name": "euler",
|
||||||
|
"scheduler": "normal"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"images": {
|
||||||
|
"first_decode": {
|
||||||
|
"node_id": "154"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return metadata
|
||||||
|
|
||||||
|
def test_find_primary_sampler_identifies_correct_node(self, mock_workflow_metadata):
|
||||||
|
"""Verify find_primary_sampler correctly identifies the sampler connected to the downstream decode."""
|
||||||
|
sampler_id, sampler_info = MetadataProcessor.find_primary_sampler(mock_workflow_metadata, downstream_id="154")
|
||||||
|
|
||||||
|
assert sampler_id == "246"
|
||||||
|
assert sampler_info is not None
|
||||||
|
assert sampler_info["parameters"]["sampler_name"] == "euler"
|
||||||
|
|
||||||
|
def test_trace_model_path_follows_topology(self, mock_workflow_metadata):
|
||||||
|
"""Verify trace_model_path follows: Sampler -> Guider -> Lora -> Checkpoint."""
|
||||||
|
prompt = mock_workflow_metadata["current_prompt"]
|
||||||
|
|
||||||
|
# Start trace from Sampler (246)
|
||||||
|
# Should find Checkpoint (238)
|
||||||
|
ckpt_id = MetadataProcessor.trace_model_path(mock_workflow_metadata, prompt, "246")
|
||||||
|
|
||||||
|
assert ckpt_id == "238" # Should be the ID of the connected checkpoint
|
||||||
|
|
||||||
|
def test_find_primary_checkpoint_prioritizes_connected_model(self, mock_workflow_metadata):
|
||||||
|
"""Verify find_primary_checkpoint returns the NAME of the topologically connected checkpoint, honoring the graph."""
|
||||||
|
name = MetadataProcessor.find_primary_checkpoint(mock_workflow_metadata, downstream_id="154")
|
||||||
|
|
||||||
|
assert name == "Correct_Model.safetensors"
|
||||||
|
assert name != "Decoy_Model.safetensors"
|
||||||
|
|
||||||
|
def test_trace_model_path_simple_direct_connection(self):
|
||||||
|
"""Verify it works for a simple Sampler -> Checkpoint connection."""
|
||||||
|
original_prompt = {
|
||||||
|
"100": { # Sampler
|
||||||
|
"class_type": "KSampler",
|
||||||
|
"inputs": {
|
||||||
|
"model": ["101", 0]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"101": { # Checkpoint
|
||||||
|
"class_type": "CheckpointLoaderSimple",
|
||||||
|
"inputs": {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
metadata = {
|
||||||
|
"current_prompt": SimpleNamespace(original_prompt=original_prompt),
|
||||||
|
MODELS: {
|
||||||
|
"101": {"type": "checkpoint", "name": "Simple_Model.safetensors"}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ckpt_id = MetadataProcessor.trace_model_path(metadata, metadata["current_prompt"], "100")
|
||||||
|
assert ckpt_id == "101"
|
||||||
|
|
||||||
|
def test_trace_stops_at_max_depth(self):
|
||||||
|
"""Verify logic halts if graph is infinitely cyclic or too deep."""
|
||||||
|
# Create a cycle: Node 1 -> Node 2 -> Node 1
|
||||||
|
original_prompt = {
|
||||||
|
"1": {"inputs": {"model": ["2", 0]}},
|
||||||
|
"2": {"inputs": {"model": ["1", 0]}}
|
||||||
|
}
|
||||||
|
|
||||||
|
metadata = {
|
||||||
|
"current_prompt": SimpleNamespace(original_prompt=original_prompt),
|
||||||
|
MODELS: {} # No checkpoints registered
|
||||||
|
}
|
||||||
|
|
||||||
|
# Should return None, not hang forever
|
||||||
|
ckpt_id = MetadataProcessor.trace_model_path(metadata, metadata["current_prompt"], "1")
|
||||||
|
assert ckpt_id is None
|
||||||
|
|
||||||
Reference in New Issue
Block a user