checkpoint

This commit is contained in:
Will Miao
2025-04-01 08:38:49 +08:00
parent 350b81d678
commit 60575b6546
5 changed files with 209 additions and 264 deletions

View File

@@ -5,72 +5,102 @@ import logging
import os
import importlib.util
import inspect
from typing import Dict, List, Any, Optional, Union, Type, Callable
from typing import Dict, List, Any, Optional, Union, Type, Callable, Tuple
logger = logging.getLogger(__name__)
# Global mapper registry
_MAPPER_REGISTRY: Dict[str, 'NodeMapper'] = {}
_MAPPER_REGISTRY: Dict[str, Dict] = {}
class NodeMapper:
"""Base class for node mappers that define how to extract information from a specific node type"""
# =============================================================================
# Mapper Definition Functions
# =============================================================================
def create_mapper(
node_type: str,
inputs_to_track: List[str],
transform_func: Callable[[Dict], Any] = None
) -> Dict:
"""Create a mapper definition for a node type"""
mapper = {
"node_type": node_type,
"inputs_to_track": inputs_to_track,
"transform": transform_func or (lambda inputs: inputs)
}
return mapper
def register_mapper(mapper: Dict) -> None:
"""Register a node mapper in the global registry"""
_MAPPER_REGISTRY[mapper["node_type"]] = mapper
logger.debug(f"Registered mapper for node type: {mapper['node_type']}")
def get_mapper(node_type: str) -> Optional[Dict]:
"""Get a mapper for the specified node type"""
return _MAPPER_REGISTRY.get(node_type)
def get_all_mappers() -> Dict[str, Dict]:
"""Get all registered mappers"""
return _MAPPER_REGISTRY.copy()
# =============================================================================
# Node Processing Function
# =============================================================================
def process_node(node_id: str, node_data: Dict, workflow: Dict, parser: 'WorkflowParser') -> Any:
"""Process a node using its mapper and extract relevant information"""
node_type = node_data.get("class_type")
mapper = get_mapper(node_type)
def __init__(self, node_type: str, inputs_to_track: List[str]):
self.node_type = node_type
self.inputs_to_track = inputs_to_track
def process(self, node_id: str, node_data: Dict, workflow: Dict, parser: 'WorkflowParser') -> Any: # type: ignore
"""Process the node and extract relevant information"""
result = {}
for input_name in self.inputs_to_track:
if input_name in node_data.get("inputs", {}):
input_value = node_data["inputs"][input_name]
# Check if input is a reference to another node's output
if isinstance(input_value, list) and len(input_value) == 2:
# Format is [node_id, output_slot]
try:
ref_node_id, output_slot = input_value
# Convert node_id to string if it's an integer
if isinstance(ref_node_id, int):
ref_node_id = str(ref_node_id)
# Recursively process the referenced node
ref_value = parser.process_node(ref_node_id, workflow)
# Store the processed value
if ref_value is not None:
result[input_name] = ref_value
else:
# If we couldn't get a value from the reference, store the raw value
result[input_name] = input_value
except Exception as e:
logger.error(f"Error processing reference in node {node_id}, input {input_name}: {e}")
# If we couldn't process the reference, store the raw value
result[input_name] = input_value
else:
# Direct value
result[input_name] = input_value
if not mapper:
return None
# Apply any transformations
return self.transform(result)
result = {}
def transform(self, inputs: Dict) -> Any:
"""Transform the extracted inputs - override in subclasses"""
return inputs
# Extract inputs based on the mapper's tracked inputs
for input_name in mapper["inputs_to_track"]:
if input_name in node_data.get("inputs", {}):
input_value = node_data["inputs"][input_name]
# Check if input is a reference to another node's output
if isinstance(input_value, list) and len(input_value) == 2:
try:
# Format is [node_id, output_slot]
ref_node_id, output_slot = input_value
# Convert node_id to string if it's an integer
if isinstance(ref_node_id, int):
ref_node_id = str(ref_node_id)
# Recursively process the referenced node
ref_value = parser.process_node(ref_node_id, workflow)
if ref_value is not None:
result[input_name] = ref_value
else:
# If we couldn't get a value from the reference, store the raw value
result[input_name] = input_value
except Exception as e:
logger.error(f"Error processing reference in node {node_id}, input {input_name}: {e}")
result[input_name] = input_value
else:
# Direct value
result[input_name] = input_value
# Apply the transform function
try:
return mapper["transform"](result)
except Exception as e:
logger.error(f"Error in transform function for node {node_id} of type {node_type}: {e}")
return result
# =============================================================================
# Default Mapper Definitions
# =============================================================================
class KSamplerMapper(NodeMapper):
"""Mapper for KSampler nodes"""
def register_default_mappers() -> None:
"""Register all default mappers"""
def __init__(self):
super().__init__(
node_type="KSampler",
inputs_to_track=["seed", "steps", "cfg", "sampler_name", "scheduler",
"denoise", "positive", "negative", "latent_image",
"model", "clip_skip"]
)
def transform(self, inputs: Dict) -> Dict:
# KSampler mapper
def transform_ksampler(inputs: Dict) -> Dict:
result = {
"seed": str(inputs.get("seed", "")),
"steps": str(inputs.get("steps", "")),
@@ -99,70 +129,52 @@ class KSamplerMapper(NodeMapper):
result["clip_skip"] = str(inputs.get("clip_skip", ""))
return result
class EmptyLatentImageMapper(NodeMapper):
"""Mapper for EmptyLatentImage nodes"""
def __init__(self):
super().__init__(
node_type="EmptyLatentImage",
inputs_to_track=["width", "height", "batch_size"]
)
register_mapper(create_mapper(
node_type="KSampler",
inputs_to_track=["seed", "steps", "cfg", "sampler_name", "scheduler",
"denoise", "positive", "negative", "latent_image",
"model", "clip_skip"],
transform_func=transform_ksampler
))
def transform(self, inputs: Dict) -> Dict:
# EmptyLatentImage mapper
def transform_empty_latent(inputs: Dict) -> Dict:
width = inputs.get("width", 0)
height = inputs.get("height", 0)
return {"width": width, "height": height, "size": f"{width}x{height}"}
class EmptySD3LatentImageMapper(NodeMapper):
"""Mapper for EmptySD3LatentImage nodes"""
def __init__(self):
super().__init__(
node_type="EmptySD3LatentImage",
inputs_to_track=["width", "height", "batch_size"]
)
register_mapper(create_mapper(
node_type="EmptyLatentImage",
inputs_to_track=["width", "height", "batch_size"],
transform_func=transform_empty_latent
))
def transform(self, inputs: Dict) -> Dict:
width = inputs.get("width", 0)
height = inputs.get("height", 0)
return {"width": width, "height": height, "size": f"{width}x{height}"}
class CLIPTextEncodeMapper(NodeMapper):
"""Mapper for CLIPTextEncode nodes"""
# SD3LatentImage mapper - reuses same transform function as EmptyLatentImage
register_mapper(create_mapper(
node_type="EmptySD3LatentImage",
inputs_to_track=["width", "height", "batch_size"],
transform_func=transform_empty_latent
))
def __init__(self):
super().__init__(
node_type="CLIPTextEncode",
inputs_to_track=["text", "clip"]
)
def transform(self, inputs: Dict) -> Any:
# Simply return the text
# CLIPTextEncode mapper
def transform_clip_text(inputs: Dict) -> Any:
return inputs.get("text", "")
class LoraLoaderMapper(NodeMapper):
"""Mapper for LoraLoader nodes"""
def __init__(self):
super().__init__(
node_type="Lora Loader (LoraManager)",
inputs_to_track=["loras", "lora_stack"]
)
register_mapper(create_mapper(
node_type="CLIPTextEncode",
inputs_to_track=["text", "clip"],
transform_func=transform_clip_text
))
def transform(self, inputs: Dict) -> Dict:
# Fallback to loras array if text field doesn't exist or is invalid
# LoraLoader mapper
def transform_lora_loader(inputs: Dict) -> Dict:
loras_data = inputs.get("loras", [])
lora_stack = inputs.get("lora_stack", {}).get("lora_stack", [])
# Process loras array - filter active entries
lora_texts = []
# Check if loras_data is a list or a dict with __value__ key (new format)
# Process loras array
if isinstance(loras_data, dict) and "__value__" in loras_data:
loras_list = loras_data["__value__"]
elif isinstance(loras_data, list):
@@ -172,42 +184,29 @@ class LoraLoaderMapper(NodeMapper):
# Process each active lora entry
for lora in loras_list:
logger.info(f"Lora: {lora}, active: {lora.get('active')}")
if isinstance(lora, dict) and lora.get("active", False):
lora_name = lora.get("name", "")
strength = lora.get("strength", 1.0)
lora_texts.append(f"<lora:{lora_name}:{strength}>")
# Process lora_stack if it exists and is a valid format (list of tuples)
# Process lora_stack if valid
if lora_stack and isinstance(lora_stack, list):
# If lora_stack is a reference to another node ([node_id, output_slot]),
# we don't process it here as it's already been processed recursively
if len(lora_stack) == 2 and isinstance(lora_stack[0], (str, int)) and isinstance(lora_stack[1], int):
# This is a reference to another node, already processed
pass
else:
# Format each entry from the stack (assuming it's a list of tuples)
if not (len(lora_stack) == 2 and isinstance(lora_stack[0], (str, int)) and isinstance(lora_stack[1], int)):
for stack_entry in lora_stack:
lora_name = stack_entry[0]
strength = stack_entry[1]
lora_texts.append(f"<lora:{lora_name}:{strength}>")
# Join with spaces
combined_text = " ".join(lora_texts)
return {"loras": combined_text}
class LoraStackerMapper(NodeMapper):
"""Mapper for LoraStacker nodes"""
return {"loras": " ".join(lora_texts)}
def __init__(self):
super().__init__(
node_type="Lora Stacker (LoraManager)",
inputs_to_track=["loras", "lora_stack"]
)
register_mapper(create_mapper(
node_type="Lora Loader (LoraManager)",
inputs_to_track=["loras", "lora_stack"],
transform_func=transform_lora_loader
))
def transform(self, inputs: Dict) -> Dict:
# LoraStacker mapper
def transform_lora_stacker(inputs: Dict) -> Dict:
loras_data = inputs.get("loras", [])
result_stack = []
@@ -215,25 +214,18 @@ class LoraStackerMapper(NodeMapper):
existing_stack = []
lora_stack_input = inputs.get("lora_stack", [])
# Handle different formats of lora_stack
if isinstance(lora_stack_input, dict) and "lora_stack" in lora_stack_input:
# Format from another LoraStacker node
existing_stack = lora_stack_input["lora_stack"]
elif isinstance(lora_stack_input, list):
# Direct list format or reference format [node_id, output_slot]
if len(lora_stack_input) == 2 and isinstance(lora_stack_input[0], (str, int)) and isinstance(lora_stack_input[1], int):
# This is likely a reference that was already processed
pass
else:
# Regular list of tuples/entries
if not (len(lora_stack_input) == 2 and isinstance(lora_stack_input[0], (str, int)) and
isinstance(lora_stack_input[1], int)):
existing_stack = lora_stack_input
# Add existing entries first
# Add existing entries
if existing_stack:
result_stack.extend(existing_stack)
# Process loras array - filter active entries
# Check if loras_data is a list or a dict with __value__ key (new format)
# Process new loras
if isinstance(loras_data, dict) and "__value__" in loras_data:
loras_list = loras_data["__value__"]
elif isinstance(loras_data, list):
@@ -241,7 +233,6 @@ class LoraStackerMapper(NodeMapper):
else:
loras_list = []
# Process each active lora entry
for lora in loras_list:
if isinstance(lora, dict) and lora.get("active", False):
lora_name = lora.get("name", "")
@@ -249,50 +240,40 @@ class LoraStackerMapper(NodeMapper):
result_stack.append((lora_name, strength))
return {"lora_stack": result_stack}
class JoinStringsMapper(NodeMapper):
"""Mapper for JoinStrings nodes"""
def __init__(self):
super().__init__(
node_type="JoinStrings",
inputs_to_track=["string1", "string2", "delimiter"]
)
register_mapper(create_mapper(
node_type="Lora Stacker (LoraManager)",
inputs_to_track=["loras", "lora_stack"],
transform_func=transform_lora_stacker
))
def transform(self, inputs: Dict) -> str:
# JoinStrings mapper
def transform_join_strings(inputs: Dict) -> str:
string1 = inputs.get("string1", "")
string2 = inputs.get("string2", "")
delimiter = inputs.get("delimiter", "")
return f"{string1}{delimiter}{string2}"
class StringConstantMapper(NodeMapper):
"""Mapper for StringConstant and StringConstantMultiline nodes"""
def __init__(self):
super().__init__(
node_type="StringConstantMultiline",
inputs_to_track=["string"]
)
register_mapper(create_mapper(
node_type="JoinStrings",
inputs_to_track=["string1", "string2", "delimiter"],
transform_func=transform_join_strings
))
def transform(self, inputs: Dict) -> str:
# StringConstant mapper
def transform_string_constant(inputs: Dict) -> str:
return inputs.get("string", "")
class TriggerWordToggleMapper(NodeMapper):
"""Mapper for TriggerWordToggle nodes"""
def __init__(self):
super().__init__(
node_type="TriggerWord Toggle (LoraManager)",
inputs_to_track=["toggle_trigger_words"]
)
register_mapper(create_mapper(
node_type="StringConstantMultiline",
inputs_to_track=["string"],
transform_func=transform_string_constant
))
def transform(self, inputs: Dict) -> str:
# TriggerWordToggle mapper
def transform_trigger_word_toggle(inputs: Dict) -> str:
toggle_data = inputs.get("toggle_trigger_words", [])
# check if toggle_words is a list or a dict with __value__ key (new format)
if isinstance(toggle_data, dict) and "__value__" in toggle_data:
toggle_words = toggle_data["__value__"]
elif isinstance(toggle_data, list):
@@ -308,28 +289,21 @@ class TriggerWordToggleMapper(NodeMapper):
if word and not word.startswith("__dummy"):
active_words.append(word)
# Join with commas
result = ", ".join(active_words)
return result
class FluxGuidanceMapper(NodeMapper):
"""Mapper for FluxGuidance nodes"""
return ", ".join(active_words)
def __init__(self):
super().__init__(
node_type="FluxGuidance",
inputs_to_track=["guidance", "conditioning"]
)
register_mapper(create_mapper(
node_type="TriggerWord Toggle (LoraManager)",
inputs_to_track=["toggle_trigger_words"],
transform_func=transform_trigger_word_toggle
))
def transform(self, inputs: Dict) -> Dict:
# FluxGuidance mapper
def transform_flux_guidance(inputs: Dict) -> Dict:
result = {}
# Handle guidance parameter
if "guidance" in inputs:
result["guidance"] = inputs["guidance"]
# Handle conditioning (the prompt text)
if "conditioning" in inputs:
conditioning = inputs["conditioning"]
if isinstance(conditioning, str):
@@ -338,42 +312,12 @@ class FluxGuidanceMapper(NodeMapper):
result["prompt"] = "Unknown prompt"
return result
# =============================================================================
# Mapper Registry Functions
# =============================================================================
def register_mapper(mapper: NodeMapper) -> None:
"""Register a node mapper in the global registry"""
_MAPPER_REGISTRY[mapper.node_type] = mapper
logger.debug(f"Registered mapper for node type: {mapper.node_type}")
def get_mapper(node_type: str) -> Optional[NodeMapper]:
"""Get a mapper for the specified node type"""
return _MAPPER_REGISTRY.get(node_type)
def get_all_mappers() -> Dict[str, NodeMapper]:
"""Get all registered mappers"""
return _MAPPER_REGISTRY.copy()
def register_default_mappers() -> None:
"""Register all default mappers"""
default_mappers = [
KSamplerMapper(),
EmptyLatentImageMapper(),
EmptySD3LatentImageMapper(),
CLIPTextEncodeMapper(),
LoraLoaderMapper(),
LoraStackerMapper(),
JoinStringsMapper(),
StringConstantMapper(),
TriggerWordToggleMapper(),
FluxGuidanceMapper()
]
for mapper in default_mappers:
register_mapper(mapper)
register_mapper(create_mapper(
node_type="FluxGuidance",
inputs_to_track=["guidance", "conditioning"],
transform_func=transform_flux_guidance
))
# =============================================================================
# Extension Loading
@@ -383,8 +327,8 @@ def load_extensions(ext_dir: str = None) -> None:
"""
Load mapper extensions from the specified directory
Each Python file in the directory will be loaded, and any NodeMapper subclasses
defined in those files will be automatically registered.
Extension files should define mappers using the create_mapper function
and then call register_mapper to add them to the registry.
"""
# Use default path if none provided
if ext_dir is None:
@@ -410,19 +354,9 @@ def load_extensions(ext_dir: str = None) -> None:
if spec and spec.loader:
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
# Find all NodeMapper subclasses in the module
for name, obj in inspect.getmembers(module):
if (inspect.isclass(obj) and issubclass(obj, NodeMapper)
and obj != NodeMapper and hasattr(obj, 'node_type')):
# Instantiate and register the mapper
mapper = obj()
register_mapper(mapper)
logger.info(f"Loaded extension mapper: {mapper.node_type} from {filename}")
logger.info(f"Loaded extension module: {filename}")
except Exception as e:
logger.warning(f"Error loading extension {filename}: {e}")
# Initialize the registry with default mappers
register_default_mappers()

View File

@@ -4,7 +4,7 @@ Main workflow parser implementation for ComfyUI
import json
import logging
from typing import Dict, List, Any, Optional, Union, Set
from .mappers import get_mapper, get_all_mappers, load_extensions
from .mappers import get_mapper, get_all_mappers, load_extensions, process_node
from .utils import (
load_workflow, save_output, find_node_by_type,
trace_model_path
@@ -45,10 +45,9 @@ class WorkflowParser:
node_type = node_data.get("class_type")
result = None
mapper = get_mapper(node_type)
if mapper:
if get_mapper(node_type):
try:
result = mapper.process(node_id, node_data, workflow, self)
result = process_node(node_id, node_data, workflow, self)
# Cache the result
self.node_results_cache[node_id] = result
except Exception as e:

View File

@@ -1,13 +1,11 @@
{
"loras": "<lora:ck-neon-retrowave-IL-000012:0.8> <lora:aorunIllstrious:1> <lora:ck-shadow-circuit-IL-000012:0.78> <lora:MoriiMee_Gothic_Niji_Style_Illustrious_r1:0.45> <lora:ck-nc-cyberpunk-IL-000011:0.4>",
"gen_params": {
"prompt": "in the style of ck-rw, aorun, scales, makeup, bare shoulders, pointy ears, dress, claws, in the style of cksc, artist:moriimee, in the style of cknc, masterpiece, best quality, good quality, very aesthetic, absurdres, newest, 8K, depth of field, focused subject, close up, stylized, in gold and neon shades, wabi sabi, 1girl, rainbow angel wings, looking at viewer, dynamic angle, from below, from side, relaxing",
"negative_prompt": "bad quality, worst quality, worst detail, sketch ,signature, watermark, patreon logo, nsfw",
"steps": "20",
"sampler": "euler_ancestral",
"cfg_scale": "8",
"seed": "241",
"size": "832x1216",
"clip_skip": "2"
}
"prompt": "in the style of ck-rw, aorun, scales, makeup, bare shoulders, pointy ears, dress, claws, in the style of cksc, artist:moriimee, in the style of cknc, masterpiece, best quality, good quality, very aesthetic, absurdres, newest, 8K, depth of field, focused subject, close up, stylized, in gold and neon shades, wabi sabi, 1girl, rainbow angel wings, looking at viewer, dynamic angle, from below, from side, relaxing",
"negative_prompt": "bad quality, worst quality, worst detail, sketch ,signature, watermark, patreon logo, nsfw",
"steps": "20",
"sampler": "euler_ancestral",
"cfg_scale": "8",
"seed": "241",
"size": "832x1216",
"clip_skip": "2"
}

View File

@@ -1,7 +1,7 @@
{
"3": {
"inputs": {
"seed": 241,
"seed": 42,
"steps": 20,
"cfg": 8,
"sampler_name": "euler_ancestral",
@@ -121,7 +121,7 @@
},
"21": {
"inputs": {
"string": "masterpiece, best quality, good quality, very aesthetic, absurdres, newest, 8K, depth of field, focused subject, close up, stylized, in gold and neon shades, wabi sabi, 1girl, rainbow angel wings, looking at viewer, dynamic angle, from below, from side, relaxing",
"string": "masterpiece, best quality, good quality, very awa, newest, highres, absurdres, 1girl, solo, dress, standing, flower, outdoors, water, white flower, pink flower, scenery, reflection, rain, dark, ripples, yellow flower, puddle, colorful, abstract, standing on liquidi¼Œ\nvery Wide Shot, limited palette,",
"strip_newlines": false
},
"class_type": "StringConstantMultiline",
@@ -151,15 +151,19 @@
"group_mode": true,
"toggle_trigger_words": [
{
"text": "in the style of ck-rw",
"text": "xxx667_illu",
"active": true
},
{
"text": "in the style of cksc",
"text": "glowing",
"active": true
},
{
"text": "artist:moriimee",
"text": "glitch",
"active": true
},
{
"text": "15546+456868",
"active": true
},
{
@@ -173,7 +177,7 @@
"_isDummy": true
}
],
"orinalMessage": "in the style of ck-rw,, in the style of cksc,, artist:moriimee",
"orinalMessage": "xxx667_illu,, glowing,, glitch,, 15546+456868",
"trigger_words": [
"56",
2
@@ -186,22 +190,32 @@
},
"56": {
"inputs": {
"text": "<lora:ck-shadow-circuit-IL-000012:0.78> <lora:MoriiMee_Gothic_Niji_Style_Illustrious_r1:0.45> <lora:ck-nc-cyberpunk-IL-000011:0.4>",
"text": "<lora:ponyv6_noobE11_2_adamW-000017:0.3> <lora:XXX667:0.4> <lora:114558v4df2fsdf5:0.6> <lora:illustriousXL_stabilizer_v1.23:0.3> <lora:mon_monmon2133:0.5>",
"loras": [
{
"name": "ck-shadow-circuit-IL-000012",
"strength": 0.78,
"name": "ponyv6_noobE11_2_adamW-000017",
"strength": 0.3,
"active": true
},
{
"name": "MoriiMee_Gothic_Niji_Style_Illustrious_r1",
"strength": 0.45,
"active": true
},
{
"name": "ck-nc-cyberpunk-IL-000011",
"name": "XXX667",
"strength": 0.4,
"active": false
"active": true
},
{
"name": "114558v4df2fsdf5",
"strength": 0.6,
"active": true
},
{
"name": "illustriousXL_stabilizer_v1.23",
"strength": 0.3,
"active": true
},
{
"name": "mon_monmon2133",
"strength": 0.5,
"active": true
},
{
"name": "__dummy_item1__",
@@ -273,7 +287,7 @@
{
"name": "ck-neon-retrowave-IL-000012",
"strength": 0.8,
"active": true
"active": false
},
{
"name": "__dummy_item1__",

View File

@@ -824,7 +824,7 @@ async function saveRecipeDirectly(widget) {
try {
// Get the workflow data from the ComfyUI app
const prompt = await app.graphToPrompt();
console.log('Prompt:', prompt.output);
console.log('Prompt:', prompt);
// Show loading toast
if (app && app.extensionManager && app.extensionManager.toast) {