feat: implement ESRGAN upscaler + Metal Backend (#104)
* add esrgan upscaler * add sd_tiling * support metal backend * add clip_skip --------- Co-authored-by: leejet <leejet714@gmail.com>
This commit is contained in:
@@ -59,6 +59,7 @@ struct SDParams {
|
||||
std::string model_path;
|
||||
std::string vae_path;
|
||||
std::string taesd_path;
|
||||
std::string esrgan_path;
|
||||
ggml_type wtype = GGML_TYPE_COUNT;
|
||||
std::string lora_model_dir;
|
||||
std::string output_path = "output.png";
|
||||
@@ -67,6 +68,7 @@ struct SDParams {
|
||||
std::string prompt;
|
||||
std::string negative_prompt;
|
||||
float cfg_scale = 7.0f;
|
||||
int clip_skip = -1; // <= 0 represents unspecified
|
||||
int width = 512;
|
||||
int height = 512;
|
||||
int batch_count = 1;
|
||||
@@ -78,6 +80,7 @@ struct SDParams {
|
||||
RNGType rng_type = CUDA_RNG;
|
||||
int64_t seed = 42;
|
||||
bool verbose = false;
|
||||
bool vae_tiling = false;
|
||||
};
|
||||
|
||||
void print_params(SDParams params) {
|
||||
@@ -88,11 +91,13 @@ void print_params(SDParams params) {
|
||||
printf(" wtype: %s\n", params.wtype < GGML_TYPE_COUNT ? ggml_type_name(params.wtype) : "unspecified");
|
||||
printf(" vae_path: %s\n", params.vae_path.c_str());
|
||||
printf(" taesd_path: %s\n", params.taesd_path.c_str());
|
||||
printf(" esrgan_path: %s\n", params.esrgan_path.c_str());
|
||||
printf(" output_path: %s\n", params.output_path.c_str());
|
||||
printf(" init_img: %s\n", params.input_path.c_str());
|
||||
printf(" prompt: %s\n", params.prompt.c_str());
|
||||
printf(" negative_prompt: %s\n", params.negative_prompt.c_str());
|
||||
printf(" cfg_scale: %.2f\n", params.cfg_scale);
|
||||
printf(" clip_skip: %d\n", params.clip_skip);
|
||||
printf(" width: %d\n", params.width);
|
||||
printf(" height: %d\n", params.height);
|
||||
printf(" sample_method: %s\n", sample_method_str[params.sample_method]);
|
||||
@@ -102,6 +107,7 @@ void print_params(SDParams params) {
|
||||
printf(" rng: %s\n", rng_type_to_str[params.rng_type]);
|
||||
printf(" seed: %ld\n", params.seed);
|
||||
printf(" batch_count: %d\n", params.batch_count);
|
||||
printf(" vae_tiling: %s\n", params.vae_tiling ? "true" : "false");
|
||||
}
|
||||
|
||||
void print_usage(int argc, const char* argv[]) {
|
||||
@@ -115,6 +121,7 @@ void print_usage(int argc, const char* argv[]) {
|
||||
printf(" -m, --model [MODEL] path to model\n");
|
||||
printf(" --vae [VAE] path to vae\n");
|
||||
printf(" --taesd [TAESD_PATH] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)\n");
|
||||
printf(" --upscale-model [ESRGAN_PATH] path to esrgan model. Upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now.\n");
|
||||
printf(" --type [TYPE] weight type (f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0)\n");
|
||||
printf(" If not specified, the default is the type of the weight file.\n");
|
||||
printf(" --lora-model-dir [DIR] lora model directory\n");
|
||||
@@ -134,6 +141,9 @@ void print_usage(int argc, const char* argv[]) {
|
||||
printf(" -s SEED, --seed SEED RNG seed (default: 42, use random seed for < 0)\n");
|
||||
printf(" -b, --batch-count COUNT number of images to generate.\n");
|
||||
printf(" --schedule {discrete, karras} Denoiser sigma schedule (default: discrete)\n");
|
||||
printf(" --clip-skip N ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1)\n");
|
||||
printf(" <= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x\n");
|
||||
printf(" --vae-tiling process vae in tiles to reduce memory usage\n");
|
||||
printf(" -v, --verbose print extra info\n");
|
||||
}
|
||||
|
||||
@@ -185,6 +195,12 @@ void parse_args(int argc, const char** argv, SDParams& params) {
|
||||
break;
|
||||
}
|
||||
params.taesd_path = argv[i];
|
||||
} else if (arg == "--upscale-model") {
|
||||
if (++i >= argc) {
|
||||
invalid_arg = true;
|
||||
break;
|
||||
}
|
||||
params.esrgan_path = argv[i];
|
||||
} else if (arg == "--type") {
|
||||
if (++i >= argc) {
|
||||
invalid_arg = true;
|
||||
@@ -270,6 +286,14 @@ void parse_args(int argc, const char** argv, SDParams& params) {
|
||||
break;
|
||||
}
|
||||
params.sample_steps = std::stoi(argv[i]);
|
||||
} else if (arg == "--clip-skip") {
|
||||
if (++i >= argc) {
|
||||
invalid_arg = true;
|
||||
break;
|
||||
}
|
||||
params.clip_skip = std::stoi(argv[i]);
|
||||
} else if (arg == "--vae-tiling") {
|
||||
params.vae_tiling = true;
|
||||
} else if (arg == "-b" || arg == "--batch-count") {
|
||||
if (++i >= argc) {
|
||||
invalid_arg = true;
|
||||
@@ -458,9 +482,9 @@ int main(int argc, const char* argv[]) {
|
||||
}
|
||||
}
|
||||
|
||||
StableDiffusion sd(params.n_threads, vae_decode_only, params.taesd_path, true, params.lora_model_dir, params.rng_type);
|
||||
StableDiffusion sd(params.n_threads, vae_decode_only, params.taesd_path, params.esrgan_path, true, params.vae_tiling, params.lora_model_dir, params.rng_type);
|
||||
|
||||
if (!sd.load_from_file(params.model_path, params.vae_path, params.wtype, params.schedule)) {
|
||||
if (!sd.load_from_file(params.model_path, params.vae_path, params.wtype, params.schedule, params.clip_skip)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
@@ -488,6 +512,19 @@ int main(int argc, const char* argv[]) {
|
||||
params.seed);
|
||||
}
|
||||
|
||||
if (params.esrgan_path.size() > 0) {
|
||||
// TODO: support more ESRGAN models, making it easier to set up ESRGAN models.
|
||||
/* hardcoded scale factor because just RealESRGAN_x4plus_anime_6B is compatible
|
||||
See also: https://github.com/xinntao/Real-ESRGAN/blob/master/inference_realesrgan.py
|
||||
|
||||
To avoid this, the upscaler needs to be separated from the stable diffusion pipeline.
|
||||
However, a considerable amount of work would be required for this. It might be better
|
||||
to opt for a complete project refactoring that facilitates the easier assignment of parameters.
|
||||
*/
|
||||
params.width *= 4;
|
||||
params.height *= 4;
|
||||
}
|
||||
|
||||
if (results.size() == 0 || results.size() != params.batch_count) {
|
||||
LOG_ERROR("generate failed");
|
||||
return 1;
|
||||
|
||||
Reference in New Issue
Block a user