Enhance node processing and error handling in workflow mappers

- Improved reference handling in NodeMapper to support integer node IDs and added error logging for reference processing failures.
- Updated LoraLoaderMapper and LoraStackerMapper to handle lora_stack as a dictionary, ensuring compatibility with new data formats.
- Refactored trace_model_path utility to perform a depth-first search for LoRA nodes, improving the accuracy of model path tracing.
- Cleaned up unused code in parser.py related to LoRA processing, streamlining the workflow parsing logic.
This commit is contained in:
Will Miao
2025-03-23 07:20:50 +08:00
parent 042153329b
commit 6aa2342be1
6 changed files with 178 additions and 91 deletions

View File

@@ -60,43 +60,61 @@ def get_input_node_ids(workflow: Dict, node_id: str) -> Dict[str, Tuple[str, int
return result
def trace_model_path(workflow: Dict, start_node_id: str,
visited: Optional[Set[str]] = None) -> List[str]:
def trace_model_path(workflow: Dict, start_node_id: str) -> List[str]:
"""
Trace through the workflow graph following 'model' inputs
to find all LoRA Loader nodes that affect the model
Trace the model path backward from KSampler to find all LoRA nodes
Returns a list of LoRA Loader node IDs
Args:
workflow: The workflow data
start_node_id: The starting node ID (usually KSampler)
Returns:
List of node IDs in the model path
"""
if visited is None:
visited = set()
# Prevent cycles
if start_node_id in visited:
return []
visited.add(start_node_id)
model_path_nodes = []
node_data = workflow.get(start_node_id)
if not node_data:
return []
# If this is a LoRA Loader node, add it to the result
if node_data.get("class_type") == "Lora Loader (LoraManager)":
return [start_node_id]
# Get all input nodes
input_nodes = get_input_node_ids(workflow, start_node_id)
# Get the model input from the start node
if start_node_id not in workflow:
return model_path_nodes
# Recursively trace the model input if it exists
result = []
if "model" in input_nodes:
model_node_id, _ = input_nodes["model"]
result.extend(trace_model_path(workflow, model_node_id, visited))
# Track visited nodes to avoid cycles
visited = set()
# Stack for depth-first search
stack = []
# Get model input reference if available
start_node = workflow[start_node_id]
if "inputs" in start_node and "model" in start_node["inputs"] and isinstance(start_node["inputs"]["model"], list):
model_ref = start_node["inputs"]["model"]
stack.append(str(model_ref[0]))
# Perform depth-first search
while stack:
node_id = stack.pop()
# Also trace lora_stack input if it exists
if "lora_stack" in input_nodes:
lora_stack_node_id, _ = input_nodes["lora_stack"]
result.extend(trace_model_path(workflow, lora_stack_node_id, visited))
# Skip if already visited
if node_id in visited:
continue
return result
# Mark as visited
visited.add(node_id)
# Skip if node doesn't exist
if node_id not in workflow:
continue
node = workflow[node_id]
node_type = node.get("class_type", "")
# Add current node to result list if it's a LoRA node
if "Lora" in node_type:
model_path_nodes.append(node_id)
# Add all input nodes that have a "model" or "lora_stack" output to the stack
if "inputs" in node:
for input_name, input_value in node["inputs"].items():
if input_name in ["model", "lora_stack"] and isinstance(input_value, list) and len(input_value) == 2:
stack.append(str(input_value[0]))
return model_path_nodes