feat: add flux support (#356)

* add flux support

* avoid build failures in non-CUDA environments

* fix schnell support

* add k quants support

* add support for applying lora to quantized tensors

* add inplace conversion support for f8_e4m3 (#359)

in the same way it is done for bf16
like how bf16 converts losslessly to fp32,
f8_e4m3 converts losslessly to fp16

* add xlabs flux comfy converted lora support

* update docs

---------

Co-authored-by: Erik Scholz <Green-Sky@users.noreply.github.com>
This commit is contained in:
leejet
2024-08-24 14:29:52 +08:00
committed by GitHub
parent 697d000f49
commit 64d231f384
25 changed files with 1886 additions and 172 deletions

View File

@@ -7,9 +7,8 @@
#include <vector>
// #include "preprocessing.hpp"
#include "mmdit.hpp"
#include "flux.hpp"
#include "stable-diffusion.h"
#include "t5.hpp"
#define STB_IMAGE_IMPLEMENTATION
#define STB_IMAGE_STATIC
@@ -68,6 +67,9 @@ struct SDParams {
SDMode mode = TXT2IMG;
std::string model_path;
std::string clip_l_path;
std::string t5xxl_path;
std::string diffusion_model_path;
std::string vae_path;
std::string taesd_path;
std::string esrgan_path;
@@ -85,6 +87,7 @@ struct SDParams {
std::string negative_prompt;
float min_cfg = 1.0f;
float cfg_scale = 7.0f;
float guidance = 3.5f;
float style_ratio = 20.f;
int clip_skip = -1; // <= 0 represents unspecified
int width = 512;
@@ -120,6 +123,9 @@ void print_params(SDParams params) {
printf(" mode: %s\n", modes_str[params.mode]);
printf(" model_path: %s\n", params.model_path.c_str());
printf(" wtype: %s\n", params.wtype < SD_TYPE_COUNT ? sd_type_name(params.wtype) : "unspecified");
printf(" clip_l_path: %s\n", params.clip_l_path.c_str());
printf(" t5xxl_path: %s\n", params.t5xxl_path.c_str());
printf(" diffusion_model_path: %s\n", params.diffusion_model_path.c_str());
printf(" vae_path: %s\n", params.vae_path.c_str());
printf(" taesd_path: %s\n", params.taesd_path.c_str());
printf(" esrgan_path: %s\n", params.esrgan_path.c_str());
@@ -140,6 +146,7 @@ void print_params(SDParams params) {
printf(" negative_prompt: %s\n", params.negative_prompt.c_str());
printf(" min_cfg: %.2f\n", params.min_cfg);
printf(" cfg_scale: %.2f\n", params.cfg_scale);
printf(" guidance: %.2f\n", params.guidance);
printf(" clip_skip: %d\n", params.clip_skip);
printf(" width: %d\n", params.width);
printf(" height: %d\n", params.height);
@@ -172,7 +179,7 @@ void print_usage(int argc, const char* argv[]) {
printf(" --normalize-input normalize PHOTOMAKER input id images\n");
printf(" --upscale-model [ESRGAN_PATH] path to esrgan model. Upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now.\n");
printf(" --upscale-repeats Run the ESRGAN upscaler this many times (default 1)\n");
printf(" --type [TYPE] weight type (f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0)\n");
printf(" --type [TYPE] weight type (f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_k, q3_k, q4_k)\n");
printf(" If not specified, the default is the type of the weight file.\n");
printf(" --lora-model-dir [DIR] lora model directory\n");
printf(" -i, --init-img [IMAGE] path to the input image, required by img2img\n");
@@ -240,6 +247,24 @@ void parse_args(int argc, const char** argv, SDParams& params) {
break;
}
params.model_path = argv[i];
} else if (arg == "--clip_l") {
if (++i >= argc) {
invalid_arg = true;
break;
}
params.clip_l_path = argv[i];
} else if (arg == "--t5xxl") {
if (++i >= argc) {
invalid_arg = true;
break;
}
params.t5xxl_path = argv[i];
} else if (arg == "--diffusion-model") {
if (++i >= argc) {
invalid_arg = true;
break;
}
params.diffusion_model_path = argv[i];
} else if (arg == "--vae") {
if (++i >= argc) {
invalid_arg = true;
@@ -302,8 +327,14 @@ void parse_args(int argc, const char** argv, SDParams& params) {
params.wtype = SD_TYPE_Q5_1;
} else if (type == "q8_0") {
params.wtype = SD_TYPE_Q8_0;
} else if (type == "q2_k") {
params.wtype = SD_TYPE_Q2_K;
} else if (type == "q3_k") {
params.wtype = SD_TYPE_Q3_K;
} else if (type == "q4_k") {
params.wtype = SD_TYPE_Q4_K;
} else {
fprintf(stderr, "error: invalid weight format %s, must be one of [f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0]\n",
fprintf(stderr, "error: invalid weight format %s, must be one of [f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_k, q3_k, q4_k]\n",
type.c_str());
exit(1);
}
@@ -359,6 +390,12 @@ void parse_args(int argc, const char** argv, SDParams& params) {
break;
}
params.cfg_scale = std::stof(argv[i]);
} else if (arg == "--guidance") {
if (++i >= argc) {
invalid_arg = true;
break;
}
params.guidance = std::stof(argv[i]);
} else if (arg == "--strength") {
if (++i >= argc) {
invalid_arg = true;
@@ -501,8 +538,8 @@ void parse_args(int argc, const char** argv, SDParams& params) {
exit(1);
}
if (params.model_path.length() == 0) {
fprintf(stderr, "error: the following arguments are required: model_path\n");
if (params.model_path.length() == 0 && params.diffusion_model_path.length() == 0) {
fprintf(stderr, "error: the following arguments are required: model_path/diffusion_model\n");
print_usage(argc, argv);
exit(1);
}
@@ -570,6 +607,7 @@ std::string get_image_params(SDParams params, int64_t seed) {
}
parameter_string += "Steps: " + std::to_string(params.sample_steps) + ", ";
parameter_string += "CFG scale: " + std::to_string(params.cfg_scale) + ", ";
parameter_string += "Guidance: " + std::to_string(params.guidance) + ", ";
parameter_string += "Seed: " + std::to_string(seed) + ", ";
parameter_string += "Size: " + std::to_string(params.width) + "x" + std::to_string(params.height) + ", ";
parameter_string += "Model: " + sd_basename(params.model_path) + ", ";
@@ -717,6 +755,9 @@ int main(int argc, const char* argv[]) {
}
sd_ctx_t* sd_ctx = new_sd_ctx(params.model_path.c_str(),
params.clip_l_path.c_str(),
params.t5xxl_path.c_str(),
params.diffusion_model_path.c_str(),
params.vae_path.c_str(),
params.taesd_path.c_str(),
params.controlnet_path.c_str(),
@@ -770,6 +811,7 @@ int main(int argc, const char* argv[]) {
params.negative_prompt.c_str(),
params.clip_skip,
params.cfg_scale,
params.guidance,
params.width,
params.height,
params.sample_method,
@@ -830,6 +872,7 @@ int main(int argc, const char* argv[]) {
params.negative_prompt.c_str(),
params.clip_skip,
params.cfg_scale,
params.guidance,
params.width,
params.height,
params.sample_method,