chore: make code cleaner

This commit is contained in:
leejet 2023-12-09 17:35:10 +08:00
parent 2eac844bbd
commit 69efe3ce2b
3 changed files with 145 additions and 248 deletions

View File

@ -1102,6 +1102,7 @@ bool ModelLoader::parse_data_pkl(uint8_t* buffer,
reader.tensor_storage.file_index = file_index;
reader.tensor_storage.name = prefix + reader.tensor_storage.name;
tensor_storages.push_back(reader.tensor_storage);
// LOG_DEBUG("%s", reader.tensor_storage.name.c_str());
// reset
reader = PickleTensorReader();
}
@ -1139,7 +1140,7 @@ bool ModelLoader::init_from_ckpt_file(const std::string& file_path, const std::s
size_t pkl_size;
zip_entry_read(zip, &pkl_data, &pkl_size);
LOG_DEBUG("%lld", pkl_size);
// LOG_DEBUG("%lld", pkl_size);
parse_data_pkl((uint8_t*)pkl_data, pkl_size, zip, dir, file_index, prefix);

View File

@ -7,8 +7,8 @@
#include <string>
#include <vector>
#include "ggml/ggml.h"
#include "ggml/ggml-backend.h"
#include "ggml/ggml.h"
#include "json.hpp"
#include "zip.h"

View File

@ -398,6 +398,64 @@ struct ggml_tensor* ggml_group_norm_32(struct ggml_context* ctx,
return ggml_group_norm(ctx, a, 32);
}
struct ggml_tensor* ggml_nn_linear(struct ggml_context* ctx,
struct ggml_tensor* x,
struct ggml_tensor* w,
struct ggml_tensor* b) {
x = ggml_mul_mat(ctx, w, x);
x = ggml_add(ctx, x, b);
return x;
}
// w: [OCIC, KH, KW]
// x: [N, IC, IH, IW]
// b: [OC,]
// result: [N, OC, OH, OW]
struct ggml_tensor* ggml_nn_conv_2d(struct ggml_context* ctx,
struct ggml_tensor* x,
struct ggml_tensor* w,
struct ggml_tensor* b,
int s0 = 1,
int s1 = 1,
int p0 = 0,
int p1 = 0,
int d0 = 1,
int d1 = 1) {
x = ggml_conv_2d(ctx, w, x, s0, s1, p0, p1, d0, d1);
if (b != NULL) {
b = ggml_reshape_4d(ctx, b, 1, 1, b->ne[0], 1);
x = ggml_add(ctx, x, b);
}
return x;
}
struct ggml_tensor* ggml_nn_layer_norm(struct ggml_context* ctx,
struct ggml_tensor* x,
struct ggml_tensor* w,
struct ggml_tensor* b,
float eps = EPS) {
x = ggml_norm(ctx, x, eps);
x = ggml_mul(ctx, x, w);
x = ggml_add(ctx, x, b);
return x;
}
struct ggml_tensor* ggml_nn_group_norm(struct ggml_context* ctx,
struct ggml_tensor* x,
struct ggml_tensor* w,
struct ggml_tensor* b,
int num_groups = 32) {
if (x->n_dims == 4) {
w = ggml_reshape_4d(ctx, w, 1, 1, w->ne[0], 1);
b = ggml_reshape_4d(ctx, b, 1, 1, b->ne[0], 1);
}
x = ggml_group_norm(ctx, x, num_groups);
x = ggml_mul(ctx, x, w);
x = ggml_add(ctx, x, b);
return x;
}
std::pair<std::unordered_map<std::string, float>, std::string> extract_and_remove_lora(std::string text) {
std::regex re("<lora:([^:]+):([^>]+)>");
std::smatch matches;
@ -749,30 +807,21 @@ struct ResidualAttentionBlock {
struct ggml_tensor* r = x;
// layer norm 1
{
x = ggml_norm(ctx, x, EPS);
x = ggml_add(ctx,
ggml_mul(ctx, x, ln1_w),
ln1_b);
}
x = ggml_nn_layer_norm(ctx, x, ln1_w, ln1_b);
// self-attention
{
struct ggml_tensor* q = ggml_add(ctx,
ggml_mul_mat(ctx, q_w, x),
q_b);
struct ggml_tensor* q = ggml_nn_linear(ctx, x, q_w, q_b);
q = ggml_scale_inplace(ctx, q, attn_scale);
q = ggml_reshape_4d(ctx, q, d_model, n_head, n_token, N); // [N, n_token, n_head, d_model]
q = ggml_cont(ctx, ggml_permute(ctx, q, 0, 2, 1, 3)); // [N, n_head, n_token, d_model]
q = ggml_reshape_3d(ctx, q, d_model, n_token, n_head * N); // [N * n_head, n_token, d_model]
struct ggml_tensor* k = ggml_add(ctx,
ggml_mul_mat(ctx, k_w, x), k_b);
struct ggml_tensor* k = ggml_nn_linear(ctx, x, k_w, k_b);
k = ggml_reshape_4d(ctx, k, d_model, n_head, n_token, N); // [N, n_token, n_head, d_model]
k = ggml_cont(ctx, ggml_permute(ctx, k, 0, 2, 1, 3)); // [N, n_head, n_token, d_model]
k = ggml_reshape_3d(ctx, k, d_model, n_token, n_head); // [N * n_head, n_token, d_model]
struct ggml_tensor* v = ggml_add(ctx,
ggml_mul_mat(ctx, v_w, x), v_b);
struct ggml_tensor* v = ggml_nn_linear(ctx, x, v_w, v_b);
v = ggml_reshape_4d(ctx, v, d_model, n_head, n_token, N); // [N, n_token, n_head, d_model]
v = ggml_cont(ctx, ggml_permute(ctx, v, 1, 2, 0, 3)); // [N, n_head, d_model, n_token]
v = ggml_reshape_3d(ctx, v, n_token, d_model, n_head * N); // [N * n_head, d_model, n_token]
@ -790,24 +839,17 @@ struct ResidualAttentionBlock {
}
// attention output
x = ggml_mul_mat(ctx, out_w, x);
x = ggml_add(ctx, x, out_b);
x = ggml_nn_linear(ctx, x, out_w, out_b);
// residual
x = ggml_add(ctx, x, r);
r = x;
// layer norm 2
{
x = ggml_norm(ctx, x, EPS);
x = ggml_add(ctx, ggml_mul(ctx, x, ln2_w),
ln2_b);
}
x = ggml_nn_layer_norm(ctx, x, ln2_w, ln2_b);
// mlp
x = ggml_mul_mat(ctx, fc1_w, x);
x = ggml_add(ctx, x, fc1_b);
x = ggml_nn_linear(ctx, x, fc1_w, fc1_b);
if (hidden_size == 1024) { // SD 2.x
x = ggml_gelu_inplace(ctx, x);
@ -815,8 +857,7 @@ struct ResidualAttentionBlock {
x = ggml_gelu_quick_inplace(ctx, x);
}
x = ggml_mul_mat(ctx, fc2_w, x);
x = ggml_add(ctx, x, fc2_b);
x = ggml_nn_linear(ctx, x, fc2_w, fc2_b);
// residual 2
x = ggml_add(ctx, x, r);
@ -1004,12 +1045,7 @@ struct CLIPTextModel {
}
// final layer norm
{
x = ggml_norm(ctx0, x, EPS);
x = ggml_add(ctx0, ggml_mul(ctx0, x, final_ln_w),
final_ln_b);
}
x = ggml_nn_layer_norm(ctx0, x, final_ln_w, final_ln_b);
return x; // [N, n_token, hidden_size]
}
@ -1263,48 +1299,29 @@ struct ResBlock {
// emb: [N, emb_channels]
// in_layers
// group norm 32
auto h = ggml_group_norm_32(ctx, x);
h = ggml_add(ctx,
ggml_mul(ctx,
h,
ggml_reshape_4d(ctx, in_layer_0_w, 1, 1, in_layer_0_w->ne[0], 1)),
ggml_reshape_4d(ctx, in_layer_0_b, 1, 1, in_layer_0_b->ne[0], 1));
// silu
h = ggml_silu_inplace(ctx, h);
// conv2d
h = ggml_conv_2d(ctx, in_layer_2_w, h, 1, 1, 1, 1, 1, 1);
h = ggml_add(ctx,
h,
ggml_reshape_4d(ctx, in_layer_2_b, 1, 1, in_layer_2_b->ne[0], 1)); // [N, out_channels, h, w]
auto h = ggml_nn_group_norm(ctx, x, in_layer_0_w, in_layer_0_b);
h = ggml_silu_inplace(ctx, h);
h = ggml_nn_conv_2d(ctx, h, in_layer_2_w, in_layer_2_b, 1, 1, 1, 1); // [N, out_channels, h, w]
// emb_layers
auto emb_out = ggml_silu(ctx, emb);
emb_out = ggml_mul_mat(ctx, emb_layer_1_w, emb_out);
emb_out = ggml_add(ctx, emb_out, emb_layer_1_b); // [N, out_channels]
emb_out = ggml_nn_linear(ctx, emb_out, emb_layer_1_w, emb_layer_1_b); // [N, out_channels]
emb_out = ggml_reshape_4d(ctx, emb_out, 1, 1, emb_out->ne[0], emb_out->ne[1]); // [N, out_channels, 1, 1]
// out_layers
h = ggml_add(ctx, h, emb_out);
// group norm 32
h = ggml_group_norm_inplace(ctx, h, 32);
h = ggml_add(ctx,
ggml_mul(ctx, h, ggml_reshape_4d(ctx, out_layer_0_w, 1, 1, out_layer_0_w->ne[0], 1)),
ggml_reshape_4d(ctx, out_layer_0_b, 1, 1, out_layer_0_b->ne[0], 1));
// silu
h = ggml_nn_group_norm(ctx, h, out_layer_0_w, out_layer_0_b);
h = ggml_silu_inplace(ctx, h);
// dropout, skip for inference
// conv2d
h = ggml_conv_2d(ctx, out_layer_3_w, h, 1, 1, 1, 1, 1, 1);
h = ggml_add(ctx,
h, ggml_reshape_4d(ctx, out_layer_3_b, 1, 1, out_layer_3_b->ne[0], 1)); // [N, out_channels, h, w
h = ggml_nn_conv_2d(ctx, h, out_layer_3_w, out_layer_3_b, 1, 1, 1, 1); // [N, out_channels, h, w]
// skip connection
if (out_channels != channels) {
x = ggml_conv_2d(ctx, skip_w, x, 1, 1, 0, 0, 1, 1);
x = ggml_add(ctx,
x, ggml_reshape_4d(ctx, skip_b, 1, 1, skip_b->ne[0], 1)); // [N, out_channels, h, w]
x = ggml_nn_conv_2d(ctx, x, skip_w, skip_b); // [N, out_channels, h, w]
}
h = ggml_add(ctx, h, x);
return h; // [N, out_channels, h, w]
}
@ -1479,15 +1496,9 @@ struct SpatialTransformer {
// x: [N, in_channels, h, w]
// context: [N, max_position, hidden_size(aka context_dim)]
auto x_in = x;
// group norm 32
x = ggml_group_norm_32(ctx, x);
x = ggml_add(ctx,
ggml_mul(ctx, x, ggml_reshape_4d(ctx, norm_w, 1, 1, norm_w->ne[0], 1)),
ggml_reshape_4d(ctx, norm_b, 1, 1, norm_b->ne[0], 1));
x = ggml_nn_group_norm(ctx, x, norm_w, norm_b);
// proj_in
x = ggml_conv_2d(ctx, proj_in_w, x, 1, 1, 0, 0, 1, 1);
x = ggml_add(ctx,
x, ggml_reshape_4d(ctx, proj_in_b, 1, 1, proj_in_b->ne[0], 1)); // [N, in_channels, h, w]
x = ggml_nn_conv_2d(ctx, x, proj_in_w, proj_in_b); // [N, in_channels, h, w]
// transformer
const int64_t n = x->ne[3];
@ -1500,13 +1511,8 @@ struct SpatialTransformer {
{
auto r = x;
// layer norm 1
{
x = ggml_reshape_2d(ctx, x, c, w * h * n);
x = ggml_norm(ctx, x, EPS);
x = ggml_add(ctx,
ggml_mul(ctx, x, transformer.norm1_w),
transformer.norm1_b);
}
x = ggml_reshape_2d(ctx, x, c, w * h * n);
x = ggml_nn_layer_norm(ctx, x, transformer.norm1_w, transformer.norm1_b);
// self-attention
{
@ -1544,7 +1550,7 @@ struct SpatialTransformer {
// x = ggml_cpy(ctx, kqv, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, d_head * n_head, h * w * n));
x = ggml_reshape_2d(ctx, kqv, d_head * n_head, h * w * n);
x = ggml_add(ctx, ggml_mul_mat(ctx, transformer.attn1_out_w, x), transformer.attn1_out_b);
x = ggml_nn_linear(ctx, x, transformer.attn1_out_w, transformer.attn1_out_b);
x = ggml_reshape_4d(ctx, x, c, w, h, n);
}
@ -1553,11 +1559,7 @@ struct SpatialTransformer {
r = x;
// layer norm 2
{
x = ggml_norm(ctx, x, EPS);
x = ggml_add(ctx,
ggml_mul(ctx, x, transformer.norm2_w), transformer.norm2_b);
}
x = ggml_nn_layer_norm(ctx, x, transformer.norm2_w, transformer.norm2_b);
// cross-attention
{
@ -1595,7 +1597,7 @@ struct SpatialTransformer {
// x = ggml_cpy(ctx, kqv, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, d_head * n_head, h * w * n)); // [N * h * w, in_channels]
x = ggml_reshape_2d(ctx, kqv, d_head * n_head, h * w * n); // [N * h * w, in_channels]
x = ggml_add(ctx, ggml_mul_mat(ctx, transformer.attn2_out_w, x), transformer.attn2_out_b);
x = ggml_nn_linear(ctx, x, transformer.attn2_out_w, transformer.attn2_out_b);
x = ggml_reshape_4d(ctx, x, c, w, h, n);
}
@ -1604,13 +1606,8 @@ struct SpatialTransformer {
r = x;
// layer norm 3
{
x = ggml_reshape_2d(ctx, x, c, h * w * n); // [N * h * w, in_channels]
x = ggml_norm(ctx, x, EPS);
x = ggml_add(ctx,
ggml_mul(ctx, x, transformer.norm3_w),
transformer.norm3_b);
}
x = ggml_reshape_2d(ctx, x, c, h * w * n); // [N * h * w, in_channels]
x = ggml_nn_layer_norm(ctx, x, transformer.norm3_w, transformer.norm3_b);
// ff
{
@ -1637,17 +1634,14 @@ struct SpatialTransformer {
transformer.ff_0_proj_b->nb[0] * transformer.ff_0_proj_b->ne[0] / 2); // [in_channels * 4, ]
x = ggml_reshape_2d(ctx, x, c, w * h * n);
auto x_in = x;
x = ggml_mul_mat(ctx, x_w, x_in); // [N * h * w, in_channels * 4]
x = ggml_add(ctx, x, x_b);
auto gate = ggml_mul_mat(ctx, gate_w, x_in); // [N * h * w, in_channels * 4]
gate = ggml_add(ctx, gate, gate_b);
x = ggml_nn_linear(ctx, x_in, x_w, x_b); // [N * h * w, in_channels * 4]
auto gate = ggml_nn_linear(ctx, x_in, gate_w, gate_b); // [N * h * w, in_channels * 4]
gate = ggml_gelu_inplace(ctx, gate);
x = ggml_mul(ctx, x, gate); // [N * h * w, in_channels * 4]
// fc
x = ggml_mul_mat(ctx, transformer.ff_2_w, x); // [N * h * w, in_channels]
x = ggml_add(ctx, x, transformer.ff_2_b);
x = ggml_nn_linear(ctx, x, transformer.ff_2_w, transformer.ff_2_b); // [N * h * w, in_channels]
}
x = ggml_reshape_4d(ctx, x, c, w, h, n); // [N, h, w, in_channels]
@ -1655,12 +1649,11 @@ struct SpatialTransformer {
// residual
x = ggml_add(ctx, x, r);
}
x = ggml_cont(ctx, ggml_permute(ctx, x, 2, 0, 1, 3)); // // [N, in_channels, h, w]
x = ggml_cont(ctx, ggml_permute(ctx, x, 2, 0, 1, 3)); // [N, in_channels, h, w]
// proj_out
x = ggml_conv_2d(ctx, proj_out_w, x, 1, 1, 0, 0, 1, 1);
x = ggml_add(ctx,
x, ggml_reshape_4d(ctx, proj_out_b, 1, 1, proj_out_b->ne[0], 1)); // [N, in_channels, h, w]
x = ggml_nn_conv_2d(ctx, x, proj_out_w, proj_out_b); // [N, in_channels, h, w]
x = ggml_add(ctx, x, x_in);
return x;
}
@ -1701,17 +1694,14 @@ struct DownSample {
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
// x: [N, channels, h, w]
struct ggml_tensor* c = nullptr;
struct ggml_tensor* c = NULL;
if (vae_downsample) {
c = ggml_pad(ctx, x, 1, 1, 0, 0);
c = ggml_conv_2d(ctx, op_w, c, 2, 2, 0, 0, 1, 1);
c = ggml_nn_conv_2d(ctx, c, op_w, op_b, 2, 2, 0, 0);
} else {
c = ggml_conv_2d(ctx, op_w, x, 2, 2, 1, 1, 1, 1);
c = ggml_nn_conv_2d(ctx, x, op_w, op_b, 2, 2, 1, 1);
}
c = ggml_add(ctx,
c,
ggml_reshape_4d(ctx, op_b, 1, 1, op_b->ne[0], 1)); // [N, out_channels, h/2, w/2]
return c;
return c; // [N, out_channels, h/2, w/2]
}
};
@ -1743,11 +1733,8 @@ struct UpSample {
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
// x: [N, channels, h, w]
x = ggml_upscale(ctx, x, 2); // [N, channels, h*2, w*2]
x = ggml_conv_2d(ctx, conv_w, x, 1, 1, 1, 1, 1, 1);
x = ggml_add(ctx,
x,
ggml_reshape_4d(ctx, conv_b, 1, 1, conv_b->ne[0], 1)); // [N, out_channels, h*2, w*2]
x = ggml_upscale(ctx, x, 2); // [N, channels, h*2, w*2]
x = ggml_nn_conv_2d(ctx, x, conv_w, conv_b, 1, 1, 1, 1); // [N, out_channels, h*2, w*2]
return x;
}
};
@ -2212,14 +2199,10 @@ struct UNetModel {
}
// time_embed = nn.Sequential
auto emb = ggml_nn_linear(ctx0, t_emb, time_embed_0_w, time_embed_0_b);
emb = ggml_silu_inplace(ctx0, emb);
// Linear
auto emb = ggml_mul_mat(ctx0, time_embed_0_w, t_emb);
emb = ggml_add(ctx0, emb, time_embed_0_b);
// nn.SiLU()
emb = ggml_silu_inplace(ctx0, emb);
// Linear
emb = ggml_mul_mat(ctx0, time_embed_2_w, emb);
emb = ggml_add(ctx0, emb, time_embed_2_b); // [N, time_embed_dim]
emb = ggml_nn_linear(ctx0, emb, time_embed_2_w, time_embed_2_b); // [N, time_embed_dim]
// SDXL
// label_emd = nn.Sequential
@ -2227,13 +2210,9 @@ struct UNetModel {
// param y: an [N] Tensor of labels, if class-conditional. (clip g)
// if(y != NULL) {
// auto y_emb = ggml_mul_mat(ctx, label_embed_0_w, y);
// y_emb = ggml_add(ctx, y_emb, label_embed_0_b);
// // nn.SiLU()
// auto y_emb = ggml_nn_linear(ctx, y, label_embed_0_w, label_embed_0_b);
// y_emb = ggml_silu_inplace(ctx, y_emb);
// // Linear
// y_emb = ggml_mul_mat(ctx, label_embed_2_w, y_emb);
// y_emb = ggml_add(ctx, y_emb, label_embed_2_b);
// y_emb = ggml_nn_linear(ctx, y_emb, label_embed_2_w, label_embed_2_b);
// emb = ggml_add(ctx, emb, y_emb);
// }
@ -2241,11 +2220,8 @@ struct UNetModel {
std::vector<struct ggml_tensor*> hs;
// input block 0
struct ggml_tensor* h = ggml_conv_2d(ctx0, input_block_0_w, x, 1, 1, 1, 1, 1, 1); // [N, model_channels, h, w]
struct ggml_tensor* h = ggml_nn_conv_2d(ctx0, x, input_block_0_w, input_block_0_b, 1, 1, 1, 1); // [N, model_channels, h, w]
h = ggml_add(ctx0,
h,
ggml_reshape_4d(ctx0, input_block_0_b, 1, 1, input_block_0_b->ne[0], 1)); // [N, model_channels, h, w]
ggml_set_name(h, "bench-start");
hs.push_back(h);
// input block 1-11
@ -2295,18 +2271,11 @@ struct UNetModel {
}
// out
// group norm 32
h = ggml_group_norm_32(ctx0, h);
h = ggml_add(ctx0,
ggml_mul(ctx0, h, ggml_reshape_4d(ctx0, out_0_w, 1, 1, out_0_w->ne[0], 1)),
ggml_reshape_4d(ctx0, out_0_b, 1, 1, out_0_b->ne[0], 1));
// silu
h = ggml_nn_group_norm(ctx0, h, out_0_w, out_0_b);
h = ggml_silu_inplace(ctx0, h);
// conv2d
h = ggml_conv_2d(ctx0, out_2_w, h, 1, 1, 1, 1, 1, 1);
h = ggml_add(ctx0,
h, ggml_reshape_4d(ctx0, out_2_b, 1, 1, out_2_b->ne[0], 1)); // [N, out_channels, h, w]
h = ggml_nn_conv_2d(ctx0, h, out_2_w, out_2_b, 1, 1, 1, 1); // [N, out_channels, h, w]
ggml_set_name(h, "bench-end");
return h;
}
@ -2503,38 +2472,19 @@ struct ResnetBlock {
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* z) {
// z: [N, in_channels, h, w]
// group norm 32
auto h = ggml_group_norm_32(ctx, z);
h = ggml_mul(ctx,
h, ggml_reshape_4d(ctx, norm1_w, 1, 1, norm1_w->ne[0], 1));
h = ggml_add(ctx,
h, ggml_reshape_4d(ctx, norm1_b, 1, 1, norm1_b->ne[0], 1));
// silu
h = ggml_silu_inplace(ctx, h);
// conv2d
h = ggml_conv_2d(ctx, conv1_w, h, 1, 1, 1, 1, 1, 1);
h = ggml_add(ctx,
h, ggml_reshape_4d(ctx, conv1_b, 1, 1, conv1_b->ne[0], 1)); // [N, out_channels, h, w]
// group norm 32
h = ggml_group_norm_32(ctx, h);
h = ggml_add(ctx,
ggml_mul(ctx, h, ggml_reshape_4d(ctx, norm2_w, 1, 1, norm2_w->ne[0], 1)),
ggml_reshape_4d(ctx, norm2_b, 1, 1, norm2_b->ne[0], 1));
// silu
h = ggml_silu_inplace(ctx, h);
auto h = ggml_nn_group_norm(ctx, z, norm1_w, norm1_b);
h = ggml_silu_inplace(ctx, h);
h = ggml_nn_conv_2d(ctx, h, conv1_w, conv1_b, 1, 1, 1, 1); // [N, out_channels, h, w]
h = ggml_nn_group_norm(ctx, h, norm2_w, norm2_b);
h = ggml_silu_inplace(ctx, h);
// dropout, skip for inference
// conv2d
h = ggml_conv_2d(ctx, conv2_w, h, 1, 1, 1, 1, 1, 1);
h = ggml_add(ctx,
h, ggml_reshape_4d(ctx, conv2_b, 1, 1, conv2_b->ne[0], 1)); // [N, out_channels, h, w
h = ggml_nn_conv_2d(ctx, h, conv2_w, conv2_b, 1, 1, 1, 1); // [N, out_channels, h, w]
// skip connection
if (out_channels != in_channels) {
z = ggml_conv_2d(ctx, nin_shortcut_w, z, 1, 1, 0, 0, 1, 1);
z = ggml_add(ctx,
z, ggml_reshape_4d(ctx, nin_shortcut_b, 1, 1, nin_shortcut_b->ne[0], 1)); // [N, out_channels, h, w]
z = ggml_nn_conv_2d(ctx, z, nin_shortcut_w, nin_shortcut_b); // [N, out_channels, h, w]
}
h = ggml_add(ctx, h, z);
return h; // [N, out_channels, h, w]
}
@ -2604,30 +2554,16 @@ struct AttnBlock {
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
// x: [N, in_channels, h, w]
// group norm 32
auto h_ = ggml_group_norm_32(ctx, x);
h_ = ggml_add(ctx,
ggml_mul(ctx, h_, ggml_reshape_4d(ctx, norm_w, 1, 1, norm_w->ne[0], 1)),
ggml_reshape_4d(ctx, norm_b, 1, 1, norm_b->ne[0], 1));
auto h_ = ggml_nn_group_norm(ctx, x, norm_w, norm_b);
const int64_t n = h_->ne[3];
const int64_t c = h_->ne[2];
const int64_t h = h_->ne[1];
const int64_t w = h_->ne[0];
// q
auto q = ggml_conv_2d(ctx, q_w, h_, 1, 1, 0, 0, 1, 1);
q = ggml_add(ctx,
q, ggml_reshape_4d(ctx, q_b, 1, 1, q_b->ne[0], 1)); // [N, in_channels, h, w]
// k
auto k = ggml_conv_2d(ctx, k_w, h_, 1, 1, 0, 0, 1, 1);
k = ggml_add(ctx,
k, ggml_reshape_4d(ctx, k_b, 1, 1, k_b->ne[0], 1)); // [N, in_channels, h, w]
// v
auto v = ggml_conv_2d(ctx, v_w, h_, 1, 1, 0, 0, 1, 1);
v = ggml_add(ctx,
v, ggml_reshape_4d(ctx, v_b, 1, 1, v_b->ne[0], 1)); // [N, in_channels, h, w]
auto q = ggml_nn_conv_2d(ctx, h_, q_w, q_b); // [N, in_channels, h, w]
auto k = ggml_nn_conv_2d(ctx, h_, k_w, k_b); // [N, in_channels, h, w]
auto v = ggml_nn_conv_2d(ctx, h_, v_w, v_b); // [N, in_channels, h, w]
q = ggml_cont(ctx, ggml_permute(ctx, q, 1, 2, 0, 3)); // [N, h, w, in_channels]
q = ggml_reshape_3d(ctx, q, c, h * w, n); // [N, h * w, in_channels]
@ -2645,9 +2581,8 @@ struct AttnBlock {
h_ = ggml_reshape_4d(ctx, h_, w, h, c, n); // [N, in_channels, h, w]
// proj_out
h_ = ggml_conv_2d(ctx, proj_out_w, h_, 1, 1, 0, 0, 1, 1);
h_ = ggml_add(ctx,
h_, ggml_reshape_4d(ctx, proj_out_b, 1, 1, proj_out_b->ne[0], 1)); // [N, in_channels, h, w]
h_ = ggml_nn_conv_2d(ctx, h_, proj_out_w, proj_out_b); // [N, in_channels, h, w]
h_ = ggml_add(ctx, h_, x);
return h_;
}
@ -2814,9 +2749,7 @@ struct Encoder {
// x: [N, in_channels, h, w]
// conv_in
auto h = ggml_conv_2d(ctx, conv_in_w, x, 1, 1, 1, 1, 1, 1);
h = ggml_add(ctx,
h, ggml_reshape_4d(ctx, conv_in_b, 1, 1, conv_in_b->ne[0], 1)); // [N, ch, h, w]
auto h = ggml_nn_conv_2d(ctx, x, conv_in_w, conv_in_b, 1, 1, 1, 1); // [N, ch, h, w]
ggml_set_name(h, "b-start");
int len_mults = sizeof(ch_mult) / sizeof(int);
for (int i = 0; i < len_mults; i++) {
@ -2832,20 +2765,11 @@ struct Encoder {
h = mid.attn_1.forward(ctx, h);
h = mid.block_2.forward(ctx, h); // [N, block_in, h, w]
// group norm 32
h = ggml_group_norm_32(ctx, h);
h = ggml_add(ctx,
ggml_mul(ctx, h, ggml_reshape_4d(ctx, norm_out_w, 1, 1, norm_out_w->ne[0], 1)),
ggml_reshape_4d(ctx, norm_out_b, 1, 1, norm_out_b->ne[0], 1));
// silu
// silu
h = ggml_nn_group_norm(ctx, h, norm_out_w, norm_out_b);
h = ggml_silu_inplace(ctx, h);
// conv_out
h = ggml_conv_2d(ctx, conv_out_w, h, 1, 1, 1, 1, 1, 1);
h = ggml_add(ctx,
h, ggml_reshape_4d(ctx, conv_out_b, 1, 1, conv_out_b->ne[0], 1)); // [N, z_channels*2, h, w]
h = ggml_nn_conv_2d(ctx, h, conv_out_w, conv_out_b, 1, 1, 1, 1); // [N, z_channels*2, h, w]
return h;
}
@ -3007,9 +2931,7 @@ struct Decoder {
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* z) {
// z: [N, z_channels, h, w]
// conv_in
auto h = ggml_conv_2d(ctx, conv_in_w, z, 1, 1, 1, 1, 1, 1);
h = ggml_add(ctx,
h, ggml_reshape_4d(ctx, conv_in_b, 1, 1, conv_in_b->ne[0], 1)); // [N, block_in, h, w]
auto h = ggml_nn_conv_2d(ctx, z, conv_in_w, conv_in_b, 1, 1, 1, 1); // [N, block_in, h, w]
h = mid.block_1.forward(ctx, h);
h = mid.attn_1.forward(ctx, h);
@ -3026,19 +2948,11 @@ struct Decoder {
}
// group norm 32
h = ggml_group_norm_32(ctx, h);
h = ggml_add(ctx,
ggml_mul(ctx, h, ggml_reshape_4d(ctx, norm_out_w, 1, 1, norm_out_w->ne[0], 1)),
ggml_reshape_4d(ctx, norm_out_b, 1, 1, norm_out_b->ne[0], 1));
// silu
// silu
h = ggml_nn_group_norm(ctx, h, norm_out_w, norm_out_b);
h = ggml_silu_inplace(ctx, h);
// conv_out
h = ggml_conv_2d(ctx, conv_out_w, h, 1, 1, 1, 1, 1, 1);
h = ggml_add(ctx,
h, ggml_reshape_4d(ctx, conv_out_b, 1, 1, conv_out_b->ne[0], 1)); // [N, out_ch, h, w]
h = ggml_nn_conv_2d(ctx, h, conv_out_w, conv_out_b, 1, 1, 1, 1); // [N, out_ch, h, w]
return h;
}
};
@ -3187,9 +3101,7 @@ struct AutoEncoderKL {
struct ggml_tensor* decode(struct ggml_context* ctx0, struct ggml_tensor* z) {
// z: [N, z_channels, h, w]
// post_quant_conv
auto h = ggml_conv_2d(ctx0, post_quant_conv_w, z, 1, 1, 0, 0, 1, 1);
h = ggml_add(ctx0,
h, ggml_reshape_4d(ctx0, post_quant_conv_b, 1, 1, post_quant_conv_b->ne[0], 1)); // [N, z_channels, h, w]
auto h = ggml_nn_conv_2d(ctx0, z, post_quant_conv_w, post_quant_conv_b); // [N, z_channels, h, w]
ggml_set_name(h, "bench-start");
h = decoder.forward(ctx0, h);
ggml_set_name(h, "bench-end");
@ -3200,10 +3112,7 @@ struct AutoEncoderKL {
// x: [N, in_channels, h, w]
auto h = encoder.forward(ctx0, x); // [N, 2*z_channels, h/8, w/8]
// quant_conv
h = ggml_conv_2d(ctx0, quant_conv_w, h, 1, 1, 0, 0, 1, 1);
h = ggml_add(ctx0,
h,
ggml_reshape_4d(ctx0, quant_conv_b, 1, 1, quant_conv_b->ne[0], 1)); // [N, 2*embed_dim, h/8, w/8]
h = ggml_nn_conv_2d(ctx0, h, quant_conv_w, quant_conv_b); // [N, 2*embed_dim, h/8, w/8]
ggml_set_name(h, "b-end");
return h;
}
@ -3367,25 +3276,16 @@ struct TAEBlock {
ggml_tensor* forward(ggml_context* ctx, ggml_tensor* x) {
// conv(n_in, n_out)
ggml_tensor* h;
h = ggml_conv_2d(ctx, conv_0_w, x, 1, 1, 1, 1, 1, 1);
h = ggml_add(ctx, h, ggml_reshape_4d(ctx, conv_0_b, 1, 1, conv_0_b->ne[0], 1));
// relu
h = ggml_nn_conv_2d(ctx, x, conv_0_w, conv_0_b, 1, 1, 1, 1);
h = ggml_relu_inplace(ctx, h);
h = ggml_conv_2d(ctx, conv_1_w, h, 1, 1, 1, 1, 1, 1);
h = ggml_add(ctx, h, ggml_reshape_4d(ctx, conv_1_b, 1, 1, conv_1_b->ne[0], 1));
// relu
h = ggml_nn_conv_2d(ctx, h, conv_1_w, conv_1_b, 1, 1, 1, 1);
h = ggml_relu_inplace(ctx, h);
h = ggml_conv_2d(ctx, conv_2_w, h, 1, 1, 1, 1, 1, 1);
h = ggml_add(ctx, h, ggml_reshape_4d(ctx, conv_2_b, 1, 1, conv_2_b->ne[0], 1));
h = ggml_nn_conv_2d(ctx, h, conv_2_w, conv_2_b, 1, 1, 1, 1);
// skip connection
if (in_channels != out_channels) {
// skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
x = ggml_conv_2d(ctx, conv_skip_w, x, 1, 1, 1, 1, 1, 1);
x = ggml_nn_conv_2d(ctx, x, conv_skip_w, NULL, 1, 1, 1, 1);
}
h = ggml_add(ctx, h, x);
@ -3514,14 +3414,13 @@ struct TinyEncoder {
ggml_tensor* forward(ggml_context* ctx, ggml_tensor* x) {
// conv(3, 64)
auto z = ggml_conv_2d(ctx, conv_input_w, x, 1, 1, 1, 1, 1, 1);
z = ggml_add(ctx, z, ggml_reshape_4d(ctx, conv_input_b, 1, 1, conv_input_b->ne[0], 1));
auto z = ggml_nn_conv_2d(ctx, x, conv_input_w, conv_input_b, 1, 1, 1, 1);
// Block(64, 64)
z = initial_block.forward(ctx, z);
// conv(64, 64, stride=2, bias=False)
z = ggml_conv_2d(ctx, conv_1_w, z, 2, 2, 1, 1, 1, 1);
z = ggml_nn_conv_2d(ctx, z, conv_1_w, NULL, 2, 2, 1, 1);
// Block(64, 64), Block(64, 64), Block(64, 64)
for (int i = 0; i < num_blocks; i++) {
@ -3529,7 +3428,7 @@ struct TinyEncoder {
}
// conv(64, 64, stride=2, bias=False)
z = ggml_conv_2d(ctx, conv_2_w, z, 2, 2, 1, 1, 1, 1);
z = ggml_nn_conv_2d(ctx, z, conv_2_w, NULL, 2, 2, 1, 1);
// Block(64, 64), Block(64, 64), Block(64, 64)
for (int i = 0; i < num_blocks; i++) {
@ -3537,7 +3436,7 @@ struct TinyEncoder {
}
// conv(64, 64, stride=2, bias=False)
z = ggml_conv_2d(ctx, conv_3_w, z, 2, 2, 1, 1, 1, 1);
z = ggml_nn_conv_2d(ctx, z, conv_3_w, NULL, 2, 2, 1, 1);
// Block(64, 64), Block(64, 64), Block(64, 64)
for (int i = 0; i < num_blocks; i++) {
@ -3545,8 +3444,7 @@ struct TinyEncoder {
}
// conv(64, 4)
z = ggml_conv_2d(ctx, conv_final_w, z, 1, 1, 1, 1, 1, 1);
z = ggml_add(ctx, z, ggml_reshape_4d(ctx, conv_final_b, 1, 1, conv_final_b->ne[0], 1));
z = ggml_nn_conv_2d(ctx, z, conv_final_w, conv_final_b, 1, 1, 1, 1);
return z;
}
};
@ -3694,8 +3592,7 @@ struct TinyDecoder {
h = ggml_scale(ctx, h, in_scale_3);
// conv(4, 64)
h = ggml_conv_2d(ctx, conv_input_w, h, 1, 1, 1, 1, 1, 1);
h = ggml_add(ctx, h, ggml_reshape_4d(ctx, conv_input_b, 1, 1, conv_input_b->ne[0], 1));
h = ggml_nn_conv_2d(ctx, h, conv_input_w, conv_input_b, 1, 1, 1, 1);
// nn.ReLU()
h = ggml_relu_inplace(ctx, h);
@ -3709,7 +3606,7 @@ struct TinyDecoder {
h = ggml_upscale(ctx, h, 2);
// conv(64, 64, bias=False)
h = ggml_conv_2d(ctx, conv_1_w, h, 1, 1, 1, 1, 1, 1);
h = ggml_nn_conv_2d(ctx, h, conv_1_w, NULL, 1, 1, 1, 1);
// Block(64, 64), Block(64, 64), Block(64, 64)
for (int i = 0; i < num_blocks; i++) {
@ -3720,7 +3617,7 @@ struct TinyDecoder {
h = ggml_upscale(ctx, h, 2);
// conv(64, 64, bias=False)
h = ggml_conv_2d(ctx, conv_2_w, h, 1, 1, 1, 1, 1, 1);
h = ggml_nn_conv_2d(ctx, h, conv_2_w, NULL, 1, 1, 1, 1);
// Block(64, 64), Block(64, 64), Block(64, 64)
for (int i = 0; i < num_blocks; i++) {
@ -3731,14 +3628,13 @@ struct TinyDecoder {
h = ggml_upscale(ctx, h, 2);
// conv(64, 64, bias=False)
h = ggml_conv_2d(ctx, conv_3_w, h, 1, 1, 1, 1, 1, 1);
h = ggml_nn_conv_2d(ctx, h, conv_3_w, NULL, 1, 1, 1, 1);
// Block(64, 64)
h = final_block.forward(ctx, h);
// conv(64, 3)
h = ggml_conv_2d(ctx, conv_final_w, h, 1, 1, 1, 1, 1, 1);
h = ggml_add(ctx, h, ggml_reshape_4d(ctx, conv_final_b, 1, 1, conv_final_b->ne[0], 1));
h = ggml_nn_conv_2d(ctx, h, conv_final_w, conv_final_b, 1, 1, 1, 1);
return h;
}
};