From d50473dc49e88dc692f20bccd52a84fab374b556 Mon Sep 17 00:00:00 2001 From: stduhpf Date: Sat, 28 Dec 2024 06:13:48 +0100 Subject: [PATCH] feat: support 16 channel tae (taesd/taef1) (#527) --- stable-diffusion.cpp | 2 +- tae.hpp | 23 ++++++++++++++++------- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 7025df8..35d49af 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -360,7 +360,7 @@ public: first_stage_model->alloc_params_buffer(); first_stage_model->get_param_tensors(tensors, "first_stage_model"); } else { - tae_first_stage = std::make_shared(backend, model_loader.tensor_storages_types, "decoder.layers", vae_decode_only); + tae_first_stage = std::make_shared(backend, model_loader.tensor_storages_types, "decoder.layers", vae_decode_only, version); } // first_stage_model->get_param_tensors(tensors, "first_stage_model."); diff --git a/tae.hpp b/tae.hpp index fee5e83..6830598 100644 --- a/tae.hpp +++ b/tae.hpp @@ -62,7 +62,8 @@ class TinyEncoder : public UnaryBlock { int num_blocks = 3; public: - TinyEncoder() { + TinyEncoder(int z_channels = 4) + : z_channels(z_channels) { int index = 0; blocks[std::to_string(index++)] = std::shared_ptr(new Conv2d(in_channels, channels, {3, 3}, {1, 1}, {1, 1})); blocks[std::to_string(index++)] = std::shared_ptr(new TAEBlock(channels, channels)); @@ -106,7 +107,10 @@ class TinyDecoder : public UnaryBlock { int num_blocks = 3; public: - TinyDecoder(int index = 0) { + TinyDecoder(int z_channels = 4) + : z_channels(z_channels) { + int index = 0; + blocks[std::to_string(index++)] = std::shared_ptr(new Conv2d(z_channels, channels, {3, 3}, {1, 1}, {1, 1})); index++; // nn.ReLU() @@ -163,12 +167,16 @@ protected: bool decode_only; public: - TAESD(bool decode_only = true) + TAESD(bool decode_only = true, SDVersion version = VERSION_SD1) : decode_only(decode_only) { - blocks["decoder.layers"] = std::shared_ptr(new TinyDecoder()); + int z_channels = 4; + if (sd_version_is_dit(version)) { + z_channels = 16; + } + blocks["decoder.layers"] = std::shared_ptr(new TinyDecoder(z_channels)); if (!decode_only) { - blocks["encoder.layers"] = std::shared_ptr(new TinyEncoder()); + blocks["encoder.layers"] = std::shared_ptr(new TinyEncoder(z_channels)); } } @@ -190,9 +198,10 @@ struct TinyAutoEncoder : public GGMLRunner { TinyAutoEncoder(ggml_backend_t backend, std::map& tensor_types, const std::string prefix, - bool decoder_only = true) + bool decoder_only = true, + SDVersion version = VERSION_SD1) : decode_only(decoder_only), - taesd(decode_only), + taesd(decode_only, version), GGMLRunner(backend) { taesd.init(params_ctx, tensor_types, prefix); }