checkpoint

This commit is contained in:
Will Miao
2025-04-02 06:05:24 +08:00
parent 27db60ce68
commit a8ec5af037
5 changed files with 67 additions and 130 deletions

View File

@@ -783,8 +783,8 @@ class RecipeRoutes:
# Parse the workflow to extract generation parameters and loras # Parse the workflow to extract generation parameters and loras
parsed_workflow = self.parser.parse_workflow(workflow_json) parsed_workflow = self.parser.parse_workflow(workflow_json)
if not parsed_workflow or not parsed_workflow.get("gen_params"): if not parsed_workflow:
return web.json_response({"error": "Could not extract generation parameters from workflow"}, status=400) return web.json_response({"error": "Could not extract parameters from workflow"}, status=400)
# Get the lora stack from the parsed workflow # Get the lora stack from the parsed workflow
lora_stack = parsed_workflow.get("loras", "") lora_stack = parsed_workflow.get("loras", "")
@@ -880,7 +880,9 @@ class RecipeRoutes:
"created_date": time.time(), "created_date": time.time(),
"base_model": most_common_base_model, "base_model": most_common_base_model,
"loras": loras_data, "loras": loras_data,
"gen_params": parsed_workflow.get("gen_params", {}), # Use the parsed workflow parameters "checkpoint": parsed_workflow.get("checkpoint", ""),
"gen_params": {key: value for key, value in parsed_workflow.items()
if key not in ['checkpoint', 'loras']},
"loras_stack": lora_stack # Include the original lora stack string "loras_stack": lora_stack # Include the original lora stack string
} }

View File

@@ -52,30 +52,15 @@ def transform_basic_guider(inputs: Dict) -> Dict:
# Get model information if needed # Get model information if needed
if "model" in inputs and isinstance(inputs["model"], dict): if "model" in inputs and isinstance(inputs["model"], dict):
if "loras" in inputs["model"]: result["model"] = inputs["model"]
result["loras"] = inputs["model"]["loras"]
return result return result
def transform_model_sampling_flux(inputs: Dict) -> Dict: def transform_model_sampling_flux(inputs: Dict) -> Dict:
"""Transform function for ModelSamplingFlux - mostly a pass-through node""" """Transform function for ModelSamplingFlux - mostly a pass-through node"""
# This node is primarily used for routing, so we mostly pass through values # This node is primarily used for routing, so we mostly pass through values
result = {}
# Extract any dimensions if present return inputs["model"]
width = inputs.get("width", 0)
height = inputs.get("height", 0)
if width and height:
result["width"] = width
result["height"] = height
result["size"] = f"{width}x{height}"
# Pass through model information
if "model" in inputs and isinstance(inputs["model"], dict):
for key, value in inputs["model"].items():
result[key] = value
return result
def transform_sampler_custom_advanced(inputs: Dict) -> Dict: def transform_sampler_custom_advanced(inputs: Dict) -> Dict:
"""Transform function for SamplerCustomAdvanced node""" """Transform function for SamplerCustomAdvanced node"""
@@ -110,6 +95,7 @@ def transform_sampler_custom_advanced(inputs: Dict) -> Dict:
result["prompt"] = guider["conditioning"].get("prompt", "") result["prompt"] = guider["conditioning"].get("prompt", "")
if "model" in guider and isinstance(guider["model"], dict): if "model" in guider and isinstance(guider["model"], dict):
result["checkpoint"] = guider["model"].get("checkpoint", "")
result["loras"] = guider["model"].get("loras", "") result["loras"] = guider["model"].get("loras", "")
# Extract dimensions from latent_image # Extract dimensions from latent_image
@@ -124,12 +110,27 @@ def transform_sampler_custom_advanced(inputs: Dict) -> Dict:
return result return result
def transform_unet_loader(inputs: Dict) -> Dict:
"""Transform function for UNETLoader node"""
unet_name = inputs.get("unet_name", "")
return {"checkpoint": unet_name} if unet_name else {}
def transform_checkpoint_loader(inputs: Dict) -> Dict:
"""Transform function for CheckpointLoaderSimple node"""
ckpt_name = inputs.get("ckpt_name", "")
return {"checkpoint": ckpt_name} if ckpt_name else {}
# ============================================================================= # =============================================================================
# Register Mappers # Register Mappers
# ============================================================================= # =============================================================================
# Define the mappers for ComfyUI core nodes not in main mapper # Define the mappers for ComfyUI core nodes not in main mapper
COMFYUI_CORE_MAPPERS = { COMFYUI_CORE_MAPPERS = {
# KSamplers
"SamplerCustomAdvanced": {
"inputs_to_track": ["noise", "guider", "sampler", "sigmas", "latent_image"],
"transform_func": transform_sampler_custom_advanced
},
"RandomNoise": { "RandomNoise": {
"inputs_to_track": ["noise_seed"], "inputs_to_track": ["noise_seed"],
"transform_func": transform_random_noise "transform_func": transform_random_noise
@@ -150,9 +151,17 @@ COMFYUI_CORE_MAPPERS = {
"inputs_to_track": ["max_shift", "base_shift", "width", "height", "model"], "inputs_to_track": ["max_shift", "base_shift", "width", "height", "model"],
"transform_func": transform_model_sampling_flux "transform_func": transform_model_sampling_flux
}, },
"SamplerCustomAdvanced": { "UNETLoader": {
"inputs_to_track": ["noise", "guider", "sampler", "sigmas", "latent_image"], "inputs_to_track": ["unet_name"],
"transform_func": transform_sampler_custom_advanced "transform_func": transform_unet_loader
},
"CheckpointLoaderSimple": {
"inputs_to_track": ["ckpt_name"],
"transform_func": transform_checkpoint_loader
},
"CheckpointLoader": {
"inputs_to_track": ["ckpt_name"],
"transform_func": transform_checkpoint_loader
} }
} }

View File

@@ -125,6 +125,15 @@ def transform_ksampler(inputs: Dict) -> Dict:
# Add clip_skip if present # Add clip_skip if present
if "clip_skip" in inputs: if "clip_skip" in inputs:
result["clip_skip"] = str(inputs.get("clip_skip", "")) result["clip_skip"] = str(inputs.get("clip_skip", ""))
# Add guidance if present
if "guidance" in inputs:
result["guidance"] = str(inputs.get("guidance", ""))
# Add model if present
if "model" in inputs:
result["checkpoint"] = inputs.get("model", {}).get("checkpoint", "")
result["loras"] = inputs.get("model", {}).get("loras", "")
return result return result
@@ -167,8 +176,13 @@ def transform_lora_loader(inputs: Dict) -> Dict:
lora_name = stack_entry[0] lora_name = stack_entry[0]
strength = stack_entry[1] strength = stack_entry[1]
lora_texts.append(f"<lora:{lora_name}:{strength}>") lora_texts.append(f"<lora:{lora_name}:{strength}>")
result = {
"checkpoint": inputs.get("model", {}).get("checkpoint", ""),
"loras": " ".join(lora_texts)
}
return {"loras": " ".join(lora_texts)} return result
def transform_lora_stacker(inputs: Dict) -> Dict: def transform_lora_stacker(inputs: Dict) -> Dict:
"""Transform function for LoraStacker nodes""" """Transform function for LoraStacker nodes"""
@@ -276,7 +290,7 @@ NODE_MAPPERS = {
}, },
# LoraManager nodes # LoraManager nodes
"Lora Loader (LoraManager)": { "Lora Loader (LoraManager)": {
"inputs_to_track": ["loras", "lora_stack"], "inputs_to_track": ["model", "loras", "lora_stack"],
"transform_func": transform_lora_loader "transform_func": transform_lora_loader
}, },
"Lora Stacker (LoraManager)": { "Lora Stacker (LoraManager)": {

View File

@@ -166,71 +166,33 @@ class WorkflowParser:
logger.warning("No suitable sampler node found in workflow") logger.warning("No suitable sampler node found in workflow")
return {} return {}
# Start parsing from the sampler node
result = {
"gen_params": {},
"loras": ""
}
# Process sampler node to extract parameters # Process sampler node to extract parameters
sampler_result = self.process_node(sampler_node_id, workflow) sampler_result = self.process_node(sampler_node_id, workflow)
if sampler_result: logger.info(f"Sampler result: {sampler_result}")
# Process the result if not sampler_result:
for key, value in sampler_result.items(): return {}
# Special handling for the positive prompt from FluxGuidance
if key == "positive" and isinstance(value, dict):
# Extract guidance value
if "guidance" in value:
result["gen_params"]["guidance"] = value["guidance"]
# Extract prompt
if "prompt" in value:
result["gen_params"]["prompt"] = value["prompt"]
else:
# Normal handling for other values
result["gen_params"][key] = value
# Process the positive prompt node if it exists and we don't have a prompt yet # Return the sampler result directly - it's already in the format we need
if "prompt" not in result["gen_params"] and "positive" in sampler_result: # This simplifies the structure and makes it easier to use in recipe_routes.py
positive_value = sampler_result.get("positive")
if isinstance(positive_value, str):
result["gen_params"]["prompt"] = positive_value
# Manually check for FluxGuidance if we don't have guidance value
if "guidance" not in result["gen_params"]:
flux_node_id = find_node_by_type(workflow, "FluxGuidance")
if flux_node_id:
# Get the direct input from the node
node_inputs = workflow[flux_node_id].get("inputs", {})
if "guidance" in node_inputs:
result["gen_params"]["guidance"] = node_inputs["guidance"]
# Extract loras from the model input of sampler
sampler_node = workflow.get(sampler_node_id, {})
sampler_inputs = sampler_node.get("inputs", {})
if "model" in sampler_inputs and isinstance(sampler_inputs["model"], list):
loras_text = self.collect_loras_from_model(sampler_inputs["model"], workflow)
if loras_text:
result["loras"] = loras_text
# Handle standard ComfyUI names vs our output format # Handle standard ComfyUI names vs our output format
if "cfg" in result["gen_params"]: if "cfg" in sampler_result:
result["gen_params"]["cfg_scale"] = result["gen_params"].pop("cfg") sampler_result["cfg_scale"] = sampler_result.pop("cfg")
# Add clip_skip = 1 to match reference output if not already present # Add clip_skip = 1 to match reference output if not already present
if "clip_skip" not in result["gen_params"]: if "clip_skip" not in sampler_result:
result["gen_params"]["clip_skip"] = "1" sampler_result["clip_skip"] = "1"
# Ensure the prompt is a string and not a nested dictionary # Ensure the prompt is a string and not a nested dictionary
if "prompt" in result["gen_params"] and isinstance(result["gen_params"]["prompt"], dict): if "prompt" in sampler_result and isinstance(sampler_result["prompt"], dict):
if "prompt" in result["gen_params"]["prompt"]: if "prompt" in sampler_result["prompt"]:
result["gen_params"]["prompt"] = result["gen_params"]["prompt"]["prompt"] sampler_result["prompt"] = sampler_result["prompt"]["prompt"]
# Save the result if requested # Save the result if requested
if output_path: if output_path:
save_output(result, output_path) save_output(sampler_result, output_path)
return result return sampler_result
def parse_workflow(workflow_path: str, output_path: Optional[str] = None) -> Dict: def parse_workflow(workflow_path: str, output_path: Optional[str] = None) -> Dict:
@@ -245,4 +207,4 @@ def parse_workflow(workflow_path: str, output_path: Optional[str] = None) -> Dic
Dictionary containing extracted parameters Dictionary containing extracted parameters
""" """
parser = WorkflowParser() parser = WorkflowParser()
return parser.parse_workflow(workflow_path, output_path) return parser.parse_workflow(workflow_path, output_path)

View File

@@ -254,51 +254,6 @@
"title": "Text Load Line From File" "title": "Text Load Line From File"
} }
}, },
"223": {
"inputs": {
"filename": "%time_%seed",
"path": "%date",
"extension": "jpeg",
"steps": [
"246",
0
],
"cfg": 3.5,
"modelname": "flux_dev",
"sampler_name": "dpmpp_2m",
"scheduler": "beta",
"positive": [
"203",
0
],
"negative": "",
"width": [
"48",
1
],
"height": [
"48",
2
],
"lossless_webp": true,
"quality_jpeg_or_webp": 100,
"optimize_png": false,
"counter": 0,
"denoise": 1,
"clip_skip": 1,
"time_format": "%Y-%m-%d-%H%M%S",
"save_workflow_as_json": false,
"embed_workflow_in_png": false,
"images": [
"8",
0
]
},
"class_type": "Image Saver",
"_meta": {
"title": "Image Saver"
}
},
"226": { "226": {
"inputs": { "inputs": {
"images": [ "images": [
@@ -325,11 +280,7 @@
"group_mode": true, "group_mode": true,
"toggle_trigger_words": [ "toggle_trigger_words": [
{ {
"text": "perfection style", "text": "bo-exposure",
"active": true
},
{
"text": "mythp0rt",
"active": true "active": true
}, },
{ {
@@ -343,7 +294,7 @@
"_isDummy": true "_isDummy": true
} }
], ],
"orinalMessage": "perfection style,, mythp0rt", "orinalMessage": "bo-exposure",
"trigger_words": [ "trigger_words": [
"299", "299",
2 2
@@ -382,7 +333,6 @@
}, },
"298": { "298": {
"inputs": { "inputs": {
"text": "flux1/testing/matmillerartFLUX.safetensors,0.2,0.2",
"anything": [ "anything": [
"297", "297",
0 0