Implement KJNodes extension with new mappers and transform functions

- Added KJNodes mappers for JoinStrings, StringConstantMultiline, and EmptyLatentImagePresets.
- Introduced transform functions to handle string joining, string constants, and dimension extraction with optional inversion.
- Registered new mappers and logged successful registration for better traceability.
This commit is contained in:
Will Miao
2025-04-01 16:22:57 +08:00
parent 60575b6546
commit 195866b00d
2 changed files with 280 additions and 218 deletions

View File

@@ -0,0 +1,81 @@
"""
KJNodes mappers extension for ComfyUI workflow parsing
"""
import logging
import re
from typing import Dict, Any
logger = logging.getLogger(__name__)
# Import the mapper registration functions from the parent module
from workflow.mappers import create_mapper, register_mapper
# =============================================================================
# Transform Functions
# =============================================================================
def transform_join_strings(inputs: Dict) -> str:
"""Transform function for JoinStrings nodes"""
string1 = inputs.get("string1", "")
string2 = inputs.get("string2", "")
delimiter = inputs.get("delimiter", "")
return f"{string1}{delimiter}{string2}"
def transform_string_constant(inputs: Dict) -> str:
"""Transform function for StringConstant nodes"""
return inputs.get("string", "")
def transform_empty_latent_presets(inputs: Dict) -> Dict:
"""Transform function for EmptyLatentImagePresets nodes"""
dimensions = inputs.get("dimensions", "")
invert = inputs.get("invert", False)
# Extract width and height from dimensions string
# Expected format: "width x height (ratio)" or similar
width = 0
height = 0
if dimensions:
# Try to extract dimensions using regex
match = re.search(r'(\d+)\s*x\s*(\d+)', dimensions)
if match:
width = int(match.group(1))
height = int(match.group(2))
# If invert is True, swap width and height
if invert and width and height:
width, height = height, width
return {"width": width, "height": height, "size": f"{width}x{height}"}
# =============================================================================
# Register Mappers
# =============================================================================
# Define the mappers for KJNodes
KJNODES_MAPPERS = {
"JoinStrings": {
"inputs_to_track": ["string1", "string2", "delimiter"],
"transform_func": transform_join_strings
},
"StringConstantMultiline": {
"inputs_to_track": ["string"],
"transform_func": transform_string_constant
},
"EmptyLatentImagePresets": {
"inputs_to_track": ["dimensions", "invert", "batch_size"],
"transform_func": transform_empty_latent_presets
}
}
# Register all KJNodes mappers
for node_type, config in KJNODES_MAPPERS.items():
mapper = create_mapper(
node_type=node_type,
inputs_to_track=config["inputs_to_track"],
transform_func=config["transform_func"]
)
register_mapper(mapper)
logger.info(f"Registered KJNodes mapper for node type: {node_type}")
logger.info(f"Loaded KJNodes extension with {len(KJNODES_MAPPERS)} mappers")

View File

@@ -52,6 +52,7 @@ def process_node(node_id: str, node_data: Dict, workflow: Dict, parser: 'Workflo
mapper = get_mapper(node_type) mapper = get_mapper(node_type)
if not mapper: if not mapper:
logger.warning(f"No mapper found for node type: {node_type}")
return None return None
result = {} result = {}
@@ -93,14 +94,11 @@ def process_node(node_id: str, node_data: Dict, workflow: Dict, parser: 'Workflo
return result return result
# ============================================================================= # =============================================================================
# Default Mapper Definitions # Transform Functions
# ============================================================================= # =============================================================================
def register_default_mappers() -> None: def transform_ksampler(inputs: Dict) -> Dict:
"""Register all default mappers""" """Transform function for KSampler nodes"""
# KSampler mapper
def transform_ksampler(inputs: Dict) -> Dict:
result = { result = {
"seed": str(inputs.get("seed", "")), "seed": str(inputs.get("seed", "")),
"steps": str(inputs.get("steps", "")), "steps": str(inputs.get("steps", "")),
@@ -130,45 +128,18 @@ def register_default_mappers() -> None:
return result return result
register_mapper(create_mapper( def transform_empty_latent(inputs: Dict) -> Dict:
node_type="KSampler", """Transform function for EmptyLatentImage nodes"""
inputs_to_track=["seed", "steps", "cfg", "sampler_name", "scheduler",
"denoise", "positive", "negative", "latent_image",
"model", "clip_skip"],
transform_func=transform_ksampler
))
# EmptyLatentImage mapper
def transform_empty_latent(inputs: Dict) -> Dict:
width = inputs.get("width", 0) width = inputs.get("width", 0)
height = inputs.get("height", 0) height = inputs.get("height", 0)
return {"width": width, "height": height, "size": f"{width}x{height}"} return {"width": width, "height": height, "size": f"{width}x{height}"}
register_mapper(create_mapper( def transform_clip_text(inputs: Dict) -> Any:
node_type="EmptyLatentImage", """Transform function for CLIPTextEncode nodes"""
inputs_to_track=["width", "height", "batch_size"],
transform_func=transform_empty_latent
))
# SD3LatentImage mapper - reuses same transform function as EmptyLatentImage
register_mapper(create_mapper(
node_type="EmptySD3LatentImage",
inputs_to_track=["width", "height", "batch_size"],
transform_func=transform_empty_latent
))
# CLIPTextEncode mapper
def transform_clip_text(inputs: Dict) -> Any:
return inputs.get("text", "") return inputs.get("text", "")
register_mapper(create_mapper( def transform_lora_loader(inputs: Dict) -> Dict:
node_type="CLIPTextEncode", """Transform function for LoraLoader nodes"""
inputs_to_track=["text", "clip"],
transform_func=transform_clip_text
))
# LoraLoader mapper
def transform_lora_loader(inputs: Dict) -> Dict:
loras_data = inputs.get("loras", []) loras_data = inputs.get("loras", [])
lora_stack = inputs.get("lora_stack", {}).get("lora_stack", []) lora_stack = inputs.get("lora_stack", {}).get("lora_stack", [])
@@ -199,14 +170,8 @@ def register_default_mappers() -> None:
return {"loras": " ".join(lora_texts)} return {"loras": " ".join(lora_texts)}
register_mapper(create_mapper( def transform_lora_stacker(inputs: Dict) -> Dict:
node_type="Lora Loader (LoraManager)", """Transform function for LoraStacker nodes"""
inputs_to_track=["loras", "lora_stack"],
transform_func=transform_lora_loader
))
# LoraStacker mapper
def transform_lora_stacker(inputs: Dict) -> Dict:
loras_data = inputs.get("loras", []) loras_data = inputs.get("loras", [])
result_stack = [] result_stack = []
@@ -241,37 +206,8 @@ def register_default_mappers() -> None:
return {"lora_stack": result_stack} return {"lora_stack": result_stack}
register_mapper(create_mapper( def transform_trigger_word_toggle(inputs: Dict) -> str:
node_type="Lora Stacker (LoraManager)", """Transform function for TriggerWordToggle nodes"""
inputs_to_track=["loras", "lora_stack"],
transform_func=transform_lora_stacker
))
# JoinStrings mapper
def transform_join_strings(inputs: Dict) -> str:
string1 = inputs.get("string1", "")
string2 = inputs.get("string2", "")
delimiter = inputs.get("delimiter", "")
return f"{string1}{delimiter}{string2}"
register_mapper(create_mapper(
node_type="JoinStrings",
inputs_to_track=["string1", "string2", "delimiter"],
transform_func=transform_join_strings
))
# StringConstant mapper
def transform_string_constant(inputs: Dict) -> str:
return inputs.get("string", "")
register_mapper(create_mapper(
node_type="StringConstantMultiline",
inputs_to_track=["string"],
transform_func=transform_string_constant
))
# TriggerWordToggle mapper
def transform_trigger_word_toggle(inputs: Dict) -> str:
toggle_data = inputs.get("toggle_trigger_words", []) toggle_data = inputs.get("toggle_trigger_words", [])
if isinstance(toggle_data, dict) and "__value__" in toggle_data: if isinstance(toggle_data, dict) and "__value__" in toggle_data:
@@ -291,14 +227,8 @@ def register_default_mappers() -> None:
return ", ".join(active_words) return ", ".join(active_words)
register_mapper(create_mapper( def transform_flux_guidance(inputs: Dict) -> Dict:
node_type="TriggerWord Toggle (LoraManager)", """Transform function for FluxGuidance nodes"""
inputs_to_track=["toggle_trigger_words"],
transform_func=transform_trigger_word_toggle
))
# FluxGuidance mapper
def transform_flux_guidance(inputs: Dict) -> Dict:
result = {} result = {}
if "guidance" in inputs: if "guidance" in inputs:
@@ -313,11 +243,62 @@ def register_default_mappers() -> None:
return result return result
register_mapper(create_mapper( # =============================================================================
node_type="FluxGuidance", # Node Mapper Definitions
inputs_to_track=["guidance", "conditioning"], # =============================================================================
transform_func=transform_flux_guidance
)) # Central definition of all supported node types and their configurations
NODE_MAPPERS = {
# ComfyUI core nodes
"KSampler": {
"inputs_to_track": [
"seed", "steps", "cfg", "sampler_name", "scheduler",
"denoise", "positive", "negative", "latent_image",
"model", "clip_skip"
],
"transform_func": transform_ksampler
},
"EmptyLatentImage": {
"inputs_to_track": ["width", "height", "batch_size"],
"transform_func": transform_empty_latent
},
"EmptySD3LatentImage": {
"inputs_to_track": ["width", "height", "batch_size"],
"transform_func": transform_empty_latent
},
"CLIPTextEncode": {
"inputs_to_track": ["text", "clip"],
"transform_func": transform_clip_text
},
"FluxGuidance": {
"inputs_to_track": ["guidance", "conditioning"],
"transform_func": transform_flux_guidance
},
# LoraManager nodes
"Lora Loader (LoraManager)": {
"inputs_to_track": ["loras", "lora_stack"],
"transform_func": transform_lora_loader
},
"Lora Stacker (LoraManager)": {
"inputs_to_track": ["loras", "lora_stack"],
"transform_func": transform_lora_stacker
},
"TriggerWord Toggle (LoraManager)": {
"inputs_to_track": ["toggle_trigger_words"],
"transform_func": transform_trigger_word_toggle
}
}
def register_default_mappers() -> None:
"""Register all default mappers from the NODE_MAPPERS dictionary"""
for node_type, config in NODE_MAPPERS.items():
mapper = create_mapper(
node_type=node_type,
inputs_to_track=config["inputs_to_track"],
transform_func=config["transform_func"]
)
register_mapper(mapper)
logger.info(f"Registered {len(NODE_MAPPERS)} default node mappers")
# ============================================================================= # =============================================================================
# Extension Loading # Extension Loading