feat: support 16 channel tae (taesd/taef1) (#527)

This commit is contained in:
stduhpf 2024-12-28 06:13:48 +01:00 committed by GitHub
parent b5cc1422da
commit d50473dc49
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 17 additions and 8 deletions

View File

@ -360,7 +360,7 @@ public:
first_stage_model->alloc_params_buffer(); first_stage_model->alloc_params_buffer();
first_stage_model->get_param_tensors(tensors, "first_stage_model"); first_stage_model->get_param_tensors(tensors, "first_stage_model");
} else { } else {
tae_first_stage = std::make_shared<TinyAutoEncoder>(backend, model_loader.tensor_storages_types, "decoder.layers", vae_decode_only); tae_first_stage = std::make_shared<TinyAutoEncoder>(backend, model_loader.tensor_storages_types, "decoder.layers", vae_decode_only, version);
} }
// first_stage_model->get_param_tensors(tensors, "first_stage_model."); // first_stage_model->get_param_tensors(tensors, "first_stage_model.");

23
tae.hpp
View File

@ -62,7 +62,8 @@ class TinyEncoder : public UnaryBlock {
int num_blocks = 3; int num_blocks = 3;
public: public:
TinyEncoder() { TinyEncoder(int z_channels = 4)
: z_channels(z_channels) {
int index = 0; int index = 0;
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, channels, {3, 3}, {1, 1}, {1, 1})); blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, channels, {3, 3}, {1, 1}, {1, 1}));
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new TAEBlock(channels, channels)); blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new TAEBlock(channels, channels));
@ -106,7 +107,10 @@ class TinyDecoder : public UnaryBlock {
int num_blocks = 3; int num_blocks = 3;
public: public:
TinyDecoder(int index = 0) { TinyDecoder(int z_channels = 4)
: z_channels(z_channels) {
int index = 0;
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(z_channels, channels, {3, 3}, {1, 1}, {1, 1})); blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(z_channels, channels, {3, 3}, {1, 1}, {1, 1}));
index++; // nn.ReLU() index++; // nn.ReLU()
@ -163,12 +167,16 @@ protected:
bool decode_only; bool decode_only;
public: public:
TAESD(bool decode_only = true) TAESD(bool decode_only = true, SDVersion version = VERSION_SD1)
: decode_only(decode_only) { : decode_only(decode_only) {
blocks["decoder.layers"] = std::shared_ptr<GGMLBlock>(new TinyDecoder()); int z_channels = 4;
if (sd_version_is_dit(version)) {
z_channels = 16;
}
blocks["decoder.layers"] = std::shared_ptr<GGMLBlock>(new TinyDecoder(z_channels));
if (!decode_only) { if (!decode_only) {
blocks["encoder.layers"] = std::shared_ptr<GGMLBlock>(new TinyEncoder()); blocks["encoder.layers"] = std::shared_ptr<GGMLBlock>(new TinyEncoder(z_channels));
} }
} }
@ -190,9 +198,10 @@ struct TinyAutoEncoder : public GGMLRunner {
TinyAutoEncoder(ggml_backend_t backend, TinyAutoEncoder(ggml_backend_t backend,
std::map<std::string, enum ggml_type>& tensor_types, std::map<std::string, enum ggml_type>& tensor_types,
const std::string prefix, const std::string prefix,
bool decoder_only = true) bool decoder_only = true,
SDVersion version = VERSION_SD1)
: decode_only(decoder_only), : decode_only(decoder_only),
taesd(decode_only), taesd(decode_only, version),
GGMLRunner(backend) { GGMLRunner(backend) {
taesd.init(params_ctx, tensor_types, prefix); taesd.init(params_ctx, tensor_types, prefix);
} }