From 2b1bc064776c45a32268c21069a04e91933a1eae Mon Sep 17 00:00:00 2001 From: bssrdf Date: Fri, 22 Nov 2024 22:50:14 -0500 Subject: [PATCH] feat: add PhotoMaker Version 2 support (#358) * first attempt at updating to photomaker v2 * continue adding photomaker v2 modules * finishing the last few pieces for photomaker v2; id_embeds need to be done by a manual step and pass as an input file * added a name converter for Photomaker V2; build ok * more debugging underway * failing at cuda mat_mul * updated chunk_half to be more efficient; redo feedforward * fixed a bug: carefully using ggml_view_4d to get chunks of a tensor; strides need to be recalculated or set properly; still failing at soft_max cuda op * redo weight calculation and weight*v * fixed a bug now Photomaker V2 kinds of working * add python script for face detection (Photomaker V2 needs) * updated readme for photomaker * fixed a bug causing PMV1 crashing; both V1 and V2 work * fixed clean_input_ids for PMV2 * fixed a double counting bug in tokenize_with_trigger_token * updated photomaker readme * removed some commented code * improved reconstructing class word free prompt * changed reading id_embed to raw binary using existing load tensor function; this is more efficient than using model load and also makes it easier to work with sd server * minor clean up --------- Co-authored-by: bssrdf --- clip.hpp | 27 +- conditioner.hpp | 24 +- docs/photo_maker.md | 24 +- face_detect.py | 88 ++++++ ggml_extend.hpp | 8 +- model.cpp | 42 ++- model.h | 6 + pmid.hpp | 624 +++++++++++++++++++++++++++++++++++++++++-- stable-diffusion.cpp | 37 ++- util.cpp | 18 ++ util.h | 2 +- 11 files changed, 844 insertions(+), 56 deletions(-) create mode 100644 face_detect.py diff --git a/clip.hpp b/clip.hpp index e0d846a..7c27058 100644 --- a/clip.hpp +++ b/clip.hpp @@ -343,6 +343,14 @@ public: } } + std::string clean_up_tokenization(std::string &text){ + + std::regex pattern(R"( ,)"); + // Replace " ," with "," + std::string result = std::regex_replace(text, pattern, ","); + return result; + } + std::string decode(const std::vector& tokens) { std::string text = ""; for (int t : tokens) { @@ -351,8 +359,12 @@ public: std::u32string ts = decoder[t]; // printf("%d, %s \n", t, utf32_to_utf8(ts).c_str()); std::string s = utf32_to_utf8(ts); - if (s.length() >= 4 && ends_with(s, "")) { - text += " " + s.replace(s.length() - 4, s.length() - 1, ""); + if (s.length() >= 4 ){ + if(ends_with(s, "")) { + text += s.replace(s.length() - 4, s.length() - 1, "") + " "; + }else{ + text += s; + } } else { text += " " + s; } @@ -364,6 +376,7 @@ public: // std::string s((char *)bytes.data()); // std::string s = ""; + text = clean_up_tokenization(text); return trim(text); } @@ -755,7 +768,8 @@ public: blocks["post_layernorm"] = std::shared_ptr(new LayerNorm(hidden_size)); } - struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* pixel_values, bool return_pooled = true) { + struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* pixel_values, + bool return_pooled = true) { // pixel_values: [N, num_channels, image_size, image_size] auto embeddings = std::dynamic_pointer_cast(blocks["embeddings"]); auto pre_layernorm = std::dynamic_pointer_cast(blocks["pre_layernorm"]); @@ -765,14 +779,17 @@ public: auto x = embeddings->forward(ctx, pixel_values); // [N, num_positions, embed_dim] x = pre_layernorm->forward(ctx, x); x = encoder->forward(ctx, x, -1, false); + // print_ggml_tensor(x, true, "ClipVisionModel x: "); + auto last_hidden_state = x; x = post_layernorm->forward(ctx, x); // [N, n_token, hidden_size] - GGML_ASSERT(x->ne[3] == 1); + GGML_ASSERT(x->ne[3] == 1); if (return_pooled) { ggml_tensor* pooled = ggml_cont(ctx, ggml_view_2d(ctx, x, x->ne[0], x->ne[2], x->nb[2], 0)); return pooled; // [N, hidden_size] } else { - return x; // [N, n_token, hidden_size] + // return x; // [N, n_token, hidden_size] + return last_hidden_state; // [N, n_token, hidden_size] } } }; diff --git a/conditioner.hpp b/conditioner.hpp index ea02d37..065f352 100644 --- a/conditioner.hpp +++ b/conditioner.hpp @@ -4,6 +4,7 @@ #include "clip.hpp" #include "t5.hpp" + struct SDCondition { struct ggml_tensor* c_crossattn = NULL; // aka context struct ggml_tensor* c_vector = NULL; // aka y @@ -44,6 +45,7 @@ struct Conditioner { // Ref: https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/cad87bf4e3e0b0a759afa94e933527c3123d59bc/modules/sd_hijack_clip.py#L283 struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { SDVersion version = VERSION_SD1; + PMVersion pm_version = VERSION_1; CLIPTokenizer tokenizer; ggml_type wtype; std::shared_ptr text_model; @@ -59,8 +61,9 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { ggml_type wtype, const std::string& embd_dir, SDVersion version = VERSION_SD1, + PMVersion pv = VERSION_1, int clip_skip = -1) - : version(version), tokenizer(version == VERSION_SD2 ? 0 : 49407), embd_dir(embd_dir), wtype(wtype) { + : version(version), pm_version(pv), tokenizer(version == VERSION_SD2 ? 0 : 49407), embd_dir(embd_dir), wtype(wtype) { if (clip_skip <= 0) { clip_skip = 1; if (version == VERSION_SD2 || version == VERSION_SDXL) { @@ -159,7 +162,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { tokenize_with_trigger_token(std::string text, int num_input_imgs, int32_t image_token, - bool padding = false) { + bool padding = false){ return tokenize_with_trigger_token(text, num_input_imgs, image_token, text_model->model.n_token, padding); } @@ -268,7 +271,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { std::vector clean_input_ids_tmp; for (uint32_t i = 0; i < class_token_index[0]; i++) clean_input_ids_tmp.push_back(clean_input_ids[i]); - for (uint32_t i = 0; i < num_input_imgs; i++) + for (uint32_t i = 0; i < (pm_version == VERSION_2 ? 2*num_input_imgs: num_input_imgs); i++) clean_input_ids_tmp.push_back(class_token); for (uint32_t i = class_token_index[0] + 1; i < clean_input_ids.size(); i++) clean_input_ids_tmp.push_back(clean_input_ids[i]); @@ -279,13 +282,16 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { tokens.insert(tokens.end(), clean_input_ids.begin(), clean_input_ids.end()); weights.insert(weights.end(), clean_input_ids.size(), curr_weight); } - tokens.insert(tokens.begin(), tokenizer.BOS_TOKEN_ID); - weights.insert(weights.begin(), 1.0); + // BUG!! double couting, pad_tokens will add BOS at the beginning + // tokens.insert(tokens.begin(), tokenizer.BOS_TOKEN_ID); + // weights.insert(weights.begin(), 1.0); tokenizer.pad_tokens(tokens, weights, max_length, padding); - + int offset = pm_version == VERSION_2 ? 2*num_input_imgs: num_input_imgs; for (uint32_t i = 0; i < tokens.size(); i++) { - if (class_idx + 1 <= i && i < class_idx + 1 + num_input_imgs) + // if (class_idx + 1 <= i && i < class_idx + 1 + 2*num_input_imgs) // photomaker V2 has num_tokens(=2)*num_input_imgs + if (class_idx + 1 <= i && i < class_idx + 1 + offset) // photomaker V2 has num_tokens(=2)*num_input_imgs + // hardcode for now class_token_mask.push_back(true); else class_token_mask.push_back(false); @@ -530,7 +536,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { int height, int num_input_imgs, int adm_in_channels = -1, - bool force_zero_embeddings = false) { + bool force_zero_embeddings = false){ auto image_tokens = convert_token_to_id(trigger_word); // if(image_tokens.size() == 1){ // printf(" image token id is: %d \n", image_tokens[0]); @@ -958,7 +964,7 @@ struct SD3CLIPEmbedder : public Conditioner { int height, int num_input_imgs, int adm_in_channels = -1, - bool force_zero_embeddings = false) { + bool force_zero_embeddings = false){ GGML_ASSERT(0 && "Not implemented yet!"); } diff --git a/docs/photo_maker.md b/docs/photo_maker.md index b69ad97..8305a33 100644 --- a/docs/photo_maker.md +++ b/docs/photo_maker.md @@ -29,4 +29,26 @@ Example: ```bash bin/sd -m ../models/sdxlUnstableDiffusers_v11.safetensors --vae ../models/sdxl_vae.safetensors --stacked-id-embd-dir ../models/photomaker-v1.safetensors --input-id-images-dir ../assets/photomaker_examples/scarletthead_woman -p "a girl img, retro futurism, retro game art style but extremely beautiful, intricate details, masterpiece, best quality, space-themed, cosmic, celestial, stars, galaxies, nebulas, planets, science fiction, highly detailed" -n "realistic, photo-realistic, worst quality, greyscale, bad anatomy, bad hands, error, text" --cfg-scale 5.0 --sampling-method euler -H 1024 -W 1024 --style-ratio 10 --vae-on-cpu -o output.png -``` \ No newline at end of file +``` + +## PhotoMaker Version 2 + +[PhotoMaker Version 2 (PMV2)](https://github.com/TencentARC/PhotoMaker/blob/main/README_pmv2.md) has some key improvements. Unfortunately it has a very heavy dependency which makes running it a bit involved in ```SD.cpp```. + +Running PMV2 is now a two-step process: + +- Run a python script ```face_detect.py``` to obtain **id_embeds** for the given input images +``` +python face_detect.py input_image_dir +``` +An ```id_embeds.safetensors``` file will be generated in ```input_images_dir``` + +**Note: this step is only needed to run once; the same ```id_embeds``` can be reused** + +- Run the same command as in version 1 but replacing ```photomaker-v1.safetensors``` with ```photomaker-v2.safetensors```. + + You can download ```photomaker-v2.safetensors``` from [here](https://huggingface.co/bssrdf/PhotoMakerV2) + +- All the command line parameters from Version 1 remain the same for Version 2 + + diff --git a/face_detect.py b/face_detect.py new file mode 100644 index 0000000..7131af3 --- /dev/null +++ b/face_detect.py @@ -0,0 +1,88 @@ +import os +import sys + +import numpy as np +import torch +from diffusers.utils import load_image +# pip install insightface==0.7.3 +from insightface.app import FaceAnalysis +from insightface.data import get_image as ins_get_image +from safetensors.torch import save_file + +### +# https://github.com/cubiq/ComfyUI_IPAdapter_plus/issues/165#issue-2055829543 +### +class FaceAnalysis2(FaceAnalysis): + # NOTE: allows setting det_size for each detection call. + # the model allows it but the wrapping code from insightface + # doesn't show it, and people end up loading duplicate models + # for different sizes where there is absolutely no need to + def get(self, img, max_num=0, det_size=(640, 640)): + if det_size is not None: + self.det_model.input_size = det_size + + return super().get(img, max_num) + +def analyze_faces(face_analysis: FaceAnalysis, img_data: np.ndarray, det_size=(640, 640)): + # NOTE: try detect faces, if no faces detected, lower det_size until it does + detection_sizes = [None] + [(size, size) for size in range(640, 256, -64)] + [(256, 256)] + + for size in detection_sizes: + faces = face_analysis.get(img_data, det_size=size) + if len(faces) > 0: + return faces + + return [] + +if __name__ == "__main__": + #face_detector = FaceAnalysis2(providers=['CUDAExecutionProvider'], allowed_modules=['detection', 'recognition']) + face_detector = FaceAnalysis2(providers=['CPUExecutionProvider'], allowed_modules=['detection', 'recognition']) + face_detector.prepare(ctx_id=0, det_size=(640, 640)) + #input_folder_name = './scarletthead_woman' + input_folder_name = sys.argv[1] + image_basename_list = os.listdir(input_folder_name) + image_path_list = sorted([os.path.join(input_folder_name, basename) for basename in image_basename_list]) + + input_id_images = [] + for image_path in image_path_list: + input_id_images.append(load_image(image_path)) + + id_embed_list = [] + + for img in input_id_images: + img = np.array(img) + img = img[:, :, ::-1] + faces = analyze_faces(face_detector, img) + if len(faces) > 0: + id_embed_list.append(torch.from_numpy((faces[0]['embedding']))) + + if len(id_embed_list) == 0: + raise ValueError(f"No face detected in input image pool") + + id_embeds = torch.stack(id_embed_list) + + # for r in id_embeds: + # print(r) + # #torch.save(id_embeds, input_folder_name+'/id_embeds.pt'); + # weights = dict() + # weights["id_embeds"] = id_embeds + # save_file(weights, input_folder_name+'/id_embeds.safetensors') + + binary_data = id_embeds.numpy().tobytes() + two = 4 + zero = 0 + one = 1 + tensor_name = "id_embeds" +# Write binary data to a file + with open(input_folder_name+'/id_embeds.bin', "wb") as f: + f.write(two.to_bytes(4, byteorder='little')) + f.write((len(tensor_name)).to_bytes(4, byteorder='little')) + f.write(zero.to_bytes(4, byteorder='little')) + f.write((id_embeds.shape[1]).to_bytes(4, byteorder='little')) + f.write((id_embeds.shape[0]).to_bytes(4, byteorder='little')) + f.write(one.to_bytes(4, byteorder='little')) + f.write(one.to_bytes(4, byteorder='little')) + f.write(tensor_name.encode('ascii')) + f.write(binary_data) + + \ No newline at end of file diff --git a/ggml_extend.hpp b/ggml_extend.hpp index e50137d..8dea410 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -1047,6 +1047,11 @@ public: params_buffer_size / (1024.0 * 1024.0), ggml_backend_is_cpu(backend) ? "RAM" : "VRAM", num_tensors); + // printf("%s params backend buffer size = % 6.2f MB(%s) (%i tensors)\n", + // get_desc().c_str(), + // params_buffer_size / (1024.0 * 1024.0), + // ggml_backend_is_cpu(backend) ? "RAM" : "VRAM", + // num_tensors); return true; } @@ -1216,7 +1221,8 @@ protected: params["weight"] = ggml_new_tensor_2d(ctx, wtype, in_features, out_features); if (bias) { params["bias"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, out_features); - } + } + } public: diff --git a/model.cpp b/model.cpp index ae1c097..5f1e6e1 100644 --- a/model.cpp +++ b/model.cpp @@ -146,6 +146,33 @@ std::unordered_map vae_decoder_name_map = { {"first_stage_model.decoder.mid.attn_1.to_v.weight", "first_stage_model.decoder.mid.attn_1.v.weight"}, }; +std::unordered_map pmid_v2_name_map = { + {"pmid.qformer_perceiver.perceiver_resampler.layers.0.1.1.weight", + "pmid.qformer_perceiver.perceiver_resampler.layers.0.1.1.fc1.weight"}, + {"pmid.qformer_perceiver.perceiver_resampler.layers.0.1.3.weight", + "pmid.qformer_perceiver.perceiver_resampler.layers.0.1.1.fc2.weight"}, + {"pmid.qformer_perceiver.perceiver_resampler.layers.1.1.1.weight", + "pmid.qformer_perceiver.perceiver_resampler.layers.1.1.1.fc1.weight"}, + {"pmid.qformer_perceiver.perceiver_resampler.layers.1.1.3.weight", + "pmid.qformer_perceiver.perceiver_resampler.layers.1.1.1.fc2.weight"}, + {"pmid.qformer_perceiver.perceiver_resampler.layers.2.1.1.weight", + "pmid.qformer_perceiver.perceiver_resampler.layers.2.1.1.fc1.weight"}, + {"pmid.qformer_perceiver.perceiver_resampler.layers.2.1.3.weight", + "pmid.qformer_perceiver.perceiver_resampler.layers.2.1.1.fc2.weight"}, + {"pmid.qformer_perceiver.perceiver_resampler.layers.3.1.1.weight", + "pmid.qformer_perceiver.perceiver_resampler.layers.3.1.1.fc1.weight"}, + {"pmid.qformer_perceiver.perceiver_resampler.layers.3.1.3.weight", + "pmid.qformer_perceiver.perceiver_resampler.layers.3.1.1.fc2.weight"}, + {"pmid.qformer_perceiver.token_proj.0.bias", + "pmid.qformer_perceiver.token_proj.fc1.bias"}, + {"pmid.qformer_perceiver.token_proj.2.bias", + "pmid.qformer_perceiver.token_proj.fc2.bias"}, + {"pmid.qformer_perceiver.token_proj.0.weight", + "pmid.qformer_perceiver.token_proj.fc1.weight"}, + {"pmid.qformer_perceiver.token_proj.2.weight", + "pmid.qformer_perceiver.token_proj.fc2.weight"}, +}; + std::string convert_open_clip_to_hf_clip(const std::string& name) { std::string new_name = name; std::string prefix; @@ -212,6 +239,13 @@ std::string convert_vae_decoder_name(const std::string& name) { return name; } +std::string convert_pmid_v2_name(const std::string& name) { + if (pmid_v2_name_map.find(name) != pmid_v2_name_map.end()) { + return pmid_v2_name_map[name]; + } + return name; +} + /* If not a SDXL LoRA the unet" prefix will have already been replaced by this * point and "te2" and "te1" don't seem to appear in non-SDXL only "te_" */ std::string convert_sdxl_lora_name(std::string tensor_name) { @@ -443,6 +477,8 @@ std::string convert_tensor_name(std::string name) { new_name = convert_open_clip_to_hf_clip(name); } else if (starts_with(name, "first_stage_model.decoder")) { new_name = convert_vae_decoder_name(name); + } else if (starts_with(name, "pmid.qformer_perceiver")) { + new_name = convert_pmid_v2_name(name); } else if (starts_with(name, "control_model.")) { // for controlnet pth models size_t pos = name.find('.'); if (pos != std::string::npos) { @@ -1015,7 +1051,7 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const } TensorStorage tensor_storage(prefix + name, type, ne, n_dims, file_index, ST_HEADER_SIZE_LEN + header_size_ + begin); - tensor_storage.reverse_ne(); + tensor_storage.reverse_ne(); size_t tensor_data_size = end - begin; @@ -1362,7 +1398,7 @@ bool ModelLoader::parse_data_pkl(uint8_t* buffer, reader.tensor_storage.reverse_ne(); reader.tensor_storage.file_index = file_index; // if(strcmp(prefix.c_str(), "scarlett") == 0) - // printf(" got tensor %s \n ", reader.tensor_storage.name.c_str()); + // printf(" ZIP got tensor %s \n ", reader.tensor_storage.name.c_str()); reader.tensor_storage.name = prefix + reader.tensor_storage.name; tensor_storages.push_back(reader.tensor_storage); // LOG_DEBUG("%s", reader.tensor_storage.name.c_str()); @@ -1398,7 +1434,9 @@ bool ModelLoader::init_from_ckpt_file(const std::string& file_path, const std::s std::string name = zip_entry_name(zip); size_t pos = name.find("data.pkl"); if (pos != std::string::npos) { + std::string dir = name.substr(0, pos); + printf("ZIP %d, name = %s, dir = %s \n", i, name.c_str(), dir.c_str()); void* pkl_data = NULL; size_t pkl_size; zip_entry_read(zip, &pkl_data, &pkl_size); diff --git a/model.h b/model.h index 924d697..77841e8 100644 --- a/model.h +++ b/model.h @@ -31,6 +31,11 @@ enum SDVersion { VERSION_COUNT, }; +enum PMVersion { + VERSION_1, + VERSION_2, +}; + struct TensorStorage { std::string name; ggml_type type = GGML_TYPE_F32; @@ -162,6 +167,7 @@ public: bool load_tensors(std::map& tensors, ggml_backend_t backend, std::set ignore_tensors = {}); + bool save_to_gguf_file(const std::string& file_path, ggml_type type); bool tensor_should_be_converted(const TensorStorage& tensor_storage, ggml_type type); int64_t get_params_mem_size(ggml_backend_t backend, ggml_type type = GGML_TYPE_COUNT); diff --git a/pmid.hpp b/pmid.hpp index 381050f..bde03cc 100644 --- a/pmid.hpp +++ b/pmid.hpp @@ -6,6 +6,7 @@ #include "clip.hpp" #include "lora.hpp" + struct FuseBlock : public GGMLBlock { // network hparams int in_dim; @@ -42,6 +43,383 @@ public: } }; +/* +class QFormerPerceiver(nn.Module): + def __init__(self, id_embeddings_dim, cross_attention_dim, num_tokens, embedding_dim=1024, use_residual=True, ratio=4): + super().__init__() + + self.num_tokens = num_tokens + self.cross_attention_dim = cross_attention_dim + self.use_residual = use_residual + print(cross_attention_dim*num_tokens) + self.token_proj = nn.Sequential( + nn.Linear(id_embeddings_dim, id_embeddings_dim*ratio), + nn.GELU(), + nn.Linear(id_embeddings_dim*ratio, cross_attention_dim*num_tokens), + ) + self.token_norm = nn.LayerNorm(cross_attention_dim) + self.perceiver_resampler = FacePerceiverResampler( + dim=cross_attention_dim, + depth=4, + dim_head=128, + heads=cross_attention_dim // 128, + embedding_dim=embedding_dim, + output_dim=cross_attention_dim, + ff_mult=4, + ) + + def forward(self, x, last_hidden_state): + x = self.token_proj(x) + x = x.reshape(-1, self.num_tokens, self.cross_attention_dim) + x = self.token_norm(x) # cls token + out = self.perceiver_resampler(x, last_hidden_state) # retrieve from patch tokens + if self.use_residual: # TODO: if use_residual is not true + out = x + 1.0 * out + return out +*/ + + +struct PMFeedForward : public GGMLBlock { + // network hparams + int dim; + +public: + PMFeedForward(int d, int multi=4) + : dim(d) { + int inner_dim = dim * multi; + blocks["0"] = std::shared_ptr(new LayerNorm(dim)); + blocks["1"] = std::shared_ptr(new Mlp(dim, inner_dim, dim, false)); + } + + struct ggml_tensor* forward(struct ggml_context* ctx, + struct ggml_tensor* x){ + + auto norm = std::dynamic_pointer_cast(blocks["0"]); + auto ff = std::dynamic_pointer_cast(blocks["1"]); + + x = norm->forward(ctx, x); + x = ff->forward(ctx, x); + return x; + } + +}; + +struct PerceiverAttention : public GGMLBlock { + // network hparams + float scale; // = dim_head**-0.5 + int dim_head; // = dim_head + int heads; // = heads +public: + PerceiverAttention(int dim, int dim_h=64, int h=8) + : scale(powf(dim_h, -0.5)), dim_head(dim_h), heads(h) { + + int inner_dim = dim_head * heads; + blocks["norm1"] = std::shared_ptr(new LayerNorm(dim)); + blocks["norm2"] = std::shared_ptr(new LayerNorm(dim)); + blocks["to_q"] = std::shared_ptr(new Linear(dim, inner_dim, false)); + blocks["to_kv"] = std::shared_ptr(new Linear(dim, inner_dim*2, false)); + blocks["to_out"] = std::shared_ptr(new Linear(inner_dim, dim, false)); + } + + struct ggml_tensor* reshape_tensor(struct ggml_context* ctx, + struct ggml_tensor* x, + int heads) { + int64_t ne[4]; + for(int i = 0; i < 4; ++i) + ne[i] = x->ne[i]; + // print_ggml_tensor(x, true, "PerceiverAttention reshape x 0: "); + // printf("heads = %d \n", heads); + // x = ggml_view_4d(ctx, x, x->ne[0], x->ne[1], heads, x->ne[2]/heads, + // x->nb[1], x->nb[2], x->nb[3], 0); + x = ggml_reshape_4d(ctx, x, x->ne[0]/heads, heads, x->ne[1], x->ne[2]); + // x = ggml_view_4d(ctx, x, x->ne[0]/heads, heads, x->ne[1], x->ne[2], + // x->nb[1], x->nb[2], x->nb[3], 0); + // x = ggml_cont(ctx, x); + x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); + // print_ggml_tensor(x, true, "PerceiverAttention reshape x 1: "); + // x = ggml_reshape_4d(ctx, x, ne[0], heads, ne[1], ne[2]/heads); + return x; + } + + std::vector chunk_half(struct ggml_context* ctx, + struct ggml_tensor* x){ + + auto tlo = ggml_view_4d(ctx, x, x->ne[0]/2, x->ne[1], x->ne[2], x->ne[3], x->nb[1], x->nb[2], x->nb[3], 0); + auto tli = ggml_view_4d(ctx, x, x->ne[0]/2, x->ne[1], x->ne[2], x->ne[3], x->nb[1], x->nb[2], x->nb[3], x->nb[0]*x->ne[0]/2); + return {ggml_cont(ctx, tlo), + ggml_cont(ctx, tli)}; + + } + + struct ggml_tensor* forward(struct ggml_context* ctx, + struct ggml_tensor* x, + struct ggml_tensor* latents){ + + // x (torch.Tensor): image features + // shape (b, n1, D) + // latent (torch.Tensor): latent features + // shape (b, n2, D) + int64_t ne[4]; + for(int i = 0; i < 4; ++i) + ne[i] = latents->ne[i]; + + auto norm1 = std::dynamic_pointer_cast(blocks["norm1"]); + auto norm2 = std::dynamic_pointer_cast(blocks["norm2"]); + x = norm1->forward(ctx, x); + latents = norm2->forward(ctx, latents); + auto to_q = std::dynamic_pointer_cast(blocks["to_q"]); + auto q = to_q->forward(ctx, latents); + + auto kv_input = ggml_concat(ctx, x, latents, 1); + auto to_kv = std::dynamic_pointer_cast(blocks["to_kv"]); + auto kv = to_kv->forward(ctx, kv_input); + auto k = ggml_view_4d(ctx, kv, kv->ne[0]/2, kv->ne[1], kv->ne[2], kv->ne[3], kv->nb[1]/2, kv->nb[2]/2, kv->nb[3]/2, 0); + auto v = ggml_view_4d(ctx, kv, kv->ne[0]/2, kv->ne[1], kv->ne[2], kv->ne[3], kv->nb[1]/2, kv->nb[2]/2, kv->nb[3]/2, kv->nb[0]*(kv->ne[0]/2)); + k = ggml_cont(ctx, k); + v = ggml_cont(ctx, v); + q = reshape_tensor(ctx, q, heads); + k = reshape_tensor(ctx, k, heads); + v = reshape_tensor(ctx, v, heads); + scale = 1.f / sqrt(sqrt((float)dim_head)); + k = ggml_scale_inplace(ctx, k, scale); + q = ggml_scale_inplace(ctx, q, scale); + // auto weight = ggml_mul_mat(ctx, q, k); + auto weight = ggml_mul_mat(ctx, k, q); // NOTE order of mul is opposite to pytorch + + // GGML's softmax() is equivalent to pytorch's softmax(x, dim=-1) + // in this case, dimension along which Softmax will be computed is the last dim + // in torch and the first dim in GGML, consistent with the convention that pytorch's + // last dimension (varying most rapidly) corresponds to GGML's first (varying most rapidly). + // weight = ggml_soft_max(ctx, weight); + weight = ggml_soft_max_inplace(ctx, weight); + v = ggml_cont(ctx, ggml_transpose(ctx, v)); + // auto out = ggml_mul_mat(ctx, weight, v); + auto out = ggml_mul_mat(ctx, v, weight); // NOTE order of mul is opposite to pytorch + out = ggml_cont(ctx, ggml_permute(ctx, out, 0, 2, 1, 3)); + out = ggml_reshape_3d(ctx, out, ne[0], ne[1], ggml_nelements(out)/(ne[0]*ne[1])); + auto to_out = std::dynamic_pointer_cast(blocks["to_out"]); + out = to_out->forward(ctx, out); + return out; + } +}; + +struct FacePerceiverResampler : public GGMLBlock { + // network hparams + int depth; +public: + FacePerceiverResampler( int dim=768, + int d=4, + int dim_head=64, + int heads=16, + int embedding_dim=1280, + int output_dim=768, + int ff_mult=4) + : depth(d) { + blocks["proj_in"] = std::shared_ptr(new Linear(embedding_dim, dim, true)); + blocks["proj_out"] = std::shared_ptr(new Linear(dim, output_dim, true)); + blocks["norm_out"] = std::shared_ptr(new LayerNorm(output_dim)); + + for (int i = 0; i < depth; i++) { + std::string name = "layers." + std::to_string(i) + ".0"; + blocks[name] = std::shared_ptr(new PerceiverAttention(dim, dim_head, heads)); + name = "layers." + std::to_string(i) + ".1"; + blocks[name] = std::shared_ptr(new PMFeedForward(dim, ff_mult)); + } + } + + struct ggml_tensor* forward(struct ggml_context* ctx, + struct ggml_tensor* latents, + struct ggml_tensor* x){ + // x: [N, channels, h, w] + auto proj_in = std::dynamic_pointer_cast(blocks["proj_in"]); + auto proj_out = std::dynamic_pointer_cast(blocks["proj_out"]); + auto norm_out = std::dynamic_pointer_cast(blocks["norm_out"]); + + x = proj_in->forward(ctx, x); + for (int i = 0; i < depth; i++) { + std::string name = "layers." + std::to_string(i) + ".0"; + auto attn = std::dynamic_pointer_cast(blocks[name]); + name = "layers." + std::to_string(i) + ".1"; + auto ff = std::dynamic_pointer_cast(blocks[name]); + auto t = attn->forward(ctx, x, latents); + latents = ggml_add(ctx, t, latents); + t = ff->forward(ctx, latents); + latents = ggml_add(ctx, t, latents); + } + latents = proj_out->forward(ctx, latents); + latents = norm_out->forward(ctx, latents); + return latents; + } +}; + +struct QFormerPerceiver : public GGMLBlock { + // network hparams + int num_tokens; + int cross_attention_dim; + bool use_residul; + + +public: + QFormerPerceiver(int id_embeddings_dim, int cross_attention_d, int num_t, int embedding_dim=1024, + bool use_r=true, int ratio=4) + : cross_attention_dim(cross_attention_d), num_tokens(num_t), use_residul(use_r) { + blocks["token_proj"] = std::shared_ptr(new Mlp(id_embeddings_dim, + id_embeddings_dim*ratio, + cross_attention_dim*num_tokens, + true)); + blocks["token_norm"] = std::shared_ptr(new LayerNorm(cross_attention_d)); + blocks["perceiver_resampler"] = std::shared_ptr(new FacePerceiverResampler( + cross_attention_dim, + 4, + 128, + cross_attention_dim / 128, + embedding_dim, + cross_attention_dim, + 4)); + } + + /* + def forward(self, x, last_hidden_state): + x = self.token_proj(x) + x = x.reshape(-1, self.num_tokens, self.cross_attention_dim) + x = self.token_norm(x) # cls token + out = self.perceiver_resampler(x, last_hidden_state) # retrieve from patch tokens + if self.use_residual: # TODO: if use_residual is not true + out = x + 1.0 * out + return out + */ + + struct ggml_tensor* forward(struct ggml_context* ctx, + struct ggml_tensor* x, + struct ggml_tensor* last_hidden_state){ + // x: [N, channels, h, w] + auto token_proj = std::dynamic_pointer_cast(blocks["token_proj"]); + auto token_norm = std::dynamic_pointer_cast(blocks["token_norm"]); + auto perceiver_resampler = std::dynamic_pointer_cast(blocks["perceiver_resampler"]); + + x = token_proj->forward(ctx, x); + int64_t nel = ggml_nelements(x); + x = ggml_reshape_3d(ctx, x, cross_attention_dim, num_tokens, nel/(cross_attention_dim*num_tokens)); + x = token_norm->forward(ctx, x); + struct ggml_tensor* out = perceiver_resampler->forward(ctx, x, last_hidden_state); + if(use_residul) + out = ggml_add(ctx, x, out); + return out; + } +}; + +/* +class FacePerceiverResampler(torch.nn.Module): + def __init__( + self, + *, + dim=768, + depth=4, + dim_head=64, + heads=16, + embedding_dim=1280, + output_dim=768, + ff_mult=4, + ): + super().__init__() + + self.proj_in = torch.nn.Linear(embedding_dim, dim) + self.proj_out = torch.nn.Linear(dim, output_dim) + self.norm_out = torch.nn.LayerNorm(output_dim) + self.layers = torch.nn.ModuleList([]) + for _ in range(depth): + self.layers.append( + torch.nn.ModuleList( + [ + PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), + FeedForward(dim=dim, mult=ff_mult), + ] + ) + ) + + def forward(self, latents, x): + x = self.proj_in(x) + for attn, ff in self.layers: + latents = attn(x, latents) + latents + latents = ff(latents) + latents + latents = self.proj_out(latents) + return self.norm_out(latents) +*/ + + + +/* + +def FeedForward(dim, mult=4): + inner_dim = int(dim * mult) + return nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, inner_dim, bias=False), + nn.GELU(), + nn.Linear(inner_dim, dim, bias=False), + ) + +def reshape_tensor(x, heads): + bs, length, width = x.shape + # (bs, length, width) --> (bs, length, n_heads, dim_per_head) + x = x.view(bs, length, heads, -1) + # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) + x = x.transpose(1, 2) + # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) + x = x.reshape(bs, heads, length, -1) + return x + +class PerceiverAttention(nn.Module): + def __init__(self, *, dim, dim_head=64, heads=8): + super().__init__() + self.scale = dim_head**-0.5 + self.dim_head = dim_head + self.heads = heads + inner_dim = dim_head * heads + + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) + self.to_out = nn.Linear(inner_dim, dim, bias=False) + + def forward(self, x, latents): + """ + Args: + x (torch.Tensor): image features + shape (b, n1, D) + latent (torch.Tensor): latent features + shape (b, n2, D) + """ + x = self.norm1(x) + latents = self.norm2(latents) + + b, l, _ = latents.shape + + q = self.to_q(latents) + kv_input = torch.cat((x, latents), dim=-2) + k, v = self.to_kv(kv_input).chunk(2, dim=-1) + + q = reshape_tensor(q, self.heads) + k = reshape_tensor(k, self.heads) + v = reshape_tensor(v, self.heads) + + # attention + scale = 1 / math.sqrt(math.sqrt(self.dim_head)) + weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + out = weight @ v + + out = out.permute(0, 2, 1, 3).reshape(b, l, -1) + + return self.to_out(out) + +*/ + + + + struct FuseModule : public GGMLBlock { // network hparams int embed_dim; @@ -61,12 +439,19 @@ public: auto mlp2 = std::dynamic_pointer_cast(blocks["mlp2"]); auto layer_norm = std::dynamic_pointer_cast(blocks["layer_norm"]); - auto prompt_embeds0 = ggml_cont(ctx, ggml_permute(ctx, prompt_embeds, 2, 0, 1, 3)); - auto id_embeds0 = ggml_cont(ctx, ggml_permute(ctx, id_embeds, 2, 0, 1, 3)); - // concat is along dim 2 - auto stacked_id_embeds = ggml_concat(ctx, prompt_embeds0, id_embeds0, 2); - stacked_id_embeds = ggml_cont(ctx, ggml_permute(ctx, stacked_id_embeds, 1, 2, 0, 3)); + // print_ggml_tensor(id_embeds, true, "Fuseblock id_embeds: "); + // print_ggml_tensor(prompt_embeds, true, "Fuseblock prompt_embeds: "); + // auto prompt_embeds0 = ggml_cont(ctx, ggml_permute(ctx, prompt_embeds, 2, 0, 1, 3)); + // auto id_embeds0 = ggml_cont(ctx, ggml_permute(ctx, id_embeds, 2, 0, 1, 3)); + // print_ggml_tensor(id_embeds0, true, "Fuseblock id_embeds0: "); + // print_ggml_tensor(prompt_embeds0, true, "Fuseblock prompt_embeds0: "); + // concat is along dim 2 + // auto stacked_id_embeds = ggml_concat(ctx, prompt_embeds0, id_embeds0, 2); + auto stacked_id_embeds = ggml_concat(ctx, prompt_embeds, id_embeds, 0); + // print_ggml_tensor(stacked_id_embeds, true, "Fuseblock stacked_id_embeds 0: "); + // stacked_id_embeds = ggml_cont(ctx, ggml_permute(ctx, stacked_id_embeds, 1, 2, 0, 3)); + // print_ggml_tensor(stacked_id_embeds, true, "Fuseblock stacked_id_embeds 1: "); // stacked_id_embeds = mlp1.forward(ctx, stacked_id_embeds); // stacked_id_embeds = ggml_add(ctx, stacked_id_embeds, prompt_embeds); // stacked_id_embeds = mlp2.forward(ctx, stacked_id_embeds); @@ -77,6 +462,8 @@ public: stacked_id_embeds = mlp2->forward(ctx, stacked_id_embeds); stacked_id_embeds = layer_norm->forward(ctx, stacked_id_embeds); + // print_ggml_tensor(stacked_id_embeds, true, "Fuseblock stacked_id_embeds 1: "); + return stacked_id_embeds; } @@ -98,23 +485,31 @@ public: // print_ggml_tensor(class_tokens_mask_pos, true, "class_tokens_mask_pos"); struct ggml_tensor* image_token_embeds = ggml_get_rows(ctx, prompt_embeds, class_tokens_mask_pos); ggml_set_name(image_token_embeds, "image_token_embeds"); - struct ggml_tensor* stacked_id_embeds = fuse_fn(ctx, image_token_embeds, valid_id_embeds); + valid_id_embeds = ggml_reshape_2d(ctx, valid_id_embeds, valid_id_embeds->ne[0], + ggml_nelements(valid_id_embeds)/valid_id_embeds->ne[0]); + struct ggml_tensor* stacked_id_embeds = fuse_fn(ctx, image_token_embeds, valid_id_embeds); - stacked_id_embeds = ggml_cont(ctx, ggml_permute(ctx, stacked_id_embeds, 0, 2, 1, 3)); + // stacked_id_embeds = ggml_cont(ctx, ggml_permute(ctx, stacked_id_embeds, 0, 2, 1, 3)); + // print_ggml_tensor(stacked_id_embeds, true, "AA stacked_id_embeds"); + // print_ggml_tensor(left, true, "AA left"); + // print_ggml_tensor(right, true, "AA right"); if (left && right) { - stacked_id_embeds = ggml_concat(ctx, left, stacked_id_embeds, 2); - stacked_id_embeds = ggml_concat(ctx, stacked_id_embeds, right, 2); + stacked_id_embeds = ggml_concat(ctx, left, stacked_id_embeds, 1); + stacked_id_embeds = ggml_concat(ctx, stacked_id_embeds, right, 1); } else if (left) { - stacked_id_embeds = ggml_concat(ctx, left, stacked_id_embeds, 2); + stacked_id_embeds = ggml_concat(ctx, left, stacked_id_embeds, 1); } else if (right) { - stacked_id_embeds = ggml_concat(ctx, stacked_id_embeds, right, 2); + stacked_id_embeds = ggml_concat(ctx, stacked_id_embeds, right, 1); } - stacked_id_embeds = ggml_cont(ctx, ggml_permute(ctx, stacked_id_embeds, 0, 2, 1, 3)); + // print_ggml_tensor(stacked_id_embeds, true, "BB stacked_id_embeds"); + // stacked_id_embeds = ggml_cont(ctx, ggml_permute(ctx, stacked_id_embeds, 0, 2, 1, 3)); + // print_ggml_tensor(stacked_id_embeds, true, "CC stacked_id_embeds"); class_tokens_mask = ggml_cont(ctx, ggml_transpose(ctx, class_tokens_mask)); class_tokens_mask = ggml_repeat(ctx, class_tokens_mask, prompt_embeds); prompt_embeds = ggml_mul(ctx, prompt_embeds, class_tokens_mask); struct ggml_tensor* updated_prompt_embeds = ggml_add(ctx, prompt_embeds, stacked_id_embeds); ggml_set_name(updated_prompt_embeds, "updated_prompt_embeds"); + // print_ggml_tensor(updated_prompt_embeds, true, "updated_prompt_embeds: "); return updated_prompt_embeds; } }; @@ -159,10 +554,79 @@ struct PhotoMakerIDEncoderBlock : public CLIPVisionModelProjection { } }; +struct PhotoMakerIDEncoder_CLIPInsightfaceExtendtokenBlock : public CLIPVisionModelProjection { + + int cross_attention_dim; + int num_tokens; + + PhotoMakerIDEncoder_CLIPInsightfaceExtendtokenBlock(int id_embeddings_dim=512) + : CLIPVisionModelProjection(OPENAI_CLIP_VIT_L_14), + cross_attention_dim (2048), + num_tokens(2) { + blocks["visual_projection_2"] = std::shared_ptr(new Linear(1024, 1280, false)); + blocks["fuse_module"] = std::shared_ptr(new FuseModule(2048)); + /* + cross_attention_dim = 2048 + # projection + self.num_tokens = 2 + self.cross_attention_dim = cross_attention_dim + self.qformer_perceiver = QFormerPerceiver( + id_embeddings_dim, + cross_attention_dim, + self.num_tokens, + )*/ + blocks["qformer_perceiver"] = std::shared_ptr(new QFormerPerceiver(id_embeddings_dim, + cross_attention_dim, + num_tokens)); + + } + + /* + def forward(self, id_pixel_values, prompt_embeds, class_tokens_mask, id_embeds): + b, num_inputs, c, h, w = id_pixel_values.shape + id_pixel_values = id_pixel_values.view(b * num_inputs, c, h, w) + + last_hidden_state = self.vision_model(id_pixel_values)[0] + id_embeds = id_embeds.view(b * num_inputs, -1) + + id_embeds = self.qformer_perceiver(id_embeds, last_hidden_state) + id_embeds = id_embeds.view(b, num_inputs, self.num_tokens, -1) + updated_prompt_embeds = self.fuse_module(prompt_embeds, id_embeds, class_tokens_mask) + */ + + struct ggml_tensor* forward(struct ggml_context* ctx, + struct ggml_tensor* id_pixel_values, + struct ggml_tensor* prompt_embeds, + struct ggml_tensor* class_tokens_mask, + struct ggml_tensor* class_tokens_mask_pos, + struct ggml_tensor* id_embeds, + struct ggml_tensor* left, + struct ggml_tensor* right) { + // x: [N, channels, h, w] + auto vision_model = std::dynamic_pointer_cast(blocks["vision_model"]); + auto fuse_module = std::dynamic_pointer_cast(blocks["fuse_module"]); + auto qformer_perceiver = std::dynamic_pointer_cast(blocks["qformer_perceiver"]); + + // struct ggml_tensor* last_hidden_state = vision_model->forward(ctx, id_pixel_values); // [N, hidden_size] + struct ggml_tensor* last_hidden_state = vision_model->forward(ctx, id_pixel_values, false); // [N, hidden_size] + id_embeds = qformer_perceiver->forward(ctx, id_embeds, last_hidden_state); + + struct ggml_tensor* updated_prompt_embeds = fuse_module->forward(ctx, + prompt_embeds, + id_embeds, + class_tokens_mask, + class_tokens_mask_pos, + left, right); + return updated_prompt_embeds; + } +}; + struct PhotoMakerIDEncoder : public GGMLRunner { public: SDVersion version = VERSION_SDXL; + PMVersion pm_version = VERSION_1; PhotoMakerIDEncoderBlock id_encoder; + PhotoMakerIDEncoder_CLIPInsightfaceExtendtokenBlock id_encoder2; float style_strength; std::vector ctm; @@ -175,25 +639,41 @@ public: std::vector zeros_right; public: - PhotoMakerIDEncoder(ggml_backend_t backend, ggml_type wtype, SDVersion version = VERSION_SDXL, float sty = 20.f) + PhotoMakerIDEncoder(ggml_backend_t backend, ggml_type wtype, SDVersion version = VERSION_SDXL, + PMVersion pm_v = VERSION_1, float sty = 20.f) : GGMLRunner(backend, wtype), version(version), + pm_version(pm_v), style_strength(sty) { - id_encoder.init(params_ctx, wtype); + if(pm_version == VERSION_1){ + id_encoder.init(params_ctx, wtype); + }else if(pm_version == VERSION_2){ + id_encoder2.init(params_ctx, wtype); + } } std::string get_desc() { return "pmid"; } + PMVersion get_version() const{ + return pm_version; + } + + void get_param_tensors(std::map& tensors, const std::string prefix) { - id_encoder.get_param_tensors(tensors, prefix); + if(pm_version == VERSION_1) + id_encoder.get_param_tensors(tensors, prefix); + else if(pm_version == VERSION_2) + id_encoder2.get_param_tensors(tensors, prefix); + } struct ggml_cgraph* build_graph( // struct ggml_allocr* allocr, struct ggml_tensor* id_pixel_values, struct ggml_tensor* prompt_embeds, - std::vector& class_tokens_mask) { + std::vector& class_tokens_mask, + struct ggml_tensor* id_embeds) { ctm.clear(); ctmf16.clear(); ctmpos.clear(); @@ -214,25 +694,32 @@ public: struct ggml_tensor* id_pixel_values_d = to_backend(id_pixel_values); struct ggml_tensor* prompt_embeds_d = to_backend(prompt_embeds); + struct ggml_tensor* id_embeds_d = to_backend(id_embeds); struct ggml_tensor* left = NULL; struct ggml_tensor* right = NULL; for (int i = 0; i < class_tokens_mask.size(); i++) { if (class_tokens_mask[i]) { + // printf(" 1,"); ctm.push_back(0.f); // here use 0.f instead of 1.f to make a scale mask ctmf16.push_back(ggml_fp32_to_fp16(0.f)); // here use 0.f instead of 1.f to make a scale mask ctmpos.push_back(i); } else { + // printf(" 0,"); ctm.push_back(1.f); // here use 1.f instead of 0.f to make a scale mask ctmf16.push_back(ggml_fp32_to_fp16(1.f)); // here use 0.f instead of 1.f to make a scale mask } } + // printf("\n"); if (ctmpos[0] > 0) { - left = ggml_new_tensor_3d(ctx0, type, hidden_size, 1, ctmpos[0]); + // left = ggml_new_tensor_3d(ctx0, type, hidden_size, 1, ctmpos[0]); + left = ggml_new_tensor_3d(ctx0, type, hidden_size, ctmpos[0], 1); } if (ctmpos[ctmpos.size() - 1] < seq_length - 1) { + // right = ggml_new_tensor_3d(ctx0, type, + // hidden_size, 1, seq_length - ctmpos[ctmpos.size() - 1] - 1); right = ggml_new_tensor_3d(ctx0, type, - hidden_size, 1, seq_length - ctmpos[ctmpos.size() - 1] - 1); + hidden_size, seq_length - ctmpos[ctmpos.size() - 1] - 1, 1); } struct ggml_tensor* class_tokens_mask_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ctmpos.size()); @@ -265,12 +752,23 @@ public: } } } - struct ggml_tensor* updated_prompt_embeds = id_encoder.forward(ctx0, - id_pixel_values_d, - prompt_embeds_d, - class_tokens_mask_d, - class_tokens_mask_pos, - left, right); + struct ggml_tensor* updated_prompt_embeds = NULL; + if(pm_version == VERSION_1) + updated_prompt_embeds = id_encoder.forward(ctx0, + id_pixel_values_d, + prompt_embeds_d, + class_tokens_mask_d, + class_tokens_mask_pos, + left, right); + else if(pm_version == VERSION_2) + updated_prompt_embeds = id_encoder2.forward(ctx0, + id_pixel_values_d, + prompt_embeds_d, + class_tokens_mask_d, + class_tokens_mask_pos, + id_embeds_d, + left, right); + ggml_build_forward_expand(gf, updated_prompt_embeds); return gf; @@ -279,12 +777,13 @@ public: void compute(const int n_threads, struct ggml_tensor* id_pixel_values, struct ggml_tensor* prompt_embeds, + struct ggml_tensor* id_embeds, std::vector& class_tokens_mask, struct ggml_tensor** updated_prompt_embeds, ggml_context* output_ctx) { auto get_graph = [&]() -> struct ggml_cgraph* { // return build_graph(compute_allocr, id_pixel_values, prompt_embeds, class_tokens_mask); - return build_graph(id_pixel_values, prompt_embeds, class_tokens_mask); + return build_graph(id_pixel_values, prompt_embeds, class_tokens_mask, id_embeds); }; // GGMLRunner::compute(get_graph, n_threads, updated_prompt_embeds); @@ -292,4 +791,79 @@ public: } }; + +struct PhotoMakerIDEmbed : public GGMLRunner { + + std::map tensors; + std::string file_path; + ModelLoader *model_loader; + bool load_failed = false; + bool applied = false; + + PhotoMakerIDEmbed(ggml_backend_t backend, + ggml_type wtype, + ModelLoader *ml, + const std::string& file_path = "", + const std::string& prefix = "") + : file_path(file_path), GGMLRunner(backend, wtype), + model_loader(ml) { + if (!model_loader->init_from_file(file_path, prefix)) { + load_failed = true; + } + } + + std::string get_desc() { + return "id_embeds"; + } + + bool load_from_file(bool filter_tensor = false) { + LOG_INFO("loading PhotoMaker ID Embeds from '%s'", file_path.c_str()); + + if (load_failed) { + LOG_ERROR("init photomaker id embed from file failed: '%s'", file_path.c_str()); + return false; + } + + bool dry_run = true; + auto on_new_tensor_cb = [&](const TensorStorage& tensor_storage, ggml_tensor** dst_tensor) -> bool { + const std::string& name = tensor_storage.name; + + if (filter_tensor && !contains(name, "pmid.id_embeds")) { + // LOG_INFO("skipping LoRA tesnor '%s'", name.c_str()); + return true; + } + if (dry_run) { + struct ggml_tensor* real = ggml_new_tensor(params_ctx, + tensor_storage.type, + tensor_storage.n_dims, + tensor_storage.ne); + tensors[name] = real; + } else { + auto real = tensors[name]; + *dst_tensor = real; + } + + return true; + }; + + model_loader->load_tensors(on_new_tensor_cb, backend); + alloc_params_buffer(); + + dry_run = false; + model_loader->load_tensors(on_new_tensor_cb, backend); + + LOG_DEBUG("finished loading PhotoMaker ID Embeds "); + return true; + } + + + struct ggml_tensor* get(){ + std::map::iterator pos; + pos = tensors.find("pmid.id_embeds"); + if(pos != tensors.end()) + return pos->second; + return NULL; + } +}; + #endif // __PMI_HPP__ diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 2297cd3..d70dab1 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -95,6 +95,7 @@ public: std::shared_ptr control_net; std::shared_ptr pmid_model; std::shared_ptr pmid_lora; + std::shared_ptr pmid_id_embeds; std::string taesd_path; bool use_tiny_autoencoder = false; @@ -331,7 +332,11 @@ public: cond_stage_model = std::make_shared(clip_backend, conditioner_wtype); diffusion_model = std::make_shared(backend, diffusion_model_wtype, version); } else { - cond_stage_model = std::make_shared(clip_backend, conditioner_wtype, embeddings_path, version); + if(id_embeddings_path.find("v2") != std::string::npos) { + cond_stage_model = std::make_shared(clip_backend, conditioner_wtype, embeddings_path, version, VERSION_2); + }else{ + cond_stage_model = std::make_shared(clip_backend, conditioner_wtype, embeddings_path, version); + } diffusion_model = std::make_shared(backend, diffusion_model_wtype, version); } cond_stage_model->alloc_params_buffer(); @@ -366,7 +371,12 @@ public: control_net = std::make_shared(controlnet_backend, diffusion_model_wtype, version); } - pmid_model = std::make_shared(clip_backend, model_wtype, version); + if(id_embeddings_path.find("v2") != std::string::npos) { + pmid_model = std::make_shared(backend, model_wtype, version, VERSION_2); + LOG_INFO("using PhotoMaker Version 2"); + } else { + pmid_model = std::make_shared(backend, model_wtype, version); + } if (id_embeddings_path.size() > 0) { pmid_lora = std::make_shared(backend, model_wtype, id_embeddings_path, ""); if (!pmid_lora->load_from_file(true)) { @@ -385,14 +395,8 @@ public: LOG_ERROR(" pmid model params buffer allocation failed"); return false; } - // LOG_INFO("pmid param memory buffer size = %.2fMB ", - // pmid_model->params_buffer_size / 1024.0 / 1024.0); pmid_model->get_param_tensors(tensors, "pmid"); } - // if(stacked_id){ - // pmid_model.init_params(GGML_TYPE_F32); - // pmid_model.map_by_name(tensors, "pmid."); - // } } struct ggml_init_params params; @@ -675,10 +679,10 @@ public: ggml_tensor* id_encoder(ggml_context* work_ctx, ggml_tensor* init_img, ggml_tensor* prompts_embeds, + ggml_tensor* id_embeds, std::vector& class_tokens_mask) { ggml_tensor* res = NULL; - pmid_model->compute(n_threads, init_img, prompts_embeds, class_tokens_mask, &res, work_ctx); - + pmid_model->compute(n_threads, init_img, prompts_embeds, id_embeds, class_tokens_mask, &res, work_ctx); return res; } @@ -1207,11 +1211,15 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, } // preprocess input id images std::vector input_id_images; + bool pmv2 = sd_ctx->sd->pmid_model->get_version() == VERSION_2; if (sd_ctx->sd->pmid_model && input_id_images_path.size() > 0) { std::vector img_files = get_files_from_dir(input_id_images_path); for (std::string img_file : img_files) { int c = 0; int width, height; + if(ends_with(img_file, "safetensors")){ + continue; + } uint8_t* input_image_buffer = stbi_load(img_file.c_str(), &width, &height, &c, 3); if (input_image_buffer == NULL) { LOG_ERROR("PhotoMaker load image from '%s' failed", img_file.c_str()); @@ -1259,8 +1267,13 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, sd_ctx->sd->diffusion_model->get_adm_in_channels()); id_cond = std::get<0>(cond_tup); class_tokens_mask = std::get<1>(cond_tup); // - - id_cond.c_crossattn = sd_ctx->sd->id_encoder(work_ctx, init_img, id_cond.c_crossattn, class_tokens_mask); + struct ggml_tensor* id_embeds = NULL; + if(pmv2){ + // id_embeds = sd_ctx->sd->pmid_id_embeds->get(); + id_embeds = load_tensor_from_file(work_ctx, path_join(input_id_images_path, "id_embeds.bin")); + // print_ggml_tensor(id_embeds, true, "id_embeds:"); + } + id_cond.c_crossattn = sd_ctx->sd->id_encoder(work_ctx, init_img, id_cond.c_crossattn, id_embeds, class_tokens_mask); t1 = ggml_time_ms(); LOG_INFO("Photomaker ID Stacking, taking %" PRId64 " ms", t1 - t0); if (sd_ctx->sd->free_params_immediately) { diff --git a/util.cpp b/util.cpp index 5de5ce2..cd058bb 100644 --- a/util.cpp +++ b/util.cpp @@ -276,6 +276,24 @@ std::string path_join(const std::string& p1, const std::string& p2) { return p1 + "/" + p2; } +std::vector splitString(const std::string& str, char delimiter) { + std::vector result; + size_t start = 0; + size_t end = str.find(delimiter); + + while (end != std::string::npos) { + result.push_back(str.substr(start, end - start)); + start = end + 1; + end = str.find(delimiter, start); + } + + // Add the last segment after the last delimiter + result.push_back(str.substr(start)); + + return result; +} + + sd_image_t* preprocess_id_image(sd_image_t* img) { int shortest_edge = 224; int size = shortest_edge; diff --git a/util.h b/util.h index 9b1e673..14fa812 100644 --- a/util.h +++ b/util.h @@ -45,7 +45,7 @@ sd_image_f32_t resize_sd_image_f32_t(sd_image_f32_t image, int target_width, int sd_image_f32_t clip_preprocess(sd_image_f32_t image, int size); std::string path_join(const std::string& p1, const std::string& p2); - +std::vector splitString(const std::string& str, char delimiter); void pretty_progress(int step, int steps, float time); void log_printf(sd_log_level_t level, const char* file, int line, const char* format, ...);