Enhance workflow parsing and node mapper registration

- Introduced a new WorkflowParser class to streamline workflow parsing and manage node mappers.
- Added functionality to load external mappers dynamically from a specified directory.
- Refactored LoraLoaderMapper and LoraStackerMapper to handle new data formats for loras and trigger words.
- Updated recipe routes to utilize the new WorkflowParser for parsing workflows.
- Made adjustments to the flux_prompt.json to reflect changes in active states and class types.
This commit is contained in:
Will Miao
2025-03-23 05:21:43 +08:00
parent 3da35cf0db
commit 2b67091986
7 changed files with 463 additions and 84 deletions

View File

@@ -2,10 +2,16 @@
Node mappers for ComfyUI workflow parsing
"""
import logging
from typing import Dict, List, Any, Optional, Union
import os
import importlib.util
import inspect
from typing import Dict, List, Any, Optional, Union, Type, Callable
logger = logging.getLogger(__name__)
# Global mapper registry
_MAPPER_REGISTRY: Dict[str, 'NodeMapper'] = {}
class NodeMapper:
"""Base class for node mappers that define how to extract information from a specific node type"""
@@ -130,32 +136,43 @@ class LoraLoaderMapper(NodeMapper):
def __init__(self):
super().__init__(
node_type="Lora Loader (LoraManager)",
inputs_to_track=["text", "loras", "lora_stack"]
inputs_to_track=["loras", "lora_stack"]
)
def transform(self, inputs: Dict) -> Dict:
lora_text = inputs.get("text", "")
# Fallback to loras array if text field doesn't exist or is invalid
loras_data = inputs.get("loras", [])
lora_stack = inputs.get("lora_stack", [])
# Process lora_stack if it exists
stack_text = ""
if lora_stack:
# Handle the formatted lora_stack info if available
stack_loras = []
for lora_path, strength, _ in lora_stack:
lora_name = lora_path.split(os.sep)[-1].split('.')[0]
stack_loras.append(f"<lora:{lora_name}:{strength}>")
stack_text = " ".join(stack_loras)
# Combine lora_text and stack_text
combined_text = lora_text
if stack_text:
combined_text = f"{combined_text} {stack_text}" if combined_text else stack_text
# Process loras array - filter active entries
lora_texts = []
# Format loras with spaces between them
if combined_text:
# Replace consecutive closing and opening tags with a space
combined_text = combined_text.replace("><", "> <")
# Check if loras_data is a list or a dict with __value__ key (new format)
if isinstance(loras_data, dict) and "__value__" in loras_data:
loras_list = loras_data["__value__"]
elif isinstance(loras_data, list):
loras_list = loras_data
else:
loras_list = []
# Process each active lora entry
for lora in loras_list:
if isinstance(lora, dict) and lora.get("active", False):
lora_name = lora.get("name", "")
strength = lora.get("strength", 1.0)
if lora_name and not lora_name.startswith("__dummy"):
lora_texts.append(f"<lora:{lora_name}:{strength}>")
# Process lora_stack if it exists
if lora_stack:
# Format each entry from the stack
for lora_path, strength, _ in lora_stack:
lora_name = os.path.basename(lora_path).split('.')[0]
if lora_name and not lora_name.startswith("__dummy"):
lora_texts.append(f"<lora:{lora_name}:{strength}>")
# Join with spaces
combined_text = " ".join(lora_texts)
return {"loras": combined_text}
@@ -170,8 +187,34 @@ class LoraStackerMapper(NodeMapper):
)
def transform(self, inputs: Dict) -> Dict:
# Return the lora_stack information
return inputs.get("lora_stack", [])
loras_data = inputs.get("loras", [])
existing_stack = inputs.get("lora_stack", [])
result_stack = []
# Keep existing stack entries
if existing_stack:
result_stack.extend(existing_stack)
# Process loras array - filter active entries
# Check if loras_data is a list or a dict with __value__ key (new format)
if isinstance(loras_data, dict) and "__value__" in loras_data:
loras_list = loras_data["__value__"]
elif isinstance(loras_data, list):
loras_list = loras_data
else:
loras_list = []
# Process each active lora entry
for lora in loras_list:
if isinstance(lora, dict) and lora.get("active", False):
lora_name = lora.get("name", "")
strength = float(lora.get("strength", 1.0))
if lora_name and not lora_name.startswith("__dummy"):
# Here we would need the real path, but as a fallback use the name
# In a real implementation, this would require looking up the file path
result_stack.append((lora_name, strength, strength))
return {"lora_stack": result_stack}
class JoinStringsMapper(NodeMapper):
@@ -209,19 +252,31 @@ class TriggerWordToggleMapper(NodeMapper):
def __init__(self):
super().__init__(
node_type="TriggerWord Toggle (LoraManager)",
inputs_to_track=["toggle_trigger_words", "orinalMessage", "trigger_words"]
inputs_to_track=["toggle_trigger_words"]
)
def transform(self, inputs: Dict) -> str:
# Get the original message or toggled trigger words
original_message = inputs.get("orinalMessage", "") or inputs.get("trigger_words", "")
toggle_data = inputs.get("toggle_trigger_words", [])
# check if toggle_words is a list or a dict with __value__ key (new format)
if isinstance(toggle_data, dict) and "__value__" in toggle_data:
toggle_words = toggle_data["__value__"]
elif isinstance(toggle_data, list):
toggle_words = toggle_data
else:
toggle_words = []
# Fix double commas to match the reference output format
if original_message:
# Replace double commas with single commas
original_message = original_message.replace(",, ", ", ")
return original_message
# Filter active trigger words
active_words = []
for item in toggle_words:
if isinstance(item, dict) and item.get("active", False):
word = item.get("text", "")
if word and not word.startswith("__dummy"):
active_words.append(word)
# Join with commas
result = ", ".join(active_words)
return result
class FluxGuidanceMapper(NodeMapper):
@@ -251,5 +306,89 @@ class FluxGuidanceMapper(NodeMapper):
return result
# Add import os for LoraLoaderMapper to work properly
import os
# =============================================================================
# Mapper Registry Functions
# =============================================================================
def register_mapper(mapper: NodeMapper) -> None:
"""Register a node mapper in the global registry"""
_MAPPER_REGISTRY[mapper.node_type] = mapper
logger.debug(f"Registered mapper for node type: {mapper.node_type}")
def get_mapper(node_type: str) -> Optional[NodeMapper]:
"""Get a mapper for the specified node type"""
return _MAPPER_REGISTRY.get(node_type)
def get_all_mappers() -> Dict[str, NodeMapper]:
"""Get all registered mappers"""
return _MAPPER_REGISTRY.copy()
def register_default_mappers() -> None:
"""Register all default mappers"""
default_mappers = [
KSamplerMapper(),
EmptyLatentImageMapper(),
EmptySD3LatentImageMapper(),
CLIPTextEncodeMapper(),
LoraLoaderMapper(),
LoraStackerMapper(),
JoinStringsMapper(),
StringConstantMapper(),
TriggerWordToggleMapper(),
FluxGuidanceMapper()
]
for mapper in default_mappers:
register_mapper(mapper)
# =============================================================================
# Extension Loading
# =============================================================================
def load_extensions(ext_dir: str = None) -> None:
"""
Load mapper extensions from the specified directory
Each Python file in the directory will be loaded, and any NodeMapper subclasses
defined in those files will be automatically registered.
"""
# Use default path if none provided
if ext_dir is None:
# Get the directory of this file
current_dir = os.path.dirname(os.path.abspath(__file__))
ext_dir = os.path.join(current_dir, 'ext')
# Ensure the extension directory exists
if not os.path.exists(ext_dir):
os.makedirs(ext_dir, exist_ok=True)
logger.info(f"Created extension directory: {ext_dir}")
return
# Load each Python file in the extension directory
for filename in os.listdir(ext_dir):
if filename.endswith('.py') and not filename.startswith('_'):
module_path = os.path.join(ext_dir, filename)
module_name = f"workflow.ext.{filename[:-3]}" # Remove .py
try:
# Load the module
spec = importlib.util.spec_from_file_location(module_name, module_path)
if spec and spec.loader:
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
# Find all NodeMapper subclasses in the module
for name, obj in inspect.getmembers(module):
if (inspect.isclass(obj) and issubclass(obj, NodeMapper)
and obj != NodeMapper and hasattr(obj, 'node_type')):
# Instantiate and register the mapper
mapper = obj()
register_mapper(mapper)
logger.info(f"Loaded extension mapper: {mapper.node_type} from {filename}")
except Exception as e:
logger.error(f"Error loading extension {filename}: {e}")
# Initialize the registry with default mappers
register_default_mappers()