fix: support bf16 lora weights (#82)

This commit is contained in:
Erik Scholz 2023-11-20 15:34:17 +01:00 committed by GitHub
parent ae1d5dcebb
commit c874063408
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -192,7 +192,7 @@ def preprocess(state_dict):
if skip:
continue
# # convert BF16 to FP16
# convert BF16 to FP16
if w.dtype == torch.bfloat16:
w = w.to(torch.float16)
@ -342,6 +342,11 @@ def preprocess_lora(state_dict):
for name, w in state_dict.items():
if not isinstance(w, torch.Tensor):
continue
# convert BF16 to FP16
if w.dtype == torch.bfloat16:
w = w.to(torch.float16)
name_without_network_parts, network_part = name.split(".", 1)
new_name_without_network_parts = convert_diffusers_name_to_compvis(name_without_network_parts)
if new_name_without_network_parts == None:
@ -421,6 +426,7 @@ def convert(model_path, out_type = None, out_file=None, lora=False):
continue
if name in unused_tensors:
continue
data = state_dict[name].numpy()
n_dims = len(data.shape)