mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
feat(metadata): enhance primary sampler detection and workflow tracing
- Add support for `basic_pipe` nodes in metadata processor to handle pipeline nodes like FromBasicPipe - Optimize `find_primary_checkpoint` by accepting optional `primary_sampler_id` to avoid redundant calculations - Update `get_workflow_trace` to pass known primary sampler ID for improved efficiency
This commit is contained in:
@@ -239,6 +239,9 @@ class MetadataProcessor:
|
|||||||
if next_input_name not in inputs:
|
if next_input_name not in inputs:
|
||||||
if "model" in inputs:
|
if "model" in inputs:
|
||||||
next_input_name = "model"
|
next_input_name = "model"
|
||||||
|
elif "basic_pipe" in inputs:
|
||||||
|
# Handle pipe nodes like FromBasicPipe by following the pipeline
|
||||||
|
next_input_name = "basic_pipe"
|
||||||
else:
|
else:
|
||||||
# Dead end - no model input to follow
|
# Dead end - no model input to follow
|
||||||
return None
|
return None
|
||||||
@@ -255,20 +258,22 @@ class MetadataProcessor:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def find_primary_checkpoint(metadata, downstream_id=None):
|
def find_primary_checkpoint(metadata, downstream_id=None, primary_sampler_id=None):
|
||||||
"""
|
"""
|
||||||
Find the primary checkpoint model in the workflow
|
Find the primary checkpoint model in the workflow
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
- metadata: The workflow metadata
|
- metadata: The workflow metadata
|
||||||
- downstream_id: Optional ID of a downstream node to help identify the specific primary sampler
|
- downstream_id: Optional ID of a downstream node to help identify the specific primary sampler
|
||||||
|
- primary_sampler_id: Optional ID of the primary sampler if already known
|
||||||
"""
|
"""
|
||||||
if not metadata.get(MODELS):
|
if not metadata.get(MODELS):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Method 1: Topology-based tracing (More accurate for complex workflows)
|
# Method 1: Topology-based tracing (More accurate for complex workflows)
|
||||||
# First, find the primary sampler
|
# First, find the primary sampler if not provided
|
||||||
primary_sampler_id, _ = MetadataProcessor.find_primary_sampler(metadata, downstream_id)
|
if not primary_sampler_id:
|
||||||
|
primary_sampler_id, _ = MetadataProcessor.find_primary_sampler(metadata, downstream_id)
|
||||||
|
|
||||||
if primary_sampler_id:
|
if primary_sampler_id:
|
||||||
prompt = metadata.get("current_prompt")
|
prompt = metadata.get("current_prompt")
|
||||||
@@ -383,7 +388,8 @@ class MetadataProcessor:
|
|||||||
primary_sampler_id, primary_sampler = MetadataProcessor.find_primary_sampler(metadata, id)
|
primary_sampler_id, primary_sampler = MetadataProcessor.find_primary_sampler(metadata, id)
|
||||||
|
|
||||||
# Directly get checkpoint from metadata instead of tracing
|
# Directly get checkpoint from metadata instead of tracing
|
||||||
checkpoint = MetadataProcessor.find_primary_checkpoint(metadata, id)
|
# Pass primary_sampler_id to avoid redundant calculation
|
||||||
|
checkpoint = MetadataProcessor.find_primary_checkpoint(metadata, id, primary_sampler_id)
|
||||||
if checkpoint:
|
if checkpoint:
|
||||||
params["checkpoint"] = checkpoint
|
params["checkpoint"] = checkpoint
|
||||||
|
|
||||||
|
|||||||
98
tests/metadata_collector/test_pipe_tracer.py
Normal file
98
tests/metadata_collector/test_pipe_tracer.py
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
|
||||||
|
import pytest
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from py.metadata_collector.metadata_processor import MetadataProcessor
|
||||||
|
from py.metadata_collector.constants import MODELS, SAMPLING, IS_SAMPLER
|
||||||
|
|
||||||
|
class TestPipeTracer:
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def pipe_workflow_metadata(self):
|
||||||
|
"""
|
||||||
|
Creates a mock metadata structure matching the one provided in refs/tmp.
|
||||||
|
Structure:
|
||||||
|
Load Checkpoint(28) -> Lora Loader(52) -> ToBasicPipe(69) -> FromBasicPipe(71) -> KSampler(32)
|
||||||
|
"""
|
||||||
|
|
||||||
|
original_prompt = {
|
||||||
|
'28': {
|
||||||
|
'inputs': {'ckpt_name': 'Illustrious\\bananaSplitzXL_vee5PointOh.safetensors'},
|
||||||
|
'class_type': 'CheckpointLoaderSimple'
|
||||||
|
},
|
||||||
|
'52': {
|
||||||
|
'inputs': {
|
||||||
|
'model': ['28', 0],
|
||||||
|
'clip': ['28', 1]
|
||||||
|
},
|
||||||
|
'class_type': 'Lora Loader (LoraManager)'
|
||||||
|
},
|
||||||
|
'69': {
|
||||||
|
'inputs': {
|
||||||
|
'model': ['52', 0],
|
||||||
|
'clip': ['52', 1],
|
||||||
|
'vae': ['28', 2],
|
||||||
|
'positive': ['75', 0],
|
||||||
|
'negative': ['30', 0]
|
||||||
|
},
|
||||||
|
'class_type': 'ToBasicPipe'
|
||||||
|
},
|
||||||
|
'71': {
|
||||||
|
'inputs': {'basic_pipe': ['69', 0]},
|
||||||
|
'class_type': 'FromBasicPipe'
|
||||||
|
},
|
||||||
|
'32': {
|
||||||
|
'inputs': {
|
||||||
|
'seed': 131755205602911,
|
||||||
|
'steps': 5,
|
||||||
|
'cfg': 8.0,
|
||||||
|
'sampler_name': 'euler_ancestral',
|
||||||
|
'scheduler': 'karras',
|
||||||
|
'denoise': 1.0,
|
||||||
|
'model': ['71', 0],
|
||||||
|
'positive': ['71', 3],
|
||||||
|
'negative': ['71', 4],
|
||||||
|
'latent_image': ['76', 0]
|
||||||
|
},
|
||||||
|
'class_type': 'KSampler'
|
||||||
|
},
|
||||||
|
'75': {'inputs': {'text': 'positive', 'clip': ['52', 1]}, 'class_type': 'CLIPTextEncode'},
|
||||||
|
'30': {'inputs': {'text': 'negative', 'clip': ['52', 1]}, 'class_type': 'CLIPTextEncode'},
|
||||||
|
'76': {'inputs': {'width': 832, 'height': 1216, 'batch_size': 1}, 'class_type': 'EmptyLatentImage'}
|
||||||
|
}
|
||||||
|
|
||||||
|
metadata = {
|
||||||
|
"current_prompt": SimpleNamespace(original_prompt=original_prompt),
|
||||||
|
MODELS: {
|
||||||
|
"28": {
|
||||||
|
"type": "checkpoint",
|
||||||
|
"name": "bananaSplitzXL_vee5PointOh.safetensors"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
SAMPLING: {
|
||||||
|
"32": {
|
||||||
|
IS_SAMPLER: True,
|
||||||
|
"parameters": {
|
||||||
|
"sampler_name": "euler_ancestral",
|
||||||
|
"scheduler": "karras"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return metadata
|
||||||
|
|
||||||
|
def test_trace_model_path_through_pipe(self, pipe_workflow_metadata):
|
||||||
|
"""Verify trace_model_path can follow: KSampler -> FromBasicPipe -> ToBasicPipe -> Lora -> Checkpoint."""
|
||||||
|
prompt = pipe_workflow_metadata["current_prompt"]
|
||||||
|
|
||||||
|
# Start trace from KSampler (32)
|
||||||
|
ckpt_id = MetadataProcessor.trace_model_path(pipe_workflow_metadata, prompt, "32")
|
||||||
|
|
||||||
|
assert ckpt_id == "28"
|
||||||
|
|
||||||
|
def test_find_primary_checkpoint_with_pipe(self, pipe_workflow_metadata):
|
||||||
|
"""Verify find_primary_checkpoint returns the correct name even with pipe nodes."""
|
||||||
|
# Providing sampler_id to test the optimization as well
|
||||||
|
name = MetadataProcessor.find_primary_checkpoint(pipe_workflow_metadata, primary_sampler_id="32")
|
||||||
|
|
||||||
|
assert name == "bananaSplitzXL_vee5PointOh.safetensors"
|
||||||
Reference in New Issue
Block a user