fix: repair flash attention support (#386)
* repair flash attention in _ext this does not fix the currently broken fa behind the define, which is only used by VAE Co-authored-by: FSSRepo <FSSRepo@users.noreply.github.com> * make flash attention in the diffusion model a runtime flag no support for sd3 or video * remove old flash attention option and switch vae over to attn_ext * update docs * format code --------- Co-authored-by: FSSRepo <FSSRepo@users.noreply.github.com> Co-authored-by: leejet <leejet714@gmail.com>
This commit is contained in:
parent
ea9b647080
commit
1c168d98a5
@ -29,7 +29,6 @@ option(SD_HIPBLAS "sd: rocm backend" OFF)
|
||||
option(SD_METAL "sd: metal backend" OFF)
|
||||
option(SD_VULKAN "sd: vulkan backend" OFF)
|
||||
option(SD_SYCL "sd: sycl backend" OFF)
|
||||
option(SD_FLASH_ATTN "sd: use flash attention for x4 less memory usage" OFF)
|
||||
option(SD_FAST_SOFTMAX "sd: x1.5 faster softmax, indeterministic (sometimes, same seed don't generate same image), cuda only" OFF)
|
||||
option(SD_BUILD_SHARED_LIBS "sd: build shared libs" OFF)
|
||||
#option(SD_BUILD_SERVER "sd: build server example" ON)
|
||||
@ -61,11 +60,6 @@ if (SD_HIPBLAS)
|
||||
endif()
|
||||
endif ()
|
||||
|
||||
if(SD_FLASH_ATTN)
|
||||
message("-- Use Flash Attention for memory optimization")
|
||||
add_definitions(-DSD_USE_FLASH_ATTENTION)
|
||||
endif()
|
||||
|
||||
set(SD_LIB stable-diffusion)
|
||||
|
||||
file(GLOB SD_LIB_SOURCES
|
||||
|
21
README.md
21
README.md
@ -24,7 +24,7 @@ Inference of Stable Diffusion and Flux in pure C/C++
|
||||
- Full CUDA, Metal, Vulkan and SYCL backend for GPU acceleration.
|
||||
- Can load ckpt, safetensors and diffusers models/checkpoints. Standalone VAEs models
|
||||
- No need to convert to `.ggml` or `.gguf` anymore!
|
||||
- Flash Attention for memory usage optimization (only cpu for now)
|
||||
- Flash Attention for memory usage optimization
|
||||
- Original `txt2img` and `img2img` mode
|
||||
- Negative prompt
|
||||
- [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) style tokenizer (not all the features, only token weighting for now)
|
||||
@ -182,11 +182,21 @@ Example of text2img by using SYCL backend:
|
||||
|
||||
##### Using Flash Attention
|
||||
|
||||
Enabling flash attention reduces memory usage by at least 400 MB. At the moment, it is not supported when CUBLAS is enabled because the kernel implementation is missing.
|
||||
Enabling flash attention for the diffusion model reduces memory usage by varying amounts of MB.
|
||||
eg.:
|
||||
- flux 768x768 ~600mb
|
||||
- SD2 768x768 ~1400mb
|
||||
|
||||
For most backends, it slows things down, but for cuda it generally speeds it up too.
|
||||
At the moment, it is only supported for some models and some backends (like cpu, cuda/rocm, metal).
|
||||
|
||||
Run by adding `--diffusion-fa` to the arguments and watch for:
|
||||
```
|
||||
cmake .. -DSD_FLASH_ATTN=ON
|
||||
cmake --build . --config Release
|
||||
[INFO ] stable-diffusion.cpp:312 - Using flash attention in the diffusion model
|
||||
```
|
||||
and the compute buffer shrink in the debug log:
|
||||
```
|
||||
[DEBUG] ggml_extend.hpp:1004 - flux compute buffer size: 650.00 MB(VRAM)
|
||||
```
|
||||
|
||||
### Run
|
||||
@ -240,6 +250,9 @@ arguments:
|
||||
--vae-tiling process vae in tiles to reduce memory usage
|
||||
--vae-on-cpu keep vae in cpu (for low vram)
|
||||
--clip-on-cpu keep clip in cpu (for low vram)
|
||||
--diffusion-fa use flash attention in the diffusion model (for low vram)
|
||||
Might lower quality, since it implies converting k and v to f16.
|
||||
This might crash if it is not supported by the backend.
|
||||
--control-net-cpu keep controlnet in cpu (for low vram)
|
||||
--canny apply canny preprocessor (edge detection)
|
||||
--color Colors the logging tags according to level
|
||||
|
14
clip.hpp
14
clip.hpp
@ -343,8 +343,7 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
std::string clean_up_tokenization(std::string &text){
|
||||
|
||||
std::string clean_up_tokenization(std::string& text) {
|
||||
std::regex pattern(R"( ,)");
|
||||
// Replace " ," with ","
|
||||
std::string result = std::regex_replace(text, pattern, ",");
|
||||
@ -359,10 +358,10 @@ public:
|
||||
std::u32string ts = decoder[t];
|
||||
// printf("%d, %s \n", t, utf32_to_utf8(ts).c_str());
|
||||
std::string s = utf32_to_utf8(ts);
|
||||
if (s.length() >= 4 ){
|
||||
if(ends_with(s, "</w>")) {
|
||||
if (s.length() >= 4) {
|
||||
if (ends_with(s, "</w>")) {
|
||||
text += s.replace(s.length() - 4, s.length() - 1, "") + " ";
|
||||
}else{
|
||||
} else {
|
||||
text += s;
|
||||
}
|
||||
} else {
|
||||
@ -768,8 +767,7 @@ public:
|
||||
blocks["post_layernorm"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size));
|
||||
}
|
||||
|
||||
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* pixel_values,
|
||||
bool return_pooled = true) {
|
||||
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* pixel_values, bool return_pooled = true) {
|
||||
// pixel_values: [N, num_channels, image_size, image_size]
|
||||
auto embeddings = std::dynamic_pointer_cast<CLIPVisionEmbeddings>(blocks["embeddings"]);
|
||||
auto pre_layernorm = std::dynamic_pointer_cast<LayerNorm>(blocks["pre_layernorm"]);
|
||||
@ -781,7 +779,7 @@ public:
|
||||
x = encoder->forward(ctx, x, -1, false);
|
||||
// print_ggml_tensor(x, true, "ClipVisionModel x: ");
|
||||
auto last_hidden_state = x;
|
||||
x = post_layernorm->forward(ctx, x); // [N, n_token, hidden_size]
|
||||
x = post_layernorm->forward(ctx, x); // [N, n_token, hidden_size]
|
||||
|
||||
GGML_ASSERT(x->ne[3] == 1);
|
||||
if (return_pooled) {
|
||||
|
21
common.hpp
21
common.hpp
@ -245,16 +245,19 @@ protected:
|
||||
int64_t context_dim;
|
||||
int64_t n_head;
|
||||
int64_t d_head;
|
||||
bool flash_attn;
|
||||
|
||||
public:
|
||||
CrossAttention(int64_t query_dim,
|
||||
int64_t context_dim,
|
||||
int64_t n_head,
|
||||
int64_t d_head)
|
||||
int64_t d_head,
|
||||
bool flash_attn = false)
|
||||
: n_head(n_head),
|
||||
d_head(d_head),
|
||||
query_dim(query_dim),
|
||||
context_dim(context_dim) {
|
||||
context_dim(context_dim),
|
||||
flash_attn(flash_attn) {
|
||||
int64_t inner_dim = d_head * n_head;
|
||||
|
||||
blocks["to_q"] = std::shared_ptr<GGMLBlock>(new Linear(query_dim, inner_dim, false));
|
||||
@ -283,7 +286,7 @@ public:
|
||||
auto k = to_k->forward(ctx, context); // [N, n_context, inner_dim]
|
||||
auto v = to_v->forward(ctx, context); // [N, n_context, inner_dim]
|
||||
|
||||
x = ggml_nn_attention_ext(ctx, q, k, v, n_head, NULL, false); // [N, n_token, inner_dim]
|
||||
x = ggml_nn_attention_ext(ctx, q, k, v, n_head, NULL, false, false, flash_attn); // [N, n_token, inner_dim]
|
||||
|
||||
x = to_out_0->forward(ctx, x); // [N, n_token, query_dim]
|
||||
return x;
|
||||
@ -301,15 +304,16 @@ public:
|
||||
int64_t n_head,
|
||||
int64_t d_head,
|
||||
int64_t context_dim,
|
||||
bool ff_in = false)
|
||||
bool ff_in = false,
|
||||
bool flash_attn = false)
|
||||
: n_head(n_head), d_head(d_head), ff_in(ff_in) {
|
||||
// disable_self_attn is always False
|
||||
// disable_temporal_crossattention is always False
|
||||
// switch_temporal_ca_to_sa is always False
|
||||
// inner_dim is always None or equal to dim
|
||||
// gated_ff is always True
|
||||
blocks["attn1"] = std::shared_ptr<GGMLBlock>(new CrossAttention(dim, dim, n_head, d_head));
|
||||
blocks["attn2"] = std::shared_ptr<GGMLBlock>(new CrossAttention(dim, context_dim, n_head, d_head));
|
||||
blocks["attn1"] = std::shared_ptr<GGMLBlock>(new CrossAttention(dim, dim, n_head, d_head, flash_attn));
|
||||
blocks["attn2"] = std::shared_ptr<GGMLBlock>(new CrossAttention(dim, context_dim, n_head, d_head, flash_attn));
|
||||
blocks["ff"] = std::shared_ptr<GGMLBlock>(new FeedForward(dim, dim));
|
||||
blocks["norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim));
|
||||
blocks["norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim));
|
||||
@ -374,7 +378,8 @@ public:
|
||||
int64_t n_head,
|
||||
int64_t d_head,
|
||||
int64_t depth,
|
||||
int64_t context_dim)
|
||||
int64_t context_dim,
|
||||
bool flash_attn = false)
|
||||
: in_channels(in_channels),
|
||||
n_head(n_head),
|
||||
d_head(d_head),
|
||||
@ -388,7 +393,7 @@ public:
|
||||
|
||||
for (int i = 0; i < depth; i++) {
|
||||
std::string name = "transformer_blocks." + std::to_string(i);
|
||||
blocks[name] = std::shared_ptr<GGMLBlock>(new BasicTransformerBlock(inner_dim, n_head, d_head, context_dim));
|
||||
blocks[name] = std::shared_ptr<GGMLBlock>(new BasicTransformerBlock(inner_dim, n_head, d_head, context_dim, false, flash_attn));
|
||||
}
|
||||
|
||||
blocks["proj_out"] = std::shared_ptr<GGMLBlock>(new Conv2d(inner_dim, in_channels, {1, 1}));
|
||||
|
@ -4,7 +4,6 @@
|
||||
#include "clip.hpp"
|
||||
#include "t5.hpp"
|
||||
|
||||
|
||||
struct SDCondition {
|
||||
struct ggml_tensor* c_crossattn = NULL; // aka context
|
||||
struct ggml_tensor* c_vector = NULL; // aka y
|
||||
@ -44,7 +43,7 @@ struct Conditioner {
|
||||
// ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||
// Ref: https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/cad87bf4e3e0b0a759afa94e933527c3123d59bc/modules/sd_hijack_clip.py#L283
|
||||
struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
|
||||
SDVersion version = VERSION_SD1;
|
||||
SDVersion version = VERSION_SD1;
|
||||
PMVersion pm_version = VERSION_1;
|
||||
CLIPTokenizer tokenizer;
|
||||
ggml_type wtype;
|
||||
@ -61,7 +60,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
|
||||
ggml_type wtype,
|
||||
const std::string& embd_dir,
|
||||
SDVersion version = VERSION_SD1,
|
||||
PMVersion pv = VERSION_1,
|
||||
PMVersion pv = VERSION_1,
|
||||
int clip_skip = -1)
|
||||
: version(version), pm_version(pv), tokenizer(version == VERSION_SD2 ? 0 : 49407), embd_dir(embd_dir), wtype(wtype) {
|
||||
if (clip_skip <= 0) {
|
||||
@ -162,7 +161,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
|
||||
tokenize_with_trigger_token(std::string text,
|
||||
int num_input_imgs,
|
||||
int32_t image_token,
|
||||
bool padding = false){
|
||||
bool padding = false) {
|
||||
return tokenize_with_trigger_token(text, num_input_imgs, image_token,
|
||||
text_model->model.n_token, padding);
|
||||
}
|
||||
@ -271,7 +270,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
|
||||
std::vector<int> clean_input_ids_tmp;
|
||||
for (uint32_t i = 0; i < class_token_index[0]; i++)
|
||||
clean_input_ids_tmp.push_back(clean_input_ids[i]);
|
||||
for (uint32_t i = 0; i < (pm_version == VERSION_2 ? 2*num_input_imgs: num_input_imgs); i++)
|
||||
for (uint32_t i = 0; i < (pm_version == VERSION_2 ? 2 * num_input_imgs : num_input_imgs); i++)
|
||||
clean_input_ids_tmp.push_back(class_token);
|
||||
for (uint32_t i = class_token_index[0] + 1; i < clean_input_ids.size(); i++)
|
||||
clean_input_ids_tmp.push_back(clean_input_ids[i]);
|
||||
@ -287,11 +286,11 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
|
||||
// weights.insert(weights.begin(), 1.0);
|
||||
|
||||
tokenizer.pad_tokens(tokens, weights, max_length, padding);
|
||||
int offset = pm_version == VERSION_2 ? 2*num_input_imgs: num_input_imgs;
|
||||
int offset = pm_version == VERSION_2 ? 2 * num_input_imgs : num_input_imgs;
|
||||
for (uint32_t i = 0; i < tokens.size(); i++) {
|
||||
// if (class_idx + 1 <= i && i < class_idx + 1 + 2*num_input_imgs) // photomaker V2 has num_tokens(=2)*num_input_imgs
|
||||
if (class_idx + 1 <= i && i < class_idx + 1 + offset) // photomaker V2 has num_tokens(=2)*num_input_imgs
|
||||
// hardcode for now
|
||||
if (class_idx + 1 <= i && i < class_idx + 1 + offset) // photomaker V2 has num_tokens(=2)*num_input_imgs
|
||||
// hardcode for now
|
||||
class_token_mask.push_back(true);
|
||||
else
|
||||
class_token_mask.push_back(false);
|
||||
@ -536,7 +535,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
|
||||
int height,
|
||||
int num_input_imgs,
|
||||
int adm_in_channels = -1,
|
||||
bool force_zero_embeddings = false){
|
||||
bool force_zero_embeddings = false) {
|
||||
auto image_tokens = convert_token_to_id(trigger_word);
|
||||
// if(image_tokens.size() == 1){
|
||||
// printf(" image token id is: %d \n", image_tokens[0]);
|
||||
@ -964,7 +963,7 @@ struct SD3CLIPEmbedder : public Conditioner {
|
||||
int height,
|
||||
int num_input_imgs,
|
||||
int adm_in_channels = -1,
|
||||
bool force_zero_embeddings = false){
|
||||
bool force_zero_embeddings = false) {
|
||||
GGML_ASSERT(0 && "Not implemented yet!");
|
||||
}
|
||||
|
||||
|
@ -32,8 +32,9 @@ struct UNetModel : public DiffusionModel {
|
||||
|
||||
UNetModel(ggml_backend_t backend,
|
||||
ggml_type wtype,
|
||||
SDVersion version = VERSION_SD1)
|
||||
: unet(backend, wtype, version) {
|
||||
SDVersion version = VERSION_SD1,
|
||||
bool flash_attn = false)
|
||||
: unet(backend, wtype, version, flash_attn) {
|
||||
}
|
||||
|
||||
void alloc_params_buffer() {
|
||||
@ -133,8 +134,9 @@ struct FluxModel : public DiffusionModel {
|
||||
|
||||
FluxModel(ggml_backend_t backend,
|
||||
ggml_type wtype,
|
||||
SDVersion version = VERSION_FLUX_DEV)
|
||||
: flux(backend, wtype, version) {
|
||||
SDVersion version = VERSION_FLUX_DEV,
|
||||
bool flash_attn = false)
|
||||
: flux(backend, wtype, version, flash_attn) {
|
||||
}
|
||||
|
||||
void alloc_params_buffer() {
|
||||
|
@ -116,6 +116,7 @@ struct SDParams {
|
||||
bool normalize_input = false;
|
||||
bool clip_on_cpu = false;
|
||||
bool vae_on_cpu = false;
|
||||
bool diffusion_flash_attn = false;
|
||||
bool canny_preprocess = false;
|
||||
bool color = false;
|
||||
int upscale_repeats = 1;
|
||||
@ -151,6 +152,7 @@ void print_params(SDParams params) {
|
||||
printf(" clip on cpu: %s\n", params.clip_on_cpu ? "true" : "false");
|
||||
printf(" controlnet cpu: %s\n", params.control_net_cpu ? "true" : "false");
|
||||
printf(" vae decoder on cpu:%s\n", params.vae_on_cpu ? "true" : "false");
|
||||
printf(" diffusion flash attention:%s\n", params.diffusion_flash_attn ? "true" : "false");
|
||||
printf(" strength(control): %.2f\n", params.control_strength);
|
||||
printf(" prompt: %s\n", params.prompt.c_str());
|
||||
printf(" negative_prompt: %s\n", params.negative_prompt.c_str());
|
||||
@ -227,6 +229,9 @@ void print_usage(int argc, const char* argv[]) {
|
||||
printf(" --vae-tiling process vae in tiles to reduce memory usage\n");
|
||||
printf(" --vae-on-cpu keep vae in cpu (for low vram)\n");
|
||||
printf(" --clip-on-cpu keep clip in cpu (for low vram)\n");
|
||||
printf(" --diffusion-fa use flash attention in the diffusion model (for low vram)\n");
|
||||
printf(" Might lower quality, since it implies converting k and v to f16.\n");
|
||||
printf(" This might crash if it is not supported by the backend.\n");
|
||||
printf(" --control-net-cpu keep controlnet in cpu (for low vram)\n");
|
||||
printf(" --canny apply canny preprocessor (edge detection)\n");
|
||||
printf(" --color Colors the logging tags according to level\n");
|
||||
@ -477,6 +482,8 @@ void parse_args(int argc, const char** argv, SDParams& params) {
|
||||
params.clip_on_cpu = true; // will slow down get_learned_condiotion but necessary for low MEM GPUs
|
||||
} else if (arg == "--vae-on-cpu") {
|
||||
params.vae_on_cpu = true; // will slow down latent decoding but necessary for low MEM GPUs
|
||||
} else if (arg == "--diffusion-fa") {
|
||||
params.diffusion_flash_attn = true; // can reduce MEM significantly
|
||||
} else if (arg == "--canny") {
|
||||
params.canny_preprocess = true;
|
||||
} else if (arg == "-b" || arg == "--batch-count") {
|
||||
@ -868,7 +875,8 @@ int main(int argc, const char* argv[]) {
|
||||
params.schedule,
|
||||
params.clip_on_cpu,
|
||||
params.control_net_cpu,
|
||||
params.vae_on_cpu);
|
||||
params.vae_on_cpu,
|
||||
params.diffusion_flash_attn);
|
||||
|
||||
if (sd_ctx == NULL) {
|
||||
printf("new_sd_ctx_t failed\n");
|
||||
|
47
flux.hpp
47
flux.hpp
@ -115,25 +115,28 @@ namespace Flux {
|
||||
struct ggml_tensor* q,
|
||||
struct ggml_tensor* k,
|
||||
struct ggml_tensor* v,
|
||||
struct ggml_tensor* pe) {
|
||||
struct ggml_tensor* pe,
|
||||
bool flash_attn) {
|
||||
// q,k,v: [N, L, n_head, d_head]
|
||||
// pe: [L, d_head/2, 2, 2]
|
||||
// return: [N, L, n_head*d_head]
|
||||
q = apply_rope(ctx, q, pe); // [N*n_head, L, d_head]
|
||||
k = apply_rope(ctx, k, pe); // [N*n_head, L, d_head]
|
||||
|
||||
auto x = ggml_nn_attention_ext(ctx, q, k, v, v->ne[1], NULL, false, true); // [N, L, n_head*d_head]
|
||||
auto x = ggml_nn_attention_ext(ctx, q, k, v, v->ne[1], NULL, false, true, flash_attn); // [N, L, n_head*d_head]
|
||||
return x;
|
||||
}
|
||||
|
||||
struct SelfAttention : public GGMLBlock {
|
||||
public:
|
||||
int64_t num_heads;
|
||||
bool flash_attn;
|
||||
|
||||
public:
|
||||
SelfAttention(int64_t dim,
|
||||
int64_t num_heads = 8,
|
||||
bool qkv_bias = false)
|
||||
bool qkv_bias = false,
|
||||
bool flash_attn = false)
|
||||
: num_heads(num_heads) {
|
||||
int64_t head_dim = dim / num_heads;
|
||||
blocks["qkv"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim * 3, qkv_bias));
|
||||
@ -167,9 +170,9 @@ namespace Flux {
|
||||
// x: [N, n_token, dim]
|
||||
// pe: [n_token, d_head/2, 2, 2]
|
||||
// return [N, n_token, dim]
|
||||
auto qkv = pre_attention(ctx, x); // q,k,v: [N, n_token, n_head, d_head]
|
||||
x = attention(ctx, qkv[0], qkv[1], qkv[2], pe); // [N, n_token, dim]
|
||||
x = post_attention(ctx, x); // [N, n_token, dim]
|
||||
auto qkv = pre_attention(ctx, x); // q,k,v: [N, n_token, n_head, d_head]
|
||||
x = attention(ctx, qkv[0], qkv[1], qkv[2], pe, flash_attn); // [N, n_token, dim]
|
||||
x = post_attention(ctx, x); // [N, n_token, dim]
|
||||
return x;
|
||||
}
|
||||
};
|
||||
@ -237,15 +240,19 @@ namespace Flux {
|
||||
}
|
||||
|
||||
struct DoubleStreamBlock : public GGMLBlock {
|
||||
bool flash_attn;
|
||||
|
||||
public:
|
||||
DoubleStreamBlock(int64_t hidden_size,
|
||||
int64_t num_heads,
|
||||
float mlp_ratio,
|
||||
bool qkv_bias = false) {
|
||||
bool qkv_bias = false,
|
||||
bool flash_attn = false)
|
||||
: flash_attn(flash_attn) {
|
||||
int64_t mlp_hidden_dim = hidden_size * mlp_ratio;
|
||||
blocks["img_mod"] = std::shared_ptr<GGMLBlock>(new Modulation(hidden_size, true));
|
||||
blocks["img_norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false));
|
||||
blocks["img_attn"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qkv_bias));
|
||||
blocks["img_attn"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qkv_bias, flash_attn));
|
||||
|
||||
blocks["img_norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false));
|
||||
blocks["img_mlp.0"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, mlp_hidden_dim));
|
||||
@ -254,7 +261,7 @@ namespace Flux {
|
||||
|
||||
blocks["txt_mod"] = std::shared_ptr<GGMLBlock>(new Modulation(hidden_size, true));
|
||||
blocks["txt_norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false));
|
||||
blocks["txt_attn"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qkv_bias));
|
||||
blocks["txt_attn"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qkv_bias, flash_attn));
|
||||
|
||||
blocks["txt_norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false));
|
||||
blocks["txt_mlp.0"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, mlp_hidden_dim));
|
||||
@ -316,7 +323,7 @@ namespace Flux {
|
||||
auto k = ggml_concat(ctx, txt_k, img_k, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
|
||||
auto v = ggml_concat(ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
|
||||
|
||||
auto attn = attention(ctx, q, k, v, pe); // [N, n_txt_token + n_img_token, n_head*d_head]
|
||||
auto attn = attention(ctx, q, k, v, pe, flash_attn); // [N, n_txt_token + n_img_token, n_head*d_head]
|
||||
attn = ggml_cont(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size]
|
||||
auto txt_attn_out = ggml_view_3d(ctx,
|
||||
attn,
|
||||
@ -364,13 +371,15 @@ namespace Flux {
|
||||
int64_t num_heads;
|
||||
int64_t hidden_size;
|
||||
int64_t mlp_hidden_dim;
|
||||
bool flash_attn;
|
||||
|
||||
public:
|
||||
SingleStreamBlock(int64_t hidden_size,
|
||||
int64_t num_heads,
|
||||
float mlp_ratio = 4.0f,
|
||||
float qk_scale = 0.f)
|
||||
: hidden_size(hidden_size), num_heads(num_heads) {
|
||||
float qk_scale = 0.f,
|
||||
bool flash_attn = false)
|
||||
: hidden_size(hidden_size), num_heads(num_heads), flash_attn(flash_attn) {
|
||||
int64_t head_dim = hidden_size / num_heads;
|
||||
float scale = qk_scale;
|
||||
if (scale <= 0.f) {
|
||||
@ -433,7 +442,7 @@ namespace Flux {
|
||||
auto v = ggml_reshape_4d(ctx, qkv_vec[2], head_dim, num_heads, qkv_vec[2]->ne[1], qkv_vec[2]->ne[2]); // [N, n_token, n_head, d_head]
|
||||
q = norm->query_norm(ctx, q);
|
||||
k = norm->key_norm(ctx, k);
|
||||
auto attn = attention(ctx, q, k, v, pe); // [N, n_token, hidden_size]
|
||||
auto attn = attention(ctx, q, k, v, pe, flash_attn); // [N, n_token, hidden_size]
|
||||
|
||||
auto attn_mlp = ggml_concat(ctx, attn, ggml_gelu_inplace(ctx, mlp), 0); // [N, n_token, hidden_size + mlp_hidden_dim]
|
||||
auto output = linear2->forward(ctx, attn_mlp); // [N, n_token, hidden_size]
|
||||
@ -492,6 +501,7 @@ namespace Flux {
|
||||
int theta = 10000;
|
||||
bool qkv_bias = true;
|
||||
bool guidance_embed = true;
|
||||
bool flash_attn = true;
|
||||
};
|
||||
|
||||
struct Flux : public GGMLBlock {
|
||||
@ -646,13 +656,16 @@ namespace Flux {
|
||||
blocks["double_blocks." + std::to_string(i)] = std::shared_ptr<GGMLBlock>(new DoubleStreamBlock(params.hidden_size,
|
||||
params.num_heads,
|
||||
params.mlp_ratio,
|
||||
params.qkv_bias));
|
||||
params.qkv_bias,
|
||||
params.flash_attn));
|
||||
}
|
||||
|
||||
for (int i = 0; i < params.depth_single_blocks; i++) {
|
||||
blocks["single_blocks." + std::to_string(i)] = std::shared_ptr<GGMLBlock>(new SingleStreamBlock(params.hidden_size,
|
||||
params.num_heads,
|
||||
params.mlp_ratio));
|
||||
params.mlp_ratio,
|
||||
0.f,
|
||||
params.flash_attn));
|
||||
}
|
||||
|
||||
blocks["final_layer"] = std::shared_ptr<GGMLBlock>(new LastLayer(params.hidden_size, 1, out_channels));
|
||||
@ -817,8 +830,10 @@ namespace Flux {
|
||||
|
||||
FluxRunner(ggml_backend_t backend,
|
||||
ggml_type wtype,
|
||||
SDVersion version = VERSION_FLUX_DEV)
|
||||
SDVersion version = VERSION_FLUX_DEV,
|
||||
bool flash_attn = false)
|
||||
: GGMLRunner(backend, wtype) {
|
||||
flux_params.flash_attn = flash_attn;
|
||||
if (version == VERSION_FLUX_SCHNELL) {
|
||||
flux_params.guidance_embed = false;
|
||||
}
|
||||
|
@ -666,32 +666,6 @@ __STATIC_INLINE__ std::vector<struct ggml_tensor*> split_qkv(struct ggml_context
|
||||
return {q, k, v};
|
||||
}
|
||||
|
||||
// q: [N * n_head, n_token, d_head]
|
||||
// k: [N * n_head, n_k, d_head]
|
||||
// v: [N * n_head, d_head, n_k]
|
||||
// return: [N * n_head, n_token, d_head]
|
||||
__STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention(struct ggml_context* ctx,
|
||||
struct ggml_tensor* q,
|
||||
struct ggml_tensor* k,
|
||||
struct ggml_tensor* v,
|
||||
bool mask = false) {
|
||||
#if defined(SD_USE_FLASH_ATTENTION) && !defined(SD_USE_CUBLAS) && !defined(SD_USE_METAL) && !defined(SD_USE_VULKAN) && !defined(SD_USE_SYCL)
|
||||
struct ggml_tensor* kqv = ggml_flash_attn(ctx, q, k, v, false); // [N * n_head, n_token, d_head]
|
||||
#else
|
||||
float d_head = (float)q->ne[0];
|
||||
|
||||
struct ggml_tensor* kq = ggml_mul_mat(ctx, k, q); // [N * n_head, n_token, n_k]
|
||||
kq = ggml_scale_inplace(ctx, kq, 1.0f / sqrt(d_head));
|
||||
if (mask) {
|
||||
kq = ggml_diag_mask_inf_inplace(ctx, kq, 0);
|
||||
}
|
||||
kq = ggml_soft_max_inplace(ctx, kq);
|
||||
|
||||
struct ggml_tensor* kqv = ggml_mul_mat(ctx, v, kq); // [N * n_head, n_token, d_head]
|
||||
#endif
|
||||
return kqv;
|
||||
}
|
||||
|
||||
// q: [N, L_q, C] or [N*n_head, L_q, d_head]
|
||||
// k: [N, L_k, C] or [N*n_head, L_k, d_head]
|
||||
// v: [N, L_k, C] or [N, L_k, n_head, d_head]
|
||||
@ -703,7 +677,8 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
|
||||
int64_t n_head,
|
||||
struct ggml_tensor* mask = NULL,
|
||||
bool diag_mask_inf = false,
|
||||
bool skip_reshape = false) {
|
||||
bool skip_reshape = false,
|
||||
bool flash_attn = false) {
|
||||
int64_t L_q;
|
||||
int64_t L_k;
|
||||
int64_t C;
|
||||
@ -734,13 +709,42 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
|
||||
|
||||
float scale = (1.0f / sqrt((float)d_head));
|
||||
|
||||
bool use_flash_attn = false;
|
||||
ggml_tensor* kqv = NULL;
|
||||
if (use_flash_attn) {
|
||||
// if (flash_attn) {
|
||||
// LOG_DEBUG("attention_ext L_q:%d L_k:%d n_head:%d C:%d d_head:%d N:%d", L_q, L_k, n_head, C, d_head, N);
|
||||
// }
|
||||
// is there anything oddly shaped?? ping Green-Sky if you can trip this assert
|
||||
GGML_ASSERT(((L_k % 256 == 0) && L_q == L_k) || !(L_k % 256 == 0));
|
||||
|
||||
bool can_use_flash_attn = true;
|
||||
can_use_flash_attn = can_use_flash_attn && L_k % 256 == 0;
|
||||
can_use_flash_attn = can_use_flash_attn && d_head % 64 == 0; // double check
|
||||
|
||||
// cuda max d_head seems to be 256, cpu does seem to work with 512
|
||||
can_use_flash_attn = can_use_flash_attn && d_head <= 256; // double check
|
||||
|
||||
if (mask != nullptr) {
|
||||
// TODO(Green-Sky): figure out if we can bend t5 to work too
|
||||
can_use_flash_attn = can_use_flash_attn && mask->ne[2] == 1;
|
||||
can_use_flash_attn = can_use_flash_attn && mask->ne[3] == 1;
|
||||
}
|
||||
|
||||
// TODO(Green-Sky): more pad or disable for funny tensor shapes
|
||||
|
||||
ggml_tensor* kqv = nullptr;
|
||||
// GGML_ASSERT((flash_attn && can_use_flash_attn) || !flash_attn);
|
||||
if (can_use_flash_attn && flash_attn) {
|
||||
// LOG_DEBUG("using flash attention");
|
||||
k = ggml_cast(ctx, k, GGML_TYPE_F16);
|
||||
|
||||
v = ggml_cont(ctx, ggml_permute(ctx, v, 0, 2, 1, 3)); // [N, n_head, L_k, d_head]
|
||||
v = ggml_reshape_3d(ctx, v, d_head, L_k, n_head * N); // [N * n_head, L_k, d_head]
|
||||
LOG_DEBUG("k->ne[1] == %d", k->ne[1]);
|
||||
v = ggml_cast(ctx, v, GGML_TYPE_F16);
|
||||
|
||||
kqv = ggml_flash_attn_ext(ctx, q, k, v, mask, scale, 0, 0);
|
||||
ggml_flash_attn_ext_set_prec(kqv, GGML_PREC_F32);
|
||||
|
||||
// kqv = ggml_view_3d(ctx, kqv, d_head, n_head, L_k, kqv->nb[1], kqv->nb[2], 0);
|
||||
kqv = ggml_view_3d(ctx, kqv, d_head, n_head, L_q, kqv->nb[1], kqv->nb[2], 0);
|
||||
} else {
|
||||
v = ggml_cont(ctx, ggml_permute(ctx, v, 1, 2, 0, 3)); // [N, n_head, d_head, L_k]
|
||||
v = ggml_reshape_3d(ctx, v, L_k, d_head, n_head * N); // [N * n_head, d_head, L_k]
|
||||
@ -756,10 +760,12 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
|
||||
kq = ggml_soft_max_inplace(ctx, kq);
|
||||
|
||||
kqv = ggml_mul_mat(ctx, v, kq); // [N * n_head, L_q, d_head]
|
||||
|
||||
kqv = ggml_reshape_4d(ctx, kqv, d_head, L_q, n_head, N); // [N, n_head, L_q, d_head]
|
||||
kqv = ggml_permute(ctx, kqv, 0, 2, 1, 3); // [N, L_q, n_head, d_head]
|
||||
}
|
||||
|
||||
kqv = ggml_reshape_4d(ctx, kqv, d_head, L_q, n_head, N); // [N, n_head, L_q, d_head]
|
||||
kqv = ggml_cont(ctx, ggml_permute(ctx, kqv, 0, 2, 1, 3)); // [N, L_q, n_head, d_head]
|
||||
kqv = ggml_cont(ctx, kqv);
|
||||
kqv = ggml_reshape_3d(ctx, kqv, d_head * n_head, L_q, N); // [N, L_q, C]
|
||||
|
||||
return kqv;
|
||||
@ -1222,7 +1228,6 @@ protected:
|
||||
if (bias) {
|
||||
params["bias"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, out_features);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
public:
|
||||
|
26
model.cpp
26
model.cpp
@ -148,19 +148,19 @@ std::unordered_map<std::string, std::string> vae_decoder_name_map = {
|
||||
|
||||
std::unordered_map<std::string, std::string> pmid_v2_name_map = {
|
||||
{"pmid.qformer_perceiver.perceiver_resampler.layers.0.1.1.weight",
|
||||
"pmid.qformer_perceiver.perceiver_resampler.layers.0.1.1.fc1.weight"},
|
||||
"pmid.qformer_perceiver.perceiver_resampler.layers.0.1.1.fc1.weight"},
|
||||
{"pmid.qformer_perceiver.perceiver_resampler.layers.0.1.3.weight",
|
||||
"pmid.qformer_perceiver.perceiver_resampler.layers.0.1.1.fc2.weight"},
|
||||
"pmid.qformer_perceiver.perceiver_resampler.layers.0.1.1.fc2.weight"},
|
||||
{"pmid.qformer_perceiver.perceiver_resampler.layers.1.1.1.weight",
|
||||
"pmid.qformer_perceiver.perceiver_resampler.layers.1.1.1.fc1.weight"},
|
||||
"pmid.qformer_perceiver.perceiver_resampler.layers.1.1.1.fc1.weight"},
|
||||
{"pmid.qformer_perceiver.perceiver_resampler.layers.1.1.3.weight",
|
||||
"pmid.qformer_perceiver.perceiver_resampler.layers.1.1.1.fc2.weight"},
|
||||
{"pmid.qformer_perceiver.perceiver_resampler.layers.2.1.1.weight",
|
||||
"pmid.qformer_perceiver.perceiver_resampler.layers.2.1.1.fc1.weight"},
|
||||
"pmid.qformer_perceiver.perceiver_resampler.layers.2.1.1.fc1.weight"},
|
||||
{"pmid.qformer_perceiver.perceiver_resampler.layers.2.1.3.weight",
|
||||
"pmid.qformer_perceiver.perceiver_resampler.layers.2.1.1.fc2.weight"},
|
||||
"pmid.qformer_perceiver.perceiver_resampler.layers.2.1.1.fc2.weight"},
|
||||
{"pmid.qformer_perceiver.perceiver_resampler.layers.3.1.1.weight",
|
||||
"pmid.qformer_perceiver.perceiver_resampler.layers.3.1.1.fc1.weight"},
|
||||
"pmid.qformer_perceiver.perceiver_resampler.layers.3.1.1.fc1.weight"},
|
||||
{"pmid.qformer_perceiver.perceiver_resampler.layers.3.1.3.weight",
|
||||
"pmid.qformer_perceiver.perceiver_resampler.layers.3.1.1.fc2.weight"},
|
||||
{"pmid.qformer_perceiver.token_proj.0.bias",
|
||||
@ -650,9 +650,8 @@ uint16_t f8_e4m3_to_f16(uint8_t f8) {
|
||||
return ggml_fp32_to_fp16(*reinterpret_cast<const float*>(&result));
|
||||
}
|
||||
|
||||
|
||||
uint16_t f8_e5m2_to_f16(uint8_t fp8) {
|
||||
uint8_t sign = (fp8 >> 7) & 0x1;
|
||||
uint8_t sign = (fp8 >> 7) & 0x1;
|
||||
uint8_t exponent = (fp8 >> 2) & 0x1F;
|
||||
uint8_t mantissa = fp8 & 0x3;
|
||||
|
||||
@ -660,23 +659,23 @@ uint16_t f8_e5m2_to_f16(uint8_t fp8) {
|
||||
uint16_t fp16_exponent;
|
||||
uint16_t fp16_mantissa;
|
||||
|
||||
if (exponent == 0 && mantissa == 0) { //zero
|
||||
if (exponent == 0 && mantissa == 0) { // zero
|
||||
return fp16_sign;
|
||||
}
|
||||
|
||||
if (exponent == 0x1F) { //NAN and INF
|
||||
if (exponent == 0x1F) { // NAN and INF
|
||||
fp16_exponent = 0x1F;
|
||||
fp16_mantissa = mantissa ? (mantissa << 8) : 0;
|
||||
return fp16_sign | (fp16_exponent << 10) | fp16_mantissa;
|
||||
}
|
||||
|
||||
if (exponent == 0) { //subnormal numbers
|
||||
if (exponent == 0) { // subnormal numbers
|
||||
fp16_exponent = 0;
|
||||
fp16_mantissa = (mantissa << 8);
|
||||
return fp16_sign | fp16_mantissa;
|
||||
}
|
||||
|
||||
//normal numbers
|
||||
// normal numbers
|
||||
int16_t true_exponent = (int16_t)exponent - 15 + 15;
|
||||
if (true_exponent <= 0) {
|
||||
fp16_exponent = 0;
|
||||
@ -1434,10 +1433,9 @@ bool ModelLoader::init_from_ckpt_file(const std::string& file_path, const std::s
|
||||
std::string name = zip_entry_name(zip);
|
||||
size_t pos = name.find("data.pkl");
|
||||
if (pos != std::string::npos) {
|
||||
|
||||
std::string dir = name.substr(0, pos);
|
||||
printf("ZIP %d, name = %s, dir = %s \n", i, name.c_str(), dir.c_str());
|
||||
void* pkl_data = NULL;
|
||||
void* pkl_data = NULL;
|
||||
size_t pkl_size;
|
||||
zip_entry_read(zip, &pkl_data, &pkl_size);
|
||||
|
||||
|
273
pmid.hpp
273
pmid.hpp
@ -6,7 +6,6 @@
|
||||
#include "clip.hpp"
|
||||
#include "lora.hpp"
|
||||
|
||||
|
||||
struct FuseBlock : public GGMLBlock {
|
||||
// network hparams
|
||||
int in_dim;
|
||||
@ -78,22 +77,20 @@ class QFormerPerceiver(nn.Module):
|
||||
return out
|
||||
*/
|
||||
|
||||
|
||||
struct PMFeedForward : public GGMLBlock {
|
||||
// network hparams
|
||||
int dim;
|
||||
|
||||
public:
|
||||
PMFeedForward(int d, int multi=4)
|
||||
: dim(d) {
|
||||
PMFeedForward(int d, int multi = 4)
|
||||
: dim(d) {
|
||||
int inner_dim = dim * multi;
|
||||
blocks["0"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim));
|
||||
blocks["1"] = std::shared_ptr<GGMLBlock>(new Mlp(dim, inner_dim, dim, false));
|
||||
}
|
||||
|
||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
||||
struct ggml_tensor* x){
|
||||
|
||||
struct ggml_tensor* x) {
|
||||
auto norm = std::dynamic_pointer_cast<LayerNorm>(blocks["0"]);
|
||||
auto ff = std::dynamic_pointer_cast<Mlp>(blocks["1"]);
|
||||
|
||||
@ -101,37 +98,35 @@ public:
|
||||
x = ff->forward(ctx, x);
|
||||
return x;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
struct PerceiverAttention : public GGMLBlock {
|
||||
// network hparams
|
||||
float scale; // = dim_head**-0.5
|
||||
int dim_head; // = dim_head
|
||||
int heads; // = heads
|
||||
float scale; // = dim_head**-0.5
|
||||
int dim_head; // = dim_head
|
||||
int heads; // = heads
|
||||
public:
|
||||
PerceiverAttention(int dim, int dim_h=64, int h=8)
|
||||
: scale(powf(dim_h, -0.5)), dim_head(dim_h), heads(h) {
|
||||
|
||||
int inner_dim = dim_head * heads;
|
||||
PerceiverAttention(int dim, int dim_h = 64, int h = 8)
|
||||
: scale(powf(dim_h, -0.5)), dim_head(dim_h), heads(h) {
|
||||
int inner_dim = dim_head * heads;
|
||||
blocks["norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim));
|
||||
blocks["norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim));
|
||||
blocks["to_q"] = std::shared_ptr<GGMLBlock>(new Linear(dim, inner_dim, false));
|
||||
blocks["to_kv"] = std::shared_ptr<GGMLBlock>(new Linear(dim, inner_dim*2, false));
|
||||
blocks["to_kv"] = std::shared_ptr<GGMLBlock>(new Linear(dim, inner_dim * 2, false));
|
||||
blocks["to_out"] = std::shared_ptr<GGMLBlock>(new Linear(inner_dim, dim, false));
|
||||
}
|
||||
|
||||
struct ggml_tensor* reshape_tensor(struct ggml_context* ctx,
|
||||
struct ggml_tensor* x,
|
||||
int heads) {
|
||||
struct ggml_tensor* x,
|
||||
int heads) {
|
||||
int64_t ne[4];
|
||||
for(int i = 0; i < 4; ++i)
|
||||
ne[i] = x->ne[i];
|
||||
for (int i = 0; i < 4; ++i)
|
||||
ne[i] = x->ne[i];
|
||||
// print_ggml_tensor(x, true, "PerceiverAttention reshape x 0: ");
|
||||
// printf("heads = %d \n", heads);
|
||||
// x = ggml_view_4d(ctx, x, x->ne[0], x->ne[1], heads, x->ne[2]/heads,
|
||||
// x->nb[1], x->nb[2], x->nb[3], 0);
|
||||
x = ggml_reshape_4d(ctx, x, x->ne[0]/heads, heads, x->ne[1], x->ne[2]);
|
||||
x = ggml_reshape_4d(ctx, x, x->ne[0] / heads, heads, x->ne[1], x->ne[2]);
|
||||
// x = ggml_view_4d(ctx, x, x->ne[0]/heads, heads, x->ne[1], x->ne[2],
|
||||
// x->nb[1], x->nb[2], x->nb[3], 0);
|
||||
// x = ggml_cont(ctx, x);
|
||||
@ -142,49 +137,46 @@ public:
|
||||
}
|
||||
|
||||
std::vector<struct ggml_tensor*> chunk_half(struct ggml_context* ctx,
|
||||
struct ggml_tensor* x){
|
||||
|
||||
auto tlo = ggml_view_4d(ctx, x, x->ne[0]/2, x->ne[1], x->ne[2], x->ne[3], x->nb[1], x->nb[2], x->nb[3], 0);
|
||||
auto tli = ggml_view_4d(ctx, x, x->ne[0]/2, x->ne[1], x->ne[2], x->ne[3], x->nb[1], x->nb[2], x->nb[3], x->nb[0]*x->ne[0]/2);
|
||||
struct ggml_tensor* x) {
|
||||
auto tlo = ggml_view_4d(ctx, x, x->ne[0] / 2, x->ne[1], x->ne[2], x->ne[3], x->nb[1], x->nb[2], x->nb[3], 0);
|
||||
auto tli = ggml_view_4d(ctx, x, x->ne[0] / 2, x->ne[1], x->ne[2], x->ne[3], x->nb[1], x->nb[2], x->nb[3], x->nb[0] * x->ne[0] / 2);
|
||||
return {ggml_cont(ctx, tlo),
|
||||
ggml_cont(ctx, tli)};
|
||||
|
||||
}
|
||||
|
||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
||||
struct ggml_tensor* x,
|
||||
struct ggml_tensor* latents){
|
||||
|
||||
struct ggml_tensor* latents) {
|
||||
// x (torch.Tensor): image features
|
||||
// shape (b, n1, D)
|
||||
// latent (torch.Tensor): latent features
|
||||
// shape (b, n2, D)
|
||||
int64_t ne[4];
|
||||
for(int i = 0; i < 4; ++i)
|
||||
ne[i] = latents->ne[i];
|
||||
for (int i = 0; i < 4; ++i)
|
||||
ne[i] = latents->ne[i];
|
||||
|
||||
auto norm1 = std::dynamic_pointer_cast<LayerNorm>(blocks["norm1"]);
|
||||
auto norm2 = std::dynamic_pointer_cast<LayerNorm>(blocks["norm2"]);
|
||||
x = norm1->forward(ctx, x);
|
||||
latents = norm2->forward(ctx, latents);
|
||||
auto to_q = std::dynamic_pointer_cast<Linear>(blocks["to_q"]);
|
||||
auto q = to_q->forward(ctx, latents);
|
||||
auto norm1 = std::dynamic_pointer_cast<LayerNorm>(blocks["norm1"]);
|
||||
auto norm2 = std::dynamic_pointer_cast<LayerNorm>(blocks["norm2"]);
|
||||
x = norm1->forward(ctx, x);
|
||||
latents = norm2->forward(ctx, latents);
|
||||
auto to_q = std::dynamic_pointer_cast<Linear>(blocks["to_q"]);
|
||||
auto q = to_q->forward(ctx, latents);
|
||||
|
||||
auto kv_input = ggml_concat(ctx, x, latents, 1);
|
||||
auto to_kv = std::dynamic_pointer_cast<Linear>(blocks["to_kv"]);
|
||||
auto kv = to_kv->forward(ctx, kv_input);
|
||||
auto k = ggml_view_4d(ctx, kv, kv->ne[0]/2, kv->ne[1], kv->ne[2], kv->ne[3], kv->nb[1]/2, kv->nb[2]/2, kv->nb[3]/2, 0);
|
||||
auto v = ggml_view_4d(ctx, kv, kv->ne[0]/2, kv->ne[1], kv->ne[2], kv->ne[3], kv->nb[1]/2, kv->nb[2]/2, kv->nb[3]/2, kv->nb[0]*(kv->ne[0]/2));
|
||||
k = ggml_cont(ctx, k);
|
||||
v = ggml_cont(ctx, v);
|
||||
q = reshape_tensor(ctx, q, heads);
|
||||
k = reshape_tensor(ctx, k, heads);
|
||||
v = reshape_tensor(ctx, v, heads);
|
||||
scale = 1.f / sqrt(sqrt((float)dim_head));
|
||||
k = ggml_scale_inplace(ctx, k, scale);
|
||||
q = ggml_scale_inplace(ctx, q, scale);
|
||||
auto kv = to_kv->forward(ctx, kv_input);
|
||||
auto k = ggml_view_4d(ctx, kv, kv->ne[0] / 2, kv->ne[1], kv->ne[2], kv->ne[3], kv->nb[1] / 2, kv->nb[2] / 2, kv->nb[3] / 2, 0);
|
||||
auto v = ggml_view_4d(ctx, kv, kv->ne[0] / 2, kv->ne[1], kv->ne[2], kv->ne[3], kv->nb[1] / 2, kv->nb[2] / 2, kv->nb[3] / 2, kv->nb[0] * (kv->ne[0] / 2));
|
||||
k = ggml_cont(ctx, k);
|
||||
v = ggml_cont(ctx, v);
|
||||
q = reshape_tensor(ctx, q, heads);
|
||||
k = reshape_tensor(ctx, k, heads);
|
||||
v = reshape_tensor(ctx, v, heads);
|
||||
scale = 1.f / sqrt(sqrt((float)dim_head));
|
||||
k = ggml_scale_inplace(ctx, k, scale);
|
||||
q = ggml_scale_inplace(ctx, q, scale);
|
||||
// auto weight = ggml_mul_mat(ctx, q, k);
|
||||
auto weight = ggml_mul_mat(ctx, k, q); // NOTE order of mul is opposite to pytorch
|
||||
auto weight = ggml_mul_mat(ctx, k, q); // NOTE order of mul is opposite to pytorch
|
||||
|
||||
// GGML's softmax() is equivalent to pytorch's softmax(x, dim=-1)
|
||||
// in this case, dimension along which Softmax will be computed is the last dim
|
||||
@ -192,13 +184,13 @@ public:
|
||||
// last dimension (varying most rapidly) corresponds to GGML's first (varying most rapidly).
|
||||
// weight = ggml_soft_max(ctx, weight);
|
||||
weight = ggml_soft_max_inplace(ctx, weight);
|
||||
v = ggml_cont(ctx, ggml_transpose(ctx, v));
|
||||
v = ggml_cont(ctx, ggml_transpose(ctx, v));
|
||||
// auto out = ggml_mul_mat(ctx, weight, v);
|
||||
auto out = ggml_mul_mat(ctx, v, weight); // NOTE order of mul is opposite to pytorch
|
||||
out = ggml_cont(ctx, ggml_permute(ctx, out, 0, 2, 1, 3));
|
||||
out = ggml_reshape_3d(ctx, out, ne[0], ne[1], ggml_nelements(out)/(ne[0]*ne[1]));
|
||||
auto to_out = std::dynamic_pointer_cast<Linear>(blocks["to_out"]);
|
||||
out = to_out->forward(ctx, out);
|
||||
auto out = ggml_mul_mat(ctx, v, weight); // NOTE order of mul is opposite to pytorch
|
||||
out = ggml_cont(ctx, ggml_permute(ctx, out, 0, 2, 1, 3));
|
||||
out = ggml_reshape_3d(ctx, out, ne[0], ne[1], ggml_nelements(out) / (ne[0] * ne[1]));
|
||||
auto to_out = std::dynamic_pointer_cast<Linear>(blocks["to_out"]);
|
||||
out = to_out->forward(ctx, out);
|
||||
return out;
|
||||
}
|
||||
};
|
||||
@ -206,45 +198,46 @@ public:
|
||||
struct FacePerceiverResampler : public GGMLBlock {
|
||||
// network hparams
|
||||
int depth;
|
||||
|
||||
public:
|
||||
FacePerceiverResampler( int dim=768,
|
||||
int d=4,
|
||||
int dim_head=64,
|
||||
int heads=16,
|
||||
int embedding_dim=1280,
|
||||
int output_dim=768,
|
||||
int ff_mult=4)
|
||||
: depth(d) {
|
||||
blocks["proj_in"] = std::shared_ptr<GGMLBlock>(new Linear(embedding_dim, dim, true));
|
||||
FacePerceiverResampler(int dim = 768,
|
||||
int d = 4,
|
||||
int dim_head = 64,
|
||||
int heads = 16,
|
||||
int embedding_dim = 1280,
|
||||
int output_dim = 768,
|
||||
int ff_mult = 4)
|
||||
: depth(d) {
|
||||
blocks["proj_in"] = std::shared_ptr<GGMLBlock>(new Linear(embedding_dim, dim, true));
|
||||
blocks["proj_out"] = std::shared_ptr<GGMLBlock>(new Linear(dim, output_dim, true));
|
||||
blocks["norm_out"] = std::shared_ptr<GGMLBlock>(new LayerNorm(output_dim));
|
||||
|
||||
for (int i = 0; i < depth; i++) {
|
||||
std::string name = "layers." + std::to_string(i) + ".0";
|
||||
blocks[name] = std::shared_ptr<GGMLBlock>(new PerceiverAttention(dim, dim_head, heads));
|
||||
name = "layers." + std::to_string(i) + ".1";
|
||||
name = "layers." + std::to_string(i) + ".1";
|
||||
blocks[name] = std::shared_ptr<GGMLBlock>(new PMFeedForward(dim, ff_mult));
|
||||
}
|
||||
}
|
||||
|
||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
||||
struct ggml_tensor* latents,
|
||||
struct ggml_tensor* x){
|
||||
struct ggml_tensor* x) {
|
||||
// x: [N, channels, h, w]
|
||||
auto proj_in = std::dynamic_pointer_cast<Linear>(blocks["proj_in"]);
|
||||
auto proj_out = std::dynamic_pointer_cast<Linear>(blocks["proj_out"]);
|
||||
auto norm_out = std::dynamic_pointer_cast<LayerNorm>(blocks["norm_out"]);
|
||||
auto proj_in = std::dynamic_pointer_cast<Linear>(blocks["proj_in"]);
|
||||
auto proj_out = std::dynamic_pointer_cast<Linear>(blocks["proj_out"]);
|
||||
auto norm_out = std::dynamic_pointer_cast<LayerNorm>(blocks["norm_out"]);
|
||||
|
||||
x = proj_in->forward(ctx, x);
|
||||
for (int i = 0; i < depth; i++) {
|
||||
std::string name = "layers." + std::to_string(i) + ".0";
|
||||
auto attn = std::dynamic_pointer_cast<PerceiverAttention>(blocks[name]);
|
||||
name = "layers." + std::to_string(i) + ".1";
|
||||
auto ff = std::dynamic_pointer_cast<PMFeedForward>(blocks[name]);
|
||||
auto t = attn->forward(ctx, x, latents);
|
||||
latents = ggml_add(ctx, t, latents);
|
||||
t = ff->forward(ctx, latents);
|
||||
latents = ggml_add(ctx, t, latents);
|
||||
auto attn = std::dynamic_pointer_cast<PerceiverAttention>(blocks[name]);
|
||||
name = "layers." + std::to_string(i) + ".1";
|
||||
auto ff = std::dynamic_pointer_cast<PMFeedForward>(blocks[name]);
|
||||
auto t = attn->forward(ctx, x, latents);
|
||||
latents = ggml_add(ctx, t, latents);
|
||||
t = ff->forward(ctx, latents);
|
||||
latents = ggml_add(ctx, t, latents);
|
||||
}
|
||||
latents = proj_out->forward(ctx, latents);
|
||||
latents = norm_out->forward(ctx, latents);
|
||||
@ -258,24 +251,22 @@ struct QFormerPerceiver : public GGMLBlock {
|
||||
int cross_attention_dim;
|
||||
bool use_residul;
|
||||
|
||||
|
||||
public:
|
||||
QFormerPerceiver(int id_embeddings_dim, int cross_attention_d, int num_t, int embedding_dim=1024,
|
||||
bool use_r=true, int ratio=4)
|
||||
: cross_attention_dim(cross_attention_d), num_tokens(num_t), use_residul(use_r) {
|
||||
blocks["token_proj"] = std::shared_ptr<GGMLBlock>(new Mlp(id_embeddings_dim,
|
||||
id_embeddings_dim*ratio,
|
||||
cross_attention_dim*num_tokens,
|
||||
true));
|
||||
blocks["token_norm"] = std::shared_ptr<GGMLBlock>(new LayerNorm(cross_attention_d));
|
||||
QFormerPerceiver(int id_embeddings_dim, int cross_attention_d, int num_t, int embedding_dim = 1024, bool use_r = true, int ratio = 4)
|
||||
: cross_attention_dim(cross_attention_d), num_tokens(num_t), use_residul(use_r) {
|
||||
blocks["token_proj"] = std::shared_ptr<GGMLBlock>(new Mlp(id_embeddings_dim,
|
||||
id_embeddings_dim * ratio,
|
||||
cross_attention_dim * num_tokens,
|
||||
true));
|
||||
blocks["token_norm"] = std::shared_ptr<GGMLBlock>(new LayerNorm(cross_attention_d));
|
||||
blocks["perceiver_resampler"] = std::shared_ptr<GGMLBlock>(new FacePerceiverResampler(
|
||||
cross_attention_dim,
|
||||
4,
|
||||
128,
|
||||
cross_attention_dim / 128,
|
||||
embedding_dim,
|
||||
cross_attention_dim,
|
||||
4));
|
||||
cross_attention_dim,
|
||||
4,
|
||||
128,
|
||||
cross_attention_dim / 128,
|
||||
embedding_dim,
|
||||
cross_attention_dim,
|
||||
4));
|
||||
}
|
||||
|
||||
/*
|
||||
@ -291,18 +282,18 @@ public:
|
||||
|
||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
||||
struct ggml_tensor* x,
|
||||
struct ggml_tensor* last_hidden_state){
|
||||
struct ggml_tensor* last_hidden_state) {
|
||||
// x: [N, channels, h, w]
|
||||
auto token_proj = std::dynamic_pointer_cast<Mlp>(blocks["token_proj"]);
|
||||
auto token_norm = std::dynamic_pointer_cast<LayerNorm>(blocks["token_norm"]);
|
||||
auto token_proj = std::dynamic_pointer_cast<Mlp>(blocks["token_proj"]);
|
||||
auto token_norm = std::dynamic_pointer_cast<LayerNorm>(blocks["token_norm"]);
|
||||
auto perceiver_resampler = std::dynamic_pointer_cast<FacePerceiverResampler>(blocks["perceiver_resampler"]);
|
||||
|
||||
x = token_proj->forward(ctx, x);
|
||||
int64_t nel = ggml_nelements(x);
|
||||
x = ggml_reshape_3d(ctx, x, cross_attention_dim, num_tokens, nel/(cross_attention_dim*num_tokens));
|
||||
x = token_norm->forward(ctx, x);
|
||||
x = token_proj->forward(ctx, x);
|
||||
int64_t nel = ggml_nelements(x);
|
||||
x = ggml_reshape_3d(ctx, x, cross_attention_dim, num_tokens, nel / (cross_attention_dim * num_tokens));
|
||||
x = token_norm->forward(ctx, x);
|
||||
struct ggml_tensor* out = perceiver_resampler->forward(ctx, x, last_hidden_state);
|
||||
if(use_residul)
|
||||
if (use_residul)
|
||||
out = ggml_add(ctx, x, out);
|
||||
return out;
|
||||
}
|
||||
@ -346,8 +337,6 @@ class FacePerceiverResampler(torch.nn.Module):
|
||||
return self.norm_out(latents)
|
||||
*/
|
||||
|
||||
|
||||
|
||||
/*
|
||||
|
||||
def FeedForward(dim, mult=4):
|
||||
@ -417,9 +406,6 @@ class PerceiverAttention(nn.Module):
|
||||
|
||||
*/
|
||||
|
||||
|
||||
|
||||
|
||||
struct FuseModule : public GGMLBlock {
|
||||
// network hparams
|
||||
int embed_dim;
|
||||
@ -485,8 +471,8 @@ public:
|
||||
// print_ggml_tensor(class_tokens_mask_pos, true, "class_tokens_mask_pos");
|
||||
struct ggml_tensor* image_token_embeds = ggml_get_rows(ctx, prompt_embeds, class_tokens_mask_pos);
|
||||
ggml_set_name(image_token_embeds, "image_token_embeds");
|
||||
valid_id_embeds = ggml_reshape_2d(ctx, valid_id_embeds, valid_id_embeds->ne[0],
|
||||
ggml_nelements(valid_id_embeds)/valid_id_embeds->ne[0]);
|
||||
valid_id_embeds = ggml_reshape_2d(ctx, valid_id_embeds, valid_id_embeds->ne[0],
|
||||
ggml_nelements(valid_id_embeds) / valid_id_embeds->ne[0]);
|
||||
struct ggml_tensor* stacked_id_embeds = fuse_fn(ctx, image_token_embeds, valid_id_embeds);
|
||||
|
||||
// stacked_id_embeds = ggml_cont(ctx, ggml_permute(ctx, stacked_id_embeds, 0, 2, 1, 3));
|
||||
@ -555,14 +541,13 @@ struct PhotoMakerIDEncoderBlock : public CLIPVisionModelProjection {
|
||||
};
|
||||
|
||||
struct PhotoMakerIDEncoder_CLIPInsightfaceExtendtokenBlock : public CLIPVisionModelProjection {
|
||||
|
||||
int cross_attention_dim;
|
||||
int num_tokens;
|
||||
|
||||
PhotoMakerIDEncoder_CLIPInsightfaceExtendtokenBlock(int id_embeddings_dim=512)
|
||||
PhotoMakerIDEncoder_CLIPInsightfaceExtendtokenBlock(int id_embeddings_dim = 512)
|
||||
: CLIPVisionModelProjection(OPENAI_CLIP_VIT_L_14),
|
||||
cross_attention_dim (2048),
|
||||
num_tokens(2) {
|
||||
cross_attention_dim(2048),
|
||||
num_tokens(2) {
|
||||
blocks["visual_projection_2"] = std::shared_ptr<GGMLBlock>(new Linear(1024, 1280, false));
|
||||
blocks["fuse_module"] = std::shared_ptr<GGMLBlock>(new FuseModule(2048));
|
||||
/*
|
||||
@ -575,10 +560,9 @@ struct PhotoMakerIDEncoder_CLIPInsightfaceExtendtokenBlock : public CLIPVisionMo
|
||||
cross_attention_dim,
|
||||
self.num_tokens,
|
||||
)*/
|
||||
blocks["qformer_perceiver"] = std::shared_ptr<GGMLBlock>(new QFormerPerceiver(id_embeddings_dim,
|
||||
cross_attention_dim,
|
||||
num_tokens));
|
||||
|
||||
blocks["qformer_perceiver"] = std::shared_ptr<GGMLBlock>(new QFormerPerceiver(id_embeddings_dim,
|
||||
cross_attention_dim,
|
||||
num_tokens));
|
||||
}
|
||||
|
||||
/*
|
||||
@ -603,13 +587,13 @@ struct PhotoMakerIDEncoder_CLIPInsightfaceExtendtokenBlock : public CLIPVisionMo
|
||||
struct ggml_tensor* left,
|
||||
struct ggml_tensor* right) {
|
||||
// x: [N, channels, h, w]
|
||||
auto vision_model = std::dynamic_pointer_cast<CLIPVisionModel>(blocks["vision_model"]);
|
||||
auto fuse_module = std::dynamic_pointer_cast<FuseModule>(blocks["fuse_module"]);
|
||||
auto qformer_perceiver = std::dynamic_pointer_cast<QFormerPerceiver>(blocks["qformer_perceiver"]);
|
||||
auto vision_model = std::dynamic_pointer_cast<CLIPVisionModel>(blocks["vision_model"]);
|
||||
auto fuse_module = std::dynamic_pointer_cast<FuseModule>(blocks["fuse_module"]);
|
||||
auto qformer_perceiver = std::dynamic_pointer_cast<QFormerPerceiver>(blocks["qformer_perceiver"]);
|
||||
|
||||
// struct ggml_tensor* last_hidden_state = vision_model->forward(ctx, id_pixel_values); // [N, hidden_size]
|
||||
struct ggml_tensor* last_hidden_state = vision_model->forward(ctx, id_pixel_values, false); // [N, hidden_size]
|
||||
id_embeds = qformer_perceiver->forward(ctx, id_embeds, last_hidden_state);
|
||||
struct ggml_tensor* last_hidden_state = vision_model->forward(ctx, id_pixel_values, false); // [N, hidden_size]
|
||||
id_embeds = qformer_perceiver->forward(ctx, id_embeds, last_hidden_state);
|
||||
|
||||
struct ggml_tensor* updated_prompt_embeds = fuse_module->forward(ctx,
|
||||
prompt_embeds,
|
||||
@ -623,7 +607,7 @@ struct PhotoMakerIDEncoder_CLIPInsightfaceExtendtokenBlock : public CLIPVisionMo
|
||||
|
||||
struct PhotoMakerIDEncoder : public GGMLRunner {
|
||||
public:
|
||||
SDVersion version = VERSION_SDXL;
|
||||
SDVersion version = VERSION_SDXL;
|
||||
PMVersion pm_version = VERSION_1;
|
||||
PhotoMakerIDEncoderBlock id_encoder;
|
||||
PhotoMakerIDEncoder_CLIPInsightfaceExtendtokenBlock id_encoder2;
|
||||
@ -639,15 +623,14 @@ public:
|
||||
std::vector<float> zeros_right;
|
||||
|
||||
public:
|
||||
PhotoMakerIDEncoder(ggml_backend_t backend, ggml_type wtype, SDVersion version = VERSION_SDXL,
|
||||
PMVersion pm_v = VERSION_1, float sty = 20.f)
|
||||
PhotoMakerIDEncoder(ggml_backend_t backend, ggml_type wtype, SDVersion version = VERSION_SDXL, PMVersion pm_v = VERSION_1, float sty = 20.f)
|
||||
: GGMLRunner(backend, wtype),
|
||||
version(version),
|
||||
pm_version(pm_v),
|
||||
style_strength(sty) {
|
||||
if(pm_version == VERSION_1){
|
||||
if (pm_version == VERSION_1) {
|
||||
id_encoder.init(params_ctx, wtype);
|
||||
}else if(pm_version == VERSION_2){
|
||||
} else if (pm_version == VERSION_2) {
|
||||
id_encoder2.init(params_ctx, wtype);
|
||||
}
|
||||
}
|
||||
@ -656,17 +639,15 @@ public:
|
||||
return "pmid";
|
||||
}
|
||||
|
||||
PMVersion get_version() const{
|
||||
PMVersion get_version() const {
|
||||
return pm_version;
|
||||
}
|
||||
|
||||
|
||||
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) {
|
||||
if(pm_version == VERSION_1)
|
||||
if (pm_version == VERSION_1)
|
||||
id_encoder.get_param_tensors(tensors, prefix);
|
||||
else if(pm_version == VERSION_2)
|
||||
else if (pm_version == VERSION_2)
|
||||
id_encoder2.get_param_tensors(tensors, prefix);
|
||||
|
||||
}
|
||||
|
||||
struct ggml_cgraph* build_graph( // struct ggml_allocr* allocr,
|
||||
@ -753,14 +734,14 @@ public:
|
||||
}
|
||||
}
|
||||
struct ggml_tensor* updated_prompt_embeds = NULL;
|
||||
if(pm_version == VERSION_1)
|
||||
if (pm_version == VERSION_1)
|
||||
updated_prompt_embeds = id_encoder.forward(ctx0,
|
||||
id_pixel_values_d,
|
||||
prompt_embeds_d,
|
||||
class_tokens_mask_d,
|
||||
class_tokens_mask_pos,
|
||||
left, right);
|
||||
else if(pm_version == VERSION_2)
|
||||
id_pixel_values_d,
|
||||
prompt_embeds_d,
|
||||
class_tokens_mask_d,
|
||||
class_tokens_mask_pos,
|
||||
left, right);
|
||||
else if (pm_version == VERSION_2)
|
||||
updated_prompt_embeds = id_encoder2.forward(ctx0,
|
||||
id_pixel_values_d,
|
||||
prompt_embeds_d,
|
||||
@ -791,22 +772,19 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
struct PhotoMakerIDEmbed : public GGMLRunner {
|
||||
|
||||
std::map<std::string, struct ggml_tensor*> tensors;
|
||||
std::string file_path;
|
||||
ModelLoader *model_loader;
|
||||
ModelLoader* model_loader;
|
||||
bool load_failed = false;
|
||||
bool applied = false;
|
||||
|
||||
PhotoMakerIDEmbed(ggml_backend_t backend,
|
||||
ggml_type wtype,
|
||||
ModelLoader *ml,
|
||||
const std::string& file_path = "",
|
||||
const std::string& prefix = "")
|
||||
: file_path(file_path), GGMLRunner(backend, wtype),
|
||||
model_loader(ml) {
|
||||
ggml_type wtype,
|
||||
ModelLoader* ml,
|
||||
const std::string& file_path = "",
|
||||
const std::string& prefix = "")
|
||||
: file_path(file_path), GGMLRunner(backend, wtype), model_loader(ml) {
|
||||
if (!model_loader->init_from_file(file_path, prefix)) {
|
||||
load_failed = true;
|
||||
}
|
||||
@ -837,7 +815,7 @@ struct PhotoMakerIDEmbed : public GGMLRunner {
|
||||
tensor_storage.type,
|
||||
tensor_storage.n_dims,
|
||||
tensor_storage.ne);
|
||||
tensors[name] = real;
|
||||
tensors[name] = real;
|
||||
} else {
|
||||
auto real = tensors[name];
|
||||
*dst_tensor = real;
|
||||
@ -856,11 +834,10 @@ struct PhotoMakerIDEmbed : public GGMLRunner {
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
struct ggml_tensor* get(){
|
||||
struct ggml_tensor* get() {
|
||||
std::map<std::string, struct ggml_tensor*>::iterator pos;
|
||||
pos = tensors.find("pmid.id_embeds");
|
||||
if(pos != tensors.end())
|
||||
if (pos != tensors.end())
|
||||
return pos->second;
|
||||
return NULL;
|
||||
}
|
||||
|
@ -156,7 +156,8 @@ public:
|
||||
schedule_t schedule,
|
||||
bool clip_on_cpu,
|
||||
bool control_net_cpu,
|
||||
bool vae_on_cpu) {
|
||||
bool vae_on_cpu,
|
||||
bool diffusion_flash_attn) {
|
||||
use_tiny_autoencoder = taesd_path.size() > 0;
|
||||
#ifdef SD_USE_CUBLAS
|
||||
LOG_DEBUG("Using CUDA backend");
|
||||
@ -185,13 +186,7 @@ public:
|
||||
LOG_DEBUG("Using CPU backend");
|
||||
backend = ggml_backend_cpu_init();
|
||||
}
|
||||
#ifdef SD_USE_FLASH_ATTENTION
|
||||
#if defined(SD_USE_CUBLAS) || defined(SD_USE_METAL) || defined(SD_USE_SYCL) || defined(SD_USE_VULKAN)
|
||||
LOG_WARN("Flash Attention not supported with GPU Backend");
|
||||
#else
|
||||
LOG_INFO("Flash Attention enabled");
|
||||
#endif
|
||||
#endif
|
||||
|
||||
ModelLoader model_loader;
|
||||
|
||||
vae_tiling = vae_tiling_;
|
||||
@ -325,19 +320,25 @@ public:
|
||||
LOG_INFO("CLIP: Using CPU backend");
|
||||
clip_backend = ggml_backend_cpu_init();
|
||||
}
|
||||
if (diffusion_flash_attn) {
|
||||
LOG_INFO("Using flash attention in the diffusion model");
|
||||
}
|
||||
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B) {
|
||||
if (diffusion_flash_attn) {
|
||||
LOG_WARN("flash attention in this diffusion model is currently unsupported!");
|
||||
}
|
||||
cond_stage_model = std::make_shared<SD3CLIPEmbedder>(clip_backend, conditioner_wtype);
|
||||
diffusion_model = std::make_shared<MMDiTModel>(backend, diffusion_model_wtype, version);
|
||||
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) {
|
||||
cond_stage_model = std::make_shared<FluxCLIPEmbedder>(clip_backend, conditioner_wtype);
|
||||
diffusion_model = std::make_shared<FluxModel>(backend, diffusion_model_wtype, version);
|
||||
diffusion_model = std::make_shared<FluxModel>(backend, diffusion_model_wtype, version, diffusion_flash_attn);
|
||||
} else {
|
||||
if(id_embeddings_path.find("v2") != std::string::npos) {
|
||||
if (id_embeddings_path.find("v2") != std::string::npos) {
|
||||
cond_stage_model = std::make_shared<FrozenCLIPEmbedderWithCustomWords>(clip_backend, conditioner_wtype, embeddings_path, version, VERSION_2);
|
||||
}else{
|
||||
} else {
|
||||
cond_stage_model = std::make_shared<FrozenCLIPEmbedderWithCustomWords>(clip_backend, conditioner_wtype, embeddings_path, version);
|
||||
}
|
||||
diffusion_model = std::make_shared<UNetModel>(backend, diffusion_model_wtype, version);
|
||||
}
|
||||
diffusion_model = std::make_shared<UNetModel>(backend, diffusion_model_wtype, version, diffusion_flash_attn);
|
||||
}
|
||||
cond_stage_model->alloc_params_buffer();
|
||||
cond_stage_model->get_param_tensors(tensors);
|
||||
@ -371,7 +372,7 @@ public:
|
||||
control_net = std::make_shared<ControlNet>(controlnet_backend, diffusion_model_wtype, version);
|
||||
}
|
||||
|
||||
if(id_embeddings_path.find("v2") != std::string::npos) {
|
||||
if (id_embeddings_path.find("v2") != std::string::npos) {
|
||||
pmid_model = std::make_shared<PhotoMakerIDEncoder>(backend, model_wtype, version, VERSION_2);
|
||||
LOG_INFO("using PhotoMaker Version 2");
|
||||
} else {
|
||||
@ -1081,7 +1082,8 @@ sd_ctx_t* new_sd_ctx(const char* model_path_c_str,
|
||||
enum schedule_t s,
|
||||
bool keep_clip_on_cpu,
|
||||
bool keep_control_net_cpu,
|
||||
bool keep_vae_on_cpu) {
|
||||
bool keep_vae_on_cpu,
|
||||
bool diffusion_flash_attn) {
|
||||
sd_ctx_t* sd_ctx = (sd_ctx_t*)malloc(sizeof(sd_ctx_t));
|
||||
if (sd_ctx == NULL) {
|
||||
return NULL;
|
||||
@ -1122,7 +1124,8 @@ sd_ctx_t* new_sd_ctx(const char* model_path_c_str,
|
||||
s,
|
||||
keep_clip_on_cpu,
|
||||
keep_control_net_cpu,
|
||||
keep_vae_on_cpu)) {
|
||||
keep_vae_on_cpu,
|
||||
diffusion_flash_attn)) {
|
||||
delete sd_ctx->sd;
|
||||
sd_ctx->sd = NULL;
|
||||
free(sd_ctx);
|
||||
@ -1217,7 +1220,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
|
||||
for (std::string img_file : img_files) {
|
||||
int c = 0;
|
||||
int width, height;
|
||||
if(ends_with(img_file, "safetensors")){
|
||||
if (ends_with(img_file, "safetensors")) {
|
||||
continue;
|
||||
}
|
||||
uint8_t* input_image_buffer = stbi_load(img_file.c_str(), &width, &height, &c, 3);
|
||||
@ -1257,18 +1260,18 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
|
||||
else
|
||||
sd_mul_images_to_tensor(init_image->data, init_img, i, NULL, NULL);
|
||||
}
|
||||
t0 = ggml_time_ms();
|
||||
auto cond_tup = sd_ctx->sd->cond_stage_model->get_learned_condition_with_trigger(work_ctx,
|
||||
sd_ctx->sd->n_threads, prompt,
|
||||
clip_skip,
|
||||
width,
|
||||
height,
|
||||
num_input_images,
|
||||
sd_ctx->sd->diffusion_model->get_adm_in_channels());
|
||||
id_cond = std::get<0>(cond_tup);
|
||||
class_tokens_mask = std::get<1>(cond_tup); //
|
||||
t0 = ggml_time_ms();
|
||||
auto cond_tup = sd_ctx->sd->cond_stage_model->get_learned_condition_with_trigger(work_ctx,
|
||||
sd_ctx->sd->n_threads, prompt,
|
||||
clip_skip,
|
||||
width,
|
||||
height,
|
||||
num_input_images,
|
||||
sd_ctx->sd->diffusion_model->get_adm_in_channels());
|
||||
id_cond = std::get<0>(cond_tup);
|
||||
class_tokens_mask = std::get<1>(cond_tup); //
|
||||
struct ggml_tensor* id_embeds = NULL;
|
||||
if(pmv2){
|
||||
if (pmv2) {
|
||||
// id_embeds = sd_ctx->sd->pmid_id_embeds->get();
|
||||
id_embeds = load_tensor_from_file(work_ctx, path_join(input_id_images_path, "id_embeds.bin"));
|
||||
// print_ggml_tensor(id_embeds, true, "id_embeds:");
|
||||
|
@ -142,7 +142,8 @@ SD_API sd_ctx_t* new_sd_ctx(const char* model_path,
|
||||
enum schedule_t s,
|
||||
bool keep_clip_on_cpu,
|
||||
bool keep_control_net_cpu,
|
||||
bool keep_vae_on_cpu);
|
||||
bool keep_vae_on_cpu,
|
||||
bool diffusion_flash_attn);
|
||||
|
||||
SD_API void free_sd_ctx(sd_ctx_t* sd_ctx);
|
||||
|
||||
|
9
unet.hpp
9
unet.hpp
@ -183,7 +183,7 @@ public:
|
||||
int model_channels = 320;
|
||||
int adm_in_channels = 2816; // only for VERSION_SDXL/SVD
|
||||
|
||||
UnetModelBlock(SDVersion version = VERSION_SD1)
|
||||
UnetModelBlock(SDVersion version = VERSION_SD1, bool flash_attn = false)
|
||||
: version(version) {
|
||||
if (version == VERSION_SD2) {
|
||||
context_dim = 1024;
|
||||
@ -242,7 +242,7 @@ public:
|
||||
if (version == VERSION_SVD) {
|
||||
return new SpatialVideoTransformer(in_channels, n_head, d_head, depth, context_dim);
|
||||
} else {
|
||||
return new SpatialTransformer(in_channels, n_head, d_head, depth, context_dim);
|
||||
return new SpatialTransformer(in_channels, n_head, d_head, depth, context_dim, flash_attn);
|
||||
}
|
||||
};
|
||||
|
||||
@ -533,8 +533,9 @@ struct UNetModelRunner : public GGMLRunner {
|
||||
|
||||
UNetModelRunner(ggml_backend_t backend,
|
||||
ggml_type wtype,
|
||||
SDVersion version = VERSION_SD1)
|
||||
: GGMLRunner(backend, wtype), unet(version) {
|
||||
SDVersion version = VERSION_SD1,
|
||||
bool flash_attn = false)
|
||||
: GGMLRunner(backend, wtype), unet(version, flash_attn) {
|
||||
unet.init(params_ctx, wtype);
|
||||
}
|
||||
|
||||
|
5
util.cpp
5
util.cpp
@ -279,12 +279,12 @@ std::string path_join(const std::string& p1, const std::string& p2) {
|
||||
std::vector<std::string> splitString(const std::string& str, char delimiter) {
|
||||
std::vector<std::string> result;
|
||||
size_t start = 0;
|
||||
size_t end = str.find(delimiter);
|
||||
size_t end = str.find(delimiter);
|
||||
|
||||
while (end != std::string::npos) {
|
||||
result.push_back(str.substr(start, end - start));
|
||||
start = end + 1;
|
||||
end = str.find(delimiter, start);
|
||||
end = str.find(delimiter, start);
|
||||
}
|
||||
|
||||
// Add the last segment after the last delimiter
|
||||
@ -293,7 +293,6 @@ std::vector<std::string> splitString(const std::string& str, char delimiter) {
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
sd_image_t* preprocess_id_image(sd_image_t* img) {
|
||||
int shortest_edge = 224;
|
||||
int size = shortest_edge;
|
||||
|
8
vae.hpp
8
vae.hpp
@ -99,10 +99,12 @@ public:
|
||||
k = ggml_cont(ctx, ggml_permute(ctx, k, 1, 2, 0, 3)); // [N, h, w, in_channels]
|
||||
k = ggml_reshape_3d(ctx, k, c, h * w, n); // [N, h * w, in_channels]
|
||||
|
||||
auto v = v_proj->forward(ctx, h_); // [N, in_channels, h, w]
|
||||
v = ggml_reshape_3d(ctx, v, h * w, c, n); // [N, in_channels, h * w]
|
||||
auto v = v_proj->forward(ctx, h_); // [N, in_channels, h, w]
|
||||
v = ggml_cont(ctx, ggml_permute(ctx, v, 1, 2, 0, 3)); // [N, h, w, in_channels]
|
||||
v = ggml_reshape_3d(ctx, v, c, h * w, n); // [N, h * w, in_channels]
|
||||
|
||||
h_ = ggml_nn_attention(ctx, q, k, v, false); // [N, h * w, in_channels]
|
||||
// h_ = ggml_nn_attention(ctx, q, k, v, false); // [N, h * w, in_channels]
|
||||
h_ = ggml_nn_attention_ext(ctx, q, k, v, 1, nullptr, false, true, false);
|
||||
|
||||
h_ = ggml_cont(ctx, ggml_permute(ctx, h_, 1, 0, 2, 3)); // [N, in_channels, h * w]
|
||||
h_ = ggml_reshape_4d(ctx, h_, w, h, c, n); // [N, in_channels, h, w]
|
||||
|
Loading…
Reference in New Issue
Block a user