diff --git a/models/.gitignore b/models/.gitignore index 756d38c..33417ff 100644 --- a/models/.gitignore +++ b/models/.gitignore @@ -1,4 +1,5 @@ *.bin *.ckpt *.safetensor +*.safetensors *.log \ No newline at end of file diff --git a/models/convert.py b/models/convert.py index b324e20..8ef2fcc 100644 --- a/models/convert.py +++ b/models/convert.py @@ -179,9 +179,9 @@ def preprocess(state_dict): state_dict["alphas_cumprod"] = alphas_cumprod new_state_dict = {} - for name in state_dict.keys(): + for name, w in state_dict.items(): # ignore unused tensors - if not isinstance(state_dict[name], torch.Tensor): + if not isinstance(w, torch.Tensor): continue skip = False for unused_tensor in unused_tensors: @@ -190,13 +190,25 @@ def preprocess(state_dict): break if skip: continue - + + # # convert BF16 to FP16 + if w.dtype == torch.bfloat16: + w = w.to(torch.float16) + # convert open_clip to hf CLIPTextModel (for SD2.x) open_clip_to_hf_clip_model = { "cond_stage_model.model.ln_final.bias": "cond_stage_model.transformer.text_model.final_layer_norm.bias", "cond_stage_model.model.ln_final.weight": "cond_stage_model.transformer.text_model.final_layer_norm.weight", "cond_stage_model.model.positional_embedding": "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight", "cond_stage_model.model.token_embedding.weight": "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight", + "first_stage_model.decoder.mid.attn_1.to_k.bias": "first_stage_model.decoder.mid.attn_1.k.bias", + "first_stage_model.decoder.mid.attn_1.to_k.weight": "first_stage_model.decoder.mid.attn_1.k.weight", + "first_stage_model.decoder.mid.attn_1.to_out.0.bias": "first_stage_model.decoder.mid.attn_1.proj_out.bias", + "first_stage_model.decoder.mid.attn_1.to_out.0.weight": "first_stage_model.decoder.mid.attn_1.proj_out.weight", + "first_stage_model.decoder.mid.attn_1.to_q.bias": "first_stage_model.decoder.mid.attn_1.q.bias", + "first_stage_model.decoder.mid.attn_1.to_q.weight": "first_stage_model.decoder.mid.attn_1.q.weight", + "first_stage_model.decoder.mid.attn_1.to_v.bias": "first_stage_model.decoder.mid.attn_1.v.bias", + "first_stage_model.decoder.mid.attn_1.to_v.weight": "first_stage_model.decoder.mid.attn_1.v.weight", } open_clip_to_hk_clip_resblock = { "attn.out_proj.bias": "self_attn.out_proj.bias", @@ -214,22 +226,19 @@ def preprocess(state_dict): hf_clip_resblock_prefix = "cond_stage_model.transformer.text_model.encoder.layers." if name in open_clip_to_hf_clip_model: new_name = open_clip_to_hf_clip_model[name] - new_state_dict[new_name] = state_dict[name] print(f"preprocess {name} => {new_name}") - continue + name = new_name if name.startswith(open_clip_resblock_prefix): remain = name[len(open_clip_resblock_prefix):] idx = remain.split(".")[0] suffix = remain[len(idx)+1:] if suffix == "attn.in_proj_weight": - w = state_dict[name] w_q, w_k, w_v = w.chunk(3) for new_suffix, new_w in zip(["self_attn.q_proj.weight", "self_attn.k_proj.weight", "self_attn.v_proj.weight"], [w_q, w_k, w_v]): new_name = hf_clip_resblock_prefix + idx + "." + new_suffix new_state_dict[new_name] = new_w print(f"preprocess {name}{w.size()} => {new_name}{new_w.size()}") elif suffix == "attn.in_proj_bias": - w = state_dict[name] w_q, w_k, w_v = w.chunk(3) for new_suffix, new_w in zip(["self_attn.q_proj.bias", "self_attn.k_proj.bias", "self_attn.v_proj.bias"], [w_q, w_k, w_v]): new_name = hf_clip_resblock_prefix + idx + "." + new_suffix @@ -238,20 +247,27 @@ def preprocess(state_dict): else: new_suffix = open_clip_to_hk_clip_resblock[suffix] new_name = hf_clip_resblock_prefix + idx + "." + new_suffix - new_state_dict[new_name] = state_dict[name] + new_state_dict[new_name] = w print(f"preprocess {name} => {new_name}") continue # convert unet transformer linear to conv2d 1x1 if name.startswith("model.diffusion_model.") and (name.endswith("proj_in.weight") or name.endswith("proj_out.weight")): - w = state_dict[name] - if len(state_dict[name].shape) == 2: + if len(w.shape) == 2: new_w = w.unsqueeze(2).unsqueeze(3) new_state_dict[name] = new_w print(f"preprocess {name} {w.size()} => {name} {new_w.size()}") continue - new_state_dict[name] = state_dict[name] + # convert vae attn block linear to conv2d 1x1 + if name.startswith("first_stage_model.") and "attn_1" in name: + if len(w.shape) == 2: + new_w = w.unsqueeze(2).unsqueeze(3) + new_state_dict[name] = new_w + print(f"preprocess {name} {w.size()} => {name} {new_w.size()}") + continue + + new_state_dict[name] = w return new_state_dict def convert(model_path, out_type = None, out_file=None):