mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
Enhance LoraStackerMapper and WorkflowParser functionality
- Updated LoraStackerMapper to handle multiple formats for lora_stack input, improving flexibility in processing existing stacks. - Introduced caching for processed node results in WorkflowParser to optimize performance and prevent redundant processing. - Added a new method to collect loras from model inputs, enhancing the ability to extract relevant data from the workflow. - Improved handling of processed nodes to avoid cycles and ensure accurate results during workflow parsing.
This commit is contained in:
@@ -777,8 +777,6 @@ class RecipeRoutes:
|
||||
# Parse the workflow to extract generation parameters and loras
|
||||
parsed_workflow = self.parser.parse_workflow(workflow_json)
|
||||
|
||||
logger.debug(f"Parsed workflow: {parsed_workflow}")
|
||||
|
||||
if not parsed_workflow or not parsed_workflow.get("gen_params"):
|
||||
return web.json_response({"error": "Could not extract generation parameters from workflow"}, status=400)
|
||||
|
||||
|
||||
@@ -209,21 +209,28 @@ class LoraStackerMapper(NodeMapper):
|
||||
|
||||
def transform(self, inputs: Dict) -> Dict:
|
||||
loras_data = inputs.get("loras", [])
|
||||
existing_stack = inputs.get("lora_stack", {}).get("lora_stack", [])
|
||||
result_stack = []
|
||||
|
||||
# Handle existing stack entries
|
||||
existing_stack = []
|
||||
lora_stack_input = inputs.get("lora_stack", [])
|
||||
|
||||
# Handle different formats of lora_stack
|
||||
if isinstance(lora_stack_input, dict) and "lora_stack" in lora_stack_input:
|
||||
# Format from another LoraStacker node
|
||||
existing_stack = lora_stack_input["lora_stack"]
|
||||
elif isinstance(lora_stack_input, list):
|
||||
# Direct list format or reference format [node_id, output_slot]
|
||||
if len(lora_stack_input) == 2 and isinstance(lora_stack_input[0], (str, int)) and isinstance(lora_stack_input[1], int):
|
||||
# This is likely a reference that was already processed
|
||||
pass
|
||||
else:
|
||||
# Regular list of tuples/entries
|
||||
existing_stack = lora_stack_input
|
||||
|
||||
# Add existing entries first
|
||||
if existing_stack:
|
||||
# Check if existing_stack is a reference to another node ([node_id, output_slot])
|
||||
if isinstance(existing_stack, list) and len(existing_stack) == 2 and isinstance(existing_stack[0], (str, int)) and isinstance(existing_stack[1], int):
|
||||
# This is a reference to another node, should already be processed
|
||||
# So we'll need to extract the value from that node
|
||||
if isinstance(inputs.get("lora_stack", {}), dict) and "lora_stack" in inputs["lora_stack"]:
|
||||
# If we have the processed result, use it
|
||||
result_stack.extend(inputs["lora_stack"]["lora_stack"])
|
||||
elif isinstance(existing_stack, list):
|
||||
# If it's a regular list (not a node reference), just add the entries
|
||||
result_stack.extend(existing_stack)
|
||||
result_stack.extend(existing_stack)
|
||||
|
||||
# Process loras array - filter active entries
|
||||
# Check if loras_data is a list or a dict with __value__ key (new format)
|
||||
|
||||
@@ -18,6 +18,7 @@ class WorkflowParser:
|
||||
def __init__(self, load_extensions_on_init: bool = True):
|
||||
"""Initialize the parser with mappers"""
|
||||
self.processed_nodes: Set[str] = set() # Track processed nodes to avoid cycles
|
||||
self.node_results_cache: Dict[str, Any] = {} # Cache for processed node results
|
||||
|
||||
# Load extensions if requested
|
||||
if load_extensions_on_init:
|
||||
@@ -25,14 +26,19 @@ class WorkflowParser:
|
||||
|
||||
def process_node(self, node_id: str, workflow: Dict) -> Any:
|
||||
"""Process a single node and extract relevant information"""
|
||||
# Check if we've already processed this node to avoid cycles
|
||||
# Return cached result if available
|
||||
if node_id in self.node_results_cache:
|
||||
return self.node_results_cache[node_id]
|
||||
|
||||
# Check if we're in a cycle
|
||||
if node_id in self.processed_nodes:
|
||||
return None
|
||||
|
||||
# Mark this node as processed
|
||||
# Mark this node as being processed (to detect cycles)
|
||||
self.processed_nodes.add(node_id)
|
||||
|
||||
if node_id not in workflow:
|
||||
self.processed_nodes.remove(node_id)
|
||||
return None
|
||||
|
||||
node_data = workflow[node_id]
|
||||
@@ -43,6 +49,8 @@ class WorkflowParser:
|
||||
if mapper:
|
||||
try:
|
||||
result = mapper.process(node_id, node_data, workflow, self)
|
||||
# Cache the result
|
||||
self.node_results_cache[node_id] = result
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing node {node_id} of type {node_type}: {e}", exc_info=True)
|
||||
# Return a partial result or None depending on how we want to handle errors
|
||||
@@ -52,6 +60,33 @@ class WorkflowParser:
|
||||
self.processed_nodes.remove(node_id)
|
||||
return result
|
||||
|
||||
def collect_loras_from_model(self, model_input: List, workflow: Dict) -> str:
|
||||
"""Collect loras information from the model node chain"""
|
||||
if not isinstance(model_input, list) or len(model_input) != 2:
|
||||
return ""
|
||||
|
||||
model_node_id, _ = model_input
|
||||
# Convert node_id to string if it's an integer
|
||||
if isinstance(model_node_id, int):
|
||||
model_node_id = str(model_node_id)
|
||||
|
||||
# Process the model node
|
||||
model_result = self.process_node(model_node_id, workflow)
|
||||
|
||||
# If this is a Lora Loader node, return the loras text
|
||||
if model_result and isinstance(model_result, dict) and "loras" in model_result:
|
||||
return model_result["loras"]
|
||||
|
||||
# If not a lora loader, check the node's inputs for a model connection
|
||||
node_data = workflow.get(model_node_id, {})
|
||||
inputs = node_data.get("inputs", {})
|
||||
|
||||
# If this node has a model input, follow that path
|
||||
if "model" in inputs and isinstance(inputs["model"], list):
|
||||
return self.collect_loras_from_model(inputs["model"], workflow)
|
||||
|
||||
return ""
|
||||
|
||||
def parse_workflow(self, workflow_data: Union[str, Dict], output_path: Optional[str] = None) -> Dict:
|
||||
"""
|
||||
Parse the workflow and extract generation parameters
|
||||
@@ -69,8 +104,9 @@ class WorkflowParser:
|
||||
else:
|
||||
workflow = workflow_data
|
||||
|
||||
# Reset the processed nodes tracker
|
||||
# Reset the processed nodes tracker and cache
|
||||
self.processed_nodes = set()
|
||||
self.node_results_cache = {}
|
||||
|
||||
# Find the KSampler node
|
||||
ksampler_node_id = find_node_by_type(workflow, "KSampler")
|
||||
@@ -117,6 +153,18 @@ class WorkflowParser:
|
||||
if "guidance" in node_inputs:
|
||||
result["gen_params"]["guidance"] = node_inputs["guidance"]
|
||||
|
||||
# Extract loras from the model input of KSampler
|
||||
ksampler_node = workflow.get(ksampler_node_id, {})
|
||||
ksampler_inputs = ksampler_node.get("inputs", {})
|
||||
if "model" in ksampler_inputs and isinstance(ksampler_inputs["model"], list):
|
||||
loras_text = self.collect_loras_from_model(ksampler_inputs["model"], workflow)
|
||||
if loras_text:
|
||||
result["loras"] = loras_text
|
||||
|
||||
# Handle standard ComfyUI names vs our output format
|
||||
if "cfg" in result["gen_params"]:
|
||||
result["gen_params"]["cfg_scale"] = result["gen_params"].pop("cfg")
|
||||
|
||||
# Add clip_skip = 2 to match reference output if not already present
|
||||
if "clip_skip" not in result["gen_params"]:
|
||||
result["gen_params"]["clip_skip"] = "2"
|
||||
|
||||
Reference in New Issue
Block a user