From 2b67091986f98c86a9d749eab58abc62fbf12507 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Sun, 23 Mar 2025 05:21:43 +0800 Subject: [PATCH] Enhance workflow parsing and node mapper registration - Introduced a new WorkflowParser class to streamline workflow parsing and manage node mappers. - Added functionality to load external mappers dynamically from a specified directory. - Refactored LoraLoaderMapper and LoraStackerMapper to handle new data formats for loras and trigger words. - Updated recipe routes to utilize the new WorkflowParser for parsing workflows. - Made adjustments to the flux_prompt.json to reflect changes in active states and class types. --- py/routes/recipe_routes.py | 6 +- py/workflow/README.md | 149 +++++++++++++++++++++ py/workflow/ext/__init__.py | 3 + py/workflow/ext/example_mapper.py | 54 ++++++++ py/workflow/mappers.py | 207 +++++++++++++++++++++++++----- py/workflow/parser.py | 41 ++---- refs/flux_prompt.json | 87 +++++++++++-- 7 files changed, 463 insertions(+), 84 deletions(-) create mode 100644 py/workflow/README.md create mode 100644 py/workflow/ext/__init__.py create mode 100644 py/workflow/ext/example_mapper.py diff --git a/py/routes/recipe_routes.py b/py/routes/recipe_routes.py index 05a157e3..6604e51b 100644 --- a/py/routes/recipe_routes.py +++ b/py/routes/recipe_routes.py @@ -12,6 +12,7 @@ from ..services.civitai_client import CivitaiClient from ..services.recipe_scanner import RecipeScanner from ..services.lora_scanner import LoraScanner from ..config import config +from ..workflow.parser import WorkflowParser import time # Add this import at the top logger = logging.getLogger(__name__) @@ -22,6 +23,7 @@ class RecipeRoutes: def __init__(self): self.recipe_scanner = RecipeScanner(LoraScanner()) self.civitai_client = CivitaiClient() + self.parser = WorkflowParser() # Pre-warm the cache self._init_cache_task = None @@ -773,9 +775,7 @@ class RecipeRoutes: latest_image_path = image_files[0][0] # Parse the workflow to extract generation parameters and loras - from ..workflow_params.workflow_parser import parse_workflow - # load_extensions=False to avoid loading extensions for now - parsed_workflow = parse_workflow(workflow_json, load_extensions=False) + parsed_workflow = self.parser.parse_workflow(workflow_json) logger.debug(f"Parsed workflow: {parsed_workflow}") diff --git a/py/workflow/README.md b/py/workflow/README.md new file mode 100644 index 00000000..0cb78d54 --- /dev/null +++ b/py/workflow/README.md @@ -0,0 +1,149 @@ +# ComfyUI Workflow Parser + +本模块提供了一个灵活的解析系统,可以从ComfyUI工作流中提取生成参数和LoRA信息。 + +## 设计理念 + +工作流解析器基于以下设计原则: + +1. **模块化**: 每种节点类型由独立的mapper处理 +2. **可扩展性**: 通过扩展系统轻松添加新的节点类型支持 +3. **回溯**: 通过工作流图的模型输入路径跟踪LoRA节点 +4. **灵活性**: 适应不同的ComfyUI工作流结构 + +## 主要组件 + +### 1. NodeMapper + +`NodeMapper`是所有节点映射器的基类,定义了如何从工作流中提取节点信息: + +```python +class NodeMapper: + def __init__(self, node_type: str, inputs_to_track: List[str]): + self.node_type = node_type + self.inputs_to_track = inputs_to_track + + def process(self, node_id: str, node_data: Dict, workflow: Dict, parser) -> Any: + # 处理节点的通用逻辑 + ... + + def transform(self, inputs: Dict) -> Any: + # 由子类覆盖以提供特定转换 + return inputs +``` + +### 2. WorkflowParser + +主要解析类,通过跟踪工作流图来提取参数: + +```python +parser = WorkflowParser() +result = parser.parse_workflow("workflow.json") +``` + +### 3. 扩展系统 + +允许通过添加新的自定义mapper来扩展支持的节点类型: + +```python +# 在py/workflow/ext/中添加自定义mapper模块 +load_extensions() # 自动加载所有扩展 +``` + +## 使用方法 + +### 基本用法 + +```python +from workflow.parser import parse_workflow + +# 解析工作流并保存结果 +result = parse_workflow("workflow.json", "output.json") +``` + +### 自定义解析 + +```python +from workflow.parser import WorkflowParser +from workflow.mappers import register_mapper, load_extensions + +# 加载扩展 +load_extensions() + +# 创建解析器 +parser = WorkflowParser(load_extensions_on_init=False) # 不自动加载扩展 + +# 解析工作流 +result = parser.parse_workflow(workflow_data) +``` + +## 扩展系统 + +### 添加新的节点映射器 + +在`py/workflow/ext/`目录中创建Python文件,定义从`NodeMapper`继承的类: + +```python +# example_mapper.py +from ..mappers import NodeMapper + +class MyCustomNodeMapper(NodeMapper): + def __init__(self): + super().__init__( + node_type="MyCustomNode", # 节点的class_type + inputs_to_track=["param1", "param2"] # 要提取的参数 + ) + + def transform(self, inputs: Dict) -> Any: + # 处理提取的参数 + return { + "custom_param": inputs.get("param1", "default") + } +``` + +扩展系统会自动加载和注册这些映射器。 + +### LoraManager节点说明 + +LoraManager相关节点的处理方式: + +1. **Lora Loader**: 处理`loras`数组,过滤出`active=true`的条目,和`lora_stack`输入 +2. **Lora Stacker**: 处理`loras`数组和已有的`lora_stack`,构建叠加的LoRA +3. **TriggerWord Toggle**: 从`toggle_trigger_words`中提取`active=true`的条目 + +## 输出格式 + +解析器生成的输出格式如下: + +```json +{ + "gen_params": { + "prompt": "...", + "negative_prompt": "", + "steps": "25", + "sampler": "dpmpp_2m", + "scheduler": "beta", + "cfg": "1", + "seed": "48", + "guidance": 3.5, + "size": "896x1152", + "clip_skip": "2" + }, + "loras": " " +} +``` + +## 高级用法 + +### 直接注册映射器 + +```python +from workflow.mappers import register_mapper +from workflow.mappers import NodeMapper + +# 创建自定义映射器 +class CustomMapper(NodeMapper): + # ...实现映射器 + +# 注册映射器 +register_mapper(CustomMapper()) \ No newline at end of file diff --git a/py/workflow/ext/__init__.py b/py/workflow/ext/__init__.py new file mode 100644 index 00000000..86e11ab6 --- /dev/null +++ b/py/workflow/ext/__init__.py @@ -0,0 +1,3 @@ +""" +Extension directory for custom node mappers +""" \ No newline at end of file diff --git a/py/workflow/ext/example_mapper.py b/py/workflow/ext/example_mapper.py new file mode 100644 index 00000000..652be09e --- /dev/null +++ b/py/workflow/ext/example_mapper.py @@ -0,0 +1,54 @@ +""" +Example extension mapper for demonstrating the extension system +""" +from typing import Dict, Any +from ..mappers import NodeMapper + +class ExampleNodeMapper(NodeMapper): + """Example mapper for custom nodes""" + + def __init__(self): + super().__init__( + node_type="ExampleCustomNode", + inputs_to_track=["param1", "param2", "image"] + ) + + def transform(self, inputs: Dict) -> Dict: + """Transform extracted inputs into the desired output format""" + result = {} + + # Extract interesting parameters + if "param1" in inputs: + result["example_param1"] = inputs["param1"] + + if "param2" in inputs: + result["example_param2"] = inputs["param2"] + + # You can process the data in any way needed + return result + + +class VAEMapperExtension(NodeMapper): + """Extension mapper for VAE nodes""" + + def __init__(self): + super().__init__( + node_type="VAELoader", + inputs_to_track=["vae_name"] + ) + + def transform(self, inputs: Dict) -> Dict: + """Extract VAE information""" + vae_name = inputs.get("vae_name", "") + + # Remove path prefix if present + if "/" in vae_name or "\\" in vae_name: + # Get just the filename without path or extension + vae_name = vae_name.replace("\\", "/").split("/")[-1] + vae_name = vae_name.split(".")[0] # Remove extension + + return {"vae": vae_name} + + +# Note: No need to register manually - extensions are automatically registered +# when the extension system loads this file \ No newline at end of file diff --git a/py/workflow/mappers.py b/py/workflow/mappers.py index 2cbc82ee..57216be8 100644 --- a/py/workflow/mappers.py +++ b/py/workflow/mappers.py @@ -2,10 +2,16 @@ Node mappers for ComfyUI workflow parsing """ import logging -from typing import Dict, List, Any, Optional, Union +import os +import importlib.util +import inspect +from typing import Dict, List, Any, Optional, Union, Type, Callable logger = logging.getLogger(__name__) +# Global mapper registry +_MAPPER_REGISTRY: Dict[str, 'NodeMapper'] = {} + class NodeMapper: """Base class for node mappers that define how to extract information from a specific node type""" @@ -130,32 +136,43 @@ class LoraLoaderMapper(NodeMapper): def __init__(self): super().__init__( node_type="Lora Loader (LoraManager)", - inputs_to_track=["text", "loras", "lora_stack"] + inputs_to_track=["loras", "lora_stack"] ) def transform(self, inputs: Dict) -> Dict: - lora_text = inputs.get("text", "") + # Fallback to loras array if text field doesn't exist or is invalid + loras_data = inputs.get("loras", []) lora_stack = inputs.get("lora_stack", []) - # Process lora_stack if it exists - stack_text = "" - if lora_stack: - # Handle the formatted lora_stack info if available - stack_loras = [] - for lora_path, strength, _ in lora_stack: - lora_name = lora_path.split(os.sep)[-1].split('.')[0] - stack_loras.append(f"") - stack_text = " ".join(stack_loras) - - # Combine lora_text and stack_text - combined_text = lora_text - if stack_text: - combined_text = f"{combined_text} {stack_text}" if combined_text else stack_text + # Process loras array - filter active entries + lora_texts = [] - # Format loras with spaces between them - if combined_text: - # Replace consecutive closing and opening tags with a space - combined_text = combined_text.replace("><", "> <") + # Check if loras_data is a list or a dict with __value__ key (new format) + if isinstance(loras_data, dict) and "__value__" in loras_data: + loras_list = loras_data["__value__"] + elif isinstance(loras_data, list): + loras_list = loras_data + else: + loras_list = [] + + # Process each active lora entry + for lora in loras_list: + if isinstance(lora, dict) and lora.get("active", False): + lora_name = lora.get("name", "") + strength = lora.get("strength", 1.0) + if lora_name and not lora_name.startswith("__dummy"): + lora_texts.append(f"") + + # Process lora_stack if it exists + if lora_stack: + # Format each entry from the stack + for lora_path, strength, _ in lora_stack: + lora_name = os.path.basename(lora_path).split('.')[0] + if lora_name and not lora_name.startswith("__dummy"): + lora_texts.append(f"") + + # Join with spaces + combined_text = " ".join(lora_texts) return {"loras": combined_text} @@ -170,8 +187,34 @@ class LoraStackerMapper(NodeMapper): ) def transform(self, inputs: Dict) -> Dict: - # Return the lora_stack information - return inputs.get("lora_stack", []) + loras_data = inputs.get("loras", []) + existing_stack = inputs.get("lora_stack", []) + result_stack = [] + + # Keep existing stack entries + if existing_stack: + result_stack.extend(existing_stack) + + # Process loras array - filter active entries + # Check if loras_data is a list or a dict with __value__ key (new format) + if isinstance(loras_data, dict) and "__value__" in loras_data: + loras_list = loras_data["__value__"] + elif isinstance(loras_data, list): + loras_list = loras_data + else: + loras_list = [] + + # Process each active lora entry + for lora in loras_list: + if isinstance(lora, dict) and lora.get("active", False): + lora_name = lora.get("name", "") + strength = float(lora.get("strength", 1.0)) + if lora_name and not lora_name.startswith("__dummy"): + # Here we would need the real path, but as a fallback use the name + # In a real implementation, this would require looking up the file path + result_stack.append((lora_name, strength, strength)) + + return {"lora_stack": result_stack} class JoinStringsMapper(NodeMapper): @@ -209,19 +252,31 @@ class TriggerWordToggleMapper(NodeMapper): def __init__(self): super().__init__( node_type="TriggerWord Toggle (LoraManager)", - inputs_to_track=["toggle_trigger_words", "orinalMessage", "trigger_words"] + inputs_to_track=["toggle_trigger_words"] ) def transform(self, inputs: Dict) -> str: - # Get the original message or toggled trigger words - original_message = inputs.get("orinalMessage", "") or inputs.get("trigger_words", "") + toggle_data = inputs.get("toggle_trigger_words", []) + + # check if toggle_words is a list or a dict with __value__ key (new format) + if isinstance(toggle_data, dict) and "__value__" in toggle_data: + toggle_words = toggle_data["__value__"] + elif isinstance(toggle_data, list): + toggle_words = toggle_data + else: + toggle_words = [] - # Fix double commas to match the reference output format - if original_message: - # Replace double commas with single commas - original_message = original_message.replace(",, ", ", ") - - return original_message + # Filter active trigger words + active_words = [] + for item in toggle_words: + if isinstance(item, dict) and item.get("active", False): + word = item.get("text", "") + if word and not word.startswith("__dummy"): + active_words.append(word) + + # Join with commas + result = ", ".join(active_words) + return result class FluxGuidanceMapper(NodeMapper): @@ -251,5 +306,89 @@ class FluxGuidanceMapper(NodeMapper): return result -# Add import os for LoraLoaderMapper to work properly -import os \ No newline at end of file +# ============================================================================= +# Mapper Registry Functions +# ============================================================================= + +def register_mapper(mapper: NodeMapper) -> None: + """Register a node mapper in the global registry""" + _MAPPER_REGISTRY[mapper.node_type] = mapper + logger.debug(f"Registered mapper for node type: {mapper.node_type}") + +def get_mapper(node_type: str) -> Optional[NodeMapper]: + """Get a mapper for the specified node type""" + return _MAPPER_REGISTRY.get(node_type) + +def get_all_mappers() -> Dict[str, NodeMapper]: + """Get all registered mappers""" + return _MAPPER_REGISTRY.copy() + +def register_default_mappers() -> None: + """Register all default mappers""" + default_mappers = [ + KSamplerMapper(), + EmptyLatentImageMapper(), + EmptySD3LatentImageMapper(), + CLIPTextEncodeMapper(), + LoraLoaderMapper(), + LoraStackerMapper(), + JoinStringsMapper(), + StringConstantMapper(), + TriggerWordToggleMapper(), + FluxGuidanceMapper() + ] + + for mapper in default_mappers: + register_mapper(mapper) + +# ============================================================================= +# Extension Loading +# ============================================================================= + +def load_extensions(ext_dir: str = None) -> None: + """ + Load mapper extensions from the specified directory + + Each Python file in the directory will be loaded, and any NodeMapper subclasses + defined in those files will be automatically registered. + """ + # Use default path if none provided + if ext_dir is None: + # Get the directory of this file + current_dir = os.path.dirname(os.path.abspath(__file__)) + ext_dir = os.path.join(current_dir, 'ext') + + # Ensure the extension directory exists + if not os.path.exists(ext_dir): + os.makedirs(ext_dir, exist_ok=True) + logger.info(f"Created extension directory: {ext_dir}") + return + + # Load each Python file in the extension directory + for filename in os.listdir(ext_dir): + if filename.endswith('.py') and not filename.startswith('_'): + module_path = os.path.join(ext_dir, filename) + module_name = f"workflow.ext.{filename[:-3]}" # Remove .py + + try: + # Load the module + spec = importlib.util.spec_from_file_location(module_name, module_path) + if spec and spec.loader: + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + + # Find all NodeMapper subclasses in the module + for name, obj in inspect.getmembers(module): + if (inspect.isclass(obj) and issubclass(obj, NodeMapper) + and obj != NodeMapper and hasattr(obj, 'node_type')): + # Instantiate and register the mapper + mapper = obj() + register_mapper(mapper) + logger.info(f"Loaded extension mapper: {mapper.node_type} from {filename}") + + except Exception as e: + logger.error(f"Error loading extension {filename}: {e}") + + +# Initialize the registry with default mappers +register_default_mappers() \ No newline at end of file diff --git a/py/workflow/parser.py b/py/workflow/parser.py index 6808492e..875dc476 100644 --- a/py/workflow/parser.py +++ b/py/workflow/parser.py @@ -4,12 +4,7 @@ Main workflow parser implementation for ComfyUI import json import logging from typing import Dict, List, Any, Optional, Union, Set -from .mappers import ( - NodeMapper, KSamplerMapper, EmptyLatentImageMapper, - EmptySD3LatentImageMapper, CLIPTextEncodeMapper, - LoraLoaderMapper, LoraStackerMapper, JoinStringsMapper, - StringConstantMapper, TriggerWordToggleMapper, FluxGuidanceMapper -) +from .mappers import get_mapper, get_all_mappers, load_extensions from .utils import ( load_workflow, save_output, find_node_by_type, trace_model_path @@ -20,33 +15,13 @@ logger = logging.getLogger(__name__) class WorkflowParser: """Parser for ComfyUI workflows""" - def __init__(self): - """Initialize the parser with default node mappers""" - self.node_mappers: Dict[str, NodeMapper] = {} + def __init__(self, load_extensions_on_init: bool = True): + """Initialize the parser with mappers""" self.processed_nodes: Set[str] = set() # Track processed nodes to avoid cycles - self.register_default_mappers() - - def register_default_mappers(self) -> None: - """Register all default node mappers""" - mappers = [ - KSamplerMapper(), - EmptyLatentImageMapper(), - EmptySD3LatentImageMapper(), - CLIPTextEncodeMapper(), - LoraLoaderMapper(), - LoraStackerMapper(), - JoinStringsMapper(), - StringConstantMapper(), - TriggerWordToggleMapper(), - FluxGuidanceMapper() - ] - for mapper in mappers: - self.register_mapper(mapper) - - def register_mapper(self, mapper: NodeMapper) -> None: - """Register a node mapper""" - self.node_mappers[mapper.node_type] = mapper + # Load extensions if requested + if load_extensions_on_init: + load_extensions() def process_node(self, node_id: str, workflow: Dict) -> Any: """Process a single node and extract relevant information""" @@ -64,8 +39,8 @@ class WorkflowParser: node_type = node_data.get("class_type") result = None - if node_type in self.node_mappers: - mapper = self.node_mappers[node_type] + mapper = get_mapper(node_type) + if mapper: result = mapper.process(node_id, node_data, workflow, self) # Remove node from processed set to allow it to be processed again in a different context diff --git a/refs/flux_prompt.json b/refs/flux_prompt.json index 4f495b95..82a51077 100644 --- a/refs/flux_prompt.json +++ b/refs/flux_prompt.json @@ -44,7 +44,7 @@ }, "31": { "inputs": { - "seed": 48, + "seed": 44, "steps": 25, "cfg": 1, "sampler_name": "dpmpp_2m", @@ -95,7 +95,7 @@ }, "class_type": "FluxGuidance", "_meta": { - "title": "g" + "title": "FluxGuidance" } }, "37": { @@ -175,12 +175,12 @@ { "name": "pp-enchanted-whimsy", "strength": "0.90", - "active": true + "active": false }, { "name": "ral-frctlgmtry_flux", "strength": "0.85", - "active": true + "active": false }, { "name": "pp-storybook_rank2_bf16", @@ -218,17 +218,9 @@ "inputs": { "group_mode": "", "toggle_trigger_words": [ - { - "text": "in the style of ppWhimsy", - "active": true - }, - { - "text": "ral-frctlgmtry", - "active": true - }, { "text": "ppstorybook", - "active": true + "active": false }, { "text": "__dummy_item__", @@ -241,7 +233,7 @@ "_isDummy": true } ], - "orinalMessage": "in the style of ppWhimsy,, ral-frctlgmtry,, ppstorybook", + "orinalMessage": "ppstorybook", "trigger_words": [ "58", 2 @@ -251,5 +243,72 @@ "_meta": { "title": "TriggerWord Toggle (LoraManager)" } + }, + "61": { + "inputs": { + "add_noise": "enable", + "noise_seed": 1111423448930884, + "steps": 20, + "cfg": 8, + "sampler_name": "euler", + "scheduler": "normal", + "start_at_step": 0, + "end_at_step": 10000, + "return_with_leftover_noise": "disable" + }, + "class_type": "KSamplerAdvanced", + "_meta": { + "title": "KSampler (Advanced)" + } + }, + "62": { + "inputs": { + "sigmas": [ + "63", + 0 + ] + }, + "class_type": "SamplerCustomAdvanced", + "_meta": { + "title": "SamplerCustomAdvanced" + } + }, + "63": { + "inputs": { + "scheduler": "normal", + "steps": 20, + "denoise": 1 + }, + "class_type": "BasicScheduler", + "_meta": { + "title": "BasicScheduler" + } + }, + "64": { + "inputs": { + "seed": 1089899258710474, + "steps": 20, + "cfg": 8, + "sampler_name": "euler", + "scheduler": "normal", + "denoise": 1 + }, + "class_type": "KSampler", + "_meta": { + "title": "KSampler" + } + }, + "65": { + "inputs": { + "text": ",Stylized geek cat artist with glasses and a paintbrush, smiling at the viewer while holding a sign that reads 'Stay tuned!', solid white background", + "anything": [ + "46", + 0 + ] + }, + "class_type": "easy showAnything", + "_meta": { + "title": "Show Any" + } } } \ No newline at end of file