refactor: reorganize code and use c api (#133)
This commit is contained in:
@@ -1,9 +1,12 @@
|
||||
#include <stdio.h>
|
||||
#include <ctime>
|
||||
#include <string.h>
|
||||
#include <time.h>
|
||||
#include <iostream>
|
||||
#include <random>
|
||||
#include "ggml/ggml.h"
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "stable-diffusion.h"
|
||||
#include "util.h"
|
||||
|
||||
#define STB_IMAGE_IMPLEMENTATION
|
||||
#include "stb_image.h"
|
||||
@@ -12,11 +15,6 @@
|
||||
#define STB_IMAGE_WRITE_STATIC
|
||||
#include "stb_image_write.h"
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
const char* rng_type_to_str[] = {
|
||||
"std_default",
|
||||
"cuda",
|
||||
@@ -60,7 +58,7 @@ struct SDParams {
|
||||
std::string vae_path;
|
||||
std::string taesd_path;
|
||||
std::string esrgan_path;
|
||||
ggml_type wtype = GGML_TYPE_COUNT;
|
||||
sd_type_t wtype = SD_TYPE_COUNT;
|
||||
std::string lora_model_dir;
|
||||
std::string output_path = "output.png";
|
||||
std::string input_path;
|
||||
@@ -73,22 +71,34 @@ struct SDParams {
|
||||
int height = 512;
|
||||
int batch_count = 1;
|
||||
|
||||
SampleMethod sample_method = EULER_A;
|
||||
Schedule schedule = DEFAULT;
|
||||
int sample_steps = 20;
|
||||
float strength = 0.75f;
|
||||
RNGType rng_type = CUDA_RNG;
|
||||
int64_t seed = 42;
|
||||
bool verbose = false;
|
||||
bool vae_tiling = false;
|
||||
sample_method_t sample_method = EULER_A;
|
||||
schedule_t schedule = DEFAULT;
|
||||
int sample_steps = 20;
|
||||
float strength = 0.75f;
|
||||
rng_type_t rng_type = CUDA_RNG;
|
||||
int64_t seed = 42;
|
||||
bool verbose = false;
|
||||
bool vae_tiling = false;
|
||||
};
|
||||
|
||||
static std::string sd_basename(const std::string& path) {
|
||||
size_t pos = path.find_last_of('/');
|
||||
if (pos != std::string::npos) {
|
||||
return path.substr(pos + 1);
|
||||
}
|
||||
pos = path.find_last_of('\\');
|
||||
if (pos != std::string::npos) {
|
||||
return path.substr(pos + 1);
|
||||
}
|
||||
return path;
|
||||
}
|
||||
|
||||
void print_params(SDParams params) {
|
||||
printf("Option: \n");
|
||||
printf(" n_threads: %d\n", params.n_threads);
|
||||
printf(" mode: %s\n", modes_str[params.mode]);
|
||||
printf(" model_path: %s\n", params.model_path.c_str());
|
||||
printf(" wtype: %s\n", params.wtype < GGML_TYPE_COUNT ? ggml_type_name(params.wtype) : "unspecified");
|
||||
printf(" wtype: %s\n", params.wtype < SD_TYPE_COUNT ? sd_type_name(params.wtype) : "unspecified");
|
||||
printf(" vae_path: %s\n", params.vae_path.c_str());
|
||||
printf(" taesd_path: %s\n", params.taesd_path.c_str());
|
||||
printf(" esrgan_path: %s\n", params.esrgan_path.c_str());
|
||||
@@ -208,19 +218,19 @@ void parse_args(int argc, const char** argv, SDParams& params) {
|
||||
}
|
||||
std::string type = argv[i];
|
||||
if (type == "f32") {
|
||||
params.wtype = GGML_TYPE_F32;
|
||||
params.wtype = SD_TYPE_F32;
|
||||
} else if (type == "f16") {
|
||||
params.wtype = GGML_TYPE_F16;
|
||||
params.wtype = SD_TYPE_F16;
|
||||
} else if (type == "q4_0") {
|
||||
params.wtype = GGML_TYPE_Q4_0;
|
||||
params.wtype = SD_TYPE_Q4_0;
|
||||
} else if (type == "q4_1") {
|
||||
params.wtype = GGML_TYPE_Q4_1;
|
||||
params.wtype = SD_TYPE_Q4_1;
|
||||
} else if (type == "q5_0") {
|
||||
params.wtype = GGML_TYPE_Q5_0;
|
||||
params.wtype = SD_TYPE_Q5_0;
|
||||
} else if (type == "q5_1") {
|
||||
params.wtype = GGML_TYPE_Q5_1;
|
||||
params.wtype = SD_TYPE_Q5_1;
|
||||
} else if (type == "q8_0") {
|
||||
params.wtype = GGML_TYPE_Q8_0;
|
||||
params.wtype = SD_TYPE_Q8_0;
|
||||
} else {
|
||||
fprintf(stderr, "error: invalid weight format %s, must be one of [f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0]\n",
|
||||
type.c_str());
|
||||
@@ -330,7 +340,7 @@ void parse_args(int argc, const char** argv, SDParams& params) {
|
||||
invalid_arg = true;
|
||||
break;
|
||||
}
|
||||
params.schedule = (Schedule)schedule_found;
|
||||
params.schedule = (schedule_t)schedule_found;
|
||||
} else if (arg == "-s" || arg == "--seed") {
|
||||
if (++i >= argc) {
|
||||
invalid_arg = true;
|
||||
@@ -353,7 +363,7 @@ void parse_args(int argc, const char** argv, SDParams& params) {
|
||||
invalid_arg = true;
|
||||
break;
|
||||
}
|
||||
params.sample_method = (SampleMethod)sample_method_found;
|
||||
params.sample_method = (sample_method_t)sample_method_found;
|
||||
} else if (arg == "-h" || arg == "--help") {
|
||||
print_usage(argc, argv);
|
||||
exit(0);
|
||||
@@ -433,7 +443,7 @@ std::string get_image_params(SDParams params, int64_t seed) {
|
||||
parameter_string += "CFG scale: " + std::to_string(params.cfg_scale) + ", ";
|
||||
parameter_string += "Seed: " + std::to_string(seed) + ", ";
|
||||
parameter_string += "Size: " + std::to_string(params.width) + "x" + std::to_string(params.height) + ", ";
|
||||
parameter_string += "Model: " + basename(params.model_path) + ", ";
|
||||
parameter_string += "Model: " + sd_basename(params.model_path) + ", ";
|
||||
parameter_string += "RNG: " + std::string(rng_type_to_str[params.rng_type]) + ", ";
|
||||
parameter_string += "Sampler: " + std::string(sample_method_str[params.sample_method]);
|
||||
if (params.schedule == KARRAS) {
|
||||
@@ -444,14 +454,29 @@ std::string get_image_params(SDParams params, int64_t seed) {
|
||||
return parameter_string;
|
||||
}
|
||||
|
||||
void sd_log_cb(enum sd_log_level_t level, const char* log, void* data) {
|
||||
SDParams* params = (SDParams*)data;
|
||||
if (!params->verbose && level <= SD_LOG_DEBUG) {
|
||||
return;
|
||||
}
|
||||
if (level <= SD_LOG_INFO) {
|
||||
fprintf(stdout, log);
|
||||
fflush(stdout);
|
||||
} else {
|
||||
fprintf(stderr, log);
|
||||
fflush(stderr);
|
||||
}
|
||||
};
|
||||
|
||||
int main(int argc, const char* argv[]) {
|
||||
SDParams params;
|
||||
parse_args(argc, argv, params);
|
||||
|
||||
sd_set_log_callback(sd_log_cb, (void*)¶ms);
|
||||
|
||||
if (params.verbose) {
|
||||
print_params(params);
|
||||
printf("%s", sd_get_system_info().c_str());
|
||||
set_sd_log_level(SDLogLevel::DEBUG);
|
||||
printf("%s", sd_get_system_info());
|
||||
}
|
||||
|
||||
bool vae_decode_only = true;
|
||||
@@ -482,60 +507,98 @@ int main(int argc, const char* argv[]) {
|
||||
}
|
||||
}
|
||||
|
||||
StableDiffusion sd(params.n_threads, vae_decode_only, params.taesd_path, params.esrgan_path, true, params.vae_tiling, params.lora_model_dir, params.rng_type);
|
||||
sd_ctx_t* sd_ctx = new_sd_ctx(params.model_path.c_str(),
|
||||
params.vae_path.c_str(),
|
||||
params.taesd_path.c_str(),
|
||||
params.lora_model_dir.c_str(),
|
||||
vae_decode_only,
|
||||
params.vae_tiling,
|
||||
true,
|
||||
params.n_threads,
|
||||
params.wtype,
|
||||
params.rng_type,
|
||||
params.schedule);
|
||||
|
||||
if (!sd.load_from_file(params.model_path, params.vae_path, params.wtype, params.schedule, params.clip_skip)) {
|
||||
if (sd_ctx == NULL) {
|
||||
printf("new_sd_ctx_t failed\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
std::vector<uint8_t*> results;
|
||||
sd_image_t* results;
|
||||
if (params.mode == TXT2IMG) {
|
||||
results = sd.txt2img(params.prompt,
|
||||
params.negative_prompt,
|
||||
params.cfg_scale,
|
||||
params.width,
|
||||
params.height,
|
||||
params.sample_method,
|
||||
params.sample_steps,
|
||||
params.seed,
|
||||
params.batch_count);
|
||||
results = txt2img(sd_ctx,
|
||||
params.prompt.c_str(),
|
||||
params.negative_prompt.c_str(),
|
||||
params.clip_skip,
|
||||
params.cfg_scale,
|
||||
params.width,
|
||||
params.height,
|
||||
params.sample_method,
|
||||
params.sample_steps,
|
||||
params.seed,
|
||||
params.batch_count);
|
||||
} else {
|
||||
results = sd.img2img(input_image_buffer,
|
||||
params.prompt,
|
||||
params.negative_prompt,
|
||||
params.cfg_scale,
|
||||
params.width,
|
||||
params.height,
|
||||
params.sample_method,
|
||||
params.sample_steps,
|
||||
params.strength,
|
||||
params.seed);
|
||||
sd_image_t input_image = {(uint32_t)params.width,
|
||||
(uint32_t)params.height,
|
||||
3,
|
||||
input_image_buffer};
|
||||
|
||||
results = img2img(sd_ctx,
|
||||
input_image,
|
||||
params.prompt.c_str(),
|
||||
params.negative_prompt.c_str(),
|
||||
params.clip_skip,
|
||||
params.cfg_scale,
|
||||
params.width,
|
||||
params.height,
|
||||
params.sample_method,
|
||||
params.sample_steps,
|
||||
params.strength,
|
||||
params.seed,
|
||||
params.batch_count);
|
||||
}
|
||||
|
||||
if (params.esrgan_path.size() > 0) {
|
||||
// TODO: support more ESRGAN models, making it easier to set up ESRGAN models.
|
||||
/* hardcoded scale factor because just RealESRGAN_x4plus_anime_6B is compatible
|
||||
See also: https://github.com/xinntao/Real-ESRGAN/blob/master/inference_realesrgan.py
|
||||
|
||||
To avoid this, the upscaler needs to be separated from the stable diffusion pipeline.
|
||||
However, a considerable amount of work would be required for this. It might be better
|
||||
to opt for a complete project refactoring that facilitates the easier assignment of parameters.
|
||||
*/
|
||||
params.width *= 4;
|
||||
params.height *= 4;
|
||||
}
|
||||
|
||||
if (results.size() == 0 || results.size() != params.batch_count) {
|
||||
LOG_ERROR("generate failed");
|
||||
if (results == NULL) {
|
||||
printf("generate failed\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
int upscale_factor = 4; // unused for RealESRGAN_x4plus_anime_6B.pth
|
||||
if (params.esrgan_path.size() > 0) {
|
||||
upscaler_ctx_t* upscaler_ctx = new_upscaler_ctx(params.esrgan_path.c_str(),
|
||||
params.n_threads,
|
||||
params.wtype);
|
||||
|
||||
if (upscaler_ctx == NULL) {
|
||||
printf("new_upscaler_ctx failed\n");
|
||||
} else {
|
||||
for (int i = 0; i < params.batch_count; i++) {
|
||||
if (results[i].data == NULL) {
|
||||
continue;
|
||||
}
|
||||
sd_image_t upscaled_image = upscale(upscaler_ctx, results[i], upscale_factor);
|
||||
if (upscaled_image.data == NULL) {
|
||||
printf("upscale failed\n");
|
||||
continue;
|
||||
}
|
||||
free(results[i].data);
|
||||
results[i] = upscaled_image;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
size_t last = params.output_path.find_last_of(".");
|
||||
std::string dummy_name = last != std::string::npos ? params.output_path.substr(0, last) : params.output_path;
|
||||
for (int i = 0; i < params.batch_count; i++) {
|
||||
if (results[i].data == NULL) {
|
||||
continue;
|
||||
}
|
||||
std::string final_image_path = i > 0 ? dummy_name + "_" + std::to_string(i + 1) + ".png" : dummy_name + ".png";
|
||||
stbi_write_png(final_image_path.c_str(), params.width, params.height, 3, results[i], 0, get_image_params(params, params.seed + i).c_str());
|
||||
LOG_INFO("save result image to '%s'", final_image_path.c_str());
|
||||
stbi_write_png(final_image_path.c_str(), results[i].width, results[i].height, results[i].channel,
|
||||
results[i].data, 0, get_image_params(params, params.seed + i).c_str());
|
||||
printf("save result image to '%s'\n", final_image_path.c_str());
|
||||
free(results[i].data);
|
||||
results[i].data = NULL;
|
||||
}
|
||||
|
||||
return 0;
|
||||
|
||||
Reference in New Issue
Block a user