fix: suport sdxl embedddings (#621)

This commit is contained in:
stduhpf 2025-03-09 05:21:23 +01:00 committed by GitHub
parent 30b3ac8e62
commit 3fb275a67b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -51,7 +51,8 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
std::string trigger_word = "img"; // should be user settable
std::string embd_dir;
int32_t num_custom_embeddings = 0;
int32_t num_custom_embeddings = 0;
int32_t num_custom_embeddings_2 = 0;
std::vector<uint8_t> token_embed_custom;
std::vector<std::string> readed_embeddings;
@ -131,28 +132,55 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
params.no_alloc = false;
struct ggml_context* embd_ctx = ggml_init(params);
struct ggml_tensor* embd = NULL;
int64_t hidden_size = text_model->model.hidden_size;
struct ggml_tensor* embd2 = NULL;
auto on_load = [&](const TensorStorage& tensor_storage, ggml_tensor** dst_tensor) {
if (tensor_storage.ne[0] != hidden_size) {
LOG_DEBUG("embedding wrong hidden size, got %i, expected %i", tensor_storage.ne[0], hidden_size);
return false;
if (tensor_storage.ne[0] != text_model->model.hidden_size) {
if (text_model2) {
if (tensor_storage.ne[0] == text_model2->model.hidden_size) {
embd2 = ggml_new_tensor_2d(embd_ctx, tensor_storage.type, text_model2->model.hidden_size, tensor_storage.n_dims > 1 ? tensor_storage.ne[1] : 1);
*dst_tensor = embd2;
} else {
LOG_DEBUG("embedding wrong hidden size, got %i, expected %i or %i", tensor_storage.ne[0], text_model->model.hidden_size, text_model2->model.hidden_size);
return false;
}
} else {
LOG_DEBUG("embedding wrong hidden size, got %i, expected %i", tensor_storage.ne[0], text_model->model.hidden_size);
return false;
}
} else {
embd = ggml_new_tensor_2d(embd_ctx, tensor_storage.type, text_model->model.hidden_size, tensor_storage.n_dims > 1 ? tensor_storage.ne[1] : 1);
*dst_tensor = embd;
}
embd = ggml_new_tensor_2d(embd_ctx, tensor_storage.type, hidden_size, tensor_storage.n_dims > 1 ? tensor_storage.ne[1] : 1);
*dst_tensor = embd;
return true;
};
model_loader.load_tensors(on_load, NULL);
readed_embeddings.push_back(embd_name);
token_embed_custom.resize(token_embed_custom.size() + ggml_nbytes(embd));
memcpy((void*)(token_embed_custom.data() + num_custom_embeddings * hidden_size * ggml_type_size(embd->type)),
embd->data,
ggml_nbytes(embd));
for (int i = 0; i < embd->ne[1]; i++) {
bpe_tokens.push_back(text_model->model.vocab_size + num_custom_embeddings);
// LOG_DEBUG("new custom token: %i", text_model.vocab_size + num_custom_embeddings);
num_custom_embeddings++;
if (embd) {
int64_t hidden_size = text_model->model.hidden_size;
token_embed_custom.resize(token_embed_custom.size() + ggml_nbytes(embd));
memcpy((void*)(token_embed_custom.data() + num_custom_embeddings * hidden_size * ggml_type_size(embd->type)),
embd->data,
ggml_nbytes(embd));
for (int i = 0; i < embd->ne[1]; i++) {
bpe_tokens.push_back(text_model->model.vocab_size + num_custom_embeddings);
// LOG_DEBUG("new custom token: %i", text_model.vocab_size + num_custom_embeddings);
num_custom_embeddings++;
}
LOG_DEBUG("embedding '%s' applied, custom embeddings: %i", embd_name.c_str(), num_custom_embeddings);
}
if (embd2) {
int64_t hidden_size = text_model2->model.hidden_size;
token_embed_custom.resize(token_embed_custom.size() + ggml_nbytes(embd2));
memcpy((void*)(token_embed_custom.data() + num_custom_embeddings_2 * hidden_size * ggml_type_size(embd2->type)),
embd2->data,
ggml_nbytes(embd2));
for (int i = 0; i < embd2->ne[1]; i++) {
bpe_tokens.push_back(text_model2->model.vocab_size + num_custom_embeddings_2);
// LOG_DEBUG("new custom token: %i", text_model.vocab_size + num_custom_embeddings);
num_custom_embeddings_2++;
}
LOG_DEBUG("embedding '%s' applied, custom embeddings: %i (text model 2)", embd_name.c_str(), num_custom_embeddings_2);
}
LOG_DEBUG("embedding '%s' applied, custom embeddings: %i", embd_name.c_str(), num_custom_embeddings);
return true;
}