From e3bf1f763c962487ce2be50f027959e2ea7d4fa7 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Wed, 7 May 2025 17:13:30 +0800 Subject: [PATCH] refactor: remove workflow parsing module and associated files for cleanup --- py/server_routes.py | 26 --- py/workflow/__init__.py | 3 - py/workflow/cli.py | 58 ------- py/workflow/ext/__init__.py | 3 - py/workflow/ext/comfyui_core.py | 285 -------------------------------- py/workflow/ext/kjnodes.py | 74 --------- py/workflow/main.py | 37 ----- py/workflow/mappers.py | 282 ------------------------------- py/workflow/parser.py | 181 -------------------- py/workflow/test.py | 63 ------- py/workflow/utils.py | 120 -------------- 11 files changed, 1132 deletions(-) delete mode 100644 py/server_routes.py delete mode 100644 py/workflow/__init__.py delete mode 100644 py/workflow/cli.py delete mode 100644 py/workflow/ext/__init__.py delete mode 100644 py/workflow/ext/comfyui_core.py delete mode 100644 py/workflow/ext/kjnodes.py delete mode 100644 py/workflow/main.py delete mode 100644 py/workflow/mappers.py delete mode 100644 py/workflow/parser.py delete mode 100644 py/workflow/test.py delete mode 100644 py/workflow/utils.py diff --git a/py/server_routes.py b/py/server_routes.py deleted file mode 100644 index 68ee9749..00000000 --- a/py/server_routes.py +++ /dev/null @@ -1,26 +0,0 @@ -from aiohttp import web -from server import PromptServer -from .nodes.utils import get_lora_info - -@PromptServer.instance.routes.post("/loramanager/get_trigger_words") -async def get_trigger_words(request): - json_data = await request.json() - lora_names = json_data.get("lora_names", []) - node_ids = json_data.get("node_ids", []) - - all_trigger_words = [] - for lora_name in lora_names: - _, trigger_words = await get_lora_info(lora_name) - all_trigger_words.extend(trigger_words) - - # Format the trigger words - trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else "" - - # Send update to all connected trigger word toggle nodes - for node_id in node_ids: - PromptServer.instance.send_sync("trigger_word_update", { - "id": node_id, - "message": trigger_words_text - }) - - return web.json_response({"success": True}) diff --git a/py/workflow/__init__.py b/py/workflow/__init__.py deleted file mode 100644 index 5bb0929b..00000000 --- a/py/workflow/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -""" -ComfyUI workflow parsing module to extract generation parameters -""" \ No newline at end of file diff --git a/py/workflow/cli.py b/py/workflow/cli.py deleted file mode 100644 index ab39ed4a..00000000 --- a/py/workflow/cli.py +++ /dev/null @@ -1,58 +0,0 @@ -""" -Command-line interface for the ComfyUI workflow parser -""" -import argparse -import json -import os -import logging -import sys -from .parser import parse_workflow - -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', - handlers=[logging.StreamHandler()] -) -logger = logging.getLogger(__name__) - -def main(): - """Entry point for the CLI""" - parser = argparse.ArgumentParser(description='Parse ComfyUI workflow files') - parser.add_argument('input', help='Input workflow JSON file path') - parser.add_argument('-o', '--output', help='Output JSON file path') - parser.add_argument('-p', '--pretty', action='store_true', help='Pretty print JSON output') - parser.add_argument('--debug', action='store_true', help='Enable debug logging') - - args = parser.parse_args() - - # Set logging level - if args.debug: - logging.getLogger().setLevel(logging.DEBUG) - - # Validate input file - if not os.path.isfile(args.input): - logger.error(f"Input file not found: {args.input}") - sys.exit(1) - - # Parse workflow - try: - result = parse_workflow(args.input, args.output) - - # Print result to console if output file not specified - if not args.output: - if args.pretty: - print(json.dumps(result, indent=4)) - else: - print(json.dumps(result)) - else: - logger.info(f"Output saved to: {args.output}") - - except Exception as e: - logger.error(f"Error parsing workflow: {e}") - if args.debug: - import traceback - traceback.print_exc() - sys.exit(1) - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/py/workflow/ext/__init__.py b/py/workflow/ext/__init__.py deleted file mode 100644 index 86e11ab6..00000000 --- a/py/workflow/ext/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -""" -Extension directory for custom node mappers -""" \ No newline at end of file diff --git a/py/workflow/ext/comfyui_core.py b/py/workflow/ext/comfyui_core.py deleted file mode 100644 index 5a116d59..00000000 --- a/py/workflow/ext/comfyui_core.py +++ /dev/null @@ -1,285 +0,0 @@ -""" -ComfyUI Core nodes mappers extension for workflow parsing -""" -import logging -from typing import Dict, Any, List - -logger = logging.getLogger(__name__) - -# ============================================================================= -# Transform Functions -# ============================================================================= - -def transform_random_noise(inputs: Dict) -> Dict: - """Transform function for RandomNoise node""" - return {"seed": str(inputs.get("noise_seed", ""))} - -def transform_ksampler_select(inputs: Dict) -> Dict: - """Transform function for KSamplerSelect node""" - return {"sampler": inputs.get("sampler_name", "")} - -def transform_basic_scheduler(inputs: Dict) -> Dict: - """Transform function for BasicScheduler node""" - result = { - "scheduler": inputs.get("scheduler", ""), - "denoise": str(inputs.get("denoise", "1.0")) - } - - # Get steps from inputs or steps input - if "steps" in inputs: - if isinstance(inputs["steps"], str): - result["steps"] = inputs["steps"] - elif isinstance(inputs["steps"], dict) and "value" in inputs["steps"]: - result["steps"] = str(inputs["steps"]["value"]) - else: - result["steps"] = str(inputs["steps"]) - - return result - -def transform_basic_guider(inputs: Dict) -> Dict: - """Transform function for BasicGuider node""" - result = {} - - # Process conditioning - if "conditioning" in inputs: - if isinstance(inputs["conditioning"], str): - result["prompt"] = inputs["conditioning"] - elif isinstance(inputs["conditioning"], dict): - result["conditioning"] = inputs["conditioning"] - - # Get model information if needed - if "model" in inputs and isinstance(inputs["model"], dict): - result["model"] = inputs["model"] - - return result - -def transform_model_sampling_flux(inputs: Dict) -> Dict: - """Transform function for ModelSamplingFlux - mostly a pass-through node""" - # This node is primarily used for routing, so we mostly pass through values - - return inputs["model"] - -def transform_sampler_custom_advanced(inputs: Dict) -> Dict: - """Transform function for SamplerCustomAdvanced node""" - result = {} - - # Extract seed from noise - if "noise" in inputs and isinstance(inputs["noise"], dict): - result["seed"] = str(inputs["noise"].get("seed", "")) - - # Extract sampler info - if "sampler" in inputs and isinstance(inputs["sampler"], dict): - sampler = inputs["sampler"].get("sampler", "") - if sampler: - result["sampler"] = sampler - - # Extract scheduler, steps, denoise from sigmas - if "sigmas" in inputs and isinstance(inputs["sigmas"], dict): - sigmas = inputs["sigmas"] - result["scheduler"] = sigmas.get("scheduler", "") - result["steps"] = str(sigmas.get("steps", "")) - result["denoise"] = str(sigmas.get("denoise", "1.0")) - - # Extract prompt and guidance from guider - if "guider" in inputs and isinstance(inputs["guider"], dict): - guider = inputs["guider"] - - # Get prompt from conditioning - if "conditioning" in guider and isinstance(guider["conditioning"], str): - result["prompt"] = guider["conditioning"] - elif "conditioning" in guider and isinstance(guider["conditioning"], dict): - result["guidance"] = guider["conditioning"].get("guidance", "") - result["prompt"] = guider["conditioning"].get("prompt", "") - - if "model" in guider and isinstance(guider["model"], dict): - result["checkpoint"] = guider["model"].get("checkpoint", "") - result["loras"] = guider["model"].get("loras", "") - result["clip_skip"] = str(int(guider["model"].get("clip_skip", "-1")) * -1) - - # Extract dimensions from latent_image - if "latent_image" in inputs and isinstance(inputs["latent_image"], dict): - latent = inputs["latent_image"] - width = latent.get("width", 0) - height = latent.get("height", 0) - if width and height: - result["width"] = width - result["height"] = height - result["size"] = f"{width}x{height}" - - return result - -def transform_ksampler(inputs: Dict) -> Dict: - """Transform function for KSampler nodes""" - result = { - "seed": str(inputs.get("seed", "")), - "steps": str(inputs.get("steps", "")), - "cfg": str(inputs.get("cfg", "")), - "sampler": inputs.get("sampler_name", ""), - "scheduler": inputs.get("scheduler", ""), - } - - # Process positive prompt - if "positive" in inputs: - result["prompt"] = inputs["positive"] - - # Process negative prompt - if "negative" in inputs: - result["negative_prompt"] = inputs["negative"] - - # Get dimensions from latent image - if "latent_image" in inputs and isinstance(inputs["latent_image"], dict): - width = inputs["latent_image"].get("width", 0) - height = inputs["latent_image"].get("height", 0) - if width and height: - result["size"] = f"{width}x{height}" - - # Add clip_skip if present - if "clip_skip" in inputs: - result["clip_skip"] = str(inputs.get("clip_skip", "")) - - # Add guidance if present - if "guidance" in inputs: - result["guidance"] = str(inputs.get("guidance", "")) - - # Add model if present - if "model" in inputs: - result["checkpoint"] = inputs.get("model", {}).get("checkpoint", "") - result["loras"] = inputs.get("model", {}).get("loras", "") - result["clip_skip"] = str(inputs.get("model", {}).get("clip_skip", -1) * -1) - - return result - -def transform_empty_latent(inputs: Dict) -> Dict: - """Transform function for EmptyLatentImage nodes""" - width = inputs.get("width", 0) - height = inputs.get("height", 0) - return {"width": width, "height": height, "size": f"{width}x{height}"} - -def transform_clip_text(inputs: Dict) -> Any: - """Transform function for CLIPTextEncode nodes""" - return inputs.get("text", "") - -def transform_flux_guidance(inputs: Dict) -> Dict: - """Transform function for FluxGuidance nodes""" - result = {} - - if "guidance" in inputs: - result["guidance"] = inputs["guidance"] - - if "conditioning" in inputs: - conditioning = inputs["conditioning"] - if isinstance(conditioning, str): - result["prompt"] = conditioning - else: - result["prompt"] = "Unknown prompt" - - return result - -def transform_unet_loader(inputs: Dict) -> Dict: - """Transform function for UNETLoader node""" - unet_name = inputs.get("unet_name", "") - return {"checkpoint": unet_name} if unet_name else {} - -def transform_checkpoint_loader(inputs: Dict) -> Dict: - """Transform function for CheckpointLoaderSimple node""" - ckpt_name = inputs.get("ckpt_name", "") - return {"checkpoint": ckpt_name} if ckpt_name else {} - -def transform_latent_upscale_by(inputs: Dict) -> Dict: - """Transform function for LatentUpscaleBy node""" - result = {} - - width = inputs["samples"].get("width", 0) * inputs["scale_by"] - height = inputs["samples"].get("height", 0) * inputs["scale_by"] - result["width"] = width - result["height"] = height - result["size"] = f"{width}x{height}" - - return result - -def transform_clip_set_last_layer(inputs: Dict) -> Dict: - """Transform function for CLIPSetLastLayer node""" - result = {} - - if "stop_at_clip_layer" in inputs: - result["clip_skip"] = inputs["stop_at_clip_layer"] - - return result - -# ============================================================================= -# Node Mapper Definitions -# ============================================================================= - -# Define the mappers for ComfyUI core nodes not in main mapper -NODE_MAPPERS_EXT = { - # KSamplers - "SamplerCustomAdvanced": { - "inputs_to_track": ["noise", "guider", "sampler", "sigmas", "latent_image"], - "transform_func": transform_sampler_custom_advanced - }, - "KSampler": { - "inputs_to_track": [ - "seed", "steps", "cfg", "sampler_name", "scheduler", - "denoise", "positive", "negative", "latent_image", - "model", "clip_skip" - ], - "transform_func": transform_ksampler - }, - # ComfyUI core nodes - "EmptyLatentImage": { - "inputs_to_track": ["width", "height", "batch_size"], - "transform_func": transform_empty_latent - }, - "EmptySD3LatentImage": { - "inputs_to_track": ["width", "height", "batch_size"], - "transform_func": transform_empty_latent - }, - "CLIPTextEncode": { - "inputs_to_track": ["text", "clip"], - "transform_func": transform_clip_text - }, - "FluxGuidance": { - "inputs_to_track": ["guidance", "conditioning"], - "transform_func": transform_flux_guidance - }, - "RandomNoise": { - "inputs_to_track": ["noise_seed"], - "transform_func": transform_random_noise - }, - "KSamplerSelect": { - "inputs_to_track": ["sampler_name"], - "transform_func": transform_ksampler_select - }, - "BasicScheduler": { - "inputs_to_track": ["scheduler", "steps", "denoise", "model"], - "transform_func": transform_basic_scheduler - }, - "BasicGuider": { - "inputs_to_track": ["model", "conditioning"], - "transform_func": transform_basic_guider - }, - "ModelSamplingFlux": { - "inputs_to_track": ["max_shift", "base_shift", "width", "height", "model"], - "transform_func": transform_model_sampling_flux - }, - "UNETLoader": { - "inputs_to_track": ["unet_name"], - "transform_func": transform_unet_loader - }, - "CheckpointLoaderSimple": { - "inputs_to_track": ["ckpt_name"], - "transform_func": transform_checkpoint_loader - }, - "LatentUpscale": { - "inputs_to_track": ["width", "height"], - "transform_func": transform_empty_latent - }, - "LatentUpscaleBy": { - "inputs_to_track": ["samples", "scale_by"], - "transform_func": transform_latent_upscale_by - }, - "CLIPSetLastLayer": { - "inputs_to_track": ["clip", "stop_at_clip_layer"], - "transform_func": transform_clip_set_last_layer - } -} \ No newline at end of file diff --git a/py/workflow/ext/kjnodes.py b/py/workflow/ext/kjnodes.py deleted file mode 100644 index 8ea99d2c..00000000 --- a/py/workflow/ext/kjnodes.py +++ /dev/null @@ -1,74 +0,0 @@ -""" -KJNodes mappers extension for ComfyUI workflow parsing -""" -import logging -import re -from typing import Dict, Any - -logger = logging.getLogger(__name__) - -# ============================================================================= -# Transform Functions -# ============================================================================= - -def transform_join_strings(inputs: Dict) -> str: - """Transform function for JoinStrings nodes""" - string1 = inputs.get("string1", "") - string2 = inputs.get("string2", "") - delimiter = inputs.get("delimiter", "") - return f"{string1}{delimiter}{string2}" - -def transform_string_constant(inputs: Dict) -> str: - """Transform function for StringConstant nodes""" - return inputs.get("string", "") - -def transform_empty_latent_presets(inputs: Dict) -> Dict: - """Transform function for EmptyLatentImagePresets nodes""" - dimensions = inputs.get("dimensions", "") - invert = inputs.get("invert", False) - - # Extract width and height from dimensions string - # Expected format: "width x height (ratio)" or similar - width = 0 - height = 0 - - if dimensions: - # Try to extract dimensions using regex - match = re.search(r'(\d+)\s*x\s*(\d+)', dimensions) - if match: - width = int(match.group(1)) - height = int(match.group(2)) - - # If invert is True, swap width and height - if invert and width and height: - width, height = height, width - - return {"width": width, "height": height, "size": f"{width}x{height}"} - -def transform_int_constant(inputs: Dict) -> int: - """Transform function for INTConstant nodes""" - return inputs.get("value", 0) - -# ============================================================================= -# Node Mapper Definitions -# ============================================================================= - -# Define the mappers for KJNodes -NODE_MAPPERS_EXT = { - "JoinStrings": { - "inputs_to_track": ["string1", "string2", "delimiter"], - "transform_func": transform_join_strings - }, - "StringConstantMultiline": { - "inputs_to_track": ["string"], - "transform_func": transform_string_constant - }, - "EmptyLatentImagePresets": { - "inputs_to_track": ["dimensions", "invert", "batch_size"], - "transform_func": transform_empty_latent_presets - }, - "INTConstant": { - "inputs_to_track": ["value"], - "transform_func": transform_int_constant - } -} \ No newline at end of file diff --git a/py/workflow/main.py b/py/workflow/main.py deleted file mode 100644 index 2f46591d..00000000 --- a/py/workflow/main.py +++ /dev/null @@ -1,37 +0,0 @@ -""" -Main entry point for the workflow parser module -""" -import os -import sys -import logging -from typing import Dict, Optional, Union - -# Add the parent directory to sys.path to enable imports -SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) -ROOT_DIR = os.path.abspath(os.path.join(SCRIPT_DIR, '..', '..')) -sys.path.insert(0, os.path.dirname(SCRIPT_DIR)) - -from .parser import parse_workflow - -logger = logging.getLogger(__name__) - -def parse_comfyui_workflow( - workflow_path: str, - output_path: Optional[str] = None -) -> Dict: - """ - Parse a ComfyUI workflow file and extract generation parameters - - Args: - workflow_path: Path to the workflow JSON file - output_path: Optional path to save the output JSON - - Returns: - Dictionary containing extracted parameters - """ - return parse_workflow(workflow_path, output_path) - -if __name__ == "__main__": - # If run directly, use the CLI - from .cli import main - main() \ No newline at end of file diff --git a/py/workflow/mappers.py b/py/workflow/mappers.py deleted file mode 100644 index 528aeefb..00000000 --- a/py/workflow/mappers.py +++ /dev/null @@ -1,282 +0,0 @@ -""" -Node mappers for ComfyUI workflow parsing -""" -import logging -import os -import importlib.util -import inspect -from typing import Dict, List, Any, Optional, Union, Type, Callable, Tuple - -logger = logging.getLogger(__name__) - -# Global mapper registry -_MAPPER_REGISTRY: Dict[str, Dict] = {} - -# ============================================================================= -# Mapper Definition Functions -# ============================================================================= - -def create_mapper( - node_type: str, - inputs_to_track: List[str], - transform_func: Callable[[Dict], Any] = None -) -> Dict: - """Create a mapper definition for a node type""" - mapper = { - "node_type": node_type, - "inputs_to_track": inputs_to_track, - "transform": transform_func or (lambda inputs: inputs) - } - return mapper - -def register_mapper(mapper: Dict) -> None: - """Register a node mapper in the global registry""" - _MAPPER_REGISTRY[mapper["node_type"]] = mapper - logger.debug(f"Registered mapper for node type: {mapper['node_type']}") - -def get_mapper(node_type: str) -> Optional[Dict]: - """Get a mapper for the specified node type""" - return _MAPPER_REGISTRY.get(node_type) - -def get_all_mappers() -> Dict[str, Dict]: - """Get all registered mappers""" - return _MAPPER_REGISTRY.copy() - -# ============================================================================= -# Node Processing Function -# ============================================================================= - -def process_node(node_id: str, node_data: Dict, workflow: Dict, parser: 'WorkflowParser') -> Any: # type: ignore - """Process a node using its mapper and extract relevant information""" - node_type = node_data.get("class_type") - mapper = get_mapper(node_type) - - if not mapper: - logger.warning(f"No mapper found for node type: {node_type}") - return None - - result = {} - - # Extract inputs based on the mapper's tracked inputs - for input_name in mapper["inputs_to_track"]: - if input_name in node_data.get("inputs", {}): - input_value = node_data["inputs"][input_name] - - # Check if input is a reference to another node's output - if isinstance(input_value, list) and len(input_value) == 2: - try: - # Format is [node_id, output_slot] - ref_node_id, output_slot = input_value - # Convert node_id to string if it's an integer - if isinstance(ref_node_id, int): - ref_node_id = str(ref_node_id) - - # Recursively process the referenced node - ref_value = parser.process_node(ref_node_id, workflow) - - if ref_value is not None: - result[input_name] = ref_value - else: - # If we couldn't get a value from the reference, store the raw value - result[input_name] = input_value - except Exception as e: - logger.error(f"Error processing reference in node {node_id}, input {input_name}: {e}") - result[input_name] = input_value - else: - # Direct value - result[input_name] = input_value - - # Apply the transform function - try: - return mapper["transform"](result) - except Exception as e: - logger.error(f"Error in transform function for node {node_id} of type {node_type}: {e}") - return result - -# ============================================================================= -# Transform Functions -# ============================================================================= - - - -def transform_lora_loader(inputs: Dict) -> Dict: - """Transform function for LoraLoader nodes""" - loras_data = inputs.get("loras", []) - lora_stack = inputs.get("lora_stack", {}).get("lora_stack", []) - - lora_texts = [] - - # Process loras array - if isinstance(loras_data, dict) and "__value__" in loras_data: - loras_list = loras_data["__value__"] - elif isinstance(loras_data, list): - loras_list = loras_data - else: - loras_list = [] - - # Process each active lora entry - for lora in loras_list: - if isinstance(lora, dict) and lora.get("active", False): - lora_name = lora.get("name", "") - strength = lora.get("strength", 1.0) - lora_texts.append(f"") - - # Process lora_stack if valid - if lora_stack and isinstance(lora_stack, list): - if not (len(lora_stack) == 2 and isinstance(lora_stack[0], (str, int)) and isinstance(lora_stack[1], int)): - for stack_entry in lora_stack: - lora_name = stack_entry[0] - strength = stack_entry[1] - lora_texts.append(f"") - - result = { - "checkpoint": inputs.get("model", {}).get("checkpoint", ""), - "loras": " ".join(lora_texts) - } - - if "clip" in inputs and isinstance(inputs["clip"], dict): - result["clip_skip"] = inputs["clip"].get("clip_skip", "-1") - - return result - -def transform_lora_stacker(inputs: Dict) -> Dict: - """Transform function for LoraStacker nodes""" - loras_data = inputs.get("loras", []) - result_stack = [] - - # Handle existing stack entries - existing_stack = [] - lora_stack_input = inputs.get("lora_stack", []) - - if isinstance(lora_stack_input, dict) and "lora_stack" in lora_stack_input: - existing_stack = lora_stack_input["lora_stack"] - elif isinstance(lora_stack_input, list): - if not (len(lora_stack_input) == 2 and isinstance(lora_stack_input[0], (str, int)) and - isinstance(lora_stack_input[1], int)): - existing_stack = lora_stack_input - - # Add existing entries - if existing_stack: - result_stack.extend(existing_stack) - - # Process new loras - if isinstance(loras_data, dict) and "__value__" in loras_data: - loras_list = loras_data["__value__"] - elif isinstance(loras_data, list): - loras_list = loras_data - else: - loras_list = [] - - for lora in loras_list: - if isinstance(lora, dict) and lora.get("active", False): - lora_name = lora.get("name", "") - strength = float(lora.get("strength", 1.0)) - result_stack.append((lora_name, strength)) - - return {"lora_stack": result_stack} - -def transform_trigger_word_toggle(inputs: Dict) -> str: - """Transform function for TriggerWordToggle nodes""" - toggle_data = inputs.get("toggle_trigger_words", []) - - if isinstance(toggle_data, dict) and "__value__" in toggle_data: - toggle_words = toggle_data["__value__"] - elif isinstance(toggle_data, list): - toggle_words = toggle_data - else: - toggle_words = [] - - # Filter active trigger words - active_words = [] - for item in toggle_words: - if isinstance(item, dict) and item.get("active", False): - word = item.get("text", "") - if word and not word.startswith("__dummy"): - active_words.append(word) - - return ", ".join(active_words) - -# ============================================================================= -# Node Mapper Definitions -# ============================================================================= - -# Central definition of all supported node types and their configurations -NODE_MAPPERS = { - - # LoraManager nodes - "Lora Loader (LoraManager)": { - "inputs_to_track": ["model", "clip", "loras", "lora_stack"], - "transform_func": transform_lora_loader - }, - "Lora Stacker (LoraManager)": { - "inputs_to_track": ["loras", "lora_stack"], - "transform_func": transform_lora_stacker - }, - "TriggerWord Toggle (LoraManager)": { - "inputs_to_track": ["toggle_trigger_words"], - "transform_func": transform_trigger_word_toggle - } -} - -def register_all_mappers() -> None: - """Register all mappers from the NODE_MAPPERS dictionary""" - for node_type, config in NODE_MAPPERS.items(): - mapper = create_mapper( - node_type=node_type, - inputs_to_track=config["inputs_to_track"], - transform_func=config["transform_func"] - ) - register_mapper(mapper) - logger.info(f"Registered {len(NODE_MAPPERS)} node mappers") - -# ============================================================================= -# Extension Loading -# ============================================================================= - -def load_extensions(ext_dir: str = None) -> None: - """ - Load mapper extensions from the specified directory - - Extension files should define a NODE_MAPPERS_EXT dictionary containing mapper configurations. - These will be added to the global NODE_MAPPERS dictionary and registered automatically. - """ - # Use default path if none provided - if ext_dir is None: - # Get the directory of this file - current_dir = os.path.dirname(os.path.abspath(__file__)) - ext_dir = os.path.join(current_dir, 'ext') - - # Ensure the extension directory exists - if not os.path.exists(ext_dir): - os.makedirs(ext_dir, exist_ok=True) - logger.info(f"Created extension directory: {ext_dir}") - return - - # Load each Python file in the extension directory - for filename in os.listdir(ext_dir): - if filename.endswith('.py') and not filename.startswith('_'): - module_path = os.path.join(ext_dir, filename) - module_name = f"workflow.ext.{filename[:-3]}" # Remove .py - - try: - # Load the module - spec = importlib.util.spec_from_file_location(module_name, module_path) - if spec and spec.loader: - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - - # Check if the module defines NODE_MAPPERS_EXT - if hasattr(module, 'NODE_MAPPERS_EXT'): - # Add the extension mappers to the global NODE_MAPPERS dictionary - NODE_MAPPERS.update(module.NODE_MAPPERS_EXT) - logger.info(f"Added {len(module.NODE_MAPPERS_EXT)} mappers from extension: {filename}") - else: - logger.warning(f"Extension {filename} does not define NODE_MAPPERS_EXT dictionary") - except Exception as e: - logger.warning(f"Error loading extension {filename}: {e}") - - # Re-register all mappers after loading extensions - register_all_mappers() - -# Initialize the registry with default mappers -# register_default_mappers() \ No newline at end of file diff --git a/py/workflow/parser.py b/py/workflow/parser.py deleted file mode 100644 index 0a5a02ef..00000000 --- a/py/workflow/parser.py +++ /dev/null @@ -1,181 +0,0 @@ -""" -Main workflow parser implementation for ComfyUI -""" -import json -import logging -from typing import Dict, List, Any, Optional, Union, Set -from .mappers import get_mapper, get_all_mappers, load_extensions, process_node -from .utils import ( - load_workflow, save_output, find_node_by_type, - trace_model_path -) - -logger = logging.getLogger(__name__) - -class WorkflowParser: - """Parser for ComfyUI workflows""" - - def __init__(self): - """Initialize the parser with mappers""" - self.processed_nodes: Set[str] = set() # Track processed nodes to avoid cycles - self.node_results_cache: Dict[str, Any] = {} # Cache for processed node results - - # Load extensions - load_extensions() - - def process_node(self, node_id: str, workflow: Dict) -> Any: - """Process a single node and extract relevant information""" - # Return cached result if available - if node_id in self.node_results_cache: - return self.node_results_cache[node_id] - - # Check if we're in a cycle - if node_id in self.processed_nodes: - return None - - # Mark this node as being processed (to detect cycles) - self.processed_nodes.add(node_id) - - if node_id not in workflow: - self.processed_nodes.remove(node_id) - return None - - node_data = workflow[node_id] - node_type = node_data.get("class_type") - - result = None - if get_mapper(node_type): - try: - result = process_node(node_id, node_data, workflow, self) - # Cache the result - self.node_results_cache[node_id] = result - except Exception as e: - logger.error(f"Error processing node {node_id} of type {node_type}: {e}", exc_info=True) - # Return a partial result or None depending on how we want to handle errors - result = {} - - # Remove node from processed set to allow it to be processed again in a different context - self.processed_nodes.remove(node_id) - return result - - def find_primary_sampler_node(self, workflow: Dict) -> Optional[str]: - """ - Find the primary sampler node in the workflow. - - Priority: - 1. First try to find a SamplerCustomAdvanced node - 2. If not found, look for KSampler nodes with denoise=1.0 - 3. If still not found, use the first KSampler node - - Args: - workflow: The workflow data as a dictionary - - Returns: - The node ID of the primary sampler node, or None if not found - """ - # First check for SamplerCustomAdvanced nodes - sampler_advanced_nodes = [] - ksampler_nodes = [] - - # Scan workflow for sampler nodes - for node_id, node_data in workflow.items(): - node_type = node_data.get("class_type") - - if node_type == "SamplerCustomAdvanced": - sampler_advanced_nodes.append(node_id) - elif node_type == "KSampler": - ksampler_nodes.append(node_id) - - # If we found SamplerCustomAdvanced nodes, return the first one - if sampler_advanced_nodes: - logger.debug(f"Found SamplerCustomAdvanced node: {sampler_advanced_nodes[0]}") - return sampler_advanced_nodes[0] - - # If we have KSampler nodes, look for one with denoise=1.0 - if ksampler_nodes: - for node_id in ksampler_nodes: - node_data = workflow[node_id] - inputs = node_data.get("inputs", {}) - denoise = inputs.get("denoise", 0) - - # Check if denoise is 1.0 (allowing for small floating point differences) - if abs(float(denoise) - 1.0) < 0.001: - logger.debug(f"Found KSampler node with denoise=1.0: {node_id}") - return node_id - - # If no KSampler with denoise=1.0 found, use the first one - logger.debug(f"No KSampler with denoise=1.0 found, using first KSampler: {ksampler_nodes[0]}") - return ksampler_nodes[0] - - # No sampler nodes found - logger.warning("No sampler nodes found in workflow") - return None - - def parse_workflow(self, workflow_data: Union[str, Dict], output_path: Optional[str] = None) -> Dict: - """ - Parse the workflow and extract generation parameters - - Args: - workflow_data: The workflow data as a dictionary or a file path - output_path: Optional path to save the output JSON - - Returns: - Dictionary containing extracted parameters - """ - # Load workflow from file if needed - if isinstance(workflow_data, str): - workflow = load_workflow(workflow_data) - else: - workflow = workflow_data - - # Reset the processed nodes tracker and cache - self.processed_nodes = set() - self.node_results_cache = {} - - # Find the primary sampler node - sampler_node_id = self.find_primary_sampler_node(workflow) - if not sampler_node_id: - logger.warning("No suitable sampler node found in workflow") - return {} - - # Process sampler node to extract parameters - sampler_result = self.process_node(sampler_node_id, workflow) - if not sampler_result: - return {} - - # Return the sampler result directly - it's already in the format we need - # This simplifies the structure and makes it easier to use in recipe_routes.py - - # Handle standard ComfyUI names vs our output format - if "cfg" in sampler_result: - sampler_result["cfg_scale"] = sampler_result.pop("cfg") - - # Add clip_skip = 1 to match reference output if not already present - if "clip_skip" not in sampler_result: - sampler_result["clip_skip"] = "1" - - # Ensure the prompt is a string and not a nested dictionary - if "prompt" in sampler_result and isinstance(sampler_result["prompt"], dict): - if "prompt" in sampler_result["prompt"]: - sampler_result["prompt"] = sampler_result["prompt"]["prompt"] - - # Save the result if requested - if output_path: - save_output(sampler_result, output_path) - - return sampler_result - - -def parse_workflow(workflow_path: str, output_path: Optional[str] = None) -> Dict: - """ - Parse a ComfyUI workflow file and extract generation parameters - - Args: - workflow_path: Path to the workflow JSON file - output_path: Optional path to save the output JSON - - Returns: - Dictionary containing extracted parameters - """ - parser = WorkflowParser() - return parser.parse_workflow(workflow_path, output_path) \ No newline at end of file diff --git a/py/workflow/test.py b/py/workflow/test.py deleted file mode 100644 index 0b14673e..00000000 --- a/py/workflow/test.py +++ /dev/null @@ -1,63 +0,0 @@ -""" -Test script for the ComfyUI workflow parser -""" -import os -import json -import logging -from .parser import parse_workflow - -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', - handlers=[logging.StreamHandler()] -) -logger = logging.getLogger(__name__) - -# Configure paths -SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) -ROOT_DIR = os.path.abspath(os.path.join(SCRIPT_DIR, '..', '..')) -REFS_DIR = os.path.join(ROOT_DIR, 'refs') -OUTPUT_DIR = os.path.join(ROOT_DIR, 'output') - -def test_parse_flux_workflow(): - """Test parsing the flux example workflow""" - # Ensure output directory exists - os.makedirs(OUTPUT_DIR, exist_ok=True) - - # Define input and output paths - input_path = os.path.join(REFS_DIR, 'flux_prompt.json') - output_path = os.path.join(OUTPUT_DIR, 'parsed_flux_output.json') - - # Parse workflow - logger.info(f"Parsing workflow: {input_path}") - result = parse_workflow(input_path, output_path) - - # Print result summary - logger.info(f"Output saved to: {output_path}") - logger.info(f"Parsing completed. Result summary:") - logger.info(f" LoRAs: {result.get('loras', '')}") - - gen_params = result.get('gen_params', {}) - logger.info(f" Prompt: {gen_params.get('prompt', '')[:50]}...") - logger.info(f" Steps: {gen_params.get('steps', '')}") - logger.info(f" Sampler: {gen_params.get('sampler', '')}") - logger.info(f" Size: {gen_params.get('size', '')}") - - # Compare with reference output - ref_output_path = os.path.join(REFS_DIR, 'flux_output.json') - try: - with open(ref_output_path, 'r') as f: - ref_output = json.load(f) - - # Simple validation - loras_match = result.get('loras', '') == ref_output.get('loras', '') - prompt_match = gen_params.get('prompt', '') == ref_output.get('gen_params', {}).get('prompt', '') - - logger.info(f"Validation against reference:") - logger.info(f" LoRAs match: {loras_match}") - logger.info(f" Prompt match: {prompt_match}") - except Exception as e: - logger.warning(f"Failed to compare with reference output: {e}") - -if __name__ == "__main__": - test_parse_flux_workflow() \ No newline at end of file diff --git a/py/workflow/utils.py b/py/workflow/utils.py deleted file mode 100644 index aaa333ea..00000000 --- a/py/workflow/utils.py +++ /dev/null @@ -1,120 +0,0 @@ -""" -Utility functions for ComfyUI workflow parsing -""" -import json -import os -import logging -from typing import Dict, List, Any, Optional, Union, Set, Tuple - -logger = logging.getLogger(__name__) - -def load_workflow(workflow_path: str) -> Dict: - """Load a workflow from a JSON file""" - try: - with open(workflow_path, 'r', encoding='utf-8') as f: - return json.load(f) - except Exception as e: - logger.error(f"Error loading workflow from {workflow_path}: {e}") - raise - -def save_output(output: Dict, output_path: str) -> None: - """Save the parsed output to a JSON file""" - os.makedirs(os.path.dirname(os.path.abspath(output_path)), exist_ok=True) - try: - with open(output_path, 'w', encoding='utf-8') as f: - json.dump(output, f, indent=4) - except Exception as e: - logger.error(f"Error saving output to {output_path}: {e}") - raise - -def find_node_by_type(workflow: Dict, node_type: str) -> Optional[str]: - """Find a node of the specified type in the workflow""" - for node_id, node_data in workflow.items(): - if node_data.get("class_type") == node_type: - return node_id - return None - -def find_nodes_by_type(workflow: Dict, node_type: str) -> List[str]: - """Find all nodes of the specified type in the workflow""" - return [node_id for node_id, node_data in workflow.items() - if node_data.get("class_type") == node_type] - -def get_input_node_ids(workflow: Dict, node_id: str) -> Dict[str, Tuple[str, int]]: - """ - Get the node IDs for all inputs of the given node - - Returns a dictionary mapping input names to (node_id, output_slot) tuples - """ - result = {} - if node_id not in workflow: - return result - - node_data = workflow[node_id] - for input_name, input_value in node_data.get("inputs", {}).items(): - # Check if this input is connected to another node - if isinstance(input_value, list) and len(input_value) == 2: - # Input is connected to another node's output - # Format: [node_id, output_slot] - ref_node_id, output_slot = input_value - result[input_name] = (str(ref_node_id), output_slot) - - return result - -def trace_model_path(workflow: Dict, start_node_id: str) -> List[str]: - """ - Trace the model path backward from KSampler to find all LoRA nodes - - Args: - workflow: The workflow data - start_node_id: The starting node ID (usually KSampler) - - Returns: - List of node IDs in the model path - """ - model_path_nodes = [] - - # Get the model input from the start node - if start_node_id not in workflow: - return model_path_nodes - - # Track visited nodes to avoid cycles - visited = set() - - # Stack for depth-first search - stack = [] - - # Get model input reference if available - start_node = workflow[start_node_id] - if "inputs" in start_node and "model" in start_node["inputs"] and isinstance(start_node["inputs"]["model"], list): - model_ref = start_node["inputs"]["model"] - stack.append(str(model_ref[0])) - - # Perform depth-first search - while stack: - node_id = stack.pop() - - # Skip if already visited - if node_id in visited: - continue - - # Mark as visited - visited.add(node_id) - - # Skip if node doesn't exist - if node_id not in workflow: - continue - - node = workflow[node_id] - node_type = node.get("class_type", "") - - # Add current node to result list if it's a LoRA node - if "Lora" in node_type: - model_path_nodes.append(node_id) - - # Add all input nodes that have a "model" or "lora_stack" output to the stack - if "inputs" in node: - for input_name, input_value in node["inputs"].items(): - if input_name in ["model", "lora_stack"] and isinstance(input_value, list) and len(input_value) == 2: - stack.append(str(input_value[0])) - - return model_path_nodes \ No newline at end of file