mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
Enhance node processing and error handling in workflow mappers
- Improved reference handling in NodeMapper to support integer node IDs and added error logging for reference processing failures. - Updated LoraLoaderMapper and LoraStackerMapper to handle lora_stack as a dictionary, ensuring compatibility with new data formats. - Refactored trace_model_path utility to perform a depth-first search for LoRA nodes, improving the accuracy of model path tracing. - Cleaned up unused code in parser.py related to LoRA processing, streamlining the workflow parsing logic.
This commit is contained in:
@@ -28,10 +28,25 @@ class NodeMapper:
|
||||
# Check if input is a reference to another node's output
|
||||
if isinstance(input_value, list) and len(input_value) == 2:
|
||||
# Format is [node_id, output_slot]
|
||||
ref_node_id, output_slot = input_value
|
||||
# Recursively process the referenced node
|
||||
ref_value = parser.process_node(str(ref_node_id), workflow)
|
||||
result[input_name] = ref_value
|
||||
try:
|
||||
ref_node_id, output_slot = input_value
|
||||
# Convert node_id to string if it's an integer
|
||||
if isinstance(ref_node_id, int):
|
||||
ref_node_id = str(ref_node_id)
|
||||
|
||||
# Recursively process the referenced node
|
||||
ref_value = parser.process_node(ref_node_id, workflow)
|
||||
|
||||
# Store the processed value
|
||||
if ref_value is not None:
|
||||
result[input_name] = ref_value
|
||||
else:
|
||||
# If we couldn't get a value from the reference, store the raw value
|
||||
result[input_name] = input_value
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing reference in node {node_id}, input {input_name}: {e}")
|
||||
# If we couldn't process the reference, store the raw value
|
||||
result[input_name] = input_value
|
||||
else:
|
||||
# Direct value
|
||||
result[input_name] = input_value
|
||||
@@ -142,7 +157,7 @@ class LoraLoaderMapper(NodeMapper):
|
||||
def transform(self, inputs: Dict) -> Dict:
|
||||
# Fallback to loras array if text field doesn't exist or is invalid
|
||||
loras_data = inputs.get("loras", [])
|
||||
lora_stack = inputs.get("lora_stack", [])
|
||||
lora_stack = inputs.get("lora_stack", {}).get("lora_stack", [])
|
||||
|
||||
# Process loras array - filter active entries
|
||||
lora_texts = []
|
||||
@@ -157,18 +172,24 @@ class LoraLoaderMapper(NodeMapper):
|
||||
|
||||
# Process each active lora entry
|
||||
for lora in loras_list:
|
||||
logger.info(f"Lora: {lora}, active: {lora.get('active')}")
|
||||
if isinstance(lora, dict) and lora.get("active", False):
|
||||
lora_name = lora.get("name", "")
|
||||
strength = lora.get("strength", 1.0)
|
||||
if lora_name and not lora_name.startswith("__dummy"):
|
||||
lora_texts.append(f"<lora:{lora_name}:{strength}>")
|
||||
lora_texts.append(f"<lora:{lora_name}:{strength}>")
|
||||
|
||||
# Process lora_stack if it exists
|
||||
if lora_stack:
|
||||
# Format each entry from the stack
|
||||
for lora_path, strength, _ in lora_stack:
|
||||
lora_name = os.path.basename(lora_path).split('.')[0]
|
||||
if lora_name and not lora_name.startswith("__dummy"):
|
||||
# Process lora_stack if it exists and is a valid format (list of tuples)
|
||||
if lora_stack and isinstance(lora_stack, list):
|
||||
# If lora_stack is a reference to another node ([node_id, output_slot]),
|
||||
# we don't process it here as it's already been processed recursively
|
||||
if len(lora_stack) == 2 and isinstance(lora_stack[0], (str, int)) and isinstance(lora_stack[1], int):
|
||||
# This is a reference to another node, already processed
|
||||
pass
|
||||
else:
|
||||
# Format each entry from the stack (assuming it's a list of tuples)
|
||||
for stack_entry in lora_stack:
|
||||
lora_name = stack_entry[0]
|
||||
strength = stack_entry[1]
|
||||
lora_texts.append(f"<lora:{lora_name}:{strength}>")
|
||||
|
||||
# Join with spaces
|
||||
@@ -188,12 +209,21 @@ class LoraStackerMapper(NodeMapper):
|
||||
|
||||
def transform(self, inputs: Dict) -> Dict:
|
||||
loras_data = inputs.get("loras", [])
|
||||
existing_stack = inputs.get("lora_stack", [])
|
||||
existing_stack = inputs.get("lora_stack", {}).get("lora_stack", [])
|
||||
result_stack = []
|
||||
|
||||
# Keep existing stack entries
|
||||
# Handle existing stack entries
|
||||
if existing_stack:
|
||||
result_stack.extend(existing_stack)
|
||||
# Check if existing_stack is a reference to another node ([node_id, output_slot])
|
||||
if isinstance(existing_stack, list) and len(existing_stack) == 2 and isinstance(existing_stack[0], (str, int)) and isinstance(existing_stack[1], int):
|
||||
# This is a reference to another node, should already be processed
|
||||
# So we'll need to extract the value from that node
|
||||
if isinstance(inputs.get("lora_stack", {}), dict) and "lora_stack" in inputs["lora_stack"]:
|
||||
# If we have the processed result, use it
|
||||
result_stack.extend(inputs["lora_stack"]["lora_stack"])
|
||||
elif isinstance(existing_stack, list):
|
||||
# If it's a regular list (not a node reference), just add the entries
|
||||
result_stack.extend(existing_stack)
|
||||
|
||||
# Process loras array - filter active entries
|
||||
# Check if loras_data is a list or a dict with __value__ key (new format)
|
||||
@@ -209,10 +239,7 @@ class LoraStackerMapper(NodeMapper):
|
||||
if isinstance(lora, dict) and lora.get("active", False):
|
||||
lora_name = lora.get("name", "")
|
||||
strength = float(lora.get("strength", 1.0))
|
||||
if lora_name and not lora_name.startswith("__dummy"):
|
||||
# Here we would need the real path, but as a fallback use the name
|
||||
# In a real implementation, this would require looking up the file path
|
||||
result_stack.append((lora_name, strength, strength))
|
||||
result_stack.append((lora_name, strength))
|
||||
|
||||
return {"lora_stack": result_stack}
|
||||
|
||||
|
||||
@@ -41,7 +41,12 @@ class WorkflowParser:
|
||||
result = None
|
||||
mapper = get_mapper(node_type)
|
||||
if mapper:
|
||||
result = mapper.process(node_id, node_data, workflow, self)
|
||||
try:
|
||||
result = mapper.process(node_id, node_data, workflow, self)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing node {node_id} of type {node_type}: {e}", exc_info=True)
|
||||
# Return a partial result or None depending on how we want to handle errors
|
||||
result = {}
|
||||
|
||||
# Remove node from processed set to allow it to be processed again in a different context
|
||||
self.processed_nodes.remove(node_id)
|
||||
@@ -112,23 +117,6 @@ class WorkflowParser:
|
||||
if "guidance" in node_inputs:
|
||||
result["gen_params"]["guidance"] = node_inputs["guidance"]
|
||||
|
||||
# Trace the model path to find LoRA Loader nodes
|
||||
lora_node_ids = trace_model_path(workflow, ksampler_node_id)
|
||||
|
||||
# Process each LoRA Loader node
|
||||
lora_texts = []
|
||||
for lora_node_id in lora_node_ids:
|
||||
# Reset the processed nodes tracker for each lora processing
|
||||
self.processed_nodes = set()
|
||||
|
||||
lora_result = self.process_node(lora_node_id, workflow)
|
||||
if lora_result and "loras" in lora_result:
|
||||
lora_texts.append(lora_result["loras"])
|
||||
|
||||
# Combine all LoRA texts
|
||||
if lora_texts:
|
||||
result["loras"] = " ".join(lora_texts)
|
||||
|
||||
# Add clip_skip = 2 to match reference output if not already present
|
||||
if "clip_skip" not in result["gen_params"]:
|
||||
result["gen_params"]["clip_skip"] = "2"
|
||||
|
||||
@@ -60,43 +60,61 @@ def get_input_node_ids(workflow: Dict, node_id: str) -> Dict[str, Tuple[str, int
|
||||
|
||||
return result
|
||||
|
||||
def trace_model_path(workflow: Dict, start_node_id: str,
|
||||
visited: Optional[Set[str]] = None) -> List[str]:
|
||||
def trace_model_path(workflow: Dict, start_node_id: str) -> List[str]:
|
||||
"""
|
||||
Trace through the workflow graph following 'model' inputs
|
||||
to find all LoRA Loader nodes that affect the model
|
||||
Trace the model path backward from KSampler to find all LoRA nodes
|
||||
|
||||
Returns a list of LoRA Loader node IDs
|
||||
Args:
|
||||
workflow: The workflow data
|
||||
start_node_id: The starting node ID (usually KSampler)
|
||||
|
||||
Returns:
|
||||
List of node IDs in the model path
|
||||
"""
|
||||
if visited is None:
|
||||
visited = set()
|
||||
model_path_nodes = []
|
||||
|
||||
# Prevent cycles
|
||||
if start_node_id in visited:
|
||||
return []
|
||||
# Get the model input from the start node
|
||||
if start_node_id not in workflow:
|
||||
return model_path_nodes
|
||||
|
||||
visited.add(start_node_id)
|
||||
# Track visited nodes to avoid cycles
|
||||
visited = set()
|
||||
|
||||
node_data = workflow.get(start_node_id)
|
||||
if not node_data:
|
||||
return []
|
||||
# Stack for depth-first search
|
||||
stack = []
|
||||
|
||||
# If this is a LoRA Loader node, add it to the result
|
||||
if node_data.get("class_type") == "Lora Loader (LoraManager)":
|
||||
return [start_node_id]
|
||||
# Get model input reference if available
|
||||
start_node = workflow[start_node_id]
|
||||
if "inputs" in start_node and "model" in start_node["inputs"] and isinstance(start_node["inputs"]["model"], list):
|
||||
model_ref = start_node["inputs"]["model"]
|
||||
stack.append(str(model_ref[0]))
|
||||
|
||||
# Get all input nodes
|
||||
input_nodes = get_input_node_ids(workflow, start_node_id)
|
||||
# Perform depth-first search
|
||||
while stack:
|
||||
node_id = stack.pop()
|
||||
|
||||
# Recursively trace the model input if it exists
|
||||
result = []
|
||||
if "model" in input_nodes:
|
||||
model_node_id, _ = input_nodes["model"]
|
||||
result.extend(trace_model_path(workflow, model_node_id, visited))
|
||||
# Skip if already visited
|
||||
if node_id in visited:
|
||||
continue
|
||||
|
||||
# Also trace lora_stack input if it exists
|
||||
if "lora_stack" in input_nodes:
|
||||
lora_stack_node_id, _ = input_nodes["lora_stack"]
|
||||
result.extend(trace_model_path(workflow, lora_stack_node_id, visited))
|
||||
# Mark as visited
|
||||
visited.add(node_id)
|
||||
|
||||
return result
|
||||
# Skip if node doesn't exist
|
||||
if node_id not in workflow:
|
||||
continue
|
||||
|
||||
node = workflow[node_id]
|
||||
node_type = node.get("class_type", "")
|
||||
|
||||
# Add current node to result list if it's a LoRA node
|
||||
if "Lora" in node_type:
|
||||
model_path_nodes.append(node_id)
|
||||
|
||||
# Add all input nodes that have a "model" or "lora_stack" output to the stack
|
||||
if "inputs" in node:
|
||||
for input_name, input_value in node["inputs"].items():
|
||||
if input_name in ["model", "lora_stack"] and isinstance(input_value, list) and len(input_value) == 2:
|
||||
stack.append(str(input_value[0]))
|
||||
|
||||
return model_path_nodes
|
||||
@@ -154,18 +154,6 @@
|
||||
"text": "in the style of ck-rw",
|
||||
"active": true
|
||||
},
|
||||
{
|
||||
"text": "aorun, scales, makeup, bare shoulders, pointy ears",
|
||||
"active": true
|
||||
},
|
||||
{
|
||||
"text": "dress",
|
||||
"active": true
|
||||
},
|
||||
{
|
||||
"text": "claws",
|
||||
"active": true
|
||||
},
|
||||
{
|
||||
"text": "in the style of cksc",
|
||||
"active": true
|
||||
@@ -174,10 +162,6 @@
|
||||
"text": "artist:moriimee",
|
||||
"active": true
|
||||
},
|
||||
{
|
||||
"text": "in the style of cknc",
|
||||
"active": true
|
||||
},
|
||||
{
|
||||
"text": "__dummy_item__",
|
||||
"active": false,
|
||||
@@ -189,7 +173,7 @@
|
||||
"_isDummy": true
|
||||
}
|
||||
],
|
||||
"orinalMessage": "in the style of ck-rw,, aorun, scales, makeup, bare shoulders, pointy ears,, dress,, claws,, in the style of cksc,, artist:moriimee,, in the style of cknc",
|
||||
"orinalMessage": "in the style of ck-rw,, in the style of cksc,, artist:moriimee",
|
||||
"trigger_words": [
|
||||
"56",
|
||||
2
|
||||
@@ -217,7 +201,7 @@
|
||||
{
|
||||
"name": "ck-nc-cyberpunk-IL-000011",
|
||||
"strength": 0.4,
|
||||
"active": true
|
||||
"active": false
|
||||
},
|
||||
{
|
||||
"name": "__dummy_item1__",
|
||||
@@ -257,7 +241,7 @@
|
||||
{
|
||||
"name": "aorunIllstrious",
|
||||
"strength": "0.90",
|
||||
"active": true
|
||||
"active": false
|
||||
},
|
||||
{
|
||||
"name": "__dummy_item1__",
|
||||
|
||||
25
simple_test.py
Normal file
25
simple_test.py
Normal file
@@ -0,0 +1,25 @@
|
||||
import json
|
||||
from py.workflow.parser import WorkflowParser
|
||||
|
||||
# Load workflow data
|
||||
with open('refs/prompt.json', 'r') as f:
|
||||
workflow_data = json.load(f)
|
||||
|
||||
# Parse workflow
|
||||
parser = WorkflowParser()
|
||||
try:
|
||||
# Parse the workflow
|
||||
result = parser.parse_workflow(workflow_data)
|
||||
print("Parsing successful!")
|
||||
|
||||
# Print each component separately
|
||||
print("\nGeneration Parameters:")
|
||||
for k, v in result.get("gen_params", {}).items():
|
||||
print(f" {k}: {v}")
|
||||
|
||||
print("\nLoRAs:")
|
||||
print(result.get("loras", ""))
|
||||
except Exception as e:
|
||||
print(f"Error parsing workflow: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
45
test_parser.py
Normal file
45
test_parser.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import json
|
||||
import sys
|
||||
from py.workflow.parser import WorkflowParser
|
||||
from py.workflow.utils import trace_model_path
|
||||
|
||||
# Load workflow data
|
||||
with open('refs/prompt.json', 'r') as f:
|
||||
workflow_data = json.load(f)
|
||||
|
||||
# Parse workflow
|
||||
parser = WorkflowParser()
|
||||
try:
|
||||
# Find KSampler node
|
||||
ksampler_node = None
|
||||
for node_id, node in workflow_data.items():
|
||||
if node.get("class_type") == "KSampler":
|
||||
ksampler_node = node_id
|
||||
break
|
||||
|
||||
if not ksampler_node:
|
||||
print("KSampler node not found")
|
||||
sys.exit(1)
|
||||
|
||||
# Trace all Lora nodes
|
||||
print("Finding Lora nodes in the workflow...")
|
||||
lora_nodes = trace_model_path(workflow_data, ksampler_node)
|
||||
print(f"Found Lora nodes: {lora_nodes}")
|
||||
|
||||
# Print node details
|
||||
for node_id in lora_nodes:
|
||||
node = workflow_data[node_id]
|
||||
print(f"\nNode {node_id}: {node.get('class_type')}")
|
||||
for key, value in node.get("inputs", {}).items():
|
||||
print(f" - {key}: {value}")
|
||||
|
||||
# Parse the workflow
|
||||
result = parser.parse_workflow(workflow_data)
|
||||
print("\nParsing successful!")
|
||||
print(json.dumps(result, indent=2))
|
||||
sys.exit(0)
|
||||
except Exception as e:
|
||||
print(f"Error parsing workflow: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
Reference in New Issue
Block a user