feat: adapt to more weight formats
This commit is contained in:
parent
3a25179d52
commit
bd62138751
1
models/.gitignore
vendored
1
models/.gitignore
vendored
@ -1,4 +1,5 @@
|
|||||||
*.bin
|
*.bin
|
||||||
*.ckpt
|
*.ckpt
|
||||||
*.safetensor
|
*.safetensor
|
||||||
|
*.safetensors
|
||||||
*.log
|
*.log
|
@ -179,9 +179,9 @@ def preprocess(state_dict):
|
|||||||
state_dict["alphas_cumprod"] = alphas_cumprod
|
state_dict["alphas_cumprod"] = alphas_cumprod
|
||||||
|
|
||||||
new_state_dict = {}
|
new_state_dict = {}
|
||||||
for name in state_dict.keys():
|
for name, w in state_dict.items():
|
||||||
# ignore unused tensors
|
# ignore unused tensors
|
||||||
if not isinstance(state_dict[name], torch.Tensor):
|
if not isinstance(w, torch.Tensor):
|
||||||
continue
|
continue
|
||||||
skip = False
|
skip = False
|
||||||
for unused_tensor in unused_tensors:
|
for unused_tensor in unused_tensors:
|
||||||
@ -191,12 +191,24 @@ def preprocess(state_dict):
|
|||||||
if skip:
|
if skip:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# # convert BF16 to FP16
|
||||||
|
if w.dtype == torch.bfloat16:
|
||||||
|
w = w.to(torch.float16)
|
||||||
|
|
||||||
# convert open_clip to hf CLIPTextModel (for SD2.x)
|
# convert open_clip to hf CLIPTextModel (for SD2.x)
|
||||||
open_clip_to_hf_clip_model = {
|
open_clip_to_hf_clip_model = {
|
||||||
"cond_stage_model.model.ln_final.bias": "cond_stage_model.transformer.text_model.final_layer_norm.bias",
|
"cond_stage_model.model.ln_final.bias": "cond_stage_model.transformer.text_model.final_layer_norm.bias",
|
||||||
"cond_stage_model.model.ln_final.weight": "cond_stage_model.transformer.text_model.final_layer_norm.weight",
|
"cond_stage_model.model.ln_final.weight": "cond_stage_model.transformer.text_model.final_layer_norm.weight",
|
||||||
"cond_stage_model.model.positional_embedding": "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight",
|
"cond_stage_model.model.positional_embedding": "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight",
|
||||||
"cond_stage_model.model.token_embedding.weight": "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight",
|
"cond_stage_model.model.token_embedding.weight": "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight",
|
||||||
|
"first_stage_model.decoder.mid.attn_1.to_k.bias": "first_stage_model.decoder.mid.attn_1.k.bias",
|
||||||
|
"first_stage_model.decoder.mid.attn_1.to_k.weight": "first_stage_model.decoder.mid.attn_1.k.weight",
|
||||||
|
"first_stage_model.decoder.mid.attn_1.to_out.0.bias": "first_stage_model.decoder.mid.attn_1.proj_out.bias",
|
||||||
|
"first_stage_model.decoder.mid.attn_1.to_out.0.weight": "first_stage_model.decoder.mid.attn_1.proj_out.weight",
|
||||||
|
"first_stage_model.decoder.mid.attn_1.to_q.bias": "first_stage_model.decoder.mid.attn_1.q.bias",
|
||||||
|
"first_stage_model.decoder.mid.attn_1.to_q.weight": "first_stage_model.decoder.mid.attn_1.q.weight",
|
||||||
|
"first_stage_model.decoder.mid.attn_1.to_v.bias": "first_stage_model.decoder.mid.attn_1.v.bias",
|
||||||
|
"first_stage_model.decoder.mid.attn_1.to_v.weight": "first_stage_model.decoder.mid.attn_1.v.weight",
|
||||||
}
|
}
|
||||||
open_clip_to_hk_clip_resblock = {
|
open_clip_to_hk_clip_resblock = {
|
||||||
"attn.out_proj.bias": "self_attn.out_proj.bias",
|
"attn.out_proj.bias": "self_attn.out_proj.bias",
|
||||||
@ -214,22 +226,19 @@ def preprocess(state_dict):
|
|||||||
hf_clip_resblock_prefix = "cond_stage_model.transformer.text_model.encoder.layers."
|
hf_clip_resblock_prefix = "cond_stage_model.transformer.text_model.encoder.layers."
|
||||||
if name in open_clip_to_hf_clip_model:
|
if name in open_clip_to_hf_clip_model:
|
||||||
new_name = open_clip_to_hf_clip_model[name]
|
new_name = open_clip_to_hf_clip_model[name]
|
||||||
new_state_dict[new_name] = state_dict[name]
|
|
||||||
print(f"preprocess {name} => {new_name}")
|
print(f"preprocess {name} => {new_name}")
|
||||||
continue
|
name = new_name
|
||||||
if name.startswith(open_clip_resblock_prefix):
|
if name.startswith(open_clip_resblock_prefix):
|
||||||
remain = name[len(open_clip_resblock_prefix):]
|
remain = name[len(open_clip_resblock_prefix):]
|
||||||
idx = remain.split(".")[0]
|
idx = remain.split(".")[0]
|
||||||
suffix = remain[len(idx)+1:]
|
suffix = remain[len(idx)+1:]
|
||||||
if suffix == "attn.in_proj_weight":
|
if suffix == "attn.in_proj_weight":
|
||||||
w = state_dict[name]
|
|
||||||
w_q, w_k, w_v = w.chunk(3)
|
w_q, w_k, w_v = w.chunk(3)
|
||||||
for new_suffix, new_w in zip(["self_attn.q_proj.weight", "self_attn.k_proj.weight", "self_attn.v_proj.weight"], [w_q, w_k, w_v]):
|
for new_suffix, new_w in zip(["self_attn.q_proj.weight", "self_attn.k_proj.weight", "self_attn.v_proj.weight"], [w_q, w_k, w_v]):
|
||||||
new_name = hf_clip_resblock_prefix + idx + "." + new_suffix
|
new_name = hf_clip_resblock_prefix + idx + "." + new_suffix
|
||||||
new_state_dict[new_name] = new_w
|
new_state_dict[new_name] = new_w
|
||||||
print(f"preprocess {name}{w.size()} => {new_name}{new_w.size()}")
|
print(f"preprocess {name}{w.size()} => {new_name}{new_w.size()}")
|
||||||
elif suffix == "attn.in_proj_bias":
|
elif suffix == "attn.in_proj_bias":
|
||||||
w = state_dict[name]
|
|
||||||
w_q, w_k, w_v = w.chunk(3)
|
w_q, w_k, w_v = w.chunk(3)
|
||||||
for new_suffix, new_w in zip(["self_attn.q_proj.bias", "self_attn.k_proj.bias", "self_attn.v_proj.bias"], [w_q, w_k, w_v]):
|
for new_suffix, new_w in zip(["self_attn.q_proj.bias", "self_attn.k_proj.bias", "self_attn.v_proj.bias"], [w_q, w_k, w_v]):
|
||||||
new_name = hf_clip_resblock_prefix + idx + "." + new_suffix
|
new_name = hf_clip_resblock_prefix + idx + "." + new_suffix
|
||||||
@ -238,20 +247,27 @@ def preprocess(state_dict):
|
|||||||
else:
|
else:
|
||||||
new_suffix = open_clip_to_hk_clip_resblock[suffix]
|
new_suffix = open_clip_to_hk_clip_resblock[suffix]
|
||||||
new_name = hf_clip_resblock_prefix + idx + "." + new_suffix
|
new_name = hf_clip_resblock_prefix + idx + "." + new_suffix
|
||||||
new_state_dict[new_name] = state_dict[name]
|
new_state_dict[new_name] = w
|
||||||
print(f"preprocess {name} => {new_name}")
|
print(f"preprocess {name} => {new_name}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# convert unet transformer linear to conv2d 1x1
|
# convert unet transformer linear to conv2d 1x1
|
||||||
if name.startswith("model.diffusion_model.") and (name.endswith("proj_in.weight") or name.endswith("proj_out.weight")):
|
if name.startswith("model.diffusion_model.") and (name.endswith("proj_in.weight") or name.endswith("proj_out.weight")):
|
||||||
w = state_dict[name]
|
if len(w.shape) == 2:
|
||||||
if len(state_dict[name].shape) == 2:
|
|
||||||
new_w = w.unsqueeze(2).unsqueeze(3)
|
new_w = w.unsqueeze(2).unsqueeze(3)
|
||||||
new_state_dict[name] = new_w
|
new_state_dict[name] = new_w
|
||||||
print(f"preprocess {name} {w.size()} => {name} {new_w.size()}")
|
print(f"preprocess {name} {w.size()} => {name} {new_w.size()}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
new_state_dict[name] = state_dict[name]
|
# convert vae attn block linear to conv2d 1x1
|
||||||
|
if name.startswith("first_stage_model.") and "attn_1" in name:
|
||||||
|
if len(w.shape) == 2:
|
||||||
|
new_w = w.unsqueeze(2).unsqueeze(3)
|
||||||
|
new_state_dict[name] = new_w
|
||||||
|
print(f"preprocess {name} {w.size()} => {name} {new_w.size()}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
new_state_dict[name] = w
|
||||||
return new_state_dict
|
return new_state_dict
|
||||||
|
|
||||||
def convert(model_path, out_type = None, out_file=None):
|
def convert(model_path, out_type = None, out_file=None):
|
||||||
|
Loading…
Reference in New Issue
Block a user