diff --git a/ggml_extend.hpp b/ggml_extend.hpp index 035f088..6a187ca 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -52,6 +52,71 @@ #define __STATIC_INLINE__ static inline #endif +// n-mode trensor-matrix product +// example: 2-mode product +// A: [ne03, k, ne01, ne00] +// B: k rows, m columns => [k, m] +// result is [ne03, m, ne01, ne00] +__STATIC_INLINE__ struct ggml_tensor* ggml_mul_n_mode(struct ggml_context* ctx, struct ggml_tensor* a, struct ggml_tensor* b, int mode = 0) { + // reshape A + // swap 0th and nth axis + a = ggml_cont(ctx, ggml_permute(ctx, a, mode, mode != 1 ? 1 : 0, mode != 2 ? 2 : 0, mode != 3 ? 3 : 0)); + int ne1 = a->ne[1]; + int ne2 = a->ne[2]; + int ne3 = a->ne[3]; + // make 2D + a = ggml_cont(ctx, ggml_reshape_2d(ctx, a, a->ne[0], (ne3 * ne2 * ne1))); + + struct ggml_tensor* result = ggml_cont(ctx, ggml_transpose(ctx, ggml_mul_mat(ctx, a, b))); + + // reshape output (same shape as a after permutation except first dim) + result = ggml_reshape_4d(ctx, result, result->ne[0], ne1, ne2, ne3); + // swap back 0th and nth axis + result = ggml_permute(ctx, result, mode, mode != 1 ? 1 : 0, mode != 2 ? 2 : 0, mode != 3 ? 3 : 0); + return result; +} + +__STATIC_INLINE__ struct ggml_tensor* ggml_merge_lora(ggml_context* ctx, struct ggml_tensor* lora_down, struct ggml_tensor* lora_up, struct ggml_tensor* lora_mid = NULL) { + struct ggml_tensor* updown; + // flat lora tensors to multiply it + int64_t lora_up_rows = lora_up->ne[ggml_n_dims(lora_up) - 1]; + lora_up = ggml_reshape_2d(ctx, lora_up, ggml_nelements(lora_up) / lora_up_rows, lora_up_rows); + auto lora_down_n_dims = ggml_n_dims(lora_down); + // assume n_dims should always be a multiple of 2 (otherwise rank 1 doesn't work) + lora_down_n_dims = (lora_down_n_dims + lora_down_n_dims % 2); + int64_t lora_down_rows = lora_down->ne[lora_down_n_dims - 1]; + lora_down = ggml_reshape_2d(ctx, lora_down, ggml_nelements(lora_down) / lora_down_rows, lora_down_rows); + + // ggml_mul_mat requires tensor b transposed + lora_down = ggml_cont(ctx, ggml_transpose(ctx, lora_down)); + if (lora_mid == NULL) { + updown = ggml_mul_mat(ctx, lora_up, lora_down); + updown = ggml_cont(ctx, ggml_transpose(ctx, updown)); + } else { + // undoing tucker decomposition for conv layers. + // lora_mid has shape (3, 3, Rank, Rank) + // lora_down has shape (Rank, In, 1, 1) + // lora_up has shape (Rank, Out, 1, 1) + // conv layer shape is (3, 3, Out, In) + updown = ggml_mul_n_mode(ctx, ggml_mul_n_mode(ctx, lora_mid, lora_down, 3), lora_up, 2); + updown = ggml_cont(ctx, updown); + } + return updown; +} + +// Kronecker product +// [ne03,ne02,ne01,ne00] x [ne13,ne12,ne11,ne10] => [ne03*ne13,ne02*ne12,ne01*ne11,ne00*ne10] +__STATIC_INLINE__ struct ggml_tensor* ggml_kronecker(ggml_context* ctx, struct ggml_tensor* a, struct ggml_tensor* b) { + return ggml_mul(ctx, + ggml_upscale_ext(ctx, + a, + a->ne[0] * b->ne[0], + a->ne[1] * b->ne[1], + a->ne[2] * b->ne[2], + a->ne[3] * b->ne[3]), + b); +} + __STATIC_INLINE__ void ggml_log_callback_default(ggml_log_level level, const char* text, void* user_data) { (void)level; (void)user_data; @@ -319,7 +384,7 @@ __STATIC_INLINE__ void sd_apply_mask(struct ggml_tensor* image_data, for (int iy = 0; iy < height; iy++) { float m = ggml_tensor_get_f32(mask, ix, iy); for (int k = 0; k < channels; k++) { - float value = ((float)(m < 254.5/255)) * (ggml_tensor_get_f32(image_data, ix, iy, k) - .5) + .5; + float value = ((float)(m < 254.5 / 255)) * (ggml_tensor_get_f32(image_data, ix, iy, k) - .5) + .5; ggml_tensor_set_f32(output, value, ix, iy, k); } } @@ -987,8 +1052,8 @@ __STATIC_INLINE__ size_t ggml_tensor_num(ggml_context* ctx) { } /* SDXL with LoRA requires more space */ -#define MAX_PARAMS_TENSOR_NUM 15360 -#define MAX_GRAPH_SIZE 15360 +#define MAX_PARAMS_TENSOR_NUM 32768 +#define MAX_GRAPH_SIZE 32768 struct GGMLRunner { protected: diff --git a/lora.hpp b/lora.hpp index ea1d03e..d38c711 100644 --- a/lora.hpp +++ b/lora.hpp @@ -197,6 +197,10 @@ struct LoraModel : public GGMLRunner { blk_name.replace(blk_name.find(".joint_blocks"), sizeof(".joint_blocks") - 1, ".transformer_blocks"); } + if (blk_name.find("text_encoders.clip_l") != std::string::npos) { + blk_name.replace(blk_name.find("text_encoders.clip_l"), sizeof("text_encoders.clip_l") - 1, "cond_stage_model"); + } + for (const auto& item : alt_names) { size_t match = blk_name.find(item.first); if (match != std::string::npos) { @@ -217,13 +221,17 @@ struct LoraModel : public GGMLRunner { keys.push_back(split_blk); } } + keys.push_back(blk_name); } - keys.push_back(blk_name); std::vector ret; for (std::string& key : keys) { ret.push_back(key); replace_all_chars(key, '.', '_'); + // fix for some sdxl lora, like lcm-lora-xl + if (key == "model_diffusion_model_output_blocks_2_2_conv") { + ret.push_back("model_diffusion_model_output_blocks_2_1_conv"); + } ret.push_back(key); } return ret; @@ -244,390 +252,545 @@ struct LoraModel : public GGMLRunner { std::vector keys = to_lora_keys(k_tensor, version); if (keys.size() == 0) continue; - ggml_tensor* lora_up = NULL; - ggml_tensor* lora_down = NULL; + for (auto& key : keys) { - std::string alpha_name = ""; - std::string scale_name = ""; - std::string split_q_scale_name = ""; - std::string lora_down_name = ""; - std::string lora_up_name = ""; - - if (starts_with(key, "SPLIT|")) { + bool is_qkv_split = starts_with(key, "SPLIT|"); + if (is_qkv_split) { key = key.substr(sizeof("SPLIT|") - 1); - // TODO: Handle alphas - std::string suffix = ""; - auto split_q_d_name = lora_pre[type] + key + "q" + suffix + lora_downs[type] + ".weight"; - - if (lora_tensors.find(split_q_d_name) == lora_tensors.end()) { - suffix = "_proj"; - split_q_d_name = lora_pre[type] + key + "q" + suffix + lora_downs[type] + ".weight"; - } - if (lora_tensors.find(split_q_d_name) != lora_tensors.end()) { - // print_ggml_tensor(it.second, true); //[3072, 21504, 1, 1] - // find qkv and mlp up parts in LoRA model - auto split_k_d_name = lora_pre[type] + key + "k" + suffix + lora_downs[type] + ".weight"; - auto split_v_d_name = lora_pre[type] + key + "v" + suffix + lora_downs[type] + ".weight"; - - auto split_q_u_name = lora_pre[type] + key + "q" + suffix + lora_ups[type] + ".weight"; - auto split_k_u_name = lora_pre[type] + key + "k" + suffix + lora_ups[type] + ".weight"; - auto split_v_u_name = lora_pre[type] + key + "v" + suffix + lora_ups[type] + ".weight"; - - auto split_q_scale_name = lora_pre[type] + key + "q" + suffix + ".scale"; - auto split_k_scale_name = lora_pre[type] + key + "k" + suffix + ".scale"; - auto split_v_scale_name = lora_pre[type] + key + "v" + suffix + ".scale"; - - auto split_q_alpha_name = lora_pre[type] + key + "q" + suffix + ".alpha"; - auto split_k_alpha_name = lora_pre[type] + key + "k" + suffix + ".alpha"; - auto split_v_alpha_name = lora_pre[type] + key + "v" + suffix + ".alpha"; - - ggml_tensor* lora_q_down = NULL; - ggml_tensor* lora_q_up = NULL; - ggml_tensor* lora_k_down = NULL; - ggml_tensor* lora_k_up = NULL; - ggml_tensor* lora_v_down = NULL; - ggml_tensor* lora_v_up = NULL; - - lora_q_down = to_f32(compute_ctx, lora_tensors[split_q_d_name]); - - if (lora_tensors.find(split_q_u_name) != lora_tensors.end()) { - lora_q_up = to_f32(compute_ctx, lora_tensors[split_q_u_name]); - } - - if (lora_tensors.find(split_k_d_name) != lora_tensors.end()) { - lora_k_down = to_f32(compute_ctx, lora_tensors[split_k_d_name]); - } - - if (lora_tensors.find(split_k_u_name) != lora_tensors.end()) { - lora_k_up = to_f32(compute_ctx, lora_tensors[split_k_u_name]); - } - - if (lora_tensors.find(split_v_d_name) != lora_tensors.end()) { - lora_v_down = to_f32(compute_ctx, lora_tensors[split_v_d_name]); - } - - if (lora_tensors.find(split_v_u_name) != lora_tensors.end()) { - lora_v_up = to_f32(compute_ctx, lora_tensors[split_v_u_name]); - } - - float q_rank = lora_q_up->ne[0]; - float k_rank = lora_k_up->ne[0]; - float v_rank = lora_v_up->ne[0]; - - float lora_q_scale = 1; - float lora_k_scale = 1; - float lora_v_scale = 1; - - if (lora_tensors.find(split_q_scale_name) != lora_tensors.end()) { - lora_q_scale = ggml_backend_tensor_get_f32(lora_tensors[split_q_scale_name]); - applied_lora_tensors.insert(split_q_scale_name); - } - if (lora_tensors.find(split_k_scale_name) != lora_tensors.end()) { - lora_k_scale = ggml_backend_tensor_get_f32(lora_tensors[split_k_scale_name]); - applied_lora_tensors.insert(split_k_scale_name); - } - if (lora_tensors.find(split_v_scale_name) != lora_tensors.end()) { - lora_v_scale = ggml_backend_tensor_get_f32(lora_tensors[split_v_scale_name]); - applied_lora_tensors.insert(split_v_scale_name); - } - - if (lora_tensors.find(split_q_alpha_name) != lora_tensors.end()) { - float lora_q_alpha = ggml_backend_tensor_get_f32(lora_tensors[split_q_alpha_name]); - applied_lora_tensors.insert(split_q_alpha_name); - lora_q_scale = lora_q_alpha / q_rank; - } - if (lora_tensors.find(split_k_alpha_name) != lora_tensors.end()) { - float lora_k_alpha = ggml_backend_tensor_get_f32(lora_tensors[split_k_alpha_name]); - applied_lora_tensors.insert(split_k_alpha_name); - lora_k_scale = lora_k_alpha / k_rank; - } - if (lora_tensors.find(split_v_alpha_name) != lora_tensors.end()) { - float lora_v_alpha = ggml_backend_tensor_get_f32(lora_tensors[split_v_alpha_name]); - applied_lora_tensors.insert(split_v_alpha_name); - lora_v_scale = lora_v_alpha / v_rank; - } - - ggml_scale_inplace(compute_ctx, lora_q_down, lora_q_scale); - ggml_scale_inplace(compute_ctx, lora_k_down, lora_k_scale); - ggml_scale_inplace(compute_ctx, lora_v_down, lora_v_scale); - - // print_ggml_tensor(lora_q_down, true); //[3072, R, 1, 1] - // print_ggml_tensor(lora_k_down, true); //[3072, R, 1, 1] - // print_ggml_tensor(lora_v_down, true); //[3072, R, 1, 1] - // print_ggml_tensor(lora_q_up, true); //[R, 3072, 1, 1] - // print_ggml_tensor(lora_k_up, true); //[R, 3072, 1, 1] - // print_ggml_tensor(lora_v_up, true); //[R, 3072, 1, 1] - - // these need to be stitched together this way: - // |q_up,0 ,0 | - // |0 ,k_up,0 | - // |0 ,0 ,v_up| - // (q_down,k_down,v_down) . (q ,k ,v) - - // up_concat will be [9216, R*3, 1, 1] - // down_concat will be [R*3, 3072, 1, 1] - ggml_tensor* lora_down_concat = ggml_concat(compute_ctx, ggml_concat(compute_ctx, lora_q_down, lora_k_down, 1), lora_v_down, 1); - - ggml_tensor* z = ggml_dup_tensor(compute_ctx, lora_q_up); - ggml_scale(compute_ctx, z, 0); - ggml_tensor* zz = ggml_concat(compute_ctx, z, z, 1); - - ggml_tensor* q_up = ggml_concat(compute_ctx, lora_q_up, zz, 1); - ggml_tensor* k_up = ggml_concat(compute_ctx, ggml_concat(compute_ctx, z, lora_k_up, 1), z, 1); - ggml_tensor* v_up = ggml_concat(compute_ctx, zz, lora_v_up, 1); - // print_ggml_tensor(q_up, true); //[R, 9216, 1, 1] - // print_ggml_tensor(k_up, true); //[R, 9216, 1, 1] - // print_ggml_tensor(v_up, true); //[R, 9216, 1, 1] - ggml_tensor* lora_up_concat = ggml_concat(compute_ctx, ggml_concat(compute_ctx, q_up, k_up, 0), v_up, 0); - // print_ggml_tensor(lora_up_concat, true); //[R*3, 9216, 1, 1] - - lora_down = ggml_cont(compute_ctx, lora_down_concat); - lora_up = ggml_cont(compute_ctx, lora_up_concat); - - applied_lora_tensors.insert(split_q_u_name); - applied_lora_tensors.insert(split_k_u_name); - applied_lora_tensors.insert(split_v_u_name); - - applied_lora_tensors.insert(split_q_d_name); - applied_lora_tensors.insert(split_k_d_name); - applied_lora_tensors.insert(split_v_d_name); - } } - if (starts_with(key, "SPLIT_L|")) { + bool is_qkvm_split = starts_with(key, "SPLIT_L|"); + if (is_qkvm_split) { key = key.substr(sizeof("SPLIT_L|") - 1); - - auto split_q_d_name = lora_pre[type] + key + "attn.to_q" + lora_downs[type] + ".weight"; - if (lora_tensors.find(split_q_d_name) != lora_tensors.end()) { - // print_ggml_tensor(it.second, true); //[3072, 21504, 1, 1] - // find qkv and mlp up parts in LoRA model - auto split_k_d_name = lora_pre[type] + key + "attn.to_k" + lora_downs[type] + ".weight"; - auto split_v_d_name = lora_pre[type] + key + "attn.to_v" + lora_downs[type] + ".weight"; - - auto split_q_u_name = lora_pre[type] + key + "attn.to_q" + lora_ups[type] + ".weight"; - auto split_k_u_name = lora_pre[type] + key + "attn.to_k" + lora_ups[type] + ".weight"; - auto split_v_u_name = lora_pre[type] + key + "attn.to_v" + lora_ups[type] + ".weight"; - - auto split_m_d_name = lora_pre[type] + key + "proj_mlp" + lora_downs[type] + ".weight"; - auto split_m_u_name = lora_pre[type] + key + "proj_mlp" + lora_ups[type] + ".weight"; - - auto split_q_scale_name = lora_pre[type] + key + "attn.to_q" + ".scale"; - auto split_k_scale_name = lora_pre[type] + key + "attn.to_k" + ".scale"; - auto split_v_scale_name = lora_pre[type] + key + "attn.to_v" + ".scale"; - auto split_m_scale_name = lora_pre[type] + key + "proj_mlp" + ".scale"; - - auto split_q_alpha_name = lora_pre[type] + key + "attn.to_q" + ".alpha"; - auto split_k_alpha_name = lora_pre[type] + key + "attn.to_k" + ".alpha"; - auto split_v_alpha_name = lora_pre[type] + key + "attn.to_v" + ".alpha"; - auto split_m_alpha_name = lora_pre[type] + key + "proj_mlp" + ".alpha"; - - ggml_tensor* lora_q_down = NULL; - ggml_tensor* lora_q_up = NULL; - ggml_tensor* lora_k_down = NULL; - ggml_tensor* lora_k_up = NULL; - ggml_tensor* lora_v_down = NULL; - ggml_tensor* lora_v_up = NULL; - - ggml_tensor* lora_m_down = NULL; - ggml_tensor* lora_m_up = NULL; - - lora_q_up = to_f32(compute_ctx, lora_tensors[split_q_u_name]); - - if (lora_tensors.find(split_q_d_name) != lora_tensors.end()) { - lora_q_down = to_f32(compute_ctx, lora_tensors[split_q_d_name]); - } - - if (lora_tensors.find(split_q_u_name) != lora_tensors.end()) { - lora_q_up = to_f32(compute_ctx, lora_tensors[split_q_u_name]); - } - - if (lora_tensors.find(split_k_d_name) != lora_tensors.end()) { - lora_k_down = to_f32(compute_ctx, lora_tensors[split_k_d_name]); - } - - if (lora_tensors.find(split_k_u_name) != lora_tensors.end()) { - lora_k_up = to_f32(compute_ctx, lora_tensors[split_k_u_name]); - } - - if (lora_tensors.find(split_v_d_name) != lora_tensors.end()) { - lora_v_down = to_f32(compute_ctx, lora_tensors[split_v_d_name]); - } - - if (lora_tensors.find(split_v_u_name) != lora_tensors.end()) { - lora_v_up = to_f32(compute_ctx, lora_tensors[split_v_u_name]); - } - - if (lora_tensors.find(split_m_d_name) != lora_tensors.end()) { - lora_m_down = to_f32(compute_ctx, lora_tensors[split_m_d_name]); - } - - if (lora_tensors.find(split_m_u_name) != lora_tensors.end()) { - lora_m_up = to_f32(compute_ctx, lora_tensors[split_m_u_name]); - } - - float q_rank = lora_q_up->ne[0]; - float k_rank = lora_k_up->ne[0]; - float v_rank = lora_v_up->ne[0]; - float m_rank = lora_v_up->ne[0]; - - float lora_q_scale = 1; - float lora_k_scale = 1; - float lora_v_scale = 1; - float lora_m_scale = 1; - - if (lora_tensors.find(split_q_scale_name) != lora_tensors.end()) { - lora_q_scale = ggml_backend_tensor_get_f32(lora_tensors[split_q_scale_name]); - applied_lora_tensors.insert(split_q_scale_name); - } - if (lora_tensors.find(split_k_scale_name) != lora_tensors.end()) { - lora_k_scale = ggml_backend_tensor_get_f32(lora_tensors[split_k_scale_name]); - applied_lora_tensors.insert(split_k_scale_name); - } - if (lora_tensors.find(split_v_scale_name) != lora_tensors.end()) { - lora_v_scale = ggml_backend_tensor_get_f32(lora_tensors[split_v_scale_name]); - applied_lora_tensors.insert(split_v_scale_name); - } - if (lora_tensors.find(split_m_scale_name) != lora_tensors.end()) { - lora_m_scale = ggml_backend_tensor_get_f32(lora_tensors[split_m_scale_name]); - applied_lora_tensors.insert(split_m_scale_name); - } - - if (lora_tensors.find(split_q_alpha_name) != lora_tensors.end()) { - float lora_q_alpha = ggml_backend_tensor_get_f32(lora_tensors[split_q_alpha_name]); - applied_lora_tensors.insert(split_q_alpha_name); - lora_q_scale = lora_q_alpha / q_rank; - } - if (lora_tensors.find(split_k_alpha_name) != lora_tensors.end()) { - float lora_k_alpha = ggml_backend_tensor_get_f32(lora_tensors[split_k_alpha_name]); - applied_lora_tensors.insert(split_k_alpha_name); - lora_k_scale = lora_k_alpha / k_rank; - } - if (lora_tensors.find(split_v_alpha_name) != lora_tensors.end()) { - float lora_v_alpha = ggml_backend_tensor_get_f32(lora_tensors[split_v_alpha_name]); - applied_lora_tensors.insert(split_v_alpha_name); - lora_v_scale = lora_v_alpha / v_rank; - } - if (lora_tensors.find(split_m_alpha_name) != lora_tensors.end()) { - float lora_m_alpha = ggml_backend_tensor_get_f32(lora_tensors[split_m_alpha_name]); - applied_lora_tensors.insert(split_m_alpha_name); - lora_m_scale = lora_m_alpha / m_rank; - } - - ggml_scale_inplace(compute_ctx, lora_q_down, lora_q_scale); - ggml_scale_inplace(compute_ctx, lora_k_down, lora_k_scale); - ggml_scale_inplace(compute_ctx, lora_v_down, lora_v_scale); - ggml_scale_inplace(compute_ctx, lora_m_down, lora_m_scale); - - // print_ggml_tensor(lora_q_down, true); //[3072, R, 1, 1] - // print_ggml_tensor(lora_k_down, true); //[3072, R, 1, 1] - // print_ggml_tensor(lora_v_down, true); //[3072, R, 1, 1] - // print_ggml_tensor(lora_m_down, true); //[3072, R, 1, 1] - // print_ggml_tensor(lora_q_up, true); //[R, 3072, 1, 1] - // print_ggml_tensor(lora_k_up, true); //[R, 3072, 1, 1] - // print_ggml_tensor(lora_v_up, true); //[R, 3072, 1, 1] - // print_ggml_tensor(lora_m_up, true); //[R, 12288, 1, 1] - - // these need to be stitched together this way: - // |q_up,0 ,0 ,0 | - // |0 ,k_up,0 ,0 | - // |0 ,0 ,v_up,0 | - // |0 ,0 ,0 ,m_up| - // (q_down,k_down,v_down,m_down) . (q ,k ,v ,m) - - // up_concat will be [21504, R*4, 1, 1] - // down_concat will be [R*4, 3072, 1, 1] - - ggml_tensor* lora_down_concat = ggml_concat(compute_ctx, ggml_concat(compute_ctx, lora_q_down, lora_k_down, 1), ggml_concat(compute_ctx, lora_v_down, lora_m_down, 1), 1); - // print_ggml_tensor(lora_down_concat, true); //[3072, R*4, 1, 1] - - // this also means that if rank is bigger than 672, it is less memory efficient to do it this way (should be fine) - // print_ggml_tensor(lora_q_up, true); //[3072, R, 1, 1] - ggml_tensor* z = ggml_dup_tensor(compute_ctx, lora_q_up); - ggml_tensor* mlp_z = ggml_dup_tensor(compute_ctx, lora_m_up); - ggml_scale(compute_ctx, z, 0); - ggml_scale(compute_ctx, mlp_z, 0); - ggml_tensor* zz = ggml_concat(compute_ctx, z, z, 1); - - ggml_tensor* q_up = ggml_concat(compute_ctx, ggml_concat(compute_ctx, lora_q_up, zz, 1), mlp_z, 1); - ggml_tensor* k_up = ggml_concat(compute_ctx, ggml_concat(compute_ctx, z, lora_k_up, 1), ggml_concat(compute_ctx, z, mlp_z, 1), 1); - ggml_tensor* v_up = ggml_concat(compute_ctx, ggml_concat(compute_ctx, zz, lora_v_up, 1), mlp_z, 1); - ggml_tensor* m_up = ggml_concat(compute_ctx, ggml_concat(compute_ctx, zz, z, 1), lora_m_up, 1); - // print_ggml_tensor(q_up, true); //[R, 21504, 1, 1] - // print_ggml_tensor(k_up, true); //[R, 21504, 1, 1] - // print_ggml_tensor(v_up, true); //[R, 21504, 1, 1] - // print_ggml_tensor(m_up, true); //[R, 21504, 1, 1] - - ggml_tensor* lora_up_concat = ggml_concat(compute_ctx, ggml_concat(compute_ctx, q_up, k_up, 0), ggml_concat(compute_ctx, v_up, m_up, 0), 0); - // print_ggml_tensor(lora_up_concat, true); //[R*4, 21504, 1, 1] - - lora_down = ggml_cont(compute_ctx, lora_down_concat); - lora_up = ggml_cont(compute_ctx, lora_up_concat); - - applied_lora_tensors.insert(split_q_u_name); - applied_lora_tensors.insert(split_k_u_name); - applied_lora_tensors.insert(split_v_u_name); - applied_lora_tensors.insert(split_m_u_name); - - applied_lora_tensors.insert(split_q_d_name); - applied_lora_tensors.insert(split_k_d_name); - applied_lora_tensors.insert(split_v_d_name); - applied_lora_tensors.insert(split_m_d_name); - } } - if (lora_up == NULL || lora_down == NULL) { - lora_up_name = lora_pre[type] + key + lora_ups[type] + ".weight"; - if (lora_tensors.find(lora_up_name) == lora_tensors.end()) { - if (key == "model_diffusion_model_output_blocks_2_2_conv") { - // fix for some sdxl lora, like lcm-lora-xl - key = "model_diffusion_model_output_blocks_2_1_conv"; - lora_up_name = lora_pre[type] + key + lora_ups[type] + ".weight"; - } + struct ggml_tensor* updown = NULL; + float scale_value = 1.0f; + std::string fk = lora_pre[type] + key; + if (lora_tensors.find(fk + ".hada_w1_a") != lora_tensors.end()) { + // LoHa mode + + // TODO: split qkv convention for LoHas (is it ever used?) + if (is_qkv_split || is_qkvm_split) { + LOG_ERROR("Split qkv isn't supported for LoHa models."); + break; + } + std::string alpha_name = ""; + + ggml_tensor* hada_1_mid = NULL; // tau for tucker decomposition + ggml_tensor* hada_1_up = NULL; + ggml_tensor* hada_1_down = NULL; + + ggml_tensor* hada_2_mid = NULL; // tau for tucker decomposition + ggml_tensor* hada_2_up = NULL; + ggml_tensor* hada_2_down = NULL; + + std::string hada_1_mid_name = ""; + std::string hada_1_down_name = ""; + std::string hada_1_up_name = ""; + + std::string hada_2_mid_name = ""; + std::string hada_2_down_name = ""; + std::string hada_2_up_name = ""; + + + hada_1_down_name = fk + ".hada_w1_b"; + hada_1_up_name = fk + ".hada_w1_a"; + hada_1_mid_name = fk + ".hada_t1"; + if (lora_tensors.find(hada_1_down_name) != lora_tensors.end()) { + hada_1_down = to_f32(compute_ctx, lora_tensors[hada_1_down_name]); + } + if (lora_tensors.find(hada_1_up_name) != lora_tensors.end()) { + hada_1_up = to_f32(compute_ctx, lora_tensors[hada_1_up_name]); + } + if (lora_tensors.find(hada_1_mid_name) != lora_tensors.end()) { + hada_1_mid = to_f32(compute_ctx, lora_tensors[hada_1_mid_name]); + applied_lora_tensors.insert(hada_1_mid_name); + hada_1_up = ggml_cont(compute_ctx, ggml_transpose(compute_ctx, hada_1_up)); } - lora_down_name = lora_pre[type] + key + lora_downs[type] + ".weight"; - alpha_name = lora_pre[type] + key + ".alpha"; - scale_name = lora_pre[type] + key + ".scale"; - - if (lora_tensors.find(lora_up_name) != lora_tensors.end()) { - lora_up = lora_tensors[lora_up_name]; + hada_2_down_name = fk + ".hada_w2_b"; + hada_2_up_name = fk + ".hada_w2_a"; + hada_2_mid_name = fk + ".hada_t2"; + if (lora_tensors.find(hada_2_down_name) != lora_tensors.end()) { + hada_2_down = to_f32(compute_ctx, lora_tensors[hada_2_down_name]); + } + if (lora_tensors.find(hada_2_up_name) != lora_tensors.end()) { + hada_2_up = to_f32(compute_ctx, lora_tensors[hada_2_up_name]); + } + if (lora_tensors.find(hada_2_mid_name) != lora_tensors.end()) { + hada_2_mid = to_f32(compute_ctx, lora_tensors[hada_2_mid_name]); + applied_lora_tensors.insert(hada_2_mid_name); + hada_2_up = ggml_cont(compute_ctx, ggml_transpose(compute_ctx, hada_2_up)); } - if (lora_tensors.find(lora_down_name) != lora_tensors.end()) { - lora_down = lora_tensors[lora_down_name]; - } - applied_lora_tensors.insert(lora_up_name); - applied_lora_tensors.insert(lora_down_name); + alpha_name = fk + ".alpha"; + + applied_lora_tensors.insert(hada_1_down_name); + applied_lora_tensors.insert(hada_1_up_name); + applied_lora_tensors.insert(hada_2_down_name); + applied_lora_tensors.insert(hada_2_up_name); + applied_lora_tensors.insert(alpha_name); - applied_lora_tensors.insert(scale_name); - } + if (hada_1_up == NULL || hada_1_down == NULL || hada_2_up == NULL || hada_2_down == NULL) { + continue; + } - if (lora_up == NULL || lora_down == NULL) { - continue; - } - // calc_scale - int64_t dim = lora_down->ne[ggml_n_dims(lora_down) - 1]; - float scale_value = 1.0f; - if (lora_tensors.find(scale_name) != lora_tensors.end()) { - scale_value = ggml_backend_tensor_get_f32(lora_tensors[scale_name]); - } else if (lora_tensors.find(alpha_name) != lora_tensors.end()) { - float alpha = ggml_backend_tensor_get_f32(lora_tensors[alpha_name]); - scale_value = alpha / dim; + struct ggml_tensor* updown_1 = ggml_merge_lora(compute_ctx, hada_1_down, hada_1_up, hada_1_mid); + struct ggml_tensor* updown_2 = ggml_merge_lora(compute_ctx, hada_2_down, hada_2_up, hada_2_mid); + updown = ggml_mul_inplace(compute_ctx, updown_1, updown_2); + + // calc_scale + // TODO: .dora_scale? + int64_t rank = hada_1_down->ne[ggml_n_dims(hada_1_down) - 1]; + if (lora_tensors.find(alpha_name) != lora_tensors.end()) { + float alpha = ggml_backend_tensor_get_f32(lora_tensors[alpha_name]); + scale_value = alpha / rank; + } + } else if (lora_tensors.find(fk + ".lokr_w1") != lora_tensors.end() || lora_tensors.find(fk + ".lokr_w1_a") != lora_tensors.end()) { + // LoKr mode + + // TODO: split qkv convention for LoKrs (is it ever used?) + if (is_qkv_split || is_qkvm_split) { + LOG_ERROR("Split qkv isn't supported for LoKr models."); + break; + } + + std::string alpha_name = fk + ".alpha"; + + ggml_tensor* lokr_w1 = NULL; + ggml_tensor* lokr_w2 = NULL; + + std::string lokr_w1_name = ""; + std::string lokr_w2_name = ""; + + lokr_w1_name = fk + ".lokr_w1"; + lokr_w2_name = fk + ".lokr_w2"; + + if (lora_tensors.find(lokr_w1_name) != lora_tensors.end()) { + lokr_w1 = to_f32(compute_ctx, lora_tensors[lokr_w1_name]); + applied_lora_tensors.insert(lokr_w1_name); + } else { + ggml_tensor* down = NULL; + ggml_tensor* up = NULL; + std::string down_name = lokr_w1_name + "_b"; + std::string up_name = lokr_w1_name + "_a"; + if (lora_tensors.find(down_name) != lora_tensors.end()) { + // w1 should not be low rank normally, sometimes w1 and w2 are swapped + down = to_f32(compute_ctx, lora_tensors[down_name]); + applied_lora_tensors.insert(down_name); + + int64_t rank = down->ne[ggml_n_dims(down) - 1]; + if (lora_tensors.find(alpha_name) != lora_tensors.end()) { + float alpha = ggml_backend_tensor_get_f32(lora_tensors[alpha_name]); + scale_value = alpha / rank; + } + } + if (lora_tensors.find(up_name) != lora_tensors.end()) { + up = to_f32(compute_ctx, lora_tensors[up_name]); + applied_lora_tensors.insert(up_name); + } + lokr_w1 = ggml_merge_lora(compute_ctx, down, up); + } + if (lora_tensors.find(lokr_w2_name) != lora_tensors.end()) { + lokr_w2 = to_f32(compute_ctx, lora_tensors[lokr_w2_name]); + applied_lora_tensors.insert(lokr_w2_name); + } else { + ggml_tensor* down = NULL; + ggml_tensor* up = NULL; + std::string down_name = lokr_w2_name + "_b"; + std::string up_name = lokr_w2_name + "_a"; + if (lora_tensors.find(down_name) != lora_tensors.end()) { + down = to_f32(compute_ctx, lora_tensors[down_name]); + applied_lora_tensors.insert(down_name); + + int64_t rank = down->ne[ggml_n_dims(down) - 1]; + if (lora_tensors.find(alpha_name) != lora_tensors.end()) { + float alpha = ggml_backend_tensor_get_f32(lora_tensors[alpha_name]); + scale_value = alpha / rank; + } + } + if (lora_tensors.find(up_name) != lora_tensors.end()) { + up = to_f32(compute_ctx, lora_tensors[up_name]); + applied_lora_tensors.insert(up_name); + } + lokr_w2 = ggml_merge_lora(compute_ctx, down, up); + } + + // Technically it might be unused, but I believe it's the expected behavior + applied_lora_tensors.insert(alpha_name); + + updown = ggml_kronecker(compute_ctx, lokr_w1, lokr_w2); + + } else { + // LoRA mode + ggml_tensor* lora_mid = NULL; // tau for tucker decomposition + ggml_tensor* lora_up = NULL; + ggml_tensor* lora_down = NULL; + + std::string alpha_name = ""; + std::string scale_name = ""; + std::string split_q_scale_name = ""; + std::string lora_mid_name = ""; + std::string lora_down_name = ""; + std::string lora_up_name = ""; + + if (is_qkv_split) { + std::string suffix = ""; + auto split_q_d_name = fk + "q" + suffix + lora_downs[type] + ".weight"; + + if (lora_tensors.find(split_q_d_name) == lora_tensors.end()) { + suffix = "_proj"; + split_q_d_name = fk + "q" + suffix + lora_downs[type] + ".weight"; + } + if (lora_tensors.find(split_q_d_name) != lora_tensors.end()) { + // print_ggml_tensor(it.second, true); //[3072, 21504, 1, 1] + // find qkv and mlp up parts in LoRA model + auto split_k_d_name = fk + "k" + suffix + lora_downs[type] + ".weight"; + auto split_v_d_name = fk + "v" + suffix + lora_downs[type] + ".weight"; + + auto split_q_u_name = fk + "q" + suffix + lora_ups[type] + ".weight"; + auto split_k_u_name = fk + "k" + suffix + lora_ups[type] + ".weight"; + auto split_v_u_name = fk + "v" + suffix + lora_ups[type] + ".weight"; + + auto split_q_scale_name = fk + "q" + suffix + ".scale"; + auto split_k_scale_name = fk + "k" + suffix + ".scale"; + auto split_v_scale_name = fk + "v" + suffix + ".scale"; + + auto split_q_alpha_name = fk + "q" + suffix + ".alpha"; + auto split_k_alpha_name = fk + "k" + suffix + ".alpha"; + auto split_v_alpha_name = fk + "v" + suffix + ".alpha"; + + ggml_tensor* lora_q_down = NULL; + ggml_tensor* lora_q_up = NULL; + ggml_tensor* lora_k_down = NULL; + ggml_tensor* lora_k_up = NULL; + ggml_tensor* lora_v_down = NULL; + ggml_tensor* lora_v_up = NULL; + + lora_q_down = to_f32(compute_ctx, lora_tensors[split_q_d_name]); + + if (lora_tensors.find(split_q_u_name) != lora_tensors.end()) { + lora_q_up = to_f32(compute_ctx, lora_tensors[split_q_u_name]); + } + + if (lora_tensors.find(split_k_d_name) != lora_tensors.end()) { + lora_k_down = to_f32(compute_ctx, lora_tensors[split_k_d_name]); + } + + if (lora_tensors.find(split_k_u_name) != lora_tensors.end()) { + lora_k_up = to_f32(compute_ctx, lora_tensors[split_k_u_name]); + } + + if (lora_tensors.find(split_v_d_name) != lora_tensors.end()) { + lora_v_down = to_f32(compute_ctx, lora_tensors[split_v_d_name]); + } + + if (lora_tensors.find(split_v_u_name) != lora_tensors.end()) { + lora_v_up = to_f32(compute_ctx, lora_tensors[split_v_u_name]); + } + + float q_rank = lora_q_up->ne[0]; + float k_rank = lora_k_up->ne[0]; + float v_rank = lora_v_up->ne[0]; + + float lora_q_scale = 1; + float lora_k_scale = 1; + float lora_v_scale = 1; + + if (lora_tensors.find(split_q_scale_name) != lora_tensors.end()) { + lora_q_scale = ggml_backend_tensor_get_f32(lora_tensors[split_q_scale_name]); + applied_lora_tensors.insert(split_q_scale_name); + } + if (lora_tensors.find(split_k_scale_name) != lora_tensors.end()) { + lora_k_scale = ggml_backend_tensor_get_f32(lora_tensors[split_k_scale_name]); + applied_lora_tensors.insert(split_k_scale_name); + } + if (lora_tensors.find(split_v_scale_name) != lora_tensors.end()) { + lora_v_scale = ggml_backend_tensor_get_f32(lora_tensors[split_v_scale_name]); + applied_lora_tensors.insert(split_v_scale_name); + } + + if (lora_tensors.find(split_q_alpha_name) != lora_tensors.end()) { + float lora_q_alpha = ggml_backend_tensor_get_f32(lora_tensors[split_q_alpha_name]); + applied_lora_tensors.insert(split_q_alpha_name); + lora_q_scale = lora_q_alpha / q_rank; + } + if (lora_tensors.find(split_k_alpha_name) != lora_tensors.end()) { + float lora_k_alpha = ggml_backend_tensor_get_f32(lora_tensors[split_k_alpha_name]); + applied_lora_tensors.insert(split_k_alpha_name); + lora_k_scale = lora_k_alpha / k_rank; + } + if (lora_tensors.find(split_v_alpha_name) != lora_tensors.end()) { + float lora_v_alpha = ggml_backend_tensor_get_f32(lora_tensors[split_v_alpha_name]); + applied_lora_tensors.insert(split_v_alpha_name); + lora_v_scale = lora_v_alpha / v_rank; + } + + ggml_scale_inplace(compute_ctx, lora_q_down, lora_q_scale); + ggml_scale_inplace(compute_ctx, lora_k_down, lora_k_scale); + ggml_scale_inplace(compute_ctx, lora_v_down, lora_v_scale); + + // print_ggml_tensor(lora_q_down, true); //[3072, R, 1, 1] + // print_ggml_tensor(lora_k_down, true); //[3072, R, 1, 1] + // print_ggml_tensor(lora_v_down, true); //[3072, R, 1, 1] + // print_ggml_tensor(lora_q_up, true); //[R, 3072, 1, 1] + // print_ggml_tensor(lora_k_up, true); //[R, 3072, 1, 1] + // print_ggml_tensor(lora_v_up, true); //[R, 3072, 1, 1] + + // these need to be stitched together this way: + // |q_up,0 ,0 | + // |0 ,k_up,0 | + // |0 ,0 ,v_up| + // (q_down,k_down,v_down) . (q ,k ,v) + + // up_concat will be [9216, R*3, 1, 1] + // down_concat will be [R*3, 3072, 1, 1] + ggml_tensor* lora_down_concat = ggml_concat(compute_ctx, ggml_concat(compute_ctx, lora_q_down, lora_k_down, 1), lora_v_down, 1); + + ggml_tensor* z = ggml_dup_tensor(compute_ctx, lora_q_up); + ggml_scale(compute_ctx, z, 0); + ggml_tensor* zz = ggml_concat(compute_ctx, z, z, 1); + + ggml_tensor* q_up = ggml_concat(compute_ctx, lora_q_up, zz, 1); + ggml_tensor* k_up = ggml_concat(compute_ctx, ggml_concat(compute_ctx, z, lora_k_up, 1), z, 1); + ggml_tensor* v_up = ggml_concat(compute_ctx, zz, lora_v_up, 1); + // print_ggml_tensor(q_up, true); //[R, 9216, 1, 1] + // print_ggml_tensor(k_up, true); //[R, 9216, 1, 1] + // print_ggml_tensor(v_up, true); //[R, 9216, 1, 1] + ggml_tensor* lora_up_concat = ggml_concat(compute_ctx, ggml_concat(compute_ctx, q_up, k_up, 0), v_up, 0); + // print_ggml_tensor(lora_up_concat, true); //[R*3, 9216, 1, 1] + + lora_down = ggml_cont(compute_ctx, lora_down_concat); + lora_up = ggml_cont(compute_ctx, lora_up_concat); + + applied_lora_tensors.insert(split_q_u_name); + applied_lora_tensors.insert(split_k_u_name); + applied_lora_tensors.insert(split_v_u_name); + + applied_lora_tensors.insert(split_q_d_name); + applied_lora_tensors.insert(split_k_d_name); + applied_lora_tensors.insert(split_v_d_name); + } + } else if (is_qkvm_split) { + auto split_q_d_name = fk + "attn.to_q" + lora_downs[type] + ".weight"; + if (lora_tensors.find(split_q_d_name) != lora_tensors.end()) { + // print_ggml_tensor(it.second, true); //[3072, 21504, 1, 1] + // find qkv and mlp up parts in LoRA model + auto split_k_d_name = fk + "attn.to_k" + lora_downs[type] + ".weight"; + auto split_v_d_name = fk + "attn.to_v" + lora_downs[type] + ".weight"; + + auto split_q_u_name = fk + "attn.to_q" + lora_ups[type] + ".weight"; + auto split_k_u_name = fk + "attn.to_k" + lora_ups[type] + ".weight"; + auto split_v_u_name = fk + "attn.to_v" + lora_ups[type] + ".weight"; + + auto split_m_d_name = fk + "proj_mlp" + lora_downs[type] + ".weight"; + auto split_m_u_name = fk + "proj_mlp" + lora_ups[type] + ".weight"; + + auto split_q_scale_name = fk + "attn.to_q" + ".scale"; + auto split_k_scale_name = fk + "attn.to_k" + ".scale"; + auto split_v_scale_name = fk + "attn.to_v" + ".scale"; + auto split_m_scale_name = fk + "proj_mlp" + ".scale"; + + auto split_q_alpha_name = fk + "attn.to_q" + ".alpha"; + auto split_k_alpha_name = fk + "attn.to_k" + ".alpha"; + auto split_v_alpha_name = fk + "attn.to_v" + ".alpha"; + auto split_m_alpha_name = fk + "proj_mlp" + ".alpha"; + + ggml_tensor* lora_q_down = NULL; + ggml_tensor* lora_q_up = NULL; + ggml_tensor* lora_k_down = NULL; + ggml_tensor* lora_k_up = NULL; + ggml_tensor* lora_v_down = NULL; + ggml_tensor* lora_v_up = NULL; + + ggml_tensor* lora_m_down = NULL; + ggml_tensor* lora_m_up = NULL; + + lora_q_up = to_f32(compute_ctx, lora_tensors[split_q_u_name]); + + if (lora_tensors.find(split_q_d_name) != lora_tensors.end()) { + lora_q_down = to_f32(compute_ctx, lora_tensors[split_q_d_name]); + } + + if (lora_tensors.find(split_q_u_name) != lora_tensors.end()) { + lora_q_up = to_f32(compute_ctx, lora_tensors[split_q_u_name]); + } + + if (lora_tensors.find(split_k_d_name) != lora_tensors.end()) { + lora_k_down = to_f32(compute_ctx, lora_tensors[split_k_d_name]); + } + + if (lora_tensors.find(split_k_u_name) != lora_tensors.end()) { + lora_k_up = to_f32(compute_ctx, lora_tensors[split_k_u_name]); + } + + if (lora_tensors.find(split_v_d_name) != lora_tensors.end()) { + lora_v_down = to_f32(compute_ctx, lora_tensors[split_v_d_name]); + } + + if (lora_tensors.find(split_v_u_name) != lora_tensors.end()) { + lora_v_up = to_f32(compute_ctx, lora_tensors[split_v_u_name]); + } + + if (lora_tensors.find(split_m_d_name) != lora_tensors.end()) { + lora_m_down = to_f32(compute_ctx, lora_tensors[split_m_d_name]); + } + + if (lora_tensors.find(split_m_u_name) != lora_tensors.end()) { + lora_m_up = to_f32(compute_ctx, lora_tensors[split_m_u_name]); + } + + float q_rank = lora_q_up->ne[0]; + float k_rank = lora_k_up->ne[0]; + float v_rank = lora_v_up->ne[0]; + float m_rank = lora_v_up->ne[0]; + + float lora_q_scale = 1; + float lora_k_scale = 1; + float lora_v_scale = 1; + float lora_m_scale = 1; + + if (lora_tensors.find(split_q_scale_name) != lora_tensors.end()) { + lora_q_scale = ggml_backend_tensor_get_f32(lora_tensors[split_q_scale_name]); + applied_lora_tensors.insert(split_q_scale_name); + } + if (lora_tensors.find(split_k_scale_name) != lora_tensors.end()) { + lora_k_scale = ggml_backend_tensor_get_f32(lora_tensors[split_k_scale_name]); + applied_lora_tensors.insert(split_k_scale_name); + } + if (lora_tensors.find(split_v_scale_name) != lora_tensors.end()) { + lora_v_scale = ggml_backend_tensor_get_f32(lora_tensors[split_v_scale_name]); + applied_lora_tensors.insert(split_v_scale_name); + } + if (lora_tensors.find(split_m_scale_name) != lora_tensors.end()) { + lora_m_scale = ggml_backend_tensor_get_f32(lora_tensors[split_m_scale_name]); + applied_lora_tensors.insert(split_m_scale_name); + } + + if (lora_tensors.find(split_q_alpha_name) != lora_tensors.end()) { + float lora_q_alpha = ggml_backend_tensor_get_f32(lora_tensors[split_q_alpha_name]); + applied_lora_tensors.insert(split_q_alpha_name); + lora_q_scale = lora_q_alpha / q_rank; + } + if (lora_tensors.find(split_k_alpha_name) != lora_tensors.end()) { + float lora_k_alpha = ggml_backend_tensor_get_f32(lora_tensors[split_k_alpha_name]); + applied_lora_tensors.insert(split_k_alpha_name); + lora_k_scale = lora_k_alpha / k_rank; + } + if (lora_tensors.find(split_v_alpha_name) != lora_tensors.end()) { + float lora_v_alpha = ggml_backend_tensor_get_f32(lora_tensors[split_v_alpha_name]); + applied_lora_tensors.insert(split_v_alpha_name); + lora_v_scale = lora_v_alpha / v_rank; + } + if (lora_tensors.find(split_m_alpha_name) != lora_tensors.end()) { + float lora_m_alpha = ggml_backend_tensor_get_f32(lora_tensors[split_m_alpha_name]); + applied_lora_tensors.insert(split_m_alpha_name); + lora_m_scale = lora_m_alpha / m_rank; + } + + ggml_scale_inplace(compute_ctx, lora_q_down, lora_q_scale); + ggml_scale_inplace(compute_ctx, lora_k_down, lora_k_scale); + ggml_scale_inplace(compute_ctx, lora_v_down, lora_v_scale); + ggml_scale_inplace(compute_ctx, lora_m_down, lora_m_scale); + + // print_ggml_tensor(lora_q_down, true); //[3072, R, 1, 1] + // print_ggml_tensor(lora_k_down, true); //[3072, R, 1, 1] + // print_ggml_tensor(lora_v_down, true); //[3072, R, 1, 1] + // print_ggml_tensor(lora_m_down, true); //[3072, R, 1, 1] + // print_ggml_tensor(lora_q_up, true); //[R, 3072, 1, 1] + // print_ggml_tensor(lora_k_up, true); //[R, 3072, 1, 1] + // print_ggml_tensor(lora_v_up, true); //[R, 3072, 1, 1] + // print_ggml_tensor(lora_m_up, true); //[R, 12288, 1, 1] + + // these need to be stitched together this way: + // |q_up,0 ,0 ,0 | + // |0 ,k_up,0 ,0 | + // |0 ,0 ,v_up,0 | + // |0 ,0 ,0 ,m_up| + // (q_down,k_down,v_down,m_down) . (q ,k ,v ,m) + + // up_concat will be [21504, R*4, 1, 1] + // down_concat will be [R*4, 3072, 1, 1] + + ggml_tensor* lora_down_concat = ggml_concat(compute_ctx, ggml_concat(compute_ctx, lora_q_down, lora_k_down, 1), ggml_concat(compute_ctx, lora_v_down, lora_m_down, 1), 1); + // print_ggml_tensor(lora_down_concat, true); //[3072, R*4, 1, 1] + + // this also means that if rank is bigger than 672, it is less memory efficient to do it this way (should be fine) + // print_ggml_tensor(lora_q_up, true); //[3072, R, 1, 1] + ggml_tensor* z = ggml_dup_tensor(compute_ctx, lora_q_up); + ggml_tensor* mlp_z = ggml_dup_tensor(compute_ctx, lora_m_up); + ggml_scale(compute_ctx, z, 0); + ggml_scale(compute_ctx, mlp_z, 0); + ggml_tensor* zz = ggml_concat(compute_ctx, z, z, 1); + + ggml_tensor* q_up = ggml_concat(compute_ctx, ggml_concat(compute_ctx, lora_q_up, zz, 1), mlp_z, 1); + ggml_tensor* k_up = ggml_concat(compute_ctx, ggml_concat(compute_ctx, z, lora_k_up, 1), ggml_concat(compute_ctx, z, mlp_z, 1), 1); + ggml_tensor* v_up = ggml_concat(compute_ctx, ggml_concat(compute_ctx, zz, lora_v_up, 1), mlp_z, 1); + ggml_tensor* m_up = ggml_concat(compute_ctx, ggml_concat(compute_ctx, zz, z, 1), lora_m_up, 1); + // print_ggml_tensor(q_up, true); //[R, 21504, 1, 1] + // print_ggml_tensor(k_up, true); //[R, 21504, 1, 1] + // print_ggml_tensor(v_up, true); //[R, 21504, 1, 1] + // print_ggml_tensor(m_up, true); //[R, 21504, 1, 1] + + ggml_tensor* lora_up_concat = ggml_concat(compute_ctx, ggml_concat(compute_ctx, q_up, k_up, 0), ggml_concat(compute_ctx, v_up, m_up, 0), 0); + // print_ggml_tensor(lora_up_concat, true); //[R*4, 21504, 1, 1] + + lora_down = ggml_cont(compute_ctx, lora_down_concat); + lora_up = ggml_cont(compute_ctx, lora_up_concat); + + applied_lora_tensors.insert(split_q_u_name); + applied_lora_tensors.insert(split_k_u_name); + applied_lora_tensors.insert(split_v_u_name); + applied_lora_tensors.insert(split_m_u_name); + + applied_lora_tensors.insert(split_q_d_name); + applied_lora_tensors.insert(split_k_d_name); + applied_lora_tensors.insert(split_v_d_name); + applied_lora_tensors.insert(split_m_d_name); + } + } else { + lora_up_name = fk + lora_ups[type] + ".weight"; + lora_down_name = fk + lora_downs[type] + ".weight"; + lora_mid_name = fk + ".lora_mid.weight"; + + alpha_name = fk + ".alpha"; + scale_name = fk + ".scale"; + + if (lora_tensors.find(lora_up_name) != lora_tensors.end()) { + lora_up = to_f32(compute_ctx, lora_tensors[lora_up_name]); + } + + if (lora_tensors.find(lora_down_name) != lora_tensors.end()) { + lora_down = to_f32(compute_ctx, lora_tensors[lora_down_name]); + } + + if (lora_tensors.find(lora_mid_name) != lora_tensors.end()) { + lora_mid = to_f32(compute_ctx, lora_tensors[lora_mid_name]); + applied_lora_tensors.insert(lora_mid_name); + } + + applied_lora_tensors.insert(lora_up_name); + applied_lora_tensors.insert(lora_down_name); + applied_lora_tensors.insert(alpha_name); + applied_lora_tensors.insert(scale_name); + } + + if (lora_up == NULL || lora_down == NULL) { + continue; + } + // calc_scale + // TODO: .dora_scale? + int64_t rank = lora_down->ne[ggml_n_dims(lora_down) - 1]; + if (lora_tensors.find(scale_name) != lora_tensors.end()) { + scale_value = ggml_backend_tensor_get_f32(lora_tensors[scale_name]); + } else if (lora_tensors.find(alpha_name) != lora_tensors.end()) { + float alpha = ggml_backend_tensor_get_f32(lora_tensors[alpha_name]); + scale_value = alpha / rank; + } + + updown = ggml_merge_lora(compute_ctx, lora_down, lora_up, lora_mid); } scale_value *= multiplier; - - // flat lora tensors to multiply it - int64_t lora_up_rows = lora_up->ne[ggml_n_dims(lora_up) - 1]; - lora_up = ggml_reshape_2d(compute_ctx, lora_up, ggml_nelements(lora_up) / lora_up_rows, lora_up_rows); - auto lora_down_n_dims = ggml_n_dims(lora_down); - // assume n_dims should always be a multiple of 2 (otherwise rank 1 doesn't work) - lora_down_n_dims = (lora_down_n_dims + lora_down_n_dims % 2); - int64_t lora_down_rows = lora_down->ne[lora_down_n_dims - 1]; - lora_down = ggml_reshape_2d(compute_ctx, lora_down, ggml_nelements(lora_down) / lora_down_rows, lora_down_rows); - - // ggml_mul_mat requires tensor b transposed - lora_down = ggml_cont(compute_ctx, ggml_transpose(compute_ctx, lora_down)); - struct ggml_tensor* updown = ggml_mul_mat(compute_ctx, lora_up, lora_down); - updown = ggml_cont(compute_ctx, ggml_transpose(compute_ctx, updown)); - updown = ggml_reshape(compute_ctx, updown, weight); + updown = ggml_reshape(compute_ctx, updown, weight); GGML_ASSERT(ggml_nelements(updown) == ggml_nelements(weight)); updown = ggml_scale_inplace(compute_ctx, updown, scale_value); ggml_tensor* final_weight;