Enhance node processing and error handling in workflow mappers

- Improved reference handling in NodeMapper to support integer node IDs and added error logging for reference processing failures.
- Updated LoraLoaderMapper and LoraStackerMapper to handle lora_stack as a dictionary, ensuring compatibility with new data formats.
- Refactored trace_model_path utility to perform a depth-first search for LoRA nodes, improving the accuracy of model path tracing.
- Cleaned up unused code in parser.py related to LoRA processing, streamlining the workflow parsing logic.
This commit is contained in:
Will Miao
2025-03-23 07:20:50 +08:00
parent 042153329b
commit 6aa2342be1
6 changed files with 178 additions and 91 deletions

View File

@@ -28,10 +28,25 @@ class NodeMapper:
# Check if input is a reference to another node's output
if isinstance(input_value, list) and len(input_value) == 2:
# Format is [node_id, output_slot]
ref_node_id, output_slot = input_value
# Recursively process the referenced node
ref_value = parser.process_node(str(ref_node_id), workflow)
result[input_name] = ref_value
try:
ref_node_id, output_slot = input_value
# Convert node_id to string if it's an integer
if isinstance(ref_node_id, int):
ref_node_id = str(ref_node_id)
# Recursively process the referenced node
ref_value = parser.process_node(ref_node_id, workflow)
# Store the processed value
if ref_value is not None:
result[input_name] = ref_value
else:
# If we couldn't get a value from the reference, store the raw value
result[input_name] = input_value
except Exception as e:
logger.error(f"Error processing reference in node {node_id}, input {input_name}: {e}")
# If we couldn't process the reference, store the raw value
result[input_name] = input_value
else:
# Direct value
result[input_name] = input_value
@@ -142,7 +157,7 @@ class LoraLoaderMapper(NodeMapper):
def transform(self, inputs: Dict) -> Dict:
# Fallback to loras array if text field doesn't exist or is invalid
loras_data = inputs.get("loras", [])
lora_stack = inputs.get("lora_stack", [])
lora_stack = inputs.get("lora_stack", {}).get("lora_stack", [])
# Process loras array - filter active entries
lora_texts = []
@@ -157,18 +172,24 @@ class LoraLoaderMapper(NodeMapper):
# Process each active lora entry
for lora in loras_list:
logger.info(f"Lora: {lora}, active: {lora.get('active')}")
if isinstance(lora, dict) and lora.get("active", False):
lora_name = lora.get("name", "")
strength = lora.get("strength", 1.0)
if lora_name and not lora_name.startswith("__dummy"):
lora_texts.append(f"<lora:{lora_name}:{strength}>")
lora_texts.append(f"<lora:{lora_name}:{strength}>")
# Process lora_stack if it exists
if lora_stack:
# Format each entry from the stack
for lora_path, strength, _ in lora_stack:
lora_name = os.path.basename(lora_path).split('.')[0]
if lora_name and not lora_name.startswith("__dummy"):
# Process lora_stack if it exists and is a valid format (list of tuples)
if lora_stack and isinstance(lora_stack, list):
# If lora_stack is a reference to another node ([node_id, output_slot]),
# we don't process it here as it's already been processed recursively
if len(lora_stack) == 2 and isinstance(lora_stack[0], (str, int)) and isinstance(lora_stack[1], int):
# This is a reference to another node, already processed
pass
else:
# Format each entry from the stack (assuming it's a list of tuples)
for stack_entry in lora_stack:
lora_name = stack_entry[0]
strength = stack_entry[1]
lora_texts.append(f"<lora:{lora_name}:{strength}>")
# Join with spaces
@@ -188,12 +209,21 @@ class LoraStackerMapper(NodeMapper):
def transform(self, inputs: Dict) -> Dict:
loras_data = inputs.get("loras", [])
existing_stack = inputs.get("lora_stack", [])
existing_stack = inputs.get("lora_stack", {}).get("lora_stack", [])
result_stack = []
# Keep existing stack entries
# Handle existing stack entries
if existing_stack:
result_stack.extend(existing_stack)
# Check if existing_stack is a reference to another node ([node_id, output_slot])
if isinstance(existing_stack, list) and len(existing_stack) == 2 and isinstance(existing_stack[0], (str, int)) and isinstance(existing_stack[1], int):
# This is a reference to another node, should already be processed
# So we'll need to extract the value from that node
if isinstance(inputs.get("lora_stack", {}), dict) and "lora_stack" in inputs["lora_stack"]:
# If we have the processed result, use it
result_stack.extend(inputs["lora_stack"]["lora_stack"])
elif isinstance(existing_stack, list):
# If it's a regular list (not a node reference), just add the entries
result_stack.extend(existing_stack)
# Process loras array - filter active entries
# Check if loras_data is a list or a dict with __value__ key (new format)
@@ -209,10 +239,7 @@ class LoraStackerMapper(NodeMapper):
if isinstance(lora, dict) and lora.get("active", False):
lora_name = lora.get("name", "")
strength = float(lora.get("strength", 1.0))
if lora_name and not lora_name.startswith("__dummy"):
# Here we would need the real path, but as a fallback use the name
# In a real implementation, this would require looking up the file path
result_stack.append((lora_name, strength, strength))
result_stack.append((lora_name, strength))
return {"lora_stack": result_stack}

View File

@@ -41,7 +41,12 @@ class WorkflowParser:
result = None
mapper = get_mapper(node_type)
if mapper:
result = mapper.process(node_id, node_data, workflow, self)
try:
result = mapper.process(node_id, node_data, workflow, self)
except Exception as e:
logger.error(f"Error processing node {node_id} of type {node_type}: {e}", exc_info=True)
# Return a partial result or None depending on how we want to handle errors
result = {}
# Remove node from processed set to allow it to be processed again in a different context
self.processed_nodes.remove(node_id)
@@ -112,23 +117,6 @@ class WorkflowParser:
if "guidance" in node_inputs:
result["gen_params"]["guidance"] = node_inputs["guidance"]
# Trace the model path to find LoRA Loader nodes
lora_node_ids = trace_model_path(workflow, ksampler_node_id)
# Process each LoRA Loader node
lora_texts = []
for lora_node_id in lora_node_ids:
# Reset the processed nodes tracker for each lora processing
self.processed_nodes = set()
lora_result = self.process_node(lora_node_id, workflow)
if lora_result and "loras" in lora_result:
lora_texts.append(lora_result["loras"])
# Combine all LoRA texts
if lora_texts:
result["loras"] = " ".join(lora_texts)
# Add clip_skip = 2 to match reference output if not already present
if "clip_skip" not in result["gen_params"]:
result["gen_params"]["clip_skip"] = "2"

View File

@@ -60,43 +60,61 @@ def get_input_node_ids(workflow: Dict, node_id: str) -> Dict[str, Tuple[str, int
return result
def trace_model_path(workflow: Dict, start_node_id: str,
visited: Optional[Set[str]] = None) -> List[str]:
def trace_model_path(workflow: Dict, start_node_id: str) -> List[str]:
"""
Trace through the workflow graph following 'model' inputs
to find all LoRA Loader nodes that affect the model
Trace the model path backward from KSampler to find all LoRA nodes
Returns a list of LoRA Loader node IDs
Args:
workflow: The workflow data
start_node_id: The starting node ID (usually KSampler)
Returns:
List of node IDs in the model path
"""
if visited is None:
visited = set()
# Prevent cycles
if start_node_id in visited:
return []
visited.add(start_node_id)
model_path_nodes = []
node_data = workflow.get(start_node_id)
if not node_data:
return []
# If this is a LoRA Loader node, add it to the result
if node_data.get("class_type") == "Lora Loader (LoraManager)":
return [start_node_id]
# Get all input nodes
input_nodes = get_input_node_ids(workflow, start_node_id)
# Get the model input from the start node
if start_node_id not in workflow:
return model_path_nodes
# Recursively trace the model input if it exists
result = []
if "model" in input_nodes:
model_node_id, _ = input_nodes["model"]
result.extend(trace_model_path(workflow, model_node_id, visited))
# Track visited nodes to avoid cycles
visited = set()
# Stack for depth-first search
stack = []
# Get model input reference if available
start_node = workflow[start_node_id]
if "inputs" in start_node and "model" in start_node["inputs"] and isinstance(start_node["inputs"]["model"], list):
model_ref = start_node["inputs"]["model"]
stack.append(str(model_ref[0]))
# Perform depth-first search
while stack:
node_id = stack.pop()
# Also trace lora_stack input if it exists
if "lora_stack" in input_nodes:
lora_stack_node_id, _ = input_nodes["lora_stack"]
result.extend(trace_model_path(workflow, lora_stack_node_id, visited))
# Skip if already visited
if node_id in visited:
continue
return result
# Mark as visited
visited.add(node_id)
# Skip if node doesn't exist
if node_id not in workflow:
continue
node = workflow[node_id]
node_type = node.get("class_type", "")
# Add current node to result list if it's a LoRA node
if "Lora" in node_type:
model_path_nodes.append(node_id)
# Add all input nodes that have a "model" or "lora_stack" output to the stack
if "inputs" in node:
for input_name, input_value in node["inputs"].items():
if input_name in ["model", "lora_stack"] and isinstance(input_value, list) and len(input_value) == 2:
stack.append(str(input_value[0]))
return model_path_nodes

View File

@@ -114,7 +114,7 @@
1
]
},
"class_type": "CLIPSetLastLayer",
"class_type": "CLIPSetLastLayer",
"_meta": {
"title": "CLIP Set Last Layer"
}
@@ -154,18 +154,6 @@
"text": "in the style of ck-rw",
"active": true
},
{
"text": "aorun, scales, makeup, bare shoulders, pointy ears",
"active": true
},
{
"text": "dress",
"active": true
},
{
"text": "claws",
"active": true
},
{
"text": "in the style of cksc",
"active": true
@@ -174,10 +162,6 @@
"text": "artist:moriimee",
"active": true
},
{
"text": "in the style of cknc",
"active": true
},
{
"text": "__dummy_item__",
"active": false,
@@ -189,7 +173,7 @@
"_isDummy": true
}
],
"orinalMessage": "in the style of ck-rw,, aorun, scales, makeup, bare shoulders, pointy ears,, dress,, claws,, in the style of cksc,, artist:moriimee,, in the style of cknc",
"orinalMessage": "in the style of ck-rw,, in the style of cksc,, artist:moriimee",
"trigger_words": [
"56",
2
@@ -217,7 +201,7 @@
{
"name": "ck-nc-cyberpunk-IL-000011",
"strength": 0.4,
"active": true
"active": false
},
{
"name": "__dummy_item1__",
@@ -257,7 +241,7 @@
{
"name": "aorunIllstrious",
"strength": "0.90",
"active": true
"active": false
},
{
"name": "__dummy_item1__",

25
simple_test.py Normal file
View File

@@ -0,0 +1,25 @@
import json
from py.workflow.parser import WorkflowParser
# Load workflow data
with open('refs/prompt.json', 'r') as f:
workflow_data = json.load(f)
# Parse workflow
parser = WorkflowParser()
try:
# Parse the workflow
result = parser.parse_workflow(workflow_data)
print("Parsing successful!")
# Print each component separately
print("\nGeneration Parameters:")
for k, v in result.get("gen_params", {}).items():
print(f" {k}: {v}")
print("\nLoRAs:")
print(result.get("loras", ""))
except Exception as e:
print(f"Error parsing workflow: {e}")
import traceback
traceback.print_exc()

45
test_parser.py Normal file
View File

@@ -0,0 +1,45 @@
import json
import sys
from py.workflow.parser import WorkflowParser
from py.workflow.utils import trace_model_path
# Load workflow data
with open('refs/prompt.json', 'r') as f:
workflow_data = json.load(f)
# Parse workflow
parser = WorkflowParser()
try:
# Find KSampler node
ksampler_node = None
for node_id, node in workflow_data.items():
if node.get("class_type") == "KSampler":
ksampler_node = node_id
break
if not ksampler_node:
print("KSampler node not found")
sys.exit(1)
# Trace all Lora nodes
print("Finding Lora nodes in the workflow...")
lora_nodes = trace_model_path(workflow_data, ksampler_node)
print(f"Found Lora nodes: {lora_nodes}")
# Print node details
for node_id in lora_nodes:
node = workflow_data[node_id]
print(f"\nNode {node_id}: {node.get('class_type')}")
for key, value in node.get("inputs", {}).items():
print(f" - {key}: {value}")
# Parse the workflow
result = parser.parse_workflow(workflow_data)
print("\nParsing successful!")
print(json.dumps(result, indent=2))
sys.exit(0)
except Exception as e:
print(f"Error parsing workflow: {e}")
import traceback
traceback.print_exc()
sys.exit(1)