mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-24 06:32:12 -03:00
Enhance node processing and error handling in workflow mappers
- Improved reference handling in NodeMapper to support integer node IDs and added error logging for reference processing failures. - Updated LoraLoaderMapper and LoraStackerMapper to handle lora_stack as a dictionary, ensuring compatibility with new data formats. - Refactored trace_model_path utility to perform a depth-first search for LoRA nodes, improving the accuracy of model path tracing. - Cleaned up unused code in parser.py related to LoRA processing, streamlining the workflow parsing logic.
This commit is contained in:
@@ -60,43 +60,61 @@ def get_input_node_ids(workflow: Dict, node_id: str) -> Dict[str, Tuple[str, int
|
||||
|
||||
return result
|
||||
|
||||
def trace_model_path(workflow: Dict, start_node_id: str,
|
||||
visited: Optional[Set[str]] = None) -> List[str]:
|
||||
def trace_model_path(workflow: Dict, start_node_id: str) -> List[str]:
|
||||
"""
|
||||
Trace through the workflow graph following 'model' inputs
|
||||
to find all LoRA Loader nodes that affect the model
|
||||
Trace the model path backward from KSampler to find all LoRA nodes
|
||||
|
||||
Returns a list of LoRA Loader node IDs
|
||||
Args:
|
||||
workflow: The workflow data
|
||||
start_node_id: The starting node ID (usually KSampler)
|
||||
|
||||
Returns:
|
||||
List of node IDs in the model path
|
||||
"""
|
||||
if visited is None:
|
||||
visited = set()
|
||||
|
||||
# Prevent cycles
|
||||
if start_node_id in visited:
|
||||
return []
|
||||
|
||||
visited.add(start_node_id)
|
||||
model_path_nodes = []
|
||||
|
||||
node_data = workflow.get(start_node_id)
|
||||
if not node_data:
|
||||
return []
|
||||
|
||||
# If this is a LoRA Loader node, add it to the result
|
||||
if node_data.get("class_type") == "Lora Loader (LoraManager)":
|
||||
return [start_node_id]
|
||||
|
||||
# Get all input nodes
|
||||
input_nodes = get_input_node_ids(workflow, start_node_id)
|
||||
# Get the model input from the start node
|
||||
if start_node_id not in workflow:
|
||||
return model_path_nodes
|
||||
|
||||
# Recursively trace the model input if it exists
|
||||
result = []
|
||||
if "model" in input_nodes:
|
||||
model_node_id, _ = input_nodes["model"]
|
||||
result.extend(trace_model_path(workflow, model_node_id, visited))
|
||||
# Track visited nodes to avoid cycles
|
||||
visited = set()
|
||||
|
||||
# Stack for depth-first search
|
||||
stack = []
|
||||
|
||||
# Get model input reference if available
|
||||
start_node = workflow[start_node_id]
|
||||
if "inputs" in start_node and "model" in start_node["inputs"] and isinstance(start_node["inputs"]["model"], list):
|
||||
model_ref = start_node["inputs"]["model"]
|
||||
stack.append(str(model_ref[0]))
|
||||
|
||||
# Perform depth-first search
|
||||
while stack:
|
||||
node_id = stack.pop()
|
||||
|
||||
# Also trace lora_stack input if it exists
|
||||
if "lora_stack" in input_nodes:
|
||||
lora_stack_node_id, _ = input_nodes["lora_stack"]
|
||||
result.extend(trace_model_path(workflow, lora_stack_node_id, visited))
|
||||
# Skip if already visited
|
||||
if node_id in visited:
|
||||
continue
|
||||
|
||||
return result
|
||||
# Mark as visited
|
||||
visited.add(node_id)
|
||||
|
||||
# Skip if node doesn't exist
|
||||
if node_id not in workflow:
|
||||
continue
|
||||
|
||||
node = workflow[node_id]
|
||||
node_type = node.get("class_type", "")
|
||||
|
||||
# Add current node to result list if it's a LoRA node
|
||||
if "Lora" in node_type:
|
||||
model_path_nodes.append(node_id)
|
||||
|
||||
# Add all input nodes that have a "model" or "lora_stack" output to the stack
|
||||
if "inputs" in node:
|
||||
for input_name, input_value in node["inputs"].items():
|
||||
if input_name in ["model", "lora_stack"] and isinstance(input_value, list) and len(input_value) == 2:
|
||||
stack.append(str(input_value[0]))
|
||||
|
||||
return model_path_nodes
|
||||
Reference in New Issue
Block a user