mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-23 14:12:11 -03:00
Update prompt configuration and enhance Lora management functionality
- Expanded the prompt.json file with new configurations for KSampler, CheckpointLoaderSimple, and various CLIPTextEncode nodes. - Introduced additional Lora management features, including a new Lora Stacker and improved trigger word handling. - Enhanced the loras_widget.js to log the generated prompt when saving recipes directly, aiding in debugging and user feedback. - Improved overall structure and organization of the prompt configurations for better maintainability.
This commit is contained in:
116
py/workflow_params/README.md
Normal file
116
py/workflow_params/README.md
Normal file
@@ -0,0 +1,116 @@
|
||||
# ComfyUI Workflow Parser
|
||||
|
||||
A module for parsing ComfyUI workflow JSON and extracting generation parameters.
|
||||
|
||||
## Features
|
||||
|
||||
- Parse ComfyUI workflow JSON files to extract generation parameters
|
||||
- Extract lora information from workflows
|
||||
- Support for node traversal and parameter resolution
|
||||
- Extensible architecture for supporting custom node types
|
||||
- Dynamic loading of node processor extensions
|
||||
|
||||
## Usage
|
||||
|
||||
### Basic Usage
|
||||
|
||||
```python
|
||||
from workflow_params import parse_workflow
|
||||
|
||||
# Parse from a file
|
||||
with open('my_workflow.json', 'r') as f:
|
||||
workflow_json = f.read()
|
||||
|
||||
result = parse_workflow(workflow_json)
|
||||
print(result)
|
||||
```
|
||||
|
||||
### Using the WorkflowParser directly
|
||||
|
||||
```python
|
||||
from workflow_params import WorkflowParser
|
||||
|
||||
parser = WorkflowParser()
|
||||
result = parser.parse_workflow(workflow_json)
|
||||
```
|
||||
|
||||
### Loading Extensions
|
||||
|
||||
Extensions are loaded automatically by default, but you can also control this behavior:
|
||||
|
||||
```python
|
||||
from workflow_params import WorkflowParser
|
||||
|
||||
# Don't load extensions
|
||||
parser = WorkflowParser(load_extensions=False)
|
||||
|
||||
# Load extensions from a custom directory
|
||||
parser = WorkflowParser(extensions_dir='/path/to/extensions')
|
||||
```
|
||||
|
||||
### Creating Custom Node Processors
|
||||
|
||||
To support a custom node type, create a processor class:
|
||||
|
||||
```python
|
||||
from workflow_params import NodeProcessor, register_processor
|
||||
|
||||
@register_processor
|
||||
class CustomNodeProcessor(NodeProcessor):
|
||||
"""Processor for CustomNode nodes"""
|
||||
|
||||
NODE_CLASS_TYPE = "CustomNode"
|
||||
REQUIRED_FIELDS = {"param1", "param2"}
|
||||
|
||||
def process(self, workflow_parser):
|
||||
result = {}
|
||||
|
||||
# Extract direct values
|
||||
if "param1" in self.inputs:
|
||||
result["value1"] = self.inputs["param1"]
|
||||
|
||||
# Resolve referenced inputs
|
||||
if "param2" in self.inputs:
|
||||
result["value2"] = self.resolve_input("param2", workflow_parser)
|
||||
|
||||
return result
|
||||
```
|
||||
|
||||
## Command Line Interface
|
||||
|
||||
A command-line interface is available for testing:
|
||||
|
||||
```bash
|
||||
python -m workflow_params.cli input_workflow.json -o output.json
|
||||
```
|
||||
|
||||
## Extension System
|
||||
|
||||
The module includes an extension system for dynamically loading node processors:
|
||||
|
||||
```python
|
||||
from workflow_params import get_extension_manager
|
||||
|
||||
# Get the extension manager
|
||||
manager = get_extension_manager()
|
||||
|
||||
# Load all extensions
|
||||
manager.load_all_extensions()
|
||||
|
||||
# Load a specific extension
|
||||
manager.load_extension('path/to/extension.py')
|
||||
```
|
||||
|
||||
Extensions should be placed in the `workflow_params/extensions` directory by default, or a custom directory can be specified.
|
||||
|
||||
## Supported Node Types
|
||||
|
||||
- KSampler
|
||||
- CLIPTextEncode
|
||||
- EmptyLatentImage
|
||||
- JoinStrings
|
||||
- StringConstantMultiline
|
||||
- CLIPSetLastLayer
|
||||
- TriggerWord Toggle (LoraManager)
|
||||
- Lora Loader (LoraManager)
|
||||
- Lora Stacker (LoraManager)
|
||||
14
py/workflow_params/__init__.py
Normal file
14
py/workflow_params/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
# This package contains modules for workflow parameter extraction and processing
|
||||
from .workflow_parser import WorkflowParser, parse_workflow
|
||||
from .extension_manager import ExtensionManager, get_extension_manager
|
||||
from .node_processors import NodeProcessor, NODE_PROCESSORS, register_processor
|
||||
|
||||
__all__ = [
|
||||
"WorkflowParser",
|
||||
"parse_workflow",
|
||||
"ExtensionManager",
|
||||
"get_extension_manager",
|
||||
"NodeProcessor",
|
||||
"NODE_PROCESSORS",
|
||||
"register_processor"
|
||||
]
|
||||
68
py/workflow_params/cli.py
Normal file
68
py/workflow_params/cli.py
Normal file
@@ -0,0 +1,68 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Command-line interface for testing the workflow parser"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import logging
|
||||
|
||||
from .workflow_parser import WorkflowParser
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def main():
|
||||
"""Main entry point for the command-line interface"""
|
||||
parser = argparse.ArgumentParser(description="Parse ComfyUI workflow JSON files")
|
||||
parser.add_argument("input_file", type=str, help="Path to input workflow JSON file")
|
||||
parser.add_argument("-o", "--output", type=str, help="Path to output JSON file (defaults to stdout)")
|
||||
parser.add_argument("-v", "--verbose", action="store_true", help="Enable verbose output")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Set log level based on verbosity
|
||||
if args.verbose:
|
||||
logging.getLogger().setLevel(logging.DEBUG)
|
||||
|
||||
# Read input file
|
||||
input_path = Path(args.input_file)
|
||||
if not input_path.exists():
|
||||
logger.error(f"Input file {input_path} does not exist")
|
||||
return 1
|
||||
|
||||
try:
|
||||
with open(input_path, 'r', encoding='utf-8') as f:
|
||||
workflow_json = f.read()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to read input file: {e}")
|
||||
return 1
|
||||
|
||||
# Parse workflow
|
||||
try:
|
||||
workflow_parser = WorkflowParser()
|
||||
result = workflow_parser.parse_workflow(workflow_json)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse workflow: {e}")
|
||||
return 1
|
||||
|
||||
# Output result
|
||||
output_json = json.dumps(result, indent=4)
|
||||
|
||||
if args.output:
|
||||
try:
|
||||
with open(args.output, 'w', encoding='utf-8') as f:
|
||||
f.write(output_json)
|
||||
logger.info(f"Output written to {args.output}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to write output file: {e}")
|
||||
return 1
|
||||
else:
|
||||
print(output_json)
|
||||
|
||||
return 0
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
163
py/workflow_params/extension_manager.py
Normal file
163
py/workflow_params/extension_manager.py
Normal file
@@ -0,0 +1,163 @@
|
||||
"""Module for dynamically loading node processor extensions"""
|
||||
|
||||
import os
|
||||
import importlib
|
||||
import importlib.util
|
||||
import logging
|
||||
import inspect
|
||||
from typing import Dict, Any, List, Set, Type
|
||||
from pathlib import Path
|
||||
|
||||
from .node_processors import NodeProcessor, NODE_PROCESSORS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ExtensionManager:
|
||||
"""Manager for dynamically loading node processor extensions"""
|
||||
|
||||
def __init__(self, extensions_dir: str = None):
|
||||
"""
|
||||
Initialize the extension manager
|
||||
|
||||
Args:
|
||||
extensions_dir: Optional path to a directory containing extensions
|
||||
If None, uses the default extensions directory
|
||||
"""
|
||||
if extensions_dir is None:
|
||||
# Use the default extensions directory
|
||||
module_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
self.extensions_dir = os.path.join(module_dir, "extensions")
|
||||
else:
|
||||
self.extensions_dir = extensions_dir
|
||||
|
||||
self.loaded_extensions: Dict[str, Any] = {}
|
||||
|
||||
def discover_extensions(self) -> List[str]:
|
||||
"""
|
||||
Discover available extensions in the extensions directory
|
||||
|
||||
Returns:
|
||||
List of extension file paths that can be loaded
|
||||
"""
|
||||
if not os.path.exists(self.extensions_dir):
|
||||
logger.warning(f"Extensions directory not found: {self.extensions_dir}")
|
||||
return []
|
||||
|
||||
extension_files = []
|
||||
|
||||
# Walk through the extensions directory
|
||||
for root, _, files in os.walk(self.extensions_dir):
|
||||
for filename in files:
|
||||
# Only consider Python files
|
||||
if filename.endswith('.py') and not filename.startswith('__'):
|
||||
filepath = os.path.join(root, filename)
|
||||
extension_files.append(filepath)
|
||||
|
||||
return extension_files
|
||||
|
||||
def load_extension(self, extension_path: str) -> bool:
|
||||
"""
|
||||
Load a single extension from a file path
|
||||
|
||||
Args:
|
||||
extension_path: Path to the extension file
|
||||
|
||||
Returns:
|
||||
True if loaded successfully, False otherwise
|
||||
"""
|
||||
if extension_path in self.loaded_extensions:
|
||||
logger.debug(f"Extension already loaded: {extension_path}")
|
||||
return True
|
||||
|
||||
try:
|
||||
# Get module name from file path
|
||||
module_name = os.path.basename(extension_path).replace(".py", "")
|
||||
|
||||
# Load the module
|
||||
spec = importlib.util.spec_from_file_location(module_name, extension_path)
|
||||
if spec is None or spec.loader is None:
|
||||
logger.error(f"Failed to load extension spec: {extension_path}")
|
||||
return False
|
||||
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
# Find NodeProcessor subclasses in the module
|
||||
processor_classes = []
|
||||
for _, obj in inspect.getmembers(module):
|
||||
if (inspect.isclass(obj) and
|
||||
issubclass(obj, NodeProcessor) and
|
||||
obj is not NodeProcessor):
|
||||
processor_classes.append(obj)
|
||||
|
||||
if not processor_classes:
|
||||
logger.warning(f"No NodeProcessor subclasses found in {extension_path}")
|
||||
return False
|
||||
|
||||
# Register each processor class
|
||||
for cls in processor_classes:
|
||||
cls.register()
|
||||
|
||||
# Store the loaded module
|
||||
self.loaded_extensions[extension_path] = module
|
||||
logger.info(f"Loaded extension: {extension_path} with {len(processor_classes)} processors")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load extension {extension_path}: {e}")
|
||||
return False
|
||||
|
||||
def load_all_extensions(self) -> Dict[str, bool]:
|
||||
"""
|
||||
Load all available extensions
|
||||
|
||||
Returns:
|
||||
Dict mapping extension paths to success/failure status
|
||||
"""
|
||||
extension_files = self.discover_extensions()
|
||||
results = {}
|
||||
|
||||
for extension_path in extension_files:
|
||||
results[extension_path] = self.load_extension(extension_path)
|
||||
|
||||
return results
|
||||
|
||||
def get_loaded_processor_types(self) -> Set[str]:
|
||||
"""
|
||||
Get the set of all loaded processor types
|
||||
|
||||
Returns:
|
||||
Set of class_type names for all loaded processors
|
||||
"""
|
||||
return set(NODE_PROCESSORS.keys())
|
||||
|
||||
def get_loaded_extension_count(self) -> int:
|
||||
"""
|
||||
Get the number of loaded extensions
|
||||
|
||||
Returns:
|
||||
Number of loaded extensions
|
||||
"""
|
||||
return len(self.loaded_extensions)
|
||||
|
||||
|
||||
# Create a singleton instance
|
||||
_extension_manager = None
|
||||
|
||||
def get_extension_manager(extensions_dir: str = None) -> ExtensionManager:
|
||||
"""
|
||||
Get the singleton ExtensionManager instance
|
||||
|
||||
Args:
|
||||
extensions_dir: Optional path to extensions directory
|
||||
|
||||
Returns:
|
||||
ExtensionManager instance
|
||||
"""
|
||||
global _extension_manager
|
||||
|
||||
if _extension_manager is None:
|
||||
_extension_manager = ExtensionManager(extensions_dir)
|
||||
|
||||
return _extension_manager
|
||||
2
py/workflow_params/extensions/__init__.py
Normal file
2
py/workflow_params/extensions/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# Extensions module for workflow parameter parsing
|
||||
# This module contains extensions for specific node types that may be loaded dynamically
|
||||
43
py/workflow_params/extensions/custom_node_example.py
Normal file
43
py/workflow_params/extensions/custom_node_example.py
Normal file
@@ -0,0 +1,43 @@
|
||||
"""
|
||||
Example of how to extend the workflow parser with custom node processors
|
||||
This file is not imported automatically - it serves as a template for creating extensions
|
||||
"""
|
||||
|
||||
from typing import Dict, Any
|
||||
from ..node_processors import NodeProcessor, register_processor
|
||||
|
||||
@register_processor
|
||||
class CustomNodeProcessor(NodeProcessor):
|
||||
"""Example processor for a custom node type"""
|
||||
|
||||
NODE_CLASS_TYPE = "CustomNodeType"
|
||||
REQUIRED_FIELDS = {"custom_field1", "custom_field2"}
|
||||
|
||||
def process(self, workflow_parser) -> Dict[str, Any]:
|
||||
"""Process a custom node"""
|
||||
# Example implementation
|
||||
result = {}
|
||||
|
||||
# Extract direct values
|
||||
if "custom_field1" in self.inputs:
|
||||
result["custom_value1"] = self.inputs["custom_field1"]
|
||||
|
||||
# Resolve references to other nodes
|
||||
if "custom_field2" in self.inputs:
|
||||
resolved_value = self.resolve_input("custom_field2", workflow_parser)
|
||||
if resolved_value:
|
||||
result["custom_value2"] = resolved_value
|
||||
|
||||
return result
|
||||
|
||||
# To use this extension, you would need to:
|
||||
# 1. Save this file in the extensions directory
|
||||
# 2. Import it in your code before using the WorkflowParser
|
||||
#
|
||||
# For example:
|
||||
#
|
||||
# from workflow_params.extensions import custom_node_example
|
||||
# from workflow_params import WorkflowParser
|
||||
#
|
||||
# parser = WorkflowParser()
|
||||
# result = parser.parse_workflow(workflow_json)
|
||||
116
py/workflow_params/integration_example.py
Normal file
116
py/workflow_params/integration_example.py
Normal file
@@ -0,0 +1,116 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Example of integrating the workflow parser with other modules"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import sys
|
||||
import logging
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
# Add the parent directory to the Python path if needed
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
||||
|
||||
from py.workflow_params import WorkflowParser
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def extract_and_save_workflow_params(workflow_path, output_path=None):
|
||||
"""
|
||||
Extract parameters from a workflow and save them to a file
|
||||
|
||||
Args:
|
||||
workflow_path: Path to the workflow JSON file
|
||||
output_path: Optional path to save the extracted parameters
|
||||
If None, prints the parameters to stdout
|
||||
|
||||
Returns:
|
||||
The extracted parameters
|
||||
"""
|
||||
# Ensure the workflow file exists
|
||||
if not os.path.exists(workflow_path):
|
||||
logger.error(f"Workflow file not found: {workflow_path}")
|
||||
return None
|
||||
|
||||
# Read the workflow file
|
||||
try:
|
||||
with open(workflow_path, 'r', encoding='utf-8') as f:
|
||||
workflow_json = f.read()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to read workflow file: {e}")
|
||||
return None
|
||||
|
||||
# Parse the workflow
|
||||
try:
|
||||
parser = WorkflowParser()
|
||||
params = parser.parse_workflow(workflow_json)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse workflow: {e}")
|
||||
return None
|
||||
|
||||
# Format the output
|
||||
output_json = json.dumps(params, indent=4)
|
||||
|
||||
# Save or print the output
|
||||
if output_path:
|
||||
try:
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
f.write(output_json)
|
||||
logger.info(f"Parameters saved to {output_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to write output file: {e}")
|
||||
else:
|
||||
print(output_json)
|
||||
|
||||
return params
|
||||
|
||||
def get_workflow_loras(workflow_path):
|
||||
"""
|
||||
Extract just the loras from a workflow
|
||||
|
||||
Args:
|
||||
workflow_path: Path to the workflow JSON file
|
||||
|
||||
Returns:
|
||||
List of lora names used in the workflow
|
||||
"""
|
||||
params = extract_and_save_workflow_params(workflow_path)
|
||||
if not params or "loras" not in params:
|
||||
return []
|
||||
|
||||
# Extract lora names from the lora strings
|
||||
lora_text = params["loras"]
|
||||
lora_names = []
|
||||
|
||||
# Parse the lora text format <lora:name:strength>
|
||||
lora_pattern = r'<lora:([^:]+):[^>]+>'
|
||||
matches = re.findall(lora_pattern, lora_text)
|
||||
|
||||
return matches
|
||||
|
||||
def main():
|
||||
"""Main example function"""
|
||||
# Check for command line arguments
|
||||
if len(sys.argv) < 2:
|
||||
print(f"Usage: {sys.argv[0]} <workflow_json_file> [output_file]")
|
||||
return 1
|
||||
|
||||
workflow_path = sys.argv[1]
|
||||
output_path = sys.argv[2] if len(sys.argv) > 2 else None
|
||||
|
||||
# Example 1: Extract and save all parameters
|
||||
params = extract_and_save_workflow_params(workflow_path, output_path)
|
||||
if not params:
|
||||
return 1
|
||||
|
||||
# Example 2: Get just the loras
|
||||
loras = get_workflow_loras(workflow_path)
|
||||
print(f"Loras used in the workflow: {loras}")
|
||||
|
||||
return 0
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
6
py/workflow_params/node_processors/__init__.py
Normal file
6
py/workflow_params/node_processors/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
# This module contains processors for different node types in a ComfyUI workflow
|
||||
|
||||
from .base_processor import NodeProcessor, NODE_PROCESSORS, register_processor
|
||||
from . import load_processors
|
||||
|
||||
__all__ = ["NodeProcessor", "NODE_PROCESSORS", "register_processor"]
|
||||
77
py/workflow_params/node_processors/base_processor.py
Normal file
77
py/workflow_params/node_processors/base_processor.py
Normal file
@@ -0,0 +1,77 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Any, List, Optional, Set, Callable, Type
|
||||
|
||||
# Registry to store node processors by class_type
|
||||
NODE_PROCESSORS: Dict[str, Type['NodeProcessor']] = {}
|
||||
|
||||
class NodeProcessor(ABC):
|
||||
"""Base class for node processors that extract information from workflow nodes"""
|
||||
|
||||
# Class-level attributes to define which node type this processor handles
|
||||
# and which fields should be extracted
|
||||
NODE_CLASS_TYPE: str = None
|
||||
REQUIRED_FIELDS: Set[str] = set()
|
||||
|
||||
def __init__(self, node_id: str, node_data: Dict[str, Any], workflow: Dict[str, Any]):
|
||||
"""
|
||||
Initialize a node processor
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node in the workflow
|
||||
node_data: The node data from the workflow
|
||||
workflow: The complete workflow data
|
||||
"""
|
||||
self.node_id = node_id
|
||||
self.node_data = node_data
|
||||
self.workflow = workflow
|
||||
self.inputs = node_data.get('inputs', {})
|
||||
|
||||
@classmethod
|
||||
def register(cls):
|
||||
"""Register this processor in the global registry"""
|
||||
if cls.NODE_CLASS_TYPE:
|
||||
NODE_PROCESSORS[cls.NODE_CLASS_TYPE] = cls
|
||||
|
||||
@abstractmethod
|
||||
def process(self, workflow_parser) -> Dict[str, Any]:
|
||||
"""
|
||||
Process the node and extract relevant information
|
||||
|
||||
Args:
|
||||
workflow_parser: The workflow parser instance for resolving node references
|
||||
|
||||
Returns:
|
||||
Dict containing extracted information from the node
|
||||
"""
|
||||
pass
|
||||
|
||||
def resolve_input(self, input_key: str, workflow_parser) -> Any:
|
||||
"""
|
||||
Resolve an input value which might be a reference to another node
|
||||
|
||||
Args:
|
||||
input_key: The input key to resolve
|
||||
workflow_parser: The workflow parser instance
|
||||
|
||||
Returns:
|
||||
The resolved value
|
||||
"""
|
||||
input_value = self.inputs.get(input_key)
|
||||
|
||||
# If not found, return None
|
||||
if input_value is None:
|
||||
return None
|
||||
|
||||
# If it's a list with node reference [node_id, slot_index]
|
||||
if isinstance(input_value, list) and len(input_value) == 2:
|
||||
ref_node_id, slot_index = input_value
|
||||
return workflow_parser.process_node(ref_node_id)
|
||||
|
||||
# Otherwise return the direct value
|
||||
return input_value
|
||||
|
||||
|
||||
def register_processor(cls):
|
||||
"""Decorator to register a node processor class"""
|
||||
cls.register()
|
||||
return cls
|
||||
@@ -0,0 +1,21 @@
|
||||
from typing import Dict, Any
|
||||
from .base_processor import NodeProcessor, register_processor
|
||||
|
||||
@register_processor
|
||||
class CLIPSetLastLayerProcessor(NodeProcessor):
|
||||
"""Processor for CLIPSetLastLayer nodes"""
|
||||
|
||||
NODE_CLASS_TYPE = "CLIPSetLastLayer"
|
||||
REQUIRED_FIELDS = {"stop_at_clip_layer", "clip"}
|
||||
|
||||
def process(self, workflow_parser) -> Dict[str, Any]:
|
||||
"""Process a CLIPSetLastLayer node to extract clip skip value"""
|
||||
if "stop_at_clip_layer" in self.inputs:
|
||||
# Convert to positive number for clip_skip
|
||||
layer = self.inputs["stop_at_clip_layer"]
|
||||
if isinstance(layer, (int, float)) and layer < 0:
|
||||
# CLIP skip is reported as a positive number
|
||||
# but stored as a negative layer index
|
||||
return {"clip_skip": str(abs(layer))}
|
||||
|
||||
return None
|
||||
@@ -0,0 +1,18 @@
|
||||
from typing import Dict, Any
|
||||
from .base_processor import NodeProcessor, register_processor
|
||||
|
||||
@register_processor
|
||||
class CLIPTextEncodeProcessor(NodeProcessor):
|
||||
"""Processor for CLIPTextEncode nodes"""
|
||||
|
||||
NODE_CLASS_TYPE = "CLIPTextEncode"
|
||||
REQUIRED_FIELDS = {"text", "clip"}
|
||||
|
||||
def process(self, workflow_parser) -> Dict[str, Any]:
|
||||
"""Process a CLIPTextEncode node to extract text prompt"""
|
||||
if "text" in self.inputs:
|
||||
# Text might be a direct string or a reference to another node
|
||||
text_value = self.resolve_input("text", workflow_parser)
|
||||
return text_value
|
||||
|
||||
return None
|
||||
@@ -0,0 +1,21 @@
|
||||
from typing import Dict, Any
|
||||
from .base_processor import NodeProcessor, register_processor
|
||||
|
||||
@register_processor
|
||||
class EmptyLatentImageProcessor(NodeProcessor):
|
||||
"""Processor for EmptyLatentImage nodes"""
|
||||
|
||||
NODE_CLASS_TYPE = "EmptyLatentImage"
|
||||
REQUIRED_FIELDS = {"width", "height", "batch_size"}
|
||||
|
||||
def process(self, workflow_parser) -> Dict[str, Any]:
|
||||
"""Process an EmptyLatentImage node to extract image dimensions"""
|
||||
result = {}
|
||||
|
||||
if "width" in self.inputs and "height" in self.inputs:
|
||||
width = self.inputs["width"]
|
||||
height = self.inputs["height"]
|
||||
result["width"] = width
|
||||
result["height"] = height
|
||||
|
||||
return result
|
||||
27
py/workflow_params/node_processors/join_strings_processor.py
Normal file
27
py/workflow_params/node_processors/join_strings_processor.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from typing import Dict, Any
|
||||
from .base_processor import NodeProcessor, register_processor
|
||||
|
||||
@register_processor
|
||||
class JoinStringsProcessor(NodeProcessor):
|
||||
"""Processor for JoinStrings nodes"""
|
||||
|
||||
NODE_CLASS_TYPE = "JoinStrings"
|
||||
REQUIRED_FIELDS = {"string1", "string2", "delimiter"}
|
||||
|
||||
def process(self, workflow_parser) -> Dict[str, Any]:
|
||||
"""Process a JoinStrings node to combine strings"""
|
||||
string1 = self.resolve_input("string1", workflow_parser)
|
||||
string2 = self.resolve_input("string2", workflow_parser)
|
||||
delimiter = self.inputs.get("delimiter", ", ")
|
||||
|
||||
if string1 is None and string2 is None:
|
||||
return None
|
||||
|
||||
if string1 is None:
|
||||
return string2
|
||||
|
||||
if string2 is None:
|
||||
return string1
|
||||
|
||||
# Join the strings with the delimiter
|
||||
return f"{string1}{delimiter}{string2}"
|
||||
46
py/workflow_params/node_processors/ksampler_processor.py
Normal file
46
py/workflow_params/node_processors/ksampler_processor.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from typing import Dict, Any, Set
|
||||
from .base_processor import NodeProcessor, register_processor
|
||||
|
||||
@register_processor
|
||||
class KSamplerProcessor(NodeProcessor):
|
||||
"""Processor for KSampler nodes"""
|
||||
|
||||
NODE_CLASS_TYPE = "KSampler"
|
||||
REQUIRED_FIELDS = {"seed", "steps", "cfg", "sampler_name", "scheduler", "denoise",
|
||||
"positive", "negative", "latent_image"}
|
||||
|
||||
def process(self, workflow_parser) -> Dict[str, Any]:
|
||||
"""Process a KSampler node to extract generation parameters"""
|
||||
result = {}
|
||||
|
||||
# Directly extract numeric parameters
|
||||
if "seed" in self.inputs:
|
||||
result["seed"] = str(self.inputs["seed"])
|
||||
|
||||
if "steps" in self.inputs:
|
||||
result["steps"] = str(self.inputs["steps"])
|
||||
|
||||
if "cfg" in self.inputs:
|
||||
result["cfg_scale"] = str(self.inputs["cfg"])
|
||||
|
||||
if "sampler_name" in self.inputs:
|
||||
result["sampler"] = self.inputs["sampler_name"]
|
||||
|
||||
# Resolve referenced inputs
|
||||
if "positive" in self.inputs:
|
||||
positive_text = self.resolve_input("positive", workflow_parser)
|
||||
if positive_text:
|
||||
result["prompt"] = positive_text
|
||||
|
||||
if "negative" in self.inputs:
|
||||
negative_text = self.resolve_input("negative", workflow_parser)
|
||||
if negative_text:
|
||||
result["negative_prompt"] = negative_text
|
||||
|
||||
# Resolve latent image for size
|
||||
if "latent_image" in self.inputs:
|
||||
latent_info = self.resolve_input("latent_image", workflow_parser)
|
||||
if latent_info and "width" in latent_info and "height" in latent_info:
|
||||
result["size"] = f"{latent_info['width']}x{latent_info['height']}"
|
||||
|
||||
return result
|
||||
15
py/workflow_params/node_processors/load_processors.py
Normal file
15
py/workflow_params/node_processors/load_processors.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""Module to load all node processors"""
|
||||
|
||||
# Import all processor types to register them
|
||||
from .ksampler_processor import KSamplerProcessor
|
||||
from .clip_text_encode_processor import CLIPTextEncodeProcessor
|
||||
from .empty_latent_image_processor import EmptyLatentImageProcessor
|
||||
from .join_strings_processor import JoinStringsProcessor
|
||||
from .string_constant_processor import StringConstantProcessor
|
||||
from .clip_set_last_layer_processor import CLIPSetLastLayerProcessor
|
||||
from .trigger_word_toggle_processor import TriggerWordToggleProcessor
|
||||
from .lora_loader_processor import LoraLoaderProcessor
|
||||
from .lora_stacker_processor import LoraStackerProcessor
|
||||
|
||||
# Update the node_processors/__init__.py to include this import
|
||||
# This ensures all processors are registered when the package is imported
|
||||
50
py/workflow_params/node_processors/lora_loader_processor.py
Normal file
50
py/workflow_params/node_processors/lora_loader_processor.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from typing import Dict, Any, List
|
||||
from .base_processor import NodeProcessor, register_processor
|
||||
|
||||
@register_processor
|
||||
class LoraLoaderProcessor(NodeProcessor):
|
||||
"""Processor for Lora Loader (LoraManager) nodes"""
|
||||
|
||||
NODE_CLASS_TYPE = "Lora Loader (LoraManager)"
|
||||
REQUIRED_FIELDS = {"loras", "text", "lora_stack"}
|
||||
|
||||
def process(self, workflow_parser) -> Dict[str, Any]:
|
||||
"""Process a Lora Loader node to extract lora text and stack"""
|
||||
result = {}
|
||||
|
||||
# Get the direct lora text
|
||||
if "text" in self.inputs:
|
||||
lora_text = self.inputs.get("text", "")
|
||||
result["lora_text"] = lora_text
|
||||
|
||||
# Process the loras array
|
||||
if "loras" in self.inputs:
|
||||
loras = self.inputs["loras"]
|
||||
active_loras = []
|
||||
|
||||
if isinstance(loras, list):
|
||||
for lora in loras:
|
||||
if (isinstance(lora, dict) and
|
||||
lora.get("active", False) and
|
||||
not lora.get("_isDummy", False) and
|
||||
"name" in lora and "strength" in lora):
|
||||
active_loras.append(f"<lora:{lora['name']}:{lora['strength']}>")
|
||||
|
||||
if active_loras:
|
||||
result["active_loras"] = " ".join(active_loras)
|
||||
|
||||
# Process the lora stack from a referenced node
|
||||
if "lora_stack" in self.inputs:
|
||||
stack_result = self.resolve_input("lora_stack", workflow_parser)
|
||||
if isinstance(stack_result, dict) and "lora_stack" in stack_result:
|
||||
# If we got a stack from another node, add it to our result
|
||||
if "active_loras" in result:
|
||||
result["active_loras"] = f"{stack_result['lora_stack']} {result['active_loras']}"
|
||||
else:
|
||||
result["active_loras"] = stack_result["lora_stack"]
|
||||
|
||||
# Combine all loras into one stack
|
||||
if "active_loras" in result:
|
||||
result["lora_stack"] = result["active_loras"]
|
||||
|
||||
return result
|
||||
52
py/workflow_params/node_processors/lora_stacker_processor.py
Normal file
52
py/workflow_params/node_processors/lora_stacker_processor.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from typing import Dict, Any, List
|
||||
from .base_processor import NodeProcessor, register_processor
|
||||
|
||||
@register_processor
|
||||
class LoraStackerProcessor(NodeProcessor):
|
||||
"""Processor for Lora Stacker (LoraManager) nodes"""
|
||||
|
||||
NODE_CLASS_TYPE = "Lora Stacker (LoraManager)"
|
||||
REQUIRED_FIELDS = {"loras", "text", "lora_stack"}
|
||||
|
||||
def process(self, workflow_parser) -> Dict[str, Any]:
|
||||
"""Process a Lora Stacker node to extract lora stack"""
|
||||
result = {}
|
||||
|
||||
# Get the direct lora text
|
||||
if "text" in self.inputs:
|
||||
lora_text = self.inputs.get("text", "")
|
||||
result["lora_text"] = lora_text
|
||||
|
||||
# Process the loras array
|
||||
if "loras" in self.inputs:
|
||||
loras = self.inputs["loras"]
|
||||
active_loras = []
|
||||
|
||||
if isinstance(loras, list):
|
||||
for lora in loras:
|
||||
if (isinstance(lora, dict) and
|
||||
lora.get("active", False) and
|
||||
not lora.get("_isDummy", False) and
|
||||
"name" in lora and "strength" in lora):
|
||||
active_loras.append(f"<lora:{lora['name']}:{lora['strength']}>")
|
||||
|
||||
if active_loras:
|
||||
result["active_loras"] = " ".join(active_loras)
|
||||
|
||||
# Process the lora stack from a referenced node
|
||||
if "lora_stack" in self.inputs:
|
||||
stack_result = self.resolve_input("lora_stack", workflow_parser)
|
||||
if isinstance(stack_result, dict) and "lora_stack" in stack_result:
|
||||
# If we got a stack from another node, add it to our result
|
||||
if "active_loras" in result:
|
||||
result["lora_stack"] = f"{result['active_loras']} {stack_result['lora_stack']}"
|
||||
else:
|
||||
result["lora_stack"] = stack_result["lora_stack"]
|
||||
elif "active_loras" in result:
|
||||
# If there was no stack from the referenced node but we have active loras
|
||||
result["lora_stack"] = result["active_loras"]
|
||||
elif "active_loras" in result:
|
||||
# If there's no lora_stack input but we have active loras
|
||||
result["lora_stack"] = result["active_loras"]
|
||||
|
||||
return result
|
||||
@@ -0,0 +1,22 @@
|
||||
from typing import Dict, Any
|
||||
from .base_processor import NodeProcessor, register_processor
|
||||
|
||||
@register_processor
|
||||
class StringConstantProcessor(NodeProcessor):
|
||||
"""Processor for StringConstantMultiline nodes"""
|
||||
|
||||
NODE_CLASS_TYPE = "StringConstantMultiline"
|
||||
REQUIRED_FIELDS = {"string", "strip_newlines"}
|
||||
|
||||
def process(self, workflow_parser) -> Dict[str, Any]:
|
||||
"""Process a StringConstantMultiline node to extract the string content"""
|
||||
if "string" in self.inputs:
|
||||
string_value = self.inputs["string"]
|
||||
strip_newlines = self.inputs.get("strip_newlines", False)
|
||||
|
||||
if strip_newlines and isinstance(string_value, str):
|
||||
string_value = string_value.replace("\n", " ")
|
||||
|
||||
return string_value
|
||||
|
||||
return None
|
||||
@@ -0,0 +1,33 @@
|
||||
from typing import Dict, Any, List
|
||||
from .base_processor import NodeProcessor, register_processor
|
||||
|
||||
@register_processor
|
||||
class TriggerWordToggleProcessor(NodeProcessor):
|
||||
"""Processor for TriggerWord Toggle (LoraManager) nodes"""
|
||||
|
||||
NODE_CLASS_TYPE = "TriggerWord Toggle (LoraManager)"
|
||||
REQUIRED_FIELDS = {"toggle_trigger_words", "group_mode"}
|
||||
|
||||
def process(self, workflow_parser) -> Dict[str, Any]:
|
||||
"""Process a TriggerWord Toggle node to extract active trigger words"""
|
||||
if "toggle_trigger_words" not in self.inputs:
|
||||
return None
|
||||
|
||||
toggle_words = self.inputs["toggle_trigger_words"]
|
||||
if not isinstance(toggle_words, list):
|
||||
return None
|
||||
|
||||
# Filter active trigger words that aren't dummy items
|
||||
active_words = []
|
||||
for word_entry in toggle_words:
|
||||
if (isinstance(word_entry, dict) and
|
||||
word_entry.get("active", False) and
|
||||
not word_entry.get("_isDummy", False) and
|
||||
"text" in word_entry):
|
||||
active_words.append(word_entry["text"])
|
||||
|
||||
if not active_words:
|
||||
return None
|
||||
|
||||
# Join all active trigger words with a comma
|
||||
return ", ".join(active_words)
|
||||
63
py/workflow_params/simple_test.py
Normal file
63
py/workflow_params/simple_test.py
Normal file
@@ -0,0 +1,63 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Simple test script for the workflow parser"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Get project path
|
||||
project_path = Path(__file__).parent.parent.parent
|
||||
refs_path = project_path / "refs"
|
||||
prompt_path = refs_path / "prompt.json"
|
||||
output_path = refs_path / "output.json"
|
||||
|
||||
print(f"Loading workflow from {prompt_path}")
|
||||
print(f"Expected output from {output_path}")
|
||||
|
||||
# Load the workflow JSON
|
||||
with open(prompt_path, 'r', encoding='utf-8') as f:
|
||||
workflow_json = json.load(f)
|
||||
|
||||
# Load the expected output
|
||||
with open(output_path, 'r', encoding='utf-8') as f:
|
||||
expected_output = json.load(f)
|
||||
|
||||
print("\nExpected output:")
|
||||
print(json.dumps(expected_output, indent=2))
|
||||
|
||||
# Manually extract important parameters to verify our understanding
|
||||
sampler_node_id = "3"
|
||||
sampler_node = workflow_json.get(sampler_node_id, {})
|
||||
print("\nSampler node:")
|
||||
print(json.dumps(sampler_node, indent=2))
|
||||
|
||||
# Extract seed, steps, cfg
|
||||
seed = sampler_node.get("inputs", {}).get("seed")
|
||||
steps = sampler_node.get("inputs", {}).get("steps")
|
||||
cfg = sampler_node.get("inputs", {}).get("cfg")
|
||||
|
||||
print(f"\nExtracted parameters:")
|
||||
print(f"seed: {seed}")
|
||||
print(f"steps: {steps}")
|
||||
print(f"cfg_scale: {cfg}")
|
||||
|
||||
# Extract positive prompt - this requires following node references
|
||||
positive_ref = sampler_node.get("inputs", {}).get("positive", [])
|
||||
if isinstance(positive_ref, list) and len(positive_ref) == 2:
|
||||
positive_node_id, slot_index = positive_ref
|
||||
positive_node = workflow_json.get(positive_node_id, {})
|
||||
|
||||
print(f"\nPositive node ({positive_node_id}):")
|
||||
print(json.dumps(positive_node, indent=2))
|
||||
|
||||
# Follow the reference to the text value
|
||||
text_ref = positive_node.get("inputs", {}).get("text", [])
|
||||
if isinstance(text_ref, list) and len(text_ref) == 2:
|
||||
text_node_id, slot_index = text_ref
|
||||
text_node = workflow_json.get(text_node_id, {})
|
||||
|
||||
print(f"\nText node ({text_node_id}):")
|
||||
print(json.dumps(text_node, indent=2))
|
||||
|
||||
print("\nTest completed.")
|
||||
80
py/workflow_params/test_parser.py
Normal file
80
py/workflow_params/test_parser.py
Normal file
@@ -0,0 +1,80 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Test script for the workflow parser"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from .workflow_parser import WorkflowParser
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def test_parse_example():
|
||||
"""Test parsing the example prompt.json file and compare with expected output"""
|
||||
# Get the project root directory
|
||||
project_root = Path(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
||||
|
||||
# Path to the example files
|
||||
prompt_path = project_root / "refs" / "prompt.json"
|
||||
output_path = project_root / "refs" / "output.json"
|
||||
|
||||
# Ensure the files exist
|
||||
if not prompt_path.exists():
|
||||
logger.error(f"Example prompt file not found: {prompt_path}")
|
||||
return False
|
||||
|
||||
if not output_path.exists():
|
||||
logger.error(f"Example output file not found: {output_path}")
|
||||
return False
|
||||
|
||||
# Load the files
|
||||
try:
|
||||
with open(prompt_path, 'r', encoding='utf-8') as f:
|
||||
prompt_json = f.read()
|
||||
|
||||
with open(output_path, 'r', encoding='utf-8') as f:
|
||||
expected_output = json.load(f)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to read example files: {e}")
|
||||
return False
|
||||
|
||||
# Parse the workflow
|
||||
parser = WorkflowParser()
|
||||
result = parser.parse_workflow(prompt_json)
|
||||
|
||||
# Display the result
|
||||
logger.info("Parsed workflow:")
|
||||
logger.info(json.dumps(result, indent=4))
|
||||
|
||||
# Compare with expected output
|
||||
logger.info("Expected output:")
|
||||
logger.info(json.dumps(expected_output, indent=4))
|
||||
|
||||
# Basic validation
|
||||
if "loras" not in result:
|
||||
logger.error("Missing 'loras' field in result")
|
||||
return False
|
||||
|
||||
if "gen_params" not in result:
|
||||
logger.error("Missing 'gen_params' field in result")
|
||||
return False
|
||||
|
||||
required_params = [
|
||||
"prompt", "negative_prompt", "steps", "sampler",
|
||||
"cfg_scale", "seed", "size", "clip_skip"
|
||||
]
|
||||
|
||||
for param in required_params:
|
||||
if param not in result["gen_params"]:
|
||||
logger.error(f"Missing '{param}' in gen_params")
|
||||
return False
|
||||
|
||||
logger.info("Test completed successfully!")
|
||||
return True
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_parse_example()
|
||||
106
py/workflow_params/verify_workflow.py
Normal file
106
py/workflow_params/verify_workflow.py
Normal file
@@ -0,0 +1,106 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Script to verify the workflow structure and save the output to a file"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
# Get project path
|
||||
project_path = Path(__file__).parent.parent.parent
|
||||
refs_path = project_path / "refs"
|
||||
prompt_path = refs_path / "prompt.json"
|
||||
output_path = refs_path / "output.json"
|
||||
test_output_path = refs_path / "test_output.txt"
|
||||
|
||||
# Load the workflow JSON
|
||||
with open(prompt_path, 'r', encoding='utf-8') as f:
|
||||
workflow_json = json.load(f)
|
||||
|
||||
# Load the expected output
|
||||
with open(output_path, 'r', encoding='utf-8') as f:
|
||||
expected_output = json.load(f)
|
||||
|
||||
# Open the output file
|
||||
with open(test_output_path, 'w', encoding='utf-8') as f:
|
||||
f.write(f"Loading workflow from {prompt_path}\n")
|
||||
f.write(f"Expected output from {output_path}\n\n")
|
||||
|
||||
f.write("Expected output:\n")
|
||||
f.write(json.dumps(expected_output, indent=2) + "\n\n")
|
||||
|
||||
# Manually extract important parameters
|
||||
sampler_node_id = "3"
|
||||
sampler_node = workflow_json.get(sampler_node_id, {})
|
||||
f.write("Sampler node:\n")
|
||||
f.write(json.dumps(sampler_node, indent=2) + "\n\n")
|
||||
|
||||
# Extract seed, steps, cfg
|
||||
seed = sampler_node.get("inputs", {}).get("seed")
|
||||
steps = sampler_node.get("inputs", {}).get("steps")
|
||||
cfg = sampler_node.get("inputs", {}).get("cfg")
|
||||
|
||||
f.write(f"Extracted parameters:\n")
|
||||
f.write(f"seed: {seed}\n")
|
||||
f.write(f"steps: {steps}\n")
|
||||
f.write(f"cfg_scale: {cfg}\n\n")
|
||||
|
||||
# Extract positive prompt - this requires following node references
|
||||
positive_ref = sampler_node.get("inputs", {}).get("positive", [])
|
||||
if isinstance(positive_ref, list) and len(positive_ref) == 2:
|
||||
positive_node_id, slot_index = positive_ref
|
||||
positive_node = workflow_json.get(positive_node_id, {})
|
||||
|
||||
f.write(f"Positive node ({positive_node_id}):\n")
|
||||
f.write(json.dumps(positive_node, indent=2) + "\n\n")
|
||||
|
||||
# Follow the reference to the text value
|
||||
text_ref = positive_node.get("inputs", {}).get("text", [])
|
||||
if isinstance(text_ref, list) and len(text_ref) == 2:
|
||||
text_node_id, slot_index = text_ref
|
||||
text_node = workflow_json.get(text_node_id, {})
|
||||
|
||||
f.write(f"Text node ({text_node_id}):\n")
|
||||
f.write(json.dumps(text_node, indent=2) + "\n\n")
|
||||
|
||||
# If the text node is a JoinStrings node, follow its inputs
|
||||
if text_node.get("class_type") == "JoinStrings":
|
||||
string1_ref = text_node.get("inputs", {}).get("string1", [])
|
||||
string2_ref = text_node.get("inputs", {}).get("string2", [])
|
||||
|
||||
if isinstance(string1_ref, list) and len(string1_ref) == 2:
|
||||
string1_node_id, slot_index = string1_ref
|
||||
string1_node = workflow_json.get(string1_node_id, {})
|
||||
|
||||
f.write(f"String1 node ({string1_node_id}):\n")
|
||||
f.write(json.dumps(string1_node, indent=2) + "\n\n")
|
||||
|
||||
if isinstance(string2_ref, list) and len(string2_ref) == 2:
|
||||
string2_node_id, slot_index = string2_ref
|
||||
string2_node = workflow_json.get(string2_node_id, {})
|
||||
|
||||
f.write(f"String2 node ({string2_node_id}):\n")
|
||||
f.write(json.dumps(string2_node, indent=2) + "\n\n")
|
||||
|
||||
# Extract negative prompt
|
||||
negative_ref = sampler_node.get("inputs", {}).get("negative", [])
|
||||
if isinstance(negative_ref, list) and len(negative_ref) == 2:
|
||||
negative_node_id, slot_index = negative_ref
|
||||
negative_node = workflow_json.get(negative_node_id, {})
|
||||
|
||||
f.write(f"Negative node ({negative_node_id}):\n")
|
||||
f.write(json.dumps(negative_node, indent=2) + "\n\n")
|
||||
|
||||
# Extract LoRA information
|
||||
lora_nodes = []
|
||||
for node_id, node_data in workflow_json.items():
|
||||
if node_data.get("class_type") in ["Lora Loader (LoraManager)", "Lora Stacker (LoraManager)"]:
|
||||
lora_nodes.append((node_id, node_data))
|
||||
|
||||
f.write(f"LoRA nodes ({len(lora_nodes)}):\n")
|
||||
for node_id, node_data in lora_nodes:
|
||||
f.write(f"\nLoRA node {node_id}:\n")
|
||||
f.write(json.dumps(node_data, indent=2) + "\n")
|
||||
|
||||
f.write("\nTest completed.\n")
|
||||
|
||||
print(f"Test output written to {test_output_path}")
|
||||
209
py/workflow_params/workflow_parser.py
Normal file
209
py/workflow_params/workflow_parser.py
Normal file
@@ -0,0 +1,209 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, Any, List, Optional, Set, Union
|
||||
from .node_processors import NODE_PROCESSORS, NodeProcessor
|
||||
from .extension_manager import get_extension_manager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class WorkflowParser:
|
||||
"""Parser for ComfyUI workflow JSON files"""
|
||||
|
||||
def __init__(self, load_extensions: bool = True, extensions_dir: str = None):
|
||||
"""
|
||||
Initialize the workflow parser
|
||||
|
||||
Args:
|
||||
load_extensions: Whether to load extensions automatically
|
||||
extensions_dir: Optional path to extensions directory
|
||||
"""
|
||||
self.workflow = None
|
||||
self.processed_nodes = {} # Cache for processed nodes
|
||||
self.processing_nodes = set() # To detect circular references
|
||||
|
||||
# Load extensions if requested
|
||||
if load_extensions:
|
||||
self._load_extensions(extensions_dir)
|
||||
|
||||
def _load_extensions(self, extensions_dir: str = None):
|
||||
"""
|
||||
Load node processor extensions
|
||||
|
||||
Args:
|
||||
extensions_dir: Optional path to extensions directory
|
||||
"""
|
||||
extension_manager = get_extension_manager(extensions_dir)
|
||||
results = extension_manager.load_all_extensions()
|
||||
|
||||
# Log the results
|
||||
successful = sum(1 for status in results.values() if status)
|
||||
logger.debug(f"Loaded {successful} of {len(results)} extensions")
|
||||
|
||||
def parse_workflow(self, workflow_json: Union[str, Dict]) -> Dict[str, Any]:
|
||||
"""
|
||||
Parse a ComfyUI workflow JSON string or dict and extract generation parameters
|
||||
|
||||
Args:
|
||||
workflow_json: JSON string or dict containing the workflow
|
||||
|
||||
Returns:
|
||||
Dict containing extracted generation parameters
|
||||
"""
|
||||
# Reset state for this parsing operation
|
||||
self.processed_nodes = {}
|
||||
self.processing_nodes = set()
|
||||
|
||||
# Load JSON if it's a string
|
||||
if isinstance(workflow_json, str):
|
||||
try:
|
||||
self.workflow = json.loads(workflow_json)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Failed to parse workflow JSON: {e}")
|
||||
return {}
|
||||
else:
|
||||
self.workflow = workflow_json
|
||||
|
||||
if not self.workflow:
|
||||
return {}
|
||||
|
||||
# Find KSampler nodes as entry points
|
||||
ksampler_nodes = self._find_nodes_by_class("KSampler")
|
||||
|
||||
# Find LoraLoader nodes for lora information
|
||||
lora_nodes = self._find_nodes_by_class("Lora Loader (LoraManager)")
|
||||
|
||||
# Check if we need to register additional node types by scanning the workflow
|
||||
self._check_for_unregistered_node_types()
|
||||
|
||||
result = {
|
||||
"gen_params": {}
|
||||
}
|
||||
|
||||
# Process KSampler nodes to get generation parameters
|
||||
for node_id in ksampler_nodes:
|
||||
gen_params = self.process_node(node_id)
|
||||
if gen_params:
|
||||
result["gen_params"].update(gen_params)
|
||||
|
||||
# Process Lora nodes to get lora stack
|
||||
lora_stack = ""
|
||||
for node_id in lora_nodes:
|
||||
lora_info = self.process_node(node_id)
|
||||
if lora_info and "lora_stack" in lora_info:
|
||||
if lora_stack:
|
||||
lora_stack = f"{lora_stack} {lora_info['lora_stack']}"
|
||||
else:
|
||||
lora_stack = lora_info["lora_stack"]
|
||||
|
||||
if lora_stack:
|
||||
result["loras"] = lora_stack
|
||||
|
||||
# Process CLIPSetLastLayer node for clip_skip
|
||||
clip_layer_nodes = self._find_nodes_by_class("CLIPSetLastLayer")
|
||||
for node_id in clip_layer_nodes:
|
||||
clip_info = self.process_node(node_id)
|
||||
if clip_info and "clip_skip" in clip_info:
|
||||
result["gen_params"]["clip_skip"] = clip_info["clip_skip"]
|
||||
|
||||
return result
|
||||
|
||||
def _check_for_unregistered_node_types(self):
|
||||
"""Check for node types in the workflow that aren't registered yet"""
|
||||
unknown_node_types = set()
|
||||
|
||||
# Collect all unique node types in the workflow
|
||||
for node_id, node_data in self.workflow.items():
|
||||
class_type = node_data.get("class_type")
|
||||
if class_type and class_type not in NODE_PROCESSORS:
|
||||
unknown_node_types.add(class_type)
|
||||
|
||||
if unknown_node_types:
|
||||
logger.debug(f"Found {len(unknown_node_types)} unregistered node types: {unknown_node_types}")
|
||||
|
||||
def process_node(self, node_id: str) -> Any:
|
||||
"""
|
||||
Process a single node and its dependencies recursively
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node to process
|
||||
|
||||
Returns:
|
||||
Processed data from the node
|
||||
"""
|
||||
# Check if already processed
|
||||
if node_id in self.processed_nodes:
|
||||
return self.processed_nodes[node_id]
|
||||
|
||||
# Check for circular references
|
||||
if node_id in self.processing_nodes:
|
||||
logger.warning(f"Circular reference detected for node {node_id}")
|
||||
return None
|
||||
|
||||
# Mark as being processed
|
||||
self.processing_nodes.add(node_id)
|
||||
|
||||
# Get node data
|
||||
node_data = self.workflow.get(node_id)
|
||||
if not node_data:
|
||||
logger.warning(f"Node {node_id} not found in workflow")
|
||||
self.processing_nodes.remove(node_id)
|
||||
return None
|
||||
|
||||
class_type = node_data.get("class_type")
|
||||
if not class_type:
|
||||
logger.warning(f"Node {node_id} has no class_type")
|
||||
self.processing_nodes.remove(node_id)
|
||||
return None
|
||||
|
||||
# Get the appropriate node processor
|
||||
processor_class = NODE_PROCESSORS.get(class_type)
|
||||
if not processor_class:
|
||||
logger.debug(f"No processor for node type {class_type}")
|
||||
self.processing_nodes.remove(node_id)
|
||||
return None
|
||||
|
||||
# Process the node
|
||||
processor = processor_class(node_id, node_data, self.workflow)
|
||||
result = processor.process(self)
|
||||
|
||||
# Cache the result
|
||||
self.processed_nodes[node_id] = result
|
||||
|
||||
# Mark as processed
|
||||
self.processing_nodes.remove(node_id)
|
||||
|
||||
return result
|
||||
|
||||
def _find_nodes_by_class(self, class_type: str) -> List[str]:
|
||||
"""
|
||||
Find all nodes of a particular class type in the workflow
|
||||
|
||||
Args:
|
||||
class_type: The node class type to find
|
||||
|
||||
Returns:
|
||||
List of node IDs matching the class type
|
||||
"""
|
||||
nodes = []
|
||||
for node_id, node_data in self.workflow.items():
|
||||
if node_data.get("class_type") == class_type:
|
||||
nodes.append(node_id)
|
||||
return nodes
|
||||
|
||||
|
||||
def parse_workflow(workflow_json: Union[str, Dict],
|
||||
load_extensions: bool = True,
|
||||
extensions_dir: str = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Helper function to parse a workflow JSON without having to create a parser instance
|
||||
|
||||
Args:
|
||||
workflow_json: JSON string or dict containing the workflow
|
||||
load_extensions: Whether to load extensions automatically
|
||||
extensions_dir: Optional path to extensions directory
|
||||
|
||||
Returns:
|
||||
Dict containing extracted generation parameters
|
||||
"""
|
||||
parser = WorkflowParser(load_extensions, extensions_dir)
|
||||
return parser.parse_workflow(workflow_json)
|
||||
Reference in New Issue
Block a user