mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-24 06:32:12 -03:00
checkpoint
This commit is contained in:
169
py/workflow/ext/comfyui_core.py
Normal file
169
py/workflow/ext/comfyui_core.py
Normal file
@@ -0,0 +1,169 @@
|
||||
"""
|
||||
ComfyUI Core nodes mappers extension for workflow parsing
|
||||
"""
|
||||
import logging
|
||||
from typing import Dict, Any, List
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Import the mapper registration functions from the parent module
|
||||
from workflow.mappers import create_mapper, register_mapper
|
||||
|
||||
# =============================================================================
|
||||
# Transform Functions
|
||||
# =============================================================================
|
||||
|
||||
def transform_random_noise(inputs: Dict) -> Dict:
|
||||
"""Transform function for RandomNoise node"""
|
||||
return {"seed": str(inputs.get("noise_seed", ""))}
|
||||
|
||||
def transform_ksampler_select(inputs: Dict) -> Dict:
|
||||
"""Transform function for KSamplerSelect node"""
|
||||
return {"sampler": inputs.get("sampler_name", "")}
|
||||
|
||||
def transform_basic_scheduler(inputs: Dict) -> Dict:
|
||||
"""Transform function for BasicScheduler node"""
|
||||
result = {
|
||||
"scheduler": inputs.get("scheduler", ""),
|
||||
"denoise": str(inputs.get("denoise", "1.0"))
|
||||
}
|
||||
|
||||
# Get steps from inputs or steps input
|
||||
if "steps" in inputs:
|
||||
if isinstance(inputs["steps"], str):
|
||||
result["steps"] = inputs["steps"]
|
||||
elif isinstance(inputs["steps"], dict) and "value" in inputs["steps"]:
|
||||
result["steps"] = str(inputs["steps"]["value"])
|
||||
else:
|
||||
result["steps"] = str(inputs["steps"])
|
||||
|
||||
return result
|
||||
|
||||
def transform_basic_guider(inputs: Dict) -> Dict:
|
||||
"""Transform function for BasicGuider node"""
|
||||
result = {}
|
||||
|
||||
# Process conditioning
|
||||
if "conditioning" in inputs:
|
||||
if isinstance(inputs["conditioning"], str):
|
||||
result["prompt"] = inputs["conditioning"]
|
||||
elif isinstance(inputs["conditioning"], dict):
|
||||
result["conditioning"] = inputs["conditioning"]
|
||||
|
||||
# Get model information if needed
|
||||
if "model" in inputs and isinstance(inputs["model"], dict):
|
||||
if "loras" in inputs["model"]:
|
||||
result["loras"] = inputs["model"]["loras"]
|
||||
|
||||
return result
|
||||
|
||||
def transform_model_sampling_flux(inputs: Dict) -> Dict:
|
||||
"""Transform function for ModelSamplingFlux - mostly a pass-through node"""
|
||||
# This node is primarily used for routing, so we mostly pass through values
|
||||
result = {}
|
||||
|
||||
# Extract any dimensions if present
|
||||
width = inputs.get("width", 0)
|
||||
height = inputs.get("height", 0)
|
||||
if width and height:
|
||||
result["width"] = width
|
||||
result["height"] = height
|
||||
result["size"] = f"{width}x{height}"
|
||||
|
||||
# Pass through model information
|
||||
if "model" in inputs and isinstance(inputs["model"], dict):
|
||||
for key, value in inputs["model"].items():
|
||||
result[key] = value
|
||||
|
||||
return result
|
||||
|
||||
def transform_sampler_custom_advanced(inputs: Dict) -> Dict:
|
||||
"""Transform function for SamplerCustomAdvanced node"""
|
||||
result = {}
|
||||
|
||||
# Extract seed from noise
|
||||
if "noise" in inputs and isinstance(inputs["noise"], dict):
|
||||
result["seed"] = str(inputs["noise"].get("seed", ""))
|
||||
|
||||
# Extract sampler info
|
||||
if "sampler" in inputs and isinstance(inputs["sampler"], dict):
|
||||
sampler = inputs["sampler"].get("sampler", "")
|
||||
if sampler:
|
||||
result["sampler"] = sampler
|
||||
|
||||
# Extract scheduler, steps, denoise from sigmas
|
||||
if "sigmas" in inputs and isinstance(inputs["sigmas"], dict):
|
||||
sigmas = inputs["sigmas"]
|
||||
result["scheduler"] = sigmas.get("scheduler", "")
|
||||
result["steps"] = str(sigmas.get("steps", ""))
|
||||
result["denoise"] = str(sigmas.get("denoise", "1.0"))
|
||||
|
||||
# Extract prompt and guidance from guider
|
||||
if "guider" in inputs and isinstance(inputs["guider"], dict):
|
||||
guider = inputs["guider"]
|
||||
|
||||
# Get prompt from conditioning
|
||||
if "conditioning" in guider and isinstance(guider["conditioning"], str):
|
||||
result["prompt"] = guider["conditioning"]
|
||||
elif "conditioning" in guider and isinstance(guider["conditioning"], dict):
|
||||
result["guidance"] = guider["conditioning"].get("guidance", "")
|
||||
result["prompt"] = guider["conditioning"].get("prompt", "")
|
||||
|
||||
if "model" in guider and isinstance(guider["model"], dict):
|
||||
result["loras"] = guider["model"].get("loras", "")
|
||||
|
||||
# Extract dimensions from latent_image
|
||||
if "latent_image" in inputs and isinstance(inputs["latent_image"], dict):
|
||||
latent = inputs["latent_image"]
|
||||
width = latent.get("width", 0)
|
||||
height = latent.get("height", 0)
|
||||
if width and height:
|
||||
result["width"] = width
|
||||
result["height"] = height
|
||||
result["size"] = f"{width}x{height}"
|
||||
|
||||
return result
|
||||
|
||||
# =============================================================================
|
||||
# Register Mappers
|
||||
# =============================================================================
|
||||
|
||||
# Define the mappers for ComfyUI core nodes not in main mapper
|
||||
COMFYUI_CORE_MAPPERS = {
|
||||
"RandomNoise": {
|
||||
"inputs_to_track": ["noise_seed"],
|
||||
"transform_func": transform_random_noise
|
||||
},
|
||||
"KSamplerSelect": {
|
||||
"inputs_to_track": ["sampler_name"],
|
||||
"transform_func": transform_ksampler_select
|
||||
},
|
||||
"BasicScheduler": {
|
||||
"inputs_to_track": ["scheduler", "steps", "denoise", "model"],
|
||||
"transform_func": transform_basic_scheduler
|
||||
},
|
||||
"BasicGuider": {
|
||||
"inputs_to_track": ["model", "conditioning"],
|
||||
"transform_func": transform_basic_guider
|
||||
},
|
||||
"ModelSamplingFlux": {
|
||||
"inputs_to_track": ["max_shift", "base_shift", "width", "height", "model"],
|
||||
"transform_func": transform_model_sampling_flux
|
||||
},
|
||||
"SamplerCustomAdvanced": {
|
||||
"inputs_to_track": ["noise", "guider", "sampler", "sigmas", "latent_image"],
|
||||
"transform_func": transform_sampler_custom_advanced
|
||||
}
|
||||
}
|
||||
|
||||
# Register all ComfyUI core mappers
|
||||
for node_type, config in COMFYUI_CORE_MAPPERS.items():
|
||||
mapper = create_mapper(
|
||||
node_type=node_type,
|
||||
inputs_to_track=config["inputs_to_track"],
|
||||
transform_func=config["transform_func"]
|
||||
)
|
||||
register_mapper(mapper)
|
||||
logger.info(f"Registered ComfyUI core mapper for node type: {node_type}")
|
||||
|
||||
logger.info(f"Loaded ComfyUI core extension with {len(COMFYUI_CORE_MAPPERS)} mappers")
|
||||
@@ -1,54 +0,0 @@
|
||||
"""
|
||||
Example extension mapper for demonstrating the extension system
|
||||
"""
|
||||
from typing import Dict, Any
|
||||
from ..mappers import NodeMapper
|
||||
|
||||
class ExampleNodeMapper(NodeMapper):
|
||||
"""Example mapper for custom nodes"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
node_type="ExampleCustomNode",
|
||||
inputs_to_track=["param1", "param2", "image"]
|
||||
)
|
||||
|
||||
def transform(self, inputs: Dict) -> Dict:
|
||||
"""Transform extracted inputs into the desired output format"""
|
||||
result = {}
|
||||
|
||||
# Extract interesting parameters
|
||||
if "param1" in inputs:
|
||||
result["example_param1"] = inputs["param1"]
|
||||
|
||||
if "param2" in inputs:
|
||||
result["example_param2"] = inputs["param2"]
|
||||
|
||||
# You can process the data in any way needed
|
||||
return result
|
||||
|
||||
|
||||
class VAEMapperExtension(NodeMapper):
|
||||
"""Extension mapper for VAE nodes"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
node_type="VAELoader",
|
||||
inputs_to_track=["vae_name"]
|
||||
)
|
||||
|
||||
def transform(self, inputs: Dict) -> Dict:
|
||||
"""Extract VAE information"""
|
||||
vae_name = inputs.get("vae_name", "")
|
||||
|
||||
# Remove path prefix if present
|
||||
if "/" in vae_name or "\\" in vae_name:
|
||||
# Get just the filename without path or extension
|
||||
vae_name = vae_name.replace("\\", "/").split("/")[-1]
|
||||
vae_name = vae_name.split(".")[0] # Remove extension
|
||||
|
||||
return {"vae": vae_name}
|
||||
|
||||
|
||||
# Note: No need to register manually - extensions are automatically registered
|
||||
# when the extension system loads this file
|
||||
@@ -48,6 +48,10 @@ def transform_empty_latent_presets(inputs: Dict) -> Dict:
|
||||
|
||||
return {"width": width, "height": height, "size": f"{width}x{height}"}
|
||||
|
||||
def transform_int_constant(inputs: Dict) -> int:
|
||||
"""Transform function for INTConstant nodes"""
|
||||
return inputs.get("value", 0)
|
||||
|
||||
# =============================================================================
|
||||
# Register Mappers
|
||||
# =============================================================================
|
||||
@@ -65,6 +69,10 @@ KJNODES_MAPPERS = {
|
||||
"EmptyLatentImagePresets": {
|
||||
"inputs_to_track": ["dimensions", "invert", "batch_size"],
|
||||
"transform_func": transform_empty_latent_presets
|
||||
},
|
||||
"INTConstant": {
|
||||
"inputs_to_track": ["value"],
|
||||
"transform_func": transform_int_constant
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -59,6 +59,59 @@ class WorkflowParser:
|
||||
self.processed_nodes.remove(node_id)
|
||||
return result
|
||||
|
||||
def find_primary_sampler_node(self, workflow: Dict) -> Optional[str]:
|
||||
"""
|
||||
Find the primary sampler node in the workflow.
|
||||
|
||||
Priority:
|
||||
1. First try to find a SamplerCustomAdvanced node
|
||||
2. If not found, look for KSampler nodes with denoise=1.0
|
||||
3. If still not found, use the first KSampler node
|
||||
|
||||
Args:
|
||||
workflow: The workflow data as a dictionary
|
||||
|
||||
Returns:
|
||||
The node ID of the primary sampler node, or None if not found
|
||||
"""
|
||||
# First check for SamplerCustomAdvanced nodes
|
||||
sampler_advanced_nodes = []
|
||||
ksampler_nodes = []
|
||||
|
||||
# Scan workflow for sampler nodes
|
||||
for node_id, node_data in workflow.items():
|
||||
node_type = node_data.get("class_type")
|
||||
|
||||
if node_type == "SamplerCustomAdvanced":
|
||||
sampler_advanced_nodes.append(node_id)
|
||||
elif node_type == "KSampler":
|
||||
ksampler_nodes.append(node_id)
|
||||
|
||||
# If we found SamplerCustomAdvanced nodes, return the first one
|
||||
if sampler_advanced_nodes:
|
||||
logger.info(f"Found SamplerCustomAdvanced node: {sampler_advanced_nodes[0]}")
|
||||
return sampler_advanced_nodes[0]
|
||||
|
||||
# If we have KSampler nodes, look for one with denoise=1.0
|
||||
if ksampler_nodes:
|
||||
for node_id in ksampler_nodes:
|
||||
node_data = workflow[node_id]
|
||||
inputs = node_data.get("inputs", {})
|
||||
denoise = inputs.get("denoise", 0)
|
||||
|
||||
# Check if denoise is 1.0 (allowing for small floating point differences)
|
||||
if abs(float(denoise) - 1.0) < 0.001:
|
||||
logger.info(f"Found KSampler node with denoise=1.0: {node_id}")
|
||||
return node_id
|
||||
|
||||
# If no KSampler with denoise=1.0 found, use the first one
|
||||
logger.info(f"No KSampler with denoise=1.0 found, using first KSampler: {ksampler_nodes[0]}")
|
||||
return ksampler_nodes[0]
|
||||
|
||||
# No sampler nodes found
|
||||
logger.warning("No sampler nodes found in workflow")
|
||||
return None
|
||||
|
||||
def collect_loras_from_model(self, model_input: List, workflow: Dict) -> str:
|
||||
"""Collect loras information from the model node chain"""
|
||||
if not isinstance(model_input, list) or len(model_input) != 2:
|
||||
@@ -107,23 +160,23 @@ class WorkflowParser:
|
||||
self.processed_nodes = set()
|
||||
self.node_results_cache = {}
|
||||
|
||||
# Find the KSampler node
|
||||
ksampler_node_id = find_node_by_type(workflow, "KSampler")
|
||||
if not ksampler_node_id:
|
||||
logger.warning("No KSampler node found in workflow")
|
||||
# Find the primary sampler node
|
||||
sampler_node_id = self.find_primary_sampler_node(workflow)
|
||||
if not sampler_node_id:
|
||||
logger.warning("No suitable sampler node found in workflow")
|
||||
return {}
|
||||
|
||||
# Start parsing from the KSampler node
|
||||
# Start parsing from the sampler node
|
||||
result = {
|
||||
"gen_params": {},
|
||||
"loras": ""
|
||||
}
|
||||
|
||||
# Process KSampler node to extract parameters
|
||||
ksampler_result = self.process_node(ksampler_node_id, workflow)
|
||||
if ksampler_result:
|
||||
# Process sampler node to extract parameters
|
||||
sampler_result = self.process_node(sampler_node_id, workflow)
|
||||
if sampler_result:
|
||||
# Process the result
|
||||
for key, value in ksampler_result.items():
|
||||
for key, value in sampler_result.items():
|
||||
# Special handling for the positive prompt from FluxGuidance
|
||||
if key == "positive" and isinstance(value, dict):
|
||||
# Extract guidance value
|
||||
@@ -138,8 +191,8 @@ class WorkflowParser:
|
||||
result["gen_params"][key] = value
|
||||
|
||||
# Process the positive prompt node if it exists and we don't have a prompt yet
|
||||
if "prompt" not in result["gen_params"] and "positive" in ksampler_result:
|
||||
positive_value = ksampler_result.get("positive")
|
||||
if "prompt" not in result["gen_params"] and "positive" in sampler_result:
|
||||
positive_value = sampler_result.get("positive")
|
||||
if isinstance(positive_value, str):
|
||||
result["gen_params"]["prompt"] = positive_value
|
||||
|
||||
@@ -152,11 +205,11 @@ class WorkflowParser:
|
||||
if "guidance" in node_inputs:
|
||||
result["gen_params"]["guidance"] = node_inputs["guidance"]
|
||||
|
||||
# Extract loras from the model input of KSampler
|
||||
ksampler_node = workflow.get(ksampler_node_id, {})
|
||||
ksampler_inputs = ksampler_node.get("inputs", {})
|
||||
if "model" in ksampler_inputs and isinstance(ksampler_inputs["model"], list):
|
||||
loras_text = self.collect_loras_from_model(ksampler_inputs["model"], workflow)
|
||||
# Extract loras from the model input of sampler
|
||||
sampler_node = workflow.get(sampler_node_id, {})
|
||||
sampler_inputs = sampler_node.get("inputs", {})
|
||||
if "model" in sampler_inputs and isinstance(sampler_inputs["model"], list):
|
||||
loras_text = self.collect_loras_from_model(sampler_inputs["model"], workflow)
|
||||
if loras_text:
|
||||
result["loras"] = loras_text
|
||||
|
||||
@@ -164,9 +217,9 @@ class WorkflowParser:
|
||||
if "cfg" in result["gen_params"]:
|
||||
result["gen_params"]["cfg_scale"] = result["gen_params"].pop("cfg")
|
||||
|
||||
# Add clip_skip = 2 to match reference output if not already present
|
||||
# Add clip_skip = 1 to match reference output if not already present
|
||||
if "clip_skip" not in result["gen_params"]:
|
||||
result["gen_params"]["clip_skip"] = "2"
|
||||
result["gen_params"]["clip_skip"] = "1"
|
||||
|
||||
# Ensure the prompt is a string and not a nested dictionary
|
||||
if "prompt" in result["gen_params"] and isinstance(result["gen_params"]["prompt"], dict):
|
||||
|
||||
Reference in New Issue
Block a user