feat: add flux support (#356)
* add flux support * avoid build failures in non-CUDA environments * fix schnell support * add k quants support * add support for applying lora to quantized tensors * add inplace conversion support for f8_e4m3 (#359) in the same way it is done for bf16 like how bf16 converts losslessly to fp32, f8_e4m3 converts losslessly to fp16 * add xlabs flux comfy converted lora support * update docs --------- Co-authored-by: Erik Scholz <Green-Sky@users.noreply.github.com>
This commit is contained in:
@@ -7,9 +7,8 @@
|
||||
#include <vector>
|
||||
|
||||
// #include "preprocessing.hpp"
|
||||
#include "mmdit.hpp"
|
||||
#include "flux.hpp"
|
||||
#include "stable-diffusion.h"
|
||||
#include "t5.hpp"
|
||||
|
||||
#define STB_IMAGE_IMPLEMENTATION
|
||||
#define STB_IMAGE_STATIC
|
||||
@@ -68,6 +67,9 @@ struct SDParams {
|
||||
SDMode mode = TXT2IMG;
|
||||
|
||||
std::string model_path;
|
||||
std::string clip_l_path;
|
||||
std::string t5xxl_path;
|
||||
std::string diffusion_model_path;
|
||||
std::string vae_path;
|
||||
std::string taesd_path;
|
||||
std::string esrgan_path;
|
||||
@@ -85,6 +87,7 @@ struct SDParams {
|
||||
std::string negative_prompt;
|
||||
float min_cfg = 1.0f;
|
||||
float cfg_scale = 7.0f;
|
||||
float guidance = 3.5f;
|
||||
float style_ratio = 20.f;
|
||||
int clip_skip = -1; // <= 0 represents unspecified
|
||||
int width = 512;
|
||||
@@ -120,6 +123,9 @@ void print_params(SDParams params) {
|
||||
printf(" mode: %s\n", modes_str[params.mode]);
|
||||
printf(" model_path: %s\n", params.model_path.c_str());
|
||||
printf(" wtype: %s\n", params.wtype < SD_TYPE_COUNT ? sd_type_name(params.wtype) : "unspecified");
|
||||
printf(" clip_l_path: %s\n", params.clip_l_path.c_str());
|
||||
printf(" t5xxl_path: %s\n", params.t5xxl_path.c_str());
|
||||
printf(" diffusion_model_path: %s\n", params.diffusion_model_path.c_str());
|
||||
printf(" vae_path: %s\n", params.vae_path.c_str());
|
||||
printf(" taesd_path: %s\n", params.taesd_path.c_str());
|
||||
printf(" esrgan_path: %s\n", params.esrgan_path.c_str());
|
||||
@@ -140,6 +146,7 @@ void print_params(SDParams params) {
|
||||
printf(" negative_prompt: %s\n", params.negative_prompt.c_str());
|
||||
printf(" min_cfg: %.2f\n", params.min_cfg);
|
||||
printf(" cfg_scale: %.2f\n", params.cfg_scale);
|
||||
printf(" guidance: %.2f\n", params.guidance);
|
||||
printf(" clip_skip: %d\n", params.clip_skip);
|
||||
printf(" width: %d\n", params.width);
|
||||
printf(" height: %d\n", params.height);
|
||||
@@ -172,7 +179,7 @@ void print_usage(int argc, const char* argv[]) {
|
||||
printf(" --normalize-input normalize PHOTOMAKER input id images\n");
|
||||
printf(" --upscale-model [ESRGAN_PATH] path to esrgan model. Upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now.\n");
|
||||
printf(" --upscale-repeats Run the ESRGAN upscaler this many times (default 1)\n");
|
||||
printf(" --type [TYPE] weight type (f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0)\n");
|
||||
printf(" --type [TYPE] weight type (f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_k, q3_k, q4_k)\n");
|
||||
printf(" If not specified, the default is the type of the weight file.\n");
|
||||
printf(" --lora-model-dir [DIR] lora model directory\n");
|
||||
printf(" -i, --init-img [IMAGE] path to the input image, required by img2img\n");
|
||||
@@ -240,6 +247,24 @@ void parse_args(int argc, const char** argv, SDParams& params) {
|
||||
break;
|
||||
}
|
||||
params.model_path = argv[i];
|
||||
} else if (arg == "--clip_l") {
|
||||
if (++i >= argc) {
|
||||
invalid_arg = true;
|
||||
break;
|
||||
}
|
||||
params.clip_l_path = argv[i];
|
||||
} else if (arg == "--t5xxl") {
|
||||
if (++i >= argc) {
|
||||
invalid_arg = true;
|
||||
break;
|
||||
}
|
||||
params.t5xxl_path = argv[i];
|
||||
} else if (arg == "--diffusion-model") {
|
||||
if (++i >= argc) {
|
||||
invalid_arg = true;
|
||||
break;
|
||||
}
|
||||
params.diffusion_model_path = argv[i];
|
||||
} else if (arg == "--vae") {
|
||||
if (++i >= argc) {
|
||||
invalid_arg = true;
|
||||
@@ -302,8 +327,14 @@ void parse_args(int argc, const char** argv, SDParams& params) {
|
||||
params.wtype = SD_TYPE_Q5_1;
|
||||
} else if (type == "q8_0") {
|
||||
params.wtype = SD_TYPE_Q8_0;
|
||||
} else if (type == "q2_k") {
|
||||
params.wtype = SD_TYPE_Q2_K;
|
||||
} else if (type == "q3_k") {
|
||||
params.wtype = SD_TYPE_Q3_K;
|
||||
} else if (type == "q4_k") {
|
||||
params.wtype = SD_TYPE_Q4_K;
|
||||
} else {
|
||||
fprintf(stderr, "error: invalid weight format %s, must be one of [f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0]\n",
|
||||
fprintf(stderr, "error: invalid weight format %s, must be one of [f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_k, q3_k, q4_k]\n",
|
||||
type.c_str());
|
||||
exit(1);
|
||||
}
|
||||
@@ -359,6 +390,12 @@ void parse_args(int argc, const char** argv, SDParams& params) {
|
||||
break;
|
||||
}
|
||||
params.cfg_scale = std::stof(argv[i]);
|
||||
} else if (arg == "--guidance") {
|
||||
if (++i >= argc) {
|
||||
invalid_arg = true;
|
||||
break;
|
||||
}
|
||||
params.guidance = std::stof(argv[i]);
|
||||
} else if (arg == "--strength") {
|
||||
if (++i >= argc) {
|
||||
invalid_arg = true;
|
||||
@@ -501,8 +538,8 @@ void parse_args(int argc, const char** argv, SDParams& params) {
|
||||
exit(1);
|
||||
}
|
||||
|
||||
if (params.model_path.length() == 0) {
|
||||
fprintf(stderr, "error: the following arguments are required: model_path\n");
|
||||
if (params.model_path.length() == 0 && params.diffusion_model_path.length() == 0) {
|
||||
fprintf(stderr, "error: the following arguments are required: model_path/diffusion_model\n");
|
||||
print_usage(argc, argv);
|
||||
exit(1);
|
||||
}
|
||||
@@ -570,6 +607,7 @@ std::string get_image_params(SDParams params, int64_t seed) {
|
||||
}
|
||||
parameter_string += "Steps: " + std::to_string(params.sample_steps) + ", ";
|
||||
parameter_string += "CFG scale: " + std::to_string(params.cfg_scale) + ", ";
|
||||
parameter_string += "Guidance: " + std::to_string(params.guidance) + ", ";
|
||||
parameter_string += "Seed: " + std::to_string(seed) + ", ";
|
||||
parameter_string += "Size: " + std::to_string(params.width) + "x" + std::to_string(params.height) + ", ";
|
||||
parameter_string += "Model: " + sd_basename(params.model_path) + ", ";
|
||||
@@ -717,6 +755,9 @@ int main(int argc, const char* argv[]) {
|
||||
}
|
||||
|
||||
sd_ctx_t* sd_ctx = new_sd_ctx(params.model_path.c_str(),
|
||||
params.clip_l_path.c_str(),
|
||||
params.t5xxl_path.c_str(),
|
||||
params.diffusion_model_path.c_str(),
|
||||
params.vae_path.c_str(),
|
||||
params.taesd_path.c_str(),
|
||||
params.controlnet_path.c_str(),
|
||||
@@ -770,6 +811,7 @@ int main(int argc, const char* argv[]) {
|
||||
params.negative_prompt.c_str(),
|
||||
params.clip_skip,
|
||||
params.cfg_scale,
|
||||
params.guidance,
|
||||
params.width,
|
||||
params.height,
|
||||
params.sample_method,
|
||||
@@ -830,6 +872,7 @@ int main(int argc, const char* argv[]) {
|
||||
params.negative_prompt.c_str(),
|
||||
params.clip_skip,
|
||||
params.cfg_scale,
|
||||
params.guidance,
|
||||
params.width,
|
||||
params.height,
|
||||
params.sample_method,
|
||||
|
||||
Reference in New Issue
Block a user