Efficient Loader Checkpoint Update Fix

Fixed an issue where the checkpoint would not get updated when a LoRA file was loaded.
This commit is contained in:
TSC
2023-05-13 17:38:30 -05:00
committed by GitHub
parent 02c19626ab
commit e88216c874

View File

@@ -50,11 +50,11 @@ def resolve_input_links(prompt, input_value):
# Cache models in RAM
loaded_objects = {
"ckpt": [], # (ckpt_name, location)
"clip": [], # (ckpt_name, location)
"bvae": [], # (ckpt_name, location)
"vae": [], # (vae_name, location)
"lora": [] # (lora_name, location)
"ckpt": [], # (ckpt_name, model)
"clip": [], # (ckpt_name, clip)
"bvae": [], # (ckpt_name, vae)
"vae": [], # (vae_name, vae)
"lora": [] # (lora_name, model_name, model_lora, clip_lora, strength_model, strength_clip)
}
def print_loaded_objects_entries():
@@ -164,20 +164,27 @@ def load_vae(vae_name):
def load_lora(lora_name, model, clip, strength_model, strength_clip):
"""
Extracts the Lora model with a given name from the "lora" array in loaded_objects.
If the Lora model is not found or the strength values change, creates a new Lora object with the given name and adds it to the "lora" array.
If the Lora model is not found or the strength values change or the original model has changed, creates a new Lora object with the given name and adds it to the "lora" array.
"""
global loaded_objects
# Get the model_name (ckpt_name) from the first entry in loaded_objects
model_name = loaded_objects["ckpt"][0][0] if loaded_objects["ckpt"] else None
# Check if lora_name exists in "lora" array
existing_lora = [entry for entry in loaded_objects["lora"] if entry[0] == lora_name]
if existing_lora:
lora_name, model_lora, clip_lora, stored_strength_model, stored_strength_clip = existing_lora[0]
lora_name, stored_model_name, model_lora, clip_lora, stored_strength_model, stored_strength_clip = existing_lora[0]
if strength_model == stored_strength_model and strength_clip == stored_strength_clip:
return model_lora, clip_lora
# Check if the model_name, strength_model, and strength_clip are the same
if model_name == stored_model_name and strength_model == stored_strength_model and strength_clip == stored_strength_clip:
# Check if the model has not changed in the loaded_objects
existing_model = [entry for entry in loaded_objects["ckpt"] if entry[0] == model_name]
if existing_model and existing_model[0][1] == model:
return model_lora, clip_lora
# If Lora model not found or strength values changed, generate new Lora models
# If Lora model not found or strength values changed or model changed, generate new Lora models
lora_path = folder_paths.get_full_path("loras", lora_name)
model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora_path, strength_model, strength_clip)
@@ -186,7 +193,7 @@ def load_lora(lora_name, model, clip, strength_model, strength_clip):
loaded_objects["lora"].remove(existing_lora[0])
# Update loaded_objects[] array
loaded_objects["lora"].append((lora_name, model_lora, clip_lora, strength_model, strength_clip))
loaded_objects["lora"].append((lora_name, model_name, model_lora, clip_lora, strength_model, strength_clip))
return model_lora, clip_lora
@@ -232,6 +239,7 @@ class TSC_EfficientLoader:
if lora_name != "None":
model, clip = load_lora(lora_name, model, clip, lora_model_strength, lora_clip_strength)
# note: load_lora only works properly (as of now) when ckpt dictionary is only 1 entry long!
# Check for custom VAE
if vae_name != "Baked VAE":