Files
ComfyUI-Lora-Manager/tests/metadata_collector/test_tracer.py

173 lines
6.2 KiB
Python

import pytest
from types import SimpleNamespace
from py.metadata_collector.metadata_processor import MetadataProcessor
from py.metadata_collector.constants import MODELS, SAMPLING, IS_SAMPLER
class TestMetadataTracer:
@pytest.fixture
def mock_workflow_metadata(self):
"""
Creates a mock metadata structure with a complex workflow graph.
Structure:
Sampler(246) -> Guider(241) -> LoraLoader(264) -> CheckpointLoader(238)
Also includes a "Decoy" checkpoint (ID 999) that is NOT connected,
to verify we found the *connected* one, not just *any* one.
"""
# 1. Define the Graph (Original Prompt)
# Using IDs as strings to match typical ComfyUI behavior in metadata
original_prompt = {
"246": {
"class_type": "SamplerCustomAdvanced",
"inputs": {
"guider": ["241", 0],
"noise": ["255", 0],
"sampler": ["247", 0],
"sigmas": ["248", 0],
"latent_image": ["153", 0]
}
},
"241": {
"class_type": "CFGGuider",
"inputs": {
"model": ["264", 0],
"positive": ["239", 0],
"negative": ["240", 0]
}
},
"264": {
"class_type": "LoraLoader", # Simplified name
"inputs": {
"model": ["238", 0],
"lora_name": "some_style_lora.safetensors"
}
},
"238": {
"class_type": "CheckpointLoaderSimple",
"inputs": {
"ckpt_name": "Correct_Model.safetensors"
}
},
# unconnected / decoy nodes
"999": {
"class_type": "CheckpointLoaderSimple",
"inputs": {
"ckpt_name": "Decoy_Model.safetensors"
}
},
"154": { # Downstream VAE Decode
"class_type": "VAEDecode",
"inputs": {
"samples": ["246", 0]
}
}
}
# 2. Define the Metadata (Collected execution data)
metadata = {
"current_prompt": SimpleNamespace(original_prompt=original_prompt),
"execution_order": ["238", "264", "241", "246", "154", "999"], # 999 execs last or separately
# Models Registry
MODELS: {
"238": {
"type": "checkpoint",
"name": "Correct_Model.safetensors"
},
"999": {
"type": "checkpoint",
"name": "Decoy_Model.safetensors"
}
},
# Sampling Registry
SAMPLING: {
"246": {
IS_SAMPLER: True,
"parameters": {
"sampler_name": "euler",
"scheduler": "normal"
}
}
},
"images": {
"first_decode": {
"node_id": "154"
}
}
}
return metadata
def test_find_primary_sampler_identifies_correct_node(self, mock_workflow_metadata):
"""Verify find_primary_sampler correctly identifies the sampler connected to the downstream decode."""
sampler_id, sampler_info = MetadataProcessor.find_primary_sampler(mock_workflow_metadata, downstream_id="154")
assert sampler_id == "246"
assert sampler_info is not None
assert sampler_info["parameters"]["sampler_name"] == "euler"
def test_trace_model_path_follows_topology(self, mock_workflow_metadata):
"""Verify trace_model_path follows: Sampler -> Guider -> Lora -> Checkpoint."""
prompt = mock_workflow_metadata["current_prompt"]
# Start trace from Sampler (246)
# Should find Checkpoint (238)
ckpt_id = MetadataProcessor.trace_model_path(mock_workflow_metadata, prompt, "246")
assert ckpt_id == "238" # Should be the ID of the connected checkpoint
def test_find_primary_checkpoint_prioritizes_connected_model(self, mock_workflow_metadata):
"""Verify find_primary_checkpoint returns the NAME of the topologically connected checkpoint, honoring the graph."""
name = MetadataProcessor.find_primary_checkpoint(mock_workflow_metadata, downstream_id="154")
assert name == "Correct_Model.safetensors"
assert name != "Decoy_Model.safetensors"
def test_trace_model_path_simple_direct_connection(self):
"""Verify it works for a simple Sampler -> Checkpoint connection."""
original_prompt = {
"100": { # Sampler
"class_type": "KSampler",
"inputs": {
"model": ["101", 0]
}
},
"101": { # Checkpoint
"class_type": "CheckpointLoaderSimple",
"inputs": {}
}
}
metadata = {
"current_prompt": SimpleNamespace(original_prompt=original_prompt),
MODELS: {
"101": {"type": "checkpoint", "name": "Simple_Model.safetensors"}
}
}
ckpt_id = MetadataProcessor.trace_model_path(metadata, metadata["current_prompt"], "100")
assert ckpt_id == "101"
def test_trace_stops_at_max_depth(self):
"""Verify logic halts if graph is infinitely cyclic or too deep."""
# Create a cycle: Node 1 -> Node 2 -> Node 1
original_prompt = {
"1": {"inputs": {"model": ["2", 0]}},
"2": {"inputs": {"model": ["1", 0]}}
}
metadata = {
"current_prompt": SimpleNamespace(original_prompt=original_prompt),
MODELS: {} # No checkpoints registered
}
# Should return None, not hang forever
ckpt_id = MetadataProcessor.trace_model_path(metadata, metadata["current_prompt"], "1")
assert ckpt_id is None