feat: partial LyCORIS support (tucker decomposition for LoCon + LoHa + LoKr) (#577)
This commit is contained in:
parent
3753223982
commit
1be2491dcf
@ -52,6 +52,71 @@
|
||||
#define __STATIC_INLINE__ static inline
|
||||
#endif
|
||||
|
||||
// n-mode trensor-matrix product
|
||||
// example: 2-mode product
|
||||
// A: [ne03, k, ne01, ne00]
|
||||
// B: k rows, m columns => [k, m]
|
||||
// result is [ne03, m, ne01, ne00]
|
||||
__STATIC_INLINE__ struct ggml_tensor* ggml_mul_n_mode(struct ggml_context* ctx, struct ggml_tensor* a, struct ggml_tensor* b, int mode = 0) {
|
||||
// reshape A
|
||||
// swap 0th and nth axis
|
||||
a = ggml_cont(ctx, ggml_permute(ctx, a, mode, mode != 1 ? 1 : 0, mode != 2 ? 2 : 0, mode != 3 ? 3 : 0));
|
||||
int ne1 = a->ne[1];
|
||||
int ne2 = a->ne[2];
|
||||
int ne3 = a->ne[3];
|
||||
// make 2D
|
||||
a = ggml_cont(ctx, ggml_reshape_2d(ctx, a, a->ne[0], (ne3 * ne2 * ne1)));
|
||||
|
||||
struct ggml_tensor* result = ggml_cont(ctx, ggml_transpose(ctx, ggml_mul_mat(ctx, a, b)));
|
||||
|
||||
// reshape output (same shape as a after permutation except first dim)
|
||||
result = ggml_reshape_4d(ctx, result, result->ne[0], ne1, ne2, ne3);
|
||||
// swap back 0th and nth axis
|
||||
result = ggml_permute(ctx, result, mode, mode != 1 ? 1 : 0, mode != 2 ? 2 : 0, mode != 3 ? 3 : 0);
|
||||
return result;
|
||||
}
|
||||
|
||||
__STATIC_INLINE__ struct ggml_tensor* ggml_merge_lora(ggml_context* ctx, struct ggml_tensor* lora_down, struct ggml_tensor* lora_up, struct ggml_tensor* lora_mid = NULL) {
|
||||
struct ggml_tensor* updown;
|
||||
// flat lora tensors to multiply it
|
||||
int64_t lora_up_rows = lora_up->ne[ggml_n_dims(lora_up) - 1];
|
||||
lora_up = ggml_reshape_2d(ctx, lora_up, ggml_nelements(lora_up) / lora_up_rows, lora_up_rows);
|
||||
auto lora_down_n_dims = ggml_n_dims(lora_down);
|
||||
// assume n_dims should always be a multiple of 2 (otherwise rank 1 doesn't work)
|
||||
lora_down_n_dims = (lora_down_n_dims + lora_down_n_dims % 2);
|
||||
int64_t lora_down_rows = lora_down->ne[lora_down_n_dims - 1];
|
||||
lora_down = ggml_reshape_2d(ctx, lora_down, ggml_nelements(lora_down) / lora_down_rows, lora_down_rows);
|
||||
|
||||
// ggml_mul_mat requires tensor b transposed
|
||||
lora_down = ggml_cont(ctx, ggml_transpose(ctx, lora_down));
|
||||
if (lora_mid == NULL) {
|
||||
updown = ggml_mul_mat(ctx, lora_up, lora_down);
|
||||
updown = ggml_cont(ctx, ggml_transpose(ctx, updown));
|
||||
} else {
|
||||
// undoing tucker decomposition for conv layers.
|
||||
// lora_mid has shape (3, 3, Rank, Rank)
|
||||
// lora_down has shape (Rank, In, 1, 1)
|
||||
// lora_up has shape (Rank, Out, 1, 1)
|
||||
// conv layer shape is (3, 3, Out, In)
|
||||
updown = ggml_mul_n_mode(ctx, ggml_mul_n_mode(ctx, lora_mid, lora_down, 3), lora_up, 2);
|
||||
updown = ggml_cont(ctx, updown);
|
||||
}
|
||||
return updown;
|
||||
}
|
||||
|
||||
// Kronecker product
|
||||
// [ne03,ne02,ne01,ne00] x [ne13,ne12,ne11,ne10] => [ne03*ne13,ne02*ne12,ne01*ne11,ne00*ne10]
|
||||
__STATIC_INLINE__ struct ggml_tensor* ggml_kronecker(ggml_context* ctx, struct ggml_tensor* a, struct ggml_tensor* b) {
|
||||
return ggml_mul(ctx,
|
||||
ggml_upscale_ext(ctx,
|
||||
a,
|
||||
a->ne[0] * b->ne[0],
|
||||
a->ne[1] * b->ne[1],
|
||||
a->ne[2] * b->ne[2],
|
||||
a->ne[3] * b->ne[3]),
|
||||
b);
|
||||
}
|
||||
|
||||
__STATIC_INLINE__ void ggml_log_callback_default(ggml_log_level level, const char* text, void* user_data) {
|
||||
(void)level;
|
||||
(void)user_data;
|
||||
@ -319,7 +384,7 @@ __STATIC_INLINE__ void sd_apply_mask(struct ggml_tensor* image_data,
|
||||
for (int iy = 0; iy < height; iy++) {
|
||||
float m = ggml_tensor_get_f32(mask, ix, iy);
|
||||
for (int k = 0; k < channels; k++) {
|
||||
float value = ((float)(m < 254.5/255)) * (ggml_tensor_get_f32(image_data, ix, iy, k) - .5) + .5;
|
||||
float value = ((float)(m < 254.5 / 255)) * (ggml_tensor_get_f32(image_data, ix, iy, k) - .5) + .5;
|
||||
ggml_tensor_set_f32(output, value, ix, iy, k);
|
||||
}
|
||||
}
|
||||
@ -987,8 +1052,8 @@ __STATIC_INLINE__ size_t ggml_tensor_num(ggml_context* ctx) {
|
||||
}
|
||||
|
||||
/* SDXL with LoRA requires more space */
|
||||
#define MAX_PARAMS_TENSOR_NUM 15360
|
||||
#define MAX_GRAPH_SIZE 15360
|
||||
#define MAX_PARAMS_TENSOR_NUM 32768
|
||||
#define MAX_GRAPH_SIZE 32768
|
||||
|
||||
struct GGMLRunner {
|
||||
protected:
|
||||
|
907
lora.hpp
907
lora.hpp
@ -197,6 +197,10 @@ struct LoraModel : public GGMLRunner {
|
||||
blk_name.replace(blk_name.find(".joint_blocks"), sizeof(".joint_blocks") - 1, ".transformer_blocks");
|
||||
}
|
||||
|
||||
if (blk_name.find("text_encoders.clip_l") != std::string::npos) {
|
||||
blk_name.replace(blk_name.find("text_encoders.clip_l"), sizeof("text_encoders.clip_l") - 1, "cond_stage_model");
|
||||
}
|
||||
|
||||
for (const auto& item : alt_names) {
|
||||
size_t match = blk_name.find(item.first);
|
||||
if (match != std::string::npos) {
|
||||
@ -217,13 +221,17 @@ struct LoraModel : public GGMLRunner {
|
||||
keys.push_back(split_blk);
|
||||
}
|
||||
}
|
||||
keys.push_back(blk_name);
|
||||
}
|
||||
keys.push_back(blk_name);
|
||||
|
||||
std::vector<std::string> ret;
|
||||
for (std::string& key : keys) {
|
||||
ret.push_back(key);
|
||||
replace_all_chars(key, '.', '_');
|
||||
// fix for some sdxl lora, like lcm-lora-xl
|
||||
if (key == "model_diffusion_model_output_blocks_2_2_conv") {
|
||||
ret.push_back("model_diffusion_model_output_blocks_2_1_conv");
|
||||
}
|
||||
ret.push_back(key);
|
||||
}
|
||||
return ret;
|
||||
@ -244,390 +252,545 @@ struct LoraModel : public GGMLRunner {
|
||||
std::vector<std::string> keys = to_lora_keys(k_tensor, version);
|
||||
if (keys.size() == 0)
|
||||
continue;
|
||||
ggml_tensor* lora_up = NULL;
|
||||
ggml_tensor* lora_down = NULL;
|
||||
|
||||
for (auto& key : keys) {
|
||||
std::string alpha_name = "";
|
||||
std::string scale_name = "";
|
||||
std::string split_q_scale_name = "";
|
||||
std::string lora_down_name = "";
|
||||
std::string lora_up_name = "";
|
||||
|
||||
if (starts_with(key, "SPLIT|")) {
|
||||
bool is_qkv_split = starts_with(key, "SPLIT|");
|
||||
if (is_qkv_split) {
|
||||
key = key.substr(sizeof("SPLIT|") - 1);
|
||||
// TODO: Handle alphas
|
||||
std::string suffix = "";
|
||||
auto split_q_d_name = lora_pre[type] + key + "q" + suffix + lora_downs[type] + ".weight";
|
||||
|
||||
if (lora_tensors.find(split_q_d_name) == lora_tensors.end()) {
|
||||
suffix = "_proj";
|
||||
split_q_d_name = lora_pre[type] + key + "q" + suffix + lora_downs[type] + ".weight";
|
||||
}
|
||||
if (lora_tensors.find(split_q_d_name) != lora_tensors.end()) {
|
||||
// print_ggml_tensor(it.second, true); //[3072, 21504, 1, 1]
|
||||
// find qkv and mlp up parts in LoRA model
|
||||
auto split_k_d_name = lora_pre[type] + key + "k" + suffix + lora_downs[type] + ".weight";
|
||||
auto split_v_d_name = lora_pre[type] + key + "v" + suffix + lora_downs[type] + ".weight";
|
||||
|
||||
auto split_q_u_name = lora_pre[type] + key + "q" + suffix + lora_ups[type] + ".weight";
|
||||
auto split_k_u_name = lora_pre[type] + key + "k" + suffix + lora_ups[type] + ".weight";
|
||||
auto split_v_u_name = lora_pre[type] + key + "v" + suffix + lora_ups[type] + ".weight";
|
||||
|
||||
auto split_q_scale_name = lora_pre[type] + key + "q" + suffix + ".scale";
|
||||
auto split_k_scale_name = lora_pre[type] + key + "k" + suffix + ".scale";
|
||||
auto split_v_scale_name = lora_pre[type] + key + "v" + suffix + ".scale";
|
||||
|
||||
auto split_q_alpha_name = lora_pre[type] + key + "q" + suffix + ".alpha";
|
||||
auto split_k_alpha_name = lora_pre[type] + key + "k" + suffix + ".alpha";
|
||||
auto split_v_alpha_name = lora_pre[type] + key + "v" + suffix + ".alpha";
|
||||
|
||||
ggml_tensor* lora_q_down = NULL;
|
||||
ggml_tensor* lora_q_up = NULL;
|
||||
ggml_tensor* lora_k_down = NULL;
|
||||
ggml_tensor* lora_k_up = NULL;
|
||||
ggml_tensor* lora_v_down = NULL;
|
||||
ggml_tensor* lora_v_up = NULL;
|
||||
|
||||
lora_q_down = to_f32(compute_ctx, lora_tensors[split_q_d_name]);
|
||||
|
||||
if (lora_tensors.find(split_q_u_name) != lora_tensors.end()) {
|
||||
lora_q_up = to_f32(compute_ctx, lora_tensors[split_q_u_name]);
|
||||
}
|
||||
|
||||
if (lora_tensors.find(split_k_d_name) != lora_tensors.end()) {
|
||||
lora_k_down = to_f32(compute_ctx, lora_tensors[split_k_d_name]);
|
||||
}
|
||||
|
||||
if (lora_tensors.find(split_k_u_name) != lora_tensors.end()) {
|
||||
lora_k_up = to_f32(compute_ctx, lora_tensors[split_k_u_name]);
|
||||
}
|
||||
|
||||
if (lora_tensors.find(split_v_d_name) != lora_tensors.end()) {
|
||||
lora_v_down = to_f32(compute_ctx, lora_tensors[split_v_d_name]);
|
||||
}
|
||||
|
||||
if (lora_tensors.find(split_v_u_name) != lora_tensors.end()) {
|
||||
lora_v_up = to_f32(compute_ctx, lora_tensors[split_v_u_name]);
|
||||
}
|
||||
|
||||
float q_rank = lora_q_up->ne[0];
|
||||
float k_rank = lora_k_up->ne[0];
|
||||
float v_rank = lora_v_up->ne[0];
|
||||
|
||||
float lora_q_scale = 1;
|
||||
float lora_k_scale = 1;
|
||||
float lora_v_scale = 1;
|
||||
|
||||
if (lora_tensors.find(split_q_scale_name) != lora_tensors.end()) {
|
||||
lora_q_scale = ggml_backend_tensor_get_f32(lora_tensors[split_q_scale_name]);
|
||||
applied_lora_tensors.insert(split_q_scale_name);
|
||||
}
|
||||
if (lora_tensors.find(split_k_scale_name) != lora_tensors.end()) {
|
||||
lora_k_scale = ggml_backend_tensor_get_f32(lora_tensors[split_k_scale_name]);
|
||||
applied_lora_tensors.insert(split_k_scale_name);
|
||||
}
|
||||
if (lora_tensors.find(split_v_scale_name) != lora_tensors.end()) {
|
||||
lora_v_scale = ggml_backend_tensor_get_f32(lora_tensors[split_v_scale_name]);
|
||||
applied_lora_tensors.insert(split_v_scale_name);
|
||||
}
|
||||
|
||||
if (lora_tensors.find(split_q_alpha_name) != lora_tensors.end()) {
|
||||
float lora_q_alpha = ggml_backend_tensor_get_f32(lora_tensors[split_q_alpha_name]);
|
||||
applied_lora_tensors.insert(split_q_alpha_name);
|
||||
lora_q_scale = lora_q_alpha / q_rank;
|
||||
}
|
||||
if (lora_tensors.find(split_k_alpha_name) != lora_tensors.end()) {
|
||||
float lora_k_alpha = ggml_backend_tensor_get_f32(lora_tensors[split_k_alpha_name]);
|
||||
applied_lora_tensors.insert(split_k_alpha_name);
|
||||
lora_k_scale = lora_k_alpha / k_rank;
|
||||
}
|
||||
if (lora_tensors.find(split_v_alpha_name) != lora_tensors.end()) {
|
||||
float lora_v_alpha = ggml_backend_tensor_get_f32(lora_tensors[split_v_alpha_name]);
|
||||
applied_lora_tensors.insert(split_v_alpha_name);
|
||||
lora_v_scale = lora_v_alpha / v_rank;
|
||||
}
|
||||
|
||||
ggml_scale_inplace(compute_ctx, lora_q_down, lora_q_scale);
|
||||
ggml_scale_inplace(compute_ctx, lora_k_down, lora_k_scale);
|
||||
ggml_scale_inplace(compute_ctx, lora_v_down, lora_v_scale);
|
||||
|
||||
// print_ggml_tensor(lora_q_down, true); //[3072, R, 1, 1]
|
||||
// print_ggml_tensor(lora_k_down, true); //[3072, R, 1, 1]
|
||||
// print_ggml_tensor(lora_v_down, true); //[3072, R, 1, 1]
|
||||
// print_ggml_tensor(lora_q_up, true); //[R, 3072, 1, 1]
|
||||
// print_ggml_tensor(lora_k_up, true); //[R, 3072, 1, 1]
|
||||
// print_ggml_tensor(lora_v_up, true); //[R, 3072, 1, 1]
|
||||
|
||||
// these need to be stitched together this way:
|
||||
// |q_up,0 ,0 |
|
||||
// |0 ,k_up,0 |
|
||||
// |0 ,0 ,v_up|
|
||||
// (q_down,k_down,v_down) . (q ,k ,v)
|
||||
|
||||
// up_concat will be [9216, R*3, 1, 1]
|
||||
// down_concat will be [R*3, 3072, 1, 1]
|
||||
ggml_tensor* lora_down_concat = ggml_concat(compute_ctx, ggml_concat(compute_ctx, lora_q_down, lora_k_down, 1), lora_v_down, 1);
|
||||
|
||||
ggml_tensor* z = ggml_dup_tensor(compute_ctx, lora_q_up);
|
||||
ggml_scale(compute_ctx, z, 0);
|
||||
ggml_tensor* zz = ggml_concat(compute_ctx, z, z, 1);
|
||||
|
||||
ggml_tensor* q_up = ggml_concat(compute_ctx, lora_q_up, zz, 1);
|
||||
ggml_tensor* k_up = ggml_concat(compute_ctx, ggml_concat(compute_ctx, z, lora_k_up, 1), z, 1);
|
||||
ggml_tensor* v_up = ggml_concat(compute_ctx, zz, lora_v_up, 1);
|
||||
// print_ggml_tensor(q_up, true); //[R, 9216, 1, 1]
|
||||
// print_ggml_tensor(k_up, true); //[R, 9216, 1, 1]
|
||||
// print_ggml_tensor(v_up, true); //[R, 9216, 1, 1]
|
||||
ggml_tensor* lora_up_concat = ggml_concat(compute_ctx, ggml_concat(compute_ctx, q_up, k_up, 0), v_up, 0);
|
||||
// print_ggml_tensor(lora_up_concat, true); //[R*3, 9216, 1, 1]
|
||||
|
||||
lora_down = ggml_cont(compute_ctx, lora_down_concat);
|
||||
lora_up = ggml_cont(compute_ctx, lora_up_concat);
|
||||
|
||||
applied_lora_tensors.insert(split_q_u_name);
|
||||
applied_lora_tensors.insert(split_k_u_name);
|
||||
applied_lora_tensors.insert(split_v_u_name);
|
||||
|
||||
applied_lora_tensors.insert(split_q_d_name);
|
||||
applied_lora_tensors.insert(split_k_d_name);
|
||||
applied_lora_tensors.insert(split_v_d_name);
|
||||
}
|
||||
}
|
||||
if (starts_with(key, "SPLIT_L|")) {
|
||||
bool is_qkvm_split = starts_with(key, "SPLIT_L|");
|
||||
if (is_qkvm_split) {
|
||||
key = key.substr(sizeof("SPLIT_L|") - 1);
|
||||
|
||||
auto split_q_d_name = lora_pre[type] + key + "attn.to_q" + lora_downs[type] + ".weight";
|
||||
if (lora_tensors.find(split_q_d_name) != lora_tensors.end()) {
|
||||
// print_ggml_tensor(it.second, true); //[3072, 21504, 1, 1]
|
||||
// find qkv and mlp up parts in LoRA model
|
||||
auto split_k_d_name = lora_pre[type] + key + "attn.to_k" + lora_downs[type] + ".weight";
|
||||
auto split_v_d_name = lora_pre[type] + key + "attn.to_v" + lora_downs[type] + ".weight";
|
||||
|
||||
auto split_q_u_name = lora_pre[type] + key + "attn.to_q" + lora_ups[type] + ".weight";
|
||||
auto split_k_u_name = lora_pre[type] + key + "attn.to_k" + lora_ups[type] + ".weight";
|
||||
auto split_v_u_name = lora_pre[type] + key + "attn.to_v" + lora_ups[type] + ".weight";
|
||||
|
||||
auto split_m_d_name = lora_pre[type] + key + "proj_mlp" + lora_downs[type] + ".weight";
|
||||
auto split_m_u_name = lora_pre[type] + key + "proj_mlp" + lora_ups[type] + ".weight";
|
||||
|
||||
auto split_q_scale_name = lora_pre[type] + key + "attn.to_q" + ".scale";
|
||||
auto split_k_scale_name = lora_pre[type] + key + "attn.to_k" + ".scale";
|
||||
auto split_v_scale_name = lora_pre[type] + key + "attn.to_v" + ".scale";
|
||||
auto split_m_scale_name = lora_pre[type] + key + "proj_mlp" + ".scale";
|
||||
|
||||
auto split_q_alpha_name = lora_pre[type] + key + "attn.to_q" + ".alpha";
|
||||
auto split_k_alpha_name = lora_pre[type] + key + "attn.to_k" + ".alpha";
|
||||
auto split_v_alpha_name = lora_pre[type] + key + "attn.to_v" + ".alpha";
|
||||
auto split_m_alpha_name = lora_pre[type] + key + "proj_mlp" + ".alpha";
|
||||
|
||||
ggml_tensor* lora_q_down = NULL;
|
||||
ggml_tensor* lora_q_up = NULL;
|
||||
ggml_tensor* lora_k_down = NULL;
|
||||
ggml_tensor* lora_k_up = NULL;
|
||||
ggml_tensor* lora_v_down = NULL;
|
||||
ggml_tensor* lora_v_up = NULL;
|
||||
|
||||
ggml_tensor* lora_m_down = NULL;
|
||||
ggml_tensor* lora_m_up = NULL;
|
||||
|
||||
lora_q_up = to_f32(compute_ctx, lora_tensors[split_q_u_name]);
|
||||
|
||||
if (lora_tensors.find(split_q_d_name) != lora_tensors.end()) {
|
||||
lora_q_down = to_f32(compute_ctx, lora_tensors[split_q_d_name]);
|
||||
}
|
||||
|
||||
if (lora_tensors.find(split_q_u_name) != lora_tensors.end()) {
|
||||
lora_q_up = to_f32(compute_ctx, lora_tensors[split_q_u_name]);
|
||||
}
|
||||
|
||||
if (lora_tensors.find(split_k_d_name) != lora_tensors.end()) {
|
||||
lora_k_down = to_f32(compute_ctx, lora_tensors[split_k_d_name]);
|
||||
}
|
||||
|
||||
if (lora_tensors.find(split_k_u_name) != lora_tensors.end()) {
|
||||
lora_k_up = to_f32(compute_ctx, lora_tensors[split_k_u_name]);
|
||||
}
|
||||
|
||||
if (lora_tensors.find(split_v_d_name) != lora_tensors.end()) {
|
||||
lora_v_down = to_f32(compute_ctx, lora_tensors[split_v_d_name]);
|
||||
}
|
||||
|
||||
if (lora_tensors.find(split_v_u_name) != lora_tensors.end()) {
|
||||
lora_v_up = to_f32(compute_ctx, lora_tensors[split_v_u_name]);
|
||||
}
|
||||
|
||||
if (lora_tensors.find(split_m_d_name) != lora_tensors.end()) {
|
||||
lora_m_down = to_f32(compute_ctx, lora_tensors[split_m_d_name]);
|
||||
}
|
||||
|
||||
if (lora_tensors.find(split_m_u_name) != lora_tensors.end()) {
|
||||
lora_m_up = to_f32(compute_ctx, lora_tensors[split_m_u_name]);
|
||||
}
|
||||
|
||||
float q_rank = lora_q_up->ne[0];
|
||||
float k_rank = lora_k_up->ne[0];
|
||||
float v_rank = lora_v_up->ne[0];
|
||||
float m_rank = lora_v_up->ne[0];
|
||||
|
||||
float lora_q_scale = 1;
|
||||
float lora_k_scale = 1;
|
||||
float lora_v_scale = 1;
|
||||
float lora_m_scale = 1;
|
||||
|
||||
if (lora_tensors.find(split_q_scale_name) != lora_tensors.end()) {
|
||||
lora_q_scale = ggml_backend_tensor_get_f32(lora_tensors[split_q_scale_name]);
|
||||
applied_lora_tensors.insert(split_q_scale_name);
|
||||
}
|
||||
if (lora_tensors.find(split_k_scale_name) != lora_tensors.end()) {
|
||||
lora_k_scale = ggml_backend_tensor_get_f32(lora_tensors[split_k_scale_name]);
|
||||
applied_lora_tensors.insert(split_k_scale_name);
|
||||
}
|
||||
if (lora_tensors.find(split_v_scale_name) != lora_tensors.end()) {
|
||||
lora_v_scale = ggml_backend_tensor_get_f32(lora_tensors[split_v_scale_name]);
|
||||
applied_lora_tensors.insert(split_v_scale_name);
|
||||
}
|
||||
if (lora_tensors.find(split_m_scale_name) != lora_tensors.end()) {
|
||||
lora_m_scale = ggml_backend_tensor_get_f32(lora_tensors[split_m_scale_name]);
|
||||
applied_lora_tensors.insert(split_m_scale_name);
|
||||
}
|
||||
|
||||
if (lora_tensors.find(split_q_alpha_name) != lora_tensors.end()) {
|
||||
float lora_q_alpha = ggml_backend_tensor_get_f32(lora_tensors[split_q_alpha_name]);
|
||||
applied_lora_tensors.insert(split_q_alpha_name);
|
||||
lora_q_scale = lora_q_alpha / q_rank;
|
||||
}
|
||||
if (lora_tensors.find(split_k_alpha_name) != lora_tensors.end()) {
|
||||
float lora_k_alpha = ggml_backend_tensor_get_f32(lora_tensors[split_k_alpha_name]);
|
||||
applied_lora_tensors.insert(split_k_alpha_name);
|
||||
lora_k_scale = lora_k_alpha / k_rank;
|
||||
}
|
||||
if (lora_tensors.find(split_v_alpha_name) != lora_tensors.end()) {
|
||||
float lora_v_alpha = ggml_backend_tensor_get_f32(lora_tensors[split_v_alpha_name]);
|
||||
applied_lora_tensors.insert(split_v_alpha_name);
|
||||
lora_v_scale = lora_v_alpha / v_rank;
|
||||
}
|
||||
if (lora_tensors.find(split_m_alpha_name) != lora_tensors.end()) {
|
||||
float lora_m_alpha = ggml_backend_tensor_get_f32(lora_tensors[split_m_alpha_name]);
|
||||
applied_lora_tensors.insert(split_m_alpha_name);
|
||||
lora_m_scale = lora_m_alpha / m_rank;
|
||||
}
|
||||
|
||||
ggml_scale_inplace(compute_ctx, lora_q_down, lora_q_scale);
|
||||
ggml_scale_inplace(compute_ctx, lora_k_down, lora_k_scale);
|
||||
ggml_scale_inplace(compute_ctx, lora_v_down, lora_v_scale);
|
||||
ggml_scale_inplace(compute_ctx, lora_m_down, lora_m_scale);
|
||||
|
||||
// print_ggml_tensor(lora_q_down, true); //[3072, R, 1, 1]
|
||||
// print_ggml_tensor(lora_k_down, true); //[3072, R, 1, 1]
|
||||
// print_ggml_tensor(lora_v_down, true); //[3072, R, 1, 1]
|
||||
// print_ggml_tensor(lora_m_down, true); //[3072, R, 1, 1]
|
||||
// print_ggml_tensor(lora_q_up, true); //[R, 3072, 1, 1]
|
||||
// print_ggml_tensor(lora_k_up, true); //[R, 3072, 1, 1]
|
||||
// print_ggml_tensor(lora_v_up, true); //[R, 3072, 1, 1]
|
||||
// print_ggml_tensor(lora_m_up, true); //[R, 12288, 1, 1]
|
||||
|
||||
// these need to be stitched together this way:
|
||||
// |q_up,0 ,0 ,0 |
|
||||
// |0 ,k_up,0 ,0 |
|
||||
// |0 ,0 ,v_up,0 |
|
||||
// |0 ,0 ,0 ,m_up|
|
||||
// (q_down,k_down,v_down,m_down) . (q ,k ,v ,m)
|
||||
|
||||
// up_concat will be [21504, R*4, 1, 1]
|
||||
// down_concat will be [R*4, 3072, 1, 1]
|
||||
|
||||
ggml_tensor* lora_down_concat = ggml_concat(compute_ctx, ggml_concat(compute_ctx, lora_q_down, lora_k_down, 1), ggml_concat(compute_ctx, lora_v_down, lora_m_down, 1), 1);
|
||||
// print_ggml_tensor(lora_down_concat, true); //[3072, R*4, 1, 1]
|
||||
|
||||
// this also means that if rank is bigger than 672, it is less memory efficient to do it this way (should be fine)
|
||||
// print_ggml_tensor(lora_q_up, true); //[3072, R, 1, 1]
|
||||
ggml_tensor* z = ggml_dup_tensor(compute_ctx, lora_q_up);
|
||||
ggml_tensor* mlp_z = ggml_dup_tensor(compute_ctx, lora_m_up);
|
||||
ggml_scale(compute_ctx, z, 0);
|
||||
ggml_scale(compute_ctx, mlp_z, 0);
|
||||
ggml_tensor* zz = ggml_concat(compute_ctx, z, z, 1);
|
||||
|
||||
ggml_tensor* q_up = ggml_concat(compute_ctx, ggml_concat(compute_ctx, lora_q_up, zz, 1), mlp_z, 1);
|
||||
ggml_tensor* k_up = ggml_concat(compute_ctx, ggml_concat(compute_ctx, z, lora_k_up, 1), ggml_concat(compute_ctx, z, mlp_z, 1), 1);
|
||||
ggml_tensor* v_up = ggml_concat(compute_ctx, ggml_concat(compute_ctx, zz, lora_v_up, 1), mlp_z, 1);
|
||||
ggml_tensor* m_up = ggml_concat(compute_ctx, ggml_concat(compute_ctx, zz, z, 1), lora_m_up, 1);
|
||||
// print_ggml_tensor(q_up, true); //[R, 21504, 1, 1]
|
||||
// print_ggml_tensor(k_up, true); //[R, 21504, 1, 1]
|
||||
// print_ggml_tensor(v_up, true); //[R, 21504, 1, 1]
|
||||
// print_ggml_tensor(m_up, true); //[R, 21504, 1, 1]
|
||||
|
||||
ggml_tensor* lora_up_concat = ggml_concat(compute_ctx, ggml_concat(compute_ctx, q_up, k_up, 0), ggml_concat(compute_ctx, v_up, m_up, 0), 0);
|
||||
// print_ggml_tensor(lora_up_concat, true); //[R*4, 21504, 1, 1]
|
||||
|
||||
lora_down = ggml_cont(compute_ctx, lora_down_concat);
|
||||
lora_up = ggml_cont(compute_ctx, lora_up_concat);
|
||||
|
||||
applied_lora_tensors.insert(split_q_u_name);
|
||||
applied_lora_tensors.insert(split_k_u_name);
|
||||
applied_lora_tensors.insert(split_v_u_name);
|
||||
applied_lora_tensors.insert(split_m_u_name);
|
||||
|
||||
applied_lora_tensors.insert(split_q_d_name);
|
||||
applied_lora_tensors.insert(split_k_d_name);
|
||||
applied_lora_tensors.insert(split_v_d_name);
|
||||
applied_lora_tensors.insert(split_m_d_name);
|
||||
}
|
||||
}
|
||||
if (lora_up == NULL || lora_down == NULL) {
|
||||
lora_up_name = lora_pre[type] + key + lora_ups[type] + ".weight";
|
||||
if (lora_tensors.find(lora_up_name) == lora_tensors.end()) {
|
||||
if (key == "model_diffusion_model_output_blocks_2_2_conv") {
|
||||
// fix for some sdxl lora, like lcm-lora-xl
|
||||
key = "model_diffusion_model_output_blocks_2_1_conv";
|
||||
lora_up_name = lora_pre[type] + key + lora_ups[type] + ".weight";
|
||||
}
|
||||
struct ggml_tensor* updown = NULL;
|
||||
float scale_value = 1.0f;
|
||||
std::string fk = lora_pre[type] + key;
|
||||
if (lora_tensors.find(fk + ".hada_w1_a") != lora_tensors.end()) {
|
||||
// LoHa mode
|
||||
|
||||
// TODO: split qkv convention for LoHas (is it ever used?)
|
||||
if (is_qkv_split || is_qkvm_split) {
|
||||
LOG_ERROR("Split qkv isn't supported for LoHa models.");
|
||||
break;
|
||||
}
|
||||
std::string alpha_name = "";
|
||||
|
||||
ggml_tensor* hada_1_mid = NULL; // tau for tucker decomposition
|
||||
ggml_tensor* hada_1_up = NULL;
|
||||
ggml_tensor* hada_1_down = NULL;
|
||||
|
||||
ggml_tensor* hada_2_mid = NULL; // tau for tucker decomposition
|
||||
ggml_tensor* hada_2_up = NULL;
|
||||
ggml_tensor* hada_2_down = NULL;
|
||||
|
||||
std::string hada_1_mid_name = "";
|
||||
std::string hada_1_down_name = "";
|
||||
std::string hada_1_up_name = "";
|
||||
|
||||
std::string hada_2_mid_name = "";
|
||||
std::string hada_2_down_name = "";
|
||||
std::string hada_2_up_name = "";
|
||||
|
||||
|
||||
hada_1_down_name = fk + ".hada_w1_b";
|
||||
hada_1_up_name = fk + ".hada_w1_a";
|
||||
hada_1_mid_name = fk + ".hada_t1";
|
||||
if (lora_tensors.find(hada_1_down_name) != lora_tensors.end()) {
|
||||
hada_1_down = to_f32(compute_ctx, lora_tensors[hada_1_down_name]);
|
||||
}
|
||||
if (lora_tensors.find(hada_1_up_name) != lora_tensors.end()) {
|
||||
hada_1_up = to_f32(compute_ctx, lora_tensors[hada_1_up_name]);
|
||||
}
|
||||
if (lora_tensors.find(hada_1_mid_name) != lora_tensors.end()) {
|
||||
hada_1_mid = to_f32(compute_ctx, lora_tensors[hada_1_mid_name]);
|
||||
applied_lora_tensors.insert(hada_1_mid_name);
|
||||
hada_1_up = ggml_cont(compute_ctx, ggml_transpose(compute_ctx, hada_1_up));
|
||||
}
|
||||
|
||||
lora_down_name = lora_pre[type] + key + lora_downs[type] + ".weight";
|
||||
alpha_name = lora_pre[type] + key + ".alpha";
|
||||
scale_name = lora_pre[type] + key + ".scale";
|
||||
|
||||
if (lora_tensors.find(lora_up_name) != lora_tensors.end()) {
|
||||
lora_up = lora_tensors[lora_up_name];
|
||||
hada_2_down_name = fk + ".hada_w2_b";
|
||||
hada_2_up_name = fk + ".hada_w2_a";
|
||||
hada_2_mid_name = fk + ".hada_t2";
|
||||
if (lora_tensors.find(hada_2_down_name) != lora_tensors.end()) {
|
||||
hada_2_down = to_f32(compute_ctx, lora_tensors[hada_2_down_name]);
|
||||
}
|
||||
if (lora_tensors.find(hada_2_up_name) != lora_tensors.end()) {
|
||||
hada_2_up = to_f32(compute_ctx, lora_tensors[hada_2_up_name]);
|
||||
}
|
||||
if (lora_tensors.find(hada_2_mid_name) != lora_tensors.end()) {
|
||||
hada_2_mid = to_f32(compute_ctx, lora_tensors[hada_2_mid_name]);
|
||||
applied_lora_tensors.insert(hada_2_mid_name);
|
||||
hada_2_up = ggml_cont(compute_ctx, ggml_transpose(compute_ctx, hada_2_up));
|
||||
}
|
||||
|
||||
if (lora_tensors.find(lora_down_name) != lora_tensors.end()) {
|
||||
lora_down = lora_tensors[lora_down_name];
|
||||
}
|
||||
applied_lora_tensors.insert(lora_up_name);
|
||||
applied_lora_tensors.insert(lora_down_name);
|
||||
alpha_name = fk + ".alpha";
|
||||
|
||||
applied_lora_tensors.insert(hada_1_down_name);
|
||||
applied_lora_tensors.insert(hada_1_up_name);
|
||||
applied_lora_tensors.insert(hada_2_down_name);
|
||||
applied_lora_tensors.insert(hada_2_up_name);
|
||||
|
||||
applied_lora_tensors.insert(alpha_name);
|
||||
applied_lora_tensors.insert(scale_name);
|
||||
}
|
||||
if (hada_1_up == NULL || hada_1_down == NULL || hada_2_up == NULL || hada_2_down == NULL) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (lora_up == NULL || lora_down == NULL) {
|
||||
continue;
|
||||
}
|
||||
// calc_scale
|
||||
int64_t dim = lora_down->ne[ggml_n_dims(lora_down) - 1];
|
||||
float scale_value = 1.0f;
|
||||
if (lora_tensors.find(scale_name) != lora_tensors.end()) {
|
||||
scale_value = ggml_backend_tensor_get_f32(lora_tensors[scale_name]);
|
||||
} else if (lora_tensors.find(alpha_name) != lora_tensors.end()) {
|
||||
float alpha = ggml_backend_tensor_get_f32(lora_tensors[alpha_name]);
|
||||
scale_value = alpha / dim;
|
||||
struct ggml_tensor* updown_1 = ggml_merge_lora(compute_ctx, hada_1_down, hada_1_up, hada_1_mid);
|
||||
struct ggml_tensor* updown_2 = ggml_merge_lora(compute_ctx, hada_2_down, hada_2_up, hada_2_mid);
|
||||
updown = ggml_mul_inplace(compute_ctx, updown_1, updown_2);
|
||||
|
||||
// calc_scale
|
||||
// TODO: .dora_scale?
|
||||
int64_t rank = hada_1_down->ne[ggml_n_dims(hada_1_down) - 1];
|
||||
if (lora_tensors.find(alpha_name) != lora_tensors.end()) {
|
||||
float alpha = ggml_backend_tensor_get_f32(lora_tensors[alpha_name]);
|
||||
scale_value = alpha / rank;
|
||||
}
|
||||
} else if (lora_tensors.find(fk + ".lokr_w1") != lora_tensors.end() || lora_tensors.find(fk + ".lokr_w1_a") != lora_tensors.end()) {
|
||||
// LoKr mode
|
||||
|
||||
// TODO: split qkv convention for LoKrs (is it ever used?)
|
||||
if (is_qkv_split || is_qkvm_split) {
|
||||
LOG_ERROR("Split qkv isn't supported for LoKr models.");
|
||||
break;
|
||||
}
|
||||
|
||||
std::string alpha_name = fk + ".alpha";
|
||||
|
||||
ggml_tensor* lokr_w1 = NULL;
|
||||
ggml_tensor* lokr_w2 = NULL;
|
||||
|
||||
std::string lokr_w1_name = "";
|
||||
std::string lokr_w2_name = "";
|
||||
|
||||
lokr_w1_name = fk + ".lokr_w1";
|
||||
lokr_w2_name = fk + ".lokr_w2";
|
||||
|
||||
if (lora_tensors.find(lokr_w1_name) != lora_tensors.end()) {
|
||||
lokr_w1 = to_f32(compute_ctx, lora_tensors[lokr_w1_name]);
|
||||
applied_lora_tensors.insert(lokr_w1_name);
|
||||
} else {
|
||||
ggml_tensor* down = NULL;
|
||||
ggml_tensor* up = NULL;
|
||||
std::string down_name = lokr_w1_name + "_b";
|
||||
std::string up_name = lokr_w1_name + "_a";
|
||||
if (lora_tensors.find(down_name) != lora_tensors.end()) {
|
||||
// w1 should not be low rank normally, sometimes w1 and w2 are swapped
|
||||
down = to_f32(compute_ctx, lora_tensors[down_name]);
|
||||
applied_lora_tensors.insert(down_name);
|
||||
|
||||
int64_t rank = down->ne[ggml_n_dims(down) - 1];
|
||||
if (lora_tensors.find(alpha_name) != lora_tensors.end()) {
|
||||
float alpha = ggml_backend_tensor_get_f32(lora_tensors[alpha_name]);
|
||||
scale_value = alpha / rank;
|
||||
}
|
||||
}
|
||||
if (lora_tensors.find(up_name) != lora_tensors.end()) {
|
||||
up = to_f32(compute_ctx, lora_tensors[up_name]);
|
||||
applied_lora_tensors.insert(up_name);
|
||||
}
|
||||
lokr_w1 = ggml_merge_lora(compute_ctx, down, up);
|
||||
}
|
||||
if (lora_tensors.find(lokr_w2_name) != lora_tensors.end()) {
|
||||
lokr_w2 = to_f32(compute_ctx, lora_tensors[lokr_w2_name]);
|
||||
applied_lora_tensors.insert(lokr_w2_name);
|
||||
} else {
|
||||
ggml_tensor* down = NULL;
|
||||
ggml_tensor* up = NULL;
|
||||
std::string down_name = lokr_w2_name + "_b";
|
||||
std::string up_name = lokr_w2_name + "_a";
|
||||
if (lora_tensors.find(down_name) != lora_tensors.end()) {
|
||||
down = to_f32(compute_ctx, lora_tensors[down_name]);
|
||||
applied_lora_tensors.insert(down_name);
|
||||
|
||||
int64_t rank = down->ne[ggml_n_dims(down) - 1];
|
||||
if (lora_tensors.find(alpha_name) != lora_tensors.end()) {
|
||||
float alpha = ggml_backend_tensor_get_f32(lora_tensors[alpha_name]);
|
||||
scale_value = alpha / rank;
|
||||
}
|
||||
}
|
||||
if (lora_tensors.find(up_name) != lora_tensors.end()) {
|
||||
up = to_f32(compute_ctx, lora_tensors[up_name]);
|
||||
applied_lora_tensors.insert(up_name);
|
||||
}
|
||||
lokr_w2 = ggml_merge_lora(compute_ctx, down, up);
|
||||
}
|
||||
|
||||
// Technically it might be unused, but I believe it's the expected behavior
|
||||
applied_lora_tensors.insert(alpha_name);
|
||||
|
||||
updown = ggml_kronecker(compute_ctx, lokr_w1, lokr_w2);
|
||||
|
||||
} else {
|
||||
// LoRA mode
|
||||
ggml_tensor* lora_mid = NULL; // tau for tucker decomposition
|
||||
ggml_tensor* lora_up = NULL;
|
||||
ggml_tensor* lora_down = NULL;
|
||||
|
||||
std::string alpha_name = "";
|
||||
std::string scale_name = "";
|
||||
std::string split_q_scale_name = "";
|
||||
std::string lora_mid_name = "";
|
||||
std::string lora_down_name = "";
|
||||
std::string lora_up_name = "";
|
||||
|
||||
if (is_qkv_split) {
|
||||
std::string suffix = "";
|
||||
auto split_q_d_name = fk + "q" + suffix + lora_downs[type] + ".weight";
|
||||
|
||||
if (lora_tensors.find(split_q_d_name) == lora_tensors.end()) {
|
||||
suffix = "_proj";
|
||||
split_q_d_name = fk + "q" + suffix + lora_downs[type] + ".weight";
|
||||
}
|
||||
if (lora_tensors.find(split_q_d_name) != lora_tensors.end()) {
|
||||
// print_ggml_tensor(it.second, true); //[3072, 21504, 1, 1]
|
||||
// find qkv and mlp up parts in LoRA model
|
||||
auto split_k_d_name = fk + "k" + suffix + lora_downs[type] + ".weight";
|
||||
auto split_v_d_name = fk + "v" + suffix + lora_downs[type] + ".weight";
|
||||
|
||||
auto split_q_u_name = fk + "q" + suffix + lora_ups[type] + ".weight";
|
||||
auto split_k_u_name = fk + "k" + suffix + lora_ups[type] + ".weight";
|
||||
auto split_v_u_name = fk + "v" + suffix + lora_ups[type] + ".weight";
|
||||
|
||||
auto split_q_scale_name = fk + "q" + suffix + ".scale";
|
||||
auto split_k_scale_name = fk + "k" + suffix + ".scale";
|
||||
auto split_v_scale_name = fk + "v" + suffix + ".scale";
|
||||
|
||||
auto split_q_alpha_name = fk + "q" + suffix + ".alpha";
|
||||
auto split_k_alpha_name = fk + "k" + suffix + ".alpha";
|
||||
auto split_v_alpha_name = fk + "v" + suffix + ".alpha";
|
||||
|
||||
ggml_tensor* lora_q_down = NULL;
|
||||
ggml_tensor* lora_q_up = NULL;
|
||||
ggml_tensor* lora_k_down = NULL;
|
||||
ggml_tensor* lora_k_up = NULL;
|
||||
ggml_tensor* lora_v_down = NULL;
|
||||
ggml_tensor* lora_v_up = NULL;
|
||||
|
||||
lora_q_down = to_f32(compute_ctx, lora_tensors[split_q_d_name]);
|
||||
|
||||
if (lora_tensors.find(split_q_u_name) != lora_tensors.end()) {
|
||||
lora_q_up = to_f32(compute_ctx, lora_tensors[split_q_u_name]);
|
||||
}
|
||||
|
||||
if (lora_tensors.find(split_k_d_name) != lora_tensors.end()) {
|
||||
lora_k_down = to_f32(compute_ctx, lora_tensors[split_k_d_name]);
|
||||
}
|
||||
|
||||
if (lora_tensors.find(split_k_u_name) != lora_tensors.end()) {
|
||||
lora_k_up = to_f32(compute_ctx, lora_tensors[split_k_u_name]);
|
||||
}
|
||||
|
||||
if (lora_tensors.find(split_v_d_name) != lora_tensors.end()) {
|
||||
lora_v_down = to_f32(compute_ctx, lora_tensors[split_v_d_name]);
|
||||
}
|
||||
|
||||
if (lora_tensors.find(split_v_u_name) != lora_tensors.end()) {
|
||||
lora_v_up = to_f32(compute_ctx, lora_tensors[split_v_u_name]);
|
||||
}
|
||||
|
||||
float q_rank = lora_q_up->ne[0];
|
||||
float k_rank = lora_k_up->ne[0];
|
||||
float v_rank = lora_v_up->ne[0];
|
||||
|
||||
float lora_q_scale = 1;
|
||||
float lora_k_scale = 1;
|
||||
float lora_v_scale = 1;
|
||||
|
||||
if (lora_tensors.find(split_q_scale_name) != lora_tensors.end()) {
|
||||
lora_q_scale = ggml_backend_tensor_get_f32(lora_tensors[split_q_scale_name]);
|
||||
applied_lora_tensors.insert(split_q_scale_name);
|
||||
}
|
||||
if (lora_tensors.find(split_k_scale_name) != lora_tensors.end()) {
|
||||
lora_k_scale = ggml_backend_tensor_get_f32(lora_tensors[split_k_scale_name]);
|
||||
applied_lora_tensors.insert(split_k_scale_name);
|
||||
}
|
||||
if (lora_tensors.find(split_v_scale_name) != lora_tensors.end()) {
|
||||
lora_v_scale = ggml_backend_tensor_get_f32(lora_tensors[split_v_scale_name]);
|
||||
applied_lora_tensors.insert(split_v_scale_name);
|
||||
}
|
||||
|
||||
if (lora_tensors.find(split_q_alpha_name) != lora_tensors.end()) {
|
||||
float lora_q_alpha = ggml_backend_tensor_get_f32(lora_tensors[split_q_alpha_name]);
|
||||
applied_lora_tensors.insert(split_q_alpha_name);
|
||||
lora_q_scale = lora_q_alpha / q_rank;
|
||||
}
|
||||
if (lora_tensors.find(split_k_alpha_name) != lora_tensors.end()) {
|
||||
float lora_k_alpha = ggml_backend_tensor_get_f32(lora_tensors[split_k_alpha_name]);
|
||||
applied_lora_tensors.insert(split_k_alpha_name);
|
||||
lora_k_scale = lora_k_alpha / k_rank;
|
||||
}
|
||||
if (lora_tensors.find(split_v_alpha_name) != lora_tensors.end()) {
|
||||
float lora_v_alpha = ggml_backend_tensor_get_f32(lora_tensors[split_v_alpha_name]);
|
||||
applied_lora_tensors.insert(split_v_alpha_name);
|
||||
lora_v_scale = lora_v_alpha / v_rank;
|
||||
}
|
||||
|
||||
ggml_scale_inplace(compute_ctx, lora_q_down, lora_q_scale);
|
||||
ggml_scale_inplace(compute_ctx, lora_k_down, lora_k_scale);
|
||||
ggml_scale_inplace(compute_ctx, lora_v_down, lora_v_scale);
|
||||
|
||||
// print_ggml_tensor(lora_q_down, true); //[3072, R, 1, 1]
|
||||
// print_ggml_tensor(lora_k_down, true); //[3072, R, 1, 1]
|
||||
// print_ggml_tensor(lora_v_down, true); //[3072, R, 1, 1]
|
||||
// print_ggml_tensor(lora_q_up, true); //[R, 3072, 1, 1]
|
||||
// print_ggml_tensor(lora_k_up, true); //[R, 3072, 1, 1]
|
||||
// print_ggml_tensor(lora_v_up, true); //[R, 3072, 1, 1]
|
||||
|
||||
// these need to be stitched together this way:
|
||||
// |q_up,0 ,0 |
|
||||
// |0 ,k_up,0 |
|
||||
// |0 ,0 ,v_up|
|
||||
// (q_down,k_down,v_down) . (q ,k ,v)
|
||||
|
||||
// up_concat will be [9216, R*3, 1, 1]
|
||||
// down_concat will be [R*3, 3072, 1, 1]
|
||||
ggml_tensor* lora_down_concat = ggml_concat(compute_ctx, ggml_concat(compute_ctx, lora_q_down, lora_k_down, 1), lora_v_down, 1);
|
||||
|
||||
ggml_tensor* z = ggml_dup_tensor(compute_ctx, lora_q_up);
|
||||
ggml_scale(compute_ctx, z, 0);
|
||||
ggml_tensor* zz = ggml_concat(compute_ctx, z, z, 1);
|
||||
|
||||
ggml_tensor* q_up = ggml_concat(compute_ctx, lora_q_up, zz, 1);
|
||||
ggml_tensor* k_up = ggml_concat(compute_ctx, ggml_concat(compute_ctx, z, lora_k_up, 1), z, 1);
|
||||
ggml_tensor* v_up = ggml_concat(compute_ctx, zz, lora_v_up, 1);
|
||||
// print_ggml_tensor(q_up, true); //[R, 9216, 1, 1]
|
||||
// print_ggml_tensor(k_up, true); //[R, 9216, 1, 1]
|
||||
// print_ggml_tensor(v_up, true); //[R, 9216, 1, 1]
|
||||
ggml_tensor* lora_up_concat = ggml_concat(compute_ctx, ggml_concat(compute_ctx, q_up, k_up, 0), v_up, 0);
|
||||
// print_ggml_tensor(lora_up_concat, true); //[R*3, 9216, 1, 1]
|
||||
|
||||
lora_down = ggml_cont(compute_ctx, lora_down_concat);
|
||||
lora_up = ggml_cont(compute_ctx, lora_up_concat);
|
||||
|
||||
applied_lora_tensors.insert(split_q_u_name);
|
||||
applied_lora_tensors.insert(split_k_u_name);
|
||||
applied_lora_tensors.insert(split_v_u_name);
|
||||
|
||||
applied_lora_tensors.insert(split_q_d_name);
|
||||
applied_lora_tensors.insert(split_k_d_name);
|
||||
applied_lora_tensors.insert(split_v_d_name);
|
||||
}
|
||||
} else if (is_qkvm_split) {
|
||||
auto split_q_d_name = fk + "attn.to_q" + lora_downs[type] + ".weight";
|
||||
if (lora_tensors.find(split_q_d_name) != lora_tensors.end()) {
|
||||
// print_ggml_tensor(it.second, true); //[3072, 21504, 1, 1]
|
||||
// find qkv and mlp up parts in LoRA model
|
||||
auto split_k_d_name = fk + "attn.to_k" + lora_downs[type] + ".weight";
|
||||
auto split_v_d_name = fk + "attn.to_v" + lora_downs[type] + ".weight";
|
||||
|
||||
auto split_q_u_name = fk + "attn.to_q" + lora_ups[type] + ".weight";
|
||||
auto split_k_u_name = fk + "attn.to_k" + lora_ups[type] + ".weight";
|
||||
auto split_v_u_name = fk + "attn.to_v" + lora_ups[type] + ".weight";
|
||||
|
||||
auto split_m_d_name = fk + "proj_mlp" + lora_downs[type] + ".weight";
|
||||
auto split_m_u_name = fk + "proj_mlp" + lora_ups[type] + ".weight";
|
||||
|
||||
auto split_q_scale_name = fk + "attn.to_q" + ".scale";
|
||||
auto split_k_scale_name = fk + "attn.to_k" + ".scale";
|
||||
auto split_v_scale_name = fk + "attn.to_v" + ".scale";
|
||||
auto split_m_scale_name = fk + "proj_mlp" + ".scale";
|
||||
|
||||
auto split_q_alpha_name = fk + "attn.to_q" + ".alpha";
|
||||
auto split_k_alpha_name = fk + "attn.to_k" + ".alpha";
|
||||
auto split_v_alpha_name = fk + "attn.to_v" + ".alpha";
|
||||
auto split_m_alpha_name = fk + "proj_mlp" + ".alpha";
|
||||
|
||||
ggml_tensor* lora_q_down = NULL;
|
||||
ggml_tensor* lora_q_up = NULL;
|
||||
ggml_tensor* lora_k_down = NULL;
|
||||
ggml_tensor* lora_k_up = NULL;
|
||||
ggml_tensor* lora_v_down = NULL;
|
||||
ggml_tensor* lora_v_up = NULL;
|
||||
|
||||
ggml_tensor* lora_m_down = NULL;
|
||||
ggml_tensor* lora_m_up = NULL;
|
||||
|
||||
lora_q_up = to_f32(compute_ctx, lora_tensors[split_q_u_name]);
|
||||
|
||||
if (lora_tensors.find(split_q_d_name) != lora_tensors.end()) {
|
||||
lora_q_down = to_f32(compute_ctx, lora_tensors[split_q_d_name]);
|
||||
}
|
||||
|
||||
if (lora_tensors.find(split_q_u_name) != lora_tensors.end()) {
|
||||
lora_q_up = to_f32(compute_ctx, lora_tensors[split_q_u_name]);
|
||||
}
|
||||
|
||||
if (lora_tensors.find(split_k_d_name) != lora_tensors.end()) {
|
||||
lora_k_down = to_f32(compute_ctx, lora_tensors[split_k_d_name]);
|
||||
}
|
||||
|
||||
if (lora_tensors.find(split_k_u_name) != lora_tensors.end()) {
|
||||
lora_k_up = to_f32(compute_ctx, lora_tensors[split_k_u_name]);
|
||||
}
|
||||
|
||||
if (lora_tensors.find(split_v_d_name) != lora_tensors.end()) {
|
||||
lora_v_down = to_f32(compute_ctx, lora_tensors[split_v_d_name]);
|
||||
}
|
||||
|
||||
if (lora_tensors.find(split_v_u_name) != lora_tensors.end()) {
|
||||
lora_v_up = to_f32(compute_ctx, lora_tensors[split_v_u_name]);
|
||||
}
|
||||
|
||||
if (lora_tensors.find(split_m_d_name) != lora_tensors.end()) {
|
||||
lora_m_down = to_f32(compute_ctx, lora_tensors[split_m_d_name]);
|
||||
}
|
||||
|
||||
if (lora_tensors.find(split_m_u_name) != lora_tensors.end()) {
|
||||
lora_m_up = to_f32(compute_ctx, lora_tensors[split_m_u_name]);
|
||||
}
|
||||
|
||||
float q_rank = lora_q_up->ne[0];
|
||||
float k_rank = lora_k_up->ne[0];
|
||||
float v_rank = lora_v_up->ne[0];
|
||||
float m_rank = lora_v_up->ne[0];
|
||||
|
||||
float lora_q_scale = 1;
|
||||
float lora_k_scale = 1;
|
||||
float lora_v_scale = 1;
|
||||
float lora_m_scale = 1;
|
||||
|
||||
if (lora_tensors.find(split_q_scale_name) != lora_tensors.end()) {
|
||||
lora_q_scale = ggml_backend_tensor_get_f32(lora_tensors[split_q_scale_name]);
|
||||
applied_lora_tensors.insert(split_q_scale_name);
|
||||
}
|
||||
if (lora_tensors.find(split_k_scale_name) != lora_tensors.end()) {
|
||||
lora_k_scale = ggml_backend_tensor_get_f32(lora_tensors[split_k_scale_name]);
|
||||
applied_lora_tensors.insert(split_k_scale_name);
|
||||
}
|
||||
if (lora_tensors.find(split_v_scale_name) != lora_tensors.end()) {
|
||||
lora_v_scale = ggml_backend_tensor_get_f32(lora_tensors[split_v_scale_name]);
|
||||
applied_lora_tensors.insert(split_v_scale_name);
|
||||
}
|
||||
if (lora_tensors.find(split_m_scale_name) != lora_tensors.end()) {
|
||||
lora_m_scale = ggml_backend_tensor_get_f32(lora_tensors[split_m_scale_name]);
|
||||
applied_lora_tensors.insert(split_m_scale_name);
|
||||
}
|
||||
|
||||
if (lora_tensors.find(split_q_alpha_name) != lora_tensors.end()) {
|
||||
float lora_q_alpha = ggml_backend_tensor_get_f32(lora_tensors[split_q_alpha_name]);
|
||||
applied_lora_tensors.insert(split_q_alpha_name);
|
||||
lora_q_scale = lora_q_alpha / q_rank;
|
||||
}
|
||||
if (lora_tensors.find(split_k_alpha_name) != lora_tensors.end()) {
|
||||
float lora_k_alpha = ggml_backend_tensor_get_f32(lora_tensors[split_k_alpha_name]);
|
||||
applied_lora_tensors.insert(split_k_alpha_name);
|
||||
lora_k_scale = lora_k_alpha / k_rank;
|
||||
}
|
||||
if (lora_tensors.find(split_v_alpha_name) != lora_tensors.end()) {
|
||||
float lora_v_alpha = ggml_backend_tensor_get_f32(lora_tensors[split_v_alpha_name]);
|
||||
applied_lora_tensors.insert(split_v_alpha_name);
|
||||
lora_v_scale = lora_v_alpha / v_rank;
|
||||
}
|
||||
if (lora_tensors.find(split_m_alpha_name) != lora_tensors.end()) {
|
||||
float lora_m_alpha = ggml_backend_tensor_get_f32(lora_tensors[split_m_alpha_name]);
|
||||
applied_lora_tensors.insert(split_m_alpha_name);
|
||||
lora_m_scale = lora_m_alpha / m_rank;
|
||||
}
|
||||
|
||||
ggml_scale_inplace(compute_ctx, lora_q_down, lora_q_scale);
|
||||
ggml_scale_inplace(compute_ctx, lora_k_down, lora_k_scale);
|
||||
ggml_scale_inplace(compute_ctx, lora_v_down, lora_v_scale);
|
||||
ggml_scale_inplace(compute_ctx, lora_m_down, lora_m_scale);
|
||||
|
||||
// print_ggml_tensor(lora_q_down, true); //[3072, R, 1, 1]
|
||||
// print_ggml_tensor(lora_k_down, true); //[3072, R, 1, 1]
|
||||
// print_ggml_tensor(lora_v_down, true); //[3072, R, 1, 1]
|
||||
// print_ggml_tensor(lora_m_down, true); //[3072, R, 1, 1]
|
||||
// print_ggml_tensor(lora_q_up, true); //[R, 3072, 1, 1]
|
||||
// print_ggml_tensor(lora_k_up, true); //[R, 3072, 1, 1]
|
||||
// print_ggml_tensor(lora_v_up, true); //[R, 3072, 1, 1]
|
||||
// print_ggml_tensor(lora_m_up, true); //[R, 12288, 1, 1]
|
||||
|
||||
// these need to be stitched together this way:
|
||||
// |q_up,0 ,0 ,0 |
|
||||
// |0 ,k_up,0 ,0 |
|
||||
// |0 ,0 ,v_up,0 |
|
||||
// |0 ,0 ,0 ,m_up|
|
||||
// (q_down,k_down,v_down,m_down) . (q ,k ,v ,m)
|
||||
|
||||
// up_concat will be [21504, R*4, 1, 1]
|
||||
// down_concat will be [R*4, 3072, 1, 1]
|
||||
|
||||
ggml_tensor* lora_down_concat = ggml_concat(compute_ctx, ggml_concat(compute_ctx, lora_q_down, lora_k_down, 1), ggml_concat(compute_ctx, lora_v_down, lora_m_down, 1), 1);
|
||||
// print_ggml_tensor(lora_down_concat, true); //[3072, R*4, 1, 1]
|
||||
|
||||
// this also means that if rank is bigger than 672, it is less memory efficient to do it this way (should be fine)
|
||||
// print_ggml_tensor(lora_q_up, true); //[3072, R, 1, 1]
|
||||
ggml_tensor* z = ggml_dup_tensor(compute_ctx, lora_q_up);
|
||||
ggml_tensor* mlp_z = ggml_dup_tensor(compute_ctx, lora_m_up);
|
||||
ggml_scale(compute_ctx, z, 0);
|
||||
ggml_scale(compute_ctx, mlp_z, 0);
|
||||
ggml_tensor* zz = ggml_concat(compute_ctx, z, z, 1);
|
||||
|
||||
ggml_tensor* q_up = ggml_concat(compute_ctx, ggml_concat(compute_ctx, lora_q_up, zz, 1), mlp_z, 1);
|
||||
ggml_tensor* k_up = ggml_concat(compute_ctx, ggml_concat(compute_ctx, z, lora_k_up, 1), ggml_concat(compute_ctx, z, mlp_z, 1), 1);
|
||||
ggml_tensor* v_up = ggml_concat(compute_ctx, ggml_concat(compute_ctx, zz, lora_v_up, 1), mlp_z, 1);
|
||||
ggml_tensor* m_up = ggml_concat(compute_ctx, ggml_concat(compute_ctx, zz, z, 1), lora_m_up, 1);
|
||||
// print_ggml_tensor(q_up, true); //[R, 21504, 1, 1]
|
||||
// print_ggml_tensor(k_up, true); //[R, 21504, 1, 1]
|
||||
// print_ggml_tensor(v_up, true); //[R, 21504, 1, 1]
|
||||
// print_ggml_tensor(m_up, true); //[R, 21504, 1, 1]
|
||||
|
||||
ggml_tensor* lora_up_concat = ggml_concat(compute_ctx, ggml_concat(compute_ctx, q_up, k_up, 0), ggml_concat(compute_ctx, v_up, m_up, 0), 0);
|
||||
// print_ggml_tensor(lora_up_concat, true); //[R*4, 21504, 1, 1]
|
||||
|
||||
lora_down = ggml_cont(compute_ctx, lora_down_concat);
|
||||
lora_up = ggml_cont(compute_ctx, lora_up_concat);
|
||||
|
||||
applied_lora_tensors.insert(split_q_u_name);
|
||||
applied_lora_tensors.insert(split_k_u_name);
|
||||
applied_lora_tensors.insert(split_v_u_name);
|
||||
applied_lora_tensors.insert(split_m_u_name);
|
||||
|
||||
applied_lora_tensors.insert(split_q_d_name);
|
||||
applied_lora_tensors.insert(split_k_d_name);
|
||||
applied_lora_tensors.insert(split_v_d_name);
|
||||
applied_lora_tensors.insert(split_m_d_name);
|
||||
}
|
||||
} else {
|
||||
lora_up_name = fk + lora_ups[type] + ".weight";
|
||||
lora_down_name = fk + lora_downs[type] + ".weight";
|
||||
lora_mid_name = fk + ".lora_mid.weight";
|
||||
|
||||
alpha_name = fk + ".alpha";
|
||||
scale_name = fk + ".scale";
|
||||
|
||||
if (lora_tensors.find(lora_up_name) != lora_tensors.end()) {
|
||||
lora_up = to_f32(compute_ctx, lora_tensors[lora_up_name]);
|
||||
}
|
||||
|
||||
if (lora_tensors.find(lora_down_name) != lora_tensors.end()) {
|
||||
lora_down = to_f32(compute_ctx, lora_tensors[lora_down_name]);
|
||||
}
|
||||
|
||||
if (lora_tensors.find(lora_mid_name) != lora_tensors.end()) {
|
||||
lora_mid = to_f32(compute_ctx, lora_tensors[lora_mid_name]);
|
||||
applied_lora_tensors.insert(lora_mid_name);
|
||||
}
|
||||
|
||||
applied_lora_tensors.insert(lora_up_name);
|
||||
applied_lora_tensors.insert(lora_down_name);
|
||||
applied_lora_tensors.insert(alpha_name);
|
||||
applied_lora_tensors.insert(scale_name);
|
||||
}
|
||||
|
||||
if (lora_up == NULL || lora_down == NULL) {
|
||||
continue;
|
||||
}
|
||||
// calc_scale
|
||||
// TODO: .dora_scale?
|
||||
int64_t rank = lora_down->ne[ggml_n_dims(lora_down) - 1];
|
||||
if (lora_tensors.find(scale_name) != lora_tensors.end()) {
|
||||
scale_value = ggml_backend_tensor_get_f32(lora_tensors[scale_name]);
|
||||
} else if (lora_tensors.find(alpha_name) != lora_tensors.end()) {
|
||||
float alpha = ggml_backend_tensor_get_f32(lora_tensors[alpha_name]);
|
||||
scale_value = alpha / rank;
|
||||
}
|
||||
|
||||
updown = ggml_merge_lora(compute_ctx, lora_down, lora_up, lora_mid);
|
||||
}
|
||||
scale_value *= multiplier;
|
||||
|
||||
// flat lora tensors to multiply it
|
||||
int64_t lora_up_rows = lora_up->ne[ggml_n_dims(lora_up) - 1];
|
||||
lora_up = ggml_reshape_2d(compute_ctx, lora_up, ggml_nelements(lora_up) / lora_up_rows, lora_up_rows);
|
||||
auto lora_down_n_dims = ggml_n_dims(lora_down);
|
||||
// assume n_dims should always be a multiple of 2 (otherwise rank 1 doesn't work)
|
||||
lora_down_n_dims = (lora_down_n_dims + lora_down_n_dims % 2);
|
||||
int64_t lora_down_rows = lora_down->ne[lora_down_n_dims - 1];
|
||||
lora_down = ggml_reshape_2d(compute_ctx, lora_down, ggml_nelements(lora_down) / lora_down_rows, lora_down_rows);
|
||||
|
||||
// ggml_mul_mat requires tensor b transposed
|
||||
lora_down = ggml_cont(compute_ctx, ggml_transpose(compute_ctx, lora_down));
|
||||
struct ggml_tensor* updown = ggml_mul_mat(compute_ctx, lora_up, lora_down);
|
||||
updown = ggml_cont(compute_ctx, ggml_transpose(compute_ctx, updown));
|
||||
updown = ggml_reshape(compute_ctx, updown, weight);
|
||||
updown = ggml_reshape(compute_ctx, updown, weight);
|
||||
GGML_ASSERT(ggml_nelements(updown) == ggml_nelements(weight));
|
||||
updown = ggml_scale_inplace(compute_ctx, updown, scale_value);
|
||||
ggml_tensor* final_weight;
|
||||
|
Loading…
Reference in New Issue
Block a user