fix: enhance the tokenizer's handing of Unicode (#120)
This commit is contained in:
parent
9842a3f819
commit
8f6b4a39d6
17
model.cpp
17
model.cpp
@ -1192,20 +1192,9 @@ ggml_type ModelLoader::get_sd_wtype() {
|
|||||||
return GGML_TYPE_COUNT;
|
return GGML_TYPE_COUNT;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ModelLoader::load_vocab(on_new_token_cb_t on_new_token_cb) {
|
std::string ModelLoader::load_merges() {
|
||||||
char* vocab_buffer = reinterpret_cast<char*>(vocab_json);
|
std::string merges_utf8_str(reinterpret_cast<const char*>(merges_utf8_c_str), sizeof(merges_utf8_c_str));
|
||||||
nlohmann::json vocab = nlohmann::json::parse(vocab_buffer);
|
return merges_utf8_str;
|
||||||
std::map<char, int> decoder = unicode_to_byte();
|
|
||||||
for (auto& it : vocab.items()) {
|
|
||||||
int token_id = it.value();
|
|
||||||
std::string token_str = it.key();
|
|
||||||
std::string token = "";
|
|
||||||
for (char c : token_str) {
|
|
||||||
token += decoder[c];
|
|
||||||
}
|
|
||||||
on_new_token_cb(token, token_id);
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb) {
|
bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb) {
|
||||||
|
2
model.h
2
model.h
@ -115,7 +115,7 @@ public:
|
|||||||
bool init_from_file(const std::string& file_path, const std::string& prefix = "");
|
bool init_from_file(const std::string& file_path, const std::string& prefix = "");
|
||||||
SDVersion get_sd_version();
|
SDVersion get_sd_version();
|
||||||
ggml_type get_sd_wtype();
|
ggml_type get_sd_wtype();
|
||||||
bool load_vocab(on_new_token_cb_t on_new_token_cb);
|
std::string load_merges();
|
||||||
bool load_tensors(on_new_tensor_cb_t on_new_tensor_cb);
|
bool load_tensors(on_new_tensor_cb_t on_new_tensor_cb);
|
||||||
int64_t cal_mem_size(ggml_backend_t backend);
|
int64_t cal_mem_size(ggml_backend_t backend);
|
||||||
~ModelLoader() = default;
|
~ModelLoader() = default;
|
||||||
|
@ -493,12 +493,40 @@ const int BOS_TOKEN_ID = 49406;
|
|||||||
const int EOS_TOKEN_ID = 49407;
|
const int EOS_TOKEN_ID = 49407;
|
||||||
const int PAD_TOKEN_ID = 49407;
|
const int PAD_TOKEN_ID = 49407;
|
||||||
|
|
||||||
|
std::vector<std::pair<int, std::u32string>> bytes_to_unicode() {
|
||||||
|
std::vector<std::pair<int, std::u32string>> byte_unicode_pairs;
|
||||||
|
std::set<int> byte_set;
|
||||||
|
for (int b = static_cast<int>('!'); b <= static_cast<int>('~'); ++b) {
|
||||||
|
byte_set.insert(b);
|
||||||
|
byte_unicode_pairs.push_back(std::pair<int, std::u32string>(b, unicode_value_to_utf32(b)));
|
||||||
|
}
|
||||||
|
for (int b = 161; b <= 172; ++b) {
|
||||||
|
byte_set.insert(b);
|
||||||
|
byte_unicode_pairs.push_back(std::pair<int, std::u32string>(b, unicode_value_to_utf32(b)));
|
||||||
|
}
|
||||||
|
for (int b = 174; b <= 255; ++b) {
|
||||||
|
byte_set.insert(b);
|
||||||
|
byte_unicode_pairs.push_back(std::pair<int, std::u32string>(b, unicode_value_to_utf32(b)));
|
||||||
|
}
|
||||||
|
int n = 0;
|
||||||
|
for (int b = 0; b < 256; ++b) {
|
||||||
|
if (byte_set.find(b) == byte_set.end()) {
|
||||||
|
byte_unicode_pairs.push_back(std::pair<int, std::u32string>(b, unicode_value_to_utf32(n + 256)));
|
||||||
|
++n;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// LOG_DEBUG("byte_unicode_pairs %d", byte_unicode_pairs.size());
|
||||||
|
return byte_unicode_pairs;
|
||||||
|
}
|
||||||
|
|
||||||
// Ref: https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py
|
// Ref: https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py
|
||||||
// TODO: implement bpe
|
// TODO: implement bpe
|
||||||
class CLIPTokenizer {
|
class CLIPTokenizer {
|
||||||
private:
|
private:
|
||||||
SDVersion version = VERSION_1_x;
|
SDVersion version = VERSION_1_x;
|
||||||
std::map<std::string, int32_t> encoder;
|
std::map<int, std::u32string> byte_encoder;
|
||||||
|
std::map<std::u32string, int> encoder;
|
||||||
|
std::map<std::pair<std::u32string, std::u32string>, int> bpe_ranks;
|
||||||
std::regex pat;
|
std::regex pat;
|
||||||
|
|
||||||
static std::string strip(const std::string& str) {
|
static std::string strip(const std::string& str) {
|
||||||
@ -521,19 +549,61 @@ private:
|
|||||||
|
|
||||||
public:
|
public:
|
||||||
CLIPTokenizer(SDVersion version = VERSION_1_x)
|
CLIPTokenizer(SDVersion version = VERSION_1_x)
|
||||||
: version(version){};
|
: version(version) {}
|
||||||
std::string bpe(std::string token) {
|
|
||||||
std::string word = token + "</w>";
|
void load_from_merges(const std::string& merges_utf8_str) {
|
||||||
|
auto byte_unicode_pairs = bytes_to_unicode();
|
||||||
|
byte_encoder = std::map<int, std::u32string>(byte_unicode_pairs.begin(), byte_unicode_pairs.end());
|
||||||
|
// for (auto & pair: byte_unicode_pairs) {
|
||||||
|
// std::cout << pair.first << ": " << pair.second << std::endl;
|
||||||
|
// }
|
||||||
|
std::vector<std::u32string> merges;
|
||||||
|
size_t start = 0;
|
||||||
|
size_t pos;
|
||||||
|
std::u32string merges_utf32_str = utf8_to_utf32(merges_utf8_str);
|
||||||
|
while ((pos = merges_utf32_str.find('\n', start)) != std::string::npos) {
|
||||||
|
merges.push_back(merges_utf32_str.substr(start, pos - start));
|
||||||
|
start = pos + 1;
|
||||||
|
}
|
||||||
|
merges = std::vector<std::u32string>(merges.begin() + 1, merges.begin() + 49152 - 256 - 2 + 1);
|
||||||
|
std::vector<std::pair<std::u32string, std::u32string>> merge_pairs;
|
||||||
|
for (const auto& merge : merges) {
|
||||||
|
size_t space_pos = merge.find(' ');
|
||||||
|
merge_pairs.emplace_back(merge.substr(0, space_pos), merge.substr(space_pos + 1));
|
||||||
|
// LOG_DEBUG("%s", utf32_to_utf8(merge.substr(space_pos + 1)).c_str());
|
||||||
|
}
|
||||||
|
std::vector<std::u32string> vocab;
|
||||||
|
for (const auto& pair : byte_unicode_pairs) {
|
||||||
|
vocab.push_back(pair.second);
|
||||||
|
}
|
||||||
|
for (const auto& pair : byte_unicode_pairs) {
|
||||||
|
vocab.push_back(pair.second + utf8_to_utf32("</w>"));
|
||||||
|
}
|
||||||
|
for (const auto& merge : merge_pairs) {
|
||||||
|
vocab.push_back(merge.first + merge.second);
|
||||||
|
}
|
||||||
|
vocab.push_back(utf8_to_utf32("<|startoftext|>"));
|
||||||
|
vocab.push_back(utf8_to_utf32("<|endoftext|>"));
|
||||||
|
LOG_DEBUG("vocab size: %llu", vocab.size());
|
||||||
|
int i = 0;
|
||||||
|
for (const auto& token : vocab) {
|
||||||
|
encoder[token] = i++;
|
||||||
|
}
|
||||||
|
|
||||||
|
int rank = 0;
|
||||||
|
for (const auto& merge : merge_pairs) {
|
||||||
|
bpe_ranks[merge] = rank++;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
std::u32string bpe(std::u32string token) {
|
||||||
|
std::u32string word = token + utf8_to_utf32("</w>");
|
||||||
if (encoder.find(word) != encoder.end()) {
|
if (encoder.find(word) != encoder.end()) {
|
||||||
return word;
|
return word;
|
||||||
} else if (encoder.find(token) != encoder.end()) {
|
} else if (encoder.find(token) != encoder.end()) {
|
||||||
return token;
|
return token;
|
||||||
}
|
}
|
||||||
return UNK_TOKEN;
|
return utf8_to_utf32(UNK_TOKEN);
|
||||||
}
|
|
||||||
|
|
||||||
void add_token(std::string token, int32_t token_id) {
|
|
||||||
encoder[token] = token_id;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<int> tokenize(std::string text, size_t max_length = 0, bool padding = false) {
|
std::vector<int> tokenize(std::string text, size_t max_length = 0, bool padding = false) {
|
||||||
@ -571,13 +641,25 @@ public:
|
|||||||
std::vector<std::string> token_strs;
|
std::vector<std::string> token_strs;
|
||||||
while (std::regex_search(str, matches, pat)) {
|
while (std::regex_search(str, matches, pat)) {
|
||||||
for (auto& token : matches) {
|
for (auto& token : matches) {
|
||||||
std::istringstream iss(bpe(token));
|
std::string token_str = token.str();
|
||||||
std::vector<std::string> tokens{std::istream_iterator<std::string>{iss},
|
std::u32string utf32_token;
|
||||||
std::istream_iterator<std::string>{}};
|
for (int i = 0; i < token_str.length(); i++) {
|
||||||
for (const auto& bpe_token : tokens) {
|
char b = token_str[i];
|
||||||
bpe_tokens.push_back(encoder[bpe_token]);
|
utf32_token += byte_encoder[b];
|
||||||
token_strs.push_back(bpe_token);
|
|
||||||
}
|
}
|
||||||
|
auto bpe_strs = bpe(utf32_token);
|
||||||
|
size_t start = 0;
|
||||||
|
size_t pos;
|
||||||
|
while ((pos = bpe_strs.find(' ', start)) != std::u32string::npos) {
|
||||||
|
auto bpe_str = bpe_strs.substr(start, pos - start);
|
||||||
|
bpe_tokens.push_back(encoder[bpe_str]);
|
||||||
|
token_strs.push_back(utf32_to_utf8(bpe_str));
|
||||||
|
|
||||||
|
start = pos + 1;
|
||||||
|
}
|
||||||
|
auto bpe_str = bpe_strs.substr(start, bpe_strs.size() - start);
|
||||||
|
bpe_tokens.push_back(encoder[bpe_str]);
|
||||||
|
token_strs.push_back(utf32_to_utf8(bpe_str));
|
||||||
}
|
}
|
||||||
str = matches.suffix();
|
str = matches.suffix();
|
||||||
}
|
}
|
||||||
@ -4323,15 +4405,14 @@ public:
|
|||||||
LOG_INFO("Stable Diffusion weight type: %s", ggml_type_name(model_data_type));
|
LOG_INFO("Stable Diffusion weight type: %s", ggml_type_name(model_data_type));
|
||||||
|
|
||||||
LOG_DEBUG("loading vocab");
|
LOG_DEBUG("loading vocab");
|
||||||
auto add_token = [&](const std::string& token, int32_t token_id) {
|
std::string merges_utf8_str = model_loader.load_merges();
|
||||||
cond_stage_model.tokenizer.add_token(token, token_id);
|
if (merges_utf8_str.size() == 0) {
|
||||||
};
|
LOG_ERROR("get merges failed: '%s'", model_path.c_str());
|
||||||
bool success = model_loader.load_vocab(add_token);
|
|
||||||
if (!success) {
|
|
||||||
LOG_ERROR("get vocab from file failed: '%s'", model_path.c_str());
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
cond_stage_model.tokenizer.load_from_merges(merges_utf8_str);
|
||||||
|
|
||||||
// create the ggml context for network params
|
// create the ggml context for network params
|
||||||
LOG_DEBUG("ggml tensor size = %d bytes", (int)sizeof(ggml_tensor));
|
LOG_DEBUG("ggml tensor size = %d bytes", (int)sizeof(ggml_tensor));
|
||||||
|
|
||||||
@ -4431,7 +4512,7 @@ public:
|
|||||||
|
|
||||||
// print_ggml_tensor(alphas_cumprod_tensor);
|
// print_ggml_tensor(alphas_cumprod_tensor);
|
||||||
|
|
||||||
success = model_loader.load_tensors(on_new_tensor_cb);
|
bool success = model_loader.load_tensors(on_new_tensor_cb);
|
||||||
if (!success) {
|
if (!success) {
|
||||||
LOG_ERROR("load tensors from file failed");
|
LOG_ERROR("load tensors from file failed");
|
||||||
ggml_free(ctx);
|
ggml_free(ctx);
|
||||||
|
17
util.cpp
17
util.cpp
@ -1,7 +1,9 @@
|
|||||||
#include "util.h"
|
#include "util.h"
|
||||||
|
|
||||||
#include <stdarg.h>
|
#include <stdarg.h>
|
||||||
|
#include <codecvt>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
|
#include <locale>
|
||||||
#include <thread>
|
#include <thread>
|
||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
@ -119,6 +121,21 @@ int32_t get_num_physical_cores() {
|
|||||||
return n_threads > 0 ? (n_threads <= 4 ? n_threads : n_threads / 2) : 4;
|
return n_threads > 0 ? (n_threads <= 4 ? n_threads : n_threads / 2) : 4;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::u32string utf8_to_utf32(const std::string& utf8_str) {
|
||||||
|
std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> converter;
|
||||||
|
return converter.from_bytes(utf8_str);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string utf32_to_utf8(const std::u32string& utf32_str) {
|
||||||
|
std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> converter;
|
||||||
|
return converter.to_bytes(utf32_str);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::u32string unicode_value_to_utf32(int unicode_value) {
|
||||||
|
std::u32string utf32_string = {static_cast<char32_t>(unicode_value)};
|
||||||
|
return utf32_string;
|
||||||
|
}
|
||||||
|
|
||||||
std::string basename(const std::string& path) {
|
std::string basename(const std::string& path) {
|
||||||
size_t pos = path.find_last_of('/');
|
size_t pos = path.find_last_of('/');
|
||||||
if (pos != std::string::npos) {
|
if (pos != std::string::npos) {
|
||||||
|
4
util.h
4
util.h
@ -14,6 +14,10 @@ void replace_all_chars(std::string& str, char target, char replacement);
|
|||||||
bool file_exists(const std::string& filename);
|
bool file_exists(const std::string& filename);
|
||||||
bool is_directory(const std::string& path);
|
bool is_directory(const std::string& path);
|
||||||
|
|
||||||
|
std::u32string utf8_to_utf32(const std::string& utf8_str);
|
||||||
|
std::string utf32_to_utf8(const std::u32string& utf32_str);
|
||||||
|
std::u32string unicode_value_to_utf32(int unicode_value);
|
||||||
|
|
||||||
std::string basename(const std::string& path);
|
std::string basename(const std::string& path);
|
||||||
|
|
||||||
std::string path_join(const std::string& p1, const std::string& p2);
|
std::string path_join(const std::string& p1, const std::string& p2);
|
||||||
|
Loading…
Reference in New Issue
Block a user