feat: add token weighting support (#13)
This commit is contained in:
parent
7132027862
commit
17095dddea
@ -16,6 +16,7 @@ Inference of [Stable Diffusion](https://github.com/CompVis/stable-diffusion) in
|
||||
- AVX, AVX2 and AVX512 support for x86 architectures
|
||||
- Original `txt2img` and `img2img` mode
|
||||
- Negative prompt
|
||||
- [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) style tokenizer (not all the features, only token weighting for now)
|
||||
- Sampling method
|
||||
- `Euler A`
|
||||
- Supported platforms
|
||||
@ -30,7 +31,6 @@ Inference of [Stable Diffusion](https://github.com/CompVis/stable-diffusion) in
|
||||
- [ ] Make inference faster
|
||||
- The current implementation of ggml_conv_2d is slow and has high memory usage
|
||||
- [ ] Continuing to reduce memory usage (quantizing the weights of ggml_conv_2d)
|
||||
- [ ] [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) style tokenizer (eg: token weighting, ...)
|
||||
- [ ] LoRA support
|
||||
- [ ] k-quants support
|
||||
- [ ] Cross-platform reproducibility (perhaps ensuring consistency with the original SD)
|
||||
|
@ -355,6 +355,113 @@ class CLIPTokenizer {
|
||||
}
|
||||
};
|
||||
|
||||
// Ref: https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/cad87bf4e3e0b0a759afa94e933527c3123d59bc/modules/prompt_parser.py#L345
|
||||
//
|
||||
// Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
|
||||
// Accepted tokens are:
|
||||
// (abc) - increases attention to abc by a multiplier of 1.1
|
||||
// (abc:3.12) - increases attention to abc by a multiplier of 3.12
|
||||
// [abc] - decreases attention to abc by a multiplier of 1.1
|
||||
// \( - literal character '('
|
||||
// \[ - literal character '['
|
||||
// \) - literal character ')'
|
||||
// \] - literal character ']'
|
||||
// \\ - literal character '\'
|
||||
// anything else - just text
|
||||
//
|
||||
// >>> parse_prompt_attention('normal text')
|
||||
// [['normal text', 1.0]]
|
||||
// >>> parse_prompt_attention('an (important) word')
|
||||
// [['an ', 1.0], ['important', 1.1], [' word', 1.0]]
|
||||
// >>> parse_prompt_attention('(unbalanced')
|
||||
// [['unbalanced', 1.1]]
|
||||
// >>> parse_prompt_attention('\(literal\]')
|
||||
// [['(literal]', 1.0]]
|
||||
// >>> parse_prompt_attention('(unnecessary)(parens)')
|
||||
// [['unnecessaryparens', 1.1]]
|
||||
// >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
|
||||
// [['a ', 1.0],
|
||||
// ['house', 1.5730000000000004],
|
||||
// [' ', 1.1],
|
||||
// ['on', 1.0],
|
||||
// [' a ', 1.1],
|
||||
// ['hill', 0.55],
|
||||
// [', sun, ', 1.1],
|
||||
// ['sky', 1.4641000000000006],
|
||||
// ['.', 1.1]]
|
||||
std::vector<std::pair<std::string, float>> parse_prompt_attention(const std::string& text) {
|
||||
std::vector<std::pair<std::string, float>> res;
|
||||
std::vector<int> round_brackets;
|
||||
std::vector<int> square_brackets;
|
||||
|
||||
float round_bracket_multiplier = 1.1f;
|
||||
float square_bracket_multiplier = 1 / 1.1f;
|
||||
|
||||
std::regex re_attention(R"(\\\(|\\\)|\\\[|\\]|\\\\|\\|\(|\[|:([+-]?[.\d]+)\)|\)|]|[^\\()\[\]:]+|:)");
|
||||
std::regex re_break(R"(\s*\bBREAK\b\s*)");
|
||||
|
||||
auto multiply_range = [&](int start_position, float multiplier) {
|
||||
for (int p = start_position; p < res.size(); ++p) {
|
||||
res[p].second *= multiplier;
|
||||
}
|
||||
};
|
||||
|
||||
std::smatch m;
|
||||
std::string remaining_text = text;
|
||||
|
||||
while (std::regex_search(remaining_text, m, re_attention)) {
|
||||
std::string text = m[0];
|
||||
std::string weight = m[1];
|
||||
|
||||
if (text == "(") {
|
||||
round_brackets.push_back(res.size());
|
||||
} else if (text == "[") {
|
||||
square_brackets.push_back(res.size());
|
||||
} else if (!weight.empty()) {
|
||||
if (!round_brackets.empty()) {
|
||||
multiply_range(round_brackets.back(), std::stod(weight));
|
||||
round_brackets.pop_back();
|
||||
}
|
||||
} else if (text == ")" && !round_brackets.empty()) {
|
||||
multiply_range(round_brackets.back(), round_bracket_multiplier);
|
||||
round_brackets.pop_back();
|
||||
} else if (text == "]" && !square_brackets.empty()) {
|
||||
multiply_range(square_brackets.back(), square_bracket_multiplier);
|
||||
square_brackets.pop_back();
|
||||
} else if (text == "\\(") {
|
||||
res.push_back({text.substr(1), 1.0f});
|
||||
} else {
|
||||
res.push_back({text, 1.0f});
|
||||
}
|
||||
|
||||
remaining_text = m.suffix();
|
||||
}
|
||||
|
||||
for (int pos : round_brackets) {
|
||||
multiply_range(pos, round_bracket_multiplier);
|
||||
}
|
||||
|
||||
for (int pos : square_brackets) {
|
||||
multiply_range(pos, square_bracket_multiplier);
|
||||
}
|
||||
|
||||
if (res.empty()) {
|
||||
res.push_back({"", 1.0f});
|
||||
}
|
||||
|
||||
int i = 0;
|
||||
while (i + 1 < res.size()) {
|
||||
if (res[i].second == res[i + 1].second) {
|
||||
res[i].first += res[i + 1].first;
|
||||
res.erase(res.begin() + i + 1);
|
||||
} else {
|
||||
++i;
|
||||
}
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
/*================================================ FrozenCLIPEmbedder ================================================*/
|
||||
|
||||
struct ResidualAttentionBlock {
|
||||
@ -639,6 +746,61 @@ struct FrozenCLIPEmbedder {
|
||||
}
|
||||
};
|
||||
|
||||
// Ref: https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/cad87bf4e3e0b0a759afa94e933527c3123d59bc/modules/sd_hijack_clip.py#L283
|
||||
struct FrozenCLIPEmbedderWithCustomWords {
|
||||
CLIPTokenizer tokenizer;
|
||||
CLIPTextModel text_model;
|
||||
|
||||
std::pair<std::vector<int>, std::vector<float>> tokenize(std::string text,
|
||||
size_t max_length = 0,
|
||||
bool padding = false) {
|
||||
auto parsed_attention = parse_prompt_attention(text);
|
||||
|
||||
{
|
||||
std::stringstream ss;
|
||||
ss << "[";
|
||||
for (const auto& item : parsed_attention) {
|
||||
ss << "['" << item.first << "', " << item.second << "], ";
|
||||
}
|
||||
ss << "]";
|
||||
LOG_DEBUG("parse '%s' to %s", text.c_str(), ss.str().c_str());
|
||||
}
|
||||
|
||||
std::vector<int> tokens;
|
||||
std::vector<float> weights;
|
||||
for (const auto& item : parsed_attention) {
|
||||
const std::string& curr_text = item.first;
|
||||
float curr_weight = item.second;
|
||||
std::vector<int> curr_tokens = tokenizer.encode(curr_text);
|
||||
tokens.insert(tokens.end(), curr_tokens.begin(), curr_tokens.end());
|
||||
weights.insert(weights.end(), curr_tokens.size(), curr_weight);
|
||||
}
|
||||
tokens.insert(tokens.begin(), BOS_TOKEN_ID);
|
||||
weights.insert(weights.begin(), 1.0);
|
||||
|
||||
if (max_length > 0) {
|
||||
if (tokens.size() > max_length - 1) {
|
||||
tokens.resize(max_length - 1);
|
||||
weights.resize(max_length - 1);
|
||||
} else {
|
||||
if (padding) {
|
||||
tokens.insert(tokens.end(), max_length - 1 - tokens.size(), PAD_TOKEN_ID);
|
||||
weights.insert(weights.end(), max_length - 1 - weights.size(), 1.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
tokens.push_back(EOS_TOKEN_ID);
|
||||
weights.push_back(1.0);
|
||||
|
||||
// for (int i = 0; i < tokens.size(); i++) {
|
||||
// std::cout << tokens[i] << ":" << weights[i] << ", ";
|
||||
// }
|
||||
// std::cout << std::endl;
|
||||
|
||||
return {tokens, weights};
|
||||
}
|
||||
};
|
||||
|
||||
/*==================================================== UnetModel =====================================================*/
|
||||
|
||||
struct ResBlock {
|
||||
@ -2489,7 +2651,7 @@ class StableDiffusionGGML {
|
||||
size_t max_params_mem_size = 0;
|
||||
size_t max_rt_mem_size = 0;
|
||||
|
||||
FrozenCLIPEmbedder cond_stage_model;
|
||||
FrozenCLIPEmbedderWithCustomWords cond_stage_model;
|
||||
UNetModel diffusion_model;
|
||||
AutoEncoderKL first_stage_model;
|
||||
|
||||
@ -2784,9 +2946,11 @@ class StableDiffusionGGML {
|
||||
}
|
||||
|
||||
ggml_tensor* get_learned_condition(ggml_context* res_ctx, const std::string& text) {
|
||||
std::vector<int32_t> tokens = cond_stage_model.tokenizer.tokenize(text,
|
||||
cond_stage_model.text_model.max_position_embeddings,
|
||||
true);
|
||||
auto tokens_and_weights = cond_stage_model.tokenize(text,
|
||||
cond_stage_model.text_model.max_position_embeddings,
|
||||
true);
|
||||
std::vector<int>& tokens = tokens_and_weights.first;
|
||||
std::vector<float>& weights = tokens_and_weights.second;
|
||||
size_t ctx_size = 1 * 1024 * 1024; // 1MB
|
||||
// calculate the amount of memory required
|
||||
{
|
||||
@ -2848,10 +3012,39 @@ class StableDiffusionGGML {
|
||||
int64_t t1 = ggml_time_ms();
|
||||
LOG_DEBUG("computing condition graph completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
|
||||
|
||||
ggml_tensor* result = ggml_dup_tensor(res_ctx, hidden_states);
|
||||
copy_ggml_tensor(result, hidden_states);
|
||||
ggml_tensor* result = ggml_dup_tensor(res_ctx, hidden_states); // [N, n_token, hidden_size]
|
||||
|
||||
{
|
||||
int64_t nelements = ggml_nelements(hidden_states);
|
||||
float original_mean = 0.f;
|
||||
float new_mean = 0.f;
|
||||
float* vec = (float*)hidden_states->data;
|
||||
for (int i = 0; i < nelements; i++) {
|
||||
original_mean += vec[i] / nelements * 1.0f;
|
||||
}
|
||||
|
||||
for (int i2 = 0; i2 < hidden_states->ne[2]; i2++) {
|
||||
for (int i1 = 0; i1 < hidden_states->ne[1]; i1++) {
|
||||
for (int i0 = 0; i0 < hidden_states->ne[0]; i0++) {
|
||||
float value = ggml_tensor_get_f32(hidden_states, i0, i1, i2);
|
||||
value *= weights[i1];
|
||||
ggml_tensor_set_f32(result, value, i0, i1, i2);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
vec = (float*)result->data;
|
||||
for (int i = 0; i < nelements; i++) {
|
||||
new_mean += vec[i] / nelements * 1.0f;
|
||||
}
|
||||
|
||||
for (int i = 0; i < nelements; i++) {
|
||||
vec[i] = vec[i] * (original_mean / new_mean);
|
||||
}
|
||||
}
|
||||
|
||||
// print_ggml_tensor(result);
|
||||
|
||||
size_t rt_mem_size = ctx_size + ggml_curr_max_dynamic_size();
|
||||
if (rt_mem_size > max_rt_mem_size) {
|
||||
max_rt_mem_size = rt_mem_size;
|
||||
|
Loading…
Reference in New Issue
Block a user