fix: repair flash attention support (#386)

* repair flash attention in _ext
this does not fix the currently broken fa behind the define, which is only used by VAE

Co-authored-by: FSSRepo <FSSRepo@users.noreply.github.com>

* make flash attention in the diffusion model a runtime flag
no support for sd3 or video

* remove old flash attention option and switch vae over to attn_ext

* update docs

* format code

---------

Co-authored-by: FSSRepo <FSSRepo@users.noreply.github.com>
Co-authored-by: leejet <leejet714@gmail.com>
This commit is contained in:
Erik Scholz 2024-11-23 05:39:08 +01:00 committed by GitHub
parent ea9b647080
commit 1c168d98a5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 334 additions and 314 deletions

View File

@ -29,7 +29,6 @@ option(SD_HIPBLAS "sd: rocm backend" OFF)
option(SD_METAL "sd: metal backend" OFF)
option(SD_VULKAN "sd: vulkan backend" OFF)
option(SD_SYCL "sd: sycl backend" OFF)
option(SD_FLASH_ATTN "sd: use flash attention for x4 less memory usage" OFF)
option(SD_FAST_SOFTMAX "sd: x1.5 faster softmax, indeterministic (sometimes, same seed don't generate same image), cuda only" OFF)
option(SD_BUILD_SHARED_LIBS "sd: build shared libs" OFF)
#option(SD_BUILD_SERVER "sd: build server example" ON)
@ -61,11 +60,6 @@ if (SD_HIPBLAS)
endif()
endif ()
if(SD_FLASH_ATTN)
message("-- Use Flash Attention for memory optimization")
add_definitions(-DSD_USE_FLASH_ATTENTION)
endif()
set(SD_LIB stable-diffusion)
file(GLOB SD_LIB_SOURCES

View File

@ -24,7 +24,7 @@ Inference of Stable Diffusion and Flux in pure C/C++
- Full CUDA, Metal, Vulkan and SYCL backend for GPU acceleration.
- Can load ckpt, safetensors and diffusers models/checkpoints. Standalone VAEs models
- No need to convert to `.ggml` or `.gguf` anymore!
- Flash Attention for memory usage optimization (only cpu for now)
- Flash Attention for memory usage optimization
- Original `txt2img` and `img2img` mode
- Negative prompt
- [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) style tokenizer (not all the features, only token weighting for now)
@ -182,11 +182,21 @@ Example of text2img by using SYCL backend:
##### Using Flash Attention
Enabling flash attention reduces memory usage by at least 400 MB. At the moment, it is not supported when CUBLAS is enabled because the kernel implementation is missing.
Enabling flash attention for the diffusion model reduces memory usage by varying amounts of MB.
eg.:
- flux 768x768 ~600mb
- SD2 768x768 ~1400mb
For most backends, it slows things down, but for cuda it generally speeds it up too.
At the moment, it is only supported for some models and some backends (like cpu, cuda/rocm, metal).
Run by adding `--diffusion-fa` to the arguments and watch for:
```
cmake .. -DSD_FLASH_ATTN=ON
cmake --build . --config Release
[INFO ] stable-diffusion.cpp:312 - Using flash attention in the diffusion model
```
and the compute buffer shrink in the debug log:
```
[DEBUG] ggml_extend.hpp:1004 - flux compute buffer size: 650.00 MB(VRAM)
```
### Run
@ -240,6 +250,9 @@ arguments:
--vae-tiling process vae in tiles to reduce memory usage
--vae-on-cpu keep vae in cpu (for low vram)
--clip-on-cpu keep clip in cpu (for low vram)
--diffusion-fa use flash attention in the diffusion model (for low vram)
Might lower quality, since it implies converting k and v to f16.
This might crash if it is not supported by the backend.
--control-net-cpu keep controlnet in cpu (for low vram)
--canny apply canny preprocessor (edge detection)
--color Colors the logging tags according to level

View File

@ -343,8 +343,7 @@ public:
}
}
std::string clean_up_tokenization(std::string &text){
std::string clean_up_tokenization(std::string& text) {
std::regex pattern(R"( ,)");
// Replace " ," with ","
std::string result = std::regex_replace(text, pattern, ",");
@ -359,10 +358,10 @@ public:
std::u32string ts = decoder[t];
// printf("%d, %s \n", t, utf32_to_utf8(ts).c_str());
std::string s = utf32_to_utf8(ts);
if (s.length() >= 4 ){
if(ends_with(s, "</w>")) {
if (s.length() >= 4) {
if (ends_with(s, "</w>")) {
text += s.replace(s.length() - 4, s.length() - 1, "") + " ";
}else{
} else {
text += s;
}
} else {
@ -768,8 +767,7 @@ public:
blocks["post_layernorm"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size));
}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* pixel_values,
bool return_pooled = true) {
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* pixel_values, bool return_pooled = true) {
// pixel_values: [N, num_channels, image_size, image_size]
auto embeddings = std::dynamic_pointer_cast<CLIPVisionEmbeddings>(blocks["embeddings"]);
auto pre_layernorm = std::dynamic_pointer_cast<LayerNorm>(blocks["pre_layernorm"]);
@ -781,7 +779,7 @@ public:
x = encoder->forward(ctx, x, -1, false);
// print_ggml_tensor(x, true, "ClipVisionModel x: ");
auto last_hidden_state = x;
x = post_layernorm->forward(ctx, x); // [N, n_token, hidden_size]
x = post_layernorm->forward(ctx, x); // [N, n_token, hidden_size]
GGML_ASSERT(x->ne[3] == 1);
if (return_pooled) {

View File

@ -245,16 +245,19 @@ protected:
int64_t context_dim;
int64_t n_head;
int64_t d_head;
bool flash_attn;
public:
CrossAttention(int64_t query_dim,
int64_t context_dim,
int64_t n_head,
int64_t d_head)
int64_t d_head,
bool flash_attn = false)
: n_head(n_head),
d_head(d_head),
query_dim(query_dim),
context_dim(context_dim) {
context_dim(context_dim),
flash_attn(flash_attn) {
int64_t inner_dim = d_head * n_head;
blocks["to_q"] = std::shared_ptr<GGMLBlock>(new Linear(query_dim, inner_dim, false));
@ -283,7 +286,7 @@ public:
auto k = to_k->forward(ctx, context); // [N, n_context, inner_dim]
auto v = to_v->forward(ctx, context); // [N, n_context, inner_dim]
x = ggml_nn_attention_ext(ctx, q, k, v, n_head, NULL, false); // [N, n_token, inner_dim]
x = ggml_nn_attention_ext(ctx, q, k, v, n_head, NULL, false, false, flash_attn); // [N, n_token, inner_dim]
x = to_out_0->forward(ctx, x); // [N, n_token, query_dim]
return x;
@ -301,15 +304,16 @@ public:
int64_t n_head,
int64_t d_head,
int64_t context_dim,
bool ff_in = false)
bool ff_in = false,
bool flash_attn = false)
: n_head(n_head), d_head(d_head), ff_in(ff_in) {
// disable_self_attn is always False
// disable_temporal_crossattention is always False
// switch_temporal_ca_to_sa is always False
// inner_dim is always None or equal to dim
// gated_ff is always True
blocks["attn1"] = std::shared_ptr<GGMLBlock>(new CrossAttention(dim, dim, n_head, d_head));
blocks["attn2"] = std::shared_ptr<GGMLBlock>(new CrossAttention(dim, context_dim, n_head, d_head));
blocks["attn1"] = std::shared_ptr<GGMLBlock>(new CrossAttention(dim, dim, n_head, d_head, flash_attn));
blocks["attn2"] = std::shared_ptr<GGMLBlock>(new CrossAttention(dim, context_dim, n_head, d_head, flash_attn));
blocks["ff"] = std::shared_ptr<GGMLBlock>(new FeedForward(dim, dim));
blocks["norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim));
blocks["norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim));
@ -374,7 +378,8 @@ public:
int64_t n_head,
int64_t d_head,
int64_t depth,
int64_t context_dim)
int64_t context_dim,
bool flash_attn = false)
: in_channels(in_channels),
n_head(n_head),
d_head(d_head),
@ -388,7 +393,7 @@ public:
for (int i = 0; i < depth; i++) {
std::string name = "transformer_blocks." + std::to_string(i);
blocks[name] = std::shared_ptr<GGMLBlock>(new BasicTransformerBlock(inner_dim, n_head, d_head, context_dim));
blocks[name] = std::shared_ptr<GGMLBlock>(new BasicTransformerBlock(inner_dim, n_head, d_head, context_dim, false, flash_attn));
}
blocks["proj_out"] = std::shared_ptr<GGMLBlock>(new Conv2d(inner_dim, in_channels, {1, 1}));

View File

@ -4,7 +4,6 @@
#include "clip.hpp"
#include "t5.hpp"
struct SDCondition {
struct ggml_tensor* c_crossattn = NULL; // aka context
struct ggml_tensor* c_vector = NULL; // aka y
@ -44,7 +43,7 @@ struct Conditioner {
// ldm.modules.encoders.modules.FrozenCLIPEmbedder
// Ref: https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/cad87bf4e3e0b0a759afa94e933527c3123d59bc/modules/sd_hijack_clip.py#L283
struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
SDVersion version = VERSION_SD1;
SDVersion version = VERSION_SD1;
PMVersion pm_version = VERSION_1;
CLIPTokenizer tokenizer;
ggml_type wtype;
@ -61,7 +60,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
ggml_type wtype,
const std::string& embd_dir,
SDVersion version = VERSION_SD1,
PMVersion pv = VERSION_1,
PMVersion pv = VERSION_1,
int clip_skip = -1)
: version(version), pm_version(pv), tokenizer(version == VERSION_SD2 ? 0 : 49407), embd_dir(embd_dir), wtype(wtype) {
if (clip_skip <= 0) {
@ -162,7 +161,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
tokenize_with_trigger_token(std::string text,
int num_input_imgs,
int32_t image_token,
bool padding = false){
bool padding = false) {
return tokenize_with_trigger_token(text, num_input_imgs, image_token,
text_model->model.n_token, padding);
}
@ -271,7 +270,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
std::vector<int> clean_input_ids_tmp;
for (uint32_t i = 0; i < class_token_index[0]; i++)
clean_input_ids_tmp.push_back(clean_input_ids[i]);
for (uint32_t i = 0; i < (pm_version == VERSION_2 ? 2*num_input_imgs: num_input_imgs); i++)
for (uint32_t i = 0; i < (pm_version == VERSION_2 ? 2 * num_input_imgs : num_input_imgs); i++)
clean_input_ids_tmp.push_back(class_token);
for (uint32_t i = class_token_index[0] + 1; i < clean_input_ids.size(); i++)
clean_input_ids_tmp.push_back(clean_input_ids[i]);
@ -287,11 +286,11 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
// weights.insert(weights.begin(), 1.0);
tokenizer.pad_tokens(tokens, weights, max_length, padding);
int offset = pm_version == VERSION_2 ? 2*num_input_imgs: num_input_imgs;
int offset = pm_version == VERSION_2 ? 2 * num_input_imgs : num_input_imgs;
for (uint32_t i = 0; i < tokens.size(); i++) {
// if (class_idx + 1 <= i && i < class_idx + 1 + 2*num_input_imgs) // photomaker V2 has num_tokens(=2)*num_input_imgs
if (class_idx + 1 <= i && i < class_idx + 1 + offset) // photomaker V2 has num_tokens(=2)*num_input_imgs
// hardcode for now
if (class_idx + 1 <= i && i < class_idx + 1 + offset) // photomaker V2 has num_tokens(=2)*num_input_imgs
// hardcode for now
class_token_mask.push_back(true);
else
class_token_mask.push_back(false);
@ -536,7 +535,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
int height,
int num_input_imgs,
int adm_in_channels = -1,
bool force_zero_embeddings = false){
bool force_zero_embeddings = false) {
auto image_tokens = convert_token_to_id(trigger_word);
// if(image_tokens.size() == 1){
// printf(" image token id is: %d \n", image_tokens[0]);
@ -964,7 +963,7 @@ struct SD3CLIPEmbedder : public Conditioner {
int height,
int num_input_imgs,
int adm_in_channels = -1,
bool force_zero_embeddings = false){
bool force_zero_embeddings = false) {
GGML_ASSERT(0 && "Not implemented yet!");
}

View File

@ -32,8 +32,9 @@ struct UNetModel : public DiffusionModel {
UNetModel(ggml_backend_t backend,
ggml_type wtype,
SDVersion version = VERSION_SD1)
: unet(backend, wtype, version) {
SDVersion version = VERSION_SD1,
bool flash_attn = false)
: unet(backend, wtype, version, flash_attn) {
}
void alloc_params_buffer() {
@ -133,8 +134,9 @@ struct FluxModel : public DiffusionModel {
FluxModel(ggml_backend_t backend,
ggml_type wtype,
SDVersion version = VERSION_FLUX_DEV)
: flux(backend, wtype, version) {
SDVersion version = VERSION_FLUX_DEV,
bool flash_attn = false)
: flux(backend, wtype, version, flash_attn) {
}
void alloc_params_buffer() {

View File

@ -116,6 +116,7 @@ struct SDParams {
bool normalize_input = false;
bool clip_on_cpu = false;
bool vae_on_cpu = false;
bool diffusion_flash_attn = false;
bool canny_preprocess = false;
bool color = false;
int upscale_repeats = 1;
@ -151,6 +152,7 @@ void print_params(SDParams params) {
printf(" clip on cpu: %s\n", params.clip_on_cpu ? "true" : "false");
printf(" controlnet cpu: %s\n", params.control_net_cpu ? "true" : "false");
printf(" vae decoder on cpu:%s\n", params.vae_on_cpu ? "true" : "false");
printf(" diffusion flash attention:%s\n", params.diffusion_flash_attn ? "true" : "false");
printf(" strength(control): %.2f\n", params.control_strength);
printf(" prompt: %s\n", params.prompt.c_str());
printf(" negative_prompt: %s\n", params.negative_prompt.c_str());
@ -227,6 +229,9 @@ void print_usage(int argc, const char* argv[]) {
printf(" --vae-tiling process vae in tiles to reduce memory usage\n");
printf(" --vae-on-cpu keep vae in cpu (for low vram)\n");
printf(" --clip-on-cpu keep clip in cpu (for low vram)\n");
printf(" --diffusion-fa use flash attention in the diffusion model (for low vram)\n");
printf(" Might lower quality, since it implies converting k and v to f16.\n");
printf(" This might crash if it is not supported by the backend.\n");
printf(" --control-net-cpu keep controlnet in cpu (for low vram)\n");
printf(" --canny apply canny preprocessor (edge detection)\n");
printf(" --color Colors the logging tags according to level\n");
@ -477,6 +482,8 @@ void parse_args(int argc, const char** argv, SDParams& params) {
params.clip_on_cpu = true; // will slow down get_learned_condiotion but necessary for low MEM GPUs
} else if (arg == "--vae-on-cpu") {
params.vae_on_cpu = true; // will slow down latent decoding but necessary for low MEM GPUs
} else if (arg == "--diffusion-fa") {
params.diffusion_flash_attn = true; // can reduce MEM significantly
} else if (arg == "--canny") {
params.canny_preprocess = true;
} else if (arg == "-b" || arg == "--batch-count") {
@ -868,7 +875,8 @@ int main(int argc, const char* argv[]) {
params.schedule,
params.clip_on_cpu,
params.control_net_cpu,
params.vae_on_cpu);
params.vae_on_cpu,
params.diffusion_flash_attn);
if (sd_ctx == NULL) {
printf("new_sd_ctx_t failed\n");

View File

@ -115,25 +115,28 @@ namespace Flux {
struct ggml_tensor* q,
struct ggml_tensor* k,
struct ggml_tensor* v,
struct ggml_tensor* pe) {
struct ggml_tensor* pe,
bool flash_attn) {
// q,k,v: [N, L, n_head, d_head]
// pe: [L, d_head/2, 2, 2]
// return: [N, L, n_head*d_head]
q = apply_rope(ctx, q, pe); // [N*n_head, L, d_head]
k = apply_rope(ctx, k, pe); // [N*n_head, L, d_head]
auto x = ggml_nn_attention_ext(ctx, q, k, v, v->ne[1], NULL, false, true); // [N, L, n_head*d_head]
auto x = ggml_nn_attention_ext(ctx, q, k, v, v->ne[1], NULL, false, true, flash_attn); // [N, L, n_head*d_head]
return x;
}
struct SelfAttention : public GGMLBlock {
public:
int64_t num_heads;
bool flash_attn;
public:
SelfAttention(int64_t dim,
int64_t num_heads = 8,
bool qkv_bias = false)
bool qkv_bias = false,
bool flash_attn = false)
: num_heads(num_heads) {
int64_t head_dim = dim / num_heads;
blocks["qkv"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim * 3, qkv_bias));
@ -167,9 +170,9 @@ namespace Flux {
// x: [N, n_token, dim]
// pe: [n_token, d_head/2, 2, 2]
// return [N, n_token, dim]
auto qkv = pre_attention(ctx, x); // q,k,v: [N, n_token, n_head, d_head]
x = attention(ctx, qkv[0], qkv[1], qkv[2], pe); // [N, n_token, dim]
x = post_attention(ctx, x); // [N, n_token, dim]
auto qkv = pre_attention(ctx, x); // q,k,v: [N, n_token, n_head, d_head]
x = attention(ctx, qkv[0], qkv[1], qkv[2], pe, flash_attn); // [N, n_token, dim]
x = post_attention(ctx, x); // [N, n_token, dim]
return x;
}
};
@ -237,15 +240,19 @@ namespace Flux {
}
struct DoubleStreamBlock : public GGMLBlock {
bool flash_attn;
public:
DoubleStreamBlock(int64_t hidden_size,
int64_t num_heads,
float mlp_ratio,
bool qkv_bias = false) {
bool qkv_bias = false,
bool flash_attn = false)
: flash_attn(flash_attn) {
int64_t mlp_hidden_dim = hidden_size * mlp_ratio;
blocks["img_mod"] = std::shared_ptr<GGMLBlock>(new Modulation(hidden_size, true));
blocks["img_norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false));
blocks["img_attn"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qkv_bias));
blocks["img_attn"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qkv_bias, flash_attn));
blocks["img_norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false));
blocks["img_mlp.0"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, mlp_hidden_dim));
@ -254,7 +261,7 @@ namespace Flux {
blocks["txt_mod"] = std::shared_ptr<GGMLBlock>(new Modulation(hidden_size, true));
blocks["txt_norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false));
blocks["txt_attn"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qkv_bias));
blocks["txt_attn"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qkv_bias, flash_attn));
blocks["txt_norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false));
blocks["txt_mlp.0"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, mlp_hidden_dim));
@ -316,7 +323,7 @@ namespace Flux {
auto k = ggml_concat(ctx, txt_k, img_k, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
auto v = ggml_concat(ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
auto attn = attention(ctx, q, k, v, pe); // [N, n_txt_token + n_img_token, n_head*d_head]
auto attn = attention(ctx, q, k, v, pe, flash_attn); // [N, n_txt_token + n_img_token, n_head*d_head]
attn = ggml_cont(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size]
auto txt_attn_out = ggml_view_3d(ctx,
attn,
@ -364,13 +371,15 @@ namespace Flux {
int64_t num_heads;
int64_t hidden_size;
int64_t mlp_hidden_dim;
bool flash_attn;
public:
SingleStreamBlock(int64_t hidden_size,
int64_t num_heads,
float mlp_ratio = 4.0f,
float qk_scale = 0.f)
: hidden_size(hidden_size), num_heads(num_heads) {
float qk_scale = 0.f,
bool flash_attn = false)
: hidden_size(hidden_size), num_heads(num_heads), flash_attn(flash_attn) {
int64_t head_dim = hidden_size / num_heads;
float scale = qk_scale;
if (scale <= 0.f) {
@ -433,7 +442,7 @@ namespace Flux {
auto v = ggml_reshape_4d(ctx, qkv_vec[2], head_dim, num_heads, qkv_vec[2]->ne[1], qkv_vec[2]->ne[2]); // [N, n_token, n_head, d_head]
q = norm->query_norm(ctx, q);
k = norm->key_norm(ctx, k);
auto attn = attention(ctx, q, k, v, pe); // [N, n_token, hidden_size]
auto attn = attention(ctx, q, k, v, pe, flash_attn); // [N, n_token, hidden_size]
auto attn_mlp = ggml_concat(ctx, attn, ggml_gelu_inplace(ctx, mlp), 0); // [N, n_token, hidden_size + mlp_hidden_dim]
auto output = linear2->forward(ctx, attn_mlp); // [N, n_token, hidden_size]
@ -492,6 +501,7 @@ namespace Flux {
int theta = 10000;
bool qkv_bias = true;
bool guidance_embed = true;
bool flash_attn = true;
};
struct Flux : public GGMLBlock {
@ -646,13 +656,16 @@ namespace Flux {
blocks["double_blocks." + std::to_string(i)] = std::shared_ptr<GGMLBlock>(new DoubleStreamBlock(params.hidden_size,
params.num_heads,
params.mlp_ratio,
params.qkv_bias));
params.qkv_bias,
params.flash_attn));
}
for (int i = 0; i < params.depth_single_blocks; i++) {
blocks["single_blocks." + std::to_string(i)] = std::shared_ptr<GGMLBlock>(new SingleStreamBlock(params.hidden_size,
params.num_heads,
params.mlp_ratio));
params.mlp_ratio,
0.f,
params.flash_attn));
}
blocks["final_layer"] = std::shared_ptr<GGMLBlock>(new LastLayer(params.hidden_size, 1, out_channels));
@ -817,8 +830,10 @@ namespace Flux {
FluxRunner(ggml_backend_t backend,
ggml_type wtype,
SDVersion version = VERSION_FLUX_DEV)
SDVersion version = VERSION_FLUX_DEV,
bool flash_attn = false)
: GGMLRunner(backend, wtype) {
flux_params.flash_attn = flash_attn;
if (version == VERSION_FLUX_SCHNELL) {
flux_params.guidance_embed = false;
}

View File

@ -666,32 +666,6 @@ __STATIC_INLINE__ std::vector<struct ggml_tensor*> split_qkv(struct ggml_context
return {q, k, v};
}
// q: [N * n_head, n_token, d_head]
// k: [N * n_head, n_k, d_head]
// v: [N * n_head, d_head, n_k]
// return: [N * n_head, n_token, d_head]
__STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention(struct ggml_context* ctx,
struct ggml_tensor* q,
struct ggml_tensor* k,
struct ggml_tensor* v,
bool mask = false) {
#if defined(SD_USE_FLASH_ATTENTION) && !defined(SD_USE_CUBLAS) && !defined(SD_USE_METAL) && !defined(SD_USE_VULKAN) && !defined(SD_USE_SYCL)
struct ggml_tensor* kqv = ggml_flash_attn(ctx, q, k, v, false); // [N * n_head, n_token, d_head]
#else
float d_head = (float)q->ne[0];
struct ggml_tensor* kq = ggml_mul_mat(ctx, k, q); // [N * n_head, n_token, n_k]
kq = ggml_scale_inplace(ctx, kq, 1.0f / sqrt(d_head));
if (mask) {
kq = ggml_diag_mask_inf_inplace(ctx, kq, 0);
}
kq = ggml_soft_max_inplace(ctx, kq);
struct ggml_tensor* kqv = ggml_mul_mat(ctx, v, kq); // [N * n_head, n_token, d_head]
#endif
return kqv;
}
// q: [N, L_q, C] or [N*n_head, L_q, d_head]
// k: [N, L_k, C] or [N*n_head, L_k, d_head]
// v: [N, L_k, C] or [N, L_k, n_head, d_head]
@ -703,7 +677,8 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
int64_t n_head,
struct ggml_tensor* mask = NULL,
bool diag_mask_inf = false,
bool skip_reshape = false) {
bool skip_reshape = false,
bool flash_attn = false) {
int64_t L_q;
int64_t L_k;
int64_t C;
@ -734,13 +709,42 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
float scale = (1.0f / sqrt((float)d_head));
bool use_flash_attn = false;
ggml_tensor* kqv = NULL;
if (use_flash_attn) {
// if (flash_attn) {
// LOG_DEBUG("attention_ext L_q:%d L_k:%d n_head:%d C:%d d_head:%d N:%d", L_q, L_k, n_head, C, d_head, N);
// }
// is there anything oddly shaped?? ping Green-Sky if you can trip this assert
GGML_ASSERT(((L_k % 256 == 0) && L_q == L_k) || !(L_k % 256 == 0));
bool can_use_flash_attn = true;
can_use_flash_attn = can_use_flash_attn && L_k % 256 == 0;
can_use_flash_attn = can_use_flash_attn && d_head % 64 == 0; // double check
// cuda max d_head seems to be 256, cpu does seem to work with 512
can_use_flash_attn = can_use_flash_attn && d_head <= 256; // double check
if (mask != nullptr) {
// TODO(Green-Sky): figure out if we can bend t5 to work too
can_use_flash_attn = can_use_flash_attn && mask->ne[2] == 1;
can_use_flash_attn = can_use_flash_attn && mask->ne[3] == 1;
}
// TODO(Green-Sky): more pad or disable for funny tensor shapes
ggml_tensor* kqv = nullptr;
// GGML_ASSERT((flash_attn && can_use_flash_attn) || !flash_attn);
if (can_use_flash_attn && flash_attn) {
// LOG_DEBUG("using flash attention");
k = ggml_cast(ctx, k, GGML_TYPE_F16);
v = ggml_cont(ctx, ggml_permute(ctx, v, 0, 2, 1, 3)); // [N, n_head, L_k, d_head]
v = ggml_reshape_3d(ctx, v, d_head, L_k, n_head * N); // [N * n_head, L_k, d_head]
LOG_DEBUG("k->ne[1] == %d", k->ne[1]);
v = ggml_cast(ctx, v, GGML_TYPE_F16);
kqv = ggml_flash_attn_ext(ctx, q, k, v, mask, scale, 0, 0);
ggml_flash_attn_ext_set_prec(kqv, GGML_PREC_F32);
// kqv = ggml_view_3d(ctx, kqv, d_head, n_head, L_k, kqv->nb[1], kqv->nb[2], 0);
kqv = ggml_view_3d(ctx, kqv, d_head, n_head, L_q, kqv->nb[1], kqv->nb[2], 0);
} else {
v = ggml_cont(ctx, ggml_permute(ctx, v, 1, 2, 0, 3)); // [N, n_head, d_head, L_k]
v = ggml_reshape_3d(ctx, v, L_k, d_head, n_head * N); // [N * n_head, d_head, L_k]
@ -756,10 +760,12 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
kq = ggml_soft_max_inplace(ctx, kq);
kqv = ggml_mul_mat(ctx, v, kq); // [N * n_head, L_q, d_head]
kqv = ggml_reshape_4d(ctx, kqv, d_head, L_q, n_head, N); // [N, n_head, L_q, d_head]
kqv = ggml_permute(ctx, kqv, 0, 2, 1, 3); // [N, L_q, n_head, d_head]
}
kqv = ggml_reshape_4d(ctx, kqv, d_head, L_q, n_head, N); // [N, n_head, L_q, d_head]
kqv = ggml_cont(ctx, ggml_permute(ctx, kqv, 0, 2, 1, 3)); // [N, L_q, n_head, d_head]
kqv = ggml_cont(ctx, kqv);
kqv = ggml_reshape_3d(ctx, kqv, d_head * n_head, L_q, N); // [N, L_q, C]
return kqv;
@ -1222,7 +1228,6 @@ protected:
if (bias) {
params["bias"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, out_features);
}
}
public:

View File

@ -148,19 +148,19 @@ std::unordered_map<std::string, std::string> vae_decoder_name_map = {
std::unordered_map<std::string, std::string> pmid_v2_name_map = {
{"pmid.qformer_perceiver.perceiver_resampler.layers.0.1.1.weight",
"pmid.qformer_perceiver.perceiver_resampler.layers.0.1.1.fc1.weight"},
"pmid.qformer_perceiver.perceiver_resampler.layers.0.1.1.fc1.weight"},
{"pmid.qformer_perceiver.perceiver_resampler.layers.0.1.3.weight",
"pmid.qformer_perceiver.perceiver_resampler.layers.0.1.1.fc2.weight"},
"pmid.qformer_perceiver.perceiver_resampler.layers.0.1.1.fc2.weight"},
{"pmid.qformer_perceiver.perceiver_resampler.layers.1.1.1.weight",
"pmid.qformer_perceiver.perceiver_resampler.layers.1.1.1.fc1.weight"},
"pmid.qformer_perceiver.perceiver_resampler.layers.1.1.1.fc1.weight"},
{"pmid.qformer_perceiver.perceiver_resampler.layers.1.1.3.weight",
"pmid.qformer_perceiver.perceiver_resampler.layers.1.1.1.fc2.weight"},
{"pmid.qformer_perceiver.perceiver_resampler.layers.2.1.1.weight",
"pmid.qformer_perceiver.perceiver_resampler.layers.2.1.1.fc1.weight"},
"pmid.qformer_perceiver.perceiver_resampler.layers.2.1.1.fc1.weight"},
{"pmid.qformer_perceiver.perceiver_resampler.layers.2.1.3.weight",
"pmid.qformer_perceiver.perceiver_resampler.layers.2.1.1.fc2.weight"},
"pmid.qformer_perceiver.perceiver_resampler.layers.2.1.1.fc2.weight"},
{"pmid.qformer_perceiver.perceiver_resampler.layers.3.1.1.weight",
"pmid.qformer_perceiver.perceiver_resampler.layers.3.1.1.fc1.weight"},
"pmid.qformer_perceiver.perceiver_resampler.layers.3.1.1.fc1.weight"},
{"pmid.qformer_perceiver.perceiver_resampler.layers.3.1.3.weight",
"pmid.qformer_perceiver.perceiver_resampler.layers.3.1.1.fc2.weight"},
{"pmid.qformer_perceiver.token_proj.0.bias",
@ -650,9 +650,8 @@ uint16_t f8_e4m3_to_f16(uint8_t f8) {
return ggml_fp32_to_fp16(*reinterpret_cast<const float*>(&result));
}
uint16_t f8_e5m2_to_f16(uint8_t fp8) {
uint8_t sign = (fp8 >> 7) & 0x1;
uint8_t sign = (fp8 >> 7) & 0x1;
uint8_t exponent = (fp8 >> 2) & 0x1F;
uint8_t mantissa = fp8 & 0x3;
@ -660,23 +659,23 @@ uint16_t f8_e5m2_to_f16(uint8_t fp8) {
uint16_t fp16_exponent;
uint16_t fp16_mantissa;
if (exponent == 0 && mantissa == 0) { //zero
if (exponent == 0 && mantissa == 0) { // zero
return fp16_sign;
}
if (exponent == 0x1F) { //NAN and INF
if (exponent == 0x1F) { // NAN and INF
fp16_exponent = 0x1F;
fp16_mantissa = mantissa ? (mantissa << 8) : 0;
return fp16_sign | (fp16_exponent << 10) | fp16_mantissa;
}
if (exponent == 0) { //subnormal numbers
if (exponent == 0) { // subnormal numbers
fp16_exponent = 0;
fp16_mantissa = (mantissa << 8);
return fp16_sign | fp16_mantissa;
}
//normal numbers
// normal numbers
int16_t true_exponent = (int16_t)exponent - 15 + 15;
if (true_exponent <= 0) {
fp16_exponent = 0;
@ -1434,10 +1433,9 @@ bool ModelLoader::init_from_ckpt_file(const std::string& file_path, const std::s
std::string name = zip_entry_name(zip);
size_t pos = name.find("data.pkl");
if (pos != std::string::npos) {
std::string dir = name.substr(0, pos);
printf("ZIP %d, name = %s, dir = %s \n", i, name.c_str(), dir.c_str());
void* pkl_data = NULL;
void* pkl_data = NULL;
size_t pkl_size;
zip_entry_read(zip, &pkl_data, &pkl_size);

273
pmid.hpp
View File

@ -6,7 +6,6 @@
#include "clip.hpp"
#include "lora.hpp"
struct FuseBlock : public GGMLBlock {
// network hparams
int in_dim;
@ -78,22 +77,20 @@ class QFormerPerceiver(nn.Module):
return out
*/
struct PMFeedForward : public GGMLBlock {
// network hparams
int dim;
public:
PMFeedForward(int d, int multi=4)
: dim(d) {
PMFeedForward(int d, int multi = 4)
: dim(d) {
int inner_dim = dim * multi;
blocks["0"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim));
blocks["1"] = std::shared_ptr<GGMLBlock>(new Mlp(dim, inner_dim, dim, false));
}
struct ggml_tensor* forward(struct ggml_context* ctx,
struct ggml_tensor* x){
struct ggml_tensor* x) {
auto norm = std::dynamic_pointer_cast<LayerNorm>(blocks["0"]);
auto ff = std::dynamic_pointer_cast<Mlp>(blocks["1"]);
@ -101,37 +98,35 @@ public:
x = ff->forward(ctx, x);
return x;
}
};
struct PerceiverAttention : public GGMLBlock {
// network hparams
float scale; // = dim_head**-0.5
int dim_head; // = dim_head
int heads; // = heads
float scale; // = dim_head**-0.5
int dim_head; // = dim_head
int heads; // = heads
public:
PerceiverAttention(int dim, int dim_h=64, int h=8)
: scale(powf(dim_h, -0.5)), dim_head(dim_h), heads(h) {
int inner_dim = dim_head * heads;
PerceiverAttention(int dim, int dim_h = 64, int h = 8)
: scale(powf(dim_h, -0.5)), dim_head(dim_h), heads(h) {
int inner_dim = dim_head * heads;
blocks["norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim));
blocks["norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim));
blocks["to_q"] = std::shared_ptr<GGMLBlock>(new Linear(dim, inner_dim, false));
blocks["to_kv"] = std::shared_ptr<GGMLBlock>(new Linear(dim, inner_dim*2, false));
blocks["to_kv"] = std::shared_ptr<GGMLBlock>(new Linear(dim, inner_dim * 2, false));
blocks["to_out"] = std::shared_ptr<GGMLBlock>(new Linear(inner_dim, dim, false));
}
struct ggml_tensor* reshape_tensor(struct ggml_context* ctx,
struct ggml_tensor* x,
int heads) {
struct ggml_tensor* x,
int heads) {
int64_t ne[4];
for(int i = 0; i < 4; ++i)
ne[i] = x->ne[i];
for (int i = 0; i < 4; ++i)
ne[i] = x->ne[i];
// print_ggml_tensor(x, true, "PerceiverAttention reshape x 0: ");
// printf("heads = %d \n", heads);
// x = ggml_view_4d(ctx, x, x->ne[0], x->ne[1], heads, x->ne[2]/heads,
// x->nb[1], x->nb[2], x->nb[3], 0);
x = ggml_reshape_4d(ctx, x, x->ne[0]/heads, heads, x->ne[1], x->ne[2]);
x = ggml_reshape_4d(ctx, x, x->ne[0] / heads, heads, x->ne[1], x->ne[2]);
// x = ggml_view_4d(ctx, x, x->ne[0]/heads, heads, x->ne[1], x->ne[2],
// x->nb[1], x->nb[2], x->nb[3], 0);
// x = ggml_cont(ctx, x);
@ -142,49 +137,46 @@ public:
}
std::vector<struct ggml_tensor*> chunk_half(struct ggml_context* ctx,
struct ggml_tensor* x){
auto tlo = ggml_view_4d(ctx, x, x->ne[0]/2, x->ne[1], x->ne[2], x->ne[3], x->nb[1], x->nb[2], x->nb[3], 0);
auto tli = ggml_view_4d(ctx, x, x->ne[0]/2, x->ne[1], x->ne[2], x->ne[3], x->nb[1], x->nb[2], x->nb[3], x->nb[0]*x->ne[0]/2);
struct ggml_tensor* x) {
auto tlo = ggml_view_4d(ctx, x, x->ne[0] / 2, x->ne[1], x->ne[2], x->ne[3], x->nb[1], x->nb[2], x->nb[3], 0);
auto tli = ggml_view_4d(ctx, x, x->ne[0] / 2, x->ne[1], x->ne[2], x->ne[3], x->nb[1], x->nb[2], x->nb[3], x->nb[0] * x->ne[0] / 2);
return {ggml_cont(ctx, tlo),
ggml_cont(ctx, tli)};
}
struct ggml_tensor* forward(struct ggml_context* ctx,
struct ggml_tensor* x,
struct ggml_tensor* latents){
struct ggml_tensor* latents) {
// x (torch.Tensor): image features
// shape (b, n1, D)
// latent (torch.Tensor): latent features
// shape (b, n2, D)
int64_t ne[4];
for(int i = 0; i < 4; ++i)
ne[i] = latents->ne[i];
for (int i = 0; i < 4; ++i)
ne[i] = latents->ne[i];
auto norm1 = std::dynamic_pointer_cast<LayerNorm>(blocks["norm1"]);
auto norm2 = std::dynamic_pointer_cast<LayerNorm>(blocks["norm2"]);
x = norm1->forward(ctx, x);
latents = norm2->forward(ctx, latents);
auto to_q = std::dynamic_pointer_cast<Linear>(blocks["to_q"]);
auto q = to_q->forward(ctx, latents);
auto norm1 = std::dynamic_pointer_cast<LayerNorm>(blocks["norm1"]);
auto norm2 = std::dynamic_pointer_cast<LayerNorm>(blocks["norm2"]);
x = norm1->forward(ctx, x);
latents = norm2->forward(ctx, latents);
auto to_q = std::dynamic_pointer_cast<Linear>(blocks["to_q"]);
auto q = to_q->forward(ctx, latents);
auto kv_input = ggml_concat(ctx, x, latents, 1);
auto to_kv = std::dynamic_pointer_cast<Linear>(blocks["to_kv"]);
auto kv = to_kv->forward(ctx, kv_input);
auto k = ggml_view_4d(ctx, kv, kv->ne[0]/2, kv->ne[1], kv->ne[2], kv->ne[3], kv->nb[1]/2, kv->nb[2]/2, kv->nb[3]/2, 0);
auto v = ggml_view_4d(ctx, kv, kv->ne[0]/2, kv->ne[1], kv->ne[2], kv->ne[3], kv->nb[1]/2, kv->nb[2]/2, kv->nb[3]/2, kv->nb[0]*(kv->ne[0]/2));
k = ggml_cont(ctx, k);
v = ggml_cont(ctx, v);
q = reshape_tensor(ctx, q, heads);
k = reshape_tensor(ctx, k, heads);
v = reshape_tensor(ctx, v, heads);
scale = 1.f / sqrt(sqrt((float)dim_head));
k = ggml_scale_inplace(ctx, k, scale);
q = ggml_scale_inplace(ctx, q, scale);
auto kv = to_kv->forward(ctx, kv_input);
auto k = ggml_view_4d(ctx, kv, kv->ne[0] / 2, kv->ne[1], kv->ne[2], kv->ne[3], kv->nb[1] / 2, kv->nb[2] / 2, kv->nb[3] / 2, 0);
auto v = ggml_view_4d(ctx, kv, kv->ne[0] / 2, kv->ne[1], kv->ne[2], kv->ne[3], kv->nb[1] / 2, kv->nb[2] / 2, kv->nb[3] / 2, kv->nb[0] * (kv->ne[0] / 2));
k = ggml_cont(ctx, k);
v = ggml_cont(ctx, v);
q = reshape_tensor(ctx, q, heads);
k = reshape_tensor(ctx, k, heads);
v = reshape_tensor(ctx, v, heads);
scale = 1.f / sqrt(sqrt((float)dim_head));
k = ggml_scale_inplace(ctx, k, scale);
q = ggml_scale_inplace(ctx, q, scale);
// auto weight = ggml_mul_mat(ctx, q, k);
auto weight = ggml_mul_mat(ctx, k, q); // NOTE order of mul is opposite to pytorch
auto weight = ggml_mul_mat(ctx, k, q); // NOTE order of mul is opposite to pytorch
// GGML's softmax() is equivalent to pytorch's softmax(x, dim=-1)
// in this case, dimension along which Softmax will be computed is the last dim
@ -192,13 +184,13 @@ public:
// last dimension (varying most rapidly) corresponds to GGML's first (varying most rapidly).
// weight = ggml_soft_max(ctx, weight);
weight = ggml_soft_max_inplace(ctx, weight);
v = ggml_cont(ctx, ggml_transpose(ctx, v));
v = ggml_cont(ctx, ggml_transpose(ctx, v));
// auto out = ggml_mul_mat(ctx, weight, v);
auto out = ggml_mul_mat(ctx, v, weight); // NOTE order of mul is opposite to pytorch
out = ggml_cont(ctx, ggml_permute(ctx, out, 0, 2, 1, 3));
out = ggml_reshape_3d(ctx, out, ne[0], ne[1], ggml_nelements(out)/(ne[0]*ne[1]));
auto to_out = std::dynamic_pointer_cast<Linear>(blocks["to_out"]);
out = to_out->forward(ctx, out);
auto out = ggml_mul_mat(ctx, v, weight); // NOTE order of mul is opposite to pytorch
out = ggml_cont(ctx, ggml_permute(ctx, out, 0, 2, 1, 3));
out = ggml_reshape_3d(ctx, out, ne[0], ne[1], ggml_nelements(out) / (ne[0] * ne[1]));
auto to_out = std::dynamic_pointer_cast<Linear>(blocks["to_out"]);
out = to_out->forward(ctx, out);
return out;
}
};
@ -206,45 +198,46 @@ public:
struct FacePerceiverResampler : public GGMLBlock {
// network hparams
int depth;
public:
FacePerceiverResampler( int dim=768,
int d=4,
int dim_head=64,
int heads=16,
int embedding_dim=1280,
int output_dim=768,
int ff_mult=4)
: depth(d) {
blocks["proj_in"] = std::shared_ptr<GGMLBlock>(new Linear(embedding_dim, dim, true));
FacePerceiverResampler(int dim = 768,
int d = 4,
int dim_head = 64,
int heads = 16,
int embedding_dim = 1280,
int output_dim = 768,
int ff_mult = 4)
: depth(d) {
blocks["proj_in"] = std::shared_ptr<GGMLBlock>(new Linear(embedding_dim, dim, true));
blocks["proj_out"] = std::shared_ptr<GGMLBlock>(new Linear(dim, output_dim, true));
blocks["norm_out"] = std::shared_ptr<GGMLBlock>(new LayerNorm(output_dim));
for (int i = 0; i < depth; i++) {
std::string name = "layers." + std::to_string(i) + ".0";
blocks[name] = std::shared_ptr<GGMLBlock>(new PerceiverAttention(dim, dim_head, heads));
name = "layers." + std::to_string(i) + ".1";
name = "layers." + std::to_string(i) + ".1";
blocks[name] = std::shared_ptr<GGMLBlock>(new PMFeedForward(dim, ff_mult));
}
}
struct ggml_tensor* forward(struct ggml_context* ctx,
struct ggml_tensor* latents,
struct ggml_tensor* x){
struct ggml_tensor* x) {
// x: [N, channels, h, w]
auto proj_in = std::dynamic_pointer_cast<Linear>(blocks["proj_in"]);
auto proj_out = std::dynamic_pointer_cast<Linear>(blocks["proj_out"]);
auto norm_out = std::dynamic_pointer_cast<LayerNorm>(blocks["norm_out"]);
auto proj_in = std::dynamic_pointer_cast<Linear>(blocks["proj_in"]);
auto proj_out = std::dynamic_pointer_cast<Linear>(blocks["proj_out"]);
auto norm_out = std::dynamic_pointer_cast<LayerNorm>(blocks["norm_out"]);
x = proj_in->forward(ctx, x);
for (int i = 0; i < depth; i++) {
std::string name = "layers." + std::to_string(i) + ".0";
auto attn = std::dynamic_pointer_cast<PerceiverAttention>(blocks[name]);
name = "layers." + std::to_string(i) + ".1";
auto ff = std::dynamic_pointer_cast<PMFeedForward>(blocks[name]);
auto t = attn->forward(ctx, x, latents);
latents = ggml_add(ctx, t, latents);
t = ff->forward(ctx, latents);
latents = ggml_add(ctx, t, latents);
auto attn = std::dynamic_pointer_cast<PerceiverAttention>(blocks[name]);
name = "layers." + std::to_string(i) + ".1";
auto ff = std::dynamic_pointer_cast<PMFeedForward>(blocks[name]);
auto t = attn->forward(ctx, x, latents);
latents = ggml_add(ctx, t, latents);
t = ff->forward(ctx, latents);
latents = ggml_add(ctx, t, latents);
}
latents = proj_out->forward(ctx, latents);
latents = norm_out->forward(ctx, latents);
@ -258,24 +251,22 @@ struct QFormerPerceiver : public GGMLBlock {
int cross_attention_dim;
bool use_residul;
public:
QFormerPerceiver(int id_embeddings_dim, int cross_attention_d, int num_t, int embedding_dim=1024,
bool use_r=true, int ratio=4)
: cross_attention_dim(cross_attention_d), num_tokens(num_t), use_residul(use_r) {
blocks["token_proj"] = std::shared_ptr<GGMLBlock>(new Mlp(id_embeddings_dim,
id_embeddings_dim*ratio,
cross_attention_dim*num_tokens,
true));
blocks["token_norm"] = std::shared_ptr<GGMLBlock>(new LayerNorm(cross_attention_d));
QFormerPerceiver(int id_embeddings_dim, int cross_attention_d, int num_t, int embedding_dim = 1024, bool use_r = true, int ratio = 4)
: cross_attention_dim(cross_attention_d), num_tokens(num_t), use_residul(use_r) {
blocks["token_proj"] = std::shared_ptr<GGMLBlock>(new Mlp(id_embeddings_dim,
id_embeddings_dim * ratio,
cross_attention_dim * num_tokens,
true));
blocks["token_norm"] = std::shared_ptr<GGMLBlock>(new LayerNorm(cross_attention_d));
blocks["perceiver_resampler"] = std::shared_ptr<GGMLBlock>(new FacePerceiverResampler(
cross_attention_dim,
4,
128,
cross_attention_dim / 128,
embedding_dim,
cross_attention_dim,
4));
cross_attention_dim,
4,
128,
cross_attention_dim / 128,
embedding_dim,
cross_attention_dim,
4));
}
/*
@ -291,18 +282,18 @@ public:
struct ggml_tensor* forward(struct ggml_context* ctx,
struct ggml_tensor* x,
struct ggml_tensor* last_hidden_state){
struct ggml_tensor* last_hidden_state) {
// x: [N, channels, h, w]
auto token_proj = std::dynamic_pointer_cast<Mlp>(blocks["token_proj"]);
auto token_norm = std::dynamic_pointer_cast<LayerNorm>(blocks["token_norm"]);
auto token_proj = std::dynamic_pointer_cast<Mlp>(blocks["token_proj"]);
auto token_norm = std::dynamic_pointer_cast<LayerNorm>(blocks["token_norm"]);
auto perceiver_resampler = std::dynamic_pointer_cast<FacePerceiverResampler>(blocks["perceiver_resampler"]);
x = token_proj->forward(ctx, x);
int64_t nel = ggml_nelements(x);
x = ggml_reshape_3d(ctx, x, cross_attention_dim, num_tokens, nel/(cross_attention_dim*num_tokens));
x = token_norm->forward(ctx, x);
x = token_proj->forward(ctx, x);
int64_t nel = ggml_nelements(x);
x = ggml_reshape_3d(ctx, x, cross_attention_dim, num_tokens, nel / (cross_attention_dim * num_tokens));
x = token_norm->forward(ctx, x);
struct ggml_tensor* out = perceiver_resampler->forward(ctx, x, last_hidden_state);
if(use_residul)
if (use_residul)
out = ggml_add(ctx, x, out);
return out;
}
@ -346,8 +337,6 @@ class FacePerceiverResampler(torch.nn.Module):
return self.norm_out(latents)
*/
/*
def FeedForward(dim, mult=4):
@ -417,9 +406,6 @@ class PerceiverAttention(nn.Module):
*/
struct FuseModule : public GGMLBlock {
// network hparams
int embed_dim;
@ -485,8 +471,8 @@ public:
// print_ggml_tensor(class_tokens_mask_pos, true, "class_tokens_mask_pos");
struct ggml_tensor* image_token_embeds = ggml_get_rows(ctx, prompt_embeds, class_tokens_mask_pos);
ggml_set_name(image_token_embeds, "image_token_embeds");
valid_id_embeds = ggml_reshape_2d(ctx, valid_id_embeds, valid_id_embeds->ne[0],
ggml_nelements(valid_id_embeds)/valid_id_embeds->ne[0]);
valid_id_embeds = ggml_reshape_2d(ctx, valid_id_embeds, valid_id_embeds->ne[0],
ggml_nelements(valid_id_embeds) / valid_id_embeds->ne[0]);
struct ggml_tensor* stacked_id_embeds = fuse_fn(ctx, image_token_embeds, valid_id_embeds);
// stacked_id_embeds = ggml_cont(ctx, ggml_permute(ctx, stacked_id_embeds, 0, 2, 1, 3));
@ -555,14 +541,13 @@ struct PhotoMakerIDEncoderBlock : public CLIPVisionModelProjection {
};
struct PhotoMakerIDEncoder_CLIPInsightfaceExtendtokenBlock : public CLIPVisionModelProjection {
int cross_attention_dim;
int num_tokens;
PhotoMakerIDEncoder_CLIPInsightfaceExtendtokenBlock(int id_embeddings_dim=512)
PhotoMakerIDEncoder_CLIPInsightfaceExtendtokenBlock(int id_embeddings_dim = 512)
: CLIPVisionModelProjection(OPENAI_CLIP_VIT_L_14),
cross_attention_dim (2048),
num_tokens(2) {
cross_attention_dim(2048),
num_tokens(2) {
blocks["visual_projection_2"] = std::shared_ptr<GGMLBlock>(new Linear(1024, 1280, false));
blocks["fuse_module"] = std::shared_ptr<GGMLBlock>(new FuseModule(2048));
/*
@ -575,10 +560,9 @@ struct PhotoMakerIDEncoder_CLIPInsightfaceExtendtokenBlock : public CLIPVisionMo
cross_attention_dim,
self.num_tokens,
)*/
blocks["qformer_perceiver"] = std::shared_ptr<GGMLBlock>(new QFormerPerceiver(id_embeddings_dim,
cross_attention_dim,
num_tokens));
blocks["qformer_perceiver"] = std::shared_ptr<GGMLBlock>(new QFormerPerceiver(id_embeddings_dim,
cross_attention_dim,
num_tokens));
}
/*
@ -603,13 +587,13 @@ struct PhotoMakerIDEncoder_CLIPInsightfaceExtendtokenBlock : public CLIPVisionMo
struct ggml_tensor* left,
struct ggml_tensor* right) {
// x: [N, channels, h, w]
auto vision_model = std::dynamic_pointer_cast<CLIPVisionModel>(blocks["vision_model"]);
auto fuse_module = std::dynamic_pointer_cast<FuseModule>(blocks["fuse_module"]);
auto qformer_perceiver = std::dynamic_pointer_cast<QFormerPerceiver>(blocks["qformer_perceiver"]);
auto vision_model = std::dynamic_pointer_cast<CLIPVisionModel>(blocks["vision_model"]);
auto fuse_module = std::dynamic_pointer_cast<FuseModule>(blocks["fuse_module"]);
auto qformer_perceiver = std::dynamic_pointer_cast<QFormerPerceiver>(blocks["qformer_perceiver"]);
// struct ggml_tensor* last_hidden_state = vision_model->forward(ctx, id_pixel_values); // [N, hidden_size]
struct ggml_tensor* last_hidden_state = vision_model->forward(ctx, id_pixel_values, false); // [N, hidden_size]
id_embeds = qformer_perceiver->forward(ctx, id_embeds, last_hidden_state);
struct ggml_tensor* last_hidden_state = vision_model->forward(ctx, id_pixel_values, false); // [N, hidden_size]
id_embeds = qformer_perceiver->forward(ctx, id_embeds, last_hidden_state);
struct ggml_tensor* updated_prompt_embeds = fuse_module->forward(ctx,
prompt_embeds,
@ -623,7 +607,7 @@ struct PhotoMakerIDEncoder_CLIPInsightfaceExtendtokenBlock : public CLIPVisionMo
struct PhotoMakerIDEncoder : public GGMLRunner {
public:
SDVersion version = VERSION_SDXL;
SDVersion version = VERSION_SDXL;
PMVersion pm_version = VERSION_1;
PhotoMakerIDEncoderBlock id_encoder;
PhotoMakerIDEncoder_CLIPInsightfaceExtendtokenBlock id_encoder2;
@ -639,15 +623,14 @@ public:
std::vector<float> zeros_right;
public:
PhotoMakerIDEncoder(ggml_backend_t backend, ggml_type wtype, SDVersion version = VERSION_SDXL,
PMVersion pm_v = VERSION_1, float sty = 20.f)
PhotoMakerIDEncoder(ggml_backend_t backend, ggml_type wtype, SDVersion version = VERSION_SDXL, PMVersion pm_v = VERSION_1, float sty = 20.f)
: GGMLRunner(backend, wtype),
version(version),
pm_version(pm_v),
style_strength(sty) {
if(pm_version == VERSION_1){
if (pm_version == VERSION_1) {
id_encoder.init(params_ctx, wtype);
}else if(pm_version == VERSION_2){
} else if (pm_version == VERSION_2) {
id_encoder2.init(params_ctx, wtype);
}
}
@ -656,17 +639,15 @@ public:
return "pmid";
}
PMVersion get_version() const{
PMVersion get_version() const {
return pm_version;
}
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) {
if(pm_version == VERSION_1)
if (pm_version == VERSION_1)
id_encoder.get_param_tensors(tensors, prefix);
else if(pm_version == VERSION_2)
else if (pm_version == VERSION_2)
id_encoder2.get_param_tensors(tensors, prefix);
}
struct ggml_cgraph* build_graph( // struct ggml_allocr* allocr,
@ -753,14 +734,14 @@ public:
}
}
struct ggml_tensor* updated_prompt_embeds = NULL;
if(pm_version == VERSION_1)
if (pm_version == VERSION_1)
updated_prompt_embeds = id_encoder.forward(ctx0,
id_pixel_values_d,
prompt_embeds_d,
class_tokens_mask_d,
class_tokens_mask_pos,
left, right);
else if(pm_version == VERSION_2)
id_pixel_values_d,
prompt_embeds_d,
class_tokens_mask_d,
class_tokens_mask_pos,
left, right);
else if (pm_version == VERSION_2)
updated_prompt_embeds = id_encoder2.forward(ctx0,
id_pixel_values_d,
prompt_embeds_d,
@ -791,22 +772,19 @@ public:
}
};
struct PhotoMakerIDEmbed : public GGMLRunner {
std::map<std::string, struct ggml_tensor*> tensors;
std::string file_path;
ModelLoader *model_loader;
ModelLoader* model_loader;
bool load_failed = false;
bool applied = false;
PhotoMakerIDEmbed(ggml_backend_t backend,
ggml_type wtype,
ModelLoader *ml,
const std::string& file_path = "",
const std::string& prefix = "")
: file_path(file_path), GGMLRunner(backend, wtype),
model_loader(ml) {
ggml_type wtype,
ModelLoader* ml,
const std::string& file_path = "",
const std::string& prefix = "")
: file_path(file_path), GGMLRunner(backend, wtype), model_loader(ml) {
if (!model_loader->init_from_file(file_path, prefix)) {
load_failed = true;
}
@ -837,7 +815,7 @@ struct PhotoMakerIDEmbed : public GGMLRunner {
tensor_storage.type,
tensor_storage.n_dims,
tensor_storage.ne);
tensors[name] = real;
tensors[name] = real;
} else {
auto real = tensors[name];
*dst_tensor = real;
@ -856,11 +834,10 @@ struct PhotoMakerIDEmbed : public GGMLRunner {
return true;
}
struct ggml_tensor* get(){
struct ggml_tensor* get() {
std::map<std::string, struct ggml_tensor*>::iterator pos;
pos = tensors.find("pmid.id_embeds");
if(pos != tensors.end())
if (pos != tensors.end())
return pos->second;
return NULL;
}

View File

@ -156,7 +156,8 @@ public:
schedule_t schedule,
bool clip_on_cpu,
bool control_net_cpu,
bool vae_on_cpu) {
bool vae_on_cpu,
bool diffusion_flash_attn) {
use_tiny_autoencoder = taesd_path.size() > 0;
#ifdef SD_USE_CUBLAS
LOG_DEBUG("Using CUDA backend");
@ -185,13 +186,7 @@ public:
LOG_DEBUG("Using CPU backend");
backend = ggml_backend_cpu_init();
}
#ifdef SD_USE_FLASH_ATTENTION
#if defined(SD_USE_CUBLAS) || defined(SD_USE_METAL) || defined(SD_USE_SYCL) || defined(SD_USE_VULKAN)
LOG_WARN("Flash Attention not supported with GPU Backend");
#else
LOG_INFO("Flash Attention enabled");
#endif
#endif
ModelLoader model_loader;
vae_tiling = vae_tiling_;
@ -325,19 +320,25 @@ public:
LOG_INFO("CLIP: Using CPU backend");
clip_backend = ggml_backend_cpu_init();
}
if (diffusion_flash_attn) {
LOG_INFO("Using flash attention in the diffusion model");
}
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B) {
if (diffusion_flash_attn) {
LOG_WARN("flash attention in this diffusion model is currently unsupported!");
}
cond_stage_model = std::make_shared<SD3CLIPEmbedder>(clip_backend, conditioner_wtype);
diffusion_model = std::make_shared<MMDiTModel>(backend, diffusion_model_wtype, version);
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) {
cond_stage_model = std::make_shared<FluxCLIPEmbedder>(clip_backend, conditioner_wtype);
diffusion_model = std::make_shared<FluxModel>(backend, diffusion_model_wtype, version);
diffusion_model = std::make_shared<FluxModel>(backend, diffusion_model_wtype, version, diffusion_flash_attn);
} else {
if(id_embeddings_path.find("v2") != std::string::npos) {
if (id_embeddings_path.find("v2") != std::string::npos) {
cond_stage_model = std::make_shared<FrozenCLIPEmbedderWithCustomWords>(clip_backend, conditioner_wtype, embeddings_path, version, VERSION_2);
}else{
} else {
cond_stage_model = std::make_shared<FrozenCLIPEmbedderWithCustomWords>(clip_backend, conditioner_wtype, embeddings_path, version);
}
diffusion_model = std::make_shared<UNetModel>(backend, diffusion_model_wtype, version);
}
diffusion_model = std::make_shared<UNetModel>(backend, diffusion_model_wtype, version, diffusion_flash_attn);
}
cond_stage_model->alloc_params_buffer();
cond_stage_model->get_param_tensors(tensors);
@ -371,7 +372,7 @@ public:
control_net = std::make_shared<ControlNet>(controlnet_backend, diffusion_model_wtype, version);
}
if(id_embeddings_path.find("v2") != std::string::npos) {
if (id_embeddings_path.find("v2") != std::string::npos) {
pmid_model = std::make_shared<PhotoMakerIDEncoder>(backend, model_wtype, version, VERSION_2);
LOG_INFO("using PhotoMaker Version 2");
} else {
@ -1081,7 +1082,8 @@ sd_ctx_t* new_sd_ctx(const char* model_path_c_str,
enum schedule_t s,
bool keep_clip_on_cpu,
bool keep_control_net_cpu,
bool keep_vae_on_cpu) {
bool keep_vae_on_cpu,
bool diffusion_flash_attn) {
sd_ctx_t* sd_ctx = (sd_ctx_t*)malloc(sizeof(sd_ctx_t));
if (sd_ctx == NULL) {
return NULL;
@ -1122,7 +1124,8 @@ sd_ctx_t* new_sd_ctx(const char* model_path_c_str,
s,
keep_clip_on_cpu,
keep_control_net_cpu,
keep_vae_on_cpu)) {
keep_vae_on_cpu,
diffusion_flash_attn)) {
delete sd_ctx->sd;
sd_ctx->sd = NULL;
free(sd_ctx);
@ -1217,7 +1220,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
for (std::string img_file : img_files) {
int c = 0;
int width, height;
if(ends_with(img_file, "safetensors")){
if (ends_with(img_file, "safetensors")) {
continue;
}
uint8_t* input_image_buffer = stbi_load(img_file.c_str(), &width, &height, &c, 3);
@ -1257,18 +1260,18 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
else
sd_mul_images_to_tensor(init_image->data, init_img, i, NULL, NULL);
}
t0 = ggml_time_ms();
auto cond_tup = sd_ctx->sd->cond_stage_model->get_learned_condition_with_trigger(work_ctx,
sd_ctx->sd->n_threads, prompt,
clip_skip,
width,
height,
num_input_images,
sd_ctx->sd->diffusion_model->get_adm_in_channels());
id_cond = std::get<0>(cond_tup);
class_tokens_mask = std::get<1>(cond_tup); //
t0 = ggml_time_ms();
auto cond_tup = sd_ctx->sd->cond_stage_model->get_learned_condition_with_trigger(work_ctx,
sd_ctx->sd->n_threads, prompt,
clip_skip,
width,
height,
num_input_images,
sd_ctx->sd->diffusion_model->get_adm_in_channels());
id_cond = std::get<0>(cond_tup);
class_tokens_mask = std::get<1>(cond_tup); //
struct ggml_tensor* id_embeds = NULL;
if(pmv2){
if (pmv2) {
// id_embeds = sd_ctx->sd->pmid_id_embeds->get();
id_embeds = load_tensor_from_file(work_ctx, path_join(input_id_images_path, "id_embeds.bin"));
// print_ggml_tensor(id_embeds, true, "id_embeds:");

View File

@ -142,7 +142,8 @@ SD_API sd_ctx_t* new_sd_ctx(const char* model_path,
enum schedule_t s,
bool keep_clip_on_cpu,
bool keep_control_net_cpu,
bool keep_vae_on_cpu);
bool keep_vae_on_cpu,
bool diffusion_flash_attn);
SD_API void free_sd_ctx(sd_ctx_t* sd_ctx);

View File

@ -183,7 +183,7 @@ public:
int model_channels = 320;
int adm_in_channels = 2816; // only for VERSION_SDXL/SVD
UnetModelBlock(SDVersion version = VERSION_SD1)
UnetModelBlock(SDVersion version = VERSION_SD1, bool flash_attn = false)
: version(version) {
if (version == VERSION_SD2) {
context_dim = 1024;
@ -242,7 +242,7 @@ public:
if (version == VERSION_SVD) {
return new SpatialVideoTransformer(in_channels, n_head, d_head, depth, context_dim);
} else {
return new SpatialTransformer(in_channels, n_head, d_head, depth, context_dim);
return new SpatialTransformer(in_channels, n_head, d_head, depth, context_dim, flash_attn);
}
};
@ -533,8 +533,9 @@ struct UNetModelRunner : public GGMLRunner {
UNetModelRunner(ggml_backend_t backend,
ggml_type wtype,
SDVersion version = VERSION_SD1)
: GGMLRunner(backend, wtype), unet(version) {
SDVersion version = VERSION_SD1,
bool flash_attn = false)
: GGMLRunner(backend, wtype), unet(version, flash_attn) {
unet.init(params_ctx, wtype);
}

View File

@ -279,12 +279,12 @@ std::string path_join(const std::string& p1, const std::string& p2) {
std::vector<std::string> splitString(const std::string& str, char delimiter) {
std::vector<std::string> result;
size_t start = 0;
size_t end = str.find(delimiter);
size_t end = str.find(delimiter);
while (end != std::string::npos) {
result.push_back(str.substr(start, end - start));
start = end + 1;
end = str.find(delimiter, start);
end = str.find(delimiter, start);
}
// Add the last segment after the last delimiter
@ -293,7 +293,6 @@ std::vector<std::string> splitString(const std::string& str, char delimiter) {
return result;
}
sd_image_t* preprocess_id_image(sd_image_t* img) {
int shortest_edge = 224;
int size = shortest_edge;

View File

@ -99,10 +99,12 @@ public:
k = ggml_cont(ctx, ggml_permute(ctx, k, 1, 2, 0, 3)); // [N, h, w, in_channels]
k = ggml_reshape_3d(ctx, k, c, h * w, n); // [N, h * w, in_channels]
auto v = v_proj->forward(ctx, h_); // [N, in_channels, h, w]
v = ggml_reshape_3d(ctx, v, h * w, c, n); // [N, in_channels, h * w]
auto v = v_proj->forward(ctx, h_); // [N, in_channels, h, w]
v = ggml_cont(ctx, ggml_permute(ctx, v, 1, 2, 0, 3)); // [N, h, w, in_channels]
v = ggml_reshape_3d(ctx, v, c, h * w, n); // [N, h * w, in_channels]
h_ = ggml_nn_attention(ctx, q, k, v, false); // [N, h * w, in_channels]
// h_ = ggml_nn_attention(ctx, q, k, v, false); // [N, h * w, in_channels]
h_ = ggml_nn_attention_ext(ctx, q, k, v, 1, nullptr, false, true, false);
h_ = ggml_cont(ctx, ggml_permute(ctx, h_, 1, 0, 2, 3)); // [N, in_channels, h * w]
h_ = ggml_reshape_4d(ctx, h_, w, h, c, n); // [N, in_channels, h, w]