checkpoint

This commit is contained in:
Will Miao
2025-04-01 19:17:43 +08:00
parent 195866b00d
commit 27db60ce68
6 changed files with 601 additions and 270 deletions

View File

@@ -59,6 +59,59 @@ class WorkflowParser:
self.processed_nodes.remove(node_id)
return result
def find_primary_sampler_node(self, workflow: Dict) -> Optional[str]:
"""
Find the primary sampler node in the workflow.
Priority:
1. First try to find a SamplerCustomAdvanced node
2. If not found, look for KSampler nodes with denoise=1.0
3. If still not found, use the first KSampler node
Args:
workflow: The workflow data as a dictionary
Returns:
The node ID of the primary sampler node, or None if not found
"""
# First check for SamplerCustomAdvanced nodes
sampler_advanced_nodes = []
ksampler_nodes = []
# Scan workflow for sampler nodes
for node_id, node_data in workflow.items():
node_type = node_data.get("class_type")
if node_type == "SamplerCustomAdvanced":
sampler_advanced_nodes.append(node_id)
elif node_type == "KSampler":
ksampler_nodes.append(node_id)
# If we found SamplerCustomAdvanced nodes, return the first one
if sampler_advanced_nodes:
logger.info(f"Found SamplerCustomAdvanced node: {sampler_advanced_nodes[0]}")
return sampler_advanced_nodes[0]
# If we have KSampler nodes, look for one with denoise=1.0
if ksampler_nodes:
for node_id in ksampler_nodes:
node_data = workflow[node_id]
inputs = node_data.get("inputs", {})
denoise = inputs.get("denoise", 0)
# Check if denoise is 1.0 (allowing for small floating point differences)
if abs(float(denoise) - 1.0) < 0.001:
logger.info(f"Found KSampler node with denoise=1.0: {node_id}")
return node_id
# If no KSampler with denoise=1.0 found, use the first one
logger.info(f"No KSampler with denoise=1.0 found, using first KSampler: {ksampler_nodes[0]}")
return ksampler_nodes[0]
# No sampler nodes found
logger.warning("No sampler nodes found in workflow")
return None
def collect_loras_from_model(self, model_input: List, workflow: Dict) -> str:
"""Collect loras information from the model node chain"""
if not isinstance(model_input, list) or len(model_input) != 2:
@@ -107,23 +160,23 @@ class WorkflowParser:
self.processed_nodes = set()
self.node_results_cache = {}
# Find the KSampler node
ksampler_node_id = find_node_by_type(workflow, "KSampler")
if not ksampler_node_id:
logger.warning("No KSampler node found in workflow")
# Find the primary sampler node
sampler_node_id = self.find_primary_sampler_node(workflow)
if not sampler_node_id:
logger.warning("No suitable sampler node found in workflow")
return {}
# Start parsing from the KSampler node
# Start parsing from the sampler node
result = {
"gen_params": {},
"loras": ""
}
# Process KSampler node to extract parameters
ksampler_result = self.process_node(ksampler_node_id, workflow)
if ksampler_result:
# Process sampler node to extract parameters
sampler_result = self.process_node(sampler_node_id, workflow)
if sampler_result:
# Process the result
for key, value in ksampler_result.items():
for key, value in sampler_result.items():
# Special handling for the positive prompt from FluxGuidance
if key == "positive" and isinstance(value, dict):
# Extract guidance value
@@ -138,8 +191,8 @@ class WorkflowParser:
result["gen_params"][key] = value
# Process the positive prompt node if it exists and we don't have a prompt yet
if "prompt" not in result["gen_params"] and "positive" in ksampler_result:
positive_value = ksampler_result.get("positive")
if "prompt" not in result["gen_params"] and "positive" in sampler_result:
positive_value = sampler_result.get("positive")
if isinstance(positive_value, str):
result["gen_params"]["prompt"] = positive_value
@@ -152,11 +205,11 @@ class WorkflowParser:
if "guidance" in node_inputs:
result["gen_params"]["guidance"] = node_inputs["guidance"]
# Extract loras from the model input of KSampler
ksampler_node = workflow.get(ksampler_node_id, {})
ksampler_inputs = ksampler_node.get("inputs", {})
if "model" in ksampler_inputs and isinstance(ksampler_inputs["model"], list):
loras_text = self.collect_loras_from_model(ksampler_inputs["model"], workflow)
# Extract loras from the model input of sampler
sampler_node = workflow.get(sampler_node_id, {})
sampler_inputs = sampler_node.get("inputs", {})
if "model" in sampler_inputs and isinstance(sampler_inputs["model"], list):
loras_text = self.collect_loras_from_model(sampler_inputs["model"], workflow)
if loras_text:
result["loras"] = loras_text
@@ -164,9 +217,9 @@ class WorkflowParser:
if "cfg" in result["gen_params"]:
result["gen_params"]["cfg_scale"] = result["gen_params"].pop("cfg")
# Add clip_skip = 2 to match reference output if not already present
# Add clip_skip = 1 to match reference output if not already present
if "clip_skip" not in result["gen_params"]:
result["gen_params"]["clip_skip"] = "2"
result["gen_params"]["clip_skip"] = "1"
# Ensure the prompt is a string and not a nested dictionary
if "prompt" in result["gen_params"] and isinstance(result["gen_params"]["prompt"], dict):