mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-25 15:15:44 -03:00
Enhance LoraStackerMapper and WorkflowParser functionality
- Updated LoraStackerMapper to handle multiple formats for lora_stack input, improving flexibility in processing existing stacks. - Introduced caching for processed node results in WorkflowParser to optimize performance and prevent redundant processing. - Added a new method to collect loras from model inputs, enhancing the ability to extract relevant data from the workflow. - Improved handling of processed nodes to avoid cycles and ensure accurate results during workflow parsing.
This commit is contained in:
@@ -777,8 +777,6 @@ class RecipeRoutes:
|
|||||||
# Parse the workflow to extract generation parameters and loras
|
# Parse the workflow to extract generation parameters and loras
|
||||||
parsed_workflow = self.parser.parse_workflow(workflow_json)
|
parsed_workflow = self.parser.parse_workflow(workflow_json)
|
||||||
|
|
||||||
logger.debug(f"Parsed workflow: {parsed_workflow}")
|
|
||||||
|
|
||||||
if not parsed_workflow or not parsed_workflow.get("gen_params"):
|
if not parsed_workflow or not parsed_workflow.get("gen_params"):
|
||||||
return web.json_response({"error": "Could not extract generation parameters from workflow"}, status=400)
|
return web.json_response({"error": "Could not extract generation parameters from workflow"}, status=400)
|
||||||
|
|
||||||
|
|||||||
@@ -209,21 +209,28 @@ class LoraStackerMapper(NodeMapper):
|
|||||||
|
|
||||||
def transform(self, inputs: Dict) -> Dict:
|
def transform(self, inputs: Dict) -> Dict:
|
||||||
loras_data = inputs.get("loras", [])
|
loras_data = inputs.get("loras", [])
|
||||||
existing_stack = inputs.get("lora_stack", {}).get("lora_stack", [])
|
|
||||||
result_stack = []
|
result_stack = []
|
||||||
|
|
||||||
# Handle existing stack entries
|
# Handle existing stack entries
|
||||||
|
existing_stack = []
|
||||||
|
lora_stack_input = inputs.get("lora_stack", [])
|
||||||
|
|
||||||
|
# Handle different formats of lora_stack
|
||||||
|
if isinstance(lora_stack_input, dict) and "lora_stack" in lora_stack_input:
|
||||||
|
# Format from another LoraStacker node
|
||||||
|
existing_stack = lora_stack_input["lora_stack"]
|
||||||
|
elif isinstance(lora_stack_input, list):
|
||||||
|
# Direct list format or reference format [node_id, output_slot]
|
||||||
|
if len(lora_stack_input) == 2 and isinstance(lora_stack_input[0], (str, int)) and isinstance(lora_stack_input[1], int):
|
||||||
|
# This is likely a reference that was already processed
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
# Regular list of tuples/entries
|
||||||
|
existing_stack = lora_stack_input
|
||||||
|
|
||||||
|
# Add existing entries first
|
||||||
if existing_stack:
|
if existing_stack:
|
||||||
# Check if existing_stack is a reference to another node ([node_id, output_slot])
|
result_stack.extend(existing_stack)
|
||||||
if isinstance(existing_stack, list) and len(existing_stack) == 2 and isinstance(existing_stack[0], (str, int)) and isinstance(existing_stack[1], int):
|
|
||||||
# This is a reference to another node, should already be processed
|
|
||||||
# So we'll need to extract the value from that node
|
|
||||||
if isinstance(inputs.get("lora_stack", {}), dict) and "lora_stack" in inputs["lora_stack"]:
|
|
||||||
# If we have the processed result, use it
|
|
||||||
result_stack.extend(inputs["lora_stack"]["lora_stack"])
|
|
||||||
elif isinstance(existing_stack, list):
|
|
||||||
# If it's a regular list (not a node reference), just add the entries
|
|
||||||
result_stack.extend(existing_stack)
|
|
||||||
|
|
||||||
# Process loras array - filter active entries
|
# Process loras array - filter active entries
|
||||||
# Check if loras_data is a list or a dict with __value__ key (new format)
|
# Check if loras_data is a list or a dict with __value__ key (new format)
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ class WorkflowParser:
|
|||||||
def __init__(self, load_extensions_on_init: bool = True):
|
def __init__(self, load_extensions_on_init: bool = True):
|
||||||
"""Initialize the parser with mappers"""
|
"""Initialize the parser with mappers"""
|
||||||
self.processed_nodes: Set[str] = set() # Track processed nodes to avoid cycles
|
self.processed_nodes: Set[str] = set() # Track processed nodes to avoid cycles
|
||||||
|
self.node_results_cache: Dict[str, Any] = {} # Cache for processed node results
|
||||||
|
|
||||||
# Load extensions if requested
|
# Load extensions if requested
|
||||||
if load_extensions_on_init:
|
if load_extensions_on_init:
|
||||||
@@ -25,14 +26,19 @@ class WorkflowParser:
|
|||||||
|
|
||||||
def process_node(self, node_id: str, workflow: Dict) -> Any:
|
def process_node(self, node_id: str, workflow: Dict) -> Any:
|
||||||
"""Process a single node and extract relevant information"""
|
"""Process a single node and extract relevant information"""
|
||||||
# Check if we've already processed this node to avoid cycles
|
# Return cached result if available
|
||||||
|
if node_id in self.node_results_cache:
|
||||||
|
return self.node_results_cache[node_id]
|
||||||
|
|
||||||
|
# Check if we're in a cycle
|
||||||
if node_id in self.processed_nodes:
|
if node_id in self.processed_nodes:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Mark this node as processed
|
# Mark this node as being processed (to detect cycles)
|
||||||
self.processed_nodes.add(node_id)
|
self.processed_nodes.add(node_id)
|
||||||
|
|
||||||
if node_id not in workflow:
|
if node_id not in workflow:
|
||||||
|
self.processed_nodes.remove(node_id)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
node_data = workflow[node_id]
|
node_data = workflow[node_id]
|
||||||
@@ -43,6 +49,8 @@ class WorkflowParser:
|
|||||||
if mapper:
|
if mapper:
|
||||||
try:
|
try:
|
||||||
result = mapper.process(node_id, node_data, workflow, self)
|
result = mapper.process(node_id, node_data, workflow, self)
|
||||||
|
# Cache the result
|
||||||
|
self.node_results_cache[node_id] = result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error processing node {node_id} of type {node_type}: {e}", exc_info=True)
|
logger.error(f"Error processing node {node_id} of type {node_type}: {e}", exc_info=True)
|
||||||
# Return a partial result or None depending on how we want to handle errors
|
# Return a partial result or None depending on how we want to handle errors
|
||||||
@@ -52,6 +60,33 @@ class WorkflowParser:
|
|||||||
self.processed_nodes.remove(node_id)
|
self.processed_nodes.remove(node_id)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def collect_loras_from_model(self, model_input: List, workflow: Dict) -> str:
|
||||||
|
"""Collect loras information from the model node chain"""
|
||||||
|
if not isinstance(model_input, list) or len(model_input) != 2:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
model_node_id, _ = model_input
|
||||||
|
# Convert node_id to string if it's an integer
|
||||||
|
if isinstance(model_node_id, int):
|
||||||
|
model_node_id = str(model_node_id)
|
||||||
|
|
||||||
|
# Process the model node
|
||||||
|
model_result = self.process_node(model_node_id, workflow)
|
||||||
|
|
||||||
|
# If this is a Lora Loader node, return the loras text
|
||||||
|
if model_result and isinstance(model_result, dict) and "loras" in model_result:
|
||||||
|
return model_result["loras"]
|
||||||
|
|
||||||
|
# If not a lora loader, check the node's inputs for a model connection
|
||||||
|
node_data = workflow.get(model_node_id, {})
|
||||||
|
inputs = node_data.get("inputs", {})
|
||||||
|
|
||||||
|
# If this node has a model input, follow that path
|
||||||
|
if "model" in inputs and isinstance(inputs["model"], list):
|
||||||
|
return self.collect_loras_from_model(inputs["model"], workflow)
|
||||||
|
|
||||||
|
return ""
|
||||||
|
|
||||||
def parse_workflow(self, workflow_data: Union[str, Dict], output_path: Optional[str] = None) -> Dict:
|
def parse_workflow(self, workflow_data: Union[str, Dict], output_path: Optional[str] = None) -> Dict:
|
||||||
"""
|
"""
|
||||||
Parse the workflow and extract generation parameters
|
Parse the workflow and extract generation parameters
|
||||||
@@ -69,8 +104,9 @@ class WorkflowParser:
|
|||||||
else:
|
else:
|
||||||
workflow = workflow_data
|
workflow = workflow_data
|
||||||
|
|
||||||
# Reset the processed nodes tracker
|
# Reset the processed nodes tracker and cache
|
||||||
self.processed_nodes = set()
|
self.processed_nodes = set()
|
||||||
|
self.node_results_cache = {}
|
||||||
|
|
||||||
# Find the KSampler node
|
# Find the KSampler node
|
||||||
ksampler_node_id = find_node_by_type(workflow, "KSampler")
|
ksampler_node_id = find_node_by_type(workflow, "KSampler")
|
||||||
@@ -117,6 +153,18 @@ class WorkflowParser:
|
|||||||
if "guidance" in node_inputs:
|
if "guidance" in node_inputs:
|
||||||
result["gen_params"]["guidance"] = node_inputs["guidance"]
|
result["gen_params"]["guidance"] = node_inputs["guidance"]
|
||||||
|
|
||||||
|
# Extract loras from the model input of KSampler
|
||||||
|
ksampler_node = workflow.get(ksampler_node_id, {})
|
||||||
|
ksampler_inputs = ksampler_node.get("inputs", {})
|
||||||
|
if "model" in ksampler_inputs and isinstance(ksampler_inputs["model"], list):
|
||||||
|
loras_text = self.collect_loras_from_model(ksampler_inputs["model"], workflow)
|
||||||
|
if loras_text:
|
||||||
|
result["loras"] = loras_text
|
||||||
|
|
||||||
|
# Handle standard ComfyUI names vs our output format
|
||||||
|
if "cfg" in result["gen_params"]:
|
||||||
|
result["gen_params"]["cfg_scale"] = result["gen_params"].pop("cfg")
|
||||||
|
|
||||||
# Add clip_skip = 2 to match reference output if not already present
|
# Add clip_skip = 2 to match reference output if not already present
|
||||||
if "clip_skip" not in result["gen_params"]:
|
if "clip_skip" not in result["gen_params"]:
|
||||||
result["gen_params"]["clip_skip"] = "2"
|
result["gen_params"]["clip_skip"] = "2"
|
||||||
|
|||||||
Reference in New Issue
Block a user