mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-22 13:42:12 -03:00
- Improved reference handling in NodeMapper to support integer node IDs and added error logging for reference processing failures. - Updated LoraLoaderMapper and LoraStackerMapper to handle lora_stack as a dictionary, ensuring compatibility with new data formats. - Refactored trace_model_path utility to perform a depth-first search for LoRA nodes, improving the accuracy of model path tracing. - Cleaned up unused code in parser.py related to LoRA processing, streamlining the workflow parsing logic.
120 lines
4.0 KiB
Python
120 lines
4.0 KiB
Python
"""
|
|
Utility functions for ComfyUI workflow parsing
|
|
"""
|
|
import json
|
|
import os
|
|
import logging
|
|
from typing import Dict, List, Any, Optional, Union, Set, Tuple
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
def load_workflow(workflow_path: str) -> Dict:
|
|
"""Load a workflow from a JSON file"""
|
|
try:
|
|
with open(workflow_path, 'r', encoding='utf-8') as f:
|
|
return json.load(f)
|
|
except Exception as e:
|
|
logger.error(f"Error loading workflow from {workflow_path}: {e}")
|
|
raise
|
|
|
|
def save_output(output: Dict, output_path: str) -> None:
|
|
"""Save the parsed output to a JSON file"""
|
|
os.makedirs(os.path.dirname(os.path.abspath(output_path)), exist_ok=True)
|
|
try:
|
|
with open(output_path, 'w', encoding='utf-8') as f:
|
|
json.dump(output, f, indent=4)
|
|
except Exception as e:
|
|
logger.error(f"Error saving output to {output_path}: {e}")
|
|
raise
|
|
|
|
def find_node_by_type(workflow: Dict, node_type: str) -> Optional[str]:
|
|
"""Find a node of the specified type in the workflow"""
|
|
for node_id, node_data in workflow.items():
|
|
if node_data.get("class_type") == node_type:
|
|
return node_id
|
|
return None
|
|
|
|
def find_nodes_by_type(workflow: Dict, node_type: str) -> List[str]:
|
|
"""Find all nodes of the specified type in the workflow"""
|
|
return [node_id for node_id, node_data in workflow.items()
|
|
if node_data.get("class_type") == node_type]
|
|
|
|
def get_input_node_ids(workflow: Dict, node_id: str) -> Dict[str, Tuple[str, int]]:
|
|
"""
|
|
Get the node IDs for all inputs of the given node
|
|
|
|
Returns a dictionary mapping input names to (node_id, output_slot) tuples
|
|
"""
|
|
result = {}
|
|
if node_id not in workflow:
|
|
return result
|
|
|
|
node_data = workflow[node_id]
|
|
for input_name, input_value in node_data.get("inputs", {}).items():
|
|
# Check if this input is connected to another node
|
|
if isinstance(input_value, list) and len(input_value) == 2:
|
|
# Input is connected to another node's output
|
|
# Format: [node_id, output_slot]
|
|
ref_node_id, output_slot = input_value
|
|
result[input_name] = (str(ref_node_id), output_slot)
|
|
|
|
return result
|
|
|
|
def trace_model_path(workflow: Dict, start_node_id: str) -> List[str]:
|
|
"""
|
|
Trace the model path backward from KSampler to find all LoRA nodes
|
|
|
|
Args:
|
|
workflow: The workflow data
|
|
start_node_id: The starting node ID (usually KSampler)
|
|
|
|
Returns:
|
|
List of node IDs in the model path
|
|
"""
|
|
model_path_nodes = []
|
|
|
|
# Get the model input from the start node
|
|
if start_node_id not in workflow:
|
|
return model_path_nodes
|
|
|
|
# Track visited nodes to avoid cycles
|
|
visited = set()
|
|
|
|
# Stack for depth-first search
|
|
stack = []
|
|
|
|
# Get model input reference if available
|
|
start_node = workflow[start_node_id]
|
|
if "inputs" in start_node and "model" in start_node["inputs"] and isinstance(start_node["inputs"]["model"], list):
|
|
model_ref = start_node["inputs"]["model"]
|
|
stack.append(str(model_ref[0]))
|
|
|
|
# Perform depth-first search
|
|
while stack:
|
|
node_id = stack.pop()
|
|
|
|
# Skip if already visited
|
|
if node_id in visited:
|
|
continue
|
|
|
|
# Mark as visited
|
|
visited.add(node_id)
|
|
|
|
# Skip if node doesn't exist
|
|
if node_id not in workflow:
|
|
continue
|
|
|
|
node = workflow[node_id]
|
|
node_type = node.get("class_type", "")
|
|
|
|
# Add current node to result list if it's a LoRA node
|
|
if "Lora" in node_type:
|
|
model_path_nodes.append(node_id)
|
|
|
|
# Add all input nodes that have a "model" or "lora_stack" output to the stack
|
|
if "inputs" in node:
|
|
for input_name, input_value in node["inputs"].items():
|
|
if input_name in ["model", "lora_stack"] and isinstance(input_value, list) and len(input_value) == 2:
|
|
stack.append(str(input_value[0]))
|
|
|
|
return model_path_nodes |