fix: support bf16 lora weights (#82)
This commit is contained in:
parent
ae1d5dcebb
commit
c874063408
@ -192,7 +192,7 @@ def preprocess(state_dict):
|
||||
if skip:
|
||||
continue
|
||||
|
||||
# # convert BF16 to FP16
|
||||
# convert BF16 to FP16
|
||||
if w.dtype == torch.bfloat16:
|
||||
w = w.to(torch.float16)
|
||||
|
||||
@ -342,6 +342,11 @@ def preprocess_lora(state_dict):
|
||||
for name, w in state_dict.items():
|
||||
if not isinstance(w, torch.Tensor):
|
||||
continue
|
||||
|
||||
# convert BF16 to FP16
|
||||
if w.dtype == torch.bfloat16:
|
||||
w = w.to(torch.float16)
|
||||
|
||||
name_without_network_parts, network_part = name.split(".", 1)
|
||||
new_name_without_network_parts = convert_diffusers_name_to_compvis(name_without_network_parts)
|
||||
if new_name_without_network_parts == None:
|
||||
@ -421,6 +426,7 @@ def convert(model_path, out_type = None, out_file=None, lora=False):
|
||||
continue
|
||||
if name in unused_tensors:
|
||||
continue
|
||||
|
||||
data = state_dict[name].numpy()
|
||||
|
||||
n_dims = len(data.shape)
|
||||
|
Loading…
Reference in New Issue
Block a user