fix: support bf16 lora weights (#82)
This commit is contained in:
parent
ae1d5dcebb
commit
c874063408
@ -101,7 +101,7 @@ QK8_0 = 32
|
|||||||
def quantize_q8_0(x):
|
def quantize_q8_0(x):
|
||||||
assert x.shape[-1] % QK8_0 == 0 and x.shape[-1] > QK8_0
|
assert x.shape[-1] % QK8_0 == 0 and x.shape[-1] > QK8_0
|
||||||
x = x.reshape(-1, QK8_0)
|
x = x.reshape(-1, QK8_0)
|
||||||
amax = np.max(np.abs(x), axis=-1, keepdims=True)
|
amax = np.max(np.abs(x), axis=-1, keepdims=True)
|
||||||
d = amax / ((1 << 7) - 1)
|
d = amax / ((1 << 7) - 1)
|
||||||
qs = (x / d).round().clip(min=-128, max=127).astype(np.int8)
|
qs = (x / d).round().clip(min=-128, max=127).astype(np.int8)
|
||||||
d = d.astype(np.float16).view(np.int8)
|
d = d.astype(np.float16).view(np.int8)
|
||||||
@ -178,7 +178,7 @@ def preprocess(state_dict):
|
|||||||
print("no alphas_cumprod in file, generate new one")
|
print("no alphas_cumprod in file, generate new one")
|
||||||
alphas_cumprod = get_alpha_comprod()
|
alphas_cumprod = get_alpha_comprod()
|
||||||
state_dict["alphas_cumprod"] = alphas_cumprod
|
state_dict["alphas_cumprod"] = alphas_cumprod
|
||||||
|
|
||||||
new_state_dict = {}
|
new_state_dict = {}
|
||||||
for name, w in state_dict.items():
|
for name, w in state_dict.items():
|
||||||
# ignore unused tensors
|
# ignore unused tensors
|
||||||
@ -192,7 +192,7 @@ def preprocess(state_dict):
|
|||||||
if skip:
|
if skip:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# # convert BF16 to FP16
|
# convert BF16 to FP16
|
||||||
if w.dtype == torch.bfloat16:
|
if w.dtype == torch.bfloat16:
|
||||||
w = w.to(torch.float16)
|
w = w.to(torch.float16)
|
||||||
|
|
||||||
@ -251,7 +251,7 @@ def preprocess(state_dict):
|
|||||||
new_state_dict[new_name] = w
|
new_state_dict[new_name] = w
|
||||||
print(f"preprocess {name} => {new_name}")
|
print(f"preprocess {name} => {new_name}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# convert unet transformer linear to conv2d 1x1
|
# convert unet transformer linear to conv2d 1x1
|
||||||
if name.startswith("model.diffusion_model.") and (name.endswith("proj_in.weight") or name.endswith("proj_out.weight")):
|
if name.startswith("model.diffusion_model.") and (name.endswith("proj_in.weight") or name.endswith("proj_out.weight")):
|
||||||
if len(w.shape) == 2:
|
if len(w.shape) == 2:
|
||||||
@ -342,6 +342,11 @@ def preprocess_lora(state_dict):
|
|||||||
for name, w in state_dict.items():
|
for name, w in state_dict.items():
|
||||||
if not isinstance(w, torch.Tensor):
|
if not isinstance(w, torch.Tensor):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# convert BF16 to FP16
|
||||||
|
if w.dtype == torch.bfloat16:
|
||||||
|
w = w.to(torch.float16)
|
||||||
|
|
||||||
name_without_network_parts, network_part = name.split(".", 1)
|
name_without_network_parts, network_part = name.split(".", 1)
|
||||||
new_name_without_network_parts = convert_diffusers_name_to_compvis(name_without_network_parts)
|
new_name_without_network_parts = convert_diffusers_name_to_compvis(name_without_network_parts)
|
||||||
if new_name_without_network_parts == None:
|
if new_name_without_network_parts == None:
|
||||||
@ -421,6 +426,7 @@ def convert(model_path, out_type = None, out_file=None, lora=False):
|
|||||||
continue
|
continue
|
||||||
if name in unused_tensors:
|
if name in unused_tensors:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
data = state_dict[name].numpy()
|
data = state_dict[name].numpy()
|
||||||
|
|
||||||
n_dims = len(data.shape)
|
n_dims = len(data.shape)
|
||||||
@ -452,7 +458,7 @@ def convert(model_path, out_type = None, out_file=None, lora=False):
|
|||||||
else:
|
else:
|
||||||
data = data.astype(np.float32)
|
data = data.astype(np.float32)
|
||||||
ttype = "f32"
|
ttype = "f32"
|
||||||
|
|
||||||
print("Processing tensor: {} with shape {}, {} -> {}".format(name, data.shape, old_type, ttype))
|
print("Processing tensor: {} with shape {}, {} -> {}".format(name, data.shape, old_type, ttype))
|
||||||
|
|
||||||
# header
|
# header
|
||||||
|
Loading…
Reference in New Issue
Block a user