checkpoint

This commit is contained in:
Will Miao
2025-04-02 06:05:24 +08:00
parent 27db60ce68
commit a8ec5af037
5 changed files with 67 additions and 130 deletions

View File

@@ -783,8 +783,8 @@ class RecipeRoutes:
# Parse the workflow to extract generation parameters and loras
parsed_workflow = self.parser.parse_workflow(workflow_json)
if not parsed_workflow or not parsed_workflow.get("gen_params"):
return web.json_response({"error": "Could not extract generation parameters from workflow"}, status=400)
if not parsed_workflow:
return web.json_response({"error": "Could not extract parameters from workflow"}, status=400)
# Get the lora stack from the parsed workflow
lora_stack = parsed_workflow.get("loras", "")
@@ -880,7 +880,9 @@ class RecipeRoutes:
"created_date": time.time(),
"base_model": most_common_base_model,
"loras": loras_data,
"gen_params": parsed_workflow.get("gen_params", {}), # Use the parsed workflow parameters
"checkpoint": parsed_workflow.get("checkpoint", ""),
"gen_params": {key: value for key, value in parsed_workflow.items()
if key not in ['checkpoint', 'loras']},
"loras_stack": lora_stack # Include the original lora stack string
}

View File

@@ -52,30 +52,15 @@ def transform_basic_guider(inputs: Dict) -> Dict:
# Get model information if needed
if "model" in inputs and isinstance(inputs["model"], dict):
if "loras" in inputs["model"]:
result["loras"] = inputs["model"]["loras"]
result["model"] = inputs["model"]
return result
def transform_model_sampling_flux(inputs: Dict) -> Dict:
"""Transform function for ModelSamplingFlux - mostly a pass-through node"""
# This node is primarily used for routing, so we mostly pass through values
result = {}
# Extract any dimensions if present
width = inputs.get("width", 0)
height = inputs.get("height", 0)
if width and height:
result["width"] = width
result["height"] = height
result["size"] = f"{width}x{height}"
# Pass through model information
if "model" in inputs and isinstance(inputs["model"], dict):
for key, value in inputs["model"].items():
result[key] = value
return result
return inputs["model"]
def transform_sampler_custom_advanced(inputs: Dict) -> Dict:
"""Transform function for SamplerCustomAdvanced node"""
@@ -110,6 +95,7 @@ def transform_sampler_custom_advanced(inputs: Dict) -> Dict:
result["prompt"] = guider["conditioning"].get("prompt", "")
if "model" in guider and isinstance(guider["model"], dict):
result["checkpoint"] = guider["model"].get("checkpoint", "")
result["loras"] = guider["model"].get("loras", "")
# Extract dimensions from latent_image
@@ -124,12 +110,27 @@ def transform_sampler_custom_advanced(inputs: Dict) -> Dict:
return result
def transform_unet_loader(inputs: Dict) -> Dict:
"""Transform function for UNETLoader node"""
unet_name = inputs.get("unet_name", "")
return {"checkpoint": unet_name} if unet_name else {}
def transform_checkpoint_loader(inputs: Dict) -> Dict:
"""Transform function for CheckpointLoaderSimple node"""
ckpt_name = inputs.get("ckpt_name", "")
return {"checkpoint": ckpt_name} if ckpt_name else {}
# =============================================================================
# Register Mappers
# =============================================================================
# Define the mappers for ComfyUI core nodes not in main mapper
COMFYUI_CORE_MAPPERS = {
# KSamplers
"SamplerCustomAdvanced": {
"inputs_to_track": ["noise", "guider", "sampler", "sigmas", "latent_image"],
"transform_func": transform_sampler_custom_advanced
},
"RandomNoise": {
"inputs_to_track": ["noise_seed"],
"transform_func": transform_random_noise
@@ -150,9 +151,17 @@ COMFYUI_CORE_MAPPERS = {
"inputs_to_track": ["max_shift", "base_shift", "width", "height", "model"],
"transform_func": transform_model_sampling_flux
},
"SamplerCustomAdvanced": {
"inputs_to_track": ["noise", "guider", "sampler", "sigmas", "latent_image"],
"transform_func": transform_sampler_custom_advanced
"UNETLoader": {
"inputs_to_track": ["unet_name"],
"transform_func": transform_unet_loader
},
"CheckpointLoaderSimple": {
"inputs_to_track": ["ckpt_name"],
"transform_func": transform_checkpoint_loader
},
"CheckpointLoader": {
"inputs_to_track": ["ckpt_name"],
"transform_func": transform_checkpoint_loader
}
}

View File

@@ -126,6 +126,15 @@ def transform_ksampler(inputs: Dict) -> Dict:
if "clip_skip" in inputs:
result["clip_skip"] = str(inputs.get("clip_skip", ""))
# Add guidance if present
if "guidance" in inputs:
result["guidance"] = str(inputs.get("guidance", ""))
# Add model if present
if "model" in inputs:
result["checkpoint"] = inputs.get("model", {}).get("checkpoint", "")
result["loras"] = inputs.get("model", {}).get("loras", "")
return result
def transform_empty_latent(inputs: Dict) -> Dict:
@@ -168,7 +177,12 @@ def transform_lora_loader(inputs: Dict) -> Dict:
strength = stack_entry[1]
lora_texts.append(f"<lora:{lora_name}:{strength}>")
return {"loras": " ".join(lora_texts)}
result = {
"checkpoint": inputs.get("model", {}).get("checkpoint", ""),
"loras": " ".join(lora_texts)
}
return result
def transform_lora_stacker(inputs: Dict) -> Dict:
"""Transform function for LoraStacker nodes"""
@@ -276,7 +290,7 @@ NODE_MAPPERS = {
},
# LoraManager nodes
"Lora Loader (LoraManager)": {
"inputs_to_track": ["loras", "lora_stack"],
"inputs_to_track": ["model", "loras", "lora_stack"],
"transform_func": transform_lora_loader
},
"Lora Stacker (LoraManager)": {

View File

@@ -166,71 +166,33 @@ class WorkflowParser:
logger.warning("No suitable sampler node found in workflow")
return {}
# Start parsing from the sampler node
result = {
"gen_params": {},
"loras": ""
}
# Process sampler node to extract parameters
sampler_result = self.process_node(sampler_node_id, workflow)
if sampler_result:
# Process the result
for key, value in sampler_result.items():
# Special handling for the positive prompt from FluxGuidance
if key == "positive" and isinstance(value, dict):
# Extract guidance value
if "guidance" in value:
result["gen_params"]["guidance"] = value["guidance"]
logger.info(f"Sampler result: {sampler_result}")
if not sampler_result:
return {}
# Extract prompt
if "prompt" in value:
result["gen_params"]["prompt"] = value["prompt"]
else:
# Normal handling for other values
result["gen_params"][key] = value
# Process the positive prompt node if it exists and we don't have a prompt yet
if "prompt" not in result["gen_params"] and "positive" in sampler_result:
positive_value = sampler_result.get("positive")
if isinstance(positive_value, str):
result["gen_params"]["prompt"] = positive_value
# Manually check for FluxGuidance if we don't have guidance value
if "guidance" not in result["gen_params"]:
flux_node_id = find_node_by_type(workflow, "FluxGuidance")
if flux_node_id:
# Get the direct input from the node
node_inputs = workflow[flux_node_id].get("inputs", {})
if "guidance" in node_inputs:
result["gen_params"]["guidance"] = node_inputs["guidance"]
# Extract loras from the model input of sampler
sampler_node = workflow.get(sampler_node_id, {})
sampler_inputs = sampler_node.get("inputs", {})
if "model" in sampler_inputs and isinstance(sampler_inputs["model"], list):
loras_text = self.collect_loras_from_model(sampler_inputs["model"], workflow)
if loras_text:
result["loras"] = loras_text
# Return the sampler result directly - it's already in the format we need
# This simplifies the structure and makes it easier to use in recipe_routes.py
# Handle standard ComfyUI names vs our output format
if "cfg" in result["gen_params"]:
result["gen_params"]["cfg_scale"] = result["gen_params"].pop("cfg")
if "cfg" in sampler_result:
sampler_result["cfg_scale"] = sampler_result.pop("cfg")
# Add clip_skip = 1 to match reference output if not already present
if "clip_skip" not in result["gen_params"]:
result["gen_params"]["clip_skip"] = "1"
if "clip_skip" not in sampler_result:
sampler_result["clip_skip"] = "1"
# Ensure the prompt is a string and not a nested dictionary
if "prompt" in result["gen_params"] and isinstance(result["gen_params"]["prompt"], dict):
if "prompt" in result["gen_params"]["prompt"]:
result["gen_params"]["prompt"] = result["gen_params"]["prompt"]["prompt"]
if "prompt" in sampler_result and isinstance(sampler_result["prompt"], dict):
if "prompt" in sampler_result["prompt"]:
sampler_result["prompt"] = sampler_result["prompt"]["prompt"]
# Save the result if requested
if output_path:
save_output(result, output_path)
save_output(sampler_result, output_path)
return result
return sampler_result
def parse_workflow(workflow_path: str, output_path: Optional[str] = None) -> Dict:

View File

@@ -254,51 +254,6 @@
"title": "Text Load Line From File"
}
},
"223": {
"inputs": {
"filename": "%time_%seed",
"path": "%date",
"extension": "jpeg",
"steps": [
"246",
0
],
"cfg": 3.5,
"modelname": "flux_dev",
"sampler_name": "dpmpp_2m",
"scheduler": "beta",
"positive": [
"203",
0
],
"negative": "",
"width": [
"48",
1
],
"height": [
"48",
2
],
"lossless_webp": true,
"quality_jpeg_or_webp": 100,
"optimize_png": false,
"counter": 0,
"denoise": 1,
"clip_skip": 1,
"time_format": "%Y-%m-%d-%H%M%S",
"save_workflow_as_json": false,
"embed_workflow_in_png": false,
"images": [
"8",
0
]
},
"class_type": "Image Saver",
"_meta": {
"title": "Image Saver"
}
},
"226": {
"inputs": {
"images": [
@@ -325,11 +280,7 @@
"group_mode": true,
"toggle_trigger_words": [
{
"text": "perfection style",
"active": true
},
{
"text": "mythp0rt",
"text": "bo-exposure",
"active": true
},
{
@@ -343,7 +294,7 @@
"_isDummy": true
}
],
"orinalMessage": "perfection style,, mythp0rt",
"orinalMessage": "bo-exposure",
"trigger_words": [
"299",
2
@@ -382,7 +333,6 @@
},
"298": {
"inputs": {
"text": "flux1/testing/matmillerartFLUX.safetensors,0.2,0.2",
"anything": [
"297",
0