feat: add LoRA support
This commit is contained in:
parent
536f3af672
commit
9a9f3daf8e
46
README.md
46
README.md
@ -18,6 +18,8 @@ Inference of [Stable Diffusion](https://github.com/CompVis/stable-diffusion) in
|
|||||||
- Original `txt2img` and `img2img` mode
|
- Original `txt2img` and `img2img` mode
|
||||||
- Negative prompt
|
- Negative prompt
|
||||||
- [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) style tokenizer (not all the features, only token weighting for now)
|
- [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) style tokenizer (not all the features, only token weighting for now)
|
||||||
|
- LoRA support, same as [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#lora)
|
||||||
|
- Latent Consistency Models support(LCM/LCM-LoRA)
|
||||||
- Sampling method
|
- Sampling method
|
||||||
- `Euler A`
|
- `Euler A`
|
||||||
- `Euler`
|
- `Euler`
|
||||||
@ -42,7 +44,6 @@ Inference of [Stable Diffusion](https://github.com/CompVis/stable-diffusion) in
|
|||||||
- [ ] Make inference faster
|
- [ ] Make inference faster
|
||||||
- The current implementation of ggml_conv_2d is slow and has high memory usage
|
- The current implementation of ggml_conv_2d is slow and has high memory usage
|
||||||
- [ ] Continuing to reduce memory usage (quantizing the weights of ggml_conv_2d)
|
- [ ] Continuing to reduce memory usage (quantizing the weights of ggml_conv_2d)
|
||||||
- [ ] LoRA support
|
|
||||||
- [ ] k-quants support
|
- [ ] k-quants support
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
@ -125,6 +126,7 @@ arguments:
|
|||||||
-t, --threads N number of threads to use during computation (default: -1).
|
-t, --threads N number of threads to use during computation (default: -1).
|
||||||
If threads <= 0, then threads will be set to the number of CPU physical cores
|
If threads <= 0, then threads will be set to the number of CPU physical cores
|
||||||
-m, --model [MODEL] path to model
|
-m, --model [MODEL] path to model
|
||||||
|
--lora-model-dir [DIR] lora model directory
|
||||||
-i, --init-img [IMAGE] path to the input image, required by img2img
|
-i, --init-img [IMAGE] path to the input image, required by img2img
|
||||||
-o, --output OUTPUT path to write result image to (default: .\output.png)
|
-o, --output OUTPUT path to write result image to (default: .\output.png)
|
||||||
-p, --prompt [PROMPT] the prompt to render
|
-p, --prompt [PROMPT] the prompt to render
|
||||||
@ -134,11 +136,12 @@ arguments:
|
|||||||
1.0 corresponds to full destruction of information in init image
|
1.0 corresponds to full destruction of information in init image
|
||||||
-H, --height H image height, in pixel space (default: 512)
|
-H, --height H image height, in pixel space (default: 512)
|
||||||
-W, --width W image width, in pixel space (default: 512)
|
-W, --width W image width, in pixel space (default: 512)
|
||||||
--sampling-method {euler, euler_a, heun, dpm++2m, dpm++2mv2, lcm}
|
--sampling-method {euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, lcm}
|
||||||
sampling method (default: "euler_a")
|
sampling method (default: "euler_a")
|
||||||
--steps STEPS number of sample steps (default: 20)
|
--steps STEPS number of sample steps (default: 20)
|
||||||
--rng {std_default, cuda} RNG (default: cuda)
|
--rng {std_default, cuda} RNG (default: cuda)
|
||||||
-s SEED, --seed SEED RNG seed (default: 42, use random seed for < 0)
|
-s SEED, --seed SEED RNG seed (default: 42, use random seed for < 0)
|
||||||
|
--schedule {discrete, karras} Denoiser sigma schedule (default: discrete)
|
||||||
-v, --verbose print extra info
|
-v, --verbose print extra info
|
||||||
```
|
```
|
||||||
|
|
||||||
@ -167,6 +170,45 @@ Using formats of different precisions will yield results of varying quality.
|
|||||||
<img src="./assets/img2img_output.png" width="256x">
|
<img src="./assets/img2img_output.png" width="256x">
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
|
#### with LoRA
|
||||||
|
|
||||||
|
- convert lora weights to ggml model format
|
||||||
|
|
||||||
|
```shell
|
||||||
|
cd models
|
||||||
|
python convert.py [path to weights] --lora
|
||||||
|
# For example, python convert.py marblesh.safetensors
|
||||||
|
```
|
||||||
|
|
||||||
|
- You can specify the directory where the lora weights are stored via `--lora-model-dir`. If not specified, the default is the current working directory.
|
||||||
|
|
||||||
|
- LoRA is specified via prompt, just like [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#lora).
|
||||||
|
|
||||||
|
Here's a simple example:
|
||||||
|
|
||||||
|
```
|
||||||
|
./bin/sd -m ../models/v1-5-pruned-emaonly-ggml-model-f16.bin -p "a lovely cat<lora:marblesh:1>" --lora-model-dir ../models
|
||||||
|
```
|
||||||
|
|
||||||
|
`../models/marblesh-ggml-lora.bin` will be applied to the model
|
||||||
|
|
||||||
|
#### LCM/LCM-LoRA
|
||||||
|
|
||||||
|
- Download LCM-LoRA form https://huggingface.co/latent-consistency/lcm-lora-sdv1-5
|
||||||
|
- Specify LCM-LoRA by adding `<lora:lcm-lora-sdv1-5:1>` to prompt
|
||||||
|
- It's advisable to set `--cfg-scale` to `1.0` instead of the default `7.0`. For `--steps`, a range of `2-8` steps is recommended. For `--sampling-method`, `lcm`/`euler_a` is recommended.
|
||||||
|
|
||||||
|
Here's a simple example:
|
||||||
|
|
||||||
|
```
|
||||||
|
./bin/sd -m ../models/v1-5-pruned-emaonly-ggml-model-f16.bin -p "a lovely cat<lora:lcm-lora-sdv1-5:1>" --steps 4 --lora-model-dir ../models -v --cfg-scale 1
|
||||||
|
```
|
||||||
|
|
||||||
|
| without LCM-LoRA (--cfg-scale 7) | with LCM-LoRA (--cfg-scale 1) |
|
||||||
|
| ---- |---- |
|
||||||
|
|  | |
|
||||||
|
|
||||||
|
|
||||||
### Docker
|
### Docker
|
||||||
|
|
||||||
#### Building using Docker
|
#### Building using Docker
|
||||||
|
BIN
assets/with_lcm.png
Normal file
BIN
assets/with_lcm.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 596 KiB |
BIN
assets/without_lcm.png
Normal file
BIN
assets/without_lcm.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 533 KiB |
@ -95,6 +95,7 @@ struct Option {
|
|||||||
int n_threads = -1;
|
int n_threads = -1;
|
||||||
std::string mode = TXT2IMG;
|
std::string mode = TXT2IMG;
|
||||||
std::string model_path;
|
std::string model_path;
|
||||||
|
std::string lora_model_dir;
|
||||||
std::string output_path = "output.png";
|
std::string output_path = "output.png";
|
||||||
std::string init_img;
|
std::string init_img;
|
||||||
std::string prompt;
|
std::string prompt;
|
||||||
@ -115,6 +116,7 @@ struct Option {
|
|||||||
printf(" n_threads: %d\n", n_threads);
|
printf(" n_threads: %d\n", n_threads);
|
||||||
printf(" mode: %s\n", mode.c_str());
|
printf(" mode: %s\n", mode.c_str());
|
||||||
printf(" model_path: %s\n", model_path.c_str());
|
printf(" model_path: %s\n", model_path.c_str());
|
||||||
|
printf(" lora_model_dir: %s\n", lora_model_dir.c_str());
|
||||||
printf(" output_path: %s\n", output_path.c_str());
|
printf(" output_path: %s\n", output_path.c_str());
|
||||||
printf(" init_img: %s\n", init_img.c_str());
|
printf(" init_img: %s\n", init_img.c_str());
|
||||||
printf(" prompt: %s\n", prompt.c_str());
|
printf(" prompt: %s\n", prompt.c_str());
|
||||||
@ -127,7 +129,7 @@ struct Option {
|
|||||||
printf(" sample_steps: %d\n", sample_steps);
|
printf(" sample_steps: %d\n", sample_steps);
|
||||||
printf(" strength: %.2f\n", strength);
|
printf(" strength: %.2f\n", strength);
|
||||||
printf(" rng: %s\n", rng_type_to_str[rng_type]);
|
printf(" rng: %s\n", rng_type_to_str[rng_type]);
|
||||||
printf(" seed: %ld\n", seed);
|
printf(" seed: %lld\n", seed);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -140,6 +142,7 @@ void print_usage(int argc, const char* argv[]) {
|
|||||||
printf(" -t, --threads N number of threads to use during computation (default: -1).\n");
|
printf(" -t, --threads N number of threads to use during computation (default: -1).\n");
|
||||||
printf(" If threads <= 0, then threads will be set to the number of CPU physical cores\n");
|
printf(" If threads <= 0, then threads will be set to the number of CPU physical cores\n");
|
||||||
printf(" -m, --model [MODEL] path to model\n");
|
printf(" -m, --model [MODEL] path to model\n");
|
||||||
|
printf(" --lora-model-dir [DIR] lora model directory\n");
|
||||||
printf(" -i, --init-img [IMAGE] path to the input image, required by img2img\n");
|
printf(" -i, --init-img [IMAGE] path to the input image, required by img2img\n");
|
||||||
printf(" -o, --output OUTPUT path to write result image to (default: .\\output.png)\n");
|
printf(" -o, --output OUTPUT path to write result image to (default: .\\output.png)\n");
|
||||||
printf(" -p, --prompt [PROMPT] the prompt to render\n");
|
printf(" -p, --prompt [PROMPT] the prompt to render\n");
|
||||||
@ -183,6 +186,12 @@ void parse_args(int argc, const char* argv[], Option* opt) {
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
opt->model_path = argv[i];
|
opt->model_path = argv[i];
|
||||||
|
} else if (arg == "--lora-model-dir") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
invalid_arg = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
opt->lora_model_dir = argv[i];
|
||||||
} else if (arg == "-i" || arg == "--init-img") {
|
} else if (arg == "-i" || arg == "--init-img") {
|
||||||
if (++i >= argc) {
|
if (++i >= argc) {
|
||||||
invalid_arg = true;
|
invalid_arg = true;
|
||||||
@ -419,7 +428,7 @@ int main(int argc, const char* argv[]) {
|
|||||||
init_img.assign(img_data, img_data + (opt.w * opt.h * c));
|
init_img.assign(img_data, img_data + (opt.w * opt.h * c));
|
||||||
}
|
}
|
||||||
|
|
||||||
StableDiffusion sd(opt.n_threads, vae_decode_only, true, opt.rng_type);
|
StableDiffusion sd(opt.n_threads, vae_decode_only, true, opt.lora_model_dir, opt.rng_type);
|
||||||
if (!sd.load_from_file(opt.model_path, opt.schedule)) {
|
if (!sd.load_from_file(opt.model_path, opt.schedule)) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
@ -4,6 +4,7 @@ import os
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
import re
|
||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
|
|
||||||
this_file_dir = os.path.dirname(__file__)
|
this_file_dir = os.path.dirname(__file__)
|
||||||
@ -270,21 +271,107 @@ def preprocess(state_dict):
|
|||||||
new_state_dict[name] = w
|
new_state_dict[name] = w
|
||||||
return new_state_dict
|
return new_state_dict
|
||||||
|
|
||||||
def convert(model_path, out_type = None, out_file=None):
|
re_digits = re.compile(r"\d+")
|
||||||
|
re_x_proj = re.compile(r"(.*)_([qkv]_proj)$")
|
||||||
|
re_compiled = {}
|
||||||
|
|
||||||
|
suffix_conversion = {
|
||||||
|
"attentions": {},
|
||||||
|
"resnets": {
|
||||||
|
"conv1": "in_layers_2",
|
||||||
|
"conv2": "out_layers_3",
|
||||||
|
"norm1": "in_layers_0",
|
||||||
|
"norm2": "out_layers_0",
|
||||||
|
"time_emb_proj": "emb_layers_1",
|
||||||
|
"conv_shortcut": "skip_connection",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def convert_diffusers_name_to_compvis(key):
|
||||||
|
def match(match_list, regex_text):
|
||||||
|
regex = re_compiled.get(regex_text)
|
||||||
|
if regex is None:
|
||||||
|
regex = re.compile(regex_text)
|
||||||
|
re_compiled[regex_text] = regex
|
||||||
|
|
||||||
|
r = re.match(regex, key)
|
||||||
|
if not r:
|
||||||
|
return False
|
||||||
|
|
||||||
|
match_list.clear()
|
||||||
|
match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()])
|
||||||
|
return True
|
||||||
|
|
||||||
|
m = []
|
||||||
|
|
||||||
|
if match(m, r"lora_unet_conv_in(.*)"):
|
||||||
|
return f'model_diffusion_model_input_blocks_0_0{m[0]}'
|
||||||
|
|
||||||
|
if match(m, r"lora_unet_conv_out(.*)"):
|
||||||
|
return f'model_diffusion_model_out_2{m[0]}'
|
||||||
|
|
||||||
|
if match(m, r"lora_unet_time_embedding_linear_(\d+)(.*)"):
|
||||||
|
return f"model_diffusion_model_time_embed_{m[0] * 2 - 2}{m[1]}"
|
||||||
|
|
||||||
|
if match(m, r"lora_unet_down_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"):
|
||||||
|
suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3])
|
||||||
|
return f"model_diffusion_model_input_blocks_{1 + m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}"
|
||||||
|
|
||||||
|
if match(m, r"lora_unet_mid_block_(attentions|resnets)_(\d+)_(.+)"):
|
||||||
|
suffix = suffix_conversion.get(m[0], {}).get(m[2], m[2])
|
||||||
|
return f"model_diffusion_model_middle_block_{1 if m[0] == 'attentions' else m[1] * 2}_{suffix}"
|
||||||
|
|
||||||
|
if match(m, r"lora_unet_up_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"):
|
||||||
|
suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3])
|
||||||
|
return f"model_diffusion_model_output_blocks_{m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}"
|
||||||
|
|
||||||
|
if match(m, r"lora_unet_down_blocks_(\d+)_downsamplers_0_conv"):
|
||||||
|
return f"model_diffusion_model_input_blocks_{3 + m[0] * 3}_0_op"
|
||||||
|
|
||||||
|
if match(m, r"lora_unet_up_blocks_(\d+)_upsamplers_0_conv"):
|
||||||
|
return f"model_diffusion_model_output_blocks_{2 + m[0] * 3}_{2 if m[0]>0 else 1}_conv"
|
||||||
|
|
||||||
|
if match(m, r"lora_te_text_model_encoder_layers_(\d+)_(.+)"):
|
||||||
|
return f"cond_stage_model_transformer_text_model_encoder_layers_{m[0]}_{m[1]}"
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def preprocess_lora(state_dict):
|
||||||
|
new_state_dict = {}
|
||||||
|
for name, w in state_dict.items():
|
||||||
|
if not isinstance(w, torch.Tensor):
|
||||||
|
continue
|
||||||
|
name_without_network_parts, network_part = name.split(".", 1)
|
||||||
|
new_name_without_network_parts = convert_diffusers_name_to_compvis(name_without_network_parts)
|
||||||
|
if new_name_without_network_parts == None:
|
||||||
|
raise Exception(f"unknown lora tensor: {name}")
|
||||||
|
new_name = new_name_without_network_parts + "." + network_part
|
||||||
|
print(f"preprocess {name} => {new_name}")
|
||||||
|
new_state_dict[new_name] = w
|
||||||
|
return new_state_dict
|
||||||
|
|
||||||
|
def convert(model_path, out_type = None, out_file=None, lora=False):
|
||||||
# load model
|
# load model
|
||||||
with open(os.path.join(vocab_dir, "vocab.json"), encoding="utf-8") as f:
|
if not lora:
|
||||||
clip_vocab = json.load(f)
|
with open(os.path.join(vocab_dir, "vocab.json"), encoding="utf-8") as f:
|
||||||
|
clip_vocab = json.load(f)
|
||||||
|
|
||||||
state_dict = load_model_from_file(model_path)
|
state_dict = load_model_from_file(model_path)
|
||||||
model_type = SD1
|
model_type = SD1 # lora only for SD1 now
|
||||||
if "cond_stage_model.model.token_embedding.weight" in state_dict.keys():
|
if not lora and "cond_stage_model.model.token_embedding.weight" in state_dict.keys():
|
||||||
model_type = SD2
|
model_type = SD2
|
||||||
print("Stable diffuison 2.x")
|
print("Stable diffuison 2.x")
|
||||||
else:
|
else:
|
||||||
print("Stable diffuison 1.x")
|
print("Stable diffuison 1.x")
|
||||||
state_dict = preprocess(state_dict)
|
if lora:
|
||||||
|
state_dict = preprocess_lora(state_dict)
|
||||||
|
else:
|
||||||
|
state_dict = preprocess(state_dict)
|
||||||
|
|
||||||
# output option
|
# output option
|
||||||
|
if lora:
|
||||||
|
out_type = "f16" # only f16 for now
|
||||||
if out_type == None:
|
if out_type == None:
|
||||||
weight = state_dict["model.diffusion_model.input_blocks.0.0.weight"].numpy()
|
weight = state_dict["model.diffusion_model.input_blocks.0.0.weight"].numpy()
|
||||||
if weight.dtype == np.float32:
|
if weight.dtype == np.float32:
|
||||||
@ -296,7 +383,10 @@ def convert(model_path, out_type = None, out_file=None):
|
|||||||
else:
|
else:
|
||||||
raise Exception("unsupported weight type %s" % weight.dtype)
|
raise Exception("unsupported weight type %s" % weight.dtype)
|
||||||
if out_file == None:
|
if out_file == None:
|
||||||
out_file = os.path.splitext(os.path.basename(model_path))[0] + f"-ggml-model-{out_type}.bin"
|
if lora:
|
||||||
|
out_file = os.path.splitext(os.path.basename(model_path))[0] + f"-ggml-lora.bin"
|
||||||
|
else:
|
||||||
|
out_file = os.path.splitext(os.path.basename(model_path))[0] + f"-ggml-model-{out_type}.bin"
|
||||||
out_file = os.path.join(os.getcwd(), out_file)
|
out_file = os.path.join(os.getcwd(), out_file)
|
||||||
print(f"Saving GGML compatible file to {out_file}")
|
print(f"Saving GGML compatible file to {out_file}")
|
||||||
|
|
||||||
@ -309,14 +399,15 @@ def convert(model_path, out_type = None, out_file=None):
|
|||||||
file.write(struct.pack("i", ftype))
|
file.write(struct.pack("i", ftype))
|
||||||
|
|
||||||
# vocab
|
# vocab
|
||||||
byte_encoder = bytes_to_unicode()
|
if not lora:
|
||||||
byte_decoder = {v: k for k, v in byte_encoder.items()}
|
byte_encoder = bytes_to_unicode()
|
||||||
file.write(struct.pack("i", len(clip_vocab)))
|
byte_decoder = {v: k for k, v in byte_encoder.items()}
|
||||||
for key in clip_vocab:
|
file.write(struct.pack("i", len(clip_vocab)))
|
||||||
text = bytearray([byte_decoder[c] for c in key])
|
for key in clip_vocab:
|
||||||
file.write(struct.pack("i", len(text)))
|
text = bytearray([byte_decoder[c] for c in key])
|
||||||
file.write(text)
|
file.write(struct.pack("i", len(text)))
|
||||||
|
file.write(text)
|
||||||
|
|
||||||
# weights
|
# weights
|
||||||
for name in state_dict.keys():
|
for name in state_dict.keys():
|
||||||
if not isinstance(state_dict[name], torch.Tensor):
|
if not isinstance(state_dict[name], torch.Tensor):
|
||||||
@ -337,7 +428,7 @@ def convert(model_path, out_type = None, out_file=None):
|
|||||||
old_type = data.dtype
|
old_type = data.dtype
|
||||||
|
|
||||||
ttype = "f32"
|
ttype = "f32"
|
||||||
if n_dims == 4:
|
if n_dims == 4 and not lora:
|
||||||
data = data.astype(np.float16)
|
data = data.astype(np.float16)
|
||||||
ttype = "f16"
|
ttype = "f16"
|
||||||
elif n_dims == 2 and name[-7:] == ".weight":
|
elif n_dims == 2 and name[-7:] == ".weight":
|
||||||
@ -380,6 +471,7 @@ if __name__ == "__main__":
|
|||||||
parser = argparse.ArgumentParser(description="Convert Stable Diffuison model to GGML compatible file format")
|
parser = argparse.ArgumentParser(description="Convert Stable Diffuison model to GGML compatible file format")
|
||||||
parser.add_argument("--out_type", choices=["f32", "f16", "q4_0", "q4_1", "q5_0", "q5_1", "q8_0"], help="output format (default: based on input)")
|
parser.add_argument("--out_type", choices=["f32", "f16", "q4_0", "q4_1", "q5_0", "q5_1", "q8_0"], help="output format (default: based on input)")
|
||||||
parser.add_argument("--out_file", help="path to write to; default: based on input and current working directory")
|
parser.add_argument("--out_file", help="path to write to; default: based on input and current working directory")
|
||||||
|
parser.add_argument("--lora", action='store_true', default = False, help="convert lora weight; default: false")
|
||||||
parser.add_argument("model_path", help="model file path (*.pth, *.pt, *.ckpt, *.safetensors)")
|
parser.add_argument("model_path", help="model file path (*.pth, *.pt, *.ckpt, *.safetensors)")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
convert(args.model_path, args.out_type, args.out_file)
|
convert(args.model_path, args.out_type, args.out_file, args.lora)
|
||||||
|
@ -268,6 +268,45 @@ struct ggml_tensor* ggml_group_norm_32(struct ggml_context* ctx,
|
|||||||
return ggml_group_norm(ctx, a, 32);
|
return ggml_group_norm(ctx, a, 32);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::pair<std::unordered_map<std::string, float>, std::string> extract_and_remove_lora(std::string text) {
|
||||||
|
std::regex re("<lora:([^:]+):([^>]+)>");
|
||||||
|
std::smatch matches;
|
||||||
|
std::unordered_map<std::string, float> filename2multiplier;
|
||||||
|
|
||||||
|
while (std::regex_search(text, matches, re)) {
|
||||||
|
std::string filename = matches[1].str();
|
||||||
|
float multiplier = std::stof(matches[2].str());
|
||||||
|
if (multiplier < 0.f) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (filename2multiplier.find(filename) == filename2multiplier.end()) {
|
||||||
|
filename2multiplier[filename] = multiplier;
|
||||||
|
} else {
|
||||||
|
filename2multiplier[filename] += multiplier;
|
||||||
|
}
|
||||||
|
|
||||||
|
text = std::regex_replace(text, re, "");
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::make_pair(filename2multiplier, text);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ends_with(const std::string& str, const std::string& ending) {
|
||||||
|
if (str.length() >= ending.length()) {
|
||||||
|
return (str.compare(str.length() - ending.length(), ending.length(), ending) == 0);
|
||||||
|
} else {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void replace_all_chars(std::string& str, char target, char replacement) {
|
||||||
|
for (size_t i = 0; i < str.length(); ++i) {
|
||||||
|
if (str[i] == target) {
|
||||||
|
str[i] = replacement;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/*================================================== CLIPTokenizer ===================================================*/
|
/*================================================== CLIPTokenizer ===================================================*/
|
||||||
|
|
||||||
const std::string UNK_TOKEN = "<|endoftext|>";
|
const std::string UNK_TOKEN = "<|endoftext|>";
|
||||||
@ -2794,6 +2833,16 @@ class StableDiffusionGGML {
|
|||||||
UNetModel diffusion_model;
|
UNetModel diffusion_model;
|
||||||
AutoEncoderKL first_stage_model;
|
AutoEncoderKL first_stage_model;
|
||||||
|
|
||||||
|
std::map<std::string, struct ggml_tensor*> tensors;
|
||||||
|
|
||||||
|
std::string lora_model_dir;
|
||||||
|
// lora_name => lora_tensor_name => tensor
|
||||||
|
std::map<std::string, std::map<std::string, struct ggml_tensor*>> lora_tensors;
|
||||||
|
// lora_name => lora_params_ctx
|
||||||
|
std::map<std::string, ggml_context*> lora_params_ctxs;
|
||||||
|
// lora_name => multiplier
|
||||||
|
std::unordered_map<std::string, float> curr_lora_state;
|
||||||
|
|
||||||
std::shared_ptr<Denoiser> denoiser = std::make_shared<CompVisDenoiser>();
|
std::shared_ptr<Denoiser> denoiser = std::make_shared<CompVisDenoiser>();
|
||||||
|
|
||||||
StableDiffusionGGML() = default;
|
StableDiffusionGGML() = default;
|
||||||
@ -2801,16 +2850,23 @@ class StableDiffusionGGML {
|
|||||||
StableDiffusionGGML(int n_threads,
|
StableDiffusionGGML(int n_threads,
|
||||||
bool vae_decode_only,
|
bool vae_decode_only,
|
||||||
bool free_params_immediately,
|
bool free_params_immediately,
|
||||||
|
std::string lora_model_dir,
|
||||||
RNGType rng_type)
|
RNGType rng_type)
|
||||||
: n_threads(n_threads),
|
: n_threads(n_threads),
|
||||||
vae_decode_only(vae_decode_only),
|
vae_decode_only(vae_decode_only),
|
||||||
free_params_immediately(free_params_immediately) {
|
free_params_immediately(free_params_immediately),
|
||||||
|
lora_model_dir(lora_model_dir) {
|
||||||
first_stage_model.decode_only = vae_decode_only;
|
first_stage_model.decode_only = vae_decode_only;
|
||||||
if (rng_type == STD_DEFAULT_RNG) {
|
if (rng_type == STD_DEFAULT_RNG) {
|
||||||
rng = std::make_shared<STDDefaultRNG>();
|
rng = std::make_shared<STDDefaultRNG>();
|
||||||
} else if (rng_type == CUDA_RNG) {
|
} else if (rng_type == CUDA_RNG) {
|
||||||
rng = std::make_shared<PhiloxRNG>();
|
rng = std::make_shared<PhiloxRNG>();
|
||||||
}
|
}
|
||||||
|
if (lora_model_dir.size() > 0) {
|
||||||
|
if (lora_model_dir[lora_model_dir.size() - 1] != '/' && lora_model_dir[lora_model_dir.size() - 1] != '\\') {
|
||||||
|
this->lora_model_dir = lora_model_dir + "/";
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
~StableDiffusionGGML() {
|
~StableDiffusionGGML() {
|
||||||
@ -2826,6 +2882,13 @@ class StableDiffusionGGML {
|
|||||||
ggml_free(vae_params_ctx);
|
ggml_free(vae_params_ctx);
|
||||||
vae_params_ctx = NULL;
|
vae_params_ctx = NULL;
|
||||||
}
|
}
|
||||||
|
for (auto& kv : lora_params_ctxs) {
|
||||||
|
ggml_free(kv.second);
|
||||||
|
}
|
||||||
|
lora_params_ctxs.clear();
|
||||||
|
|
||||||
|
tensors.clear();
|
||||||
|
lora_tensors.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool load_from_file(const std::string& file_path, Schedule schedule) {
|
bool load_from_file(const std::string& file_path, Schedule schedule) {
|
||||||
@ -2963,8 +3026,6 @@ class StableDiffusionGGML {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::map<std::string, struct ggml_tensor*> tensors;
|
|
||||||
|
|
||||||
LOG_DEBUG("preparing memory for the weights");
|
LOG_DEBUG("preparing memory for the weights");
|
||||||
// prepare memory for the weights
|
// prepare memory for the weights
|
||||||
{
|
{
|
||||||
@ -3255,6 +3316,306 @@ class StableDiffusionGGML {
|
|||||||
return result < -1;
|
return result < -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool load_lora_from_file(const std::string& lora_name) {
|
||||||
|
if (lora_tensors.find(lora_name) != lora_tensors.end()) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
std::string file_path = lora_model_dir + lora_name + "-ggml-lora.bin";
|
||||||
|
LOG_INFO("loading lora '%s' from '%s'", lora_name.c_str(), file_path.c_str());
|
||||||
|
|
||||||
|
std::ifstream file(file_path, std::ios::binary);
|
||||||
|
if (!file.is_open()) {
|
||||||
|
LOG_ERROR("failed to open '%s'", file_path.c_str());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// get file size
|
||||||
|
file.seekg(0, file.end);
|
||||||
|
int file_size = (int)file.tellg();
|
||||||
|
file.seekg(0, file.beg);
|
||||||
|
|
||||||
|
LOG_DEBUG("'%s': %.2fMB", file_path.c_str(), file_size * 1.f / 1024 / 1024);
|
||||||
|
|
||||||
|
LOG_DEBUG("verifying magic");
|
||||||
|
// verify magic
|
||||||
|
{
|
||||||
|
uint32_t magic;
|
||||||
|
file.read(reinterpret_cast<char*>(&magic), sizeof(magic));
|
||||||
|
if (magic != GGML_FILE_MAGIC) {
|
||||||
|
LOG_ERROR("invalid model file '%s' (bad magic)", file_path.c_str());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
LOG_DEBUG("loading hparams");
|
||||||
|
// load hparams
|
||||||
|
file.read(reinterpret_cast<char*>(&ftype), sizeof(ftype));
|
||||||
|
|
||||||
|
int model_type = (ftype >> 16) & 0xFFFF;
|
||||||
|
if (model_type >= MODEL_TYPE_COUNT) {
|
||||||
|
LOG_ERROR("invalid model file '%s' (bad model type value %d)", file_path.c_str(), ftype);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
LOG_INFO("lora model type: %s", model_type_to_str[model_type]);
|
||||||
|
|
||||||
|
ggml_type wtype = ggml_ftype_to_ggml_type((ggml_ftype)(ftype & 0xFFFF));
|
||||||
|
LOG_INFO("ftype: %s", ggml_type_name(wtype));
|
||||||
|
if (wtype == GGML_TYPE_COUNT) {
|
||||||
|
LOG_ERROR("invalid model file '%s' (bad ftype value %d)", file_path.c_str(), ftype);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// create the ggml context for network params
|
||||||
|
struct ggml_init_params params;
|
||||||
|
size_t ctx_size = 10 * 1024 * 1024; // 10 MB, for padding
|
||||||
|
ctx_size += file_size;
|
||||||
|
params.mem_size = ctx_size;
|
||||||
|
params.mem_buffer = NULL;
|
||||||
|
params.no_alloc = false;
|
||||||
|
params.dynamic = false;
|
||||||
|
LOG_DEBUG("lora '%s' params ctx size = % 6.2f MB", lora_name.c_str(), ctx_size / (1024.0 * 1024.0));
|
||||||
|
ggml_context* lora_params_ctx = ggml_init(params);
|
||||||
|
if (!lora_params_ctx) {
|
||||||
|
LOG_ERROR("ggml_init() failed");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
lora_params_ctxs[lora_name] = lora_params_ctx;
|
||||||
|
|
||||||
|
std::map<std::string, struct ggml_tensor*> lora_tensor_map;
|
||||||
|
int64_t t0 = ggml_time_ms();
|
||||||
|
// load weights
|
||||||
|
{
|
||||||
|
int n_tensors = 0;
|
||||||
|
size_t total_size = 0;
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
int32_t n_dims;
|
||||||
|
int32_t length;
|
||||||
|
int32_t ttype;
|
||||||
|
|
||||||
|
file.read(reinterpret_cast<char*>(&n_dims), sizeof(n_dims));
|
||||||
|
file.read(reinterpret_cast<char*>(&length), sizeof(length));
|
||||||
|
file.read(reinterpret_cast<char*>(&ttype), sizeof(ttype));
|
||||||
|
|
||||||
|
if (file.eof()) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t nelements = 1;
|
||||||
|
int32_t ne[4] = {1, 1, 1, 1};
|
||||||
|
for (int i = 0; i < n_dims; ++i) {
|
||||||
|
file.read(reinterpret_cast<char*>(&ne[i]), sizeof(ne[i]));
|
||||||
|
nelements *= ne[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
const size_t num_bytes = nelements / ggml_blck_size(ggml_type(ttype)) * ggml_type_size(ggml_type(ttype));
|
||||||
|
|
||||||
|
std::string name_buf(length, 0);
|
||||||
|
file.read(&name_buf[0], length);
|
||||||
|
std::string name = std::string(name_buf.data());
|
||||||
|
|
||||||
|
// LOG_DEBUG("load lora tensor %s", name.c_str());
|
||||||
|
|
||||||
|
int64_t ne64[4] = {ne[0], ne[1], ne[2], ne[3]};
|
||||||
|
struct ggml_tensor* tensor = ggml_new_tensor(lora_params_ctx, (ggml_type)ttype, n_dims, ne64);
|
||||||
|
file.read(reinterpret_cast<char*>(tensor->data), num_bytes);
|
||||||
|
|
||||||
|
lora_tensor_map[name] = tensor;
|
||||||
|
|
||||||
|
total_size += ggml_nbytes(tensor);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
lora_tensors[lora_name] = lora_tensor_map;
|
||||||
|
int64_t t1 = ggml_time_ms();
|
||||||
|
LOG_INFO("lora '%s' params size = %.2fMB",
|
||||||
|
lora_name.c_str(),
|
||||||
|
ggml_used_mem(lora_params_ctx) / 1024.0 / 1024.0);
|
||||||
|
LOG_INFO("loading lora from '%s' completed, taking %.2fs", file_path.c_str(), (t1 - t0) * 1.0f / 1000);
|
||||||
|
file.close();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void remove_lora_params(const std::string& lora_name) {
|
||||||
|
if (lora_params_ctxs.find(lora_name) == lora_params_ctxs.end()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
ggml_free(lora_params_ctxs[lora_name]);
|
||||||
|
lora_params_ctxs.erase(lora_name);
|
||||||
|
lora_tensors.erase(lora_name);
|
||||||
|
}
|
||||||
|
|
||||||
|
void apply_lora(const std::string& lora_name, float multiplier) {
|
||||||
|
int64_t t0 = ggml_time_ms();
|
||||||
|
if (!load_lora_from_file(lora_name)) {
|
||||||
|
std::string file_path = lora_model_dir + lora_name + "-ggml-lora.bin";
|
||||||
|
LOG_WARN("apply lora '%s' failed", lora_name.c_str());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t ctx_size = 500 * 1024 * 1024; // 500MB
|
||||||
|
void* mem_buffer = malloc(ctx_size);
|
||||||
|
if (!mem_buffer) {
|
||||||
|
if (free_params_immediately) {
|
||||||
|
remove_lora_params(lora_name);
|
||||||
|
}
|
||||||
|
LOG_ERROR("malloc() failed");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::map<std::string, struct ggml_tensor*>& lora_tensor_map = lora_tensors[lora_name];
|
||||||
|
std::set<std::string> applied_lora_tensors;
|
||||||
|
for (auto& kv : tensors) {
|
||||||
|
const std::string name = kv.first;
|
||||||
|
ggml_tensor* weight = kv.second;
|
||||||
|
std::string ending = ".weight";
|
||||||
|
if (!ends_with(name, ending)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// find corresponding lora tensors
|
||||||
|
std::string network_name = name.substr(0, name.size() - ending.size()); // remove .weight
|
||||||
|
replace_all_chars(network_name, '.', '_');
|
||||||
|
std::string lora_up_name = network_name + ".lora_up.weight";
|
||||||
|
std::string lora_down_name = network_name + ".lora_down.weight";
|
||||||
|
std::string alpha_name = network_name + ".alpha";
|
||||||
|
std::string scale_name = network_name + ".scale";
|
||||||
|
|
||||||
|
ggml_tensor* lora_up = NULL;
|
||||||
|
ggml_tensor* lora_down = NULL;
|
||||||
|
|
||||||
|
float scale = 1.0f;
|
||||||
|
|
||||||
|
if (lora_tensor_map.find(lora_up_name) != lora_tensor_map.end()) {
|
||||||
|
lora_up = lora_tensor_map[lora_up_name];
|
||||||
|
}
|
||||||
|
|
||||||
|
if (lora_tensor_map.find(lora_down_name) != lora_tensor_map.end()) {
|
||||||
|
lora_down = lora_tensor_map[lora_down_name];
|
||||||
|
}
|
||||||
|
|
||||||
|
if (lora_up == NULL || lora_down == NULL) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// LOG_DEBUG("apply lora tensor %s [%ld %ld %ld %ld]", network_name.c_str(), weight->ne[0], weight->ne[1], weight->ne[2], weight->ne[3]);
|
||||||
|
|
||||||
|
applied_lora_tensors.insert(lora_up_name);
|
||||||
|
applied_lora_tensors.insert(lora_down_name);
|
||||||
|
applied_lora_tensors.insert(alpha_name);
|
||||||
|
applied_lora_tensors.insert(scale_name);
|
||||||
|
|
||||||
|
// calc_scale
|
||||||
|
int64_t dim = lora_down->ne[lora_down->n_dims - 1];
|
||||||
|
if (lora_tensor_map.find(scale_name) != lora_tensor_map.end()) {
|
||||||
|
ggml_tensor* t = lora_tensor_map[scale_name];
|
||||||
|
scale = ggml_get_f32_1d(t, 0);
|
||||||
|
} else if (lora_tensor_map.find(alpha_name) != lora_tensor_map.end()) {
|
||||||
|
ggml_tensor* t = lora_tensor_map[alpha_name];
|
||||||
|
scale = ggml_get_f32_1d(t, 0) / dim;
|
||||||
|
}
|
||||||
|
|
||||||
|
// LOG_DEBUG("scale: %f %ld", scale, dim);
|
||||||
|
|
||||||
|
scale = scale * multiplier;
|
||||||
|
|
||||||
|
// apply
|
||||||
|
{
|
||||||
|
struct ggml_init_params params;
|
||||||
|
params.mem_size = ctx_size;
|
||||||
|
params.mem_buffer = mem_buffer;
|
||||||
|
params.no_alloc = false;
|
||||||
|
params.dynamic = false;
|
||||||
|
|
||||||
|
struct ggml_context* ctx = ggml_init(params);
|
||||||
|
if (!ctx) {
|
||||||
|
LOG_ERROR("ggml_init() failed");
|
||||||
|
free(mem_buffer);
|
||||||
|
if (free_params_immediately) {
|
||||||
|
remove_lora_params(lora_name);
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor* scale_factor = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
|
||||||
|
ggml_set_f32_1d(scale_factor, 0, scale);
|
||||||
|
int64_t lora_up_size_0 = lora_up->ne[lora_up->n_dims - 1];
|
||||||
|
lora_up = ggml_reshape_2d(ctx, lora_up, ggml_nelements(lora_up) / lora_up_size_0, lora_up_size_0);
|
||||||
|
int64_t lora_down_size_0 = lora_down->ne[lora_down->n_dims - 1];
|
||||||
|
lora_down = ggml_reshape_2d(ctx, lora_down, ggml_nelements(lora_down) / lora_down_size_0, lora_down_size_0);
|
||||||
|
|
||||||
|
lora_down = ggml_cont(ctx, ggml_transpose(ctx, lora_down));
|
||||||
|
|
||||||
|
if (lora_down->type != GGML_TYPE_F32) {
|
||||||
|
ggml_tensor* lora_down_f32 = ggml_new_tensor(ctx, GGML_TYPE_F32, lora_down->n_dims, lora_down->ne);
|
||||||
|
lora_down = ggml_cpy_inplace(ctx, lora_down, lora_down_f32);
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor* updown = ggml_mul_mat(ctx, lora_up, lora_down);
|
||||||
|
updown = ggml_cont(ctx, ggml_transpose(ctx, updown));
|
||||||
|
updown = ggml_reshape(ctx, updown, weight);
|
||||||
|
|
||||||
|
GGML_ASSERT(ggml_nelements(updown) == ggml_nelements(weight));
|
||||||
|
|
||||||
|
updown = ggml_scale_inplace(ctx, updown, scale_factor);
|
||||||
|
ggml_tensor* final_weight;
|
||||||
|
final_weight = ggml_add_inplace(ctx, weight, updown);
|
||||||
|
final_weight = ggml_cpy_inplace(ctx, final_weight, weight);
|
||||||
|
|
||||||
|
struct ggml_cgraph* graph = ggml_build_forward_ctx(ctx, final_weight);
|
||||||
|
|
||||||
|
ggml_graph_compute_with_ctx(ctx, graph, n_threads);
|
||||||
|
|
||||||
|
// LOG_INFO("network_name '%s' ggml_used_mem size = %.2fMB",
|
||||||
|
// network_name.c_str(),
|
||||||
|
// ggml_used_mem(ctx) / 1024.0 / 1024.0);
|
||||||
|
|
||||||
|
ggml_free(ctx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
free(mem_buffer);
|
||||||
|
|
||||||
|
for (auto& kv : lora_tensor_map) {
|
||||||
|
if (applied_lora_tensors.find(kv.first) == applied_lora_tensors.end()) {
|
||||||
|
LOG_WARN("unused lora tensor %s", kv.first.c_str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (free_params_immediately) {
|
||||||
|
remove_lora_params(lora_name);
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t t1 = ggml_time_ms();
|
||||||
|
|
||||||
|
LOG_INFO("apply lora '%s:%f' completed, taking %.2fs",
|
||||||
|
lora_name.c_str(),
|
||||||
|
multiplier,
|
||||||
|
(t1 - t0) * 1.0f / 1000);
|
||||||
|
}
|
||||||
|
|
||||||
|
void apply_loras(const std::unordered_map<std::string, float>& lora_state) {
|
||||||
|
std::unordered_map<std::string, float> lora_state_diff;
|
||||||
|
for (auto& kv : lora_state) {
|
||||||
|
const std::string& lora_name = kv.first;
|
||||||
|
float multiplier = kv.second;
|
||||||
|
|
||||||
|
if (curr_lora_state.find(lora_name) != curr_lora_state.end()) {
|
||||||
|
float curr_multiplier = curr_lora_state[lora_name];
|
||||||
|
float multiplier_diff = multiplier - curr_multiplier;
|
||||||
|
if (multiplier_diff != 0.f) {
|
||||||
|
lora_state_diff[lora_name] = multiplier_diff;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
lora_state_diff[lora_name] = multiplier;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto& kv : lora_state_diff) {
|
||||||
|
apply_lora(kv.first, kv.second);
|
||||||
|
}
|
||||||
|
|
||||||
|
curr_lora_state = lora_state;
|
||||||
|
}
|
||||||
|
|
||||||
ggml_tensor* get_learned_condition(ggml_context* res_ctx, const std::string& text) {
|
ggml_tensor* get_learned_condition(ggml_context* res_ctx, const std::string& text) {
|
||||||
auto tokens_and_weights = cond_stage_model.tokenize(text,
|
auto tokens_and_weights = cond_stage_model.tokenize(text,
|
||||||
cond_stage_model.text_model.max_position_embeddings,
|
cond_stage_model.text_model.max_position_embeddings,
|
||||||
@ -4235,10 +4596,12 @@ class StableDiffusionGGML {
|
|||||||
StableDiffusion::StableDiffusion(int n_threads,
|
StableDiffusion::StableDiffusion(int n_threads,
|
||||||
bool vae_decode_only,
|
bool vae_decode_only,
|
||||||
bool free_params_immediately,
|
bool free_params_immediately,
|
||||||
|
std::string lora_model_dir,
|
||||||
RNGType rng_type) {
|
RNGType rng_type) {
|
||||||
sd = std::make_shared<StableDiffusionGGML>(n_threads,
|
sd = std::make_shared<StableDiffusionGGML>(n_threads,
|
||||||
vae_decode_only,
|
vae_decode_only,
|
||||||
free_params_immediately,
|
free_params_immediately,
|
||||||
|
lora_model_dir,
|
||||||
rng_type);
|
rng_type);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -4246,8 +4609,8 @@ bool StableDiffusion::load_from_file(const std::string& file_path, Schedule s) {
|
|||||||
return sd->load_from_file(file_path, s);
|
return sd->load_from_file(file_path, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<uint8_t> StableDiffusion::txt2img(const std::string& prompt,
|
std::vector<uint8_t> StableDiffusion::txt2img(std::string prompt,
|
||||||
const std::string& negative_prompt,
|
std::string negative_prompt,
|
||||||
float cfg_scale,
|
float cfg_scale,
|
||||||
int width,
|
int width,
|
||||||
int height,
|
int height,
|
||||||
@ -4272,13 +4635,28 @@ std::vector<uint8_t> StableDiffusion::txt2img(const std::string& prompt,
|
|||||||
}
|
}
|
||||||
sd->rng->manual_seed(seed);
|
sd->rng->manual_seed(seed);
|
||||||
|
|
||||||
|
// extract and remote lora
|
||||||
|
auto result_pair = extract_and_remove_lora(prompt);
|
||||||
|
std::unordered_map<std::string, float> lora_f2m = result_pair.first; // lora_name -> multiplier
|
||||||
|
for (auto& kv : lora_f2m) {
|
||||||
|
LOG_DEBUG("lora %s:%.2f", kv.first.c_str(), kv.second);
|
||||||
|
}
|
||||||
|
prompt = result_pair.second;
|
||||||
|
LOG_DEBUG("prompt after extract and remote lora: \"%s\"", prompt.c_str());
|
||||||
|
|
||||||
|
// load lora from file
|
||||||
int64_t t0 = ggml_time_ms();
|
int64_t t0 = ggml_time_ms();
|
||||||
|
sd->apply_loras(lora_f2m);
|
||||||
|
int64_t t1 = ggml_time_ms();
|
||||||
|
LOG_INFO("apply_loras completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
|
||||||
|
|
||||||
|
t0 = ggml_time_ms();
|
||||||
ggml_tensor* c = sd->get_learned_condition(ctx, prompt);
|
ggml_tensor* c = sd->get_learned_condition(ctx, prompt);
|
||||||
struct ggml_tensor* uc = NULL;
|
struct ggml_tensor* uc = NULL;
|
||||||
if (cfg_scale != 1.0) {
|
if (cfg_scale != 1.0) {
|
||||||
uc = sd->get_learned_condition(ctx, negative_prompt);
|
uc = sd->get_learned_condition(ctx, negative_prompt);
|
||||||
}
|
}
|
||||||
int64_t t1 = ggml_time_ms();
|
t1 = ggml_time_ms();
|
||||||
LOG_INFO("get_learned_condition completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
|
LOG_INFO("get_learned_condition completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
|
||||||
|
|
||||||
if (sd->free_params_immediately) {
|
if (sd->free_params_immediately) {
|
||||||
@ -4334,8 +4712,8 @@ std::vector<uint8_t> StableDiffusion::txt2img(const std::string& prompt,
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::vector<uint8_t> StableDiffusion::img2img(const std::vector<uint8_t>& init_img_vec,
|
std::vector<uint8_t> StableDiffusion::img2img(const std::vector<uint8_t>& init_img_vec,
|
||||||
const std::string& prompt,
|
std::string prompt,
|
||||||
const std::string& negative_prompt,
|
std::string negative_prompt,
|
||||||
float cfg_scale,
|
float cfg_scale,
|
||||||
int width,
|
int width,
|
||||||
int height,
|
int height,
|
||||||
@ -4372,14 +4750,29 @@ std::vector<uint8_t> StableDiffusion::img2img(const std::vector<uint8_t>& init_i
|
|||||||
}
|
}
|
||||||
sd->rng->manual_seed(seed);
|
sd->rng->manual_seed(seed);
|
||||||
|
|
||||||
|
// extract and remote lora
|
||||||
|
auto result_pair = extract_and_remove_lora(prompt);
|
||||||
|
std::unordered_map<std::string, float> lora_f2m = result_pair.first; // lora_name -> multiplier
|
||||||
|
for (auto& kv : lora_f2m) {
|
||||||
|
LOG_DEBUG("lora %s:%.2f", kv.first.c_str(), kv.second);
|
||||||
|
}
|
||||||
|
prompt = result_pair.second;
|
||||||
|
LOG_DEBUG("prompt after extract and remote lora: \"%s\"", prompt.c_str());
|
||||||
|
|
||||||
|
// load lora from file
|
||||||
|
int64_t t0 = ggml_time_ms();
|
||||||
|
sd->apply_loras(lora_f2m);
|
||||||
|
int64_t t1 = ggml_time_ms();
|
||||||
|
LOG_INFO("apply_loras completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
|
||||||
|
|
||||||
ggml_tensor* init_img = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, width, height, 3, 1);
|
ggml_tensor* init_img = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, width, height, 3, 1);
|
||||||
image_vec_to_ggml(init_img_vec, init_img);
|
image_vec_to_ggml(init_img_vec, init_img);
|
||||||
|
|
||||||
int64_t t0 = ggml_time_ms();
|
t0 = ggml_time_ms();
|
||||||
ggml_tensor* moments = sd->encode_first_stage(ctx, init_img);
|
ggml_tensor* moments = sd->encode_first_stage(ctx, init_img);
|
||||||
ggml_tensor* init_latent = sd->get_first_stage_encoding(ctx, moments);
|
ggml_tensor* init_latent = sd->get_first_stage_encoding(ctx, moments);
|
||||||
// print_ggml_tensor(init_latent);
|
// print_ggml_tensor(init_latent);
|
||||||
int64_t t1 = ggml_time_ms();
|
t1 = ggml_time_ms();
|
||||||
LOG_INFO("encode_first_stage completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
|
LOG_INFO("encode_first_stage completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
|
||||||
|
|
||||||
ggml_reset_curr_max_dynamic_size(); // reset counter
|
ggml_reset_curr_max_dynamic_size(); // reset counter
|
||||||
|
@ -45,11 +45,12 @@ class StableDiffusion {
|
|||||||
StableDiffusion(int n_threads = -1,
|
StableDiffusion(int n_threads = -1,
|
||||||
bool vae_decode_only = false,
|
bool vae_decode_only = false,
|
||||||
bool free_params_immediately = false,
|
bool free_params_immediately = false,
|
||||||
|
std::string lora_model_dir = "",
|
||||||
RNGType rng_type = STD_DEFAULT_RNG);
|
RNGType rng_type = STD_DEFAULT_RNG);
|
||||||
bool load_from_file(const std::string& file_path, Schedule d = DEFAULT);
|
bool load_from_file(const std::string& file_path, Schedule d = DEFAULT);
|
||||||
std::vector<uint8_t> txt2img(
|
std::vector<uint8_t> txt2img(
|
||||||
const std::string& prompt,
|
std::string prompt,
|
||||||
const std::string& negative_prompt,
|
std::string negative_prompt,
|
||||||
float cfg_scale,
|
float cfg_scale,
|
||||||
int width,
|
int width,
|
||||||
int height,
|
int height,
|
||||||
@ -58,8 +59,8 @@ class StableDiffusion {
|
|||||||
int64_t seed);
|
int64_t seed);
|
||||||
std::vector<uint8_t> img2img(
|
std::vector<uint8_t> img2img(
|
||||||
const std::vector<uint8_t>& init_img,
|
const std::vector<uint8_t>& init_img,
|
||||||
const std::string& prompt,
|
std::string prompt,
|
||||||
const std::string& negative_prompt,
|
std::string negative_prompt,
|
||||||
float cfg_scale,
|
float cfg_scale,
|
||||||
int width,
|
int width,
|
||||||
int height,
|
int height,
|
||||||
|
Loading…
Reference in New Issue
Block a user