feat: add SD2.x support (#40)

This commit is contained in:
leejet
2023-09-03 16:00:33 +08:00
committed by GitHub
parent c542a77a3f
commit 31e77e1573
3 changed files with 507 additions and 140 deletions

View File

@@ -9,6 +9,9 @@ import safetensors.torch
this_file_dir = os.path.dirname(__file__)
vocab_dir = this_file_dir
SD1 = 0
SD2 = 1
ggml_ftype_str_to_int = {
"f32": 0,
"f16": 1,
@@ -155,6 +158,8 @@ unused_tensors = [
"posterior_mean_coef1",
"posterior_mean_coef2",
"cond_stage_model.transformer.text_model.embeddings.position_ids",
"cond_stage_model.model.logit_scale",
"cond_stage_model.model.text_projection",
"model_ema.decay",
"model_ema.num_updates",
"control_model",
@@ -162,12 +167,8 @@ unused_tensors = [
"embedding_manager"
]
def convert(model_path, out_type = None, out_file=None):
# load model
with open(os.path.join(vocab_dir, "vocab.json"), encoding="utf-8") as f:
clip_vocab = json.load(f)
state_dict = load_model_from_file(model_path)
def preprocess(state_dict):
alphas_cumprod = state_dict.get("alphas_cumprod")
if alphas_cumprod != None:
# print((np.abs(get_alpha_comprod().numpy() - alphas_cumprod.numpy()) < 0.000001).all())
@@ -176,11 +177,100 @@ def convert(model_path, out_type = None, out_file=None):
print("no alphas_cumprod in file, generate new one")
alphas_cumprod = get_alpha_comprod()
state_dict["alphas_cumprod"] = alphas_cumprod
new_state_dict = {}
for name in state_dict.keys():
# ignore unused tensors
if not isinstance(state_dict[name], torch.Tensor):
continue
skip = False
for unused_tensor in unused_tensors:
if name.startswith(unused_tensor):
skip = True
break
if skip:
continue
# convert open_clip to hf CLIPTextModel (for SD2.x)
open_clip_to_hf_clip_model = {
"cond_stage_model.model.ln_final.bias": "cond_stage_model.transformer.text_model.final_layer_norm.bias",
"cond_stage_model.model.ln_final.weight": "cond_stage_model.transformer.text_model.final_layer_norm.weight",
"cond_stage_model.model.positional_embedding": "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight",
"cond_stage_model.model.token_embedding.weight": "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight",
}
open_clip_to_hk_clip_resblock = {
"attn.out_proj.bias": "self_attn.out_proj.bias",
"attn.out_proj.weight": "self_attn.out_proj.weight",
"ln_1.bias": "layer_norm1.bias",
"ln_1.weight": "layer_norm1.weight",
"ln_2.bias": "layer_norm2.bias",
"ln_2.weight": "layer_norm2.weight",
"mlp.c_fc.bias": "mlp.fc1.bias",
"mlp.c_fc.weight": "mlp.fc1.weight",
"mlp.c_proj.bias": "mlp.fc2.bias",
"mlp.c_proj.weight": "mlp.fc2.weight",
}
open_clip_resblock_prefix = "cond_stage_model.model.transformer.resblocks."
hf_clip_resblock_prefix = "cond_stage_model.transformer.text_model.encoder.layers."
if name in open_clip_to_hf_clip_model:
new_name = open_clip_to_hf_clip_model[name]
new_state_dict[new_name] = state_dict[name]
print(f"preprocess {name} => {new_name}")
continue
if name.startswith(open_clip_resblock_prefix):
remain = name[len(open_clip_resblock_prefix):]
idx = remain.split(".")[0]
suffix = remain[len(idx)+1:]
if suffix == "attn.in_proj_weight":
w = state_dict[name]
w_q, w_k, w_v = w.chunk(3)
for new_suffix, new_w in zip(["self_attn.q_proj.weight", "self_attn.k_proj.weight", "self_attn.v_proj.weight"], [w_q, w_k, w_v]):
new_name = hf_clip_resblock_prefix + idx + "." + new_suffix
new_state_dict[new_name] = new_w
print(f"preprocess {name}{w.size()} => {new_name}{new_w.size()}")
elif suffix == "attn.in_proj_bias":
w = state_dict[name]
w_q, w_k, w_v = w.chunk(3)
for new_suffix, new_w in zip(["self_attn.q_proj.bias", "self_attn.k_proj.bias", "self_attn.v_proj.bias"], [w_q, w_k, w_v]):
new_name = hf_clip_resblock_prefix + idx + "." + new_suffix
new_state_dict[new_name] = new_w
print(f"preprocess {name}{w.size()} => {new_name}{new_w.size()}")
else:
new_suffix = open_clip_to_hk_clip_resblock[suffix]
new_name = hf_clip_resblock_prefix + idx + "." + new_suffix
new_state_dict[new_name] = state_dict[name]
print(f"preprocess {name} => {new_name}")
continue
# convert unet transformer linear to conv2d 1x1
if name.startswith("model.diffusion_model.") and (name.endswith("proj_in.weight") or name.endswith("proj_out.weight")):
w = state_dict[name]
if len(state_dict[name].shape) == 2:
new_w = w.unsqueeze(2).unsqueeze(3)
new_state_dict[name] = new_w
print(f"preprocess {name} {w.size()} => {name} {new_w.size()}")
continue
new_state_dict[name] = state_dict[name]
return new_state_dict
def convert(model_path, out_type = None, out_file=None):
# load model
with open(os.path.join(vocab_dir, "vocab.json"), encoding="utf-8") as f:
clip_vocab = json.load(f)
state_dict = load_model_from_file(model_path)
model_type = SD1
if "cond_stage_model.model.token_embedding.weight" in state_dict.keys():
model_type = SD2
print("Stable diffuison 2.x")
else:
print("Stable diffuison 1.x")
state_dict = preprocess(state_dict)
# output option
if out_type == None:
weight = state_dict["cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.weight"].numpy()
weight = state_dict["model.diffusion_model.input_blocks.0.0.weight"].numpy()
if weight.dtype == np.float32:
out_type = "f32"
elif weight.dtype == np.float16:
@@ -198,8 +288,9 @@ def convert(model_path, out_type = None, out_file=None):
with open(out_file, "wb") as file:
# magic: ggml in hex
file.write(struct.pack("i", 0x67676D6C))
# out type
file.write(struct.pack("i", ggml_ftype_str_to_int[out_type]))
# model & file type
ftype = (model_type << 16) | ggml_ftype_str_to_int[out_type]
file.write(struct.pack("i", ftype))
# vocab
byte_encoder = bytes_to_unicode()