mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-24 06:32:12 -03:00
Enhance node processing and error handling in workflow mappers
- Improved reference handling in NodeMapper to support integer node IDs and added error logging for reference processing failures. - Updated LoraLoaderMapper and LoraStackerMapper to handle lora_stack as a dictionary, ensuring compatibility with new data formats. - Refactored trace_model_path utility to perform a depth-first search for LoRA nodes, improving the accuracy of model path tracing. - Cleaned up unused code in parser.py related to LoRA processing, streamlining the workflow parsing logic.
This commit is contained in:
@@ -28,10 +28,25 @@ class NodeMapper:
|
||||
# Check if input is a reference to another node's output
|
||||
if isinstance(input_value, list) and len(input_value) == 2:
|
||||
# Format is [node_id, output_slot]
|
||||
ref_node_id, output_slot = input_value
|
||||
# Recursively process the referenced node
|
||||
ref_value = parser.process_node(str(ref_node_id), workflow)
|
||||
result[input_name] = ref_value
|
||||
try:
|
||||
ref_node_id, output_slot = input_value
|
||||
# Convert node_id to string if it's an integer
|
||||
if isinstance(ref_node_id, int):
|
||||
ref_node_id = str(ref_node_id)
|
||||
|
||||
# Recursively process the referenced node
|
||||
ref_value = parser.process_node(ref_node_id, workflow)
|
||||
|
||||
# Store the processed value
|
||||
if ref_value is not None:
|
||||
result[input_name] = ref_value
|
||||
else:
|
||||
# If we couldn't get a value from the reference, store the raw value
|
||||
result[input_name] = input_value
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing reference in node {node_id}, input {input_name}: {e}")
|
||||
# If we couldn't process the reference, store the raw value
|
||||
result[input_name] = input_value
|
||||
else:
|
||||
# Direct value
|
||||
result[input_name] = input_value
|
||||
@@ -142,7 +157,7 @@ class LoraLoaderMapper(NodeMapper):
|
||||
def transform(self, inputs: Dict) -> Dict:
|
||||
# Fallback to loras array if text field doesn't exist or is invalid
|
||||
loras_data = inputs.get("loras", [])
|
||||
lora_stack = inputs.get("lora_stack", [])
|
||||
lora_stack = inputs.get("lora_stack", {}).get("lora_stack", [])
|
||||
|
||||
# Process loras array - filter active entries
|
||||
lora_texts = []
|
||||
@@ -157,18 +172,24 @@ class LoraLoaderMapper(NodeMapper):
|
||||
|
||||
# Process each active lora entry
|
||||
for lora in loras_list:
|
||||
logger.info(f"Lora: {lora}, active: {lora.get('active')}")
|
||||
if isinstance(lora, dict) and lora.get("active", False):
|
||||
lora_name = lora.get("name", "")
|
||||
strength = lora.get("strength", 1.0)
|
||||
if lora_name and not lora_name.startswith("__dummy"):
|
||||
lora_texts.append(f"<lora:{lora_name}:{strength}>")
|
||||
lora_texts.append(f"<lora:{lora_name}:{strength}>")
|
||||
|
||||
# Process lora_stack if it exists
|
||||
if lora_stack:
|
||||
# Format each entry from the stack
|
||||
for lora_path, strength, _ in lora_stack:
|
||||
lora_name = os.path.basename(lora_path).split('.')[0]
|
||||
if lora_name and not lora_name.startswith("__dummy"):
|
||||
# Process lora_stack if it exists and is a valid format (list of tuples)
|
||||
if lora_stack and isinstance(lora_stack, list):
|
||||
# If lora_stack is a reference to another node ([node_id, output_slot]),
|
||||
# we don't process it here as it's already been processed recursively
|
||||
if len(lora_stack) == 2 and isinstance(lora_stack[0], (str, int)) and isinstance(lora_stack[1], int):
|
||||
# This is a reference to another node, already processed
|
||||
pass
|
||||
else:
|
||||
# Format each entry from the stack (assuming it's a list of tuples)
|
||||
for stack_entry in lora_stack:
|
||||
lora_name = stack_entry[0]
|
||||
strength = stack_entry[1]
|
||||
lora_texts.append(f"<lora:{lora_name}:{strength}>")
|
||||
|
||||
# Join with spaces
|
||||
@@ -188,12 +209,21 @@ class LoraStackerMapper(NodeMapper):
|
||||
|
||||
def transform(self, inputs: Dict) -> Dict:
|
||||
loras_data = inputs.get("loras", [])
|
||||
existing_stack = inputs.get("lora_stack", [])
|
||||
existing_stack = inputs.get("lora_stack", {}).get("lora_stack", [])
|
||||
result_stack = []
|
||||
|
||||
# Keep existing stack entries
|
||||
# Handle existing stack entries
|
||||
if existing_stack:
|
||||
result_stack.extend(existing_stack)
|
||||
# Check if existing_stack is a reference to another node ([node_id, output_slot])
|
||||
if isinstance(existing_stack, list) and len(existing_stack) == 2 and isinstance(existing_stack[0], (str, int)) and isinstance(existing_stack[1], int):
|
||||
# This is a reference to another node, should already be processed
|
||||
# So we'll need to extract the value from that node
|
||||
if isinstance(inputs.get("lora_stack", {}), dict) and "lora_stack" in inputs["lora_stack"]:
|
||||
# If we have the processed result, use it
|
||||
result_stack.extend(inputs["lora_stack"]["lora_stack"])
|
||||
elif isinstance(existing_stack, list):
|
||||
# If it's a regular list (not a node reference), just add the entries
|
||||
result_stack.extend(existing_stack)
|
||||
|
||||
# Process loras array - filter active entries
|
||||
# Check if loras_data is a list or a dict with __value__ key (new format)
|
||||
@@ -209,10 +239,7 @@ class LoraStackerMapper(NodeMapper):
|
||||
if isinstance(lora, dict) and lora.get("active", False):
|
||||
lora_name = lora.get("name", "")
|
||||
strength = float(lora.get("strength", 1.0))
|
||||
if lora_name and not lora_name.startswith("__dummy"):
|
||||
# Here we would need the real path, but as a fallback use the name
|
||||
# In a real implementation, this would require looking up the file path
|
||||
result_stack.append((lora_name, strength, strength))
|
||||
result_stack.append((lora_name, strength))
|
||||
|
||||
return {"lora_stack": result_stack}
|
||||
|
||||
|
||||
@@ -41,7 +41,12 @@ class WorkflowParser:
|
||||
result = None
|
||||
mapper = get_mapper(node_type)
|
||||
if mapper:
|
||||
result = mapper.process(node_id, node_data, workflow, self)
|
||||
try:
|
||||
result = mapper.process(node_id, node_data, workflow, self)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing node {node_id} of type {node_type}: {e}", exc_info=True)
|
||||
# Return a partial result or None depending on how we want to handle errors
|
||||
result = {}
|
||||
|
||||
# Remove node from processed set to allow it to be processed again in a different context
|
||||
self.processed_nodes.remove(node_id)
|
||||
@@ -112,23 +117,6 @@ class WorkflowParser:
|
||||
if "guidance" in node_inputs:
|
||||
result["gen_params"]["guidance"] = node_inputs["guidance"]
|
||||
|
||||
# Trace the model path to find LoRA Loader nodes
|
||||
lora_node_ids = trace_model_path(workflow, ksampler_node_id)
|
||||
|
||||
# Process each LoRA Loader node
|
||||
lora_texts = []
|
||||
for lora_node_id in lora_node_ids:
|
||||
# Reset the processed nodes tracker for each lora processing
|
||||
self.processed_nodes = set()
|
||||
|
||||
lora_result = self.process_node(lora_node_id, workflow)
|
||||
if lora_result and "loras" in lora_result:
|
||||
lora_texts.append(lora_result["loras"])
|
||||
|
||||
# Combine all LoRA texts
|
||||
if lora_texts:
|
||||
result["loras"] = " ".join(lora_texts)
|
||||
|
||||
# Add clip_skip = 2 to match reference output if not already present
|
||||
if "clip_skip" not in result["gen_params"]:
|
||||
result["gen_params"]["clip_skip"] = "2"
|
||||
|
||||
@@ -60,43 +60,61 @@ def get_input_node_ids(workflow: Dict, node_id: str) -> Dict[str, Tuple[str, int
|
||||
|
||||
return result
|
||||
|
||||
def trace_model_path(workflow: Dict, start_node_id: str,
|
||||
visited: Optional[Set[str]] = None) -> List[str]:
|
||||
def trace_model_path(workflow: Dict, start_node_id: str) -> List[str]:
|
||||
"""
|
||||
Trace through the workflow graph following 'model' inputs
|
||||
to find all LoRA Loader nodes that affect the model
|
||||
Trace the model path backward from KSampler to find all LoRA nodes
|
||||
|
||||
Returns a list of LoRA Loader node IDs
|
||||
Args:
|
||||
workflow: The workflow data
|
||||
start_node_id: The starting node ID (usually KSampler)
|
||||
|
||||
Returns:
|
||||
List of node IDs in the model path
|
||||
"""
|
||||
if visited is None:
|
||||
visited = set()
|
||||
|
||||
# Prevent cycles
|
||||
if start_node_id in visited:
|
||||
return []
|
||||
|
||||
visited.add(start_node_id)
|
||||
model_path_nodes = []
|
||||
|
||||
node_data = workflow.get(start_node_id)
|
||||
if not node_data:
|
||||
return []
|
||||
|
||||
# If this is a LoRA Loader node, add it to the result
|
||||
if node_data.get("class_type") == "Lora Loader (LoraManager)":
|
||||
return [start_node_id]
|
||||
|
||||
# Get all input nodes
|
||||
input_nodes = get_input_node_ids(workflow, start_node_id)
|
||||
# Get the model input from the start node
|
||||
if start_node_id not in workflow:
|
||||
return model_path_nodes
|
||||
|
||||
# Recursively trace the model input if it exists
|
||||
result = []
|
||||
if "model" in input_nodes:
|
||||
model_node_id, _ = input_nodes["model"]
|
||||
result.extend(trace_model_path(workflow, model_node_id, visited))
|
||||
# Track visited nodes to avoid cycles
|
||||
visited = set()
|
||||
|
||||
# Stack for depth-first search
|
||||
stack = []
|
||||
|
||||
# Get model input reference if available
|
||||
start_node = workflow[start_node_id]
|
||||
if "inputs" in start_node and "model" in start_node["inputs"] and isinstance(start_node["inputs"]["model"], list):
|
||||
model_ref = start_node["inputs"]["model"]
|
||||
stack.append(str(model_ref[0]))
|
||||
|
||||
# Perform depth-first search
|
||||
while stack:
|
||||
node_id = stack.pop()
|
||||
|
||||
# Also trace lora_stack input if it exists
|
||||
if "lora_stack" in input_nodes:
|
||||
lora_stack_node_id, _ = input_nodes["lora_stack"]
|
||||
result.extend(trace_model_path(workflow, lora_stack_node_id, visited))
|
||||
# Skip if already visited
|
||||
if node_id in visited:
|
||||
continue
|
||||
|
||||
return result
|
||||
# Mark as visited
|
||||
visited.add(node_id)
|
||||
|
||||
# Skip if node doesn't exist
|
||||
if node_id not in workflow:
|
||||
continue
|
||||
|
||||
node = workflow[node_id]
|
||||
node_type = node.get("class_type", "")
|
||||
|
||||
# Add current node to result list if it's a LoRA node
|
||||
if "Lora" in node_type:
|
||||
model_path_nodes.append(node_id)
|
||||
|
||||
# Add all input nodes that have a "model" or "lora_stack" output to the stack
|
||||
if "inputs" in node:
|
||||
for input_name, input_value in node["inputs"].items():
|
||||
if input_name in ["model", "lora_stack"] and isinstance(input_value, list) and len(input_value) == 2:
|
||||
stack.append(str(input_value[0]))
|
||||
|
||||
return model_path_nodes
|
||||
Reference in New Issue
Block a user