checkpoint

This commit is contained in:
Will Miao
2025-04-02 06:05:24 +08:00
parent 27db60ce68
commit a8ec5af037
5 changed files with 67 additions and 130 deletions

View File

@@ -783,8 +783,8 @@ class RecipeRoutes:
# Parse the workflow to extract generation parameters and loras
parsed_workflow = self.parser.parse_workflow(workflow_json)
if not parsed_workflow or not parsed_workflow.get("gen_params"):
return web.json_response({"error": "Could not extract generation parameters from workflow"}, status=400)
if not parsed_workflow:
return web.json_response({"error": "Could not extract parameters from workflow"}, status=400)
# Get the lora stack from the parsed workflow
lora_stack = parsed_workflow.get("loras", "")
@@ -880,7 +880,9 @@ class RecipeRoutes:
"created_date": time.time(),
"base_model": most_common_base_model,
"loras": loras_data,
"gen_params": parsed_workflow.get("gen_params", {}), # Use the parsed workflow parameters
"checkpoint": parsed_workflow.get("checkpoint", ""),
"gen_params": {key: value for key, value in parsed_workflow.items()
if key not in ['checkpoint', 'loras']},
"loras_stack": lora_stack # Include the original lora stack string
}

View File

@@ -52,30 +52,15 @@ def transform_basic_guider(inputs: Dict) -> Dict:
# Get model information if needed
if "model" in inputs and isinstance(inputs["model"], dict):
if "loras" in inputs["model"]:
result["loras"] = inputs["model"]["loras"]
result["model"] = inputs["model"]
return result
def transform_model_sampling_flux(inputs: Dict) -> Dict:
"""Transform function for ModelSamplingFlux - mostly a pass-through node"""
# This node is primarily used for routing, so we mostly pass through values
result = {}
# Extract any dimensions if present
width = inputs.get("width", 0)
height = inputs.get("height", 0)
if width and height:
result["width"] = width
result["height"] = height
result["size"] = f"{width}x{height}"
# Pass through model information
if "model" in inputs and isinstance(inputs["model"], dict):
for key, value in inputs["model"].items():
result[key] = value
return result
return inputs["model"]
def transform_sampler_custom_advanced(inputs: Dict) -> Dict:
"""Transform function for SamplerCustomAdvanced node"""
@@ -110,6 +95,7 @@ def transform_sampler_custom_advanced(inputs: Dict) -> Dict:
result["prompt"] = guider["conditioning"].get("prompt", "")
if "model" in guider and isinstance(guider["model"], dict):
result["checkpoint"] = guider["model"].get("checkpoint", "")
result["loras"] = guider["model"].get("loras", "")
# Extract dimensions from latent_image
@@ -124,12 +110,27 @@ def transform_sampler_custom_advanced(inputs: Dict) -> Dict:
return result
def transform_unet_loader(inputs: Dict) -> Dict:
"""Transform function for UNETLoader node"""
unet_name = inputs.get("unet_name", "")
return {"checkpoint": unet_name} if unet_name else {}
def transform_checkpoint_loader(inputs: Dict) -> Dict:
"""Transform function for CheckpointLoaderSimple node"""
ckpt_name = inputs.get("ckpt_name", "")
return {"checkpoint": ckpt_name} if ckpt_name else {}
# =============================================================================
# Register Mappers
# =============================================================================
# Define the mappers for ComfyUI core nodes not in main mapper
COMFYUI_CORE_MAPPERS = {
# KSamplers
"SamplerCustomAdvanced": {
"inputs_to_track": ["noise", "guider", "sampler", "sigmas", "latent_image"],
"transform_func": transform_sampler_custom_advanced
},
"RandomNoise": {
"inputs_to_track": ["noise_seed"],
"transform_func": transform_random_noise
@@ -150,9 +151,17 @@ COMFYUI_CORE_MAPPERS = {
"inputs_to_track": ["max_shift", "base_shift", "width", "height", "model"],
"transform_func": transform_model_sampling_flux
},
"SamplerCustomAdvanced": {
"inputs_to_track": ["noise", "guider", "sampler", "sigmas", "latent_image"],
"transform_func": transform_sampler_custom_advanced
"UNETLoader": {
"inputs_to_track": ["unet_name"],
"transform_func": transform_unet_loader
},
"CheckpointLoaderSimple": {
"inputs_to_track": ["ckpt_name"],
"transform_func": transform_checkpoint_loader
},
"CheckpointLoader": {
"inputs_to_track": ["ckpt_name"],
"transform_func": transform_checkpoint_loader
}
}

View File

@@ -125,6 +125,15 @@ def transform_ksampler(inputs: Dict) -> Dict:
# Add clip_skip if present
if "clip_skip" in inputs:
result["clip_skip"] = str(inputs.get("clip_skip", ""))
# Add guidance if present
if "guidance" in inputs:
result["guidance"] = str(inputs.get("guidance", ""))
# Add model if present
if "model" in inputs:
result["checkpoint"] = inputs.get("model", {}).get("checkpoint", "")
result["loras"] = inputs.get("model", {}).get("loras", "")
return result
@@ -167,8 +176,13 @@ def transform_lora_loader(inputs: Dict) -> Dict:
lora_name = stack_entry[0]
strength = stack_entry[1]
lora_texts.append(f"<lora:{lora_name}:{strength}>")
result = {
"checkpoint": inputs.get("model", {}).get("checkpoint", ""),
"loras": " ".join(lora_texts)
}
return {"loras": " ".join(lora_texts)}
return result
def transform_lora_stacker(inputs: Dict) -> Dict:
"""Transform function for LoraStacker nodes"""
@@ -276,7 +290,7 @@ NODE_MAPPERS = {
},
# LoraManager nodes
"Lora Loader (LoraManager)": {
"inputs_to_track": ["loras", "lora_stack"],
"inputs_to_track": ["model", "loras", "lora_stack"],
"transform_func": transform_lora_loader
},
"Lora Stacker (LoraManager)": {

View File

@@ -166,71 +166,33 @@ class WorkflowParser:
logger.warning("No suitable sampler node found in workflow")
return {}
# Start parsing from the sampler node
result = {
"gen_params": {},
"loras": ""
}
# Process sampler node to extract parameters
sampler_result = self.process_node(sampler_node_id, workflow)
if sampler_result:
# Process the result
for key, value in sampler_result.items():
# Special handling for the positive prompt from FluxGuidance
if key == "positive" and isinstance(value, dict):
# Extract guidance value
if "guidance" in value:
result["gen_params"]["guidance"] = value["guidance"]
# Extract prompt
if "prompt" in value:
result["gen_params"]["prompt"] = value["prompt"]
else:
# Normal handling for other values
result["gen_params"][key] = value
logger.info(f"Sampler result: {sampler_result}")
if not sampler_result:
return {}
# Process the positive prompt node if it exists and we don't have a prompt yet
if "prompt" not in result["gen_params"] and "positive" in sampler_result:
positive_value = sampler_result.get("positive")
if isinstance(positive_value, str):
result["gen_params"]["prompt"] = positive_value
# Manually check for FluxGuidance if we don't have guidance value
if "guidance" not in result["gen_params"]:
flux_node_id = find_node_by_type(workflow, "FluxGuidance")
if flux_node_id:
# Get the direct input from the node
node_inputs = workflow[flux_node_id].get("inputs", {})
if "guidance" in node_inputs:
result["gen_params"]["guidance"] = node_inputs["guidance"]
# Extract loras from the model input of sampler
sampler_node = workflow.get(sampler_node_id, {})
sampler_inputs = sampler_node.get("inputs", {})
if "model" in sampler_inputs and isinstance(sampler_inputs["model"], list):
loras_text = self.collect_loras_from_model(sampler_inputs["model"], workflow)
if loras_text:
result["loras"] = loras_text
# Return the sampler result directly - it's already in the format we need
# This simplifies the structure and makes it easier to use in recipe_routes.py
# Handle standard ComfyUI names vs our output format
if "cfg" in result["gen_params"]:
result["gen_params"]["cfg_scale"] = result["gen_params"].pop("cfg")
if "cfg" in sampler_result:
sampler_result["cfg_scale"] = sampler_result.pop("cfg")
# Add clip_skip = 1 to match reference output if not already present
if "clip_skip" not in result["gen_params"]:
result["gen_params"]["clip_skip"] = "1"
if "clip_skip" not in sampler_result:
sampler_result["clip_skip"] = "1"
# Ensure the prompt is a string and not a nested dictionary
if "prompt" in result["gen_params"] and isinstance(result["gen_params"]["prompt"], dict):
if "prompt" in result["gen_params"]["prompt"]:
result["gen_params"]["prompt"] = result["gen_params"]["prompt"]["prompt"]
if "prompt" in sampler_result and isinstance(sampler_result["prompt"], dict):
if "prompt" in sampler_result["prompt"]:
sampler_result["prompt"] = sampler_result["prompt"]["prompt"]
# Save the result if requested
if output_path:
save_output(result, output_path)
save_output(sampler_result, output_path)
return result
return sampler_result
def parse_workflow(workflow_path: str, output_path: Optional[str] = None) -> Dict:
@@ -245,4 +207,4 @@ def parse_workflow(workflow_path: str, output_path: Optional[str] = None) -> Dic
Dictionary containing extracted parameters
"""
parser = WorkflowParser()
return parser.parse_workflow(workflow_path, output_path)
return parser.parse_workflow(workflow_path, output_path)