mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-24 22:52:12 -03:00
Enhance workflow parsing and node mapper registration
- Introduced a new WorkflowParser class to streamline workflow parsing and manage node mappers. - Added functionality to load external mappers dynamically from a specified directory. - Refactored LoraLoaderMapper and LoraStackerMapper to handle new data formats for loras and trigger words. - Updated recipe routes to utilize the new WorkflowParser for parsing workflows. - Made adjustments to the flux_prompt.json to reflect changes in active states and class types.
This commit is contained in:
@@ -2,10 +2,16 @@
|
||||
Node mappers for ComfyUI workflow parsing
|
||||
"""
|
||||
import logging
|
||||
from typing import Dict, List, Any, Optional, Union
|
||||
import os
|
||||
import importlib.util
|
||||
import inspect
|
||||
from typing import Dict, List, Any, Optional, Union, Type, Callable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global mapper registry
|
||||
_MAPPER_REGISTRY: Dict[str, 'NodeMapper'] = {}
|
||||
|
||||
class NodeMapper:
|
||||
"""Base class for node mappers that define how to extract information from a specific node type"""
|
||||
|
||||
@@ -130,32 +136,43 @@ class LoraLoaderMapper(NodeMapper):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
node_type="Lora Loader (LoraManager)",
|
||||
inputs_to_track=["text", "loras", "lora_stack"]
|
||||
inputs_to_track=["loras", "lora_stack"]
|
||||
)
|
||||
|
||||
def transform(self, inputs: Dict) -> Dict:
|
||||
lora_text = inputs.get("text", "")
|
||||
# Fallback to loras array if text field doesn't exist or is invalid
|
||||
loras_data = inputs.get("loras", [])
|
||||
lora_stack = inputs.get("lora_stack", [])
|
||||
|
||||
# Process lora_stack if it exists
|
||||
stack_text = ""
|
||||
if lora_stack:
|
||||
# Handle the formatted lora_stack info if available
|
||||
stack_loras = []
|
||||
for lora_path, strength, _ in lora_stack:
|
||||
lora_name = lora_path.split(os.sep)[-1].split('.')[0]
|
||||
stack_loras.append(f"<lora:{lora_name}:{strength}>")
|
||||
stack_text = " ".join(stack_loras)
|
||||
|
||||
# Combine lora_text and stack_text
|
||||
combined_text = lora_text
|
||||
if stack_text:
|
||||
combined_text = f"{combined_text} {stack_text}" if combined_text else stack_text
|
||||
# Process loras array - filter active entries
|
||||
lora_texts = []
|
||||
|
||||
# Format loras with spaces between them
|
||||
if combined_text:
|
||||
# Replace consecutive closing and opening tags with a space
|
||||
combined_text = combined_text.replace("><", "> <")
|
||||
# Check if loras_data is a list or a dict with __value__ key (new format)
|
||||
if isinstance(loras_data, dict) and "__value__" in loras_data:
|
||||
loras_list = loras_data["__value__"]
|
||||
elif isinstance(loras_data, list):
|
||||
loras_list = loras_data
|
||||
else:
|
||||
loras_list = []
|
||||
|
||||
# Process each active lora entry
|
||||
for lora in loras_list:
|
||||
if isinstance(lora, dict) and lora.get("active", False):
|
||||
lora_name = lora.get("name", "")
|
||||
strength = lora.get("strength", 1.0)
|
||||
if lora_name and not lora_name.startswith("__dummy"):
|
||||
lora_texts.append(f"<lora:{lora_name}:{strength}>")
|
||||
|
||||
# Process lora_stack if it exists
|
||||
if lora_stack:
|
||||
# Format each entry from the stack
|
||||
for lora_path, strength, _ in lora_stack:
|
||||
lora_name = os.path.basename(lora_path).split('.')[0]
|
||||
if lora_name and not lora_name.startswith("__dummy"):
|
||||
lora_texts.append(f"<lora:{lora_name}:{strength}>")
|
||||
|
||||
# Join with spaces
|
||||
combined_text = " ".join(lora_texts)
|
||||
|
||||
return {"loras": combined_text}
|
||||
|
||||
@@ -170,8 +187,34 @@ class LoraStackerMapper(NodeMapper):
|
||||
)
|
||||
|
||||
def transform(self, inputs: Dict) -> Dict:
|
||||
# Return the lora_stack information
|
||||
return inputs.get("lora_stack", [])
|
||||
loras_data = inputs.get("loras", [])
|
||||
existing_stack = inputs.get("lora_stack", [])
|
||||
result_stack = []
|
||||
|
||||
# Keep existing stack entries
|
||||
if existing_stack:
|
||||
result_stack.extend(existing_stack)
|
||||
|
||||
# Process loras array - filter active entries
|
||||
# Check if loras_data is a list or a dict with __value__ key (new format)
|
||||
if isinstance(loras_data, dict) and "__value__" in loras_data:
|
||||
loras_list = loras_data["__value__"]
|
||||
elif isinstance(loras_data, list):
|
||||
loras_list = loras_data
|
||||
else:
|
||||
loras_list = []
|
||||
|
||||
# Process each active lora entry
|
||||
for lora in loras_list:
|
||||
if isinstance(lora, dict) and lora.get("active", False):
|
||||
lora_name = lora.get("name", "")
|
||||
strength = float(lora.get("strength", 1.0))
|
||||
if lora_name and not lora_name.startswith("__dummy"):
|
||||
# Here we would need the real path, but as a fallback use the name
|
||||
# In a real implementation, this would require looking up the file path
|
||||
result_stack.append((lora_name, strength, strength))
|
||||
|
||||
return {"lora_stack": result_stack}
|
||||
|
||||
|
||||
class JoinStringsMapper(NodeMapper):
|
||||
@@ -209,19 +252,31 @@ class TriggerWordToggleMapper(NodeMapper):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
node_type="TriggerWord Toggle (LoraManager)",
|
||||
inputs_to_track=["toggle_trigger_words", "orinalMessage", "trigger_words"]
|
||||
inputs_to_track=["toggle_trigger_words"]
|
||||
)
|
||||
|
||||
def transform(self, inputs: Dict) -> str:
|
||||
# Get the original message or toggled trigger words
|
||||
original_message = inputs.get("orinalMessage", "") or inputs.get("trigger_words", "")
|
||||
toggle_data = inputs.get("toggle_trigger_words", [])
|
||||
|
||||
# check if toggle_words is a list or a dict with __value__ key (new format)
|
||||
if isinstance(toggle_data, dict) and "__value__" in toggle_data:
|
||||
toggle_words = toggle_data["__value__"]
|
||||
elif isinstance(toggle_data, list):
|
||||
toggle_words = toggle_data
|
||||
else:
|
||||
toggle_words = []
|
||||
|
||||
# Fix double commas to match the reference output format
|
||||
if original_message:
|
||||
# Replace double commas with single commas
|
||||
original_message = original_message.replace(",, ", ", ")
|
||||
|
||||
return original_message
|
||||
# Filter active trigger words
|
||||
active_words = []
|
||||
for item in toggle_words:
|
||||
if isinstance(item, dict) and item.get("active", False):
|
||||
word = item.get("text", "")
|
||||
if word and not word.startswith("__dummy"):
|
||||
active_words.append(word)
|
||||
|
||||
# Join with commas
|
||||
result = ", ".join(active_words)
|
||||
return result
|
||||
|
||||
|
||||
class FluxGuidanceMapper(NodeMapper):
|
||||
@@ -251,5 +306,89 @@ class FluxGuidanceMapper(NodeMapper):
|
||||
return result
|
||||
|
||||
|
||||
# Add import os for LoraLoaderMapper to work properly
|
||||
import os
|
||||
# =============================================================================
|
||||
# Mapper Registry Functions
|
||||
# =============================================================================
|
||||
|
||||
def register_mapper(mapper: NodeMapper) -> None:
|
||||
"""Register a node mapper in the global registry"""
|
||||
_MAPPER_REGISTRY[mapper.node_type] = mapper
|
||||
logger.debug(f"Registered mapper for node type: {mapper.node_type}")
|
||||
|
||||
def get_mapper(node_type: str) -> Optional[NodeMapper]:
|
||||
"""Get a mapper for the specified node type"""
|
||||
return _MAPPER_REGISTRY.get(node_type)
|
||||
|
||||
def get_all_mappers() -> Dict[str, NodeMapper]:
|
||||
"""Get all registered mappers"""
|
||||
return _MAPPER_REGISTRY.copy()
|
||||
|
||||
def register_default_mappers() -> None:
|
||||
"""Register all default mappers"""
|
||||
default_mappers = [
|
||||
KSamplerMapper(),
|
||||
EmptyLatentImageMapper(),
|
||||
EmptySD3LatentImageMapper(),
|
||||
CLIPTextEncodeMapper(),
|
||||
LoraLoaderMapper(),
|
||||
LoraStackerMapper(),
|
||||
JoinStringsMapper(),
|
||||
StringConstantMapper(),
|
||||
TriggerWordToggleMapper(),
|
||||
FluxGuidanceMapper()
|
||||
]
|
||||
|
||||
for mapper in default_mappers:
|
||||
register_mapper(mapper)
|
||||
|
||||
# =============================================================================
|
||||
# Extension Loading
|
||||
# =============================================================================
|
||||
|
||||
def load_extensions(ext_dir: str = None) -> None:
|
||||
"""
|
||||
Load mapper extensions from the specified directory
|
||||
|
||||
Each Python file in the directory will be loaded, and any NodeMapper subclasses
|
||||
defined in those files will be automatically registered.
|
||||
"""
|
||||
# Use default path if none provided
|
||||
if ext_dir is None:
|
||||
# Get the directory of this file
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
ext_dir = os.path.join(current_dir, 'ext')
|
||||
|
||||
# Ensure the extension directory exists
|
||||
if not os.path.exists(ext_dir):
|
||||
os.makedirs(ext_dir, exist_ok=True)
|
||||
logger.info(f"Created extension directory: {ext_dir}")
|
||||
return
|
||||
|
||||
# Load each Python file in the extension directory
|
||||
for filename in os.listdir(ext_dir):
|
||||
if filename.endswith('.py') and not filename.startswith('_'):
|
||||
module_path = os.path.join(ext_dir, filename)
|
||||
module_name = f"workflow.ext.{filename[:-3]}" # Remove .py
|
||||
|
||||
try:
|
||||
# Load the module
|
||||
spec = importlib.util.spec_from_file_location(module_name, module_path)
|
||||
if spec and spec.loader:
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
# Find all NodeMapper subclasses in the module
|
||||
for name, obj in inspect.getmembers(module):
|
||||
if (inspect.isclass(obj) and issubclass(obj, NodeMapper)
|
||||
and obj != NodeMapper and hasattr(obj, 'node_type')):
|
||||
# Instantiate and register the mapper
|
||||
mapper = obj()
|
||||
register_mapper(mapper)
|
||||
logger.info(f"Loaded extension mapper: {mapper.node_type} from {filename}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading extension {filename}: {e}")
|
||||
|
||||
|
||||
# Initialize the registry with default mappers
|
||||
register_default_mappers()
|
||||
Reference in New Issue
Block a user