fix: improve clip text_projection support (#397)

This commit is contained in:
stduhpf 2024-11-23 04:19:27 +01:00 committed by GitHub
parent 65fa646684
commit 9b1d90bc23
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 37 additions and 44 deletions

View File

@ -711,8 +711,12 @@ public:
if (return_pooled) { if (return_pooled) {
auto text_projection = params["text_projection"]; auto text_projection = params["text_projection"];
ggml_tensor* pooled = ggml_view_1d(ctx, x, hidden_size, x->nb[1] * max_token_idx); ggml_tensor* pooled = ggml_view_1d(ctx, x, hidden_size, x->nb[1] * max_token_idx);
pooled = ggml_mul_mat(ctx, ggml_cont(ctx, ggml_transpose(ctx, text_projection)), pooled); if (text_projection != NULL) {
return pooled; pooled = ggml_nn_linear(ctx, pooled, text_projection, NULL);
} else {
LOG_DEBUG("Missing text_projection matrix, assuming identity...");
}
return pooled; // [hidden_size, 1, 1]
} }
return x; // [N, n_token, hidden_size] return x; // [N, n_token, hidden_size]

View File

@ -798,21 +798,17 @@ struct SD3CLIPEmbedder : public Conditioner {
} }
if (chunk_idx == 0) { if (chunk_idx == 0) {
// auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_l_tokenizer.EOS_TOKEN_ID); auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_l_tokenizer.EOS_TOKEN_ID);
// max_token_idx = std::min<size_t>(std::distance(chunk_tokens.begin(), it), chunk_tokens.size() - 1); max_token_idx = std::min<size_t>(std::distance(chunk_tokens.begin(), it), chunk_tokens.size() - 1);
// clip_l->compute(n_threads, clip_l->compute(n_threads,
// input_ids, input_ids,
// 0, 0,
// NULL, NULL,
// max_token_idx, max_token_idx,
// true, true,
// &pooled_l, &pooled_l,
// work_ctx); work_ctx);
// clip_l.transformer.text_model.text_projection no in file, ignore
// TODO: use torch.eye(embed_dim) as default clip_l.transformer.text_model.text_projection
pooled_l = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, 768);
ggml_set_f32(pooled_l, 0.f);
} }
} }
@ -852,21 +848,17 @@ struct SD3CLIPEmbedder : public Conditioner {
} }
if (chunk_idx == 0) { if (chunk_idx == 0) {
// auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_g_tokenizer.EOS_TOKEN_ID); auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_g_tokenizer.EOS_TOKEN_ID);
// max_token_idx = std::min<size_t>(std::distance(chunk_tokens.begin(), it), chunk_tokens.size() - 1); max_token_idx = std::min<size_t>(std::distance(chunk_tokens.begin(), it), chunk_tokens.size() - 1);
// clip_g->compute(n_threads, clip_g->compute(n_threads,
// input_ids, input_ids,
// 0, 0,
// NULL, NULL,
// max_token_idx, max_token_idx,
// true, true,
// &pooled_g, &pooled_g,
// work_ctx); work_ctx);
// clip_l.transformer.text_model.text_projection no in file, ignore pooled_g too
// TODO: fix pooled_g
pooled_g = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, 1280);
ggml_set_f32(pooled_g, 0.f);
} }
} }
@ -1104,21 +1096,18 @@ struct FluxCLIPEmbedder : public Conditioner {
auto input_ids = vector_to_ggml_tensor_i32(work_ctx, chunk_tokens); auto input_ids = vector_to_ggml_tensor_i32(work_ctx, chunk_tokens);
size_t max_token_idx = 0; size_t max_token_idx = 0;
// auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_l_tokenizer.EOS_TOKEN_ID); auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_l_tokenizer.EOS_TOKEN_ID);
// max_token_idx = std::min<size_t>(std::distance(chunk_tokens.begin(), it), chunk_tokens.size() - 1); max_token_idx = std::min<size_t>(std::distance(chunk_tokens.begin(), it), chunk_tokens.size() - 1);
// clip_l->compute(n_threads,
// input_ids, clip_l->compute(n_threads,
// 0, input_ids,
// NULL, 0,
// max_token_idx, NULL,
// true, max_token_idx,
// &pooled, true,
// work_ctx); &pooled,
work_ctx);
// clip_l.transformer.text_model.text_projection no in file, ignore
// TODO: use torch.eye(embed_dim) as default clip_l.transformer.text_model.text_projection
pooled = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, 768);
ggml_set_f32(pooled, 0.f);
} }
// t5 // t5