Files
ComfyUI-Lora-Manager/py/workflow_params/workflow_parser.py
Will Miao 4bff17aa1a Update prompt configuration and enhance Lora management functionality
- Expanded the prompt.json file with new configurations for KSampler, CheckpointLoaderSimple, and various CLIPTextEncode nodes.
- Introduced additional Lora management features, including a new Lora Stacker and improved trigger word handling.
- Enhanced the loras_widget.js to log the generated prompt when saving recipes directly, aiding in debugging and user feedback.
- Improved overall structure and organization of the prompt configurations for better maintainability.
2025-03-21 16:35:52 +08:00

209 lines
7.3 KiB
Python

import json
import logging
from typing import Dict, Any, List, Optional, Set, Union
from .node_processors import NODE_PROCESSORS, NodeProcessor
from .extension_manager import get_extension_manager
logger = logging.getLogger(__name__)
class WorkflowParser:
"""Parser for ComfyUI workflow JSON files"""
def __init__(self, load_extensions: bool = True, extensions_dir: str = None):
"""
Initialize the workflow parser
Args:
load_extensions: Whether to load extensions automatically
extensions_dir: Optional path to extensions directory
"""
self.workflow = None
self.processed_nodes = {} # Cache for processed nodes
self.processing_nodes = set() # To detect circular references
# Load extensions if requested
if load_extensions:
self._load_extensions(extensions_dir)
def _load_extensions(self, extensions_dir: str = None):
"""
Load node processor extensions
Args:
extensions_dir: Optional path to extensions directory
"""
extension_manager = get_extension_manager(extensions_dir)
results = extension_manager.load_all_extensions()
# Log the results
successful = sum(1 for status in results.values() if status)
logger.debug(f"Loaded {successful} of {len(results)} extensions")
def parse_workflow(self, workflow_json: Union[str, Dict]) -> Dict[str, Any]:
"""
Parse a ComfyUI workflow JSON string or dict and extract generation parameters
Args:
workflow_json: JSON string or dict containing the workflow
Returns:
Dict containing extracted generation parameters
"""
# Reset state for this parsing operation
self.processed_nodes = {}
self.processing_nodes = set()
# Load JSON if it's a string
if isinstance(workflow_json, str):
try:
self.workflow = json.loads(workflow_json)
except json.JSONDecodeError as e:
logger.error(f"Failed to parse workflow JSON: {e}")
return {}
else:
self.workflow = workflow_json
if not self.workflow:
return {}
# Find KSampler nodes as entry points
ksampler_nodes = self._find_nodes_by_class("KSampler")
# Find LoraLoader nodes for lora information
lora_nodes = self._find_nodes_by_class("Lora Loader (LoraManager)")
# Check if we need to register additional node types by scanning the workflow
self._check_for_unregistered_node_types()
result = {
"gen_params": {}
}
# Process KSampler nodes to get generation parameters
for node_id in ksampler_nodes:
gen_params = self.process_node(node_id)
if gen_params:
result["gen_params"].update(gen_params)
# Process Lora nodes to get lora stack
lora_stack = ""
for node_id in lora_nodes:
lora_info = self.process_node(node_id)
if lora_info and "lora_stack" in lora_info:
if lora_stack:
lora_stack = f"{lora_stack} {lora_info['lora_stack']}"
else:
lora_stack = lora_info["lora_stack"]
if lora_stack:
result["loras"] = lora_stack
# Process CLIPSetLastLayer node for clip_skip
clip_layer_nodes = self._find_nodes_by_class("CLIPSetLastLayer")
for node_id in clip_layer_nodes:
clip_info = self.process_node(node_id)
if clip_info and "clip_skip" in clip_info:
result["gen_params"]["clip_skip"] = clip_info["clip_skip"]
return result
def _check_for_unregistered_node_types(self):
"""Check for node types in the workflow that aren't registered yet"""
unknown_node_types = set()
# Collect all unique node types in the workflow
for node_id, node_data in self.workflow.items():
class_type = node_data.get("class_type")
if class_type and class_type not in NODE_PROCESSORS:
unknown_node_types.add(class_type)
if unknown_node_types:
logger.debug(f"Found {len(unknown_node_types)} unregistered node types: {unknown_node_types}")
def process_node(self, node_id: str) -> Any:
"""
Process a single node and its dependencies recursively
Args:
node_id: The ID of the node to process
Returns:
Processed data from the node
"""
# Check if already processed
if node_id in self.processed_nodes:
return self.processed_nodes[node_id]
# Check for circular references
if node_id in self.processing_nodes:
logger.warning(f"Circular reference detected for node {node_id}")
return None
# Mark as being processed
self.processing_nodes.add(node_id)
# Get node data
node_data = self.workflow.get(node_id)
if not node_data:
logger.warning(f"Node {node_id} not found in workflow")
self.processing_nodes.remove(node_id)
return None
class_type = node_data.get("class_type")
if not class_type:
logger.warning(f"Node {node_id} has no class_type")
self.processing_nodes.remove(node_id)
return None
# Get the appropriate node processor
processor_class = NODE_PROCESSORS.get(class_type)
if not processor_class:
logger.debug(f"No processor for node type {class_type}")
self.processing_nodes.remove(node_id)
return None
# Process the node
processor = processor_class(node_id, node_data, self.workflow)
result = processor.process(self)
# Cache the result
self.processed_nodes[node_id] = result
# Mark as processed
self.processing_nodes.remove(node_id)
return result
def _find_nodes_by_class(self, class_type: str) -> List[str]:
"""
Find all nodes of a particular class type in the workflow
Args:
class_type: The node class type to find
Returns:
List of node IDs matching the class type
"""
nodes = []
for node_id, node_data in self.workflow.items():
if node_data.get("class_type") == class_type:
nodes.append(node_id)
return nodes
def parse_workflow(workflow_json: Union[str, Dict],
load_extensions: bool = True,
extensions_dir: str = None) -> Dict[str, Any]:
"""
Helper function to parse a workflow JSON without having to create a parser instance
Args:
workflow_json: JSON string or dict containing the workflow
load_extensions: Whether to load extensions automatically
extensions_dir: Optional path to extensions directory
Returns:
Dict containing extracted generation parameters
"""
parser = WorkflowParser(load_extensions, extensions_dir)
return parser.parse_workflow(workflow_json)