stable-diffusion.cpp/examples/cli/main.cpp

96 lines
3.4 KiB
C++

#include <stdio.h>
#include <ctime>
#include <random>
#include "common.h"
#include "stable-diffusion.h"
#define STB_IMAGE_IMPLEMENTATION
#include "stb_image.h"
#define STB_IMAGE_WRITE_IMPLEMENTATION
#define STB_IMAGE_WRITE_STATIC
#include "stb_image_write.h"
int main(int argc, const char* argv[]) {
SDParams params;
parse_args(argc, argv, params);
if (params.verbose) {
print_params(params);
printf("%s", sd_get_system_info().c_str());
set_sd_log_level(SDLogLevel::DEBUG);
}
bool vae_decode_only = true;
uint8_t* input_image_buffer = NULL;
if (params.mode == IMG2IMG) {
vae_decode_only = false;
int c = 0;
input_image_buffer = stbi_load(params.input_path.c_str(), &params.width, &params.height, &c, 3);
if (input_image_buffer == NULL) {
fprintf(stderr, "load image from '%s' failed\n", params.input_path.c_str());
return 1;
}
if (c != 3) {
fprintf(stderr, "input image must be a 3 channels RGB image, but got %d channels\n", c);
free(input_image_buffer);
return 1;
}
if (params.width <= 0 || params.width % 64 != 0) {
fprintf(stderr, "error: the width of image must be a multiple of 64\n");
free(input_image_buffer);
return 1;
}
if (params.height <= 0 || params.height % 64 != 0) {
fprintf(stderr, "error: the height of image must be a multiple of 64\n");
free(input_image_buffer);
return 1;
}
}
StableDiffusion sd(params.n_threads, vae_decode_only, true, params.lora_model_dir, params.rng_type);
if (!sd.load_from_file(params.model_path, params.schedule)) {
return 1;
}
std::vector<uint8_t*> results;
if (params.mode == TXT2IMG) {
results = sd.txt2img(params.prompt,
params.negative_prompt,
params.cfg_scale,
params.width,
params.height,
params.sample_method,
params.sample_steps,
params.seed,
params.batch_count);
} else {
results = sd.img2img(input_image_buffer,
params.prompt,
params.negative_prompt,
params.cfg_scale,
params.width,
params.height,
params.sample_method,
params.sample_steps,
params.strength,
params.seed);
}
if (results.size() == 0 || results.size() != params.batch_count) {
fprintf(stderr, "generate failed\n");
return 1;
}
size_t last = params.output_path.find_last_of(".");
std::string dummy_name = last != std::string::npos ? params.output_path.substr(0, last) : params.output_path;
for (int i = 0; i < params.batch_count; i++) {
std::string final_image_path = i > 0 ? dummy_name + "_" + std::to_string(i + 1) + ".png" : dummy_name + ".png";
stbi_write_png(final_image_path.c_str(), params.width, params.height, 3, results[i], 0, get_image_params(params, params.seed + i).c_str());
printf("save result image to '%s'\n", final_image_path.c_str());
}
return 0;
}