fix: enhance the tokenizer's handing of Unicode (#120)

This commit is contained in:
leejet 2023-12-21 00:22:03 +08:00 committed by GitHub
parent 9842a3f819
commit 8f6b4a39d6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 43830 additions and 80117 deletions

View File

@ -1192,20 +1192,9 @@ ggml_type ModelLoader::get_sd_wtype() {
return GGML_TYPE_COUNT; return GGML_TYPE_COUNT;
} }
bool ModelLoader::load_vocab(on_new_token_cb_t on_new_token_cb) { std::string ModelLoader::load_merges() {
char* vocab_buffer = reinterpret_cast<char*>(vocab_json); std::string merges_utf8_str(reinterpret_cast<const char*>(merges_utf8_c_str), sizeof(merges_utf8_c_str));
nlohmann::json vocab = nlohmann::json::parse(vocab_buffer); return merges_utf8_str;
std::map<char, int> decoder = unicode_to_byte();
for (auto& it : vocab.items()) {
int token_id = it.value();
std::string token_str = it.key();
std::string token = "";
for (char c : token_str) {
token += decoder[c];
}
on_new_token_cb(token, token_id);
}
return true;
} }
bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb) { bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb) {

View File

@ -115,7 +115,7 @@ public:
bool init_from_file(const std::string& file_path, const std::string& prefix = ""); bool init_from_file(const std::string& file_path, const std::string& prefix = "");
SDVersion get_sd_version(); SDVersion get_sd_version();
ggml_type get_sd_wtype(); ggml_type get_sd_wtype();
bool load_vocab(on_new_token_cb_t on_new_token_cb); std::string load_merges();
bool load_tensors(on_new_tensor_cb_t on_new_tensor_cb); bool load_tensors(on_new_tensor_cb_t on_new_tensor_cb);
int64_t cal_mem_size(ggml_backend_t backend); int64_t cal_mem_size(ggml_backend_t backend);
~ModelLoader() = default; ~ModelLoader() = default;

View File

@ -493,12 +493,40 @@ const int BOS_TOKEN_ID = 49406;
const int EOS_TOKEN_ID = 49407; const int EOS_TOKEN_ID = 49407;
const int PAD_TOKEN_ID = 49407; const int PAD_TOKEN_ID = 49407;
std::vector<std::pair<int, std::u32string>> bytes_to_unicode() {
std::vector<std::pair<int, std::u32string>> byte_unicode_pairs;
std::set<int> byte_set;
for (int b = static_cast<int>('!'); b <= static_cast<int>('~'); ++b) {
byte_set.insert(b);
byte_unicode_pairs.push_back(std::pair<int, std::u32string>(b, unicode_value_to_utf32(b)));
}
for (int b = 161; b <= 172; ++b) {
byte_set.insert(b);
byte_unicode_pairs.push_back(std::pair<int, std::u32string>(b, unicode_value_to_utf32(b)));
}
for (int b = 174; b <= 255; ++b) {
byte_set.insert(b);
byte_unicode_pairs.push_back(std::pair<int, std::u32string>(b, unicode_value_to_utf32(b)));
}
int n = 0;
for (int b = 0; b < 256; ++b) {
if (byte_set.find(b) == byte_set.end()) {
byte_unicode_pairs.push_back(std::pair<int, std::u32string>(b, unicode_value_to_utf32(n + 256)));
++n;
}
}
// LOG_DEBUG("byte_unicode_pairs %d", byte_unicode_pairs.size());
return byte_unicode_pairs;
}
// Ref: https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py // Ref: https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py
// TODO: implement bpe // TODO: implement bpe
class CLIPTokenizer { class CLIPTokenizer {
private: private:
SDVersion version = VERSION_1_x; SDVersion version = VERSION_1_x;
std::map<std::string, int32_t> encoder; std::map<int, std::u32string> byte_encoder;
std::map<std::u32string, int> encoder;
std::map<std::pair<std::u32string, std::u32string>, int> bpe_ranks;
std::regex pat; std::regex pat;
static std::string strip(const std::string& str) { static std::string strip(const std::string& str) {
@ -521,19 +549,61 @@ private:
public: public:
CLIPTokenizer(SDVersion version = VERSION_1_x) CLIPTokenizer(SDVersion version = VERSION_1_x)
: version(version){}; : version(version) {}
std::string bpe(std::string token) {
std::string word = token + "</w>"; void load_from_merges(const std::string& merges_utf8_str) {
auto byte_unicode_pairs = bytes_to_unicode();
byte_encoder = std::map<int, std::u32string>(byte_unicode_pairs.begin(), byte_unicode_pairs.end());
// for (auto & pair: byte_unicode_pairs) {
// std::cout << pair.first << ": " << pair.second << std::endl;
// }
std::vector<std::u32string> merges;
size_t start = 0;
size_t pos;
std::u32string merges_utf32_str = utf8_to_utf32(merges_utf8_str);
while ((pos = merges_utf32_str.find('\n', start)) != std::string::npos) {
merges.push_back(merges_utf32_str.substr(start, pos - start));
start = pos + 1;
}
merges = std::vector<std::u32string>(merges.begin() + 1, merges.begin() + 49152 - 256 - 2 + 1);
std::vector<std::pair<std::u32string, std::u32string>> merge_pairs;
for (const auto& merge : merges) {
size_t space_pos = merge.find(' ');
merge_pairs.emplace_back(merge.substr(0, space_pos), merge.substr(space_pos + 1));
// LOG_DEBUG("%s", utf32_to_utf8(merge.substr(space_pos + 1)).c_str());
}
std::vector<std::u32string> vocab;
for (const auto& pair : byte_unicode_pairs) {
vocab.push_back(pair.second);
}
for (const auto& pair : byte_unicode_pairs) {
vocab.push_back(pair.second + utf8_to_utf32("</w>"));
}
for (const auto& merge : merge_pairs) {
vocab.push_back(merge.first + merge.second);
}
vocab.push_back(utf8_to_utf32("<|startoftext|>"));
vocab.push_back(utf8_to_utf32("<|endoftext|>"));
LOG_DEBUG("vocab size: %llu", vocab.size());
int i = 0;
for (const auto& token : vocab) {
encoder[token] = i++;
}
int rank = 0;
for (const auto& merge : merge_pairs) {
bpe_ranks[merge] = rank++;
}
};
std::u32string bpe(std::u32string token) {
std::u32string word = token + utf8_to_utf32("</w>");
if (encoder.find(word) != encoder.end()) { if (encoder.find(word) != encoder.end()) {
return word; return word;
} else if (encoder.find(token) != encoder.end()) { } else if (encoder.find(token) != encoder.end()) {
return token; return token;
} }
return UNK_TOKEN; return utf8_to_utf32(UNK_TOKEN);
}
void add_token(std::string token, int32_t token_id) {
encoder[token] = token_id;
} }
std::vector<int> tokenize(std::string text, size_t max_length = 0, bool padding = false) { std::vector<int> tokenize(std::string text, size_t max_length = 0, bool padding = false) {
@ -571,13 +641,25 @@ public:
std::vector<std::string> token_strs; std::vector<std::string> token_strs;
while (std::regex_search(str, matches, pat)) { while (std::regex_search(str, matches, pat)) {
for (auto& token : matches) { for (auto& token : matches) {
std::istringstream iss(bpe(token)); std::string token_str = token.str();
std::vector<std::string> tokens{std::istream_iterator<std::string>{iss}, std::u32string utf32_token;
std::istream_iterator<std::string>{}}; for (int i = 0; i < token_str.length(); i++) {
for (const auto& bpe_token : tokens) { char b = token_str[i];
bpe_tokens.push_back(encoder[bpe_token]); utf32_token += byte_encoder[b];
token_strs.push_back(bpe_token);
} }
auto bpe_strs = bpe(utf32_token);
size_t start = 0;
size_t pos;
while ((pos = bpe_strs.find(' ', start)) != std::u32string::npos) {
auto bpe_str = bpe_strs.substr(start, pos - start);
bpe_tokens.push_back(encoder[bpe_str]);
token_strs.push_back(utf32_to_utf8(bpe_str));
start = pos + 1;
}
auto bpe_str = bpe_strs.substr(start, bpe_strs.size() - start);
bpe_tokens.push_back(encoder[bpe_str]);
token_strs.push_back(utf32_to_utf8(bpe_str));
} }
str = matches.suffix(); str = matches.suffix();
} }
@ -4323,15 +4405,14 @@ public:
LOG_INFO("Stable Diffusion weight type: %s", ggml_type_name(model_data_type)); LOG_INFO("Stable Diffusion weight type: %s", ggml_type_name(model_data_type));
LOG_DEBUG("loading vocab"); LOG_DEBUG("loading vocab");
auto add_token = [&](const std::string& token, int32_t token_id) { std::string merges_utf8_str = model_loader.load_merges();
cond_stage_model.tokenizer.add_token(token, token_id); if (merges_utf8_str.size() == 0) {
}; LOG_ERROR("get merges failed: '%s'", model_path.c_str());
bool success = model_loader.load_vocab(add_token);
if (!success) {
LOG_ERROR("get vocab from file failed: '%s'", model_path.c_str());
return false; return false;
} }
cond_stage_model.tokenizer.load_from_merges(merges_utf8_str);
// create the ggml context for network params // create the ggml context for network params
LOG_DEBUG("ggml tensor size = %d bytes", (int)sizeof(ggml_tensor)); LOG_DEBUG("ggml tensor size = %d bytes", (int)sizeof(ggml_tensor));
@ -4431,7 +4512,7 @@ public:
// print_ggml_tensor(alphas_cumprod_tensor); // print_ggml_tensor(alphas_cumprod_tensor);
success = model_loader.load_tensors(on_new_tensor_cb); bool success = model_loader.load_tensors(on_new_tensor_cb);
if (!success) { if (!success) {
LOG_ERROR("load tensors from file failed"); LOG_ERROR("load tensors from file failed");
ggml_free(ctx); ggml_free(ctx);

View File

@ -1,7 +1,9 @@
#include "util.h" #include "util.h"
#include <stdarg.h> #include <stdarg.h>
#include <codecvt>
#include <fstream> #include <fstream>
#include <locale>
#include <thread> #include <thread>
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
@ -119,6 +121,21 @@ int32_t get_num_physical_cores() {
return n_threads > 0 ? (n_threads <= 4 ? n_threads : n_threads / 2) : 4; return n_threads > 0 ? (n_threads <= 4 ? n_threads : n_threads / 2) : 4;
} }
std::u32string utf8_to_utf32(const std::string& utf8_str) {
std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> converter;
return converter.from_bytes(utf8_str);
}
std::string utf32_to_utf8(const std::u32string& utf32_str) {
std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> converter;
return converter.to_bytes(utf32_str);
}
std::u32string unicode_value_to_utf32(int unicode_value) {
std::u32string utf32_string = {static_cast<char32_t>(unicode_value)};
return utf32_string;
}
std::string basename(const std::string& path) { std::string basename(const std::string& path) {
size_t pos = path.find_last_of('/'); size_t pos = path.find_last_of('/');
if (pos != std::string::npos) { if (pos != std::string::npos) {

4
util.h
View File

@ -14,6 +14,10 @@ void replace_all_chars(std::string& str, char target, char replacement);
bool file_exists(const std::string& filename); bool file_exists(const std::string& filename);
bool is_directory(const std::string& path); bool is_directory(const std::string& path);
std::u32string utf8_to_utf32(const std::string& utf8_str);
std::string utf32_to_utf8(const std::u32string& utf32_str);
std::u32string unicode_value_to_utf32(int unicode_value);
std::string basename(const std::string& path); std::string basename(const std::string& path);
std::string path_join(const std::string& p1, const std::string& p2); std::string path_join(const std::string& p1, const std::string& p2);

123782
vocab.hpp

File diff suppressed because it is too large Load Diff