feat: Control Net support + Textual Inversion (embeddings) (#131)

* add controlnet to pipeline

* add cli params

* control strength cli param

* cli param keep controlnet in cpu

* add Textual Inversion

* add canny preprocessor

* refactor: change ggml_type_sizef to ggml_row_size

* process hint once time

* ignore the embedding name case

---------

Co-authored-by: leejet <leejet714@gmail.com>
This commit is contained in:
Steward Garcia
2024-01-29 09:38:51 -05:00
committed by GitHub
parent c6071fa82f
commit 36ec16ac99
20 changed files with 1823 additions and 589 deletions

View File

@@ -7,6 +7,7 @@
#include <vector>
#include "stable-diffusion.h"
#include "preprocessing.hpp"
#define STB_IMAGE_IMPLEMENTATION
#include "stb_image.h"
@@ -60,10 +61,13 @@ struct SDParams {
std::string vae_path;
std::string taesd_path;
std::string esrgan_path;
std::string controlnet_path;
std::string embeddings_path;
sd_type_t wtype = SD_TYPE_COUNT;
std::string lora_model_dir;
std::string output_path = "output.png";
std::string input_path;
std::string control_image_path;
std::string prompt;
std::string negative_prompt;
@@ -77,24 +81,15 @@ struct SDParams {
schedule_t schedule = DEFAULT;
int sample_steps = 20;
float strength = 0.75f;
float control_strength = 0.9f;
rng_type_t rng_type = CUDA_RNG;
int64_t seed = 42;
bool verbose = false;
bool vae_tiling = false;
bool control_net_cpu = false;
bool canny_preprocess = false;
};
static std::string sd_basename(const std::string& path) {
size_t pos = path.find_last_of('/');
if (pos != std::string::npos) {
return path.substr(pos + 1);
}
pos = path.find_last_of('\\');
if (pos != std::string::npos) {
return path.substr(pos + 1);
}
return path;
}
void print_params(SDParams params) {
printf("Option: \n");
printf(" n_threads: %d\n", params.n_threads);
@@ -104,8 +99,13 @@ void print_params(SDParams params) {
printf(" vae_path: %s\n", params.vae_path.c_str());
printf(" taesd_path: %s\n", params.taesd_path.c_str());
printf(" esrgan_path: %s\n", params.esrgan_path.c_str());
printf(" controlnet_path: %s\n", params.controlnet_path.c_str());
printf(" embeddings_path: %s\n", params.embeddings_path.c_str());
printf(" output_path: %s\n", params.output_path.c_str());
printf(" init_img: %s\n", params.input_path.c_str());
printf(" control_image: %s\n", params.control_image_path.c_str());
printf(" controlnet cpu: %s\n", params.control_net_cpu ? "true" : "false");
printf(" strength(control): %.2f\n", params.control_strength);
printf(" prompt: %s\n", params.prompt.c_str());
printf(" negative_prompt: %s\n", params.negative_prompt.c_str());
printf(" cfg_scale: %.2f\n", params.cfg_scale);
@@ -133,16 +133,20 @@ void print_usage(int argc, const char* argv[]) {
printf(" -m, --model [MODEL] path to model\n");
printf(" --vae [VAE] path to vae\n");
printf(" --taesd [TAESD_PATH] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)\n");
printf(" --control-net [CONTROL_PATH] path to control net model\n");
printf(" --embd-dir [EMBEDDING_PATH] path to embeddings.\n");
printf(" --upscale-model [ESRGAN_PATH] path to esrgan model. Upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now.\n");
printf(" --type [TYPE] weight type (f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0)\n");
printf(" If not specified, the default is the type of the weight file.\n");
printf(" --lora-model-dir [DIR] lora model directory\n");
printf(" -i, --init-img [IMAGE] path to the input image, required by img2img\n");
printf(" --control-image [IMAGE] path to image condition, control net\n");
printf(" -o, --output OUTPUT path to write result image to (default: ./output.png)\n");
printf(" -p, --prompt [PROMPT] the prompt to render\n");
printf(" -n, --negative-prompt PROMPT the negative prompt (default: \"\")\n");
printf(" --cfg-scale SCALE unconditional guidance scale: (default: 7.0)\n");
printf(" --strength STRENGTH strength for noising/unnoising (default: 0.75)\n");
printf(" --control-strength STRENGTH strength to apply Control Net (default: 0.9)\n");
printf(" 1.0 corresponds to full destruction of information in init image\n");
printf(" -H, --height H image height, in pixel space (default: 512)\n");
printf(" -W, --width W image width, in pixel space (default: 512)\n");
@@ -156,6 +160,8 @@ void print_usage(int argc, const char* argv[]) {
printf(" --clip-skip N ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1)\n");
printf(" <= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x\n");
printf(" --vae-tiling process vae in tiles to reduce memory usage\n");
printf(" --control-net-cpu keep controlnet in cpu (for low vram)\n");
printf(" --canny apply canny preprocessor (edge detection)\n");
printf(" -v, --verbose print extra info\n");
}
@@ -207,13 +213,25 @@ void parse_args(int argc, const char** argv, SDParams& params) {
break;
}
params.taesd_path = argv[i];
} else if (arg == "--control-net") {
if (++i >= argc) {
invalid_arg = true;
break;
}
params.controlnet_path = argv[i];
} else if (arg == "--upscale-model") {
if (++i >= argc) {
invalid_arg = true;
break;
}
params.esrgan_path = argv[i];
} else if (arg == "--type") {
} else if (arg == "--embd-dir") {
if (++i >= argc) {
invalid_arg = true;
break;
}
params.embeddings_path = argv[i];
} else if (arg == "--type") {
if (++i >= argc) {
invalid_arg = true;
break;
@@ -250,6 +268,12 @@ void parse_args(int argc, const char** argv, SDParams& params) {
break;
}
params.input_path = argv[i];
} else if (arg == "--control-image") {
if (++i >= argc) {
invalid_arg = true;
break;
}
params.control_image_path = argv[i];
} else if (arg == "-o" || arg == "--output") {
if (++i >= argc) {
invalid_arg = true;
@@ -280,6 +304,12 @@ void parse_args(int argc, const char** argv, SDParams& params) {
break;
}
params.strength = std::stof(argv[i]);
} else if (arg == "--control-strength") {
if (++i >= argc) {
invalid_arg = true;
break;
}
params.control_strength = std::stof(argv[i]);
} else if (arg == "-H" || arg == "--height") {
if (++i >= argc) {
invalid_arg = true;
@@ -306,6 +336,10 @@ void parse_args(int argc, const char** argv, SDParams& params) {
params.clip_skip = std::stoi(argv[i]);
} else if (arg == "--vae-tiling") {
params.vae_tiling = true;
} else if (arg == "--control-net-cpu") {
params.control_net_cpu = true;
} else if (arg == "--canny") {
params.canny_preprocess = true;
} else if (arg == "-b" || arg == "--batch-count") {
if (++i >= argc) {
invalid_arg = true;
@@ -536,14 +570,17 @@ int main(int argc, const char* argv[]) {
sd_ctx_t* sd_ctx = new_sd_ctx(params.model_path.c_str(),
params.vae_path.c_str(),
params.taesd_path.c_str(),
params.controlnet_path.c_str(),
params.lora_model_dir.c_str(),
params.embeddings_path.c_str(),
vae_decode_only,
params.vae_tiling,
true,
params.n_threads,
params.wtype,
params.rng_type,
params.schedule);
params.schedule,
params.control_net_cpu);
if (sd_ctx == NULL) {
printf("new_sd_ctx_t failed\n");
@@ -552,6 +589,23 @@ int main(int argc, const char* argv[]) {
sd_image_t* results;
if (params.mode == TXT2IMG) {
sd_image_t* control_image = NULL;
if(params.controlnet_path.size() > 0 && params.control_image_path.size() > 0) {
int c = 0;
input_image_buffer = stbi_load(params.control_image_path.c_str(), &params.width, &params.height, &c, 3);
if(input_image_buffer == NULL) {
fprintf(stderr, "load image from '%s' failed\n", params.control_image_path.c_str());
return 1;
}
control_image = new sd_image_t{(uint32_t)params.width,
(uint32_t)params.height,
3,
input_image_buffer};
if(params.canny_preprocess) { // apply preprocessor
LOG_INFO("Applying canny preprocessor");
control_image->data = preprocess_canny(control_image->data, control_image->width, control_image->height);
}
}
results = txt2img(sd_ctx,
params.prompt.c_str(),
params.negative_prompt.c_str(),
@@ -562,7 +616,9 @@ int main(int argc, const char* argv[]) {
params.sample_method,
params.sample_steps,
params.seed,
params.batch_count);
params.batch_count,
control_image,
params.control_strength);
} else {
sd_image_t input_image = {(uint32_t)params.width,
(uint32_t)params.height,