Enhance LoraStackerMapper and WorkflowParser functionality

- Updated LoraStackerMapper to handle multiple formats for lora_stack input, improving flexibility in processing existing stacks.
- Introduced caching for processed node results in WorkflowParser to optimize performance and prevent redundant processing.
- Added a new method to collect loras from model inputs, enhancing the ability to extract relevant data from the workflow.
- Improved handling of processed nodes to avoid cycles and ensure accurate results during workflow parsing.
This commit is contained in:
Will Miao
2025-03-23 07:41:55 +08:00
parent 6aa2342be1
commit 8690a8f11a
3 changed files with 69 additions and 16 deletions

View File

@@ -777,8 +777,6 @@ class RecipeRoutes:
# Parse the workflow to extract generation parameters and loras
parsed_workflow = self.parser.parse_workflow(workflow_json)
logger.debug(f"Parsed workflow: {parsed_workflow}")
if not parsed_workflow or not parsed_workflow.get("gen_params"):
return web.json_response({"error": "Could not extract generation parameters from workflow"}, status=400)

View File

@@ -209,21 +209,28 @@ class LoraStackerMapper(NodeMapper):
def transform(self, inputs: Dict) -> Dict:
loras_data = inputs.get("loras", [])
existing_stack = inputs.get("lora_stack", {}).get("lora_stack", [])
result_stack = []
# Handle existing stack entries
existing_stack = []
lora_stack_input = inputs.get("lora_stack", [])
# Handle different formats of lora_stack
if isinstance(lora_stack_input, dict) and "lora_stack" in lora_stack_input:
# Format from another LoraStacker node
existing_stack = lora_stack_input["lora_stack"]
elif isinstance(lora_stack_input, list):
# Direct list format or reference format [node_id, output_slot]
if len(lora_stack_input) == 2 and isinstance(lora_stack_input[0], (str, int)) and isinstance(lora_stack_input[1], int):
# This is likely a reference that was already processed
pass
else:
# Regular list of tuples/entries
existing_stack = lora_stack_input
# Add existing entries first
if existing_stack:
# Check if existing_stack is a reference to another node ([node_id, output_slot])
if isinstance(existing_stack, list) and len(existing_stack) == 2 and isinstance(existing_stack[0], (str, int)) and isinstance(existing_stack[1], int):
# This is a reference to another node, should already be processed
# So we'll need to extract the value from that node
if isinstance(inputs.get("lora_stack", {}), dict) and "lora_stack" in inputs["lora_stack"]:
# If we have the processed result, use it
result_stack.extend(inputs["lora_stack"]["lora_stack"])
elif isinstance(existing_stack, list):
# If it's a regular list (not a node reference), just add the entries
result_stack.extend(existing_stack)
result_stack.extend(existing_stack)
# Process loras array - filter active entries
# Check if loras_data is a list or a dict with __value__ key (new format)

View File

@@ -18,6 +18,7 @@ class WorkflowParser:
def __init__(self, load_extensions_on_init: bool = True):
"""Initialize the parser with mappers"""
self.processed_nodes: Set[str] = set() # Track processed nodes to avoid cycles
self.node_results_cache: Dict[str, Any] = {} # Cache for processed node results
# Load extensions if requested
if load_extensions_on_init:
@@ -25,14 +26,19 @@ class WorkflowParser:
def process_node(self, node_id: str, workflow: Dict) -> Any:
"""Process a single node and extract relevant information"""
# Check if we've already processed this node to avoid cycles
# Return cached result if available
if node_id in self.node_results_cache:
return self.node_results_cache[node_id]
# Check if we're in a cycle
if node_id in self.processed_nodes:
return None
# Mark this node as processed
# Mark this node as being processed (to detect cycles)
self.processed_nodes.add(node_id)
if node_id not in workflow:
self.processed_nodes.remove(node_id)
return None
node_data = workflow[node_id]
@@ -43,6 +49,8 @@ class WorkflowParser:
if mapper:
try:
result = mapper.process(node_id, node_data, workflow, self)
# Cache the result
self.node_results_cache[node_id] = result
except Exception as e:
logger.error(f"Error processing node {node_id} of type {node_type}: {e}", exc_info=True)
# Return a partial result or None depending on how we want to handle errors
@@ -52,6 +60,33 @@ class WorkflowParser:
self.processed_nodes.remove(node_id)
return result
def collect_loras_from_model(self, model_input: List, workflow: Dict) -> str:
"""Collect loras information from the model node chain"""
if not isinstance(model_input, list) or len(model_input) != 2:
return ""
model_node_id, _ = model_input
# Convert node_id to string if it's an integer
if isinstance(model_node_id, int):
model_node_id = str(model_node_id)
# Process the model node
model_result = self.process_node(model_node_id, workflow)
# If this is a Lora Loader node, return the loras text
if model_result and isinstance(model_result, dict) and "loras" in model_result:
return model_result["loras"]
# If not a lora loader, check the node's inputs for a model connection
node_data = workflow.get(model_node_id, {})
inputs = node_data.get("inputs", {})
# If this node has a model input, follow that path
if "model" in inputs and isinstance(inputs["model"], list):
return self.collect_loras_from_model(inputs["model"], workflow)
return ""
def parse_workflow(self, workflow_data: Union[str, Dict], output_path: Optional[str] = None) -> Dict:
"""
Parse the workflow and extract generation parameters
@@ -69,8 +104,9 @@ class WorkflowParser:
else:
workflow = workflow_data
# Reset the processed nodes tracker
# Reset the processed nodes tracker and cache
self.processed_nodes = set()
self.node_results_cache = {}
# Find the KSampler node
ksampler_node_id = find_node_by_type(workflow, "KSampler")
@@ -117,6 +153,18 @@ class WorkflowParser:
if "guidance" in node_inputs:
result["gen_params"]["guidance"] = node_inputs["guidance"]
# Extract loras from the model input of KSampler
ksampler_node = workflow.get(ksampler_node_id, {})
ksampler_inputs = ksampler_node.get("inputs", {})
if "model" in ksampler_inputs and isinstance(ksampler_inputs["model"], list):
loras_text = self.collect_loras_from_model(ksampler_inputs["model"], workflow)
if loras_text:
result["loras"] = loras_text
# Handle standard ComfyUI names vs our output format
if "cfg" in result["gen_params"]:
result["gen_params"]["cfg_scale"] = result["gen_params"].pop("cfg")
# Add clip_skip = 2 to match reference output if not already present
if "clip_skip" not in result["gen_params"]:
result["gen_params"]["clip_skip"] = "2"