diff --git a/README.md b/README.md
index 96b9e73..5b19fc9 100644
--- a/README.md
+++ b/README.md
@@ -18,6 +18,8 @@ Inference of [Stable Diffusion](https://github.com/CompVis/stable-diffusion) in
- Original `txt2img` and `img2img` mode
- Negative prompt
- [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) style tokenizer (not all the features, only token weighting for now)
+- LoRA support, same as [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#lora)
+- Latent Consistency Models support(LCM/LCM-LoRA)
- Sampling method
- `Euler A`
- `Euler`
@@ -42,7 +44,6 @@ Inference of [Stable Diffusion](https://github.com/CompVis/stable-diffusion) in
- [ ] Make inference faster
- The current implementation of ggml_conv_2d is slow and has high memory usage
- [ ] Continuing to reduce memory usage (quantizing the weights of ggml_conv_2d)
-- [ ] LoRA support
- [ ] k-quants support
## Usage
@@ -125,6 +126,7 @@ arguments:
-t, --threads N number of threads to use during computation (default: -1).
If threads <= 0, then threads will be set to the number of CPU physical cores
-m, --model [MODEL] path to model
+ --lora-model-dir [DIR] lora model directory
-i, --init-img [IMAGE] path to the input image, required by img2img
-o, --output OUTPUT path to write result image to (default: .\output.png)
-p, --prompt [PROMPT] the prompt to render
@@ -134,11 +136,12 @@ arguments:
1.0 corresponds to full destruction of information in init image
-H, --height H image height, in pixel space (default: 512)
-W, --width W image width, in pixel space (default: 512)
- --sampling-method {euler, euler_a, heun, dpm++2m, dpm++2mv2, lcm}
+ --sampling-method {euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, lcm}
sampling method (default: "euler_a")
--steps STEPS number of sample steps (default: 20)
--rng {std_default, cuda} RNG (default: cuda)
-s SEED, --seed SEED RNG seed (default: 42, use random seed for < 0)
+ --schedule {discrete, karras} Denoiser sigma schedule (default: discrete)
-v, --verbose print extra info
```
@@ -167,6 +170,45 @@ Using formats of different precisions will yield results of varying quality.
+#### with LoRA
+
+- convert lora weights to ggml model format
+
+ ```shell
+ cd models
+ python convert.py [path to weights] --lora
+ # For example, python convert.py marblesh.safetensors
+ ```
+
+- You can specify the directory where the lora weights are stored via `--lora-model-dir`. If not specified, the default is the current working directory.
+
+- LoRA is specified via prompt, just like [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#lora).
+
+Here's a simple example:
+
+```
+./bin/sd -m ../models/v1-5-pruned-emaonly-ggml-model-f16.bin -p "a lovely cat" --lora-model-dir ../models
+```
+
+`../models/marblesh-ggml-lora.bin` will be applied to the model
+
+#### LCM/LCM-LoRA
+
+- Download LCM-LoRA form https://huggingface.co/latent-consistency/lcm-lora-sdv1-5
+- Specify LCM-LoRA by adding `` to prompt
+- It's advisable to set `--cfg-scale` to `1.0` instead of the default `7.0`. For `--steps`, a range of `2-8` steps is recommended. For `--sampling-method`, `lcm`/`euler_a` is recommended.
+
+Here's a simple example:
+
+```
+./bin/sd -m ../models/v1-5-pruned-emaonly-ggml-model-f16.bin -p "a lovely cat" --steps 4 --lora-model-dir ../models -v --cfg-scale 1
+```
+
+| without LCM-LoRA (--cfg-scale 7) | with LCM-LoRA (--cfg-scale 1) |
+| ---- |---- |
+|  | |
+
+
### Docker
#### Building using Docker
diff --git a/assets/with_lcm.png b/assets/with_lcm.png
new file mode 100644
index 0000000..70e2c70
Binary files /dev/null and b/assets/with_lcm.png differ
diff --git a/assets/without_lcm.png b/assets/without_lcm.png
new file mode 100644
index 0000000..145ab94
Binary files /dev/null and b/assets/without_lcm.png differ
diff --git a/examples/main.cpp b/examples/main.cpp
index cfcb6e3..b97035a 100644
--- a/examples/main.cpp
+++ b/examples/main.cpp
@@ -95,6 +95,7 @@ struct Option {
int n_threads = -1;
std::string mode = TXT2IMG;
std::string model_path;
+ std::string lora_model_dir;
std::string output_path = "output.png";
std::string init_img;
std::string prompt;
@@ -115,6 +116,7 @@ struct Option {
printf(" n_threads: %d\n", n_threads);
printf(" mode: %s\n", mode.c_str());
printf(" model_path: %s\n", model_path.c_str());
+ printf(" lora_model_dir: %s\n", lora_model_dir.c_str());
printf(" output_path: %s\n", output_path.c_str());
printf(" init_img: %s\n", init_img.c_str());
printf(" prompt: %s\n", prompt.c_str());
@@ -127,7 +129,7 @@ struct Option {
printf(" sample_steps: %d\n", sample_steps);
printf(" strength: %.2f\n", strength);
printf(" rng: %s\n", rng_type_to_str[rng_type]);
- printf(" seed: %ld\n", seed);
+ printf(" seed: %lld\n", seed);
}
};
@@ -140,6 +142,7 @@ void print_usage(int argc, const char* argv[]) {
printf(" -t, --threads N number of threads to use during computation (default: -1).\n");
printf(" If threads <= 0, then threads will be set to the number of CPU physical cores\n");
printf(" -m, --model [MODEL] path to model\n");
+ printf(" --lora-model-dir [DIR] lora model directory\n");
printf(" -i, --init-img [IMAGE] path to the input image, required by img2img\n");
printf(" -o, --output OUTPUT path to write result image to (default: .\\output.png)\n");
printf(" -p, --prompt [PROMPT] the prompt to render\n");
@@ -183,6 +186,12 @@ void parse_args(int argc, const char* argv[], Option* opt) {
break;
}
opt->model_path = argv[i];
+ } else if (arg == "--lora-model-dir") {
+ if (++i >= argc) {
+ invalid_arg = true;
+ break;
+ }
+ opt->lora_model_dir = argv[i];
} else if (arg == "-i" || arg == "--init-img") {
if (++i >= argc) {
invalid_arg = true;
@@ -419,7 +428,7 @@ int main(int argc, const char* argv[]) {
init_img.assign(img_data, img_data + (opt.w * opt.h * c));
}
- StableDiffusion sd(opt.n_threads, vae_decode_only, true, opt.rng_type);
+ StableDiffusion sd(opt.n_threads, vae_decode_only, true, opt.lora_model_dir, opt.rng_type);
if (!sd.load_from_file(opt.model_path, opt.schedule)) {
return 1;
}
diff --git a/models/convert.py b/models/convert.py
index 8ef2fcc..503b10d 100644
--- a/models/convert.py
+++ b/models/convert.py
@@ -4,6 +4,7 @@ import os
import numpy as np
import torch
+import re
import safetensors.torch
this_file_dir = os.path.dirname(__file__)
@@ -270,21 +271,107 @@ def preprocess(state_dict):
new_state_dict[name] = w
return new_state_dict
-def convert(model_path, out_type = None, out_file=None):
+re_digits = re.compile(r"\d+")
+re_x_proj = re.compile(r"(.*)_([qkv]_proj)$")
+re_compiled = {}
+
+suffix_conversion = {
+ "attentions": {},
+ "resnets": {
+ "conv1": "in_layers_2",
+ "conv2": "out_layers_3",
+ "norm1": "in_layers_0",
+ "norm2": "out_layers_0",
+ "time_emb_proj": "emb_layers_1",
+ "conv_shortcut": "skip_connection",
+ }
+}
+
+
+def convert_diffusers_name_to_compvis(key):
+ def match(match_list, regex_text):
+ regex = re_compiled.get(regex_text)
+ if regex is None:
+ regex = re.compile(regex_text)
+ re_compiled[regex_text] = regex
+
+ r = re.match(regex, key)
+ if not r:
+ return False
+
+ match_list.clear()
+ match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()])
+ return True
+
+ m = []
+
+ if match(m, r"lora_unet_conv_in(.*)"):
+ return f'model_diffusion_model_input_blocks_0_0{m[0]}'
+
+ if match(m, r"lora_unet_conv_out(.*)"):
+ return f'model_diffusion_model_out_2{m[0]}'
+
+ if match(m, r"lora_unet_time_embedding_linear_(\d+)(.*)"):
+ return f"model_diffusion_model_time_embed_{m[0] * 2 - 2}{m[1]}"
+
+ if match(m, r"lora_unet_down_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"):
+ suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3])
+ return f"model_diffusion_model_input_blocks_{1 + m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}"
+
+ if match(m, r"lora_unet_mid_block_(attentions|resnets)_(\d+)_(.+)"):
+ suffix = suffix_conversion.get(m[0], {}).get(m[2], m[2])
+ return f"model_diffusion_model_middle_block_{1 if m[0] == 'attentions' else m[1] * 2}_{suffix}"
+
+ if match(m, r"lora_unet_up_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"):
+ suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3])
+ return f"model_diffusion_model_output_blocks_{m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}"
+
+ if match(m, r"lora_unet_down_blocks_(\d+)_downsamplers_0_conv"):
+ return f"model_diffusion_model_input_blocks_{3 + m[0] * 3}_0_op"
+
+ if match(m, r"lora_unet_up_blocks_(\d+)_upsamplers_0_conv"):
+ return f"model_diffusion_model_output_blocks_{2 + m[0] * 3}_{2 if m[0]>0 else 1}_conv"
+
+ if match(m, r"lora_te_text_model_encoder_layers_(\d+)_(.+)"):
+ return f"cond_stage_model_transformer_text_model_encoder_layers_{m[0]}_{m[1]}"
+
+ return None
+
+def preprocess_lora(state_dict):
+ new_state_dict = {}
+ for name, w in state_dict.items():
+ if not isinstance(w, torch.Tensor):
+ continue
+ name_without_network_parts, network_part = name.split(".", 1)
+ new_name_without_network_parts = convert_diffusers_name_to_compvis(name_without_network_parts)
+ if new_name_without_network_parts == None:
+ raise Exception(f"unknown lora tensor: {name}")
+ new_name = new_name_without_network_parts + "." + network_part
+ print(f"preprocess {name} => {new_name}")
+ new_state_dict[new_name] = w
+ return new_state_dict
+
+def convert(model_path, out_type = None, out_file=None, lora=False):
# load model
- with open(os.path.join(vocab_dir, "vocab.json"), encoding="utf-8") as f:
- clip_vocab = json.load(f)
-
+ if not lora:
+ with open(os.path.join(vocab_dir, "vocab.json"), encoding="utf-8") as f:
+ clip_vocab = json.load(f)
+
state_dict = load_model_from_file(model_path)
- model_type = SD1
- if "cond_stage_model.model.token_embedding.weight" in state_dict.keys():
+ model_type = SD1 # lora only for SD1 now
+ if not lora and "cond_stage_model.model.token_embedding.weight" in state_dict.keys():
model_type = SD2
print("Stable diffuison 2.x")
else:
print("Stable diffuison 1.x")
- state_dict = preprocess(state_dict)
+ if lora:
+ state_dict = preprocess_lora(state_dict)
+ else:
+ state_dict = preprocess(state_dict)
# output option
+ if lora:
+ out_type = "f16" # only f16 for now
if out_type == None:
weight = state_dict["model.diffusion_model.input_blocks.0.0.weight"].numpy()
if weight.dtype == np.float32:
@@ -296,7 +383,10 @@ def convert(model_path, out_type = None, out_file=None):
else:
raise Exception("unsupported weight type %s" % weight.dtype)
if out_file == None:
- out_file = os.path.splitext(os.path.basename(model_path))[0] + f"-ggml-model-{out_type}.bin"
+ if lora:
+ out_file = os.path.splitext(os.path.basename(model_path))[0] + f"-ggml-lora.bin"
+ else:
+ out_file = os.path.splitext(os.path.basename(model_path))[0] + f"-ggml-model-{out_type}.bin"
out_file = os.path.join(os.getcwd(), out_file)
print(f"Saving GGML compatible file to {out_file}")
@@ -309,14 +399,15 @@ def convert(model_path, out_type = None, out_file=None):
file.write(struct.pack("i", ftype))
# vocab
- byte_encoder = bytes_to_unicode()
- byte_decoder = {v: k for k, v in byte_encoder.items()}
- file.write(struct.pack("i", len(clip_vocab)))
- for key in clip_vocab:
- text = bytearray([byte_decoder[c] for c in key])
- file.write(struct.pack("i", len(text)))
- file.write(text)
-
+ if not lora:
+ byte_encoder = bytes_to_unicode()
+ byte_decoder = {v: k for k, v in byte_encoder.items()}
+ file.write(struct.pack("i", len(clip_vocab)))
+ for key in clip_vocab:
+ text = bytearray([byte_decoder[c] for c in key])
+ file.write(struct.pack("i", len(text)))
+ file.write(text)
+
# weights
for name in state_dict.keys():
if not isinstance(state_dict[name], torch.Tensor):
@@ -337,7 +428,7 @@ def convert(model_path, out_type = None, out_file=None):
old_type = data.dtype
ttype = "f32"
- if n_dims == 4:
+ if n_dims == 4 and not lora:
data = data.astype(np.float16)
ttype = "f16"
elif n_dims == 2 and name[-7:] == ".weight":
@@ -380,6 +471,7 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert Stable Diffuison model to GGML compatible file format")
parser.add_argument("--out_type", choices=["f32", "f16", "q4_0", "q4_1", "q5_0", "q5_1", "q8_0"], help="output format (default: based on input)")
parser.add_argument("--out_file", help="path to write to; default: based on input and current working directory")
+ parser.add_argument("--lora", action='store_true', default = False, help="convert lora weight; default: false")
parser.add_argument("model_path", help="model file path (*.pth, *.pt, *.ckpt, *.safetensors)")
args = parser.parse_args()
- convert(args.model_path, args.out_type, args.out_file)
+ convert(args.model_path, args.out_type, args.out_file, args.lora)
diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp
index 0a461b8..8339501 100644
--- a/stable-diffusion.cpp
+++ b/stable-diffusion.cpp
@@ -268,6 +268,45 @@ struct ggml_tensor* ggml_group_norm_32(struct ggml_context* ctx,
return ggml_group_norm(ctx, a, 32);
}
+std::pair, std::string> extract_and_remove_lora(std::string text) {
+ std::regex re("]+)>");
+ std::smatch matches;
+ std::unordered_map filename2multiplier;
+
+ while (std::regex_search(text, matches, re)) {
+ std::string filename = matches[1].str();
+ float multiplier = std::stof(matches[2].str());
+ if (multiplier < 0.f) {
+ continue;
+ }
+ if (filename2multiplier.find(filename) == filename2multiplier.end()) {
+ filename2multiplier[filename] = multiplier;
+ } else {
+ filename2multiplier[filename] += multiplier;
+ }
+
+ text = std::regex_replace(text, re, "");
+ }
+
+ return std::make_pair(filename2multiplier, text);
+}
+
+bool ends_with(const std::string& str, const std::string& ending) {
+ if (str.length() >= ending.length()) {
+ return (str.compare(str.length() - ending.length(), ending.length(), ending) == 0);
+ } else {
+ return false;
+ }
+}
+
+void replace_all_chars(std::string& str, char target, char replacement) {
+ for (size_t i = 0; i < str.length(); ++i) {
+ if (str[i] == target) {
+ str[i] = replacement;
+ }
+ }
+}
+
/*================================================== CLIPTokenizer ===================================================*/
const std::string UNK_TOKEN = "<|endoftext|>";
@@ -2794,6 +2833,16 @@ class StableDiffusionGGML {
UNetModel diffusion_model;
AutoEncoderKL first_stage_model;
+ std::map tensors;
+
+ std::string lora_model_dir;
+ // lora_name => lora_tensor_name => tensor
+ std::map> lora_tensors;
+ // lora_name => lora_params_ctx
+ std::map lora_params_ctxs;
+ // lora_name => multiplier
+ std::unordered_map curr_lora_state;
+
std::shared_ptr denoiser = std::make_shared();
StableDiffusionGGML() = default;
@@ -2801,16 +2850,23 @@ class StableDiffusionGGML {
StableDiffusionGGML(int n_threads,
bool vae_decode_only,
bool free_params_immediately,
+ std::string lora_model_dir,
RNGType rng_type)
: n_threads(n_threads),
vae_decode_only(vae_decode_only),
- free_params_immediately(free_params_immediately) {
+ free_params_immediately(free_params_immediately),
+ lora_model_dir(lora_model_dir) {
first_stage_model.decode_only = vae_decode_only;
if (rng_type == STD_DEFAULT_RNG) {
rng = std::make_shared();
} else if (rng_type == CUDA_RNG) {
rng = std::make_shared();
}
+ if (lora_model_dir.size() > 0) {
+ if (lora_model_dir[lora_model_dir.size() - 1] != '/' && lora_model_dir[lora_model_dir.size() - 1] != '\\') {
+ this->lora_model_dir = lora_model_dir + "/";
+ }
+ }
}
~StableDiffusionGGML() {
@@ -2826,6 +2882,13 @@ class StableDiffusionGGML {
ggml_free(vae_params_ctx);
vae_params_ctx = NULL;
}
+ for (auto& kv : lora_params_ctxs) {
+ ggml_free(kv.second);
+ }
+ lora_params_ctxs.clear();
+
+ tensors.clear();
+ lora_tensors.clear();
}
bool load_from_file(const std::string& file_path, Schedule schedule) {
@@ -2963,8 +3026,6 @@ class StableDiffusionGGML {
}
}
- std::map tensors;
-
LOG_DEBUG("preparing memory for the weights");
// prepare memory for the weights
{
@@ -3255,6 +3316,306 @@ class StableDiffusionGGML {
return result < -1;
}
+ bool load_lora_from_file(const std::string& lora_name) {
+ if (lora_tensors.find(lora_name) != lora_tensors.end()) {
+ return true;
+ }
+ std::string file_path = lora_model_dir + lora_name + "-ggml-lora.bin";
+ LOG_INFO("loading lora '%s' from '%s'", lora_name.c_str(), file_path.c_str());
+
+ std::ifstream file(file_path, std::ios::binary);
+ if (!file.is_open()) {
+ LOG_ERROR("failed to open '%s'", file_path.c_str());
+ return false;
+ }
+
+ // get file size
+ file.seekg(0, file.end);
+ int file_size = (int)file.tellg();
+ file.seekg(0, file.beg);
+
+ LOG_DEBUG("'%s': %.2fMB", file_path.c_str(), file_size * 1.f / 1024 / 1024);
+
+ LOG_DEBUG("verifying magic");
+ // verify magic
+ {
+ uint32_t magic;
+ file.read(reinterpret_cast(&magic), sizeof(magic));
+ if (magic != GGML_FILE_MAGIC) {
+ LOG_ERROR("invalid model file '%s' (bad magic)", file_path.c_str());
+ return false;
+ }
+ }
+
+ LOG_DEBUG("loading hparams");
+ // load hparams
+ file.read(reinterpret_cast(&ftype), sizeof(ftype));
+
+ int model_type = (ftype >> 16) & 0xFFFF;
+ if (model_type >= MODEL_TYPE_COUNT) {
+ LOG_ERROR("invalid model file '%s' (bad model type value %d)", file_path.c_str(), ftype);
+ return false;
+ }
+ LOG_INFO("lora model type: %s", model_type_to_str[model_type]);
+
+ ggml_type wtype = ggml_ftype_to_ggml_type((ggml_ftype)(ftype & 0xFFFF));
+ LOG_INFO("ftype: %s", ggml_type_name(wtype));
+ if (wtype == GGML_TYPE_COUNT) {
+ LOG_ERROR("invalid model file '%s' (bad ftype value %d)", file_path.c_str(), ftype);
+ return false;
+ }
+
+ // create the ggml context for network params
+ struct ggml_init_params params;
+ size_t ctx_size = 10 * 1024 * 1024; // 10 MB, for padding
+ ctx_size += file_size;
+ params.mem_size = ctx_size;
+ params.mem_buffer = NULL;
+ params.no_alloc = false;
+ params.dynamic = false;
+ LOG_DEBUG("lora '%s' params ctx size = % 6.2f MB", lora_name.c_str(), ctx_size / (1024.0 * 1024.0));
+ ggml_context* lora_params_ctx = ggml_init(params);
+ if (!lora_params_ctx) {
+ LOG_ERROR("ggml_init() failed");
+ return false;
+ }
+ lora_params_ctxs[lora_name] = lora_params_ctx;
+
+ std::map lora_tensor_map;
+ int64_t t0 = ggml_time_ms();
+ // load weights
+ {
+ int n_tensors = 0;
+ size_t total_size = 0;
+
+ while (true) {
+ int32_t n_dims;
+ int32_t length;
+ int32_t ttype;
+
+ file.read(reinterpret_cast(&n_dims), sizeof(n_dims));
+ file.read(reinterpret_cast(&length), sizeof(length));
+ file.read(reinterpret_cast(&ttype), sizeof(ttype));
+
+ if (file.eof()) {
+ break;
+ }
+
+ int32_t nelements = 1;
+ int32_t ne[4] = {1, 1, 1, 1};
+ for (int i = 0; i < n_dims; ++i) {
+ file.read(reinterpret_cast(&ne[i]), sizeof(ne[i]));
+ nelements *= ne[i];
+ }
+
+ const size_t num_bytes = nelements / ggml_blck_size(ggml_type(ttype)) * ggml_type_size(ggml_type(ttype));
+
+ std::string name_buf(length, 0);
+ file.read(&name_buf[0], length);
+ std::string name = std::string(name_buf.data());
+
+ // LOG_DEBUG("load lora tensor %s", name.c_str());
+
+ int64_t ne64[4] = {ne[0], ne[1], ne[2], ne[3]};
+ struct ggml_tensor* tensor = ggml_new_tensor(lora_params_ctx, (ggml_type)ttype, n_dims, ne64);
+ file.read(reinterpret_cast(tensor->data), num_bytes);
+
+ lora_tensor_map[name] = tensor;
+
+ total_size += ggml_nbytes(tensor);
+ }
+ }
+ lora_tensors[lora_name] = lora_tensor_map;
+ int64_t t1 = ggml_time_ms();
+ LOG_INFO("lora '%s' params size = %.2fMB",
+ lora_name.c_str(),
+ ggml_used_mem(lora_params_ctx) / 1024.0 / 1024.0);
+ LOG_INFO("loading lora from '%s' completed, taking %.2fs", file_path.c_str(), (t1 - t0) * 1.0f / 1000);
+ file.close();
+ return true;
+ }
+
+ void remove_lora_params(const std::string& lora_name) {
+ if (lora_params_ctxs.find(lora_name) == lora_params_ctxs.end()) {
+ return;
+ }
+ ggml_free(lora_params_ctxs[lora_name]);
+ lora_params_ctxs.erase(lora_name);
+ lora_tensors.erase(lora_name);
+ }
+
+ void apply_lora(const std::string& lora_name, float multiplier) {
+ int64_t t0 = ggml_time_ms();
+ if (!load_lora_from_file(lora_name)) {
+ std::string file_path = lora_model_dir + lora_name + "-ggml-lora.bin";
+ LOG_WARN("apply lora '%s' failed", lora_name.c_str());
+ return;
+ }
+
+ size_t ctx_size = 500 * 1024 * 1024; // 500MB
+ void* mem_buffer = malloc(ctx_size);
+ if (!mem_buffer) {
+ if (free_params_immediately) {
+ remove_lora_params(lora_name);
+ }
+ LOG_ERROR("malloc() failed");
+ return;
+ }
+
+ std::map& lora_tensor_map = lora_tensors[lora_name];
+ std::set applied_lora_tensors;
+ for (auto& kv : tensors) {
+ const std::string name = kv.first;
+ ggml_tensor* weight = kv.second;
+ std::string ending = ".weight";
+ if (!ends_with(name, ending)) {
+ continue;
+ }
+
+ // find corresponding lora tensors
+ std::string network_name = name.substr(0, name.size() - ending.size()); // remove .weight
+ replace_all_chars(network_name, '.', '_');
+ std::string lora_up_name = network_name + ".lora_up.weight";
+ std::string lora_down_name = network_name + ".lora_down.weight";
+ std::string alpha_name = network_name + ".alpha";
+ std::string scale_name = network_name + ".scale";
+
+ ggml_tensor* lora_up = NULL;
+ ggml_tensor* lora_down = NULL;
+
+ float scale = 1.0f;
+
+ if (lora_tensor_map.find(lora_up_name) != lora_tensor_map.end()) {
+ lora_up = lora_tensor_map[lora_up_name];
+ }
+
+ if (lora_tensor_map.find(lora_down_name) != lora_tensor_map.end()) {
+ lora_down = lora_tensor_map[lora_down_name];
+ }
+
+ if (lora_up == NULL || lora_down == NULL) {
+ continue;
+ }
+
+ // LOG_DEBUG("apply lora tensor %s [%ld %ld %ld %ld]", network_name.c_str(), weight->ne[0], weight->ne[1], weight->ne[2], weight->ne[3]);
+
+ applied_lora_tensors.insert(lora_up_name);
+ applied_lora_tensors.insert(lora_down_name);
+ applied_lora_tensors.insert(alpha_name);
+ applied_lora_tensors.insert(scale_name);
+
+ // calc_scale
+ int64_t dim = lora_down->ne[lora_down->n_dims - 1];
+ if (lora_tensor_map.find(scale_name) != lora_tensor_map.end()) {
+ ggml_tensor* t = lora_tensor_map[scale_name];
+ scale = ggml_get_f32_1d(t, 0);
+ } else if (lora_tensor_map.find(alpha_name) != lora_tensor_map.end()) {
+ ggml_tensor* t = lora_tensor_map[alpha_name];
+ scale = ggml_get_f32_1d(t, 0) / dim;
+ }
+
+ // LOG_DEBUG("scale: %f %ld", scale, dim);
+
+ scale = scale * multiplier;
+
+ // apply
+ {
+ struct ggml_init_params params;
+ params.mem_size = ctx_size;
+ params.mem_buffer = mem_buffer;
+ params.no_alloc = false;
+ params.dynamic = false;
+
+ struct ggml_context* ctx = ggml_init(params);
+ if (!ctx) {
+ LOG_ERROR("ggml_init() failed");
+ free(mem_buffer);
+ if (free_params_immediately) {
+ remove_lora_params(lora_name);
+ }
+ return;
+ }
+
+ ggml_tensor* scale_factor = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
+ ggml_set_f32_1d(scale_factor, 0, scale);
+ int64_t lora_up_size_0 = lora_up->ne[lora_up->n_dims - 1];
+ lora_up = ggml_reshape_2d(ctx, lora_up, ggml_nelements(lora_up) / lora_up_size_0, lora_up_size_0);
+ int64_t lora_down_size_0 = lora_down->ne[lora_down->n_dims - 1];
+ lora_down = ggml_reshape_2d(ctx, lora_down, ggml_nelements(lora_down) / lora_down_size_0, lora_down_size_0);
+
+ lora_down = ggml_cont(ctx, ggml_transpose(ctx, lora_down));
+
+ if (lora_down->type != GGML_TYPE_F32) {
+ ggml_tensor* lora_down_f32 = ggml_new_tensor(ctx, GGML_TYPE_F32, lora_down->n_dims, lora_down->ne);
+ lora_down = ggml_cpy_inplace(ctx, lora_down, lora_down_f32);
+ }
+
+ ggml_tensor* updown = ggml_mul_mat(ctx, lora_up, lora_down);
+ updown = ggml_cont(ctx, ggml_transpose(ctx, updown));
+ updown = ggml_reshape(ctx, updown, weight);
+
+ GGML_ASSERT(ggml_nelements(updown) == ggml_nelements(weight));
+
+ updown = ggml_scale_inplace(ctx, updown, scale_factor);
+ ggml_tensor* final_weight;
+ final_weight = ggml_add_inplace(ctx, weight, updown);
+ final_weight = ggml_cpy_inplace(ctx, final_weight, weight);
+
+ struct ggml_cgraph* graph = ggml_build_forward_ctx(ctx, final_weight);
+
+ ggml_graph_compute_with_ctx(ctx, graph, n_threads);
+
+ // LOG_INFO("network_name '%s' ggml_used_mem size = %.2fMB",
+ // network_name.c_str(),
+ // ggml_used_mem(ctx) / 1024.0 / 1024.0);
+
+ ggml_free(ctx);
+ }
+ }
+ free(mem_buffer);
+
+ for (auto& kv : lora_tensor_map) {
+ if (applied_lora_tensors.find(kv.first) == applied_lora_tensors.end()) {
+ LOG_WARN("unused lora tensor %s", kv.first.c_str());
+ }
+ }
+
+ if (free_params_immediately) {
+ remove_lora_params(lora_name);
+ }
+
+ int64_t t1 = ggml_time_ms();
+
+ LOG_INFO("apply lora '%s:%f' completed, taking %.2fs",
+ lora_name.c_str(),
+ multiplier,
+ (t1 - t0) * 1.0f / 1000);
+ }
+
+ void apply_loras(const std::unordered_map& lora_state) {
+ std::unordered_map lora_state_diff;
+ for (auto& kv : lora_state) {
+ const std::string& lora_name = kv.first;
+ float multiplier = kv.second;
+
+ if (curr_lora_state.find(lora_name) != curr_lora_state.end()) {
+ float curr_multiplier = curr_lora_state[lora_name];
+ float multiplier_diff = multiplier - curr_multiplier;
+ if (multiplier_diff != 0.f) {
+ lora_state_diff[lora_name] = multiplier_diff;
+ }
+ } else {
+ lora_state_diff[lora_name] = multiplier;
+ }
+ }
+
+ for (auto& kv : lora_state_diff) {
+ apply_lora(kv.first, kv.second);
+ }
+
+ curr_lora_state = lora_state;
+ }
+
ggml_tensor* get_learned_condition(ggml_context* res_ctx, const std::string& text) {
auto tokens_and_weights = cond_stage_model.tokenize(text,
cond_stage_model.text_model.max_position_embeddings,
@@ -4235,10 +4596,12 @@ class StableDiffusionGGML {
StableDiffusion::StableDiffusion(int n_threads,
bool vae_decode_only,
bool free_params_immediately,
+ std::string lora_model_dir,
RNGType rng_type) {
sd = std::make_shared(n_threads,
vae_decode_only,
free_params_immediately,
+ lora_model_dir,
rng_type);
}
@@ -4246,8 +4609,8 @@ bool StableDiffusion::load_from_file(const std::string& file_path, Schedule s) {
return sd->load_from_file(file_path, s);
}
-std::vector StableDiffusion::txt2img(const std::string& prompt,
- const std::string& negative_prompt,
+std::vector StableDiffusion::txt2img(std::string prompt,
+ std::string negative_prompt,
float cfg_scale,
int width,
int height,
@@ -4272,13 +4635,28 @@ std::vector StableDiffusion::txt2img(const std::string& prompt,
}
sd->rng->manual_seed(seed);
+ // extract and remote lora
+ auto result_pair = extract_and_remove_lora(prompt);
+ std::unordered_map lora_f2m = result_pair.first; // lora_name -> multiplier
+ for (auto& kv : lora_f2m) {
+ LOG_DEBUG("lora %s:%.2f", kv.first.c_str(), kv.second);
+ }
+ prompt = result_pair.second;
+ LOG_DEBUG("prompt after extract and remote lora: \"%s\"", prompt.c_str());
+
+ // load lora from file
int64_t t0 = ggml_time_ms();
+ sd->apply_loras(lora_f2m);
+ int64_t t1 = ggml_time_ms();
+ LOG_INFO("apply_loras completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
+
+ t0 = ggml_time_ms();
ggml_tensor* c = sd->get_learned_condition(ctx, prompt);
struct ggml_tensor* uc = NULL;
if (cfg_scale != 1.0) {
uc = sd->get_learned_condition(ctx, negative_prompt);
}
- int64_t t1 = ggml_time_ms();
+ t1 = ggml_time_ms();
LOG_INFO("get_learned_condition completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
if (sd->free_params_immediately) {
@@ -4334,8 +4712,8 @@ std::vector StableDiffusion::txt2img(const std::string& prompt,
}
std::vector StableDiffusion::img2img(const std::vector& init_img_vec,
- const std::string& prompt,
- const std::string& negative_prompt,
+ std::string prompt,
+ std::string negative_prompt,
float cfg_scale,
int width,
int height,
@@ -4372,14 +4750,29 @@ std::vector StableDiffusion::img2img(const std::vector& init_i
}
sd->rng->manual_seed(seed);
+ // extract and remote lora
+ auto result_pair = extract_and_remove_lora(prompt);
+ std::unordered_map lora_f2m = result_pair.first; // lora_name -> multiplier
+ for (auto& kv : lora_f2m) {
+ LOG_DEBUG("lora %s:%.2f", kv.first.c_str(), kv.second);
+ }
+ prompt = result_pair.second;
+ LOG_DEBUG("prompt after extract and remote lora: \"%s\"", prompt.c_str());
+
+ // load lora from file
+ int64_t t0 = ggml_time_ms();
+ sd->apply_loras(lora_f2m);
+ int64_t t1 = ggml_time_ms();
+ LOG_INFO("apply_loras completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
+
ggml_tensor* init_img = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, width, height, 3, 1);
image_vec_to_ggml(init_img_vec, init_img);
- int64_t t0 = ggml_time_ms();
+ t0 = ggml_time_ms();
ggml_tensor* moments = sd->encode_first_stage(ctx, init_img);
ggml_tensor* init_latent = sd->get_first_stage_encoding(ctx, moments);
// print_ggml_tensor(init_latent);
- int64_t t1 = ggml_time_ms();
+ t1 = ggml_time_ms();
LOG_INFO("encode_first_stage completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
ggml_reset_curr_max_dynamic_size(); // reset counter
diff --git a/stable-diffusion.h b/stable-diffusion.h
index 14381db..ed8cd1f 100644
--- a/stable-diffusion.h
+++ b/stable-diffusion.h
@@ -45,11 +45,12 @@ class StableDiffusion {
StableDiffusion(int n_threads = -1,
bool vae_decode_only = false,
bool free_params_immediately = false,
+ std::string lora_model_dir = "",
RNGType rng_type = STD_DEFAULT_RNG);
bool load_from_file(const std::string& file_path, Schedule d = DEFAULT);
std::vector txt2img(
- const std::string& prompt,
- const std::string& negative_prompt,
+ std::string prompt,
+ std::string negative_prompt,
float cfg_scale,
int width,
int height,
@@ -58,8 +59,8 @@ class StableDiffusion {
int64_t seed);
std::vector img2img(
const std::vector& init_img,
- const std::string& prompt,
- const std::string& negative_prompt,
+ std::string prompt,
+ std::string negative_prompt,
float cfg_scale,
int width,
int height,