diff --git a/CMakeLists.txt b/CMakeLists.txt
index 3557b8c..158eeed 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -7,7 +7,7 @@ set(SD_TARGET sd)
add_subdirectory(ggml)
add_library(${SD_LIB} stable-diffusion.h stable-diffusion.cpp)
-add_executable(${SD_TARGET} main.cpp stb_image_write.h)
+add_executable(${SD_TARGET} main.cpp stb_image.h stb_image_write.h)
target_link_libraries(${SD_LIB} PUBLIC ggml)
target_link_libraries(${SD_TARGET} ${SD_LIB})
diff --git a/README.md b/README.md
index a2b02b5..dfc146d 100644
--- a/README.md
+++ b/README.md
@@ -13,7 +13,7 @@ Inference of [Stable Diffusion](https://github.com/CompVis/stable-diffusion) in
- 4-bit, 5-bit and 8-bit integer quantization support
- Accelerated memory-efficient CPU inference
- AVX, AVX2 and AVX512 support for x86 architectures
-- Original `txt2img` mode
+- Original `txt2img` and `img2img` mode
- Negative prompt
- Sampling method
- `Euler A`
@@ -24,7 +24,6 @@ Inference of [Stable Diffusion](https://github.com/CompVis/stable-diffusion) in
### TODO
-- [ ] Original `img2img` mode
- [ ] More sampling methods
- [ ] GPU support
- [ ] Make inference faster
@@ -97,13 +96,17 @@ usage: ./sd [arguments]
arguments:
-h, --help show this help message and exit
+ -M, --mode [txt2img or img2img] generation mode (default: txt2img)
-t, --threads N number of threads to use during computation (default: -1).
- If threads <= 0, then threads will be set to the number of CPU cores
+ If threads <= 0, then threads will be set to the number of CPU physical cores
-m, --model [MODEL] path to model
+ -i, --init-img [IMAGE] path to the input image, required by img2img
-o, --output OUTPUT path to write result image to (default: .\output.png)
-p, --prompt [PROMPT] the prompt to render
-n, --negative-prompt PROMPT the negative prompt (default: "")
--cfg-scale SCALE unconditional guidance scale: (default: 7.0)
+ --strength STRENGTH strength for noising/unnoising (default: 0.75)
+ 1.0 corresponds to full destruction of information in init image
-H, --height H image height, in pixel space (default: 512)
-W, --width W image width, in pixel space (default: 512)
--sample-method SAMPLE_METHOD sample method (default: "eular a")
@@ -112,7 +115,7 @@ arguments:
-v, --verbose print extra info
```
-For example
+#### txt2img example
```
./sd -m ../models/sd-v1-4-ggml-model-f16.bin -p "a lovely cat"
@@ -124,6 +127,19 @@ Using formats of different precisions will yield results of varying quality.
| ---- |---- |---- |---- |---- |---- |---- |
|  | | | | | | |
+#### img2img example
+
+- `./output.png` is the image generated from the above txt2img pipeline
+
+
+```
+./sd --mode img2img -m ../models/sd-v1-4-ggml-model-f16.bin -p "cat with blue eyes" -i ./output.png -o ./img2img_output.png --strength 0.4
+```
+
+
+
+
+
## Memory/Disk Requirements
| precision | f32 | f16 |q8_0 |q5_0 |q5_1 |q4_0 |q4_1 |
diff --git a/assets/img2img_output.png b/assets/img2img_output.png
new file mode 100644
index 0000000..80579a1
Binary files /dev/null and b/assets/img2img_output.png differ
diff --git a/main.cpp b/main.cpp
index 5622482..8f9d1e8 100644
--- a/main.cpp
+++ b/main.cpp
@@ -8,13 +8,16 @@
#include "stable-diffusion.h"
+#define STB_IMAGE_IMPLEMENTATION
+#include "stb_image.h"
+
#define STB_IMAGE_WRITE_IMPLEMENTATION
#define STB_IMAGE_WRITE_STATIC
#include "stb_image_write.h"
#if defined(__APPLE__) && defined(__MACH__)
-#include
#include
+#include
#endif
#if !defined(_WIN32)
@@ -22,6 +25,9 @@
#include
#endif
+#define TXT2IMG "txt2img"
+#define IMG2IMG "img2img"
+
// get_num_physical_cores is copy from
// https://github.com/ggerganov/llama.cpp/blob/master/examples/common.cpp
// LICENSE: https://github.com/ggerganov/llama.cpp/blob/master/LICENSE
@@ -63,8 +69,10 @@ int32_t get_num_physical_cores() {
struct Option {
int n_threads = -1;
+ std::string mode = TXT2IMG;
std::string model_path;
std::string output_path = "output.png";
+ std::string init_img;
std::string prompt;
std::string negative_prompt;
float cfg_scale = 7.0f;
@@ -72,14 +80,17 @@ struct Option {
int h = 512;
SampleMethod sample_method = EULAR_A;
int sample_steps = 20;
+ float strength = 0.75f;
int seed = 42;
bool verbose = false;
void print() {
printf("Option: \n");
printf(" n_threads: %d\n", n_threads);
+ printf(" mode: %s\n", mode.c_str());
printf(" model_path: %s\n", model_path.c_str());
printf(" output_path: %s\n", output_path.c_str());
+ printf(" init_img: %s\n", init_img.c_str());
printf(" prompt: %s\n", prompt.c_str());
printf(" negative_prompt: %s\n", negative_prompt.c_str());
printf(" cfg_scale: %.2f\n", cfg_scale);
@@ -87,6 +98,7 @@ struct Option {
printf(" height: %d\n", h);
printf(" sample_method: %s\n", "eular a");
printf(" sample_steps: %d\n", sample_steps);
+ printf(" strength: %.2f\n", strength);
printf(" seed: %d\n", seed);
}
};
@@ -96,13 +108,17 @@ void print_usage(int argc, const char* argv[]) {
printf("\n");
printf("arguments:\n");
printf(" -h, --help show this help message and exit\n");
+ printf(" -M, --mode [txt2img or img2img] generation mode (default: txt2img)\n");
printf(" -t, --threads N number of threads to use during computation (default: -1).\n");
printf(" If threads <= 0, then threads will be set to the number of CPU physical cores\n");
printf(" -m, --model [MODEL] path to model\n");
+ printf(" -i, --init-img [IMAGE] path to the input image, required by img2img\n");
printf(" -o, --output OUTPUT path to write result image to (default: .\\output.png)\n");
printf(" -p, --prompt [PROMPT] the prompt to render\n");
printf(" -n, --negative-prompt PROMPT the negative prompt (default: \"\")\n");
printf(" --cfg-scale SCALE unconditional guidance scale: (default: 7.0)\n");
+ printf(" --strength STRENGTH strength for noising/unnoising (default: 0.75)\n");
+ printf(" 1.0 corresponds to full destruction of information in init image\n");
printf(" -H, --height H image height, in pixel space (default: 512)\n");
printf(" -W, --width W image width, in pixel space (default: 512)\n");
printf(" --sample-method SAMPLE_METHOD sample method (default: \"eular a\")\n");
@@ -123,12 +139,25 @@ void parse_args(int argc, const char* argv[], Option* opt) {
break;
}
opt->n_threads = std::stoi(argv[i]);
+ } else if (arg == "-M" || arg == "--mode") {
+ if (++i >= argc) {
+ invalid_arg = true;
+ break;
+ }
+ opt->mode = argv[i];
+
} else if (arg == "-m" || arg == "--model") {
if (++i >= argc) {
invalid_arg = true;
break;
}
opt->model_path = argv[i];
+ } else if (arg == "-i" || arg == "--init-img") {
+ if (++i >= argc) {
+ invalid_arg = true;
+ break;
+ }
+ opt->init_img = argv[i];
} else if (arg == "-o" || arg == "--output") {
if (++i >= argc) {
invalid_arg = true;
@@ -153,6 +182,12 @@ void parse_args(int argc, const char* argv[], Option* opt) {
break;
}
opt->cfg_scale = std::stof(argv[i]);
+ } else if (arg == "--strength") {
+ if (++i >= argc) {
+ invalid_arg = true;
+ break;
+ }
+ opt->strength = std::stof(argv[i]);
} else if (arg == "-H" || arg == "--height") {
if (++i >= argc) {
invalid_arg = true;
@@ -198,6 +233,12 @@ void parse_args(int argc, const char* argv[], Option* opt) {
opt->n_threads = get_num_physical_cores();
}
+ if (opt->mode != TXT2IMG && opt->mode != IMG2IMG) {
+ fprintf(stderr, "error: invalid mode %s, must be one of ['%s', '%s']\n",
+ opt->mode.c_str(), TXT2IMG, IMG2IMG);
+ exit(1);
+ }
+
if (opt->prompt.length() == 0) {
fprintf(stderr, "error: the following arguments are required: prompt\n");
print_usage(argc, argv);
@@ -210,6 +251,12 @@ void parse_args(int argc, const char* argv[], Option* opt) {
exit(1);
}
+ if (opt->mode == IMG2IMG && opt->init_img.length() == 0) {
+ fprintf(stderr, "error: when using the img2img mode, the following arguments are required: init-img\n");
+ print_usage(argc, argv);
+ exit(1);
+ }
+
if (opt->output_path.length() == 0) {
fprintf(stderr, "error: the following arguments are required: output_path\n");
print_usage(argc, argv);
@@ -230,6 +277,11 @@ void parse_args(int argc, const char* argv[], Option* opt) {
fprintf(stderr, "error: the sample_steps must be greater than 0\n");
exit(1);
}
+
+ if (opt->strength < 0.f || opt->strength > 1.f) {
+ fprintf(stderr, "error: can only work with strength in [0.0, 1.0]\n");
+ exit(1);
+ }
}
int main(int argc, const char* argv[]) {
@@ -242,19 +294,66 @@ int main(int argc, const char* argv[]) {
set_sd_log_level(SDLogLevel::DEBUG);
}
- StableDiffusion sd(opt.n_threads);
+ bool vae_decode_only = true;
+ std::vector init_img;
+ if (opt.mode == IMG2IMG) {
+ vae_decode_only = false;
+
+ int c = 0;
+ unsigned char* img_data = stbi_load(opt.init_img.c_str(), &opt.w, &opt.h, &c, 3);
+ if (img_data == NULL) {
+ fprintf(stderr, "load image from '%s' failed\n", opt.init_img.c_str());
+ return 1;
+ }
+ if (c != 3) {
+ fprintf(stderr, "input image must be a 3 channels RGB image, but got %d channels\n", c);
+ free(img_data);
+ return 1;
+ }
+ if (opt.w <= 0 || opt.w % 32 != 0) {
+ fprintf(stderr, "error: the width of image must be a multiple of 32\n");
+ free(img_data);
+ return 1;
+ }
+ if (opt.h <= 0 || opt.h % 32 != 0) {
+ fprintf(stderr, "error: the height of image must be a multiple of 32\n");
+ free(img_data);
+ return 1;
+ }
+ init_img.assign(img_data, img_data + (opt.w * opt.h * c));
+ }
+ StableDiffusion sd(opt.n_threads, vae_decode_only);
if (!sd.load_from_file(opt.model_path)) {
return 1;
}
- std::vector img = sd.txt2img(opt.prompt,
- opt.negative_prompt,
- opt.cfg_scale,
- opt.w,
- opt.h,
- opt.sample_method,
- opt.sample_steps,
- opt.seed);
+ std::vector img;
+ if (opt.mode == TXT2IMG) {
+ img = sd.txt2img(opt.prompt,
+ opt.negative_prompt,
+ opt.cfg_scale,
+ opt.w,
+ opt.h,
+ opt.sample_method,
+ opt.sample_steps,
+ opt.seed);
+ } else {
+ img = sd.img2img(init_img,
+ opt.prompt,
+ opt.negative_prompt,
+ opt.cfg_scale,
+ opt.w,
+ opt.h,
+ opt.sample_method,
+ opt.sample_steps,
+ opt.strength,
+ opt.seed);
+ }
+
+ if (img.size() == 0) {
+ fprintf(stderr, "generate failed\n");
+ return 1;
+ }
stbi_write_png(opt.output_path.c_str(), opt.w, opt.h, 3, img.data(), 0);
printf("save result image to '%s'\n", opt.output_path.c_str());
diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp
index c0f9db7..5820a63 100644
--- a/stable-diffusion.cpp
+++ b/stable-diffusion.cpp
@@ -1,7 +1,9 @@
#include
#include
+#include
#include
#include
+#include
#include