From a8ec5af037534bfd826defa4c25768965e08e3ac Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Wed, 2 Apr 2025 06:05:24 +0800 Subject: [PATCH] checkpoint --- py/routes/recipe_routes.py | 8 ++-- py/workflow/ext/comfyui_core.py | 49 ++++++++++++++---------- py/workflow/mappers.py | 18 ++++++++- py/workflow/parser.py | 68 ++++++++------------------------- refs/prompt.json | 54 +------------------------- 5 files changed, 67 insertions(+), 130 deletions(-) diff --git a/py/routes/recipe_routes.py b/py/routes/recipe_routes.py index 037f5c19..a6e3ef91 100644 --- a/py/routes/recipe_routes.py +++ b/py/routes/recipe_routes.py @@ -783,8 +783,8 @@ class RecipeRoutes: # Parse the workflow to extract generation parameters and loras parsed_workflow = self.parser.parse_workflow(workflow_json) - if not parsed_workflow or not parsed_workflow.get("gen_params"): - return web.json_response({"error": "Could not extract generation parameters from workflow"}, status=400) + if not parsed_workflow: + return web.json_response({"error": "Could not extract parameters from workflow"}, status=400) # Get the lora stack from the parsed workflow lora_stack = parsed_workflow.get("loras", "") @@ -880,7 +880,9 @@ class RecipeRoutes: "created_date": time.time(), "base_model": most_common_base_model, "loras": loras_data, - "gen_params": parsed_workflow.get("gen_params", {}), # Use the parsed workflow parameters + "checkpoint": parsed_workflow.get("checkpoint", ""), + "gen_params": {key: value for key, value in parsed_workflow.items() + if key not in ['checkpoint', 'loras']}, "loras_stack": lora_stack # Include the original lora stack string } diff --git a/py/workflow/ext/comfyui_core.py b/py/workflow/ext/comfyui_core.py index 59713ee8..73b56f41 100644 --- a/py/workflow/ext/comfyui_core.py +++ b/py/workflow/ext/comfyui_core.py @@ -52,30 +52,15 @@ def transform_basic_guider(inputs: Dict) -> Dict: # Get model information if needed if "model" in inputs and isinstance(inputs["model"], dict): - if "loras" in inputs["model"]: - result["loras"] = inputs["model"]["loras"] + result["model"] = inputs["model"] return result def transform_model_sampling_flux(inputs: Dict) -> Dict: """Transform function for ModelSamplingFlux - mostly a pass-through node""" # This node is primarily used for routing, so we mostly pass through values - result = {} - # Extract any dimensions if present - width = inputs.get("width", 0) - height = inputs.get("height", 0) - if width and height: - result["width"] = width - result["height"] = height - result["size"] = f"{width}x{height}" - - # Pass through model information - if "model" in inputs and isinstance(inputs["model"], dict): - for key, value in inputs["model"].items(): - result[key] = value - - return result + return inputs["model"] def transform_sampler_custom_advanced(inputs: Dict) -> Dict: """Transform function for SamplerCustomAdvanced node""" @@ -110,6 +95,7 @@ def transform_sampler_custom_advanced(inputs: Dict) -> Dict: result["prompt"] = guider["conditioning"].get("prompt", "") if "model" in guider and isinstance(guider["model"], dict): + result["checkpoint"] = guider["model"].get("checkpoint", "") result["loras"] = guider["model"].get("loras", "") # Extract dimensions from latent_image @@ -124,12 +110,27 @@ def transform_sampler_custom_advanced(inputs: Dict) -> Dict: return result +def transform_unet_loader(inputs: Dict) -> Dict: + """Transform function for UNETLoader node""" + unet_name = inputs.get("unet_name", "") + return {"checkpoint": unet_name} if unet_name else {} + +def transform_checkpoint_loader(inputs: Dict) -> Dict: + """Transform function for CheckpointLoaderSimple node""" + ckpt_name = inputs.get("ckpt_name", "") + return {"checkpoint": ckpt_name} if ckpt_name else {} + # ============================================================================= # Register Mappers # ============================================================================= # Define the mappers for ComfyUI core nodes not in main mapper COMFYUI_CORE_MAPPERS = { + # KSamplers + "SamplerCustomAdvanced": { + "inputs_to_track": ["noise", "guider", "sampler", "sigmas", "latent_image"], + "transform_func": transform_sampler_custom_advanced + }, "RandomNoise": { "inputs_to_track": ["noise_seed"], "transform_func": transform_random_noise @@ -150,9 +151,17 @@ COMFYUI_CORE_MAPPERS = { "inputs_to_track": ["max_shift", "base_shift", "width", "height", "model"], "transform_func": transform_model_sampling_flux }, - "SamplerCustomAdvanced": { - "inputs_to_track": ["noise", "guider", "sampler", "sigmas", "latent_image"], - "transform_func": transform_sampler_custom_advanced + "UNETLoader": { + "inputs_to_track": ["unet_name"], + "transform_func": transform_unet_loader + }, + "CheckpointLoaderSimple": { + "inputs_to_track": ["ckpt_name"], + "transform_func": transform_checkpoint_loader + }, + "CheckpointLoader": { + "inputs_to_track": ["ckpt_name"], + "transform_func": transform_checkpoint_loader } } diff --git a/py/workflow/mappers.py b/py/workflow/mappers.py index 96fb02a2..f954cf54 100644 --- a/py/workflow/mappers.py +++ b/py/workflow/mappers.py @@ -125,6 +125,15 @@ def transform_ksampler(inputs: Dict) -> Dict: # Add clip_skip if present if "clip_skip" in inputs: result["clip_skip"] = str(inputs.get("clip_skip", "")) + + # Add guidance if present + if "guidance" in inputs: + result["guidance"] = str(inputs.get("guidance", "")) + + # Add model if present + if "model" in inputs: + result["checkpoint"] = inputs.get("model", {}).get("checkpoint", "") + result["loras"] = inputs.get("model", {}).get("loras", "") return result @@ -167,8 +176,13 @@ def transform_lora_loader(inputs: Dict) -> Dict: lora_name = stack_entry[0] strength = stack_entry[1] lora_texts.append(f"") + + result = { + "checkpoint": inputs.get("model", {}).get("checkpoint", ""), + "loras": " ".join(lora_texts) + } - return {"loras": " ".join(lora_texts)} + return result def transform_lora_stacker(inputs: Dict) -> Dict: """Transform function for LoraStacker nodes""" @@ -276,7 +290,7 @@ NODE_MAPPERS = { }, # LoraManager nodes "Lora Loader (LoraManager)": { - "inputs_to_track": ["loras", "lora_stack"], + "inputs_to_track": ["model", "loras", "lora_stack"], "transform_func": transform_lora_loader }, "Lora Stacker (LoraManager)": { diff --git a/py/workflow/parser.py b/py/workflow/parser.py index e57d8ce4..2c913173 100644 --- a/py/workflow/parser.py +++ b/py/workflow/parser.py @@ -166,71 +166,33 @@ class WorkflowParser: logger.warning("No suitable sampler node found in workflow") return {} - # Start parsing from the sampler node - result = { - "gen_params": {}, - "loras": "" - } - # Process sampler node to extract parameters sampler_result = self.process_node(sampler_node_id, workflow) - if sampler_result: - # Process the result - for key, value in sampler_result.items(): - # Special handling for the positive prompt from FluxGuidance - if key == "positive" and isinstance(value, dict): - # Extract guidance value - if "guidance" in value: - result["gen_params"]["guidance"] = value["guidance"] - - # Extract prompt - if "prompt" in value: - result["gen_params"]["prompt"] = value["prompt"] - else: - # Normal handling for other values - result["gen_params"][key] = value + logger.info(f"Sampler result: {sampler_result}") + if not sampler_result: + return {} - # Process the positive prompt node if it exists and we don't have a prompt yet - if "prompt" not in result["gen_params"] and "positive" in sampler_result: - positive_value = sampler_result.get("positive") - if isinstance(positive_value, str): - result["gen_params"]["prompt"] = positive_value - - # Manually check for FluxGuidance if we don't have guidance value - if "guidance" not in result["gen_params"]: - flux_node_id = find_node_by_type(workflow, "FluxGuidance") - if flux_node_id: - # Get the direct input from the node - node_inputs = workflow[flux_node_id].get("inputs", {}) - if "guidance" in node_inputs: - result["gen_params"]["guidance"] = node_inputs["guidance"] - - # Extract loras from the model input of sampler - sampler_node = workflow.get(sampler_node_id, {}) - sampler_inputs = sampler_node.get("inputs", {}) - if "model" in sampler_inputs and isinstance(sampler_inputs["model"], list): - loras_text = self.collect_loras_from_model(sampler_inputs["model"], workflow) - if loras_text: - result["loras"] = loras_text + # Return the sampler result directly - it's already in the format we need + # This simplifies the structure and makes it easier to use in recipe_routes.py # Handle standard ComfyUI names vs our output format - if "cfg" in result["gen_params"]: - result["gen_params"]["cfg_scale"] = result["gen_params"].pop("cfg") + if "cfg" in sampler_result: + sampler_result["cfg_scale"] = sampler_result.pop("cfg") # Add clip_skip = 1 to match reference output if not already present - if "clip_skip" not in result["gen_params"]: - result["gen_params"]["clip_skip"] = "1" + if "clip_skip" not in sampler_result: + sampler_result["clip_skip"] = "1" # Ensure the prompt is a string and not a nested dictionary - if "prompt" in result["gen_params"] and isinstance(result["gen_params"]["prompt"], dict): - if "prompt" in result["gen_params"]["prompt"]: - result["gen_params"]["prompt"] = result["gen_params"]["prompt"]["prompt"] + if "prompt" in sampler_result and isinstance(sampler_result["prompt"], dict): + if "prompt" in sampler_result["prompt"]: + sampler_result["prompt"] = sampler_result["prompt"]["prompt"] # Save the result if requested if output_path: - save_output(result, output_path) + save_output(sampler_result, output_path) - return result + return sampler_result def parse_workflow(workflow_path: str, output_path: Optional[str] = None) -> Dict: @@ -245,4 +207,4 @@ def parse_workflow(workflow_path: str, output_path: Optional[str] = None) -> Dic Dictionary containing extracted parameters """ parser = WorkflowParser() - return parser.parse_workflow(workflow_path, output_path) \ No newline at end of file + return parser.parse_workflow(workflow_path, output_path) \ No newline at end of file diff --git a/refs/prompt.json b/refs/prompt.json index 535ddec5..96f62b0a 100644 --- a/refs/prompt.json +++ b/refs/prompt.json @@ -254,51 +254,6 @@ "title": "Text Load Line From File" } }, - "223": { - "inputs": { - "filename": "%time_%seed", - "path": "%date", - "extension": "jpeg", - "steps": [ - "246", - 0 - ], - "cfg": 3.5, - "modelname": "flux_dev", - "sampler_name": "dpmpp_2m", - "scheduler": "beta", - "positive": [ - "203", - 0 - ], - "negative": "", - "width": [ - "48", - 1 - ], - "height": [ - "48", - 2 - ], - "lossless_webp": true, - "quality_jpeg_or_webp": 100, - "optimize_png": false, - "counter": 0, - "denoise": 1, - "clip_skip": 1, - "time_format": "%Y-%m-%d-%H%M%S", - "save_workflow_as_json": false, - "embed_workflow_in_png": false, - "images": [ - "8", - 0 - ] - }, - "class_type": "Image Saver", - "_meta": { - "title": "Image Saver" - } - }, "226": { "inputs": { "images": [ @@ -325,11 +280,7 @@ "group_mode": true, "toggle_trigger_words": [ { - "text": "perfection style", - "active": true - }, - { - "text": "mythp0rt", + "text": "bo-exposure", "active": true }, { @@ -343,7 +294,7 @@ "_isDummy": true } ], - "orinalMessage": "perfection style,, mythp0rt", + "orinalMessage": "bo-exposure", "trigger_words": [ "299", 2 @@ -382,7 +333,6 @@ }, "298": { "inputs": { - "text": "flux1/testing/matmillerartFLUX.safetensors,0.2,0.2", "anything": [ "297", 0