mirror of
https://github.com/jags111/efficiency-nodes-comfyui.git
synced 2026-03-25 07:05:44 -03:00
patch for lora changes from comfy commit 5a9ddf9
This commit is contained in:
@@ -22,6 +22,9 @@ import psutil
|
|||||||
# Get the absolute path of the parent directory of the current script
|
# Get the absolute path of the parent directory of the current script
|
||||||
my_dir = os.path.dirname(os.path.abspath(__file__))
|
my_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
|
# Add the My directory path to the sys.path list
|
||||||
|
sys.path.append(my_dir)
|
||||||
|
|
||||||
# Construct the absolute path to the ComfyUI directory
|
# Construct the absolute path to the ComfyUI directory
|
||||||
comfy_dir = os.path.abspath(os.path.join(my_dir, '..', '..'))
|
comfy_dir = os.path.abspath(os.path.join(my_dir, '..', '..'))
|
||||||
|
|
||||||
@@ -36,6 +39,13 @@ import comfy.samplers
|
|||||||
import comfy.sd
|
import comfy.sd
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
|
|
||||||
|
# Load legacy lora functions
|
||||||
|
from lora_patch import load_lora, load_lora_for_models
|
||||||
|
|
||||||
|
# Replace the lora functions with the legacy functions
|
||||||
|
comfy.sd.load_lora = load_lora
|
||||||
|
comfy.sd.load_lora_for_models = load_lora_for_models
|
||||||
|
|
||||||
MAX_RESOLUTION=8192
|
MAX_RESOLUTION=8192
|
||||||
|
|
||||||
# Tensor to PIL (grabbed from WAS Suite)
|
# Tensor to PIL (grabbed from WAS Suite)
|
||||||
@@ -1568,6 +1578,7 @@ class TSC_XYplot:
|
|||||||
if (X_type == Y_type):
|
if (X_type == Y_type):
|
||||||
if X_type != "Nothing":
|
if X_type != "Nothing":
|
||||||
print(f"\033[31mXY Plot Error:\033[0m X and Y input types must be different.")
|
print(f"\033[31mXY Plot Error:\033[0m X and Y input types must be different.")
|
||||||
|
'''
|
||||||
else:
|
else:
|
||||||
# Print XY Plot Inputs
|
# Print XY Plot Inputs
|
||||||
print("-" * 40)
|
print("-" * 40)
|
||||||
@@ -1575,6 +1586,7 @@ class TSC_XYplot:
|
|||||||
print(f"(X) {X_type}: {X_value}")
|
print(f"(X) {X_type}: {X_value}")
|
||||||
print(f"(Y) {Y_type}: {Y_value}")
|
print(f"(Y) {Y_type}: {Y_value}")
|
||||||
print("-" * 40)
|
print("-" * 40)
|
||||||
|
'''
|
||||||
return (None,)
|
return (None,)
|
||||||
|
|
||||||
# Check that dependencies is connected for Checkpoint and LoRA plots
|
# Check that dependencies is connected for Checkpoint and LoRA plots
|
||||||
|
|||||||
262
lora_patch.py
Normal file
262
lora_patch.py
Normal file
@@ -0,0 +1,262 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
# Get the absolute path of the parent directory of the current script
|
||||||
|
my_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
|
# Add the My directory path to the sys.path list
|
||||||
|
sys.path.append(my_dir)
|
||||||
|
|
||||||
|
# Construct the absolute path to the ComfyUI directory
|
||||||
|
comfy_dir = os.path.abspath(os.path.join(my_dir, '..', '..'))
|
||||||
|
|
||||||
|
# Add the ComfyUI directory path to the sys.path list
|
||||||
|
sys.path.append(comfy_dir)
|
||||||
|
|
||||||
|
from comfy import utils
|
||||||
|
|
||||||
|
LORA_CLIP_MAP = {
|
||||||
|
"mlp.fc1": "mlp_fc1",
|
||||||
|
"mlp.fc2": "mlp_fc2",
|
||||||
|
"self_attn.k_proj": "self_attn_k_proj",
|
||||||
|
"self_attn.q_proj": "self_attn_q_proj",
|
||||||
|
"self_attn.v_proj": "self_attn_v_proj",
|
||||||
|
"self_attn.out_proj": "self_attn_out_proj",
|
||||||
|
}
|
||||||
|
|
||||||
|
LORA_UNET_MAP_ATTENTIONS = {
|
||||||
|
"proj_in": "proj_in",
|
||||||
|
"proj_out": "proj_out",
|
||||||
|
"transformer_blocks.0.attn1.to_q": "transformer_blocks_0_attn1_to_q",
|
||||||
|
"transformer_blocks.0.attn1.to_k": "transformer_blocks_0_attn1_to_k",
|
||||||
|
"transformer_blocks.0.attn1.to_v": "transformer_blocks_0_attn1_to_v",
|
||||||
|
"transformer_blocks.0.attn1.to_out.0": "transformer_blocks_0_attn1_to_out_0",
|
||||||
|
"transformer_blocks.0.attn2.to_q": "transformer_blocks_0_attn2_to_q",
|
||||||
|
"transformer_blocks.0.attn2.to_k": "transformer_blocks_0_attn2_to_k",
|
||||||
|
"transformer_blocks.0.attn2.to_v": "transformer_blocks_0_attn2_to_v",
|
||||||
|
"transformer_blocks.0.attn2.to_out.0": "transformer_blocks_0_attn2_to_out_0",
|
||||||
|
"transformer_blocks.0.ff.net.0.proj": "transformer_blocks_0_ff_net_0_proj",
|
||||||
|
"transformer_blocks.0.ff.net.2": "transformer_blocks_0_ff_net_2",
|
||||||
|
}
|
||||||
|
|
||||||
|
LORA_UNET_MAP_RESNET = {
|
||||||
|
"in_layers.2": "resnets_{}_conv1",
|
||||||
|
"emb_layers.1": "resnets_{}_time_emb_proj",
|
||||||
|
"out_layers.3": "resnets_{}_conv2",
|
||||||
|
"skip_connection": "resnets_{}_conv_shortcut"
|
||||||
|
}
|
||||||
|
|
||||||
|
def load_lora(path, to_load):
|
||||||
|
lora = utils.load_torch_file(path)
|
||||||
|
patch_dict = {}
|
||||||
|
loaded_keys = set()
|
||||||
|
for x in to_load:
|
||||||
|
alpha_name = "{}.alpha".format(x)
|
||||||
|
alpha = None
|
||||||
|
if alpha_name in lora.keys():
|
||||||
|
alpha = lora[alpha_name].item()
|
||||||
|
loaded_keys.add(alpha_name)
|
||||||
|
|
||||||
|
A_name = "{}.lora_up.weight".format(x)
|
||||||
|
B_name = "{}.lora_down.weight".format(x)
|
||||||
|
mid_name = "{}.lora_mid.weight".format(x)
|
||||||
|
|
||||||
|
if A_name in lora.keys():
|
||||||
|
mid = None
|
||||||
|
if mid_name in lora.keys():
|
||||||
|
mid = lora[mid_name]
|
||||||
|
loaded_keys.add(mid_name)
|
||||||
|
patch_dict[to_load[x]] = (lora[A_name], lora[B_name], alpha, mid)
|
||||||
|
loaded_keys.add(A_name)
|
||||||
|
loaded_keys.add(B_name)
|
||||||
|
|
||||||
|
|
||||||
|
######## loha
|
||||||
|
hada_w1_a_name = "{}.hada_w1_a".format(x)
|
||||||
|
hada_w1_b_name = "{}.hada_w1_b".format(x)
|
||||||
|
hada_w2_a_name = "{}.hada_w2_a".format(x)
|
||||||
|
hada_w2_b_name = "{}.hada_w2_b".format(x)
|
||||||
|
hada_t1_name = "{}.hada_t1".format(x)
|
||||||
|
hada_t2_name = "{}.hada_t2".format(x)
|
||||||
|
if hada_w1_a_name in lora.keys():
|
||||||
|
hada_t1 = None
|
||||||
|
hada_t2 = None
|
||||||
|
if hada_t1_name in lora.keys():
|
||||||
|
hada_t1 = lora[hada_t1_name]
|
||||||
|
hada_t2 = lora[hada_t2_name]
|
||||||
|
loaded_keys.add(hada_t1_name)
|
||||||
|
loaded_keys.add(hada_t2_name)
|
||||||
|
|
||||||
|
patch_dict[to_load[x]] = (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2)
|
||||||
|
loaded_keys.add(hada_w1_a_name)
|
||||||
|
loaded_keys.add(hada_w1_b_name)
|
||||||
|
loaded_keys.add(hada_w2_a_name)
|
||||||
|
loaded_keys.add(hada_w2_b_name)
|
||||||
|
|
||||||
|
|
||||||
|
######## lokr
|
||||||
|
lokr_w1_name = "{}.lokr_w1".format(x)
|
||||||
|
lokr_w2_name = "{}.lokr_w2".format(x)
|
||||||
|
lokr_w1_a_name = "{}.lokr_w1_a".format(x)
|
||||||
|
lokr_w1_b_name = "{}.lokr_w1_b".format(x)
|
||||||
|
lokr_t2_name = "{}.lokr_t2".format(x)
|
||||||
|
lokr_w2_a_name = "{}.lokr_w2_a".format(x)
|
||||||
|
lokr_w2_b_name = "{}.lokr_w2_b".format(x)
|
||||||
|
|
||||||
|
lokr_w1 = None
|
||||||
|
if lokr_w1_name in lora.keys():
|
||||||
|
lokr_w1 = lora[lokr_w1_name]
|
||||||
|
loaded_keys.add(lokr_w1_name)
|
||||||
|
|
||||||
|
lokr_w2 = None
|
||||||
|
if lokr_w2_name in lora.keys():
|
||||||
|
lokr_w2 = lora[lokr_w2_name]
|
||||||
|
loaded_keys.add(lokr_w2_name)
|
||||||
|
|
||||||
|
lokr_w1_a = None
|
||||||
|
if lokr_w1_a_name in lora.keys():
|
||||||
|
lokr_w1_a = lora[lokr_w1_a_name]
|
||||||
|
loaded_keys.add(lokr_w1_a_name)
|
||||||
|
|
||||||
|
lokr_w1_b = None
|
||||||
|
if lokr_w1_b_name in lora.keys():
|
||||||
|
lokr_w1_b = lora[lokr_w1_b_name]
|
||||||
|
loaded_keys.add(lokr_w1_b_name)
|
||||||
|
|
||||||
|
lokr_w2_a = None
|
||||||
|
if lokr_w2_a_name in lora.keys():
|
||||||
|
lokr_w2_a = lora[lokr_w2_a_name]
|
||||||
|
loaded_keys.add(lokr_w2_a_name)
|
||||||
|
|
||||||
|
lokr_w2_b = None
|
||||||
|
if lokr_w2_b_name in lora.keys():
|
||||||
|
lokr_w2_b = lora[lokr_w2_b_name]
|
||||||
|
loaded_keys.add(lokr_w2_b_name)
|
||||||
|
|
||||||
|
lokr_t2 = None
|
||||||
|
if lokr_t2_name in lora.keys():
|
||||||
|
lokr_t2 = lora[lokr_t2_name]
|
||||||
|
loaded_keys.add(lokr_t2_name)
|
||||||
|
|
||||||
|
if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None):
|
||||||
|
patch_dict[to_load[x]] = (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2)
|
||||||
|
|
||||||
|
for x in lora.keys():
|
||||||
|
if x not in loaded_keys:
|
||||||
|
print("lora key not loaded", x)
|
||||||
|
return patch_dict
|
||||||
|
|
||||||
|
def model_lora_keys(model, key_map={}):
|
||||||
|
sdk = model.state_dict().keys()
|
||||||
|
|
||||||
|
counter = 0
|
||||||
|
for b in range(12):
|
||||||
|
tk = "diffusion_model.input_blocks.{}.1".format(b)
|
||||||
|
up_counter = 0
|
||||||
|
for c in LORA_UNET_MAP_ATTENTIONS:
|
||||||
|
k = "{}.{}.weight".format(tk, c)
|
||||||
|
if k in sdk:
|
||||||
|
lora_key = "lora_unet_down_blocks_{}_attentions_{}_{}".format(counter // 2, counter % 2, LORA_UNET_MAP_ATTENTIONS[c])
|
||||||
|
key_map[lora_key] = k
|
||||||
|
up_counter += 1
|
||||||
|
if up_counter >= 4:
|
||||||
|
counter += 1
|
||||||
|
for c in LORA_UNET_MAP_ATTENTIONS:
|
||||||
|
k = "diffusion_model.middle_block.1.{}.weight".format(c)
|
||||||
|
if k in sdk:
|
||||||
|
lora_key = "lora_unet_mid_block_attentions_0_{}".format(LORA_UNET_MAP_ATTENTIONS[c])
|
||||||
|
key_map[lora_key] = k
|
||||||
|
counter = 3
|
||||||
|
for b in range(12):
|
||||||
|
tk = "diffusion_model.output_blocks.{}.1".format(b)
|
||||||
|
up_counter = 0
|
||||||
|
for c in LORA_UNET_MAP_ATTENTIONS:
|
||||||
|
k = "{}.{}.weight".format(tk, c)
|
||||||
|
if k in sdk:
|
||||||
|
lora_key = "lora_unet_up_blocks_{}_attentions_{}_{}".format(counter // 3, counter % 3, LORA_UNET_MAP_ATTENTIONS[c])
|
||||||
|
key_map[lora_key] = k
|
||||||
|
up_counter += 1
|
||||||
|
if up_counter >= 4:
|
||||||
|
counter += 1
|
||||||
|
counter = 0
|
||||||
|
text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
|
||||||
|
for b in range(24):
|
||||||
|
for c in LORA_CLIP_MAP:
|
||||||
|
k = "transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
||||||
|
if k in sdk:
|
||||||
|
lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c])
|
||||||
|
key_map[lora_key] = k
|
||||||
|
|
||||||
|
|
||||||
|
#Locon stuff
|
||||||
|
ds_counter = 0
|
||||||
|
counter = 0
|
||||||
|
for b in range(12):
|
||||||
|
tk = "diffusion_model.input_blocks.{}.0".format(b)
|
||||||
|
key_in = False
|
||||||
|
for c in LORA_UNET_MAP_RESNET:
|
||||||
|
k = "{}.{}.weight".format(tk, c)
|
||||||
|
if k in sdk:
|
||||||
|
lora_key = "lora_unet_down_blocks_{}_{}".format(counter // 2, LORA_UNET_MAP_RESNET[c].format(counter % 2))
|
||||||
|
key_map[lora_key] = k
|
||||||
|
key_in = True
|
||||||
|
for bb in range(3):
|
||||||
|
k = "{}.{}.op.weight".format(tk[:-2], bb)
|
||||||
|
if k in sdk:
|
||||||
|
lora_key = "lora_unet_down_blocks_{}_downsamplers_0_conv".format(ds_counter)
|
||||||
|
key_map[lora_key] = k
|
||||||
|
ds_counter += 1
|
||||||
|
if key_in:
|
||||||
|
counter += 1
|
||||||
|
|
||||||
|
counter = 0
|
||||||
|
for b in range(3):
|
||||||
|
tk = "diffusion_model.middle_block.{}".format(b)
|
||||||
|
key_in = False
|
||||||
|
for c in LORA_UNET_MAP_RESNET:
|
||||||
|
k = "{}.{}.weight".format(tk, c)
|
||||||
|
if k in sdk:
|
||||||
|
lora_key = "lora_unet_mid_block_{}".format(LORA_UNET_MAP_RESNET[c].format(counter))
|
||||||
|
key_map[lora_key] = k
|
||||||
|
key_in = True
|
||||||
|
if key_in:
|
||||||
|
counter += 1
|
||||||
|
|
||||||
|
counter = 0
|
||||||
|
us_counter = 0
|
||||||
|
for b in range(12):
|
||||||
|
tk = "diffusion_model.output_blocks.{}.0".format(b)
|
||||||
|
key_in = False
|
||||||
|
for c in LORA_UNET_MAP_RESNET:
|
||||||
|
k = "{}.{}.weight".format(tk, c)
|
||||||
|
if k in sdk:
|
||||||
|
lora_key = "lora_unet_up_blocks_{}_{}".format(counter // 3, LORA_UNET_MAP_RESNET[c].format(counter % 3))
|
||||||
|
key_map[lora_key] = k
|
||||||
|
key_in = True
|
||||||
|
for bb in range(3):
|
||||||
|
k = "{}.{}.conv.weight".format(tk[:-2], bb)
|
||||||
|
if k in sdk:
|
||||||
|
lora_key = "lora_unet_up_blocks_{}_upsamplers_0_conv".format(us_counter)
|
||||||
|
key_map[lora_key] = k
|
||||||
|
us_counter += 1
|
||||||
|
if key_in:
|
||||||
|
counter += 1
|
||||||
|
|
||||||
|
return key_map
|
||||||
|
|
||||||
|
def load_lora_for_models(model, clip, lora_path, strength_model, strength_clip):
|
||||||
|
key_map = model_lora_keys(model.model)
|
||||||
|
key_map = model_lora_keys(clip.cond_stage_model, key_map)
|
||||||
|
loaded = load_lora(lora_path, key_map)
|
||||||
|
new_modelpatcher = model.clone()
|
||||||
|
k = new_modelpatcher.add_patches(loaded, strength_model)
|
||||||
|
new_clip = clip.clone()
|
||||||
|
k1 = new_clip.add_patches(loaded, strength_clip)
|
||||||
|
k = set(k)
|
||||||
|
k1 = set(k1)
|
||||||
|
for x in loaded:
|
||||||
|
if (x not in k) and (x not in k1):
|
||||||
|
print("NOT LOADED", x)
|
||||||
|
|
||||||
|
return (new_modelpatcher, new_clip)
|
||||||
|
|
||||||
Reference in New Issue
Block a user