fix: enhance the tokenizer's handing of Unicode (#120)
This commit is contained in:
parent
9842a3f819
commit
8f6b4a39d6
17
model.cpp
17
model.cpp
@ -1192,20 +1192,9 @@ ggml_type ModelLoader::get_sd_wtype() {
|
||||
return GGML_TYPE_COUNT;
|
||||
}
|
||||
|
||||
bool ModelLoader::load_vocab(on_new_token_cb_t on_new_token_cb) {
|
||||
char* vocab_buffer = reinterpret_cast<char*>(vocab_json);
|
||||
nlohmann::json vocab = nlohmann::json::parse(vocab_buffer);
|
||||
std::map<char, int> decoder = unicode_to_byte();
|
||||
for (auto& it : vocab.items()) {
|
||||
int token_id = it.value();
|
||||
std::string token_str = it.key();
|
||||
std::string token = "";
|
||||
for (char c : token_str) {
|
||||
token += decoder[c];
|
||||
}
|
||||
on_new_token_cb(token, token_id);
|
||||
}
|
||||
return true;
|
||||
std::string ModelLoader::load_merges() {
|
||||
std::string merges_utf8_str(reinterpret_cast<const char*>(merges_utf8_c_str), sizeof(merges_utf8_c_str));
|
||||
return merges_utf8_str;
|
||||
}
|
||||
|
||||
bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb) {
|
||||
|
2
model.h
2
model.h
@ -115,7 +115,7 @@ public:
|
||||
bool init_from_file(const std::string& file_path, const std::string& prefix = "");
|
||||
SDVersion get_sd_version();
|
||||
ggml_type get_sd_wtype();
|
||||
bool load_vocab(on_new_token_cb_t on_new_token_cb);
|
||||
std::string load_merges();
|
||||
bool load_tensors(on_new_tensor_cb_t on_new_tensor_cb);
|
||||
int64_t cal_mem_size(ggml_backend_t backend);
|
||||
~ModelLoader() = default;
|
||||
|
@ -493,12 +493,40 @@ const int BOS_TOKEN_ID = 49406;
|
||||
const int EOS_TOKEN_ID = 49407;
|
||||
const int PAD_TOKEN_ID = 49407;
|
||||
|
||||
std::vector<std::pair<int, std::u32string>> bytes_to_unicode() {
|
||||
std::vector<std::pair<int, std::u32string>> byte_unicode_pairs;
|
||||
std::set<int> byte_set;
|
||||
for (int b = static_cast<int>('!'); b <= static_cast<int>('~'); ++b) {
|
||||
byte_set.insert(b);
|
||||
byte_unicode_pairs.push_back(std::pair<int, std::u32string>(b, unicode_value_to_utf32(b)));
|
||||
}
|
||||
for (int b = 161; b <= 172; ++b) {
|
||||
byte_set.insert(b);
|
||||
byte_unicode_pairs.push_back(std::pair<int, std::u32string>(b, unicode_value_to_utf32(b)));
|
||||
}
|
||||
for (int b = 174; b <= 255; ++b) {
|
||||
byte_set.insert(b);
|
||||
byte_unicode_pairs.push_back(std::pair<int, std::u32string>(b, unicode_value_to_utf32(b)));
|
||||
}
|
||||
int n = 0;
|
||||
for (int b = 0; b < 256; ++b) {
|
||||
if (byte_set.find(b) == byte_set.end()) {
|
||||
byte_unicode_pairs.push_back(std::pair<int, std::u32string>(b, unicode_value_to_utf32(n + 256)));
|
||||
++n;
|
||||
}
|
||||
}
|
||||
// LOG_DEBUG("byte_unicode_pairs %d", byte_unicode_pairs.size());
|
||||
return byte_unicode_pairs;
|
||||
}
|
||||
|
||||
// Ref: https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py
|
||||
// TODO: implement bpe
|
||||
class CLIPTokenizer {
|
||||
private:
|
||||
SDVersion version = VERSION_1_x;
|
||||
std::map<std::string, int32_t> encoder;
|
||||
std::map<int, std::u32string> byte_encoder;
|
||||
std::map<std::u32string, int> encoder;
|
||||
std::map<std::pair<std::u32string, std::u32string>, int> bpe_ranks;
|
||||
std::regex pat;
|
||||
|
||||
static std::string strip(const std::string& str) {
|
||||
@ -521,19 +549,61 @@ private:
|
||||
|
||||
public:
|
||||
CLIPTokenizer(SDVersion version = VERSION_1_x)
|
||||
: version(version){};
|
||||
std::string bpe(std::string token) {
|
||||
std::string word = token + "</w>";
|
||||
: version(version) {}
|
||||
|
||||
void load_from_merges(const std::string& merges_utf8_str) {
|
||||
auto byte_unicode_pairs = bytes_to_unicode();
|
||||
byte_encoder = std::map<int, std::u32string>(byte_unicode_pairs.begin(), byte_unicode_pairs.end());
|
||||
// for (auto & pair: byte_unicode_pairs) {
|
||||
// std::cout << pair.first << ": " << pair.second << std::endl;
|
||||
// }
|
||||
std::vector<std::u32string> merges;
|
||||
size_t start = 0;
|
||||
size_t pos;
|
||||
std::u32string merges_utf32_str = utf8_to_utf32(merges_utf8_str);
|
||||
while ((pos = merges_utf32_str.find('\n', start)) != std::string::npos) {
|
||||
merges.push_back(merges_utf32_str.substr(start, pos - start));
|
||||
start = pos + 1;
|
||||
}
|
||||
merges = std::vector<std::u32string>(merges.begin() + 1, merges.begin() + 49152 - 256 - 2 + 1);
|
||||
std::vector<std::pair<std::u32string, std::u32string>> merge_pairs;
|
||||
for (const auto& merge : merges) {
|
||||
size_t space_pos = merge.find(' ');
|
||||
merge_pairs.emplace_back(merge.substr(0, space_pos), merge.substr(space_pos + 1));
|
||||
// LOG_DEBUG("%s", utf32_to_utf8(merge.substr(space_pos + 1)).c_str());
|
||||
}
|
||||
std::vector<std::u32string> vocab;
|
||||
for (const auto& pair : byte_unicode_pairs) {
|
||||
vocab.push_back(pair.second);
|
||||
}
|
||||
for (const auto& pair : byte_unicode_pairs) {
|
||||
vocab.push_back(pair.second + utf8_to_utf32("</w>"));
|
||||
}
|
||||
for (const auto& merge : merge_pairs) {
|
||||
vocab.push_back(merge.first + merge.second);
|
||||
}
|
||||
vocab.push_back(utf8_to_utf32("<|startoftext|>"));
|
||||
vocab.push_back(utf8_to_utf32("<|endoftext|>"));
|
||||
LOG_DEBUG("vocab size: %llu", vocab.size());
|
||||
int i = 0;
|
||||
for (const auto& token : vocab) {
|
||||
encoder[token] = i++;
|
||||
}
|
||||
|
||||
int rank = 0;
|
||||
for (const auto& merge : merge_pairs) {
|
||||
bpe_ranks[merge] = rank++;
|
||||
}
|
||||
};
|
||||
|
||||
std::u32string bpe(std::u32string token) {
|
||||
std::u32string word = token + utf8_to_utf32("</w>");
|
||||
if (encoder.find(word) != encoder.end()) {
|
||||
return word;
|
||||
} else if (encoder.find(token) != encoder.end()) {
|
||||
return token;
|
||||
}
|
||||
return UNK_TOKEN;
|
||||
}
|
||||
|
||||
void add_token(std::string token, int32_t token_id) {
|
||||
encoder[token] = token_id;
|
||||
return utf8_to_utf32(UNK_TOKEN);
|
||||
}
|
||||
|
||||
std::vector<int> tokenize(std::string text, size_t max_length = 0, bool padding = false) {
|
||||
@ -571,13 +641,25 @@ public:
|
||||
std::vector<std::string> token_strs;
|
||||
while (std::regex_search(str, matches, pat)) {
|
||||
for (auto& token : matches) {
|
||||
std::istringstream iss(bpe(token));
|
||||
std::vector<std::string> tokens{std::istream_iterator<std::string>{iss},
|
||||
std::istream_iterator<std::string>{}};
|
||||
for (const auto& bpe_token : tokens) {
|
||||
bpe_tokens.push_back(encoder[bpe_token]);
|
||||
token_strs.push_back(bpe_token);
|
||||
std::string token_str = token.str();
|
||||
std::u32string utf32_token;
|
||||
for (int i = 0; i < token_str.length(); i++) {
|
||||
char b = token_str[i];
|
||||
utf32_token += byte_encoder[b];
|
||||
}
|
||||
auto bpe_strs = bpe(utf32_token);
|
||||
size_t start = 0;
|
||||
size_t pos;
|
||||
while ((pos = bpe_strs.find(' ', start)) != std::u32string::npos) {
|
||||
auto bpe_str = bpe_strs.substr(start, pos - start);
|
||||
bpe_tokens.push_back(encoder[bpe_str]);
|
||||
token_strs.push_back(utf32_to_utf8(bpe_str));
|
||||
|
||||
start = pos + 1;
|
||||
}
|
||||
auto bpe_str = bpe_strs.substr(start, bpe_strs.size() - start);
|
||||
bpe_tokens.push_back(encoder[bpe_str]);
|
||||
token_strs.push_back(utf32_to_utf8(bpe_str));
|
||||
}
|
||||
str = matches.suffix();
|
||||
}
|
||||
@ -4323,15 +4405,14 @@ public:
|
||||
LOG_INFO("Stable Diffusion weight type: %s", ggml_type_name(model_data_type));
|
||||
|
||||
LOG_DEBUG("loading vocab");
|
||||
auto add_token = [&](const std::string& token, int32_t token_id) {
|
||||
cond_stage_model.tokenizer.add_token(token, token_id);
|
||||
};
|
||||
bool success = model_loader.load_vocab(add_token);
|
||||
if (!success) {
|
||||
LOG_ERROR("get vocab from file failed: '%s'", model_path.c_str());
|
||||
std::string merges_utf8_str = model_loader.load_merges();
|
||||
if (merges_utf8_str.size() == 0) {
|
||||
LOG_ERROR("get merges failed: '%s'", model_path.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
cond_stage_model.tokenizer.load_from_merges(merges_utf8_str);
|
||||
|
||||
// create the ggml context for network params
|
||||
LOG_DEBUG("ggml tensor size = %d bytes", (int)sizeof(ggml_tensor));
|
||||
|
||||
@ -4431,7 +4512,7 @@ public:
|
||||
|
||||
// print_ggml_tensor(alphas_cumprod_tensor);
|
||||
|
||||
success = model_loader.load_tensors(on_new_tensor_cb);
|
||||
bool success = model_loader.load_tensors(on_new_tensor_cb);
|
||||
if (!success) {
|
||||
LOG_ERROR("load tensors from file failed");
|
||||
ggml_free(ctx);
|
||||
|
17
util.cpp
17
util.cpp
@ -1,7 +1,9 @@
|
||||
#include "util.h"
|
||||
|
||||
#include <stdarg.h>
|
||||
#include <codecvt>
|
||||
#include <fstream>
|
||||
#include <locale>
|
||||
#include <thread>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
@ -119,6 +121,21 @@ int32_t get_num_physical_cores() {
|
||||
return n_threads > 0 ? (n_threads <= 4 ? n_threads : n_threads / 2) : 4;
|
||||
}
|
||||
|
||||
std::u32string utf8_to_utf32(const std::string& utf8_str) {
|
||||
std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> converter;
|
||||
return converter.from_bytes(utf8_str);
|
||||
}
|
||||
|
||||
std::string utf32_to_utf8(const std::u32string& utf32_str) {
|
||||
std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> converter;
|
||||
return converter.to_bytes(utf32_str);
|
||||
}
|
||||
|
||||
std::u32string unicode_value_to_utf32(int unicode_value) {
|
||||
std::u32string utf32_string = {static_cast<char32_t>(unicode_value)};
|
||||
return utf32_string;
|
||||
}
|
||||
|
||||
std::string basename(const std::string& path) {
|
||||
size_t pos = path.find_last_of('/');
|
||||
if (pos != std::string::npos) {
|
||||
|
4
util.h
4
util.h
@ -14,6 +14,10 @@ void replace_all_chars(std::string& str, char target, char replacement);
|
||||
bool file_exists(const std::string& filename);
|
||||
bool is_directory(const std::string& path);
|
||||
|
||||
std::u32string utf8_to_utf32(const std::string& utf8_str);
|
||||
std::string utf32_to_utf8(const std::u32string& utf32_str);
|
||||
std::u32string unicode_value_to_utf32(int unicode_value);
|
||||
|
||||
std::string basename(const std::string& path);
|
||||
|
||||
std::string path_join(const std::string& p1, const std::string& p2);
|
||||
|
Loading…
Reference in New Issue
Block a user