From 195866b00d2c002fb9c7e578e2e6d798300f06e6 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Tue, 1 Apr 2025 16:22:57 +0800 Subject: [PATCH] Implement KJNodes extension with new mappers and transform functions - Added KJNodes mappers for JoinStrings, StringConstantMultiline, and EmptyLatentImagePresets. - Introduced transform functions to handle string joining, string constants, and dimension extraction with optional inversion. - Registered new mappers and logged successful registration for better traceability. --- py/workflow/ext/kjnodes.py | 81 +++++++ py/workflow/mappers.py | 417 ++++++++++++++++++------------------- 2 files changed, 280 insertions(+), 218 deletions(-) create mode 100644 py/workflow/ext/kjnodes.py diff --git a/py/workflow/ext/kjnodes.py b/py/workflow/ext/kjnodes.py new file mode 100644 index 00000000..76e45edf --- /dev/null +++ b/py/workflow/ext/kjnodes.py @@ -0,0 +1,81 @@ +""" +KJNodes mappers extension for ComfyUI workflow parsing +""" +import logging +import re +from typing import Dict, Any + +logger = logging.getLogger(__name__) + +# Import the mapper registration functions from the parent module +from workflow.mappers import create_mapper, register_mapper + +# ============================================================================= +# Transform Functions +# ============================================================================= + +def transform_join_strings(inputs: Dict) -> str: + """Transform function for JoinStrings nodes""" + string1 = inputs.get("string1", "") + string2 = inputs.get("string2", "") + delimiter = inputs.get("delimiter", "") + return f"{string1}{delimiter}{string2}" + +def transform_string_constant(inputs: Dict) -> str: + """Transform function for StringConstant nodes""" + return inputs.get("string", "") + +def transform_empty_latent_presets(inputs: Dict) -> Dict: + """Transform function for EmptyLatentImagePresets nodes""" + dimensions = inputs.get("dimensions", "") + invert = inputs.get("invert", False) + + # Extract width and height from dimensions string + # Expected format: "width x height (ratio)" or similar + width = 0 + height = 0 + + if dimensions: + # Try to extract dimensions using regex + match = re.search(r'(\d+)\s*x\s*(\d+)', dimensions) + if match: + width = int(match.group(1)) + height = int(match.group(2)) + + # If invert is True, swap width and height + if invert and width and height: + width, height = height, width + + return {"width": width, "height": height, "size": f"{width}x{height}"} + +# ============================================================================= +# Register Mappers +# ============================================================================= + +# Define the mappers for KJNodes +KJNODES_MAPPERS = { + "JoinStrings": { + "inputs_to_track": ["string1", "string2", "delimiter"], + "transform_func": transform_join_strings + }, + "StringConstantMultiline": { + "inputs_to_track": ["string"], + "transform_func": transform_string_constant + }, + "EmptyLatentImagePresets": { + "inputs_to_track": ["dimensions", "invert", "batch_size"], + "transform_func": transform_empty_latent_presets + } +} + +# Register all KJNodes mappers +for node_type, config in KJNODES_MAPPERS.items(): + mapper = create_mapper( + node_type=node_type, + inputs_to_track=config["inputs_to_track"], + transform_func=config["transform_func"] + ) + register_mapper(mapper) + logger.info(f"Registered KJNodes mapper for node type: {node_type}") + +logger.info(f"Loaded KJNodes extension with {len(KJNODES_MAPPERS)} mappers") \ No newline at end of file diff --git a/py/workflow/mappers.py b/py/workflow/mappers.py index 383b1b4a..96fb02a2 100644 --- a/py/workflow/mappers.py +++ b/py/workflow/mappers.py @@ -52,6 +52,7 @@ def process_node(node_id: str, node_data: Dict, workflow: Dict, parser: 'Workflo mapper = get_mapper(node_type) if not mapper: + logger.warning(f"No mapper found for node type: {node_type}") return None result = {} @@ -93,231 +94,211 @@ def process_node(node_id: str, node_data: Dict, workflow: Dict, parser: 'Workflo return result # ============================================================================= -# Default Mapper Definitions +# Transform Functions # ============================================================================= -def register_default_mappers() -> None: - """Register all default mappers""" +def transform_ksampler(inputs: Dict) -> Dict: + """Transform function for KSampler nodes""" + result = { + "seed": str(inputs.get("seed", "")), + "steps": str(inputs.get("steps", "")), + "cfg": str(inputs.get("cfg", "")), + "sampler": inputs.get("sampler_name", ""), + "scheduler": inputs.get("scheduler", ""), + } - # KSampler mapper - def transform_ksampler(inputs: Dict) -> Dict: - result = { - "seed": str(inputs.get("seed", "")), - "steps": str(inputs.get("steps", "")), - "cfg": str(inputs.get("cfg", "")), - "sampler": inputs.get("sampler_name", ""), - "scheduler": inputs.get("scheduler", ""), - } + # Process positive prompt + if "positive" in inputs: + result["prompt"] = inputs["positive"] + + # Process negative prompt + if "negative" in inputs: + result["negative_prompt"] = inputs["negative"] + + # Get dimensions from latent image + if "latent_image" in inputs and isinstance(inputs["latent_image"], dict): + width = inputs["latent_image"].get("width", 0) + height = inputs["latent_image"].get("height", 0) + if width and height: + result["size"] = f"{width}x{height}" + + # Add clip_skip if present + if "clip_skip" in inputs: + result["clip_skip"] = str(inputs.get("clip_skip", "")) - # Process positive prompt - if "positive" in inputs: - result["prompt"] = inputs["positive"] - - # Process negative prompt - if "negative" in inputs: - result["negative_prompt"] = inputs["negative"] - - # Get dimensions from latent image - if "latent_image" in inputs and isinstance(inputs["latent_image"], dict): - width = inputs["latent_image"].get("width", 0) - height = inputs["latent_image"].get("height", 0) - if width and height: - result["size"] = f"{width}x{height}" - - # Add clip_skip if present - if "clip_skip" in inputs: - result["clip_skip"] = str(inputs.get("clip_skip", "")) - - return result - - register_mapper(create_mapper( - node_type="KSampler", - inputs_to_track=["seed", "steps", "cfg", "sampler_name", "scheduler", - "denoise", "positive", "negative", "latent_image", - "model", "clip_skip"], - transform_func=transform_ksampler - )) - - # EmptyLatentImage mapper - def transform_empty_latent(inputs: Dict) -> Dict: - width = inputs.get("width", 0) - height = inputs.get("height", 0) - return {"width": width, "height": height, "size": f"{width}x{height}"} - - register_mapper(create_mapper( - node_type="EmptyLatentImage", - inputs_to_track=["width", "height", "batch_size"], - transform_func=transform_empty_latent - )) - - # SD3LatentImage mapper - reuses same transform function as EmptyLatentImage - register_mapper(create_mapper( - node_type="EmptySD3LatentImage", - inputs_to_track=["width", "height", "batch_size"], - transform_func=transform_empty_latent - )) - - # CLIPTextEncode mapper - def transform_clip_text(inputs: Dict) -> Any: - return inputs.get("text", "") - - register_mapper(create_mapper( - node_type="CLIPTextEncode", - inputs_to_track=["text", "clip"], - transform_func=transform_clip_text - )) - - # LoraLoader mapper - def transform_lora_loader(inputs: Dict) -> Dict: - loras_data = inputs.get("loras", []) - lora_stack = inputs.get("lora_stack", {}).get("lora_stack", []) - - lora_texts = [] - - # Process loras array - if isinstance(loras_data, dict) and "__value__" in loras_data: - loras_list = loras_data["__value__"] - elif isinstance(loras_data, list): - loras_list = loras_data - else: - loras_list = [] - - # Process each active lora entry - for lora in loras_list: - if isinstance(lora, dict) and lora.get("active", False): - lora_name = lora.get("name", "") - strength = lora.get("strength", 1.0) - lora_texts.append(f"") - - # Process lora_stack if valid - if lora_stack and isinstance(lora_stack, list): - if not (len(lora_stack) == 2 and isinstance(lora_stack[0], (str, int)) and isinstance(lora_stack[1], int)): - for stack_entry in lora_stack: - lora_name = stack_entry[0] - strength = stack_entry[1] - lora_texts.append(f"") - - return {"loras": " ".join(lora_texts)} - - register_mapper(create_mapper( - node_type="Lora Loader (LoraManager)", - inputs_to_track=["loras", "lora_stack"], - transform_func=transform_lora_loader - )) - - # LoraStacker mapper - def transform_lora_stacker(inputs: Dict) -> Dict: - loras_data = inputs.get("loras", []) - result_stack = [] - - # Handle existing stack entries - existing_stack = [] - lora_stack_input = inputs.get("lora_stack", []) - - if isinstance(lora_stack_input, dict) and "lora_stack" in lora_stack_input: - existing_stack = lora_stack_input["lora_stack"] - elif isinstance(lora_stack_input, list): - if not (len(lora_stack_input) == 2 and isinstance(lora_stack_input[0], (str, int)) and - isinstance(lora_stack_input[1], int)): - existing_stack = lora_stack_input - - # Add existing entries - if existing_stack: - result_stack.extend(existing_stack) - - # Process new loras - if isinstance(loras_data, dict) and "__value__" in loras_data: - loras_list = loras_data["__value__"] - elif isinstance(loras_data, list): - loras_list = loras_data - else: - loras_list = [] - - for lora in loras_list: - if isinstance(lora, dict) and lora.get("active", False): - lora_name = lora.get("name", "") - strength = float(lora.get("strength", 1.0)) - result_stack.append((lora_name, strength)) - - return {"lora_stack": result_stack} - - register_mapper(create_mapper( - node_type="Lora Stacker (LoraManager)", - inputs_to_track=["loras", "lora_stack"], - transform_func=transform_lora_stacker - )) - - # JoinStrings mapper - def transform_join_strings(inputs: Dict) -> str: - string1 = inputs.get("string1", "") - string2 = inputs.get("string2", "") - delimiter = inputs.get("delimiter", "") - return f"{string1}{delimiter}{string2}" - - register_mapper(create_mapper( - node_type="JoinStrings", - inputs_to_track=["string1", "string2", "delimiter"], - transform_func=transform_join_strings - )) - - # StringConstant mapper - def transform_string_constant(inputs: Dict) -> str: - return inputs.get("string", "") - - register_mapper(create_mapper( - node_type="StringConstantMultiline", - inputs_to_track=["string"], - transform_func=transform_string_constant - )) - - # TriggerWordToggle mapper - def transform_trigger_word_toggle(inputs: Dict) -> str: - toggle_data = inputs.get("toggle_trigger_words", []) + return result - if isinstance(toggle_data, dict) and "__value__" in toggle_data: - toggle_words = toggle_data["__value__"] - elif isinstance(toggle_data, list): - toggle_words = toggle_data +def transform_empty_latent(inputs: Dict) -> Dict: + """Transform function for EmptyLatentImage nodes""" + width = inputs.get("width", 0) + height = inputs.get("height", 0) + return {"width": width, "height": height, "size": f"{width}x{height}"} + +def transform_clip_text(inputs: Dict) -> Any: + """Transform function for CLIPTextEncode nodes""" + return inputs.get("text", "") + +def transform_lora_loader(inputs: Dict) -> Dict: + """Transform function for LoraLoader nodes""" + loras_data = inputs.get("loras", []) + lora_stack = inputs.get("lora_stack", {}).get("lora_stack", []) + + lora_texts = [] + + # Process loras array + if isinstance(loras_data, dict) and "__value__" in loras_data: + loras_list = loras_data["__value__"] + elif isinstance(loras_data, list): + loras_list = loras_data + else: + loras_list = [] + + # Process each active lora entry + for lora in loras_list: + if isinstance(lora, dict) and lora.get("active", False): + lora_name = lora.get("name", "") + strength = lora.get("strength", 1.0) + lora_texts.append(f"") + + # Process lora_stack if valid + if lora_stack and isinstance(lora_stack, list): + if not (len(lora_stack) == 2 and isinstance(lora_stack[0], (str, int)) and isinstance(lora_stack[1], int)): + for stack_entry in lora_stack: + lora_name = stack_entry[0] + strength = stack_entry[1] + lora_texts.append(f"") + + return {"loras": " ".join(lora_texts)} + +def transform_lora_stacker(inputs: Dict) -> Dict: + """Transform function for LoraStacker nodes""" + loras_data = inputs.get("loras", []) + result_stack = [] + + # Handle existing stack entries + existing_stack = [] + lora_stack_input = inputs.get("lora_stack", []) + + if isinstance(lora_stack_input, dict) and "lora_stack" in lora_stack_input: + existing_stack = lora_stack_input["lora_stack"] + elif isinstance(lora_stack_input, list): + if not (len(lora_stack_input) == 2 and isinstance(lora_stack_input[0], (str, int)) and + isinstance(lora_stack_input[1], int)): + existing_stack = lora_stack_input + + # Add existing entries + if existing_stack: + result_stack.extend(existing_stack) + + # Process new loras + if isinstance(loras_data, dict) and "__value__" in loras_data: + loras_list = loras_data["__value__"] + elif isinstance(loras_data, list): + loras_list = loras_data + else: + loras_list = [] + + for lora in loras_list: + if isinstance(lora, dict) and lora.get("active", False): + lora_name = lora.get("name", "") + strength = float(lora.get("strength", 1.0)) + result_stack.append((lora_name, strength)) + + return {"lora_stack": result_stack} + +def transform_trigger_word_toggle(inputs: Dict) -> str: + """Transform function for TriggerWordToggle nodes""" + toggle_data = inputs.get("toggle_trigger_words", []) + + if isinstance(toggle_data, dict) and "__value__" in toggle_data: + toggle_words = toggle_data["__value__"] + elif isinstance(toggle_data, list): + toggle_words = toggle_data + else: + toggle_words = [] + + # Filter active trigger words + active_words = [] + for item in toggle_words: + if isinstance(item, dict) and item.get("active", False): + word = item.get("text", "") + if word and not word.startswith("__dummy"): + active_words.append(word) + + return ", ".join(active_words) + +def transform_flux_guidance(inputs: Dict) -> Dict: + """Transform function for FluxGuidance nodes""" + result = {} + + if "guidance" in inputs: + result["guidance"] = inputs["guidance"] + + if "conditioning" in inputs: + conditioning = inputs["conditioning"] + if isinstance(conditioning, str): + result["prompt"] = conditioning else: - toggle_words = [] - - # Filter active trigger words - active_words = [] - for item in toggle_words: - if isinstance(item, dict) and item.get("active", False): - word = item.get("text", "") - if word and not word.startswith("__dummy"): - active_words.append(word) - - return ", ".join(active_words) + result["prompt"] = "Unknown prompt" - register_mapper(create_mapper( - node_type="TriggerWord Toggle (LoraManager)", - inputs_to_track=["toggle_trigger_words"], - transform_func=transform_trigger_word_toggle - )) - - # FluxGuidance mapper - def transform_flux_guidance(inputs: Dict) -> Dict: - result = {} - - if "guidance" in inputs: - result["guidance"] = inputs["guidance"] - - if "conditioning" in inputs: - conditioning = inputs["conditioning"] - if isinstance(conditioning, str): - result["prompt"] = conditioning - else: - result["prompt"] = "Unknown prompt" - - return result - - register_mapper(create_mapper( - node_type="FluxGuidance", - inputs_to_track=["guidance", "conditioning"], - transform_func=transform_flux_guidance - )) + return result + +# ============================================================================= +# Node Mapper Definitions +# ============================================================================= + +# Central definition of all supported node types and their configurations +NODE_MAPPERS = { + # ComfyUI core nodes + "KSampler": { + "inputs_to_track": [ + "seed", "steps", "cfg", "sampler_name", "scheduler", + "denoise", "positive", "negative", "latent_image", + "model", "clip_skip" + ], + "transform_func": transform_ksampler + }, + "EmptyLatentImage": { + "inputs_to_track": ["width", "height", "batch_size"], + "transform_func": transform_empty_latent + }, + "EmptySD3LatentImage": { + "inputs_to_track": ["width", "height", "batch_size"], + "transform_func": transform_empty_latent + }, + "CLIPTextEncode": { + "inputs_to_track": ["text", "clip"], + "transform_func": transform_clip_text + }, + "FluxGuidance": { + "inputs_to_track": ["guidance", "conditioning"], + "transform_func": transform_flux_guidance + }, + # LoraManager nodes + "Lora Loader (LoraManager)": { + "inputs_to_track": ["loras", "lora_stack"], + "transform_func": transform_lora_loader + }, + "Lora Stacker (LoraManager)": { + "inputs_to_track": ["loras", "lora_stack"], + "transform_func": transform_lora_stacker + }, + "TriggerWord Toggle (LoraManager)": { + "inputs_to_track": ["toggle_trigger_words"], + "transform_func": transform_trigger_word_toggle + } +} + +def register_default_mappers() -> None: + """Register all default mappers from the NODE_MAPPERS dictionary""" + for node_type, config in NODE_MAPPERS.items(): + mapper = create_mapper( + node_type=node_type, + inputs_to_track=config["inputs_to_track"], + transform_func=config["transform_func"] + ) + register_mapper(mapper) + logger.info(f"Registered {len(NODE_MAPPERS)} default node mappers") # ============================================================================= # Extension Loading