checkpoint

This commit is contained in:
Will Miao
2025-04-02 06:05:24 +08:00
parent 27db60ce68
commit a8ec5af037
5 changed files with 67 additions and 130 deletions

View File

@@ -52,30 +52,15 @@ def transform_basic_guider(inputs: Dict) -> Dict:
# Get model information if needed
if "model" in inputs and isinstance(inputs["model"], dict):
if "loras" in inputs["model"]:
result["loras"] = inputs["model"]["loras"]
result["model"] = inputs["model"]
return result
def transform_model_sampling_flux(inputs: Dict) -> Dict:
"""Transform function for ModelSamplingFlux - mostly a pass-through node"""
# This node is primarily used for routing, so we mostly pass through values
result = {}
# Extract any dimensions if present
width = inputs.get("width", 0)
height = inputs.get("height", 0)
if width and height:
result["width"] = width
result["height"] = height
result["size"] = f"{width}x{height}"
# Pass through model information
if "model" in inputs and isinstance(inputs["model"], dict):
for key, value in inputs["model"].items():
result[key] = value
return result
return inputs["model"]
def transform_sampler_custom_advanced(inputs: Dict) -> Dict:
"""Transform function for SamplerCustomAdvanced node"""
@@ -110,6 +95,7 @@ def transform_sampler_custom_advanced(inputs: Dict) -> Dict:
result["prompt"] = guider["conditioning"].get("prompt", "")
if "model" in guider and isinstance(guider["model"], dict):
result["checkpoint"] = guider["model"].get("checkpoint", "")
result["loras"] = guider["model"].get("loras", "")
# Extract dimensions from latent_image
@@ -124,12 +110,27 @@ def transform_sampler_custom_advanced(inputs: Dict) -> Dict:
return result
def transform_unet_loader(inputs: Dict) -> Dict:
"""Transform function for UNETLoader node"""
unet_name = inputs.get("unet_name", "")
return {"checkpoint": unet_name} if unet_name else {}
def transform_checkpoint_loader(inputs: Dict) -> Dict:
"""Transform function for CheckpointLoaderSimple node"""
ckpt_name = inputs.get("ckpt_name", "")
return {"checkpoint": ckpt_name} if ckpt_name else {}
# =============================================================================
# Register Mappers
# =============================================================================
# Define the mappers for ComfyUI core nodes not in main mapper
COMFYUI_CORE_MAPPERS = {
# KSamplers
"SamplerCustomAdvanced": {
"inputs_to_track": ["noise", "guider", "sampler", "sigmas", "latent_image"],
"transform_func": transform_sampler_custom_advanced
},
"RandomNoise": {
"inputs_to_track": ["noise_seed"],
"transform_func": transform_random_noise
@@ -150,9 +151,17 @@ COMFYUI_CORE_MAPPERS = {
"inputs_to_track": ["max_shift", "base_shift", "width", "height", "model"],
"transform_func": transform_model_sampling_flux
},
"SamplerCustomAdvanced": {
"inputs_to_track": ["noise", "guider", "sampler", "sigmas", "latent_image"],
"transform_func": transform_sampler_custom_advanced
"UNETLoader": {
"inputs_to_track": ["unet_name"],
"transform_func": transform_unet_loader
},
"CheckpointLoaderSimple": {
"inputs_to_track": ["ckpt_name"],
"transform_func": transform_checkpoint_loader
},
"CheckpointLoader": {
"inputs_to_track": ["ckpt_name"],
"transform_func": transform_checkpoint_loader
}
}