mirror of
https://github.com/jags111/efficiency-nodes-comfyui.git
synced 2026-03-25 15:15:45 -03:00
Merge pull request #68 from LucianoCirino/SDXL-LoRA-Patch1
Fixed Issues with SDXL LoRAs
This commit is contained in:
177
tsc_sd.py
177
tsc_sd.py
@@ -17,36 +17,6 @@ sys.path.append(comfy_dir)
|
|||||||
from comfy.sd import *
|
from comfy.sd import *
|
||||||
from comfy import utils
|
from comfy import utils
|
||||||
|
|
||||||
LORA_CLIP_MAP = {
|
|
||||||
"mlp.fc1": "mlp_fc1",
|
|
||||||
"mlp.fc2": "mlp_fc2",
|
|
||||||
"self_attn.k_proj": "self_attn_k_proj",
|
|
||||||
"self_attn.q_proj": "self_attn_q_proj",
|
|
||||||
"self_attn.v_proj": "self_attn_v_proj",
|
|
||||||
"self_attn.out_proj": "self_attn_out_proj",
|
|
||||||
}
|
|
||||||
|
|
||||||
LORA_UNET_MAP_ATTENTIONS = {
|
|
||||||
"proj_in": "proj_in",
|
|
||||||
"proj_out": "proj_out",
|
|
||||||
"transformer_blocks.0.attn1.to_q": "transformer_blocks_0_attn1_to_q",
|
|
||||||
"transformer_blocks.0.attn1.to_k": "transformer_blocks_0_attn1_to_k",
|
|
||||||
"transformer_blocks.0.attn1.to_v": "transformer_blocks_0_attn1_to_v",
|
|
||||||
"transformer_blocks.0.attn1.to_out.0": "transformer_blocks_0_attn1_to_out_0",
|
|
||||||
"transformer_blocks.0.attn2.to_q": "transformer_blocks_0_attn2_to_q",
|
|
||||||
"transformer_blocks.0.attn2.to_k": "transformer_blocks_0_attn2_to_k",
|
|
||||||
"transformer_blocks.0.attn2.to_v": "transformer_blocks_0_attn2_to_v",
|
|
||||||
"transformer_blocks.0.attn2.to_out.0": "transformer_blocks_0_attn2_to_out_0",
|
|
||||||
"transformer_blocks.0.ff.net.0.proj": "transformer_blocks_0_ff_net_0_proj",
|
|
||||||
"transformer_blocks.0.ff.net.2": "transformer_blocks_0_ff_net_2",
|
|
||||||
}
|
|
||||||
|
|
||||||
LORA_UNET_MAP_RESNET = {
|
|
||||||
"in_layers.2": "resnets_{}_conv1",
|
|
||||||
"emb_layers.1": "resnets_{}_time_emb_proj",
|
|
||||||
"out_layers.3": "resnets_{}_conv2",
|
|
||||||
"skip_connection": "resnets_{}_conv_shortcut"
|
|
||||||
}
|
|
||||||
|
|
||||||
def load_lora_tsc(path, to_load):
|
def load_lora_tsc(path, to_load):
|
||||||
lora = utils.load_torch_file(path)
|
lora = utils.load_torch_file(path)
|
||||||
@@ -148,107 +118,9 @@ def load_lora_tsc(path, to_load):
|
|||||||
print("lora key not loaded", x)
|
print("lora key not loaded", x)
|
||||||
return patch_dict
|
return patch_dict
|
||||||
|
|
||||||
def model_lora_keys(model, key_map={}):
|
|
||||||
sdk = model.state_dict().keys()
|
|
||||||
|
|
||||||
counter = 0
|
|
||||||
for b in range(12):
|
|
||||||
tk = "diffusion_model.input_blocks.{}.1".format(b)
|
|
||||||
up_counter = 0
|
|
||||||
for c in LORA_UNET_MAP_ATTENTIONS:
|
|
||||||
k = "{}.{}.weight".format(tk, c)
|
|
||||||
if k in sdk:
|
|
||||||
lora_key = "lora_unet_down_blocks_{}_attentions_{}_{}".format(counter // 2, counter % 2, LORA_UNET_MAP_ATTENTIONS[c])
|
|
||||||
key_map[lora_key] = k
|
|
||||||
up_counter += 1
|
|
||||||
if up_counter >= 4:
|
|
||||||
counter += 1
|
|
||||||
for c in LORA_UNET_MAP_ATTENTIONS:
|
|
||||||
k = "diffusion_model.middle_block.1.{}.weight".format(c)
|
|
||||||
if k in sdk:
|
|
||||||
lora_key = "lora_unet_mid_block_attentions_0_{}".format(LORA_UNET_MAP_ATTENTIONS[c])
|
|
||||||
key_map[lora_key] = k
|
|
||||||
counter = 3
|
|
||||||
for b in range(12):
|
|
||||||
tk = "diffusion_model.output_blocks.{}.1".format(b)
|
|
||||||
up_counter = 0
|
|
||||||
for c in LORA_UNET_MAP_ATTENTIONS:
|
|
||||||
k = "{}.{}.weight".format(tk, c)
|
|
||||||
if k in sdk:
|
|
||||||
lora_key = "lora_unet_up_blocks_{}_attentions_{}_{}".format(counter // 3, counter % 3, LORA_UNET_MAP_ATTENTIONS[c])
|
|
||||||
key_map[lora_key] = k
|
|
||||||
up_counter += 1
|
|
||||||
if up_counter >= 4:
|
|
||||||
counter += 1
|
|
||||||
counter = 0
|
|
||||||
text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
|
|
||||||
for b in range(24):
|
|
||||||
for c in LORA_CLIP_MAP:
|
|
||||||
k = "transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
|
||||||
if k in sdk:
|
|
||||||
lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c])
|
|
||||||
key_map[lora_key] = k
|
|
||||||
|
|
||||||
|
|
||||||
#Locon stuff
|
|
||||||
ds_counter = 0
|
|
||||||
counter = 0
|
|
||||||
for b in range(12):
|
|
||||||
tk = "diffusion_model.input_blocks.{}.0".format(b)
|
|
||||||
key_in = False
|
|
||||||
for c in LORA_UNET_MAP_RESNET:
|
|
||||||
k = "{}.{}.weight".format(tk, c)
|
|
||||||
if k in sdk:
|
|
||||||
lora_key = "lora_unet_down_blocks_{}_{}".format(counter // 2, LORA_UNET_MAP_RESNET[c].format(counter % 2))
|
|
||||||
key_map[lora_key] = k
|
|
||||||
key_in = True
|
|
||||||
for bb in range(3):
|
|
||||||
k = "{}.{}.op.weight".format(tk[:-2], bb)
|
|
||||||
if k in sdk:
|
|
||||||
lora_key = "lora_unet_down_blocks_{}_downsamplers_0_conv".format(ds_counter)
|
|
||||||
key_map[lora_key] = k
|
|
||||||
ds_counter += 1
|
|
||||||
if key_in:
|
|
||||||
counter += 1
|
|
||||||
|
|
||||||
counter = 0
|
|
||||||
for b in range(3):
|
|
||||||
tk = "diffusion_model.middle_block.{}".format(b)
|
|
||||||
key_in = False
|
|
||||||
for c in LORA_UNET_MAP_RESNET:
|
|
||||||
k = "{}.{}.weight".format(tk, c)
|
|
||||||
if k in sdk:
|
|
||||||
lora_key = "lora_unet_mid_block_{}".format(LORA_UNET_MAP_RESNET[c].format(counter))
|
|
||||||
key_map[lora_key] = k
|
|
||||||
key_in = True
|
|
||||||
if key_in:
|
|
||||||
counter += 1
|
|
||||||
|
|
||||||
counter = 0
|
|
||||||
us_counter = 0
|
|
||||||
for b in range(12):
|
|
||||||
tk = "diffusion_model.output_blocks.{}.0".format(b)
|
|
||||||
key_in = False
|
|
||||||
for c in LORA_UNET_MAP_RESNET:
|
|
||||||
k = "{}.{}.weight".format(tk, c)
|
|
||||||
if k in sdk:
|
|
||||||
lora_key = "lora_unet_up_blocks_{}_{}".format(counter // 3, LORA_UNET_MAP_RESNET[c].format(counter % 3))
|
|
||||||
key_map[lora_key] = k
|
|
||||||
key_in = True
|
|
||||||
for bb in range(3):
|
|
||||||
k = "{}.{}.conv.weight".format(tk[:-2], bb)
|
|
||||||
if k in sdk:
|
|
||||||
lora_key = "lora_unet_up_blocks_{}_upsamplers_0_conv".format(us_counter)
|
|
||||||
key_map[lora_key] = k
|
|
||||||
us_counter += 1
|
|
||||||
if key_in:
|
|
||||||
counter += 1
|
|
||||||
|
|
||||||
return key_map
|
|
||||||
|
|
||||||
def load_lora_for_models_tsc(model, clip, lora_path, strength_model, strength_clip):
|
def load_lora_for_models_tsc(model, clip, lora_path, strength_model, strength_clip):
|
||||||
key_map = model_lora_keys(model.model)
|
key_map = model_lora_keys_unet(model.model)
|
||||||
key_map = model_lora_keys(clip.cond_stage_model, key_map)
|
key_map = model_lora_keys_clip(clip.cond_stage_model, key_map)
|
||||||
loaded = load_lora_tsc(lora_path, key_map)
|
loaded = load_lora_tsc(lora_path, key_map)
|
||||||
new_modelpatcher = model.clone()
|
new_modelpatcher = model.clone()
|
||||||
k = new_modelpatcher.add_patches(loaded, strength_model)
|
k = new_modelpatcher.add_patches(loaded, strength_model)
|
||||||
@@ -261,48 +133,3 @@ def load_lora_for_models_tsc(model, clip, lora_path, strength_model, strength_cl
|
|||||||
print("NOT LOADED", x)
|
print("NOT LOADED", x)
|
||||||
|
|
||||||
return (new_modelpatcher, new_clip)
|
return (new_modelpatcher, new_clip)
|
||||||
|
|
||||||
def load_checkpoint_guess_config_tsc(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None):
|
|
||||||
sd = utils.load_torch_file(ckpt_path)
|
|
||||||
sd_keys = sd.keys()
|
|
||||||
clip = None
|
|
||||||
clipvision = None
|
|
||||||
vae = None
|
|
||||||
model = None
|
|
||||||
clip_target = None
|
|
||||||
|
|
||||||
parameters = calculate_parameters(sd, "model.diffusion_model.")
|
|
||||||
fp16 = model_management.should_use_fp16(model_params=parameters)
|
|
||||||
|
|
||||||
class WeightsLoader(torch.nn.Module):
|
|
||||||
pass
|
|
||||||
|
|
||||||
model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.", fp16)
|
|
||||||
if model_config is None:
|
|
||||||
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
|
|
||||||
|
|
||||||
if model_config.clip_vision_prefix is not None:
|
|
||||||
if output_clipvision:
|
|
||||||
clipvision = clip_vision.load_clipvision_from_sd(sd, model_config.clip_vision_prefix, True)
|
|
||||||
|
|
||||||
offload_device = model_management.unet_offload_device()
|
|
||||||
model = model_config.get_model(sd, "model.diffusion_model.")
|
|
||||||
model = model.to(offload_device)
|
|
||||||
model.load_model_weights(sd, "model.diffusion_model.")
|
|
||||||
|
|
||||||
if output_vae:
|
|
||||||
vae = VAE()
|
|
||||||
w = WeightsLoader()
|
|
||||||
w.first_stage_model = vae.first_stage_model
|
|
||||||
load_model_weights(w, sd)
|
|
||||||
|
|
||||||
if output_clip:
|
|
||||||
w = WeightsLoader()
|
|
||||||
clip_target = model_config.clip_target()
|
|
||||||
clip = CLIP(clip_target, embedding_directory=embedding_directory)
|
|
||||||
w.cond_stage_model = clip.cond_stage_model
|
|
||||||
sd = model_config.process_clip_state_dict(sd)
|
|
||||||
load_model_weights(w, sd)
|
|
||||||
|
|
||||||
return (ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae, clipvision)
|
|
||||||
|
|
||||||
|
|||||||
22
tsc_utils.py
22
tsc_utils.py
@@ -7,6 +7,8 @@ import numpy as np
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
import io
|
||||||
|
from contextlib import contextmanager
|
||||||
import json
|
import json
|
||||||
import folder_paths
|
import folder_paths
|
||||||
|
|
||||||
@@ -176,8 +178,9 @@ def load_checkpoint(ckpt_name, id, output_vae=True, cache=None, cache_overwrite=
|
|||||||
return model, clip, vae
|
return model, clip, vae
|
||||||
|
|
||||||
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
|
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
|
||||||
out = load_checkpoint_guess_config_tsc(ckpt_path, output_vae, output_clip=True,
|
with suppress_output():
|
||||||
embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
out = load_checkpoint_guess_config(ckpt_path, output_vae, output_clip=True,
|
||||||
|
embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||||
model = out[0]
|
model = out[0]
|
||||||
clip = out[1]
|
clip = out[1]
|
||||||
vae = out[2] # bvae
|
vae = out[2] # bvae
|
||||||
@@ -425,3 +428,18 @@ def print_last_helds(id=None):
|
|||||||
print(f" [{i}] Output: {output}")
|
print(f" [{i}] Output: {output}")
|
||||||
print("-" * 40) # Print a separator line
|
print("-" * 40) # Print a separator line
|
||||||
print("\n") # Print an empty line
|
print("\n") # Print an empty line
|
||||||
|
|
||||||
|
# For suppressing print outputs from functions
|
||||||
|
@contextmanager
|
||||||
|
def suppress_output():
|
||||||
|
original_stdout = sys.stdout
|
||||||
|
original_stderr = sys.stderr
|
||||||
|
|
||||||
|
sys.stdout = io.StringIO()
|
||||||
|
sys.stderr = io.StringIO()
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
sys.stdout = original_stdout
|
||||||
|
sys.stderr = original_stderr
|
||||||
Reference in New Issue
Block a user