feat: support 16 channel tae (taesd/taef1) (#527)
This commit is contained in:
parent
b5cc1422da
commit
d50473dc49
@ -360,7 +360,7 @@ public:
|
|||||||
first_stage_model->alloc_params_buffer();
|
first_stage_model->alloc_params_buffer();
|
||||||
first_stage_model->get_param_tensors(tensors, "first_stage_model");
|
first_stage_model->get_param_tensors(tensors, "first_stage_model");
|
||||||
} else {
|
} else {
|
||||||
tae_first_stage = std::make_shared<TinyAutoEncoder>(backend, model_loader.tensor_storages_types, "decoder.layers", vae_decode_only);
|
tae_first_stage = std::make_shared<TinyAutoEncoder>(backend, model_loader.tensor_storages_types, "decoder.layers", vae_decode_only, version);
|
||||||
}
|
}
|
||||||
// first_stage_model->get_param_tensors(tensors, "first_stage_model.");
|
// first_stage_model->get_param_tensors(tensors, "first_stage_model.");
|
||||||
|
|
||||||
|
23
tae.hpp
23
tae.hpp
@ -62,7 +62,8 @@ class TinyEncoder : public UnaryBlock {
|
|||||||
int num_blocks = 3;
|
int num_blocks = 3;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
TinyEncoder() {
|
TinyEncoder(int z_channels = 4)
|
||||||
|
: z_channels(z_channels) {
|
||||||
int index = 0;
|
int index = 0;
|
||||||
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, channels, {3, 3}, {1, 1}, {1, 1}));
|
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, channels, {3, 3}, {1, 1}, {1, 1}));
|
||||||
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new TAEBlock(channels, channels));
|
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new TAEBlock(channels, channels));
|
||||||
@ -106,7 +107,10 @@ class TinyDecoder : public UnaryBlock {
|
|||||||
int num_blocks = 3;
|
int num_blocks = 3;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
TinyDecoder(int index = 0) {
|
TinyDecoder(int z_channels = 4)
|
||||||
|
: z_channels(z_channels) {
|
||||||
|
int index = 0;
|
||||||
|
|
||||||
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(z_channels, channels, {3, 3}, {1, 1}, {1, 1}));
|
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(z_channels, channels, {3, 3}, {1, 1}, {1, 1}));
|
||||||
index++; // nn.ReLU()
|
index++; // nn.ReLU()
|
||||||
|
|
||||||
@ -163,12 +167,16 @@ protected:
|
|||||||
bool decode_only;
|
bool decode_only;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
TAESD(bool decode_only = true)
|
TAESD(bool decode_only = true, SDVersion version = VERSION_SD1)
|
||||||
: decode_only(decode_only) {
|
: decode_only(decode_only) {
|
||||||
blocks["decoder.layers"] = std::shared_ptr<GGMLBlock>(new TinyDecoder());
|
int z_channels = 4;
|
||||||
|
if (sd_version_is_dit(version)) {
|
||||||
|
z_channels = 16;
|
||||||
|
}
|
||||||
|
blocks["decoder.layers"] = std::shared_ptr<GGMLBlock>(new TinyDecoder(z_channels));
|
||||||
|
|
||||||
if (!decode_only) {
|
if (!decode_only) {
|
||||||
blocks["encoder.layers"] = std::shared_ptr<GGMLBlock>(new TinyEncoder());
|
blocks["encoder.layers"] = std::shared_ptr<GGMLBlock>(new TinyEncoder(z_channels));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -190,9 +198,10 @@ struct TinyAutoEncoder : public GGMLRunner {
|
|||||||
TinyAutoEncoder(ggml_backend_t backend,
|
TinyAutoEncoder(ggml_backend_t backend,
|
||||||
std::map<std::string, enum ggml_type>& tensor_types,
|
std::map<std::string, enum ggml_type>& tensor_types,
|
||||||
const std::string prefix,
|
const std::string prefix,
|
||||||
bool decoder_only = true)
|
bool decoder_only = true,
|
||||||
|
SDVersion version = VERSION_SD1)
|
||||||
: decode_only(decoder_only),
|
: decode_only(decoder_only),
|
||||||
taesd(decode_only),
|
taesd(decode_only, version),
|
||||||
GGMLRunner(backend) {
|
GGMLRunner(backend) {
|
||||||
taesd.init(params_ctx, tensor_types, prefix);
|
taesd.init(params_ctx, tensor_types, prefix);
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user