From ef5c3f74012c3e23f6f244b280b97091a281cd59 Mon Sep 17 00:00:00 2001 From: leejet Date: Sat, 2 Mar 2024 17:12:39 +0800 Subject: [PATCH] feat: add support for prompt longer than 77 --- clip.hpp | 82 ++++++++++++++++++++++-------- ggml_extend.hpp | 8 +-- stable-diffusion.cpp | 116 +++++++++++++++++++++++++------------------ util.cpp | 4 +- 4 files changed, 133 insertions(+), 77 deletions(-) diff --git a/clip.hpp b/clip.hpp index 31efa5f..9bdbc80 100644 --- a/clip.hpp +++ b/clip.hpp @@ -558,11 +558,14 @@ public: auto token_embed_weight = params["token_embedding.weight"]; auto position_embed_weight = params["position_embedding.weight"]; - GGML_ASSERT(input_ids->ne[0] <= position_embed_weight->ne[0]); + GGML_ASSERT(input_ids->ne[0] == position_embed_weight->ne[1]); + input_ids = ggml_reshape_3d(ctx, input_ids, input_ids->ne[0], 1, input_ids->ne[1]); + auto token_embedding = ggml_get_rows(ctx, custom_embed_weight != NULL ? custom_embed_weight : token_embed_weight, input_ids); + token_embedding = ggml_reshape_3d(ctx, token_embedding, token_embedding->ne[0], token_embedding->ne[1], token_embedding->ne[3]); // token_embedding + position_embedding auto x = ggml_add(ctx, - ggml_get_rows(ctx, custom_embed_weight != NULL ? custom_embed_weight : token_embed_weight, input_ids), + token_embedding, position_embed_weight); // [N, n_token, embed_dim] return x; } @@ -700,7 +703,7 @@ public: auto final_layer_norm = std::dynamic_pointer_cast(blocks["final_layer_norm"]); auto x = embeddings->forward(ctx, input_ids, tkn_embeddings); // [N, n_token, hidden_size] - x = encoder->forward(ctx, x, return_pooled ? -1 : clip_skip, true); + x = encoder->forward(ctx, x, return_pooled ? -1 : clip_skip, true); if (return_pooled || with_final_ln) { x = final_layer_norm->forward(ctx, x); } @@ -889,7 +892,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public GGMLModule { return false; } struct ggml_init_params params; - params.mem_size = 32 * 1024; // max for custom embeddings 32 KB + params.mem_size = 10 * 1024 * 1024; // max for custom embeddings 10 MB params.mem_buffer = NULL; params.no_alloc = false; struct ggml_context* embd_ctx = ggml_init(params); @@ -924,9 +927,21 @@ struct FrozenCLIPEmbedderWithCustomWords : public GGMLModule { struct ggml_tensor* embeddings, size_t max_token_idx = 0, bool return_pooled = false) { + size_t N = input_ids->ne[1]; + size_t n_token = input_ids->ne[0]; + if (input_ids != NULL && input_ids->ne[0] > text_model.n_token) { + GGML_ASSERT(input_ids->ne[0] % text_model.n_token == 0); + input_ids = ggml_reshape_2d(ctx, input_ids, text_model.n_token, input_ids->ne[0] / text_model.n_token); + } + if (input_ids2 != NULL && input_ids2->ne[0] > text_model2.n_token) { + GGML_ASSERT(input_ids2->ne[0] % text_model2.n_token == 0); + input_ids2 = ggml_reshape_2d(ctx, input_ids2, text_model2.n_token, input_ids2->ne[0] / text_model2.n_token); + } + if (return_pooled) { return text_model2.forward(ctx, input_ids2, NULL, max_token_idx, return_pooled); } + auto hidden_states = text_model.forward(ctx, input_ids, embeddings); // [N, n_token, hidden_size] // LOG_DEBUG("hidden_states: %d %d %d %d", hidden_states->ne[0], hidden_states->ne[1], hidden_states->ne[2], hidden_states->ne[3]); if (version == VERSION_XL) { @@ -952,6 +967,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public GGMLModule { hidden_states = ggml_cont(ctx, ggml_permute(ctx, hidden_states, 1, 2, 0, 3)); } + hidden_states = ggml_reshape_3d(ctx, hidden_states, hidden_states->ne[0], n_token, N); // LOG_DEBUG("hidden_states: %d %d %d %d", hidden_states->ne[0], hidden_states->ne[1], hidden_states->ne[2], hidden_states->ne[3]); return hidden_states; } @@ -1057,26 +1073,48 @@ struct FrozenCLIPEmbedderWithCustomWords : public GGMLModule { tokens.insert(tokens.end(), curr_tokens.begin(), curr_tokens.end()); weights.insert(weights.end(), curr_tokens.size(), curr_weight); } - tokens.insert(tokens.begin(), BOS_TOKEN_ID); - weights.insert(weights.begin(), 1.0); - if (max_length > 0) { - if (tokens.size() > max_length - 1) { - tokens.resize(max_length - 1); - weights.resize(max_length - 1); - tokens.push_back(EOS_TOKEN_ID); - weights.push_back(1.0); - } else { - tokens.push_back(EOS_TOKEN_ID); - weights.push_back(1.0); - if (padding) { - int pad_token_id = PAD_TOKEN_ID; - if (version == VERSION_2_x) { - pad_token_id = 0; - } - tokens.insert(tokens.end(), max_length - tokens.size(), pad_token_id); - weights.insert(weights.end(), max_length - weights.size(), 1.0); + if (max_length > 0 && padding) { + size_t n = std::ceil(tokens.size() * 1.0 / (max_length - 2)); + if (n == 0) { + n = 1; + } + size_t length = max_length * n; + LOG_DEBUG("token length: %llu", length); + std::vector new_tokens; + std::vector new_weights; + new_tokens.push_back(BOS_TOKEN_ID); + new_weights.push_back(1.0); + int token_idx = 0; + for (int i = 1; i < length; i++) { + if (token_idx >= tokens.size()) { + break; } + if (i % max_length == 0) { + new_tokens.push_back(BOS_TOKEN_ID); + new_weights.push_back(1.0); + } else if (i % max_length == max_length - 1) { + new_tokens.push_back(EOS_TOKEN_ID); + new_weights.push_back(1.0); + } else { + new_tokens.push_back(tokens[token_idx]); + new_weights.push_back(weights[token_idx]); + token_idx++; + } + } + + new_tokens.push_back(EOS_TOKEN_ID); + new_weights.push_back(1.0); + tokens = new_tokens; + weights = new_weights; + + if (padding) { + int pad_token_id = PAD_TOKEN_ID; + if (version == VERSION_2_x) { + pad_token_id = 0; + } + tokens.insert(tokens.end(), length - tokens.size(), pad_token_id); + weights.insert(weights.end(), length - weights.size(), 1.0); } } diff --git a/ggml_extend.hpp b/ggml_extend.hpp index dc39beb..ddcdd29 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -1231,9 +1231,9 @@ public: q = ggml_reshape_3d(ctx, q, d_head, n_token, n_head * N); // [N * n_head, n_token, d_head] struct ggml_tensor* k = k_proj->forward(ctx, x); - k = ggml_reshape_4d(ctx, k, d_head, n_head, n_token, N); // [N, n_token, n_head, d_head] - k = ggml_cont(ctx, ggml_permute(ctx, k, 0, 2, 1, 3)); // [N, n_head, n_token, d_head] - k = ggml_reshape_3d(ctx, k, d_head, n_token, n_head); // [N * n_head, n_token, d_head] + k = ggml_reshape_4d(ctx, k, d_head, n_head, n_token, N); // [N, n_token, n_head, d_head] + k = ggml_cont(ctx, ggml_permute(ctx, k, 0, 2, 1, 3)); // [N, n_head, n_token, d_head] + k = ggml_reshape_3d(ctx, k, d_head, n_token, n_head * N); // [N * n_head, n_token, d_head] struct ggml_tensor* v = v_proj->forward(ctx, x); v = ggml_reshape_4d(ctx, v, d_head, n_head, n_token, N); // [N, n_token, n_head, d_head] @@ -1245,7 +1245,7 @@ public: kqv = ggml_reshape_4d(ctx, kqv, d_head, n_token, n_head, N); kqv = ggml_cont(ctx, ggml_permute(ctx, kqv, 0, 2, 1, 3)); // [N, n_token, n_head, d_head] - x = ggml_reshape_2d(ctx, kqv, d_head * n_head, n_token * N); // [N * n_token, d_head * n_head] + x = ggml_reshape_3d(ctx, kqv, d_head * n_head, n_token, N); // [N * n_token, d_head * n_head] x = out_proj->forward(ctx, x); return x; diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 5313e7d..0845b4f 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -451,65 +451,83 @@ public: int height, bool force_zero_embeddings = false) { cond_stage_model->set_clip_skip(clip_skip); - auto tokens_and_weights = cond_stage_model->tokenize(text, true); - std::vector& tokens = tokens_and_weights.first; - std::vector& weights = tokens_and_weights.second; - int64_t t0 = ggml_time_ms(); - struct ggml_tensor* hidden_states = NULL; // [N, n_token, hidden_size] - struct ggml_tensor* pooled = NULL; + auto tokens_and_weights = cond_stage_model->tokenize(text, true); + std::vector& tokens = tokens_and_weights.first; + std::vector& weights = tokens_and_weights.second; + int64_t t0 = ggml_time_ms(); + struct ggml_tensor* hidden_states = NULL; // [N, n_token, hidden_size] + struct ggml_tensor* chunk_hidden_states = NULL; // [n_token, hidden_size] + struct ggml_tensor* pooled = NULL; + std::vector hidden_states_vec; - auto input_ids = vector_to_ggml_tensor_i32(work_ctx, tokens); - struct ggml_tensor* input_ids2 = NULL; - size_t max_token_idx = 0; - if (version == VERSION_XL) { - auto it = std::find(tokens.begin(), tokens.end(), EOS_TOKEN_ID); - if (it != tokens.end()) { - std::fill(std::next(it), tokens.end(), 0); + size_t chunk_len = 77; + size_t chunk_count = tokens.size() / chunk_len; + for (int chunk_idx = 0; chunk_idx < chunk_count; chunk_idx++) { + std::vector chunk_tokens(tokens.begin() + chunk_idx * chunk_len, + tokens.begin() + (chunk_idx + 1) * chunk_len); + std::vector chunk_weights(weights.begin() + chunk_idx * chunk_len, + weights.begin() + (chunk_idx + 1) * chunk_len); + + auto input_ids = vector_to_ggml_tensor_i32(work_ctx, chunk_tokens); + struct ggml_tensor* input_ids2 = NULL; + size_t max_token_idx = 0; + if (version == VERSION_XL) { + auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), EOS_TOKEN_ID); + if (it != chunk_tokens.end()) { + std::fill(std::next(it), chunk_tokens.end(), 0); + } + + max_token_idx = std::min(std::distance(chunk_tokens.begin(), it), chunk_tokens.size() - 1); + + input_ids2 = vector_to_ggml_tensor_i32(work_ctx, chunk_tokens); + + // for (int i = 0; i < chunk_tokens.size(); i++) { + // printf("%d ", chunk_tokens[i]); + // } + // printf("\n"); } - max_token_idx = std::min(std::distance(tokens.begin(), it), tokens.size() - 1); - - input_ids2 = vector_to_ggml_tensor_i32(work_ctx, tokens); - - // for (int i = 0; i < tokens.size(); i++) { - // printf("%d ", tokens[i]); + cond_stage_model->compute(n_threads, input_ids, input_ids2, max_token_idx, false, &chunk_hidden_states, work_ctx); + if (version == VERSION_XL && chunk_idx == 0) { + cond_stage_model->compute(n_threads, input_ids, input_ids2, max_token_idx, true, &pooled, work_ctx); + } + // if (pooled != NULL) { + // print_ggml_tensor(chunk_hidden_states); + // print_ggml_tensor(pooled); // } - // printf("\n"); - } - cond_stage_model->compute(n_threads, input_ids, input_ids2, max_token_idx, false, &hidden_states, work_ctx); - if (version == VERSION_XL) { - cond_stage_model->compute(n_threads, input_ids, input_ids2, max_token_idx, true, &pooled, work_ctx); - } - // if (pooled != NULL) { - // print_ggml_tensor(hidden_states); - // print_ggml_tensor(pooled); - // } - - int64_t t1 = ggml_time_ms(); - LOG_DEBUG("computing condition graph completed, taking %" PRId64 " ms", t1 - t0); - ggml_tensor* result = ggml_dup_tensor(work_ctx, hidden_states); - { - float original_mean = ggml_tensor_mean(hidden_states); - for (int i2 = 0; i2 < hidden_states->ne[2]; i2++) { - for (int i1 = 0; i1 < hidden_states->ne[1]; i1++) { - for (int i0 = 0; i0 < hidden_states->ne[0]; i0++) { - float value = ggml_tensor_get_f32(hidden_states, i0, i1, i2); - value *= weights[i1]; - ggml_tensor_set_f32(result, value, i0, i1, i2); + int64_t t1 = ggml_time_ms(); + LOG_DEBUG("computing condition graph completed, taking %" PRId64 " ms", t1 - t0); + ggml_tensor* result = ggml_dup_tensor(work_ctx, chunk_hidden_states); + { + float original_mean = ggml_tensor_mean(chunk_hidden_states); + for (int i2 = 0; i2 < chunk_hidden_states->ne[2]; i2++) { + for (int i1 = 0; i1 < chunk_hidden_states->ne[1]; i1++) { + for (int i0 = 0; i0 < chunk_hidden_states->ne[0]; i0++) { + float value = ggml_tensor_get_f32(chunk_hidden_states, i0, i1, i2); + value *= chunk_weights[i1]; + ggml_tensor_set_f32(result, value, i0, i1, i2); + } } } + float new_mean = ggml_tensor_mean(result); + ggml_tensor_scale(result, (original_mean / new_mean)); } - float new_mean = ggml_tensor_mean(result); - ggml_tensor_scale(result, (original_mean / new_mean)); - } - if (force_zero_embeddings) { - float* vec = (float*)result->data; - for (int i = 0; i < ggml_nelements(result); i++) { - vec[i] = 0; + if (force_zero_embeddings) { + float* vec = (float*)result->data; + for (int i = 0; i < ggml_nelements(result); i++) { + vec[i] = 0; + } } + hidden_states_vec.insert(hidden_states_vec.end(), (float*)result->data, ((float*)result->data) + ggml_nelements(result)); } + hidden_states = vector_to_ggml_tensor(work_ctx, hidden_states_vec); + hidden_states = ggml_reshape_2d(work_ctx, + hidden_states, + chunk_hidden_states->ne[0], + ggml_nelements(hidden_states) / chunk_hidden_states->ne[0]); + ggml_tensor* vec = NULL; if (version == VERSION_XL) { int out_dim = 256; @@ -547,7 +565,7 @@ public: GGML_ASSERT(offset == ggml_nbytes(vec)); } // print_ggml_tensor(result); - return {result, vec}; + return {hidden_states, vec}; } std::tuple get_svd_condition(ggml_context* work_ctx, diff --git a/util.cpp b/util.cpp index 346886b..a4e74a3 100644 --- a/util.cpp +++ b/util.cpp @@ -266,14 +266,14 @@ void log_printf(sd_log_level_t level, const char* file, int line, const char* fo level_str = "ERROR"; } - static char log_buffer[LOG_BUFFER_SIZE]; + static char log_buffer[LOG_BUFFER_SIZE + 1]; int written = snprintf(log_buffer, LOG_BUFFER_SIZE, "[%s] %s:%-4d - ", level_str, sd_basename(file).c_str(), line); if (written >= 0 && written < LOG_BUFFER_SIZE) { vsnprintf(log_buffer + written, LOG_BUFFER_SIZE - written, format, args); - strncat(log_buffer, "\n", LOG_BUFFER_SIZE - strlen(log_buffer) - 1); } + strncat(log_buffer, "\n", LOG_BUFFER_SIZE - strlen(log_buffer)); if (sd_log_cb) { sd_log_cb(level, log_buffer, sd_log_cb_data);