From 0e64238e4c4c902e0c043b741ae48fe22a2fd0fa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=97=BA=E6=97=BA=E7=A2=8E=E5=86=B0=E5=86=B0?= <38837039+Cyberhan123@users.noreply.github.com> Date: Sat, 23 Dec 2023 12:11:07 +0800 Subject: [PATCH] feat: implement the complete bpe function (#119) * implement the complete bpe function --------- Co-authored-by: leejet --- .clang-format | 2 +- README.md | 1 - stable-diffusion.cpp | 99 ++++++++++++++++++++++++++++++++++++++++---- stable-diffusion.h | 4 ++ 4 files changed, 95 insertions(+), 11 deletions(-) diff --git a/.clang-format b/.clang-format index 58d1885..4fe720b 100644 --- a/.clang-format +++ b/.clang-format @@ -7,7 +7,7 @@ IndentCaseLabels: false ColumnLimit: 0 AccessModifierOffset: -4 NamespaceIndentation: All -FixNamespaceComments: false +FixNamespaceComments: false AlignAfterOpenBracket: true AlignConsecutiveAssignments: true IndentCaseLabels: true \ No newline at end of file diff --git a/README.md b/README.md index ed75459..a0765da 100644 --- a/README.md +++ b/README.md @@ -51,7 +51,6 @@ Inference of [Stable Diffusion](https://github.com/CompVis/stable-diffusion) in - The current implementation of ggml_conv_2d is slow and has high memory usage - Implement Winograd Convolution 2D for 3x3 kernel filtering - [ ] Continuing to reduce memory usage (quantizing the weights of ggml_conv_2d) -- [ ] Implement BPE Tokenizer - [ ] Implement [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN/tree/master) upscaler - [ ] k-quants support diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 98b2045..d0c499f 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -520,7 +520,6 @@ std::vector> bytes_to_unicode() { } // Ref: https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py -// TODO: implement bpe class CLIPTokenizer { private: SDVersion version = VERSION_1_x; @@ -547,6 +546,21 @@ private: return text; } + static std::set> get_pairs(const std::vector& subwords) { + std::set> pairs; + if (subwords.size() == 0) { + return pairs; + } + std::u32string prev_subword = subwords[0]; + for (int i = 1; i < subwords.size(); i++) { + std::u32string subword = subwords[i]; + std::pair pair(prev_subword, subword); + pairs.insert(pair); + prev_subword = subword; + } + return pairs; + } + public: CLIPTokenizer(SDVersion version = VERSION_1_x) : version(version) {} @@ -565,7 +579,9 @@ public: merges.push_back(merges_utf32_str.substr(start, pos - start)); start = pos + 1; } - merges = std::vector(merges.begin() + 1, merges.begin() + 49152 - 256 - 2 + 1); + // LOG_DEBUG("merges size %llu", merges.size()); + GGML_ASSERT(merges.size() == 48895); + merges = std::vector(merges.begin() + 1, merges.end()); std::vector> merge_pairs; for (const auto& merge : merges) { size_t space_pos = merge.find(' '); @@ -596,14 +612,79 @@ public: } }; - std::u32string bpe(std::u32string token) { - std::u32string word = token + utf8_to_utf32(""); - if (encoder.find(word) != encoder.end()) { - return word; - } else if (encoder.find(token) != encoder.end()) { - return token; + std::u32string bpe(const std::u32string& token) { + std::vector word; + + for (int i = 0; i < token.size() - 1; i++) { + word.emplace_back(1, token[i]); } - return utf8_to_utf32(UNK_TOKEN); + word.push_back(token.substr(token.size() - 1) + utf8_to_utf32("")); + + std::set> pairs = get_pairs(word); + + if (pairs.empty()) { + return token + utf8_to_utf32(""); + } + + while (true) { + auto min_pair_iter = std::min_element(pairs.begin(), + pairs.end(), + [&](const std::pair& a, + const std::pair& b) { + if (bpe_ranks.find(a) == bpe_ranks.end()) { + return false; + } else if (bpe_ranks.find(b) == bpe_ranks.end()) { + return true; + } + return bpe_ranks.at(a) < bpe_ranks.at(b); + }); + + const std::pair& bigram = *min_pair_iter; + + if (bpe_ranks.find(bigram) == bpe_ranks.end()) { + break; + } + + std::u32string first = bigram.first; + std::u32string second = bigram.second; + std::vector new_word; + int32_t i = 0; + + while (i < word.size()) { + auto it = std::find(word.begin() + i, word.end(), first); + if (it == word.end()) { + new_word.insert(new_word.end(), word.begin() + i, word.end()); + break; + } + new_word.insert(new_word.end(), word.begin() + i, it); + i = static_cast(std::distance(word.begin(), it)); + + if (word[i] == first && i < static_cast(word.size()) - 1 && word[i + 1] == second) { + new_word.push_back(first + second); + i += 2; + } else { + new_word.push_back(word[i]); + i += 1; + } + } + + word = new_word; + + if (word.size() == 1) { + break; + } + pairs = get_pairs(word); + } + + std::u32string result; + for (int i = 0; i < word.size(); i++) { + result += word[i]; + if (i != word.size() - 1) { + result += utf8_to_utf32(" "); + } + } + + return result; } std::vector tokenize(std::string text, size_t max_length = 0, bool padding = false) { diff --git a/stable-diffusion.h b/stable-diffusion.h index 095016c..c94f6c7 100644 --- a/stable-diffusion.h +++ b/stable-diffusion.h @@ -5,6 +5,8 @@ #include #include +#include "ggml/ggml.h" + enum RNGType { STD_DEFAULT_RNG, CUDA_RNG @@ -42,10 +44,12 @@ public: bool free_params_immediately = false, std::string lora_model_dir = "", RNGType rng_type = STD_DEFAULT_RNG); + bool load_from_file(const std::string& model_path, const std::string& vae_path, ggml_type wtype, Schedule d = DEFAULT); + std::vector txt2img( std::string prompt, std::string negative_prompt,