stable-diffusion.cpp/examples/cli/main.cpp

496 lines
18 KiB
C++

#include <stdio.h>
#include <ctime>
#include <random>
#include "ggml/ggml.h"
#include "stable-diffusion.h"
#include "util.h"
#define STB_IMAGE_IMPLEMENTATION
#include "stb_image.h"
#define STB_IMAGE_WRITE_IMPLEMENTATION
#define STB_IMAGE_WRITE_STATIC
#include "stb_image_write.h"
#include <cstring>
#include <iostream>
#include <string>
#include <vector>
const char* rng_type_to_str[] = {
"std_default",
"cuda",
};
// Names of the sampler method, same order as enum sample_method in stable-diffusion.h
const char* sample_method_str[] = {
"euler_a",
"euler",
"heun",
"dpm2",
"dpm++2s_a",
"dpm++2m",
"dpm++2mv2",
"lcm",
};
// Names of the sigma schedule overrides, same order as sample_schedule in stable-diffusion.h
const char* schedule_str[] = {
"default",
"discrete",
"karras",
};
const char* modes_str[] = {
"txt2img",
"img2img",
};
enum SDMode {
TXT2IMG,
IMG2IMG,
MODE_COUNT
};
struct SDParams {
int n_threads = -1;
SDMode mode = TXT2IMG;
std::string model_path;
std::string vae_path;
ggml_type wtype = GGML_TYPE_COUNT;
std::string lora_model_dir;
std::string output_path = "output.png";
std::string input_path;
std::string prompt;
std::string negative_prompt;
float cfg_scale = 7.0f;
int width = 512;
int height = 512;
int batch_count = 1;
SampleMethod sample_method = EULER_A;
Schedule schedule = DEFAULT;
int sample_steps = 20;
float strength = 0.75f;
RNGType rng_type = CUDA_RNG;
int64_t seed = 42;
bool verbose = false;
};
void print_params(SDParams params) {
printf("Option: \n");
printf(" n_threads: %d\n", params.n_threads);
printf(" mode: %s\n", modes_str[params.mode]);
printf(" model_path: %s\n", params.model_path.c_str());
printf(" wtype: %s\n", params.wtype < GGML_TYPE_COUNT ? ggml_type_name(params.wtype) : "unspecified");
printf(" vae_path: %s\n", params.vae_path.c_str());
printf(" output_path: %s\n", params.output_path.c_str());
printf(" init_img: %s\n", params.input_path.c_str());
printf(" prompt: %s\n", params.prompt.c_str());
printf(" negative_prompt: %s\n", params.negative_prompt.c_str());
printf(" cfg_scale: %.2f\n", params.cfg_scale);
printf(" width: %d\n", params.width);
printf(" height: %d\n", params.height);
printf(" sample_method: %s\n", sample_method_str[params.sample_method]);
printf(" schedule: %s\n", schedule_str[params.schedule]);
printf(" sample_steps: %d\n", params.sample_steps);
printf(" strength(img2img): %.2f\n", params.strength);
printf(" rng: %s\n", rng_type_to_str[params.rng_type]);
printf(" seed: %ld\n", params.seed);
printf(" batch_count: %d\n", params.batch_count);
}
void print_usage(int argc, const char* argv[]) {
printf("usage: %s [arguments]\n", argv[0]);
printf("\n");
printf("arguments:\n");
printf(" -h, --help show this help message and exit\n");
printf(" -M, --mode [txt2img or img2img] generation mode (default: txt2img)\n");
printf(" -t, --threads N number of threads to use during computation (default: -1).\n");
printf(" If threads <= 0, then threads will be set to the number of CPU physical cores\n");
printf(" -m, --model [MODEL] path to model\n");
printf(" --vae [VAE] path to vae\n");
printf(" --type [TYPE] weight type (f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0)\n");
printf(" If not specified, the default is the type of the weight file.");
printf(" --lora-model-dir [DIR] lora model directory\n");
printf(" -i, --init-img [IMAGE] path to the input image, required by img2img\n");
printf(" -o, --output OUTPUT path to write result image to (default: ./output.png)\n");
printf(" -p, --prompt [PROMPT] the prompt to render\n");
printf(" -n, --negative-prompt PROMPT the negative prompt (default: \"\")\n");
printf(" --cfg-scale SCALE unconditional guidance scale: (default: 7.0)\n");
printf(" --strength STRENGTH strength for noising/unnoising (default: 0.75)\n");
printf(" 1.0 corresponds to full destruction of information in init image\n");
printf(" -H, --height H image height, in pixel space (default: 512)\n");
printf(" -W, --width W image width, in pixel space (default: 512)\n");
printf(" --sampling-method {euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, lcm}\n");
printf(" sampling method (default: \"euler_a\")\n");
printf(" --steps STEPS number of sample steps (default: 20)\n");
printf(" --rng {std_default, cuda} RNG (default: cuda)\n");
printf(" -s SEED, --seed SEED RNG seed (default: 42, use random seed for < 0)\n");
printf(" -b, --batch-count COUNT number of images to generate.\n");
printf(" --schedule {discrete, karras} Denoiser sigma schedule (default: discrete)\n");
printf(" -v, --verbose print extra info\n");
}
void parse_args(int argc, const char** argv, SDParams& params) {
bool invalid_arg = false;
std::string arg;
for (int i = 1; i < argc; i++) {
arg = argv[i];
if (arg == "-t" || arg == "--threads") {
if (++i >= argc) {
invalid_arg = true;
break;
}
params.n_threads = std::stoi(argv[i]);
} else if (arg == "-M" || arg == "--mode") {
if (++i >= argc) {
invalid_arg = true;
break;
}
const char* mode_selected = argv[i];
int mode_found = -1;
for (int d = 0; d < MODE_COUNT; d++) {
if (!strcmp(mode_selected, modes_str[d])) {
mode_found = d;
}
}
if (mode_found == -1) {
fprintf(stderr, "error: invalid mode %s, must be one of [txt2img, img2img]\n",
mode_selected);
exit(1);
}
params.mode = (SDMode)mode_found;
} else if (arg == "-m" || arg == "--model") {
if (++i >= argc) {
invalid_arg = true;
break;
}
params.model_path = argv[i];
} else if (arg == "--vae") {
if (++i >= argc) {
invalid_arg = true;
break;
}
params.vae_path = argv[i];
} else if (arg == "--type") {
if (++i >= argc) {
invalid_arg = true;
break;
}
std::string type = argv[i];
if (type == "f32") {
params.wtype = GGML_TYPE_F32;
} else if (type == "f16") {
params.wtype = GGML_TYPE_F16;
} else if (type == "q4_0") {
params.wtype = GGML_TYPE_Q4_0;
} else if (type == "q4_1") {
params.wtype = GGML_TYPE_Q4_1;
} else if (type == "q5_0") {
params.wtype = GGML_TYPE_Q5_0;
} else if (type == "q5_1") {
params.wtype = GGML_TYPE_Q5_1;
} else if (type == "q8_0") {
params.wtype = GGML_TYPE_Q8_0;
} else {
fprintf(stderr, "error: invalid weight format %s, must be one of [f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0]\n",
type.c_str());
exit(1);
}
} else if (arg == "--lora-model-dir") {
if (++i >= argc) {
invalid_arg = true;
break;
}
params.lora_model_dir = argv[i];
} else if (arg == "-i" || arg == "--init-img") {
if (++i >= argc) {
invalid_arg = true;
break;
}
params.input_path = argv[i];
} else if (arg == "-o" || arg == "--output") {
if (++i >= argc) {
invalid_arg = true;
break;
}
params.output_path = argv[i];
} else if (arg == "-p" || arg == "--prompt") {
if (++i >= argc) {
invalid_arg = true;
break;
}
params.prompt = argv[i];
} else if (arg == "-n" || arg == "--negative-prompt") {
if (++i >= argc) {
invalid_arg = true;
break;
}
params.negative_prompt = argv[i];
} else if (arg == "--cfg-scale") {
if (++i >= argc) {
invalid_arg = true;
break;
}
params.cfg_scale = std::stof(argv[i]);
} else if (arg == "--strength") {
if (++i >= argc) {
invalid_arg = true;
break;
}
params.strength = std::stof(argv[i]);
} else if (arg == "-H" || arg == "--height") {
if (++i >= argc) {
invalid_arg = true;
break;
}
params.height = std::stoi(argv[i]);
} else if (arg == "-W" || arg == "--width") {
if (++i >= argc) {
invalid_arg = true;
break;
}
params.width = std::stoi(argv[i]);
} else if (arg == "--steps") {
if (++i >= argc) {
invalid_arg = true;
break;
}
params.sample_steps = std::stoi(argv[i]);
} else if (arg == "-b" || arg == "--batch-count") {
if (++i >= argc) {
invalid_arg = true;
break;
}
params.batch_count = std::stoi(argv[i]);
} else if (arg == "--rng") {
if (++i >= argc) {
invalid_arg = true;
break;
}
std::string rng_type_str = argv[i];
if (rng_type_str == "std_default") {
params.rng_type = STD_DEFAULT_RNG;
} else if (rng_type_str == "cuda") {
params.rng_type = CUDA_RNG;
} else {
invalid_arg = true;
break;
}
} else if (arg == "--schedule") {
if (++i >= argc) {
invalid_arg = true;
break;
}
const char* schedule_selected = argv[i];
int schedule_found = -1;
for (int d = 0; d < N_SCHEDULES; d++) {
if (!strcmp(schedule_selected, schedule_str[d])) {
schedule_found = d;
}
}
if (schedule_found == -1) {
invalid_arg = true;
break;
}
params.schedule = (Schedule)schedule_found;
} else if (arg == "-s" || arg == "--seed") {
if (++i >= argc) {
invalid_arg = true;
break;
}
params.seed = std::stoll(argv[i]);
} else if (arg == "--sampling-method") {
if (++i >= argc) {
invalid_arg = true;
break;
}
const char* sample_method_selected = argv[i];
int sample_method_found = -1;
for (int m = 0; m < N_SAMPLE_METHODS; m++) {
if (!strcmp(sample_method_selected, sample_method_str[m])) {
sample_method_found = m;
}
}
if (sample_method_found == -1) {
invalid_arg = true;
break;
}
params.sample_method = (SampleMethod)sample_method_found;
} else if (arg == "-h" || arg == "--help") {
print_usage(argc, argv);
exit(0);
} else if (arg == "-v" || arg == "--verbose") {
params.verbose = true;
} else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
print_usage(argc, argv);
exit(1);
}
}
if (invalid_arg) {
fprintf(stderr, "error: invalid parameter for argument: %s\n", arg.c_str());
print_usage(argc, argv);
exit(1);
}
if (params.n_threads <= 0) {
params.n_threads = get_num_physical_cores();
}
if (params.prompt.length() == 0) {
fprintf(stderr, "error: the following arguments are required: prompt\n");
print_usage(argc, argv);
exit(1);
}
if (params.model_path.length() == 0) {
fprintf(stderr, "error: the following arguments are required: model_path\n");
print_usage(argc, argv);
exit(1);
}
if (params.mode == IMG2IMG && params.input_path.length() == 0) {
fprintf(stderr, "error: when using the img2img mode, the following arguments are required: init-img\n");
print_usage(argc, argv);
exit(1);
}
if (params.output_path.length() == 0) {
fprintf(stderr, "error: the following arguments are required: output_path\n");
print_usage(argc, argv);
exit(1);
}
if (params.width <= 0 || params.width % 64 != 0) {
fprintf(stderr, "error: the width must be a multiple of 64\n");
exit(1);
}
if (params.height <= 0 || params.height % 64 != 0) {
fprintf(stderr, "error: the height must be a multiple of 64\n");
exit(1);
}
if (params.sample_steps <= 0) {
fprintf(stderr, "error: the sample_steps must be greater than 0\n");
exit(1);
}
if (params.strength < 0.f || params.strength > 1.f) {
fprintf(stderr, "error: can only work with strength in [0.0, 1.0]\n");
exit(1);
}
if (params.seed < 0) {
srand((int)time(NULL));
params.seed = rand();
}
}
std::string get_image_params(SDParams params, int64_t seed) {
std::string parameter_string = params.prompt + "\n";
if (params.negative_prompt.size() != 0) {
parameter_string += "Negative prompt: " + params.negative_prompt + "\n";
}
parameter_string += "Steps: " + std::to_string(params.sample_steps) + ", ";
parameter_string += "CFG scale: " + std::to_string(params.cfg_scale) + ", ";
parameter_string += "Seed: " + std::to_string(seed) + ", ";
parameter_string += "Size: " + std::to_string(params.width) + "x" + std::to_string(params.height) + ", ";
parameter_string += "Model: " + basename(params.model_path) + ", ";
parameter_string += "RNG: " + std::string(rng_type_to_str[params.rng_type]) + ", ";
parameter_string += "Sampler: " + std::string(sample_method_str[params.sample_method]);
if (params.schedule == KARRAS) {
parameter_string += " karras";
}
parameter_string += ", ";
parameter_string += "Version: stable-diffusion.cpp";
return parameter_string;
}
int main(int argc, const char* argv[]) {
SDParams params;
parse_args(argc, argv, params);
if (params.verbose) {
print_params(params);
printf("%s", sd_get_system_info().c_str());
set_sd_log_level(SDLogLevel::DEBUG);
}
bool vae_decode_only = true;
uint8_t* input_image_buffer = NULL;
if (params.mode == IMG2IMG) {
vae_decode_only = false;
int c = 0;
input_image_buffer = stbi_load(params.input_path.c_str(), &params.width, &params.height, &c, 3);
if (input_image_buffer == NULL) {
fprintf(stderr, "load image from '%s' failed\n", params.input_path.c_str());
return 1;
}
if (c != 3) {
fprintf(stderr, "input image must be a 3 channels RGB image, but got %d channels\n", c);
free(input_image_buffer);
return 1;
}
if (params.width <= 0 || params.width % 64 != 0) {
fprintf(stderr, "error: the width of image must be a multiple of 64\n");
free(input_image_buffer);
return 1;
}
if (params.height <= 0 || params.height % 64 != 0) {
fprintf(stderr, "error: the height of image must be a multiple of 64\n");
free(input_image_buffer);
return 1;
}
}
StableDiffusion sd(params.n_threads, vae_decode_only, true, params.lora_model_dir, params.rng_type);
if (!sd.load_from_file(params.model_path, params.vae_path, params.wtype, params.schedule)) {
return 1;
}
std::vector<uint8_t*> results;
if (params.mode == TXT2IMG) {
results = sd.txt2img(params.prompt,
params.negative_prompt,
params.cfg_scale,
params.width,
params.height,
params.sample_method,
params.sample_steps,
params.seed,
params.batch_count);
} else {
results = sd.img2img(input_image_buffer,
params.prompt,
params.negative_prompt,
params.cfg_scale,
params.width,
params.height,
params.sample_method,
params.sample_steps,
params.strength,
params.seed);
}
if (results.size() == 0 || results.size() != params.batch_count) {
LOG_ERROR("generate failed");
return 1;
}
size_t last = params.output_path.find_last_of(".");
std::string dummy_name = last != std::string::npos ? params.output_path.substr(0, last) : params.output_path;
for (int i = 0; i < params.batch_count; i++) {
std::string final_image_path = i > 0 ? dummy_name + "_" + std::to_string(i + 1) + ".png" : dummy_name + ".png";
stbi_write_png(final_image_path.c_str(), params.width, params.height, 3, results[i], 0, get_image_params(params, params.seed + i).c_str());
LOG_INFO("save result image to '%s'", final_image_path.c_str());
}
return 0;
}