feat: add LoRA support

This commit is contained in:
leejet
2023-11-19 17:43:49 +08:00
parent 536f3af672
commit 9a9f3daf8e
7 changed files with 573 additions and 36 deletions

View File

@@ -4,6 +4,7 @@ import os
import numpy as np
import torch
import re
import safetensors.torch
this_file_dir = os.path.dirname(__file__)
@@ -270,21 +271,107 @@ def preprocess(state_dict):
new_state_dict[name] = w
return new_state_dict
def convert(model_path, out_type = None, out_file=None):
re_digits = re.compile(r"\d+")
re_x_proj = re.compile(r"(.*)_([qkv]_proj)$")
re_compiled = {}
suffix_conversion = {
"attentions": {},
"resnets": {
"conv1": "in_layers_2",
"conv2": "out_layers_3",
"norm1": "in_layers_0",
"norm2": "out_layers_0",
"time_emb_proj": "emb_layers_1",
"conv_shortcut": "skip_connection",
}
}
def convert_diffusers_name_to_compvis(key):
def match(match_list, regex_text):
regex = re_compiled.get(regex_text)
if regex is None:
regex = re.compile(regex_text)
re_compiled[regex_text] = regex
r = re.match(regex, key)
if not r:
return False
match_list.clear()
match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()])
return True
m = []
if match(m, r"lora_unet_conv_in(.*)"):
return f'model_diffusion_model_input_blocks_0_0{m[0]}'
if match(m, r"lora_unet_conv_out(.*)"):
return f'model_diffusion_model_out_2{m[0]}'
if match(m, r"lora_unet_time_embedding_linear_(\d+)(.*)"):
return f"model_diffusion_model_time_embed_{m[0] * 2 - 2}{m[1]}"
if match(m, r"lora_unet_down_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"):
suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3])
return f"model_diffusion_model_input_blocks_{1 + m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}"
if match(m, r"lora_unet_mid_block_(attentions|resnets)_(\d+)_(.+)"):
suffix = suffix_conversion.get(m[0], {}).get(m[2], m[2])
return f"model_diffusion_model_middle_block_{1 if m[0] == 'attentions' else m[1] * 2}_{suffix}"
if match(m, r"lora_unet_up_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"):
suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3])
return f"model_diffusion_model_output_blocks_{m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}"
if match(m, r"lora_unet_down_blocks_(\d+)_downsamplers_0_conv"):
return f"model_diffusion_model_input_blocks_{3 + m[0] * 3}_0_op"
if match(m, r"lora_unet_up_blocks_(\d+)_upsamplers_0_conv"):
return f"model_diffusion_model_output_blocks_{2 + m[0] * 3}_{2 if m[0]>0 else 1}_conv"
if match(m, r"lora_te_text_model_encoder_layers_(\d+)_(.+)"):
return f"cond_stage_model_transformer_text_model_encoder_layers_{m[0]}_{m[1]}"
return None
def preprocess_lora(state_dict):
new_state_dict = {}
for name, w in state_dict.items():
if not isinstance(w, torch.Tensor):
continue
name_without_network_parts, network_part = name.split(".", 1)
new_name_without_network_parts = convert_diffusers_name_to_compvis(name_without_network_parts)
if new_name_without_network_parts == None:
raise Exception(f"unknown lora tensor: {name}")
new_name = new_name_without_network_parts + "." + network_part
print(f"preprocess {name} => {new_name}")
new_state_dict[new_name] = w
return new_state_dict
def convert(model_path, out_type = None, out_file=None, lora=False):
# load model
with open(os.path.join(vocab_dir, "vocab.json"), encoding="utf-8") as f:
clip_vocab = json.load(f)
if not lora:
with open(os.path.join(vocab_dir, "vocab.json"), encoding="utf-8") as f:
clip_vocab = json.load(f)
state_dict = load_model_from_file(model_path)
model_type = SD1
if "cond_stage_model.model.token_embedding.weight" in state_dict.keys():
model_type = SD1 # lora only for SD1 now
if not lora and "cond_stage_model.model.token_embedding.weight" in state_dict.keys():
model_type = SD2
print("Stable diffuison 2.x")
else:
print("Stable diffuison 1.x")
state_dict = preprocess(state_dict)
if lora:
state_dict = preprocess_lora(state_dict)
else:
state_dict = preprocess(state_dict)
# output option
if lora:
out_type = "f16" # only f16 for now
if out_type == None:
weight = state_dict["model.diffusion_model.input_blocks.0.0.weight"].numpy()
if weight.dtype == np.float32:
@@ -296,7 +383,10 @@ def convert(model_path, out_type = None, out_file=None):
else:
raise Exception("unsupported weight type %s" % weight.dtype)
if out_file == None:
out_file = os.path.splitext(os.path.basename(model_path))[0] + f"-ggml-model-{out_type}.bin"
if lora:
out_file = os.path.splitext(os.path.basename(model_path))[0] + f"-ggml-lora.bin"
else:
out_file = os.path.splitext(os.path.basename(model_path))[0] + f"-ggml-model-{out_type}.bin"
out_file = os.path.join(os.getcwd(), out_file)
print(f"Saving GGML compatible file to {out_file}")
@@ -309,14 +399,15 @@ def convert(model_path, out_type = None, out_file=None):
file.write(struct.pack("i", ftype))
# vocab
byte_encoder = bytes_to_unicode()
byte_decoder = {v: k for k, v in byte_encoder.items()}
file.write(struct.pack("i", len(clip_vocab)))
for key in clip_vocab:
text = bytearray([byte_decoder[c] for c in key])
file.write(struct.pack("i", len(text)))
file.write(text)
if not lora:
byte_encoder = bytes_to_unicode()
byte_decoder = {v: k for k, v in byte_encoder.items()}
file.write(struct.pack("i", len(clip_vocab)))
for key in clip_vocab:
text = bytearray([byte_decoder[c] for c in key])
file.write(struct.pack("i", len(text)))
file.write(text)
# weights
for name in state_dict.keys():
if not isinstance(state_dict[name], torch.Tensor):
@@ -337,7 +428,7 @@ def convert(model_path, out_type = None, out_file=None):
old_type = data.dtype
ttype = "f32"
if n_dims == 4:
if n_dims == 4 and not lora:
data = data.astype(np.float16)
ttype = "f16"
elif n_dims == 2 and name[-7:] == ".weight":
@@ -380,6 +471,7 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert Stable Diffuison model to GGML compatible file format")
parser.add_argument("--out_type", choices=["f32", "f16", "q4_0", "q4_1", "q5_0", "q5_1", "q8_0"], help="output format (default: based on input)")
parser.add_argument("--out_file", help="path to write to; default: based on input and current working directory")
parser.add_argument("--lora", action='store_true', default = False, help="convert lora weight; default: false")
parser.add_argument("model_path", help="model file path (*.pth, *.pt, *.ckpt, *.safetensors)")
args = parser.parse_args()
convert(args.model_path, args.out_type, args.out_file)
convert(args.model_path, args.out_type, args.out_file, args.lora)