refactor: remove workflow parsing module and associated files for cleanup

This commit is contained in:
Will Miao
2025-05-07 17:13:30 +08:00
parent 1c6e9d0b69
commit e3bf1f763c
11 changed files with 0 additions and 1132 deletions

View File

@@ -1,26 +0,0 @@
from aiohttp import web
from server import PromptServer
from .nodes.utils import get_lora_info
@PromptServer.instance.routes.post("/loramanager/get_trigger_words")
async def get_trigger_words(request):
json_data = await request.json()
lora_names = json_data.get("lora_names", [])
node_ids = json_data.get("node_ids", [])
all_trigger_words = []
for lora_name in lora_names:
_, trigger_words = await get_lora_info(lora_name)
all_trigger_words.extend(trigger_words)
# Format the trigger words
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
# Send update to all connected trigger word toggle nodes
for node_id in node_ids:
PromptServer.instance.send_sync("trigger_word_update", {
"id": node_id,
"message": trigger_words_text
})
return web.json_response({"success": True})

View File

@@ -1,3 +0,0 @@
"""
ComfyUI workflow parsing module to extract generation parameters
"""

View File

@@ -1,58 +0,0 @@
"""
Command-line interface for the ComfyUI workflow parser
"""
import argparse
import json
import os
import logging
import sys
from .parser import parse_workflow
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[logging.StreamHandler()]
)
logger = logging.getLogger(__name__)
def main():
"""Entry point for the CLI"""
parser = argparse.ArgumentParser(description='Parse ComfyUI workflow files')
parser.add_argument('input', help='Input workflow JSON file path')
parser.add_argument('-o', '--output', help='Output JSON file path')
parser.add_argument('-p', '--pretty', action='store_true', help='Pretty print JSON output')
parser.add_argument('--debug', action='store_true', help='Enable debug logging')
args = parser.parse_args()
# Set logging level
if args.debug:
logging.getLogger().setLevel(logging.DEBUG)
# Validate input file
if not os.path.isfile(args.input):
logger.error(f"Input file not found: {args.input}")
sys.exit(1)
# Parse workflow
try:
result = parse_workflow(args.input, args.output)
# Print result to console if output file not specified
if not args.output:
if args.pretty:
print(json.dumps(result, indent=4))
else:
print(json.dumps(result))
else:
logger.info(f"Output saved to: {args.output}")
except Exception as e:
logger.error(f"Error parsing workflow: {e}")
if args.debug:
import traceback
traceback.print_exc()
sys.exit(1)
if __name__ == "__main__":
main()

View File

@@ -1,3 +0,0 @@
"""
Extension directory for custom node mappers
"""

View File

@@ -1,285 +0,0 @@
"""
ComfyUI Core nodes mappers extension for workflow parsing
"""
import logging
from typing import Dict, Any, List
logger = logging.getLogger(__name__)
# =============================================================================
# Transform Functions
# =============================================================================
def transform_random_noise(inputs: Dict) -> Dict:
"""Transform function for RandomNoise node"""
return {"seed": str(inputs.get("noise_seed", ""))}
def transform_ksampler_select(inputs: Dict) -> Dict:
"""Transform function for KSamplerSelect node"""
return {"sampler": inputs.get("sampler_name", "")}
def transform_basic_scheduler(inputs: Dict) -> Dict:
"""Transform function for BasicScheduler node"""
result = {
"scheduler": inputs.get("scheduler", ""),
"denoise": str(inputs.get("denoise", "1.0"))
}
# Get steps from inputs or steps input
if "steps" in inputs:
if isinstance(inputs["steps"], str):
result["steps"] = inputs["steps"]
elif isinstance(inputs["steps"], dict) and "value" in inputs["steps"]:
result["steps"] = str(inputs["steps"]["value"])
else:
result["steps"] = str(inputs["steps"])
return result
def transform_basic_guider(inputs: Dict) -> Dict:
"""Transform function for BasicGuider node"""
result = {}
# Process conditioning
if "conditioning" in inputs:
if isinstance(inputs["conditioning"], str):
result["prompt"] = inputs["conditioning"]
elif isinstance(inputs["conditioning"], dict):
result["conditioning"] = inputs["conditioning"]
# Get model information if needed
if "model" in inputs and isinstance(inputs["model"], dict):
result["model"] = inputs["model"]
return result
def transform_model_sampling_flux(inputs: Dict) -> Dict:
"""Transform function for ModelSamplingFlux - mostly a pass-through node"""
# This node is primarily used for routing, so we mostly pass through values
return inputs["model"]
def transform_sampler_custom_advanced(inputs: Dict) -> Dict:
"""Transform function for SamplerCustomAdvanced node"""
result = {}
# Extract seed from noise
if "noise" in inputs and isinstance(inputs["noise"], dict):
result["seed"] = str(inputs["noise"].get("seed", ""))
# Extract sampler info
if "sampler" in inputs and isinstance(inputs["sampler"], dict):
sampler = inputs["sampler"].get("sampler", "")
if sampler:
result["sampler"] = sampler
# Extract scheduler, steps, denoise from sigmas
if "sigmas" in inputs and isinstance(inputs["sigmas"], dict):
sigmas = inputs["sigmas"]
result["scheduler"] = sigmas.get("scheduler", "")
result["steps"] = str(sigmas.get("steps", ""))
result["denoise"] = str(sigmas.get("denoise", "1.0"))
# Extract prompt and guidance from guider
if "guider" in inputs and isinstance(inputs["guider"], dict):
guider = inputs["guider"]
# Get prompt from conditioning
if "conditioning" in guider and isinstance(guider["conditioning"], str):
result["prompt"] = guider["conditioning"]
elif "conditioning" in guider and isinstance(guider["conditioning"], dict):
result["guidance"] = guider["conditioning"].get("guidance", "")
result["prompt"] = guider["conditioning"].get("prompt", "")
if "model" in guider and isinstance(guider["model"], dict):
result["checkpoint"] = guider["model"].get("checkpoint", "")
result["loras"] = guider["model"].get("loras", "")
result["clip_skip"] = str(int(guider["model"].get("clip_skip", "-1")) * -1)
# Extract dimensions from latent_image
if "latent_image" in inputs and isinstance(inputs["latent_image"], dict):
latent = inputs["latent_image"]
width = latent.get("width", 0)
height = latent.get("height", 0)
if width and height:
result["width"] = width
result["height"] = height
result["size"] = f"{width}x{height}"
return result
def transform_ksampler(inputs: Dict) -> Dict:
"""Transform function for KSampler nodes"""
result = {
"seed": str(inputs.get("seed", "")),
"steps": str(inputs.get("steps", "")),
"cfg": str(inputs.get("cfg", "")),
"sampler": inputs.get("sampler_name", ""),
"scheduler": inputs.get("scheduler", ""),
}
# Process positive prompt
if "positive" in inputs:
result["prompt"] = inputs["positive"]
# Process negative prompt
if "negative" in inputs:
result["negative_prompt"] = inputs["negative"]
# Get dimensions from latent image
if "latent_image" in inputs and isinstance(inputs["latent_image"], dict):
width = inputs["latent_image"].get("width", 0)
height = inputs["latent_image"].get("height", 0)
if width and height:
result["size"] = f"{width}x{height}"
# Add clip_skip if present
if "clip_skip" in inputs:
result["clip_skip"] = str(inputs.get("clip_skip", ""))
# Add guidance if present
if "guidance" in inputs:
result["guidance"] = str(inputs.get("guidance", ""))
# Add model if present
if "model" in inputs:
result["checkpoint"] = inputs.get("model", {}).get("checkpoint", "")
result["loras"] = inputs.get("model", {}).get("loras", "")
result["clip_skip"] = str(inputs.get("model", {}).get("clip_skip", -1) * -1)
return result
def transform_empty_latent(inputs: Dict) -> Dict:
"""Transform function for EmptyLatentImage nodes"""
width = inputs.get("width", 0)
height = inputs.get("height", 0)
return {"width": width, "height": height, "size": f"{width}x{height}"}
def transform_clip_text(inputs: Dict) -> Any:
"""Transform function for CLIPTextEncode nodes"""
return inputs.get("text", "")
def transform_flux_guidance(inputs: Dict) -> Dict:
"""Transform function for FluxGuidance nodes"""
result = {}
if "guidance" in inputs:
result["guidance"] = inputs["guidance"]
if "conditioning" in inputs:
conditioning = inputs["conditioning"]
if isinstance(conditioning, str):
result["prompt"] = conditioning
else:
result["prompt"] = "Unknown prompt"
return result
def transform_unet_loader(inputs: Dict) -> Dict:
"""Transform function for UNETLoader node"""
unet_name = inputs.get("unet_name", "")
return {"checkpoint": unet_name} if unet_name else {}
def transform_checkpoint_loader(inputs: Dict) -> Dict:
"""Transform function for CheckpointLoaderSimple node"""
ckpt_name = inputs.get("ckpt_name", "")
return {"checkpoint": ckpt_name} if ckpt_name else {}
def transform_latent_upscale_by(inputs: Dict) -> Dict:
"""Transform function for LatentUpscaleBy node"""
result = {}
width = inputs["samples"].get("width", 0) * inputs["scale_by"]
height = inputs["samples"].get("height", 0) * inputs["scale_by"]
result["width"] = width
result["height"] = height
result["size"] = f"{width}x{height}"
return result
def transform_clip_set_last_layer(inputs: Dict) -> Dict:
"""Transform function for CLIPSetLastLayer node"""
result = {}
if "stop_at_clip_layer" in inputs:
result["clip_skip"] = inputs["stop_at_clip_layer"]
return result
# =============================================================================
# Node Mapper Definitions
# =============================================================================
# Define the mappers for ComfyUI core nodes not in main mapper
NODE_MAPPERS_EXT = {
# KSamplers
"SamplerCustomAdvanced": {
"inputs_to_track": ["noise", "guider", "sampler", "sigmas", "latent_image"],
"transform_func": transform_sampler_custom_advanced
},
"KSampler": {
"inputs_to_track": [
"seed", "steps", "cfg", "sampler_name", "scheduler",
"denoise", "positive", "negative", "latent_image",
"model", "clip_skip"
],
"transform_func": transform_ksampler
},
# ComfyUI core nodes
"EmptyLatentImage": {
"inputs_to_track": ["width", "height", "batch_size"],
"transform_func": transform_empty_latent
},
"EmptySD3LatentImage": {
"inputs_to_track": ["width", "height", "batch_size"],
"transform_func": transform_empty_latent
},
"CLIPTextEncode": {
"inputs_to_track": ["text", "clip"],
"transform_func": transform_clip_text
},
"FluxGuidance": {
"inputs_to_track": ["guidance", "conditioning"],
"transform_func": transform_flux_guidance
},
"RandomNoise": {
"inputs_to_track": ["noise_seed"],
"transform_func": transform_random_noise
},
"KSamplerSelect": {
"inputs_to_track": ["sampler_name"],
"transform_func": transform_ksampler_select
},
"BasicScheduler": {
"inputs_to_track": ["scheduler", "steps", "denoise", "model"],
"transform_func": transform_basic_scheduler
},
"BasicGuider": {
"inputs_to_track": ["model", "conditioning"],
"transform_func": transform_basic_guider
},
"ModelSamplingFlux": {
"inputs_to_track": ["max_shift", "base_shift", "width", "height", "model"],
"transform_func": transform_model_sampling_flux
},
"UNETLoader": {
"inputs_to_track": ["unet_name"],
"transform_func": transform_unet_loader
},
"CheckpointLoaderSimple": {
"inputs_to_track": ["ckpt_name"],
"transform_func": transform_checkpoint_loader
},
"LatentUpscale": {
"inputs_to_track": ["width", "height"],
"transform_func": transform_empty_latent
},
"LatentUpscaleBy": {
"inputs_to_track": ["samples", "scale_by"],
"transform_func": transform_latent_upscale_by
},
"CLIPSetLastLayer": {
"inputs_to_track": ["clip", "stop_at_clip_layer"],
"transform_func": transform_clip_set_last_layer
}
}

View File

@@ -1,74 +0,0 @@
"""
KJNodes mappers extension for ComfyUI workflow parsing
"""
import logging
import re
from typing import Dict, Any
logger = logging.getLogger(__name__)
# =============================================================================
# Transform Functions
# =============================================================================
def transform_join_strings(inputs: Dict) -> str:
"""Transform function for JoinStrings nodes"""
string1 = inputs.get("string1", "")
string2 = inputs.get("string2", "")
delimiter = inputs.get("delimiter", "")
return f"{string1}{delimiter}{string2}"
def transform_string_constant(inputs: Dict) -> str:
"""Transform function for StringConstant nodes"""
return inputs.get("string", "")
def transform_empty_latent_presets(inputs: Dict) -> Dict:
"""Transform function for EmptyLatentImagePresets nodes"""
dimensions = inputs.get("dimensions", "")
invert = inputs.get("invert", False)
# Extract width and height from dimensions string
# Expected format: "width x height (ratio)" or similar
width = 0
height = 0
if dimensions:
# Try to extract dimensions using regex
match = re.search(r'(\d+)\s*x\s*(\d+)', dimensions)
if match:
width = int(match.group(1))
height = int(match.group(2))
# If invert is True, swap width and height
if invert and width and height:
width, height = height, width
return {"width": width, "height": height, "size": f"{width}x{height}"}
def transform_int_constant(inputs: Dict) -> int:
"""Transform function for INTConstant nodes"""
return inputs.get("value", 0)
# =============================================================================
# Node Mapper Definitions
# =============================================================================
# Define the mappers for KJNodes
NODE_MAPPERS_EXT = {
"JoinStrings": {
"inputs_to_track": ["string1", "string2", "delimiter"],
"transform_func": transform_join_strings
},
"StringConstantMultiline": {
"inputs_to_track": ["string"],
"transform_func": transform_string_constant
},
"EmptyLatentImagePresets": {
"inputs_to_track": ["dimensions", "invert", "batch_size"],
"transform_func": transform_empty_latent_presets
},
"INTConstant": {
"inputs_to_track": ["value"],
"transform_func": transform_int_constant
}
}

View File

@@ -1,37 +0,0 @@
"""
Main entry point for the workflow parser module
"""
import os
import sys
import logging
from typing import Dict, Optional, Union
# Add the parent directory to sys.path to enable imports
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
ROOT_DIR = os.path.abspath(os.path.join(SCRIPT_DIR, '..', '..'))
sys.path.insert(0, os.path.dirname(SCRIPT_DIR))
from .parser import parse_workflow
logger = logging.getLogger(__name__)
def parse_comfyui_workflow(
workflow_path: str,
output_path: Optional[str] = None
) -> Dict:
"""
Parse a ComfyUI workflow file and extract generation parameters
Args:
workflow_path: Path to the workflow JSON file
output_path: Optional path to save the output JSON
Returns:
Dictionary containing extracted parameters
"""
return parse_workflow(workflow_path, output_path)
if __name__ == "__main__":
# If run directly, use the CLI
from .cli import main
main()

View File

@@ -1,282 +0,0 @@
"""
Node mappers for ComfyUI workflow parsing
"""
import logging
import os
import importlib.util
import inspect
from typing import Dict, List, Any, Optional, Union, Type, Callable, Tuple
logger = logging.getLogger(__name__)
# Global mapper registry
_MAPPER_REGISTRY: Dict[str, Dict] = {}
# =============================================================================
# Mapper Definition Functions
# =============================================================================
def create_mapper(
node_type: str,
inputs_to_track: List[str],
transform_func: Callable[[Dict], Any] = None
) -> Dict:
"""Create a mapper definition for a node type"""
mapper = {
"node_type": node_type,
"inputs_to_track": inputs_to_track,
"transform": transform_func or (lambda inputs: inputs)
}
return mapper
def register_mapper(mapper: Dict) -> None:
"""Register a node mapper in the global registry"""
_MAPPER_REGISTRY[mapper["node_type"]] = mapper
logger.debug(f"Registered mapper for node type: {mapper['node_type']}")
def get_mapper(node_type: str) -> Optional[Dict]:
"""Get a mapper for the specified node type"""
return _MAPPER_REGISTRY.get(node_type)
def get_all_mappers() -> Dict[str, Dict]:
"""Get all registered mappers"""
return _MAPPER_REGISTRY.copy()
# =============================================================================
# Node Processing Function
# =============================================================================
def process_node(node_id: str, node_data: Dict, workflow: Dict, parser: 'WorkflowParser') -> Any: # type: ignore
"""Process a node using its mapper and extract relevant information"""
node_type = node_data.get("class_type")
mapper = get_mapper(node_type)
if not mapper:
logger.warning(f"No mapper found for node type: {node_type}")
return None
result = {}
# Extract inputs based on the mapper's tracked inputs
for input_name in mapper["inputs_to_track"]:
if input_name in node_data.get("inputs", {}):
input_value = node_data["inputs"][input_name]
# Check if input is a reference to another node's output
if isinstance(input_value, list) and len(input_value) == 2:
try:
# Format is [node_id, output_slot]
ref_node_id, output_slot = input_value
# Convert node_id to string if it's an integer
if isinstance(ref_node_id, int):
ref_node_id = str(ref_node_id)
# Recursively process the referenced node
ref_value = parser.process_node(ref_node_id, workflow)
if ref_value is not None:
result[input_name] = ref_value
else:
# If we couldn't get a value from the reference, store the raw value
result[input_name] = input_value
except Exception as e:
logger.error(f"Error processing reference in node {node_id}, input {input_name}: {e}")
result[input_name] = input_value
else:
# Direct value
result[input_name] = input_value
# Apply the transform function
try:
return mapper["transform"](result)
except Exception as e:
logger.error(f"Error in transform function for node {node_id} of type {node_type}: {e}")
return result
# =============================================================================
# Transform Functions
# =============================================================================
def transform_lora_loader(inputs: Dict) -> Dict:
"""Transform function for LoraLoader nodes"""
loras_data = inputs.get("loras", [])
lora_stack = inputs.get("lora_stack", {}).get("lora_stack", [])
lora_texts = []
# Process loras array
if isinstance(loras_data, dict) and "__value__" in loras_data:
loras_list = loras_data["__value__"]
elif isinstance(loras_data, list):
loras_list = loras_data
else:
loras_list = []
# Process each active lora entry
for lora in loras_list:
if isinstance(lora, dict) and lora.get("active", False):
lora_name = lora.get("name", "")
strength = lora.get("strength", 1.0)
lora_texts.append(f"<lora:{lora_name}:{strength}>")
# Process lora_stack if valid
if lora_stack and isinstance(lora_stack, list):
if not (len(lora_stack) == 2 and isinstance(lora_stack[0], (str, int)) and isinstance(lora_stack[1], int)):
for stack_entry in lora_stack:
lora_name = stack_entry[0]
strength = stack_entry[1]
lora_texts.append(f"<lora:{lora_name}:{strength}>")
result = {
"checkpoint": inputs.get("model", {}).get("checkpoint", ""),
"loras": " ".join(lora_texts)
}
if "clip" in inputs and isinstance(inputs["clip"], dict):
result["clip_skip"] = inputs["clip"].get("clip_skip", "-1")
return result
def transform_lora_stacker(inputs: Dict) -> Dict:
"""Transform function for LoraStacker nodes"""
loras_data = inputs.get("loras", [])
result_stack = []
# Handle existing stack entries
existing_stack = []
lora_stack_input = inputs.get("lora_stack", [])
if isinstance(lora_stack_input, dict) and "lora_stack" in lora_stack_input:
existing_stack = lora_stack_input["lora_stack"]
elif isinstance(lora_stack_input, list):
if not (len(lora_stack_input) == 2 and isinstance(lora_stack_input[0], (str, int)) and
isinstance(lora_stack_input[1], int)):
existing_stack = lora_stack_input
# Add existing entries
if existing_stack:
result_stack.extend(existing_stack)
# Process new loras
if isinstance(loras_data, dict) and "__value__" in loras_data:
loras_list = loras_data["__value__"]
elif isinstance(loras_data, list):
loras_list = loras_data
else:
loras_list = []
for lora in loras_list:
if isinstance(lora, dict) and lora.get("active", False):
lora_name = lora.get("name", "")
strength = float(lora.get("strength", 1.0))
result_stack.append((lora_name, strength))
return {"lora_stack": result_stack}
def transform_trigger_word_toggle(inputs: Dict) -> str:
"""Transform function for TriggerWordToggle nodes"""
toggle_data = inputs.get("toggle_trigger_words", [])
if isinstance(toggle_data, dict) and "__value__" in toggle_data:
toggle_words = toggle_data["__value__"]
elif isinstance(toggle_data, list):
toggle_words = toggle_data
else:
toggle_words = []
# Filter active trigger words
active_words = []
for item in toggle_words:
if isinstance(item, dict) and item.get("active", False):
word = item.get("text", "")
if word and not word.startswith("__dummy"):
active_words.append(word)
return ", ".join(active_words)
# =============================================================================
# Node Mapper Definitions
# =============================================================================
# Central definition of all supported node types and their configurations
NODE_MAPPERS = {
# LoraManager nodes
"Lora Loader (LoraManager)": {
"inputs_to_track": ["model", "clip", "loras", "lora_stack"],
"transform_func": transform_lora_loader
},
"Lora Stacker (LoraManager)": {
"inputs_to_track": ["loras", "lora_stack"],
"transform_func": transform_lora_stacker
},
"TriggerWord Toggle (LoraManager)": {
"inputs_to_track": ["toggle_trigger_words"],
"transform_func": transform_trigger_word_toggle
}
}
def register_all_mappers() -> None:
"""Register all mappers from the NODE_MAPPERS dictionary"""
for node_type, config in NODE_MAPPERS.items():
mapper = create_mapper(
node_type=node_type,
inputs_to_track=config["inputs_to_track"],
transform_func=config["transform_func"]
)
register_mapper(mapper)
logger.info(f"Registered {len(NODE_MAPPERS)} node mappers")
# =============================================================================
# Extension Loading
# =============================================================================
def load_extensions(ext_dir: str = None) -> None:
"""
Load mapper extensions from the specified directory
Extension files should define a NODE_MAPPERS_EXT dictionary containing mapper configurations.
These will be added to the global NODE_MAPPERS dictionary and registered automatically.
"""
# Use default path if none provided
if ext_dir is None:
# Get the directory of this file
current_dir = os.path.dirname(os.path.abspath(__file__))
ext_dir = os.path.join(current_dir, 'ext')
# Ensure the extension directory exists
if not os.path.exists(ext_dir):
os.makedirs(ext_dir, exist_ok=True)
logger.info(f"Created extension directory: {ext_dir}")
return
# Load each Python file in the extension directory
for filename in os.listdir(ext_dir):
if filename.endswith('.py') and not filename.startswith('_'):
module_path = os.path.join(ext_dir, filename)
module_name = f"workflow.ext.{filename[:-3]}" # Remove .py
try:
# Load the module
spec = importlib.util.spec_from_file_location(module_name, module_path)
if spec and spec.loader:
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
# Check if the module defines NODE_MAPPERS_EXT
if hasattr(module, 'NODE_MAPPERS_EXT'):
# Add the extension mappers to the global NODE_MAPPERS dictionary
NODE_MAPPERS.update(module.NODE_MAPPERS_EXT)
logger.info(f"Added {len(module.NODE_MAPPERS_EXT)} mappers from extension: {filename}")
else:
logger.warning(f"Extension {filename} does not define NODE_MAPPERS_EXT dictionary")
except Exception as e:
logger.warning(f"Error loading extension {filename}: {e}")
# Re-register all mappers after loading extensions
register_all_mappers()
# Initialize the registry with default mappers
# register_default_mappers()

View File

@@ -1,181 +0,0 @@
"""
Main workflow parser implementation for ComfyUI
"""
import json
import logging
from typing import Dict, List, Any, Optional, Union, Set
from .mappers import get_mapper, get_all_mappers, load_extensions, process_node
from .utils import (
load_workflow, save_output, find_node_by_type,
trace_model_path
)
logger = logging.getLogger(__name__)
class WorkflowParser:
"""Parser for ComfyUI workflows"""
def __init__(self):
"""Initialize the parser with mappers"""
self.processed_nodes: Set[str] = set() # Track processed nodes to avoid cycles
self.node_results_cache: Dict[str, Any] = {} # Cache for processed node results
# Load extensions
load_extensions()
def process_node(self, node_id: str, workflow: Dict) -> Any:
"""Process a single node and extract relevant information"""
# Return cached result if available
if node_id in self.node_results_cache:
return self.node_results_cache[node_id]
# Check if we're in a cycle
if node_id in self.processed_nodes:
return None
# Mark this node as being processed (to detect cycles)
self.processed_nodes.add(node_id)
if node_id not in workflow:
self.processed_nodes.remove(node_id)
return None
node_data = workflow[node_id]
node_type = node_data.get("class_type")
result = None
if get_mapper(node_type):
try:
result = process_node(node_id, node_data, workflow, self)
# Cache the result
self.node_results_cache[node_id] = result
except Exception as e:
logger.error(f"Error processing node {node_id} of type {node_type}: {e}", exc_info=True)
# Return a partial result or None depending on how we want to handle errors
result = {}
# Remove node from processed set to allow it to be processed again in a different context
self.processed_nodes.remove(node_id)
return result
def find_primary_sampler_node(self, workflow: Dict) -> Optional[str]:
"""
Find the primary sampler node in the workflow.
Priority:
1. First try to find a SamplerCustomAdvanced node
2. If not found, look for KSampler nodes with denoise=1.0
3. If still not found, use the first KSampler node
Args:
workflow: The workflow data as a dictionary
Returns:
The node ID of the primary sampler node, or None if not found
"""
# First check for SamplerCustomAdvanced nodes
sampler_advanced_nodes = []
ksampler_nodes = []
# Scan workflow for sampler nodes
for node_id, node_data in workflow.items():
node_type = node_data.get("class_type")
if node_type == "SamplerCustomAdvanced":
sampler_advanced_nodes.append(node_id)
elif node_type == "KSampler":
ksampler_nodes.append(node_id)
# If we found SamplerCustomAdvanced nodes, return the first one
if sampler_advanced_nodes:
logger.debug(f"Found SamplerCustomAdvanced node: {sampler_advanced_nodes[0]}")
return sampler_advanced_nodes[0]
# If we have KSampler nodes, look for one with denoise=1.0
if ksampler_nodes:
for node_id in ksampler_nodes:
node_data = workflow[node_id]
inputs = node_data.get("inputs", {})
denoise = inputs.get("denoise", 0)
# Check if denoise is 1.0 (allowing for small floating point differences)
if abs(float(denoise) - 1.0) < 0.001:
logger.debug(f"Found KSampler node with denoise=1.0: {node_id}")
return node_id
# If no KSampler with denoise=1.0 found, use the first one
logger.debug(f"No KSampler with denoise=1.0 found, using first KSampler: {ksampler_nodes[0]}")
return ksampler_nodes[0]
# No sampler nodes found
logger.warning("No sampler nodes found in workflow")
return None
def parse_workflow(self, workflow_data: Union[str, Dict], output_path: Optional[str] = None) -> Dict:
"""
Parse the workflow and extract generation parameters
Args:
workflow_data: The workflow data as a dictionary or a file path
output_path: Optional path to save the output JSON
Returns:
Dictionary containing extracted parameters
"""
# Load workflow from file if needed
if isinstance(workflow_data, str):
workflow = load_workflow(workflow_data)
else:
workflow = workflow_data
# Reset the processed nodes tracker and cache
self.processed_nodes = set()
self.node_results_cache = {}
# Find the primary sampler node
sampler_node_id = self.find_primary_sampler_node(workflow)
if not sampler_node_id:
logger.warning("No suitable sampler node found in workflow")
return {}
# Process sampler node to extract parameters
sampler_result = self.process_node(sampler_node_id, workflow)
if not sampler_result:
return {}
# Return the sampler result directly - it's already in the format we need
# This simplifies the structure and makes it easier to use in recipe_routes.py
# Handle standard ComfyUI names vs our output format
if "cfg" in sampler_result:
sampler_result["cfg_scale"] = sampler_result.pop("cfg")
# Add clip_skip = 1 to match reference output if not already present
if "clip_skip" not in sampler_result:
sampler_result["clip_skip"] = "1"
# Ensure the prompt is a string and not a nested dictionary
if "prompt" in sampler_result and isinstance(sampler_result["prompt"], dict):
if "prompt" in sampler_result["prompt"]:
sampler_result["prompt"] = sampler_result["prompt"]["prompt"]
# Save the result if requested
if output_path:
save_output(sampler_result, output_path)
return sampler_result
def parse_workflow(workflow_path: str, output_path: Optional[str] = None) -> Dict:
"""
Parse a ComfyUI workflow file and extract generation parameters
Args:
workflow_path: Path to the workflow JSON file
output_path: Optional path to save the output JSON
Returns:
Dictionary containing extracted parameters
"""
parser = WorkflowParser()
return parser.parse_workflow(workflow_path, output_path)

View File

@@ -1,63 +0,0 @@
"""
Test script for the ComfyUI workflow parser
"""
import os
import json
import logging
from .parser import parse_workflow
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[logging.StreamHandler()]
)
logger = logging.getLogger(__name__)
# Configure paths
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
ROOT_DIR = os.path.abspath(os.path.join(SCRIPT_DIR, '..', '..'))
REFS_DIR = os.path.join(ROOT_DIR, 'refs')
OUTPUT_DIR = os.path.join(ROOT_DIR, 'output')
def test_parse_flux_workflow():
"""Test parsing the flux example workflow"""
# Ensure output directory exists
os.makedirs(OUTPUT_DIR, exist_ok=True)
# Define input and output paths
input_path = os.path.join(REFS_DIR, 'flux_prompt.json')
output_path = os.path.join(OUTPUT_DIR, 'parsed_flux_output.json')
# Parse workflow
logger.info(f"Parsing workflow: {input_path}")
result = parse_workflow(input_path, output_path)
# Print result summary
logger.info(f"Output saved to: {output_path}")
logger.info(f"Parsing completed. Result summary:")
logger.info(f" LoRAs: {result.get('loras', '')}")
gen_params = result.get('gen_params', {})
logger.info(f" Prompt: {gen_params.get('prompt', '')[:50]}...")
logger.info(f" Steps: {gen_params.get('steps', '')}")
logger.info(f" Sampler: {gen_params.get('sampler', '')}")
logger.info(f" Size: {gen_params.get('size', '')}")
# Compare with reference output
ref_output_path = os.path.join(REFS_DIR, 'flux_output.json')
try:
with open(ref_output_path, 'r') as f:
ref_output = json.load(f)
# Simple validation
loras_match = result.get('loras', '') == ref_output.get('loras', '')
prompt_match = gen_params.get('prompt', '') == ref_output.get('gen_params', {}).get('prompt', '')
logger.info(f"Validation against reference:")
logger.info(f" LoRAs match: {loras_match}")
logger.info(f" Prompt match: {prompt_match}")
except Exception as e:
logger.warning(f"Failed to compare with reference output: {e}")
if __name__ == "__main__":
test_parse_flux_workflow()

View File

@@ -1,120 +0,0 @@
"""
Utility functions for ComfyUI workflow parsing
"""
import json
import os
import logging
from typing import Dict, List, Any, Optional, Union, Set, Tuple
logger = logging.getLogger(__name__)
def load_workflow(workflow_path: str) -> Dict:
"""Load a workflow from a JSON file"""
try:
with open(workflow_path, 'r', encoding='utf-8') as f:
return json.load(f)
except Exception as e:
logger.error(f"Error loading workflow from {workflow_path}: {e}")
raise
def save_output(output: Dict, output_path: str) -> None:
"""Save the parsed output to a JSON file"""
os.makedirs(os.path.dirname(os.path.abspath(output_path)), exist_ok=True)
try:
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(output, f, indent=4)
except Exception as e:
logger.error(f"Error saving output to {output_path}: {e}")
raise
def find_node_by_type(workflow: Dict, node_type: str) -> Optional[str]:
"""Find a node of the specified type in the workflow"""
for node_id, node_data in workflow.items():
if node_data.get("class_type") == node_type:
return node_id
return None
def find_nodes_by_type(workflow: Dict, node_type: str) -> List[str]:
"""Find all nodes of the specified type in the workflow"""
return [node_id for node_id, node_data in workflow.items()
if node_data.get("class_type") == node_type]
def get_input_node_ids(workflow: Dict, node_id: str) -> Dict[str, Tuple[str, int]]:
"""
Get the node IDs for all inputs of the given node
Returns a dictionary mapping input names to (node_id, output_slot) tuples
"""
result = {}
if node_id not in workflow:
return result
node_data = workflow[node_id]
for input_name, input_value in node_data.get("inputs", {}).items():
# Check if this input is connected to another node
if isinstance(input_value, list) and len(input_value) == 2:
# Input is connected to another node's output
# Format: [node_id, output_slot]
ref_node_id, output_slot = input_value
result[input_name] = (str(ref_node_id), output_slot)
return result
def trace_model_path(workflow: Dict, start_node_id: str) -> List[str]:
"""
Trace the model path backward from KSampler to find all LoRA nodes
Args:
workflow: The workflow data
start_node_id: The starting node ID (usually KSampler)
Returns:
List of node IDs in the model path
"""
model_path_nodes = []
# Get the model input from the start node
if start_node_id not in workflow:
return model_path_nodes
# Track visited nodes to avoid cycles
visited = set()
# Stack for depth-first search
stack = []
# Get model input reference if available
start_node = workflow[start_node_id]
if "inputs" in start_node and "model" in start_node["inputs"] and isinstance(start_node["inputs"]["model"], list):
model_ref = start_node["inputs"]["model"]
stack.append(str(model_ref[0]))
# Perform depth-first search
while stack:
node_id = stack.pop()
# Skip if already visited
if node_id in visited:
continue
# Mark as visited
visited.add(node_id)
# Skip if node doesn't exist
if node_id not in workflow:
continue
node = workflow[node_id]
node_type = node.get("class_type", "")
# Add current node to result list if it's a LoRA node
if "Lora" in node_type:
model_path_nodes.append(node_id)
# Add all input nodes that have a "model" or "lora_stack" output to the stack
if "inputs" in node:
for input_name, input_value in node["inputs"].items():
if input_name in ["model", "lora_stack"] and isinstance(input_value, list) and len(input_value) == 2:
stack.append(str(input_value[0]))
return model_path_nodes