fix: repair flash attention support (#386)

* repair flash attention in _ext
this does not fix the currently broken fa behind the define, which is only used by VAE

Co-authored-by: FSSRepo <FSSRepo@users.noreply.github.com>

* make flash attention in the diffusion model a runtime flag
no support for sd3 or video

* remove old flash attention option and switch vae over to attn_ext

* update docs

* format code

---------

Co-authored-by: FSSRepo <FSSRepo@users.noreply.github.com>
Co-authored-by: leejet <leejet714@gmail.com>
This commit is contained in:
Erik Scholz
2024-11-23 05:39:08 +01:00
committed by GitHub
parent ea9b647080
commit 1c168d98a5
17 changed files with 334 additions and 314 deletions

View File

@@ -116,6 +116,7 @@ struct SDParams {
bool normalize_input = false;
bool clip_on_cpu = false;
bool vae_on_cpu = false;
bool diffusion_flash_attn = false;
bool canny_preprocess = false;
bool color = false;
int upscale_repeats = 1;
@@ -151,6 +152,7 @@ void print_params(SDParams params) {
printf(" clip on cpu: %s\n", params.clip_on_cpu ? "true" : "false");
printf(" controlnet cpu: %s\n", params.control_net_cpu ? "true" : "false");
printf(" vae decoder on cpu:%s\n", params.vae_on_cpu ? "true" : "false");
printf(" diffusion flash attention:%s\n", params.diffusion_flash_attn ? "true" : "false");
printf(" strength(control): %.2f\n", params.control_strength);
printf(" prompt: %s\n", params.prompt.c_str());
printf(" negative_prompt: %s\n", params.negative_prompt.c_str());
@@ -227,6 +229,9 @@ void print_usage(int argc, const char* argv[]) {
printf(" --vae-tiling process vae in tiles to reduce memory usage\n");
printf(" --vae-on-cpu keep vae in cpu (for low vram)\n");
printf(" --clip-on-cpu keep clip in cpu (for low vram)\n");
printf(" --diffusion-fa use flash attention in the diffusion model (for low vram)\n");
printf(" Might lower quality, since it implies converting k and v to f16.\n");
printf(" This might crash if it is not supported by the backend.\n");
printf(" --control-net-cpu keep controlnet in cpu (for low vram)\n");
printf(" --canny apply canny preprocessor (edge detection)\n");
printf(" --color Colors the logging tags according to level\n");
@@ -477,6 +482,8 @@ void parse_args(int argc, const char** argv, SDParams& params) {
params.clip_on_cpu = true; // will slow down get_learned_condiotion but necessary for low MEM GPUs
} else if (arg == "--vae-on-cpu") {
params.vae_on_cpu = true; // will slow down latent decoding but necessary for low MEM GPUs
} else if (arg == "--diffusion-fa") {
params.diffusion_flash_attn = true; // can reduce MEM significantly
} else if (arg == "--canny") {
params.canny_preprocess = true;
} else if (arg == "-b" || arg == "--batch-count") {
@@ -868,7 +875,8 @@ int main(int argc, const char* argv[]) {
params.schedule,
params.clip_on_cpu,
params.control_net_cpu,
params.vae_on_cpu);
params.vae_on_cpu,
params.diffusion_flash_attn);
if (sd_ctx == NULL) {
printf("new_sd_ctx_t failed\n");