diff --git a/README.md b/README.md index a067eb4..7b2961a 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,7 @@ Inference of [Stable Diffusion](https://github.com/CompVis/stable-diffusion) in - !!!The VAE in SDXL encounters NaN issues under FP16, but unfortunately, the ggml_conv_2d only operates under FP16. Hence, a parameter is needed to specify the VAE that has fixed the FP16 NaN issue. You can find it here: [SDXL VAE FP16 Fix](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/blob/main/sdxl_vae.safetensors). - [SD-Turbo](https://huggingface.co/stabilityai/sd-turbo) and [SDXL-Turbo](https://huggingface.co/stabilityai/sdxl-turbo) support +- [PhotoMaker](https://github.com/TencentARC/PhotoMaker) support. - 16-bit, 32-bit float support - 4-bit, 5-bit and 8-bit integer quantization support - Accelerated memory-efficient CPU inference @@ -151,7 +152,7 @@ cmake --build . --config Release ### Run ``` -usage: ./build/bin/sd [arguments] +usage: ./bin/sd [arguments] arguments: -h, --help show this help message and exit @@ -163,6 +164,9 @@ arguments: --taesd [TAESD_PATH] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality) --control-net [CONTROL_PATH] path to control net model --embd-dir [EMBEDDING_PATH] path to embeddings. + --stacked-id-embd-dir [DIR] path to PHOTOMAKER stacked id embeddings. + --input-id-images-dir [DIR] path to PHOTOMAKER input id images dir. + --normalize-input normalize PHOTOMAKER input id images --upscale-model [ESRGAN_PATH] path to esrgan model. Upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now. --upscale-repeats Run the ESRGAN upscaler this many times (default 1) --type [TYPE] weight type (f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0) @@ -175,6 +179,7 @@ arguments: -n, --negative-prompt PROMPT the negative prompt (default: "") --cfg-scale SCALE unconditional guidance scale: (default: 7.0) --strength STRENGTH strength for noising/unnoising (default: 0.75) + --style-ratio STYLE-RATIO strength for keeping input identity (default: 20%) --control-strength STRENGTH strength to apply Control Net (default: 0.9) 1.0 corresponds to full destruction of information in init image -H, --height H image height, in pixel space (default: 512) @@ -299,6 +304,39 @@ You can use ESRGAN to upscale the generated images. At the moment, only the [Rea sd -m ../models/v1-5-pruned-emaonly.safetensors -p "a lovely cat" --upscale-model ../models/RealESRGAN_x4plus_anime_6B.pth ``` +#### Using PhotoMaker to personalize image generation + +You can use [PhotoMaker](https://github.com/TencentARC/PhotoMaker) to personalize generated images with your own ID. + +**NOTE**, currently PhotoMaker **ONLY** works with **SDXL** (any SDXL model files will work). + +Download PhotoMaker model file (in safetensor format) [here](https://huggingface.co/bssrdf/PhotoMaker). The official release of the model file (in .bin format) does not work with ```stablediffusion.cpp```. + +- Specify the PhotoMaker model path using the `--stacked-id-embd-dir PATH` parameter. +- Specify the input images path using the `--input-id-images-dir PATH` parameter. + - input images **must** have the same width and height for preprocessing (to be improved) + +In prompt, make sure you have a class word followed by the trigger word ```"img"``` (hard-coded for now). The class word could be one of ```"man, woman, girl, boy"```. If input ID images contain asian faces, add ```Asian``` before the class +word. + +Another PhotoMaker specific parameter: + +- ```--style-ratio (0-100)%```: default is 20 and 10-20 typically gets good results. Lower ratio means more faithfully following input ID (not necessarily better quality). + +Other parameters recommended for running Photomaker: + +- ```--cfg-scale 5.0``` +- ```-H 1024``` +- ```-W 1024``` + +If on low memory GPUs (<= 8GB), recommend running with ```--vae-on-cpu``` option to get artifact free images. + +Example: + +```bash +bin/sd -m ../models/sdxlUnstableDiffusers_v11.safetensors --vae ../models/sdxl_vae.safetensors --stacked-id-embd-dir ../models/photomaker-v1.safetensors --input-id-images-dir ../assets/examples/scarletthead_woman -p "a girl img, retro futurism, retro game art style but extremely beautiful, intricate details, masterpiece, best quality, space-themed, cosmic, celestial, stars, galaxies, nebulas, planets, science fiction, highly detailed" -n "realistic, photo-realistic, worst quality, greyscale, bad anatomy, bad hands, error, text" --cfg-scale 5.0 --sampling-method euler -H 1024 -W 1024 --style-ratio 10 --vae-on-cpu -o output.png +``` + ### Docker #### Building using Docker @@ -345,3 +383,4 @@ Thank you to all the people who have already contributed to stable-diffusion.cpp - [k-diffusion](https://github.com/crowsonkb/k-diffusion) - [latent-consistency-model](https://github.com/luosiallen/latent-consistency-model) - [generative-models](https://github.com/Stability-AI/generative-models/) +- [PhotoMaker](https://github.com/TencentARC/PhotoMaker) diff --git a/assets/photomaker_examples/lenna_woman/lenna.jpg b/assets/photomaker_examples/lenna_woman/lenna.jpg new file mode 100644 index 0000000..ca3ef19 Binary files /dev/null and b/assets/photomaker_examples/lenna_woman/lenna.jpg differ diff --git a/assets/photomaker_examples/newton_man/newton_0.jpg b/assets/photomaker_examples/newton_man/newton_0.jpg new file mode 100644 index 0000000..71ba285 Binary files /dev/null and b/assets/photomaker_examples/newton_man/newton_0.jpg differ diff --git a/assets/photomaker_examples/newton_man/newton_1.jpg b/assets/photomaker_examples/newton_man/newton_1.jpg new file mode 100644 index 0000000..a59ed8c Binary files /dev/null and b/assets/photomaker_examples/newton_man/newton_1.jpg differ diff --git a/assets/photomaker_examples/newton_man/newton_2.png b/assets/photomaker_examples/newton_man/newton_2.png new file mode 100644 index 0000000..d8d4b94 Binary files /dev/null and b/assets/photomaker_examples/newton_man/newton_2.png differ diff --git a/assets/photomaker_examples/newton_man/newton_3.jpg b/assets/photomaker_examples/newton_man/newton_3.jpg new file mode 100644 index 0000000..852867e Binary files /dev/null and b/assets/photomaker_examples/newton_man/newton_3.jpg differ diff --git a/assets/photomaker_examples/scarletthead_woman/scarlett_0.jpg b/assets/photomaker_examples/scarletthead_woman/scarlett_0.jpg new file mode 100644 index 0000000..ce9435a Binary files /dev/null and b/assets/photomaker_examples/scarletthead_woman/scarlett_0.jpg differ diff --git a/assets/photomaker_examples/scarletthead_woman/scarlett_1.jpg b/assets/photomaker_examples/scarletthead_woman/scarlett_1.jpg new file mode 100644 index 0000000..2326996 Binary files /dev/null and b/assets/photomaker_examples/scarletthead_woman/scarlett_1.jpg differ diff --git a/assets/photomaker_examples/scarletthead_woman/scarlett_2.jpg b/assets/photomaker_examples/scarletthead_woman/scarlett_2.jpg new file mode 100644 index 0000000..93ae735 Binary files /dev/null and b/assets/photomaker_examples/scarletthead_woman/scarlett_2.jpg differ diff --git a/assets/photomaker_examples/scarletthead_woman/scarlett_3.jpg b/assets/photomaker_examples/scarletthead_woman/scarlett_3.jpg new file mode 100644 index 0000000..ccdca4b Binary files /dev/null and b/assets/photomaker_examples/scarletthead_woman/scarlett_3.jpg differ diff --git a/assets/photomaker_examples/yangmi_woman/yangmi_1.jpg b/assets/photomaker_examples/yangmi_woman/yangmi_1.jpg new file mode 100644 index 0000000..20fe66c Binary files /dev/null and b/assets/photomaker_examples/yangmi_woman/yangmi_1.jpg differ diff --git a/assets/photomaker_examples/yangmi_woman/yangmi_2.jpeg b/assets/photomaker_examples/yangmi_woman/yangmi_2.jpeg new file mode 100644 index 0000000..9ed4743 Binary files /dev/null and b/assets/photomaker_examples/yangmi_woman/yangmi_2.jpeg differ diff --git a/assets/photomaker_examples/yangmi_woman/yangmi_3.jpg b/assets/photomaker_examples/yangmi_woman/yangmi_3.jpg new file mode 100644 index 0000000..e840c1c Binary files /dev/null and b/assets/photomaker_examples/yangmi_woman/yangmi_3.jpg differ diff --git a/assets/photomaker_examples/yangmi_woman/yangmi_4.jpg b/assets/photomaker_examples/yangmi_woman/yangmi_4.jpg new file mode 100644 index 0000000..f436011 Binary files /dev/null and b/assets/photomaker_examples/yangmi_woman/yangmi_4.jpg differ diff --git a/assets/photomaker_examples/yangmi_woman/yangmi_5.jpg b/assets/photomaker_examples/yangmi_woman/yangmi_5.jpg new file mode 100644 index 0000000..95e7714 Binary files /dev/null and b/assets/photomaker_examples/yangmi_woman/yangmi_5.jpg differ diff --git a/assets/photomaker_examples/yangmi_woman/yangmi_6.jpg b/assets/photomaker_examples/yangmi_woman/yangmi_6.jpg new file mode 100644 index 0000000..8c7c442 Binary files /dev/null and b/assets/photomaker_examples/yangmi_woman/yangmi_6.jpg differ diff --git a/clip.hpp b/clip.hpp index 2ba12da..42bfd08 100644 --- a/clip.hpp +++ b/clip.hpp @@ -75,9 +75,13 @@ class CLIPTokenizer { private: SDVersion version = VERSION_1_x; std::map byte_encoder; + std::map byte_decoder; std::map encoder; + std::map decoder; std::map, int> bpe_ranks; std::regex pat; + int encoder_len; + int bpe_len; static std::string strip(const std::string& str) { std::string::size_type start = str.find_first_not_of(" \t\n\r\v\f"); @@ -118,7 +122,11 @@ public: void load_from_merges(const std::string& merges_utf8_str) { auto byte_unicode_pairs = bytes_to_unicode(); - byte_encoder = std::map(byte_unicode_pairs.begin(), byte_unicode_pairs.end()); + // printf("byte_unicode_pairs have %lu pairs \n", byte_unicode_pairs.size()); + byte_encoder = std::map(byte_unicode_pairs.begin(), byte_unicode_pairs.end()); + for (auto& pair : byte_unicode_pairs) { + byte_decoder[pair.second] = pair.first; + } // for (auto & pair: byte_unicode_pairs) { // std::cout << pair.first << ": " << pair.second << std::endl; // } @@ -138,6 +146,8 @@ public: size_t space_pos = merge.find(' '); merge_pairs.emplace_back(merge.substr(0, space_pos), merge.substr(space_pos + 1)); // LOG_DEBUG("%s", utf32_to_utf8(merge.substr(space_pos + 1)).c_str()); + // printf("%s :: %s | %s \n", utf32_to_utf8(merge).c_str(), utf32_to_utf8(merge.substr(0, space_pos)).c_str(), + // utf32_to_utf8(merge.substr(space_pos + 1)).c_str()); } std::vector vocab; for (const auto& pair : byte_unicode_pairs) { @@ -154,15 +164,36 @@ public: LOG_DEBUG("vocab size: %llu", vocab.size()); int i = 0; for (const auto& token : vocab) { - encoder[token] = i++; + encoder[token] = i; + decoder[i] = token; + i++; + } + encoder_len = i; + + auto it = encoder.find(utf8_to_utf32("img")); + if (it != encoder.end()) { + LOG_DEBUG(" trigger word img already in vocab"); + } else { + LOG_DEBUG(" trigger word img not in vocab yet"); } int rank = 0; for (const auto& merge : merge_pairs) { bpe_ranks[merge] = rank++; } + bpe_len = rank; }; + void add_token(const std::string& text) { + std::u32string token = utf8_to_utf32(text); + auto it = encoder.find(token); + if (it != encoder.end()) { + encoder[token] = encoder_len; + decoder[encoder_len] = token; + encoder_len++; + } + } + std::u32string bpe(const std::u32string& token) { std::vector word; @@ -243,6 +274,7 @@ public: size_t max_length = 0, bool padding = false) { std::vector tokens = encode(text, on_new_token_cb); + tokens.insert(tokens.begin(), BOS_TOKEN_ID); if (max_length > 0) { if (tokens.size() > max_length - 1) { @@ -259,9 +291,34 @@ public: } } } + return tokens; } + std::string decode(const std::vector& tokens) { + std::string text = ""; + for (int t : tokens) { + if (t == 49406 || t == 49407) + continue; + std::u32string ts = decoder[t]; + // printf("%d, %s \n", t, utf32_to_utf8(ts).c_str()); + std::string s = utf32_to_utf8(ts); + if (s.length() >= 4 && ends_with(s, "")) { + text += " " + s.replace(s.length() - 4, s.length() - 1, ""); + } else { + text += " " + s; + } + } + // std::vector bytes; + // for (auto c : text){ + // bytes.push_back(byte_decoder[c]); + // } + + // std::string s((char *)bytes.data()); + // std::string s = ""; + return trim(text); + } + std::vector encode(std::string text, on_new_token_cb_t on_new_token_cb) { std::string original_text = text; std::vector bpe_tokens; @@ -308,7 +365,8 @@ public: ss << "\"" << token << "\", "; } ss << "]"; - LOG_DEBUG("split prompt \"%s\" to tokens %s", original_text.c_str(), ss.str().c_str()); + // LOG_DEBUG("split prompt \"%s\" to tokens %s", original_text.c_str(), ss.str().c_str()); + // printf("split prompt \"%s\" to tokens %s \n", original_text.c_str(), ss.str().c_str()); return bpe_tokens; } }; @@ -469,7 +527,8 @@ public: : d_model(d_model), n_head(n_head), intermediate_size(intermediate_size) { - blocks["self_attn"] = std::shared_ptr(new MultiheadAttention(d_model, n_head)); + blocks["self_attn"] = std::shared_ptr(new MultiheadAttention(d_model, n_head, true)); + blocks["layer_norm1"] = std::shared_ptr(new LayerNorm(d_model)); blocks["layer_norm2"] = std::shared_ptr(new LayerNorm(d_model)); @@ -508,7 +567,7 @@ public: struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, int clip_skip = -1, bool mask = true) { // x: [N, n_token, d_model] int layer_idx = n_layer - 1; - LOG_DEBUG("clip_skip %d", clip_skip); + // LOG_DEBUG("clip_skip %d", clip_skip); if (clip_skip > 0) { layer_idx = n_layer - clip_skip; } @@ -520,7 +579,7 @@ public: } std::string name = "layers." + std::to_string(i); auto layer = std::dynamic_pointer_cast(blocks[name]); - x = layer->forward(ctx, x); // [N, n_token, d_model] + x = layer->forward(ctx, x, mask); // [N, n_token, d_model] // LOG_DEBUG("layer %d", i); } return x; @@ -703,7 +762,7 @@ public: auto final_layer_norm = std::dynamic_pointer_cast(blocks["final_layer_norm"]); auto x = embeddings->forward(ctx, input_ids, tkn_embeddings); // [N, n_token, hidden_size] - x = encoder->forward(ctx, x, return_pooled ? -1 : clip_skip, true); + x = encoder->forward(ctx, x, return_pooled ? -1 : clip_skip, true); if (return_pooled || with_final_ln) { x = final_layer_norm->forward(ctx, x); } @@ -720,11 +779,6 @@ public: }; class CLIPVisionModel : public GGMLBlock { -protected: - void init_params(struct ggml_context* ctx, ggml_type wtype) { - params["visual_projection"] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, projection_dim, hidden_size); - } - public: // network hparams int32_t num_channels = 3; @@ -735,16 +789,14 @@ public: int32_t intermediate_size = 4096; int32_t n_head = 16; int32_t n_layer = 24; - int32_t projection_dim = 768; public: - CLIPVisionModel(CLIPVersion version = OPEN_CLIP_VIT_H_14) { + CLIPVisionModel(CLIPVersion version = OPENAI_CLIP_VIT_L_14) { if (version == OPEN_CLIP_VIT_H_14) { hidden_size = 1280; intermediate_size = 5120; n_head = 16; n_layer = 32; - projection_dim = 1024; } else if (version == OPEN_CLIP_VIT_BIGG_14) { hidden_size = 1664; intermediate_size = 8192; @@ -758,9 +810,8 @@ public: blocks["post_layernorm"] = std::shared_ptr(new LayerNorm(hidden_size)); } - struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* pixel_values) { + struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* pixel_values, bool return_pooled = true) { // pixel_values: [N, num_channels, image_size, image_size] - // return: // [N, projection_dim] auto embeddings = std::dynamic_pointer_cast(blocks["embeddings"]); auto pre_layernorm = std::dynamic_pointer_cast(blocks["pre_layernorm"]); auto encoder = std::dynamic_pointer_cast(blocks["encoder"]); @@ -768,26 +819,60 @@ public: auto x = embeddings->forward(ctx, pixel_values); // [N, num_positions, embed_dim] x = pre_layernorm->forward(ctx, x); - x = encoder->forward(ctx, x, -1, true); + x = encoder->forward(ctx, x, -1, false); x = post_layernorm->forward(ctx, x); // [N, n_token, hidden_size] - GGML_ASSERT(x->ne[2] == 1); - int64_t max_token_idx = 0; - ggml_tensor* pooled = ggml_view_1d(ctx, x, x->ne[0], x->nb[1] * max_token_idx); // assert N == 1 - auto visual_projection = params["visual_projection"]; - pooled = ggml_mul_mat(ctx, ggml_cont(ctx, ggml_transpose(ctx, visual_projection)), pooled); - return pooled; // [N, projection_dim] + GGML_ASSERT(x->ne[3] == 1); + if (return_pooled) { + ggml_tensor* pooled = ggml_cont(ctx, ggml_view_2d(ctx, x, x->ne[0], x->ne[2], x->nb[2], 0)); + return pooled; // [N, hidden_size] + } else { + return x; // [N, n_token, hidden_size] + } + } +}; + +class CLIPProjection : public UnaryBlock { +protected: + int64_t in_features; + int64_t out_features; + bool transpose_weight; + + void init_params(struct ggml_context* ctx, ggml_type wtype) { + if (transpose_weight) { + LOG_ERROR("transpose_weight"); + params["weight"] = ggml_new_tensor_2d(ctx, wtype, out_features, in_features); + } else { + params["weight"] = ggml_new_tensor_2d(ctx, wtype, in_features, out_features); + } + } + +public: + CLIPProjection(int64_t in_features, + int64_t out_features, + bool transpose_weight = false) + : in_features(in_features), + out_features(out_features), + transpose_weight(transpose_weight) {} + + struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { + struct ggml_tensor* w = params["weight"]; + if (transpose_weight) { + w = ggml_cont(ctx, ggml_transpose(ctx, w)); + } + return ggml_nn_linear(ctx, x, w, NULL); } }; class CLIPVisionModelProjection : public GGMLBlock { public: int32_t hidden_size = 1024; - int32_t projection_dim = 1024; + int32_t projection_dim = 768; int32_t image_size = 224; public: - CLIPVisionModelProjection(CLIPVersion version = OPEN_CLIP_VIT_H_14) { + CLIPVisionModelProjection(CLIPVersion version = OPENAI_CLIP_VIT_L_14, + bool transpose_proj_w = false) { if (version == OPEN_CLIP_VIT_H_14) { hidden_size = 1280; projection_dim = 1024; @@ -795,17 +880,17 @@ public: hidden_size = 1664; } - blocks["visual_model"] = std::shared_ptr(new CLIPVisionModel(version)); - blocks["visual_projection"] = std::shared_ptr(new Linear(hidden_size, projection_dim, false)); + blocks["vision_model"] = std::shared_ptr(new CLIPVisionModel(version)); + blocks["visual_projection"] = std::shared_ptr(new CLIPProjection(hidden_size, projection_dim, transpose_proj_w)); } struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* pixel_values) { // pixel_values: [N, num_channels, image_size, image_size] - // return: [N, num_positions, projection_dim] - auto visual_model = std::dynamic_pointer_cast(blocks["visual_model"]); - auto visual_projection = std::dynamic_pointer_cast(blocks["visual_projection"]); + // return: [N, projection_dim] + auto vision_model = std::dynamic_pointer_cast(blocks["vision_model"]); + auto visual_projection = std::dynamic_pointer_cast(blocks["visual_projection"]); - auto x = visual_model->forward(ctx, pixel_values); // [N, embed_dim] + auto x = vision_model->forward(ctx, pixel_values); // [N, hidden_size] x = visual_projection->forward(ctx, x); // [N, projection_dim] return x; // [N, projection_dim] @@ -1029,6 +1114,205 @@ struct FrozenCLIPEmbedderWithCustomWords : public GGMLModule { return tokenize(text, text_model.n_token, padding); } + std::tuple, std::vector, std::vector> + tokenize_with_trigger_token(std::string text, + int num_input_imgs, + int32_t image_token, + bool padding = false) { + return tokenize_with_trigger_token(text, num_input_imgs, image_token, + text_model.n_token, padding); + } + + std::vector convert_token_to_id(std::string text) { + auto on_new_token_cb = [&](std::string& str, std::vector& bpe_tokens) -> bool { + size_t word_end = str.find(","); + std::string embd_name = word_end == std::string::npos ? str : str.substr(0, word_end); + embd_name = trim(embd_name); + std::string embd_path = get_full_path(embd_dir, embd_name + ".pt"); + if (embd_path.size() == 0) { + embd_path = get_full_path(embd_dir, embd_name + ".ckpt"); + } + if (embd_path.size() == 0) { + embd_path = get_full_path(embd_dir, embd_name + ".safetensors"); + } + if (embd_path.size() > 0) { + if (load_embedding(embd_name, embd_path, bpe_tokens)) { + if (word_end != std::string::npos) { + str = str.substr(word_end); + } else { + str = ""; + } + return true; + } + } + return false; + }; + std::vector curr_tokens = tokenizer.encode(text, on_new_token_cb); + return curr_tokens; + } + + std::string decode(const std::vector& tokens) { + return tokenizer.decode(tokens); + } + + void pad_tokens(std::vector& tokens, + std::vector& weights, + size_t max_length = 0, + bool padding = false) { + if (max_length > 0 && padding) { + size_t n = std::ceil(tokens.size() * 1.0 / (max_length - 2)); + if (n == 0) { + n = 1; + } + size_t length = max_length * n; + LOG_DEBUG("token length: %llu", length); + std::vector new_tokens; + std::vector new_weights; + new_tokens.push_back(BOS_TOKEN_ID); + new_weights.push_back(1.0); + int token_idx = 0; + for (int i = 1; i < length; i++) { + if (token_idx >= tokens.size()) { + break; + } + if (i % max_length == 0) { + new_tokens.push_back(BOS_TOKEN_ID); + new_weights.push_back(1.0); + } else if (i % max_length == max_length - 1) { + new_tokens.push_back(EOS_TOKEN_ID); + new_weights.push_back(1.0); + } else { + new_tokens.push_back(tokens[token_idx]); + new_weights.push_back(weights[token_idx]); + token_idx++; + } + } + + new_tokens.push_back(EOS_TOKEN_ID); + new_weights.push_back(1.0); + tokens = new_tokens; + weights = new_weights; + + if (padding) { + int pad_token_id = PAD_TOKEN_ID; + if (version == VERSION_2_x) { + pad_token_id = 0; + } + tokens.insert(tokens.end(), length - tokens.size(), pad_token_id); + weights.insert(weights.end(), length - weights.size(), 1.0); + } + } + } + + std::tuple, std::vector, std::vector> + tokenize_with_trigger_token(std::string text, + int num_input_imgs, + int32_t image_token, + size_t max_length = 0, + bool padding = false) { + auto parsed_attention = parse_prompt_attention(text); + + { + std::stringstream ss; + ss << "["; + for (const auto& item : parsed_attention) { + ss << "['" << item.first << "', " << item.second << "], "; + } + ss << "]"; + LOG_DEBUG("parse '%s' to %s", text.c_str(), ss.str().c_str()); + } + + auto on_new_token_cb = [&](std::string& str, std::vector& bpe_tokens) -> bool { + size_t word_end = str.find(","); + std::string embd_name = word_end == std::string::npos ? str : str.substr(0, word_end); + embd_name = trim(embd_name); + std::string embd_path = get_full_path(embd_dir, embd_name + ".pt"); + if (embd_path.size() == 0) { + embd_path = get_full_path(embd_dir, embd_name + ".ckpt"); + } + if (embd_path.size() == 0) { + embd_path = get_full_path(embd_dir, embd_name + ".safetensors"); + } + if (embd_path.size() > 0) { + if (load_embedding(embd_name, embd_path, bpe_tokens)) { + if (word_end != std::string::npos) { + str = str.substr(word_end); + } else { + str = ""; + } + return true; + } + } + return false; + }; + + std::vector tokens; + std::vector weights; + std::vector class_token_mask; + int32_t class_idx = -1, tokens_acc = 0; + for (const auto& item : parsed_attention) { + std::vector class_token_index; + std::vector clean_input_ids; + const std::string& curr_text = item.first; + float curr_weight = item.second; + // printf(" %s: %f \n", curr_text.c_str(), curr_weight); + std::vector curr_tokens = tokenizer.encode(curr_text, on_new_token_cb); + int32_t clean_index = 0; + for (uint32_t i = 0; i < curr_tokens.size(); i++) { + int token_id = curr_tokens[i]; + if (token_id == image_token) + class_token_index.push_back(clean_index - 1); + else { + clean_input_ids.push_back(token_id); + clean_index++; + } + } + // GGML_ASSERT(class_token_index.size() == 1); // PhotoMaker currently does not support multiple + // trigger words in a single prompt. + if (class_token_index.size() == 1) { + // Expand the class word token and corresponding mask + int class_token = clean_input_ids[class_token_index[0]]; + class_idx = tokens_acc + class_token_index[0]; + std::vector clean_input_ids_tmp; + for (uint32_t i = 0; i < class_token_index[0]; i++) + clean_input_ids_tmp.push_back(clean_input_ids[i]); + for (uint32_t i = 0; i < num_input_imgs; i++) + clean_input_ids_tmp.push_back(class_token); + for (uint32_t i = class_token_index[0] + 1; i < clean_input_ids.size(); i++) + clean_input_ids_tmp.push_back(clean_input_ids[i]); + clean_input_ids.clear(); + clean_input_ids = clean_input_ids_tmp; + } + tokens_acc += clean_index; + tokens.insert(tokens.end(), clean_input_ids.begin(), clean_input_ids.end()); + weights.insert(weights.end(), clean_input_ids.size(), curr_weight); + } + tokens.insert(tokens.begin(), BOS_TOKEN_ID); + weights.insert(weights.begin(), 1.0); + + pad_tokens(tokens, weights, max_length, padding); + + for (uint32_t i = 0; i < tokens.size(); i++) { + if (class_idx + 1 <= i && i < class_idx + 1 + num_input_imgs) + class_token_mask.push_back(true); + else + class_token_mask.push_back(false); + } + + // printf("["); + // for (int i = 0; i < tokens.size(); i++) { + // printf("%d, ", class_token_mask[i] ? 1 : 0); + // } + // printf("]\n"); + + // for (int i = 0; i < tokens.size(); i++) { + // std::cout << tokens[i] << ":" << weights[i] << ", "; + // } + // std::cout << std::endl; + + return std::make_tuple(tokens, weights, class_token_mask); + } + std::pair, std::vector> tokenize(std::string text, size_t max_length = 0, bool padding = false) { @@ -1078,49 +1362,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public GGMLModule { weights.insert(weights.end(), curr_tokens.size(), curr_weight); } - if (max_length > 0 && padding) { - size_t n = std::ceil(tokens.size() * 1.0 / (max_length - 2)); - if (n == 0) { - n = 1; - } - size_t length = max_length * n; - LOG_DEBUG("token length: %llu", length); - std::vector new_tokens; - std::vector new_weights; - new_tokens.push_back(BOS_TOKEN_ID); - new_weights.push_back(1.0); - int token_idx = 0; - for (int i = 1; i < length; i++) { - if (token_idx >= tokens.size()) { - break; - } - if (i % max_length == 0) { - new_tokens.push_back(BOS_TOKEN_ID); - new_weights.push_back(1.0); - } else if (i % max_length == max_length - 1) { - new_tokens.push_back(EOS_TOKEN_ID); - new_weights.push_back(1.0); - } else { - new_tokens.push_back(tokens[token_idx]); - new_weights.push_back(weights[token_idx]); - token_idx++; - } - } - - new_tokens.push_back(EOS_TOKEN_ID); - new_weights.push_back(1.0); - tokens = new_tokens; - weights = new_weights; - - if (padding) { - int pad_token_id = PAD_TOKEN_ID; - if (version == VERSION_2_x) { - pad_token_id = 0; - } - tokens.insert(tokens.end(), length - tokens.size(), pad_token_id); - weights.insert(weights.end(), length - weights.size(), 1.0); - } - } + pad_tokens(tokens, weights, max_length, padding); // for (int i = 0; i < tokens.size(); i++) { // std::cout << tokens[i] << ":" << weights[i] << ", "; @@ -1132,10 +1374,10 @@ struct FrozenCLIPEmbedderWithCustomWords : public GGMLModule { }; struct FrozenCLIPVisionEmbedder : public GGMLModule { - CLIPVisionModel vision_model; + CLIPVisionModelProjection vision_model; FrozenCLIPVisionEmbedder(ggml_backend_t backend, ggml_type wtype) - : GGMLModule(backend, wtype) { + : vision_model(OPEN_CLIP_VIT_H_14, true), GGMLModule(backend, wtype) { vision_model.init(params_ctx, wtype); } @@ -1152,7 +1394,7 @@ struct FrozenCLIPVisionEmbedder : public GGMLModule { } void get_param_tensors(std::map& tensors, const std::string prefix) { - vision_model.get_param_tensors(tensors, prefix + "transformer.visual_model"); + vision_model.get_param_tensors(tensors, prefix + "transformer"); } struct ggml_cgraph* build_graph(struct ggml_tensor* pixel_values) { diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index 6b74b69..e5b082c 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -10,6 +10,7 @@ #include "stable-diffusion.h" #define STB_IMAGE_IMPLEMENTATION +#define STB_IMAGE_STATIC #include "stb_image.h" #define STB_IMAGE_WRITE_IMPLEMENTATION @@ -65,6 +66,8 @@ struct SDParams { std::string esrgan_path; std::string controlnet_path; std::string embeddings_path; + std::string stacked_id_embeddings_path; + std::string input_id_images_path; sd_type_t wtype = SD_TYPE_COUNT; std::string lora_model_dir; std::string output_path = "output.png"; @@ -73,12 +76,13 @@ struct SDParams { std::string prompt; std::string negative_prompt; - float min_cfg = 1.0f; - float cfg_scale = 7.0f; - int clip_skip = -1; // <= 0 represents unspecified - int width = 512; - int height = 512; - int batch_count = 1; + float min_cfg = 1.0f; + float cfg_scale = 7.0f; + float style_ratio = 20.f; + int clip_skip = -1; // <= 0 represents unspecified + int width = 512; + int height = 512; + int batch_count = 1; int video_frames = 6; int motion_bucket_id = 127; @@ -95,6 +99,9 @@ struct SDParams { bool verbose = false; bool vae_tiling = false; bool control_net_cpu = false; + bool normalize_input = false; + bool clip_on_cpu = false; + bool vae_on_cpu = false; bool canny_preprocess = false; int upscale_repeats = 1; }; @@ -110,10 +117,16 @@ void print_params(SDParams params) { printf(" esrgan_path: %s\n", params.esrgan_path.c_str()); printf(" controlnet_path: %s\n", params.controlnet_path.c_str()); printf(" embeddings_path: %s\n", params.embeddings_path.c_str()); + printf(" stacked_id_embeddings_path: %s\n", params.stacked_id_embeddings_path.c_str()); + printf(" input_id_images_path: %s\n", params.input_id_images_path.c_str()); + printf(" style ratio: %.2f\n", params.style_ratio); + printf(" normzalize input image : %s\n", params.normalize_input ? "true" : "false"); printf(" output_path: %s\n", params.output_path.c_str()); printf(" init_img: %s\n", params.input_path.c_str()); printf(" control_image: %s\n", params.control_image_path.c_str()); + printf(" clip on cpu: %s\n", params.clip_on_cpu ? "true" : "false"); printf(" controlnet cpu: %s\n", params.control_net_cpu ? "true" : "false"); + printf(" vae decoder on cpu:%s\n", params.vae_on_cpu ? "true" : "false"); printf(" strength(control): %.2f\n", params.control_strength); printf(" prompt: %s\n", params.prompt.c_str()); printf(" negative_prompt: %s\n", params.negative_prompt.c_str()); @@ -146,6 +159,9 @@ void print_usage(int argc, const char* argv[]) { printf(" --taesd [TAESD_PATH] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)\n"); printf(" --control-net [CONTROL_PATH] path to control net model\n"); printf(" --embd-dir [EMBEDDING_PATH] path to embeddings.\n"); + printf(" --stacked-id-embd-dir [DIR] path to PHOTOMAKER stacked id embeddings.\n"); + printf(" --input-id-images-dir [DIR] path to PHOTOMAKER input id images dir.\n"); + printf(" --normalize-input normalize PHOTOMAKER input id images\n"); printf(" --upscale-model [ESRGAN_PATH] path to esrgan model. Upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now.\n"); printf(" --upscale-repeats Run the ESRGAN upscaler this many times (default 1)\n"); printf(" --type [TYPE] weight type (f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0)\n"); @@ -158,6 +174,7 @@ void print_usage(int argc, const char* argv[]) { printf(" -n, --negative-prompt PROMPT the negative prompt (default: \"\")\n"); printf(" --cfg-scale SCALE unconditional guidance scale: (default: 7.0)\n"); printf(" --strength STRENGTH strength for noising/unnoising (default: 0.75)\n"); + printf(" --style-ratio STYLE-RATIO strength for keeping input identity (default: 20%%)\n"); printf(" --control-strength STRENGTH strength to apply Control Net (default: 0.9)\n"); printf(" 1.0 corresponds to full destruction of information in init image\n"); printf(" -H, --height H image height, in pixel space (default: 512)\n"); @@ -244,6 +261,18 @@ void parse_args(int argc, const char** argv, SDParams& params) { break; } params.embeddings_path = argv[i]; + } else if (arg == "--stacked-id-embd-dir") { + if (++i >= argc) { + invalid_arg = true; + break; + } + params.stacked_id_embeddings_path = argv[i]; + } else if (arg == "--input-id-images-dir") { + if (++i >= argc) { + invalid_arg = true; + break; + } + params.input_id_images_path = argv[i]; } else if (arg == "--type") { if (++i >= argc) { invalid_arg = true; @@ -327,6 +356,12 @@ void parse_args(int argc, const char** argv, SDParams& params) { break; } params.strength = std::stof(argv[i]); + } else if (arg == "--style-ratio") { + if (++i >= argc) { + invalid_arg = true; + break; + } + params.style_ratio = std::stof(argv[i]); } else if (arg == "--control-strength") { if (++i >= argc) { invalid_arg = true; @@ -361,6 +396,12 @@ void parse_args(int argc, const char** argv, SDParams& params) { params.vae_tiling = true; } else if (arg == "--control-net-cpu") { params.control_net_cpu = true; + } else if (arg == "--normalize-input") { + params.normalize_input = true; + } else if (arg == "--clip-on-cpu") { + params.clip_on_cpu = true; // will slow down get_learned_condiotion but necessary for low MEM GPUs + } else if (arg == "--vae-on-cpu") { + params.vae_on_cpu = true; // will slow down latent decoding but necessary for low MEM GPUs } else if (arg == "--canny") { params.canny_preprocess = true; } else if (arg == "-b" || arg == "--batch-count") { @@ -613,6 +654,7 @@ int main(int argc, const char* argv[]) { params.controlnet_path.c_str(), params.lora_model_dir.c_str(), params.embeddings_path.c_str(), + params.stacked_id_embeddings_path.c_str(), vae_decode_only, params.vae_tiling, true, @@ -620,7 +662,9 @@ int main(int argc, const char* argv[]) { params.wtype, params.rng_type, params.schedule, - params.control_net_cpu); + params.clip_on_cpu, + params.control_net_cpu, + params.vae_on_cpu); if (sd_ctx == NULL) { printf("new_sd_ctx_t failed\n"); @@ -664,7 +708,10 @@ int main(int argc, const char* argv[]) { params.seed, params.batch_count, control_image, - params.control_strength); + params.control_strength, + params.style_ratio, + params.normalize_input, + params.input_id_images_path.c_str()); } else { sd_image_t input_image = {(uint32_t)params.width, (uint32_t)params.height, diff --git a/ggml_extend.hpp b/ggml_extend.hpp index c2d6552..25a7cdc 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -80,8 +80,27 @@ __STATIC_INLINE__ ggml_fp16_t ggml_tensor_get_f16(const ggml_tensor* tensor, int return *(ggml_fp16_t*)((char*)(tensor->data) + i * tensor->nb[3] + j * tensor->nb[2] + k * tensor->nb[1] + l * tensor->nb[0]); } -__STATIC_INLINE__ void print_ggml_tensor(struct ggml_tensor* tensor, bool shape_only = false) { - printf("shape(%zu, %zu, %zu, %zu)\n", tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]); +static struct ggml_tensor* get_tensor_from_graph(struct ggml_cgraph* gf, const char* name) { + struct ggml_tensor* res = NULL; + for (int i = 0; i < gf->n_nodes; i++) { + // printf("%d, %s \n", i, gf->nodes[i]->name); + if (strcmp(ggml_get_name(gf->nodes[i]), name) == 0) { + res = gf->nodes[i]; + break; + } + } + for (int i = 0; i < gf->n_leafs; i++) { + // printf("%d, %s \n", i, gf->leafs[i]->name); + if (strcmp(ggml_get_name(gf->leafs[i]), name) == 0) { + res = gf->leafs[i]; + break; + } + } + return res; +} + +__STATIC_INLINE__ void print_ggml_tensor(struct ggml_tensor* tensor, bool shape_only = false, const char* mark = "") { + printf("%s (%s): shape(%zu, %zu, %zu, %zu)\n", mark, ggml_type_name(tensor->type), tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]); fflush(stdout); if (shape_only) { return; @@ -217,6 +236,23 @@ __STATIC_INLINE__ uint8_t* sd_tensor_to_image(struct ggml_tensor* input) { return image_data; } +__STATIC_INLINE__ uint8_t* sd_tensor_to_mul_image(struct ggml_tensor* input, int idx) { + int64_t width = input->ne[0]; + int64_t height = input->ne[1]; + int64_t channels = input->ne[2]; + GGML_ASSERT(channels == 3 && input->type == GGML_TYPE_F32); + uint8_t* image_data = (uint8_t*)malloc(width * height * channels); + for (int iy = 0; iy < height; iy++) { + for (int ix = 0; ix < width; ix++) { + for (int k = 0; k < channels; k++) { + float value = ggml_tensor_get_f32(input, ix, iy, k, idx); + *(image_data + iy * width * channels + ix * channels + k) = (uint8_t)(value * 255.0f); + } + } + } + return image_data; +} + __STATIC_INLINE__ void sd_image_to_tensor(const uint8_t* image_data, struct ggml_tensor* output, bool scale = true) { @@ -237,6 +273,28 @@ __STATIC_INLINE__ void sd_image_to_tensor(const uint8_t* image_data, } } +__STATIC_INLINE__ void sd_mul_images_to_tensor(const uint8_t* image_data, + struct ggml_tensor* output, + int idx, + float* mean = NULL, + float* std = NULL) { + int64_t width = output->ne[0]; + int64_t height = output->ne[1]; + int64_t channels = output->ne[2]; + GGML_ASSERT(channels == 3 && output->type == GGML_TYPE_F32); + for (int iy = 0; iy < height; iy++) { + for (int ix = 0; ix < width; ix++) { + for (int k = 0; k < channels; k++) { + int value = *(image_data + iy * width * channels + ix * channels + k); + float pixel_val = value / 255.0f; + if (mean != NULL && std != NULL) + pixel_val = (pixel_val - mean[k]) / std[k]; + ggml_tensor_set_f32(output, pixel_val, ix, iy, k, idx); + } + } + } +} + __STATIC_INLINE__ void sd_image_f32_to_tensor(const float* image_data, struct ggml_tensor* output, bool scale = true) { @@ -247,7 +305,7 @@ __STATIC_INLINE__ void sd_image_f32_to_tensor(const float* image_data, for (int iy = 0; iy < height; iy++) { for (int ix = 0; ix < width; ix++) { for (int k = 0; k < channels; k++) { - float value = *(image_data + iy * width * channels + ix * channels + k); + int value = *(image_data + iy * width * channels + ix * channels + k); if (scale) { value /= 255.f; } @@ -771,7 +829,10 @@ protected: // compute the required memory size_t compute_buffer_size = ggml_gallocr_get_buffer_size(compute_allocr, 0); - LOG_DEBUG("%s compute buffer size: %.2f MB", get_desc().c_str(), compute_buffer_size / 1024.0 / 1024.0); + LOG_DEBUG("%s compute buffer size: %.2f MB(%s)", + get_desc().c_str(), + compute_buffer_size / 1024.0 / 1024.0, + ggml_backend_is_cpu(backend) ? "RAM" : "VRAM"); return true; } @@ -816,8 +877,11 @@ public: return false; } size_t params_buffer_size = ggml_backend_buffer_get_size(params_buffer); - LOG_DEBUG("%s params backend buffer size = % 6.2f MB (%i tensors)", - get_desc().c_str(), params_buffer_size / (1024.0 * 1024.0), num_tensors); + LOG_DEBUG("%s params backend buffer size = % 6.2f MB(%s) (%i tensors)", + get_desc().c_str(), + params_buffer_size / (1024.0 * 1024.0), + ggml_backend_is_cpu(backend) ? "RAM" : "VRAM", + num_tensors); return true; } @@ -865,11 +929,8 @@ public: alloc_compute_buffer(get_graph); reset_compute_ctx(); struct ggml_cgraph* gf = get_graph(); - GGML_ASSERT(ggml_gallocr_alloc_graph(compute_allocr, gf)); - cpy_data_to_backend_tensor(); - if (ggml_backend_is_cpu(backend)) { ggml_backend_cpu_set_n_threads(backend, n_threads); } @@ -879,13 +940,11 @@ public: ggml_backend_metal_set_n_cb(backend, n_threads); } #endif - ggml_backend_graph_compute(backend, gf); #ifdef GGML_PERF ggml_graph_print(gf); #endif - if (output != NULL) { auto result = gf->nodes[gf->n_nodes - 1]; if (*output == NULL && output_ctx != NULL) { @@ -977,13 +1036,11 @@ public: } for (auto& pair : blocks) { auto& block = pair.second; - block->get_param_tensors(tensors, prefix + pair.first); } for (auto& pair : params) { - struct ggml_tensor* param = pair.second; - + struct ggml_tensor* param = pair.second; tensors[prefix + pair.first] = pair.second; } } @@ -1243,11 +1300,10 @@ public: struct ggml_tensor* kqv = ggml_nn_attention(ctx, q, k, v, mask); // [N * n_head, n_token, d_head] kqv = ggml_reshape_4d(ctx, kqv, d_head, n_token, n_head, N); - kqv = ggml_cont(ctx, ggml_permute(ctx, kqv, 0, 2, 1, 3)); // [N, n_token, n_head, d_head] + kqv = ggml_cont(ctx, ggml_permute(ctx, kqv, 0, 2, 1, 3)); // [N, n_token, n_head, d_head] + x = ggml_reshape_3d(ctx, kqv, d_head * n_head, n_token, N); // [N, n_token, d_head * n_head] - x = ggml_reshape_3d(ctx, kqv, d_head * n_head, n_token, N); // [N * n_token, d_head * n_head] - - x = out_proj->forward(ctx, x); + x = out_proj->forward(ctx, x); // [N, n_token, embed_dim] return x; } }; diff --git a/lora.hpp b/lora.hpp index 7eb42e1..1336e82 100644 --- a/lora.hpp +++ b/lora.hpp @@ -14,9 +14,10 @@ struct LoraModel : public GGMLModule { LoraModel(ggml_backend_t backend, ggml_type wtype, - const std::string file_path = "") + const std::string& file_path = "", + const std::string& prefix = "") : file_path(file_path), GGMLModule(backend, wtype) { - if (!model_loader.init_from_file(file_path)) { + if (!model_loader.init_from_file(file_path, prefix)) { load_failed = true; } } @@ -33,8 +34,7 @@ struct LoraModel : public GGMLModule { return model_loader.get_params_mem_size(NULL); } - - bool load_from_file() { + bool load_from_file(bool filter_tensor = false) { LOG_INFO("loading LoRA from '%s'", file_path.c_str()); if (load_failed) { @@ -46,6 +46,11 @@ struct LoraModel : public GGMLModule { auto on_new_tensor_cb = [&](const TensorStorage& tensor_storage, ggml_tensor** dst_tensor) -> bool { const std::string& name = tensor_storage.name; + if (filter_tensor && !contains(name, "lora")) { + // LOG_INFO("skipping LoRA tesnor '%s'", name.c_str()); + return true; + } + if (dry_run) { struct ggml_tensor* real = ggml_new_tensor(params_ctx, tensor_storage.type, @@ -66,7 +71,6 @@ struct LoraModel : public GGMLModule { dry_run = false; model_loader.load_tensors(on_new_tensor_cb, backend); - LOG_DEBUG("finished loaded lora"); return true; } @@ -85,6 +89,10 @@ struct LoraModel : public GGMLModule { } k_tensor = k_tensor.substr(0, k_pos); replace_all_chars(k_tensor, '.', '_'); + // LOG_DEBUG("k_tensor %s", k_tensor.c_str()); + if (k_tensor == "model_diffusion_model_output_blocks_2_2_conv") { // fix for SDXL + k_tensor = "model_diffusion_model_output_blocks_2_1_conv"; + } std::string lora_up_name = "lora." + k_tensor + ".lora_up.weight"; std::string lora_down_name = "lora." + k_tensor + ".lora_down.weight"; std::string alpha_name = "lora." + k_tensor + ".alpha"; diff --git a/model.cpp b/model.cpp index 5925a7d..78b1dc3 100644 --- a/model.cpp +++ b/model.cpp @@ -108,14 +108,14 @@ std::unordered_map open_clip_to_hf_clip_model = { {"model.positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"}, {"model.token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"}, {"model.text_projection", "transformer.text_model.text_projection"}, - {"model.visual.class_embedding", "transformer.visual_model.embeddings.class_embedding"}, - {"model.visual.conv1.weight", "transformer.visual_model.embeddings.patch_embedding.weight"}, - {"model.visual.ln_post.bias", "transformer.visual_model.post_layernorm.bias"}, - {"model.visual.ln_post.weight", "transformer.visual_model.post_layernorm.weight"}, - {"model.visual.ln_pre.bias", "transformer.visual_model.pre_layernorm.bias"}, - {"model.visual.ln_pre.weight", "transformer.visual_model.pre_layernorm.weight"}, - {"model.visual.positional_embedding", "transformer.visual_model.embeddings.position_embedding.weight"}, - {"model.visual.proj", "transformer.visual_model.visual_projection"}, + {"model.visual.class_embedding", "transformer.vision_model.embeddings.class_embedding"}, + {"model.visual.conv1.weight", "transformer.vision_model.embeddings.patch_embedding.weight"}, + {"model.visual.ln_post.bias", "transformer.vision_model.post_layernorm.bias"}, + {"model.visual.ln_post.weight", "transformer.vision_model.post_layernorm.weight"}, + {"model.visual.ln_pre.bias", "transformer.vision_model.pre_layernorm.bias"}, + {"model.visual.ln_pre.weight", "transformer.vision_model.pre_layernorm.weight"}, + {"model.visual.positional_embedding", "transformer.vision_model.embeddings.position_embedding.weight"}, + {"model.visual.proj", "transformer.visual_projection.weight"}, }; std::unordered_map open_clip_to_hk_clip_resblock = { @@ -157,6 +157,10 @@ std::string convert_open_clip_to_hf_clip(const std::string& name) { } else if (starts_with(new_name, "cond_stage_model.")) { prefix = "cond_stage_model."; new_name = new_name.substr(strlen("cond_stage_model.")); + } else if (ends_with(new_name, "vision_model.visual_projection.weight")) { + prefix = new_name.substr(0, new_name.size() - strlen("vision_model.visual_projection.weight")); + new_name = prefix + "visual_projection.weight"; + return new_name; } else { return new_name; } @@ -186,7 +190,7 @@ std::string convert_open_clip_to_hf_clip(const std::string& name) { replace_suffix(); open_clip_resblock_prefix = "model.visual.transformer.resblocks."; - hf_clip_resblock_prefix = "transformer.visual_model.encoder.layers."; + hf_clip_resblock_prefix = "transformer.vision_model.encoder.layers."; replace_suffix(); @@ -248,7 +252,7 @@ std::unordered_map> su }, }; -std::string convert_diffusers_name_to_compvis(const std::string& key, char seq) { +std::string convert_diffusers_name_to_compvis(std::string key, char seq) { std::vector m; auto match = [](std::vector& match_list, const std::regex& regex, const std::string& key) { @@ -282,6 +286,11 @@ std::string convert_diffusers_name_to_compvis(const std::string& key, char seq) return inner_key; }; + // convert attn to out + if (ends_with(key, "to_out")) { + key += format("%c0", seq); + } + // unet if (match(m, std::regex(format("unet%cconv_in(.*)", seq)), key)) { return format("model%cdiffusion_model%cinput_blocks%c0%c0", seq, seq, seq, seq) + m[0]; @@ -391,8 +400,8 @@ std::string convert_diffusers_name_to_compvis(const std::string& key, char seq) } std::string convert_tensor_name(const std::string& name) { - std::string new_name; - if (starts_with(name, "cond_stage_model.") || starts_with(name, "conditioner.embedders.")) { + std::string new_name = name; + if (starts_with(name, "cond_stage_model.") || starts_with(name, "conditioner.embedders.") || ends_with(name, ".vision_model.visual_projection.weight")) { new_name = convert_open_clip_to_hf_clip(name); } else if (starts_with(name, "first_stage_model.decoder")) { new_name = convert_vae_decoder_name(name); @@ -416,6 +425,26 @@ std::string convert_tensor_name(const std::string& name) { } else { new_name = name; } + } else if (contains(name, "lora_up") || contains(name, "lora_down") || contains(name, "lora.up") || contains(name, "lora.down")) { + size_t pos = new_name.find(".processor"); + if (pos != std::string::npos) { + new_name.replace(pos, strlen(".processor"), ""); + } + pos = new_name.find_last_of('_'); + if (pos != std::string::npos) { + std::string name_without_network_parts = new_name.substr(0, pos); + std::string network_part = new_name.substr(pos + 1); + // LOG_DEBUG("%s %s", name_without_network_parts.c_str(), network_part.c_str()); + std::string new_key = convert_diffusers_name_to_compvis(name_without_network_parts, '.'); + replace_all_chars(new_key, '.', '_'); + if (starts_with(network_part, "lora.")) { + network_part = "lora_" + network_part.substr(5); + } + if (new_key.size() > 0) { + new_name = "lora." + new_key + "." + network_part; + } + // LOG_DEBUG("new name: %s", new_name.c_str()); + } } else if (starts_with(name, "unet") || starts_with(name, "vae") || starts_with(name, "te")) { // for diffuser size_t pos = name.find_last_of('.'); if (pos != std::string::npos) { @@ -830,7 +859,6 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const } TensorStorage tensor_storage(prefix + name, type, ne, n_dims, file_index, ST_HEADER_SIZE_LEN + header_size_ + begin); - tensor_storage.reverse_ne(); size_t tensor_data_size = end - begin; @@ -1169,7 +1197,9 @@ bool ModelLoader::parse_data_pkl(uint8_t* buffer, if (reader.phase == PickleTensorReader::READ_DIMENS) { reader.tensor_storage.reverse_ne(); reader.tensor_storage.file_index = file_index; - reader.tensor_storage.name = prefix + reader.tensor_storage.name; + // if(strcmp(prefix.c_str(), "scarlett") == 0) + // printf(" got tensor %s \n ", reader.tensor_storage.name.c_str()); + reader.tensor_storage.name = prefix + reader.tensor_storage.name; tensor_storages.push_back(reader.tensor_storage); // LOG_DEBUG("%s", reader.tensor_storage.name.c_str()); // reset @@ -1272,7 +1302,8 @@ std::string ModelLoader::load_merges() { return merges_utf8_str; } -void remove_duplicates(std::vector& vec) { +std::vector remove_duplicates(const std::vector& vec) { + std::vector res; std::unordered_map name_to_index_map; for (size_t i = 0; i < vec.size(); ++i) { @@ -1280,13 +1311,16 @@ void remove_duplicates(std::vector& vec) { auto it = name_to_index_map.find(current_name); if (it != name_to_index_map.end()) { - vec[it->second] = vec[i]; + res[it->second] = vec[i]; } else { name_to_index_map[current_name] = i; + res.push_back(vec[i]); } } - vec.resize(name_to_index_map.size()); + // vec.resize(name_to_index_map.size()); + + return res; } bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend_t backend) { @@ -1300,7 +1334,9 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend preprocess_tensor(tensor_storage, processed_tensor_storages); } - remove_duplicates(processed_tensor_storages); + std::vector dedup = remove_duplicates(processed_tensor_storages); + processed_tensor_storages = dedup; + bool success = true; for (size_t file_index = 0; file_index < file_paths_.size(); file_index++) { std::string file_path = file_paths_[file_index]; @@ -1362,7 +1398,6 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend if (tensor_storage.file_index != file_index) { continue; } - ggml_tensor* dst_tensor = NULL; success = on_new_tensor_cb(tensor_storage, &dst_tensor); diff --git a/model.h b/model.h index c50bc2a..5b41421 100644 --- a/model.h +++ b/model.h @@ -7,6 +7,7 @@ #include #include #include +#include #include #include "ggml/ggml-backend.h" diff --git a/pmid.hpp b/pmid.hpp new file mode 100644 index 0000000..54b0142 --- /dev/null +++ b/pmid.hpp @@ -0,0 +1,305 @@ +#ifndef __PMI_HPP__ +#define __PMI_HPP__ + +#include "ggml_extend.hpp" + +#include "clip.hpp" +#include "lora.hpp" + +struct FuseBlock : public GGMLBlock { + // network hparams + int in_dim; + int out_dim; + int hidden_dim; + bool use_residue; + +public: + FuseBlock(int i_d, int o_d, int h_d, bool use_residue = true) + : in_dim(i_d), out_dim(o_d), hidden_dim(h_d), use_residue(use_residue) { + blocks["fc1"] = std::shared_ptr(new Linear(in_dim, hidden_dim, true)); + blocks["fc2"] = std::shared_ptr(new Linear(hidden_dim, out_dim, true)); + blocks["layernorm"] = std::shared_ptr(new LayerNorm(in_dim)); + } + + struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { + // x: [N, channels, h, w] + + auto fc1 = std::dynamic_pointer_cast(blocks["fc1"]); + auto fc2 = std::dynamic_pointer_cast(blocks["fc2"]); + auto layer_norm = std::dynamic_pointer_cast(blocks["layernorm"]); + + struct ggml_tensor* r = x; + // x = ggml_nn_layer_norm(ctx, x, ln_w, ln_b); + x = layer_norm->forward(ctx, x); + // x = ggml_add(ctx, ggml_mul_mat(ctx, fc1_w, x), fc1_b); + x = fc1->forward(ctx, x); + x = ggml_gelu_inplace(ctx, x); + x = fc2->forward(ctx, x); + // x = ggml_add(ctx, ggml_mul_mat(ctx, fc2_w, x), fc2_b); + if (use_residue) + x = ggml_add(ctx, x, r); + return x; + } +}; + +struct FuseModule : public GGMLBlock { + // network hparams + int embed_dim; + +public: + FuseModule(int imb_d) + : embed_dim(imb_d) { + blocks["mlp1"] = std::shared_ptr(new FuseBlock(imb_d * 2, imb_d, imb_d, false)); + blocks["mlp2"] = std::shared_ptr(new FuseBlock(imb_d, imb_d, imb_d, true)); + blocks["layer_norm"] = std::shared_ptr(new LayerNorm(embed_dim)); + } + + struct ggml_tensor* fuse_fn(struct ggml_context* ctx, + struct ggml_tensor* prompt_embeds, + struct ggml_tensor* id_embeds) { + auto mlp1 = std::dynamic_pointer_cast(blocks["mlp1"]); + auto mlp2 = std::dynamic_pointer_cast(blocks["mlp2"]); + auto layer_norm = std::dynamic_pointer_cast(blocks["layer_norm"]); + + auto prompt_embeds0 = ggml_cont(ctx, ggml_permute(ctx, prompt_embeds, 2, 0, 1, 3)); + auto id_embeds0 = ggml_cont(ctx, ggml_permute(ctx, id_embeds, 2, 0, 1, 3)); + // concat is along dim 2 + auto stacked_id_embeds = ggml_concat(ctx, prompt_embeds0, id_embeds0); + stacked_id_embeds = ggml_cont(ctx, ggml_permute(ctx, stacked_id_embeds, 1, 2, 0, 3)); + + // stacked_id_embeds = mlp1.forward(ctx, stacked_id_embeds); + // stacked_id_embeds = ggml_add(ctx, stacked_id_embeds, prompt_embeds); + // stacked_id_embeds = mlp2.forward(ctx, stacked_id_embeds); + // stacked_id_embeds = ggml_nn_layer_norm(ctx, stacked_id_embeds, ln_w, ln_b); + + stacked_id_embeds = mlp1->forward(ctx, stacked_id_embeds); + stacked_id_embeds = ggml_add(ctx, stacked_id_embeds, prompt_embeds); + stacked_id_embeds = mlp2->forward(ctx, stacked_id_embeds); + stacked_id_embeds = layer_norm->forward(ctx, stacked_id_embeds); + + return stacked_id_embeds; + } + + struct ggml_tensor* forward(struct ggml_context* ctx, + struct ggml_tensor* prompt_embeds, + struct ggml_tensor* id_embeds, + struct ggml_tensor* class_tokens_mask, + struct ggml_tensor* class_tokens_mask_pos, + struct ggml_tensor* left, + struct ggml_tensor* right) { + // x: [N, channels, h, w] + + struct ggml_tensor* valid_id_embeds = id_embeds; + // # slice out the image token embeddings + // print_ggml_tensor(class_tokens_mask_pos, false); + ggml_set_name(class_tokens_mask_pos, "class_tokens_mask_pos"); + ggml_set_name(prompt_embeds, "prompt_embeds"); + // print_ggml_tensor(valid_id_embeds, true, "valid_id_embeds"); + // print_ggml_tensor(class_tokens_mask_pos, true, "class_tokens_mask_pos"); + struct ggml_tensor* image_token_embeds = ggml_get_rows(ctx, prompt_embeds, class_tokens_mask_pos); + ggml_set_name(image_token_embeds, "image_token_embeds"); + struct ggml_tensor* stacked_id_embeds = fuse_fn(ctx, image_token_embeds, valid_id_embeds); + + stacked_id_embeds = ggml_cont(ctx, ggml_permute(ctx, stacked_id_embeds, 0, 2, 1, 3)); + if (left && right) { + stacked_id_embeds = ggml_concat(ctx, left, stacked_id_embeds); + stacked_id_embeds = ggml_concat(ctx, stacked_id_embeds, right); + } else if (left) { + stacked_id_embeds = ggml_concat(ctx, left, stacked_id_embeds); + } else if (right) { + stacked_id_embeds = ggml_concat(ctx, stacked_id_embeds, right); + } + stacked_id_embeds = ggml_cont(ctx, ggml_permute(ctx, stacked_id_embeds, 0, 2, 1, 3)); + class_tokens_mask = ggml_cont(ctx, ggml_transpose(ctx, class_tokens_mask)); + class_tokens_mask = ggml_repeat(ctx, class_tokens_mask, prompt_embeds); + prompt_embeds = ggml_mul(ctx, prompt_embeds, class_tokens_mask); + struct ggml_tensor* updated_prompt_embeds = ggml_add(ctx, prompt_embeds, stacked_id_embeds); + ggml_set_name(updated_prompt_embeds, "updated_prompt_embeds"); + return updated_prompt_embeds; + } +}; + +struct PhotoMakerIDEncoderBlock : public CLIPVisionModelProjection { + PhotoMakerIDEncoderBlock() + : CLIPVisionModelProjection(OPENAI_CLIP_VIT_L_14) { + blocks["visual_projection_2"] = std::shared_ptr(new Linear(1024, 1280, false)); + blocks["fuse_module"] = std::shared_ptr(new FuseModule(2048)); + } + + struct ggml_tensor* forward(struct ggml_context* ctx, + struct ggml_tensor* id_pixel_values, + struct ggml_tensor* prompt_embeds, + struct ggml_tensor* class_tokens_mask, + struct ggml_tensor* class_tokens_mask_pos, + struct ggml_tensor* left, + struct ggml_tensor* right) { + // x: [N, channels, h, w] + auto vision_model = std::dynamic_pointer_cast(blocks["vision_model"]); + auto visual_projection = std::dynamic_pointer_cast(blocks["visual_projection"]); + auto visual_projection_2 = std::dynamic_pointer_cast(blocks["visual_projection_2"]); + auto fuse_module = std::dynamic_pointer_cast(blocks["fuse_module"]); + + struct ggml_tensor* shared_id_embeds = vision_model->forward(ctx, id_pixel_values); // [N, hidden_size] + struct ggml_tensor* id_embeds = visual_projection->forward(ctx, shared_id_embeds); // [N, proj_dim(768)] + struct ggml_tensor* id_embeds_2 = visual_projection_2->forward(ctx, shared_id_embeds); // [N, 1280] + + id_embeds = ggml_cont(ctx, ggml_permute(ctx, id_embeds, 2, 0, 1, 3)); + id_embeds_2 = ggml_cont(ctx, ggml_permute(ctx, id_embeds_2, 2, 0, 1, 3)); + + id_embeds = ggml_concat(ctx, id_embeds, id_embeds_2); // [batch_size, seq_length, 1, 2048] check whether concat at dim 2 is right + id_embeds = ggml_cont(ctx, ggml_permute(ctx, id_embeds, 1, 2, 0, 3)); + + struct ggml_tensor* updated_prompt_embeds = fuse_module->forward(ctx, + prompt_embeds, + id_embeds, + class_tokens_mask, + class_tokens_mask_pos, + left, right); + return updated_prompt_embeds; + } +}; + +struct PhotoMakerIDEncoder : public GGMLModule { +public: + SDVersion version = VERSION_XL; + PhotoMakerIDEncoderBlock id_encoder; + float style_strength; + + std::vector ctm; + std::vector ctmf16; + std::vector ctmpos; + + std::vector zeros_left_16; + std::vector zeros_left; + std::vector zeros_right_16; + std::vector zeros_right; + +public: + PhotoMakerIDEncoder(ggml_backend_t backend, ggml_type wtype, SDVersion version = VERSION_XL, float sty = 20.f) + : GGMLModule(backend, wtype), + version(version), + style_strength(sty) { + id_encoder.init(params_ctx, wtype); + } + + std::string get_desc() { + return "pmid"; + } + + size_t get_params_mem_size() { + size_t params_mem_size = id_encoder.get_params_mem_size(); + return params_mem_size; + } + + size_t get_params_num() { + size_t params_num = id_encoder.get_params_num(); + return params_num; + } + + void get_param_tensors(std::map& tensors, const std::string prefix) { + id_encoder.get_param_tensors(tensors, prefix); + } + + struct ggml_cgraph* build_graph( // struct ggml_allocr* allocr, + struct ggml_tensor* id_pixel_values, + struct ggml_tensor* prompt_embeds, + std::vector& class_tokens_mask) { + ctm.clear(); + ctmf16.clear(); + ctmpos.clear(); + zeros_left.clear(); + zeros_left_16.clear(); + zeros_right.clear(); + zeros_right_16.clear(); + + ggml_context* ctx0 = compute_ctx; + + struct ggml_cgraph* gf = ggml_new_graph(compute_ctx); + + int64_t hidden_size = prompt_embeds->ne[0]; + int64_t seq_length = prompt_embeds->ne[1]; + ggml_type type = GGML_TYPE_F32; + + struct ggml_tensor* class_tokens_mask_d = ggml_new_tensor_1d(ctx0, type, class_tokens_mask.size()); + + struct ggml_tensor* id_pixel_values_d = to_backend(id_pixel_values); + struct ggml_tensor* prompt_embeds_d = to_backend(prompt_embeds); + + struct ggml_tensor* left = NULL; + struct ggml_tensor* right = NULL; + for (int i = 0; i < class_tokens_mask.size(); i++) { + if (class_tokens_mask[i]) { + ctm.push_back(0.f); // here use 0.f instead of 1.f to make a scale mask + ctmf16.push_back(ggml_fp32_to_fp16(0.f)); // here use 0.f instead of 1.f to make a scale mask + ctmpos.push_back(i); + } else { + ctm.push_back(1.f); // here use 1.f instead of 0.f to make a scale mask + ctmf16.push_back(ggml_fp32_to_fp16(1.f)); // here use 0.f instead of 1.f to make a scale mask + } + } + if (ctmpos[0] > 0) { + left = ggml_new_tensor_3d(ctx0, type, hidden_size, 1, ctmpos[0]); + } + if (ctmpos[ctmpos.size() - 1] < seq_length - 1) { + right = ggml_new_tensor_3d(ctx0, type, + hidden_size, 1, seq_length - ctmpos[ctmpos.size() - 1] - 1); + } + struct ggml_tensor* class_tokens_mask_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ctmpos.size()); + + { + if (type == GGML_TYPE_F16) + set_backend_tensor_data(class_tokens_mask_d, ctmf16.data()); + else + set_backend_tensor_data(class_tokens_mask_d, ctm.data()); + set_backend_tensor_data(class_tokens_mask_pos, ctmpos.data()); + if (left) { + if (type == GGML_TYPE_F16) { + for (int i = 0; i < ggml_nelements(left); ++i) + zeros_left_16.push_back(ggml_fp32_to_fp16(0.f)); + set_backend_tensor_data(left, zeros_left_16.data()); + } else { + for (int i = 0; i < ggml_nelements(left); ++i) + zeros_left.push_back(0.f); + set_backend_tensor_data(left, zeros_left.data()); + } + } + if (right) { + if (type == GGML_TYPE_F16) { + for (int i = 0; i < ggml_nelements(right); ++i) + zeros_right_16.push_back(ggml_fp32_to_fp16(0.f)); + set_backend_tensor_data(right, zeros_right_16.data()); + } else { + for (int i = 0; i < ggml_nelements(right); ++i) + zeros_right.push_back(0.f); + set_backend_tensor_data(right, zeros_right.data()); + } + } + } + struct ggml_tensor* updated_prompt_embeds = id_encoder.forward(ctx0, + id_pixel_values_d, + prompt_embeds_d, + class_tokens_mask_d, + class_tokens_mask_pos, + left, right); + ggml_build_forward_expand(gf, updated_prompt_embeds); + + return gf; + } + + void compute(const int n_threads, + struct ggml_tensor* id_pixel_values, + struct ggml_tensor* prompt_embeds, + std::vector& class_tokens_mask, + struct ggml_tensor** updated_prompt_embeds, + ggml_context* output_ctx) { + auto get_graph = [&]() -> struct ggml_cgraph* { + // return build_graph(compute_allocr, id_pixel_values, prompt_embeds, class_tokens_mask); + return build_graph(id_pixel_values, prompt_embeds, class_tokens_mask); + }; + + // GGMLModule::compute(get_graph, n_threads, updated_prompt_embeds); + GGMLModule::compute(get_graph, n_threads, true, updated_prompt_embeds, output_ctx); + } +}; + +#endif // __PMI_HPP__ diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 1f1a0d0..4d622dd 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -11,10 +11,19 @@ #include "denoiser.hpp" #include "esrgan.hpp" #include "lora.hpp" +#include "pmid.hpp" #include "tae.hpp" #include "unet.hpp" #include "vae.hpp" +#define STB_IMAGE_IMPLEMENTATION +#define STB_IMAGE_STATIC +#include "stb_image.h" + +// #define STB_IMAGE_WRITE_IMPLEMENTATION +// #define STB_IMAGE_WRITE_STATIC +// #include "stb_image_write.h" + const char* model_version_to_str[] = { "1.x", "2.x", @@ -56,8 +65,11 @@ void calculate_alphas_cumprod(float* alphas_cumprod, class StableDiffusionGGML { public: - ggml_backend_t backend = NULL; // general backend - ggml_type model_data_type = GGML_TYPE_COUNT; + ggml_backend_t backend = NULL; // general backend + ggml_backend_t clip_backend = NULL; + ggml_backend_t control_net_backend = NULL; + ggml_backend_t vae_backend = NULL; + ggml_type model_data_type = GGML_TYPE_COUNT; SDVersion version; bool vae_decode_only = false; @@ -73,10 +85,13 @@ public: std::shared_ptr first_stage_model; std::shared_ptr tae_first_stage; std::shared_ptr control_net; + std::shared_ptr pmid_model; + std::shared_ptr pmid_lora; std::string taesd_path; bool use_tiny_autoencoder = false; bool vae_tiling = false; + bool stacked_id = false; std::map tensors; @@ -86,6 +101,8 @@ public: std::shared_ptr denoiser = std::make_shared(); + std::string trigger_word = "img"; // should be user settable + StableDiffusionGGML() = default; StableDiffusionGGML(int n_threads, @@ -106,17 +123,23 @@ public: ~StableDiffusionGGML() { ggml_backend_free(backend); + ggml_backend_free(clip_backend); + ggml_backend_free(control_net_backend); + ggml_backend_free(vae_backend); } bool load_from_file(const std::string& model_path, const std::string& vae_path, const std::string control_net_path, const std::string embeddings_path, + const std::string id_embeddings_path, const std::string& taesd_path, bool vae_tiling_, ggml_type wtype, schedule_t schedule, - bool control_net_cpu) { + bool clip_on_cpu, + bool control_net_cpu, + bool vae_on_cpu) { use_tiny_autoencoder = taesd_path.size() > 0; #ifdef SD_USE_CUBLAS LOG_DEBUG("Using CUDA backend"); @@ -161,6 +184,7 @@ public: LOG_ERROR("get sd version from file failed: '%s'", model_path.c_str()); return false; } + LOG_INFO("Stable Diffusion %s ", model_version_to_str[version]); if (wtype == GGML_TYPE_COUNT) { model_data_type = model_loader.get_sd_wtype(); @@ -195,7 +219,12 @@ public: first_stage_model->alloc_params_buffer(); first_stage_model->get_param_tensors(tensors, "first_stage_model"); } else { - cond_stage_model = std::make_shared(backend, model_data_type, version); + clip_backend = backend; + if (clip_on_cpu && !ggml_backend_is_cpu(backend)) { + LOG_INFO("CLIP: Using CPU backend"); + clip_backend = ggml_backend_cpu_init(); + } + cond_stage_model = std::make_shared(clip_backend, model_data_type, version); cond_stage_model->alloc_params_buffer(); cond_stage_model->get_param_tensors(tensors, "cond_stage_model."); @@ -211,24 +240,59 @@ public: } if (!use_tiny_autoencoder) { - first_stage_model = std::make_shared(backend, vae_type, vae_decode_only); + if (vae_on_cpu && !ggml_backend_is_cpu(backend)) { + LOG_INFO("VAE Autoencoder: Using CPU backend"); + vae_backend = ggml_backend_cpu_init(); + } else { + vae_backend = backend; + } + first_stage_model = std::make_shared(vae_backend, vae_type, vae_decode_only); first_stage_model->alloc_params_buffer(); first_stage_model->get_param_tensors(tensors, "first_stage_model"); } else { tae_first_stage = std::make_shared(backend, model_data_type, vae_decode_only); } + // first_stage_model->get_param_tensors(tensors, "first_stage_model."); if (control_net_path.size() > 0) { - ggml_backend_t cn_backend = NULL; + ggml_backend_t controlnet_backend = NULL; if (control_net_cpu && !ggml_backend_is_cpu(backend)) { LOG_DEBUG("ControlNet: Using CPU backend"); - cn_backend = ggml_backend_cpu_init(); + controlnet_backend = ggml_backend_cpu_init(); } else { - cn_backend = backend; + controlnet_backend = backend; } - control_net = std::make_shared(cn_backend, model_data_type, version); + control_net = std::make_shared(controlnet_backend, model_data_type, version); } + pmid_model = std::make_shared(clip_backend, model_data_type, version); + if (id_embeddings_path.size() > 0) { + pmid_lora = std::make_shared(backend, model_data_type, id_embeddings_path, ""); + if (!pmid_lora->load_from_file(true)) { + LOG_WARN("load photomaker lora tensors from %s failed", id_embeddings_path.c_str()); + return false; + } + LOG_INFO("loading stacked ID embedding (PHOTOMAKER) model file from '%s'", id_embeddings_path.c_str()); + if (!model_loader.init_from_file(id_embeddings_path, "pmid.")) { + LOG_WARN("loading stacked ID embedding from '%s' failed", id_embeddings_path.c_str()); + } else { + stacked_id = true; + } + } + if (stacked_id) { + if (!pmid_model->alloc_params_buffer()) { + LOG_ERROR(" pmid model params buffer allocation failed"); + return false; + } + // LOG_INFO("pmid param memory buffer size = %.2fMB ", + // pmid_model->params_buffer_size / 1024.0 / 1024.0); + pmid_model->get_param_tensors(tensors, "pmid"); + } + // if(stacked_id){ + // pmid_model.init_params(GGML_TYPE_F32); + // pmid_model.map_by_name(tensors, "pmid."); + // } + LOG_DEBUG("loading vocab"); std::string merges_utf8_str = model_loader.load_merges(); if (merges_utf8_str.size() == 0) { @@ -250,6 +314,7 @@ public: // load weights LOG_DEBUG("loading weights"); + int64_t t0 = ggml_time_ms(); std::set ignore_tensors; @@ -257,6 +322,10 @@ public: if (use_tiny_autoencoder) { ignore_tensors.insert("first_stage_model."); } + if (stacked_id) { + ignore_tensors.insert("lora."); + } + if (vae_decode_only) { ignore_tensors.insert("first_stage_model.encoder"); ignore_tensors.insert("first_stage_model.quant"); @@ -296,14 +365,54 @@ public: } control_net_params_mem_size = control_net->get_params_mem_size(); } + size_t pmid_params_mem_size = 0; + if (stacked_id) { + pmid_params_mem_size = pmid_model->get_params_mem_size(); + } - size_t total_params_size = clip_params_mem_size + clip_params_mem_size + clip_params_mem_size + control_net_params_mem_size; - LOG_INFO("total params memory size = %.2fMB (clip %.2fMB, unet %.2fMB, vae %.2fMB, controlnet %.2fMB)", - total_params_size / 1024.0 / 1024.0, - clip_params_mem_size / 1024.0 / 1024.0, - unet_params_mem_size / 1024.0 / 1024.0, - vae_params_mem_size / 1024.0 / 1024.0, - control_net_params_mem_size / 1024.0 / 1024.0); + size_t total_params_ram_size = 0; + size_t total_params_vram_size = 0; + if (ggml_backend_is_cpu(clip_backend)) { + total_params_ram_size += clip_params_mem_size + pmid_params_mem_size; + } else { + total_params_vram_size += clip_params_mem_size + pmid_params_mem_size; + } + + if (ggml_backend_is_cpu(backend)) { + total_params_ram_size += unet_params_mem_size; + } else { + total_params_vram_size += unet_params_mem_size; + } + + if (ggml_backend_is_cpu(vae_backend)) { + total_params_ram_size += vae_params_mem_size; + } else { + total_params_vram_size += vae_params_mem_size; + } + + if (ggml_backend_is_cpu(control_net_backend)) { + total_params_ram_size += control_net_params_mem_size; + } else { + total_params_vram_size += control_net_params_mem_size; + } + + size_t total_params_size = total_params_ram_size + total_params_vram_size; + LOG_INFO( + "total params memory size = %.2fMB (VRAM %.2fMB, RAM %.2fMB): " + "clip %.2fMB(%s), unet %.2fMB(%s), vae %.2fMB(%s), controlnet %.2fMB(%s), pmid %.2fMB(%s)", + total_params_size / 1024.0 / 1024.0, + total_params_vram_size / 1024.0 / 1024.0, + total_params_ram_size / 1024.0 / 1024.0, + clip_params_mem_size / 1024.0 / 1024.0, + ggml_backend_is_cpu(clip_backend) ? "RAM" : "VRAM", + unet_params_mem_size / 1024.0 / 1024.0, + ggml_backend_is_cpu(backend) ? "RAM" : "VRAM", + vae_params_mem_size / 1024.0 / 1024.0, + ggml_backend_is_cpu(vae_backend) ? "RAM" : "VRAM", + control_net_params_mem_size / 1024.0 / 1024.0, + ggml_backend_is_cpu(control_net_backend) ? "RAM" : "VRAM", + pmid_params_mem_size / 1024.0 / 1024.0, + ggml_backend_is_cpu(clip_backend) ? "RAM" : "VRAM"); } int64_t t1 = ggml_time_ms(); @@ -444,16 +553,80 @@ public: curr_lora_state = lora_state; } + std::string remove_trigger_from_prompt(ggml_context* work_ctx, + const std::string& prompt) { + auto image_tokens = cond_stage_model->convert_token_to_id(trigger_word); + GGML_ASSERT(image_tokens.size() == 1); + auto tokens_and_weights = cond_stage_model->tokenize(prompt, false); + std::vector& tokens = tokens_and_weights.first; + auto it = std::find(tokens.begin(), tokens.end(), image_tokens[0]); + GGML_ASSERT(it != tokens.end()); // prompt must have trigger word + tokens.erase(it); + return cond_stage_model->decode(tokens); + } + + std::tuple> + get_learned_condition_with_trigger(ggml_context* work_ctx, + const std::string& text, + int clip_skip, + int width, + int height, + int num_input_imgs, + bool force_zero_embeddings = false) { + auto image_tokens = cond_stage_model->convert_token_to_id(trigger_word); + // if(image_tokens.size() == 1){ + // printf(" image token id is: %d \n", image_tokens[0]); + // } + GGML_ASSERT(image_tokens.size() == 1); + auto tokens_and_weights = cond_stage_model->tokenize_with_trigger_token(text, + num_input_imgs, + image_tokens[0], + true); + std::vector& tokens = std::get<0>(tokens_and_weights); + std::vector& weights = std::get<1>(tokens_and_weights); + std::vector& clsm = std::get<2>(tokens_and_weights); + // printf("tokens: \n"); + // for(int i = 0; i < tokens.size(); ++i) + // printf("%d ", tokens[i]); + // printf("\n"); + // printf("clsm: \n"); + // for(int i = 0; i < clsm.size(); ++i) + // printf("%d ", clsm[i]?1:0); + // printf("\n"); + auto cond = get_learned_condition_common(work_ctx, tokens, weights, clip_skip, width, height, force_zero_embeddings); + return std::make_tuple(cond.first, cond.second, clsm); + } + + ggml_tensor* id_encoder(ggml_context* work_ctx, + ggml_tensor* init_img, + ggml_tensor* prompts_embeds, + std::vector& class_tokens_mask) { + ggml_tensor* res = NULL; + pmid_model->compute(n_threads, init_img, prompts_embeds, class_tokens_mask, &res, work_ctx); + + return res; + } + std::pair get_learned_condition(ggml_context* work_ctx, const std::string& text, int clip_skip, int width, int height, bool force_zero_embeddings = false) { + auto tokens_and_weights = cond_stage_model->tokenize(text, true); + std::vector& tokens = tokens_and_weights.first; + std::vector& weights = tokens_and_weights.second; + return get_learned_condition_common(work_ctx, tokens, weights, clip_skip, width, height, force_zero_embeddings); + } + + std::pair get_learned_condition_common(ggml_context* work_ctx, + std::vector& tokens, + std::vector& weights, + int clip_skip, + int width, + int height, + bool force_zero_embeddings = false) { cond_stage_model->set_clip_skip(clip_skip); - auto tokens_and_weights = cond_stage_model->tokenize(text, true); - std::vector& tokens = tokens_and_weights.first; - std::vector& weights = tokens_and_weights.second; int64_t t0 = ggml_time_ms(); struct ggml_tensor* hidden_states = NULL; // [N, n_token, hidden_size] struct ggml_tensor* chunk_hidden_states = NULL; // [n_token, hidden_size] @@ -466,7 +639,7 @@ public: std::vector chunk_tokens(tokens.begin() + chunk_idx * chunk_len, tokens.begin() + (chunk_idx + 1) * chunk_len); std::vector chunk_weights(weights.begin() + chunk_idx * chunk_len, - weights.begin() + (chunk_idx + 1) * chunk_len); + weights.begin() + (chunk_idx + 1) * chunk_len); auto input_ids = vector_to_ggml_tensor_i32(work_ctx, chunk_tokens); struct ggml_tensor* input_ids2 = NULL; @@ -664,7 +837,10 @@ public: float min_cfg, float cfg_scale, sample_method_t method, - const std::vector& sigmas) { + const std::vector& sigmas, + int start_merge_step, + ggml_tensor* c_id, + ggml_tensor* c_vec_id) { size_t steps = sigmas.size() - 1; // x_t = load_tensor_from_file(work_ctx, "./rand0.bin"); // print_ggml_tensor(x_t); @@ -730,17 +906,30 @@ public: // GGML_ASSERT(0); } - // cond - diffusion_model->compute(n_threads, - noised_input, - timesteps, - c, - c_concat, - c_vector, - -1, - controls, - control_strength, - &out_cond); + if (start_merge_step == -1 || step <= start_merge_step) { + // cond + diffusion_model->compute(n_threads, + noised_input, + timesteps, + c, + c_concat, + c_vector, + -1, + controls, + control_strength, + &out_cond); + } else { + diffusion_model->compute(n_threads, + noised_input, + timesteps, + c_id, + c_concat, + c_vec_id, + -1, + controls, + control_strength, + &out_cond); + } float* negative_data = NULL; if (has_unconditioned) { @@ -1283,6 +1472,7 @@ sd_ctx_t* new_sd_ctx(const char* model_path_c_str, const char* control_net_path_c_str, const char* lora_model_dir_c_str, const char* embed_dir_c_str, + const char* id_embed_dir_c_str, bool vae_decode_only, bool vae_tiling, bool free_params_immediately, @@ -1290,7 +1480,9 @@ sd_ctx_t* new_sd_ctx(const char* model_path_c_str, enum sd_type_t wtype, enum rng_type_t rng_type, enum schedule_t s, - bool keep_control_net_cpu) { + bool keep_clip_on_cpu, + bool keep_control_net_cpu, + bool keep_vae_on_cpu) { sd_ctx_t* sd_ctx = (sd_ctx_t*)malloc(sizeof(sd_ctx_t)); if (sd_ctx == NULL) { return NULL; @@ -1300,6 +1492,7 @@ sd_ctx_t* new_sd_ctx(const char* model_path_c_str, std::string taesd_path(taesd_path_c_str); std::string control_net_path(control_net_path_c_str); std::string embd_path(embed_dir_c_str); + std::string id_embd_path(id_embed_dir_c_str); std::string lora_model_dir(lora_model_dir_c_str); sd_ctx->sd = new StableDiffusionGGML(n_threads, @@ -1315,11 +1508,14 @@ sd_ctx_t* new_sd_ctx(const char* model_path_c_str, vae_path, control_net_path, embd_path, + id_embd_path, taesd_path, vae_tiling, (ggml_type)wtype, s, - keep_control_net_cpu)) { + keep_clip_on_cpu, + keep_control_net_cpu, + keep_vae_on_cpu)) { delete sd_ctx->sd; sd_ctx->sd = NULL; free(sd_ctx); @@ -1348,7 +1544,10 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx, int64_t seed, int batch_count, const sd_image_t* control_cond, - float control_strength) { + float control_strength, + float style_ratio, + bool normalize_input, + const char* input_id_images_path_c_str) { LOG_DEBUG("txt2img %dx%d", width, height); if (sd_ctx == NULL) { return NULL; @@ -1356,6 +1555,35 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx, // LOG_DEBUG("%s %s %f %d %d %d", prompt_c_str, negative_prompt_c_str, cfg_scale, sample_steps, seed, batch_count); std::string prompt(prompt_c_str); std::string negative_prompt(negative_prompt_c_str); + std::string input_id_images_path(input_id_images_path_c_str); + + // preprocess input id images + std::vector input_id_images; + if (sd_ctx->sd->pmid_model && input_id_images_path.size() > 0) { + std::vector img_files = get_files_from_dir(input_id_images_path); + for (std::string img_file : img_files) { + int c = 0; + int width, height; + uint8_t* input_image_buffer = stbi_load(img_file.c_str(), &width, &height, &c, 3); + if (input_image_buffer == NULL) { + LOG_ERROR("PhotoMaker load image from '%s' failed", img_file.c_str()); + continue; + } else { + LOG_INFO("PhotoMaker loaded image from '%s'", img_file.c_str()); + } + sd_image_t* input_image = NULL; + input_image = new sd_image_t{(uint32_t)width, + (uint32_t)height, + 3, + input_image_buffer}; + input_image = preprocess_id_image(input_image); + if (input_image == NULL) { + LOG_ERROR("preprocess input id image from '%s' failed", img_file.c_str()); + continue; + } + input_id_images.push_back(input_image); + } + } // extract and remove lora auto result_pair = extract_and_remove_lora(prompt); @@ -1372,8 +1600,22 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx, sd_ctx->sd->apply_loras(lora_f2m); int64_t t1 = ggml_time_ms(); LOG_INFO("apply_loras completed, taking %.2fs", (t1 - t0) * 1.0f / 1000); + + if (sd_ctx->sd->stacked_id) { + t0 = ggml_time_ms(); + sd_ctx->sd->pmid_lora->apply(sd_ctx->sd->tensors, sd_ctx->sd->n_threads); + t1 = ggml_time_ms(); + LOG_INFO("pmid_lora apply completed, taking %.2fs", (t1 - t0) * 1.0f / 1000); + if (sd_ctx->sd->free_params_immediately) { + sd_ctx->sd->pmid_lora->free_params_buffer(); + } + } + struct ggml_init_params params; params.mem_size = static_cast(10 * 1024 * 1024); // 10 MB + if (sd_ctx->sd->stacked_id) { + params.mem_size += static_cast(10 * 1024 * 1024); // 10 MB + } params.mem_size += width * height * 3 * sizeof(float); params.mem_size *= batch_count; params.mem_buffer = NULL; @@ -1394,10 +1636,67 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx, seed = rand(); } - t0 = ggml_time_ms(); - auto cond_pair = sd_ctx->sd->get_learned_condition(work_ctx, prompt, clip_skip, width, height); - ggml_tensor* c = cond_pair.first; - ggml_tensor* c_vector = cond_pair.second; // [adm_in_channels, ] + std::string prompt_text_only; + ggml_tensor* init_img = NULL; + ggml_tensor* prompts_embeds = NULL; + ggml_tensor* pooled_prompts_embeds = NULL; + // ggml_tensor* class_tokens_mask = NULL; + std::vector class_tokens_mask; + if (sd_ctx->sd->stacked_id) { + if (input_id_images.size() > 0) { + sd_ctx->sd->pmid_model->style_strength = style_ratio; + int32_t w = input_id_images[0]->width; + int32_t h = input_id_images[0]->height; + int32_t channels = input_id_images[0]->channel; + int32_t num_input_images = (int32_t)input_id_images.size(); + init_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, w, h, channels, num_input_images); + // TODO: move these to somewhere else and be user settable + float mean[] = {0.48145466f, 0.4578275f, 0.40821073f}; + float std[] = {0.26862954f, 0.26130258f, 0.27577711f}; + for (int i = 0; i < num_input_images; i++) { + sd_image_t* init_image = input_id_images[i]; + if (normalize_input) + sd_mul_images_to_tensor(init_image->data, init_img, i, mean, std); + else + sd_mul_images_to_tensor(init_image->data, init_img, i, NULL, NULL); + } + t0 = ggml_time_ms(); + auto cond_tup = sd_ctx->sd->get_learned_condition_with_trigger(work_ctx, prompt, + clip_skip, width, height, num_input_images); + prompts_embeds = std::get<0>(cond_tup); + pooled_prompts_embeds = std::get<1>(cond_tup); // [adm_in_channels, ] + class_tokens_mask = std::get<2>(cond_tup); // + + prompts_embeds = sd_ctx->sd->id_encoder(work_ctx, init_img, prompts_embeds, class_tokens_mask); + t1 = ggml_time_ms(); + LOG_INFO("Photomaker ID Stacking, taking %" PRId64 " ms", t1 - t0); + if (sd_ctx->sd->free_params_immediately) { + sd_ctx->sd->pmid_model->free_params_buffer(); + } + // Encode input prompt without the trigger word for delayed conditioning + prompt_text_only = sd_ctx->sd->remove_trigger_from_prompt(work_ctx, prompt); + // printf("%s || %s \n", prompt.c_str(), prompt_text_only.c_str()); + prompt = prompt_text_only; // + if (sample_steps < 50) { + LOG_INFO("sampling steps increases from %d to 50 for PHOTOMAKER", sample_steps); + sample_steps = 50; + } + } else { + LOG_WARN("Provided PhotoMaker model file, but NO input ID images"); + LOG_WARN("Turn off PhotoMaker"); + sd_ctx->sd->stacked_id = false; + } + } + for (sd_image_t* img : input_id_images) { + free(img->data); + } + input_id_images.clear(); + + t0 = ggml_time_ms(); + auto cond_pair = sd_ctx->sd->get_learned_condition(work_ctx, prompt, clip_skip, width, height); + ggml_tensor* c = cond_pair.first; + ggml_tensor* c_vector = cond_pair.second; // [adm_in_channels, ] + struct ggml_tensor* uc = NULL; struct ggml_tensor* uc_vector = NULL; if (cfg_scale != 1.0) { @@ -1438,6 +1737,14 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx, std::vector sigmas = sd_ctx->sd->denoiser->schedule->get_sigmas(sample_steps); + int start_merge_step = -1; + if (sd_ctx->sd->stacked_id) { + start_merge_step = int(sd_ctx->sd->pmid_model->style_strength / 100.f * sample_steps); + if (start_merge_step > 30) + start_merge_step = 30; + LOG_INFO("PHOTOMAKER: start_merge_step: %d", start_merge_step); + } + struct ggml_tensor* x_0 = sd_ctx->sd->sample(work_ctx, x_t, NULL, @@ -1452,7 +1759,10 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx, cfg_scale, cfg_scale, sample_method, - sigmas); + sigmas, + start_merge_step, + prompts_embeds, + pooled_prompts_embeds); // struct ggml_tensor* x_0 = load_tensor_from_file(ctx, "samples_ddim.bin"); // print_ggml_tensor(x_0); int64_t sampling_end = ggml_time_ms(); @@ -1619,7 +1929,10 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx, cfg_scale, cfg_scale, sample_method, - sigma_sched); + sigma_sched, + -1, + NULL, + NULL); // struct ggml_tensor *x_0 = load_tensor_from_file(ctx, "samples_ddim.bin"); // print_ggml_tensor(x_0); int64_t t3 = ggml_time_ms(); @@ -1755,7 +2068,10 @@ SD_API sd_image_t* img2vid(sd_ctx_t* sd_ctx, min_cfg, cfg_scale, sample_method, - sigmas); + sigmas, + -1, + NULL, + NULL); int64_t t2 = ggml_time_ms(); LOG_INFO("sampling completed, taking %.2fs", (t2 - t1) * 1.0f / 1000); diff --git a/stable-diffusion.h b/stable-diffusion.h index 1b32111..1b3cd14 100644 --- a/stable-diffusion.h +++ b/stable-diffusion.h @@ -65,12 +65,12 @@ enum sd_type_t { SD_TYPE_Q8_0 = 8, SD_TYPE_Q8_1 = 9, // k-quantizations - SD_TYPE_Q2_K = 10, - SD_TYPE_Q3_K = 11, - SD_TYPE_Q4_K = 12, - SD_TYPE_Q5_K = 13, - SD_TYPE_Q6_K = 14, - SD_TYPE_Q8_K = 15, + SD_TYPE_Q2_K = 10, + SD_TYPE_Q3_K = 11, + SD_TYPE_Q4_K = 12, + SD_TYPE_Q5_K = 13, + SD_TYPE_Q6_K = 14, + SD_TYPE_Q8_K = 15, SD_TYPE_IQ2_XXS = 16, SD_TYPE_IQ2_XS = 17, SD_TYPE_IQ3_XXS = 18, @@ -95,7 +95,7 @@ enum sd_log_level_t { }; typedef void (*sd_log_cb_t)(enum sd_log_level_t level, const char* text, void* data); -typedef void (*sd_progress_cb_t)(int step,int steps,float time, void* data); +typedef void (*sd_progress_cb_t)(int step, int steps, float time, void* data); SD_API void sd_set_log_callback(sd_log_cb_t sd_log_cb, void* data); SD_API void sd_set_progress_callback(sd_progress_cb_t cb, void* data); @@ -117,6 +117,7 @@ SD_API sd_ctx_t* new_sd_ctx(const char* model_path, const char* control_net_path_c_str, const char* lora_model_dir, const char* embed_dir_c_str, + const char* stacked_id_embed_dir_c_str, bool vae_decode_only, bool vae_tiling, bool free_params_immediately, @@ -124,7 +125,9 @@ SD_API sd_ctx_t* new_sd_ctx(const char* model_path, enum sd_type_t wtype, enum rng_type_t rng_type, enum schedule_t s, - bool keep_control_net_cpu); + bool keep_clip_on_cpu, + bool keep_control_net_cpu, + bool keep_vae_on_cpu); SD_API void free_sd_ctx(sd_ctx_t* sd_ctx); @@ -140,7 +143,10 @@ SD_API sd_image_t* txt2img(sd_ctx_t* sd_ctx, int64_t seed, int batch_count, const sd_image_t* control_cond, - float control_strength); + float control_strength, + float style_strength, + bool normalize_input, + const char* input_id_images_path); SD_API sd_image_t* img2img(sd_ctx_t* sd_ctx, sd_image_t init_image, diff --git a/thirdparty/stb_image_resize.h b/thirdparty/stb_image_resize.h new file mode 100644 index 0000000..bcca92c --- /dev/null +++ b/thirdparty/stb_image_resize.h @@ -0,0 +1,2585 @@ +/* stb_image_resize - v0.90 - public domain image resizing + by Jorge L Rodriguez (@VinoBS) - 2014 + http://github.com/nothings/stb + + Written with emphasis on usability, portability, and efficiency. (No + SIMD or threads, so it be easily outperformed by libs that use those.) + Only scaling and translation is supported, no rotations or shears. + Easy API downsamples w/Mitchell filter, upsamples w/cubic interpolation. + + COMPILING & LINKING + In one C/C++ file that #includes this file, do this: + #define STB_IMAGE_RESIZE_IMPLEMENTATION + before the #include. That will create the implementation in that file. + + QUICKSTART + stbir_resize_uint8( input_pixels , in_w , in_h , 0, + output_pixels, out_w, out_h, 0, num_channels) + stbir_resize_float(...) + stbir_resize_uint8_srgb( input_pixels , in_w , in_h , 0, + output_pixels, out_w, out_h, 0, + num_channels , alpha_chan , 0) + stbir_resize_uint8_srgb_edgemode( + input_pixels , in_w , in_h , 0, + output_pixels, out_w, out_h, 0, + num_channels , alpha_chan , 0, STBIR_EDGE_CLAMP) + // WRAP/REFLECT/ZERO + + FULL API + See the "header file" section of the source for API documentation. + + ADDITIONAL DOCUMENTATION + + SRGB & FLOATING POINT REPRESENTATION + The sRGB functions presume IEEE floating point. If you do not have + IEEE floating point, define STBIR_NON_IEEE_FLOAT. This will use + a slower implementation. + + MEMORY ALLOCATION + The resize functions here perform a single memory allocation using + malloc. To control the memory allocation, before the #include that + triggers the implementation, do: + + #define STBIR_MALLOC(size,context) ... + #define STBIR_FREE(ptr,context) ... + + Each resize function makes exactly one call to malloc/free, so to use + temp memory, store the temp memory in the context and return that. + + ASSERT + Define STBIR_ASSERT(boolval) to override assert() and not use assert.h + + OPTIMIZATION + Define STBIR_SATURATE_INT to compute clamp values in-range using + integer operations instead of float operations. This may be faster + on some platforms. + + DEFAULT FILTERS + For functions which don't provide explicit control over what filters + to use, you can change the compile-time defaults with + + #define STBIR_DEFAULT_FILTER_UPSAMPLE STBIR_FILTER_something + #define STBIR_DEFAULT_FILTER_DOWNSAMPLE STBIR_FILTER_something + + See stbir_filter in the header-file section for the list of filters. + + NEW FILTERS + A number of 1D filter kernels are used. For a list of + supported filters see the stbir_filter enum. To add a new filter, + write a filter function and add it to stbir__filter_info_table. + + PROGRESS + For interactive use with slow resize operations, you can install + a progress-report callback: + + #define STBIR_PROGRESS_REPORT(val) some_func(val) + + The parameter val is a float which goes from 0 to 1 as progress is made. + + For example: + + static void my_progress_report(float progress); + #define STBIR_PROGRESS_REPORT(val) my_progress_report(val) + + #define STB_IMAGE_RESIZE_IMPLEMENTATION + #include "stb_image_resize.h" + + static void my_progress_report(float progress) + { + printf("Progress: %f%%\n", progress*100); + } + + MAX CHANNELS + If your image has more than 64 channels, define STBIR_MAX_CHANNELS + to the max you'll have. + + ALPHA CHANNEL + Most of the resizing functions provide the ability to control how + the alpha channel of an image is processed. The important things + to know about this: + + 1. The best mathematically-behaved version of alpha to use is + called "premultiplied alpha", in which the other color channels + have had the alpha value multiplied in. If you use premultiplied + alpha, linear filtering (such as image resampling done by this + library, or performed in texture units on GPUs) does the "right + thing". While premultiplied alpha is standard in the movie CGI + industry, it is still uncommon in the videogame/real-time world. + + If you linearly filter non-premultiplied alpha, strange effects + occur. (For example, the average of 1% opaque bright green + and 99% opaque black produces 50% transparent dark green when + non-premultiplied, whereas premultiplied it produces 50% + transparent near-black. The former introduces green energy + that doesn't exist in the source image.) + + 2. Artists should not edit premultiplied-alpha images; artists + want non-premultiplied alpha images. Thus, art tools generally output + non-premultiplied alpha images. + + 3. You will get best results in most cases by converting images + to premultiplied alpha before processing them mathematically. + + 4. If you pass the flag STBIR_FLAG_ALPHA_PREMULTIPLIED, the + resizer does not do anything special for the alpha channel; + it is resampled identically to other channels. This produces + the correct results for premultiplied-alpha images, but produces + less-than-ideal results for non-premultiplied-alpha images. + + 5. If you do not pass the flag STBIR_FLAG_ALPHA_PREMULTIPLIED, + then the resizer weights the contribution of input pixels + based on their alpha values, or, equivalently, it multiplies + the alpha value into the color channels, resamples, then divides + by the resultant alpha value. Input pixels which have alpha=0 do + not contribute at all to output pixels unless _all_ of the input + pixels affecting that output pixel have alpha=0, in which case + the result for that pixel is the same as it would be without + STBIR_FLAG_ALPHA_PREMULTIPLIED. However, this is only true for + input images in integer formats. For input images in float format, + input pixels with alpha=0 have no effect, and output pixels + which have alpha=0 will be 0 in all channels. (For float images, + you can manually achieve the same result by adding a tiny epsilon + value to the alpha channel of every image, and then subtracting + or clamping it at the end.) + + 6. You can suppress the behavior described in #5 and make + all-0-alpha pixels have 0 in all channels by #defining + STBIR_NO_ALPHA_EPSILON. + + 7. You can separately control whether the alpha channel is + interpreted as linear or affected by the colorspace. By default + it is linear; you almost never want to apply the colorspace. + (For example, graphics hardware does not apply sRGB conversion + to the alpha channel.) + + ADDITIONAL CONTRIBUTORS + Sean Barrett: API design, optimizations + + REVISIONS + 0.90 (2014-09-17) first released version + + LICENSE + This software is in the public domain. Where that dedication is not + recognized, you are granted a perpetual, irrevocable license to copy + and modify this file as you see fit. + + TODO + Don't decode all of the image data when only processing a partial tile + Don't use full-width decode buffers when only processing a partial tile + When processing wide images, break processing into tiles so data fits in L1 cache + Installable filters? + Resize that respects alpha test coverage + (Reference code: FloatImage::alphaTestCoverage and FloatImage::scaleAlphaToCoverage: + https://code.google.com/p/nvidia-texture-tools/source/browse/trunk/src/nvimage/FloatImage.cpp ) +*/ + +#ifndef STBIR_INCLUDE_STB_IMAGE_RESIZE_H +#define STBIR_INCLUDE_STB_IMAGE_RESIZE_H + +#ifdef _MSC_VER +typedef unsigned char stbir_uint8; +typedef unsigned short stbir_uint16; +typedef unsigned int stbir_uint32; +#else +#include +typedef uint8_t stbir_uint8; +typedef uint16_t stbir_uint16; +typedef uint32_t stbir_uint32; +#endif + +#ifdef STB_IMAGE_RESIZE_STATIC +#define STBIRDEF static +#else +#ifdef __cplusplus +#define STBIRDEF extern "C" +#else +#define STBIRDEF extern +#endif +#endif + + +////////////////////////////////////////////////////////////////////////////// +// +// Easy-to-use API: +// +// * "input pixels" points to an array of image data with 'num_channels' channels (e.g. RGB=3, RGBA=4) +// * input_w is input image width (x-axis), input_h is input image height (y-axis) +// * stride is the offset between successive rows of image data in memory, in bytes. you can +// specify 0 to mean packed continuously in memory +// * alpha channel is treated identically to other channels. +// * colorspace is linear or sRGB as specified by function name +// * returned result is 1 for success or 0 in case of an error. +// #define STBIR_ASSERT() to trigger an assert on parameter validation errors. +// * Memory required grows approximately linearly with input and output size, but with +// discontinuities at input_w == output_w and input_h == output_h. +// * These functions use a "default" resampling filter defined at compile time. To change the filter, +// you can change the compile-time defaults by #defining STBIR_DEFAULT_FILTER_UPSAMPLE +// and STBIR_DEFAULT_FILTER_DOWNSAMPLE, or you can use the medium-complexity API. + +STBIRDEF int stbir_resize_uint8( const unsigned char *input_pixels , int input_w , int input_h , int input_stride_in_bytes, + unsigned char *output_pixels, int output_w, int output_h, int output_stride_in_bytes, + int num_channels); + +STBIRDEF int stbir_resize_float( const float *input_pixels , int input_w , int input_h , int input_stride_in_bytes, + float *output_pixels, int output_w, int output_h, int output_stride_in_bytes, + int num_channels); + + +// The following functions interpret image data as gamma-corrected sRGB. +// Specify STBIR_ALPHA_CHANNEL_NONE if you have no alpha channel, +// or otherwise provide the index of the alpha channel. Flags value +// of 0 will probably do the right thing if you're not sure what +// the flags mean. + +#define STBIR_ALPHA_CHANNEL_NONE -1 + +// Set this flag if your texture has premultiplied alpha. Otherwise, stbir will +// use alpha-weighted resampling (effectively premultiplying, resampling, +// then unpremultiplying). +#define STBIR_FLAG_ALPHA_PREMULTIPLIED (1 << 0) +// The specified alpha channel should be handled as gamma-corrected value even +// when doing sRGB operations. +#define STBIR_FLAG_ALPHA_USES_COLORSPACE (1 << 1) + +STBIRDEF int stbir_resize_uint8_srgb(const unsigned char *input_pixels , int input_w , int input_h , int input_stride_in_bytes, + unsigned char *output_pixels, int output_w, int output_h, int output_stride_in_bytes, + int num_channels, int alpha_channel, int flags); + + +typedef enum +{ + STBIR_EDGE_CLAMP = 1, + STBIR_EDGE_REFLECT = 2, + STBIR_EDGE_WRAP = 3, + STBIR_EDGE_ZERO = 4, +} stbir_edge; + +// This function adds the ability to specify how requests to sample off the edge of the image are handled. +STBIRDEF int stbir_resize_uint8_srgb_edgemode(const unsigned char *input_pixels , int input_w , int input_h , int input_stride_in_bytes, + unsigned char *output_pixels, int output_w, int output_h, int output_stride_in_bytes, + int num_channels, int alpha_channel, int flags, + stbir_edge edge_wrap_mode); + +////////////////////////////////////////////////////////////////////////////// +// +// Medium-complexity API +// +// This extends the easy-to-use API as follows: +// +// * Alpha-channel can be processed separately +// * If alpha_channel is not STBIR_ALPHA_CHANNEL_NONE +// * Alpha channel will not be gamma corrected (unless flags&STBIR_FLAG_GAMMA_CORRECT) +// * Filters will be weighted by alpha channel (unless flags&STBIR_FLAG_ALPHA_PREMULTIPLIED) +// * Filter can be selected explicitly +// * uint16 image type +// * sRGB colorspace available for all types +// * context parameter for passing to STBIR_MALLOC + +typedef enum +{ + STBIR_FILTER_DEFAULT = 0, // use same filter type that easy-to-use API chooses + STBIR_FILTER_BOX = 1, // A trapezoid w/1-pixel wide ramps, same result as box for integer scale ratios + STBIR_FILTER_TRIANGLE = 2, // On upsampling, produces same results as bilinear texture filtering + STBIR_FILTER_CUBICBSPLINE = 3, // The cubic b-spline (aka Mitchell-Netrevalli with B=1,C=0), gaussian-esque + STBIR_FILTER_CATMULLROM = 4, // An interpolating cubic spline + STBIR_FILTER_MITCHELL = 5, // Mitchell-Netrevalli filter with B=1/3, C=1/3 +} stbir_filter; + +typedef enum +{ + STBIR_COLORSPACE_LINEAR, + STBIR_COLORSPACE_SRGB, + + STBIR_MAX_COLORSPACES, +} stbir_colorspace; + +// The following functions are all identical except for the type of the image data + +STBIRDEF int stbir_resize_uint8_generic( const unsigned char *input_pixels , int input_w , int input_h , int input_stride_in_bytes, + unsigned char *output_pixels, int output_w, int output_h, int output_stride_in_bytes, + int num_channels, int alpha_channel, int flags, + stbir_edge edge_wrap_mode, stbir_filter filter, stbir_colorspace space, + void *alloc_context); + +STBIRDEF int stbir_resize_uint16_generic(const stbir_uint16 *input_pixels , int input_w , int input_h , int input_stride_in_bytes, + stbir_uint16 *output_pixels , int output_w, int output_h, int output_stride_in_bytes, + int num_channels, int alpha_channel, int flags, + stbir_edge edge_wrap_mode, stbir_filter filter, stbir_colorspace space, + void *alloc_context); + +STBIRDEF int stbir_resize_float_generic( const float *input_pixels , int input_w , int input_h , int input_stride_in_bytes, + float *output_pixels , int output_w, int output_h, int output_stride_in_bytes, + int num_channels, int alpha_channel, int flags, + stbir_edge edge_wrap_mode, stbir_filter filter, stbir_colorspace space, + void *alloc_context); + + + +////////////////////////////////////////////////////////////////////////////// +// +// Full-complexity API +// +// This extends the medium API as follows: +// +// * uint32 image type +// * not typesafe +// * separate filter types for each axis +// * separate edge modes for each axis +// * can specify scale explicitly for subpixel correctness +// * can specify image source tile using texture coordinates + +typedef enum +{ + STBIR_TYPE_UINT8 , + STBIR_TYPE_UINT16, + STBIR_TYPE_UINT32, + STBIR_TYPE_FLOAT , + + STBIR_MAX_TYPES +} stbir_datatype; + +STBIRDEF int stbir_resize( const void *input_pixels , int input_w , int input_h , int input_stride_in_bytes, + void *output_pixels, int output_w, int output_h, int output_stride_in_bytes, + stbir_datatype datatype, + int num_channels, int alpha_channel, int flags, + stbir_edge edge_mode_horizontal, stbir_edge edge_mode_vertical, + stbir_filter filter_horizontal, stbir_filter filter_vertical, + stbir_colorspace space, void *alloc_context); + +STBIRDEF int stbir_resize_subpixel(const void *input_pixels , int input_w , int input_h , int input_stride_in_bytes, + void *output_pixels, int output_w, int output_h, int output_stride_in_bytes, + stbir_datatype datatype, + int num_channels, int alpha_channel, int flags, + stbir_edge edge_mode_horizontal, stbir_edge edge_mode_vertical, + stbir_filter filter_horizontal, stbir_filter filter_vertical, + stbir_colorspace space, void *alloc_context, + float x_scale, float y_scale, + float x_offset, float y_offset); + +STBIRDEF int stbir_resize_region( const void *input_pixels , int input_w , int input_h , int input_stride_in_bytes, + void *output_pixels, int output_w, int output_h, int output_stride_in_bytes, + stbir_datatype datatype, + int num_channels, int alpha_channel, int flags, + stbir_edge edge_mode_horizontal, stbir_edge edge_mode_vertical, + stbir_filter filter_horizontal, stbir_filter filter_vertical, + stbir_colorspace space, void *alloc_context, + float s0, float t0, float s1, float t1); +// (s0, t0) & (s1, t1) are the top-left and bottom right corner (uv addressing style: [0, 1]x[0, 1]) of a region of the input image to use. + +// +// +//// end header file ///////////////////////////////////////////////////// +#endif // STBIR_INCLUDE_STB_IMAGE_RESIZE_H + + + + + +#ifdef STB_IMAGE_RESIZE_IMPLEMENTATION + +#ifndef STBIR_ASSERT +#include +#define STBIR_ASSERT(x) assert(x) +#endif + +#ifdef STBIR_DEBUG +#define STBIR__DEBUG_ASSERT STBIR_ASSERT +#else +#define STBIR__DEBUG_ASSERT +#endif + +// If you hit this it means I haven't done it yet. +#define STBIR__UNIMPLEMENTED(x) STBIR_ASSERT(!(x)) + +// For memset +#include + +#include + +#ifndef STBIR_MALLOC +#include +#define STBIR_MALLOC(size,c) malloc(size) +#define STBIR_FREE(ptr,c) free(ptr) +#endif + +#ifndef _MSC_VER +#ifdef __cplusplus +#define stbir__inline inline +#else +#define stbir__inline +#endif +#else +#define stbir__inline __forceinline +#endif + + +// should produce compiler error if size is wrong +typedef unsigned char stbir__validate_uint32[sizeof(stbir_uint32) == 4 ? 1 : -1]; + +#ifdef _MSC_VER +#define STBIR__NOTUSED(v) (void)(v) +#else +#define STBIR__NOTUSED(v) (void)sizeof(v) +#endif + +#define STBIR__ARRAY_SIZE(a) (sizeof((a))/sizeof((a)[0])) + +#ifndef STBIR_DEFAULT_FILTER_UPSAMPLE +#define STBIR_DEFAULT_FILTER_UPSAMPLE STBIR_FILTER_CATMULLROM +#endif + +#ifndef STBIR_DEFAULT_FILTER_DOWNSAMPLE +#define STBIR_DEFAULT_FILTER_DOWNSAMPLE STBIR_FILTER_MITCHELL +#endif + +#ifndef STBIR_PROGRESS_REPORT +#define STBIR_PROGRESS_REPORT(float_0_to_1) +#endif + +#ifndef STBIR_MAX_CHANNELS +#define STBIR_MAX_CHANNELS 64 +#endif + +#if STBIR_MAX_CHANNELS > 65536 +#error "Too many channels; STBIR_MAX_CHANNELS must be no more than 65536." +// because we store the indices in 16-bit variables +#endif + +// This value is added to alpha just before premultiplication to avoid +// zeroing out color values. It is equivalent to 2^-80. If you don't want +// that behavior (it may interfere if you have floating point images with +// very small alpha values) then you can define STBIR_NO_ALPHA_EPSILON to +// disable it. +#ifndef STBIR_ALPHA_EPSILON +#define STBIR_ALPHA_EPSILON ((float)1 / (1 << 20) / (1 << 20) / (1 << 20) / (1 << 20)) +#endif + + + +#ifdef _MSC_VER +#define STBIR__UNUSED_PARAM(v) (void)(v) +#else +#define STBIR__UNUSED_PARAM(v) (void)sizeof(v) +#endif + +// must match stbir_datatype +static unsigned char stbir__type_size[] = { + 1, // STBIR_TYPE_UINT8 + 2, // STBIR_TYPE_UINT16 + 4, // STBIR_TYPE_UINT32 + 4, // STBIR_TYPE_FLOAT +}; + +// Kernel function centered at 0 +typedef float (stbir__kernel_fn)(float x, float scale); +typedef float (stbir__support_fn)(float scale); + +typedef struct +{ + stbir__kernel_fn* kernel; + stbir__support_fn* support; +} stbir__filter_info; + +// When upsampling, the contributors are which source pixels contribute. +// When downsampling, the contributors are which destination pixels are contributed to. +typedef struct +{ + int n0; // First contributing pixel + int n1; // Last contributing pixel +} stbir__contributors; + +typedef struct +{ + const void* input_data; + int input_w; + int input_h; + int input_stride_bytes; + + void* output_data; + int output_w; + int output_h; + int output_stride_bytes; + + float s0, t0, s1, t1; + + float horizontal_shift; // Units: output pixels + float vertical_shift; // Units: output pixels + float horizontal_scale; + float vertical_scale; + + int channels; + int alpha_channel; + stbir_uint32 flags; + stbir_datatype type; + stbir_filter horizontal_filter; + stbir_filter vertical_filter; + stbir_edge edge_horizontal; + stbir_edge edge_vertical; + stbir_colorspace colorspace; + + stbir__contributors* horizontal_contributors; + float* horizontal_coefficients; + + stbir__contributors* vertical_contributors; + float* vertical_coefficients; + + int decode_buffer_pixels; + float* decode_buffer; + + float* horizontal_buffer; + + // cache these because ceil/floor are inexplicably showing up in profile + int horizontal_coefficient_width; + int vertical_coefficient_width; + int horizontal_filter_pixel_width; + int vertical_filter_pixel_width; + int horizontal_filter_pixel_margin; + int vertical_filter_pixel_margin; + int horizontal_num_contributors; + int vertical_num_contributors; + + int ring_buffer_length_bytes; // The length of an individual entry in the ring buffer. The total number of ring buffers is stbir__get_filter_pixel_width(filter) + int ring_buffer_first_scanline; + int ring_buffer_last_scanline; + int ring_buffer_begin_index; + float* ring_buffer; + + float* encode_buffer; // A temporary buffer to store floats so we don't lose precision while we do multiply-adds. + + int horizontal_contributors_size; + int horizontal_coefficients_size; + int vertical_contributors_size; + int vertical_coefficients_size; + int decode_buffer_size; + int horizontal_buffer_size; + int ring_buffer_size; + int encode_buffer_size; +} stbir__info; + +static stbir__inline int stbir__min(int a, int b) +{ + return a < b ? a : b; +} + +static stbir__inline int stbir__max(int a, int b) +{ + return a > b ? a : b; +} + +static stbir__inline float stbir__saturate(float x) +{ + if (x < 0) + return 0; + + if (x > 1) + return 1; + + return x; +} + +#ifdef STBIR_SATURATE_INT +static stbir__inline stbir_uint8 stbir__saturate8(int x) +{ + if ((unsigned int) x <= 255) + return x; + + if (x < 0) + return 0; + + return 255; +} + +static stbir__inline stbir_uint16 stbir__saturate16(int x) +{ + if ((unsigned int) x <= 65535) + return x; + + if (x < 0) + return 0; + + return 65535; +} +#endif + +static float stbir__srgb_uchar_to_linear_float[256] = { + 0.000000f, 0.000304f, 0.000607f, 0.000911f, 0.001214f, 0.001518f, 0.001821f, 0.002125f, 0.002428f, 0.002732f, 0.003035f, + 0.003347f, 0.003677f, 0.004025f, 0.004391f, 0.004777f, 0.005182f, 0.005605f, 0.006049f, 0.006512f, 0.006995f, 0.007499f, + 0.008023f, 0.008568f, 0.009134f, 0.009721f, 0.010330f, 0.010960f, 0.011612f, 0.012286f, 0.012983f, 0.013702f, 0.014444f, + 0.015209f, 0.015996f, 0.016807f, 0.017642f, 0.018500f, 0.019382f, 0.020289f, 0.021219f, 0.022174f, 0.023153f, 0.024158f, + 0.025187f, 0.026241f, 0.027321f, 0.028426f, 0.029557f, 0.030713f, 0.031896f, 0.033105f, 0.034340f, 0.035601f, 0.036889f, + 0.038204f, 0.039546f, 0.040915f, 0.042311f, 0.043735f, 0.045186f, 0.046665f, 0.048172f, 0.049707f, 0.051269f, 0.052861f, + 0.054480f, 0.056128f, 0.057805f, 0.059511f, 0.061246f, 0.063010f, 0.064803f, 0.066626f, 0.068478f, 0.070360f, 0.072272f, + 0.074214f, 0.076185f, 0.078187f, 0.080220f, 0.082283f, 0.084376f, 0.086500f, 0.088656f, 0.090842f, 0.093059f, 0.095307f, + 0.097587f, 0.099899f, 0.102242f, 0.104616f, 0.107023f, 0.109462f, 0.111932f, 0.114435f, 0.116971f, 0.119538f, 0.122139f, + 0.124772f, 0.127438f, 0.130136f, 0.132868f, 0.135633f, 0.138432f, 0.141263f, 0.144128f, 0.147027f, 0.149960f, 0.152926f, + 0.155926f, 0.158961f, 0.162029f, 0.165132f, 0.168269f, 0.171441f, 0.174647f, 0.177888f, 0.181164f, 0.184475f, 0.187821f, + 0.191202f, 0.194618f, 0.198069f, 0.201556f, 0.205079f, 0.208637f, 0.212231f, 0.215861f, 0.219526f, 0.223228f, 0.226966f, + 0.230740f, 0.234551f, 0.238398f, 0.242281f, 0.246201f, 0.250158f, 0.254152f, 0.258183f, 0.262251f, 0.266356f, 0.270498f, + 0.274677f, 0.278894f, 0.283149f, 0.287441f, 0.291771f, 0.296138f, 0.300544f, 0.304987f, 0.309469f, 0.313989f, 0.318547f, + 0.323143f, 0.327778f, 0.332452f, 0.337164f, 0.341914f, 0.346704f, 0.351533f, 0.356400f, 0.361307f, 0.366253f, 0.371238f, + 0.376262f, 0.381326f, 0.386430f, 0.391573f, 0.396755f, 0.401978f, 0.407240f, 0.412543f, 0.417885f, 0.423268f, 0.428691f, + 0.434154f, 0.439657f, 0.445201f, 0.450786f, 0.456411f, 0.462077f, 0.467784f, 0.473532f, 0.479320f, 0.485150f, 0.491021f, + 0.496933f, 0.502887f, 0.508881f, 0.514918f, 0.520996f, 0.527115f, 0.533276f, 0.539480f, 0.545725f, 0.552011f, 0.558340f, + 0.564712f, 0.571125f, 0.577581f, 0.584078f, 0.590619f, 0.597202f, 0.603827f, 0.610496f, 0.617207f, 0.623960f, 0.630757f, + 0.637597f, 0.644480f, 0.651406f, 0.658375f, 0.665387f, 0.672443f, 0.679543f, 0.686685f, 0.693872f, 0.701102f, 0.708376f, + 0.715694f, 0.723055f, 0.730461f, 0.737911f, 0.745404f, 0.752942f, 0.760525f, 0.768151f, 0.775822f, 0.783538f, 0.791298f, + 0.799103f, 0.806952f, 0.814847f, 0.822786f, 0.830770f, 0.838799f, 0.846873f, 0.854993f, 0.863157f, 0.871367f, 0.879622f, + 0.887923f, 0.896269f, 0.904661f, 0.913099f, 0.921582f, 0.930111f, 0.938686f, 0.947307f, 0.955974f, 0.964686f, 0.973445f, + 0.982251f, 0.991102f, 1.0f +}; + +static float stbir__srgb_to_linear(float f) +{ + if (f <= 0.04045f) + return f / 12.92f; + else + return (float)pow((f + 0.055f) / 1.055f, 2.4f); +} + +static float stbir__linear_to_srgb(float f) +{ + if (f <= 0.0031308f) + return f * 12.92f; + else + return 1.055f * (float)pow(f, 1 / 2.4f) - 0.055f; +} + +#ifndef STBIR_NON_IEEE_FLOAT +// From https://gist.github.com/rygorous/2203834 + +typedef union +{ + stbir_uint32 u; + float f; +} stbir__FP32; + +static const stbir_uint32 fp32_to_srgb8_tab4[104] = { + 0x0073000d, 0x007a000d, 0x0080000d, 0x0087000d, 0x008d000d, 0x0094000d, 0x009a000d, 0x00a1000d, + 0x00a7001a, 0x00b4001a, 0x00c1001a, 0x00ce001a, 0x00da001a, 0x00e7001a, 0x00f4001a, 0x0101001a, + 0x010e0033, 0x01280033, 0x01410033, 0x015b0033, 0x01750033, 0x018f0033, 0x01a80033, 0x01c20033, + 0x01dc0067, 0x020f0067, 0x02430067, 0x02760067, 0x02aa0067, 0x02dd0067, 0x03110067, 0x03440067, + 0x037800ce, 0x03df00ce, 0x044600ce, 0x04ad00ce, 0x051400ce, 0x057b00c5, 0x05dd00bc, 0x063b00b5, + 0x06970158, 0x07420142, 0x07e30130, 0x087b0120, 0x090b0112, 0x09940106, 0x0a1700fc, 0x0a9500f2, + 0x0b0f01cb, 0x0bf401ae, 0x0ccb0195, 0x0d950180, 0x0e56016e, 0x0f0d015e, 0x0fbc0150, 0x10630143, + 0x11070264, 0x1238023e, 0x1357021d, 0x14660201, 0x156601e9, 0x165a01d3, 0x174401c0, 0x182401af, + 0x18fe0331, 0x1a9602fe, 0x1c1502d2, 0x1d7e02ad, 0x1ed4028d, 0x201a0270, 0x21520256, 0x227d0240, + 0x239f0443, 0x25c003fe, 0x27bf03c4, 0x29a10392, 0x2b6a0367, 0x2d1d0341, 0x2ebe031f, 0x304d0300, + 0x31d105b0, 0x34a80555, 0x37520507, 0x39d504c5, 0x3c37048b, 0x3e7c0458, 0x40a8042a, 0x42bd0401, + 0x44c20798, 0x488e071e, 0x4c1c06b6, 0x4f76065d, 0x52a50610, 0x55ac05cc, 0x5892058f, 0x5b590559, + 0x5e0c0a23, 0x631c0980, 0x67db08f6, 0x6c55087f, 0x70940818, 0x74a007bd, 0x787d076c, 0x7c330723, +}; + +static stbir_uint8 stbir__linear_to_srgb_uchar(float in) +{ + static const stbir__FP32 almostone = { 0x3f7fffff }; // 1-eps + static const stbir__FP32 minval = { (127-13) << 23 }; + stbir_uint32 tab,bias,scale,t; + stbir__FP32 f; + + // Clamp to [2^(-13), 1-eps]; these two values map to 0 and 1, respectively. + // The tests are carefully written so that NaNs map to 0, same as in the reference + // implementation. + if (!(in > minval.f)) // written this way to catch NaNs + in = minval.f; + if (in > almostone.f) + in = almostone.f; + + // Do the table lookup and unpack bias, scale + f.f = in; + tab = fp32_to_srgb8_tab4[(f.u - minval.u) >> 20]; + bias = (tab >> 16) << 9; + scale = tab & 0xffff; + + // Grab next-highest mantissa bits and perform linear interpolation + t = (f.u >> 12) & 0xff; + return (unsigned char) ((bias + scale*t) >> 16); +} + +#else +// sRGB transition values, scaled by 1<<28 +static int stbir__srgb_offset_to_linear_scaled[256] = +{ + 0, 40738, 122216, 203693, 285170, 366648, 448125, 529603, + 611080, 692557, 774035, 855852, 942009, 1033024, 1128971, 1229926, + 1335959, 1447142, 1563542, 1685229, 1812268, 1944725, 2082664, 2226148, + 2375238, 2529996, 2690481, 2856753, 3028870, 3206888, 3390865, 3580856, + 3776916, 3979100, 4187460, 4402049, 4622919, 4850123, 5083710, 5323731, + 5570236, 5823273, 6082892, 6349140, 6622065, 6901714, 7188133, 7481369, + 7781466, 8088471, 8402427, 8723380, 9051372, 9386448, 9728650, 10078021, + 10434603, 10798439, 11169569, 11548036, 11933879, 12327139, 12727857, 13136073, + 13551826, 13975156, 14406100, 14844697, 15290987, 15745007, 16206795, 16676389, + 17153826, 17639142, 18132374, 18633560, 19142734, 19659934, 20185196, 20718552, + 21260042, 21809696, 22367554, 22933648, 23508010, 24090680, 24681686, 25281066, + 25888850, 26505076, 27129772, 27762974, 28404716, 29055026, 29713942, 30381490, + 31057708, 31742624, 32436272, 33138682, 33849884, 34569912, 35298800, 36036568, + 36783260, 37538896, 38303512, 39077136, 39859796, 40651528, 41452360, 42262316, + 43081432, 43909732, 44747252, 45594016, 46450052, 47315392, 48190064, 49074096, + 49967516, 50870356, 51782636, 52704392, 53635648, 54576432, 55526772, 56486700, + 57456236, 58435408, 59424248, 60422780, 61431036, 62449032, 63476804, 64514376, + 65561776, 66619028, 67686160, 68763192, 69850160, 70947088, 72053992, 73170912, + 74297864, 75434880, 76581976, 77739184, 78906536, 80084040, 81271736, 82469648, + 83677792, 84896192, 86124888, 87363888, 88613232, 89872928, 91143016, 92423512, + 93714432, 95015816, 96327688, 97650056, 98982952, 100326408, 101680440, 103045072, + 104420320, 105806224, 107202800, 108610064, 110028048, 111456776, 112896264, 114346544, + 115807632, 117279552, 118762328, 120255976, 121760536, 123276016, 124802440, 126339832, + 127888216, 129447616, 131018048, 132599544, 134192112, 135795792, 137410592, 139036528, + 140673648, 142321952, 143981456, 145652208, 147334208, 149027488, 150732064, 152447968, + 154175200, 155913792, 157663776, 159425168, 161197984, 162982240, 164777968, 166585184, + 168403904, 170234160, 172075968, 173929344, 175794320, 177670896, 179559120, 181458992, + 183370528, 185293776, 187228736, 189175424, 191133888, 193104112, 195086128, 197079968, + 199085648, 201103184, 203132592, 205173888, 207227120, 209292272, 211369392, 213458480, + 215559568, 217672656, 219797792, 221934976, 224084240, 226245600, 228419056, 230604656, + 232802400, 235012320, 237234432, 239468736, 241715280, 243974080, 246245120, 248528464, + 250824112, 253132064, 255452368, 257785040, 260130080, 262487520, 264857376, 267239664, +}; + +static stbir_uint8 stbir__linear_to_srgb_uchar(float f) +{ + int x = (int) (f * (1 << 28)); // has headroom so you don't need to clamp + int v = 0; + int i; + + // Refine the guess with a short binary search. + i = v + 128; if (x >= stbir__srgb_offset_to_linear_scaled[i]) v = i; + i = v + 64; if (x >= stbir__srgb_offset_to_linear_scaled[i]) v = i; + i = v + 32; if (x >= stbir__srgb_offset_to_linear_scaled[i]) v = i; + i = v + 16; if (x >= stbir__srgb_offset_to_linear_scaled[i]) v = i; + i = v + 8; if (x >= stbir__srgb_offset_to_linear_scaled[i]) v = i; + i = v + 4; if (x >= stbir__srgb_offset_to_linear_scaled[i]) v = i; + i = v + 2; if (x >= stbir__srgb_offset_to_linear_scaled[i]) v = i; + i = v + 1; if (x >= stbir__srgb_offset_to_linear_scaled[i]) v = i; + + return (stbir_uint8) v; +} +#endif + +static float stbir__filter_trapezoid(float x, float scale) +{ + float halfscale = scale / 2; + float t = 0.5f + halfscale; + STBIR__DEBUG_ASSERT(scale <= 1); + + x = (float)fabs(x); + + if (x >= t) + return 0; + else + { + float r = 0.5f - halfscale; + if (x <= r) + return 1; + else + return (t - x) / scale; + } +} + +static float stbir__support_trapezoid(float scale) +{ + STBIR__DEBUG_ASSERT(scale <= 1); + return 0.5f + scale / 2; +} + +static float stbir__filter_triangle(float x, float s) +{ + STBIR__UNUSED_PARAM(s); + + x = (float)fabs(x); + + if (x <= 1.0f) + return 1 - x; + else + return 0; +} + +static float stbir__filter_cubic(float x, float s) +{ + STBIR__UNUSED_PARAM(s); + + x = (float)fabs(x); + + if (x < 1.0f) + return (4 + x*x*(3*x - 6))/6; + else if (x < 2.0f) + return (8 + x*(-12 + x*(6 - x)))/6; + + return (0.0f); +} + +static float stbir__filter_catmullrom(float x, float s) +{ + STBIR__UNUSED_PARAM(s); + + x = (float)fabs(x); + + if (x < 1.0f) + return 1 - x*x*(2.5f - 1.5f*x); + else if (x < 2.0f) + return 2 - x*(4 + x*(0.5f*x - 2.5f)); + + return (0.0f); +} + +static float stbir__filter_mitchell(float x, float s) +{ + STBIR__UNUSED_PARAM(s); + + x = (float)fabs(x); + + if (x < 1.0f) + return (16 + x*x*(21 * x - 36))/18; + else if (x < 2.0f) + return (32 + x*(-60 + x*(36 - 7*x)))/18; + + return (0.0f); +} + +static float stbir__support_zero(float s) +{ + STBIR__UNUSED_PARAM(s); + return 0; +} + +static float stbir__support_one(float s) +{ + STBIR__UNUSED_PARAM(s); + return 1; +} + +static float stbir__support_two(float s) +{ + STBIR__UNUSED_PARAM(s); + return 2; +} + +static stbir__filter_info stbir__filter_info_table[] = { + { NULL, stbir__support_zero }, + { stbir__filter_trapezoid, stbir__support_trapezoid }, + { stbir__filter_triangle, stbir__support_one }, + { stbir__filter_cubic, stbir__support_two }, + { stbir__filter_catmullrom, stbir__support_two }, + { stbir__filter_mitchell, stbir__support_two }, +}; + +stbir__inline static int stbir__use_upsampling(float ratio) +{ + return ratio > 1; +} + +stbir__inline static int stbir__use_width_upsampling(stbir__info* stbir_info) +{ + return stbir__use_upsampling(stbir_info->horizontal_scale); +} + +stbir__inline static int stbir__use_height_upsampling(stbir__info* stbir_info) +{ + return stbir__use_upsampling(stbir_info->vertical_scale); +} + +// This is the maximum number of input samples that can affect an output sample +// with the given filter +static int stbir__get_filter_pixel_width(stbir_filter filter, float scale) +{ + STBIR_ASSERT(filter != 0); + STBIR_ASSERT(filter < STBIR__ARRAY_SIZE(stbir__filter_info_table)); + + if (stbir__use_upsampling(scale)) + return (int)ceil(stbir__filter_info_table[filter].support(1/scale) * 2); + else + return (int)ceil(stbir__filter_info_table[filter].support(scale) * 2 / scale); +} + +// This is how much to expand buffers to account for filters seeking outside +// the image boundaries. +static int stbir__get_filter_pixel_margin(stbir_filter filter, float scale) +{ + return stbir__get_filter_pixel_width(filter, scale) / 2; +} + +static int stbir__get_coefficient_width(stbir_filter filter, float scale) +{ + if (stbir__use_upsampling(scale)) + return (int)ceil(stbir__filter_info_table[filter].support(1 / scale) * 2); + else + return (int)ceil(stbir__filter_info_table[filter].support(scale) * 2); +} + +static int stbir__get_contributors(float scale, stbir_filter filter, int input_size, int output_size) +{ + if (stbir__use_upsampling(scale)) + return output_size; + else + return (input_size + stbir__get_filter_pixel_margin(filter, scale) * 2); +} + +static int stbir__get_total_horizontal_coefficients(stbir__info* info) +{ + return info->horizontal_num_contributors + * stbir__get_coefficient_width (info->horizontal_filter, info->horizontal_scale); +} + +static int stbir__get_total_vertical_coefficients(stbir__info* info) +{ + return info->vertical_num_contributors + * stbir__get_coefficient_width (info->vertical_filter, info->vertical_scale); +} + +static stbir__contributors* stbir__get_contributor(stbir__contributors* contributors, int n) +{ + return &contributors[n]; +} + +// For perf reasons this code is duplicated in stbir__resample_horizontal_upsample/downsample, +// if you change it here change it there too. +static float* stbir__get_coefficient(float* coefficients, stbir_filter filter, float scale, int n, int c) +{ + int width = stbir__get_coefficient_width(filter, scale); + return &coefficients[width*n + c]; +} + +static int stbir__edge_wrap_slow(stbir_edge edge, int n, int max) +{ + switch (edge) + { + case STBIR_EDGE_ZERO: + return 0; // we'll decode the wrong pixel here, and then overwrite with 0s later + + case STBIR_EDGE_CLAMP: + if (n < 0) + return 0; + + if (n >= max) + return max - 1; + + return n; // NOTREACHED + + case STBIR_EDGE_REFLECT: + { + if (n < 0) + { + if (n < max) + return -n; + else + return max - 1; + } + + if (n >= max) + { + int max2 = max * 2; + if (n >= max2) + return 0; + else + return max2 - n - 1; + } + + return n; // NOTREACHED + } + + case STBIR_EDGE_WRAP: + if (n >= 0) + return (n % max); + else + { + int m = (-n) % max; + + if (m != 0) + m = max - m; + + return (m); + } + return n; // NOTREACHED + + default: + STBIR__UNIMPLEMENTED("Unimplemented edge type"); + return 0; + } +} + +stbir__inline static int stbir__edge_wrap(stbir_edge edge, int n, int max) +{ + // avoid per-pixel switch + if (n >= 0 && n < max) + return n; + return stbir__edge_wrap_slow(edge, n, max); +} + +// What input pixels contribute to this output pixel? +static void stbir__calculate_sample_range_upsample(int n, float out_filter_radius, float scale_ratio, float out_shift, int* in_first_pixel, int* in_last_pixel, float* in_center_of_out) +{ + float out_pixel_center = (float)n + 0.5f; + float out_pixel_influence_lowerbound = out_pixel_center - out_filter_radius; + float out_pixel_influence_upperbound = out_pixel_center + out_filter_radius; + + float in_pixel_influence_lowerbound = (out_pixel_influence_lowerbound + out_shift) / scale_ratio; + float in_pixel_influence_upperbound = (out_pixel_influence_upperbound + out_shift) / scale_ratio; + + *in_center_of_out = (out_pixel_center + out_shift) / scale_ratio; + *in_first_pixel = (int)(floor(in_pixel_influence_lowerbound + 0.5)); + *in_last_pixel = (int)(floor(in_pixel_influence_upperbound - 0.5)); +} + +// What output pixels does this input pixel contribute to? +static void stbir__calculate_sample_range_downsample(int n, float in_pixels_radius, float scale_ratio, float out_shift, int* out_first_pixel, int* out_last_pixel, float* out_center_of_in) +{ + float in_pixel_center = (float)n + 0.5f; + float in_pixel_influence_lowerbound = in_pixel_center - in_pixels_radius; + float in_pixel_influence_upperbound = in_pixel_center + in_pixels_radius; + + float out_pixel_influence_lowerbound = in_pixel_influence_lowerbound * scale_ratio - out_shift; + float out_pixel_influence_upperbound = in_pixel_influence_upperbound * scale_ratio - out_shift; + + *out_center_of_in = in_pixel_center * scale_ratio - out_shift; + *out_first_pixel = (int)(floor(out_pixel_influence_lowerbound + 0.5)); + *out_last_pixel = (int)(floor(out_pixel_influence_upperbound - 0.5)); +} + +static void stbir__calculate_coefficients_upsample(stbir__info* stbir_info, stbir_filter filter, float scale, int in_first_pixel, int in_last_pixel, float in_center_of_out, stbir__contributors* contributor, float* coefficient_group) +{ + int i; + float total_filter = 0; + float filter_scale; + + STBIR__DEBUG_ASSERT(in_last_pixel - in_first_pixel <= (int)ceil(stbir__filter_info_table[filter].support(1/scale) * 2)); // Taken directly from stbir__get_coefficient_width() which we can't call because we don't know if we're horizontal or vertical. + + contributor->n0 = in_first_pixel; + contributor->n1 = in_last_pixel; + + STBIR__DEBUG_ASSERT(contributor->n1 >= contributor->n0); + + for (i = 0; i <= in_last_pixel - in_first_pixel; i++) + { + float in_pixel_center = (float)(i + in_first_pixel) + 0.5f; + coefficient_group[i] = stbir__filter_info_table[filter].kernel(in_center_of_out - in_pixel_center, 1 / scale); + + // If the coefficient is zero, skip it. (Don't do the <0 check here, we want the influence of those outside pixels.) + if (i == 0 && !coefficient_group[i]) + { + contributor->n0 = ++in_first_pixel; + i--; + continue; + } + + total_filter += coefficient_group[i]; + } + + STBIR__DEBUG_ASSERT(stbir__filter_info_table[filter].kernel((float)(in_last_pixel + 1) + 0.5f - in_center_of_out, 1/scale) == 0); + + STBIR__DEBUG_ASSERT(total_filter > 0.9); + STBIR__DEBUG_ASSERT(total_filter < 1.1f); // Make sure it's not way off. + + // Make sure the sum of all coefficients is 1. + filter_scale = 1 / total_filter; + + for (i = 0; i <= in_last_pixel - in_first_pixel; i++) + coefficient_group[i] *= filter_scale; + + for (i = in_last_pixel - in_first_pixel; i >= 0; i--) + { + if (coefficient_group[i]) + break; + + // This line has no weight. We can skip it. + contributor->n1 = contributor->n0 + i - 1; + } +} + +static void stbir__calculate_coefficients_downsample(stbir__info* stbir_info, stbir_filter filter, float scale_ratio, int out_first_pixel, int out_last_pixel, float out_center_of_in, stbir__contributors* contributor, float* coefficient_group) +{ + int i; + + STBIR__DEBUG_ASSERT(out_last_pixel - out_first_pixel <= (int)ceil(stbir__filter_info_table[filter].support(scale_ratio) * 2)); // Taken directly from stbir__get_coefficient_width() which we can't call because we don't know if we're horizontal or vertical. + + contributor->n0 = out_first_pixel; + contributor->n1 = out_last_pixel; + + STBIR__DEBUG_ASSERT(contributor->n1 >= contributor->n0); + + for (i = 0; i <= out_last_pixel - out_first_pixel; i++) + { + float out_pixel_center = (float)(i + out_first_pixel) + 0.5f; + float x = out_pixel_center - out_center_of_in; + coefficient_group[i] = stbir__filter_info_table[filter].kernel(x, scale_ratio) * scale_ratio; + } + + STBIR__DEBUG_ASSERT(stbir__filter_info_table[filter].kernel((float)(out_last_pixel + 1) + 0.5f - out_center_of_in, scale_ratio) == 0); + + for (i = out_last_pixel - out_first_pixel; i >= 0; i--) + { + if (coefficient_group[i]) + break; + + // This line has no weight. We can skip it. + contributor->n1 = contributor->n0 + i - 1; + } +} + +static void stbir__normalize_downsample_coefficients(stbir__info* stbir_info, stbir__contributors* contributors, float* coefficients, stbir_filter filter, float scale_ratio, float shift, int input_size, int output_size) +{ + int num_contributors = stbir__get_contributors(scale_ratio, filter, input_size, output_size); + int num_coefficients = stbir__get_coefficient_width(filter, scale_ratio); + int i, j; + int skip; + + for (i = 0; i < output_size; i++) + { + float scale; + float total = 0; + + for (j = 0; j < num_contributors; j++) + { + if (i >= contributors[j].n0 && i <= contributors[j].n1) + { + float coefficient = *stbir__get_coefficient(coefficients, filter, scale_ratio, j, i - contributors[j].n0); + total += coefficient; + } + else if (i < contributors[j].n0) + break; + } + + STBIR__DEBUG_ASSERT(total > 0.9f); + STBIR__DEBUG_ASSERT(total < 1.1f); + + scale = 1 / total; + + for (j = 0; j < num_contributors; j++) + { + if (i >= contributors[j].n0 && i <= contributors[j].n1) + *stbir__get_coefficient(coefficients, filter, scale_ratio, j, i - contributors[j].n0) *= scale; + else if (i < contributors[j].n0) + break; + } + } + + // Optimize: Skip zero coefficients and contributions outside of image bounds. + // Do this after normalizing because normalization depends on the n0/n1 values. + for (j = 0; j < num_contributors; j++) + { + int range, max, width; + + skip = 0; + while (*stbir__get_coefficient(coefficients, filter, scale_ratio, j, skip) == 0) + skip++; + + contributors[j].n0 += skip; + + while (contributors[j].n0 < 0) + { + contributors[j].n0++; + skip++; + } + + range = contributors[j].n1 - contributors[j].n0 + 1; + max = stbir__min(num_coefficients, range); + + width = stbir__get_coefficient_width(filter, scale_ratio); + for (i = 0; i < max; i++) + { + if (i + skip >= width) + break; + + *stbir__get_coefficient(coefficients, filter, scale_ratio, j, i) = *stbir__get_coefficient(coefficients, filter, scale_ratio, j, i + skip); + } + + continue; + } + + // Using min to avoid writing into invalid pixels. + for (i = 0; i < num_contributors; i++) + contributors[i].n1 = stbir__min(contributors[i].n1, output_size - 1); +} + +// Each scan line uses the same kernel values so we should calculate the kernel +// values once and then we can use them for every scan line. +static void stbir__calculate_filters(stbir__info* stbir_info, stbir__contributors* contributors, float* coefficients, stbir_filter filter, float scale_ratio, float shift, int input_size, int output_size) +{ + int n; + int total_contributors = stbir__get_contributors(scale_ratio, filter, input_size, output_size); + + if (stbir__use_upsampling(scale_ratio)) + { + float out_pixels_radius = stbir__filter_info_table[filter].support(1 / scale_ratio) * scale_ratio; + + // Looping through out pixels + for (n = 0; n < total_contributors; n++) + { + float in_center_of_out; // Center of the current out pixel in the in pixel space + int in_first_pixel, in_last_pixel; + + stbir__calculate_sample_range_upsample(n, out_pixels_radius, scale_ratio, shift, &in_first_pixel, &in_last_pixel, &in_center_of_out); + + stbir__calculate_coefficients_upsample(stbir_info, filter, scale_ratio, in_first_pixel, in_last_pixel, in_center_of_out, stbir__get_contributor(contributors, n), stbir__get_coefficient(coefficients, filter, scale_ratio, n, 0)); + } + } + else + { + float in_pixels_radius = stbir__filter_info_table[filter].support(scale_ratio) / scale_ratio; + + // Looping through in pixels + for (n = 0; n < total_contributors; n++) + { + float out_center_of_in; // Center of the current out pixel in the in pixel space + int out_first_pixel, out_last_pixel; + int n_adjusted = n - stbir__get_filter_pixel_margin(filter, scale_ratio); + + stbir__calculate_sample_range_downsample(n_adjusted, in_pixels_radius, scale_ratio, shift, &out_first_pixel, &out_last_pixel, &out_center_of_in); + + stbir__calculate_coefficients_downsample(stbir_info, filter, scale_ratio, out_first_pixel, out_last_pixel, out_center_of_in, stbir__get_contributor(contributors, n), stbir__get_coefficient(coefficients, filter, scale_ratio, n, 0)); + } + + stbir__normalize_downsample_coefficients(stbir_info, contributors, coefficients, filter, scale_ratio, shift, input_size, output_size); + } +} + +static float* stbir__get_decode_buffer(stbir__info* stbir_info) +{ + // The 0 index of the decode buffer starts after the margin. This makes + // it okay to use negative indexes on the decode buffer. + return &stbir_info->decode_buffer[stbir_info->horizontal_filter_pixel_margin * stbir_info->channels]; +} + +#define STBIR__DECODE(type, colorspace) ((type) * (STBIR_MAX_COLORSPACES) + (colorspace)) + +static void stbir__decode_scanline(stbir__info* stbir_info, int n) +{ + int c; + int channels = stbir_info->channels; + int alpha_channel = stbir_info->alpha_channel; + int type = stbir_info->type; + int colorspace = stbir_info->colorspace; + int input_w = stbir_info->input_w; + int input_stride_bytes = stbir_info->input_stride_bytes; + float* decode_buffer = stbir__get_decode_buffer(stbir_info); + stbir_edge edge_horizontal = stbir_info->edge_horizontal; + stbir_edge edge_vertical = stbir_info->edge_vertical; + int in_buffer_row_offset = stbir__edge_wrap(edge_vertical, n, stbir_info->input_h) * input_stride_bytes; + const void* input_data = (char *) stbir_info->input_data + in_buffer_row_offset; + int max_x = input_w + stbir_info->horizontal_filter_pixel_margin; + int decode = STBIR__DECODE(type, colorspace); + + int x = -stbir_info->horizontal_filter_pixel_margin; + + // special handling for STBIR_EDGE_ZERO because it needs to return an item that doesn't appear in the input, + // and we want to avoid paying overhead on every pixel if not STBIR_EDGE_ZERO + if (edge_vertical == STBIR_EDGE_ZERO && (n < 0 || n >= stbir_info->input_h)) + { + for (; x < max_x; x++) + for (c = 0; c < channels; c++) + decode_buffer[x*channels + c] = 0; + return; + } + + switch (decode) + { + case STBIR__DECODE(STBIR_TYPE_UINT8, STBIR_COLORSPACE_LINEAR): + for (; x < max_x; x++) + { + int decode_pixel_index = x * channels; + int input_pixel_index = stbir__edge_wrap(edge_horizontal, x, input_w) * channels; + for (c = 0; c < channels; c++) + decode_buffer[decode_pixel_index + c] = ((float)((const unsigned char*)input_data)[input_pixel_index + c]) / 255; + } + break; + + case STBIR__DECODE(STBIR_TYPE_UINT8, STBIR_COLORSPACE_SRGB): + for (; x < max_x; x++) + { + int decode_pixel_index = x * channels; + int input_pixel_index = stbir__edge_wrap(edge_horizontal, x, input_w) * channels; + for (c = 0; c < channels; c++) + decode_buffer[decode_pixel_index + c] = stbir__srgb_uchar_to_linear_float[((const unsigned char*)input_data)[input_pixel_index + c]]; + + if (!(stbir_info->flags&STBIR_FLAG_ALPHA_USES_COLORSPACE)) + decode_buffer[decode_pixel_index + alpha_channel] = ((float)((const unsigned char*)input_data)[input_pixel_index + alpha_channel]) / 255; + } + break; + + case STBIR__DECODE(STBIR_TYPE_UINT16, STBIR_COLORSPACE_LINEAR): + for (; x < max_x; x++) + { + int decode_pixel_index = x * channels; + int input_pixel_index = stbir__edge_wrap(edge_horizontal, x, input_w) * channels; + for (c = 0; c < channels; c++) + decode_buffer[decode_pixel_index + c] = ((float)((const unsigned short*)input_data)[input_pixel_index + c]) / 65535; + } + break; + + case STBIR__DECODE(STBIR_TYPE_UINT16, STBIR_COLORSPACE_SRGB): + for (; x < max_x; x++) + { + int decode_pixel_index = x * channels; + int input_pixel_index = stbir__edge_wrap(edge_horizontal, x, input_w) * channels; + for (c = 0; c < channels; c++) + decode_buffer[decode_pixel_index + c] = stbir__srgb_to_linear(((float)((const unsigned short*)input_data)[input_pixel_index + c]) / 65535); + + if (!(stbir_info->flags&STBIR_FLAG_ALPHA_USES_COLORSPACE)) + decode_buffer[decode_pixel_index + alpha_channel] = ((float)((const unsigned short*)input_data)[input_pixel_index + alpha_channel]) / 65535; + } + break; + + case STBIR__DECODE(STBIR_TYPE_UINT32, STBIR_COLORSPACE_LINEAR): + for (; x < max_x; x++) + { + int decode_pixel_index = x * channels; + int input_pixel_index = stbir__edge_wrap(edge_horizontal, x, input_w) * channels; + for (c = 0; c < channels; c++) + decode_buffer[decode_pixel_index + c] = (float)(((double)((const unsigned int*)input_data)[input_pixel_index + c]) / 4294967295); + } + break; + + case STBIR__DECODE(STBIR_TYPE_UINT32, STBIR_COLORSPACE_SRGB): + for (; x < max_x; x++) + { + int decode_pixel_index = x * channels; + int input_pixel_index = stbir__edge_wrap(edge_horizontal, x, input_w) * channels; + for (c = 0; c < channels; c++) + decode_buffer[decode_pixel_index + c] = stbir__srgb_to_linear((float)(((double)((const unsigned int*)input_data)[input_pixel_index + c]) / 4294967295)); + + if (!(stbir_info->flags&STBIR_FLAG_ALPHA_USES_COLORSPACE)) + decode_buffer[decode_pixel_index + alpha_channel] = (float)(((double)((const unsigned int*)input_data)[input_pixel_index + alpha_channel]) / 4294967295); + } + break; + + case STBIR__DECODE(STBIR_TYPE_FLOAT, STBIR_COLORSPACE_LINEAR): + for (; x < max_x; x++) + { + int decode_pixel_index = x * channels; + int input_pixel_index = stbir__edge_wrap(edge_horizontal, x, input_w) * channels; + for (c = 0; c < channels; c++) + decode_buffer[decode_pixel_index + c] = ((const float*)input_data)[input_pixel_index + c]; + } + break; + + case STBIR__DECODE(STBIR_TYPE_FLOAT, STBIR_COLORSPACE_SRGB): + for (; x < max_x; x++) + { + int decode_pixel_index = x * channels; + int input_pixel_index = stbir__edge_wrap(edge_horizontal, x, input_w) * channels; + for (c = 0; c < channels; c++) + decode_buffer[decode_pixel_index + c] = stbir__srgb_to_linear(((const float*)input_data)[input_pixel_index + c]); + + if (!(stbir_info->flags&STBIR_FLAG_ALPHA_USES_COLORSPACE)) + decode_buffer[decode_pixel_index + alpha_channel] = ((const float*)input_data)[input_pixel_index + alpha_channel]; + } + + break; + + default: + STBIR__UNIMPLEMENTED("Unknown type/colorspace/channels combination."); + break; + } + + if (!(stbir_info->flags & STBIR_FLAG_ALPHA_PREMULTIPLIED)) + { + for (x = -stbir_info->horizontal_filter_pixel_margin; x < max_x; x++) + { + int decode_pixel_index = x * channels; + + // If the alpha value is 0 it will clobber the color values. Make sure it's not. + float alpha = decode_buffer[decode_pixel_index + alpha_channel]; +#ifndef STBIR_NO_ALPHA_EPSILON + if (stbir_info->type != STBIR_TYPE_FLOAT) { + alpha += STBIR_ALPHA_EPSILON; + decode_buffer[decode_pixel_index + alpha_channel] = alpha; + } +#endif + for (c = 0; c < channels; c++) + { + if (c == alpha_channel) + continue; + + decode_buffer[decode_pixel_index + c] *= alpha; + } + } + } + + if (edge_horizontal == STBIR_EDGE_ZERO) + { + for (x = -stbir_info->horizontal_filter_pixel_margin; x < 0; x++) + { + for (c = 0; c < channels; c++) + decode_buffer[x*channels + c] = 0; + } + for (x = input_w; x < max_x; x++) + { + for (c = 0; c < channels; c++) + decode_buffer[x*channels + c] = 0; + } + } +} + +static float* stbir__get_ring_buffer_entry(float* ring_buffer, int index, int ring_buffer_length) +{ + return &ring_buffer[index * ring_buffer_length]; +} + +static float* stbir__add_empty_ring_buffer_entry(stbir__info* stbir_info, int n) +{ + int ring_buffer_index; + float* ring_buffer; + + if (stbir_info->ring_buffer_begin_index < 0) + { + ring_buffer_index = stbir_info->ring_buffer_begin_index = 0; + stbir_info->ring_buffer_first_scanline = n; + } + else + { + ring_buffer_index = (stbir_info->ring_buffer_begin_index + (stbir_info->ring_buffer_last_scanline - stbir_info->ring_buffer_first_scanline) + 1) % stbir_info->vertical_filter_pixel_width; + STBIR__DEBUG_ASSERT(ring_buffer_index != stbir_info->ring_buffer_begin_index); + } + + ring_buffer = stbir__get_ring_buffer_entry(stbir_info->ring_buffer, ring_buffer_index, stbir_info->ring_buffer_length_bytes / sizeof(float)); + memset(ring_buffer, 0, stbir_info->ring_buffer_length_bytes); + + stbir_info->ring_buffer_last_scanline = n; + + return ring_buffer; +} + + +static void stbir__resample_horizontal_upsample(stbir__info* stbir_info, int n, float* output_buffer) +{ + int x, k; + int output_w = stbir_info->output_w; + int kernel_pixel_width = stbir_info->horizontal_filter_pixel_width; + int channels = stbir_info->channels; + float* decode_buffer = stbir__get_decode_buffer(stbir_info); + stbir__contributors* horizontal_contributors = stbir_info->horizontal_contributors; + float* horizontal_coefficients = stbir_info->horizontal_coefficients; + int coefficient_width = stbir_info->horizontal_coefficient_width; + + for (x = 0; x < output_w; x++) + { + int n0 = horizontal_contributors[x].n0; + int n1 = horizontal_contributors[x].n1; + + int out_pixel_index = x * channels; + int coefficient_group = coefficient_width * x; + int coefficient_counter = 0; + + STBIR__DEBUG_ASSERT(n1 >= n0); + STBIR__DEBUG_ASSERT(n0 >= -stbir_info->horizontal_filter_pixel_margin); + STBIR__DEBUG_ASSERT(n1 >= -stbir_info->horizontal_filter_pixel_margin); + STBIR__DEBUG_ASSERT(n0 < stbir_info->input_w + stbir_info->horizontal_filter_pixel_margin); + STBIR__DEBUG_ASSERT(n1 < stbir_info->input_w + stbir_info->horizontal_filter_pixel_margin); + + switch (channels) { + case 1: + for (k = n0; k <= n1; k++) + { + int in_pixel_index = k * 1; + float coefficient = horizontal_coefficients[coefficient_group + coefficient_counter++]; + STBIR__DEBUG_ASSERT(coefficient != 0); + output_buffer[out_pixel_index + 0] += decode_buffer[in_pixel_index + 0] * coefficient; + } + break; + case 2: + for (k = n0; k <= n1; k++) + { + int in_pixel_index = k * 2; + float coefficient = horizontal_coefficients[coefficient_group + coefficient_counter++]; + STBIR__DEBUG_ASSERT(coefficient != 0); + output_buffer[out_pixel_index + 0] += decode_buffer[in_pixel_index + 0] * coefficient; + output_buffer[out_pixel_index + 1] += decode_buffer[in_pixel_index + 1] * coefficient; + } + break; + case 3: + for (k = n0; k <= n1; k++) + { + int in_pixel_index = k * 3; + float coefficient = horizontal_coefficients[coefficient_group + coefficient_counter++]; + STBIR__DEBUG_ASSERT(coefficient != 0); + output_buffer[out_pixel_index + 0] += decode_buffer[in_pixel_index + 0] * coefficient; + output_buffer[out_pixel_index + 1] += decode_buffer[in_pixel_index + 1] * coefficient; + output_buffer[out_pixel_index + 2] += decode_buffer[in_pixel_index + 2] * coefficient; + } + break; + case 4: + for (k = n0; k <= n1; k++) + { + int in_pixel_index = k * 4; + float coefficient = horizontal_coefficients[coefficient_group + coefficient_counter++]; + STBIR__DEBUG_ASSERT(coefficient != 0); + output_buffer[out_pixel_index + 0] += decode_buffer[in_pixel_index + 0] * coefficient; + output_buffer[out_pixel_index + 1] += decode_buffer[in_pixel_index + 1] * coefficient; + output_buffer[out_pixel_index + 2] += decode_buffer[in_pixel_index + 2] * coefficient; + output_buffer[out_pixel_index + 3] += decode_buffer[in_pixel_index + 3] * coefficient; + } + break; + default: + for (k = n0; k <= n1; k++) + { + int in_pixel_index = k * channels; + float coefficient = horizontal_coefficients[coefficient_group + coefficient_counter++]; + int c; + STBIR__DEBUG_ASSERT(coefficient != 0); + for (c = 0; c < channels; c++) + output_buffer[out_pixel_index + c] += decode_buffer[in_pixel_index + c] * coefficient; + } + break; + } + } +} + +static void stbir__resample_horizontal_downsample(stbir__info* stbir_info, int n, float* output_buffer) +{ + int x, k; + int input_w = stbir_info->input_w; + int output_w = stbir_info->output_w; + int kernel_pixel_width = stbir_info->horizontal_filter_pixel_width; + int channels = stbir_info->channels; + float* decode_buffer = stbir__get_decode_buffer(stbir_info); + stbir__contributors* horizontal_contributors = stbir_info->horizontal_contributors; + float* horizontal_coefficients = stbir_info->horizontal_coefficients; + int coefficient_width = stbir_info->horizontal_coefficient_width; + int filter_pixel_margin = stbir_info->horizontal_filter_pixel_margin; + int max_x = input_w + filter_pixel_margin * 2; + + STBIR__DEBUG_ASSERT(!stbir__use_width_upsampling(stbir_info)); + + switch (channels) { + case 1: + for (x = 0; x < max_x; x++) + { + int n0 = horizontal_contributors[x].n0; + int n1 = horizontal_contributors[x].n1; + + int in_x = x - filter_pixel_margin; + int in_pixel_index = in_x * 1; + int max_n = n1; + int coefficient_group = coefficient_width * x; + + for (k = n0; k <= max_n; k++) + { + int out_pixel_index = k * 1; + float coefficient = horizontal_coefficients[coefficient_group + k - n0]; + STBIR__DEBUG_ASSERT(coefficient != 0); + output_buffer[out_pixel_index + 0] += decode_buffer[in_pixel_index + 0] * coefficient; + } + } + break; + + case 2: + for (x = 0; x < max_x; x++) + { + int n0 = horizontal_contributors[x].n0; + int n1 = horizontal_contributors[x].n1; + + int in_x = x - filter_pixel_margin; + int in_pixel_index = in_x * 2; + int max_n = n1; + int coefficient_group = coefficient_width * x; + + for (k = n0; k <= max_n; k++) + { + int out_pixel_index = k * 2; + float coefficient = horizontal_coefficients[coefficient_group + k - n0]; + STBIR__DEBUG_ASSERT(coefficient != 0); + output_buffer[out_pixel_index + 0] += decode_buffer[in_pixel_index + 0] * coefficient; + output_buffer[out_pixel_index + 1] += decode_buffer[in_pixel_index + 1] * coefficient; + } + } + break; + + case 3: + for (x = 0; x < max_x; x++) + { + int n0 = horizontal_contributors[x].n0; + int n1 = horizontal_contributors[x].n1; + + int in_x = x - filter_pixel_margin; + int in_pixel_index = in_x * 3; + int max_n = n1; + int coefficient_group = coefficient_width * x; + + for (k = n0; k <= max_n; k++) + { + int out_pixel_index = k * 3; + float coefficient = horizontal_coefficients[coefficient_group + k - n0]; + STBIR__DEBUG_ASSERT(coefficient != 0); + output_buffer[out_pixel_index + 0] += decode_buffer[in_pixel_index + 0] * coefficient; + output_buffer[out_pixel_index + 1] += decode_buffer[in_pixel_index + 1] * coefficient; + output_buffer[out_pixel_index + 2] += decode_buffer[in_pixel_index + 2] * coefficient; + } + } + break; + + case 4: + for (x = 0; x < max_x; x++) + { + int n0 = horizontal_contributors[x].n0; + int n1 = horizontal_contributors[x].n1; + + int in_x = x - filter_pixel_margin; + int in_pixel_index = in_x * 4; + int max_n = n1; + int coefficient_group = coefficient_width * x; + + for (k = n0; k <= max_n; k++) + { + int out_pixel_index = k * 4; + float coefficient = horizontal_coefficients[coefficient_group + k - n0]; + STBIR__DEBUG_ASSERT(coefficient != 0); + output_buffer[out_pixel_index + 0] += decode_buffer[in_pixel_index + 0] * coefficient; + output_buffer[out_pixel_index + 1] += decode_buffer[in_pixel_index + 1] * coefficient; + output_buffer[out_pixel_index + 2] += decode_buffer[in_pixel_index + 2] * coefficient; + output_buffer[out_pixel_index + 3] += decode_buffer[in_pixel_index + 3] * coefficient; + } + } + break; + + default: + for (x = 0; x < max_x; x++) + { + int n0 = horizontal_contributors[x].n0; + int n1 = horizontal_contributors[x].n1; + + int in_x = x - filter_pixel_margin; + int in_pixel_index = in_x * channels; + int max_n = n1; + int coefficient_group = coefficient_width * x; + + for (k = n0; k <= max_n; k++) + { + int c; + int out_pixel_index = k * channels; + float coefficient = horizontal_coefficients[coefficient_group + k - n0]; + STBIR__DEBUG_ASSERT(coefficient != 0); + for (c = 0; c < channels; c++) + output_buffer[out_pixel_index + c] += decode_buffer[in_pixel_index + c] * coefficient; + } + } + break; + } +} + +static void stbir__decode_and_resample_upsample(stbir__info* stbir_info, int n) +{ + // Decode the nth scanline from the source image into the decode buffer. + stbir__decode_scanline(stbir_info, n); + + // Now resample it into the ring buffer. + if (stbir__use_width_upsampling(stbir_info)) + stbir__resample_horizontal_upsample(stbir_info, n, stbir__add_empty_ring_buffer_entry(stbir_info, n)); + else + stbir__resample_horizontal_downsample(stbir_info, n, stbir__add_empty_ring_buffer_entry(stbir_info, n)); + + // Now it's sitting in the ring buffer ready to be used as source for the vertical sampling. +} + +static void stbir__decode_and_resample_downsample(stbir__info* stbir_info, int n) +{ + // Decode the nth scanline from the source image into the decode buffer. + stbir__decode_scanline(stbir_info, n); + + memset(stbir_info->horizontal_buffer, 0, stbir_info->output_w * stbir_info->channels * sizeof(float)); + + // Now resample it into the horizontal buffer. + if (stbir__use_width_upsampling(stbir_info)) + stbir__resample_horizontal_upsample(stbir_info, n, stbir_info->horizontal_buffer); + else + stbir__resample_horizontal_downsample(stbir_info, n, stbir_info->horizontal_buffer); + + // Now it's sitting in the horizontal buffer ready to be distributed into the ring buffers. +} + +// Get the specified scan line from the ring buffer. +static float* stbir__get_ring_buffer_scanline(int get_scanline, float* ring_buffer, int begin_index, int first_scanline, int ring_buffer_size, int ring_buffer_length) +{ + int ring_buffer_index = (begin_index + (get_scanline - first_scanline)) % ring_buffer_size; + return stbir__get_ring_buffer_entry(ring_buffer, ring_buffer_index, ring_buffer_length); +} + + +static void stbir__encode_scanline(stbir__info* stbir_info, int num_pixels, void *output_buffer, float *encode_buffer, int channels, int alpha_channel, int decode) +{ + int x; + int n; + int num_nonalpha; + stbir_uint16 nonalpha[STBIR_MAX_CHANNELS]; + + if (!(stbir_info->flags&STBIR_FLAG_ALPHA_PREMULTIPLIED)) + { + for (x=0; x < num_pixels; ++x) + { + int pixel_index = x*channels; + + float alpha = encode_buffer[pixel_index + alpha_channel]; + float reciprocal_alpha = alpha ? 1.0f / alpha : 0; + + // unrolling this produced a 1% slowdown upscaling a large RGBA linear-space image on my machine - stb + for (n = 0; n < channels; n++) + if (n != alpha_channel) + encode_buffer[pixel_index + n] *= reciprocal_alpha; + + // We added in a small epsilon to prevent the color channel from being deleted with zero alpha. + // Because we only add it for integer types, it will automatically be discarded on integer + // conversion, so we don't need to subtract it back out (which would be problematic for + // numeric precision reasons). + } + } + + // build a table of all channels that need colorspace correction, so + // we don't perform colorspace correction on channels that don't need it. + for (x=0, num_nonalpha=0; x < channels; ++x) + if (x != alpha_channel || (stbir_info->flags & STBIR_FLAG_ALPHA_USES_COLORSPACE)) + nonalpha[num_nonalpha++] = x; + + #define STBIR__ROUND_INT(f) ((int) ((f)+0.5)) + #define STBIR__ROUND_UINT(f) ((stbir_uint32) ((f)+0.5)) + + #ifdef STBIR__SATURATE_INT + #define STBIR__ENCODE_LINEAR8(f) stbir__saturate8 (STBIR__ROUND_INT((f) * 255 )) + #define STBIR__ENCODE_LINEAR16(f) stbir__saturate16(STBIR__ROUND_INT((f) * 65535)) + #else + #define STBIR__ENCODE_LINEAR8(f) (unsigned char ) STBIR__ROUND_INT(stbir__saturate(f) * 255 ) + #define STBIR__ENCODE_LINEAR16(f) (unsigned short) STBIR__ROUND_INT(stbir__saturate(f) * 65535) + #endif + + switch (decode) + { + case STBIR__DECODE(STBIR_TYPE_UINT8, STBIR_COLORSPACE_LINEAR): + for (x=0; x < num_pixels; ++x) + { + int pixel_index = x*channels; + + for (n = 0; n < channels; n++) + { + int index = pixel_index + n; + ((unsigned char*)output_buffer)[index] = STBIR__ENCODE_LINEAR8(encode_buffer[index]); + } + } + break; + + case STBIR__DECODE(STBIR_TYPE_UINT8, STBIR_COLORSPACE_SRGB): + for (x=0; x < num_pixels; ++x) + { + int pixel_index = x*channels; + + for (n = 0; n < num_nonalpha; n++) + { + int index = pixel_index + nonalpha[n]; + ((unsigned char*)output_buffer)[index] = stbir__linear_to_srgb_uchar(encode_buffer[index]); + } + + if (!(stbir_info->flags & STBIR_FLAG_ALPHA_USES_COLORSPACE)) + ((unsigned char *)output_buffer)[pixel_index + alpha_channel] = STBIR__ENCODE_LINEAR8(encode_buffer[pixel_index+alpha_channel]); + } + break; + + case STBIR__DECODE(STBIR_TYPE_UINT16, STBIR_COLORSPACE_LINEAR): + for (x=0; x < num_pixels; ++x) + { + int pixel_index = x*channels; + + for (n = 0; n < channels; n++) + { + int index = pixel_index + n; + ((unsigned short*)output_buffer)[index] = STBIR__ENCODE_LINEAR16(encode_buffer[index]); + } + } + break; + + case STBIR__DECODE(STBIR_TYPE_UINT16, STBIR_COLORSPACE_SRGB): + for (x=0; x < num_pixels; ++x) + { + int pixel_index = x*channels; + + for (n = 0; n < num_nonalpha; n++) + { + int index = pixel_index + nonalpha[n]; + ((unsigned short*)output_buffer)[index] = (unsigned short)STBIR__ROUND_INT(stbir__linear_to_srgb(stbir__saturate(encode_buffer[index])) * 65535); + } + + if (!(stbir_info->flags&STBIR_FLAG_ALPHA_USES_COLORSPACE)) + ((unsigned short*)output_buffer)[pixel_index + alpha_channel] = STBIR__ENCODE_LINEAR16(encode_buffer[pixel_index + alpha_channel]); + } + + break; + + case STBIR__DECODE(STBIR_TYPE_UINT32, STBIR_COLORSPACE_LINEAR): + for (x=0; x < num_pixels; ++x) + { + int pixel_index = x*channels; + + for (n = 0; n < channels; n++) + { + int index = pixel_index + n; + ((unsigned int*)output_buffer)[index] = (unsigned int)STBIR__ROUND_UINT(((double)stbir__saturate(encode_buffer[index])) * 4294967295); + } + } + break; + + case STBIR__DECODE(STBIR_TYPE_UINT32, STBIR_COLORSPACE_SRGB): + for (x=0; x < num_pixels; ++x) + { + int pixel_index = x*channels; + + for (n = 0; n < num_nonalpha; n++) + { + int index = pixel_index + nonalpha[n]; + ((unsigned int*)output_buffer)[index] = (unsigned int)STBIR__ROUND_UINT(((double)stbir__linear_to_srgb(stbir__saturate(encode_buffer[index]))) * 4294967295); + } + + if (!(stbir_info->flags&STBIR_FLAG_ALPHA_USES_COLORSPACE)) + ((unsigned int*)output_buffer)[pixel_index + alpha_channel] = (unsigned int)STBIR__ROUND_INT(((double)stbir__saturate(encode_buffer[pixel_index + alpha_channel])) * 4294967295); + } + break; + + case STBIR__DECODE(STBIR_TYPE_FLOAT, STBIR_COLORSPACE_LINEAR): + for (x=0; x < num_pixels; ++x) + { + int pixel_index = x*channels; + + for (n = 0; n < channels; n++) + { + int index = pixel_index + n; + ((float*)output_buffer)[index] = encode_buffer[index]; + } + } + break; + + case STBIR__DECODE(STBIR_TYPE_FLOAT, STBIR_COLORSPACE_SRGB): + for (x=0; x < num_pixels; ++x) + { + int pixel_index = x*channels; + + for (n = 0; n < num_nonalpha; n++) + { + int index = pixel_index + nonalpha[n]; + ((float*)output_buffer)[index] = stbir__linear_to_srgb(encode_buffer[index]); + } + + if (!(stbir_info->flags&STBIR_FLAG_ALPHA_USES_COLORSPACE)) + ((float*)output_buffer)[pixel_index + alpha_channel] = encode_buffer[pixel_index + alpha_channel]; + } + break; + + default: + STBIR__UNIMPLEMENTED("Unknown type/colorspace/channels combination."); + break; + } +} + +static void stbir__resample_vertical_upsample(stbir__info* stbir_info, int n, int in_first_scanline, int in_last_scanline, float in_center_of_out) +{ + int x, k; + int output_w = stbir_info->output_w; + stbir__contributors* vertical_contributors = stbir_info->vertical_contributors; + float* vertical_coefficients = stbir_info->vertical_coefficients; + int channels = stbir_info->channels; + int alpha_channel = stbir_info->alpha_channel; + int type = stbir_info->type; + int colorspace = stbir_info->colorspace; + int kernel_pixel_width = stbir_info->vertical_filter_pixel_width; + void* output_data = stbir_info->output_data; + float* encode_buffer = stbir_info->encode_buffer; + int decode = STBIR__DECODE(type, colorspace); + int coefficient_width = stbir_info->vertical_coefficient_width; + int coefficient_counter; + int contributor = n; + + float* ring_buffer = stbir_info->ring_buffer; + int ring_buffer_begin_index = stbir_info->ring_buffer_begin_index; + int ring_buffer_first_scanline = stbir_info->ring_buffer_first_scanline; + int ring_buffer_last_scanline = stbir_info->ring_buffer_last_scanline; + int ring_buffer_length = stbir_info->ring_buffer_length_bytes/sizeof(float); + + int n0,n1, output_row_start; + int coefficient_group = coefficient_width * contributor; + + n0 = vertical_contributors[contributor].n0; + n1 = vertical_contributors[contributor].n1; + + output_row_start = n * stbir_info->output_stride_bytes; + + STBIR__DEBUG_ASSERT(stbir__use_height_upsampling(stbir_info)); + + memset(encode_buffer, 0, output_w * sizeof(float) * channels); + + // I tried reblocking this for better cache usage of encode_buffer + // (using x_outer, k, x_inner), but it lost speed. -- stb + + coefficient_counter = 0; + switch (channels) { + case 1: + for (k = n0; k <= n1; k++) + { + int coefficient_index = coefficient_counter++; + float* ring_buffer_entry = stbir__get_ring_buffer_scanline(k, ring_buffer, ring_buffer_begin_index, ring_buffer_first_scanline, kernel_pixel_width, ring_buffer_length); + float coefficient = vertical_coefficients[coefficient_group + coefficient_index]; + for (x = 0; x < output_w; ++x) + { + int in_pixel_index = x * 1; + encode_buffer[in_pixel_index + 0] += ring_buffer_entry[in_pixel_index + 0] * coefficient; + } + } + break; + case 2: + for (k = n0; k <= n1; k++) + { + int coefficient_index = coefficient_counter++; + float* ring_buffer_entry = stbir__get_ring_buffer_scanline(k, ring_buffer, ring_buffer_begin_index, ring_buffer_first_scanline, kernel_pixel_width, ring_buffer_length); + float coefficient = vertical_coefficients[coefficient_group + coefficient_index]; + for (x = 0; x < output_w; ++x) + { + int in_pixel_index = x * 2; + encode_buffer[in_pixel_index + 0] += ring_buffer_entry[in_pixel_index + 0] * coefficient; + encode_buffer[in_pixel_index + 1] += ring_buffer_entry[in_pixel_index + 1] * coefficient; + } + } + break; + case 3: + for (k = n0; k <= n1; k++) + { + int coefficient_index = coefficient_counter++; + float* ring_buffer_entry = stbir__get_ring_buffer_scanline(k, ring_buffer, ring_buffer_begin_index, ring_buffer_first_scanline, kernel_pixel_width, ring_buffer_length); + float coefficient = vertical_coefficients[coefficient_group + coefficient_index]; + for (x = 0; x < output_w; ++x) + { + int in_pixel_index = x * 3; + encode_buffer[in_pixel_index + 0] += ring_buffer_entry[in_pixel_index + 0] * coefficient; + encode_buffer[in_pixel_index + 1] += ring_buffer_entry[in_pixel_index + 1] * coefficient; + encode_buffer[in_pixel_index + 2] += ring_buffer_entry[in_pixel_index + 2] * coefficient; + } + } + break; + case 4: + for (k = n0; k <= n1; k++) + { + int coefficient_index = coefficient_counter++; + float* ring_buffer_entry = stbir__get_ring_buffer_scanline(k, ring_buffer, ring_buffer_begin_index, ring_buffer_first_scanline, kernel_pixel_width, ring_buffer_length); + float coefficient = vertical_coefficients[coefficient_group + coefficient_index]; + for (x = 0; x < output_w; ++x) + { + int in_pixel_index = x * 4; + encode_buffer[in_pixel_index + 0] += ring_buffer_entry[in_pixel_index + 0] * coefficient; + encode_buffer[in_pixel_index + 1] += ring_buffer_entry[in_pixel_index + 1] * coefficient; + encode_buffer[in_pixel_index + 2] += ring_buffer_entry[in_pixel_index + 2] * coefficient; + encode_buffer[in_pixel_index + 3] += ring_buffer_entry[in_pixel_index + 3] * coefficient; + } + } + break; + default: + for (k = n0; k <= n1; k++) + { + int coefficient_index = coefficient_counter++; + float* ring_buffer_entry = stbir__get_ring_buffer_scanline(k, ring_buffer, ring_buffer_begin_index, ring_buffer_first_scanline, kernel_pixel_width, ring_buffer_length); + float coefficient = vertical_coefficients[coefficient_group + coefficient_index]; + for (x = 0; x < output_w; ++x) + { + int in_pixel_index = x * channels; + int c; + for (c = 0; c < channels; c++) + encode_buffer[in_pixel_index + c] += ring_buffer_entry[in_pixel_index + c] * coefficient; + } + } + break; + } + stbir__encode_scanline(stbir_info, output_w, (char *) output_data + output_row_start, encode_buffer, channels, alpha_channel, decode); +} + +static void stbir__resample_vertical_downsample(stbir__info* stbir_info, int n, int in_first_scanline, int in_last_scanline, float in_center_of_out) +{ + int x, k; + int output_w = stbir_info->output_w; + int output_h = stbir_info->output_h; + stbir__contributors* vertical_contributors = stbir_info->vertical_contributors; + float* vertical_coefficients = stbir_info->vertical_coefficients; + int channels = stbir_info->channels; + int kernel_pixel_width = stbir_info->vertical_filter_pixel_width; + void* output_data = stbir_info->output_data; + float* horizontal_buffer = stbir_info->horizontal_buffer; + int coefficient_width = stbir_info->vertical_coefficient_width; + int contributor = n + stbir_info->vertical_filter_pixel_margin; + + float* ring_buffer = stbir_info->ring_buffer; + int ring_buffer_begin_index = stbir_info->ring_buffer_begin_index; + int ring_buffer_first_scanline = stbir_info->ring_buffer_first_scanline; + int ring_buffer_last_scanline = stbir_info->ring_buffer_last_scanline; + int ring_buffer_length = stbir_info->ring_buffer_length_bytes/sizeof(float); + int n0,n1; + + n0 = vertical_contributors[contributor].n0; + n1 = vertical_contributors[contributor].n1; + + STBIR__DEBUG_ASSERT(!stbir__use_height_upsampling(stbir_info)); + + for (k = n0; k <= n1; k++) + { + int coefficient_index = k - n0; + int coefficient_group = coefficient_width * contributor; + float coefficient = vertical_coefficients[coefficient_group + coefficient_index]; + + float* ring_buffer_entry = stbir__get_ring_buffer_scanline(k, ring_buffer, ring_buffer_begin_index, ring_buffer_first_scanline, kernel_pixel_width, ring_buffer_length); + + switch (channels) { + case 1: + for (x = 0; x < output_w; x++) + { + int in_pixel_index = x * 1; + ring_buffer_entry[in_pixel_index + 0] += horizontal_buffer[in_pixel_index + 0] * coefficient; + } + break; + case 2: + for (x = 0; x < output_w; x++) + { + int in_pixel_index = x * 2; + ring_buffer_entry[in_pixel_index + 0] += horizontal_buffer[in_pixel_index + 0] * coefficient; + ring_buffer_entry[in_pixel_index + 1] += horizontal_buffer[in_pixel_index + 1] * coefficient; + } + break; + case 3: + for (x = 0; x < output_w; x++) + { + int in_pixel_index = x * 3; + ring_buffer_entry[in_pixel_index + 0] += horizontal_buffer[in_pixel_index + 0] * coefficient; + ring_buffer_entry[in_pixel_index + 1] += horizontal_buffer[in_pixel_index + 1] * coefficient; + ring_buffer_entry[in_pixel_index + 2] += horizontal_buffer[in_pixel_index + 2] * coefficient; + } + break; + case 4: + for (x = 0; x < output_w; x++) + { + int in_pixel_index = x * 4; + ring_buffer_entry[in_pixel_index + 0] += horizontal_buffer[in_pixel_index + 0] * coefficient; + ring_buffer_entry[in_pixel_index + 1] += horizontal_buffer[in_pixel_index + 1] * coefficient; + ring_buffer_entry[in_pixel_index + 2] += horizontal_buffer[in_pixel_index + 2] * coefficient; + ring_buffer_entry[in_pixel_index + 3] += horizontal_buffer[in_pixel_index + 3] * coefficient; + } + break; + default: + for (x = 0; x < output_w; x++) + { + int in_pixel_index = x * channels; + + int c; + for (c = 0; c < channels; c++) + ring_buffer_entry[in_pixel_index + c] += horizontal_buffer[in_pixel_index + c] * coefficient; + } + break; + } + } +} + +static void stbir__buffer_loop_upsample(stbir__info* stbir_info) +{ + int y; + float scale_ratio = stbir_info->vertical_scale; + float out_scanlines_radius = stbir__filter_info_table[stbir_info->vertical_filter].support(1/scale_ratio) * scale_ratio; + + STBIR__DEBUG_ASSERT(stbir__use_height_upsampling(stbir_info)); + + for (y = 0; y < stbir_info->output_h; y++) + { + float in_center_of_out = 0; // Center of the current out scanline in the in scanline space + int in_first_scanline = 0, in_last_scanline = 0; + + stbir__calculate_sample_range_upsample(y, out_scanlines_radius, scale_ratio, stbir_info->vertical_shift, &in_first_scanline, &in_last_scanline, &in_center_of_out); + + STBIR__DEBUG_ASSERT(in_last_scanline - in_first_scanline <= stbir_info->vertical_filter_pixel_width); + + if (stbir_info->ring_buffer_begin_index >= 0) + { + // Get rid of whatever we don't need anymore. + while (in_first_scanline > stbir_info->ring_buffer_first_scanline) + { + if (stbir_info->ring_buffer_first_scanline == stbir_info->ring_buffer_last_scanline) + { + // We just popped the last scanline off the ring buffer. + // Reset it to the empty state. + stbir_info->ring_buffer_begin_index = -1; + stbir_info->ring_buffer_first_scanline = 0; + stbir_info->ring_buffer_last_scanline = 0; + break; + } + else + { + stbir_info->ring_buffer_first_scanline++; + stbir_info->ring_buffer_begin_index = (stbir_info->ring_buffer_begin_index + 1) % stbir_info->vertical_filter_pixel_width; + } + } + } + + // Load in new ones. + if (stbir_info->ring_buffer_begin_index < 0) + stbir__decode_and_resample_upsample(stbir_info, in_first_scanline); + + while (in_last_scanline > stbir_info->ring_buffer_last_scanline) + stbir__decode_and_resample_upsample(stbir_info, stbir_info->ring_buffer_last_scanline + 1); + + // Now all buffers should be ready to write a row of vertical sampling. + stbir__resample_vertical_upsample(stbir_info, y, in_first_scanline, in_last_scanline, in_center_of_out); + + STBIR_PROGRESS_REPORT((float)y / stbir_info->output_h); + } +} + +static void stbir__empty_ring_buffer(stbir__info* stbir_info, int first_necessary_scanline) +{ + int output_stride_bytes = stbir_info->output_stride_bytes; + int channels = stbir_info->channels; + int alpha_channel = stbir_info->alpha_channel; + int type = stbir_info->type; + int colorspace = stbir_info->colorspace; + int output_w = stbir_info->output_w; + void* output_data = stbir_info->output_data; + int decode = STBIR__DECODE(type, colorspace); + + float* ring_buffer = stbir_info->ring_buffer; + int ring_buffer_length = stbir_info->ring_buffer_length_bytes/sizeof(float); + + if (stbir_info->ring_buffer_begin_index >= 0) + { + // Get rid of whatever we don't need anymore. + while (first_necessary_scanline > stbir_info->ring_buffer_first_scanline) + { + if (stbir_info->ring_buffer_first_scanline >= 0 && stbir_info->ring_buffer_first_scanline < stbir_info->output_h) + { + int output_row_start = stbir_info->ring_buffer_first_scanline * output_stride_bytes; + float* ring_buffer_entry = stbir__get_ring_buffer_entry(ring_buffer, stbir_info->ring_buffer_begin_index, ring_buffer_length); + stbir__encode_scanline(stbir_info, output_w, (char *) output_data + output_row_start, ring_buffer_entry, channels, alpha_channel, decode); + STBIR_PROGRESS_REPORT((float)stbir_info->ring_buffer_first_scanline / stbir_info->output_h); + } + + if (stbir_info->ring_buffer_first_scanline == stbir_info->ring_buffer_last_scanline) + { + // We just popped the last scanline off the ring buffer. + // Reset it to the empty state. + stbir_info->ring_buffer_begin_index = -1; + stbir_info->ring_buffer_first_scanline = 0; + stbir_info->ring_buffer_last_scanline = 0; + break; + } + else + { + stbir_info->ring_buffer_first_scanline++; + stbir_info->ring_buffer_begin_index = (stbir_info->ring_buffer_begin_index + 1) % stbir_info->vertical_filter_pixel_width; + } + } + } +} + +static void stbir__buffer_loop_downsample(stbir__info* stbir_info) +{ + int y; + float scale_ratio = stbir_info->vertical_scale; + int output_h = stbir_info->output_h; + float in_pixels_radius = stbir__filter_info_table[stbir_info->vertical_filter].support(scale_ratio) / scale_ratio; + int pixel_margin = stbir_info->vertical_filter_pixel_margin; + int max_y = stbir_info->input_h + pixel_margin; + + STBIR__DEBUG_ASSERT(!stbir__use_height_upsampling(stbir_info)); + + for (y = -pixel_margin; y < max_y; y++) + { + float out_center_of_in; // Center of the current out scanline in the in scanline space + int out_first_scanline, out_last_scanline; + + stbir__calculate_sample_range_downsample(y, in_pixels_radius, scale_ratio, stbir_info->vertical_shift, &out_first_scanline, &out_last_scanline, &out_center_of_in); + + STBIR__DEBUG_ASSERT(out_last_scanline - out_first_scanline <= stbir_info->vertical_filter_pixel_width); + + if (out_last_scanline < 0 || out_first_scanline >= output_h) + continue; + + stbir__empty_ring_buffer(stbir_info, out_first_scanline); + + stbir__decode_and_resample_downsample(stbir_info, y); + + // Load in new ones. + if (stbir_info->ring_buffer_begin_index < 0) + stbir__add_empty_ring_buffer_entry(stbir_info, out_first_scanline); + + while (out_last_scanline > stbir_info->ring_buffer_last_scanline) + stbir__add_empty_ring_buffer_entry(stbir_info, stbir_info->ring_buffer_last_scanline + 1); + + // Now the horizontal buffer is ready to write to all ring buffer rows. + stbir__resample_vertical_downsample(stbir_info, y, out_first_scanline, out_last_scanline, out_center_of_in); + } + + stbir__empty_ring_buffer(stbir_info, stbir_info->output_h); +} + +static void stbir__setup(stbir__info *info, int input_w, int input_h, int output_w, int output_h, int channels) +{ + info->input_w = input_w; + info->input_h = input_h; + info->output_w = output_w; + info->output_h = output_h; + info->channels = channels; +} + +static void stbir__calculate_transform(stbir__info *info, float s0, float t0, float s1, float t1, float *transform) +{ + info->s0 = s0; + info->t0 = t0; + info->s1 = s1; + info->t1 = t1; + + if (transform) + { + info->horizontal_scale = transform[0]; + info->vertical_scale = transform[1]; + info->horizontal_shift = transform[2]; + info->vertical_shift = transform[3]; + } + else + { + info->horizontal_scale = ((float)info->output_w / info->input_w) / (s1 - s0); + info->vertical_scale = ((float)info->output_h / info->input_h) / (t1 - t0); + + info->horizontal_shift = s0 * info->input_w / (s1 - s0); + info->vertical_shift = t0 * info->input_h / (t1 - t0); + } +} + +static void stbir__choose_filter(stbir__info *info, stbir_filter h_filter, stbir_filter v_filter) +{ + if (h_filter == 0) + h_filter = stbir__use_upsampling(info->horizontal_scale) ? STBIR_DEFAULT_FILTER_UPSAMPLE : STBIR_DEFAULT_FILTER_DOWNSAMPLE; + if (v_filter == 0) + v_filter = stbir__use_upsampling(info->vertical_scale) ? STBIR_DEFAULT_FILTER_UPSAMPLE : STBIR_DEFAULT_FILTER_DOWNSAMPLE; + info->horizontal_filter = h_filter; + info->vertical_filter = v_filter; +} + +static stbir_uint32 stbir__calculate_memory(stbir__info *info) +{ + int pixel_margin = stbir__get_filter_pixel_margin(info->horizontal_filter, info->horizontal_scale); + int filter_height = stbir__get_filter_pixel_width(info->vertical_filter, info->vertical_scale); + + info->horizontal_num_contributors = stbir__get_contributors(info->horizontal_scale, info->horizontal_filter, info->input_w, info->output_w); + info->vertical_num_contributors = stbir__get_contributors(info->vertical_scale , info->vertical_filter , info->input_h, info->output_h); + + info->horizontal_contributors_size = info->horizontal_num_contributors * sizeof(stbir__contributors); + info->horizontal_coefficients_size = stbir__get_total_horizontal_coefficients(info) * sizeof(float); + info->vertical_contributors_size = info->vertical_num_contributors * sizeof(stbir__contributors); + info->vertical_coefficients_size = stbir__get_total_vertical_coefficients(info) * sizeof(float); + info->decode_buffer_size = (info->input_w + pixel_margin * 2) * info->channels * sizeof(float); + info->horizontal_buffer_size = info->output_w * info->channels * sizeof(float); + info->ring_buffer_size = info->output_w * info->channels * filter_height * sizeof(float); + info->encode_buffer_size = info->output_w * info->channels * sizeof(float); + + STBIR_ASSERT(info->horizontal_filter != 0); + STBIR_ASSERT(info->horizontal_filter < STBIR__ARRAY_SIZE(stbir__filter_info_table)); // this now happens too late + STBIR_ASSERT(info->vertical_filter != 0); + STBIR_ASSERT(info->vertical_filter < STBIR__ARRAY_SIZE(stbir__filter_info_table)); // this now happens too late + + if (stbir__use_height_upsampling(info)) + // The horizontal buffer is for when we're downsampling the height and we + // can't output the result of sampling the decode buffer directly into the + // ring buffers. + info->horizontal_buffer_size = 0; + else + // The encode buffer is to retain precision in the height upsampling method + // and isn't used when height downsampling. + info->encode_buffer_size = 0; + + return info->horizontal_contributors_size + info->horizontal_coefficients_size + + info->vertical_contributors_size + info->vertical_coefficients_size + + info->decode_buffer_size + info->horizontal_buffer_size + + info->ring_buffer_size + info->encode_buffer_size; +} + +static int stbir__resize_allocated(stbir__info *info, + const void* input_data, int input_stride_in_bytes, + void* output_data, int output_stride_in_bytes, + int alpha_channel, stbir_uint32 flags, stbir_datatype type, + stbir_edge edge_horizontal, stbir_edge edge_vertical, stbir_colorspace colorspace, + void* tempmem, size_t tempmem_size_in_bytes) +{ + size_t memory_required = stbir__calculate_memory(info); + + int width_stride_input = input_stride_in_bytes ? input_stride_in_bytes : info->channels * info->input_w * stbir__type_size[type]; + int width_stride_output = output_stride_in_bytes ? output_stride_in_bytes : info->channels * info->output_w * stbir__type_size[type]; + +#ifdef STBIR_DEBUG_OVERWRITE_TEST +#define OVERWRITE_ARRAY_SIZE 8 + unsigned char overwrite_output_before_pre[OVERWRITE_ARRAY_SIZE]; + unsigned char overwrite_tempmem_before_pre[OVERWRITE_ARRAY_SIZE]; + unsigned char overwrite_output_after_pre[OVERWRITE_ARRAY_SIZE]; + unsigned char overwrite_tempmem_after_pre[OVERWRITE_ARRAY_SIZE]; + + size_t begin_forbidden = width_stride_output * (info->output_h - 1) + info->output_w * info->channels * stbir__type_size[type]; + memcpy(overwrite_output_before_pre, &((unsigned char*)output_data)[-OVERWRITE_ARRAY_SIZE], OVERWRITE_ARRAY_SIZE); + memcpy(overwrite_output_after_pre, &((unsigned char*)output_data)[begin_forbidden], OVERWRITE_ARRAY_SIZE); + memcpy(overwrite_tempmem_before_pre, &((unsigned char*)tempmem)[-OVERWRITE_ARRAY_SIZE], OVERWRITE_ARRAY_SIZE); + memcpy(overwrite_tempmem_after_pre, &((unsigned char*)tempmem)[tempmem_size_in_bytes], OVERWRITE_ARRAY_SIZE); +#endif + + STBIR_ASSERT(info->channels >= 0); + STBIR_ASSERT(info->channels <= STBIR_MAX_CHANNELS); + + if (info->channels < 0 || info->channels > STBIR_MAX_CHANNELS) + return 0; + + STBIR_ASSERT(info->horizontal_filter < STBIR__ARRAY_SIZE(stbir__filter_info_table)); + STBIR_ASSERT(info->vertical_filter < STBIR__ARRAY_SIZE(stbir__filter_info_table)); + + if (info->horizontal_filter >= STBIR__ARRAY_SIZE(stbir__filter_info_table)) + return 0; + if (info->vertical_filter >= STBIR__ARRAY_SIZE(stbir__filter_info_table)) + return 0; + + if (alpha_channel < 0) + flags |= STBIR_FLAG_ALPHA_USES_COLORSPACE | STBIR_FLAG_ALPHA_PREMULTIPLIED; + + if (!(flags&STBIR_FLAG_ALPHA_USES_COLORSPACE) || !(flags&STBIR_FLAG_ALPHA_PREMULTIPLIED)) + STBIR_ASSERT(alpha_channel >= 0 && alpha_channel < info->channels); + + if (alpha_channel >= info->channels) + return 0; + + STBIR_ASSERT(tempmem); + + if (!tempmem) + return 0; + + STBIR_ASSERT(tempmem_size_in_bytes >= memory_required); + + if (tempmem_size_in_bytes < memory_required) + return 0; + + memset(tempmem, 0, tempmem_size_in_bytes); + + info->input_data = input_data; + info->input_stride_bytes = width_stride_input; + + info->output_data = output_data; + info->output_stride_bytes = width_stride_output; + + info->alpha_channel = alpha_channel; + info->flags = flags; + info->type = type; + info->edge_horizontal = edge_horizontal; + info->edge_vertical = edge_vertical; + info->colorspace = colorspace; + + info->horizontal_coefficient_width = stbir__get_coefficient_width (info->horizontal_filter, info->horizontal_scale); + info->vertical_coefficient_width = stbir__get_coefficient_width (info->vertical_filter , info->vertical_scale ); + info->horizontal_filter_pixel_width = stbir__get_filter_pixel_width (info->horizontal_filter, info->horizontal_scale); + info->vertical_filter_pixel_width = stbir__get_filter_pixel_width (info->vertical_filter , info->vertical_scale ); + info->horizontal_filter_pixel_margin = stbir__get_filter_pixel_margin(info->horizontal_filter, info->horizontal_scale); + info->vertical_filter_pixel_margin = stbir__get_filter_pixel_margin(info->vertical_filter , info->vertical_scale ); + + info->ring_buffer_length_bytes = info->output_w * info->channels * sizeof(float); + info->decode_buffer_pixels = info->input_w + info->horizontal_filter_pixel_margin * 2; + +#define STBIR__NEXT_MEMPTR(current, newtype) (newtype*)(((unsigned char*)current) + current##_size) + + info->horizontal_contributors = (stbir__contributors *) tempmem; + info->horizontal_coefficients = STBIR__NEXT_MEMPTR(info->horizontal_contributors, float); + info->vertical_contributors = STBIR__NEXT_MEMPTR(info->horizontal_coefficients, stbir__contributors); + info->vertical_coefficients = STBIR__NEXT_MEMPTR(info->vertical_contributors, float); + info->decode_buffer = STBIR__NEXT_MEMPTR(info->vertical_coefficients, float); + + if (stbir__use_height_upsampling(info)) + { + info->horizontal_buffer = NULL; + info->ring_buffer = STBIR__NEXT_MEMPTR(info->decode_buffer, float); + info->encode_buffer = STBIR__NEXT_MEMPTR(info->ring_buffer, float); + + STBIR__DEBUG_ASSERT((size_t)STBIR__NEXT_MEMPTR(info->encode_buffer, unsigned char) == (size_t)tempmem + tempmem_size_in_bytes); + } + else + { + info->horizontal_buffer = STBIR__NEXT_MEMPTR(info->decode_buffer, float); + info->ring_buffer = STBIR__NEXT_MEMPTR(info->horizontal_buffer, float); + info->encode_buffer = NULL; + + STBIR__DEBUG_ASSERT((size_t)STBIR__NEXT_MEMPTR(info->ring_buffer, unsigned char) == (size_t)tempmem + tempmem_size_in_bytes); + } + +#undef STBIR__NEXT_MEMPTR + + // This signals that the ring buffer is empty + info->ring_buffer_begin_index = -1; + + stbir__calculate_filters(info, info->horizontal_contributors, info->horizontal_coefficients, info->horizontal_filter, info->horizontal_scale, info->horizontal_shift, info->input_w, info->output_w); + stbir__calculate_filters(info, info->vertical_contributors, info->vertical_coefficients, info->vertical_filter, info->vertical_scale, info->vertical_shift, info->input_h, info->output_h); + + STBIR_PROGRESS_REPORT(0); + + if (stbir__use_height_upsampling(info)) + stbir__buffer_loop_upsample(info); + else + stbir__buffer_loop_downsample(info); + + STBIR_PROGRESS_REPORT(1); + +#ifdef STBIR_DEBUG_OVERWRITE_TEST + STBIR__DEBUG_ASSERT(memcmp(overwrite_output_before_pre, &((unsigned char*)output_data)[-OVERWRITE_ARRAY_SIZE], OVERWRITE_ARRAY_SIZE) == 0); + STBIR__DEBUG_ASSERT(memcmp(overwrite_output_after_pre, &((unsigned char*)output_data)[begin_forbidden], OVERWRITE_ARRAY_SIZE) == 0); + STBIR__DEBUG_ASSERT(memcmp(overwrite_tempmem_before_pre, &((unsigned char*)tempmem)[-OVERWRITE_ARRAY_SIZE], OVERWRITE_ARRAY_SIZE) == 0); + STBIR__DEBUG_ASSERT(memcmp(overwrite_tempmem_after_pre, &((unsigned char*)tempmem)[tempmem_size_in_bytes], OVERWRITE_ARRAY_SIZE) == 0); +#endif + + return 1; +} + + +static int stbir__resize_arbitrary( + void *alloc_context, + const void* input_data, int input_w, int input_h, int input_stride_in_bytes, + void* output_data, int output_w, int output_h, int output_stride_in_bytes, + float s0, float t0, float s1, float t1, float *transform, + int channels, int alpha_channel, stbir_uint32 flags, stbir_datatype type, + stbir_filter h_filter, stbir_filter v_filter, + stbir_edge edge_horizontal, stbir_edge edge_vertical, stbir_colorspace colorspace) +{ + stbir__info info; + int result; + size_t memory_required; + void* extra_memory; + + stbir__setup(&info, input_w, input_h, output_w, output_h, channels); + stbir__calculate_transform(&info, s0,t0,s1,t1,transform); + stbir__choose_filter(&info, h_filter, v_filter); + memory_required = stbir__calculate_memory(&info); + extra_memory = STBIR_MALLOC(memory_required, alloc_context); + + if (!extra_memory) + return 0; + + result = stbir__resize_allocated(&info, input_data, input_stride_in_bytes, + output_data, output_stride_in_bytes, + alpha_channel, flags, type, + edge_horizontal, edge_vertical, + colorspace, extra_memory, memory_required); + + STBIR_FREE(extra_memory, alloc_context); + + return result; +} + +STBIRDEF int stbir_resize_uint8( const unsigned char *input_pixels , int input_w , int input_h , int input_stride_in_bytes, + unsigned char *output_pixels, int output_w, int output_h, int output_stride_in_bytes, + int num_channels) +{ + return stbir__resize_arbitrary(NULL, input_pixels, input_w, input_h, input_stride_in_bytes, + output_pixels, output_w, output_h, output_stride_in_bytes, + 0,0,1,1,NULL,num_channels,-1,0, STBIR_TYPE_UINT8, STBIR_FILTER_DEFAULT, STBIR_FILTER_DEFAULT, + STBIR_EDGE_CLAMP, STBIR_EDGE_CLAMP, STBIR_COLORSPACE_LINEAR); +} + +STBIRDEF int stbir_resize_float( const float *input_pixels , int input_w , int input_h , int input_stride_in_bytes, + float *output_pixels, int output_w, int output_h, int output_stride_in_bytes, + int num_channels) +{ + return stbir__resize_arbitrary(NULL, input_pixels, input_w, input_h, input_stride_in_bytes, + output_pixels, output_w, output_h, output_stride_in_bytes, + 0,0,1,1,NULL,num_channels,-1,0, STBIR_TYPE_FLOAT, STBIR_FILTER_DEFAULT, STBIR_FILTER_DEFAULT, + STBIR_EDGE_CLAMP, STBIR_EDGE_CLAMP, STBIR_COLORSPACE_LINEAR); +} + +STBIRDEF int stbir_resize_uint8_srgb(const unsigned char *input_pixels , int input_w , int input_h , int input_stride_in_bytes, + unsigned char *output_pixels, int output_w, int output_h, int output_stride_in_bytes, + int num_channels, int alpha_channel, int flags) +{ + return stbir__resize_arbitrary(NULL, input_pixels, input_w, input_h, input_stride_in_bytes, + output_pixels, output_w, output_h, output_stride_in_bytes, + 0,0,1,1,NULL,num_channels,alpha_channel,flags, STBIR_TYPE_UINT8, STBIR_FILTER_DEFAULT, STBIR_FILTER_DEFAULT, + STBIR_EDGE_CLAMP, STBIR_EDGE_CLAMP, STBIR_COLORSPACE_SRGB); +} + +STBIRDEF int stbir_resize_uint8_srgb_edgemode(const unsigned char *input_pixels , int input_w , int input_h , int input_stride_in_bytes, + unsigned char *output_pixels, int output_w, int output_h, int output_stride_in_bytes, + int num_channels, int alpha_channel, int flags, + stbir_edge edge_wrap_mode) +{ + return stbir__resize_arbitrary(NULL, input_pixels, input_w, input_h, input_stride_in_bytes, + output_pixels, output_w, output_h, output_stride_in_bytes, + 0,0,1,1,NULL,num_channels,alpha_channel,flags, STBIR_TYPE_UINT8, STBIR_FILTER_DEFAULT, STBIR_FILTER_DEFAULT, + edge_wrap_mode, edge_wrap_mode, STBIR_COLORSPACE_SRGB); +} + +STBIRDEF int stbir_resize_uint8_generic( const unsigned char *input_pixels , int input_w , int input_h , int input_stride_in_bytes, + unsigned char *output_pixels, int output_w, int output_h, int output_stride_in_bytes, + int num_channels, int alpha_channel, int flags, + stbir_edge edge_wrap_mode, stbir_filter filter, stbir_colorspace space, + void *alloc_context) +{ + return stbir__resize_arbitrary(alloc_context, input_pixels, input_w, input_h, input_stride_in_bytes, + output_pixels, output_w, output_h, output_stride_in_bytes, + 0,0,1,1,NULL,num_channels,alpha_channel,flags, STBIR_TYPE_UINT8, filter, filter, + edge_wrap_mode, edge_wrap_mode, space); +} + +STBIRDEF int stbir_resize_uint16_generic(const stbir_uint16 *input_pixels , int input_w , int input_h , int input_stride_in_bytes, + stbir_uint16 *output_pixels , int output_w, int output_h, int output_stride_in_bytes, + int num_channels, int alpha_channel, int flags, + stbir_edge edge_wrap_mode, stbir_filter filter, stbir_colorspace space, + void *alloc_context) +{ + return stbir__resize_arbitrary(alloc_context, input_pixels, input_w, input_h, input_stride_in_bytes, + output_pixels, output_w, output_h, output_stride_in_bytes, + 0,0,1,1,NULL,num_channels,alpha_channel,flags, STBIR_TYPE_UINT16, filter, filter, + edge_wrap_mode, edge_wrap_mode, space); +} + + +STBIRDEF int stbir_resize_float_generic( const float *input_pixels , int input_w , int input_h , int input_stride_in_bytes, + float *output_pixels , int output_w, int output_h, int output_stride_in_bytes, + int num_channels, int alpha_channel, int flags, + stbir_edge edge_wrap_mode, stbir_filter filter, stbir_colorspace space, + void *alloc_context) +{ + return stbir__resize_arbitrary(alloc_context, input_pixels, input_w, input_h, input_stride_in_bytes, + output_pixels, output_w, output_h, output_stride_in_bytes, + 0,0,1,1,NULL,num_channels,alpha_channel,flags, STBIR_TYPE_FLOAT, filter, filter, + edge_wrap_mode, edge_wrap_mode, space); +} + + +STBIRDEF int stbir_resize( const void *input_pixels , int input_w , int input_h , int input_stride_in_bytes, + void *output_pixels, int output_w, int output_h, int output_stride_in_bytes, + stbir_datatype datatype, + int num_channels, int alpha_channel, int flags, + stbir_edge edge_mode_horizontal, stbir_edge edge_mode_vertical, + stbir_filter filter_horizontal, stbir_filter filter_vertical, + stbir_colorspace space, void *alloc_context) +{ + return stbir__resize_arbitrary(alloc_context, input_pixels, input_w, input_h, input_stride_in_bytes, + output_pixels, output_w, output_h, output_stride_in_bytes, + 0,0,1,1,NULL,num_channels,alpha_channel,flags, datatype, filter_horizontal, filter_vertical, + edge_mode_horizontal, edge_mode_vertical, space); +} + + +STBIRDEF int stbir_resize_subpixel(const void *input_pixels , int input_w , int input_h , int input_stride_in_bytes, + void *output_pixels, int output_w, int output_h, int output_stride_in_bytes, + stbir_datatype datatype, + int num_channels, int alpha_channel, int flags, + stbir_edge edge_mode_horizontal, stbir_edge edge_mode_vertical, + stbir_filter filter_horizontal, stbir_filter filter_vertical, + stbir_colorspace space, void *alloc_context, + float x_scale, float y_scale, + float x_offset, float y_offset) +{ + float transform[4]; + transform[0] = x_scale; + transform[1] = y_scale; + transform[2] = x_offset; + transform[3] = y_offset; + return stbir__resize_arbitrary(alloc_context, input_pixels, input_w, input_h, input_stride_in_bytes, + output_pixels, output_w, output_h, output_stride_in_bytes, + 0,0,1,1,transform,num_channels,alpha_channel,flags, datatype, filter_horizontal, filter_vertical, + edge_mode_horizontal, edge_mode_vertical, space); +} + +STBIRDEF int stbir_resize_region( const void *input_pixels , int input_w , int input_h , int input_stride_in_bytes, + void *output_pixels, int output_w, int output_h, int output_stride_in_bytes, + stbir_datatype datatype, + int num_channels, int alpha_channel, int flags, + stbir_edge edge_mode_horizontal, stbir_edge edge_mode_vertical, + stbir_filter filter_horizontal, stbir_filter filter_vertical, + stbir_colorspace space, void *alloc_context, + float s0, float t0, float s1, float t1) +{ + return stbir__resize_arbitrary(alloc_context, input_pixels, input_w, input_h, input_stride_in_bytes, + output_pixels, output_w, output_h, output_stride_in_bytes, + s0,t0,s1,t1,NULL,num_channels,alpha_channel,flags, datatype, filter_horizontal, filter_vertical, + edge_mode_horizontal, edge_mode_vertical, space); +} + +#endif // STB_IMAGE_RESIZE_IMPLEMENTATION diff --git a/util.cpp b/util.cpp index 94b7314..96310cb 100644 --- a/util.cpp +++ b/util.cpp @@ -25,6 +25,9 @@ #include "ggml/ggml.h" #include "stable-diffusion.h" +#define STB_IMAGE_RESIZE_IMPLEMENTATION +#include "stb_image_resize.h" + bool ends_with(const std::string& str, const std::string& ending) { if (str.length() >= ending.length()) { return (str.compare(str.length() - ending.length(), ending.length(), ending) == 0); @@ -40,6 +43,13 @@ bool starts_with(const std::string& str, const std::string& start) { return false; } +bool contains(const std::string& str, const std::string& substr) { + if (str.find(substr) != std::string::npos) { + return true; + } + return false; +} + void replace_all_chars(std::string& str, char target, char replacement) { for (size_t i = 0; i < str.length(); ++i) { if (str[i] == target) { @@ -88,6 +98,43 @@ std::string get_full_path(const std::string& dir, const std::string& filename) { } } +std::vector get_files_from_dir(const std::string& dir) { + std::vector files; + + WIN32_FIND_DATA findFileData; + HANDLE hFind; + + char currentDirectory[MAX_PATH]; + GetCurrentDirectory(MAX_PATH, currentDirectory); + + char directoryPath[MAX_PATH]; // this is absolute path + sprintf(directoryPath, "%s\\%s\\*", currentDirectory, dir.c_str()); + + // Find the first file in the directory + hFind = FindFirstFile(directoryPath, &findFileData); + + // Check if the directory was found + if (hFind == INVALID_HANDLE_VALUE) { + printf("Unable to find directory.\n"); + return files; + } + + // Loop through all files in the directory + do { + // Check if the found file is a regular file (not a directory) + if (!(findFileData.dwFileAttributes & FILE_ATTRIBUTE_DIRECTORY)) { + files.push_back(std::string(currentDirectory) + "\\" + dir + "\\" + std::string(findFileData.cFileName)); + } + } while (FindNextFile(hFind, &findFileData) != 0); + + // Close the handle + FindClose(hFind); + + sort(files.begin(), files.end()); + + return files; +} + #else // Unix #include #include @@ -102,6 +149,7 @@ bool is_directory(const std::string& path) { return (stat(path.c_str(), &buffer) == 0 && S_ISDIR(buffer.st_mode)); } +// TODO: add windows version std::string get_full_path(const std::string& dir, const std::string& filename) { DIR* dp = opendir(dir.c_str()); @@ -121,6 +169,27 @@ std::string get_full_path(const std::string& dir, const std::string& filename) { return ""; } +std::vector get_files_from_dir(const std::string& dir) { + std::vector files; + + DIR* dp = opendir(dir.c_str()); + + if (dp != nullptr) { + struct dirent* entry; + + while ((entry = readdir(dp)) != nullptr) { + std::string fname = dir + "/" + entry->d_name; + if (!is_directory(fname)) + files.push_back(fname); + } + closedir(dp); + } + + sort(files.begin(), files.end()); + + return files; +} + #endif // get_num_physical_cores is copy from @@ -161,8 +230,8 @@ int32_t get_num_physical_cores() { return n_threads > 0 ? (n_threads <= 4 ? n_threads : n_threads / 2) : 4; } -static sd_progress_cb_t sd_progress_cb = NULL; -void* sd_progress_cb_data = NULL; +static sd_progress_cb_t sd_progress_cb = NULL; +void* sd_progress_cb_data = NULL; std::u32string utf8_to_utf32(const std::string& utf8_str) { std::wstring_convert, char32_t> converter; @@ -207,9 +276,42 @@ std::string path_join(const std::string& p1, const std::string& p2) { return p1 + "/" + p2; } +sd_image_t* preprocess_id_image(sd_image_t* img) { + int shortest_edge = 224; + int size = shortest_edge; + sd_image_t* resized = NULL; + uint32_t w = img->width; + uint32_t h = img->height; + uint32_t c = img->channel; + + // 1. do resize using stb_resize functions + + unsigned char* buf = (unsigned char*)malloc(sizeof(unsigned char) * 3 * size * size); + if (!stbir_resize_uint8(img->data, w, h, 0, + buf, size, size, 0, + c)) { + fprintf(stderr, "%s: resize operation failed \n ", __func__); + return resized; + } + + // 2. do center crop (likely unnecessary due to step 1) + + // 3. do rescale + + // 4. do normalize + + // 3 and 4 will need to be done in float format. + + resized = new sd_image_t{(uint32_t)shortest_edge, + (uint32_t)shortest_edge, + 3, + buf}; + return resized; +} + void pretty_progress(int step, int steps, float time) { if (sd_progress_cb) { - sd_progress_cb(step,steps,time, sd_progress_cb_data); + sd_progress_cb(step, steps, time, sd_progress_cb_data); return; } if (step == 0) { @@ -255,9 +357,8 @@ std::string trim(const std::string& s) { return rtrim(ltrim(s)); } -static sd_log_cb_t sd_log_cb = NULL; -void* sd_log_cb_data = NULL; - +static sd_log_cb_t sd_log_cb = NULL; +void* sd_log_cb_data = NULL; #define LOG_BUFFER_SIZE 1024 diff --git a/util.h b/util.h index c562ba5..b8b941d 100644 --- a/util.h +++ b/util.h @@ -3,11 +3,13 @@ #include #include +#include #include "stable-diffusion.h" bool ends_with(const std::string& str, const std::string& ending); bool starts_with(const std::string& str, const std::string& start); +bool contains(const std::string& str, const std::string& substr); std::string format(const char* fmt, ...); @@ -17,10 +19,16 @@ bool file_exists(const std::string& filename); bool is_directory(const std::string& path); std::string get_full_path(const std::string& dir, const std::string& filename); +std::vector get_files_from_dir(const std::string& dir); + std::u32string utf8_to_utf32(const std::string& utf8_str); std::string utf32_to_utf8(const std::u32string& utf32_str); std::u32string unicode_value_to_utf32(int unicode_value); +sd_image_t* preprocess_id_image(sd_image_t* img); + +// std::string sd_basename(const std::string& path); + typedef struct { uint32_t width; uint32_t height;