feat: add SD2.x support (#40)
This commit is contained in:
@@ -9,6 +9,9 @@ import safetensors.torch
|
||||
this_file_dir = os.path.dirname(__file__)
|
||||
vocab_dir = this_file_dir
|
||||
|
||||
SD1 = 0
|
||||
SD2 = 1
|
||||
|
||||
ggml_ftype_str_to_int = {
|
||||
"f32": 0,
|
||||
"f16": 1,
|
||||
@@ -155,6 +158,8 @@ unused_tensors = [
|
||||
"posterior_mean_coef1",
|
||||
"posterior_mean_coef2",
|
||||
"cond_stage_model.transformer.text_model.embeddings.position_ids",
|
||||
"cond_stage_model.model.logit_scale",
|
||||
"cond_stage_model.model.text_projection",
|
||||
"model_ema.decay",
|
||||
"model_ema.num_updates",
|
||||
"control_model",
|
||||
@@ -162,12 +167,8 @@ unused_tensors = [
|
||||
"embedding_manager"
|
||||
]
|
||||
|
||||
def convert(model_path, out_type = None, out_file=None):
|
||||
# load model
|
||||
with open(os.path.join(vocab_dir, "vocab.json"), encoding="utf-8") as f:
|
||||
clip_vocab = json.load(f)
|
||||
|
||||
state_dict = load_model_from_file(model_path)
|
||||
|
||||
def preprocess(state_dict):
|
||||
alphas_cumprod = state_dict.get("alphas_cumprod")
|
||||
if alphas_cumprod != None:
|
||||
# print((np.abs(get_alpha_comprod().numpy() - alphas_cumprod.numpy()) < 0.000001).all())
|
||||
@@ -176,11 +177,100 @@ def convert(model_path, out_type = None, out_file=None):
|
||||
print("no alphas_cumprod in file, generate new one")
|
||||
alphas_cumprod = get_alpha_comprod()
|
||||
state_dict["alphas_cumprod"] = alphas_cumprod
|
||||
|
||||
new_state_dict = {}
|
||||
for name in state_dict.keys():
|
||||
# ignore unused tensors
|
||||
if not isinstance(state_dict[name], torch.Tensor):
|
||||
continue
|
||||
skip = False
|
||||
for unused_tensor in unused_tensors:
|
||||
if name.startswith(unused_tensor):
|
||||
skip = True
|
||||
break
|
||||
if skip:
|
||||
continue
|
||||
|
||||
# convert open_clip to hf CLIPTextModel (for SD2.x)
|
||||
open_clip_to_hf_clip_model = {
|
||||
"cond_stage_model.model.ln_final.bias": "cond_stage_model.transformer.text_model.final_layer_norm.bias",
|
||||
"cond_stage_model.model.ln_final.weight": "cond_stage_model.transformer.text_model.final_layer_norm.weight",
|
||||
"cond_stage_model.model.positional_embedding": "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight",
|
||||
"cond_stage_model.model.token_embedding.weight": "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight",
|
||||
}
|
||||
open_clip_to_hk_clip_resblock = {
|
||||
"attn.out_proj.bias": "self_attn.out_proj.bias",
|
||||
"attn.out_proj.weight": "self_attn.out_proj.weight",
|
||||
"ln_1.bias": "layer_norm1.bias",
|
||||
"ln_1.weight": "layer_norm1.weight",
|
||||
"ln_2.bias": "layer_norm2.bias",
|
||||
"ln_2.weight": "layer_norm2.weight",
|
||||
"mlp.c_fc.bias": "mlp.fc1.bias",
|
||||
"mlp.c_fc.weight": "mlp.fc1.weight",
|
||||
"mlp.c_proj.bias": "mlp.fc2.bias",
|
||||
"mlp.c_proj.weight": "mlp.fc2.weight",
|
||||
}
|
||||
open_clip_resblock_prefix = "cond_stage_model.model.transformer.resblocks."
|
||||
hf_clip_resblock_prefix = "cond_stage_model.transformer.text_model.encoder.layers."
|
||||
if name in open_clip_to_hf_clip_model:
|
||||
new_name = open_clip_to_hf_clip_model[name]
|
||||
new_state_dict[new_name] = state_dict[name]
|
||||
print(f"preprocess {name} => {new_name}")
|
||||
continue
|
||||
if name.startswith(open_clip_resblock_prefix):
|
||||
remain = name[len(open_clip_resblock_prefix):]
|
||||
idx = remain.split(".")[0]
|
||||
suffix = remain[len(idx)+1:]
|
||||
if suffix == "attn.in_proj_weight":
|
||||
w = state_dict[name]
|
||||
w_q, w_k, w_v = w.chunk(3)
|
||||
for new_suffix, new_w in zip(["self_attn.q_proj.weight", "self_attn.k_proj.weight", "self_attn.v_proj.weight"], [w_q, w_k, w_v]):
|
||||
new_name = hf_clip_resblock_prefix + idx + "." + new_suffix
|
||||
new_state_dict[new_name] = new_w
|
||||
print(f"preprocess {name}{w.size()} => {new_name}{new_w.size()}")
|
||||
elif suffix == "attn.in_proj_bias":
|
||||
w = state_dict[name]
|
||||
w_q, w_k, w_v = w.chunk(3)
|
||||
for new_suffix, new_w in zip(["self_attn.q_proj.bias", "self_attn.k_proj.bias", "self_attn.v_proj.bias"], [w_q, w_k, w_v]):
|
||||
new_name = hf_clip_resblock_prefix + idx + "." + new_suffix
|
||||
new_state_dict[new_name] = new_w
|
||||
print(f"preprocess {name}{w.size()} => {new_name}{new_w.size()}")
|
||||
else:
|
||||
new_suffix = open_clip_to_hk_clip_resblock[suffix]
|
||||
new_name = hf_clip_resblock_prefix + idx + "." + new_suffix
|
||||
new_state_dict[new_name] = state_dict[name]
|
||||
print(f"preprocess {name} => {new_name}")
|
||||
continue
|
||||
|
||||
# convert unet transformer linear to conv2d 1x1
|
||||
if name.startswith("model.diffusion_model.") and (name.endswith("proj_in.weight") or name.endswith("proj_out.weight")):
|
||||
w = state_dict[name]
|
||||
if len(state_dict[name].shape) == 2:
|
||||
new_w = w.unsqueeze(2).unsqueeze(3)
|
||||
new_state_dict[name] = new_w
|
||||
print(f"preprocess {name} {w.size()} => {name} {new_w.size()}")
|
||||
continue
|
||||
|
||||
new_state_dict[name] = state_dict[name]
|
||||
return new_state_dict
|
||||
|
||||
def convert(model_path, out_type = None, out_file=None):
|
||||
# load model
|
||||
with open(os.path.join(vocab_dir, "vocab.json"), encoding="utf-8") as f:
|
||||
clip_vocab = json.load(f)
|
||||
|
||||
state_dict = load_model_from_file(model_path)
|
||||
model_type = SD1
|
||||
if "cond_stage_model.model.token_embedding.weight" in state_dict.keys():
|
||||
model_type = SD2
|
||||
print("Stable diffuison 2.x")
|
||||
else:
|
||||
print("Stable diffuison 1.x")
|
||||
state_dict = preprocess(state_dict)
|
||||
|
||||
# output option
|
||||
if out_type == None:
|
||||
weight = state_dict["cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.weight"].numpy()
|
||||
weight = state_dict["model.diffusion_model.input_blocks.0.0.weight"].numpy()
|
||||
if weight.dtype == np.float32:
|
||||
out_type = "f32"
|
||||
elif weight.dtype == np.float16:
|
||||
@@ -198,8 +288,9 @@ def convert(model_path, out_type = None, out_file=None):
|
||||
with open(out_file, "wb") as file:
|
||||
# magic: ggml in hex
|
||||
file.write(struct.pack("i", 0x67676D6C))
|
||||
# out type
|
||||
file.write(struct.pack("i", ggml_ftype_str_to_int[out_type]))
|
||||
# model & file type
|
||||
ftype = (model_type << 16) | ggml_ftype_str_to_int[out_type]
|
||||
file.write(struct.pack("i", ftype))
|
||||
|
||||
# vocab
|
||||
byte_encoder = bytes_to_unicode()
|
||||
|
||||
Reference in New Issue
Block a user