feat: introduce GGMLBlock and implement SVD(Broken) (#159)

* introduce GGMLBlock and implement SVD(Broken)

* add sdxl vae warning
This commit is contained in:
leejet
2024-02-24 20:06:39 +08:00
committed by GitHub
parent 349439f239
commit b6368868d9
20 changed files with 4137 additions and 3818 deletions

View File

@@ -43,12 +43,14 @@ const char* schedule_str[] = {
const char* modes_str[] = {
"txt2img",
"img2img",
"img2vid",
"convert",
};
enum SDMode {
TXT2IMG,
IMG2IMG,
IMG2VID,
CONVERT,
MODE_COUNT
};
@@ -71,12 +73,18 @@ struct SDParams {
std::string prompt;
std::string negative_prompt;
float min_cfg = 1.0f;
float cfg_scale = 7.0f;
int clip_skip = -1; // <= 0 represents unspecified
int width = 512;
int height = 512;
int batch_count = 1;
int video_frames = 6;
int motion_bucket_id = 127;
int fps = 6;
float augmentation_level = 0.f;
sample_method_t sample_method = EULER_A;
schedule_t schedule = DEFAULT;
int sample_steps = 20;
@@ -108,6 +116,7 @@ void print_params(SDParams params) {
printf(" strength(control): %.2f\n", params.control_strength);
printf(" prompt: %s\n", params.prompt.c_str());
printf(" negative_prompt: %s\n", params.negative_prompt.c_str());
printf(" min_cfg: %.2f\n", params.min_cfg);
printf(" cfg_scale: %.2f\n", params.cfg_scale);
printf(" clip_skip: %d\n", params.clip_skip);
printf(" width: %d\n", params.width);
@@ -190,7 +199,8 @@ void parse_args(int argc, const char** argv, SDParams& params) {
}
}
if (mode_found == -1) {
fprintf(stderr, "error: invalid mode %s, must be one of [txt2img, img2img]\n",
fprintf(stderr,
"error: invalid mode %s, must be one of [txt2img, img2img, img2vid, convert]\n",
mode_selected);
exit(1);
}
@@ -420,7 +430,7 @@ void parse_args(int argc, const char** argv, SDParams& params) {
params.n_threads = get_num_physical_cores();
}
if (params.mode != CONVERT && params.prompt.length() == 0) {
if (params.mode != CONVERT && params.mode != IMG2VID && params.prompt.length() == 0) {
fprintf(stderr, "error: the following arguments are required: prompt\n");
print_usage(argc, argv);
exit(1);
@@ -432,7 +442,7 @@ void parse_args(int argc, const char** argv, SDParams& params) {
exit(1);
}
if (params.mode == IMG2IMG && params.input_path.length() == 0) {
if ((params.mode == IMG2IMG || params.mode == IMG2VID) && params.input_path.length() == 0) {
fprintf(stderr, "error: when using the img2img mode, the following arguments are required: init-img\n");
print_usage(argc, argv);
exit(1);
@@ -539,9 +549,14 @@ int main(int argc, const char* argv[]) {
}
}
if (params.mode == IMG2VID) {
fprintf(stderr, "SVD support is broken, do not use it!!!\n");
return 1;
}
bool vae_decode_only = true;
uint8_t* input_image_buffer = NULL;
if (params.mode == IMG2IMG) {
if (params.mode == IMG2IMG || params.mode == IMG2VID) {
vae_decode_only = false;
int c = 0;
@@ -625,19 +640,57 @@ int main(int argc, const char* argv[]) {
3,
input_image_buffer};
results = img2img(sd_ctx,
input_image,
params.prompt.c_str(),
params.negative_prompt.c_str(),
params.clip_skip,
params.cfg_scale,
params.width,
params.height,
params.sample_method,
params.sample_steps,
params.strength,
params.seed,
params.batch_count);
if (params.mode == IMG2VID) {
results = img2vid(sd_ctx,
input_image,
params.width,
params.height,
params.video_frames,
params.motion_bucket_id,
params.fps,
params.augmentation_level,
params.min_cfg,
params.cfg_scale,
params.sample_method,
params.sample_steps,
params.strength,
params.seed);
if (results == NULL) {
printf("generate failed\n");
free_sd_ctx(sd_ctx);
return 1;
}
size_t last = params.output_path.find_last_of(".");
std::string dummy_name = last != std::string::npos ? params.output_path.substr(0, last) : params.output_path;
for (int i = 0; i < params.video_frames; i++) {
if (results[i].data == NULL) {
continue;
}
std::string final_image_path = i > 0 ? dummy_name + "_" + std::to_string(i + 1) + ".png" : dummy_name + ".png";
stbi_write_png(final_image_path.c_str(), results[i].width, results[i].height, results[i].channel,
results[i].data, 0, get_image_params(params, params.seed + i).c_str());
printf("save result image to '%s'\n", final_image_path.c_str());
free(results[i].data);
results[i].data = NULL;
}
free(results);
free_sd_ctx(sd_ctx);
return 0;
} else {
results = img2img(sd_ctx,
input_image,
params.prompt.c_str(),
params.negative_prompt.c_str(),
params.clip_skip,
params.cfg_scale,
params.width,
params.height,
params.sample_method,
params.sample_steps,
params.strength,
params.seed,
params.batch_count);
}
}
if (results == NULL) {