594 lines
21 KiB
C++
594 lines
21 KiB
C++
#ifndef __TAE_HPP__
|
|
#define __TAE_HPP__
|
|
|
|
#include "ggml_extend.hpp"
|
|
|
|
#include "model.h"
|
|
|
|
/*
|
|
=================================== TinyAutoEncoder ===================================
|
|
References:
|
|
https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/autoencoder_tiny.py
|
|
https://github.com/madebyollin/taesd/blob/main/taesd.py
|
|
|
|
*/
|
|
struct TAEBlock {
|
|
int in_channels;
|
|
int out_channels;
|
|
|
|
// conv
|
|
ggml_tensor* conv_0_w; // [in_channels, out_channels, 3, 3]
|
|
ggml_tensor* conv_0_b; // [in_channels]
|
|
ggml_tensor* conv_1_w; // [out_channels, out_channels, 3, 3]
|
|
ggml_tensor* conv_1_b; // [out_channels]
|
|
ggml_tensor* conv_2_w; // [out_channels, out_channels, 3, 3]
|
|
ggml_tensor* conv_2_b; // [out_channels]
|
|
|
|
// skip
|
|
ggml_tensor* conv_skip_w; // [in_channels, out_channels, 1, 1]
|
|
|
|
size_t calculate_mem_size() {
|
|
size_t mem_size = in_channels * out_channels * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv_0_w
|
|
mem_size += in_channels * ggml_type_size(GGML_TYPE_F32); // conv_0_b
|
|
mem_size += out_channels * out_channels * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv_1_w
|
|
mem_size += out_channels * ggml_type_size(GGML_TYPE_F32); // conv_1_b
|
|
mem_size += out_channels * out_channels * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv_1_w
|
|
mem_size += out_channels * ggml_type_size(GGML_TYPE_F32); // conv_1_b
|
|
mem_size += out_channels * out_channels * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv_2_w
|
|
mem_size += out_channels * ggml_type_size(GGML_TYPE_F32); // conv_2_b
|
|
|
|
if (in_channels != out_channels) {
|
|
mem_size += in_channels * out_channels * ggml_type_size(GGML_TYPE_F16); // conv_skip_w
|
|
}
|
|
return mem_size;
|
|
}
|
|
|
|
int get_num_tensors() {
|
|
return 6 + (in_channels != out_channels ? 1 : 0);
|
|
}
|
|
|
|
void init_params(ggml_context* ctx) {
|
|
conv_0_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, out_channels, in_channels);
|
|
conv_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels);
|
|
|
|
conv_1_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, out_channels, out_channels);
|
|
conv_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, out_channels);
|
|
|
|
conv_2_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, out_channels, out_channels);
|
|
conv_2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, out_channels);
|
|
|
|
if (in_channels != out_channels) {
|
|
conv_skip_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 1, 1, out_channels, in_channels);
|
|
}
|
|
}
|
|
|
|
void map_by_name(std::map<std::string, ggml_tensor*>& tensors, std::string prefix) {
|
|
tensors[prefix + "conv.0.weight"] = conv_0_w;
|
|
tensors[prefix + "conv.0.bias"] = conv_0_b;
|
|
|
|
tensors[prefix + "conv.2.weight"] = conv_1_w;
|
|
tensors[prefix + "conv.2.bias"] = conv_1_b;
|
|
|
|
tensors[prefix + "conv.4.weight"] = conv_2_w;
|
|
tensors[prefix + "conv.4.bias"] = conv_2_b;
|
|
|
|
if (in_channels != out_channels) {
|
|
tensors[prefix + "skip.weight"] = conv_skip_w;
|
|
}
|
|
}
|
|
|
|
ggml_tensor* forward(ggml_context* ctx, ggml_tensor* x) {
|
|
// conv(n_in, n_out)
|
|
ggml_tensor* h;
|
|
h = ggml_nn_conv_2d(ctx, x, conv_0_w, conv_0_b, 1, 1, 1, 1);
|
|
h = ggml_relu_inplace(ctx, h);
|
|
h = ggml_nn_conv_2d(ctx, h, conv_1_w, conv_1_b, 1, 1, 1, 1);
|
|
h = ggml_relu_inplace(ctx, h);
|
|
h = ggml_nn_conv_2d(ctx, h, conv_2_w, conv_2_b, 1, 1, 1, 1);
|
|
|
|
// skip connection
|
|
if (in_channels != out_channels) {
|
|
// skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
|
|
x = ggml_nn_conv_2d(ctx, x, conv_skip_w, NULL, 1, 1, 1, 1);
|
|
}
|
|
|
|
h = ggml_add(ctx, h, x);
|
|
h = ggml_relu_inplace(ctx, h);
|
|
return h;
|
|
}
|
|
};
|
|
|
|
struct TinyEncoder {
|
|
int in_channels = 3;
|
|
int z_channels = 4;
|
|
int channels = 64;
|
|
int num_blocks = 3;
|
|
|
|
// input
|
|
ggml_tensor* conv_input_w; // [channels, in_channels, 3, 3]
|
|
ggml_tensor* conv_input_b; // [channels]
|
|
TAEBlock initial_block;
|
|
|
|
ggml_tensor* conv_1_w; // [channels, channels, 3, 3]
|
|
TAEBlock input_blocks[3];
|
|
|
|
// middle
|
|
ggml_tensor* conv_2_w; // [channels, channels, 3, 3]
|
|
TAEBlock middle_blocks[3];
|
|
|
|
// output
|
|
ggml_tensor* conv_3_w; // [channels, channels, 3, 3]
|
|
TAEBlock output_blocks[3];
|
|
|
|
// final
|
|
ggml_tensor* conv_final_w; // [z_channels, channels, 3, 3]
|
|
ggml_tensor* conv_final_b; // [z_channels]
|
|
|
|
TinyEncoder() {
|
|
for (int i = 0; i < num_blocks; i++) {
|
|
input_blocks[i].in_channels = channels;
|
|
input_blocks[i].out_channels = channels;
|
|
|
|
middle_blocks[i].in_channels = channels;
|
|
middle_blocks[i].out_channels = channels;
|
|
|
|
output_blocks[i].in_channels = channels;
|
|
output_blocks[i].out_channels = channels;
|
|
}
|
|
|
|
initial_block.in_channels = channels;
|
|
initial_block.out_channels = channels;
|
|
}
|
|
|
|
size_t calculate_mem_size() {
|
|
size_t mem_size = channels * in_channels * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv_input_w
|
|
mem_size += channels * ggml_type_size(GGML_TYPE_F32); // conv_input_b
|
|
|
|
mem_size += initial_block.calculate_mem_size();
|
|
|
|
mem_size += channels * channels * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv_1_w
|
|
mem_size += channels * channels * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv_2_w
|
|
mem_size += channels * channels * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv_3_w
|
|
|
|
for (int i = 0; i < num_blocks; i++) {
|
|
mem_size += input_blocks[i].calculate_mem_size();
|
|
mem_size += middle_blocks[i].calculate_mem_size();
|
|
mem_size += output_blocks[i].calculate_mem_size();
|
|
}
|
|
mem_size += z_channels * channels * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv_input_w
|
|
mem_size += z_channels * ggml_type_size(GGML_TYPE_F32); // conv_input_b
|
|
return mem_size;
|
|
}
|
|
|
|
int get_num_tensors() {
|
|
int num_tensors = 7;
|
|
for (int i = 0; i < num_blocks; i++) {
|
|
num_tensors += input_blocks[i].get_num_tensors();
|
|
num_tensors += middle_blocks[i].get_num_tensors();
|
|
num_tensors += output_blocks[i].get_num_tensors();
|
|
}
|
|
num_tensors += initial_block.get_num_tensors();
|
|
return num_tensors;
|
|
}
|
|
|
|
void init_params(ggml_context* ctx) {
|
|
conv_input_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, in_channels, channels);
|
|
conv_input_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, channels);
|
|
|
|
initial_block.init_params(ctx);
|
|
|
|
conv_1_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, channels, channels);
|
|
conv_2_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, channels, channels);
|
|
conv_3_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, channels, channels);
|
|
|
|
conv_final_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, channels, z_channels);
|
|
conv_final_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, z_channels);
|
|
|
|
for (int i = 0; i < num_blocks; i++) {
|
|
input_blocks[i].init_params(ctx);
|
|
middle_blocks[i].init_params(ctx);
|
|
output_blocks[i].init_params(ctx);
|
|
}
|
|
}
|
|
|
|
void map_by_name(std::map<std::string, ggml_tensor*>& tensors, std::string prefix) {
|
|
tensors[prefix + "0.weight"] = conv_input_w;
|
|
tensors[prefix + "0.bias"] = conv_input_b;
|
|
|
|
initial_block.map_by_name(tensors, prefix + "1.");
|
|
|
|
tensors[prefix + "2.weight"] = conv_1_w;
|
|
for (int i = 0; i < num_blocks; i++) {
|
|
input_blocks[i].map_by_name(tensors, prefix + std::to_string(i + 3) + ".");
|
|
}
|
|
|
|
tensors[prefix + "6.weight"] = conv_2_w;
|
|
for (int i = 0; i < num_blocks; i++) {
|
|
middle_blocks[i].map_by_name(tensors, prefix + std::to_string(i + 7) + ".");
|
|
}
|
|
|
|
tensors[prefix + "10.weight"] = conv_3_w;
|
|
for (int i = 0; i < num_blocks; i++) {
|
|
output_blocks[i].map_by_name(tensors, prefix + std::to_string(i + 11) + ".");
|
|
}
|
|
|
|
tensors[prefix + "14.weight"] = conv_final_w;
|
|
tensors[prefix + "14.bias"] = conv_final_b;
|
|
}
|
|
|
|
ggml_tensor* forward(ggml_context* ctx, ggml_tensor* x) {
|
|
// conv(3, 64)
|
|
auto z = ggml_nn_conv_2d(ctx, x, conv_input_w, conv_input_b, 1, 1, 1, 1);
|
|
|
|
// Block(64, 64)
|
|
z = initial_block.forward(ctx, z);
|
|
|
|
// conv(64, 64, stride=2, bias=False)
|
|
z = ggml_nn_conv_2d(ctx, z, conv_1_w, NULL, 2, 2, 1, 1);
|
|
|
|
// Block(64, 64), Block(64, 64), Block(64, 64)
|
|
for (int i = 0; i < num_blocks; i++) {
|
|
z = input_blocks[i].forward(ctx, z);
|
|
}
|
|
|
|
// conv(64, 64, stride=2, bias=False)
|
|
z = ggml_nn_conv_2d(ctx, z, conv_2_w, NULL, 2, 2, 1, 1);
|
|
|
|
// Block(64, 64), Block(64, 64), Block(64, 64)
|
|
for (int i = 0; i < num_blocks; i++) {
|
|
z = middle_blocks[i].forward(ctx, z);
|
|
}
|
|
|
|
// conv(64, 64, stride=2, bias=False)
|
|
z = ggml_nn_conv_2d(ctx, z, conv_3_w, NULL, 2, 2, 1, 1);
|
|
|
|
// Block(64, 64), Block(64, 64), Block(64, 64)
|
|
for (int i = 0; i < num_blocks; i++) {
|
|
z = output_blocks[i].forward(ctx, z);
|
|
}
|
|
|
|
// conv(64, 4)
|
|
z = ggml_nn_conv_2d(ctx, z, conv_final_w, conv_final_b, 1, 1, 1, 1);
|
|
return z;
|
|
}
|
|
};
|
|
|
|
struct TinyDecoder {
|
|
int z_channels = 4;
|
|
int channels = 64;
|
|
int output_channels = 3;
|
|
int num_blocks = 3;
|
|
|
|
// input
|
|
ggml_tensor* conv_input_w; // [channels, z_channels, 3, 3]
|
|
ggml_tensor* conv_input_b; // [channels]
|
|
TAEBlock input_blocks[3];
|
|
ggml_tensor* conv_1_w; // [channels, channels, 3, 3]
|
|
|
|
// middle
|
|
TAEBlock middle_blocks[3];
|
|
ggml_tensor* conv_2_w; // [channels, channels, 3, 3]
|
|
|
|
// output
|
|
TAEBlock output_blocks[3];
|
|
ggml_tensor* conv_3_w; // [channels, channels, 3, 3]
|
|
|
|
// final
|
|
TAEBlock final_block;
|
|
ggml_tensor* conv_final_w; // [output_channels, channels, 3, 3]
|
|
ggml_tensor* conv_final_b; // [output_channels]
|
|
|
|
ggml_tensor* in_scale_1d3; // [1]
|
|
ggml_tensor* in_scale_3; // [1]
|
|
|
|
TinyDecoder() {
|
|
for (int i = 0; i < num_blocks; i++) {
|
|
input_blocks[i].in_channels = channels;
|
|
input_blocks[i].out_channels = channels;
|
|
|
|
middle_blocks[i].in_channels = channels;
|
|
middle_blocks[i].out_channels = channels;
|
|
|
|
output_blocks[i].in_channels = channels;
|
|
output_blocks[i].out_channels = channels;
|
|
}
|
|
|
|
final_block.in_channels = channels;
|
|
final_block.out_channels = channels;
|
|
}
|
|
|
|
size_t calculate_mem_size() {
|
|
size_t mem_size = channels * z_channels * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv_input_w
|
|
mem_size += channels * ggml_type_size(GGML_TYPE_F32); // conv_input_b
|
|
|
|
for (int i = 0; i < num_blocks; i++) {
|
|
mem_size += input_blocks[i].calculate_mem_size();
|
|
}
|
|
mem_size += channels * channels * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv_1_w
|
|
|
|
for (int i = 0; i < num_blocks; i++) {
|
|
mem_size += middle_blocks[i].calculate_mem_size();
|
|
}
|
|
mem_size += channels * channels * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv_2_w
|
|
|
|
for (int i = 0; i < num_blocks; i++) {
|
|
mem_size += output_blocks[i].calculate_mem_size();
|
|
}
|
|
mem_size += channels * channels * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv_3_w
|
|
|
|
mem_size += final_block.calculate_mem_size();
|
|
mem_size += output_channels * channels * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv_input_w
|
|
mem_size += output_channels * ggml_type_size(GGML_TYPE_F32); // conv_input_b
|
|
return mem_size;
|
|
}
|
|
|
|
int get_num_tensors() {
|
|
int num_tensors = 9;
|
|
for (int i = 0; i < num_blocks; i++) {
|
|
num_tensors += input_blocks[i].get_num_tensors();
|
|
num_tensors += middle_blocks[i].get_num_tensors();
|
|
num_tensors += output_blocks[i].get_num_tensors();
|
|
}
|
|
num_tensors += final_block.get_num_tensors();
|
|
return num_tensors;
|
|
}
|
|
|
|
void init_params(ggml_allocr* alloc, ggml_context* ctx) {
|
|
conv_input_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, z_channels, channels);
|
|
conv_input_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, channels);
|
|
|
|
conv_1_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, channels, channels);
|
|
conv_2_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, channels, channels);
|
|
conv_3_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, channels, channels);
|
|
|
|
conv_final_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, channels, output_channels);
|
|
conv_final_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, output_channels);
|
|
|
|
for (int i = 0; i < num_blocks; i++) {
|
|
input_blocks[i].init_params(ctx);
|
|
middle_blocks[i].init_params(ctx);
|
|
output_blocks[i].init_params(ctx);
|
|
}
|
|
|
|
final_block.init_params(ctx);
|
|
|
|
// initialize constants scales
|
|
in_scale_1d3 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
|
|
in_scale_3 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
|
|
ggml_allocr_alloc(alloc, in_scale_1d3);
|
|
float scale_1d3 = 1.0f / 3.0f;
|
|
ggml_backend_tensor_set(in_scale_1d3, &scale_1d3, 0, sizeof(scale_1d3));
|
|
ggml_allocr_alloc(alloc, in_scale_3);
|
|
float scale_3 = 3.0f;
|
|
ggml_backend_tensor_set(in_scale_3, &scale_3, 0, sizeof(scale_3));
|
|
}
|
|
|
|
void map_by_name(std::map<std::string, ggml_tensor*>& tensors, std::string prefix) {
|
|
tensors[prefix + "0.weight"] = conv_input_w;
|
|
tensors[prefix + "0.bias"] = conv_input_b;
|
|
|
|
for (int i = 0; i < num_blocks; i++) {
|
|
input_blocks[i].map_by_name(tensors, prefix + std::to_string(i + 2) + ".");
|
|
}
|
|
|
|
tensors[prefix + "6.weight"] = conv_1_w;
|
|
for (int i = 0; i < num_blocks; i++) {
|
|
middle_blocks[i].map_by_name(tensors, prefix + std::to_string(i + 7) + ".");
|
|
}
|
|
|
|
tensors[prefix + "11.weight"] = conv_2_w;
|
|
for (int i = 0; i < num_blocks; i++) {
|
|
output_blocks[i].map_by_name(tensors, prefix + std::to_string(i + 12) + ".");
|
|
}
|
|
|
|
tensors[prefix + "16.weight"] = conv_3_w;
|
|
|
|
final_block.map_by_name(tensors, prefix + "17.");
|
|
|
|
tensors[prefix + "18.weight"] = conv_final_w;
|
|
tensors[prefix + "18.bias"] = conv_final_b;
|
|
}
|
|
|
|
ggml_tensor* forward(ggml_context* ctx, ggml_tensor* z) {
|
|
// torch.tanh(x / 3) * 3
|
|
auto h = ggml_scale(ctx, z, in_scale_1d3);
|
|
h = ggml_tanh_inplace(ctx, h);
|
|
h = ggml_scale(ctx, h, in_scale_3);
|
|
|
|
// conv(4, 64)
|
|
h = ggml_nn_conv_2d(ctx, h, conv_input_w, conv_input_b, 1, 1, 1, 1);
|
|
|
|
// nn.ReLU()
|
|
h = ggml_relu_inplace(ctx, h);
|
|
|
|
// Block(64, 64), Block(64, 64), Block(64, 64)
|
|
for (int i = 0; i < num_blocks; i++) {
|
|
h = input_blocks[i].forward(ctx, h);
|
|
}
|
|
|
|
// nn.Upsample(scale_factor=2)
|
|
h = ggml_upscale(ctx, h, 2);
|
|
|
|
// conv(64, 64, bias=False)
|
|
h = ggml_nn_conv_2d(ctx, h, conv_1_w, NULL, 1, 1, 1, 1);
|
|
|
|
// Block(64, 64), Block(64, 64), Block(64, 64)
|
|
for (int i = 0; i < num_blocks; i++) {
|
|
h = middle_blocks[i].forward(ctx, h);
|
|
}
|
|
|
|
// nn.Upsample(scale_factor=2)
|
|
h = ggml_upscale(ctx, h, 2);
|
|
|
|
// conv(64, 64, bias=False)
|
|
h = ggml_nn_conv_2d(ctx, h, conv_2_w, NULL, 1, 1, 1, 1);
|
|
|
|
// Block(64, 64), Block(64, 64), Block(64, 64)
|
|
for (int i = 0; i < num_blocks; i++) {
|
|
h = output_blocks[i].forward(ctx, h);
|
|
}
|
|
|
|
// nn.Upsample(scale_factor=2)
|
|
h = ggml_upscale(ctx, h, 2);
|
|
|
|
// conv(64, 64, bias=False)
|
|
h = ggml_nn_conv_2d(ctx, h, conv_3_w, NULL, 1, 1, 1, 1);
|
|
|
|
// Block(64, 64)
|
|
h = final_block.forward(ctx, h);
|
|
|
|
// conv(64, 3)
|
|
h = ggml_nn_conv_2d(ctx, h, conv_final_w, conv_final_b, 1, 1, 1, 1);
|
|
return h;
|
|
}
|
|
};
|
|
|
|
struct TinyAutoEncoder : public GGMLModule {
|
|
TinyEncoder encoder;
|
|
TinyDecoder decoder;
|
|
bool decode_only = false;
|
|
|
|
TinyAutoEncoder(bool decoder_only_ = true)
|
|
: decode_only(decoder_only_) {
|
|
name = "tae";
|
|
}
|
|
|
|
size_t calculate_mem_size() {
|
|
size_t mem_size = decoder.calculate_mem_size();
|
|
if (!decode_only) {
|
|
mem_size += encoder.calculate_mem_size();
|
|
}
|
|
mem_size += 1024; // padding
|
|
return mem_size;
|
|
}
|
|
|
|
size_t get_num_tensors() {
|
|
size_t num_tensors = decoder.get_num_tensors();
|
|
if (!decode_only) {
|
|
num_tensors += encoder.get_num_tensors();
|
|
}
|
|
return num_tensors;
|
|
}
|
|
|
|
void init_params() {
|
|
ggml_allocr* alloc = ggml_allocr_new_from_buffer(params_buffer);
|
|
decoder.init_params(alloc, params_ctx);
|
|
if (!decode_only) {
|
|
encoder.init_params(params_ctx);
|
|
}
|
|
|
|
// alloc all tensors linked to this context
|
|
for (struct ggml_tensor* t = ggml_get_first_tensor(params_ctx); t != NULL; t = ggml_get_next_tensor(params_ctx, t)) {
|
|
if (t->data == NULL) {
|
|
ggml_allocr_alloc(alloc, t);
|
|
}
|
|
}
|
|
ggml_allocr_free(alloc);
|
|
}
|
|
|
|
void map_by_name(std::map<std::string, ggml_tensor*>& tensors) {
|
|
decoder.map_by_name(tensors, "decoder.layers.");
|
|
encoder.map_by_name(tensors, "encoder.layers.");
|
|
}
|
|
|
|
bool load_from_file(const std::string& file_path, ggml_backend_t backend) {
|
|
LOG_INFO("loading taesd from '%s'", file_path.c_str());
|
|
|
|
if (!alloc_params_buffer(backend)) {
|
|
return false;
|
|
}
|
|
|
|
std::map<std::string, ggml_tensor*> taesd_tensors;
|
|
|
|
// prepare memory for the weights
|
|
{
|
|
init_params();
|
|
map_by_name(taesd_tensors);
|
|
}
|
|
|
|
std::map<std::string, struct ggml_tensor*> tensors_need_to_load;
|
|
std::set<std::string> ignore_tensors;
|
|
for (auto& pair : taesd_tensors) {
|
|
const std::string& name = pair.first;
|
|
|
|
if (decode_only && starts_with(name, "encoder")) {
|
|
ignore_tensors.insert(name);
|
|
continue;
|
|
}
|
|
|
|
tensors_need_to_load.insert(pair);
|
|
}
|
|
|
|
ModelLoader model_loader;
|
|
if (!model_loader.init_from_file(file_path)) {
|
|
LOG_ERROR("init taesd model loader from file failed: '%s'", file_path.c_str());
|
|
return false;
|
|
}
|
|
|
|
bool success = model_loader.load_tensors(tensors_need_to_load, backend, ignore_tensors);
|
|
|
|
if (!success) {
|
|
LOG_ERROR("load tae tensors from model loader failed");
|
|
return false;
|
|
}
|
|
|
|
LOG_INFO("taesd model loaded");
|
|
return success;
|
|
}
|
|
|
|
struct ggml_cgraph* build_graph(struct ggml_tensor* z, bool decode_graph) {
|
|
// since we are using ggml-alloc, this buffer only needs enough space to hold the ggml_tensor and ggml_cgraph structs, but not the tensor data
|
|
static size_t buf_size = ggml_tensor_overhead() * GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead();
|
|
static std::vector<uint8_t> buf(buf_size);
|
|
|
|
struct ggml_init_params params = {
|
|
/*.mem_size =*/buf_size,
|
|
/*.mem_buffer =*/buf.data(),
|
|
/*.no_alloc =*/true, // the tensors will be allocated later by ggml_allocr_alloc_graph()
|
|
};
|
|
// LOG_DEBUG("mem_size %u ", params.mem_size);
|
|
|
|
struct ggml_context* ctx0 = ggml_init(params);
|
|
|
|
struct ggml_cgraph* gf = ggml_new_graph(ctx0);
|
|
|
|
struct ggml_tensor* z_ = NULL;
|
|
|
|
// it's performing a compute, check if backend isn't cpu
|
|
if (!ggml_backend_is_cpu(backend)) {
|
|
// pass input tensors to gpu memory
|
|
z_ = ggml_dup_tensor(ctx0, z);
|
|
ggml_allocr_alloc(compute_allocr, z_);
|
|
|
|
// pass data to device backend
|
|
if (!ggml_allocr_is_measure(compute_allocr)) {
|
|
ggml_backend_tensor_set(z_, z->data, 0, ggml_nbytes(z));
|
|
}
|
|
} else {
|
|
z_ = z;
|
|
}
|
|
|
|
struct ggml_tensor* out = decode_graph ? decoder.forward(ctx0, z_) : encoder.forward(ctx0, z_);
|
|
|
|
ggml_build_forward_expand(gf, out);
|
|
ggml_free(ctx0);
|
|
|
|
return gf;
|
|
}
|
|
|
|
void alloc_compute_buffer(struct ggml_tensor* x, bool decode) {
|
|
auto get_graph = [&]() -> struct ggml_cgraph* {
|
|
return build_graph(x, decode);
|
|
};
|
|
GGMLModule::alloc_compute_buffer(get_graph);
|
|
}
|
|
|
|
void compute(struct ggml_tensor* work_result, int n_threads, struct ggml_tensor* z, bool decode_graph) {
|
|
auto get_graph = [&]() -> struct ggml_cgraph* {
|
|
return build_graph(z, decode_graph);
|
|
};
|
|
GGMLModule::compute(get_graph, n_threads, work_result);
|
|
}
|
|
};
|
|
|
|
#endif // __TAE_HPP__
|