feat: add convert api (#142)

This commit is contained in:
leejet
2024-01-14 11:43:24 +08:00
committed by GitHub
parent 2b6ec97fe2
commit 5c614e4bc2
5 changed files with 167 additions and 25 deletions

View File

@@ -42,11 +42,13 @@ const char* schedule_str[] = {
const char* modes_str[] = {
"txt2img",
"img2img",
"convert",
};
enum SDMode {
TXT2IMG,
IMG2IMG,
CONVERT,
MODE_COUNT
};
@@ -125,7 +127,7 @@ void print_usage(int argc, const char* argv[]) {
printf("\n");
printf("arguments:\n");
printf(" -h, --help show this help message and exit\n");
printf(" -M, --mode [txt2img or img2img] generation mode (default: txt2img)\n");
printf(" -M, --mode [MODEL] run mode (txt2img or img2img or convert, default: txt2img)\n");
printf(" -t, --threads N number of threads to use during computation (default: -1).\n");
printf(" If threads <= 0, then threads will be set to the number of CPU physical cores\n");
printf(" -m, --model [MODEL] path to model\n");
@@ -384,7 +386,7 @@ void parse_args(int argc, const char** argv, SDParams& params) {
params.n_threads = get_num_physical_cores();
}
if (params.prompt.length() == 0) {
if (params.mode != CONVERT && params.prompt.length() == 0) {
fprintf(stderr, "error: the following arguments are required: prompt\n");
print_usage(argc, argv);
exit(1);
@@ -432,6 +434,12 @@ void parse_args(int argc, const char** argv, SDParams& params) {
srand((int)time(NULL));
params.seed = rand();
}
if (params.mode == CONVERT) {
if (params.output_path == "output.png") {
params.output_path = "output.gguf";
}
}
}
std::string get_image_params(SDParams params, int64_t seed) {
@@ -479,6 +487,24 @@ int main(int argc, const char* argv[]) {
printf("%s", sd_get_system_info());
}
if (params.mode == CONVERT) {
bool success = convert(params.model_path.c_str(), params.vae_path.c_str(), params.output_path.c_str(), params.wtype);
if (!success) {
fprintf(stderr,
"convert '%s'/'%s' to '%s' failed\n",
params.model_path.c_str(),
params.vae_path.c_str(),
params.output_path.c_str());
return 1;
} else {
printf("convert '%s'/'%s' to '%s' success\n",
params.model_path.c_str(),
params.vae_path.c_str(),
params.output_path.c_str());
return 0;
}
}
bool vae_decode_only = true;
uint8_t* input_image_buffer = NULL;
if (params.mode == IMG2IMG) {