mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-23 22:22:11 -03:00
feat: Add model path tracing to accurately identify the primary checkpoint in workflows and include new tests.
This commit is contained in:
@@ -202,12 +202,84 @@ class MetadataProcessor:
|
||||
return last_valid_node if not target_class else None
|
||||
|
||||
@staticmethod
|
||||
def find_primary_checkpoint(metadata):
|
||||
"""Find the primary checkpoint model in the workflow"""
|
||||
if not metadata.get(MODELS):
|
||||
def trace_model_path(metadata, prompt, start_node_id):
|
||||
"""
|
||||
Trace the model connection path upstream to find the checkpoint
|
||||
"""
|
||||
if not prompt or not prompt.original_prompt:
|
||||
return None
|
||||
|
||||
# In most workflows, there's only one checkpoint, so we can just take the first one
|
||||
current_node_id = start_node_id
|
||||
depth = 0
|
||||
max_depth = 50
|
||||
|
||||
while depth < max_depth:
|
||||
# Check if current node is a registered checkpoint in our metadata
|
||||
# This handles cached nodes correctly because metadata contains info for all nodes in the graph
|
||||
if current_node_id in metadata.get(MODELS, {}):
|
||||
if metadata[MODELS][current_node_id].get("type") == "checkpoint":
|
||||
return current_node_id
|
||||
|
||||
if current_node_id not in prompt.original_prompt:
|
||||
return None
|
||||
|
||||
node = prompt.original_prompt[current_node_id]
|
||||
inputs = node.get("inputs", {})
|
||||
class_type = node.get("class_type", "")
|
||||
|
||||
# Determine which input to follow next
|
||||
next_input_name = "model"
|
||||
|
||||
# Special handling for initial node
|
||||
if depth == 0:
|
||||
if class_type == "SamplerCustomAdvanced":
|
||||
next_input_name = "guider"
|
||||
|
||||
# If the specific input doesn't exist, try generic 'model'
|
||||
if next_input_name not in inputs:
|
||||
if "model" in inputs:
|
||||
next_input_name = "model"
|
||||
else:
|
||||
# Dead end - no model input to follow
|
||||
return None
|
||||
|
||||
# Get connected node
|
||||
input_val = inputs[next_input_name]
|
||||
if isinstance(input_val, list) and len(input_val) > 0:
|
||||
current_node_id = input_val[0]
|
||||
else:
|
||||
return None
|
||||
|
||||
depth += 1
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def find_primary_checkpoint(metadata, downstream_id=None):
|
||||
"""
|
||||
Find the primary checkpoint model in the workflow
|
||||
|
||||
Parameters:
|
||||
- metadata: The workflow metadata
|
||||
- downstream_id: Optional ID of a downstream node to help identify the specific primary sampler
|
||||
"""
|
||||
if not metadata.get(MODELS):
|
||||
return None
|
||||
|
||||
# Method 1: Topology-based tracing (More accurate for complex workflows)
|
||||
# First, find the primary sampler
|
||||
primary_sampler_id, _ = MetadataProcessor.find_primary_sampler(metadata, downstream_id)
|
||||
|
||||
if primary_sampler_id:
|
||||
prompt = metadata.get("current_prompt")
|
||||
if prompt:
|
||||
# Trace back from the sampler to find the checkpoint
|
||||
checkpoint_id = MetadataProcessor.trace_model_path(metadata, prompt, primary_sampler_id)
|
||||
if checkpoint_id and checkpoint_id in metadata.get(MODELS, {}):
|
||||
return metadata[MODELS][checkpoint_id].get("name")
|
||||
|
||||
# Method 2: Fallback to the first available checkpoint (Original behavior)
|
||||
# In most simple workflows, there's only one checkpoint, so we can just take the first one
|
||||
for node_id, model_info in metadata.get(MODELS, {}).items():
|
||||
if model_info.get("type") == "checkpoint":
|
||||
return model_info.get("name")
|
||||
@@ -311,7 +383,7 @@ class MetadataProcessor:
|
||||
primary_sampler_id, primary_sampler = MetadataProcessor.find_primary_sampler(metadata, id)
|
||||
|
||||
# Directly get checkpoint from metadata instead of tracing
|
||||
checkpoint = MetadataProcessor.find_primary_checkpoint(metadata)
|
||||
checkpoint = MetadataProcessor.find_primary_checkpoint(metadata, id)
|
||||
if checkpoint:
|
||||
params["checkpoint"] = checkpoint
|
||||
|
||||
|
||||
Reference in New Issue
Block a user