diff --git a/conditioner.hpp b/conditioner.hpp index 0e8f5a3..43d0a6d 100644 --- a/conditioner.hpp +++ b/conditioner.hpp @@ -10,8 +10,8 @@ struct SDCondition { struct ggml_tensor* c_concat = NULL; SDCondition() = default; - SDCondition(struct ggml_tensor* c_crossattn, struct ggml_tensor* c_vector, struct ggml_tensor* c_concat) : - c_crossattn(c_crossattn), c_vector(c_vector), c_concat(c_concat) {} + SDCondition(struct ggml_tensor* c_crossattn, struct ggml_tensor* c_vector, struct ggml_tensor* c_concat) + : c_crossattn(c_crossattn), c_vector(c_vector), c_concat(c_concat) {} }; struct Conditioner { @@ -978,7 +978,6 @@ struct SD3CLIPEmbedder : public Conditioner { } }; - struct FluxCLIPEmbedder : public Conditioner { ggml_type wtype; CLIPTokenizer clip_l_tokenizer; @@ -987,8 +986,8 @@ struct FluxCLIPEmbedder : public Conditioner { std::shared_ptr t5; FluxCLIPEmbedder(ggml_backend_t backend, - ggml_type wtype, - int clip_skip = -1) + ggml_type wtype, + int clip_skip = -1) : wtype(wtype) { if (clip_skip <= 0) { clip_skip = 2; @@ -1085,10 +1084,10 @@ struct FluxCLIPEmbedder : public Conditioner { auto& t5_tokens = token_and_weights[1].first; auto& t5_weights = token_and_weights[1].second; - int64_t t0 = ggml_time_ms(); - struct ggml_tensor* hidden_states = NULL; // [N, n_token, 4096] - struct ggml_tensor* chunk_hidden_states = NULL; // [n_token, 4096] - struct ggml_tensor* pooled = NULL; // [768,] + int64_t t0 = ggml_time_ms(); + struct ggml_tensor* hidden_states = NULL; // [N, n_token, 4096] + struct ggml_tensor* chunk_hidden_states = NULL; // [n_token, 4096] + struct ggml_tensor* pooled = NULL; // [768,] std::vector hidden_states_vec; size_t chunk_len = 256; diff --git a/denoiser.hpp b/denoiser.hpp index 85e4a0b..5d4cb32 100644 --- a/denoiser.hpp +++ b/denoiser.hpp @@ -351,7 +351,6 @@ struct DiscreteFlowDenoiser : public Denoiser { } }; - float flux_time_shift(float mu, float sigma, float t) { return std::exp(mu) / (std::exp(mu) + std::pow((1.0 / t - 1.0), sigma)); } @@ -369,7 +368,7 @@ struct FluxFlowDenoiser : public Denoiser { void set_parameters(float shift = 1.15f) { this->shift = shift; for (int i = 1; i < TIMESTEPS + 1; i++) { - sigmas[i - 1] = t_to_sigma(i/TIMESTEPS * TIMESTEPS); + sigmas[i - 1] = t_to_sigma(i / TIMESTEPS * TIMESTEPS); } } diff --git a/diffusion_model.hpp b/diffusion_model.hpp index 5c214e1..2530f71 100644 --- a/diffusion_model.hpp +++ b/diffusion_model.hpp @@ -1,9 +1,9 @@ #ifndef __DIFFUSION_MODEL_H__ #define __DIFFUSION_MODEL_H__ +#include "flux.hpp" #include "mmdit.hpp" #include "unet.hpp" -#include "flux.hpp" struct DiffusionModel { virtual void compute(int n_threads, @@ -124,13 +124,12 @@ struct MMDiTModel : public DiffusionModel { } }; - struct FluxModel : public DiffusionModel { Flux::FluxRunner flux; FluxModel(ggml_backend_t backend, - ggml_type wtype, - SDVersion version = VERSION_FLUX_DEV) + ggml_type wtype, + SDVersion version = VERSION_FLUX_DEV) : flux(backend, wtype, version) { } diff --git a/flux.hpp b/flux.hpp index 3b398b4..84d3cad 100644 --- a/flux.hpp +++ b/flux.hpp @@ -10,956 +10,952 @@ namespace Flux { -struct MLPEmbedder : public UnaryBlock { -public: - MLPEmbedder(int64_t in_dim, int64_t hidden_dim) { - blocks["in_layer"] = std::shared_ptr(new Linear(in_dim, hidden_dim, true)); - blocks["out_layer"] = std::shared_ptr(new Linear(hidden_dim, hidden_dim, true)); + struct MLPEmbedder : public UnaryBlock { + public: + MLPEmbedder(int64_t in_dim, int64_t hidden_dim) { + blocks["in_layer"] = std::shared_ptr(new Linear(in_dim, hidden_dim, true)); + blocks["out_layer"] = std::shared_ptr(new Linear(hidden_dim, hidden_dim, true)); + } + + struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { + // x: [..., in_dim] + // return: [..., hidden_dim] + auto in_layer = std::dynamic_pointer_cast(blocks["in_layer"]); + auto out_layer = std::dynamic_pointer_cast(blocks["out_layer"]); + + x = in_layer->forward(ctx, x); + x = ggml_silu_inplace(ctx, x); + x = out_layer->forward(ctx, x); + return x; + } + }; + + class RMSNorm : public UnaryBlock { + protected: + int64_t hidden_size; + float eps; + + void init_params(struct ggml_context* ctx, ggml_type wtype) { + params["scale"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hidden_size); + } + + public: + RMSNorm(int64_t hidden_size, + float eps = 1e-06f) + : hidden_size(hidden_size), + eps(eps) {} + + struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { + struct ggml_tensor* w = params["scale"]; + x = ggml_rms_norm(ctx, x, eps); + x = ggml_mul(ctx, x, w); + return x; + } + }; + + struct QKNorm : public GGMLBlock { + public: + QKNorm(int64_t dim) { + blocks["query_norm"] = std::shared_ptr(new RMSNorm(dim)); + blocks["key_norm"] = std::shared_ptr(new RMSNorm(dim)); + } + + struct ggml_tensor* query_norm(struct ggml_context* ctx, struct ggml_tensor* x) { + // x: [..., dim] + // return: [..., dim] + auto norm = std::dynamic_pointer_cast(blocks["query_norm"]); + + x = norm->forward(ctx, x); + return x; + } + + struct ggml_tensor* key_norm(struct ggml_context* ctx, struct ggml_tensor* x) { + // x: [..., dim] + // return: [..., dim] + auto norm = std::dynamic_pointer_cast(blocks["key_norm"]); + + x = norm->forward(ctx, x); + return x; + } + }; + + __STATIC_INLINE__ struct ggml_tensor* apply_rope(struct ggml_context* ctx, + struct ggml_tensor* x, + struct ggml_tensor* pe) { + // x: [N, L, n_head, d_head] + // pe: [L, d_head/2, 2, 2] + int64_t d_head = x->ne[0]; + int64_t n_head = x->ne[1]; + int64_t L = x->ne[2]; + int64_t N = x->ne[3]; + x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N, n_head, L, d_head] + x = ggml_reshape_4d(ctx, x, 2, d_head / 2, L, n_head * N); // [N * n_head, L, d_head/2, 2] + x = ggml_cont(ctx, ggml_permute(ctx, x, 3, 0, 1, 2)); // [2, N * n_head, L, d_head/2] + + int64_t offset = x->nb[2] * x->ne[2]; + auto x_0 = ggml_view_3d(ctx, x, x->ne[0], x->ne[1], x->ne[2], x->nb[1], x->nb[2], offset * 0); // [N * n_head, L, d_head/2] + auto x_1 = ggml_view_3d(ctx, x, x->ne[0], x->ne[1], x->ne[2], x->nb[1], x->nb[2], offset * 1); // [N * n_head, L, d_head/2] + x_0 = ggml_reshape_4d(ctx, x_0, 1, x_0->ne[0], x_0->ne[1], x_0->ne[2]); // [N * n_head, L, d_head/2, 1] + x_1 = ggml_reshape_4d(ctx, x_1, 1, x_1->ne[0], x_1->ne[1], x_1->ne[2]); // [N * n_head, L, d_head/2, 1] + auto temp_x = ggml_new_tensor_4d(ctx, x_0->type, 2, x_0->ne[1], x_0->ne[2], x_0->ne[3]); + x_0 = ggml_repeat(ctx, x_0, temp_x); // [N * n_head, L, d_head/2, 2] + x_1 = ggml_repeat(ctx, x_1, temp_x); // [N * n_head, L, d_head/2, 2] + + pe = ggml_cont(ctx, ggml_permute(ctx, pe, 3, 0, 1, 2)); // [2, L, d_head/2, 2] + offset = pe->nb[2] * pe->ne[2]; + auto pe_0 = ggml_view_3d(ctx, pe, pe->ne[0], pe->ne[1], pe->ne[2], pe->nb[1], pe->nb[2], offset * 0); // [L, d_head/2, 2] + auto pe_1 = ggml_view_3d(ctx, pe, pe->ne[0], pe->ne[1], pe->ne[2], pe->nb[1], pe->nb[2], offset * 1); // [L, d_head/2, 2] + + auto x_out = ggml_add_inplace(ctx, ggml_mul(ctx, x_0, pe_0), ggml_mul(ctx, x_1, pe_1)); // [N * n_head, L, d_head/2, 2] + x_out = ggml_reshape_3d(ctx, x_out, d_head, L, n_head * N); // [N*n_head, L, d_head] + return x_out; } - struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { - // x: [..., in_dim] - // return: [..., hidden_dim] - auto in_layer = std::dynamic_pointer_cast(blocks["in_layer"]); - auto out_layer = std::dynamic_pointer_cast(blocks["out_layer"]); + __STATIC_INLINE__ struct ggml_tensor* attention(struct ggml_context* ctx, + struct ggml_tensor* q, + struct ggml_tensor* k, + struct ggml_tensor* v, + struct ggml_tensor* pe) { + // q,k,v: [N, L, n_head, d_head] + // pe: [L, d_head/2, 2, 2] + // return: [N, L, n_head*d_head] + q = apply_rope(ctx, q, pe); // [N*n_head, L, d_head] + k = apply_rope(ctx, k, pe); // [N*n_head, L, d_head] - x = in_layer->forward(ctx, x); - x = ggml_silu_inplace(ctx, x); - x = out_layer->forward(ctx, x); - return x; - } -}; - -class RMSNorm : public UnaryBlock { -protected: - int64_t hidden_size; - float eps; - - void init_params(struct ggml_context* ctx, ggml_type wtype) { - params["scale"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hidden_size); - } - -public: - RMSNorm(int64_t hidden_size, - float eps = 1e-06f) - : hidden_size(hidden_size), - eps(eps) {} - - struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { - struct ggml_tensor* w = params["scale"]; - x = ggml_rms_norm(ctx, x, eps); - x = ggml_mul(ctx, x, w); - return x; - } -}; - - -struct QKNorm : public GGMLBlock { -public: - QKNorm(int64_t dim) { - blocks["query_norm"] = std::shared_ptr(new RMSNorm(dim)); - blocks["key_norm"] = std::shared_ptr(new RMSNorm(dim)); - } - - struct ggml_tensor* query_norm(struct ggml_context* ctx, struct ggml_tensor* x) { - // x: [..., dim] - // return: [..., dim] - auto norm = std::dynamic_pointer_cast(blocks["query_norm"]); - - x = norm->forward(ctx, x); + auto x = ggml_nn_attention_ext(ctx, q, k, v, v->ne[1], NULL, false, true); // [N, L, n_head*d_head] return x; } - struct ggml_tensor* key_norm(struct ggml_context* ctx, struct ggml_tensor* x) { - // x: [..., dim] - // return: [..., dim] - auto norm = std::dynamic_pointer_cast(blocks["key_norm"]); + struct SelfAttention : public GGMLBlock { + public: + int64_t num_heads; - x = norm->forward(ctx, x); - return x; - } -}; + public: + SelfAttention(int64_t dim, + int64_t num_heads = 8, + bool qkv_bias = false) + : num_heads(num_heads) { + int64_t head_dim = dim / num_heads; + blocks["qkv"] = std::shared_ptr(new Linear(dim, dim * 3, qkv_bias)); + blocks["norm"] = std::shared_ptr(new QKNorm(head_dim)); + blocks["proj"] = std::shared_ptr(new Linear(dim, dim)); + } -__STATIC_INLINE__ struct ggml_tensor* apply_rope(struct ggml_context* ctx, - struct ggml_tensor* x, - struct ggml_tensor* pe) { - // x: [N, L, n_head, d_head] - // pe: [L, d_head/2, 2, 2] - int64_t d_head = x->ne[0]; - int64_t n_head = x->ne[1]; - int64_t L = x->ne[2]; - int64_t N = x->ne[3]; - x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N, n_head, L, d_head] - x = ggml_reshape_4d(ctx, x, 2, d_head/2, L, n_head * N); // [N * n_head, L, d_head/2, 2] - x = ggml_cont(ctx, ggml_permute(ctx, x, 3, 0, 1, 2)); // [2, N * n_head, L, d_head/2] + std::vector pre_attention(struct ggml_context* ctx, struct ggml_tensor* x) { + auto qkv_proj = std::dynamic_pointer_cast(blocks["qkv"]); + auto norm = std::dynamic_pointer_cast(blocks["norm"]); - int64_t offset = x->nb[2] * x->ne[2]; - auto x_0 = ggml_view_3d(ctx, x, x->ne[0], x->ne[1], x->ne[2], x->nb[1], x->nb[2], offset * 0); // [N * n_head, L, d_head/2] - auto x_1 = ggml_view_3d(ctx, x, x->ne[0], x->ne[1], x->ne[2], x->nb[1], x->nb[2], offset * 1); // [N * n_head, L, d_head/2] - x_0 = ggml_reshape_4d(ctx, x_0, 1, x_0->ne[0], x_0->ne[1], x_0->ne[2]); // [N * n_head, L, d_head/2, 1] - x_1 = ggml_reshape_4d(ctx, x_1, 1, x_1->ne[0], x_1->ne[1], x_1->ne[2]); // [N * n_head, L, d_head/2, 1] - auto temp_x = ggml_new_tensor_4d(ctx, x_0->type, 2, x_0->ne[1], x_0->ne[2], x_0->ne[3]); - x_0 = ggml_repeat(ctx, x_0, temp_x); // [N * n_head, L, d_head/2, 2] - x_1 = ggml_repeat(ctx, x_1, temp_x); // [N * n_head, L, d_head/2, 2] + auto qkv = qkv_proj->forward(ctx, x); + auto qkv_vec = split_qkv(ctx, qkv); + int64_t head_dim = qkv_vec[0]->ne[0] / num_heads; + auto q = ggml_reshape_4d(ctx, qkv_vec[0], head_dim, num_heads, qkv_vec[0]->ne[1], qkv_vec[0]->ne[2]); + auto k = ggml_reshape_4d(ctx, qkv_vec[1], head_dim, num_heads, qkv_vec[1]->ne[1], qkv_vec[1]->ne[2]); + auto v = ggml_reshape_4d(ctx, qkv_vec[2], head_dim, num_heads, qkv_vec[2]->ne[1], qkv_vec[2]->ne[2]); + q = norm->query_norm(ctx, q); + k = norm->key_norm(ctx, k); + return {q, k, v}; + } - pe = ggml_cont(ctx, ggml_permute(ctx, pe, 3, 0, 1, 2)); // [2, L, d_head/2, 2] - offset = pe->nb[2] * pe->ne[2]; - auto pe_0 = ggml_view_3d(ctx, pe, pe->ne[0], pe->ne[1], pe->ne[2], pe->nb[1], pe->nb[2], offset * 0); // [L, d_head/2, 2] - auto pe_1 = ggml_view_3d(ctx, pe, pe->ne[0], pe->ne[1], pe->ne[2], pe->nb[1], pe->nb[2], offset * 1); // [L, d_head/2, 2] + struct ggml_tensor* post_attention(struct ggml_context* ctx, struct ggml_tensor* x) { + auto proj = std::dynamic_pointer_cast(blocks["proj"]); - auto x_out = ggml_add_inplace(ctx, ggml_mul(ctx, x_0, pe_0), ggml_mul(ctx, x_1, pe_1)); // [N * n_head, L, d_head/2, 2] - x_out = ggml_reshape_3d(ctx, x_out, d_head, L, n_head*N); // [N*n_head, L, d_head] - return x_out; -} + x = proj->forward(ctx, x); // [N, n_token, dim] + return x; + } -__STATIC_INLINE__ struct ggml_tensor* attention(struct ggml_context* ctx, - struct ggml_tensor* q, - struct ggml_tensor* k, - struct ggml_tensor* v, - struct ggml_tensor* pe) { - // q,k,v: [N, L, n_head, d_head] - // pe: [L, d_head/2, 2, 2] - // return: [N, L, n_head*d_head] - q = apply_rope(ctx, q, pe); // [N*n_head, L, d_head] - k = apply_rope(ctx, k, pe); // [N*n_head, L, d_head] + struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* pe) { + // x: [N, n_token, dim] + // pe: [n_token, d_head/2, 2, 2] + // return [N, n_token, dim] + auto qkv = pre_attention(ctx, x); // q,k,v: [N, n_token, n_head, d_head] + x = attention(ctx, qkv[0], qkv[1], qkv[2], pe); // [N, n_token, dim] + x = post_attention(ctx, x); // [N, n_token, dim] + return x; + } + }; - auto x = ggml_nn_attention_ext(ctx, q, k, v, v->ne[1], NULL, false, true); // [N, L, n_head*d_head] - return x; -} + struct ModulationOut { + ggml_tensor* shift = NULL; + ggml_tensor* scale = NULL; + ggml_tensor* gate = NULL; -struct SelfAttention : public GGMLBlock { -public: - int64_t num_heads; + ModulationOut(ggml_tensor* shift = NULL, ggml_tensor* scale = NULL, ggml_tensor* gate = NULL) + : shift(shift), scale(scale), gate(gate) {} + }; -public: - SelfAttention(int64_t dim, - int64_t num_heads = 8, - bool qkv_bias = false) - : num_heads(num_heads) { - int64_t head_dim = dim / num_heads; - blocks["qkv"] = std::shared_ptr(new Linear(dim, dim * 3, qkv_bias)); - blocks["norm"] = std::shared_ptr(new QKNorm(head_dim)); - blocks["proj"] = std::shared_ptr(new Linear(dim, dim)); - } + struct Modulation : public GGMLBlock { + public: + bool is_double; + int multiplier; - std::vector pre_attention(struct ggml_context* ctx, struct ggml_tensor* x) { - auto qkv_proj = std::dynamic_pointer_cast(blocks["qkv"]); - auto norm = std::dynamic_pointer_cast(blocks["norm"]); + public: + Modulation(int64_t dim, bool is_double) + : is_double(is_double) { + multiplier = is_double ? 6 : 3; + blocks["lin"] = std::shared_ptr(new Linear(dim, dim * multiplier)); + } + std::vector forward(struct ggml_context* ctx, struct ggml_tensor* vec) { + // x: [N, dim] + // return: [ModulationOut, ModulationOut] + auto lin = std::dynamic_pointer_cast(blocks["lin"]); - auto qkv = qkv_proj->forward(ctx, x); - auto qkv_vec = split_qkv(ctx, qkv); - int64_t head_dim = qkv_vec[0]->ne[0] / num_heads; - auto q = ggml_reshape_4d(ctx, qkv_vec[0], head_dim, num_heads, qkv_vec[0]->ne[1], qkv_vec[0]->ne[2]); - auto k = ggml_reshape_4d(ctx, qkv_vec[1], head_dim, num_heads, qkv_vec[1]->ne[1], qkv_vec[1]->ne[2]); - auto v = ggml_reshape_4d(ctx, qkv_vec[2], head_dim, num_heads, qkv_vec[2]->ne[1], qkv_vec[2]->ne[2]); - q = norm->query_norm(ctx, q); - k = norm->key_norm(ctx, k); - return {q, k, v}; - } + auto out = ggml_silu(ctx, vec); + out = lin->forward(ctx, out); // [N, multiplier*dim] - struct ggml_tensor* post_attention(struct ggml_context* ctx, struct ggml_tensor* x) { - auto proj = std::dynamic_pointer_cast(blocks["proj"]); + auto m = ggml_reshape_3d(ctx, out, vec->ne[0], multiplier, vec->ne[1]); // [N, multiplier, dim] + m = ggml_cont(ctx, ggml_permute(ctx, m, 0, 2, 1, 3)); // [multiplier, N, dim] - x = proj->forward(ctx, x); // [N, n_token, dim] + int64_t offset = m->nb[1] * m->ne[1]; + auto shift_0 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, dim] + auto scale_0 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, dim] + auto gate_0 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 2); // [N, dim] + + if (is_double) { + auto shift_1 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 3); // [N, dim] + auto scale_1 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 4); // [N, dim] + auto gate_1 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 5); // [N, dim] + return {ModulationOut(shift_0, scale_0, gate_0), ModulationOut(shift_1, scale_1, gate_1)}; + } + + return {ModulationOut(shift_0, scale_0, gate_0), ModulationOut()}; + } + }; + + __STATIC_INLINE__ struct ggml_tensor* modulate(struct ggml_context* ctx, + struct ggml_tensor* x, + struct ggml_tensor* shift, + struct ggml_tensor* scale) { + // x: [N, L, C] + // scale: [N, C] + // shift: [N, C] + scale = ggml_reshape_3d(ctx, scale, scale->ne[0], 1, scale->ne[1]); // [N, 1, C] + shift = ggml_reshape_3d(ctx, shift, shift->ne[0], 1, shift->ne[1]); // [N, 1, C] + x = ggml_add(ctx, x, ggml_mul(ctx, x, scale)); + x = ggml_add(ctx, x, shift); return x; } - struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* pe) { - // x: [N, n_token, dim] - // pe: [n_token, d_head/2, 2, 2] - // return [N, n_token, dim] - auto qkv = pre_attention(ctx, x); // q,k,v: [N, n_token, n_head, d_head] - x = attention(ctx, qkv[0], qkv[1], qkv[2], pe); // [N, n_token, dim] - x = post_attention(ctx, x); // [N, n_token, dim] - return x; - } -}; + struct DoubleStreamBlock : public GGMLBlock { + public: + DoubleStreamBlock(int64_t hidden_size, + int64_t num_heads, + float mlp_ratio, + bool qkv_bias = false) { + int64_t mlp_hidden_dim = hidden_size * mlp_ratio; + blocks["img_mod"] = std::shared_ptr(new Modulation(hidden_size, true)); + blocks["img_norm1"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-6f, false)); + blocks["img_attn"] = std::shared_ptr(new SelfAttention(hidden_size, num_heads, qkv_bias)); + blocks["img_norm2"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-6f, false)); + blocks["img_mlp.0"] = std::shared_ptr(new Linear(hidden_size, mlp_hidden_dim)); + // img_mlp.1 is nn.GELU(approximate="tanh") + blocks["img_mlp.2"] = std::shared_ptr(new Linear(mlp_hidden_dim, hidden_size)); -struct ModulationOut { - ggml_tensor* shift = NULL; - ggml_tensor* scale = NULL; - ggml_tensor* gate = NULL; + blocks["txt_mod"] = std::shared_ptr(new Modulation(hidden_size, true)); + blocks["txt_norm1"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-6f, false)); + blocks["txt_attn"] = std::shared_ptr(new SelfAttention(hidden_size, num_heads, qkv_bias)); - ModulationOut(ggml_tensor* shift = NULL, ggml_tensor* scale = NULL, ggml_tensor* gate = NULL) - : shift(shift), scale(scale), gate(gate) {} -}; - -struct Modulation : public GGMLBlock { -public: - bool is_double; - int multiplier; -public: - Modulation(int64_t dim, bool is_double): is_double(is_double) { - multiplier = is_double? 6 : 3; - blocks["lin"] = std::shared_ptr(new Linear(dim, dim * multiplier)); - } - - std::vector forward(struct ggml_context* ctx, struct ggml_tensor* vec) { - // x: [N, dim] - // return: [ModulationOut, ModulationOut] - auto lin = std::dynamic_pointer_cast(blocks["lin"]); - - auto out = ggml_silu(ctx, vec); - out = lin->forward(ctx, out); // [N, multiplier*dim] - - auto m = ggml_reshape_3d(ctx, out, vec->ne[0], multiplier, vec->ne[1]); // [N, multiplier, dim] - m = ggml_cont(ctx, ggml_permute(ctx, m, 0, 2, 1, 3)); // [multiplier, N, dim] - - int64_t offset = m->nb[1] * m->ne[1]; - auto shift_0 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, dim] - auto scale_0 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, dim] - auto gate_0 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 2); // [N, dim] - - if (is_double) { - auto shift_1 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 3); // [N, dim] - auto scale_1 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 4); // [N, dim] - auto gate_1 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 5); // [N, dim] - return {ModulationOut(shift_0, scale_0, gate_0), ModulationOut(shift_1, scale_1, gate_1)}; + blocks["txt_norm2"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-6f, false)); + blocks["txt_mlp.0"] = std::shared_ptr(new Linear(hidden_size, mlp_hidden_dim)); + // img_mlp.1 is nn.GELU(approximate="tanh") + blocks["txt_mlp.2"] = std::shared_ptr(new Linear(mlp_hidden_dim, hidden_size)); } - return {ModulationOut(shift_0, scale_0, gate_0), ModulationOut()}; - } -}; + std::pair forward(struct ggml_context* ctx, + struct ggml_tensor* img, + struct ggml_tensor* txt, + struct ggml_tensor* vec, + struct ggml_tensor* pe) { + // img: [N, n_img_token, hidden_size] + // txt: [N, n_txt_token, hidden_size] + // pe: [n_img_token + n_txt_token, d_head/2, 2, 2] + // return: ([N, n_img_token, hidden_size], [N, n_txt_token, hidden_size]) -__STATIC_INLINE__ struct ggml_tensor* modulate(struct ggml_context* ctx, - struct ggml_tensor* x, - struct ggml_tensor* shift, - struct ggml_tensor* scale) { - // x: [N, L, C] - // scale: [N, C] - // shift: [N, C] - scale = ggml_reshape_3d(ctx, scale, scale->ne[0], 1, scale->ne[1]); // [N, 1, C] - shift = ggml_reshape_3d(ctx, shift, shift->ne[0], 1, shift->ne[1]); // [N, 1, C] - x = ggml_add(ctx, x, ggml_mul(ctx, x, scale)); - x = ggml_add(ctx, x, shift); - return x; -} + auto img_mod = std::dynamic_pointer_cast(blocks["img_mod"]); + auto img_norm1 = std::dynamic_pointer_cast(blocks["img_norm1"]); + auto img_attn = std::dynamic_pointer_cast(blocks["img_attn"]); -struct DoubleStreamBlock : public GGMLBlock { -public: - DoubleStreamBlock(int64_t hidden_size, - int64_t num_heads, - float mlp_ratio, - bool qkv_bias = false) { - int64_t mlp_hidden_dim = hidden_size * mlp_ratio; - blocks["img_mod"] = std::shared_ptr(new Modulation(hidden_size, true)); - blocks["img_norm1"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-6f, false)); - blocks["img_attn"] = std::shared_ptr(new SelfAttention(hidden_size, num_heads, qkv_bias)); + auto img_norm2 = std::dynamic_pointer_cast(blocks["img_norm2"]); + auto img_mlp_0 = std::dynamic_pointer_cast(blocks["img_mlp.0"]); + auto img_mlp_2 = std::dynamic_pointer_cast(blocks["img_mlp.2"]); - blocks["img_norm2"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-6f, false)); - blocks["img_mlp.0"] = std::shared_ptr(new Linear(hidden_size, mlp_hidden_dim)); - // img_mlp.1 is nn.GELU(approximate="tanh") - blocks["img_mlp.2"] = std::shared_ptr(new Linear(mlp_hidden_dim, hidden_size)); + auto txt_mod = std::dynamic_pointer_cast(blocks["txt_mod"]); + auto txt_norm1 = std::dynamic_pointer_cast(blocks["txt_norm1"]); + auto txt_attn = std::dynamic_pointer_cast(blocks["txt_attn"]); - blocks["txt_mod"] = std::shared_ptr(new Modulation(hidden_size, true)); - blocks["txt_norm1"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-6f, false)); - blocks["txt_attn"] = std::shared_ptr(new SelfAttention(hidden_size, num_heads, qkv_bias)); + auto txt_norm2 = std::dynamic_pointer_cast(blocks["txt_norm2"]); + auto txt_mlp_0 = std::dynamic_pointer_cast(blocks["txt_mlp.0"]); + auto txt_mlp_2 = std::dynamic_pointer_cast(blocks["txt_mlp.2"]); - blocks["txt_norm2"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-6f, false)); - blocks["txt_mlp.0"] = std::shared_ptr(new Linear(hidden_size, mlp_hidden_dim)); - // img_mlp.1 is nn.GELU(approximate="tanh") - blocks["txt_mlp.2"] = std::shared_ptr(new Linear(mlp_hidden_dim, hidden_size)); - } + auto img_mods = img_mod->forward(ctx, vec); + ModulationOut img_mod1 = img_mods[0]; + ModulationOut img_mod2 = img_mods[1]; + auto txt_mods = txt_mod->forward(ctx, vec); + ModulationOut txt_mod1 = txt_mods[0]; + ModulationOut txt_mod2 = txt_mods[1]; - std::pair forward(struct ggml_context* ctx, - struct ggml_tensor* img, - struct ggml_tensor* txt, - struct ggml_tensor* vec, - struct ggml_tensor* pe) { - // img: [N, n_img_token, hidden_size] - // txt: [N, n_txt_token, hidden_size] - // pe: [n_img_token + n_txt_token, d_head/2, 2, 2] - // return: ([N, n_img_token, hidden_size], [N, n_txt_token, hidden_size]) - - auto img_mod = std::dynamic_pointer_cast(blocks["img_mod"]); - auto img_norm1 = std::dynamic_pointer_cast(blocks["img_norm1"]); - auto img_attn = std::dynamic_pointer_cast(blocks["img_attn"]); + // prepare image for attention + auto img_modulated = img_norm1->forward(ctx, img); + img_modulated = Flux::modulate(ctx, img_modulated, img_mod1.shift, img_mod1.scale); + auto img_qkv = img_attn->pre_attention(ctx, img_modulated); // q,k,v: [N, n_img_token, n_head, d_head] + auto img_q = img_qkv[0]; + auto img_k = img_qkv[1]; + auto img_v = img_qkv[2]; - auto img_norm2 = std::dynamic_pointer_cast(blocks["img_norm2"]); - auto img_mlp_0 = std::dynamic_pointer_cast(blocks["img_mlp.0"]); - auto img_mlp_2 = std::dynamic_pointer_cast(blocks["img_mlp.2"]); + // prepare txt for attention + auto txt_modulated = txt_norm1->forward(ctx, txt); + txt_modulated = Flux::modulate(ctx, txt_modulated, txt_mod1.shift, txt_mod1.scale); + auto txt_qkv = txt_attn->pre_attention(ctx, txt_modulated); // q,k,v: [N, n_txt_token, n_head, d_head] + auto txt_q = txt_qkv[0]; + auto txt_k = txt_qkv[1]; + auto txt_v = txt_qkv[2]; - auto txt_mod = std::dynamic_pointer_cast(blocks["txt_mod"]); - auto txt_norm1 = std::dynamic_pointer_cast(blocks["txt_norm1"]); - auto txt_attn = std::dynamic_pointer_cast(blocks["txt_attn"]); + // run actual attention + auto q = ggml_concat(ctx, txt_q, img_q, 2); // [N, n_txt_token + n_img_token, n_head, d_head] + auto k = ggml_concat(ctx, txt_k, img_k, 2); // [N, n_txt_token + n_img_token, n_head, d_head] + auto v = ggml_concat(ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head] - auto txt_norm2 = std::dynamic_pointer_cast(blocks["txt_norm2"]); - auto txt_mlp_0 = std::dynamic_pointer_cast(blocks["txt_mlp.0"]); - auto txt_mlp_2 = std::dynamic_pointer_cast(blocks["txt_mlp.2"]); + auto attn = attention(ctx, q, k, v, pe); // [N, n_txt_token + n_img_token, n_head*d_head] + attn = ggml_cont(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size] + auto txt_attn_out = ggml_view_3d(ctx, + attn, + attn->ne[0], + attn->ne[1], + txt->ne[1], + attn->nb[1], + attn->nb[2], + 0); // [n_txt_token, N, hidden_size] + txt_attn_out = ggml_cont(ctx, ggml_permute(ctx, txt_attn_out, 0, 2, 1, 3)); // [N, n_txt_token, hidden_size] + auto img_attn_out = ggml_view_3d(ctx, + attn, + attn->ne[0], + attn->ne[1], + img->ne[1], + attn->nb[1], + attn->nb[2], + attn->nb[2] * txt->ne[1]); // [n_img_token, N, hidden_size] + img_attn_out = ggml_cont(ctx, ggml_permute(ctx, img_attn_out, 0, 2, 1, 3)); // [N, n_img_token, hidden_size] + // calculate the img bloks + img = ggml_add(ctx, img, ggml_mul(ctx, img_attn->post_attention(ctx, img_attn_out), img_mod1.gate)); - auto img_mods = img_mod->forward(ctx, vec); - ModulationOut img_mod1 = img_mods[0]; - ModulationOut img_mod2 = img_mods[1]; - auto txt_mods = txt_mod->forward(ctx, vec); - ModulationOut txt_mod1 = txt_mods[0]; - ModulationOut txt_mod2 = txt_mods[1]; + auto img_mlp_out = img_mlp_0->forward(ctx, Flux::modulate(ctx, img_norm2->forward(ctx, img), img_mod2.shift, img_mod2.scale)); + img_mlp_out = ggml_gelu_inplace(ctx, img_mlp_out); + img_mlp_out = img_mlp_2->forward(ctx, img_mlp_out); - // prepare image for attention - auto img_modulated = img_norm1->forward(ctx, img); - img_modulated = Flux::modulate(ctx, img_modulated, img_mod1.shift, img_mod1.scale); - auto img_qkv = img_attn->pre_attention(ctx, img_modulated); // q,k,v: [N, n_img_token, n_head, d_head] - auto img_q = img_qkv[0]; - auto img_k = img_qkv[1]; - auto img_v = img_qkv[2]; + img = ggml_add(ctx, img, ggml_mul(ctx, img_mlp_out, img_mod2.gate)); - // prepare txt for attention - auto txt_modulated = txt_norm1->forward(ctx, txt); - txt_modulated = Flux::modulate(ctx, txt_modulated, txt_mod1.shift, txt_mod1.scale); - auto txt_qkv = txt_attn->pre_attention(ctx, txt_modulated); // q,k,v: [N, n_txt_token, n_head, d_head] - auto txt_q = txt_qkv[0]; - auto txt_k = txt_qkv[1]; - auto txt_v = txt_qkv[2]; + // calculate the txt bloks + txt = ggml_add(ctx, txt, ggml_mul(ctx, txt_attn->post_attention(ctx, txt_attn_out), txt_mod1.gate)); - // run actual attention - auto q = ggml_concat(ctx, txt_q, img_q, 2); // [N, n_txt_token + n_img_token, n_head, d_head] - auto k = ggml_concat(ctx, txt_k, img_k, 2); // [N, n_txt_token + n_img_token, n_head, d_head] - auto v = ggml_concat(ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head] + auto txt_mlp_out = txt_mlp_0->forward(ctx, Flux::modulate(ctx, txt_norm2->forward(ctx, txt), txt_mod2.shift, txt_mod2.scale)); + txt_mlp_out = ggml_gelu_inplace(ctx, txt_mlp_out); + txt_mlp_out = txt_mlp_2->forward(ctx, txt_mlp_out); - auto attn = attention(ctx, q, k, v, pe); // [N, n_txt_token + n_img_token, n_head*d_head] - attn = ggml_cont(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size] - auto txt_attn_out = ggml_view_3d(ctx, - attn, - attn->ne[0], - attn->ne[1], - txt->ne[1], - attn->nb[1], - attn->nb[2], - 0); // [n_txt_token, N, hidden_size] - txt_attn_out = ggml_cont(ctx, ggml_permute(ctx, txt_attn_out, 0, 2, 1, 3)); // [N, n_txt_token, hidden_size] - auto img_attn_out = ggml_view_3d(ctx, - attn, - attn->ne[0], - attn->ne[1], - img->ne[1], - attn->nb[1], - attn->nb[2], - attn->nb[2] * txt->ne[1]); // [n_img_token, N, hidden_size] - img_attn_out = ggml_cont(ctx, ggml_permute(ctx, img_attn_out, 0, 2, 1, 3)); // [N, n_img_token, hidden_size] + txt = ggml_add(ctx, txt, ggml_mul(ctx, txt_mlp_out, txt_mod2.gate)); - // calculate the img bloks - img = ggml_add(ctx, img, ggml_mul(ctx, img_attn->post_attention(ctx, img_attn_out), img_mod1.gate)); - - auto img_mlp_out = img_mlp_0->forward(ctx, Flux::modulate(ctx, img_norm2->forward(ctx, img), img_mod2.shift, img_mod2.scale)); - img_mlp_out = ggml_gelu_inplace(ctx, img_mlp_out); - img_mlp_out = img_mlp_2->forward(ctx, img_mlp_out); - - img = ggml_add(ctx, img, ggml_mul(ctx, img_mlp_out, img_mod2.gate)); - - // calculate the txt bloks - txt = ggml_add(ctx, txt, ggml_mul(ctx, txt_attn->post_attention(ctx, txt_attn_out), txt_mod1.gate)); - - auto txt_mlp_out = txt_mlp_0->forward(ctx, Flux::modulate(ctx, txt_norm2->forward(ctx, txt), txt_mod2.shift, txt_mod2.scale)); - txt_mlp_out = ggml_gelu_inplace(ctx, txt_mlp_out); - txt_mlp_out = txt_mlp_2->forward(ctx, txt_mlp_out); - - txt = ggml_add(ctx, txt, ggml_mul(ctx, txt_mlp_out, txt_mod2.gate)); - - return {img, txt}; - } -}; - - -struct SingleStreamBlock : public GGMLBlock { -public: - int64_t num_heads; - int64_t hidden_size; - int64_t mlp_hidden_dim; -public: - SingleStreamBlock(int64_t hidden_size, - int64_t num_heads, - float mlp_ratio = 4.0f, - float qk_scale = 0.f) : - hidden_size(hidden_size), num_heads(num_heads) { - int64_t head_dim = hidden_size / num_heads; - float scale = qk_scale; - if (scale <= 0.f) { - scale = 1 / sqrt((float)head_dim); + return {img, txt}; } - mlp_hidden_dim = hidden_size * mlp_ratio; + }; - blocks["linear1"] = std::shared_ptr(new Linear(hidden_size, hidden_size * 3 + mlp_hidden_dim)); - blocks["linear2"] = std::shared_ptr(new Linear(hidden_size + mlp_hidden_dim, hidden_size)); - blocks["norm"] = std::shared_ptr(new QKNorm(head_dim)); - blocks["pre_norm"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-6f, false)); - // mlp_act is nn.GELU(approximate="tanh") - blocks["modulation"] = std::shared_ptr(new Modulation(hidden_size, false)); - } + struct SingleStreamBlock : public GGMLBlock { + public: + int64_t num_heads; + int64_t hidden_size; + int64_t mlp_hidden_dim; - struct ggml_tensor* forward(struct ggml_context* ctx, - struct ggml_tensor* x, - struct ggml_tensor* vec, - struct ggml_tensor* pe) { - // x: [N, n_token, hidden_size] - // pe: [n_token, d_head/2, 2, 2] - // return: [N, n_token, hidden_size] - - auto linear1 = std::dynamic_pointer_cast(blocks["linear1"]); - auto linear2 = std::dynamic_pointer_cast(blocks["linear2"]); - auto norm = std::dynamic_pointer_cast(blocks["norm"]); - auto pre_norm = std::dynamic_pointer_cast(blocks["pre_norm"]); - auto modulation = std::dynamic_pointer_cast(blocks["modulation"]); - - auto mods = modulation->forward(ctx, vec); - ModulationOut mod = mods[0]; - - auto x_mod = Flux::modulate(ctx, pre_norm->forward(ctx, x), mod.shift, mod.scale); - auto qkv_mlp = linear1->forward(ctx, x_mod); // [N, n_token, hidden_size * 3 + mlp_hidden_dim] - qkv_mlp = ggml_cont(ctx, ggml_permute(ctx, qkv_mlp, 2, 0, 1, 3)); // [hidden_size * 3 + mlp_hidden_dim, N, n_token] - - auto qkv = ggml_view_3d(ctx, - qkv_mlp, - qkv_mlp->ne[0], - qkv_mlp->ne[1], - hidden_size * 3, - qkv_mlp->nb[1], - qkv_mlp->nb[2], - 0); // [hidden_size * 3 , N, n_token] - qkv = ggml_cont(ctx, ggml_permute(ctx, qkv, 1, 2, 0, 3)); // [N, n_token, hidden_size * 3] - auto mlp = ggml_view_3d(ctx, - qkv_mlp, - qkv_mlp->ne[0], - qkv_mlp->ne[1], - mlp_hidden_dim, - qkv_mlp->nb[1], - qkv_mlp->nb[2], - qkv_mlp->nb[2] * hidden_size * 3); // [mlp_hidden_dim , N, n_token] - mlp = ggml_cont(ctx, ggml_permute(ctx, mlp, 1, 2, 0, 3)); // [N, n_token, mlp_hidden_dim] - - auto qkv_vec = split_qkv(ctx, qkv); // q,k,v: [N, n_token, hidden_size] - int64_t head_dim = hidden_size / num_heads; - auto q = ggml_reshape_4d(ctx, qkv_vec[0], head_dim, num_heads, qkv_vec[0]->ne[1], qkv_vec[0]->ne[2]); // [N, n_token, n_head, d_head] - auto k = ggml_reshape_4d(ctx, qkv_vec[1], head_dim, num_heads, qkv_vec[1]->ne[1], qkv_vec[1]->ne[2]); // [N, n_token, n_head, d_head] - auto v = ggml_reshape_4d(ctx, qkv_vec[2], head_dim, num_heads, qkv_vec[2]->ne[1], qkv_vec[2]->ne[2]); // [N, n_token, n_head, d_head] - q = norm->query_norm(ctx, q); - k = norm->key_norm(ctx, k); - auto attn = attention(ctx, q, k, v, pe); // [N, n_token, hidden_size] - - auto attn_mlp = ggml_concat(ctx, attn, ggml_gelu_inplace(ctx, mlp), 0); // [N, n_token, hidden_size + mlp_hidden_dim] - auto output = linear2->forward(ctx, attn_mlp); // [N, n_token, hidden_size] - - output = ggml_add(ctx, x, ggml_mul(ctx, output, mod.gate)); - return output; - } -}; - - -struct LastLayer : public GGMLBlock { -public: - LastLayer(int64_t hidden_size, - int64_t patch_size, - int64_t out_channels) { - blocks["norm_final"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-06f, false)); - blocks["linear"] = std::shared_ptr(new Linear(hidden_size, patch_size * patch_size * out_channels)); - blocks["adaLN_modulation.1"] = std::shared_ptr(new Linear(hidden_size, 2 * hidden_size)); - } - - struct ggml_tensor* forward(struct ggml_context* ctx, - struct ggml_tensor* x, - struct ggml_tensor* c) { - // x: [N, n_token, hidden_size] - // c: [N, hidden_size] - // return: [N, n_token, patch_size * patch_size * out_channels] - auto norm_final = std::dynamic_pointer_cast(blocks["norm_final"]); - auto linear = std::dynamic_pointer_cast(blocks["linear"]); - auto adaLN_modulation_1 = std::dynamic_pointer_cast(blocks["adaLN_modulation.1"]); - - auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx, c)); // [N, 2 * hidden_size] - m = ggml_reshape_3d(ctx, m, c->ne[0], 2, c->ne[1]); // [N, 2, hidden_size] - m = ggml_cont(ctx, ggml_permute(ctx, m, 0, 2, 1, 3)); // [2, N, hidden_size] - - int64_t offset = m->nb[1] * m->ne[1]; - auto shift = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, hidden_size] - auto scale = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, hidden_size] - - x = Flux::modulate(ctx, norm_final->forward(ctx, x), shift, scale); - x = linear->forward(ctx, x); - - return x; - } -}; - -struct FluxParams { - int64_t in_channels = 64; - int64_t vec_in_dim=768; - int64_t context_in_dim = 4096; - int64_t hidden_size = 3072; - float mlp_ratio = 4.0f; - int64_t num_heads = 24; - int64_t depth = 19; - int64_t depth_single_blocks = 38; - std::vector axes_dim = {16, 56, 56}; - int64_t axes_dim_sum = 128; - int theta = 10000; - bool qkv_bias = true; - bool guidance_embed = true; -}; - - -struct Flux : public GGMLBlock { -public: - std::vector linspace(float start, float end, int num) { - std::vector result(num); - float step = (end - start) / (num - 1); - for (int i = 0; i < num; ++i) { - result[i] = start + i * step; - } - return result; - } - - std::vector> transpose(const std::vector>& mat) { - int rows = mat.size(); - int cols = mat[0].size(); - std::vector> transposed(cols, std::vector(rows)); - for (int i = 0; i < rows; ++i) { - for (int j = 0; j < cols; ++j) { - transposed[j][i] = mat[i][j]; + public: + SingleStreamBlock(int64_t hidden_size, + int64_t num_heads, + float mlp_ratio = 4.0f, + float qk_scale = 0.f) + : hidden_size(hidden_size), num_heads(num_heads) { + int64_t head_dim = hidden_size / num_heads; + float scale = qk_scale; + if (scale <= 0.f) { + scale = 1 / sqrt((float)head_dim); } - } - return transposed; - } + mlp_hidden_dim = hidden_size * mlp_ratio; - std::vector flatten(const std::vector>& vec) { - std::vector flat_vec; - for (const auto& sub_vec : vec) { - flat_vec.insert(flat_vec.end(), sub_vec.begin(), sub_vec.end()); - } - return flat_vec; - } - - std::vector> rope(const std::vector& pos, int dim, int theta) { - assert(dim % 2 == 0); - int half_dim = dim / 2; - - std::vector scale = linspace(0, (dim * 1.0f - 2) / dim, half_dim); - - std::vector omega(half_dim); - for (int i = 0; i < half_dim; ++i) { - omega[i] = 1.0 / std::pow(theta, scale[i]); - } - - int pos_size = pos.size(); - std::vector> out(pos_size, std::vector(half_dim)); - for (int i = 0; i < pos_size; ++i) { - for (int j = 0; j < half_dim; ++j) { - out[i][j] = pos[i] * omega[j]; - } - } - - std::vector> result(pos_size, std::vector(half_dim * 4)); - for (int i = 0; i < pos_size; ++i) { - for (int j = 0; j < half_dim; ++j) { - result[i][4 * j] = std::cos(out[i][j]); - result[i][4 * j + 1] = -std::sin(out[i][j]); - result[i][4 * j + 2] = std::sin(out[i][j]); - result[i][4 * j + 3] = std::cos(out[i][j]); - } - } - - return result; - } - - // Generate IDs for image patches and text - std::vector> gen_ids(int h, int w, int patch_size, int bs, int context_len) { - int h_len = (h + (patch_size / 2)) / patch_size; - int w_len = (w + (patch_size / 2)) / patch_size; - - std::vector> img_ids(h_len * w_len, std::vector(3, 0.0)); - - std::vector row_ids = linspace(0, h_len - 1, h_len); - std::vector col_ids = linspace(0, w_len - 1, w_len); - - for (int i = 0; i < h_len; ++i) { - for (int j = 0; j < w_len; ++j) { - img_ids[i * w_len + j][1] = row_ids[i]; - img_ids[i * w_len + j][2] = col_ids[j]; - } - } - - std::vector> img_ids_repeated(bs * img_ids.size(), std::vector(3)); - for (int i = 0; i < bs; ++i) { - for (int j = 0; j < img_ids.size(); ++j) { - img_ids_repeated[i * img_ids.size() + j] = img_ids[j]; - } - } - - std::vector> txt_ids(bs * context_len, std::vector(3, 0.0)); - std::vector> ids(bs * (context_len + img_ids.size()), std::vector(3)); - for (int i = 0; i < bs; ++i) { - for (int j = 0; j < context_len; ++j) { - ids[i * (context_len + img_ids.size()) + j] = txt_ids[j]; - } - for (int j = 0; j < img_ids.size(); ++j) { - ids[i * (context_len + img_ids.size()) + context_len + j] = img_ids_repeated[i * img_ids.size() + j]; - } - } - - return ids; - } - - // Generate positional embeddings - std::vector gen_pe(int h, int w, int patch_size, int bs, int context_len, int theta, const std::vector& axes_dim) { - std::vector> ids = gen_ids(h, w, patch_size, bs, context_len); - std::vector> trans_ids = transpose(ids); - size_t pos_len = ids.size(); - int num_axes = axes_dim.size(); - for (int i = 0; i < pos_len; i++) { - // std::cout << trans_ids[0][i] << " " << trans_ids[1][i] << " " << trans_ids[2][i] << std::endl; + blocks["linear1"] = std::shared_ptr(new Linear(hidden_size, hidden_size * 3 + mlp_hidden_dim)); + blocks["linear2"] = std::shared_ptr(new Linear(hidden_size + mlp_hidden_dim, hidden_size)); + blocks["norm"] = std::shared_ptr(new QKNorm(head_dim)); + blocks["pre_norm"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-6f, false)); + // mlp_act is nn.GELU(approximate="tanh") + blocks["modulation"] = std::shared_ptr(new Modulation(hidden_size, false)); } - - int emb_dim = 0; - for (int d : axes_dim) emb_dim += d / 2; - - std::vector> emb(bs * pos_len, std::vector(emb_dim * 2 * 2, 0.0)); - int offset = 0; - for (int i = 0; i < num_axes; ++i) { - std::vector> rope_emb = rope(trans_ids[i], axes_dim[i], theta); // [bs*pos_len, axes_dim[i]/2 * 2 * 2] - for (int b = 0; b < bs; ++b) { - for (int j = 0; j < pos_len; ++j) { - for (int k = 0; k < rope_emb[0].size(); ++k) { - emb[b * pos_len + j][offset + k] = rope_emb[j][k]; - } + struct ggml_tensor* forward(struct ggml_context* ctx, + struct ggml_tensor* x, + struct ggml_tensor* vec, + struct ggml_tensor* pe) { + // x: [N, n_token, hidden_size] + // pe: [n_token, d_head/2, 2, 2] + // return: [N, n_token, hidden_size] + + auto linear1 = std::dynamic_pointer_cast(blocks["linear1"]); + auto linear2 = std::dynamic_pointer_cast(blocks["linear2"]); + auto norm = std::dynamic_pointer_cast(blocks["norm"]); + auto pre_norm = std::dynamic_pointer_cast(blocks["pre_norm"]); + auto modulation = std::dynamic_pointer_cast(blocks["modulation"]); + + auto mods = modulation->forward(ctx, vec); + ModulationOut mod = mods[0]; + + auto x_mod = Flux::modulate(ctx, pre_norm->forward(ctx, x), mod.shift, mod.scale); + auto qkv_mlp = linear1->forward(ctx, x_mod); // [N, n_token, hidden_size * 3 + mlp_hidden_dim] + qkv_mlp = ggml_cont(ctx, ggml_permute(ctx, qkv_mlp, 2, 0, 1, 3)); // [hidden_size * 3 + mlp_hidden_dim, N, n_token] + + auto qkv = ggml_view_3d(ctx, + qkv_mlp, + qkv_mlp->ne[0], + qkv_mlp->ne[1], + hidden_size * 3, + qkv_mlp->nb[1], + qkv_mlp->nb[2], + 0); // [hidden_size * 3 , N, n_token] + qkv = ggml_cont(ctx, ggml_permute(ctx, qkv, 1, 2, 0, 3)); // [N, n_token, hidden_size * 3] + auto mlp = ggml_view_3d(ctx, + qkv_mlp, + qkv_mlp->ne[0], + qkv_mlp->ne[1], + mlp_hidden_dim, + qkv_mlp->nb[1], + qkv_mlp->nb[2], + qkv_mlp->nb[2] * hidden_size * 3); // [mlp_hidden_dim , N, n_token] + mlp = ggml_cont(ctx, ggml_permute(ctx, mlp, 1, 2, 0, 3)); // [N, n_token, mlp_hidden_dim] + + auto qkv_vec = split_qkv(ctx, qkv); // q,k,v: [N, n_token, hidden_size] + int64_t head_dim = hidden_size / num_heads; + auto q = ggml_reshape_4d(ctx, qkv_vec[0], head_dim, num_heads, qkv_vec[0]->ne[1], qkv_vec[0]->ne[2]); // [N, n_token, n_head, d_head] + auto k = ggml_reshape_4d(ctx, qkv_vec[1], head_dim, num_heads, qkv_vec[1]->ne[1], qkv_vec[1]->ne[2]); // [N, n_token, n_head, d_head] + auto v = ggml_reshape_4d(ctx, qkv_vec[2], head_dim, num_heads, qkv_vec[2]->ne[1], qkv_vec[2]->ne[2]); // [N, n_token, n_head, d_head] + q = norm->query_norm(ctx, q); + k = norm->key_norm(ctx, k); + auto attn = attention(ctx, q, k, v, pe); // [N, n_token, hidden_size] + + auto attn_mlp = ggml_concat(ctx, attn, ggml_gelu_inplace(ctx, mlp), 0); // [N, n_token, hidden_size + mlp_hidden_dim] + auto output = linear2->forward(ctx, attn_mlp); // [N, n_token, hidden_size] + + output = ggml_add(ctx, x, ggml_mul(ctx, output, mod.gate)); + return output; + } + }; + + struct LastLayer : public GGMLBlock { + public: + LastLayer(int64_t hidden_size, + int64_t patch_size, + int64_t out_channels) { + blocks["norm_final"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-06f, false)); + blocks["linear"] = std::shared_ptr(new Linear(hidden_size, patch_size * patch_size * out_channels)); + blocks["adaLN_modulation.1"] = std::shared_ptr(new Linear(hidden_size, 2 * hidden_size)); + } + + struct ggml_tensor* forward(struct ggml_context* ctx, + struct ggml_tensor* x, + struct ggml_tensor* c) { + // x: [N, n_token, hidden_size] + // c: [N, hidden_size] + // return: [N, n_token, patch_size * patch_size * out_channels] + auto norm_final = std::dynamic_pointer_cast(blocks["norm_final"]); + auto linear = std::dynamic_pointer_cast(blocks["linear"]); + auto adaLN_modulation_1 = std::dynamic_pointer_cast(blocks["adaLN_modulation.1"]); + + auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx, c)); // [N, 2 * hidden_size] + m = ggml_reshape_3d(ctx, m, c->ne[0], 2, c->ne[1]); // [N, 2, hidden_size] + m = ggml_cont(ctx, ggml_permute(ctx, m, 0, 2, 1, 3)); // [2, N, hidden_size] + + int64_t offset = m->nb[1] * m->ne[1]; + auto shift = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, hidden_size] + auto scale = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, hidden_size] + + x = Flux::modulate(ctx, norm_final->forward(ctx, x), shift, scale); + x = linear->forward(ctx, x); + + return x; + } + }; + + struct FluxParams { + int64_t in_channels = 64; + int64_t vec_in_dim = 768; + int64_t context_in_dim = 4096; + int64_t hidden_size = 3072; + float mlp_ratio = 4.0f; + int64_t num_heads = 24; + int64_t depth = 19; + int64_t depth_single_blocks = 38; + std::vector axes_dim = {16, 56, 56}; + int64_t axes_dim_sum = 128; + int theta = 10000; + bool qkv_bias = true; + bool guidance_embed = true; + }; + + struct Flux : public GGMLBlock { + public: + std::vector linspace(float start, float end, int num) { + std::vector result(num); + float step = (end - start) / (num - 1); + for (int i = 0; i < num; ++i) { + result[i] = start + i * step; + } + return result; + } + + std::vector> transpose(const std::vector>& mat) { + int rows = mat.size(); + int cols = mat[0].size(); + std::vector> transposed(cols, std::vector(rows)); + for (int i = 0; i < rows; ++i) { + for (int j = 0; j < cols; ++j) { + transposed[j][i] = mat[i][j]; } } - offset += rope_emb[0].size(); - } - - return flatten(emb); - } -public: - FluxParams params; - Flux() {} - Flux(FluxParams params) : params(params) { - int64_t out_channels = params.in_channels; - int64_t pe_dim = params.hidden_size / params.num_heads; - - blocks["img_in"] = std::shared_ptr(new Linear(params.in_channels, params.hidden_size)); - blocks["time_in"] = std::shared_ptr(new MLPEmbedder(256, params.hidden_size)); - blocks["vector_in"] = std::shared_ptr(new MLPEmbedder(params.vec_in_dim, params.hidden_size)); - if (params.guidance_embed) { - blocks["guidance_in"] = std::shared_ptr(new MLPEmbedder(256, params.hidden_size)); - } - blocks["txt_in"] = std::shared_ptr(new Linear(params.context_in_dim, params.hidden_size)); - - for (int i = 0; i < params.depth; i++) { - blocks["double_blocks." + std::to_string(i)] = std::shared_ptr(new DoubleStreamBlock(params.hidden_size, - params.num_heads, - params.mlp_ratio, - params.qkv_bias)); + return transposed; } - for (int i = 0; i < params.depth_single_blocks; i++) { - blocks["single_blocks." + std::to_string(i)] = std::shared_ptr(new SingleStreamBlock(params.hidden_size, - params.num_heads, - params.mlp_ratio)); + std::vector flatten(const std::vector>& vec) { + std::vector flat_vec; + for (const auto& sub_vec : vec) { + flat_vec.insert(flat_vec.end(), sub_vec.begin(), sub_vec.end()); + } + return flat_vec; } - blocks["final_layer"] = std::shared_ptr(new LastLayer(params.hidden_size, 1, out_channels)); - } + std::vector> rope(const std::vector& pos, int dim, int theta) { + assert(dim % 2 == 0); + int half_dim = dim / 2; - struct ggml_tensor* patchify(struct ggml_context* ctx, - struct ggml_tensor* x, - int64_t patch_size) { - // x: [N, C, H, W] - // return: [N, h*w, C * patch_size * patch_size] - int64_t N = x->ne[3]; - int64_t C = x->ne[2]; - int64_t H = x->ne[1]; - int64_t W = x->ne[0]; - int64_t p = patch_size; - int64_t h = H / patch_size; - int64_t w = W / patch_size; - - GGML_ASSERT(h * p == H && w * p == W); + std::vector scale = linspace(0, (dim * 1.0f - 2) / dim, half_dim); - x = ggml_reshape_4d(ctx, x, p, w, p, h*C*N); // [N*C*h, p, w, p] - x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N*C*h, w, p, p] - x = ggml_reshape_4d(ctx, x, p * p, w * h, C, N); // [N, C, h*w, p*p] - x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N, h*w, C, p*p] - x = ggml_reshape_3d(ctx, x, p*p*C, w*h, N); // [N, h*w, C*p*p] - return x; - } + std::vector omega(half_dim); + for (int i = 0; i < half_dim; ++i) { + omega[i] = 1.0 / std::pow(theta, scale[i]); + } - struct ggml_tensor* unpatchify(struct ggml_context* ctx, - struct ggml_tensor* x, - int64_t h, - int64_t w, - int64_t patch_size) { - // x: [N, h*w, C*patch_size*patch_size] - // return: [N, C, H, W] - int64_t N = x->ne[2]; - int64_t C = x->ne[0] / patch_size / patch_size; - int64_t H = h * patch_size; - int64_t W = w * patch_size; - int64_t p = patch_size; - - GGML_ASSERT(C * p * p == x->ne[0]); + int pos_size = pos.size(); + std::vector> out(pos_size, std::vector(half_dim)); + for (int i = 0; i < pos_size; ++i) { + for (int j = 0; j < half_dim; ++j) { + out[i][j] = pos[i] * omega[j]; + } + } - x = ggml_reshape_4d(ctx, x, p * p, C, w * h, N); // [N, h*w, C, p*p] - x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N, C, h*w, p*p] - x = ggml_reshape_4d(ctx, x, p, p, w, h * C * N); // [N*C*h, w, p, p] - x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N*C*h, p, w, p] - x = ggml_reshape_4d(ctx, x, W, H, C, N); // [N, C, h*p, w*p] + std::vector> result(pos_size, std::vector(half_dim * 4)); + for (int i = 0; i < pos_size; ++i) { + for (int j = 0; j < half_dim; ++j) { + result[i][4 * j] = std::cos(out[i][j]); + result[i][4 * j + 1] = -std::sin(out[i][j]); + result[i][4 * j + 2] = std::sin(out[i][j]); + result[i][4 * j + 3] = std::cos(out[i][j]); + } + } - return x; - } - - struct ggml_tensor* forward_orig(struct ggml_context* ctx, - struct ggml_tensor* img, - struct ggml_tensor* txt, - struct ggml_tensor* timesteps, - struct ggml_tensor* y, - struct ggml_tensor* guidance, - struct ggml_tensor* pe) { - auto img_in = std::dynamic_pointer_cast(blocks["img_in"]); - auto time_in = std::dynamic_pointer_cast(blocks["time_in"]); - auto vector_in = std::dynamic_pointer_cast(blocks["vector_in"]); - auto txt_in = std::dynamic_pointer_cast(blocks["txt_in"]); - auto final_layer = std::dynamic_pointer_cast(blocks["final_layer"]); - - img = img_in->forward(ctx, img); - auto vec = time_in->forward(ctx, ggml_nn_timestep_embedding(ctx, timesteps, 256, 10000, 1000.f)); - - if (params.guidance_embed) { - GGML_ASSERT(guidance != NULL); - auto guidance_in = std::dynamic_pointer_cast(blocks["guidance_in"]); - // bf16 and fp16 result is different - auto g_in = ggml_nn_timestep_embedding(ctx, guidance, 256, 10000, 1000.f); - vec = ggml_add(ctx, vec, guidance_in->forward(ctx, g_in)); + return result; } - vec = ggml_add(ctx, vec, vector_in->forward(ctx, y)); - txt = txt_in->forward(ctx, txt); + // Generate IDs for image patches and text + std::vector> gen_ids(int h, int w, int patch_size, int bs, int context_len) { + int h_len = (h + (patch_size / 2)) / patch_size; + int w_len = (w + (patch_size / 2)) / patch_size; - for (int i = 0; i < params.depth; i++) { - auto block = std::dynamic_pointer_cast(blocks["double_blocks." + std::to_string(i)]); + std::vector> img_ids(h_len * w_len, std::vector(3, 0.0)); - auto img_txt = block->forward(ctx, img, txt, vec, pe); - img = img_txt.first; // [N, n_img_token, hidden_size] - txt = img_txt.second; // [N, n_txt_token, hidden_size] + std::vector row_ids = linspace(0, h_len - 1, h_len); + std::vector col_ids = linspace(0, w_len - 1, w_len); + + for (int i = 0; i < h_len; ++i) { + for (int j = 0; j < w_len; ++j) { + img_ids[i * w_len + j][1] = row_ids[i]; + img_ids[i * w_len + j][2] = col_ids[j]; + } + } + + std::vector> img_ids_repeated(bs * img_ids.size(), std::vector(3)); + for (int i = 0; i < bs; ++i) { + for (int j = 0; j < img_ids.size(); ++j) { + img_ids_repeated[i * img_ids.size() + j] = img_ids[j]; + } + } + + std::vector> txt_ids(bs * context_len, std::vector(3, 0.0)); + std::vector> ids(bs * (context_len + img_ids.size()), std::vector(3)); + for (int i = 0; i < bs; ++i) { + for (int j = 0; j < context_len; ++j) { + ids[i * (context_len + img_ids.size()) + j] = txt_ids[j]; + } + for (int j = 0; j < img_ids.size(); ++j) { + ids[i * (context_len + img_ids.size()) + context_len + j] = img_ids_repeated[i * img_ids.size() + j]; + } + } + + return ids; } - auto txt_img = ggml_concat(ctx, txt, img, 1); // [N, n_txt_token + n_img_token, hidden_size] - for (int i = 0; i < params.depth_single_blocks; i++) { - auto block = std::dynamic_pointer_cast(blocks["single_blocks." + std::to_string(i)]); + // Generate positional embeddings + std::vector gen_pe(int h, int w, int patch_size, int bs, int context_len, int theta, const std::vector& axes_dim) { + std::vector> ids = gen_ids(h, w, patch_size, bs, context_len); + std::vector> trans_ids = transpose(ids); + size_t pos_len = ids.size(); + int num_axes = axes_dim.size(); + for (int i = 0; i < pos_len; i++) { + // std::cout << trans_ids[0][i] << " " << trans_ids[1][i] << " " << trans_ids[2][i] << std::endl; + } - txt_img = block->forward(ctx, txt_img, vec, pe); + int emb_dim = 0; + for (int d : axes_dim) + emb_dim += d / 2; + + std::vector> emb(bs * pos_len, std::vector(emb_dim * 2 * 2, 0.0)); + int offset = 0; + for (int i = 0; i < num_axes; ++i) { + std::vector> rope_emb = rope(trans_ids[i], axes_dim[i], theta); // [bs*pos_len, axes_dim[i]/2 * 2 * 2] + for (int b = 0; b < bs; ++b) { + for (int j = 0; j < pos_len; ++j) { + for (int k = 0; k < rope_emb[0].size(); ++k) { + emb[b * pos_len + j][offset + k] = rope_emb[j][k]; + } + } + } + offset += rope_emb[0].size(); + } + + return flatten(emb); } - txt_img = ggml_cont(ctx, ggml_permute(ctx, txt_img, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size] - img = ggml_view_3d(ctx, - txt_img, - txt_img->ne[0], - txt_img->ne[1], - img->ne[1], - txt_img->nb[1], - txt_img->nb[2], - txt_img->nb[2] * txt->ne[1]); // [n_img_token, N, hidden_size] - img = ggml_cont(ctx, ggml_permute(ctx, img, 0, 2, 1, 3)); // [N, n_img_token, hidden_size] + public: + FluxParams params; + Flux() {} + Flux(FluxParams params) + : params(params) { + int64_t out_channels = params.in_channels; + int64_t pe_dim = params.hidden_size / params.num_heads; - img = final_layer->forward(ctx, img, vec); // (N, T, patch_size ** 2 * out_channels) + blocks["img_in"] = std::shared_ptr(new Linear(params.in_channels, params.hidden_size)); + blocks["time_in"] = std::shared_ptr(new MLPEmbedder(256, params.hidden_size)); + blocks["vector_in"] = std::shared_ptr(new MLPEmbedder(params.vec_in_dim, params.hidden_size)); + if (params.guidance_embed) { + blocks["guidance_in"] = std::shared_ptr(new MLPEmbedder(256, params.hidden_size)); + } + blocks["txt_in"] = std::shared_ptr(new Linear(params.context_in_dim, params.hidden_size)); - return img; - } + for (int i = 0; i < params.depth; i++) { + blocks["double_blocks." + std::to_string(i)] = std::shared_ptr(new DoubleStreamBlock(params.hidden_size, + params.num_heads, + params.mlp_ratio, + params.qkv_bias)); + } - struct ggml_tensor* forward(struct ggml_context* ctx, - struct ggml_tensor* x, - struct ggml_tensor* timestep, - struct ggml_tensor* context, - struct ggml_tensor* y, - struct ggml_tensor* guidance, - struct ggml_tensor* pe) { - // Forward pass of DiT. - // x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) - // timestep: (N,) tensor of diffusion timesteps - // context: (N, L, D) - // y: (N, adm_in_channels) tensor of class labels - // guidance: (N,) - // pe: (L, d_head/2, 2, 2) - // return: (N, C, H, W) + for (int i = 0; i < params.depth_single_blocks; i++) { + blocks["single_blocks." + std::to_string(i)] = std::shared_ptr(new SingleStreamBlock(params.hidden_size, + params.num_heads, + params.mlp_ratio)); + } - GGML_ASSERT(x->ne[3] == 1); - - int64_t W = x->ne[0]; - int64_t H = x->ne[1]; - int64_t patch_size = 2; - int pad_h = (patch_size - H % patch_size) % patch_size; - int pad_w = (patch_size - W % patch_size) % patch_size; - x = ggml_pad(ctx, x, pad_w, pad_h, 0, 0); // [N, C, H + pad_h, W + pad_w] - - // img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size) - auto img = patchify(ctx, x, patch_size); // [N, h*w, C * patch_size * patch_size] - - auto out = forward_orig(ctx, img, context, timestep, y, guidance, pe); // [N, h*w, C * patch_size * patch_size] - - // rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2) - out = unpatchify(ctx, out, (H + pad_h) / patch_size, (W + pad_w) / patch_size, patch_size); // [N, C, H + pad_h, W + pad_w] - - return out; - } -}; - - -struct FluxRunner : public GGMLRunner { -public: - FluxParams flux_params; - Flux flux; - std::vector pe_vec; // for cache - - FluxRunner(ggml_backend_t backend, - ggml_type wtype, - SDVersion version = VERSION_FLUX_DEV) - : GGMLRunner(backend, wtype) { - if (version == VERSION_FLUX_SCHNELL) { - flux_params.guidance_embed = false; + blocks["final_layer"] = std::shared_ptr(new LastLayer(params.hidden_size, 1, out_channels)); } - flux = Flux(flux_params); - flux.init(params_ctx, wtype); - } - std::string get_desc() { - return "flux"; - } + struct ggml_tensor* patchify(struct ggml_context* ctx, + struct ggml_tensor* x, + int64_t patch_size) { + // x: [N, C, H, W] + // return: [N, h*w, C * patch_size * patch_size] + int64_t N = x->ne[3]; + int64_t C = x->ne[2]; + int64_t H = x->ne[1]; + int64_t W = x->ne[0]; + int64_t p = patch_size; + int64_t h = H / patch_size; + int64_t w = W / patch_size; - void get_param_tensors(std::map& tensors, const std::string prefix) { - flux.get_param_tensors(tensors, prefix); - } + GGML_ASSERT(h * p == H && w * p == W); - struct ggml_cgraph* build_graph(struct ggml_tensor* x, - struct ggml_tensor* timesteps, + x = ggml_reshape_4d(ctx, x, p, w, p, h * C * N); // [N*C*h, p, w, p] + x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N*C*h, w, p, p] + x = ggml_reshape_4d(ctx, x, p * p, w * h, C, N); // [N, C, h*w, p*p] + x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N, h*w, C, p*p] + x = ggml_reshape_3d(ctx, x, p * p * C, w * h, N); // [N, h*w, C*p*p] + return x; + } + + struct ggml_tensor* unpatchify(struct ggml_context* ctx, + struct ggml_tensor* x, + int64_t h, + int64_t w, + int64_t patch_size) { + // x: [N, h*w, C*patch_size*patch_size] + // return: [N, C, H, W] + int64_t N = x->ne[2]; + int64_t C = x->ne[0] / patch_size / patch_size; + int64_t H = h * patch_size; + int64_t W = w * patch_size; + int64_t p = patch_size; + + GGML_ASSERT(C * p * p == x->ne[0]); + + x = ggml_reshape_4d(ctx, x, p * p, C, w * h, N); // [N, h*w, C, p*p] + x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N, C, h*w, p*p] + x = ggml_reshape_4d(ctx, x, p, p, w, h * C * N); // [N*C*h, w, p, p] + x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N*C*h, p, w, p] + x = ggml_reshape_4d(ctx, x, W, H, C, N); // [N, C, h*p, w*p] + + return x; + } + + struct ggml_tensor* forward_orig(struct ggml_context* ctx, + struct ggml_tensor* img, + struct ggml_tensor* txt, + struct ggml_tensor* timesteps, + struct ggml_tensor* y, + struct ggml_tensor* guidance, + struct ggml_tensor* pe) { + auto img_in = std::dynamic_pointer_cast(blocks["img_in"]); + auto time_in = std::dynamic_pointer_cast(blocks["time_in"]); + auto vector_in = std::dynamic_pointer_cast(blocks["vector_in"]); + auto txt_in = std::dynamic_pointer_cast(blocks["txt_in"]); + auto final_layer = std::dynamic_pointer_cast(blocks["final_layer"]); + + img = img_in->forward(ctx, img); + auto vec = time_in->forward(ctx, ggml_nn_timestep_embedding(ctx, timesteps, 256, 10000, 1000.f)); + + if (params.guidance_embed) { + GGML_ASSERT(guidance != NULL); + auto guidance_in = std::dynamic_pointer_cast(blocks["guidance_in"]); + // bf16 and fp16 result is different + auto g_in = ggml_nn_timestep_embedding(ctx, guidance, 256, 10000, 1000.f); + vec = ggml_add(ctx, vec, guidance_in->forward(ctx, g_in)); + } + + vec = ggml_add(ctx, vec, vector_in->forward(ctx, y)); + txt = txt_in->forward(ctx, txt); + + for (int i = 0; i < params.depth; i++) { + auto block = std::dynamic_pointer_cast(blocks["double_blocks." + std::to_string(i)]); + + auto img_txt = block->forward(ctx, img, txt, vec, pe); + img = img_txt.first; // [N, n_img_token, hidden_size] + txt = img_txt.second; // [N, n_txt_token, hidden_size] + } + + auto txt_img = ggml_concat(ctx, txt, img, 1); // [N, n_txt_token + n_img_token, hidden_size] + for (int i = 0; i < params.depth_single_blocks; i++) { + auto block = std::dynamic_pointer_cast(blocks["single_blocks." + std::to_string(i)]); + + txt_img = block->forward(ctx, txt_img, vec, pe); + } + + txt_img = ggml_cont(ctx, ggml_permute(ctx, txt_img, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size] + img = ggml_view_3d(ctx, + txt_img, + txt_img->ne[0], + txt_img->ne[1], + img->ne[1], + txt_img->nb[1], + txt_img->nb[2], + txt_img->nb[2] * txt->ne[1]); // [n_img_token, N, hidden_size] + img = ggml_cont(ctx, ggml_permute(ctx, img, 0, 2, 1, 3)); // [N, n_img_token, hidden_size] + + img = final_layer->forward(ctx, img, vec); // (N, T, patch_size ** 2 * out_channels) + + return img; + } + + struct ggml_tensor* forward(struct ggml_context* ctx, + struct ggml_tensor* x, + struct ggml_tensor* timestep, struct ggml_tensor* context, struct ggml_tensor* y, - struct ggml_tensor* guidance) { - GGML_ASSERT(x->ne[3] == 1); - struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, FLUX_GRAPH_SIZE, false); + struct ggml_tensor* guidance, + struct ggml_tensor* pe) { + // Forward pass of DiT. + // x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) + // timestep: (N,) tensor of diffusion timesteps + // context: (N, L, D) + // y: (N, adm_in_channels) tensor of class labels + // guidance: (N,) + // pe: (L, d_head/2, 2, 2) + // return: (N, C, H, W) - x = to_backend(x); - context = to_backend(context); - y = to_backend(y); - timesteps = to_backend(timesteps); - if (flux_params.guidance_embed) { - guidance = to_backend(guidance); + GGML_ASSERT(x->ne[3] == 1); + + int64_t W = x->ne[0]; + int64_t H = x->ne[1]; + int64_t patch_size = 2; + int pad_h = (patch_size - H % patch_size) % patch_size; + int pad_w = (patch_size - W % patch_size) % patch_size; + x = ggml_pad(ctx, x, pad_w, pad_h, 0, 0); // [N, C, H + pad_h, W + pad_w] + + // img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size) + auto img = patchify(ctx, x, patch_size); // [N, h*w, C * patch_size * patch_size] + + auto out = forward_orig(ctx, img, context, timestep, y, guidance, pe); // [N, h*w, C * patch_size * patch_size] + + // rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2) + out = unpatchify(ctx, out, (H + pad_h) / patch_size, (W + pad_w) / patch_size, patch_size); // [N, C, H + pad_h, W + pad_w] + + return out; + } + }; + + struct FluxRunner : public GGMLRunner { + public: + FluxParams flux_params; + Flux flux; + std::vector pe_vec; // for cache + + FluxRunner(ggml_backend_t backend, + ggml_type wtype, + SDVersion version = VERSION_FLUX_DEV) + : GGMLRunner(backend, wtype) { + if (version == VERSION_FLUX_SCHNELL) { + flux_params.guidance_embed = false; + } + flux = Flux(flux_params); + flux.init(params_ctx, wtype); } - pe_vec = flux.gen_pe(x->ne[1], x->ne[0], 2, x->ne[3], context->ne[1], flux_params.theta, flux_params.axes_dim); - int pos_len = pe_vec.size() / flux_params.axes_dim_sum / 2; - // LOG_DEBUG("pos_len %d", pos_len); - auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, flux_params.axes_dim_sum/2, pos_len); - // pe->data = pe_vec.data(); - // print_ggml_tensor(pe); - // pe->data = NULL; - set_backend_tensor_data(pe, pe_vec.data()); - - - struct ggml_tensor* out = flux.forward(compute_ctx, - x, - timesteps, - context, - y, - guidance, - pe); - - ggml_build_forward_expand(gf, out); - - return gf; - } - - void compute(int n_threads, - struct ggml_tensor* x, - struct ggml_tensor* timesteps, - struct ggml_tensor* context, - struct ggml_tensor* y, - struct ggml_tensor* guidance, - struct ggml_tensor** output = NULL, - struct ggml_context* output_ctx = NULL) { - // x: [N, in_channels, h, w] - // timesteps: [N, ] - // context: [N, max_position, hidden_size] - // y: [N, adm_in_channels] or [1, adm_in_channels] - // guidance: [N, ] - auto get_graph = [&]() -> struct ggml_cgraph* { - return build_graph(x, timesteps, context, y, guidance); - }; - - GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx); - } - - void test() { - struct ggml_init_params params; - params.mem_size = static_cast(20 * 1024 * 1024); // 20 MB - params.mem_buffer = NULL; - params.no_alloc = false; - - struct ggml_context* work_ctx = ggml_init(params); - GGML_ASSERT(work_ctx != NULL); - - { - // cpu f16: - // cuda f16: nan - // cuda q8_0: pass - auto x = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 16, 16, 16, 1); - ggml_set_f32(x, 0.01f); - // print_ggml_tensor(x); - - std::vector timesteps_vec(1, 999.f); - auto timesteps = vector_to_ggml_tensor(work_ctx, timesteps_vec); - - std::vector guidance_vec(1, 3.5f); - auto guidance = vector_to_ggml_tensor(work_ctx, guidance_vec); - - auto context = ggml_new_tensor_3d(work_ctx, GGML_TYPE_F32, 4096, 256, 1); - ggml_set_f32(context, 0.01f); - // print_ggml_tensor(context); - - auto y = ggml_new_tensor_2d(work_ctx, GGML_TYPE_F32, 768, 1); - ggml_set_f32(y, 0.01f); - // print_ggml_tensor(y); - - struct ggml_tensor* out = NULL; - - int t0 = ggml_time_ms(); - compute(8, x, timesteps, context, y, guidance, &out, work_ctx); - int t1 = ggml_time_ms(); - - print_ggml_tensor(out); - LOG_DEBUG("flux test done in %dms", t1 - t0); + std::string get_desc() { + return "flux"; } - } - static void load_from_file_and_test(const std::string& file_path) { - // ggml_backend_t backend = ggml_backend_cuda_init(0); - ggml_backend_t backend = ggml_backend_cpu_init(); - ggml_type model_data_type = GGML_TYPE_Q8_0; - std::shared_ptr flux = std::shared_ptr(new FluxRunner(backend, model_data_type)); - { - LOG_INFO("loading from '%s'", file_path.c_str()); + void get_param_tensors(std::map& tensors, const std::string prefix) { + flux.get_param_tensors(tensors, prefix); + } - flux->alloc_params_buffer(); - std::map tensors; - flux->get_param_tensors(tensors, "model.diffusion_model"); + struct ggml_cgraph* build_graph(struct ggml_tensor* x, + struct ggml_tensor* timesteps, + struct ggml_tensor* context, + struct ggml_tensor* y, + struct ggml_tensor* guidance) { + GGML_ASSERT(x->ne[3] == 1); + struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, FLUX_GRAPH_SIZE, false); - ModelLoader model_loader; - if (!model_loader.init_from_file(file_path, "model.diffusion_model.")) { - LOG_ERROR("init model loader from file failed: '%s'", file_path.c_str()); - return; + x = to_backend(x); + context = to_backend(context); + y = to_backend(y); + timesteps = to_backend(timesteps); + if (flux_params.guidance_embed) { + guidance = to_backend(guidance); } - bool success = model_loader.load_tensors(tensors, backend); + pe_vec = flux.gen_pe(x->ne[1], x->ne[0], 2, x->ne[3], context->ne[1], flux_params.theta, flux_params.axes_dim); + int pos_len = pe_vec.size() / flux_params.axes_dim_sum / 2; + // LOG_DEBUG("pos_len %d", pos_len); + auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, flux_params.axes_dim_sum / 2, pos_len); + // pe->data = pe_vec.data(); + // print_ggml_tensor(pe); + // pe->data = NULL; + set_backend_tensor_data(pe, pe_vec.data()); - if (!success) { - LOG_ERROR("load tensors from model loader failed"); - return; - } + struct ggml_tensor* out = flux.forward(compute_ctx, + x, + timesteps, + context, + y, + guidance, + pe); - LOG_INFO("flux model loaded"); + ggml_build_forward_expand(gf, out); + + return gf; } - flux->test(); - } -}; -} // namespace Flux + void compute(int n_threads, + struct ggml_tensor* x, + struct ggml_tensor* timesteps, + struct ggml_tensor* context, + struct ggml_tensor* y, + struct ggml_tensor* guidance, + struct ggml_tensor** output = NULL, + struct ggml_context* output_ctx = NULL) { + // x: [N, in_channels, h, w] + // timesteps: [N, ] + // context: [N, max_position, hidden_size] + // y: [N, adm_in_channels] or [1, adm_in_channels] + // guidance: [N, ] + auto get_graph = [&]() -> struct ggml_cgraph* { + return build_graph(x, timesteps, context, y, guidance); + }; -#endif // __FLUX_HPP__ \ No newline at end of file + GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx); + } + + void test() { + struct ggml_init_params params; + params.mem_size = static_cast(20 * 1024 * 1024); // 20 MB + params.mem_buffer = NULL; + params.no_alloc = false; + + struct ggml_context* work_ctx = ggml_init(params); + GGML_ASSERT(work_ctx != NULL); + + { + // cpu f16: + // cuda f16: nan + // cuda q8_0: pass + auto x = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 16, 16, 16, 1); + ggml_set_f32(x, 0.01f); + // print_ggml_tensor(x); + + std::vector timesteps_vec(1, 999.f); + auto timesteps = vector_to_ggml_tensor(work_ctx, timesteps_vec); + + std::vector guidance_vec(1, 3.5f); + auto guidance = vector_to_ggml_tensor(work_ctx, guidance_vec); + + auto context = ggml_new_tensor_3d(work_ctx, GGML_TYPE_F32, 4096, 256, 1); + ggml_set_f32(context, 0.01f); + // print_ggml_tensor(context); + + auto y = ggml_new_tensor_2d(work_ctx, GGML_TYPE_F32, 768, 1); + ggml_set_f32(y, 0.01f); + // print_ggml_tensor(y); + + struct ggml_tensor* out = NULL; + + int t0 = ggml_time_ms(); + compute(8, x, timesteps, context, y, guidance, &out, work_ctx); + int t1 = ggml_time_ms(); + + print_ggml_tensor(out); + LOG_DEBUG("flux test done in %dms", t1 - t0); + } + } + + static void load_from_file_and_test(const std::string& file_path) { + // ggml_backend_t backend = ggml_backend_cuda_init(0); + ggml_backend_t backend = ggml_backend_cpu_init(); + ggml_type model_data_type = GGML_TYPE_Q8_0; + std::shared_ptr flux = std::shared_ptr(new FluxRunner(backend, model_data_type)); + { + LOG_INFO("loading from '%s'", file_path.c_str()); + + flux->alloc_params_buffer(); + std::map tensors; + flux->get_param_tensors(tensors, "model.diffusion_model"); + + ModelLoader model_loader; + if (!model_loader.init_from_file(file_path, "model.diffusion_model.")) { + LOG_ERROR("init model loader from file failed: '%s'", file_path.c_str()); + return; + } + + bool success = model_loader.load_tensors(tensors, backend); + + if (!success) { + LOG_ERROR("load tensors from model loader failed"); + return; + } + + LOG_INFO("flux model loaded"); + } + flux->test(); + } + }; + +} // namespace Flux + +#endif // __FLUX_HPP__ \ No newline at end of file diff --git a/ggml_extend.hpp b/ggml_extend.hpp index 32ad1e6..19410e8 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -541,7 +541,7 @@ __STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const __STATIC_INLINE__ struct ggml_tensor* ggml_group_norm_32(struct ggml_context* ctx, struct ggml_tensor* a) { - const float eps = 1e-6f; // default eps parameter + const float eps = 1e-6f; // default eps parameter return ggml_group_norm(ctx, a, 32, eps); } @@ -683,27 +683,27 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context* int64_t n_head, struct ggml_tensor* mask = NULL, bool diag_mask_inf = false, - bool skip_reshape = false) { + bool skip_reshape = false) { int64_t L_q; int64_t L_k; - int64_t C ; - int64_t N ; + int64_t C; + int64_t N; int64_t d_head; if (!skip_reshape) { - L_q = q->ne[1]; - L_k = k->ne[1]; - C = q->ne[0]; - N = q->ne[2]; + L_q = q->ne[1]; + L_k = k->ne[1]; + C = q->ne[0]; + N = q->ne[2]; d_head = C / n_head; - q = ggml_reshape_4d(ctx, q, d_head, n_head, L_q, N); // [N, L_q, n_head, d_head] - q = ggml_cont(ctx, ggml_permute(ctx, q, 0, 2, 1, 3)); // [N, n_head, L_q, d_head] - q = ggml_reshape_3d(ctx, q, d_head, L_q, n_head * N); // [N * n_head, L_q, d_head] + q = ggml_reshape_4d(ctx, q, d_head, n_head, L_q, N); // [N, L_q, n_head, d_head] + q = ggml_cont(ctx, ggml_permute(ctx, q, 0, 2, 1, 3)); // [N, n_head, L_q, d_head] + q = ggml_reshape_3d(ctx, q, d_head, L_q, n_head * N); // [N * n_head, L_q, d_head] k = ggml_reshape_4d(ctx, k, d_head, n_head, L_k, N); // [N, L_k, n_head, d_head] k = ggml_cont(ctx, ggml_permute(ctx, k, 0, 2, 1, 3)); // [N, n_head, L_k, d_head] k = ggml_reshape_3d(ctx, k, d_head, L_k, n_head * N); // [N * n_head, L_k, d_head] - v = ggml_reshape_4d(ctx, v, d_head, n_head, L_k, N); // [N, L_k, n_head, d_head] + v = ggml_reshape_4d(ctx, v, d_head, n_head, L_k, N); // [N, L_k, n_head, d_head] } else { L_q = q->ne[1]; L_k = k->ne[1]; @@ -712,10 +712,10 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context* C = d_head * n_head; } - float scale = (1.0f / sqrt((float)d_head)); + float scale = (1.0f / sqrt((float)d_head)); bool use_flash_attn = false; - ggml_tensor* kqv = NULL; + ggml_tensor* kqv = NULL; if (use_flash_attn) { v = ggml_cont(ctx, ggml_permute(ctx, v, 0, 2, 1, 3)); // [N, n_head, L_k, d_head] v = ggml_reshape_3d(ctx, v, d_head, L_k, n_head * N); // [N * n_head, L_k, d_head] @@ -770,8 +770,8 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_group_norm(struct ggml_context* ct b = ggml_reshape_4d(ctx, b, 1, 1, b->ne[0], 1); } - const float eps = 1e-6f; // default eps parameter - x = ggml_group_norm(ctx, x, num_groups, eps); + const float eps = 1e-6f; // default eps parameter + x = ggml_group_norm(ctx, x, num_groups, eps); if (w != NULL && b != NULL) { x = ggml_mul(ctx, x, w); // b = ggml_repeat(ctx, b, x); @@ -781,7 +781,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_group_norm(struct ggml_context* ct } __STATIC_INLINE__ void ggml_backend_tensor_get_and_sync(ggml_backend_t backend, const struct ggml_tensor* tensor, void* data, size_t offset, size_t size) { -#if defined (SD_USE_CUBLAS) || defined (SD_USE_SYCL) +#if defined(SD_USE_CUBLAS) || defined(SD_USE_SYCL) if (!ggml_backend_is_cpu(backend)) { ggml_backend_tensor_get_async(backend, tensor, data, offset, size); ggml_backend_synchronize(backend); @@ -889,7 +889,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_timestep_embedding( struct ggml_context* ctx, struct ggml_tensor* timesteps, int dim, - int max_period = 10000, + int max_period = 10000, float time_factor = 1.0f) { timesteps = ggml_scale(ctx, timesteps, time_factor); return ggml_timestep_embedding(ctx, timesteps, dim, max_period); diff --git a/lora.hpp b/lora.hpp index 309378f..6b28a03 100644 --- a/lora.hpp +++ b/lora.hpp @@ -10,10 +10,10 @@ struct LoraModel : public GGMLRunner { std::map lora_tensors; std::string file_path; ModelLoader model_loader; - bool load_failed = false; - bool applied = false; + bool load_failed = false; + bool applied = false; std::vector zero_index_vec = {0}; - ggml_tensor* zero_index = NULL; + ggml_tensor* zero_index = NULL; LoraModel(ggml_backend_t backend, ggml_type wtype, @@ -72,8 +72,8 @@ struct LoraModel : public GGMLRunner { ggml_tensor* to_f32(ggml_context* ctx, ggml_tensor* a) { auto out = ggml_reshape_1d(ctx, a, ggml_nelements(a)); - out = ggml_get_rows(ctx, out, zero_index); - out = ggml_reshape(ctx, out, a); + out = ggml_get_rows(ctx, out, zero_index); + out = ggml_reshape(ctx, out, a); return out; } diff --git a/model.cpp b/model.cpp index f5c0701..e693eec 100644 --- a/model.cpp +++ b/model.cpp @@ -567,10 +567,10 @@ uint16_t f8_e4m3_to_f16(uint8_t f8) { return ggml_fp32_to_fp16(NAN); } - uint32_t sign = f8 & 0x80; + uint32_t sign = f8 & 0x80; uint32_t exponent = (f8 & 0x78) >> 3; uint32_t mantissa = f8 & 0x07; - uint32_t result = sign << 24; + uint32_t result = sign << 24; if (exponent == 0) { if (mantissa > 0) { exponent = 0x7f - exponent_bias; @@ -1399,8 +1399,8 @@ ggml_type ModelLoader::get_sd_wtype() { if (tensor_storage.name.find(".weight") != std::string::npos && (tensor_storage.name.find("time_embed") != std::string::npos || - tensor_storage.name.find("context_embedder") != std::string::npos || - tensor_storage.name.find("time_in") != std::string::npos)) { + tensor_storage.name.find("context_embedder") != std::string::npos || + tensor_storage.name.find("time_in") != std::string::npos)) { return tensor_storage.type; } } @@ -1414,9 +1414,9 @@ ggml_type ModelLoader::get_conditioner_wtype() { } if ((tensor_storage.name.find("text_encoders") == std::string::npos && - tensor_storage.name.find("cond_stage_model") == std::string::npos && - tensor_storage.name.find("te.text_model.") == std::string::npos && - tensor_storage.name.find("conditioner") == std::string::npos)) { + tensor_storage.name.find("cond_stage_model") == std::string::npos && + tensor_storage.name.find("te.text_model.") == std::string::npos && + tensor_storage.name.find("conditioner") == std::string::npos)) { continue; } @@ -1427,7 +1427,6 @@ ggml_type ModelLoader::get_conditioner_wtype() { return GGML_TYPE_COUNT; } - ggml_type ModelLoader::get_diffusion_model_wtype() { for (auto& tensor_storage : tensor_storages) { if (is_unused_tensor(tensor_storage.name)) { @@ -1440,8 +1439,8 @@ ggml_type ModelLoader::get_diffusion_model_wtype() { if (tensor_storage.name.find(".weight") != std::string::npos && (tensor_storage.name.find("time_embed") != std::string::npos || - tensor_storage.name.find("context_embedder") != std::string::npos || - tensor_storage.name.find("time_in") != std::string::npos)) { + tensor_storage.name.find("context_embedder") != std::string::npos || + tensor_storage.name.find("time_in") != std::string::npos)) { return tensor_storage.type; } } diff --git a/model.h b/model.h index f96c067..5cd6314 100644 --- a/model.h +++ b/model.h @@ -165,4 +165,3 @@ public: }; #endif // __MODEL_H__ - diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 74f6102..1bbe0d9 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -69,11 +69,10 @@ public: ggml_backend_t clip_backend = NULL; ggml_backend_t control_net_backend = NULL; ggml_backend_t vae_backend = NULL; - ggml_type model_wtype = GGML_TYPE_COUNT; - ggml_type conditioner_wtype = GGML_TYPE_COUNT; - ggml_type diffusion_model_wtype = GGML_TYPE_COUNT; - ggml_type vae_wtype = GGML_TYPE_COUNT; - + ggml_type model_wtype = GGML_TYPE_COUNT; + ggml_type conditioner_wtype = GGML_TYPE_COUNT; + ggml_type diffusion_model_wtype = GGML_TYPE_COUNT; + ggml_type vae_wtype = GGML_TYPE_COUNT; SDVersion version; bool vae_decode_only = false; @@ -171,7 +170,7 @@ public: backend = ggml_backend_cpu_init(); } #ifdef SD_USE_FLASH_ATTENTION -#if defined(SD_USE_CUBLAS) || defined(SD_USE_METAL) || defined (SD_USE_SYCL) +#if defined(SD_USE_CUBLAS) || defined(SD_USE_METAL) || defined(SD_USE_SYCL) LOG_WARN("Flash Attention not supported with GPU Backend"); #else LOG_INFO("Flash Attention enabled"); @@ -243,10 +242,10 @@ public: vae_wtype = wtype; } } else { - model_wtype = wtype; - conditioner_wtype = wtype; + model_wtype = wtype; + conditioner_wtype = wtype; diffusion_model_wtype = wtype; - vae_wtype = wtype; + vae_wtype = wtype; } if (version == VERSION_SDXL) { @@ -290,7 +289,7 @@ public: first_stage_model->alloc_params_buffer(); first_stage_model->get_param_tensors(tensors, "first_stage_model"); } else { - clip_backend = backend; + clip_backend = backend; bool use_t5xxl = false; if (version == VERSION_SD3_2B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) { use_t5xxl = true; @@ -508,7 +507,7 @@ public: LOG_INFO("running in Flux FLOW mode"); float shift = 1.15f; if (version == VERSION_FLUX_SCHNELL) { - shift = 1.0f; // TODO: validate + shift = 1.0f; // TODO: validate } denoiser = std::make_shared(shift); } else if (is_using_v_parameterization) { diff --git a/stable-diffusion.h b/stable-diffusion.h index 0225b34..f616eef 100644 --- a/stable-diffusion.h +++ b/stable-diffusion.h @@ -55,37 +55,37 @@ enum schedule_t { // same as enum ggml_type enum sd_type_t { - SD_TYPE_F32 = 0, - SD_TYPE_F16 = 1, - SD_TYPE_Q4_0 = 2, - SD_TYPE_Q4_1 = 3, + SD_TYPE_F32 = 0, + SD_TYPE_F16 = 1, + SD_TYPE_Q4_0 = 2, + SD_TYPE_Q4_1 = 3, // SD_TYPE_Q4_2 = 4, support has been removed // SD_TYPE_Q4_3 = 5, support has been removed - SD_TYPE_Q5_0 = 6, - SD_TYPE_Q5_1 = 7, - SD_TYPE_Q8_0 = 8, - SD_TYPE_Q8_1 = 9, - SD_TYPE_Q2_K = 10, - SD_TYPE_Q3_K = 11, - SD_TYPE_Q4_K = 12, - SD_TYPE_Q5_K = 13, - SD_TYPE_Q6_K = 14, - SD_TYPE_Q8_K = 15, - SD_TYPE_IQ2_XXS = 16, - SD_TYPE_IQ2_XS = 17, - SD_TYPE_IQ3_XXS = 18, - SD_TYPE_IQ1_S = 19, - SD_TYPE_IQ4_NL = 20, - SD_TYPE_IQ3_S = 21, - SD_TYPE_IQ2_S = 22, - SD_TYPE_IQ4_XS = 23, - SD_TYPE_I8 = 24, - SD_TYPE_I16 = 25, - SD_TYPE_I32 = 26, - SD_TYPE_I64 = 27, - SD_TYPE_F64 = 28, - SD_TYPE_IQ1_M = 29, - SD_TYPE_BF16 = 30, + SD_TYPE_Q5_0 = 6, + SD_TYPE_Q5_1 = 7, + SD_TYPE_Q8_0 = 8, + SD_TYPE_Q8_1 = 9, + SD_TYPE_Q2_K = 10, + SD_TYPE_Q3_K = 11, + SD_TYPE_Q4_K = 12, + SD_TYPE_Q5_K = 13, + SD_TYPE_Q6_K = 14, + SD_TYPE_Q8_K = 15, + SD_TYPE_IQ2_XXS = 16, + SD_TYPE_IQ2_XS = 17, + SD_TYPE_IQ3_XXS = 18, + SD_TYPE_IQ1_S = 19, + SD_TYPE_IQ4_NL = 20, + SD_TYPE_IQ3_S = 21, + SD_TYPE_IQ2_S = 22, + SD_TYPE_IQ4_XS = 23, + SD_TYPE_I8 = 24, + SD_TYPE_I16 = 25, + SD_TYPE_I32 = 26, + SD_TYPE_I64 = 27, + SD_TYPE_F64 = 28, + SD_TYPE_IQ1_M = 29, + SD_TYPE_BF16 = 30, SD_TYPE_Q4_0_4_4 = 31, SD_TYPE_Q4_0_4_8 = 32, SD_TYPE_Q4_0_8_8 = 33,