From 8690a8f11aaa726843568ae06511a62c6885f199 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Sun, 23 Mar 2025 07:41:55 +0800 Subject: [PATCH] Enhance LoraStackerMapper and WorkflowParser functionality - Updated LoraStackerMapper to handle multiple formats for lora_stack input, improving flexibility in processing existing stacks. - Introduced caching for processed node results in WorkflowParser to optimize performance and prevent redundant processing. - Added a new method to collect loras from model inputs, enhancing the ability to extract relevant data from the workflow. - Improved handling of processed nodes to avoid cycles and ensure accurate results during workflow parsing. --- py/routes/recipe_routes.py | 2 -- py/workflow/mappers.py | 29 ++++++++++++-------- py/workflow/parser.py | 54 +++++++++++++++++++++++++++++++++++--- 3 files changed, 69 insertions(+), 16 deletions(-) diff --git a/py/routes/recipe_routes.py b/py/routes/recipe_routes.py index 6604e51b..dc56c7b1 100644 --- a/py/routes/recipe_routes.py +++ b/py/routes/recipe_routes.py @@ -777,8 +777,6 @@ class RecipeRoutes: # Parse the workflow to extract generation parameters and loras parsed_workflow = self.parser.parse_workflow(workflow_json) - logger.debug(f"Parsed workflow: {parsed_workflow}") - if not parsed_workflow or not parsed_workflow.get("gen_params"): return web.json_response({"error": "Could not extract generation parameters from workflow"}, status=400) diff --git a/py/workflow/mappers.py b/py/workflow/mappers.py index aca0285f..0cdf0e0a 100644 --- a/py/workflow/mappers.py +++ b/py/workflow/mappers.py @@ -209,21 +209,28 @@ class LoraStackerMapper(NodeMapper): def transform(self, inputs: Dict) -> Dict: loras_data = inputs.get("loras", []) - existing_stack = inputs.get("lora_stack", {}).get("lora_stack", []) result_stack = [] # Handle existing stack entries + existing_stack = [] + lora_stack_input = inputs.get("lora_stack", []) + + # Handle different formats of lora_stack + if isinstance(lora_stack_input, dict) and "lora_stack" in lora_stack_input: + # Format from another LoraStacker node + existing_stack = lora_stack_input["lora_stack"] + elif isinstance(lora_stack_input, list): + # Direct list format or reference format [node_id, output_slot] + if len(lora_stack_input) == 2 and isinstance(lora_stack_input[0], (str, int)) and isinstance(lora_stack_input[1], int): + # This is likely a reference that was already processed + pass + else: + # Regular list of tuples/entries + existing_stack = lora_stack_input + + # Add existing entries first if existing_stack: - # Check if existing_stack is a reference to another node ([node_id, output_slot]) - if isinstance(existing_stack, list) and len(existing_stack) == 2 and isinstance(existing_stack[0], (str, int)) and isinstance(existing_stack[1], int): - # This is a reference to another node, should already be processed - # So we'll need to extract the value from that node - if isinstance(inputs.get("lora_stack", {}), dict) and "lora_stack" in inputs["lora_stack"]: - # If we have the processed result, use it - result_stack.extend(inputs["lora_stack"]["lora_stack"]) - elif isinstance(existing_stack, list): - # If it's a regular list (not a node reference), just add the entries - result_stack.extend(existing_stack) + result_stack.extend(existing_stack) # Process loras array - filter active entries # Check if loras_data is a list or a dict with __value__ key (new format) diff --git a/py/workflow/parser.py b/py/workflow/parser.py index 92ca8323..70b40edd 100644 --- a/py/workflow/parser.py +++ b/py/workflow/parser.py @@ -18,6 +18,7 @@ class WorkflowParser: def __init__(self, load_extensions_on_init: bool = True): """Initialize the parser with mappers""" self.processed_nodes: Set[str] = set() # Track processed nodes to avoid cycles + self.node_results_cache: Dict[str, Any] = {} # Cache for processed node results # Load extensions if requested if load_extensions_on_init: @@ -25,14 +26,19 @@ class WorkflowParser: def process_node(self, node_id: str, workflow: Dict) -> Any: """Process a single node and extract relevant information""" - # Check if we've already processed this node to avoid cycles + # Return cached result if available + if node_id in self.node_results_cache: + return self.node_results_cache[node_id] + + # Check if we're in a cycle if node_id in self.processed_nodes: return None - # Mark this node as processed + # Mark this node as being processed (to detect cycles) self.processed_nodes.add(node_id) if node_id not in workflow: + self.processed_nodes.remove(node_id) return None node_data = workflow[node_id] @@ -43,6 +49,8 @@ class WorkflowParser: if mapper: try: result = mapper.process(node_id, node_data, workflow, self) + # Cache the result + self.node_results_cache[node_id] = result except Exception as e: logger.error(f"Error processing node {node_id} of type {node_type}: {e}", exc_info=True) # Return a partial result or None depending on how we want to handle errors @@ -52,6 +60,33 @@ class WorkflowParser: self.processed_nodes.remove(node_id) return result + def collect_loras_from_model(self, model_input: List, workflow: Dict) -> str: + """Collect loras information from the model node chain""" + if not isinstance(model_input, list) or len(model_input) != 2: + return "" + + model_node_id, _ = model_input + # Convert node_id to string if it's an integer + if isinstance(model_node_id, int): + model_node_id = str(model_node_id) + + # Process the model node + model_result = self.process_node(model_node_id, workflow) + + # If this is a Lora Loader node, return the loras text + if model_result and isinstance(model_result, dict) and "loras" in model_result: + return model_result["loras"] + + # If not a lora loader, check the node's inputs for a model connection + node_data = workflow.get(model_node_id, {}) + inputs = node_data.get("inputs", {}) + + # If this node has a model input, follow that path + if "model" in inputs and isinstance(inputs["model"], list): + return self.collect_loras_from_model(inputs["model"], workflow) + + return "" + def parse_workflow(self, workflow_data: Union[str, Dict], output_path: Optional[str] = None) -> Dict: """ Parse the workflow and extract generation parameters @@ -69,8 +104,9 @@ class WorkflowParser: else: workflow = workflow_data - # Reset the processed nodes tracker + # Reset the processed nodes tracker and cache self.processed_nodes = set() + self.node_results_cache = {} # Find the KSampler node ksampler_node_id = find_node_by_type(workflow, "KSampler") @@ -117,6 +153,18 @@ class WorkflowParser: if "guidance" in node_inputs: result["gen_params"]["guidance"] = node_inputs["guidance"] + # Extract loras from the model input of KSampler + ksampler_node = workflow.get(ksampler_node_id, {}) + ksampler_inputs = ksampler_node.get("inputs", {}) + if "model" in ksampler_inputs and isinstance(ksampler_inputs["model"], list): + loras_text = self.collect_loras_from_model(ksampler_inputs["model"], workflow) + if loras_text: + result["loras"] = loras_text + + # Handle standard ComfyUI names vs our output format + if "cfg" in result["gen_params"]: + result["gen_params"]["cfg_scale"] = result["gen_params"].pop("cfg") + # Add clip_skip = 2 to match reference output if not already present if "clip_skip" not in result["gen_params"]: result["gen_params"]["clip_skip"] = "2"