Enhance workflow parsing and node mapper registration

- Introduced a new WorkflowParser class to streamline workflow parsing and manage node mappers.
- Added functionality to load external mappers dynamically from a specified directory.
- Refactored LoraLoaderMapper and LoraStackerMapper to handle new data formats for loras and trigger words.
- Updated recipe routes to utilize the new WorkflowParser for parsing workflows.
- Made adjustments to the flux_prompt.json to reflect changes in active states and class types.
This commit is contained in:
Will Miao
2025-03-23 05:21:43 +08:00
parent 3da35cf0db
commit 2b67091986
7 changed files with 463 additions and 84 deletions

View File

@@ -12,6 +12,7 @@ from ..services.civitai_client import CivitaiClient
from ..services.recipe_scanner import RecipeScanner from ..services.recipe_scanner import RecipeScanner
from ..services.lora_scanner import LoraScanner from ..services.lora_scanner import LoraScanner
from ..config import config from ..config import config
from ..workflow.parser import WorkflowParser
import time # Add this import at the top import time # Add this import at the top
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -22,6 +23,7 @@ class RecipeRoutes:
def __init__(self): def __init__(self):
self.recipe_scanner = RecipeScanner(LoraScanner()) self.recipe_scanner = RecipeScanner(LoraScanner())
self.civitai_client = CivitaiClient() self.civitai_client = CivitaiClient()
self.parser = WorkflowParser()
# Pre-warm the cache # Pre-warm the cache
self._init_cache_task = None self._init_cache_task = None
@@ -773,9 +775,7 @@ class RecipeRoutes:
latest_image_path = image_files[0][0] latest_image_path = image_files[0][0]
# Parse the workflow to extract generation parameters and loras # Parse the workflow to extract generation parameters and loras
from ..workflow_params.workflow_parser import parse_workflow parsed_workflow = self.parser.parse_workflow(workflow_json)
# load_extensions=False to avoid loading extensions for now
parsed_workflow = parse_workflow(workflow_json, load_extensions=False)
logger.debug(f"Parsed workflow: {parsed_workflow}") logger.debug(f"Parsed workflow: {parsed_workflow}")

149
py/workflow/README.md Normal file
View File

@@ -0,0 +1,149 @@
# ComfyUI Workflow Parser
本模块提供了一个灵活的解析系统可以从ComfyUI工作流中提取生成参数和LoRA信息。
## 设计理念
工作流解析器基于以下设计原则:
1. **模块化**: 每种节点类型由独立的mapper处理
2. **可扩展性**: 通过扩展系统轻松添加新的节点类型支持
3. **回溯**: 通过工作流图的模型输入路径跟踪LoRA节点
4. **灵活性**: 适应不同的ComfyUI工作流结构
## 主要组件
### 1. NodeMapper
`NodeMapper`是所有节点映射器的基类,定义了如何从工作流中提取节点信息:
```python
class NodeMapper:
def __init__(self, node_type: str, inputs_to_track: List[str]):
self.node_type = node_type
self.inputs_to_track = inputs_to_track
def process(self, node_id: str, node_data: Dict, workflow: Dict, parser) -> Any:
# 处理节点的通用逻辑
...
def transform(self, inputs: Dict) -> Any:
# 由子类覆盖以提供特定转换
return inputs
```
### 2. WorkflowParser
主要解析类,通过跟踪工作流图来提取参数:
```python
parser = WorkflowParser()
result = parser.parse_workflow("workflow.json")
```
### 3. 扩展系统
允许通过添加新的自定义mapper来扩展支持的节点类型:
```python
# 在py/workflow/ext/中添加自定义mapper模块
load_extensions() # 自动加载所有扩展
```
## 使用方法
### 基本用法
```python
from workflow.parser import parse_workflow
# 解析工作流并保存结果
result = parse_workflow("workflow.json", "output.json")
```
### 自定义解析
```python
from workflow.parser import WorkflowParser
from workflow.mappers import register_mapper, load_extensions
# 加载扩展
load_extensions()
# 创建解析器
parser = WorkflowParser(load_extensions_on_init=False) # 不自动加载扩展
# 解析工作流
result = parser.parse_workflow(workflow_data)
```
## 扩展系统
### 添加新的节点映射器
`py/workflow/ext/`目录中创建Python文件定义从`NodeMapper`继承的类:
```python
# example_mapper.py
from ..mappers import NodeMapper
class MyCustomNodeMapper(NodeMapper):
def __init__(self):
super().__init__(
node_type="MyCustomNode", # 节点的class_type
inputs_to_track=["param1", "param2"] # 要提取的参数
)
def transform(self, inputs: Dict) -> Any:
# 处理提取的参数
return {
"custom_param": inputs.get("param1", "default")
}
```
扩展系统会自动加载和注册这些映射器。
### LoraManager节点说明
LoraManager相关节点的处理方式:
1. **Lora Loader**: 处理`loras`数组,过滤出`active=true`的条目,和`lora_stack`输入
2. **Lora Stacker**: 处理`loras`数组和已有的`lora_stack`构建叠加的LoRA
3. **TriggerWord Toggle**: 从`toggle_trigger_words`中提取`active=true`的条目
## 输出格式
解析器生成的输出格式如下:
```json
{
"gen_params": {
"prompt": "...",
"negative_prompt": "",
"steps": "25",
"sampler": "dpmpp_2m",
"scheduler": "beta",
"cfg": "1",
"seed": "48",
"guidance": 3.5,
"size": "896x1152",
"clip_skip": "2"
},
"loras": "<lora:name1:0.9> <lora:name2:0.8>"
}
```
## 高级用法
### 直接注册映射器
```python
from workflow.mappers import register_mapper
from workflow.mappers import NodeMapper
# 创建自定义映射器
class CustomMapper(NodeMapper):
# ...实现映射器
# 注册映射器
register_mapper(CustomMapper())

View File

@@ -0,0 +1,3 @@
"""
Extension directory for custom node mappers
"""

View File

@@ -0,0 +1,54 @@
"""
Example extension mapper for demonstrating the extension system
"""
from typing import Dict, Any
from ..mappers import NodeMapper
class ExampleNodeMapper(NodeMapper):
"""Example mapper for custom nodes"""
def __init__(self):
super().__init__(
node_type="ExampleCustomNode",
inputs_to_track=["param1", "param2", "image"]
)
def transform(self, inputs: Dict) -> Dict:
"""Transform extracted inputs into the desired output format"""
result = {}
# Extract interesting parameters
if "param1" in inputs:
result["example_param1"] = inputs["param1"]
if "param2" in inputs:
result["example_param2"] = inputs["param2"]
# You can process the data in any way needed
return result
class VAEMapperExtension(NodeMapper):
"""Extension mapper for VAE nodes"""
def __init__(self):
super().__init__(
node_type="VAELoader",
inputs_to_track=["vae_name"]
)
def transform(self, inputs: Dict) -> Dict:
"""Extract VAE information"""
vae_name = inputs.get("vae_name", "")
# Remove path prefix if present
if "/" in vae_name or "\\" in vae_name:
# Get just the filename without path or extension
vae_name = vae_name.replace("\\", "/").split("/")[-1]
vae_name = vae_name.split(".")[0] # Remove extension
return {"vae": vae_name}
# Note: No need to register manually - extensions are automatically registered
# when the extension system loads this file

View File

@@ -2,10 +2,16 @@
Node mappers for ComfyUI workflow parsing Node mappers for ComfyUI workflow parsing
""" """
import logging import logging
from typing import Dict, List, Any, Optional, Union import os
import importlib.util
import inspect
from typing import Dict, List, Any, Optional, Union, Type, Callable
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Global mapper registry
_MAPPER_REGISTRY: Dict[str, 'NodeMapper'] = {}
class NodeMapper: class NodeMapper:
"""Base class for node mappers that define how to extract information from a specific node type""" """Base class for node mappers that define how to extract information from a specific node type"""
@@ -130,32 +136,43 @@ class LoraLoaderMapper(NodeMapper):
def __init__(self): def __init__(self):
super().__init__( super().__init__(
node_type="Lora Loader (LoraManager)", node_type="Lora Loader (LoraManager)",
inputs_to_track=["text", "loras", "lora_stack"] inputs_to_track=["loras", "lora_stack"]
) )
def transform(self, inputs: Dict) -> Dict: def transform(self, inputs: Dict) -> Dict:
lora_text = inputs.get("text", "") # Fallback to loras array if text field doesn't exist or is invalid
loras_data = inputs.get("loras", [])
lora_stack = inputs.get("lora_stack", []) lora_stack = inputs.get("lora_stack", [])
# Process lora_stack if it exists # Process loras array - filter active entries
stack_text = "" lora_texts = []
if lora_stack:
# Handle the formatted lora_stack info if available
stack_loras = []
for lora_path, strength, _ in lora_stack:
lora_name = lora_path.split(os.sep)[-1].split('.')[0]
stack_loras.append(f"<lora:{lora_name}:{strength}>")
stack_text = " ".join(stack_loras)
# Combine lora_text and stack_text
combined_text = lora_text
if stack_text:
combined_text = f"{combined_text} {stack_text}" if combined_text else stack_text
# Format loras with spaces between them # Check if loras_data is a list or a dict with __value__ key (new format)
if combined_text: if isinstance(loras_data, dict) and "__value__" in loras_data:
# Replace consecutive closing and opening tags with a space loras_list = loras_data["__value__"]
combined_text = combined_text.replace("><", "> <") elif isinstance(loras_data, list):
loras_list = loras_data
else:
loras_list = []
# Process each active lora entry
for lora in loras_list:
if isinstance(lora, dict) and lora.get("active", False):
lora_name = lora.get("name", "")
strength = lora.get("strength", 1.0)
if lora_name and not lora_name.startswith("__dummy"):
lora_texts.append(f"<lora:{lora_name}:{strength}>")
# Process lora_stack if it exists
if lora_stack:
# Format each entry from the stack
for lora_path, strength, _ in lora_stack:
lora_name = os.path.basename(lora_path).split('.')[0]
if lora_name and not lora_name.startswith("__dummy"):
lora_texts.append(f"<lora:{lora_name}:{strength}>")
# Join with spaces
combined_text = " ".join(lora_texts)
return {"loras": combined_text} return {"loras": combined_text}
@@ -170,8 +187,34 @@ class LoraStackerMapper(NodeMapper):
) )
def transform(self, inputs: Dict) -> Dict: def transform(self, inputs: Dict) -> Dict:
# Return the lora_stack information loras_data = inputs.get("loras", [])
return inputs.get("lora_stack", []) existing_stack = inputs.get("lora_stack", [])
result_stack = []
# Keep existing stack entries
if existing_stack:
result_stack.extend(existing_stack)
# Process loras array - filter active entries
# Check if loras_data is a list or a dict with __value__ key (new format)
if isinstance(loras_data, dict) and "__value__" in loras_data:
loras_list = loras_data["__value__"]
elif isinstance(loras_data, list):
loras_list = loras_data
else:
loras_list = []
# Process each active lora entry
for lora in loras_list:
if isinstance(lora, dict) and lora.get("active", False):
lora_name = lora.get("name", "")
strength = float(lora.get("strength", 1.0))
if lora_name and not lora_name.startswith("__dummy"):
# Here we would need the real path, but as a fallback use the name
# In a real implementation, this would require looking up the file path
result_stack.append((lora_name, strength, strength))
return {"lora_stack": result_stack}
class JoinStringsMapper(NodeMapper): class JoinStringsMapper(NodeMapper):
@@ -209,19 +252,31 @@ class TriggerWordToggleMapper(NodeMapper):
def __init__(self): def __init__(self):
super().__init__( super().__init__(
node_type="TriggerWord Toggle (LoraManager)", node_type="TriggerWord Toggle (LoraManager)",
inputs_to_track=["toggle_trigger_words", "orinalMessage", "trigger_words"] inputs_to_track=["toggle_trigger_words"]
) )
def transform(self, inputs: Dict) -> str: def transform(self, inputs: Dict) -> str:
# Get the original message or toggled trigger words toggle_data = inputs.get("toggle_trigger_words", [])
original_message = inputs.get("orinalMessage", "") or inputs.get("trigger_words", "")
# check if toggle_words is a list or a dict with __value__ key (new format)
if isinstance(toggle_data, dict) and "__value__" in toggle_data:
toggle_words = toggle_data["__value__"]
elif isinstance(toggle_data, list):
toggle_words = toggle_data
else:
toggle_words = []
# Fix double commas to match the reference output format # Filter active trigger words
if original_message: active_words = []
# Replace double commas with single commas for item in toggle_words:
original_message = original_message.replace(",, ", ", ") if isinstance(item, dict) and item.get("active", False):
word = item.get("text", "")
return original_message if word and not word.startswith("__dummy"):
active_words.append(word)
# Join with commas
result = ", ".join(active_words)
return result
class FluxGuidanceMapper(NodeMapper): class FluxGuidanceMapper(NodeMapper):
@@ -251,5 +306,89 @@ class FluxGuidanceMapper(NodeMapper):
return result return result
# Add import os for LoraLoaderMapper to work properly # =============================================================================
import os # Mapper Registry Functions
# =============================================================================
def register_mapper(mapper: NodeMapper) -> None:
"""Register a node mapper in the global registry"""
_MAPPER_REGISTRY[mapper.node_type] = mapper
logger.debug(f"Registered mapper for node type: {mapper.node_type}")
def get_mapper(node_type: str) -> Optional[NodeMapper]:
"""Get a mapper for the specified node type"""
return _MAPPER_REGISTRY.get(node_type)
def get_all_mappers() -> Dict[str, NodeMapper]:
"""Get all registered mappers"""
return _MAPPER_REGISTRY.copy()
def register_default_mappers() -> None:
"""Register all default mappers"""
default_mappers = [
KSamplerMapper(),
EmptyLatentImageMapper(),
EmptySD3LatentImageMapper(),
CLIPTextEncodeMapper(),
LoraLoaderMapper(),
LoraStackerMapper(),
JoinStringsMapper(),
StringConstantMapper(),
TriggerWordToggleMapper(),
FluxGuidanceMapper()
]
for mapper in default_mappers:
register_mapper(mapper)
# =============================================================================
# Extension Loading
# =============================================================================
def load_extensions(ext_dir: str = None) -> None:
"""
Load mapper extensions from the specified directory
Each Python file in the directory will be loaded, and any NodeMapper subclasses
defined in those files will be automatically registered.
"""
# Use default path if none provided
if ext_dir is None:
# Get the directory of this file
current_dir = os.path.dirname(os.path.abspath(__file__))
ext_dir = os.path.join(current_dir, 'ext')
# Ensure the extension directory exists
if not os.path.exists(ext_dir):
os.makedirs(ext_dir, exist_ok=True)
logger.info(f"Created extension directory: {ext_dir}")
return
# Load each Python file in the extension directory
for filename in os.listdir(ext_dir):
if filename.endswith('.py') and not filename.startswith('_'):
module_path = os.path.join(ext_dir, filename)
module_name = f"workflow.ext.{filename[:-3]}" # Remove .py
try:
# Load the module
spec = importlib.util.spec_from_file_location(module_name, module_path)
if spec and spec.loader:
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
# Find all NodeMapper subclasses in the module
for name, obj in inspect.getmembers(module):
if (inspect.isclass(obj) and issubclass(obj, NodeMapper)
and obj != NodeMapper and hasattr(obj, 'node_type')):
# Instantiate and register the mapper
mapper = obj()
register_mapper(mapper)
logger.info(f"Loaded extension mapper: {mapper.node_type} from {filename}")
except Exception as e:
logger.error(f"Error loading extension {filename}: {e}")
# Initialize the registry with default mappers
register_default_mappers()

View File

@@ -4,12 +4,7 @@ Main workflow parser implementation for ComfyUI
import json import json
import logging import logging
from typing import Dict, List, Any, Optional, Union, Set from typing import Dict, List, Any, Optional, Union, Set
from .mappers import ( from .mappers import get_mapper, get_all_mappers, load_extensions
NodeMapper, KSamplerMapper, EmptyLatentImageMapper,
EmptySD3LatentImageMapper, CLIPTextEncodeMapper,
LoraLoaderMapper, LoraStackerMapper, JoinStringsMapper,
StringConstantMapper, TriggerWordToggleMapper, FluxGuidanceMapper
)
from .utils import ( from .utils import (
load_workflow, save_output, find_node_by_type, load_workflow, save_output, find_node_by_type,
trace_model_path trace_model_path
@@ -20,33 +15,13 @@ logger = logging.getLogger(__name__)
class WorkflowParser: class WorkflowParser:
"""Parser for ComfyUI workflows""" """Parser for ComfyUI workflows"""
def __init__(self): def __init__(self, load_extensions_on_init: bool = True):
"""Initialize the parser with default node mappers""" """Initialize the parser with mappers"""
self.node_mappers: Dict[str, NodeMapper] = {}
self.processed_nodes: Set[str] = set() # Track processed nodes to avoid cycles self.processed_nodes: Set[str] = set() # Track processed nodes to avoid cycles
self.register_default_mappers()
def register_default_mappers(self) -> None:
"""Register all default node mappers"""
mappers = [
KSamplerMapper(),
EmptyLatentImageMapper(),
EmptySD3LatentImageMapper(),
CLIPTextEncodeMapper(),
LoraLoaderMapper(),
LoraStackerMapper(),
JoinStringsMapper(),
StringConstantMapper(),
TriggerWordToggleMapper(),
FluxGuidanceMapper()
]
for mapper in mappers: # Load extensions if requested
self.register_mapper(mapper) if load_extensions_on_init:
load_extensions()
def register_mapper(self, mapper: NodeMapper) -> None:
"""Register a node mapper"""
self.node_mappers[mapper.node_type] = mapper
def process_node(self, node_id: str, workflow: Dict) -> Any: def process_node(self, node_id: str, workflow: Dict) -> Any:
"""Process a single node and extract relevant information""" """Process a single node and extract relevant information"""
@@ -64,8 +39,8 @@ class WorkflowParser:
node_type = node_data.get("class_type") node_type = node_data.get("class_type")
result = None result = None
if node_type in self.node_mappers: mapper = get_mapper(node_type)
mapper = self.node_mappers[node_type] if mapper:
result = mapper.process(node_id, node_data, workflow, self) result = mapper.process(node_id, node_data, workflow, self)
# Remove node from processed set to allow it to be processed again in a different context # Remove node from processed set to allow it to be processed again in a different context

View File

@@ -44,7 +44,7 @@
}, },
"31": { "31": {
"inputs": { "inputs": {
"seed": 48, "seed": 44,
"steps": 25, "steps": 25,
"cfg": 1, "cfg": 1,
"sampler_name": "dpmpp_2m", "sampler_name": "dpmpp_2m",
@@ -95,7 +95,7 @@
}, },
"class_type": "FluxGuidance", "class_type": "FluxGuidance",
"_meta": { "_meta": {
"title": "g" "title": "FluxGuidance"
} }
}, },
"37": { "37": {
@@ -175,12 +175,12 @@
{ {
"name": "pp-enchanted-whimsy", "name": "pp-enchanted-whimsy",
"strength": "0.90", "strength": "0.90",
"active": true "active": false
}, },
{ {
"name": "ral-frctlgmtry_flux", "name": "ral-frctlgmtry_flux",
"strength": "0.85", "strength": "0.85",
"active": true "active": false
}, },
{ {
"name": "pp-storybook_rank2_bf16", "name": "pp-storybook_rank2_bf16",
@@ -218,17 +218,9 @@
"inputs": { "inputs": {
"group_mode": "", "group_mode": "",
"toggle_trigger_words": [ "toggle_trigger_words": [
{
"text": "in the style of ppWhimsy",
"active": true
},
{
"text": "ral-frctlgmtry",
"active": true
},
{ {
"text": "ppstorybook", "text": "ppstorybook",
"active": true "active": false
}, },
{ {
"text": "__dummy_item__", "text": "__dummy_item__",
@@ -241,7 +233,7 @@
"_isDummy": true "_isDummy": true
} }
], ],
"orinalMessage": "in the style of ppWhimsy,, ral-frctlgmtry,, ppstorybook", "orinalMessage": "ppstorybook",
"trigger_words": [ "trigger_words": [
"58", "58",
2 2
@@ -251,5 +243,72 @@
"_meta": { "_meta": {
"title": "TriggerWord Toggle (LoraManager)" "title": "TriggerWord Toggle (LoraManager)"
} }
},
"61": {
"inputs": {
"add_noise": "enable",
"noise_seed": 1111423448930884,
"steps": 20,
"cfg": 8,
"sampler_name": "euler",
"scheduler": "normal",
"start_at_step": 0,
"end_at_step": 10000,
"return_with_leftover_noise": "disable"
},
"class_type": "KSamplerAdvanced",
"_meta": {
"title": "KSampler (Advanced)"
}
},
"62": {
"inputs": {
"sigmas": [
"63",
0
]
},
"class_type": "SamplerCustomAdvanced",
"_meta": {
"title": "SamplerCustomAdvanced"
}
},
"63": {
"inputs": {
"scheduler": "normal",
"steps": 20,
"denoise": 1
},
"class_type": "BasicScheduler",
"_meta": {
"title": "BasicScheduler"
}
},
"64": {
"inputs": {
"seed": 1089899258710474,
"steps": 20,
"cfg": 8,
"sampler_name": "euler",
"scheduler": "normal",
"denoise": 1
},
"class_type": "KSampler",
"_meta": {
"title": "KSampler"
}
},
"65": {
"inputs": {
"text": ",Stylized geek cat artist with glasses and a paintbrush, smiling at the viewer while holding a sign that reads 'Stay tuned!', solid white background",
"anything": [
"46",
0
]
},
"class_type": "easy showAnything",
"_meta": {
"title": "Show Any"
}
} }
} }