feat: implement the complete bpe function (#119)
* implement the complete bpe function --------- Co-authored-by: leejet <leejet714@gmail.com>
This commit is contained in:
parent
8f6b4a39d6
commit
0e64238e4c
@ -51,7 +51,6 @@ Inference of [Stable Diffusion](https://github.com/CompVis/stable-diffusion) in
|
|||||||
- The current implementation of ggml_conv_2d is slow and has high memory usage
|
- The current implementation of ggml_conv_2d is slow and has high memory usage
|
||||||
- Implement Winograd Convolution 2D for 3x3 kernel filtering
|
- Implement Winograd Convolution 2D for 3x3 kernel filtering
|
||||||
- [ ] Continuing to reduce memory usage (quantizing the weights of ggml_conv_2d)
|
- [ ] Continuing to reduce memory usage (quantizing the weights of ggml_conv_2d)
|
||||||
- [ ] Implement BPE Tokenizer
|
|
||||||
- [ ] Implement [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN/tree/master) upscaler
|
- [ ] Implement [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN/tree/master) upscaler
|
||||||
- [ ] k-quants support
|
- [ ] k-quants support
|
||||||
|
|
||||||
|
@ -520,7 +520,6 @@ std::vector<std::pair<int, std::u32string>> bytes_to_unicode() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Ref: https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py
|
// Ref: https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py
|
||||||
// TODO: implement bpe
|
|
||||||
class CLIPTokenizer {
|
class CLIPTokenizer {
|
||||||
private:
|
private:
|
||||||
SDVersion version = VERSION_1_x;
|
SDVersion version = VERSION_1_x;
|
||||||
@ -547,6 +546,21 @@ private:
|
|||||||
return text;
|
return text;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static std::set<std::pair<std::u32string, std::u32string>> get_pairs(const std::vector<std::u32string>& subwords) {
|
||||||
|
std::set<std::pair<std::u32string, std::u32string>> pairs;
|
||||||
|
if (subwords.size() == 0) {
|
||||||
|
return pairs;
|
||||||
|
}
|
||||||
|
std::u32string prev_subword = subwords[0];
|
||||||
|
for (int i = 1; i < subwords.size(); i++) {
|
||||||
|
std::u32string subword = subwords[i];
|
||||||
|
std::pair<std::u32string, std::u32string> pair(prev_subword, subword);
|
||||||
|
pairs.insert(pair);
|
||||||
|
prev_subword = subword;
|
||||||
|
}
|
||||||
|
return pairs;
|
||||||
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
CLIPTokenizer(SDVersion version = VERSION_1_x)
|
CLIPTokenizer(SDVersion version = VERSION_1_x)
|
||||||
: version(version) {}
|
: version(version) {}
|
||||||
@ -565,7 +579,9 @@ public:
|
|||||||
merges.push_back(merges_utf32_str.substr(start, pos - start));
|
merges.push_back(merges_utf32_str.substr(start, pos - start));
|
||||||
start = pos + 1;
|
start = pos + 1;
|
||||||
}
|
}
|
||||||
merges = std::vector<std::u32string>(merges.begin() + 1, merges.begin() + 49152 - 256 - 2 + 1);
|
// LOG_DEBUG("merges size %llu", merges.size());
|
||||||
|
GGML_ASSERT(merges.size() == 48895);
|
||||||
|
merges = std::vector<std::u32string>(merges.begin() + 1, merges.end());
|
||||||
std::vector<std::pair<std::u32string, std::u32string>> merge_pairs;
|
std::vector<std::pair<std::u32string, std::u32string>> merge_pairs;
|
||||||
for (const auto& merge : merges) {
|
for (const auto& merge : merges) {
|
||||||
size_t space_pos = merge.find(' ');
|
size_t space_pos = merge.find(' ');
|
||||||
@ -596,14 +612,79 @@ public:
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
std::u32string bpe(std::u32string token) {
|
std::u32string bpe(const std::u32string& token) {
|
||||||
std::u32string word = token + utf8_to_utf32("</w>");
|
std::vector<std::u32string> word;
|
||||||
if (encoder.find(word) != encoder.end()) {
|
|
||||||
return word;
|
for (int i = 0; i < token.size() - 1; i++) {
|
||||||
} else if (encoder.find(token) != encoder.end()) {
|
word.emplace_back(1, token[i]);
|
||||||
return token;
|
|
||||||
}
|
}
|
||||||
return utf8_to_utf32(UNK_TOKEN);
|
word.push_back(token.substr(token.size() - 1) + utf8_to_utf32("</w>"));
|
||||||
|
|
||||||
|
std::set<std::pair<std::u32string, std::u32string>> pairs = get_pairs(word);
|
||||||
|
|
||||||
|
if (pairs.empty()) {
|
||||||
|
return token + utf8_to_utf32("</w>");
|
||||||
|
}
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
auto min_pair_iter = std::min_element(pairs.begin(),
|
||||||
|
pairs.end(),
|
||||||
|
[&](const std::pair<std::u32string, std::u32string>& a,
|
||||||
|
const std::pair<std::u32string, std::u32string>& b) {
|
||||||
|
if (bpe_ranks.find(a) == bpe_ranks.end()) {
|
||||||
|
return false;
|
||||||
|
} else if (bpe_ranks.find(b) == bpe_ranks.end()) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return bpe_ranks.at(a) < bpe_ranks.at(b);
|
||||||
|
});
|
||||||
|
|
||||||
|
const std::pair<std::u32string, std::u32string>& bigram = *min_pair_iter;
|
||||||
|
|
||||||
|
if (bpe_ranks.find(bigram) == bpe_ranks.end()) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::u32string first = bigram.first;
|
||||||
|
std::u32string second = bigram.second;
|
||||||
|
std::vector<std::u32string> new_word;
|
||||||
|
int32_t i = 0;
|
||||||
|
|
||||||
|
while (i < word.size()) {
|
||||||
|
auto it = std::find(word.begin() + i, word.end(), first);
|
||||||
|
if (it == word.end()) {
|
||||||
|
new_word.insert(new_word.end(), word.begin() + i, word.end());
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
new_word.insert(new_word.end(), word.begin() + i, it);
|
||||||
|
i = static_cast<int32_t>(std::distance(word.begin(), it));
|
||||||
|
|
||||||
|
if (word[i] == first && i < static_cast<int32_t>(word.size()) - 1 && word[i + 1] == second) {
|
||||||
|
new_word.push_back(first + second);
|
||||||
|
i += 2;
|
||||||
|
} else {
|
||||||
|
new_word.push_back(word[i]);
|
||||||
|
i += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
word = new_word;
|
||||||
|
|
||||||
|
if (word.size() == 1) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
pairs = get_pairs(word);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::u32string result;
|
||||||
|
for (int i = 0; i < word.size(); i++) {
|
||||||
|
result += word[i];
|
||||||
|
if (i != word.size() - 1) {
|
||||||
|
result += utf8_to_utf32(" ");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<int> tokenize(std::string text, size_t max_length = 0, bool padding = false) {
|
std::vector<int> tokenize(std::string text, size_t max_length = 0, bool padding = false) {
|
||||||
|
@ -5,6 +5,8 @@
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "ggml/ggml.h"
|
||||||
|
|
||||||
enum RNGType {
|
enum RNGType {
|
||||||
STD_DEFAULT_RNG,
|
STD_DEFAULT_RNG,
|
||||||
CUDA_RNG
|
CUDA_RNG
|
||||||
@ -42,10 +44,12 @@ public:
|
|||||||
bool free_params_immediately = false,
|
bool free_params_immediately = false,
|
||||||
std::string lora_model_dir = "",
|
std::string lora_model_dir = "",
|
||||||
RNGType rng_type = STD_DEFAULT_RNG);
|
RNGType rng_type = STD_DEFAULT_RNG);
|
||||||
|
|
||||||
bool load_from_file(const std::string& model_path,
|
bool load_from_file(const std::string& model_path,
|
||||||
const std::string& vae_path,
|
const std::string& vae_path,
|
||||||
ggml_type wtype,
|
ggml_type wtype,
|
||||||
Schedule d = DEFAULT);
|
Schedule d = DEFAULT);
|
||||||
|
|
||||||
std::vector<uint8_t*> txt2img(
|
std::vector<uint8_t*> txt2img(
|
||||||
std::string prompt,
|
std::string prompt,
|
||||||
std::string negative_prompt,
|
std::string negative_prompt,
|
||||||
|
Loading…
Reference in New Issue
Block a user