fix: repair flash attention support (#386)
* repair flash attention in _ext this does not fix the currently broken fa behind the define, which is only used by VAE Co-authored-by: FSSRepo <FSSRepo@users.noreply.github.com> * make flash attention in the diffusion model a runtime flag no support for sd3 or video * remove old flash attention option and switch vae over to attn_ext * update docs * format code --------- Co-authored-by: FSSRepo <FSSRepo@users.noreply.github.com> Co-authored-by: leejet <leejet714@gmail.com>
This commit is contained in:
@@ -116,6 +116,7 @@ struct SDParams {
|
||||
bool normalize_input = false;
|
||||
bool clip_on_cpu = false;
|
||||
bool vae_on_cpu = false;
|
||||
bool diffusion_flash_attn = false;
|
||||
bool canny_preprocess = false;
|
||||
bool color = false;
|
||||
int upscale_repeats = 1;
|
||||
@@ -151,6 +152,7 @@ void print_params(SDParams params) {
|
||||
printf(" clip on cpu: %s\n", params.clip_on_cpu ? "true" : "false");
|
||||
printf(" controlnet cpu: %s\n", params.control_net_cpu ? "true" : "false");
|
||||
printf(" vae decoder on cpu:%s\n", params.vae_on_cpu ? "true" : "false");
|
||||
printf(" diffusion flash attention:%s\n", params.diffusion_flash_attn ? "true" : "false");
|
||||
printf(" strength(control): %.2f\n", params.control_strength);
|
||||
printf(" prompt: %s\n", params.prompt.c_str());
|
||||
printf(" negative_prompt: %s\n", params.negative_prompt.c_str());
|
||||
@@ -227,6 +229,9 @@ void print_usage(int argc, const char* argv[]) {
|
||||
printf(" --vae-tiling process vae in tiles to reduce memory usage\n");
|
||||
printf(" --vae-on-cpu keep vae in cpu (for low vram)\n");
|
||||
printf(" --clip-on-cpu keep clip in cpu (for low vram)\n");
|
||||
printf(" --diffusion-fa use flash attention in the diffusion model (for low vram)\n");
|
||||
printf(" Might lower quality, since it implies converting k and v to f16.\n");
|
||||
printf(" This might crash if it is not supported by the backend.\n");
|
||||
printf(" --control-net-cpu keep controlnet in cpu (for low vram)\n");
|
||||
printf(" --canny apply canny preprocessor (edge detection)\n");
|
||||
printf(" --color Colors the logging tags according to level\n");
|
||||
@@ -477,6 +482,8 @@ void parse_args(int argc, const char** argv, SDParams& params) {
|
||||
params.clip_on_cpu = true; // will slow down get_learned_condiotion but necessary for low MEM GPUs
|
||||
} else if (arg == "--vae-on-cpu") {
|
||||
params.vae_on_cpu = true; // will slow down latent decoding but necessary for low MEM GPUs
|
||||
} else if (arg == "--diffusion-fa") {
|
||||
params.diffusion_flash_attn = true; // can reduce MEM significantly
|
||||
} else if (arg == "--canny") {
|
||||
params.canny_preprocess = true;
|
||||
} else if (arg == "-b" || arg == "--batch-count") {
|
||||
@@ -868,7 +875,8 @@ int main(int argc, const char* argv[]) {
|
||||
params.schedule,
|
||||
params.clip_on_cpu,
|
||||
params.control_net_cpu,
|
||||
params.vae_on_cpu);
|
||||
params.vae_on_cpu,
|
||||
params.diffusion_flash_attn);
|
||||
|
||||
if (sd_ctx == NULL) {
|
||||
printf("new_sd_ctx_t failed\n");
|
||||
|
||||
Reference in New Issue
Block a user