diff --git a/README.md b/README.md index 96b9e73..5b19fc9 100644 --- a/README.md +++ b/README.md @@ -18,6 +18,8 @@ Inference of [Stable Diffusion](https://github.com/CompVis/stable-diffusion) in - Original `txt2img` and `img2img` mode - Negative prompt - [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) style tokenizer (not all the features, only token weighting for now) +- LoRA support, same as [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#lora) +- Latent Consistency Models support(LCM/LCM-LoRA) - Sampling method - `Euler A` - `Euler` @@ -42,7 +44,6 @@ Inference of [Stable Diffusion](https://github.com/CompVis/stable-diffusion) in - [ ] Make inference faster - The current implementation of ggml_conv_2d is slow and has high memory usage - [ ] Continuing to reduce memory usage (quantizing the weights of ggml_conv_2d) -- [ ] LoRA support - [ ] k-quants support ## Usage @@ -125,6 +126,7 @@ arguments: -t, --threads N number of threads to use during computation (default: -1). If threads <= 0, then threads will be set to the number of CPU physical cores -m, --model [MODEL] path to model + --lora-model-dir [DIR] lora model directory -i, --init-img [IMAGE] path to the input image, required by img2img -o, --output OUTPUT path to write result image to (default: .\output.png) -p, --prompt [PROMPT] the prompt to render @@ -134,11 +136,12 @@ arguments: 1.0 corresponds to full destruction of information in init image -H, --height H image height, in pixel space (default: 512) -W, --width W image width, in pixel space (default: 512) - --sampling-method {euler, euler_a, heun, dpm++2m, dpm++2mv2, lcm} + --sampling-method {euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, lcm} sampling method (default: "euler_a") --steps STEPS number of sample steps (default: 20) --rng {std_default, cuda} RNG (default: cuda) -s SEED, --seed SEED RNG seed (default: 42, use random seed for < 0) + --schedule {discrete, karras} Denoiser sigma schedule (default: discrete) -v, --verbose print extra info ``` @@ -167,6 +170,45 @@ Using formats of different precisions will yield results of varying quality.

+#### with LoRA + +- convert lora weights to ggml model format + + ```shell + cd models + python convert.py [path to weights] --lora + # For example, python convert.py marblesh.safetensors + ``` + +- You can specify the directory where the lora weights are stored via `--lora-model-dir`. If not specified, the default is the current working directory. + +- LoRA is specified via prompt, just like [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#lora). + +Here's a simple example: + +``` +./bin/sd -m ../models/v1-5-pruned-emaonly-ggml-model-f16.bin -p "a lovely cat" --lora-model-dir ../models +``` + +`../models/marblesh-ggml-lora.bin` will be applied to the model + +#### LCM/LCM-LoRA + +- Download LCM-LoRA form https://huggingface.co/latent-consistency/lcm-lora-sdv1-5 +- Specify LCM-LoRA by adding `` to prompt +- It's advisable to set `--cfg-scale` to `1.0` instead of the default `7.0`. For `--steps`, a range of `2-8` steps is recommended. For `--sampling-method`, `lcm`/`euler_a` is recommended. + +Here's a simple example: + +``` +./bin/sd -m ../models/v1-5-pruned-emaonly-ggml-model-f16.bin -p "a lovely cat" --steps 4 --lora-model-dir ../models -v --cfg-scale 1 +``` + +| without LCM-LoRA (--cfg-scale 7) | with LCM-LoRA (--cfg-scale 1) | +| ---- |---- | +| ![](./assets/without_lcm.png) |![](./assets/with_lcm.png) | + + ### Docker #### Building using Docker diff --git a/assets/with_lcm.png b/assets/with_lcm.png new file mode 100644 index 0000000..70e2c70 Binary files /dev/null and b/assets/with_lcm.png differ diff --git a/assets/without_lcm.png b/assets/without_lcm.png new file mode 100644 index 0000000..145ab94 Binary files /dev/null and b/assets/without_lcm.png differ diff --git a/examples/main.cpp b/examples/main.cpp index cfcb6e3..b97035a 100644 --- a/examples/main.cpp +++ b/examples/main.cpp @@ -95,6 +95,7 @@ struct Option { int n_threads = -1; std::string mode = TXT2IMG; std::string model_path; + std::string lora_model_dir; std::string output_path = "output.png"; std::string init_img; std::string prompt; @@ -115,6 +116,7 @@ struct Option { printf(" n_threads: %d\n", n_threads); printf(" mode: %s\n", mode.c_str()); printf(" model_path: %s\n", model_path.c_str()); + printf(" lora_model_dir: %s\n", lora_model_dir.c_str()); printf(" output_path: %s\n", output_path.c_str()); printf(" init_img: %s\n", init_img.c_str()); printf(" prompt: %s\n", prompt.c_str()); @@ -127,7 +129,7 @@ struct Option { printf(" sample_steps: %d\n", sample_steps); printf(" strength: %.2f\n", strength); printf(" rng: %s\n", rng_type_to_str[rng_type]); - printf(" seed: %ld\n", seed); + printf(" seed: %lld\n", seed); } }; @@ -140,6 +142,7 @@ void print_usage(int argc, const char* argv[]) { printf(" -t, --threads N number of threads to use during computation (default: -1).\n"); printf(" If threads <= 0, then threads will be set to the number of CPU physical cores\n"); printf(" -m, --model [MODEL] path to model\n"); + printf(" --lora-model-dir [DIR] lora model directory\n"); printf(" -i, --init-img [IMAGE] path to the input image, required by img2img\n"); printf(" -o, --output OUTPUT path to write result image to (default: .\\output.png)\n"); printf(" -p, --prompt [PROMPT] the prompt to render\n"); @@ -183,6 +186,12 @@ void parse_args(int argc, const char* argv[], Option* opt) { break; } opt->model_path = argv[i]; + } else if (arg == "--lora-model-dir") { + if (++i >= argc) { + invalid_arg = true; + break; + } + opt->lora_model_dir = argv[i]; } else if (arg == "-i" || arg == "--init-img") { if (++i >= argc) { invalid_arg = true; @@ -419,7 +428,7 @@ int main(int argc, const char* argv[]) { init_img.assign(img_data, img_data + (opt.w * opt.h * c)); } - StableDiffusion sd(opt.n_threads, vae_decode_only, true, opt.rng_type); + StableDiffusion sd(opt.n_threads, vae_decode_only, true, opt.lora_model_dir, opt.rng_type); if (!sd.load_from_file(opt.model_path, opt.schedule)) { return 1; } diff --git a/models/convert.py b/models/convert.py index 8ef2fcc..503b10d 100644 --- a/models/convert.py +++ b/models/convert.py @@ -4,6 +4,7 @@ import os import numpy as np import torch +import re import safetensors.torch this_file_dir = os.path.dirname(__file__) @@ -270,21 +271,107 @@ def preprocess(state_dict): new_state_dict[name] = w return new_state_dict -def convert(model_path, out_type = None, out_file=None): +re_digits = re.compile(r"\d+") +re_x_proj = re.compile(r"(.*)_([qkv]_proj)$") +re_compiled = {} + +suffix_conversion = { + "attentions": {}, + "resnets": { + "conv1": "in_layers_2", + "conv2": "out_layers_3", + "norm1": "in_layers_0", + "norm2": "out_layers_0", + "time_emb_proj": "emb_layers_1", + "conv_shortcut": "skip_connection", + } +} + + +def convert_diffusers_name_to_compvis(key): + def match(match_list, regex_text): + regex = re_compiled.get(regex_text) + if regex is None: + regex = re.compile(regex_text) + re_compiled[regex_text] = regex + + r = re.match(regex, key) + if not r: + return False + + match_list.clear() + match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()]) + return True + + m = [] + + if match(m, r"lora_unet_conv_in(.*)"): + return f'model_diffusion_model_input_blocks_0_0{m[0]}' + + if match(m, r"lora_unet_conv_out(.*)"): + return f'model_diffusion_model_out_2{m[0]}' + + if match(m, r"lora_unet_time_embedding_linear_(\d+)(.*)"): + return f"model_diffusion_model_time_embed_{m[0] * 2 - 2}{m[1]}" + + if match(m, r"lora_unet_down_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"): + suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3]) + return f"model_diffusion_model_input_blocks_{1 + m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}" + + if match(m, r"lora_unet_mid_block_(attentions|resnets)_(\d+)_(.+)"): + suffix = suffix_conversion.get(m[0], {}).get(m[2], m[2]) + return f"model_diffusion_model_middle_block_{1 if m[0] == 'attentions' else m[1] * 2}_{suffix}" + + if match(m, r"lora_unet_up_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"): + suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3]) + return f"model_diffusion_model_output_blocks_{m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}" + + if match(m, r"lora_unet_down_blocks_(\d+)_downsamplers_0_conv"): + return f"model_diffusion_model_input_blocks_{3 + m[0] * 3}_0_op" + + if match(m, r"lora_unet_up_blocks_(\d+)_upsamplers_0_conv"): + return f"model_diffusion_model_output_blocks_{2 + m[0] * 3}_{2 if m[0]>0 else 1}_conv" + + if match(m, r"lora_te_text_model_encoder_layers_(\d+)_(.+)"): + return f"cond_stage_model_transformer_text_model_encoder_layers_{m[0]}_{m[1]}" + + return None + +def preprocess_lora(state_dict): + new_state_dict = {} + for name, w in state_dict.items(): + if not isinstance(w, torch.Tensor): + continue + name_without_network_parts, network_part = name.split(".", 1) + new_name_without_network_parts = convert_diffusers_name_to_compvis(name_without_network_parts) + if new_name_without_network_parts == None: + raise Exception(f"unknown lora tensor: {name}") + new_name = new_name_without_network_parts + "." + network_part + print(f"preprocess {name} => {new_name}") + new_state_dict[new_name] = w + return new_state_dict + +def convert(model_path, out_type = None, out_file=None, lora=False): # load model - with open(os.path.join(vocab_dir, "vocab.json"), encoding="utf-8") as f: - clip_vocab = json.load(f) - + if not lora: + with open(os.path.join(vocab_dir, "vocab.json"), encoding="utf-8") as f: + clip_vocab = json.load(f) + state_dict = load_model_from_file(model_path) - model_type = SD1 - if "cond_stage_model.model.token_embedding.weight" in state_dict.keys(): + model_type = SD1 # lora only for SD1 now + if not lora and "cond_stage_model.model.token_embedding.weight" in state_dict.keys(): model_type = SD2 print("Stable diffuison 2.x") else: print("Stable diffuison 1.x") - state_dict = preprocess(state_dict) + if lora: + state_dict = preprocess_lora(state_dict) + else: + state_dict = preprocess(state_dict) # output option + if lora: + out_type = "f16" # only f16 for now if out_type == None: weight = state_dict["model.diffusion_model.input_blocks.0.0.weight"].numpy() if weight.dtype == np.float32: @@ -296,7 +383,10 @@ def convert(model_path, out_type = None, out_file=None): else: raise Exception("unsupported weight type %s" % weight.dtype) if out_file == None: - out_file = os.path.splitext(os.path.basename(model_path))[0] + f"-ggml-model-{out_type}.bin" + if lora: + out_file = os.path.splitext(os.path.basename(model_path))[0] + f"-ggml-lora.bin" + else: + out_file = os.path.splitext(os.path.basename(model_path))[0] + f"-ggml-model-{out_type}.bin" out_file = os.path.join(os.getcwd(), out_file) print(f"Saving GGML compatible file to {out_file}") @@ -309,14 +399,15 @@ def convert(model_path, out_type = None, out_file=None): file.write(struct.pack("i", ftype)) # vocab - byte_encoder = bytes_to_unicode() - byte_decoder = {v: k for k, v in byte_encoder.items()} - file.write(struct.pack("i", len(clip_vocab))) - for key in clip_vocab: - text = bytearray([byte_decoder[c] for c in key]) - file.write(struct.pack("i", len(text))) - file.write(text) - + if not lora: + byte_encoder = bytes_to_unicode() + byte_decoder = {v: k for k, v in byte_encoder.items()} + file.write(struct.pack("i", len(clip_vocab))) + for key in clip_vocab: + text = bytearray([byte_decoder[c] for c in key]) + file.write(struct.pack("i", len(text))) + file.write(text) + # weights for name in state_dict.keys(): if not isinstance(state_dict[name], torch.Tensor): @@ -337,7 +428,7 @@ def convert(model_path, out_type = None, out_file=None): old_type = data.dtype ttype = "f32" - if n_dims == 4: + if n_dims == 4 and not lora: data = data.astype(np.float16) ttype = "f16" elif n_dims == 2 and name[-7:] == ".weight": @@ -380,6 +471,7 @@ if __name__ == "__main__": parser = argparse.ArgumentParser(description="Convert Stable Diffuison model to GGML compatible file format") parser.add_argument("--out_type", choices=["f32", "f16", "q4_0", "q4_1", "q5_0", "q5_1", "q8_0"], help="output format (default: based on input)") parser.add_argument("--out_file", help="path to write to; default: based on input and current working directory") + parser.add_argument("--lora", action='store_true', default = False, help="convert lora weight; default: false") parser.add_argument("model_path", help="model file path (*.pth, *.pt, *.ckpt, *.safetensors)") args = parser.parse_args() - convert(args.model_path, args.out_type, args.out_file) + convert(args.model_path, args.out_type, args.out_file, args.lora) diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 0a461b8..8339501 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -268,6 +268,45 @@ struct ggml_tensor* ggml_group_norm_32(struct ggml_context* ctx, return ggml_group_norm(ctx, a, 32); } +std::pair, std::string> extract_and_remove_lora(std::string text) { + std::regex re("]+)>"); + std::smatch matches; + std::unordered_map filename2multiplier; + + while (std::regex_search(text, matches, re)) { + std::string filename = matches[1].str(); + float multiplier = std::stof(matches[2].str()); + if (multiplier < 0.f) { + continue; + } + if (filename2multiplier.find(filename) == filename2multiplier.end()) { + filename2multiplier[filename] = multiplier; + } else { + filename2multiplier[filename] += multiplier; + } + + text = std::regex_replace(text, re, ""); + } + + return std::make_pair(filename2multiplier, text); +} + +bool ends_with(const std::string& str, const std::string& ending) { + if (str.length() >= ending.length()) { + return (str.compare(str.length() - ending.length(), ending.length(), ending) == 0); + } else { + return false; + } +} + +void replace_all_chars(std::string& str, char target, char replacement) { + for (size_t i = 0; i < str.length(); ++i) { + if (str[i] == target) { + str[i] = replacement; + } + } +} + /*================================================== CLIPTokenizer ===================================================*/ const std::string UNK_TOKEN = "<|endoftext|>"; @@ -2794,6 +2833,16 @@ class StableDiffusionGGML { UNetModel diffusion_model; AutoEncoderKL first_stage_model; + std::map tensors; + + std::string lora_model_dir; + // lora_name => lora_tensor_name => tensor + std::map> lora_tensors; + // lora_name => lora_params_ctx + std::map lora_params_ctxs; + // lora_name => multiplier + std::unordered_map curr_lora_state; + std::shared_ptr denoiser = std::make_shared(); StableDiffusionGGML() = default; @@ -2801,16 +2850,23 @@ class StableDiffusionGGML { StableDiffusionGGML(int n_threads, bool vae_decode_only, bool free_params_immediately, + std::string lora_model_dir, RNGType rng_type) : n_threads(n_threads), vae_decode_only(vae_decode_only), - free_params_immediately(free_params_immediately) { + free_params_immediately(free_params_immediately), + lora_model_dir(lora_model_dir) { first_stage_model.decode_only = vae_decode_only; if (rng_type == STD_DEFAULT_RNG) { rng = std::make_shared(); } else if (rng_type == CUDA_RNG) { rng = std::make_shared(); } + if (lora_model_dir.size() > 0) { + if (lora_model_dir[lora_model_dir.size() - 1] != '/' && lora_model_dir[lora_model_dir.size() - 1] != '\\') { + this->lora_model_dir = lora_model_dir + "/"; + } + } } ~StableDiffusionGGML() { @@ -2826,6 +2882,13 @@ class StableDiffusionGGML { ggml_free(vae_params_ctx); vae_params_ctx = NULL; } + for (auto& kv : lora_params_ctxs) { + ggml_free(kv.second); + } + lora_params_ctxs.clear(); + + tensors.clear(); + lora_tensors.clear(); } bool load_from_file(const std::string& file_path, Schedule schedule) { @@ -2963,8 +3026,6 @@ class StableDiffusionGGML { } } - std::map tensors; - LOG_DEBUG("preparing memory for the weights"); // prepare memory for the weights { @@ -3255,6 +3316,306 @@ class StableDiffusionGGML { return result < -1; } + bool load_lora_from_file(const std::string& lora_name) { + if (lora_tensors.find(lora_name) != lora_tensors.end()) { + return true; + } + std::string file_path = lora_model_dir + lora_name + "-ggml-lora.bin"; + LOG_INFO("loading lora '%s' from '%s'", lora_name.c_str(), file_path.c_str()); + + std::ifstream file(file_path, std::ios::binary); + if (!file.is_open()) { + LOG_ERROR("failed to open '%s'", file_path.c_str()); + return false; + } + + // get file size + file.seekg(0, file.end); + int file_size = (int)file.tellg(); + file.seekg(0, file.beg); + + LOG_DEBUG("'%s': %.2fMB", file_path.c_str(), file_size * 1.f / 1024 / 1024); + + LOG_DEBUG("verifying magic"); + // verify magic + { + uint32_t magic; + file.read(reinterpret_cast(&magic), sizeof(magic)); + if (magic != GGML_FILE_MAGIC) { + LOG_ERROR("invalid model file '%s' (bad magic)", file_path.c_str()); + return false; + } + } + + LOG_DEBUG("loading hparams"); + // load hparams + file.read(reinterpret_cast(&ftype), sizeof(ftype)); + + int model_type = (ftype >> 16) & 0xFFFF; + if (model_type >= MODEL_TYPE_COUNT) { + LOG_ERROR("invalid model file '%s' (bad model type value %d)", file_path.c_str(), ftype); + return false; + } + LOG_INFO("lora model type: %s", model_type_to_str[model_type]); + + ggml_type wtype = ggml_ftype_to_ggml_type((ggml_ftype)(ftype & 0xFFFF)); + LOG_INFO("ftype: %s", ggml_type_name(wtype)); + if (wtype == GGML_TYPE_COUNT) { + LOG_ERROR("invalid model file '%s' (bad ftype value %d)", file_path.c_str(), ftype); + return false; + } + + // create the ggml context for network params + struct ggml_init_params params; + size_t ctx_size = 10 * 1024 * 1024; // 10 MB, for padding + ctx_size += file_size; + params.mem_size = ctx_size; + params.mem_buffer = NULL; + params.no_alloc = false; + params.dynamic = false; + LOG_DEBUG("lora '%s' params ctx size = % 6.2f MB", lora_name.c_str(), ctx_size / (1024.0 * 1024.0)); + ggml_context* lora_params_ctx = ggml_init(params); + if (!lora_params_ctx) { + LOG_ERROR("ggml_init() failed"); + return false; + } + lora_params_ctxs[lora_name] = lora_params_ctx; + + std::map lora_tensor_map; + int64_t t0 = ggml_time_ms(); + // load weights + { + int n_tensors = 0; + size_t total_size = 0; + + while (true) { + int32_t n_dims; + int32_t length; + int32_t ttype; + + file.read(reinterpret_cast(&n_dims), sizeof(n_dims)); + file.read(reinterpret_cast(&length), sizeof(length)); + file.read(reinterpret_cast(&ttype), sizeof(ttype)); + + if (file.eof()) { + break; + } + + int32_t nelements = 1; + int32_t ne[4] = {1, 1, 1, 1}; + for (int i = 0; i < n_dims; ++i) { + file.read(reinterpret_cast(&ne[i]), sizeof(ne[i])); + nelements *= ne[i]; + } + + const size_t num_bytes = nelements / ggml_blck_size(ggml_type(ttype)) * ggml_type_size(ggml_type(ttype)); + + std::string name_buf(length, 0); + file.read(&name_buf[0], length); + std::string name = std::string(name_buf.data()); + + // LOG_DEBUG("load lora tensor %s", name.c_str()); + + int64_t ne64[4] = {ne[0], ne[1], ne[2], ne[3]}; + struct ggml_tensor* tensor = ggml_new_tensor(lora_params_ctx, (ggml_type)ttype, n_dims, ne64); + file.read(reinterpret_cast(tensor->data), num_bytes); + + lora_tensor_map[name] = tensor; + + total_size += ggml_nbytes(tensor); + } + } + lora_tensors[lora_name] = lora_tensor_map; + int64_t t1 = ggml_time_ms(); + LOG_INFO("lora '%s' params size = %.2fMB", + lora_name.c_str(), + ggml_used_mem(lora_params_ctx) / 1024.0 / 1024.0); + LOG_INFO("loading lora from '%s' completed, taking %.2fs", file_path.c_str(), (t1 - t0) * 1.0f / 1000); + file.close(); + return true; + } + + void remove_lora_params(const std::string& lora_name) { + if (lora_params_ctxs.find(lora_name) == lora_params_ctxs.end()) { + return; + } + ggml_free(lora_params_ctxs[lora_name]); + lora_params_ctxs.erase(lora_name); + lora_tensors.erase(lora_name); + } + + void apply_lora(const std::string& lora_name, float multiplier) { + int64_t t0 = ggml_time_ms(); + if (!load_lora_from_file(lora_name)) { + std::string file_path = lora_model_dir + lora_name + "-ggml-lora.bin"; + LOG_WARN("apply lora '%s' failed", lora_name.c_str()); + return; + } + + size_t ctx_size = 500 * 1024 * 1024; // 500MB + void* mem_buffer = malloc(ctx_size); + if (!mem_buffer) { + if (free_params_immediately) { + remove_lora_params(lora_name); + } + LOG_ERROR("malloc() failed"); + return; + } + + std::map& lora_tensor_map = lora_tensors[lora_name]; + std::set applied_lora_tensors; + for (auto& kv : tensors) { + const std::string name = kv.first; + ggml_tensor* weight = kv.second; + std::string ending = ".weight"; + if (!ends_with(name, ending)) { + continue; + } + + // find corresponding lora tensors + std::string network_name = name.substr(0, name.size() - ending.size()); // remove .weight + replace_all_chars(network_name, '.', '_'); + std::string lora_up_name = network_name + ".lora_up.weight"; + std::string lora_down_name = network_name + ".lora_down.weight"; + std::string alpha_name = network_name + ".alpha"; + std::string scale_name = network_name + ".scale"; + + ggml_tensor* lora_up = NULL; + ggml_tensor* lora_down = NULL; + + float scale = 1.0f; + + if (lora_tensor_map.find(lora_up_name) != lora_tensor_map.end()) { + lora_up = lora_tensor_map[lora_up_name]; + } + + if (lora_tensor_map.find(lora_down_name) != lora_tensor_map.end()) { + lora_down = lora_tensor_map[lora_down_name]; + } + + if (lora_up == NULL || lora_down == NULL) { + continue; + } + + // LOG_DEBUG("apply lora tensor %s [%ld %ld %ld %ld]", network_name.c_str(), weight->ne[0], weight->ne[1], weight->ne[2], weight->ne[3]); + + applied_lora_tensors.insert(lora_up_name); + applied_lora_tensors.insert(lora_down_name); + applied_lora_tensors.insert(alpha_name); + applied_lora_tensors.insert(scale_name); + + // calc_scale + int64_t dim = lora_down->ne[lora_down->n_dims - 1]; + if (lora_tensor_map.find(scale_name) != lora_tensor_map.end()) { + ggml_tensor* t = lora_tensor_map[scale_name]; + scale = ggml_get_f32_1d(t, 0); + } else if (lora_tensor_map.find(alpha_name) != lora_tensor_map.end()) { + ggml_tensor* t = lora_tensor_map[alpha_name]; + scale = ggml_get_f32_1d(t, 0) / dim; + } + + // LOG_DEBUG("scale: %f %ld", scale, dim); + + scale = scale * multiplier; + + // apply + { + struct ggml_init_params params; + params.mem_size = ctx_size; + params.mem_buffer = mem_buffer; + params.no_alloc = false; + params.dynamic = false; + + struct ggml_context* ctx = ggml_init(params); + if (!ctx) { + LOG_ERROR("ggml_init() failed"); + free(mem_buffer); + if (free_params_immediately) { + remove_lora_params(lora_name); + } + return; + } + + ggml_tensor* scale_factor = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1); + ggml_set_f32_1d(scale_factor, 0, scale); + int64_t lora_up_size_0 = lora_up->ne[lora_up->n_dims - 1]; + lora_up = ggml_reshape_2d(ctx, lora_up, ggml_nelements(lora_up) / lora_up_size_0, lora_up_size_0); + int64_t lora_down_size_0 = lora_down->ne[lora_down->n_dims - 1]; + lora_down = ggml_reshape_2d(ctx, lora_down, ggml_nelements(lora_down) / lora_down_size_0, lora_down_size_0); + + lora_down = ggml_cont(ctx, ggml_transpose(ctx, lora_down)); + + if (lora_down->type != GGML_TYPE_F32) { + ggml_tensor* lora_down_f32 = ggml_new_tensor(ctx, GGML_TYPE_F32, lora_down->n_dims, lora_down->ne); + lora_down = ggml_cpy_inplace(ctx, lora_down, lora_down_f32); + } + + ggml_tensor* updown = ggml_mul_mat(ctx, lora_up, lora_down); + updown = ggml_cont(ctx, ggml_transpose(ctx, updown)); + updown = ggml_reshape(ctx, updown, weight); + + GGML_ASSERT(ggml_nelements(updown) == ggml_nelements(weight)); + + updown = ggml_scale_inplace(ctx, updown, scale_factor); + ggml_tensor* final_weight; + final_weight = ggml_add_inplace(ctx, weight, updown); + final_weight = ggml_cpy_inplace(ctx, final_weight, weight); + + struct ggml_cgraph* graph = ggml_build_forward_ctx(ctx, final_weight); + + ggml_graph_compute_with_ctx(ctx, graph, n_threads); + + // LOG_INFO("network_name '%s' ggml_used_mem size = %.2fMB", + // network_name.c_str(), + // ggml_used_mem(ctx) / 1024.0 / 1024.0); + + ggml_free(ctx); + } + } + free(mem_buffer); + + for (auto& kv : lora_tensor_map) { + if (applied_lora_tensors.find(kv.first) == applied_lora_tensors.end()) { + LOG_WARN("unused lora tensor %s", kv.first.c_str()); + } + } + + if (free_params_immediately) { + remove_lora_params(lora_name); + } + + int64_t t1 = ggml_time_ms(); + + LOG_INFO("apply lora '%s:%f' completed, taking %.2fs", + lora_name.c_str(), + multiplier, + (t1 - t0) * 1.0f / 1000); + } + + void apply_loras(const std::unordered_map& lora_state) { + std::unordered_map lora_state_diff; + for (auto& kv : lora_state) { + const std::string& lora_name = kv.first; + float multiplier = kv.second; + + if (curr_lora_state.find(lora_name) != curr_lora_state.end()) { + float curr_multiplier = curr_lora_state[lora_name]; + float multiplier_diff = multiplier - curr_multiplier; + if (multiplier_diff != 0.f) { + lora_state_diff[lora_name] = multiplier_diff; + } + } else { + lora_state_diff[lora_name] = multiplier; + } + } + + for (auto& kv : lora_state_diff) { + apply_lora(kv.first, kv.second); + } + + curr_lora_state = lora_state; + } + ggml_tensor* get_learned_condition(ggml_context* res_ctx, const std::string& text) { auto tokens_and_weights = cond_stage_model.tokenize(text, cond_stage_model.text_model.max_position_embeddings, @@ -4235,10 +4596,12 @@ class StableDiffusionGGML { StableDiffusion::StableDiffusion(int n_threads, bool vae_decode_only, bool free_params_immediately, + std::string lora_model_dir, RNGType rng_type) { sd = std::make_shared(n_threads, vae_decode_only, free_params_immediately, + lora_model_dir, rng_type); } @@ -4246,8 +4609,8 @@ bool StableDiffusion::load_from_file(const std::string& file_path, Schedule s) { return sd->load_from_file(file_path, s); } -std::vector StableDiffusion::txt2img(const std::string& prompt, - const std::string& negative_prompt, +std::vector StableDiffusion::txt2img(std::string prompt, + std::string negative_prompt, float cfg_scale, int width, int height, @@ -4272,13 +4635,28 @@ std::vector StableDiffusion::txt2img(const std::string& prompt, } sd->rng->manual_seed(seed); + // extract and remote lora + auto result_pair = extract_and_remove_lora(prompt); + std::unordered_map lora_f2m = result_pair.first; // lora_name -> multiplier + for (auto& kv : lora_f2m) { + LOG_DEBUG("lora %s:%.2f", kv.first.c_str(), kv.second); + } + prompt = result_pair.second; + LOG_DEBUG("prompt after extract and remote lora: \"%s\"", prompt.c_str()); + + // load lora from file int64_t t0 = ggml_time_ms(); + sd->apply_loras(lora_f2m); + int64_t t1 = ggml_time_ms(); + LOG_INFO("apply_loras completed, taking %.2fs", (t1 - t0) * 1.0f / 1000); + + t0 = ggml_time_ms(); ggml_tensor* c = sd->get_learned_condition(ctx, prompt); struct ggml_tensor* uc = NULL; if (cfg_scale != 1.0) { uc = sd->get_learned_condition(ctx, negative_prompt); } - int64_t t1 = ggml_time_ms(); + t1 = ggml_time_ms(); LOG_INFO("get_learned_condition completed, taking %.2fs", (t1 - t0) * 1.0f / 1000); if (sd->free_params_immediately) { @@ -4334,8 +4712,8 @@ std::vector StableDiffusion::txt2img(const std::string& prompt, } std::vector StableDiffusion::img2img(const std::vector& init_img_vec, - const std::string& prompt, - const std::string& negative_prompt, + std::string prompt, + std::string negative_prompt, float cfg_scale, int width, int height, @@ -4372,14 +4750,29 @@ std::vector StableDiffusion::img2img(const std::vector& init_i } sd->rng->manual_seed(seed); + // extract and remote lora + auto result_pair = extract_and_remove_lora(prompt); + std::unordered_map lora_f2m = result_pair.first; // lora_name -> multiplier + for (auto& kv : lora_f2m) { + LOG_DEBUG("lora %s:%.2f", kv.first.c_str(), kv.second); + } + prompt = result_pair.second; + LOG_DEBUG("prompt after extract and remote lora: \"%s\"", prompt.c_str()); + + // load lora from file + int64_t t0 = ggml_time_ms(); + sd->apply_loras(lora_f2m); + int64_t t1 = ggml_time_ms(); + LOG_INFO("apply_loras completed, taking %.2fs", (t1 - t0) * 1.0f / 1000); + ggml_tensor* init_img = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, width, height, 3, 1); image_vec_to_ggml(init_img_vec, init_img); - int64_t t0 = ggml_time_ms(); + t0 = ggml_time_ms(); ggml_tensor* moments = sd->encode_first_stage(ctx, init_img); ggml_tensor* init_latent = sd->get_first_stage_encoding(ctx, moments); // print_ggml_tensor(init_latent); - int64_t t1 = ggml_time_ms(); + t1 = ggml_time_ms(); LOG_INFO("encode_first_stage completed, taking %.2fs", (t1 - t0) * 1.0f / 1000); ggml_reset_curr_max_dynamic_size(); // reset counter diff --git a/stable-diffusion.h b/stable-diffusion.h index 14381db..ed8cd1f 100644 --- a/stable-diffusion.h +++ b/stable-diffusion.h @@ -45,11 +45,12 @@ class StableDiffusion { StableDiffusion(int n_threads = -1, bool vae_decode_only = false, bool free_params_immediately = false, + std::string lora_model_dir = "", RNGType rng_type = STD_DEFAULT_RNG); bool load_from_file(const std::string& file_path, Schedule d = DEFAULT); std::vector txt2img( - const std::string& prompt, - const std::string& negative_prompt, + std::string prompt, + std::string negative_prompt, float cfg_scale, int width, int height, @@ -58,8 +59,8 @@ class StableDiffusion { int64_t seed); std::vector img2img( const std::vector& init_img, - const std::string& prompt, - const std::string& negative_prompt, + std::string prompt, + std::string negative_prompt, float cfg_scale, int width, int height,