From 1bdc767aafdcd37380f2343d416e9186a8a8da93 Mon Sep 17 00:00:00 2001 From: leejet Date: Sun, 25 Aug 2024 13:53:16 +0800 Subject: [PATCH] feat: force using f32 for some layers --- flux.hpp | 8 ++++---- ggml_extend.hpp | 9 ++++++--- mmdit.hpp | 12 ++++++------ model.cpp | 12 ++++++++++-- 4 files changed, 26 insertions(+), 15 deletions(-) diff --git a/flux.hpp b/flux.hpp index 84d3cad..c837a1d 100644 --- a/flux.hpp +++ b/flux.hpp @@ -13,7 +13,7 @@ namespace Flux { struct MLPEmbedder : public UnaryBlock { public: MLPEmbedder(int64_t in_dim, int64_t hidden_dim) { - blocks["in_layer"] = std::shared_ptr(new Linear(in_dim, hidden_dim, true)); + blocks["in_layer"] = std::shared_ptr(new Linear(in_dim, hidden_dim, true, true)); blocks["out_layer"] = std::shared_ptr(new Linear(hidden_dim, hidden_dim, true)); } @@ -449,7 +449,7 @@ namespace Flux { int64_t patch_size, int64_t out_channels) { blocks["norm_final"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-06f, false)); - blocks["linear"] = std::shared_ptr(new Linear(hidden_size, patch_size * patch_size * out_channels)); + blocks["linear"] = std::shared_ptr(new Linear(hidden_size, patch_size * patch_size * out_channels, true, true)); blocks["adaLN_modulation.1"] = std::shared_ptr(new Linear(hidden_size, 2 * hidden_size)); } @@ -634,13 +634,13 @@ namespace Flux { int64_t out_channels = params.in_channels; int64_t pe_dim = params.hidden_size / params.num_heads; - blocks["img_in"] = std::shared_ptr(new Linear(params.in_channels, params.hidden_size)); + blocks["img_in"] = std::shared_ptr(new Linear(params.in_channels, params.hidden_size, true, true)); blocks["time_in"] = std::shared_ptr(new MLPEmbedder(256, params.hidden_size)); blocks["vector_in"] = std::shared_ptr(new MLPEmbedder(params.vec_in_dim, params.hidden_size)); if (params.guidance_embed) { blocks["guidance_in"] = std::shared_ptr(new MLPEmbedder(256, params.hidden_size)); } - blocks["txt_in"] = std::shared_ptr(new Linear(params.context_in_dim, params.hidden_size)); + blocks["txt_in"] = std::shared_ptr(new Linear(params.context_in_dim, params.hidden_size, true, true)); for (int i = 0; i < params.depth; i++) { blocks["double_blocks." + std::to_string(i)] = std::shared_ptr(new DoubleStreamBlock(params.hidden_size, diff --git a/ggml_extend.hpp b/ggml_extend.hpp index 19410e8..09e4fcb 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -1187,9 +1187,10 @@ protected: int64_t in_features; int64_t out_features; bool bias; + bool force_f32; void init_params(struct ggml_context* ctx, ggml_type wtype) { - if (in_features % ggml_blck_size(wtype) != 0) { + if (in_features % ggml_blck_size(wtype) != 0 || force_f32) { wtype = GGML_TYPE_F32; } params["weight"] = ggml_new_tensor_2d(ctx, wtype, in_features, out_features); @@ -1201,10 +1202,12 @@ protected: public: Linear(int64_t in_features, int64_t out_features, - bool bias = true) + bool bias = true, + bool force_f32 = false) : in_features(in_features), out_features(out_features), - bias(bias) {} + bias(bias), + force_f32(force_f32) {} struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { struct ggml_tensor* w = params["weight"]; diff --git a/mmdit.hpp b/mmdit.hpp index 0a4d831..6f3a8a0 100644 --- a/mmdit.hpp +++ b/mmdit.hpp @@ -101,8 +101,8 @@ public: TimestepEmbedder(int64_t hidden_size, int64_t frequency_embedding_size = 256) : frequency_embedding_size(frequency_embedding_size) { - blocks["mlp.0"] = std::shared_ptr(new Linear(frequency_embedding_size, hidden_size)); - blocks["mlp.2"] = std::shared_ptr(new Linear(hidden_size, hidden_size)); + blocks["mlp.0"] = std::shared_ptr(new Linear(frequency_embedding_size, hidden_size, true, true)); + blocks["mlp.2"] = std::shared_ptr(new Linear(hidden_size, hidden_size, true, true)); } struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* t) { @@ -125,8 +125,8 @@ struct VectorEmbedder : public GGMLBlock { public: VectorEmbedder(int64_t input_dim, int64_t hidden_size) { - blocks["mlp.0"] = std::shared_ptr(new Linear(input_dim, hidden_size)); - blocks["mlp.2"] = std::shared_ptr(new Linear(hidden_size, hidden_size)); + blocks["mlp.0"] = std::shared_ptr(new Linear(input_dim, hidden_size, true, true)); + blocks["mlp.2"] = std::shared_ptr(new Linear(hidden_size, hidden_size, true, true)); } struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { @@ -423,7 +423,7 @@ public: int64_t out_channels) { // total_out_channels is always None blocks["norm_final"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-06f, false)); - blocks["linear"] = std::shared_ptr(new Linear(hidden_size, patch_size * patch_size * out_channels)); + blocks["linear"] = std::shared_ptr(new Linear(hidden_size, patch_size * patch_size * out_channels, true, true)); blocks["adaLN_modulation.1"] = std::shared_ptr(new Linear(hidden_size, 2 * hidden_size)); } @@ -510,7 +510,7 @@ public: blocks["y_embedder"] = std::shared_ptr(new VectorEmbedder(adm_in_channels, hidden_size)); } - blocks["context_embedder"] = std::shared_ptr(new Linear(4096, 1536)); + blocks["context_embedder"] = std::shared_ptr(new Linear(4096, 1536, true, true)); for (int i = 0; i < depth; i++) { blocks["joint_blocks." + std::to_string(i)] = std::shared_ptr(new JointBlock(hidden_size, diff --git a/model.cpp b/model.cpp index f6995b9..6ca7e9b 100644 --- a/model.cpp +++ b/model.cpp @@ -1740,9 +1740,17 @@ bool ModelLoader::tensor_should_be_converted(const TensorStorage& tensor_storage // Pass, do not convert } else if (ends_with(name, ".bias")) { // Pass, do not convert - } else if (contains(name, "img_in.") || contains(name, "time_in.in_layer.") || contains(name, "vector_in.in_layer.") || contains(name, "guidance_in.in_layer.") || contains(name, "final_layer.linear.")) { + } else if (contains(name, "img_in.") || + contains(name, "time_in.in_layer.") || + contains(name, "vector_in.in_layer.") || + contains(name, "guidance_in.in_layer.") || + contains(name, "final_layer.linear.")) { // Pass, do not convert. For FLUX - } else if (contains(name, "x_embedder.") || contains(name, "t_embedder.") || contains(name, "y_embedder.") || contains(name, "context_embedder.")) { + } else if (contains(name, "x_embedder.") || + contains(name, "t_embedder.") || + contains(name, "y_embedder.") || + contains(name, "pos_embed") || + contains(name, "context_embedder.")) { // Pass, do not convert. For MMDiT } else if (contains(name, "time_embed.") || contains(name, "label_emb.")) { // Pass, do not convert. For Unet