feat: add support for prompt longer than 77
This commit is contained in:
parent
b7870a0f89
commit
ef5c3f7401
70
clip.hpp
70
clip.hpp
@ -558,11 +558,14 @@ public:
|
|||||||
auto token_embed_weight = params["token_embedding.weight"];
|
auto token_embed_weight = params["token_embedding.weight"];
|
||||||
auto position_embed_weight = params["position_embedding.weight"];
|
auto position_embed_weight = params["position_embedding.weight"];
|
||||||
|
|
||||||
GGML_ASSERT(input_ids->ne[0] <= position_embed_weight->ne[0]);
|
GGML_ASSERT(input_ids->ne[0] == position_embed_weight->ne[1]);
|
||||||
|
input_ids = ggml_reshape_3d(ctx, input_ids, input_ids->ne[0], 1, input_ids->ne[1]);
|
||||||
|
auto token_embedding = ggml_get_rows(ctx, custom_embed_weight != NULL ? custom_embed_weight : token_embed_weight, input_ids);
|
||||||
|
token_embedding = ggml_reshape_3d(ctx, token_embedding, token_embedding->ne[0], token_embedding->ne[1], token_embedding->ne[3]);
|
||||||
|
|
||||||
// token_embedding + position_embedding
|
// token_embedding + position_embedding
|
||||||
auto x = ggml_add(ctx,
|
auto x = ggml_add(ctx,
|
||||||
ggml_get_rows(ctx, custom_embed_weight != NULL ? custom_embed_weight : token_embed_weight, input_ids),
|
token_embedding,
|
||||||
position_embed_weight); // [N, n_token, embed_dim]
|
position_embed_weight); // [N, n_token, embed_dim]
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
@ -889,7 +892,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public GGMLModule {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
struct ggml_init_params params;
|
struct ggml_init_params params;
|
||||||
params.mem_size = 32 * 1024; // max for custom embeddings 32 KB
|
params.mem_size = 10 * 1024 * 1024; // max for custom embeddings 10 MB
|
||||||
params.mem_buffer = NULL;
|
params.mem_buffer = NULL;
|
||||||
params.no_alloc = false;
|
params.no_alloc = false;
|
||||||
struct ggml_context* embd_ctx = ggml_init(params);
|
struct ggml_context* embd_ctx = ggml_init(params);
|
||||||
@ -924,9 +927,21 @@ struct FrozenCLIPEmbedderWithCustomWords : public GGMLModule {
|
|||||||
struct ggml_tensor* embeddings,
|
struct ggml_tensor* embeddings,
|
||||||
size_t max_token_idx = 0,
|
size_t max_token_idx = 0,
|
||||||
bool return_pooled = false) {
|
bool return_pooled = false) {
|
||||||
|
size_t N = input_ids->ne[1];
|
||||||
|
size_t n_token = input_ids->ne[0];
|
||||||
|
if (input_ids != NULL && input_ids->ne[0] > text_model.n_token) {
|
||||||
|
GGML_ASSERT(input_ids->ne[0] % text_model.n_token == 0);
|
||||||
|
input_ids = ggml_reshape_2d(ctx, input_ids, text_model.n_token, input_ids->ne[0] / text_model.n_token);
|
||||||
|
}
|
||||||
|
if (input_ids2 != NULL && input_ids2->ne[0] > text_model2.n_token) {
|
||||||
|
GGML_ASSERT(input_ids2->ne[0] % text_model2.n_token == 0);
|
||||||
|
input_ids2 = ggml_reshape_2d(ctx, input_ids2, text_model2.n_token, input_ids2->ne[0] / text_model2.n_token);
|
||||||
|
}
|
||||||
|
|
||||||
if (return_pooled) {
|
if (return_pooled) {
|
||||||
return text_model2.forward(ctx, input_ids2, NULL, max_token_idx, return_pooled);
|
return text_model2.forward(ctx, input_ids2, NULL, max_token_idx, return_pooled);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto hidden_states = text_model.forward(ctx, input_ids, embeddings); // [N, n_token, hidden_size]
|
auto hidden_states = text_model.forward(ctx, input_ids, embeddings); // [N, n_token, hidden_size]
|
||||||
// LOG_DEBUG("hidden_states: %d %d %d %d", hidden_states->ne[0], hidden_states->ne[1], hidden_states->ne[2], hidden_states->ne[3]);
|
// LOG_DEBUG("hidden_states: %d %d %d %d", hidden_states->ne[0], hidden_states->ne[1], hidden_states->ne[2], hidden_states->ne[3]);
|
||||||
if (version == VERSION_XL) {
|
if (version == VERSION_XL) {
|
||||||
@ -952,6 +967,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public GGMLModule {
|
|||||||
|
|
||||||
hidden_states = ggml_cont(ctx, ggml_permute(ctx, hidden_states, 1, 2, 0, 3));
|
hidden_states = ggml_cont(ctx, ggml_permute(ctx, hidden_states, 1, 2, 0, 3));
|
||||||
}
|
}
|
||||||
|
hidden_states = ggml_reshape_3d(ctx, hidden_states, hidden_states->ne[0], n_token, N);
|
||||||
// LOG_DEBUG("hidden_states: %d %d %d %d", hidden_states->ne[0], hidden_states->ne[1], hidden_states->ne[2], hidden_states->ne[3]);
|
// LOG_DEBUG("hidden_states: %d %d %d %d", hidden_states->ne[0], hidden_states->ne[1], hidden_states->ne[2], hidden_states->ne[3]);
|
||||||
return hidden_states;
|
return hidden_states;
|
||||||
}
|
}
|
||||||
@ -1057,26 +1073,48 @@ struct FrozenCLIPEmbedderWithCustomWords : public GGMLModule {
|
|||||||
tokens.insert(tokens.end(), curr_tokens.begin(), curr_tokens.end());
|
tokens.insert(tokens.end(), curr_tokens.begin(), curr_tokens.end());
|
||||||
weights.insert(weights.end(), curr_tokens.size(), curr_weight);
|
weights.insert(weights.end(), curr_tokens.size(), curr_weight);
|
||||||
}
|
}
|
||||||
tokens.insert(tokens.begin(), BOS_TOKEN_ID);
|
|
||||||
weights.insert(weights.begin(), 1.0);
|
|
||||||
|
|
||||||
if (max_length > 0) {
|
if (max_length > 0 && padding) {
|
||||||
if (tokens.size() > max_length - 1) {
|
size_t n = std::ceil(tokens.size() * 1.0 / (max_length - 2));
|
||||||
tokens.resize(max_length - 1);
|
if (n == 0) {
|
||||||
weights.resize(max_length - 1);
|
n = 1;
|
||||||
tokens.push_back(EOS_TOKEN_ID);
|
}
|
||||||
weights.push_back(1.0);
|
size_t length = max_length * n;
|
||||||
|
LOG_DEBUG("token length: %llu", length);
|
||||||
|
std::vector<int> new_tokens;
|
||||||
|
std::vector<float> new_weights;
|
||||||
|
new_tokens.push_back(BOS_TOKEN_ID);
|
||||||
|
new_weights.push_back(1.0);
|
||||||
|
int token_idx = 0;
|
||||||
|
for (int i = 1; i < length; i++) {
|
||||||
|
if (token_idx >= tokens.size()) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if (i % max_length == 0) {
|
||||||
|
new_tokens.push_back(BOS_TOKEN_ID);
|
||||||
|
new_weights.push_back(1.0);
|
||||||
|
} else if (i % max_length == max_length - 1) {
|
||||||
|
new_tokens.push_back(EOS_TOKEN_ID);
|
||||||
|
new_weights.push_back(1.0);
|
||||||
} else {
|
} else {
|
||||||
tokens.push_back(EOS_TOKEN_ID);
|
new_tokens.push_back(tokens[token_idx]);
|
||||||
weights.push_back(1.0);
|
new_weights.push_back(weights[token_idx]);
|
||||||
|
token_idx++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
new_tokens.push_back(EOS_TOKEN_ID);
|
||||||
|
new_weights.push_back(1.0);
|
||||||
|
tokens = new_tokens;
|
||||||
|
weights = new_weights;
|
||||||
|
|
||||||
if (padding) {
|
if (padding) {
|
||||||
int pad_token_id = PAD_TOKEN_ID;
|
int pad_token_id = PAD_TOKEN_ID;
|
||||||
if (version == VERSION_2_x) {
|
if (version == VERSION_2_x) {
|
||||||
pad_token_id = 0;
|
pad_token_id = 0;
|
||||||
}
|
}
|
||||||
tokens.insert(tokens.end(), max_length - tokens.size(), pad_token_id);
|
tokens.insert(tokens.end(), length - tokens.size(), pad_token_id);
|
||||||
weights.insert(weights.end(), max_length - weights.size(), 1.0);
|
weights.insert(weights.end(), length - weights.size(), 1.0);
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1233,7 +1233,7 @@ public:
|
|||||||
struct ggml_tensor* k = k_proj->forward(ctx, x);
|
struct ggml_tensor* k = k_proj->forward(ctx, x);
|
||||||
k = ggml_reshape_4d(ctx, k, d_head, n_head, n_token, N); // [N, n_token, n_head, d_head]
|
k = ggml_reshape_4d(ctx, k, d_head, n_head, n_token, N); // [N, n_token, n_head, d_head]
|
||||||
k = ggml_cont(ctx, ggml_permute(ctx, k, 0, 2, 1, 3)); // [N, n_head, n_token, d_head]
|
k = ggml_cont(ctx, ggml_permute(ctx, k, 0, 2, 1, 3)); // [N, n_head, n_token, d_head]
|
||||||
k = ggml_reshape_3d(ctx, k, d_head, n_token, n_head); // [N * n_head, n_token, d_head]
|
k = ggml_reshape_3d(ctx, k, d_head, n_token, n_head * N); // [N * n_head, n_token, d_head]
|
||||||
|
|
||||||
struct ggml_tensor* v = v_proj->forward(ctx, x);
|
struct ggml_tensor* v = v_proj->forward(ctx, x);
|
||||||
v = ggml_reshape_4d(ctx, v, d_head, n_head, n_token, N); // [N, n_token, n_head, d_head]
|
v = ggml_reshape_4d(ctx, v, d_head, n_head, n_token, N); // [N, n_token, n_head, d_head]
|
||||||
@ -1245,7 +1245,7 @@ public:
|
|||||||
kqv = ggml_reshape_4d(ctx, kqv, d_head, n_token, n_head, N);
|
kqv = ggml_reshape_4d(ctx, kqv, d_head, n_token, n_head, N);
|
||||||
kqv = ggml_cont(ctx, ggml_permute(ctx, kqv, 0, 2, 1, 3)); // [N, n_token, n_head, d_head]
|
kqv = ggml_cont(ctx, ggml_permute(ctx, kqv, 0, 2, 1, 3)); // [N, n_token, n_head, d_head]
|
||||||
|
|
||||||
x = ggml_reshape_2d(ctx, kqv, d_head * n_head, n_token * N); // [N * n_token, d_head * n_head]
|
x = ggml_reshape_3d(ctx, kqv, d_head * n_head, n_token, N); // [N * n_token, d_head * n_head]
|
||||||
|
|
||||||
x = out_proj->forward(ctx, x);
|
x = out_proj->forward(ctx, x);
|
||||||
return x;
|
return x;
|
||||||
|
@ -456,46 +456,56 @@ public:
|
|||||||
std::vector<float>& weights = tokens_and_weights.second;
|
std::vector<float>& weights = tokens_and_weights.second;
|
||||||
int64_t t0 = ggml_time_ms();
|
int64_t t0 = ggml_time_ms();
|
||||||
struct ggml_tensor* hidden_states = NULL; // [N, n_token, hidden_size]
|
struct ggml_tensor* hidden_states = NULL; // [N, n_token, hidden_size]
|
||||||
|
struct ggml_tensor* chunk_hidden_states = NULL; // [n_token, hidden_size]
|
||||||
struct ggml_tensor* pooled = NULL;
|
struct ggml_tensor* pooled = NULL;
|
||||||
|
std::vector<float> hidden_states_vec;
|
||||||
|
|
||||||
auto input_ids = vector_to_ggml_tensor_i32(work_ctx, tokens);
|
size_t chunk_len = 77;
|
||||||
|
size_t chunk_count = tokens.size() / chunk_len;
|
||||||
|
for (int chunk_idx = 0; chunk_idx < chunk_count; chunk_idx++) {
|
||||||
|
std::vector<int> chunk_tokens(tokens.begin() + chunk_idx * chunk_len,
|
||||||
|
tokens.begin() + (chunk_idx + 1) * chunk_len);
|
||||||
|
std::vector<float> chunk_weights(weights.begin() + chunk_idx * chunk_len,
|
||||||
|
weights.begin() + (chunk_idx + 1) * chunk_len);
|
||||||
|
|
||||||
|
auto input_ids = vector_to_ggml_tensor_i32(work_ctx, chunk_tokens);
|
||||||
struct ggml_tensor* input_ids2 = NULL;
|
struct ggml_tensor* input_ids2 = NULL;
|
||||||
size_t max_token_idx = 0;
|
size_t max_token_idx = 0;
|
||||||
if (version == VERSION_XL) {
|
if (version == VERSION_XL) {
|
||||||
auto it = std::find(tokens.begin(), tokens.end(), EOS_TOKEN_ID);
|
auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), EOS_TOKEN_ID);
|
||||||
if (it != tokens.end()) {
|
if (it != chunk_tokens.end()) {
|
||||||
std::fill(std::next(it), tokens.end(), 0);
|
std::fill(std::next(it), chunk_tokens.end(), 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
max_token_idx = std::min<size_t>(std::distance(tokens.begin(), it), tokens.size() - 1);
|
max_token_idx = std::min<size_t>(std::distance(chunk_tokens.begin(), it), chunk_tokens.size() - 1);
|
||||||
|
|
||||||
input_ids2 = vector_to_ggml_tensor_i32(work_ctx, tokens);
|
input_ids2 = vector_to_ggml_tensor_i32(work_ctx, chunk_tokens);
|
||||||
|
|
||||||
// for (int i = 0; i < tokens.size(); i++) {
|
// for (int i = 0; i < chunk_tokens.size(); i++) {
|
||||||
// printf("%d ", tokens[i]);
|
// printf("%d ", chunk_tokens[i]);
|
||||||
// }
|
// }
|
||||||
// printf("\n");
|
// printf("\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
cond_stage_model->compute(n_threads, input_ids, input_ids2, max_token_idx, false, &hidden_states, work_ctx);
|
cond_stage_model->compute(n_threads, input_ids, input_ids2, max_token_idx, false, &chunk_hidden_states, work_ctx);
|
||||||
if (version == VERSION_XL) {
|
if (version == VERSION_XL && chunk_idx == 0) {
|
||||||
cond_stage_model->compute(n_threads, input_ids, input_ids2, max_token_idx, true, &pooled, work_ctx);
|
cond_stage_model->compute(n_threads, input_ids, input_ids2, max_token_idx, true, &pooled, work_ctx);
|
||||||
}
|
}
|
||||||
// if (pooled != NULL) {
|
// if (pooled != NULL) {
|
||||||
// print_ggml_tensor(hidden_states);
|
// print_ggml_tensor(chunk_hidden_states);
|
||||||
// print_ggml_tensor(pooled);
|
// print_ggml_tensor(pooled);
|
||||||
// }
|
// }
|
||||||
|
|
||||||
int64_t t1 = ggml_time_ms();
|
int64_t t1 = ggml_time_ms();
|
||||||
LOG_DEBUG("computing condition graph completed, taking %" PRId64 " ms", t1 - t0);
|
LOG_DEBUG("computing condition graph completed, taking %" PRId64 " ms", t1 - t0);
|
||||||
ggml_tensor* result = ggml_dup_tensor(work_ctx, hidden_states);
|
ggml_tensor* result = ggml_dup_tensor(work_ctx, chunk_hidden_states);
|
||||||
{
|
{
|
||||||
float original_mean = ggml_tensor_mean(hidden_states);
|
float original_mean = ggml_tensor_mean(chunk_hidden_states);
|
||||||
for (int i2 = 0; i2 < hidden_states->ne[2]; i2++) {
|
for (int i2 = 0; i2 < chunk_hidden_states->ne[2]; i2++) {
|
||||||
for (int i1 = 0; i1 < hidden_states->ne[1]; i1++) {
|
for (int i1 = 0; i1 < chunk_hidden_states->ne[1]; i1++) {
|
||||||
for (int i0 = 0; i0 < hidden_states->ne[0]; i0++) {
|
for (int i0 = 0; i0 < chunk_hidden_states->ne[0]; i0++) {
|
||||||
float value = ggml_tensor_get_f32(hidden_states, i0, i1, i2);
|
float value = ggml_tensor_get_f32(chunk_hidden_states, i0, i1, i2);
|
||||||
value *= weights[i1];
|
value *= chunk_weights[i1];
|
||||||
ggml_tensor_set_f32(result, value, i0, i1, i2);
|
ggml_tensor_set_f32(result, value, i0, i1, i2);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -509,6 +519,14 @@ public:
|
|||||||
vec[i] = 0;
|
vec[i] = 0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
hidden_states_vec.insert(hidden_states_vec.end(), (float*)result->data, ((float*)result->data) + ggml_nelements(result));
|
||||||
|
}
|
||||||
|
|
||||||
|
hidden_states = vector_to_ggml_tensor(work_ctx, hidden_states_vec);
|
||||||
|
hidden_states = ggml_reshape_2d(work_ctx,
|
||||||
|
hidden_states,
|
||||||
|
chunk_hidden_states->ne[0],
|
||||||
|
ggml_nelements(hidden_states) / chunk_hidden_states->ne[0]);
|
||||||
|
|
||||||
ggml_tensor* vec = NULL;
|
ggml_tensor* vec = NULL;
|
||||||
if (version == VERSION_XL) {
|
if (version == VERSION_XL) {
|
||||||
@ -547,7 +565,7 @@ public:
|
|||||||
GGML_ASSERT(offset == ggml_nbytes(vec));
|
GGML_ASSERT(offset == ggml_nbytes(vec));
|
||||||
}
|
}
|
||||||
// print_ggml_tensor(result);
|
// print_ggml_tensor(result);
|
||||||
return {result, vec};
|
return {hidden_states, vec};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<ggml_tensor*, ggml_tensor*, ggml_tensor*> get_svd_condition(ggml_context* work_ctx,
|
std::tuple<ggml_tensor*, ggml_tensor*, ggml_tensor*> get_svd_condition(ggml_context* work_ctx,
|
||||||
|
4
util.cpp
4
util.cpp
@ -266,14 +266,14 @@ void log_printf(sd_log_level_t level, const char* file, int line, const char* fo
|
|||||||
level_str = "ERROR";
|
level_str = "ERROR";
|
||||||
}
|
}
|
||||||
|
|
||||||
static char log_buffer[LOG_BUFFER_SIZE];
|
static char log_buffer[LOG_BUFFER_SIZE + 1];
|
||||||
|
|
||||||
int written = snprintf(log_buffer, LOG_BUFFER_SIZE, "[%s] %s:%-4d - ", level_str, sd_basename(file).c_str(), line);
|
int written = snprintf(log_buffer, LOG_BUFFER_SIZE, "[%s] %s:%-4d - ", level_str, sd_basename(file).c_str(), line);
|
||||||
|
|
||||||
if (written >= 0 && written < LOG_BUFFER_SIZE) {
|
if (written >= 0 && written < LOG_BUFFER_SIZE) {
|
||||||
vsnprintf(log_buffer + written, LOG_BUFFER_SIZE - written, format, args);
|
vsnprintf(log_buffer + written, LOG_BUFFER_SIZE - written, format, args);
|
||||||
strncat(log_buffer, "\n", LOG_BUFFER_SIZE - strlen(log_buffer) - 1);
|
|
||||||
}
|
}
|
||||||
|
strncat(log_buffer, "\n", LOG_BUFFER_SIZE - strlen(log_buffer));
|
||||||
|
|
||||||
if (sd_log_cb) {
|
if (sd_log_cb) {
|
||||||
sd_log_cb(level, log_buffer, sd_log_cb_data);
|
sd_log_cb(level, log_buffer, sd_log_cb_data);
|
||||||
|
Loading…
Reference in New Issue
Block a user