feat: implement the complete bpe function (#119)
* implement the complete bpe function --------- Co-authored-by: leejet <leejet714@gmail.com>
This commit is contained in:
parent
8f6b4a39d6
commit
0e64238e4c
@ -7,7 +7,7 @@ IndentCaseLabels: false
|
||||
ColumnLimit: 0
|
||||
AccessModifierOffset: -4
|
||||
NamespaceIndentation: All
|
||||
FixNamespaceComments: false
|
||||
FixNamespaceComments: false
|
||||
AlignAfterOpenBracket: true
|
||||
AlignConsecutiveAssignments: true
|
||||
IndentCaseLabels: true
|
@ -51,7 +51,6 @@ Inference of [Stable Diffusion](https://github.com/CompVis/stable-diffusion) in
|
||||
- The current implementation of ggml_conv_2d is slow and has high memory usage
|
||||
- Implement Winograd Convolution 2D for 3x3 kernel filtering
|
||||
- [ ] Continuing to reduce memory usage (quantizing the weights of ggml_conv_2d)
|
||||
- [ ] Implement BPE Tokenizer
|
||||
- [ ] Implement [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN/tree/master) upscaler
|
||||
- [ ] k-quants support
|
||||
|
||||
|
@ -520,7 +520,6 @@ std::vector<std::pair<int, std::u32string>> bytes_to_unicode() {
|
||||
}
|
||||
|
||||
// Ref: https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py
|
||||
// TODO: implement bpe
|
||||
class CLIPTokenizer {
|
||||
private:
|
||||
SDVersion version = VERSION_1_x;
|
||||
@ -547,6 +546,21 @@ private:
|
||||
return text;
|
||||
}
|
||||
|
||||
static std::set<std::pair<std::u32string, std::u32string>> get_pairs(const std::vector<std::u32string>& subwords) {
|
||||
std::set<std::pair<std::u32string, std::u32string>> pairs;
|
||||
if (subwords.size() == 0) {
|
||||
return pairs;
|
||||
}
|
||||
std::u32string prev_subword = subwords[0];
|
||||
for (int i = 1; i < subwords.size(); i++) {
|
||||
std::u32string subword = subwords[i];
|
||||
std::pair<std::u32string, std::u32string> pair(prev_subword, subword);
|
||||
pairs.insert(pair);
|
||||
prev_subword = subword;
|
||||
}
|
||||
return pairs;
|
||||
}
|
||||
|
||||
public:
|
||||
CLIPTokenizer(SDVersion version = VERSION_1_x)
|
||||
: version(version) {}
|
||||
@ -565,7 +579,9 @@ public:
|
||||
merges.push_back(merges_utf32_str.substr(start, pos - start));
|
||||
start = pos + 1;
|
||||
}
|
||||
merges = std::vector<std::u32string>(merges.begin() + 1, merges.begin() + 49152 - 256 - 2 + 1);
|
||||
// LOG_DEBUG("merges size %llu", merges.size());
|
||||
GGML_ASSERT(merges.size() == 48895);
|
||||
merges = std::vector<std::u32string>(merges.begin() + 1, merges.end());
|
||||
std::vector<std::pair<std::u32string, std::u32string>> merge_pairs;
|
||||
for (const auto& merge : merges) {
|
||||
size_t space_pos = merge.find(' ');
|
||||
@ -596,14 +612,79 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
std::u32string bpe(std::u32string token) {
|
||||
std::u32string word = token + utf8_to_utf32("</w>");
|
||||
if (encoder.find(word) != encoder.end()) {
|
||||
return word;
|
||||
} else if (encoder.find(token) != encoder.end()) {
|
||||
return token;
|
||||
std::u32string bpe(const std::u32string& token) {
|
||||
std::vector<std::u32string> word;
|
||||
|
||||
for (int i = 0; i < token.size() - 1; i++) {
|
||||
word.emplace_back(1, token[i]);
|
||||
}
|
||||
return utf8_to_utf32(UNK_TOKEN);
|
||||
word.push_back(token.substr(token.size() - 1) + utf8_to_utf32("</w>"));
|
||||
|
||||
std::set<std::pair<std::u32string, std::u32string>> pairs = get_pairs(word);
|
||||
|
||||
if (pairs.empty()) {
|
||||
return token + utf8_to_utf32("</w>");
|
||||
}
|
||||
|
||||
while (true) {
|
||||
auto min_pair_iter = std::min_element(pairs.begin(),
|
||||
pairs.end(),
|
||||
[&](const std::pair<std::u32string, std::u32string>& a,
|
||||
const std::pair<std::u32string, std::u32string>& b) {
|
||||
if (bpe_ranks.find(a) == bpe_ranks.end()) {
|
||||
return false;
|
||||
} else if (bpe_ranks.find(b) == bpe_ranks.end()) {
|
||||
return true;
|
||||
}
|
||||
return bpe_ranks.at(a) < bpe_ranks.at(b);
|
||||
});
|
||||
|
||||
const std::pair<std::u32string, std::u32string>& bigram = *min_pair_iter;
|
||||
|
||||
if (bpe_ranks.find(bigram) == bpe_ranks.end()) {
|
||||
break;
|
||||
}
|
||||
|
||||
std::u32string first = bigram.first;
|
||||
std::u32string second = bigram.second;
|
||||
std::vector<std::u32string> new_word;
|
||||
int32_t i = 0;
|
||||
|
||||
while (i < word.size()) {
|
||||
auto it = std::find(word.begin() + i, word.end(), first);
|
||||
if (it == word.end()) {
|
||||
new_word.insert(new_word.end(), word.begin() + i, word.end());
|
||||
break;
|
||||
}
|
||||
new_word.insert(new_word.end(), word.begin() + i, it);
|
||||
i = static_cast<int32_t>(std::distance(word.begin(), it));
|
||||
|
||||
if (word[i] == first && i < static_cast<int32_t>(word.size()) - 1 && word[i + 1] == second) {
|
||||
new_word.push_back(first + second);
|
||||
i += 2;
|
||||
} else {
|
||||
new_word.push_back(word[i]);
|
||||
i += 1;
|
||||
}
|
||||
}
|
||||
|
||||
word = new_word;
|
||||
|
||||
if (word.size() == 1) {
|
||||
break;
|
||||
}
|
||||
pairs = get_pairs(word);
|
||||
}
|
||||
|
||||
std::u32string result;
|
||||
for (int i = 0; i < word.size(); i++) {
|
||||
result += word[i];
|
||||
if (i != word.size() - 1) {
|
||||
result += utf8_to_utf32(" ");
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<int> tokenize(std::string text, size_t max_length = 0, bool padding = false) {
|
||||
|
@ -5,6 +5,8 @@
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "ggml/ggml.h"
|
||||
|
||||
enum RNGType {
|
||||
STD_DEFAULT_RNG,
|
||||
CUDA_RNG
|
||||
@ -42,10 +44,12 @@ public:
|
||||
bool free_params_immediately = false,
|
||||
std::string lora_model_dir = "",
|
||||
RNGType rng_type = STD_DEFAULT_RNG);
|
||||
|
||||
bool load_from_file(const std::string& model_path,
|
||||
const std::string& vae_path,
|
||||
ggml_type wtype,
|
||||
Schedule d = DEFAULT);
|
||||
|
||||
std::vector<uint8_t*> txt2img(
|
||||
std::string prompt,
|
||||
std::string negative_prompt,
|
||||
|
Loading…
Reference in New Issue
Block a user