feat: implement ESRGAN upscaler + Metal Backend (#104)

* add esrgan upscaler

* add sd_tiling

* support metal backend

* add clip_skip

---------

Co-authored-by: leejet <leejet714@gmail.com>
This commit is contained in:
Steward Garcia 2023-12-28 10:46:48 -05:00 committed by GitHub
parent 0e64238e4c
commit 004dfbef27
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 915 additions and 39 deletions

View File

@ -25,7 +25,9 @@ endif()
#option(SD_BUILD_TESTS "sd: build tests" ${SD_STANDALONE}) #option(SD_BUILD_TESTS "sd: build tests" ${SD_STANDALONE})
option(SD_BUILD_EXAMPLES "sd: build examples" ${SD_STANDALONE}) option(SD_BUILD_EXAMPLES "sd: build examples" ${SD_STANDALONE})
option(SD_CUBLAS "sd: cuda backend" OFF) option(SD_CUBLAS "sd: cuda backend" OFF)
option(SD_METAL "sd: metal backend" OFF)
option(SD_FLASH_ATTN "sd: use flash attention for x4 less memory usage" OFF) option(SD_FLASH_ATTN "sd: use flash attention for x4 less memory usage" OFF)
option(SD_FAST_SOFTMAX "sd: x1.5 faster softmax, indeterministic (sometimes, same seed don't generate same image), cuda only" OFF)
option(BUILD_SHARED_LIBS "sd: build shared libs" OFF) option(BUILD_SHARED_LIBS "sd: build shared libs" OFF)
#option(SD_BUILD_SERVER "sd: build server example" ON) #option(SD_BUILD_SERVER "sd: build server example" ON)
@ -33,6 +35,15 @@ if(SD_CUBLAS)
message("Use CUBLAS as backend stable-diffusion") message("Use CUBLAS as backend stable-diffusion")
set(GGML_CUBLAS ON) set(GGML_CUBLAS ON)
add_definitions(-DSD_USE_CUBLAS) add_definitions(-DSD_USE_CUBLAS)
if(SD_FAST_SOFTMAX)
set(GGML_CUDA_FAST_SOFTMAX ON)
endif()
endif()
if(SD_METAL)
message("Use Metal as backend stable-diffusion")
set(GGML_METAL ON)
add_definitions(-DSD_USE_METAL)
endif() endif()
if(SD_FLASH_ATTN) if(SD_FLASH_ATTN)

View File

@ -17,7 +17,7 @@ Inference of [Stable Diffusion](https://github.com/CompVis/stable-diffusion) in
- Accelerated memory-efficient CPU inference - Accelerated memory-efficient CPU inference
- Only requires ~2.3GB when using txt2img with fp16 precision to generate a 512x512 image, enabling Flash Attention just requires ~1.8GB. - Only requires ~2.3GB when using txt2img with fp16 precision to generate a 512x512 image, enabling Flash Attention just requires ~1.8GB.
- AVX, AVX2 and AVX512 support for x86 architectures - AVX, AVX2 and AVX512 support for x86 architectures
- Full CUDA backend for GPU acceleration. - Full CUDA and Metal backend for GPU acceleration.
- Can load ckpt, safetensors and diffusers models/checkpoints. Standalone VAEs models - Can load ckpt, safetensors and diffusers models/checkpoints. Standalone VAEs models
- No need to convert to `.ggml` or `.gguf` anymore! - No need to convert to `.ggml` or `.gguf` anymore!
- Flash Attention for memory usage optimization (only cpu for now) - Flash Attention for memory usage optimization (only cpu for now)
@ -27,6 +27,8 @@ Inference of [Stable Diffusion](https://github.com/CompVis/stable-diffusion) in
- LoRA support, same as [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#lora) - LoRA support, same as [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#lora)
- Latent Consistency Models support (LCM/LCM-LoRA) - Latent Consistency Models support (LCM/LCM-LoRA)
- Faster and memory efficient latent decoding with [TAESD](https://github.com/madebyollin/taesd) - Faster and memory efficient latent decoding with [TAESD](https://github.com/madebyollin/taesd)
- Upscale images generated with [ESRGAN](https://github.com/xinntao/Real-ESRGAN)
- VAE tiling processing for reduce memory usage
- Sampling method - Sampling method
- `Euler A` - `Euler A`
- `Euler` - `Euler`
@ -51,7 +53,8 @@ Inference of [Stable Diffusion](https://github.com/CompVis/stable-diffusion) in
- The current implementation of ggml_conv_2d is slow and has high memory usage - The current implementation of ggml_conv_2d is slow and has high memory usage
- Implement Winograd Convolution 2D for 3x3 kernel filtering - Implement Winograd Convolution 2D for 3x3 kernel filtering
- [ ] Continuing to reduce memory usage (quantizing the weights of ggml_conv_2d) - [ ] Continuing to reduce memory usage (quantizing the weights of ggml_conv_2d)
- [ ] Implement [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN/tree/master) upscaler - [ ] Implement Textual Inversion (embeddings)
- [ ] Implement Inpainting support
- [ ] k-quants support - [ ] k-quants support
## Usage ## Usage
@ -112,6 +115,15 @@ cmake .. -DSD_CUBLAS=ON
cmake --build . --config Release cmake --build . --config Release
``` ```
##### Using Metal
Using Metal makes the computation run on the GPU. Currently, there are some issues with Metal when performing operations on very large matrices, making it highly inefficient at the moment. Performance improvements are expected in the near future.
```
cmake .. -DSD_METAL=ON
cmake --build . --config Release
```
### Using Flash Attention ### Using Flash Attention
Enabling flash attention reduces memory usage by at least 400 MB. At the moment, it is not supported when CUBLAS is enabled because the kernel implementation is missing. Enabling flash attention reduces memory usage by at least 400 MB. At the moment, it is not supported when CUBLAS is enabled because the kernel implementation is missing.
@ -124,7 +136,7 @@ cmake --build . --config Release
### Run ### Run
``` ```
usage: sd [arguments] usage: ./bin/sd [arguments]
arguments: arguments:
-h, --help show this help message and exit -h, --help show this help message and exit
@ -134,6 +146,7 @@ arguments:
-m, --model [MODEL] path to model -m, --model [MODEL] path to model
--vae [VAE] path to vae --vae [VAE] path to vae
--taesd [TAESD_PATH] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality) --taesd [TAESD_PATH] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)
--upscale-model [ESRGAN_PATH] path to esrgan model. Upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now.
--type [TYPE] weight type (f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0) --type [TYPE] weight type (f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0)
If not specified, the default is the type of the weight file. If not specified, the default is the type of the weight file.
--lora-model-dir [DIR] lora model directory --lora-model-dir [DIR] lora model directory
@ -153,6 +166,8 @@ arguments:
-s SEED, --seed SEED RNG seed (default: 42, use random seed for < 0) -s SEED, --seed SEED RNG seed (default: 42, use random seed for < 0)
-b, --batch-count COUNT number of images to generate. -b, --batch-count COUNT number of images to generate.
--schedule {discrete, karras} Denoiser sigma schedule (default: discrete) --schedule {discrete, karras} Denoiser sigma schedule (default: discrete)
--clip-skip N number of layers to skip of clip model (default: 0)
--vae-tiling process vae in tiles to reduce memory usage
-v, --verbose print extra info -v, --verbose print extra info
``` ```
@ -240,6 +255,16 @@ curl -L -O https://huggingface.co/madebyollin/taesd/blob/main/diffusion_pytorch_
sd -m ../models/v1-5-pruned-emaonly.safetensors -p "a lovely cat" --taesd ../models/diffusion_pytorch_model.safetensors sd -m ../models/v1-5-pruned-emaonly.safetensors -p "a lovely cat" --taesd ../models/diffusion_pytorch_model.safetensors
``` ```
## Using ESRGAN to upscale results
You can use ESRGAN to upscale the generated images. At the moment, only the [RealESRGAN_x4plus_anime_6B.pth](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth) model is supported. Support for more models of this architecture will be added soon.
- Specify the model path using the `--upscale-model PATH` parameter. example:
```bash
sd -m ../models/v1-5-pruned-emaonly.safetensors -p "a lovely cat" --upscale-model ../models/RealESRGAN_x4plus_anime_6B.pth
```
### Docker ### Docker
#### Building using Docker #### Building using Docker

View File

@ -59,6 +59,7 @@ struct SDParams {
std::string model_path; std::string model_path;
std::string vae_path; std::string vae_path;
std::string taesd_path; std::string taesd_path;
std::string esrgan_path;
ggml_type wtype = GGML_TYPE_COUNT; ggml_type wtype = GGML_TYPE_COUNT;
std::string lora_model_dir; std::string lora_model_dir;
std::string output_path = "output.png"; std::string output_path = "output.png";
@ -67,6 +68,7 @@ struct SDParams {
std::string prompt; std::string prompt;
std::string negative_prompt; std::string negative_prompt;
float cfg_scale = 7.0f; float cfg_scale = 7.0f;
int clip_skip = -1; // <= 0 represents unspecified
int width = 512; int width = 512;
int height = 512; int height = 512;
int batch_count = 1; int batch_count = 1;
@ -78,6 +80,7 @@ struct SDParams {
RNGType rng_type = CUDA_RNG; RNGType rng_type = CUDA_RNG;
int64_t seed = 42; int64_t seed = 42;
bool verbose = false; bool verbose = false;
bool vae_tiling = false;
}; };
void print_params(SDParams params) { void print_params(SDParams params) {
@ -88,11 +91,13 @@ void print_params(SDParams params) {
printf(" wtype: %s\n", params.wtype < GGML_TYPE_COUNT ? ggml_type_name(params.wtype) : "unspecified"); printf(" wtype: %s\n", params.wtype < GGML_TYPE_COUNT ? ggml_type_name(params.wtype) : "unspecified");
printf(" vae_path: %s\n", params.vae_path.c_str()); printf(" vae_path: %s\n", params.vae_path.c_str());
printf(" taesd_path: %s\n", params.taesd_path.c_str()); printf(" taesd_path: %s\n", params.taesd_path.c_str());
printf(" esrgan_path: %s\n", params.esrgan_path.c_str());
printf(" output_path: %s\n", params.output_path.c_str()); printf(" output_path: %s\n", params.output_path.c_str());
printf(" init_img: %s\n", params.input_path.c_str()); printf(" init_img: %s\n", params.input_path.c_str());
printf(" prompt: %s\n", params.prompt.c_str()); printf(" prompt: %s\n", params.prompt.c_str());
printf(" negative_prompt: %s\n", params.negative_prompt.c_str()); printf(" negative_prompt: %s\n", params.negative_prompt.c_str());
printf(" cfg_scale: %.2f\n", params.cfg_scale); printf(" cfg_scale: %.2f\n", params.cfg_scale);
printf(" clip_skip: %d\n", params.clip_skip);
printf(" width: %d\n", params.width); printf(" width: %d\n", params.width);
printf(" height: %d\n", params.height); printf(" height: %d\n", params.height);
printf(" sample_method: %s\n", sample_method_str[params.sample_method]); printf(" sample_method: %s\n", sample_method_str[params.sample_method]);
@ -102,6 +107,7 @@ void print_params(SDParams params) {
printf(" rng: %s\n", rng_type_to_str[params.rng_type]); printf(" rng: %s\n", rng_type_to_str[params.rng_type]);
printf(" seed: %ld\n", params.seed); printf(" seed: %ld\n", params.seed);
printf(" batch_count: %d\n", params.batch_count); printf(" batch_count: %d\n", params.batch_count);
printf(" vae_tiling: %s\n", params.vae_tiling ? "true" : "false");
} }
void print_usage(int argc, const char* argv[]) { void print_usage(int argc, const char* argv[]) {
@ -115,6 +121,7 @@ void print_usage(int argc, const char* argv[]) {
printf(" -m, --model [MODEL] path to model\n"); printf(" -m, --model [MODEL] path to model\n");
printf(" --vae [VAE] path to vae\n"); printf(" --vae [VAE] path to vae\n");
printf(" --taesd [TAESD_PATH] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)\n"); printf(" --taesd [TAESD_PATH] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)\n");
printf(" --upscale-model [ESRGAN_PATH] path to esrgan model. Upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now.\n");
printf(" --type [TYPE] weight type (f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0)\n"); printf(" --type [TYPE] weight type (f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0)\n");
printf(" If not specified, the default is the type of the weight file.\n"); printf(" If not specified, the default is the type of the weight file.\n");
printf(" --lora-model-dir [DIR] lora model directory\n"); printf(" --lora-model-dir [DIR] lora model directory\n");
@ -134,6 +141,9 @@ void print_usage(int argc, const char* argv[]) {
printf(" -s SEED, --seed SEED RNG seed (default: 42, use random seed for < 0)\n"); printf(" -s SEED, --seed SEED RNG seed (default: 42, use random seed for < 0)\n");
printf(" -b, --batch-count COUNT number of images to generate.\n"); printf(" -b, --batch-count COUNT number of images to generate.\n");
printf(" --schedule {discrete, karras} Denoiser sigma schedule (default: discrete)\n"); printf(" --schedule {discrete, karras} Denoiser sigma schedule (default: discrete)\n");
printf(" --clip-skip N ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1)\n");
printf(" <= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x\n");
printf(" --vae-tiling process vae in tiles to reduce memory usage\n");
printf(" -v, --verbose print extra info\n"); printf(" -v, --verbose print extra info\n");
} }
@ -185,6 +195,12 @@ void parse_args(int argc, const char** argv, SDParams& params) {
break; break;
} }
params.taesd_path = argv[i]; params.taesd_path = argv[i];
} else if (arg == "--upscale-model") {
if (++i >= argc) {
invalid_arg = true;
break;
}
params.esrgan_path = argv[i];
} else if (arg == "--type") { } else if (arg == "--type") {
if (++i >= argc) { if (++i >= argc) {
invalid_arg = true; invalid_arg = true;
@ -270,6 +286,14 @@ void parse_args(int argc, const char** argv, SDParams& params) {
break; break;
} }
params.sample_steps = std::stoi(argv[i]); params.sample_steps = std::stoi(argv[i]);
} else if (arg == "--clip-skip") {
if (++i >= argc) {
invalid_arg = true;
break;
}
params.clip_skip = std::stoi(argv[i]);
} else if (arg == "--vae-tiling") {
params.vae_tiling = true;
} else if (arg == "-b" || arg == "--batch-count") { } else if (arg == "-b" || arg == "--batch-count") {
if (++i >= argc) { if (++i >= argc) {
invalid_arg = true; invalid_arg = true;
@ -458,9 +482,9 @@ int main(int argc, const char* argv[]) {
} }
} }
StableDiffusion sd(params.n_threads, vae_decode_only, params.taesd_path, true, params.lora_model_dir, params.rng_type); StableDiffusion sd(params.n_threads, vae_decode_only, params.taesd_path, params.esrgan_path, true, params.vae_tiling, params.lora_model_dir, params.rng_type);
if (!sd.load_from_file(params.model_path, params.vae_path, params.wtype, params.schedule)) { if (!sd.load_from_file(params.model_path, params.vae_path, params.wtype, params.schedule, params.clip_skip)) {
return 1; return 1;
} }
@ -488,6 +512,19 @@ int main(int argc, const char* argv[]) {
params.seed); params.seed);
} }
if (params.esrgan_path.size() > 0) {
// TODO: support more ESRGAN models, making it easier to set up ESRGAN models.
/* hardcoded scale factor because just RealESRGAN_x4plus_anime_6B is compatible
See also: https://github.com/xinntao/Real-ESRGAN/blob/master/inference_realesrgan.py
To avoid this, the upscaler needs to be separated from the stable diffusion pipeline.
However, a considerable amount of work would be required for this. It might be better
to opt for a complete project refactoring that facilitates the easier assignment of parameters.
*/
params.width *= 4;
params.height *= 4;
}
if (results.size() == 0 || results.size() != params.batch_count) { if (results.size() == 0 || results.size() != params.batch_count) {
LOG_ERROR("generate failed"); LOG_ERROR("generate failed");
return 1; return 1;

2
ggml

@ -1 +1 @@
Subproject commit 70474c6890c015b53dc10a2300ae35246cc73589 Subproject commit a0c2ec77a5ef8e630aff65bc535d13b9805cb929

View File

@ -14,6 +14,10 @@
#include "ggml/ggml-backend.h" #include "ggml/ggml-backend.h"
#include "ggml/ggml.h" #include "ggml/ggml.h"
#ifdef SD_USE_METAL
#include "ggml-metal.h"
#endif
#define ST_HEADER_SIZE_LEN 8 #define ST_HEADER_SIZE_LEN 8
uint64_t read_u64(uint8_t* buffer) { uint64_t read_u64(uint8_t* buffer) {
@ -1197,7 +1201,7 @@ std::string ModelLoader::load_merges() {
return merges_utf8_str; return merges_utf8_str;
} }
bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb) { bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend_t backend) {
bool success = true; bool success = true;
for (size_t file_index = 0; file_index < file_paths_.size(); file_index++) { for (size_t file_index = 0; file_index < file_paths_.size(); file_index++) {
std::string file_path = file_paths_[file_index]; std::string file_path = file_paths_[file_index];
@ -1285,11 +1289,13 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb) {
continue; continue;
} }
ggml_backend_t backend = ggml_get_backend(dst_tensor);
size_t nbytes_to_read = tensor_storage.nbytes_to_read(); size_t nbytes_to_read = tensor_storage.nbytes_to_read();
if (backend == NULL || ggml_backend_is_cpu(backend)) { if (dst_tensor->buffer == NULL || ggml_backend_is_cpu(backend)
#ifdef SD_USE_METAL
|| ggml_backend_is_metal(backend)
#endif
) {
// for the CPU and Metal backend, we can copy directly into the tensor // for the CPU and Metal backend, we can copy directly into the tensor
if (tensor_storage.type == dst_tensor->type) { if (tensor_storage.type == dst_tensor->type) {
GGML_ASSERT(ggml_nbytes(dst_tensor) == tensor_storage.nbytes()); GGML_ASSERT(ggml_nbytes(dst_tensor) == tensor_storage.nbytes());

View File

@ -116,7 +116,7 @@ public:
SDVersion get_sd_version(); SDVersion get_sd_version();
ggml_type get_sd_wtype(); ggml_type get_sd_wtype();
std::string load_merges(); std::string load_merges();
bool load_tensors(on_new_tensor_cb_t on_new_tensor_cb); bool load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend_t backend);
int64_t cal_mem_size(ggml_backend_t backend); int64_t cal_mem_size(ggml_backend_t backend);
~ModelLoader() = default; ~ModelLoader() = default;
}; };

File diff suppressed because it is too large Load Diff

View File

@ -4,6 +4,7 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
#include "ggml/ggml.h"
#include "ggml/ggml.h" #include "ggml/ggml.h"
@ -41,14 +42,17 @@ public:
StableDiffusion(int n_threads = -1, StableDiffusion(int n_threads = -1,
bool vae_decode_only = false, bool vae_decode_only = false,
std::string taesd_path = "", std::string taesd_path = "",
std::string esrgan_path = "",
bool free_params_immediately = false, bool free_params_immediately = false,
bool vae_tiling = false,
std::string lora_model_dir = "", std::string lora_model_dir = "",
RNGType rng_type = STD_DEFAULT_RNG); RNGType rng_type = STD_DEFAULT_RNG);
bool load_from_file(const std::string& model_path, bool load_from_file(const std::string& model_path,
const std::string& vae_path, const std::string& vae_path,
ggml_type wtype, ggml_type wtype,
Schedule d = DEFAULT); Schedule d = DEFAULT,
int clip_skip = -1);
std::vector<uint8_t*> txt2img( std::vector<uint8_t*> txt2img(
std::string prompt, std::string prompt,