refactor: reorganize code and use c api (#133)
This commit is contained in:
parent
b139434b57
commit
2e79a82f85
@ -60,7 +60,8 @@ add_subdirectory(thirdparty)
|
||||
|
||||
set(SD_LIB stable-diffusion)
|
||||
|
||||
add_library(${SD_LIB} stable-diffusion.h stable-diffusion.cpp model.h model.cpp util.h util.cpp)
|
||||
add_library(${SD_LIB} stable-diffusion.h stable-diffusion.cpp model.h model.cpp util.h util.cpp upscaler.cpp
|
||||
ggml_extend.hpp clip.hpp common.hpp unet.hpp tae.hpp esrgan.hpp lora.hpp denoiser.hpp rng.hpp rng_philox.hpp)
|
||||
target_link_libraries(${SD_LIB} PUBLIC ggml zip)
|
||||
target_include_directories(${SD_LIB} PUBLIC . thirdparty)
|
||||
target_compile_features(${SD_LIB} PUBLIC cxx_std_11)
|
||||
|
998
clip.hpp
Normal file
998
clip.hpp
Normal file
@ -0,0 +1,998 @@
|
||||
#ifndef __CLIP_HPP__
|
||||
#define __CLIP_HPP__
|
||||
|
||||
#include "ggml_extend.hpp"
|
||||
|
||||
/*================================================== CLIPTokenizer ===================================================*/
|
||||
|
||||
std::pair<std::unordered_map<std::string, float>, std::string> extract_and_remove_lora(std::string text) {
|
||||
std::regex re("<lora:([^:]+):([^>]+)>");
|
||||
std::smatch matches;
|
||||
std::unordered_map<std::string, float> filename2multiplier;
|
||||
|
||||
while (std::regex_search(text, matches, re)) {
|
||||
std::string filename = matches[1].str();
|
||||
float multiplier = std::stof(matches[2].str());
|
||||
|
||||
text = std::regex_replace(text, re, "", std::regex_constants::format_first_only);
|
||||
|
||||
if (multiplier == 0.f) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (filename2multiplier.find(filename) == filename2multiplier.end()) {
|
||||
filename2multiplier[filename] = multiplier;
|
||||
} else {
|
||||
filename2multiplier[filename] += multiplier;
|
||||
}
|
||||
}
|
||||
|
||||
return std::make_pair(filename2multiplier, text);
|
||||
}
|
||||
|
||||
const std::string UNK_TOKEN = "<|endoftext|>";
|
||||
const std::string BOS_TOKEN = "<|startoftext|>";
|
||||
const std::string EOS_TOKEN = "<|endoftext|>";
|
||||
const std::string PAD_TOEKN = "<|endoftext|>";
|
||||
|
||||
const int UNK_TOKEN_ID = 49407;
|
||||
const int BOS_TOKEN_ID = 49406;
|
||||
const int EOS_TOKEN_ID = 49407;
|
||||
const int PAD_TOKEN_ID = 49407;
|
||||
|
||||
std::vector<std::pair<int, std::u32string>> bytes_to_unicode() {
|
||||
std::vector<std::pair<int, std::u32string>> byte_unicode_pairs;
|
||||
std::set<int> byte_set;
|
||||
for (int b = static_cast<int>('!'); b <= static_cast<int>('~'); ++b) {
|
||||
byte_set.insert(b);
|
||||
byte_unicode_pairs.push_back(std::pair<int, std::u32string>(b, unicode_value_to_utf32(b)));
|
||||
}
|
||||
for (int b = 161; b <= 172; ++b) {
|
||||
byte_set.insert(b);
|
||||
byte_unicode_pairs.push_back(std::pair<int, std::u32string>(b, unicode_value_to_utf32(b)));
|
||||
}
|
||||
for (int b = 174; b <= 255; ++b) {
|
||||
byte_set.insert(b);
|
||||
byte_unicode_pairs.push_back(std::pair<int, std::u32string>(b, unicode_value_to_utf32(b)));
|
||||
}
|
||||
int n = 0;
|
||||
for (int b = 0; b < 256; ++b) {
|
||||
if (byte_set.find(b) == byte_set.end()) {
|
||||
byte_unicode_pairs.push_back(std::pair<int, std::u32string>(b, unicode_value_to_utf32(n + 256)));
|
||||
++n;
|
||||
}
|
||||
}
|
||||
// LOG_DEBUG("byte_unicode_pairs %d", byte_unicode_pairs.size());
|
||||
return byte_unicode_pairs;
|
||||
}
|
||||
|
||||
// Ref: https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py
|
||||
class CLIPTokenizer {
|
||||
private:
|
||||
SDVersion version = VERSION_1_x;
|
||||
std::map<int, std::u32string> byte_encoder;
|
||||
std::map<std::u32string, int> encoder;
|
||||
std::map<std::pair<std::u32string, std::u32string>, int> bpe_ranks;
|
||||
std::regex pat;
|
||||
|
||||
static std::string strip(const std::string& str) {
|
||||
std::string::size_type start = str.find_first_not_of(" \t\n\r\v\f");
|
||||
std::string::size_type end = str.find_last_not_of(" \t\n\r\v\f");
|
||||
|
||||
if (start == std::string::npos) {
|
||||
// String contains only whitespace characters
|
||||
return "";
|
||||
}
|
||||
|
||||
return str.substr(start, end - start + 1);
|
||||
}
|
||||
|
||||
static std::string whitespace_clean(std::string text) {
|
||||
text = std::regex_replace(text, std::regex(R"(\s+)"), " ");
|
||||
text = strip(text);
|
||||
return text;
|
||||
}
|
||||
|
||||
static std::set<std::pair<std::u32string, std::u32string>> get_pairs(const std::vector<std::u32string>& subwords) {
|
||||
std::set<std::pair<std::u32string, std::u32string>> pairs;
|
||||
if (subwords.size() == 0) {
|
||||
return pairs;
|
||||
}
|
||||
std::u32string prev_subword = subwords[0];
|
||||
for (int i = 1; i < subwords.size(); i++) {
|
||||
std::u32string subword = subwords[i];
|
||||
std::pair<std::u32string, std::u32string> pair(prev_subword, subword);
|
||||
pairs.insert(pair);
|
||||
prev_subword = subword;
|
||||
}
|
||||
return pairs;
|
||||
}
|
||||
|
||||
public:
|
||||
CLIPTokenizer(SDVersion version = VERSION_1_x)
|
||||
: version(version) {}
|
||||
|
||||
void load_from_merges(const std::string& merges_utf8_str) {
|
||||
auto byte_unicode_pairs = bytes_to_unicode();
|
||||
byte_encoder = std::map<int, std::u32string>(byte_unicode_pairs.begin(), byte_unicode_pairs.end());
|
||||
// for (auto & pair: byte_unicode_pairs) {
|
||||
// std::cout << pair.first << ": " << pair.second << std::endl;
|
||||
// }
|
||||
std::vector<std::u32string> merges;
|
||||
size_t start = 0;
|
||||
size_t pos;
|
||||
std::u32string merges_utf32_str = utf8_to_utf32(merges_utf8_str);
|
||||
while ((pos = merges_utf32_str.find('\n', start)) != std::string::npos) {
|
||||
merges.push_back(merges_utf32_str.substr(start, pos - start));
|
||||
start = pos + 1;
|
||||
}
|
||||
// LOG_DEBUG("merges size %llu", merges.size());
|
||||
GGML_ASSERT(merges.size() == 48895);
|
||||
merges = std::vector<std::u32string>(merges.begin() + 1, merges.end());
|
||||
std::vector<std::pair<std::u32string, std::u32string>> merge_pairs;
|
||||
for (const auto& merge : merges) {
|
||||
size_t space_pos = merge.find(' ');
|
||||
merge_pairs.emplace_back(merge.substr(0, space_pos), merge.substr(space_pos + 1));
|
||||
// LOG_DEBUG("%s", utf32_to_utf8(merge.substr(space_pos + 1)).c_str());
|
||||
}
|
||||
std::vector<std::u32string> vocab;
|
||||
for (const auto& pair : byte_unicode_pairs) {
|
||||
vocab.push_back(pair.second);
|
||||
}
|
||||
for (const auto& pair : byte_unicode_pairs) {
|
||||
vocab.push_back(pair.second + utf8_to_utf32("</w>"));
|
||||
}
|
||||
for (const auto& merge : merge_pairs) {
|
||||
vocab.push_back(merge.first + merge.second);
|
||||
}
|
||||
vocab.push_back(utf8_to_utf32("<|startoftext|>"));
|
||||
vocab.push_back(utf8_to_utf32("<|endoftext|>"));
|
||||
LOG_DEBUG("vocab size: %llu", vocab.size());
|
||||
int i = 0;
|
||||
for (const auto& token : vocab) {
|
||||
encoder[token] = i++;
|
||||
}
|
||||
|
||||
int rank = 0;
|
||||
for (const auto& merge : merge_pairs) {
|
||||
bpe_ranks[merge] = rank++;
|
||||
}
|
||||
};
|
||||
|
||||
std::u32string bpe(const std::u32string& token) {
|
||||
std::vector<std::u32string> word;
|
||||
|
||||
for (int i = 0; i < token.size() - 1; i++) {
|
||||
word.emplace_back(1, token[i]);
|
||||
}
|
||||
word.push_back(token.substr(token.size() - 1) + utf8_to_utf32("</w>"));
|
||||
|
||||
std::set<std::pair<std::u32string, std::u32string>> pairs = get_pairs(word);
|
||||
|
||||
if (pairs.empty()) {
|
||||
return token + utf8_to_utf32("</w>");
|
||||
}
|
||||
|
||||
while (true) {
|
||||
auto min_pair_iter = std::min_element(pairs.begin(),
|
||||
pairs.end(),
|
||||
[&](const std::pair<std::u32string, std::u32string>& a,
|
||||
const std::pair<std::u32string, std::u32string>& b) {
|
||||
if (bpe_ranks.find(a) == bpe_ranks.end()) {
|
||||
return false;
|
||||
} else if (bpe_ranks.find(b) == bpe_ranks.end()) {
|
||||
return true;
|
||||
}
|
||||
return bpe_ranks.at(a) < bpe_ranks.at(b);
|
||||
});
|
||||
|
||||
const std::pair<std::u32string, std::u32string>& bigram = *min_pair_iter;
|
||||
|
||||
if (bpe_ranks.find(bigram) == bpe_ranks.end()) {
|
||||
break;
|
||||
}
|
||||
|
||||
std::u32string first = bigram.first;
|
||||
std::u32string second = bigram.second;
|
||||
std::vector<std::u32string> new_word;
|
||||
int32_t i = 0;
|
||||
|
||||
while (i < word.size()) {
|
||||
auto it = std::find(word.begin() + i, word.end(), first);
|
||||
if (it == word.end()) {
|
||||
new_word.insert(new_word.end(), word.begin() + i, word.end());
|
||||
break;
|
||||
}
|
||||
new_word.insert(new_word.end(), word.begin() + i, it);
|
||||
i = static_cast<int32_t>(std::distance(word.begin(), it));
|
||||
|
||||
if (word[i] == first && i < static_cast<int32_t>(word.size()) - 1 && word[i + 1] == second) {
|
||||
new_word.push_back(first + second);
|
||||
i += 2;
|
||||
} else {
|
||||
new_word.push_back(word[i]);
|
||||
i += 1;
|
||||
}
|
||||
}
|
||||
|
||||
word = new_word;
|
||||
|
||||
if (word.size() == 1) {
|
||||
break;
|
||||
}
|
||||
pairs = get_pairs(word);
|
||||
}
|
||||
|
||||
std::u32string result;
|
||||
for (int i = 0; i < word.size(); i++) {
|
||||
result += word[i];
|
||||
if (i != word.size() - 1) {
|
||||
result += utf8_to_utf32(" ");
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<int> tokenize(std::string text, size_t max_length = 0, bool padding = false) {
|
||||
std::vector<int32_t> tokens = encode(text);
|
||||
tokens.insert(tokens.begin(), BOS_TOKEN_ID);
|
||||
if (max_length > 0) {
|
||||
if (tokens.size() > max_length - 1) {
|
||||
tokens.resize(max_length - 1);
|
||||
tokens.push_back(EOS_TOKEN_ID);
|
||||
} else {
|
||||
tokens.push_back(EOS_TOKEN_ID);
|
||||
if (padding) {
|
||||
int pad_token_id = PAD_TOKEN_ID;
|
||||
if (version == VERSION_2_x) {
|
||||
pad_token_id = 0;
|
||||
}
|
||||
tokens.insert(tokens.end(), max_length - tokens.size(), pad_token_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
return tokens;
|
||||
}
|
||||
|
||||
std::vector<int> encode(std::string text) {
|
||||
std::string original_text = text;
|
||||
std::vector<int32_t> bpe_tokens;
|
||||
text = whitespace_clean(text);
|
||||
std::transform(text.begin(), text.end(), text.begin(), [](unsigned char c) { return std::tolower(c); });
|
||||
|
||||
std::regex pat(R"(<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[[:alpha:]]+|[[:digit:]]|[^[:space:][:alpha:][:digit:]]+)",
|
||||
std::regex::icase);
|
||||
|
||||
std::smatch matches;
|
||||
std::string str = text;
|
||||
std::vector<std::string> token_strs;
|
||||
while (std::regex_search(str, matches, pat)) {
|
||||
for (auto& token : matches) {
|
||||
std::string token_str = token.str();
|
||||
std::u32string utf32_token;
|
||||
for (int i = 0; i < token_str.length(); i++) {
|
||||
char b = token_str[i];
|
||||
utf32_token += byte_encoder[b];
|
||||
}
|
||||
auto bpe_strs = bpe(utf32_token);
|
||||
size_t start = 0;
|
||||
size_t pos;
|
||||
while ((pos = bpe_strs.find(' ', start)) != std::u32string::npos) {
|
||||
auto bpe_str = bpe_strs.substr(start, pos - start);
|
||||
bpe_tokens.push_back(encoder[bpe_str]);
|
||||
token_strs.push_back(utf32_to_utf8(bpe_str));
|
||||
|
||||
start = pos + 1;
|
||||
}
|
||||
auto bpe_str = bpe_strs.substr(start, bpe_strs.size() - start);
|
||||
bpe_tokens.push_back(encoder[bpe_str]);
|
||||
token_strs.push_back(utf32_to_utf8(bpe_str));
|
||||
}
|
||||
str = matches.suffix();
|
||||
}
|
||||
std::stringstream ss;
|
||||
ss << "[";
|
||||
for (auto token : token_strs) {
|
||||
ss << "\"" << token << "\", ";
|
||||
}
|
||||
ss << "]";
|
||||
LOG_DEBUG("split prompt \"%s\" to tokens %s", original_text.c_str(), ss.str().c_str());
|
||||
return bpe_tokens;
|
||||
}
|
||||
};
|
||||
|
||||
// Ref: https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/cad87bf4e3e0b0a759afa94e933527c3123d59bc/modules/prompt_parser.py#L345
|
||||
//
|
||||
// Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
|
||||
// Accepted tokens are:
|
||||
// (abc) - increases attention to abc by a multiplier of 1.1
|
||||
// (abc:3.12) - increases attention to abc by a multiplier of 3.12
|
||||
// [abc] - decreases attention to abc by a multiplier of 1.1
|
||||
// \( - literal character '('
|
||||
// \[ - literal character '['
|
||||
// \) - literal character ')'
|
||||
// \] - literal character ']'
|
||||
// \\ - literal character '\'
|
||||
// anything else - just text
|
||||
//
|
||||
// >>> parse_prompt_attention('normal text')
|
||||
// [['normal text', 1.0]]
|
||||
// >>> parse_prompt_attention('an (important) word')
|
||||
// [['an ', 1.0], ['important', 1.1], [' word', 1.0]]
|
||||
// >>> parse_prompt_attention('(unbalanced')
|
||||
// [['unbalanced', 1.1]]
|
||||
// >>> parse_prompt_attention('\(literal\]')
|
||||
// [['(literal]', 1.0]]
|
||||
// >>> parse_prompt_attention('(unnecessary)(parens)')
|
||||
// [['unnecessaryparens', 1.1]]
|
||||
// >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
|
||||
// [['a ', 1.0],
|
||||
// ['house', 1.5730000000000004],
|
||||
// [' ', 1.1],
|
||||
// ['on', 1.0],
|
||||
// [' a ', 1.1],
|
||||
// ['hill', 0.55],
|
||||
// [', sun, ', 1.1],
|
||||
// ['sky', 1.4641000000000006],
|
||||
// ['.', 1.1]]
|
||||
std::vector<std::pair<std::string, float>> parse_prompt_attention(const std::string& text) {
|
||||
std::vector<std::pair<std::string, float>> res;
|
||||
std::vector<int> round_brackets;
|
||||
std::vector<int> square_brackets;
|
||||
|
||||
float round_bracket_multiplier = 1.1f;
|
||||
float square_bracket_multiplier = 1 / 1.1f;
|
||||
|
||||
std::regex re_attention(R"(\\\(|\\\)|\\\[|\\\]|\\\\|\\|\(|\[|:([+-]?[.\d]+)\)|\)|\]|[^\\()\[\]:]+|:)");
|
||||
std::regex re_break(R"(\s*\bBREAK\b\s*)");
|
||||
|
||||
auto multiply_range = [&](int start_position, float multiplier) {
|
||||
for (int p = start_position; p < res.size(); ++p) {
|
||||
res[p].second *= multiplier;
|
||||
}
|
||||
};
|
||||
|
||||
std::smatch m;
|
||||
std::string remaining_text = text;
|
||||
|
||||
while (std::regex_search(remaining_text, m, re_attention)) {
|
||||
std::string text = m[0];
|
||||
std::string weight = m[1];
|
||||
|
||||
if (text == "(") {
|
||||
round_brackets.push_back((int)res.size());
|
||||
} else if (text == "[") {
|
||||
square_brackets.push_back((int)res.size());
|
||||
} else if (!weight.empty()) {
|
||||
if (!round_brackets.empty()) {
|
||||
multiply_range(round_brackets.back(), std::stof(weight));
|
||||
round_brackets.pop_back();
|
||||
}
|
||||
} else if (text == ")" && !round_brackets.empty()) {
|
||||
multiply_range(round_brackets.back(), round_bracket_multiplier);
|
||||
round_brackets.pop_back();
|
||||
} else if (text == "]" && !square_brackets.empty()) {
|
||||
multiply_range(square_brackets.back(), square_bracket_multiplier);
|
||||
square_brackets.pop_back();
|
||||
} else if (text == "\\(") {
|
||||
res.push_back({text.substr(1), 1.0f});
|
||||
} else {
|
||||
res.push_back({text, 1.0f});
|
||||
}
|
||||
|
||||
remaining_text = m.suffix();
|
||||
}
|
||||
|
||||
for (int pos : round_brackets) {
|
||||
multiply_range(pos, round_bracket_multiplier);
|
||||
}
|
||||
|
||||
for (int pos : square_brackets) {
|
||||
multiply_range(pos, square_bracket_multiplier);
|
||||
}
|
||||
|
||||
if (res.empty()) {
|
||||
res.push_back({"", 1.0f});
|
||||
}
|
||||
|
||||
int i = 0;
|
||||
while (i + 1 < res.size()) {
|
||||
if (res[i].second == res[i + 1].second) {
|
||||
res[i].first += res[i + 1].first;
|
||||
res.erase(res.begin() + i + 1);
|
||||
} else {
|
||||
++i;
|
||||
}
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
/*================================================ FrozenCLIPEmbedder ================================================*/
|
||||
|
||||
struct ResidualAttentionBlock {
|
||||
int32_t n_head;
|
||||
int32_t d_model;
|
||||
int32_t hidden_size; // n_head * d_model
|
||||
int32_t intermediate_size;
|
||||
|
||||
// attention
|
||||
struct ggml_tensor* q_w; // [hidden_size, hidden_size]
|
||||
struct ggml_tensor* q_b; // [hidden_size, ]
|
||||
struct ggml_tensor* k_w; // [hidden_size, hidden_size]
|
||||
struct ggml_tensor* k_b; // [hidden_size, ]
|
||||
struct ggml_tensor* v_w; // [hidden_size, hidden_size]
|
||||
struct ggml_tensor* v_b; // [hidden_size, ]
|
||||
|
||||
struct ggml_tensor* out_w; // [hidden_size, hidden_size]
|
||||
struct ggml_tensor* out_b; // [hidden_size, ]
|
||||
|
||||
// layer norm 1
|
||||
struct ggml_tensor* ln1_w; // [hidden_size, ]
|
||||
struct ggml_tensor* ln1_b; // [hidden_size, ]
|
||||
|
||||
// mlp
|
||||
struct ggml_tensor* fc1_w; // [intermediate_size, hidden_size]
|
||||
struct ggml_tensor* fc1_b; // [intermediate_size, ]
|
||||
|
||||
struct ggml_tensor* fc2_w; // [hidden_size, intermediate_size]
|
||||
struct ggml_tensor* fc2_b; // [hidden_size, ]
|
||||
|
||||
// layer norm 2
|
||||
struct ggml_tensor* ln2_w; // [hidden_size, ]
|
||||
struct ggml_tensor* ln2_b; // [hidden_size, ]
|
||||
|
||||
struct ggml_tensor* attn_scale; // [hidden_size, ]
|
||||
|
||||
size_t calculate_mem_size(ggml_type wtype) {
|
||||
double mem_size = 0;
|
||||
mem_size += 4 * hidden_size * hidden_size * ggml_type_sizef(wtype); // q_w/k_w/v_w/out_w
|
||||
mem_size += 8 * hidden_size * ggml_type_sizef(GGML_TYPE_F32); // q_b/k_b/v_b/out_b/ln1_w/ln1_b/ln2_w/ln2_b
|
||||
mem_size += 2 * hidden_size * intermediate_size * ggml_type_sizef(wtype); // fc1_w/fc2_w
|
||||
mem_size += intermediate_size * ggml_type_sizef(GGML_TYPE_F32); // fc1_b
|
||||
mem_size += hidden_size * ggml_type_sizef(GGML_TYPE_F32); // fc2_b
|
||||
mem_size += ggml_type_sizef(GGML_TYPE_F32); // attn_scale
|
||||
return static_cast<size_t>(mem_size);
|
||||
}
|
||||
|
||||
void init_params(struct ggml_context* ctx, ggml_allocr* alloc, ggml_type wtype) {
|
||||
ln1_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hidden_size);
|
||||
ln1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hidden_size);
|
||||
|
||||
q_w = ggml_new_tensor_2d(ctx, wtype, hidden_size, hidden_size);
|
||||
q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hidden_size);
|
||||
k_w = ggml_new_tensor_2d(ctx, wtype, hidden_size, hidden_size);
|
||||
k_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hidden_size);
|
||||
v_w = ggml_new_tensor_2d(ctx, wtype, hidden_size, hidden_size);
|
||||
v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hidden_size);
|
||||
|
||||
out_w = ggml_new_tensor_2d(ctx, wtype, hidden_size, hidden_size);
|
||||
out_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hidden_size);
|
||||
|
||||
fc1_w = ggml_new_tensor_2d(ctx, wtype, hidden_size, intermediate_size);
|
||||
fc1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, intermediate_size);
|
||||
|
||||
fc2_w = ggml_new_tensor_2d(ctx, wtype, intermediate_size, hidden_size);
|
||||
fc2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hidden_size);
|
||||
|
||||
ln2_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hidden_size);
|
||||
ln2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hidden_size);
|
||||
|
||||
attn_scale = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
|
||||
ggml_allocr_alloc(alloc, attn_scale);
|
||||
float scale = 1.0f / sqrt((float)d_model);
|
||||
ggml_backend_tensor_set(attn_scale, &scale, 0, sizeof(scale));
|
||||
}
|
||||
|
||||
void map_by_name(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) {
|
||||
tensors[prefix + "self_attn.q_proj.weight"] = q_w;
|
||||
tensors[prefix + "self_attn.q_proj.bias"] = q_b;
|
||||
tensors[prefix + "self_attn.k_proj.weight"] = k_w;
|
||||
tensors[prefix + "self_attn.k_proj.bias"] = k_b;
|
||||
tensors[prefix + "self_attn.v_proj.weight"] = v_w;
|
||||
tensors[prefix + "self_attn.v_proj.bias"] = v_b;
|
||||
tensors[prefix + "self_attn.out_proj.weight"] = out_w;
|
||||
tensors[prefix + "self_attn.out_proj.bias"] = out_b;
|
||||
|
||||
tensors[prefix + "layer_norm1.weight"] = ln1_w;
|
||||
tensors[prefix + "layer_norm1.bias"] = ln1_b;
|
||||
|
||||
tensors[prefix + "layer_norm2.weight"] = ln2_w;
|
||||
tensors[prefix + "layer_norm2.bias"] = ln2_b;
|
||||
|
||||
tensors[prefix + "mlp.fc1.weight"] = fc1_w;
|
||||
tensors[prefix + "mlp.fc1.bias"] = fc1_b;
|
||||
|
||||
tensors[prefix + "mlp.fc2.weight"] = fc2_w;
|
||||
tensors[prefix + "mlp.fc2.bias"] = fc2_b;
|
||||
}
|
||||
|
||||
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
|
||||
// x: [N, n_token, hidden_size]
|
||||
int64_t N = x->ne[2];
|
||||
int64_t n_token = x->ne[1];
|
||||
int64_t hidden_size = n_head * d_model;
|
||||
|
||||
struct ggml_tensor* r = x;
|
||||
|
||||
// layer norm 1
|
||||
x = ggml_nn_layer_norm(ctx, x, ln1_w, ln1_b);
|
||||
// self-attention
|
||||
{
|
||||
struct ggml_tensor* q = ggml_nn_linear(ctx, x, q_w, q_b);
|
||||
q = ggml_scale_inplace(ctx, q, attn_scale);
|
||||
q = ggml_reshape_4d(ctx, q, d_model, n_head, n_token, N); // [N, n_token, n_head, d_model]
|
||||
q = ggml_cont(ctx, ggml_permute(ctx, q, 0, 2, 1, 3)); // [N, n_head, n_token, d_model]
|
||||
q = ggml_reshape_3d(ctx, q, d_model, n_token, n_head * N); // [N * n_head, n_token, d_model]
|
||||
|
||||
struct ggml_tensor* k = ggml_nn_linear(ctx, x, k_w, k_b);
|
||||
k = ggml_reshape_4d(ctx, k, d_model, n_head, n_token, N); // [N, n_token, n_head, d_model]
|
||||
k = ggml_cont(ctx, ggml_permute(ctx, k, 0, 2, 1, 3)); // [N, n_head, n_token, d_model]
|
||||
k = ggml_reshape_3d(ctx, k, d_model, n_token, n_head); // [N * n_head, n_token, d_model]
|
||||
|
||||
struct ggml_tensor* v = ggml_nn_linear(ctx, x, v_w, v_b);
|
||||
v = ggml_reshape_4d(ctx, v, d_model, n_head, n_token, N); // [N, n_token, n_head, d_model]
|
||||
v = ggml_cont(ctx, ggml_permute(ctx, v, 1, 2, 0, 3)); // [N, n_head, d_model, n_token]
|
||||
v = ggml_reshape_3d(ctx, v, n_token, d_model, n_head * N); // [N * n_head, d_model, n_token]
|
||||
|
||||
struct ggml_tensor* kq = ggml_mul_mat(ctx, k, q); // [N * n_head, n_token, n_token]
|
||||
|
||||
kq = ggml_diag_mask_inf_inplace(ctx, kq, 0);
|
||||
kq = ggml_soft_max_inplace(ctx, kq);
|
||||
|
||||
struct ggml_tensor* kqv = ggml_mul_mat(ctx, v, kq); // [N * n_head, n_token, d_model]
|
||||
kqv = ggml_reshape_4d(ctx, kqv, d_model, n_token, n_head, N);
|
||||
kqv = ggml_cont(ctx, ggml_permute(ctx, kqv, 0, 2, 1, 3)); // [N, n_token, n_head, d_model]
|
||||
|
||||
x = ggml_reshape_2d(ctx, kqv, d_model * n_head, n_token * N); // // [N * n_token, d_model * n_head]
|
||||
}
|
||||
|
||||
// attention output
|
||||
x = ggml_nn_linear(ctx, x, out_w, out_b);
|
||||
|
||||
// residual
|
||||
x = ggml_add(ctx, x, r);
|
||||
r = x;
|
||||
|
||||
// layer norm 2
|
||||
x = ggml_nn_layer_norm(ctx, x, ln2_w, ln2_b);
|
||||
|
||||
// mlp
|
||||
x = ggml_nn_linear(ctx, x, fc1_w, fc1_b);
|
||||
|
||||
if (hidden_size == 1024 || hidden_size == 1280) { // SD 2.x
|
||||
x = ggml_gelu_inplace(ctx, x);
|
||||
} else { // SD 1.x
|
||||
x = ggml_gelu_quick_inplace(ctx, x);
|
||||
}
|
||||
|
||||
x = ggml_nn_linear(ctx, x, fc2_w, fc2_b);
|
||||
|
||||
// residual 2
|
||||
x = ggml_add(ctx, x, r);
|
||||
return x;
|
||||
}
|
||||
};
|
||||
|
||||
// OPENAI_CLIP_VIT_L_14: https://huggingface.co/openai/clip-vit-large-patch14/blob/main/config.json
|
||||
// OPEN_CLIP_VIT_H_14: https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K/blob/main/config.json
|
||||
// OPEN_CLIP_VIT_BIGG_14: https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k/blob/main/config.json (CLIPTextModelWithProjection)
|
||||
// SDXL CLIPModel
|
||||
// CLIPTextModelWithProjection seems optional
|
||||
|
||||
enum CLIPVersion {
|
||||
OPENAI_CLIP_VIT_L_14, // SD 1.x and SDXL
|
||||
OPEN_CLIP_VIT_H_14, // SD 2.x
|
||||
OPEN_CLIP_VIT_BIGG_14, // SDXL
|
||||
};
|
||||
|
||||
struct CLIPTextModel {
|
||||
CLIPVersion version = OPENAI_CLIP_VIT_L_14;
|
||||
// network hparams
|
||||
int32_t vocab_size = 49408;
|
||||
int32_t max_position_embeddings = 77;
|
||||
int32_t hidden_size = 768; // 1024 for OPEN_CLIP_VIT_H_14
|
||||
int32_t intermediate_size = 3072; // 4096 for OPEN_CLIP_VIT_H_14
|
||||
int32_t n_head = 12; // num_attention_heads, 16 for OPEN_CLIP_VIT_H_14
|
||||
int32_t num_hidden_layers = 12; // 24 for OPEN_CLIP_VIT_H_14
|
||||
int32_t layer_idx = 11;
|
||||
int32_t projection_dim = 1280; // only for OPEN_CLIP_VIT_BIGG_14
|
||||
bool with_final_ln = true;
|
||||
|
||||
// embeddings
|
||||
struct ggml_tensor* position_ids;
|
||||
struct ggml_tensor* token_embed_weight;
|
||||
struct ggml_tensor* position_embed_weight;
|
||||
|
||||
// transformer
|
||||
std::vector<ResidualAttentionBlock> resblocks;
|
||||
struct ggml_tensor* final_ln_w;
|
||||
struct ggml_tensor* final_ln_b;
|
||||
|
||||
struct ggml_tensor* text_projection;
|
||||
|
||||
CLIPTextModel(CLIPVersion version = OPENAI_CLIP_VIT_L_14,
|
||||
int clip_skip = -1,
|
||||
bool with_final_ln = true)
|
||||
: version(version), with_final_ln(with_final_ln) {
|
||||
if (version == OPEN_CLIP_VIT_H_14) {
|
||||
hidden_size = 1024;
|
||||
intermediate_size = 4096;
|
||||
n_head = 16;
|
||||
num_hidden_layers = 24;
|
||||
} else if (version == OPEN_CLIP_VIT_BIGG_14) { // CLIPTextModelWithProjection
|
||||
hidden_size = 1280;
|
||||
intermediate_size = 5120;
|
||||
n_head = 20;
|
||||
num_hidden_layers = 32;
|
||||
}
|
||||
set_clip_skip(clip_skip);
|
||||
resblocks.resize(num_hidden_layers);
|
||||
set_resblocks_hp_params();
|
||||
}
|
||||
|
||||
void set_clip_skip(int clip_skip) {
|
||||
if (clip_skip > 0) {
|
||||
layer_idx = num_hidden_layers - clip_skip;
|
||||
}
|
||||
}
|
||||
|
||||
void set_resblocks_hp_params() {
|
||||
int d_model = hidden_size / n_head; // 64 / SDXL is 40 for CLIPTextModelWithProjection
|
||||
for (int i = 0; i < num_hidden_layers; i++) {
|
||||
resblocks[i].d_model = d_model;
|
||||
resblocks[i].n_head = n_head;
|
||||
resblocks[i].hidden_size = hidden_size;
|
||||
resblocks[i].intermediate_size = intermediate_size;
|
||||
}
|
||||
}
|
||||
|
||||
size_t calculate_mem_size(ggml_type wtype) {
|
||||
double mem_size = 0;
|
||||
mem_size += hidden_size * max_position_embeddings * ggml_type_sizef(GGML_TYPE_I32); // position_ids
|
||||
mem_size += hidden_size * vocab_size * ggml_type_sizef(wtype); // token_embed_weight
|
||||
mem_size += hidden_size * max_position_embeddings * ggml_type_sizef(wtype); // position_embed_weight
|
||||
for (int i = 0; i < num_hidden_layers; i++) {
|
||||
mem_size += resblocks[i].calculate_mem_size(wtype);
|
||||
}
|
||||
mem_size += 2 * hidden_size * ggml_type_sizef(GGML_TYPE_F32); // final_ln_w/b
|
||||
if (version == OPEN_CLIP_VIT_BIGG_14) {
|
||||
mem_size += hidden_size * projection_dim * ggml_type_sizef(GGML_TYPE_F32); // text_projection
|
||||
}
|
||||
return static_cast<size_t>(mem_size);
|
||||
}
|
||||
|
||||
void map_by_name(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) {
|
||||
tensors[prefix + "embeddings.token_embedding.weight"] = token_embed_weight;
|
||||
tensors[prefix + "embeddings.position_embedding.weight"] = position_embed_weight;
|
||||
tensors[prefix + "final_layer_norm.weight"] = final_ln_w;
|
||||
tensors[prefix + "final_layer_norm.bias"] = final_ln_b;
|
||||
for (int i = 0; i < num_hidden_layers; i++) {
|
||||
std::string name = prefix + "encoder.layers." + std::to_string(i) + ".";
|
||||
resblocks[i].map_by_name(tensors, prefix + "encoder.layers." + std::to_string(i) + ".");
|
||||
}
|
||||
if (version == OPEN_CLIP_VIT_BIGG_14) {
|
||||
tensors[prefix + "text_projection"] = text_projection;
|
||||
}
|
||||
}
|
||||
|
||||
struct ggml_tensor* forward(struct ggml_context* ctx0, struct ggml_tensor* input_ids, size_t max_token_idx = 0, bool return_pooled = false) {
|
||||
// input_ids: [N, n_token]
|
||||
GGML_ASSERT(input_ids->ne[0] <= position_ids->ne[0]);
|
||||
|
||||
// token_embedding + position_embedding
|
||||
struct ggml_tensor* x;
|
||||
x = ggml_add(ctx0,
|
||||
ggml_get_rows(ctx0, token_embed_weight, input_ids),
|
||||
ggml_get_rows(ctx0,
|
||||
position_embed_weight,
|
||||
ggml_view_1d(ctx0, position_ids, input_ids->ne[0], 0))); // [N, n_token, hidden_size]
|
||||
|
||||
// transformer
|
||||
for (int i = 0; i < num_hidden_layers; i++) {
|
||||
if (!return_pooled && i == layer_idx + 1) {
|
||||
// LOG_DEBUG("layer %d", i);
|
||||
break;
|
||||
}
|
||||
x = resblocks[i].forward(ctx0, x); // [N, n_token, hidden_size]
|
||||
}
|
||||
|
||||
// final layer norm
|
||||
if (return_pooled || with_final_ln) {
|
||||
x = ggml_nn_layer_norm(ctx0, x, final_ln_w, final_ln_b);
|
||||
}
|
||||
|
||||
if (return_pooled) {
|
||||
// ggml_tensor* idx = ggml_argmax(ctx0, input_ids);
|
||||
// ggml_tensor* pooled = ggml_get_rows(ctx0, x, idx);
|
||||
// LOG_DEBUG("max_token_idx: %u %u", max_token_idx, x->nb[1]);
|
||||
ggml_tensor* pooled = ggml_view_1d(ctx0, x, hidden_size, x->nb[1] * max_token_idx);
|
||||
pooled = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, text_projection)), pooled);
|
||||
return pooled;
|
||||
}
|
||||
|
||||
return x; // [N, n_token, hidden_size]
|
||||
}
|
||||
|
||||
void init_params(ggml_context* ctx, ggml_backend_t backend, ggml_type wtype, ggml_allocr* alloc) {
|
||||
position_ids = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, max_position_embeddings);
|
||||
|
||||
token_embed_weight = ggml_new_tensor_2d(ctx, wtype, hidden_size, vocab_size);
|
||||
|
||||
position_embed_weight = ggml_new_tensor_2d(ctx, wtype, hidden_size, max_position_embeddings);
|
||||
|
||||
for (int i = 0; i < num_hidden_layers; i++) {
|
||||
resblocks[i].init_params(ctx, alloc, wtype);
|
||||
}
|
||||
|
||||
final_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hidden_size);
|
||||
|
||||
final_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hidden_size);
|
||||
|
||||
if (version == OPEN_CLIP_VIT_BIGG_14) {
|
||||
text_projection = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, projection_dim, hidden_size);
|
||||
}
|
||||
|
||||
// alloc all tensors linked to this context
|
||||
for (struct ggml_tensor* t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
|
||||
if (t->data == NULL) {
|
||||
ggml_allocr_alloc(alloc, t);
|
||||
}
|
||||
}
|
||||
|
||||
if (ggml_backend_is_cpu(backend)) {
|
||||
for (int i = 0; i < max_position_embeddings; i++) {
|
||||
ggml_set_i32_1d(position_ids, i, i);
|
||||
}
|
||||
} else {
|
||||
std::vector<int> pos_temp;
|
||||
for (int i = 0; i < max_position_embeddings; i++) {
|
||||
pos_temp.push_back(i);
|
||||
}
|
||||
ggml_backend_tensor_set(position_ids, pos_temp.data(), 0, ggml_nbytes(position_ids));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||
// Ref: https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/cad87bf4e3e0b0a759afa94e933527c3123d59bc/modules/sd_hijack_clip.py#L283
|
||||
struct FrozenCLIPEmbedderWithCustomWords : public GGMLModule {
|
||||
SDVersion version = VERSION_1_x;
|
||||
CLIPTokenizer tokenizer;
|
||||
CLIPTextModel text_model;
|
||||
CLIPTextModel text_model2;
|
||||
|
||||
FrozenCLIPEmbedderWithCustomWords(SDVersion version = VERSION_1_x, int clip_skip = -1)
|
||||
: version(version), tokenizer(version) {
|
||||
name = "clip";
|
||||
if (clip_skip <= 0) {
|
||||
clip_skip = 1;
|
||||
if (version == VERSION_2_x || version == VERSION_XL) {
|
||||
clip_skip = 2;
|
||||
}
|
||||
}
|
||||
if (version == VERSION_1_x) {
|
||||
text_model = CLIPTextModel(OPENAI_CLIP_VIT_L_14, clip_skip);
|
||||
} else if (version == VERSION_2_x) {
|
||||
text_model = CLIPTextModel(OPEN_CLIP_VIT_H_14, clip_skip);
|
||||
} else if (version == VERSION_XL) {
|
||||
text_model = CLIPTextModel(OPENAI_CLIP_VIT_L_14, clip_skip, false);
|
||||
text_model2 = CLIPTextModel(OPEN_CLIP_VIT_BIGG_14, clip_skip, false);
|
||||
}
|
||||
}
|
||||
|
||||
void set_clip_skip(int clip_skip) {
|
||||
text_model.set_clip_skip(clip_skip);
|
||||
if (version == VERSION_XL) {
|
||||
text_model2.set_clip_skip(clip_skip);
|
||||
}
|
||||
}
|
||||
|
||||
size_t calculate_mem_size() {
|
||||
size_t mem_size = text_model.calculate_mem_size(wtype);
|
||||
if (version == VERSION_XL) {
|
||||
mem_size += text_model2.calculate_mem_size(wtype);
|
||||
}
|
||||
return mem_size;
|
||||
}
|
||||
|
||||
size_t get_num_tensors() {
|
||||
size_t num_tensors = (3 + 2 + 37 * text_model.num_hidden_layers);
|
||||
if (version == VERSION_XL) {
|
||||
num_tensors += (3 + 2 + 37 * text_model2.num_hidden_layers);
|
||||
}
|
||||
return num_tensors;
|
||||
}
|
||||
|
||||
void map_by_name(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) {
|
||||
text_model.map_by_name(tensors, prefix + "transformer.text_model.");
|
||||
if (version == VERSION_XL) {
|
||||
text_model2.map_by_name(tensors, prefix + "1.transformer.text_model.");
|
||||
}
|
||||
}
|
||||
|
||||
struct ggml_tensor* forward(struct ggml_context* ctx0, struct ggml_tensor* input_ids, struct ggml_tensor* input_ids2, size_t max_token_idx = 0, bool return_pooled = false) {
|
||||
if (return_pooled) {
|
||||
return text_model2.forward(ctx0, input_ids2, max_token_idx, return_pooled);
|
||||
}
|
||||
auto hidden_states = text_model.forward(ctx0, input_ids); // [N, n_token, hidden_size]
|
||||
// LOG_DEBUG("hidden_states: %d %d %d %d %d", hidden_states->n_dims, hidden_states->ne[0], hidden_states->ne[1], hidden_states->ne[2], hidden_states->ne[3]);
|
||||
if (version == VERSION_XL) {
|
||||
hidden_states = ggml_reshape_4d(ctx0,
|
||||
hidden_states,
|
||||
hidden_states->ne[0],
|
||||
hidden_states->ne[1],
|
||||
hidden_states->ne[2],
|
||||
hidden_states->ne[3]);
|
||||
hidden_states = ggml_cont(ctx0, ggml_permute(ctx0, hidden_states, 2, 0, 1, 3));
|
||||
|
||||
auto hidden_states2 = text_model2.forward(ctx0, input_ids2); // [N, n_token, hidden_size2]
|
||||
hidden_states2 = ggml_reshape_4d(ctx0,
|
||||
hidden_states2,
|
||||
hidden_states2->ne[0],
|
||||
hidden_states2->ne[1],
|
||||
hidden_states2->ne[2],
|
||||
hidden_states2->ne[3]);
|
||||
hidden_states2 = ggml_cont(ctx0, ggml_permute(ctx0, hidden_states2, 2, 0, 1, 3));
|
||||
|
||||
hidden_states = ggml_concat(ctx0, hidden_states, hidden_states2); // [N, n_token, hidden_size + hidden_size2]
|
||||
|
||||
hidden_states = ggml_cont(ctx0, ggml_permute(ctx0, hidden_states, 1, 2, 0, 3));
|
||||
}
|
||||
// LOG_DEBUG("hidden_states: %d %d %d %d", hidden_states->ne[0], hidden_states->ne[1], hidden_states->ne[2], hidden_states->ne[3]);
|
||||
return hidden_states;
|
||||
}
|
||||
|
||||
std::pair<std::vector<int>, std::vector<float>> tokenize(std::string text,
|
||||
bool padding = false) {
|
||||
return tokenize(text, text_model.max_position_embeddings, padding);
|
||||
}
|
||||
|
||||
std::pair<std::vector<int>, std::vector<float>> tokenize(std::string text,
|
||||
size_t max_length = 0,
|
||||
bool padding = false) {
|
||||
auto parsed_attention = parse_prompt_attention(text);
|
||||
|
||||
{
|
||||
std::stringstream ss;
|
||||
ss << "[";
|
||||
for (const auto& item : parsed_attention) {
|
||||
ss << "['" << item.first << "', " << item.second << "], ";
|
||||
}
|
||||
ss << "]";
|
||||
LOG_DEBUG("parse '%s' to %s", text.c_str(), ss.str().c_str());
|
||||
}
|
||||
|
||||
std::vector<int> tokens;
|
||||
std::vector<float> weights;
|
||||
for (const auto& item : parsed_attention) {
|
||||
const std::string& curr_text = item.first;
|
||||
float curr_weight = item.second;
|
||||
std::vector<int> curr_tokens = tokenizer.encode(curr_text);
|
||||
tokens.insert(tokens.end(), curr_tokens.begin(), curr_tokens.end());
|
||||
weights.insert(weights.end(), curr_tokens.size(), curr_weight);
|
||||
}
|
||||
tokens.insert(tokens.begin(), BOS_TOKEN_ID);
|
||||
weights.insert(weights.begin(), 1.0);
|
||||
|
||||
if (max_length > 0) {
|
||||
if (tokens.size() > max_length - 1) {
|
||||
tokens.resize(max_length - 1);
|
||||
weights.resize(max_length - 1);
|
||||
tokens.push_back(EOS_TOKEN_ID);
|
||||
weights.push_back(1.0);
|
||||
} else {
|
||||
tokens.push_back(EOS_TOKEN_ID);
|
||||
weights.push_back(1.0);
|
||||
if (padding) {
|
||||
int pad_token_id = PAD_TOKEN_ID;
|
||||
if (version == VERSION_2_x) {
|
||||
pad_token_id = 0;
|
||||
}
|
||||
tokens.insert(tokens.end(), max_length - tokens.size(), pad_token_id);
|
||||
weights.insert(weights.end(), max_length - weights.size(), 1.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// for (int i = 0; i < tokens.size(); i++) {
|
||||
// std::cout << tokens[i] << ":" << weights[i] << ", ";
|
||||
// }
|
||||
// std::cout << std::endl;
|
||||
|
||||
return {tokens, weights};
|
||||
}
|
||||
|
||||
void init_params() {
|
||||
ggml_allocr* alloc = ggml_allocr_new_from_buffer(params_buffer);
|
||||
text_model.init_params(params_ctx, backend, wtype, alloc);
|
||||
if (version == VERSION_XL) {
|
||||
text_model2.init_params(params_ctx, backend, wtype, alloc);
|
||||
}
|
||||
ggml_allocr_free(alloc);
|
||||
}
|
||||
|
||||
struct ggml_cgraph* build_graph(struct ggml_allocr* allocr, std::vector<int> tokens, bool return_pooled = false) {
|
||||
// since we are using ggml-alloc, this buffer only needs enough space to hold the ggml_tensor and ggml_cgraph structs, but not the tensor data
|
||||
static size_t buf_size = ggml_tensor_overhead() * GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead();
|
||||
static std::vector<uint8_t> buf(buf_size);
|
||||
|
||||
struct ggml_init_params params = {
|
||||
/*.mem_size =*/buf_size,
|
||||
/*.mem_buffer =*/buf.data(),
|
||||
/*.no_alloc =*/true, // the tensors will be allocated later by ggml_allocr_alloc_graph()
|
||||
};
|
||||
|
||||
struct ggml_context* ctx0 = ggml_init(params);
|
||||
|
||||
struct ggml_cgraph* gf = ggml_new_graph(ctx0);
|
||||
|
||||
struct ggml_tensor* input_ids = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, tokens.size());
|
||||
ggml_allocr_alloc(allocr, input_ids);
|
||||
|
||||
if (!ggml_allocr_is_measure(allocr)) {
|
||||
ggml_backend_tensor_set(input_ids, tokens.data(), 0, tokens.size() * ggml_element_size(input_ids));
|
||||
}
|
||||
|
||||
struct ggml_tensor* input_ids2 = NULL;
|
||||
size_t max_token_idx = 0;
|
||||
if (version == VERSION_XL) {
|
||||
input_ids2 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, tokens.size());
|
||||
ggml_allocr_alloc(allocr, input_ids2);
|
||||
|
||||
auto it = std::find(tokens.begin(), tokens.end(), EOS_TOKEN_ID);
|
||||
if (it != tokens.end()) {
|
||||
std::fill(std::next(it), tokens.end(), 0);
|
||||
}
|
||||
|
||||
max_token_idx = std::min<size_t>(std::distance(tokens.begin(), it), tokens.size() - 1);
|
||||
|
||||
// for (int i = 0; i < tokens.size(); i++) {
|
||||
// printf("%d ", tokens[i]);
|
||||
// }
|
||||
// printf("\n");
|
||||
|
||||
if (!ggml_allocr_is_measure(allocr)) {
|
||||
ggml_backend_tensor_set(input_ids2, tokens.data(), 0, tokens.size() * ggml_element_size(input_ids2));
|
||||
}
|
||||
}
|
||||
|
||||
struct ggml_tensor* hidden_states = forward(ctx0, input_ids, input_ids2, max_token_idx, return_pooled);
|
||||
|
||||
ggml_build_forward_expand(gf, hidden_states);
|
||||
ggml_free(ctx0);
|
||||
|
||||
return gf;
|
||||
}
|
||||
|
||||
void alloc_compute_buffer(ggml_context* work_ctx, int max_tokens) {
|
||||
auto get_graph = [&]() -> struct ggml_cgraph* {
|
||||
bool return_pooled = false;
|
||||
if (version == VERSION_XL) {
|
||||
return_pooled = true;
|
||||
}
|
||||
return build_graph(compute_allocr, std::vector<int>(max_tokens), return_pooled);
|
||||
};
|
||||
GGMLModule::alloc_compute_buffer(get_graph);
|
||||
}
|
||||
|
||||
void compute(const int n_threads,
|
||||
std::vector<int> tokens,
|
||||
ggml_tensor* hidden_state_output,
|
||||
ggml_tensor* pooled_output = NULL) {
|
||||
auto get_graph = [&]() -> struct ggml_cgraph* {
|
||||
return build_graph(compute_allocr, tokens, false);
|
||||
};
|
||||
GGMLModule::compute(get_graph, n_threads, hidden_state_output);
|
||||
|
||||
if (version == VERSION_XL && pooled_output != NULL) {
|
||||
auto get_graph = [&]() -> struct ggml_cgraph* {
|
||||
return build_graph(compute_allocr, tokens, true);
|
||||
};
|
||||
GGMLModule::compute(get_graph, n_threads, pooled_output);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
#endif // __CLIP_HPP__
|
86
common.hpp
Normal file
86
common.hpp
Normal file
@ -0,0 +1,86 @@
|
||||
#ifndef __COMMON_HPP__
|
||||
#define __COMMON_HPP__
|
||||
|
||||
#include "ggml_extend.hpp"
|
||||
|
||||
struct DownSample {
|
||||
// hparams
|
||||
int channels;
|
||||
int out_channels;
|
||||
|
||||
// conv2d params
|
||||
struct ggml_tensor* op_w; // [out_channels, channels, 3, 3]
|
||||
struct ggml_tensor* op_b; // [out_channels,]
|
||||
|
||||
bool vae_downsample = false;
|
||||
|
||||
size_t calculate_mem_size(ggml_type wtype) {
|
||||
double mem_size = 0;
|
||||
mem_size += out_channels * channels * 3 * 3 * ggml_type_sizef(GGML_TYPE_F16); // op_w
|
||||
mem_size += out_channels * ggml_type_sizef(GGML_TYPE_F32); // op_b
|
||||
return static_cast<size_t>(mem_size);
|
||||
}
|
||||
|
||||
void init_params(struct ggml_context* ctx, ggml_type wtype) {
|
||||
op_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, channels, out_channels);
|
||||
op_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, out_channels);
|
||||
}
|
||||
|
||||
void map_by_name(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) {
|
||||
if (vae_downsample) {
|
||||
tensors[prefix + "conv.weight"] = op_w;
|
||||
tensors[prefix + "conv.bias"] = op_b;
|
||||
} else {
|
||||
tensors[prefix + "op.weight"] = op_w;
|
||||
tensors[prefix + "op.bias"] = op_b;
|
||||
}
|
||||
}
|
||||
|
||||
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
|
||||
// x: [N, channels, h, w]
|
||||
struct ggml_tensor* c = NULL;
|
||||
if (vae_downsample) {
|
||||
c = ggml_pad(ctx, x, 1, 1, 0, 0);
|
||||
c = ggml_nn_conv_2d(ctx, c, op_w, op_b, 2, 2, 0, 0);
|
||||
} else {
|
||||
c = ggml_nn_conv_2d(ctx, x, op_w, op_b, 2, 2, 1, 1);
|
||||
}
|
||||
return c; // [N, out_channels, h/2, w/2]
|
||||
}
|
||||
};
|
||||
|
||||
struct UpSample {
|
||||
// hparams
|
||||
int channels;
|
||||
int out_channels;
|
||||
|
||||
// conv2d params
|
||||
struct ggml_tensor* conv_w; // [out_channels, channels, 3, 3]
|
||||
struct ggml_tensor* conv_b; // [out_channels,]
|
||||
|
||||
size_t calculate_mem_size(ggml_type wtype) {
|
||||
double mem_size = 0;
|
||||
mem_size += out_channels * channels * 3 * 3 * ggml_type_sizef(GGML_TYPE_F16); // op_w
|
||||
mem_size += out_channels * ggml_type_sizef(GGML_TYPE_F32); // op_b
|
||||
return static_cast<size_t>(mem_size);
|
||||
}
|
||||
|
||||
void init_params(struct ggml_context* ctx, ggml_type wtype) {
|
||||
conv_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, channels, out_channels);
|
||||
conv_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, out_channels);
|
||||
}
|
||||
|
||||
void map_by_name(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) {
|
||||
tensors[prefix + "conv.weight"] = conv_w;
|
||||
tensors[prefix + "conv.bias"] = conv_b;
|
||||
}
|
||||
|
||||
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
|
||||
// x: [N, channels, h, w]
|
||||
x = ggml_upscale(ctx, x, 2); // [N, channels, h*2, w*2]
|
||||
x = ggml_nn_conv_2d(ctx, x, conv_w, conv_b, 1, 1, 1, 1); // [N, out_channels, h*2, w*2]
|
||||
return x;
|
||||
}
|
||||
};
|
||||
|
||||
#endif // __COMMON_HPP__
|
125
denoiser.hpp
Normal file
125
denoiser.hpp
Normal file
@ -0,0 +1,125 @@
|
||||
#ifndef __DENOISER_HPP__
|
||||
#define __DENOISER_HPP__
|
||||
|
||||
#include "ggml_extend.hpp"
|
||||
|
||||
/*================================================= CompVisDenoiser ==================================================*/
|
||||
|
||||
// Ref: https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/external.py
|
||||
|
||||
#define TIMESTEPS 1000
|
||||
|
||||
struct SigmaSchedule {
|
||||
float alphas_cumprod[TIMESTEPS];
|
||||
float sigmas[TIMESTEPS];
|
||||
float log_sigmas[TIMESTEPS];
|
||||
|
||||
virtual std::vector<float> get_sigmas(uint32_t n) = 0;
|
||||
|
||||
float sigma_to_t(float sigma) {
|
||||
float log_sigma = std::log(sigma);
|
||||
std::vector<float> dists;
|
||||
dists.reserve(TIMESTEPS);
|
||||
for (float log_sigma_val : log_sigmas) {
|
||||
dists.push_back(log_sigma - log_sigma_val);
|
||||
}
|
||||
|
||||
int low_idx = 0;
|
||||
for (size_t i = 0; i < TIMESTEPS; i++) {
|
||||
if (dists[i] >= 0) {
|
||||
low_idx++;
|
||||
}
|
||||
}
|
||||
low_idx = std::min(std::max(low_idx - 1, 0), TIMESTEPS - 2);
|
||||
int high_idx = low_idx + 1;
|
||||
|
||||
float low = log_sigmas[low_idx];
|
||||
float high = log_sigmas[high_idx];
|
||||
float w = (low - log_sigma) / (low - high);
|
||||
w = std::max(0.f, std::min(1.f, w));
|
||||
float t = (1.0f - w) * low_idx + w * high_idx;
|
||||
|
||||
return t;
|
||||
}
|
||||
|
||||
float t_to_sigma(float t) {
|
||||
int low_idx = static_cast<int>(std::floor(t));
|
||||
int high_idx = static_cast<int>(std::ceil(t));
|
||||
float w = t - static_cast<float>(low_idx);
|
||||
float log_sigma = (1.0f - w) * log_sigmas[low_idx] + w * log_sigmas[high_idx];
|
||||
return std::exp(log_sigma);
|
||||
}
|
||||
};
|
||||
|
||||
struct DiscreteSchedule : SigmaSchedule {
|
||||
std::vector<float> get_sigmas(uint32_t n) {
|
||||
std::vector<float> result;
|
||||
|
||||
int t_max = TIMESTEPS - 1;
|
||||
|
||||
if (n == 0) {
|
||||
return result;
|
||||
} else if (n == 1) {
|
||||
result.push_back(t_to_sigma((float)t_max));
|
||||
result.push_back(0);
|
||||
return result;
|
||||
}
|
||||
|
||||
float step = static_cast<float>(t_max) / static_cast<float>(n - 1);
|
||||
for (uint32_t i = 0; i < n; ++i) {
|
||||
float t = t_max - step * i;
|
||||
result.push_back(t_to_sigma(t));
|
||||
}
|
||||
result.push_back(0);
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
struct KarrasSchedule : SigmaSchedule {
|
||||
std::vector<float> get_sigmas(uint32_t n) {
|
||||
// These *COULD* be function arguments here,
|
||||
// but does anybody ever bother to touch them?
|
||||
float sigma_min = 0.1f;
|
||||
float sigma_max = 10.f;
|
||||
float rho = 7.f;
|
||||
|
||||
std::vector<float> result(n + 1);
|
||||
|
||||
float min_inv_rho = pow(sigma_min, (1.f / rho));
|
||||
float max_inv_rho = pow(sigma_max, (1.f / rho));
|
||||
for (uint32_t i = 0; i < n; i++) {
|
||||
// Eq. (5) from Karras et al 2022
|
||||
result[i] = pow(max_inv_rho + (float)i / ((float)n - 1.f) * (min_inv_rho - max_inv_rho), rho);
|
||||
}
|
||||
result[n] = 0.;
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
struct Denoiser {
|
||||
std::shared_ptr<SigmaSchedule> schedule = std::make_shared<DiscreteSchedule>();
|
||||
virtual std::vector<float> get_scalings(float sigma) = 0;
|
||||
};
|
||||
|
||||
struct CompVisDenoiser : public Denoiser {
|
||||
float sigma_data = 1.0f;
|
||||
|
||||
std::vector<float> get_scalings(float sigma) {
|
||||
float c_out = -sigma;
|
||||
float c_in = 1.0f / std::sqrt(sigma * sigma + sigma_data * sigma_data);
|
||||
return {c_out, c_in};
|
||||
}
|
||||
};
|
||||
|
||||
struct CompVisVDenoiser : public Denoiser {
|
||||
float sigma_data = 1.0f;
|
||||
|
||||
std::vector<float> get_scalings(float sigma) {
|
||||
float c_skip = sigma_data * sigma_data / (sigma * sigma + sigma_data * sigma_data);
|
||||
float c_out = -sigma * sigma_data / std::sqrt(sigma * sigma + sigma_data * sigma_data);
|
||||
float c_in = 1.0f / std::sqrt(sigma * sigma + sigma_data * sigma_data);
|
||||
return {c_skip, c_out, c_in};
|
||||
}
|
||||
};
|
||||
|
||||
#endif // __DENOISER_HPP__
|
423
esrgan.hpp
Normal file
423
esrgan.hpp
Normal file
@ -0,0 +1,423 @@
|
||||
#ifndef __ESRGAN_HPP__
|
||||
#define __ESRGAN_HPP__
|
||||
|
||||
#include "ggml_extend.hpp"
|
||||
#include "model.h"
|
||||
|
||||
/*
|
||||
=================================== ESRGAN ===================================
|
||||
References:
|
||||
https://github.com/xinntao/Real-ESRGAN/blob/master/inference_realesrgan.py
|
||||
https://github.com/XPixelGroup/BasicSR/blob/v1.4.2/basicsr/archs/rrdbnet_arch.py
|
||||
|
||||
*/
|
||||
|
||||
struct ResidualDenseBlock {
|
||||
int num_features;
|
||||
int num_grow_ch;
|
||||
ggml_tensor* conv1_w; // [num_grow_ch, num_features, 3, 3]
|
||||
ggml_tensor* conv1_b; // [num_grow_ch]
|
||||
|
||||
ggml_tensor* conv2_w; // [num_grow_ch, num_features + num_grow_ch, 3, 3]
|
||||
ggml_tensor* conv2_b; // [num_grow_ch]
|
||||
|
||||
ggml_tensor* conv3_w; // [num_grow_ch, num_features + 2 * num_grow_ch, 3, 3]
|
||||
ggml_tensor* conv3_b; // [num_grow_ch]
|
||||
|
||||
ggml_tensor* conv4_w; // [num_grow_ch, num_features + 3 * num_grow_ch, 3, 3]
|
||||
ggml_tensor* conv4_b; // [num_grow_ch]
|
||||
|
||||
ggml_tensor* conv5_w; // [num_features, num_features + 4 * num_grow_ch, 3, 3]
|
||||
ggml_tensor* conv5_b; // [num_features]
|
||||
|
||||
ResidualDenseBlock() {}
|
||||
|
||||
ResidualDenseBlock(int num_feat, int n_grow_ch) {
|
||||
num_features = num_feat;
|
||||
num_grow_ch = n_grow_ch;
|
||||
}
|
||||
|
||||
size_t calculate_mem_size() {
|
||||
size_t mem_size = num_features * num_grow_ch * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv1_w
|
||||
mem_size += num_grow_ch * ggml_type_size(GGML_TYPE_F32); // conv1_b
|
||||
|
||||
mem_size += (num_features + num_grow_ch) * num_grow_ch * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv2_w
|
||||
mem_size += num_grow_ch * ggml_type_size(GGML_TYPE_F32); // conv2_b
|
||||
|
||||
mem_size += (num_features + 2 * num_grow_ch) * num_grow_ch * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv3_w
|
||||
mem_size += num_grow_ch * ggml_type_size(GGML_TYPE_F32); // conv3_w
|
||||
|
||||
mem_size += (num_features + 3 * num_grow_ch) * num_grow_ch * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv4_w
|
||||
mem_size += num_grow_ch * ggml_type_size(GGML_TYPE_F32); // conv4_w
|
||||
|
||||
mem_size += (num_features + 4 * num_grow_ch) * num_features * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv5_w
|
||||
mem_size += num_features * ggml_type_size(GGML_TYPE_F32); // conv5_w
|
||||
|
||||
return mem_size;
|
||||
}
|
||||
|
||||
int get_num_tensors() {
|
||||
int num_tensors = 10;
|
||||
return num_tensors;
|
||||
}
|
||||
|
||||
void init_params(ggml_context* ctx) {
|
||||
conv1_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, num_features, num_grow_ch);
|
||||
conv1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, num_grow_ch);
|
||||
conv2_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, num_features + num_grow_ch, num_grow_ch);
|
||||
conv2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, num_grow_ch);
|
||||
conv3_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, num_features + 2 * num_grow_ch, num_grow_ch);
|
||||
conv3_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, num_grow_ch);
|
||||
conv4_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, num_features + 3 * num_grow_ch, num_grow_ch);
|
||||
conv4_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, num_grow_ch);
|
||||
conv5_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, num_features + 4 * num_grow_ch, num_features);
|
||||
conv5_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, num_features);
|
||||
}
|
||||
|
||||
void map_by_name(std::map<std::string, ggml_tensor*>& tensors, std::string prefix) {
|
||||
tensors[prefix + "conv1.weight"] = conv1_w;
|
||||
tensors[prefix + "conv1.bias"] = conv1_b;
|
||||
|
||||
tensors[prefix + "conv2.weight"] = conv2_w;
|
||||
tensors[prefix + "conv2.bias"] = conv2_b;
|
||||
|
||||
tensors[prefix + "conv3.weight"] = conv3_w;
|
||||
tensors[prefix + "conv3.bias"] = conv3_b;
|
||||
|
||||
tensors[prefix + "conv4.weight"] = conv4_w;
|
||||
tensors[prefix + "conv4.bias"] = conv4_b;
|
||||
|
||||
tensors[prefix + "conv5.weight"] = conv5_w;
|
||||
tensors[prefix + "conv5.bias"] = conv5_b;
|
||||
}
|
||||
|
||||
ggml_tensor* forward(ggml_context* ctx, ggml_tensor* out_scale, ggml_tensor* x /* feat */) {
|
||||
// x1 = self.lrelu(self.conv1(x))
|
||||
ggml_tensor* x1 = ggml_nn_conv_2d(ctx, x, conv1_w, conv1_b, 1, 1, 1, 1);
|
||||
x1 = ggml_leaky_relu(ctx, x1, 0.2f, true);
|
||||
|
||||
// x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
|
||||
ggml_tensor* x_cat = ggml_concat(ctx, x, x1);
|
||||
ggml_tensor* x2 = ggml_nn_conv_2d(ctx, x_cat, conv2_w, conv2_b, 1, 1, 1, 1);
|
||||
x2 = ggml_leaky_relu(ctx, x2, 0.2f, true);
|
||||
|
||||
// x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
|
||||
x_cat = ggml_concat(ctx, x_cat, x2);
|
||||
ggml_tensor* x3 = ggml_nn_conv_2d(ctx, x_cat, conv3_w, conv3_b, 1, 1, 1, 1);
|
||||
x3 = ggml_leaky_relu(ctx, x3, 0.2f, true);
|
||||
|
||||
// x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
|
||||
x_cat = ggml_concat(ctx, x_cat, x3);
|
||||
ggml_tensor* x4 = ggml_nn_conv_2d(ctx, x_cat, conv4_w, conv4_b, 1, 1, 1, 1);
|
||||
x4 = ggml_leaky_relu(ctx, x4, 0.2f, true);
|
||||
|
||||
// self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
||||
x_cat = ggml_concat(ctx, x_cat, x4);
|
||||
ggml_tensor* x5 = ggml_nn_conv_2d(ctx, x_cat, conv5_w, conv5_b, 1, 1, 1, 1);
|
||||
|
||||
// return x5 * 0.2 + x
|
||||
x5 = ggml_add(ctx, ggml_scale(ctx, x5, out_scale), x);
|
||||
return x5;
|
||||
}
|
||||
};
|
||||
|
||||
struct EsrganBlock {
|
||||
ResidualDenseBlock rd_blocks[3];
|
||||
int num_residual_blocks = 3;
|
||||
|
||||
EsrganBlock() {}
|
||||
|
||||
EsrganBlock(int num_feat, int num_grow_ch) {
|
||||
for (int i = 0; i < num_residual_blocks; i++) {
|
||||
rd_blocks[i] = ResidualDenseBlock(num_feat, num_grow_ch);
|
||||
}
|
||||
}
|
||||
|
||||
int get_num_tensors() {
|
||||
int num_tensors = 0;
|
||||
for (int i = 0; i < num_residual_blocks; i++) {
|
||||
num_tensors += rd_blocks[i].get_num_tensors();
|
||||
}
|
||||
return num_tensors;
|
||||
}
|
||||
|
||||
size_t calculate_mem_size() {
|
||||
size_t mem_size = 0;
|
||||
for (int i = 0; i < num_residual_blocks; i++) {
|
||||
mem_size += rd_blocks[i].calculate_mem_size();
|
||||
}
|
||||
return mem_size;
|
||||
}
|
||||
|
||||
void init_params(ggml_context* ctx) {
|
||||
for (int i = 0; i < num_residual_blocks; i++) {
|
||||
rd_blocks[i].init_params(ctx);
|
||||
}
|
||||
}
|
||||
|
||||
void map_by_name(std::map<std::string, ggml_tensor*>& tensors, std::string prefix) {
|
||||
for (int i = 0; i < num_residual_blocks; i++) {
|
||||
rd_blocks[i].map_by_name(tensors, prefix + "rdb" + std::to_string(i + 1) + ".");
|
||||
}
|
||||
}
|
||||
|
||||
ggml_tensor* forward(ggml_context* ctx, ggml_tensor* out_scale, ggml_tensor* x) {
|
||||
ggml_tensor* out = x;
|
||||
for (int i = 0; i < num_residual_blocks; i++) {
|
||||
// out = self.rdb...(x)
|
||||
out = rd_blocks[i].forward(ctx, out_scale, out);
|
||||
}
|
||||
// return out * 0.2 + x
|
||||
out = ggml_add(ctx, ggml_scale(ctx, out, out_scale), x);
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
struct ESRGAN : public GGMLModule {
|
||||
int scale = 4; // default RealESRGAN_x4plus_anime_6B
|
||||
int num_blocks = 6; // default RealESRGAN_x4plus_anime_6B
|
||||
int in_channels = 3;
|
||||
int out_channels = 3;
|
||||
int num_features = 64; // default RealESRGAN_x4plus_anime_6B
|
||||
int num_grow_ch = 32; // default RealESRGAN_x4plus_anime_6B
|
||||
int tile_size = 128; // avoid cuda OOM for 4gb VRAM
|
||||
|
||||
ggml_tensor* conv_first_w; // [num_features, in_channels, 3, 3]
|
||||
ggml_tensor* conv_first_b; // [num_features]
|
||||
|
||||
EsrganBlock body_blocks[6];
|
||||
ggml_tensor* conv_body_w; // [num_features, num_features, 3, 3]
|
||||
ggml_tensor* conv_body_b; // [num_features]
|
||||
|
||||
// upsample
|
||||
ggml_tensor* conv_up1_w; // [num_features, num_features, 3, 3]
|
||||
ggml_tensor* conv_up1_b; // [num_features]
|
||||
ggml_tensor* conv_up2_w; // [num_features, num_features, 3, 3]
|
||||
ggml_tensor* conv_up2_b; // [num_features]
|
||||
|
||||
ggml_tensor* conv_hr_w; // [num_features, num_features, 3, 3]
|
||||
ggml_tensor* conv_hr_b; // [num_features]
|
||||
ggml_tensor* conv_last_w; // [out_channels, num_features, 3, 3]
|
||||
ggml_tensor* conv_last_b; // [out_channels]
|
||||
|
||||
bool decode_only = false;
|
||||
|
||||
ESRGAN() {
|
||||
name = "esrgan";
|
||||
for (int i = 0; i < num_blocks; i++) {
|
||||
body_blocks[i] = EsrganBlock(num_features, num_grow_ch);
|
||||
}
|
||||
}
|
||||
|
||||
size_t calculate_mem_size() {
|
||||
size_t mem_size = num_features * in_channels * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv_first_w
|
||||
mem_size += num_features * ggml_type_size(GGML_TYPE_F32); // conv_first_b
|
||||
|
||||
for (int i = 0; i < num_blocks; i++) {
|
||||
mem_size += body_blocks[i].calculate_mem_size();
|
||||
}
|
||||
|
||||
mem_size += num_features * num_features * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv_body_w
|
||||
mem_size += num_features * ggml_type_size(GGML_TYPE_F32); // conv_body_w
|
||||
|
||||
// upsample
|
||||
mem_size += num_features * num_features * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv_up1_w
|
||||
mem_size += num_features * ggml_type_size(GGML_TYPE_F32); // conv_up1_b
|
||||
|
||||
mem_size += num_features * num_features * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv_up2_w
|
||||
mem_size += num_features * ggml_type_size(GGML_TYPE_F32); // conv_up2_b
|
||||
|
||||
mem_size += num_features * num_features * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv_hr_w
|
||||
mem_size += num_features * ggml_type_size(GGML_TYPE_F32); // conv_hr_b
|
||||
|
||||
mem_size += out_channels * num_features * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv_last_w
|
||||
mem_size += out_channels * ggml_type_size(GGML_TYPE_F32); // conv_last_b
|
||||
return mem_size;
|
||||
}
|
||||
|
||||
size_t get_num_tensors() {
|
||||
size_t num_tensors = 12;
|
||||
for (int i = 0; i < num_blocks; i++) {
|
||||
num_tensors += body_blocks[i].get_num_tensors();
|
||||
}
|
||||
return num_tensors;
|
||||
}
|
||||
|
||||
void init_params() {
|
||||
ggml_allocr* alloc = ggml_allocr_new_from_buffer(params_buffer);
|
||||
conv_first_w = ggml_new_tensor_4d(params_ctx, GGML_TYPE_F16, 3, 3, in_channels, num_features);
|
||||
conv_first_b = ggml_new_tensor_1d(params_ctx, GGML_TYPE_F32, num_features);
|
||||
conv_body_w = ggml_new_tensor_4d(params_ctx, GGML_TYPE_F16, 3, 3, num_features, num_features);
|
||||
conv_body_b = ggml_new_tensor_1d(params_ctx, GGML_TYPE_F32, num_features);
|
||||
conv_up1_w = ggml_new_tensor_4d(params_ctx, GGML_TYPE_F16, 3, 3, num_features, num_features);
|
||||
conv_up1_b = ggml_new_tensor_1d(params_ctx, GGML_TYPE_F32, num_features);
|
||||
conv_up2_w = ggml_new_tensor_4d(params_ctx, GGML_TYPE_F16, 3, 3, num_features, num_features);
|
||||
conv_up2_b = ggml_new_tensor_1d(params_ctx, GGML_TYPE_F32, num_features);
|
||||
conv_hr_w = ggml_new_tensor_4d(params_ctx, GGML_TYPE_F16, 3, 3, num_features, num_features);
|
||||
conv_hr_b = ggml_new_tensor_1d(params_ctx, GGML_TYPE_F32, num_features);
|
||||
conv_last_w = ggml_new_tensor_4d(params_ctx, GGML_TYPE_F16, 3, 3, num_features, out_channels);
|
||||
conv_last_b = ggml_new_tensor_1d(params_ctx, GGML_TYPE_F32, out_channels);
|
||||
|
||||
for (int i = 0; i < num_blocks; i++) {
|
||||
body_blocks[i].init_params(params_ctx);
|
||||
}
|
||||
|
||||
// alloc all tensors linked to this context
|
||||
for (struct ggml_tensor* t = ggml_get_first_tensor(params_ctx); t != NULL; t = ggml_get_next_tensor(params_ctx, t)) {
|
||||
if (t->data == NULL) {
|
||||
ggml_allocr_alloc(alloc, t);
|
||||
}
|
||||
}
|
||||
ggml_allocr_free(alloc);
|
||||
}
|
||||
|
||||
bool load_from_file(const std::string& file_path, ggml_backend_t backend) {
|
||||
LOG_INFO("loading esrgan from '%s'", file_path.c_str());
|
||||
|
||||
if (!alloc_params_buffer(backend)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::map<std::string, ggml_tensor*> esrgan_tensors;
|
||||
|
||||
// prepare memory for the weights
|
||||
{
|
||||
init_params();
|
||||
map_by_name(esrgan_tensors);
|
||||
}
|
||||
|
||||
ModelLoader model_loader;
|
||||
if (!model_loader.init_from_file(file_path)) {
|
||||
LOG_ERROR("init esrgan model loader from file failed: '%s'", file_path.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
bool success = model_loader.load_tensors(esrgan_tensors, backend);
|
||||
|
||||
if (!success) {
|
||||
LOG_ERROR("load esrgan tensors from model loader failed");
|
||||
return false;
|
||||
}
|
||||
|
||||
LOG_INFO("esrgan model loaded");
|
||||
return success;
|
||||
}
|
||||
|
||||
void map_by_name(std::map<std::string, ggml_tensor*>& tensors) {
|
||||
tensors["conv_first.weight"] = conv_first_w;
|
||||
tensors["conv_first.bias"] = conv_first_b;
|
||||
|
||||
for (int i = 0; i < num_blocks; i++) {
|
||||
body_blocks[i].map_by_name(tensors, "body." + std::to_string(i) + ".");
|
||||
}
|
||||
|
||||
tensors["conv_body.weight"] = conv_body_w;
|
||||
tensors["conv_body.bias"] = conv_body_b;
|
||||
|
||||
tensors["conv_up1.weight"] = conv_up1_w;
|
||||
tensors["conv_up1.bias"] = conv_up1_b;
|
||||
tensors["conv_up2.weight"] = conv_up2_w;
|
||||
tensors["conv_up2.bias"] = conv_up2_b;
|
||||
tensors["conv_hr.weight"] = conv_hr_w;
|
||||
tensors["conv_hr.bias"] = conv_hr_b;
|
||||
|
||||
tensors["conv_last.weight"] = conv_last_w;
|
||||
tensors["conv_last.bias"] = conv_last_b;
|
||||
}
|
||||
|
||||
ggml_tensor* forward(ggml_context* ctx0, ggml_tensor* out_scale, ggml_tensor* x /* feat */) {
|
||||
// feat = self.conv_first(feat)
|
||||
auto h = ggml_nn_conv_2d(ctx0, x, conv_first_w, conv_first_b, 1, 1, 1, 1);
|
||||
|
||||
auto body_h = h;
|
||||
// self.body(feat)
|
||||
for (int i = 0; i < num_blocks; i++) {
|
||||
body_h = body_blocks[i].forward(ctx0, out_scale, body_h);
|
||||
}
|
||||
|
||||
// body_feat = self.conv_body(self.body(feat))
|
||||
body_h = ggml_nn_conv_2d(ctx0, body_h, conv_body_w, conv_body_b, 1, 1, 1, 1);
|
||||
|
||||
// feat = feat + body_feat
|
||||
h = ggml_add(ctx0, h, body_h);
|
||||
|
||||
// upsample
|
||||
// feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))
|
||||
h = ggml_upscale(ctx0, h, 2);
|
||||
h = ggml_nn_conv_2d(ctx0, h, conv_up1_w, conv_up1_b, 1, 1, 1, 1);
|
||||
h = ggml_leaky_relu(ctx0, h, 0.2f, true);
|
||||
|
||||
// feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest')))
|
||||
h = ggml_upscale(ctx0, h, 2);
|
||||
h = ggml_nn_conv_2d(ctx0, h, conv_up2_w, conv_up2_b, 1, 1, 1, 1);
|
||||
h = ggml_leaky_relu(ctx0, h, 0.2f, true);
|
||||
|
||||
// out = self.conv_last(self.lrelu(self.conv_hr(feat)))
|
||||
h = ggml_nn_conv_2d(ctx0, h, conv_hr_w, conv_hr_b, 1, 1, 1, 1);
|
||||
h = ggml_leaky_relu(ctx0, h, 0.2f, true);
|
||||
|
||||
h = ggml_nn_conv_2d(ctx0, h, conv_last_w, conv_last_b, 1, 1, 1, 1);
|
||||
return h;
|
||||
}
|
||||
|
||||
struct ggml_cgraph* build_graph(struct ggml_tensor* x) {
|
||||
// since we are using ggml-alloc, this buffer only needs enough space to hold the ggml_tensor and ggml_cgraph structs, but not the tensor data
|
||||
static size_t buf_size = ggml_tensor_overhead() * GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead();
|
||||
static std::vector<uint8_t> buf(buf_size);
|
||||
|
||||
struct ggml_init_params params = {
|
||||
/*.mem_size =*/buf_size,
|
||||
/*.mem_buffer =*/buf.data(),
|
||||
/*.no_alloc =*/true, // the tensors will be allocated later by ggml_allocr_alloc_graph()
|
||||
};
|
||||
|
||||
struct ggml_context* ctx0 = ggml_init(params);
|
||||
|
||||
struct ggml_cgraph* gf = ggml_new_graph(ctx0);
|
||||
|
||||
struct ggml_tensor* x_ = NULL;
|
||||
struct ggml_tensor* os = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
|
||||
ggml_allocr_alloc(compute_allocr, os);
|
||||
if (!ggml_allocr_is_measure(compute_allocr)) {
|
||||
float scale = 0.2f;
|
||||
ggml_backend_tensor_set(os, &scale, 0, sizeof(scale));
|
||||
}
|
||||
|
||||
// it's performing a compute, check if backend isn't cpu
|
||||
if (!ggml_backend_is_cpu(backend)) {
|
||||
// pass input tensors to gpu memory
|
||||
x_ = ggml_dup_tensor(ctx0, x);
|
||||
ggml_allocr_alloc(compute_allocr, x_);
|
||||
|
||||
// pass data to device backend
|
||||
if (!ggml_allocr_is_measure(compute_allocr)) {
|
||||
ggml_backend_tensor_set(x_, x->data, 0, ggml_nbytes(x));
|
||||
}
|
||||
} else {
|
||||
x_ = x;
|
||||
}
|
||||
|
||||
struct ggml_tensor* out = forward(ctx0, os, x);
|
||||
|
||||
ggml_build_forward_expand(gf, out);
|
||||
ggml_free(ctx0);
|
||||
|
||||
return gf;
|
||||
}
|
||||
|
||||
void alloc_compute_buffer(struct ggml_tensor* x) {
|
||||
auto get_graph = [&]() -> struct ggml_cgraph* {
|
||||
return build_graph(x);
|
||||
};
|
||||
GGMLModule::alloc_compute_buffer(get_graph);
|
||||
}
|
||||
|
||||
void compute(struct ggml_tensor* work_result, const int n_threads, struct ggml_tensor* x) {
|
||||
auto get_graph = [&]() -> struct ggml_cgraph* {
|
||||
return build_graph(x);
|
||||
};
|
||||
GGMLModule::compute(get_graph, n_threads, work_result);
|
||||
}
|
||||
};
|
||||
|
||||
#endif // __ESRGAN_HPP__
|
@ -1,9 +1,12 @@
|
||||
#include <stdio.h>
|
||||
#include <ctime>
|
||||
#include <string.h>
|
||||
#include <time.h>
|
||||
#include <iostream>
|
||||
#include <random>
|
||||
#include "ggml/ggml.h"
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "stable-diffusion.h"
|
||||
#include "util.h"
|
||||
|
||||
#define STB_IMAGE_IMPLEMENTATION
|
||||
#include "stb_image.h"
|
||||
@ -12,11 +15,6 @@
|
||||
#define STB_IMAGE_WRITE_STATIC
|
||||
#include "stb_image_write.h"
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
const char* rng_type_to_str[] = {
|
||||
"std_default",
|
||||
"cuda",
|
||||
@ -60,7 +58,7 @@ struct SDParams {
|
||||
std::string vae_path;
|
||||
std::string taesd_path;
|
||||
std::string esrgan_path;
|
||||
ggml_type wtype = GGML_TYPE_COUNT;
|
||||
sd_type_t wtype = SD_TYPE_COUNT;
|
||||
std::string lora_model_dir;
|
||||
std::string output_path = "output.png";
|
||||
std::string input_path;
|
||||
@ -73,22 +71,34 @@ struct SDParams {
|
||||
int height = 512;
|
||||
int batch_count = 1;
|
||||
|
||||
SampleMethod sample_method = EULER_A;
|
||||
Schedule schedule = DEFAULT;
|
||||
sample_method_t sample_method = EULER_A;
|
||||
schedule_t schedule = DEFAULT;
|
||||
int sample_steps = 20;
|
||||
float strength = 0.75f;
|
||||
RNGType rng_type = CUDA_RNG;
|
||||
rng_type_t rng_type = CUDA_RNG;
|
||||
int64_t seed = 42;
|
||||
bool verbose = false;
|
||||
bool vae_tiling = false;
|
||||
};
|
||||
|
||||
static std::string sd_basename(const std::string& path) {
|
||||
size_t pos = path.find_last_of('/');
|
||||
if (pos != std::string::npos) {
|
||||
return path.substr(pos + 1);
|
||||
}
|
||||
pos = path.find_last_of('\\');
|
||||
if (pos != std::string::npos) {
|
||||
return path.substr(pos + 1);
|
||||
}
|
||||
return path;
|
||||
}
|
||||
|
||||
void print_params(SDParams params) {
|
||||
printf("Option: \n");
|
||||
printf(" n_threads: %d\n", params.n_threads);
|
||||
printf(" mode: %s\n", modes_str[params.mode]);
|
||||
printf(" model_path: %s\n", params.model_path.c_str());
|
||||
printf(" wtype: %s\n", params.wtype < GGML_TYPE_COUNT ? ggml_type_name(params.wtype) : "unspecified");
|
||||
printf(" wtype: %s\n", params.wtype < SD_TYPE_COUNT ? sd_type_name(params.wtype) : "unspecified");
|
||||
printf(" vae_path: %s\n", params.vae_path.c_str());
|
||||
printf(" taesd_path: %s\n", params.taesd_path.c_str());
|
||||
printf(" esrgan_path: %s\n", params.esrgan_path.c_str());
|
||||
@ -208,19 +218,19 @@ void parse_args(int argc, const char** argv, SDParams& params) {
|
||||
}
|
||||
std::string type = argv[i];
|
||||
if (type == "f32") {
|
||||
params.wtype = GGML_TYPE_F32;
|
||||
params.wtype = SD_TYPE_F32;
|
||||
} else if (type == "f16") {
|
||||
params.wtype = GGML_TYPE_F16;
|
||||
params.wtype = SD_TYPE_F16;
|
||||
} else if (type == "q4_0") {
|
||||
params.wtype = GGML_TYPE_Q4_0;
|
||||
params.wtype = SD_TYPE_Q4_0;
|
||||
} else if (type == "q4_1") {
|
||||
params.wtype = GGML_TYPE_Q4_1;
|
||||
params.wtype = SD_TYPE_Q4_1;
|
||||
} else if (type == "q5_0") {
|
||||
params.wtype = GGML_TYPE_Q5_0;
|
||||
params.wtype = SD_TYPE_Q5_0;
|
||||
} else if (type == "q5_1") {
|
||||
params.wtype = GGML_TYPE_Q5_1;
|
||||
params.wtype = SD_TYPE_Q5_1;
|
||||
} else if (type == "q8_0") {
|
||||
params.wtype = GGML_TYPE_Q8_0;
|
||||
params.wtype = SD_TYPE_Q8_0;
|
||||
} else {
|
||||
fprintf(stderr, "error: invalid weight format %s, must be one of [f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0]\n",
|
||||
type.c_str());
|
||||
@ -330,7 +340,7 @@ void parse_args(int argc, const char** argv, SDParams& params) {
|
||||
invalid_arg = true;
|
||||
break;
|
||||
}
|
||||
params.schedule = (Schedule)schedule_found;
|
||||
params.schedule = (schedule_t)schedule_found;
|
||||
} else if (arg == "-s" || arg == "--seed") {
|
||||
if (++i >= argc) {
|
||||
invalid_arg = true;
|
||||
@ -353,7 +363,7 @@ void parse_args(int argc, const char** argv, SDParams& params) {
|
||||
invalid_arg = true;
|
||||
break;
|
||||
}
|
||||
params.sample_method = (SampleMethod)sample_method_found;
|
||||
params.sample_method = (sample_method_t)sample_method_found;
|
||||
} else if (arg == "-h" || arg == "--help") {
|
||||
print_usage(argc, argv);
|
||||
exit(0);
|
||||
@ -433,7 +443,7 @@ std::string get_image_params(SDParams params, int64_t seed) {
|
||||
parameter_string += "CFG scale: " + std::to_string(params.cfg_scale) + ", ";
|
||||
parameter_string += "Seed: " + std::to_string(seed) + ", ";
|
||||
parameter_string += "Size: " + std::to_string(params.width) + "x" + std::to_string(params.height) + ", ";
|
||||
parameter_string += "Model: " + basename(params.model_path) + ", ";
|
||||
parameter_string += "Model: " + sd_basename(params.model_path) + ", ";
|
||||
parameter_string += "RNG: " + std::string(rng_type_to_str[params.rng_type]) + ", ";
|
||||
parameter_string += "Sampler: " + std::string(sample_method_str[params.sample_method]);
|
||||
if (params.schedule == KARRAS) {
|
||||
@ -444,14 +454,29 @@ std::string get_image_params(SDParams params, int64_t seed) {
|
||||
return parameter_string;
|
||||
}
|
||||
|
||||
void sd_log_cb(enum sd_log_level_t level, const char* log, void* data) {
|
||||
SDParams* params = (SDParams*)data;
|
||||
if (!params->verbose && level <= SD_LOG_DEBUG) {
|
||||
return;
|
||||
}
|
||||
if (level <= SD_LOG_INFO) {
|
||||
fprintf(stdout, log);
|
||||
fflush(stdout);
|
||||
} else {
|
||||
fprintf(stderr, log);
|
||||
fflush(stderr);
|
||||
}
|
||||
};
|
||||
|
||||
int main(int argc, const char* argv[]) {
|
||||
SDParams params;
|
||||
parse_args(argc, argv, params);
|
||||
|
||||
sd_set_log_callback(sd_log_cb, (void*)¶ms);
|
||||
|
||||
if (params.verbose) {
|
||||
print_params(params);
|
||||
printf("%s", sd_get_system_info().c_str());
|
||||
set_sd_log_level(SDLogLevel::DEBUG);
|
||||
printf("%s", sd_get_system_info());
|
||||
}
|
||||
|
||||
bool vae_decode_only = true;
|
||||
@ -482,16 +507,29 @@ int main(int argc, const char* argv[]) {
|
||||
}
|
||||
}
|
||||
|
||||
StableDiffusion sd(params.n_threads, vae_decode_only, params.taesd_path, params.esrgan_path, true, params.vae_tiling, params.lora_model_dir, params.rng_type);
|
||||
sd_ctx_t* sd_ctx = new_sd_ctx(params.model_path.c_str(),
|
||||
params.vae_path.c_str(),
|
||||
params.taesd_path.c_str(),
|
||||
params.lora_model_dir.c_str(),
|
||||
vae_decode_only,
|
||||
params.vae_tiling,
|
||||
true,
|
||||
params.n_threads,
|
||||
params.wtype,
|
||||
params.rng_type,
|
||||
params.schedule);
|
||||
|
||||
if (!sd.load_from_file(params.model_path, params.vae_path, params.wtype, params.schedule, params.clip_skip)) {
|
||||
if (sd_ctx == NULL) {
|
||||
printf("new_sd_ctx_t failed\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
std::vector<uint8_t*> results;
|
||||
sd_image_t* results;
|
||||
if (params.mode == TXT2IMG) {
|
||||
results = sd.txt2img(params.prompt,
|
||||
params.negative_prompt,
|
||||
results = txt2img(sd_ctx,
|
||||
params.prompt.c_str(),
|
||||
params.negative_prompt.c_str(),
|
||||
params.clip_skip,
|
||||
params.cfg_scale,
|
||||
params.width,
|
||||
params.height,
|
||||
@ -500,42 +538,67 @@ int main(int argc, const char* argv[]) {
|
||||
params.seed,
|
||||
params.batch_count);
|
||||
} else {
|
||||
results = sd.img2img(input_image_buffer,
|
||||
params.prompt,
|
||||
params.negative_prompt,
|
||||
sd_image_t input_image = {(uint32_t)params.width,
|
||||
(uint32_t)params.height,
|
||||
3,
|
||||
input_image_buffer};
|
||||
|
||||
results = img2img(sd_ctx,
|
||||
input_image,
|
||||
params.prompt.c_str(),
|
||||
params.negative_prompt.c_str(),
|
||||
params.clip_skip,
|
||||
params.cfg_scale,
|
||||
params.width,
|
||||
params.height,
|
||||
params.sample_method,
|
||||
params.sample_steps,
|
||||
params.strength,
|
||||
params.seed);
|
||||
params.seed,
|
||||
params.batch_count);
|
||||
}
|
||||
|
||||
if (params.esrgan_path.size() > 0) {
|
||||
// TODO: support more ESRGAN models, making it easier to set up ESRGAN models.
|
||||
/* hardcoded scale factor because just RealESRGAN_x4plus_anime_6B is compatible
|
||||
See also: https://github.com/xinntao/Real-ESRGAN/blob/master/inference_realesrgan.py
|
||||
|
||||
To avoid this, the upscaler needs to be separated from the stable diffusion pipeline.
|
||||
However, a considerable amount of work would be required for this. It might be better
|
||||
to opt for a complete project refactoring that facilitates the easier assignment of parameters.
|
||||
*/
|
||||
params.width *= 4;
|
||||
params.height *= 4;
|
||||
}
|
||||
|
||||
if (results.size() == 0 || results.size() != params.batch_count) {
|
||||
LOG_ERROR("generate failed");
|
||||
if (results == NULL) {
|
||||
printf("generate failed\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
int upscale_factor = 4; // unused for RealESRGAN_x4plus_anime_6B.pth
|
||||
if (params.esrgan_path.size() > 0) {
|
||||
upscaler_ctx_t* upscaler_ctx = new_upscaler_ctx(params.esrgan_path.c_str(),
|
||||
params.n_threads,
|
||||
params.wtype);
|
||||
|
||||
if (upscaler_ctx == NULL) {
|
||||
printf("new_upscaler_ctx failed\n");
|
||||
} else {
|
||||
for (int i = 0; i < params.batch_count; i++) {
|
||||
if (results |