mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
refactor: remove workflow parsing module and associated files for cleanup
This commit is contained in:
@@ -1,26 +0,0 @@
|
||||
from aiohttp import web
|
||||
from server import PromptServer
|
||||
from .nodes.utils import get_lora_info
|
||||
|
||||
@PromptServer.instance.routes.post("/loramanager/get_trigger_words")
|
||||
async def get_trigger_words(request):
|
||||
json_data = await request.json()
|
||||
lora_names = json_data.get("lora_names", [])
|
||||
node_ids = json_data.get("node_ids", [])
|
||||
|
||||
all_trigger_words = []
|
||||
for lora_name in lora_names:
|
||||
_, trigger_words = await get_lora_info(lora_name)
|
||||
all_trigger_words.extend(trigger_words)
|
||||
|
||||
# Format the trigger words
|
||||
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
|
||||
|
||||
# Send update to all connected trigger word toggle nodes
|
||||
for node_id in node_ids:
|
||||
PromptServer.instance.send_sync("trigger_word_update", {
|
||||
"id": node_id,
|
||||
"message": trigger_words_text
|
||||
})
|
||||
|
||||
return web.json_response({"success": True})
|
||||
@@ -1,3 +0,0 @@
|
||||
"""
|
||||
ComfyUI workflow parsing module to extract generation parameters
|
||||
"""
|
||||
@@ -1,58 +0,0 @@
|
||||
"""
|
||||
Command-line interface for the ComfyUI workflow parser
|
||||
"""
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import logging
|
||||
import sys
|
||||
from .parser import parse_workflow
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
handlers=[logging.StreamHandler()]
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def main():
|
||||
"""Entry point for the CLI"""
|
||||
parser = argparse.ArgumentParser(description='Parse ComfyUI workflow files')
|
||||
parser.add_argument('input', help='Input workflow JSON file path')
|
||||
parser.add_argument('-o', '--output', help='Output JSON file path')
|
||||
parser.add_argument('-p', '--pretty', action='store_true', help='Pretty print JSON output')
|
||||
parser.add_argument('--debug', action='store_true', help='Enable debug logging')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Set logging level
|
||||
if args.debug:
|
||||
logging.getLogger().setLevel(logging.DEBUG)
|
||||
|
||||
# Validate input file
|
||||
if not os.path.isfile(args.input):
|
||||
logger.error(f"Input file not found: {args.input}")
|
||||
sys.exit(1)
|
||||
|
||||
# Parse workflow
|
||||
try:
|
||||
result = parse_workflow(args.input, args.output)
|
||||
|
||||
# Print result to console if output file not specified
|
||||
if not args.output:
|
||||
if args.pretty:
|
||||
print(json.dumps(result, indent=4))
|
||||
else:
|
||||
print(json.dumps(result))
|
||||
else:
|
||||
logger.info(f"Output saved to: {args.output}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing workflow: {e}")
|
||||
if args.debug:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,3 +0,0 @@
|
||||
"""
|
||||
Extension directory for custom node mappers
|
||||
"""
|
||||
@@ -1,285 +0,0 @@
|
||||
"""
|
||||
ComfyUI Core nodes mappers extension for workflow parsing
|
||||
"""
|
||||
import logging
|
||||
from typing import Dict, Any, List
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# =============================================================================
|
||||
# Transform Functions
|
||||
# =============================================================================
|
||||
|
||||
def transform_random_noise(inputs: Dict) -> Dict:
|
||||
"""Transform function for RandomNoise node"""
|
||||
return {"seed": str(inputs.get("noise_seed", ""))}
|
||||
|
||||
def transform_ksampler_select(inputs: Dict) -> Dict:
|
||||
"""Transform function for KSamplerSelect node"""
|
||||
return {"sampler": inputs.get("sampler_name", "")}
|
||||
|
||||
def transform_basic_scheduler(inputs: Dict) -> Dict:
|
||||
"""Transform function for BasicScheduler node"""
|
||||
result = {
|
||||
"scheduler": inputs.get("scheduler", ""),
|
||||
"denoise": str(inputs.get("denoise", "1.0"))
|
||||
}
|
||||
|
||||
# Get steps from inputs or steps input
|
||||
if "steps" in inputs:
|
||||
if isinstance(inputs["steps"], str):
|
||||
result["steps"] = inputs["steps"]
|
||||
elif isinstance(inputs["steps"], dict) and "value" in inputs["steps"]:
|
||||
result["steps"] = str(inputs["steps"]["value"])
|
||||
else:
|
||||
result["steps"] = str(inputs["steps"])
|
||||
|
||||
return result
|
||||
|
||||
def transform_basic_guider(inputs: Dict) -> Dict:
|
||||
"""Transform function for BasicGuider node"""
|
||||
result = {}
|
||||
|
||||
# Process conditioning
|
||||
if "conditioning" in inputs:
|
||||
if isinstance(inputs["conditioning"], str):
|
||||
result["prompt"] = inputs["conditioning"]
|
||||
elif isinstance(inputs["conditioning"], dict):
|
||||
result["conditioning"] = inputs["conditioning"]
|
||||
|
||||
# Get model information if needed
|
||||
if "model" in inputs and isinstance(inputs["model"], dict):
|
||||
result["model"] = inputs["model"]
|
||||
|
||||
return result
|
||||
|
||||
def transform_model_sampling_flux(inputs: Dict) -> Dict:
|
||||
"""Transform function for ModelSamplingFlux - mostly a pass-through node"""
|
||||
# This node is primarily used for routing, so we mostly pass through values
|
||||
|
||||
return inputs["model"]
|
||||
|
||||
def transform_sampler_custom_advanced(inputs: Dict) -> Dict:
|
||||
"""Transform function for SamplerCustomAdvanced node"""
|
||||
result = {}
|
||||
|
||||
# Extract seed from noise
|
||||
if "noise" in inputs and isinstance(inputs["noise"], dict):
|
||||
result["seed"] = str(inputs["noise"].get("seed", ""))
|
||||
|
||||
# Extract sampler info
|
||||
if "sampler" in inputs and isinstance(inputs["sampler"], dict):
|
||||
sampler = inputs["sampler"].get("sampler", "")
|
||||
if sampler:
|
||||
result["sampler"] = sampler
|
||||
|
||||
# Extract scheduler, steps, denoise from sigmas
|
||||
if "sigmas" in inputs and isinstance(inputs["sigmas"], dict):
|
||||
sigmas = inputs["sigmas"]
|
||||
result["scheduler"] = sigmas.get("scheduler", "")
|
||||
result["steps"] = str(sigmas.get("steps", ""))
|
||||
result["denoise"] = str(sigmas.get("denoise", "1.0"))
|
||||
|
||||
# Extract prompt and guidance from guider
|
||||
if "guider" in inputs and isinstance(inputs["guider"], dict):
|
||||
guider = inputs["guider"]
|
||||
|
||||
# Get prompt from conditioning
|
||||
if "conditioning" in guider and isinstance(guider["conditioning"], str):
|
||||
result["prompt"] = guider["conditioning"]
|
||||
elif "conditioning" in guider and isinstance(guider["conditioning"], dict):
|
||||
result["guidance"] = guider["conditioning"].get("guidance", "")
|
||||
result["prompt"] = guider["conditioning"].get("prompt", "")
|
||||
|
||||
if "model" in guider and isinstance(guider["model"], dict):
|
||||
result["checkpoint"] = guider["model"].get("checkpoint", "")
|
||||
result["loras"] = guider["model"].get("loras", "")
|
||||
result["clip_skip"] = str(int(guider["model"].get("clip_skip", "-1")) * -1)
|
||||
|
||||
# Extract dimensions from latent_image
|
||||
if "latent_image" in inputs and isinstance(inputs["latent_image"], dict):
|
||||
latent = inputs["latent_image"]
|
||||
width = latent.get("width", 0)
|
||||
height = latent.get("height", 0)
|
||||
if width and height:
|
||||
result["width"] = width
|
||||
result["height"] = height
|
||||
result["size"] = f"{width}x{height}"
|
||||
|
||||
return result
|
||||
|
||||
def transform_ksampler(inputs: Dict) -> Dict:
|
||||
"""Transform function for KSampler nodes"""
|
||||
result = {
|
||||
"seed": str(inputs.get("seed", "")),
|
||||
"steps": str(inputs.get("steps", "")),
|
||||
"cfg": str(inputs.get("cfg", "")),
|
||||
"sampler": inputs.get("sampler_name", ""),
|
||||
"scheduler": inputs.get("scheduler", ""),
|
||||
}
|
||||
|
||||
# Process positive prompt
|
||||
if "positive" in inputs:
|
||||
result["prompt"] = inputs["positive"]
|
||||
|
||||
# Process negative prompt
|
||||
if "negative" in inputs:
|
||||
result["negative_prompt"] = inputs["negative"]
|
||||
|
||||
# Get dimensions from latent image
|
||||
if "latent_image" in inputs and isinstance(inputs["latent_image"], dict):
|
||||
width = inputs["latent_image"].get("width", 0)
|
||||
height = inputs["latent_image"].get("height", 0)
|
||||
if width and height:
|
||||
result["size"] = f"{width}x{height}"
|
||||
|
||||
# Add clip_skip if present
|
||||
if "clip_skip" in inputs:
|
||||
result["clip_skip"] = str(inputs.get("clip_skip", ""))
|
||||
|
||||
# Add guidance if present
|
||||
if "guidance" in inputs:
|
||||
result["guidance"] = str(inputs.get("guidance", ""))
|
||||
|
||||
# Add model if present
|
||||
if "model" in inputs:
|
||||
result["checkpoint"] = inputs.get("model", {}).get("checkpoint", "")
|
||||
result["loras"] = inputs.get("model", {}).get("loras", "")
|
||||
result["clip_skip"] = str(inputs.get("model", {}).get("clip_skip", -1) * -1)
|
||||
|
||||
return result
|
||||
|
||||
def transform_empty_latent(inputs: Dict) -> Dict:
|
||||
"""Transform function for EmptyLatentImage nodes"""
|
||||
width = inputs.get("width", 0)
|
||||
height = inputs.get("height", 0)
|
||||
return {"width": width, "height": height, "size": f"{width}x{height}"}
|
||||
|
||||
def transform_clip_text(inputs: Dict) -> Any:
|
||||
"""Transform function for CLIPTextEncode nodes"""
|
||||
return inputs.get("text", "")
|
||||
|
||||
def transform_flux_guidance(inputs: Dict) -> Dict:
|
||||
"""Transform function for FluxGuidance nodes"""
|
||||
result = {}
|
||||
|
||||
if "guidance" in inputs:
|
||||
result["guidance"] = inputs["guidance"]
|
||||
|
||||
if "conditioning" in inputs:
|
||||
conditioning = inputs["conditioning"]
|
||||
if isinstance(conditioning, str):
|
||||
result["prompt"] = conditioning
|
||||
else:
|
||||
result["prompt"] = "Unknown prompt"
|
||||
|
||||
return result
|
||||
|
||||
def transform_unet_loader(inputs: Dict) -> Dict:
|
||||
"""Transform function for UNETLoader node"""
|
||||
unet_name = inputs.get("unet_name", "")
|
||||
return {"checkpoint": unet_name} if unet_name else {}
|
||||
|
||||
def transform_checkpoint_loader(inputs: Dict) -> Dict:
|
||||
"""Transform function for CheckpointLoaderSimple node"""
|
||||
ckpt_name = inputs.get("ckpt_name", "")
|
||||
return {"checkpoint": ckpt_name} if ckpt_name else {}
|
||||
|
||||
def transform_latent_upscale_by(inputs: Dict) -> Dict:
|
||||
"""Transform function for LatentUpscaleBy node"""
|
||||
result = {}
|
||||
|
||||
width = inputs["samples"].get("width", 0) * inputs["scale_by"]
|
||||
height = inputs["samples"].get("height", 0) * inputs["scale_by"]
|
||||
result["width"] = width
|
||||
result["height"] = height
|
||||
result["size"] = f"{width}x{height}"
|
||||
|
||||
return result
|
||||
|
||||
def transform_clip_set_last_layer(inputs: Dict) -> Dict:
|
||||
"""Transform function for CLIPSetLastLayer node"""
|
||||
result = {}
|
||||
|
||||
if "stop_at_clip_layer" in inputs:
|
||||
result["clip_skip"] = inputs["stop_at_clip_layer"]
|
||||
|
||||
return result
|
||||
|
||||
# =============================================================================
|
||||
# Node Mapper Definitions
|
||||
# =============================================================================
|
||||
|
||||
# Define the mappers for ComfyUI core nodes not in main mapper
|
||||
NODE_MAPPERS_EXT = {
|
||||
# KSamplers
|
||||
"SamplerCustomAdvanced": {
|
||||
"inputs_to_track": ["noise", "guider", "sampler", "sigmas", "latent_image"],
|
||||
"transform_func": transform_sampler_custom_advanced
|
||||
},
|
||||
"KSampler": {
|
||||
"inputs_to_track": [
|
||||
"seed", "steps", "cfg", "sampler_name", "scheduler",
|
||||
"denoise", "positive", "negative", "latent_image",
|
||||
"model", "clip_skip"
|
||||
],
|
||||
"transform_func": transform_ksampler
|
||||
},
|
||||
# ComfyUI core nodes
|
||||
"EmptyLatentImage": {
|
||||
"inputs_to_track": ["width", "height", "batch_size"],
|
||||
"transform_func": transform_empty_latent
|
||||
},
|
||||
"EmptySD3LatentImage": {
|
||||
"inputs_to_track": ["width", "height", "batch_size"],
|
||||
"transform_func": transform_empty_latent
|
||||
},
|
||||
"CLIPTextEncode": {
|
||||
"inputs_to_track": ["text", "clip"],
|
||||
"transform_func": transform_clip_text
|
||||
},
|
||||
"FluxGuidance": {
|
||||
"inputs_to_track": ["guidance", "conditioning"],
|
||||
"transform_func": transform_flux_guidance
|
||||
},
|
||||
"RandomNoise": {
|
||||
"inputs_to_track": ["noise_seed"],
|
||||
"transform_func": transform_random_noise
|
||||
},
|
||||
"KSamplerSelect": {
|
||||
"inputs_to_track": ["sampler_name"],
|
||||
"transform_func": transform_ksampler_select
|
||||
},
|
||||
"BasicScheduler": {
|
||||
"inputs_to_track": ["scheduler", "steps", "denoise", "model"],
|
||||
"transform_func": transform_basic_scheduler
|
||||
},
|
||||
"BasicGuider": {
|
||||
"inputs_to_track": ["model", "conditioning"],
|
||||
"transform_func": transform_basic_guider
|
||||
},
|
||||
"ModelSamplingFlux": {
|
||||
"inputs_to_track": ["max_shift", "base_shift", "width", "height", "model"],
|
||||
"transform_func": transform_model_sampling_flux
|
||||
},
|
||||
"UNETLoader": {
|
||||
"inputs_to_track": ["unet_name"],
|
||||
"transform_func": transform_unet_loader
|
||||
},
|
||||
"CheckpointLoaderSimple": {
|
||||
"inputs_to_track": ["ckpt_name"],
|
||||
"transform_func": transform_checkpoint_loader
|
||||
},
|
||||
"LatentUpscale": {
|
||||
"inputs_to_track": ["width", "height"],
|
||||
"transform_func": transform_empty_latent
|
||||
},
|
||||
"LatentUpscaleBy": {
|
||||
"inputs_to_track": ["samples", "scale_by"],
|
||||
"transform_func": transform_latent_upscale_by
|
||||
},
|
||||
"CLIPSetLastLayer": {
|
||||
"inputs_to_track": ["clip", "stop_at_clip_layer"],
|
||||
"transform_func": transform_clip_set_last_layer
|
||||
}
|
||||
}
|
||||
@@ -1,74 +0,0 @@
|
||||
"""
|
||||
KJNodes mappers extension for ComfyUI workflow parsing
|
||||
"""
|
||||
import logging
|
||||
import re
|
||||
from typing import Dict, Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# =============================================================================
|
||||
# Transform Functions
|
||||
# =============================================================================
|
||||
|
||||
def transform_join_strings(inputs: Dict) -> str:
|
||||
"""Transform function for JoinStrings nodes"""
|
||||
string1 = inputs.get("string1", "")
|
||||
string2 = inputs.get("string2", "")
|
||||
delimiter = inputs.get("delimiter", "")
|
||||
return f"{string1}{delimiter}{string2}"
|
||||
|
||||
def transform_string_constant(inputs: Dict) -> str:
|
||||
"""Transform function for StringConstant nodes"""
|
||||
return inputs.get("string", "")
|
||||
|
||||
def transform_empty_latent_presets(inputs: Dict) -> Dict:
|
||||
"""Transform function for EmptyLatentImagePresets nodes"""
|
||||
dimensions = inputs.get("dimensions", "")
|
||||
invert = inputs.get("invert", False)
|
||||
|
||||
# Extract width and height from dimensions string
|
||||
# Expected format: "width x height (ratio)" or similar
|
||||
width = 0
|
||||
height = 0
|
||||
|
||||
if dimensions:
|
||||
# Try to extract dimensions using regex
|
||||
match = re.search(r'(\d+)\s*x\s*(\d+)', dimensions)
|
||||
if match:
|
||||
width = int(match.group(1))
|
||||
height = int(match.group(2))
|
||||
|
||||
# If invert is True, swap width and height
|
||||
if invert and width and height:
|
||||
width, height = height, width
|
||||
|
||||
return {"width": width, "height": height, "size": f"{width}x{height}"}
|
||||
|
||||
def transform_int_constant(inputs: Dict) -> int:
|
||||
"""Transform function for INTConstant nodes"""
|
||||
return inputs.get("value", 0)
|
||||
|
||||
# =============================================================================
|
||||
# Node Mapper Definitions
|
||||
# =============================================================================
|
||||
|
||||
# Define the mappers for KJNodes
|
||||
NODE_MAPPERS_EXT = {
|
||||
"JoinStrings": {
|
||||
"inputs_to_track": ["string1", "string2", "delimiter"],
|
||||
"transform_func": transform_join_strings
|
||||
},
|
||||
"StringConstantMultiline": {
|
||||
"inputs_to_track": ["string"],
|
||||
"transform_func": transform_string_constant
|
||||
},
|
||||
"EmptyLatentImagePresets": {
|
||||
"inputs_to_track": ["dimensions", "invert", "batch_size"],
|
||||
"transform_func": transform_empty_latent_presets
|
||||
},
|
||||
"INTConstant": {
|
||||
"inputs_to_track": ["value"],
|
||||
"transform_func": transform_int_constant
|
||||
}
|
||||
}
|
||||
@@ -1,37 +0,0 @@
|
||||
"""
|
||||
Main entry point for the workflow parser module
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
# Add the parent directory to sys.path to enable imports
|
||||
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
ROOT_DIR = os.path.abspath(os.path.join(SCRIPT_DIR, '..', '..'))
|
||||
sys.path.insert(0, os.path.dirname(SCRIPT_DIR))
|
||||
|
||||
from .parser import parse_workflow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def parse_comfyui_workflow(
|
||||
workflow_path: str,
|
||||
output_path: Optional[str] = None
|
||||
) -> Dict:
|
||||
"""
|
||||
Parse a ComfyUI workflow file and extract generation parameters
|
||||
|
||||
Args:
|
||||
workflow_path: Path to the workflow JSON file
|
||||
output_path: Optional path to save the output JSON
|
||||
|
||||
Returns:
|
||||
Dictionary containing extracted parameters
|
||||
"""
|
||||
return parse_workflow(workflow_path, output_path)
|
||||
|
||||
if __name__ == "__main__":
|
||||
# If run directly, use the CLI
|
||||
from .cli import main
|
||||
main()
|
||||
@@ -1,282 +0,0 @@
|
||||
"""
|
||||
Node mappers for ComfyUI workflow parsing
|
||||
"""
|
||||
import logging
|
||||
import os
|
||||
import importlib.util
|
||||
import inspect
|
||||
from typing import Dict, List, Any, Optional, Union, Type, Callable, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global mapper registry
|
||||
_MAPPER_REGISTRY: Dict[str, Dict] = {}
|
||||
|
||||
# =============================================================================
|
||||
# Mapper Definition Functions
|
||||
# =============================================================================
|
||||
|
||||
def create_mapper(
|
||||
node_type: str,
|
||||
inputs_to_track: List[str],
|
||||
transform_func: Callable[[Dict], Any] = None
|
||||
) -> Dict:
|
||||
"""Create a mapper definition for a node type"""
|
||||
mapper = {
|
||||
"node_type": node_type,
|
||||
"inputs_to_track": inputs_to_track,
|
||||
"transform": transform_func or (lambda inputs: inputs)
|
||||
}
|
||||
return mapper
|
||||
|
||||
def register_mapper(mapper: Dict) -> None:
|
||||
"""Register a node mapper in the global registry"""
|
||||
_MAPPER_REGISTRY[mapper["node_type"]] = mapper
|
||||
logger.debug(f"Registered mapper for node type: {mapper['node_type']}")
|
||||
|
||||
def get_mapper(node_type: str) -> Optional[Dict]:
|
||||
"""Get a mapper for the specified node type"""
|
||||
return _MAPPER_REGISTRY.get(node_type)
|
||||
|
||||
def get_all_mappers() -> Dict[str, Dict]:
|
||||
"""Get all registered mappers"""
|
||||
return _MAPPER_REGISTRY.copy()
|
||||
|
||||
# =============================================================================
|
||||
# Node Processing Function
|
||||
# =============================================================================
|
||||
|
||||
def process_node(node_id: str, node_data: Dict, workflow: Dict, parser: 'WorkflowParser') -> Any: # type: ignore
|
||||
"""Process a node using its mapper and extract relevant information"""
|
||||
node_type = node_data.get("class_type")
|
||||
mapper = get_mapper(node_type)
|
||||
|
||||
if not mapper:
|
||||
logger.warning(f"No mapper found for node type: {node_type}")
|
||||
return None
|
||||
|
||||
result = {}
|
||||
|
||||
# Extract inputs based on the mapper's tracked inputs
|
||||
for input_name in mapper["inputs_to_track"]:
|
||||
if input_name in node_data.get("inputs", {}):
|
||||
input_value = node_data["inputs"][input_name]
|
||||
|
||||
# Check if input is a reference to another node's output
|
||||
if isinstance(input_value, list) and len(input_value) == 2:
|
||||
try:
|
||||
# Format is [node_id, output_slot]
|
||||
ref_node_id, output_slot = input_value
|
||||
# Convert node_id to string if it's an integer
|
||||
if isinstance(ref_node_id, int):
|
||||
ref_node_id = str(ref_node_id)
|
||||
|
||||
# Recursively process the referenced node
|
||||
ref_value = parser.process_node(ref_node_id, workflow)
|
||||
|
||||
if ref_value is not None:
|
||||
result[input_name] = ref_value
|
||||
else:
|
||||
# If we couldn't get a value from the reference, store the raw value
|
||||
result[input_name] = input_value
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing reference in node {node_id}, input {input_name}: {e}")
|
||||
result[input_name] = input_value
|
||||
else:
|
||||
# Direct value
|
||||
result[input_name] = input_value
|
||||
|
||||
# Apply the transform function
|
||||
try:
|
||||
return mapper["transform"](result)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in transform function for node {node_id} of type {node_type}: {e}")
|
||||
return result
|
||||
|
||||
# =============================================================================
|
||||
# Transform Functions
|
||||
# =============================================================================
|
||||
|
||||
|
||||
|
||||
def transform_lora_loader(inputs: Dict) -> Dict:
|
||||
"""Transform function for LoraLoader nodes"""
|
||||
loras_data = inputs.get("loras", [])
|
||||
lora_stack = inputs.get("lora_stack", {}).get("lora_stack", [])
|
||||
|
||||
lora_texts = []
|
||||
|
||||
# Process loras array
|
||||
if isinstance(loras_data, dict) and "__value__" in loras_data:
|
||||
loras_list = loras_data["__value__"]
|
||||
elif isinstance(loras_data, list):
|
||||
loras_list = loras_data
|
||||
else:
|
||||
loras_list = []
|
||||
|
||||
# Process each active lora entry
|
||||
for lora in loras_list:
|
||||
if isinstance(lora, dict) and lora.get("active", False):
|
||||
lora_name = lora.get("name", "")
|
||||
strength = lora.get("strength", 1.0)
|
||||
lora_texts.append(f"<lora:{lora_name}:{strength}>")
|
||||
|
||||
# Process lora_stack if valid
|
||||
if lora_stack and isinstance(lora_stack, list):
|
||||
if not (len(lora_stack) == 2 and isinstance(lora_stack[0], (str, int)) and isinstance(lora_stack[1], int)):
|
||||
for stack_entry in lora_stack:
|
||||
lora_name = stack_entry[0]
|
||||
strength = stack_entry[1]
|
||||
lora_texts.append(f"<lora:{lora_name}:{strength}>")
|
||||
|
||||
result = {
|
||||
"checkpoint": inputs.get("model", {}).get("checkpoint", ""),
|
||||
"loras": " ".join(lora_texts)
|
||||
}
|
||||
|
||||
if "clip" in inputs and isinstance(inputs["clip"], dict):
|
||||
result["clip_skip"] = inputs["clip"].get("clip_skip", "-1")
|
||||
|
||||
return result
|
||||
|
||||
def transform_lora_stacker(inputs: Dict) -> Dict:
|
||||
"""Transform function for LoraStacker nodes"""
|
||||
loras_data = inputs.get("loras", [])
|
||||
result_stack = []
|
||||
|
||||
# Handle existing stack entries
|
||||
existing_stack = []
|
||||
lora_stack_input = inputs.get("lora_stack", [])
|
||||
|
||||
if isinstance(lora_stack_input, dict) and "lora_stack" in lora_stack_input:
|
||||
existing_stack = lora_stack_input["lora_stack"]
|
||||
elif isinstance(lora_stack_input, list):
|
||||
if not (len(lora_stack_input) == 2 and isinstance(lora_stack_input[0], (str, int)) and
|
||||
isinstance(lora_stack_input[1], int)):
|
||||
existing_stack = lora_stack_input
|
||||
|
||||
# Add existing entries
|
||||
if existing_stack:
|
||||
result_stack.extend(existing_stack)
|
||||
|
||||
# Process new loras
|
||||
if isinstance(loras_data, dict) and "__value__" in loras_data:
|
||||
loras_list = loras_data["__value__"]
|
||||
elif isinstance(loras_data, list):
|
||||
loras_list = loras_data
|
||||
else:
|
||||
loras_list = []
|
||||
|
||||
for lora in loras_list:
|
||||
if isinstance(lora, dict) and lora.get("active", False):
|
||||
lora_name = lora.get("name", "")
|
||||
strength = float(lora.get("strength", 1.0))
|
||||
result_stack.append((lora_name, strength))
|
||||
|
||||
return {"lora_stack": result_stack}
|
||||
|
||||
def transform_trigger_word_toggle(inputs: Dict) -> str:
|
||||
"""Transform function for TriggerWordToggle nodes"""
|
||||
toggle_data = inputs.get("toggle_trigger_words", [])
|
||||
|
||||
if isinstance(toggle_data, dict) and "__value__" in toggle_data:
|
||||
toggle_words = toggle_data["__value__"]
|
||||
elif isinstance(toggle_data, list):
|
||||
toggle_words = toggle_data
|
||||
else:
|
||||
toggle_words = []
|
||||
|
||||
# Filter active trigger words
|
||||
active_words = []
|
||||
for item in toggle_words:
|
||||
if isinstance(item, dict) and item.get("active", False):
|
||||
word = item.get("text", "")
|
||||
if word and not word.startswith("__dummy"):
|
||||
active_words.append(word)
|
||||
|
||||
return ", ".join(active_words)
|
||||
|
||||
# =============================================================================
|
||||
# Node Mapper Definitions
|
||||
# =============================================================================
|
||||
|
||||
# Central definition of all supported node types and their configurations
|
||||
NODE_MAPPERS = {
|
||||
|
||||
# LoraManager nodes
|
||||
"Lora Loader (LoraManager)": {
|
||||
"inputs_to_track": ["model", "clip", "loras", "lora_stack"],
|
||||
"transform_func": transform_lora_loader
|
||||
},
|
||||
"Lora Stacker (LoraManager)": {
|
||||
"inputs_to_track": ["loras", "lora_stack"],
|
||||
"transform_func": transform_lora_stacker
|
||||
},
|
||||
"TriggerWord Toggle (LoraManager)": {
|
||||
"inputs_to_track": ["toggle_trigger_words"],
|
||||
"transform_func": transform_trigger_word_toggle
|
||||
}
|
||||
}
|
||||
|
||||
def register_all_mappers() -> None:
|
||||
"""Register all mappers from the NODE_MAPPERS dictionary"""
|
||||
for node_type, config in NODE_MAPPERS.items():
|
||||
mapper = create_mapper(
|
||||
node_type=node_type,
|
||||
inputs_to_track=config["inputs_to_track"],
|
||||
transform_func=config["transform_func"]
|
||||
)
|
||||
register_mapper(mapper)
|
||||
logger.info(f"Registered {len(NODE_MAPPERS)} node mappers")
|
||||
|
||||
# =============================================================================
|
||||
# Extension Loading
|
||||
# =============================================================================
|
||||
|
||||
def load_extensions(ext_dir: str = None) -> None:
|
||||
"""
|
||||
Load mapper extensions from the specified directory
|
||||
|
||||
Extension files should define a NODE_MAPPERS_EXT dictionary containing mapper configurations.
|
||||
These will be added to the global NODE_MAPPERS dictionary and registered automatically.
|
||||
"""
|
||||
# Use default path if none provided
|
||||
if ext_dir is None:
|
||||
# Get the directory of this file
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
ext_dir = os.path.join(current_dir, 'ext')
|
||||
|
||||
# Ensure the extension directory exists
|
||||
if not os.path.exists(ext_dir):
|
||||
os.makedirs(ext_dir, exist_ok=True)
|
||||
logger.info(f"Created extension directory: {ext_dir}")
|
||||
return
|
||||
|
||||
# Load each Python file in the extension directory
|
||||
for filename in os.listdir(ext_dir):
|
||||
if filename.endswith('.py') and not filename.startswith('_'):
|
||||
module_path = os.path.join(ext_dir, filename)
|
||||
module_name = f"workflow.ext.{filename[:-3]}" # Remove .py
|
||||
|
||||
try:
|
||||
# Load the module
|
||||
spec = importlib.util.spec_from_file_location(module_name, module_path)
|
||||
if spec and spec.loader:
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
# Check if the module defines NODE_MAPPERS_EXT
|
||||
if hasattr(module, 'NODE_MAPPERS_EXT'):
|
||||
# Add the extension mappers to the global NODE_MAPPERS dictionary
|
||||
NODE_MAPPERS.update(module.NODE_MAPPERS_EXT)
|
||||
logger.info(f"Added {len(module.NODE_MAPPERS_EXT)} mappers from extension: {filename}")
|
||||
else:
|
||||
logger.warning(f"Extension {filename} does not define NODE_MAPPERS_EXT dictionary")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error loading extension {filename}: {e}")
|
||||
|
||||
# Re-register all mappers after loading extensions
|
||||
register_all_mappers()
|
||||
|
||||
# Initialize the registry with default mappers
|
||||
# register_default_mappers()
|
||||
@@ -1,181 +0,0 @@
|
||||
"""
|
||||
Main workflow parser implementation for ComfyUI
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, List, Any, Optional, Union, Set
|
||||
from .mappers import get_mapper, get_all_mappers, load_extensions, process_node
|
||||
from .utils import (
|
||||
load_workflow, save_output, find_node_by_type,
|
||||
trace_model_path
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class WorkflowParser:
|
||||
"""Parser for ComfyUI workflows"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the parser with mappers"""
|
||||
self.processed_nodes: Set[str] = set() # Track processed nodes to avoid cycles
|
||||
self.node_results_cache: Dict[str, Any] = {} # Cache for processed node results
|
||||
|
||||
# Load extensions
|
||||
load_extensions()
|
||||
|
||||
def process_node(self, node_id: str, workflow: Dict) -> Any:
|
||||
"""Process a single node and extract relevant information"""
|
||||
# Return cached result if available
|
||||
if node_id in self.node_results_cache:
|
||||
return self.node_results_cache[node_id]
|
||||
|
||||
# Check if we're in a cycle
|
||||
if node_id in self.processed_nodes:
|
||||
return None
|
||||
|
||||
# Mark this node as being processed (to detect cycles)
|
||||
self.processed_nodes.add(node_id)
|
||||
|
||||
if node_id not in workflow:
|
||||
self.processed_nodes.remove(node_id)
|
||||
return None
|
||||
|
||||
node_data = workflow[node_id]
|
||||
node_type = node_data.get("class_type")
|
||||
|
||||
result = None
|
||||
if get_mapper(node_type):
|
||||
try:
|
||||
result = process_node(node_id, node_data, workflow, self)
|
||||
# Cache the result
|
||||
self.node_results_cache[node_id] = result
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing node {node_id} of type {node_type}: {e}", exc_info=True)
|
||||
# Return a partial result or None depending on how we want to handle errors
|
||||
result = {}
|
||||
|
||||
# Remove node from processed set to allow it to be processed again in a different context
|
||||
self.processed_nodes.remove(node_id)
|
||||
return result
|
||||
|
||||
def find_primary_sampler_node(self, workflow: Dict) -> Optional[str]:
|
||||
"""
|
||||
Find the primary sampler node in the workflow.
|
||||
|
||||
Priority:
|
||||
1. First try to find a SamplerCustomAdvanced node
|
||||
2. If not found, look for KSampler nodes with denoise=1.0
|
||||
3. If still not found, use the first KSampler node
|
||||
|
||||
Args:
|
||||
workflow: The workflow data as a dictionary
|
||||
|
||||
Returns:
|
||||
The node ID of the primary sampler node, or None if not found
|
||||
"""
|
||||
# First check for SamplerCustomAdvanced nodes
|
||||
sampler_advanced_nodes = []
|
||||
ksampler_nodes = []
|
||||
|
||||
# Scan workflow for sampler nodes
|
||||
for node_id, node_data in workflow.items():
|
||||
node_type = node_data.get("class_type")
|
||||
|
||||
if node_type == "SamplerCustomAdvanced":
|
||||
sampler_advanced_nodes.append(node_id)
|
||||
elif node_type == "KSampler":
|
||||
ksampler_nodes.append(node_id)
|
||||
|
||||
# If we found SamplerCustomAdvanced nodes, return the first one
|
||||
if sampler_advanced_nodes:
|
||||
logger.debug(f"Found SamplerCustomAdvanced node: {sampler_advanced_nodes[0]}")
|
||||
return sampler_advanced_nodes[0]
|
||||
|
||||
# If we have KSampler nodes, look for one with denoise=1.0
|
||||
if ksampler_nodes:
|
||||
for node_id in ksampler_nodes:
|
||||
node_data = workflow[node_id]
|
||||
inputs = node_data.get("inputs", {})
|
||||
denoise = inputs.get("denoise", 0)
|
||||
|
||||
# Check if denoise is 1.0 (allowing for small floating point differences)
|
||||
if abs(float(denoise) - 1.0) < 0.001:
|
||||
logger.debug(f"Found KSampler node with denoise=1.0: {node_id}")
|
||||
return node_id
|
||||
|
||||
# If no KSampler with denoise=1.0 found, use the first one
|
||||
logger.debug(f"No KSampler with denoise=1.0 found, using first KSampler: {ksampler_nodes[0]}")
|
||||
return ksampler_nodes[0]
|
||||
|
||||
# No sampler nodes found
|
||||
logger.warning("No sampler nodes found in workflow")
|
||||
return None
|
||||
|
||||
def parse_workflow(self, workflow_data: Union[str, Dict], output_path: Optional[str] = None) -> Dict:
|
||||
"""
|
||||
Parse the workflow and extract generation parameters
|
||||
|
||||
Args:
|
||||
workflow_data: The workflow data as a dictionary or a file path
|
||||
output_path: Optional path to save the output JSON
|
||||
|
||||
Returns:
|
||||
Dictionary containing extracted parameters
|
||||
"""
|
||||
# Load workflow from file if needed
|
||||
if isinstance(workflow_data, str):
|
||||
workflow = load_workflow(workflow_data)
|
||||
else:
|
||||
workflow = workflow_data
|
||||
|
||||
# Reset the processed nodes tracker and cache
|
||||
self.processed_nodes = set()
|
||||
self.node_results_cache = {}
|
||||
|
||||
# Find the primary sampler node
|
||||
sampler_node_id = self.find_primary_sampler_node(workflow)
|
||||
if not sampler_node_id:
|
||||
logger.warning("No suitable sampler node found in workflow")
|
||||
return {}
|
||||
|
||||
# Process sampler node to extract parameters
|
||||
sampler_result = self.process_node(sampler_node_id, workflow)
|
||||
if not sampler_result:
|
||||
return {}
|
||||
|
||||
# Return the sampler result directly - it's already in the format we need
|
||||
# This simplifies the structure and makes it easier to use in recipe_routes.py
|
||||
|
||||
# Handle standard ComfyUI names vs our output format
|
||||
if "cfg" in sampler_result:
|
||||
sampler_result["cfg_scale"] = sampler_result.pop("cfg")
|
||||
|
||||
# Add clip_skip = 1 to match reference output if not already present
|
||||
if "clip_skip" not in sampler_result:
|
||||
sampler_result["clip_skip"] = "1"
|
||||
|
||||
# Ensure the prompt is a string and not a nested dictionary
|
||||
if "prompt" in sampler_result and isinstance(sampler_result["prompt"], dict):
|
||||
if "prompt" in sampler_result["prompt"]:
|
||||
sampler_result["prompt"] = sampler_result["prompt"]["prompt"]
|
||||
|
||||
# Save the result if requested
|
||||
if output_path:
|
||||
save_output(sampler_result, output_path)
|
||||
|
||||
return sampler_result
|
||||
|
||||
|
||||
def parse_workflow(workflow_path: str, output_path: Optional[str] = None) -> Dict:
|
||||
"""
|
||||
Parse a ComfyUI workflow file and extract generation parameters
|
||||
|
||||
Args:
|
||||
workflow_path: Path to the workflow JSON file
|
||||
output_path: Optional path to save the output JSON
|
||||
|
||||
Returns:
|
||||
Dictionary containing extracted parameters
|
||||
"""
|
||||
parser = WorkflowParser()
|
||||
return parser.parse_workflow(workflow_path, output_path)
|
||||
@@ -1,63 +0,0 @@
|
||||
"""
|
||||
Test script for the ComfyUI workflow parser
|
||||
"""
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
from .parser import parse_workflow
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
handlers=[logging.StreamHandler()]
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Configure paths
|
||||
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
ROOT_DIR = os.path.abspath(os.path.join(SCRIPT_DIR, '..', '..'))
|
||||
REFS_DIR = os.path.join(ROOT_DIR, 'refs')
|
||||
OUTPUT_DIR = os.path.join(ROOT_DIR, 'output')
|
||||
|
||||
def test_parse_flux_workflow():
|
||||
"""Test parsing the flux example workflow"""
|
||||
# Ensure output directory exists
|
||||
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
||||
|
||||
# Define input and output paths
|
||||
input_path = os.path.join(REFS_DIR, 'flux_prompt.json')
|
||||
output_path = os.path.join(OUTPUT_DIR, 'parsed_flux_output.json')
|
||||
|
||||
# Parse workflow
|
||||
logger.info(f"Parsing workflow: {input_path}")
|
||||
result = parse_workflow(input_path, output_path)
|
||||
|
||||
# Print result summary
|
||||
logger.info(f"Output saved to: {output_path}")
|
||||
logger.info(f"Parsing completed. Result summary:")
|
||||
logger.info(f" LoRAs: {result.get('loras', '')}")
|
||||
|
||||
gen_params = result.get('gen_params', {})
|
||||
logger.info(f" Prompt: {gen_params.get('prompt', '')[:50]}...")
|
||||
logger.info(f" Steps: {gen_params.get('steps', '')}")
|
||||
logger.info(f" Sampler: {gen_params.get('sampler', '')}")
|
||||
logger.info(f" Size: {gen_params.get('size', '')}")
|
||||
|
||||
# Compare with reference output
|
||||
ref_output_path = os.path.join(REFS_DIR, 'flux_output.json')
|
||||
try:
|
||||
with open(ref_output_path, 'r') as f:
|
||||
ref_output = json.load(f)
|
||||
|
||||
# Simple validation
|
||||
loras_match = result.get('loras', '') == ref_output.get('loras', '')
|
||||
prompt_match = gen_params.get('prompt', '') == ref_output.get('gen_params', {}).get('prompt', '')
|
||||
|
||||
logger.info(f"Validation against reference:")
|
||||
logger.info(f" LoRAs match: {loras_match}")
|
||||
logger.info(f" Prompt match: {prompt_match}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to compare with reference output: {e}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_parse_flux_workflow()
|
||||
@@ -1,120 +0,0 @@
|
||||
"""
|
||||
Utility functions for ComfyUI workflow parsing
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
import logging
|
||||
from typing import Dict, List, Any, Optional, Union, Set, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def load_workflow(workflow_path: str) -> Dict:
|
||||
"""Load a workflow from a JSON file"""
|
||||
try:
|
||||
with open(workflow_path, 'r', encoding='utf-8') as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading workflow from {workflow_path}: {e}")
|
||||
raise
|
||||
|
||||
def save_output(output: Dict, output_path: str) -> None:
|
||||
"""Save the parsed output to a JSON file"""
|
||||
os.makedirs(os.path.dirname(os.path.abspath(output_path)), exist_ok=True)
|
||||
try:
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(output, f, indent=4)
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving output to {output_path}: {e}")
|
||||
raise
|
||||
|
||||
def find_node_by_type(workflow: Dict, node_type: str) -> Optional[str]:
|
||||
"""Find a node of the specified type in the workflow"""
|
||||
for node_id, node_data in workflow.items():
|
||||
if node_data.get("class_type") == node_type:
|
||||
return node_id
|
||||
return None
|
||||
|
||||
def find_nodes_by_type(workflow: Dict, node_type: str) -> List[str]:
|
||||
"""Find all nodes of the specified type in the workflow"""
|
||||
return [node_id for node_id, node_data in workflow.items()
|
||||
if node_data.get("class_type") == node_type]
|
||||
|
||||
def get_input_node_ids(workflow: Dict, node_id: str) -> Dict[str, Tuple[str, int]]:
|
||||
"""
|
||||
Get the node IDs for all inputs of the given node
|
||||
|
||||
Returns a dictionary mapping input names to (node_id, output_slot) tuples
|
||||
"""
|
||||
result = {}
|
||||
if node_id not in workflow:
|
||||
return result
|
||||
|
||||
node_data = workflow[node_id]
|
||||
for input_name, input_value in node_data.get("inputs", {}).items():
|
||||
# Check if this input is connected to another node
|
||||
if isinstance(input_value, list) and len(input_value) == 2:
|
||||
# Input is connected to another node's output
|
||||
# Format: [node_id, output_slot]
|
||||
ref_node_id, output_slot = input_value
|
||||
result[input_name] = (str(ref_node_id), output_slot)
|
||||
|
||||
return result
|
||||
|
||||
def trace_model_path(workflow: Dict, start_node_id: str) -> List[str]:
|
||||
"""
|
||||
Trace the model path backward from KSampler to find all LoRA nodes
|
||||
|
||||
Args:
|
||||
workflow: The workflow data
|
||||
start_node_id: The starting node ID (usually KSampler)
|
||||
|
||||
Returns:
|
||||
List of node IDs in the model path
|
||||
"""
|
||||
model_path_nodes = []
|
||||
|
||||
# Get the model input from the start node
|
||||
if start_node_id not in workflow:
|
||||
return model_path_nodes
|
||||
|
||||
# Track visited nodes to avoid cycles
|
||||
visited = set()
|
||||
|
||||
# Stack for depth-first search
|
||||
stack = []
|
||||
|
||||
# Get model input reference if available
|
||||
start_node = workflow[start_node_id]
|
||||
if "inputs" in start_node and "model" in start_node["inputs"] and isinstance(start_node["inputs"]["model"], list):
|
||||
model_ref = start_node["inputs"]["model"]
|
||||
stack.append(str(model_ref[0]))
|
||||
|
||||
# Perform depth-first search
|
||||
while stack:
|
||||
node_id = stack.pop()
|
||||
|
||||
# Skip if already visited
|
||||
if node_id in visited:
|
||||
continue
|
||||
|
||||
# Mark as visited
|
||||
visited.add(node_id)
|
||||
|
||||
# Skip if node doesn't exist
|
||||
if node_id not in workflow:
|
||||
continue
|
||||
|
||||
node = workflow[node_id]
|
||||
node_type = node.get("class_type", "")
|
||||
|
||||
# Add current node to result list if it's a LoRA node
|
||||
if "Lora" in node_type:
|
||||
model_path_nodes.append(node_id)
|
||||
|
||||
# Add all input nodes that have a "model" or "lora_stack" output to the stack
|
||||
if "inputs" in node:
|
||||
for input_name, input_value in node["inputs"].items():
|
||||
if input_name in ["model", "lora_stack"] and isinstance(input_value, list) and len(input_value) == 2:
|
||||
stack.append(str(input_value[0]))
|
||||
|
||||
return model_path_nodes
|
||||
Reference in New Issue
Block a user