feat: add SD2.x support (#40)
This commit is contained in:
parent
c542a77a3f
commit
31e77e1573
@ -14,6 +14,7 @@ Inference of [Stable Diffusion](https://github.com/CompVis/stable-diffusion) in
|
|||||||
- Accelerated memory-efficient CPU inference
|
- Accelerated memory-efficient CPU inference
|
||||||
- Only requires ~2.3GB when using txt2img with fp16 precision to generate a 512x512 image
|
- Only requires ~2.3GB when using txt2img with fp16 precision to generate a 512x512 image
|
||||||
- AVX, AVX2 and AVX512 support for x86 architectures
|
- AVX, AVX2 and AVX512 support for x86 architectures
|
||||||
|
- SD1.x and SD2.x support
|
||||||
- Original `txt2img` and `img2img` mode
|
- Original `txt2img` and `img2img` mode
|
||||||
- Negative prompt
|
- Negative prompt
|
||||||
- [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) style tokenizer (not all the features, only token weighting for now)
|
- [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) style tokenizer (not all the features, only token weighting for now)
|
||||||
@ -60,10 +61,12 @@ git submodule update
|
|||||||
- download original weights(.ckpt or .safetensors). For example
|
- download original weights(.ckpt or .safetensors). For example
|
||||||
- Stable Diffusion v1.4 from https://huggingface.co/CompVis/stable-diffusion-v-1-4-original
|
- Stable Diffusion v1.4 from https://huggingface.co/CompVis/stable-diffusion-v-1-4-original
|
||||||
- Stable Diffusion v1.5 from https://huggingface.co/runwayml/stable-diffusion-v1-5
|
- Stable Diffusion v1.5 from https://huggingface.co/runwayml/stable-diffusion-v1-5
|
||||||
|
- Stable Diffuison v2.1 from https://huggingface.co/stabilityai/stable-diffusion-2-1
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
curl -L -O https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt
|
curl -L -O https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt
|
||||||
# curl -L -O https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.safetensors
|
# curl -L -O https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.safetensors
|
||||||
|
# curl -L -o https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/v2-1_768-nonema-pruned.safetensors
|
||||||
```
|
```
|
||||||
|
|
||||||
- convert weights to ggml model format
|
- convert weights to ggml model format
|
||||||
@ -182,5 +185,6 @@ docker run -v /path/to/models:/models -v /path/to/output/:/output sd [args...]
|
|||||||
|
|
||||||
- [ggml](https://github.com/ggerganov/ggml)
|
- [ggml](https://github.com/ggerganov/ggml)
|
||||||
- [stable-diffusion](https://github.com/CompVis/stable-diffusion)
|
- [stable-diffusion](https://github.com/CompVis/stable-diffusion)
|
||||||
|
- [stable-diffusion-stability-ai](https://github.com/Stability-AI/stablediffusion)
|
||||||
- [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui)
|
- [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui)
|
||||||
- [k-diffusion](https://github.com/crowsonkb/k-diffusion)
|
- [k-diffusion](https://github.com/crowsonkb/k-diffusion)
|
||||||
|
@ -9,6 +9,9 @@ import safetensors.torch
|
|||||||
this_file_dir = os.path.dirname(__file__)
|
this_file_dir = os.path.dirname(__file__)
|
||||||
vocab_dir = this_file_dir
|
vocab_dir = this_file_dir
|
||||||
|
|
||||||
|
SD1 = 0
|
||||||
|
SD2 = 1
|
||||||
|
|
||||||
ggml_ftype_str_to_int = {
|
ggml_ftype_str_to_int = {
|
||||||
"f32": 0,
|
"f32": 0,
|
||||||
"f16": 1,
|
"f16": 1,
|
||||||
@ -155,6 +158,8 @@ unused_tensors = [
|
|||||||
"posterior_mean_coef1",
|
"posterior_mean_coef1",
|
||||||
"posterior_mean_coef2",
|
"posterior_mean_coef2",
|
||||||
"cond_stage_model.transformer.text_model.embeddings.position_ids",
|
"cond_stage_model.transformer.text_model.embeddings.position_ids",
|
||||||
|
"cond_stage_model.model.logit_scale",
|
||||||
|
"cond_stage_model.model.text_projection",
|
||||||
"model_ema.decay",
|
"model_ema.decay",
|
||||||
"model_ema.num_updates",
|
"model_ema.num_updates",
|
||||||
"control_model",
|
"control_model",
|
||||||
@ -162,12 +167,8 @@ unused_tensors = [
|
|||||||
"embedding_manager"
|
"embedding_manager"
|
||||||
]
|
]
|
||||||
|
|
||||||
def convert(model_path, out_type = None, out_file=None):
|
|
||||||
# load model
|
|
||||||
with open(os.path.join(vocab_dir, "vocab.json"), encoding="utf-8") as f:
|
|
||||||
clip_vocab = json.load(f)
|
|
||||||
|
|
||||||
state_dict = load_model_from_file(model_path)
|
def preprocess(state_dict):
|
||||||
alphas_cumprod = state_dict.get("alphas_cumprod")
|
alphas_cumprod = state_dict.get("alphas_cumprod")
|
||||||
if alphas_cumprod != None:
|
if alphas_cumprod != None:
|
||||||
# print((np.abs(get_alpha_comprod().numpy() - alphas_cumprod.numpy()) < 0.000001).all())
|
# print((np.abs(get_alpha_comprod().numpy() - alphas_cumprod.numpy()) < 0.000001).all())
|
||||||
@ -177,10 +178,99 @@ def convert(model_path, out_type = None, out_file=None):
|
|||||||
alphas_cumprod = get_alpha_comprod()
|
alphas_cumprod = get_alpha_comprod()
|
||||||
state_dict["alphas_cumprod"] = alphas_cumprod
|
state_dict["alphas_cumprod"] = alphas_cumprod
|
||||||
|
|
||||||
|
new_state_dict = {}
|
||||||
|
for name in state_dict.keys():
|
||||||
|
# ignore unused tensors
|
||||||
|
if not isinstance(state_dict[name], torch.Tensor):
|
||||||
|
continue
|
||||||
|
skip = False
|
||||||
|
for unused_tensor in unused_tensors:
|
||||||
|
if name.startswith(unused_tensor):
|
||||||
|
skip = True
|
||||||
|
break
|
||||||
|
if skip:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# convert open_clip to hf CLIPTextModel (for SD2.x)
|
||||||
|
open_clip_to_hf_clip_model = {
|
||||||
|
"cond_stage_model.model.ln_final.bias": "cond_stage_model.transformer.text_model.final_layer_norm.bias",
|
||||||
|
"cond_stage_model.model.ln_final.weight": "cond_stage_model.transformer.text_model.final_layer_norm.weight",
|
||||||
|
"cond_stage_model.model.positional_embedding": "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight",
|
||||||
|
"cond_stage_model.model.token_embedding.weight": "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight",
|
||||||
|
}
|
||||||
|
open_clip_to_hk_clip_resblock = {
|
||||||
|
"attn.out_proj.bias": "self_attn.out_proj.bias",
|
||||||
|
"attn.out_proj.weight": "self_attn.out_proj.weight",
|
||||||
|
"ln_1.bias": "layer_norm1.bias",
|
||||||
|
"ln_1.weight": "layer_norm1.weight",
|
||||||
|
"ln_2.bias": "layer_norm2.bias",
|
||||||
|
"ln_2.weight": "layer_norm2.weight",
|
||||||
|
"mlp.c_fc.bias": "mlp.fc1.bias",
|
||||||
|
"mlp.c_fc.weight": "mlp.fc1.weight",
|
||||||
|
"mlp.c_proj.bias": "mlp.fc2.bias",
|
||||||
|
"mlp.c_proj.weight": "mlp.fc2.weight",
|
||||||
|
}
|
||||||
|
open_clip_resblock_prefix = "cond_stage_model.model.transformer.resblocks."
|
||||||
|
hf_clip_resblock_prefix = "cond_stage_model.transformer.text_model.encoder.layers."
|
||||||
|
if name in open_clip_to_hf_clip_model:
|
||||||
|
new_name = open_clip_to_hf_clip_model[name]
|
||||||
|
new_state_dict[new_name] = state_dict[name]
|
||||||
|
print(f"preprocess {name} => {new_name}")
|
||||||
|
continue
|
||||||
|
if name.startswith(open_clip_resblock_prefix):
|
||||||
|
remain = name[len(open_clip_resblock_prefix):]
|
||||||
|
idx = remain.split(".")[0]
|
||||||
|
suffix = remain[len(idx)+1:]
|
||||||
|
if suffix == "attn.in_proj_weight":
|
||||||
|
w = state_dict[name]
|
||||||
|
w_q, w_k, w_v = w.chunk(3)
|
||||||
|
for new_suffix, new_w in zip(["self_attn.q_proj.weight", "self_attn.k_proj.weight", "self_attn.v_proj.weight"], [w_q, w_k, w_v]):
|
||||||
|
new_name = hf_clip_resblock_prefix + idx + "." + new_suffix
|
||||||
|
new_state_dict[new_name] = new_w
|
||||||
|
print(f"preprocess {name}{w.size()} => {new_name}{new_w.size()}")
|
||||||
|
elif suffix == "attn.in_proj_bias":
|
||||||
|
w = state_dict[name]
|
||||||
|
w_q, w_k, w_v = w.chunk(3)
|
||||||
|
for new_suffix, new_w in zip(["self_attn.q_proj.bias", "self_attn.k_proj.bias", "self_attn.v_proj.bias"], [w_q, w_k, w_v]):
|
||||||
|
new_name = hf_clip_resblock_prefix + idx + "." + new_suffix
|
||||||
|
new_state_dict[new_name] = new_w
|
||||||
|
print(f"preprocess {name}{w.size()} => {new_name}{new_w.size()}")
|
||||||
|
else:
|
||||||
|
new_suffix = open_clip_to_hk_clip_resblock[suffix]
|
||||||
|
new_name = hf_clip_resblock_prefix + idx + "." + new_suffix
|
||||||
|
new_state_dict[new_name] = state_dict[name]
|
||||||
|
print(f"preprocess {name} => {new_name}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# convert unet transformer linear to conv2d 1x1
|
||||||
|
if name.startswith("model.diffusion_model.") and (name.endswith("proj_in.weight") or name.endswith("proj_out.weight")):
|
||||||
|
w = state_dict[name]
|
||||||
|
if len(state_dict[name].shape) == 2:
|
||||||
|
new_w = w.unsqueeze(2).unsqueeze(3)
|
||||||
|
new_state_dict[name] = new_w
|
||||||
|
print(f"preprocess {name} {w.size()} => {name} {new_w.size()}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
new_state_dict[name] = state_dict[name]
|
||||||
|
return new_state_dict
|
||||||
|
|
||||||
|
def convert(model_path, out_type = None, out_file=None):
|
||||||
|
# load model
|
||||||
|
with open(os.path.join(vocab_dir, "vocab.json"), encoding="utf-8") as f:
|
||||||
|
clip_vocab = json.load(f)
|
||||||
|
|
||||||
|
state_dict = load_model_from_file(model_path)
|
||||||
|
model_type = SD1
|
||||||
|
if "cond_stage_model.model.token_embedding.weight" in state_dict.keys():
|
||||||
|
model_type = SD2
|
||||||
|
print("Stable diffuison 2.x")
|
||||||
|
else:
|
||||||
|
print("Stable diffuison 1.x")
|
||||||
|
state_dict = preprocess(state_dict)
|
||||||
|
|
||||||
# output option
|
# output option
|
||||||
if out_type == None:
|
if out_type == None:
|
||||||
weight = state_dict["cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.weight"].numpy()
|
weight = state_dict["model.diffusion_model.input_blocks.0.0.weight"].numpy()
|
||||||
if weight.dtype == np.float32:
|
if weight.dtype == np.float32:
|
||||||
out_type = "f32"
|
out_type = "f32"
|
||||||
elif weight.dtype == np.float16:
|
elif weight.dtype == np.float16:
|
||||||
@ -198,8 +288,9 @@ def convert(model_path, out_type = None, out_file=None):
|
|||||||
with open(out_file, "wb") as file:
|
with open(out_file, "wb") as file:
|
||||||
# magic: ggml in hex
|
# magic: ggml in hex
|
||||||
file.write(struct.pack("i", 0x67676D6C))
|
file.write(struct.pack("i", 0x67676D6C))
|
||||||
# out type
|
# model & file type
|
||||||
file.write(struct.pack("i", ggml_ftype_str_to_int[out_type]))
|
ftype = (model_type << 16) | ggml_ftype_str_to_int[out_type]
|
||||||
|
file.write(struct.pack("i", ftype))
|
||||||
|
|
||||||
# vocab
|
# vocab
|
||||||
byte_encoder = bytes_to_unicode()
|
byte_encoder = bytes_to_unicode()
|
||||||
|
@ -48,6 +48,16 @@ static SDLogLevel log_level = SDLogLevel::INFO;
|
|||||||
|
|
||||||
#define TIMESTEPS 1000
|
#define TIMESTEPS 1000
|
||||||
|
|
||||||
|
enum ModelType {
|
||||||
|
SD1 = 0,
|
||||||
|
SD2 = 1,
|
||||||
|
MODEL_TYPE_COUNT,
|
||||||
|
};
|
||||||
|
|
||||||
|
const char* model_type_to_str[] = {
|
||||||
|
"SD1.x",
|
||||||
|
"SD2.x"};
|
||||||
|
|
||||||
/*================================================== Helper Functions ================================================*/
|
/*================================================== Helper Functions ================================================*/
|
||||||
|
|
||||||
void set_sd_log_level(SDLogLevel level) {
|
void set_sd_log_level(SDLogLevel level) {
|
||||||
@ -257,8 +267,8 @@ void image_vec_to_ggml(const std::vector<uint8_t>& vec,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor * ggml_group_norm_32(struct ggml_context * ctx,
|
struct ggml_tensor* ggml_group_norm_32(struct ggml_context* ctx,
|
||||||
struct ggml_tensor * a) {
|
struct ggml_tensor* a) {
|
||||||
return ggml_group_norm(ctx, a, 32);
|
return ggml_group_norm(ctx, a, 32);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -278,6 +288,7 @@ const int PAD_TOKEN_ID = 49407;
|
|||||||
// TODO: implement bpe
|
// TODO: implement bpe
|
||||||
class CLIPTokenizer {
|
class CLIPTokenizer {
|
||||||
private:
|
private:
|
||||||
|
ModelType model_type = SD1;
|
||||||
std::map<std::string, int32_t> encoder;
|
std::map<std::string, int32_t> encoder;
|
||||||
std::regex pat;
|
std::regex pat;
|
||||||
|
|
||||||
@ -300,7 +311,8 @@ class CLIPTokenizer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
CLIPTokenizer() = default;
|
CLIPTokenizer(ModelType model_type = SD1)
|
||||||
|
: model_type(model_type){};
|
||||||
std::string bpe(std::string token) {
|
std::string bpe(std::string token) {
|
||||||
std::string word = token + "</w>";
|
std::string word = token + "</w>";
|
||||||
if (encoder.find(word) != encoder.end()) {
|
if (encoder.find(word) != encoder.end()) {
|
||||||
@ -321,13 +333,18 @@ class CLIPTokenizer {
|
|||||||
if (max_length > 0) {
|
if (max_length > 0) {
|
||||||
if (tokens.size() > max_length - 1) {
|
if (tokens.size() > max_length - 1) {
|
||||||
tokens.resize(max_length - 1);
|
tokens.resize(max_length - 1);
|
||||||
|
tokens.push_back(EOS_TOKEN_ID);
|
||||||
} else {
|
} else {
|
||||||
|
tokens.push_back(EOS_TOKEN_ID);
|
||||||
if (padding) {
|
if (padding) {
|
||||||
tokens.insert(tokens.end(), max_length - 1 - tokens.size(), PAD_TOKEN_ID);
|
int pad_token_id = PAD_TOKEN_ID;
|
||||||
|
if (model_type == SD2) {
|
||||||
|
pad_token_id = 0;
|
||||||
|
}
|
||||||
|
tokens.insert(tokens.end(), max_length - tokens.size(), pad_token_id);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
tokens.push_back(EOS_TOKEN_ID);
|
|
||||||
return tokens;
|
return tokens;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -635,7 +652,11 @@ struct ResidualAttentionBlock {
|
|||||||
x = ggml_mul_mat(ctx, fc1_w, x);
|
x = ggml_mul_mat(ctx, fc1_w, x);
|
||||||
x = ggml_add(ctx, ggml_repeat(ctx, fc1_b, x), x);
|
x = ggml_add(ctx, ggml_repeat(ctx, fc1_b, x), x);
|
||||||
|
|
||||||
x = ggml_gelu_quick_inplace(ctx, x);
|
if (hidden_size == 1024) { // SD 2.x
|
||||||
|
x = ggml_gelu_inplace(ctx, x);
|
||||||
|
} else { // SD 1.x
|
||||||
|
x = ggml_gelu_quick_inplace(ctx, x);
|
||||||
|
}
|
||||||
|
|
||||||
x = ggml_mul_mat(ctx, fc2_w, x);
|
x = ggml_mul_mat(ctx, fc2_w, x);
|
||||||
x = ggml_add(ctx, ggml_repeat(ctx, fc2_b, x), x);
|
x = ggml_add(ctx, ggml_repeat(ctx, fc2_b, x), x);
|
||||||
@ -647,26 +668,40 @@ struct ResidualAttentionBlock {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// SD1.x: https://huggingface.co/openai/clip-vit-large-patch14/blob/main/config.json
|
||||||
|
// SD2.x: https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K/blob/main/config.json
|
||||||
struct CLIPTextModel {
|
struct CLIPTextModel {
|
||||||
|
ModelType model_type = SD1;
|
||||||
// network hparams
|
// network hparams
|
||||||
int32_t vocab_size = 49408;
|
int32_t vocab_size = 49408;
|
||||||
int32_t max_position_embeddings = 77;
|
int32_t max_position_embeddings = 77;
|
||||||
int32_t hidden_size = 768;
|
int32_t hidden_size = 768; // 1024 for SD 2.x
|
||||||
int32_t intermediate_size = 3072;
|
int32_t intermediate_size = 3072; // 4096 for SD 2.x
|
||||||
int32_t projection_dim = 768;
|
int32_t n_head = 12; // num_attention_heads, 16 for SD 2.x
|
||||||
int32_t n_head = 12; // num_attention_heads
|
int32_t num_hidden_layers = 12; // 24 for SD 2.x
|
||||||
int32_t num_hidden_layers = 12;
|
|
||||||
|
|
||||||
// embeddings
|
// embeddings
|
||||||
struct ggml_tensor* position_ids;
|
struct ggml_tensor* position_ids;
|
||||||
struct ggml_tensor* token_embed_weight;
|
struct ggml_tensor* token_embed_weight;
|
||||||
struct ggml_tensor* position_embed_weight;
|
struct ggml_tensor* position_embed_weight;
|
||||||
// transformer
|
// transformer
|
||||||
ResidualAttentionBlock resblocks[12];
|
std::vector<ResidualAttentionBlock> resblocks;
|
||||||
struct ggml_tensor* final_ln_w;
|
struct ggml_tensor* final_ln_w;
|
||||||
struct ggml_tensor* final_ln_b;
|
struct ggml_tensor* final_ln_b;
|
||||||
|
|
||||||
CLIPTextModel() {
|
CLIPTextModel(ModelType model_type = SD1)
|
||||||
|
: model_type(model_type) {
|
||||||
|
if (model_type == SD2) {
|
||||||
|
hidden_size = 1024;
|
||||||
|
intermediate_size = 4096;
|
||||||
|
n_head = 16;
|
||||||
|
num_hidden_layers = 24;
|
||||||
|
}
|
||||||
|
resblocks.resize(num_hidden_layers);
|
||||||
|
set_resblocks_hp_params();
|
||||||
|
}
|
||||||
|
|
||||||
|
void set_resblocks_hp_params() {
|
||||||
int d_model = hidden_size / n_head; // 64
|
int d_model = hidden_size / n_head; // 64
|
||||||
for (int i = 0; i < num_hidden_layers; i++) {
|
for (int i = 0; i < num_hidden_layers; i++) {
|
||||||
resblocks[i].d_model = d_model;
|
resblocks[i].d_model = d_model;
|
||||||
@ -729,6 +764,9 @@ struct CLIPTextModel {
|
|||||||
|
|
||||||
// transformer
|
// transformer
|
||||||
for (int i = 0; i < num_hidden_layers; i++) {
|
for (int i = 0; i < num_hidden_layers; i++) {
|
||||||
|
if (model_type == SD2 && i == num_hidden_layers - 1) { // layer: "penultimate"
|
||||||
|
break;
|
||||||
|
}
|
||||||
x = resblocks[i].forward(ctx, x); // [N, n_token, hidden_size]
|
x = resblocks[i].forward(ctx, x); // [N, n_token, hidden_size]
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -759,9 +797,13 @@ struct FrozenCLIPEmbedder {
|
|||||||
|
|
||||||
// Ref: https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/cad87bf4e3e0b0a759afa94e933527c3123d59bc/modules/sd_hijack_clip.py#L283
|
// Ref: https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/cad87bf4e3e0b0a759afa94e933527c3123d59bc/modules/sd_hijack_clip.py#L283
|
||||||
struct FrozenCLIPEmbedderWithCustomWords {
|
struct FrozenCLIPEmbedderWithCustomWords {
|
||||||
|
ModelType model_type = SD1;
|
||||||
CLIPTokenizer tokenizer;
|
CLIPTokenizer tokenizer;
|
||||||
CLIPTextModel text_model;
|
CLIPTextModel text_model;
|
||||||
|
|
||||||
|
FrozenCLIPEmbedderWithCustomWords(ModelType model_type = SD1)
|
||||||
|
: model_type(model_type), tokenizer(model_type), text_model(model_type) {}
|
||||||
|
|
||||||
std::pair<std::vector<int>, std::vector<float>> tokenize(std::string text,
|
std::pair<std::vector<int>, std::vector<float>> tokenize(std::string text,
|
||||||
size_t max_length = 0,
|
size_t max_length = 0,
|
||||||
bool padding = false) {
|
bool padding = false) {
|
||||||
@ -793,15 +835,21 @@ struct FrozenCLIPEmbedderWithCustomWords {
|
|||||||
if (tokens.size() > max_length - 1) {
|
if (tokens.size() > max_length - 1) {
|
||||||
tokens.resize(max_length - 1);
|
tokens.resize(max_length - 1);
|
||||||
weights.resize(max_length - 1);
|
weights.resize(max_length - 1);
|
||||||
|
tokens.push_back(EOS_TOKEN_ID);
|
||||||
|
weights.push_back(1.0);
|
||||||
} else {
|
} else {
|
||||||
|
tokens.push_back(EOS_TOKEN_ID);
|
||||||
|
weights.push_back(1.0);
|
||||||
if (padding) {
|
if (padding) {
|
||||||
tokens.insert(tokens.end(), max_length - 1 - tokens.size(), PAD_TOKEN_ID);
|
int pad_token_id = PAD_TOKEN_ID;
|
||||||
weights.insert(weights.end(), max_length - 1 - weights.size(), 1.0);
|
if (model_type == SD2) {
|
||||||
|
pad_token_id = 0;
|
||||||
|
}
|
||||||
|
tokens.insert(tokens.end(), max_length - tokens.size(), pad_token_id);
|
||||||
|
weights.insert(weights.end(), max_length - weights.size(), 1.0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
tokens.push_back(EOS_TOKEN_ID);
|
|
||||||
weights.push_back(1.0);
|
|
||||||
|
|
||||||
// for (int i = 0; i < tokens.size(); i++) {
|
// for (int i = 0; i < tokens.size(); i++) {
|
||||||
// std::cout << tokens[i] << ":" << weights[i] << ", ";
|
// std::cout << tokens[i] << ":" << weights[i] << ", ";
|
||||||
@ -974,7 +1022,7 @@ struct SpatialTransformer {
|
|||||||
int n_head; // num_heads
|
int n_head; // num_heads
|
||||||
int d_head; // in_channels // n_heads
|
int d_head; // in_channels // n_heads
|
||||||
int depth = 1; // 1
|
int depth = 1; // 1
|
||||||
int context_dim = 768; // hidden_size
|
int context_dim = 768; // hidden_size, 1024 for SD2.x
|
||||||
|
|
||||||
// group norm
|
// group norm
|
||||||
struct ggml_tensor* norm_w; // [in_channels,]
|
struct ggml_tensor* norm_w; // [in_channels,]
|
||||||
@ -1459,6 +1507,7 @@ struct UNetModel {
|
|||||||
int time_embed_dim = 1280; // model_channels*4
|
int time_embed_dim = 1280; // model_channels*4
|
||||||
int num_heads = 8;
|
int num_heads = 8;
|
||||||
int num_head_channels = -1; // channels // num_heads
|
int num_head_channels = -1; // channels // num_heads
|
||||||
|
int context_dim = 768; // 1024 for SD2.x
|
||||||
|
|
||||||
// network params
|
// network params
|
||||||
struct ggml_tensor* time_embed_0_w; // [time_embed_dim, model_channels]
|
struct ggml_tensor* time_embed_0_w; // [time_embed_dim, model_channels]
|
||||||
@ -1493,7 +1542,12 @@ struct UNetModel {
|
|||||||
struct ggml_tensor* out_2_w; // [out_channels, model_channels, 3, 3]
|
struct ggml_tensor* out_2_w; // [out_channels, model_channels, 3, 3]
|
||||||
struct ggml_tensor* out_2_b; // [out_channels, ]
|
struct ggml_tensor* out_2_b; // [out_channels, ]
|
||||||
|
|
||||||
UNetModel() {
|
UNetModel(ModelType model_type = SD1) {
|
||||||
|
if (model_type == SD2) {
|
||||||
|
context_dim = 1024;
|
||||||
|
num_head_channels = 64;
|
||||||
|
num_heads = -1;
|
||||||
|
}
|
||||||
// set up hparams of blocks
|
// set up hparams of blocks
|
||||||
|
|
||||||
// input_blocks
|
// input_blocks
|
||||||
@ -1513,9 +1567,16 @@ struct UNetModel {
|
|||||||
ch = mult * model_channels;
|
ch = mult * model_channels;
|
||||||
|
|
||||||
if (ds == attention_resolutions[0] || ds == attention_resolutions[1] || ds == attention_resolutions[2]) {
|
if (ds == attention_resolutions[0] || ds == attention_resolutions[1] || ds == attention_resolutions[2]) {
|
||||||
|
int n_head = num_heads;
|
||||||
|
int d_head = ch / num_heads;
|
||||||
|
if (num_head_channels != -1) {
|
||||||
|
d_head = num_head_channels;
|
||||||
|
n_head = ch / d_head;
|
||||||
|
}
|
||||||
input_transformers[i][j].in_channels = ch;
|
input_transformers[i][j].in_channels = ch;
|
||||||
input_transformers[i][j].n_head = num_heads;
|
input_transformers[i][j].n_head = n_head;
|
||||||
input_transformers[i][j].d_head = ch / num_heads;
|
input_transformers[i][j].d_head = d_head;
|
||||||
|
input_transformers[i][j].context_dim = context_dim;
|
||||||
}
|
}
|
||||||
input_block_chans.push_back(ch);
|
input_block_chans.push_back(ch);
|
||||||
}
|
}
|
||||||
@ -1533,9 +1594,16 @@ struct UNetModel {
|
|||||||
middle_block_0.emb_channels = time_embed_dim;
|
middle_block_0.emb_channels = time_embed_dim;
|
||||||
middle_block_0.out_channels = ch;
|
middle_block_0.out_channels = ch;
|
||||||
|
|
||||||
|
int n_head = num_heads;
|
||||||
|
int d_head = ch / num_heads;
|
||||||
|
if (num_head_channels != -1) {
|
||||||
|
d_head = num_head_channels;
|
||||||
|
n_head = ch / d_head;
|
||||||
|
}
|
||||||
middle_block_1.in_channels = ch;
|
middle_block_1.in_channels = ch;
|
||||||
middle_block_1.n_head = num_heads;
|
middle_block_1.n_head = n_head;
|
||||||
middle_block_1.d_head = ch / num_heads;
|
middle_block_1.d_head = d_head;
|
||||||
|
middle_block_1.context_dim = context_dim;
|
||||||
|
|
||||||
middle_block_2.channels = ch;
|
middle_block_2.channels = ch;
|
||||||
middle_block_2.emb_channels = time_embed_dim;
|
middle_block_2.emb_channels = time_embed_dim;
|
||||||
@ -1555,9 +1623,16 @@ struct UNetModel {
|
|||||||
ch = mult * model_channels;
|
ch = mult * model_channels;
|
||||||
|
|
||||||
if (ds == attention_resolutions[0] || ds == attention_resolutions[1] || ds == attention_resolutions[2]) {
|
if (ds == attention_resolutions[0] || ds == attention_resolutions[1] || ds == attention_resolutions[2]) {
|
||||||
|
int n_head = num_heads;
|
||||||
|
int d_head = ch / num_heads;
|
||||||
|
if (num_head_channels != -1) {
|
||||||
|
d_head = num_head_channels;
|
||||||
|
n_head = ch / d_head;
|
||||||
|
}
|
||||||
output_transformers[i][j].in_channels = ch;
|
output_transformers[i][j].in_channels = ch;
|
||||||
output_transformers[i][j].n_head = num_heads;
|
output_transformers[i][j].n_head = n_head;
|
||||||
output_transformers[i][j].d_head = ch / num_heads;
|
output_transformers[i][j].d_head = d_head;
|
||||||
|
output_transformers[i][j].context_dim = context_dim;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (i > 0 && j == num_res_blocks) {
|
if (i > 0 && j == num_res_blocks) {
|
||||||
@ -2584,7 +2659,8 @@ struct AutoEncoderKL {
|
|||||||
/*================================================= CompVisDenoiser ==================================================*/
|
/*================================================= CompVisDenoiser ==================================================*/
|
||||||
|
|
||||||
// Ref: https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/external.py
|
// Ref: https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/external.py
|
||||||
struct CompVisDenoiser {
|
|
||||||
|
struct DiscreteSchedule {
|
||||||
float alphas_cumprod[TIMESTEPS];
|
float alphas_cumprod[TIMESTEPS];
|
||||||
float sigmas[TIMESTEPS];
|
float sigmas[TIMESTEPS];
|
||||||
float log_sigmas[TIMESTEPS];
|
float log_sigmas[TIMESTEPS];
|
||||||
@ -2602,12 +2678,6 @@ struct CompVisDenoiser {
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::pair<float, float> get_scalings(float sigma) {
|
|
||||||
float c_out = -sigma;
|
|
||||||
float c_in = 1.0f / std::sqrt(sigma * sigma + 1);
|
|
||||||
return std::pair<float, float>(c_in, c_out);
|
|
||||||
}
|
|
||||||
|
|
||||||
float sigma_to_t(float sigma) {
|
float sigma_to_t(float sigma) {
|
||||||
float log_sigma = std::log(sigma);
|
float log_sigma = std::log(sigma);
|
||||||
std::vector<float> dists;
|
std::vector<float> dists;
|
||||||
@ -2641,6 +2711,29 @@ struct CompVisDenoiser {
|
|||||||
float log_sigma = (1.0f - w) * log_sigmas[low_idx] + w * log_sigmas[high_idx];
|
float log_sigma = (1.0f - w) * log_sigmas[low_idx] + w * log_sigmas[high_idx];
|
||||||
return std::exp(log_sigma);
|
return std::exp(log_sigma);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
virtual std::vector<float> get_scalings(float sigma) = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct CompVisDenoiser : public DiscreteSchedule {
|
||||||
|
float sigma_data = 1.0f;
|
||||||
|
|
||||||
|
std::vector<float> get_scalings(float sigma) {
|
||||||
|
float c_out = -sigma;
|
||||||
|
float c_in = 1.0f / std::sqrt(sigma * sigma + sigma_data * sigma_data);
|
||||||
|
return {c_out, c_in};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct CompVisVDenoiser : public DiscreteSchedule {
|
||||||
|
float sigma_data = 1.0f;
|
||||||
|
|
||||||
|
std::vector<float> get_scalings(float sigma) {
|
||||||
|
float c_skip = sigma_data * sigma_data / (sigma * sigma + sigma_data * sigma_data);
|
||||||
|
float c_out = -sigma * sigma_data / std::sqrt(sigma * sigma + sigma_data * sigma_data);
|
||||||
|
float c_in = 1.0f / std::sqrt(sigma * sigma + sigma_data * sigma_data);
|
||||||
|
return {c_skip, c_out, c_in};
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/*=============================================== StableDiffusionGGML ================================================*/
|
/*=============================================== StableDiffusionGGML ================================================*/
|
||||||
@ -2666,7 +2759,7 @@ class StableDiffusionGGML {
|
|||||||
UNetModel diffusion_model;
|
UNetModel diffusion_model;
|
||||||
AutoEncoderKL first_stage_model;
|
AutoEncoderKL first_stage_model;
|
||||||
|
|
||||||
CompVisDenoiser denoiser;
|
std::shared_ptr<DiscreteSchedule> denoiser = std::make_shared<CompVisDenoiser>();
|
||||||
|
|
||||||
StableDiffusionGGML() = default;
|
StableDiffusionGGML() = default;
|
||||||
|
|
||||||
@ -2717,9 +2810,20 @@ class StableDiffusionGGML {
|
|||||||
LOG_DEBUG("loading hparams");
|
LOG_DEBUG("loading hparams");
|
||||||
// load hparams
|
// load hparams
|
||||||
file.read(reinterpret_cast<char*>(&ftype), sizeof(ftype));
|
file.read(reinterpret_cast<char*>(&ftype), sizeof(ftype));
|
||||||
// for the big tensors, we have the option to store the data in 16-bit floats or quantized
|
|
||||||
// in order to save memory and also to speed up the computation
|
int model_type = (ftype >> 16) & 0xFFFF;
|
||||||
ggml_type wtype = ggml_ftype_to_ggml_type((ggml_ftype)(ftype));
|
if (model_type >= MODEL_TYPE_COUNT) {
|
||||||
|
LOG_ERROR("invalid model file '%s' (bad model type value %d)", file_path.c_str(), ftype);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
LOG_INFO("model type: %s", model_type_to_str[model_type]);
|
||||||
|
|
||||||
|
if (model_type == SD2) {
|
||||||
|
cond_stage_model = FrozenCLIPEmbedderWithCustomWords((ModelType)model_type);
|
||||||
|
diffusion_model = UNetModel((ModelType)model_type);
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_type wtype = ggml_ftype_to_ggml_type((ggml_ftype)(ftype & 0xFFFF));
|
||||||
LOG_INFO("ftype: %s", ggml_type_name(wtype));
|
LOG_INFO("ftype: %s", ggml_type_name(wtype));
|
||||||
if (wtype == GGML_TYPE_COUNT) {
|
if (wtype == GGML_TYPE_COUNT) {
|
||||||
LOG_ERROR("invalid model file '%s' (bad ftype value %d)", file_path.c_str(), ftype);
|
LOG_ERROR("invalid model file '%s' (bad ftype value %d)", file_path.c_str(), ftype);
|
||||||
@ -2840,6 +2944,7 @@ class StableDiffusionGGML {
|
|||||||
std::set<std::string> tensor_names_in_file;
|
std::set<std::string> tensor_names_in_file;
|
||||||
int64_t t0 = ggml_time_ms();
|
int64_t t0 = ggml_time_ms();
|
||||||
// load weights
|
// load weights
|
||||||
|
float alphas_cumprod[TIMESTEPS];
|
||||||
{
|
{
|
||||||
int n_tensors = 0;
|
int n_tensors = 0;
|
||||||
size_t total_size = 0;
|
size_t total_size = 0;
|
||||||
@ -2872,12 +2977,7 @@ class StableDiffusionGGML {
|
|||||||
tensor_names_in_file.insert(std::string(name.data()));
|
tensor_names_in_file.insert(std::string(name.data()));
|
||||||
|
|
||||||
if (std::string(name.data()) == "alphas_cumprod") {
|
if (std::string(name.data()) == "alphas_cumprod") {
|
||||||
file.read(reinterpret_cast<char*>(denoiser.alphas_cumprod),
|
file.read(reinterpret_cast<char*>(alphas_cumprod), nelements * ggml_type_size((ggml_type)ttype));
|
||||||
nelements * ggml_type_size((ggml_type)ttype));
|
|
||||||
for (int i = 0; i < 1000; i++) {
|
|
||||||
denoiser.sigmas[i] = std::sqrt((1 - denoiser.alphas_cumprod[i]) / denoiser.alphas_cumprod[i]);
|
|
||||||
denoiser.log_sigmas[i] = std::log(denoiser.sigmas[i]);
|
|
||||||
}
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2953,9 +3053,143 @@ class StableDiffusionGGML {
|
|||||||
int64_t t1 = ggml_time_ms();
|
int64_t t1 = ggml_time_ms();
|
||||||
LOG_INFO("loading model from '%s' completed, taking %.2fs", file_path.c_str(), (t1 - t0) * 1.0f / 1000);
|
LOG_INFO("loading model from '%s' completed, taking %.2fs", file_path.c_str(), (t1 - t0) * 1.0f / 1000);
|
||||||
file.close();
|
file.close();
|
||||||
|
|
||||||
|
// check is_using_v_parameterization_for_sd2
|
||||||
|
bool is_using_v_parameterization = false;
|
||||||
|
if (model_type == SD2) {
|
||||||
|
struct ggml_init_params params;
|
||||||
|
params.mem_size = static_cast<size_t>(10 * 1024) * 1024; // 10M
|
||||||
|
params.mem_buffer = NULL;
|
||||||
|
params.no_alloc = false;
|
||||||
|
params.dynamic = false;
|
||||||
|
struct ggml_context* ctx = ggml_init(params);
|
||||||
|
if (!ctx) {
|
||||||
|
LOG_ERROR("ggml_init() failed");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (is_using_v_parameterization_for_sd2(ctx)) {
|
||||||
|
is_using_v_parameterization = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (is_using_v_parameterization) {
|
||||||
|
denoiser = std::make_shared<CompVisVDenoiser>();
|
||||||
|
LOG_INFO("running in v-prediction mode");
|
||||||
|
} else {
|
||||||
|
LOG_INFO("running in eps-prediction mode");
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < TIMESTEPS; i++) {
|
||||||
|
denoiser->alphas_cumprod[i] = alphas_cumprod[i];
|
||||||
|
denoiser->sigmas[i] = std::sqrt((1 - denoiser->alphas_cumprod[i]) / denoiser->alphas_cumprod[i]);
|
||||||
|
denoiser->log_sigmas[i] = std::log(denoiser->sigmas[i]);
|
||||||
|
}
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool is_using_v_parameterization_for_sd2(ggml_context* res_ctx) {
|
||||||
|
struct ggml_tensor* x_t = ggml_new_tensor_4d(res_ctx, GGML_TYPE_F32, 8, 8, 4, 1);
|
||||||
|
ggml_set_f32(x_t, 0.5);
|
||||||
|
struct ggml_tensor* c = ggml_new_tensor_4d(res_ctx, GGML_TYPE_F32, 1024, 2, 1, 1);
|
||||||
|
ggml_set_f32(c, 0.5);
|
||||||
|
|
||||||
|
size_t ctx_size = 1 * 1024 * 1024; // 1MB
|
||||||
|
// calculate the amount of memory required
|
||||||
|
{
|
||||||
|
struct ggml_init_params params;
|
||||||
|
params.mem_size = ctx_size;
|
||||||
|
params.mem_buffer = NULL;
|
||||||
|
params.no_alloc = true;
|
||||||
|
params.dynamic = dynamic;
|
||||||
|
|
||||||
|
struct ggml_context* ctx = ggml_init(params);
|
||||||
|
if (!ctx) {
|
||||||
|
LOG_ERROR("ggml_init() failed");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_set_dynamic(ctx, false);
|
||||||
|
struct ggml_tensor* timesteps = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1); // [N, ]
|
||||||
|
struct ggml_tensor* t_emb = new_timestep_embedding(ctx, timesteps, diffusion_model.model_channels); // [N, model_channels]
|
||||||
|
ggml_set_dynamic(ctx, params.dynamic);
|
||||||
|
|
||||||
|
struct ggml_tensor* out = diffusion_model.forward(ctx, x_t, NULL, c, t_emb);
|
||||||
|
ctx_size += ggml_used_mem(ctx) + ggml_used_mem_of_data(ctx);
|
||||||
|
|
||||||
|
struct ggml_cgraph diffusion_graph = ggml_build_forward(out);
|
||||||
|
struct ggml_cplan cplan = ggml_graph_plan(&diffusion_graph, n_threads);
|
||||||
|
|
||||||
|
ctx_size += cplan.work_size;
|
||||||
|
LOG_DEBUG("diffusion context need %.2fMB static memory, with work_size needing %.2fMB",
|
||||||
|
ctx_size * 1.0f / 1024 / 1024,
|
||||||
|
cplan.work_size * 1.0f / 1024 / 1024);
|
||||||
|
|
||||||
|
ggml_free(ctx);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_init_params params;
|
||||||
|
params.mem_size = ctx_size;
|
||||||
|
params.mem_buffer = NULL;
|
||||||
|
params.no_alloc = false;
|
||||||
|
params.dynamic = dynamic;
|
||||||
|
|
||||||
|
struct ggml_context* ctx = ggml_init(params);
|
||||||
|
if (!ctx) {
|
||||||
|
LOG_ERROR("ggml_init() failed");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_set_dynamic(ctx, false);
|
||||||
|
struct ggml_tensor* timesteps = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1); // [N, ]
|
||||||
|
struct ggml_tensor* t_emb = new_timestep_embedding(ctx, timesteps, diffusion_model.model_channels); // [N, model_channels]
|
||||||
|
ggml_set_dynamic(ctx, params.dynamic);
|
||||||
|
ggml_set_f32(timesteps, 999);
|
||||||
|
set_timestep_embedding(timesteps, t_emb, diffusion_model.model_channels);
|
||||||
|
|
||||||
|
struct ggml_tensor* out = diffusion_model.forward(ctx, x_t, NULL, c, t_emb);
|
||||||
|
ggml_hold_dynamic_tensor(out);
|
||||||
|
|
||||||
|
struct ggml_cgraph diffusion_graph = ggml_build_forward(out);
|
||||||
|
struct ggml_cplan cplan = ggml_graph_plan(&diffusion_graph, n_threads);
|
||||||
|
|
||||||
|
ggml_set_dynamic(ctx, false);
|
||||||
|
struct ggml_tensor* buf = ggml_new_tensor_1d(ctx, GGML_TYPE_I8, cplan.work_size);
|
||||||
|
ggml_set_dynamic(ctx, params.dynamic);
|
||||||
|
|
||||||
|
cplan.work_data = (uint8_t*)buf->data;
|
||||||
|
|
||||||
|
int64_t t0 = ggml_time_ms();
|
||||||
|
ggml_graph_compute(&diffusion_graph, &cplan);
|
||||||
|
|
||||||
|
double result = 0.f;
|
||||||
|
|
||||||
|
{
|
||||||
|
float* vec_x = (float*)x_t->data;
|
||||||
|
float* vec_out = (float*)out->data;
|
||||||
|
|
||||||
|
int64_t n = ggml_nelements(out);
|
||||||
|
|
||||||
|
for (int i = 0; i < n; i++) {
|
||||||
|
result += ((double)vec_out[i] - (double)vec_x[i]);
|
||||||
|
}
|
||||||
|
result /= n;
|
||||||
|
}
|
||||||
|
|
||||||
|
#ifdef GGML_PERF
|
||||||
|
ggml_graph_print(&diffusion_graph);
|
||||||
|
#endif
|
||||||
|
int64_t t1 = ggml_time_ms();
|
||||||
|
LOG_INFO("check is_using_v_parameterization_for_sd2 completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
|
||||||
|
LOG_DEBUG("diffusion graph use %.2fMB runtime memory: static %.2fMB, dynamic %.2fMB",
|
||||||
|
(ctx_size + ggml_curr_max_dynamic_size()) * 1.0f / 1024 / 1024,
|
||||||
|
ctx_size * 1.0f / 1024 / 1024,
|
||||||
|
ggml_curr_max_dynamic_size() * 1.0f / 1024 / 1024);
|
||||||
|
LOG_DEBUG("%zu bytes of dynamic memory has not been released yet", ggml_dynamic_size());
|
||||||
|
|
||||||
|
return result < -1;
|
||||||
|
}
|
||||||
|
|
||||||
ggml_tensor* get_learned_condition(ggml_context* res_ctx, const std::string& text) {
|
ggml_tensor* get_learned_condition(ggml_context* res_ctx, const std::string& text) {
|
||||||
auto tokens_and_weights = cond_stage_model.tokenize(text,
|
auto tokens_and_weights = cond_stage_model.tokenize(text,
|
||||||
cond_stage_model.text_model.max_position_embeddings,
|
cond_stage_model.text_model.max_position_embeddings,
|
||||||
@ -3093,8 +3327,8 @@ class StableDiffusionGGML {
|
|||||||
size_t steps = sigmas.size() - 1;
|
size_t steps = sigmas.size() - 1;
|
||||||
// x_t = load_tensor_from_file(res_ctx, "./rand0.bin");
|
// x_t = load_tensor_from_file(res_ctx, "./rand0.bin");
|
||||||
// print_ggml_tensor(x_t);
|
// print_ggml_tensor(x_t);
|
||||||
struct ggml_tensor* x_out = ggml_dup_tensor(res_ctx, x_t);
|
struct ggml_tensor* x = ggml_dup_tensor(res_ctx, x_t);
|
||||||
copy_ggml_tensor(x_out, x_t);
|
copy_ggml_tensor(x, x_t);
|
||||||
|
|
||||||
size_t ctx_size = 1 * 1024 * 1024; // 1MB
|
size_t ctx_size = 1 * 1024 * 1024; // 1MB
|
||||||
// calculate the amount of memory required
|
// calculate the amount of memory required
|
||||||
@ -3112,16 +3346,16 @@ class StableDiffusionGGML {
|
|||||||
}
|
}
|
||||||
|
|
||||||
ggml_set_dynamic(ctx, false);
|
ggml_set_dynamic(ctx, false);
|
||||||
struct ggml_tensor* x = ggml_dup_tensor(ctx, x_t);
|
struct ggml_tensor* noised_input = ggml_dup_tensor(ctx, x_t);
|
||||||
struct ggml_tensor* context = ggml_dup_tensor(ctx, c);
|
struct ggml_tensor* context = ggml_dup_tensor(ctx, c);
|
||||||
struct ggml_tensor* timesteps = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1); // [N, ]
|
struct ggml_tensor* timesteps = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1); // [N, ]
|
||||||
struct ggml_tensor* t_emb = new_timestep_embedding(ctx, timesteps, diffusion_model.model_channels); // [N, model_channels]
|
struct ggml_tensor* t_emb = new_timestep_embedding(ctx, timesteps, diffusion_model.model_channels); // [N, model_channels]
|
||||||
ggml_set_dynamic(ctx, params.dynamic);
|
ggml_set_dynamic(ctx, params.dynamic);
|
||||||
|
|
||||||
struct ggml_tensor* eps = diffusion_model.forward(ctx, x, NULL, context, t_emb);
|
struct ggml_tensor* out = diffusion_model.forward(ctx, noised_input, NULL, context, t_emb);
|
||||||
ctx_size += ggml_used_mem(ctx) + ggml_used_mem_of_data(ctx);
|
ctx_size += ggml_used_mem(ctx) + ggml_used_mem_of_data(ctx);
|
||||||
|
|
||||||
struct ggml_cgraph diffusion_graph = ggml_build_forward(eps);
|
struct ggml_cgraph diffusion_graph = ggml_build_forward(out);
|
||||||
struct ggml_cplan cplan = ggml_graph_plan(&diffusion_graph, n_threads);
|
struct ggml_cplan cplan = ggml_graph_plan(&diffusion_graph, n_threads);
|
||||||
|
|
||||||
ctx_size += cplan.work_size;
|
ctx_size += cplan.work_size;
|
||||||
@ -3145,16 +3379,16 @@ class StableDiffusionGGML {
|
|||||||
}
|
}
|
||||||
|
|
||||||
ggml_set_dynamic(ctx, false);
|
ggml_set_dynamic(ctx, false);
|
||||||
struct ggml_tensor* x = ggml_dup_tensor(ctx, x_t);
|
struct ggml_tensor* noised_input = ggml_dup_tensor(ctx, x_t);
|
||||||
struct ggml_tensor* context = ggml_dup_tensor(ctx, c);
|
struct ggml_tensor* context = ggml_dup_tensor(ctx, c);
|
||||||
struct ggml_tensor* timesteps = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1); // [N, ]
|
struct ggml_tensor* timesteps = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1); // [N, ]
|
||||||
struct ggml_tensor* t_emb = new_timestep_embedding(ctx, timesteps, diffusion_model.model_channels); // [N, model_channels]
|
struct ggml_tensor* t_emb = new_timestep_embedding(ctx, timesteps, diffusion_model.model_channels); // [N, model_channels]
|
||||||
ggml_set_dynamic(ctx, params.dynamic);
|
ggml_set_dynamic(ctx, params.dynamic);
|
||||||
|
|
||||||
struct ggml_tensor* eps = diffusion_model.forward(ctx, x, NULL, context, t_emb);
|
struct ggml_tensor* out = diffusion_model.forward(ctx, noised_input, NULL, context, t_emb);
|
||||||
ggml_hold_dynamic_tensor(eps);
|
ggml_hold_dynamic_tensor(out);
|
||||||
|
|
||||||
struct ggml_cgraph diffusion_graph = ggml_build_forward(eps);
|
struct ggml_cgraph diffusion_graph = ggml_build_forward(out);
|
||||||
struct ggml_cplan cplan = ggml_graph_plan(&diffusion_graph, n_threads);
|
struct ggml_cplan cplan = ggml_graph_plan(&diffusion_graph, n_threads);
|
||||||
|
|
||||||
ggml_set_dynamic(ctx, false);
|
ggml_set_dynamic(ctx, false);
|
||||||
@ -3163,80 +3397,129 @@ class StableDiffusionGGML {
|
|||||||
|
|
||||||
cplan.work_data = (uint8_t*)buf->data;
|
cplan.work_data = (uint8_t*)buf->data;
|
||||||
|
|
||||||
|
// x = x * sigmas[0]
|
||||||
|
{
|
||||||
|
float* vec = (float*)x->data;
|
||||||
|
for (int i = 0; i < ggml_nelements(x); i++) {
|
||||||
|
vec[i] = vec[i] * sigmas[0];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// denoise wrapper
|
||||||
|
ggml_set_dynamic(ctx, false);
|
||||||
|
struct ggml_tensor* out_cond = NULL;
|
||||||
|
struct ggml_tensor* out_uncond = NULL;
|
||||||
|
if (cfg_scale != 1.0f && uc != NULL) {
|
||||||
|
out_uncond = ggml_dup_tensor(ctx, x);
|
||||||
|
}
|
||||||
|
struct ggml_tensor* denoised = ggml_dup_tensor(ctx, x);
|
||||||
|
ggml_set_dynamic(ctx, params.dynamic);
|
||||||
|
|
||||||
|
auto denoise = [&](ggml_tensor* input, float sigma, int step) {
|
||||||
|
int64_t t0 = ggml_time_ms();
|
||||||
|
|
||||||
|
float c_skip = 1.0f;
|
||||||
|
float c_out = 1.0f;
|
||||||
|
float c_in = 1.0f;
|
||||||
|
std::vector<float> scaling = denoiser->get_scalings(sigma);
|
||||||
|
if (scaling.size() == 3) { // CompVisVDenoiser
|
||||||
|
c_skip = scaling[0];
|
||||||
|
c_out = scaling[1];
|
||||||
|
c_in = scaling[2];
|
||||||
|
} else { // CompVisDenoiser
|
||||||
|
c_out = scaling[0];
|
||||||
|
c_in = scaling[1];
|
||||||
|
}
|
||||||
|
|
||||||
|
float t = denoiser->sigma_to_t(sigma);
|
||||||
|
ggml_set_f32(timesteps, t);
|
||||||
|
set_timestep_embedding(timesteps, t_emb, diffusion_model.model_channels);
|
||||||
|
|
||||||
|
copy_ggml_tensor(noised_input, input);
|
||||||
|
// noised_input = noised_input * c_in
|
||||||
|
{
|
||||||
|
float* vec = (float*)noised_input->data;
|
||||||
|
for (int i = 0; i < ggml_nelements(noised_input); i++) {
|
||||||
|
vec[i] = vec[i] * c_in;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (cfg_scale != 1.0 && uc != NULL) {
|
||||||
|
// uncond
|
||||||
|
copy_ggml_tensor(context, uc);
|
||||||
|
ggml_graph_compute(&diffusion_graph, &cplan);
|
||||||
|
copy_ggml_tensor(out_uncond, out);
|
||||||
|
|
||||||
|
// cond
|
||||||
|
copy_ggml_tensor(context, c);
|
||||||
|
ggml_graph_compute(&diffusion_graph, &cplan);
|
||||||
|
|
||||||
|
out_cond = out;
|
||||||
|
|
||||||
|
// out_uncond + cfg_scale * (out_cond - out_uncond)
|
||||||
|
{
|
||||||
|
float* vec_out = (float*)out->data;
|
||||||
|
float* vec_out_uncond = (float*)out_uncond->data;
|
||||||
|
float* vec_out_cond = (float*)out_cond->data;
|
||||||
|
|
||||||
|
for (int i = 0; i < ggml_nelements(out); i++) {
|
||||||
|
vec_out[i] = vec_out_uncond[i] + cfg_scale * (vec_out_cond[i] - vec_out_uncond[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// cond
|
||||||
|
copy_ggml_tensor(context, c);
|
||||||
|
ggml_graph_compute(&diffusion_graph, &cplan);
|
||||||
|
}
|
||||||
|
|
||||||
|
// v = out, eps = out
|
||||||
|
// denoised = (v * c_out + input * c_skip) or (input + eps * c_out)
|
||||||
|
{
|
||||||
|
float* vec_denoised = (float*)denoised->data;
|
||||||
|
float* vec_input = (float*)input->data;
|
||||||
|
float* vec_out = (float*)out->data;
|
||||||
|
|
||||||
|
for (int i = 0; i < ggml_nelements(denoised); i++) {
|
||||||
|
vec_denoised[i] = vec_out[i] * c_out + vec_input[i] * c_skip;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#ifdef GGML_PERF
|
||||||
|
ggml_graph_print(&diffusion_graph);
|
||||||
|
#endif
|
||||||
|
int64_t t1 = ggml_time_ms();
|
||||||
|
LOG_INFO("step %d sampling completed, taking %.2fs", step, (t1 - t0) * 1.0f / 1000);
|
||||||
|
LOG_DEBUG("diffusion graph use %.2fMB runtime memory: static %.2fMB, dynamic %.2fMB",
|
||||||
|
(ctx_size + ggml_curr_max_dynamic_size()) * 1.0f / 1024 / 1024,
|
||||||
|
ctx_size * 1.0f / 1024 / 1024,
|
||||||
|
ggml_curr_max_dynamic_size() * 1.0f / 1024 / 1024);
|
||||||
|
LOG_DEBUG("%zu bytes of dynamic memory has not been released yet", ggml_dynamic_size());
|
||||||
|
};
|
||||||
|
|
||||||
// sample_euler_ancestral
|
// sample_euler_ancestral
|
||||||
{
|
{
|
||||||
ggml_set_dynamic(ctx, false);
|
ggml_set_dynamic(ctx, false);
|
||||||
struct ggml_tensor* eps_cond = NULL;
|
struct ggml_tensor* noise = ggml_dup_tensor(ctx, x);
|
||||||
struct ggml_tensor* eps_uncond = NULL;
|
struct ggml_tensor* d = ggml_dup_tensor(ctx, x);
|
||||||
struct ggml_tensor* noise = ggml_dup_tensor(ctx, x_out);
|
|
||||||
if (cfg_scale != 1.0f && uc != NULL) {
|
|
||||||
eps_uncond = ggml_dup_tensor(ctx, x_out);
|
|
||||||
}
|
|
||||||
struct ggml_tensor* d = ggml_dup_tensor(ctx, x_out);
|
|
||||||
ggml_set_dynamic(ctx, params.dynamic);
|
ggml_set_dynamic(ctx, params.dynamic);
|
||||||
|
|
||||||
// x_out = x_out * sigmas[0]
|
|
||||||
{
|
|
||||||
float* vec = (float*)x_out->data;
|
|
||||||
for (int i = 0; i < ggml_nelements(x_out); i++) {
|
|
||||||
vec[i] = vec[i] * sigmas[0];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int i = 0; i < steps; i++) {
|
for (int i = 0; i < steps; i++) {
|
||||||
int64_t t0 = ggml_time_ms();
|
float sigma = sigmas[i];
|
||||||
|
|
||||||
copy_ggml_tensor(x, x_out);
|
// denoise
|
||||||
|
denoise(x, sigma, i + 1);
|
||||||
|
|
||||||
std::pair<float, float> scaling = denoiser.get_scalings(sigmas[i]);
|
// d = (x - denoised) / sigma
|
||||||
float c_in = scaling.first;
|
|
||||||
float c_out = scaling.second;
|
|
||||||
float t = denoiser.sigma_to_t(sigmas[i]);
|
|
||||||
ggml_set_f32(timesteps, t);
|
|
||||||
set_timestep_embedding(timesteps, t_emb, diffusion_model.model_channels);
|
|
||||||
|
|
||||||
// x = x * c_in
|
|
||||||
{
|
{
|
||||||
float* vec = (float*)x->data;
|
float* vec_d = (float*)d->data;
|
||||||
for (int i = 0; i < ggml_nelements(x); i++) {
|
float* vec_x = (float*)x->data;
|
||||||
vec[i] = vec[i] * c_in;
|
float* vec_denoised = (float*)denoised->data;
|
||||||
|
|
||||||
|
for (int i = 0; i < ggml_nelements(d); i++) {
|
||||||
|
vec_d[i] = (vec_x[i] - vec_denoised[i]) / sigma;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/*d = (x - denoised) / sigma
|
|
||||||
= (-eps_uncond * c_out - cfg_scale * (eps_cond * c_out - eps_uncond * c_out)) / sigma
|
|
||||||
= eps_uncond + cfg_scale * (eps_cond - eps_uncond)*/
|
|
||||||
if (cfg_scale != 1.0 && uc != NULL) {
|
|
||||||
// uncond
|
|
||||||
copy_ggml_tensor(context, uc);
|
|
||||||
ggml_graph_compute(&diffusion_graph, &cplan);
|
|
||||||
copy_ggml_tensor(eps_uncond, eps);
|
|
||||||
|
|
||||||
// cond
|
|
||||||
copy_ggml_tensor(context, c);
|
|
||||||
ggml_graph_compute(&diffusion_graph, &cplan);
|
|
||||||
|
|
||||||
eps_cond = eps;
|
|
||||||
|
|
||||||
/*d = (x - denoised) / sigma
|
|
||||||
= (-eps_uncond * c_out - cfg_scale * (eps_cond * c_out - eps_uncond * c_out)) / sigma
|
|
||||||
= eps_uncond + cfg_scale * (eps_cond - eps_uncond)*/
|
|
||||||
{
|
|
||||||
float* vec_d = (float*)d->data;
|
|
||||||
float* vec_eps_uncond = (float*)eps_uncond->data;
|
|
||||||
float* vec_eps_cond = (float*)eps_cond->data;
|
|
||||||
|
|
||||||
for (int i = 0; i < ggml_nelements(d); i++) {
|
|
||||||
vec_d[i] = vec_eps_uncond[i] + cfg_scale * (vec_eps_cond[i] - vec_eps_uncond[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// cond
|
|
||||||
copy_ggml_tensor(context, c);
|
|
||||||
ggml_graph_compute(&diffusion_graph, &cplan);
|
|
||||||
copy_ggml_tensor(d, eps);
|
|
||||||
}
|
|
||||||
|
|
||||||
// get_ancestral_step
|
// get_ancestral_step
|
||||||
float sigma_up = std::min(sigmas[i + 1],
|
float sigma_up = std::min(sigmas[i + 1],
|
||||||
std::sqrt(sigmas[i + 1] * sigmas[i + 1] * (sigmas[i] * sigmas[i] - sigmas[i + 1] * sigmas[i + 1]) / (sigmas[i] * sigmas[i])));
|
std::sqrt(sigmas[i + 1] * sigmas[i + 1] * (sigmas[i] * sigmas[i] - sigmas[i + 1] * sigmas[i + 1]) / (sigmas[i] * sigmas[i])));
|
||||||
@ -3247,9 +3530,9 @@ class StableDiffusionGGML {
|
|||||||
// x = x + d * dt
|
// x = x + d * dt
|
||||||
{
|
{
|
||||||
float* vec_d = (float*)d->data;
|
float* vec_d = (float*)d->data;
|
||||||
float* vec_x = (float*)x_out->data;
|
float* vec_x = (float*)x->data;
|
||||||
|
|
||||||
for (int i = 0; i < ggml_nelements(x_out); i++) {
|
for (int i = 0; i < ggml_nelements(x); i++) {
|
||||||
vec_x[i] = vec_x[i] + vec_d[i] * dt;
|
vec_x[i] = vec_x[i] + vec_d[i] * dt;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -3259,25 +3542,14 @@ class StableDiffusionGGML {
|
|||||||
ggml_tensor_set_f32_randn(noise);
|
ggml_tensor_set_f32_randn(noise);
|
||||||
// noise = load_tensor_from_file(res_ctx, "./rand" + std::to_string(i+1) + ".bin");
|
// noise = load_tensor_from_file(res_ctx, "./rand" + std::to_string(i+1) + ".bin");
|
||||||
{
|
{
|
||||||
float* vec_x = (float*)x_out->data;
|
float* vec_x = (float*)x->data;
|
||||||
float* vec_noise = (float*)noise->data;
|
float* vec_noise = (float*)noise->data;
|
||||||
|
|
||||||
for (int i = 0; i < ggml_nelements(x_out); i++) {
|
for (int i = 0; i < ggml_nelements(x); i++) {
|
||||||
vec_x[i] = vec_x[i] + vec_noise[i] * sigma_up;
|
vec_x[i] = vec_x[i] + vec_noise[i] * sigma_up;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef GGML_PERF
|
|
||||||
ggml_graph_print(&diffusion_graph);
|
|
||||||
#endif
|
|
||||||
int64_t t1 = ggml_time_ms();
|
|
||||||
LOG_INFO("step %d sampling completed, taking %.2fs", i + 1, (t1 - t0) * 1.0f / 1000);
|
|
||||||
LOG_DEBUG("diffusion graph use %.2fMB runtime memory: static %.2fMB, dynamic %.2fMB",
|
|
||||||
(ctx_size + ggml_curr_max_dynamic_size()) * 1.0f / 1024 / 1024,
|
|
||||||
ctx_size * 1.0f / 1024 / 1024,
|
|
||||||
ggml_curr_max_dynamic_size() * 1.0f / 1024 / 1024);
|
|
||||||
LOG_DEBUG("%zu bytes of dynamic memory has not been released yet", ggml_dynamic_size());
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -3304,7 +3576,7 @@ class StableDiffusionGGML {
|
|||||||
|
|
||||||
ggml_free(ctx);
|
ggml_free(ctx);
|
||||||
|
|
||||||
return x_out;
|
return x;
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor* encode_first_stage(ggml_context* res_ctx, ggml_tensor* x) {
|
ggml_tensor* encode_first_stage(ggml_context* res_ctx, ggml_tensor* x) {
|
||||||
@ -3586,7 +3858,7 @@ std::vector<uint8_t> StableDiffusion::txt2img(const std::string& prompt,
|
|||||||
struct ggml_tensor* x_t = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, W, H, C, 1);
|
struct ggml_tensor* x_t = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, W, H, C, 1);
|
||||||
ggml_tensor_set_f32_randn(x_t);
|
ggml_tensor_set_f32_randn(x_t);
|
||||||
|
|
||||||
std::vector<float> sigmas = sd->denoiser.get_sigmas(sample_steps);
|
std::vector<float> sigmas = sd->denoiser->get_sigmas(sample_steps);
|
||||||
|
|
||||||
LOG_INFO("start sampling");
|
LOG_INFO("start sampling");
|
||||||
struct ggml_tensor* x_0 = sd->sample(ctx, x_t, c, uc, cfg_scale, sample_method, sigmas);
|
struct ggml_tensor* x_0 = sd->sample(ctx, x_t, c, uc, cfg_scale, sample_method, sigmas);
|
||||||
@ -3642,7 +3914,7 @@ std::vector<uint8_t> StableDiffusion::img2img(const std::vector<uint8_t>& init_i
|
|||||||
}
|
}
|
||||||
LOG_INFO("img2img %dx%d", width, height);
|
LOG_INFO("img2img %dx%d", width, height);
|
||||||
|
|
||||||
std::vector<float> sigmas = sd->denoiser.get_sigmas(sample_steps);
|
std::vector<float> sigmas = sd->denoiser->get_sigmas(sample_steps);
|
||||||
size_t t_enc = static_cast<size_t>(sample_steps * strength);
|
size_t t_enc = static_cast<size_t>(sample_steps * strength);
|
||||||
LOG_INFO("target t_enc is %zu steps", t_enc);
|
LOG_INFO("target t_enc is %zu steps", t_enc);
|
||||||
std::vector<float> sigma_sched;
|
std::vector<float> sigma_sched;
|
||||||
|
Loading…
Reference in New Issue
Block a user