feat: add LoRA support
This commit is contained in:
@@ -4,6 +4,7 @@ import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import re
|
||||
import safetensors.torch
|
||||
|
||||
this_file_dir = os.path.dirname(__file__)
|
||||
@@ -270,21 +271,107 @@ def preprocess(state_dict):
|
||||
new_state_dict[name] = w
|
||||
return new_state_dict
|
||||
|
||||
def convert(model_path, out_type = None, out_file=None):
|
||||
re_digits = re.compile(r"\d+")
|
||||
re_x_proj = re.compile(r"(.*)_([qkv]_proj)$")
|
||||
re_compiled = {}
|
||||
|
||||
suffix_conversion = {
|
||||
"attentions": {},
|
||||
"resnets": {
|
||||
"conv1": "in_layers_2",
|
||||
"conv2": "out_layers_3",
|
||||
"norm1": "in_layers_0",
|
||||
"norm2": "out_layers_0",
|
||||
"time_emb_proj": "emb_layers_1",
|
||||
"conv_shortcut": "skip_connection",
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def convert_diffusers_name_to_compvis(key):
|
||||
def match(match_list, regex_text):
|
||||
regex = re_compiled.get(regex_text)
|
||||
if regex is None:
|
||||
regex = re.compile(regex_text)
|
||||
re_compiled[regex_text] = regex
|
||||
|
||||
r = re.match(regex, key)
|
||||
if not r:
|
||||
return False
|
||||
|
||||
match_list.clear()
|
||||
match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()])
|
||||
return True
|
||||
|
||||
m = []
|
||||
|
||||
if match(m, r"lora_unet_conv_in(.*)"):
|
||||
return f'model_diffusion_model_input_blocks_0_0{m[0]}'
|
||||
|
||||
if match(m, r"lora_unet_conv_out(.*)"):
|
||||
return f'model_diffusion_model_out_2{m[0]}'
|
||||
|
||||
if match(m, r"lora_unet_time_embedding_linear_(\d+)(.*)"):
|
||||
return f"model_diffusion_model_time_embed_{m[0] * 2 - 2}{m[1]}"
|
||||
|
||||
if match(m, r"lora_unet_down_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"):
|
||||
suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3])
|
||||
return f"model_diffusion_model_input_blocks_{1 + m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}"
|
||||
|
||||
if match(m, r"lora_unet_mid_block_(attentions|resnets)_(\d+)_(.+)"):
|
||||
suffix = suffix_conversion.get(m[0], {}).get(m[2], m[2])
|
||||
return f"model_diffusion_model_middle_block_{1 if m[0] == 'attentions' else m[1] * 2}_{suffix}"
|
||||
|
||||
if match(m, r"lora_unet_up_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"):
|
||||
suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3])
|
||||
return f"model_diffusion_model_output_blocks_{m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}"
|
||||
|
||||
if match(m, r"lora_unet_down_blocks_(\d+)_downsamplers_0_conv"):
|
||||
return f"model_diffusion_model_input_blocks_{3 + m[0] * 3}_0_op"
|
||||
|
||||
if match(m, r"lora_unet_up_blocks_(\d+)_upsamplers_0_conv"):
|
||||
return f"model_diffusion_model_output_blocks_{2 + m[0] * 3}_{2 if m[0]>0 else 1}_conv"
|
||||
|
||||
if match(m, r"lora_te_text_model_encoder_layers_(\d+)_(.+)"):
|
||||
return f"cond_stage_model_transformer_text_model_encoder_layers_{m[0]}_{m[1]}"
|
||||
|
||||
return None
|
||||
|
||||
def preprocess_lora(state_dict):
|
||||
new_state_dict = {}
|
||||
for name, w in state_dict.items():
|
||||
if not isinstance(w, torch.Tensor):
|
||||
continue
|
||||
name_without_network_parts, network_part = name.split(".", 1)
|
||||
new_name_without_network_parts = convert_diffusers_name_to_compvis(name_without_network_parts)
|
||||
if new_name_without_network_parts == None:
|
||||
raise Exception(f"unknown lora tensor: {name}")
|
||||
new_name = new_name_without_network_parts + "." + network_part
|
||||
print(f"preprocess {name} => {new_name}")
|
||||
new_state_dict[new_name] = w
|
||||
return new_state_dict
|
||||
|
||||
def convert(model_path, out_type = None, out_file=None, lora=False):
|
||||
# load model
|
||||
with open(os.path.join(vocab_dir, "vocab.json"), encoding="utf-8") as f:
|
||||
clip_vocab = json.load(f)
|
||||
|
||||
if not lora:
|
||||
with open(os.path.join(vocab_dir, "vocab.json"), encoding="utf-8") as f:
|
||||
clip_vocab = json.load(f)
|
||||
|
||||
state_dict = load_model_from_file(model_path)
|
||||
model_type = SD1
|
||||
if "cond_stage_model.model.token_embedding.weight" in state_dict.keys():
|
||||
model_type = SD1 # lora only for SD1 now
|
||||
if not lora and "cond_stage_model.model.token_embedding.weight" in state_dict.keys():
|
||||
model_type = SD2
|
||||
print("Stable diffuison 2.x")
|
||||
else:
|
||||
print("Stable diffuison 1.x")
|
||||
state_dict = preprocess(state_dict)
|
||||
if lora:
|
||||
state_dict = preprocess_lora(state_dict)
|
||||
else:
|
||||
state_dict = preprocess(state_dict)
|
||||
|
||||
# output option
|
||||
if lora:
|
||||
out_type = "f16" # only f16 for now
|
||||
if out_type == None:
|
||||
weight = state_dict["model.diffusion_model.input_blocks.0.0.weight"].numpy()
|
||||
if weight.dtype == np.float32:
|
||||
@@ -296,7 +383,10 @@ def convert(model_path, out_type = None, out_file=None):
|
||||
else:
|
||||
raise Exception("unsupported weight type %s" % weight.dtype)
|
||||
if out_file == None:
|
||||
out_file = os.path.splitext(os.path.basename(model_path))[0] + f"-ggml-model-{out_type}.bin"
|
||||
if lora:
|
||||
out_file = os.path.splitext(os.path.basename(model_path))[0] + f"-ggml-lora.bin"
|
||||
else:
|
||||
out_file = os.path.splitext(os.path.basename(model_path))[0] + f"-ggml-model-{out_type}.bin"
|
||||
out_file = os.path.join(os.getcwd(), out_file)
|
||||
print(f"Saving GGML compatible file to {out_file}")
|
||||
|
||||
@@ -309,14 +399,15 @@ def convert(model_path, out_type = None, out_file=None):
|
||||
file.write(struct.pack("i", ftype))
|
||||
|
||||
# vocab
|
||||
byte_encoder = bytes_to_unicode()
|
||||
byte_decoder = {v: k for k, v in byte_encoder.items()}
|
||||
file.write(struct.pack("i", len(clip_vocab)))
|
||||
for key in clip_vocab:
|
||||
text = bytearray([byte_decoder[c] for c in key])
|
||||
file.write(struct.pack("i", len(text)))
|
||||
file.write(text)
|
||||
|
||||
if not lora:
|
||||
byte_encoder = bytes_to_unicode()
|
||||
byte_decoder = {v: k for k, v in byte_encoder.items()}
|
||||
file.write(struct.pack("i", len(clip_vocab)))
|
||||
for key in clip_vocab:
|
||||
text = bytearray([byte_decoder[c] for c in key])
|
||||
file.write(struct.pack("i", len(text)))
|
||||
file.write(text)
|
||||
|
||||
# weights
|
||||
for name in state_dict.keys():
|
||||
if not isinstance(state_dict[name], torch.Tensor):
|
||||
@@ -337,7 +428,7 @@ def convert(model_path, out_type = None, out_file=None):
|
||||
old_type = data.dtype
|
||||
|
||||
ttype = "f32"
|
||||
if n_dims == 4:
|
||||
if n_dims == 4 and not lora:
|
||||
data = data.astype(np.float16)
|
||||
ttype = "f16"
|
||||
elif n_dims == 2 and name[-7:] == ".weight":
|
||||
@@ -380,6 +471,7 @@ if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Convert Stable Diffuison model to GGML compatible file format")
|
||||
parser.add_argument("--out_type", choices=["f32", "f16", "q4_0", "q4_1", "q5_0", "q5_1", "q8_0"], help="output format (default: based on input)")
|
||||
parser.add_argument("--out_file", help="path to write to; default: based on input and current working directory")
|
||||
parser.add_argument("--lora", action='store_true', default = False, help="convert lora weight; default: false")
|
||||
parser.add_argument("model_path", help="model file path (*.pth, *.pt, *.ckpt, *.safetensors)")
|
||||
args = parser.parse_args()
|
||||
convert(args.model_path, args.out_type, args.out_file)
|
||||
convert(args.model_path, args.out_type, args.out_file, args.lora)
|
||||
|
||||
Reference in New Issue
Block a user