From 31e77e15730ab1adb43426afe5162ee5c430d037 Mon Sep 17 00:00:00 2001 From: leejet Date: Sun, 3 Sep 2023 16:00:33 +0800 Subject: [PATCH] feat: add SD2.x support (#40) --- README.md | 4 + models/convert.py | 109 ++++++++- stable-diffusion.cpp | 534 ++++++++++++++++++++++++++++++++----------- 3 files changed, 507 insertions(+), 140 deletions(-) diff --git a/README.md b/README.md index af6094c..443d8c5 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,7 @@ Inference of [Stable Diffusion](https://github.com/CompVis/stable-diffusion) in - Accelerated memory-efficient CPU inference - Only requires ~2.3GB when using txt2img with fp16 precision to generate a 512x512 image - AVX, AVX2 and AVX512 support for x86 architectures +- SD1.x and SD2.x support - Original `txt2img` and `img2img` mode - Negative prompt - [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) style tokenizer (not all the features, only token weighting for now) @@ -60,10 +61,12 @@ git submodule update - download original weights(.ckpt or .safetensors). For example - Stable Diffusion v1.4 from https://huggingface.co/CompVis/stable-diffusion-v-1-4-original - Stable Diffusion v1.5 from https://huggingface.co/runwayml/stable-diffusion-v1-5 + - Stable Diffuison v2.1 from https://huggingface.co/stabilityai/stable-diffusion-2-1 ```shell curl -L -O https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt # curl -L -O https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.safetensors + # curl -L -o https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/v2-1_768-nonema-pruned.safetensors ``` - convert weights to ggml model format @@ -182,5 +185,6 @@ docker run -v /path/to/models:/models -v /path/to/output/:/output sd [args...] - [ggml](https://github.com/ggerganov/ggml) - [stable-diffusion](https://github.com/CompVis/stable-diffusion) +- [stable-diffusion-stability-ai](https://github.com/Stability-AI/stablediffusion) - [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) - [k-diffusion](https://github.com/crowsonkb/k-diffusion) diff --git a/models/convert.py b/models/convert.py index 5ad4587..b324e20 100644 --- a/models/convert.py +++ b/models/convert.py @@ -9,6 +9,9 @@ import safetensors.torch this_file_dir = os.path.dirname(__file__) vocab_dir = this_file_dir +SD1 = 0 +SD2 = 1 + ggml_ftype_str_to_int = { "f32": 0, "f16": 1, @@ -155,6 +158,8 @@ unused_tensors = [ "posterior_mean_coef1", "posterior_mean_coef2", "cond_stage_model.transformer.text_model.embeddings.position_ids", + "cond_stage_model.model.logit_scale", + "cond_stage_model.model.text_projection", "model_ema.decay", "model_ema.num_updates", "control_model", @@ -162,12 +167,8 @@ unused_tensors = [ "embedding_manager" ] -def convert(model_path, out_type = None, out_file=None): - # load model - with open(os.path.join(vocab_dir, "vocab.json"), encoding="utf-8") as f: - clip_vocab = json.load(f) - - state_dict = load_model_from_file(model_path) + +def preprocess(state_dict): alphas_cumprod = state_dict.get("alphas_cumprod") if alphas_cumprod != None: # print((np.abs(get_alpha_comprod().numpy() - alphas_cumprod.numpy()) < 0.000001).all()) @@ -176,11 +177,100 @@ def convert(model_path, out_type = None, out_file=None): print("no alphas_cumprod in file, generate new one") alphas_cumprod = get_alpha_comprod() state_dict["alphas_cumprod"] = alphas_cumprod + + new_state_dict = {} + for name in state_dict.keys(): + # ignore unused tensors + if not isinstance(state_dict[name], torch.Tensor): + continue + skip = False + for unused_tensor in unused_tensors: + if name.startswith(unused_tensor): + skip = True + break + if skip: + continue + + # convert open_clip to hf CLIPTextModel (for SD2.x) + open_clip_to_hf_clip_model = { + "cond_stage_model.model.ln_final.bias": "cond_stage_model.transformer.text_model.final_layer_norm.bias", + "cond_stage_model.model.ln_final.weight": "cond_stage_model.transformer.text_model.final_layer_norm.weight", + "cond_stage_model.model.positional_embedding": "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight", + "cond_stage_model.model.token_embedding.weight": "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight", + } + open_clip_to_hk_clip_resblock = { + "attn.out_proj.bias": "self_attn.out_proj.bias", + "attn.out_proj.weight": "self_attn.out_proj.weight", + "ln_1.bias": "layer_norm1.bias", + "ln_1.weight": "layer_norm1.weight", + "ln_2.bias": "layer_norm2.bias", + "ln_2.weight": "layer_norm2.weight", + "mlp.c_fc.bias": "mlp.fc1.bias", + "mlp.c_fc.weight": "mlp.fc1.weight", + "mlp.c_proj.bias": "mlp.fc2.bias", + "mlp.c_proj.weight": "mlp.fc2.weight", + } + open_clip_resblock_prefix = "cond_stage_model.model.transformer.resblocks." + hf_clip_resblock_prefix = "cond_stage_model.transformer.text_model.encoder.layers." + if name in open_clip_to_hf_clip_model: + new_name = open_clip_to_hf_clip_model[name] + new_state_dict[new_name] = state_dict[name] + print(f"preprocess {name} => {new_name}") + continue + if name.startswith(open_clip_resblock_prefix): + remain = name[len(open_clip_resblock_prefix):] + idx = remain.split(".")[0] + suffix = remain[len(idx)+1:] + if suffix == "attn.in_proj_weight": + w = state_dict[name] + w_q, w_k, w_v = w.chunk(3) + for new_suffix, new_w in zip(["self_attn.q_proj.weight", "self_attn.k_proj.weight", "self_attn.v_proj.weight"], [w_q, w_k, w_v]): + new_name = hf_clip_resblock_prefix + idx + "." + new_suffix + new_state_dict[new_name] = new_w + print(f"preprocess {name}{w.size()} => {new_name}{new_w.size()}") + elif suffix == "attn.in_proj_bias": + w = state_dict[name] + w_q, w_k, w_v = w.chunk(3) + for new_suffix, new_w in zip(["self_attn.q_proj.bias", "self_attn.k_proj.bias", "self_attn.v_proj.bias"], [w_q, w_k, w_v]): + new_name = hf_clip_resblock_prefix + idx + "." + new_suffix + new_state_dict[new_name] = new_w + print(f"preprocess {name}{w.size()} => {new_name}{new_w.size()}") + else: + new_suffix = open_clip_to_hk_clip_resblock[suffix] + new_name = hf_clip_resblock_prefix + idx + "." + new_suffix + new_state_dict[new_name] = state_dict[name] + print(f"preprocess {name} => {new_name}") + continue + + # convert unet transformer linear to conv2d 1x1 + if name.startswith("model.diffusion_model.") and (name.endswith("proj_in.weight") or name.endswith("proj_out.weight")): + w = state_dict[name] + if len(state_dict[name].shape) == 2: + new_w = w.unsqueeze(2).unsqueeze(3) + new_state_dict[name] = new_w + print(f"preprocess {name} {w.size()} => {name} {new_w.size()}") + continue + new_state_dict[name] = state_dict[name] + return new_state_dict + +def convert(model_path, out_type = None, out_file=None): + # load model + with open(os.path.join(vocab_dir, "vocab.json"), encoding="utf-8") as f: + clip_vocab = json.load(f) + + state_dict = load_model_from_file(model_path) + model_type = SD1 + if "cond_stage_model.model.token_embedding.weight" in state_dict.keys(): + model_type = SD2 + print("Stable diffuison 2.x") + else: + print("Stable diffuison 1.x") + state_dict = preprocess(state_dict) # output option if out_type == None: - weight = state_dict["cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.weight"].numpy() + weight = state_dict["model.diffusion_model.input_blocks.0.0.weight"].numpy() if weight.dtype == np.float32: out_type = "f32" elif weight.dtype == np.float16: @@ -198,8 +288,9 @@ def convert(model_path, out_type = None, out_file=None): with open(out_file, "wb") as file: # magic: ggml in hex file.write(struct.pack("i", 0x67676D6C)) - # out type - file.write(struct.pack("i", ggml_ftype_str_to_int[out_type])) + # model & file type + ftype = (model_type << 16) | ggml_ftype_str_to_int[out_type] + file.write(struct.pack("i", ftype)) # vocab byte_encoder = bytes_to_unicode() diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 9d455f0..48dd429 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -48,6 +48,16 @@ static SDLogLevel log_level = SDLogLevel::INFO; #define TIMESTEPS 1000 +enum ModelType { + SD1 = 0, + SD2 = 1, + MODEL_TYPE_COUNT, +}; + +const char* model_type_to_str[] = { + "SD1.x", + "SD2.x"}; + /*================================================== Helper Functions ================================================*/ void set_sd_log_level(SDLogLevel level) { @@ -257,8 +267,8 @@ void image_vec_to_ggml(const std::vector& vec, } } -struct ggml_tensor * ggml_group_norm_32(struct ggml_context * ctx, - struct ggml_tensor * a) { +struct ggml_tensor* ggml_group_norm_32(struct ggml_context* ctx, + struct ggml_tensor* a) { return ggml_group_norm(ctx, a, 32); } @@ -278,6 +288,7 @@ const int PAD_TOKEN_ID = 49407; // TODO: implement bpe class CLIPTokenizer { private: + ModelType model_type = SD1; std::map encoder; std::regex pat; @@ -300,7 +311,8 @@ class CLIPTokenizer { } public: - CLIPTokenizer() = default; + CLIPTokenizer(ModelType model_type = SD1) + : model_type(model_type){}; std::string bpe(std::string token) { std::string word = token + ""; if (encoder.find(word) != encoder.end()) { @@ -321,13 +333,18 @@ class CLIPTokenizer { if (max_length > 0) { if (tokens.size() > max_length - 1) { tokens.resize(max_length - 1); + tokens.push_back(EOS_TOKEN_ID); } else { + tokens.push_back(EOS_TOKEN_ID); if (padding) { - tokens.insert(tokens.end(), max_length - 1 - tokens.size(), PAD_TOKEN_ID); + int pad_token_id = PAD_TOKEN_ID; + if (model_type == SD2) { + pad_token_id = 0; + } + tokens.insert(tokens.end(), max_length - tokens.size(), pad_token_id); } } } - tokens.push_back(EOS_TOKEN_ID); return tokens; } @@ -635,7 +652,11 @@ struct ResidualAttentionBlock { x = ggml_mul_mat(ctx, fc1_w, x); x = ggml_add(ctx, ggml_repeat(ctx, fc1_b, x), x); - x = ggml_gelu_quick_inplace(ctx, x); + if (hidden_size == 1024) { // SD 2.x + x = ggml_gelu_inplace(ctx, x); + } else { // SD 1.x + x = ggml_gelu_quick_inplace(ctx, x); + } x = ggml_mul_mat(ctx, fc2_w, x); x = ggml_add(ctx, ggml_repeat(ctx, fc2_b, x), x); @@ -647,26 +668,40 @@ struct ResidualAttentionBlock { } }; +// SD1.x: https://huggingface.co/openai/clip-vit-large-patch14/blob/main/config.json +// SD2.x: https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K/blob/main/config.json struct CLIPTextModel { + ModelType model_type = SD1; // network hparams int32_t vocab_size = 49408; int32_t max_position_embeddings = 77; - int32_t hidden_size = 768; - int32_t intermediate_size = 3072; - int32_t projection_dim = 768; - int32_t n_head = 12; // num_attention_heads - int32_t num_hidden_layers = 12; + int32_t hidden_size = 768; // 1024 for SD 2.x + int32_t intermediate_size = 3072; // 4096 for SD 2.x + int32_t n_head = 12; // num_attention_heads, 16 for SD 2.x + int32_t num_hidden_layers = 12; // 24 for SD 2.x // embeddings struct ggml_tensor* position_ids; struct ggml_tensor* token_embed_weight; struct ggml_tensor* position_embed_weight; // transformer - ResidualAttentionBlock resblocks[12]; + std::vector resblocks; struct ggml_tensor* final_ln_w; struct ggml_tensor* final_ln_b; - CLIPTextModel() { + CLIPTextModel(ModelType model_type = SD1) + : model_type(model_type) { + if (model_type == SD2) { + hidden_size = 1024; + intermediate_size = 4096; + n_head = 16; + num_hidden_layers = 24; + } + resblocks.resize(num_hidden_layers); + set_resblocks_hp_params(); + } + + void set_resblocks_hp_params() { int d_model = hidden_size / n_head; // 64 for (int i = 0; i < num_hidden_layers; i++) { resblocks[i].d_model = d_model; @@ -729,6 +764,9 @@ struct CLIPTextModel { // transformer for (int i = 0; i < num_hidden_layers; i++) { + if (model_type == SD2 && i == num_hidden_layers - 1) { // layer: "penultimate" + break; + } x = resblocks[i].forward(ctx, x); // [N, n_token, hidden_size] } @@ -759,9 +797,13 @@ struct FrozenCLIPEmbedder { // Ref: https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/cad87bf4e3e0b0a759afa94e933527c3123d59bc/modules/sd_hijack_clip.py#L283 struct FrozenCLIPEmbedderWithCustomWords { + ModelType model_type = SD1; CLIPTokenizer tokenizer; CLIPTextModel text_model; + FrozenCLIPEmbedderWithCustomWords(ModelType model_type = SD1) + : model_type(model_type), tokenizer(model_type), text_model(model_type) {} + std::pair, std::vector> tokenize(std::string text, size_t max_length = 0, bool padding = false) { @@ -793,15 +835,21 @@ struct FrozenCLIPEmbedderWithCustomWords { if (tokens.size() > max_length - 1) { tokens.resize(max_length - 1); weights.resize(max_length - 1); + tokens.push_back(EOS_TOKEN_ID); + weights.push_back(1.0); } else { + tokens.push_back(EOS_TOKEN_ID); + weights.push_back(1.0); if (padding) { - tokens.insert(tokens.end(), max_length - 1 - tokens.size(), PAD_TOKEN_ID); - weights.insert(weights.end(), max_length - 1 - weights.size(), 1.0); + int pad_token_id = PAD_TOKEN_ID; + if (model_type == SD2) { + pad_token_id = 0; + } + tokens.insert(tokens.end(), max_length - tokens.size(), pad_token_id); + weights.insert(weights.end(), max_length - weights.size(), 1.0); } } } - tokens.push_back(EOS_TOKEN_ID); - weights.push_back(1.0); // for (int i = 0; i < tokens.size(); i++) { // std::cout << tokens[i] << ":" << weights[i] << ", "; @@ -974,7 +1022,7 @@ struct SpatialTransformer { int n_head; // num_heads int d_head; // in_channels // n_heads int depth = 1; // 1 - int context_dim = 768; // hidden_size + int context_dim = 768; // hidden_size, 1024 for SD2.x // group norm struct ggml_tensor* norm_w; // [in_channels,] @@ -1459,6 +1507,7 @@ struct UNetModel { int time_embed_dim = 1280; // model_channels*4 int num_heads = 8; int num_head_channels = -1; // channels // num_heads + int context_dim = 768; // 1024 for SD2.x // network params struct ggml_tensor* time_embed_0_w; // [time_embed_dim, model_channels] @@ -1493,7 +1542,12 @@ struct UNetModel { struct ggml_tensor* out_2_w; // [out_channels, model_channels, 3, 3] struct ggml_tensor* out_2_b; // [out_channels, ] - UNetModel() { + UNetModel(ModelType model_type = SD1) { + if (model_type == SD2) { + context_dim = 1024; + num_head_channels = 64; + num_heads = -1; + } // set up hparams of blocks // input_blocks @@ -1513,9 +1567,16 @@ struct UNetModel { ch = mult * model_channels; if (ds == attention_resolutions[0] || ds == attention_resolutions[1] || ds == attention_resolutions[2]) { + int n_head = num_heads; + int d_head = ch / num_heads; + if (num_head_channels != -1) { + d_head = num_head_channels; + n_head = ch / d_head; + } input_transformers[i][j].in_channels = ch; - input_transformers[i][j].n_head = num_heads; - input_transformers[i][j].d_head = ch / num_heads; + input_transformers[i][j].n_head = n_head; + input_transformers[i][j].d_head = d_head; + input_transformers[i][j].context_dim = context_dim; } input_block_chans.push_back(ch); } @@ -1533,9 +1594,16 @@ struct UNetModel { middle_block_0.emb_channels = time_embed_dim; middle_block_0.out_channels = ch; + int n_head = num_heads; + int d_head = ch / num_heads; + if (num_head_channels != -1) { + d_head = num_head_channels; + n_head = ch / d_head; + } middle_block_1.in_channels = ch; - middle_block_1.n_head = num_heads; - middle_block_1.d_head = ch / num_heads; + middle_block_1.n_head = n_head; + middle_block_1.d_head = d_head; + middle_block_1.context_dim = context_dim; middle_block_2.channels = ch; middle_block_2.emb_channels = time_embed_dim; @@ -1555,9 +1623,16 @@ struct UNetModel { ch = mult * model_channels; if (ds == attention_resolutions[0] || ds == attention_resolutions[1] || ds == attention_resolutions[2]) { + int n_head = num_heads; + int d_head = ch / num_heads; + if (num_head_channels != -1) { + d_head = num_head_channels; + n_head = ch / d_head; + } output_transformers[i][j].in_channels = ch; - output_transformers[i][j].n_head = num_heads; - output_transformers[i][j].d_head = ch / num_heads; + output_transformers[i][j].n_head = n_head; + output_transformers[i][j].d_head = d_head; + output_transformers[i][j].context_dim = context_dim; } if (i > 0 && j == num_res_blocks) { @@ -2584,7 +2659,8 @@ struct AutoEncoderKL { /*================================================= CompVisDenoiser ==================================================*/ // Ref: https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/external.py -struct CompVisDenoiser { + +struct DiscreteSchedule { float alphas_cumprod[TIMESTEPS]; float sigmas[TIMESTEPS]; float log_sigmas[TIMESTEPS]; @@ -2602,12 +2678,6 @@ struct CompVisDenoiser { return result; } - std::pair get_scalings(float sigma) { - float c_out = -sigma; - float c_in = 1.0f / std::sqrt(sigma * sigma + 1); - return std::pair(c_in, c_out); - } - float sigma_to_t(float sigma) { float log_sigma = std::log(sigma); std::vector dists; @@ -2641,6 +2711,29 @@ struct CompVisDenoiser { float log_sigma = (1.0f - w) * log_sigmas[low_idx] + w * log_sigmas[high_idx]; return std::exp(log_sigma); } + + virtual std::vector get_scalings(float sigma) = 0; +}; + +struct CompVisDenoiser : public DiscreteSchedule { + float sigma_data = 1.0f; + + std::vector get_scalings(float sigma) { + float c_out = -sigma; + float c_in = 1.0f / std::sqrt(sigma * sigma + sigma_data * sigma_data); + return {c_out, c_in}; + } +}; + +struct CompVisVDenoiser : public DiscreteSchedule { + float sigma_data = 1.0f; + + std::vector get_scalings(float sigma) { + float c_skip = sigma_data * sigma_data / (sigma * sigma + sigma_data * sigma_data); + float c_out = -sigma * sigma_data / std::sqrt(sigma * sigma + sigma_data * sigma_data); + float c_in = 1.0f / std::sqrt(sigma * sigma + sigma_data * sigma_data); + return {c_skip, c_out, c_in}; + } }; /*=============================================== StableDiffusionGGML ================================================*/ @@ -2666,7 +2759,7 @@ class StableDiffusionGGML { UNetModel diffusion_model; AutoEncoderKL first_stage_model; - CompVisDenoiser denoiser; + std::shared_ptr denoiser = std::make_shared(); StableDiffusionGGML() = default; @@ -2717,9 +2810,20 @@ class StableDiffusionGGML { LOG_DEBUG("loading hparams"); // load hparams file.read(reinterpret_cast(&ftype), sizeof(ftype)); - // for the big tensors, we have the option to store the data in 16-bit floats or quantized - // in order to save memory and also to speed up the computation - ggml_type wtype = ggml_ftype_to_ggml_type((ggml_ftype)(ftype)); + + int model_type = (ftype >> 16) & 0xFFFF; + if (model_type >= MODEL_TYPE_COUNT) { + LOG_ERROR("invalid model file '%s' (bad model type value %d)", file_path.c_str(), ftype); + return false; + } + LOG_INFO("model type: %s", model_type_to_str[model_type]); + + if (model_type == SD2) { + cond_stage_model = FrozenCLIPEmbedderWithCustomWords((ModelType)model_type); + diffusion_model = UNetModel((ModelType)model_type); + } + + ggml_type wtype = ggml_ftype_to_ggml_type((ggml_ftype)(ftype & 0xFFFF)); LOG_INFO("ftype: %s", ggml_type_name(wtype)); if (wtype == GGML_TYPE_COUNT) { LOG_ERROR("invalid model file '%s' (bad ftype value %d)", file_path.c_str(), ftype); @@ -2840,6 +2944,7 @@ class StableDiffusionGGML { std::set tensor_names_in_file; int64_t t0 = ggml_time_ms(); // load weights + float alphas_cumprod[TIMESTEPS]; { int n_tensors = 0; size_t total_size = 0; @@ -2872,12 +2977,7 @@ class StableDiffusionGGML { tensor_names_in_file.insert(std::string(name.data())); if (std::string(name.data()) == "alphas_cumprod") { - file.read(reinterpret_cast(denoiser.alphas_cumprod), - nelements * ggml_type_size((ggml_type)ttype)); - for (int i = 0; i < 1000; i++) { - denoiser.sigmas[i] = std::sqrt((1 - denoiser.alphas_cumprod[i]) / denoiser.alphas_cumprod[i]); - denoiser.log_sigmas[i] = std::log(denoiser.sigmas[i]); - } + file.read(reinterpret_cast(alphas_cumprod), nelements * ggml_type_size((ggml_type)ttype)); continue; } @@ -2953,9 +3053,143 @@ class StableDiffusionGGML { int64_t t1 = ggml_time_ms(); LOG_INFO("loading model from '%s' completed, taking %.2fs", file_path.c_str(), (t1 - t0) * 1.0f / 1000); file.close(); + + // check is_using_v_parameterization_for_sd2 + bool is_using_v_parameterization = false; + if (model_type == SD2) { + struct ggml_init_params params; + params.mem_size = static_cast(10 * 1024) * 1024; // 10M + params.mem_buffer = NULL; + params.no_alloc = false; + params.dynamic = false; + struct ggml_context* ctx = ggml_init(params); + if (!ctx) { + LOG_ERROR("ggml_init() failed"); + return false; + } + if (is_using_v_parameterization_for_sd2(ctx)) { + is_using_v_parameterization = true; + } + } + + if (is_using_v_parameterization) { + denoiser = std::make_shared(); + LOG_INFO("running in v-prediction mode"); + } else { + LOG_INFO("running in eps-prediction mode"); + } + + for (int i = 0; i < TIMESTEPS; i++) { + denoiser->alphas_cumprod[i] = alphas_cumprod[i]; + denoiser->sigmas[i] = std::sqrt((1 - denoiser->alphas_cumprod[i]) / denoiser->alphas_cumprod[i]); + denoiser->log_sigmas[i] = std::log(denoiser->sigmas[i]); + } + return true; } + bool is_using_v_parameterization_for_sd2(ggml_context* res_ctx) { + struct ggml_tensor* x_t = ggml_new_tensor_4d(res_ctx, GGML_TYPE_F32, 8, 8, 4, 1); + ggml_set_f32(x_t, 0.5); + struct ggml_tensor* c = ggml_new_tensor_4d(res_ctx, GGML_TYPE_F32, 1024, 2, 1, 1); + ggml_set_f32(c, 0.5); + + size_t ctx_size = 1 * 1024 * 1024; // 1MB + // calculate the amount of memory required + { + struct ggml_init_params params; + params.mem_size = ctx_size; + params.mem_buffer = NULL; + params.no_alloc = true; + params.dynamic = dynamic; + + struct ggml_context* ctx = ggml_init(params); + if (!ctx) { + LOG_ERROR("ggml_init() failed"); + return false; + } + + ggml_set_dynamic(ctx, false); + struct ggml_tensor* timesteps = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1); // [N, ] + struct ggml_tensor* t_emb = new_timestep_embedding(ctx, timesteps, diffusion_model.model_channels); // [N, model_channels] + ggml_set_dynamic(ctx, params.dynamic); + + struct ggml_tensor* out = diffusion_model.forward(ctx, x_t, NULL, c, t_emb); + ctx_size += ggml_used_mem(ctx) + ggml_used_mem_of_data(ctx); + + struct ggml_cgraph diffusion_graph = ggml_build_forward(out); + struct ggml_cplan cplan = ggml_graph_plan(&diffusion_graph, n_threads); + + ctx_size += cplan.work_size; + LOG_DEBUG("diffusion context need %.2fMB static memory, with work_size needing %.2fMB", + ctx_size * 1.0f / 1024 / 1024, + cplan.work_size * 1.0f / 1024 / 1024); + + ggml_free(ctx); + } + + struct ggml_init_params params; + params.mem_size = ctx_size; + params.mem_buffer = NULL; + params.no_alloc = false; + params.dynamic = dynamic; + + struct ggml_context* ctx = ggml_init(params); + if (!ctx) { + LOG_ERROR("ggml_init() failed"); + return false; + } + + ggml_set_dynamic(ctx, false); + struct ggml_tensor* timesteps = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1); // [N, ] + struct ggml_tensor* t_emb = new_timestep_embedding(ctx, timesteps, diffusion_model.model_channels); // [N, model_channels] + ggml_set_dynamic(ctx, params.dynamic); + ggml_set_f32(timesteps, 999); + set_timestep_embedding(timesteps, t_emb, diffusion_model.model_channels); + + struct ggml_tensor* out = diffusion_model.forward(ctx, x_t, NULL, c, t_emb); + ggml_hold_dynamic_tensor(out); + + struct ggml_cgraph diffusion_graph = ggml_build_forward(out); + struct ggml_cplan cplan = ggml_graph_plan(&diffusion_graph, n_threads); + + ggml_set_dynamic(ctx, false); + struct ggml_tensor* buf = ggml_new_tensor_1d(ctx, GGML_TYPE_I8, cplan.work_size); + ggml_set_dynamic(ctx, params.dynamic); + + cplan.work_data = (uint8_t*)buf->data; + + int64_t t0 = ggml_time_ms(); + ggml_graph_compute(&diffusion_graph, &cplan); + + double result = 0.f; + + { + float* vec_x = (float*)x_t->data; + float* vec_out = (float*)out->data; + + int64_t n = ggml_nelements(out); + + for (int i = 0; i < n; i++) { + result += ((double)vec_out[i] - (double)vec_x[i]); + } + result /= n; + } + +#ifdef GGML_PERF + ggml_graph_print(&diffusion_graph); +#endif + int64_t t1 = ggml_time_ms(); + LOG_INFO("check is_using_v_parameterization_for_sd2 completed, taking %.2fs", (t1 - t0) * 1.0f / 1000); + LOG_DEBUG("diffusion graph use %.2fMB runtime memory: static %.2fMB, dynamic %.2fMB", + (ctx_size + ggml_curr_max_dynamic_size()) * 1.0f / 1024 / 1024, + ctx_size * 1.0f / 1024 / 1024, + ggml_curr_max_dynamic_size() * 1.0f / 1024 / 1024); + LOG_DEBUG("%zu bytes of dynamic memory has not been released yet", ggml_dynamic_size()); + + return result < -1; + } + ggml_tensor* get_learned_condition(ggml_context* res_ctx, const std::string& text) { auto tokens_and_weights = cond_stage_model.tokenize(text, cond_stage_model.text_model.max_position_embeddings, @@ -3093,8 +3327,8 @@ class StableDiffusionGGML { size_t steps = sigmas.size() - 1; // x_t = load_tensor_from_file(res_ctx, "./rand0.bin"); // print_ggml_tensor(x_t); - struct ggml_tensor* x_out = ggml_dup_tensor(res_ctx, x_t); - copy_ggml_tensor(x_out, x_t); + struct ggml_tensor* x = ggml_dup_tensor(res_ctx, x_t); + copy_ggml_tensor(x, x_t); size_t ctx_size = 1 * 1024 * 1024; // 1MB // calculate the amount of memory required @@ -3112,16 +3346,16 @@ class StableDiffusionGGML { } ggml_set_dynamic(ctx, false); - struct ggml_tensor* x = ggml_dup_tensor(ctx, x_t); + struct ggml_tensor* noised_input = ggml_dup_tensor(ctx, x_t); struct ggml_tensor* context = ggml_dup_tensor(ctx, c); struct ggml_tensor* timesteps = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1); // [N, ] struct ggml_tensor* t_emb = new_timestep_embedding(ctx, timesteps, diffusion_model.model_channels); // [N, model_channels] ggml_set_dynamic(ctx, params.dynamic); - struct ggml_tensor* eps = diffusion_model.forward(ctx, x, NULL, context, t_emb); + struct ggml_tensor* out = diffusion_model.forward(ctx, noised_input, NULL, context, t_emb); ctx_size += ggml_used_mem(ctx) + ggml_used_mem_of_data(ctx); - struct ggml_cgraph diffusion_graph = ggml_build_forward(eps); + struct ggml_cgraph diffusion_graph = ggml_build_forward(out); struct ggml_cplan cplan = ggml_graph_plan(&diffusion_graph, n_threads); ctx_size += cplan.work_size; @@ -3145,16 +3379,16 @@ class StableDiffusionGGML { } ggml_set_dynamic(ctx, false); - struct ggml_tensor* x = ggml_dup_tensor(ctx, x_t); + struct ggml_tensor* noised_input = ggml_dup_tensor(ctx, x_t); struct ggml_tensor* context = ggml_dup_tensor(ctx, c); struct ggml_tensor* timesteps = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1); // [N, ] struct ggml_tensor* t_emb = new_timestep_embedding(ctx, timesteps, diffusion_model.model_channels); // [N, model_channels] ggml_set_dynamic(ctx, params.dynamic); - struct ggml_tensor* eps = diffusion_model.forward(ctx, x, NULL, context, t_emb); - ggml_hold_dynamic_tensor(eps); + struct ggml_tensor* out = diffusion_model.forward(ctx, noised_input, NULL, context, t_emb); + ggml_hold_dynamic_tensor(out); - struct ggml_cgraph diffusion_graph = ggml_build_forward(eps); + struct ggml_cgraph diffusion_graph = ggml_build_forward(out); struct ggml_cplan cplan = ggml_graph_plan(&diffusion_graph, n_threads); ggml_set_dynamic(ctx, false); @@ -3163,80 +3397,129 @@ class StableDiffusionGGML { cplan.work_data = (uint8_t*)buf->data; + // x = x * sigmas[0] + { + float* vec = (float*)x->data; + for (int i = 0; i < ggml_nelements(x); i++) { + vec[i] = vec[i] * sigmas[0]; + } + } + + // denoise wrapper + ggml_set_dynamic(ctx, false); + struct ggml_tensor* out_cond = NULL; + struct ggml_tensor* out_uncond = NULL; + if (cfg_scale != 1.0f && uc != NULL) { + out_uncond = ggml_dup_tensor(ctx, x); + } + struct ggml_tensor* denoised = ggml_dup_tensor(ctx, x); + ggml_set_dynamic(ctx, params.dynamic); + + auto denoise = [&](ggml_tensor* input, float sigma, int step) { + int64_t t0 = ggml_time_ms(); + + float c_skip = 1.0f; + float c_out = 1.0f; + float c_in = 1.0f; + std::vector scaling = denoiser->get_scalings(sigma); + if (scaling.size() == 3) { // CompVisVDenoiser + c_skip = scaling[0]; + c_out = scaling[1]; + c_in = scaling[2]; + } else { // CompVisDenoiser + c_out = scaling[0]; + c_in = scaling[1]; + } + + float t = denoiser->sigma_to_t(sigma); + ggml_set_f32(timesteps, t); + set_timestep_embedding(timesteps, t_emb, diffusion_model.model_channels); + + copy_ggml_tensor(noised_input, input); + // noised_input = noised_input * c_in + { + float* vec = (float*)noised_input->data; + for (int i = 0; i < ggml_nelements(noised_input); i++) { + vec[i] = vec[i] * c_in; + } + } + + if (cfg_scale != 1.0 && uc != NULL) { + // uncond + copy_ggml_tensor(context, uc); + ggml_graph_compute(&diffusion_graph, &cplan); + copy_ggml_tensor(out_uncond, out); + + // cond + copy_ggml_tensor(context, c); + ggml_graph_compute(&diffusion_graph, &cplan); + + out_cond = out; + + // out_uncond + cfg_scale * (out_cond - out_uncond) + { + float* vec_out = (float*)out->data; + float* vec_out_uncond = (float*)out_uncond->data; + float* vec_out_cond = (float*)out_cond->data; + + for (int i = 0; i < ggml_nelements(out); i++) { + vec_out[i] = vec_out_uncond[i] + cfg_scale * (vec_out_cond[i] - vec_out_uncond[i]); + } + } + } else { + // cond + copy_ggml_tensor(context, c); + ggml_graph_compute(&diffusion_graph, &cplan); + } + + // v = out, eps = out + // denoised = (v * c_out + input * c_skip) or (input + eps * c_out) + { + float* vec_denoised = (float*)denoised->data; + float* vec_input = (float*)input->data; + float* vec_out = (float*)out->data; + + for (int i = 0; i < ggml_nelements(denoised); i++) { + vec_denoised[i] = vec_out[i] * c_out + vec_input[i] * c_skip; + } + } + +#ifdef GGML_PERF + ggml_graph_print(&diffusion_graph); +#endif + int64_t t1 = ggml_time_ms(); + LOG_INFO("step %d sampling completed, taking %.2fs", step, (t1 - t0) * 1.0f / 1000); + LOG_DEBUG("diffusion graph use %.2fMB runtime memory: static %.2fMB, dynamic %.2fMB", + (ctx_size + ggml_curr_max_dynamic_size()) * 1.0f / 1024 / 1024, + ctx_size * 1.0f / 1024 / 1024, + ggml_curr_max_dynamic_size() * 1.0f / 1024 / 1024); + LOG_DEBUG("%zu bytes of dynamic memory has not been released yet", ggml_dynamic_size()); + }; + // sample_euler_ancestral { ggml_set_dynamic(ctx, false); - struct ggml_tensor* eps_cond = NULL; - struct ggml_tensor* eps_uncond = NULL; - struct ggml_tensor* noise = ggml_dup_tensor(ctx, x_out); - if (cfg_scale != 1.0f && uc != NULL) { - eps_uncond = ggml_dup_tensor(ctx, x_out); - } - struct ggml_tensor* d = ggml_dup_tensor(ctx, x_out); + struct ggml_tensor* noise = ggml_dup_tensor(ctx, x); + struct ggml_tensor* d = ggml_dup_tensor(ctx, x); ggml_set_dynamic(ctx, params.dynamic); - // x_out = x_out * sigmas[0] - { - float* vec = (float*)x_out->data; - for (int i = 0; i < ggml_nelements(x_out); i++) { - vec[i] = vec[i] * sigmas[0]; - } - } - for (int i = 0; i < steps; i++) { - int64_t t0 = ggml_time_ms(); + float sigma = sigmas[i]; - copy_ggml_tensor(x, x_out); + // denoise + denoise(x, sigma, i + 1); - std::pair scaling = denoiser.get_scalings(sigmas[i]); - float c_in = scaling.first; - float c_out = scaling.second; - float t = denoiser.sigma_to_t(sigmas[i]); - ggml_set_f32(timesteps, t); - set_timestep_embedding(timesteps, t_emb, diffusion_model.model_channels); - - // x = x * c_in + // d = (x - denoised) / sigma { - float* vec = (float*)x->data; - for (int i = 0; i < ggml_nelements(x); i++) { - vec[i] = vec[i] * c_in; + float* vec_d = (float*)d->data; + float* vec_x = (float*)x->data; + float* vec_denoised = (float*)denoised->data; + + for (int i = 0; i < ggml_nelements(d); i++) { + vec_d[i] = (vec_x[i] - vec_denoised[i]) / sigma; } } - /*d = (x - denoised) / sigma - = (-eps_uncond * c_out - cfg_scale * (eps_cond * c_out - eps_uncond * c_out)) / sigma - = eps_uncond + cfg_scale * (eps_cond - eps_uncond)*/ - if (cfg_scale != 1.0 && uc != NULL) { - // uncond - copy_ggml_tensor(context, uc); - ggml_graph_compute(&diffusion_graph, &cplan); - copy_ggml_tensor(eps_uncond, eps); - - // cond - copy_ggml_tensor(context, c); - ggml_graph_compute(&diffusion_graph, &cplan); - - eps_cond = eps; - - /*d = (x - denoised) / sigma - = (-eps_uncond * c_out - cfg_scale * (eps_cond * c_out - eps_uncond * c_out)) / sigma - = eps_uncond + cfg_scale * (eps_cond - eps_uncond)*/ - { - float* vec_d = (float*)d->data; - float* vec_eps_uncond = (float*)eps_uncond->data; - float* vec_eps_cond = (float*)eps_cond->data; - - for (int i = 0; i < ggml_nelements(d); i++) { - vec_d[i] = vec_eps_uncond[i] + cfg_scale * (vec_eps_cond[i] - vec_eps_uncond[i]); - } - } - } else { - // cond - copy_ggml_tensor(context, c); - ggml_graph_compute(&diffusion_graph, &cplan); - copy_ggml_tensor(d, eps); - } - // get_ancestral_step float sigma_up = std::min(sigmas[i + 1], std::sqrt(sigmas[i + 1] * sigmas[i + 1] * (sigmas[i] * sigmas[i] - sigmas[i + 1] * sigmas[i + 1]) / (sigmas[i] * sigmas[i]))); @@ -3247,9 +3530,9 @@ class StableDiffusionGGML { // x = x + d * dt { float* vec_d = (float*)d->data; - float* vec_x = (float*)x_out->data; + float* vec_x = (float*)x->data; - for (int i = 0; i < ggml_nelements(x_out); i++) { + for (int i = 0; i < ggml_nelements(x); i++) { vec_x[i] = vec_x[i] + vec_d[i] * dt; } } @@ -3259,25 +3542,14 @@ class StableDiffusionGGML { ggml_tensor_set_f32_randn(noise); // noise = load_tensor_from_file(res_ctx, "./rand" + std::to_string(i+1) + ".bin"); { - float* vec_x = (float*)x_out->data; + float* vec_x = (float*)x->data; float* vec_noise = (float*)noise->data; - for (int i = 0; i < ggml_nelements(x_out); i++) { + for (int i = 0; i < ggml_nelements(x); i++) { vec_x[i] = vec_x[i] + vec_noise[i] * sigma_up; } } } - -#ifdef GGML_PERF - ggml_graph_print(&diffusion_graph); -#endif - int64_t t1 = ggml_time_ms(); - LOG_INFO("step %d sampling completed, taking %.2fs", i + 1, (t1 - t0) * 1.0f / 1000); - LOG_DEBUG("diffusion graph use %.2fMB runtime memory: static %.2fMB, dynamic %.2fMB", - (ctx_size + ggml_curr_max_dynamic_size()) * 1.0f / 1024 / 1024, - ctx_size * 1.0f / 1024 / 1024, - ggml_curr_max_dynamic_size() * 1.0f / 1024 / 1024); - LOG_DEBUG("%zu bytes of dynamic memory has not been released yet", ggml_dynamic_size()); } } @@ -3304,7 +3576,7 @@ class StableDiffusionGGML { ggml_free(ctx); - return x_out; + return x; } ggml_tensor* encode_first_stage(ggml_context* res_ctx, ggml_tensor* x) { @@ -3586,7 +3858,7 @@ std::vector StableDiffusion::txt2img(const std::string& prompt, struct ggml_tensor* x_t = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, W, H, C, 1); ggml_tensor_set_f32_randn(x_t); - std::vector sigmas = sd->denoiser.get_sigmas(sample_steps); + std::vector sigmas = sd->denoiser->get_sigmas(sample_steps); LOG_INFO("start sampling"); struct ggml_tensor* x_0 = sd->sample(ctx, x_t, c, uc, cfg_scale, sample_method, sigmas); @@ -3642,7 +3914,7 @@ std::vector StableDiffusion::img2img(const std::vector& init_i } LOG_INFO("img2img %dx%d", width, height); - std::vector sigmas = sd->denoiser.get_sigmas(sample_steps); + std::vector sigmas = sd->denoiser->get_sigmas(sample_steps); size_t t_enc = static_cast(sample_steps * strength); LOG_INFO("target t_enc is %zu steps", t_enc); std::vector sigma_sched;