feat: adapt to more weight formats
This commit is contained in:
parent
3a25179d52
commit
bd62138751
1
models/.gitignore
vendored
1
models/.gitignore
vendored
@ -1,4 +1,5 @@
|
||||
*.bin
|
||||
*.ckpt
|
||||
*.safetensor
|
||||
*.safetensors
|
||||
*.log
|
@ -179,9 +179,9 @@ def preprocess(state_dict):
|
||||
state_dict["alphas_cumprod"] = alphas_cumprod
|
||||
|
||||
new_state_dict = {}
|
||||
for name in state_dict.keys():
|
||||
for name, w in state_dict.items():
|
||||
# ignore unused tensors
|
||||
if not isinstance(state_dict[name], torch.Tensor):
|
||||
if not isinstance(w, torch.Tensor):
|
||||
continue
|
||||
skip = False
|
||||
for unused_tensor in unused_tensors:
|
||||
@ -190,13 +190,25 @@ def preprocess(state_dict):
|
||||
break
|
||||
if skip:
|
||||
continue
|
||||
|
||||
|
||||
# # convert BF16 to FP16
|
||||
if w.dtype == torch.bfloat16:
|
||||
w = w.to(torch.float16)
|
||||
|
||||
# convert open_clip to hf CLIPTextModel (for SD2.x)
|
||||
open_clip_to_hf_clip_model = {
|
||||
"cond_stage_model.model.ln_final.bias": "cond_stage_model.transformer.text_model.final_layer_norm.bias",
|
||||
"cond_stage_model.model.ln_final.weight": "cond_stage_model.transformer.text_model.final_layer_norm.weight",
|
||||
"cond_stage_model.model.positional_embedding": "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight",
|
||||
"cond_stage_model.model.token_embedding.weight": "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight",
|
||||
"first_stage_model.decoder.mid.attn_1.to_k.bias": "first_stage_model.decoder.mid.attn_1.k.bias",
|
||||
"first_stage_model.decoder.mid.attn_1.to_k.weight": "first_stage_model.decoder.mid.attn_1.k.weight",
|
||||
"first_stage_model.decoder.mid.attn_1.to_out.0.bias": "first_stage_model.decoder.mid.attn_1.proj_out.bias",
|
||||
"first_stage_model.decoder.mid.attn_1.to_out.0.weight": "first_stage_model.decoder.mid.attn_1.proj_out.weight",
|
||||
"first_stage_model.decoder.mid.attn_1.to_q.bias": "first_stage_model.decoder.mid.attn_1.q.bias",
|
||||
"first_stage_model.decoder.mid.attn_1.to_q.weight": "first_stage_model.decoder.mid.attn_1.q.weight",
|
||||
"first_stage_model.decoder.mid.attn_1.to_v.bias": "first_stage_model.decoder.mid.attn_1.v.bias",
|
||||
"first_stage_model.decoder.mid.attn_1.to_v.weight": "first_stage_model.decoder.mid.attn_1.v.weight",
|
||||
}
|
||||
open_clip_to_hk_clip_resblock = {
|
||||
"attn.out_proj.bias": "self_attn.out_proj.bias",
|
||||
@ -214,22 +226,19 @@ def preprocess(state_dict):
|
||||
hf_clip_resblock_prefix = "cond_stage_model.transformer.text_model.encoder.layers."
|
||||
if name in open_clip_to_hf_clip_model:
|
||||
new_name = open_clip_to_hf_clip_model[name]
|
||||
new_state_dict[new_name] = state_dict[name]
|
||||
print(f"preprocess {name} => {new_name}")
|
||||
continue
|
||||
name = new_name
|
||||
if name.startswith(open_clip_resblock_prefix):
|
||||
remain = name[len(open_clip_resblock_prefix):]
|
||||
idx = remain.split(".")[0]
|
||||
suffix = remain[len(idx)+1:]
|
||||
if suffix == "attn.in_proj_weight":
|
||||
w = state_dict[name]
|
||||
w_q, w_k, w_v = w.chunk(3)
|
||||
for new_suffix, new_w in zip(["self_attn.q_proj.weight", "self_attn.k_proj.weight", "self_attn.v_proj.weight"], [w_q, w_k, w_v]):
|
||||
new_name = hf_clip_resblock_prefix + idx + "." + new_suffix
|
||||
new_state_dict[new_name] = new_w
|
||||
print(f"preprocess {name}{w.size()} => {new_name}{new_w.size()}")
|
||||
elif suffix == "attn.in_proj_bias":
|
||||
w = state_dict[name]
|
||||
w_q, w_k, w_v = w.chunk(3)
|
||||
for new_suffix, new_w in zip(["self_attn.q_proj.bias", "self_attn.k_proj.bias", "self_attn.v_proj.bias"], [w_q, w_k, w_v]):
|
||||
new_name = hf_clip_resblock_prefix + idx + "." + new_suffix
|
||||
@ -238,20 +247,27 @@ def preprocess(state_dict):
|
||||
else:
|
||||
new_suffix = open_clip_to_hk_clip_resblock[suffix]
|
||||
new_name = hf_clip_resblock_prefix + idx + "." + new_suffix
|
||||
new_state_dict[new_name] = state_dict[name]
|
||||
new_state_dict[new_name] = w
|
||||
print(f"preprocess {name} => {new_name}")
|
||||
continue
|
||||
|
||||
# convert unet transformer linear to conv2d 1x1
|
||||
if name.startswith("model.diffusion_model.") and (name.endswith("proj_in.weight") or name.endswith("proj_out.weight")):
|
||||
w = state_dict[name]
|
||||
if len(state_dict[name].shape) == 2:
|
||||
if len(w.shape) == 2:
|
||||
new_w = w.unsqueeze(2).unsqueeze(3)
|
||||
new_state_dict[name] = new_w
|
||||
print(f"preprocess {name} {w.size()} => {name} {new_w.size()}")
|
||||
continue
|
||||
|
||||
new_state_dict[name] = state_dict[name]
|
||||
# convert vae attn block linear to conv2d 1x1
|
||||
if name.startswith("first_stage_model.") and "attn_1" in name:
|
||||
if len(w.shape) == 2:
|
||||
new_w = w.unsqueeze(2).unsqueeze(3)
|
||||
new_state_dict[name] = new_w
|
||||
print(f"preprocess {name} {w.size()} => {name} {new_w.size()}")
|
||||
continue
|
||||
|
||||
new_state_dict[name] = w
|
||||
return new_state_dict
|
||||
|
||||
def convert(model_path, out_type = None, out_file=None):
|
||||
|
Loading…
Reference in New Issue
Block a user