feat: add sd3.5 support (#445)
This commit is contained in:
parent
14206fd488
commit
ac54e00760
22
README.md
22
README.md
@ -10,7 +10,7 @@ Inference of Stable Diffusion and Flux in pure C/C++
|
||||
|
||||
- Plain C/C++ implementation based on [ggml](https://github.com/ggerganov/ggml), working in the same way as [llama.cpp](https://github.com/ggerganov/llama.cpp)
|
||||
- Super lightweight and without external dependencies
|
||||
- SD1.x, SD2.x, SDXL and SD3 support
|
||||
- SD1.x, SD2.x, SDXL and [SD3/SD3.5](./docs/sd3.md) support
|
||||
- !!!The VAE in SDXL encounters NaN issues under FP16, but unfortunately, the ggml_conv_2d only operates under FP16. Hence, a parameter is needed to specify the VAE that has fixed the FP16 NaN issue. You can find it here: [SDXL VAE FP16 Fix](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/blob/main/sdxl_vae.safetensors).
|
||||
- [Flux-dev/Flux-schnell Support](./docs/flux.md)
|
||||
|
||||
@ -197,23 +197,24 @@ usage: ./bin/sd [arguments]
|
||||
arguments:
|
||||
-h, --help show this help message and exit
|
||||
-M, --mode [MODEL] run mode (txt2img or img2img or convert, default: txt2img)
|
||||
-t, --threads N number of threads to use during computation (default: -1).
|
||||
-t, --threads N number of threads to use during computation (default: -1)
|
||||
If threads <= 0, then threads will be set to the number of CPU physical cores
|
||||
-m, --model [MODEL] path to full model
|
||||
--diffusion-model path to the standalone diffusion model
|
||||
--clip_l path to the clip-l text encoder
|
||||
--t5xxl path to the the t5xxl text encoder.
|
||||
--clip_g path to the clip-l text encoder
|
||||
--t5xxl path to the the t5xxl text encoder
|
||||
--vae [VAE] path to vae
|
||||
--taesd [TAESD_PATH] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)
|
||||
--control-net [CONTROL_PATH] path to control net model
|
||||
--embd-dir [EMBEDDING_PATH] path to embeddings.
|
||||
--stacked-id-embd-dir [DIR] path to PHOTOMAKER stacked id embeddings.
|
||||
--input-id-images-dir [DIR] path to PHOTOMAKER input id images dir.
|
||||
--embd-dir [EMBEDDING_PATH] path to embeddings
|
||||
--stacked-id-embd-dir [DIR] path to PHOTOMAKER stacked id embeddings
|
||||
--input-id-images-dir [DIR] path to PHOTOMAKER input id images dir
|
||||
--normalize-input normalize PHOTOMAKER input id images
|
||||
--upscale-model [ESRGAN_PATH] path to esrgan model. Upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now.
|
||||
--upscale-model [ESRGAN_PATH] path to esrgan model. Upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now
|
||||
--upscale-repeats Run the ESRGAN upscaler this many times (default 1)
|
||||
--type [TYPE] weight type (f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_k, q3_k, q4_k)
|
||||
If not specified, the default is the type of the weight file.
|
||||
If not specified, the default is the type of the weight file
|
||||
--lora-model-dir [DIR] lora model directory
|
||||
-i, --init-img [IMAGE] path to the input image, required by img2img
|
||||
--control-image [IMAGE] path to image condition, control net
|
||||
@ -232,13 +233,13 @@ arguments:
|
||||
--steps STEPS number of sample steps (default: 20)
|
||||
--rng {std_default, cuda} RNG (default: cuda)
|
||||
-s SEED, --seed SEED RNG seed (default: 42, use random seed for < 0)
|
||||
-b, --batch-count COUNT number of images to generate.
|
||||
-b, --batch-count COUNT number of images to generate
|
||||
--schedule {discrete, karras, exponential, ays, gits} Denoiser sigma schedule (default: discrete)
|
||||
--clip-skip N ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1)
|
||||
<= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x
|
||||
--vae-tiling process vae in tiles to reduce memory usage
|
||||
--vae-on-cpu keep vae in cpu (for low vram)
|
||||
--clip-on-cpu keep clip in cpu (for low vram).
|
||||
--clip-on-cpu keep clip in cpu (for low vram)
|
||||
--control-net-cpu keep controlnet in cpu (for low vram)
|
||||
--canny apply canny preprocessor (edge detection)
|
||||
--color Colors the logging tags according to level
|
||||
@ -253,6 +254,7 @@ arguments:
|
||||
# ./bin/sd -m ../models/sd_xl_base_1.0.safetensors --vae ../models/sdxl_vae-fp16-fix.safetensors -H 1024 -W 1024 -p "a lovely cat" -v
|
||||
# ./bin/sd -m ../models/sd3_medium_incl_clips_t5xxlfp16.safetensors -H 1024 -W 1024 -p 'a lovely cat holding a sign says \"Stable Diffusion CPP\"' --cfg-scale 4.5 --sampling-method euler -v
|
||||
# ./bin/sd --diffusion-model ../models/flux1-dev-q3_k.gguf --vae ../models/ae.sft --clip_l ../models/clip_l.safetensors --t5xxl ../models/t5xxl_fp16.safetensors -p "a lovely cat holding a sign says 'flux.cpp'" --cfg-scale 1.0 --sampling-method euler -v
|
||||
# ./bin/sd -m ..\models\sd3.5_large.safetensors --clip_l ..\models\clip_l.safetensors --clip_g ..\models\clip_g.safetensors --t5xxl ..\models\t5xxl_fp16.safetensors -H 1024 -W 1024 -p 'a lovely cat holding a sign says \"Stable diffusion 3.5 Large\"' --cfg-scale 4.5 --sampling-method euler -v
|
||||
```
|
||||
|
||||
Using formats of different precisions will yield results of varying quality.
|
||||
|
BIN
assets/sd3.5_large.png
Normal file
BIN
assets/sd3.5_large.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 1.8 MiB |
@ -1001,8 +1001,8 @@ struct FluxCLIPEmbedder : public Conditioner {
|
||||
}
|
||||
|
||||
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) {
|
||||
clip_l->get_param_tensors(tensors, "text_encoders.clip_l.text_model");
|
||||
t5->get_param_tensors(tensors, "text_encoders.t5xxl");
|
||||
clip_l->get_param_tensors(tensors, "text_encoders.clip_l.transformer.text_model");
|
||||
t5->get_param_tensors(tensors, "text_encoders.t5xxl.transformer");
|
||||
}
|
||||
|
||||
void alloc_params_buffer() {
|
||||
|
11
denoiser.hpp
11
denoiser.hpp
@ -205,7 +205,7 @@ struct AYSSchedule : SigmaSchedule {
|
||||
|
||||
/*
|
||||
* GITS Scheduler: https://github.com/zju-pi/diff-sampler/tree/main/gits-main
|
||||
*/
|
||||
*/
|
||||
struct GITSSchedule : SigmaSchedule {
|
||||
std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) {
|
||||
if (sigma_max <= 0.0f) {
|
||||
@ -862,8 +862,7 @@ static void sample_k_diffusion(sample_method_t method,
|
||||
for (int j = 0; j < ggml_nelements(x_next); j++) {
|
||||
vec_x_next[j] = vec_x_cur[j] + (sigma_next - sigma) * (3 * vec_d_cur[j] - vec_d_prev1[j]) / 2;
|
||||
}
|
||||
}
|
||||
break;
|
||||
} break;
|
||||
|
||||
case 3: // Use two history points
|
||||
{
|
||||
@ -872,8 +871,7 @@ static void sample_k_diffusion(sample_method_t method,
|
||||
for (int j = 0; j < ggml_nelements(x_next); j++) {
|
||||
vec_x_next[j] = vec_x_cur[j] + (sigma_next - sigma) * (23 * vec_d_cur[j] - 16 * vec_d_prev1[j] + 5 * vec_d_prev2[j]) / 12;
|
||||
}
|
||||
}
|
||||
break;
|
||||
} break;
|
||||
|
||||
case 4: // Use three history points
|
||||
{
|
||||
@ -883,8 +881,7 @@ static void sample_k_diffusion(sample_method_t method,
|
||||
for (int j = 0; j < ggml_nelements(x_next); j++) {
|
||||
vec_x_next[j] = vec_x_cur[j] + (sigma_next - sigma) * (55 * vec_d_cur[j] - 59 * vec_d_prev1[j] + 37 * vec_d_prev2[j] - 9 * vec_d_prev3[j]) / 24;
|
||||
}
|
||||
}
|
||||
break;
|
||||
} break;
|
||||
}
|
||||
|
||||
// Manage buffer_model
|
||||
|
20
docs/sd3.md
Normal file
20
docs/sd3.md
Normal file
@ -0,0 +1,20 @@
|
||||
# How to Use
|
||||
|
||||
## Download weights
|
||||
|
||||
- Download sd3.5_large from https://huggingface.co/stabilityai/stable-diffusion-3.5-large/blob/main/sd3.5_large.safetensors
|
||||
- Download clip_g from https://huggingface.co/Comfy-Org/stable-diffusion-3.5-fp8/blob/main/text_encoders/clip_g.safetensors
|
||||
- Download clip_l from https://huggingface.co/Comfy-Org/stable-diffusion-3.5-fp8/blob/main/text_encoders/clip_l.safetensors
|
||||
- Download t5xxl from https://huggingface.co/Comfy-Org/stable-diffusion-3.5-fp8/blob/main/text_encoders/t5xxl_fp16.safetensors
|
||||
|
||||
|
||||
## Run
|
||||
|
||||
### SD3.5 Large
|
||||
For example:
|
||||
|
||||
```
|
||||
.\bin\Release\sd.exe -m ..\models\sd3.5_large.safetensors --clip_l ..\models\clip_l.safetensors --clip_g ..\models\clip_g.safetensors --t5xxl ..\models\t5xxl_fp16.safetensors -H 1024 -W 1024 -p 'a lovely cat holding a sign says \"Stable diffusion 3.5 Large\"' --cfg-scale 4.5 --sampling-method euler -v
|
||||
```
|
||||
|
||||

|
@ -69,9 +69,9 @@ enum SDMode {
|
||||
struct SDParams {
|
||||
int n_threads = -1;
|
||||
SDMode mode = TXT2IMG;
|
||||
|
||||
std::string model_path;
|
||||
std::string clip_l_path;
|
||||
std::string clip_g_path;
|
||||
std::string t5xxl_path;
|
||||
std::string diffusion_model_path;
|
||||
std::string vae_path;
|
||||
@ -128,6 +128,7 @@ void print_params(SDParams params) {
|
||||
printf(" model_path: %s\n", params.model_path.c_str());
|
||||
printf(" wtype: %s\n", params.wtype < SD_TYPE_COUNT ? sd_type_name(params.wtype) : "unspecified");
|
||||
printf(" clip_l_path: %s\n", params.clip_l_path.c_str());
|
||||
printf(" clip_g_path: %s\n", params.clip_g_path.c_str());
|
||||
printf(" t5xxl_path: %s\n", params.t5xxl_path.c_str());
|
||||
printf(" diffusion_model_path: %s\n", params.diffusion_model_path.c_str());
|
||||
printf(" vae_path: %s\n", params.vae_path.c_str());
|
||||
@ -171,23 +172,24 @@ void print_usage(int argc, const char* argv[]) {
|
||||
printf("arguments:\n");
|
||||
printf(" -h, --help show this help message and exit\n");
|
||||
printf(" -M, --mode [MODEL] run mode (txt2img or img2img or convert, default: txt2img)\n");
|
||||
printf(" -t, --threads N number of threads to use during computation (default: -1).\n");
|
||||
printf(" -t, --threads N number of threads to use during computation (default: -1)\n");
|
||||
printf(" If threads <= 0, then threads will be set to the number of CPU physical cores\n");
|
||||
printf(" -m, --model [MODEL] path to full model\n");
|
||||
printf(" --diffusion-model path to the standalone diffusion model\n");
|
||||
printf(" --clip_l path to the clip-l text encoder\n");
|
||||
printf(" --t5xxl path to the the t5xxl text encoder.\n");
|
||||
printf(" --clip_g path to the clip-l text encoder\n");
|
||||
printf(" --t5xxl path to the the t5xxl text encoder\n");
|
||||
printf(" --vae [VAE] path to vae\n");
|
||||
printf(" --taesd [TAESD_PATH] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)\n");
|
||||
printf(" --control-net [CONTROL_PATH] path to control net model\n");
|
||||
printf(" --embd-dir [EMBEDDING_PATH] path to embeddings.\n");
|
||||
printf(" --stacked-id-embd-dir [DIR] path to PHOTOMAKER stacked id embeddings.\n");
|
||||
printf(" --input-id-images-dir [DIR] path to PHOTOMAKER input id images dir.\n");
|
||||
printf(" --embd-dir [EMBEDDING_PATH] path to embeddings\n");
|
||||
printf(" --stacked-id-embd-dir [DIR] path to PHOTOMAKER stacked id embeddings\n");
|
||||
printf(" --input-id-images-dir [DIR] path to PHOTOMAKER input id images dir\n");
|
||||
printf(" --normalize-input normalize PHOTOMAKER input id images\n");
|
||||
printf(" --upscale-model [ESRGAN_PATH] path to esrgan model. Upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now.\n");
|
||||
printf(" --upscale-model [ESRGAN_PATH] path to esrgan model. Upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now\n");
|
||||
printf(" --upscale-repeats Run the ESRGAN upscaler this many times (default 1)\n");
|
||||
printf(" --type [TYPE] weight type (f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_k, q3_k, q4_k)\n");
|
||||
printf(" If not specified, the default is the type of the weight file.\n");
|
||||
printf(" If not specified, the default is the type of the weight file\n");
|
||||
printf(" --lora-model-dir [DIR] lora model directory\n");
|
||||
printf(" -i, --init-img [IMAGE] path to the input image, required by img2img\n");
|
||||
printf(" --control-image [IMAGE] path to image condition, control net\n");
|
||||
@ -206,13 +208,13 @@ void print_usage(int argc, const char* argv[]) {
|
||||
printf(" --steps STEPS number of sample steps (default: 20)\n");
|
||||
printf(" --rng {std_default, cuda} RNG (default: cuda)\n");
|
||||
printf(" -s SEED, --seed SEED RNG seed (default: 42, use random seed for < 0)\n");
|
||||
printf(" -b, --batch-count COUNT number of images to generate.\n");
|
||||
printf(" -b, --batch-count COUNT number of images to generate\n");
|
||||
printf(" --schedule {discrete, karras, exponential, ays, gits} Denoiser sigma schedule (default: discrete)\n");
|
||||
printf(" --clip-skip N ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1)\n");
|
||||
printf(" <= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x\n");
|
||||
printf(" --vae-tiling process vae in tiles to reduce memory usage\n");
|
||||
printf(" --vae-on-cpu keep vae in cpu (for low vram)\n");
|
||||
printf(" --clip-on-cpu keep clip in cpu (for low vram).\n");
|
||||
printf(" --clip-on-cpu keep clip in cpu (for low vram)\n");
|
||||
printf(" --control-net-cpu keep controlnet in cpu (for low vram)\n");
|
||||
printf(" --canny apply canny preprocessor (edge detection)\n");
|
||||
printf(" --color Colors the logging tags according to level\n");
|
||||
@ -262,6 +264,12 @@ void parse_args(int argc, const char** argv, SDParams& params) {
|
||||
break;
|
||||
}
|
||||
params.clip_l_path = argv[i];
|
||||
} else if (arg == "--clip_g") {
|
||||
if (++i >= argc) {
|
||||
invalid_arg = true;
|
||||
break;
|
||||
}
|
||||
params.clip_g_path = argv[i];
|
||||
} else if (arg == "--t5xxl") {
|
||||
if (++i >= argc) {
|
||||
invalid_arg = true;
|
||||
@ -765,6 +773,7 @@ int main(int argc, const char* argv[]) {
|
||||
|
||||
sd_ctx_t* sd_ctx = new_sd_ctx(params.model_path.c_str(),
|
||||
params.clip_l_path.c_str(),
|
||||
params.clip_g_path.c_str(),
|
||||
params.t5xxl_path.c_str(),
|
||||
params.diffusion_model_path.c_str(),
|
||||
params.vae_path.c_str(),
|
||||
|
@ -380,7 +380,7 @@ __STATIC_INLINE__ void ggml_merge_tensor_2d(struct ggml_tensor* input,
|
||||
float old_value = ggml_tensor_get_f32(output, x + ix, y + iy, k);
|
||||
|
||||
const float x_f_0 = (x > 0) ? ix / float(overlap) : 1;
|
||||
const float x_f_1 = (x < (img_width - width)) ? (width - ix) / float(overlap) : 1 ;
|
||||
const float x_f_1 = (x < (img_width - width)) ? (width - ix) / float(overlap) : 1;
|
||||
const float y_f_0 = (y > 0) ? iy / float(overlap) : 1;
|
||||
const float y_f_1 = (y < (img_height - height)) ? (height - iy) / float(overlap) : 1;
|
||||
|
||||
@ -390,8 +390,7 @@ __STATIC_INLINE__ void ggml_merge_tensor_2d(struct ggml_tensor* input,
|
||||
ggml_tensor_set_f32(
|
||||
output,
|
||||
old_value + new_value * ggml_smootherstep_f32(y_f) * ggml_smootherstep_f32(x_f),
|
||||
x + ix, y + iy, k
|
||||
);
|
||||
x + ix, y + iy, k);
|
||||
} else {
|
||||
ggml_tensor_set_f32(output, new_value, x + ix, y + iy, k);
|
||||
}
|
||||
|
84
mmdit.hpp
84
mmdit.hpp
@ -142,29 +142,77 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
class RMSNorm : public UnaryBlock {
|
||||
protected:
|
||||
int64_t hidden_size;
|
||||
float eps;
|
||||
|
||||
void init_params(struct ggml_context* ctx, ggml_type wtype) {
|
||||
params["weight"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hidden_size);
|
||||
}
|
||||
|
||||
public:
|
||||
RMSNorm(int64_t hidden_size,
|
||||
float eps = 1e-06f)
|
||||
: hidden_size(hidden_size),
|
||||
eps(eps) {}
|
||||
|
||||
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
|
||||
struct ggml_tensor* w = params["weight"];
|
||||
x = ggml_rms_norm(ctx, x, eps);
|
||||
x = ggml_mul(ctx, x, w);
|
||||
return x;
|
||||
}
|
||||
};
|
||||
|
||||
class SelfAttention : public GGMLBlock {
|
||||
public:
|
||||
int64_t num_heads;
|
||||
bool pre_only;
|
||||
std::string qk_norm;
|
||||
|
||||
public:
|
||||
SelfAttention(int64_t dim,
|
||||
int64_t num_heads = 8,
|
||||
std::string qk_norm = "",
|
||||
bool qkv_bias = false,
|
||||
bool pre_only = false)
|
||||
: num_heads(num_heads), pre_only(pre_only) {
|
||||
// qk_norm is always None
|
||||
: num_heads(num_heads), pre_only(pre_only), qk_norm(qk_norm) {
|
||||
int64_t d_head = dim / num_heads;
|
||||
blocks["qkv"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim * 3, qkv_bias));
|
||||
if (!pre_only) {
|
||||
blocks["proj"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim));
|
||||
}
|
||||
if (qk_norm == "rms") {
|
||||
blocks["ln_q"] = std::shared_ptr<GGMLBlock>(new RMSNorm(d_head, 1.0e-6));
|
||||
blocks["ln_k"] = std::shared_ptr<GGMLBlock>(new RMSNorm(d_head, 1.0e-6));
|
||||
} else if (qk_norm == "ln") {
|
||||
blocks["ln_q"] = std::shared_ptr<GGMLBlock>(new LayerNorm(d_head, 1.0e-6));
|
||||
blocks["ln_k"] = std::shared_ptr<GGMLBlock>(new LayerNorm(d_head, 1.0e-6));
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<struct ggml_tensor*> pre_attention(struct ggml_context* ctx, struct ggml_tensor* x) {
|
||||
auto qkv_proj = std::dynamic_pointer_cast<Linear>(blocks["qkv"]);
|
||||
|
||||
auto qkv = qkv_proj->forward(ctx, x);
|
||||
return split_qkv(ctx, qkv);
|
||||
auto qkv_vec = split_qkv(ctx, qkv);
|
||||
int64_t head_dim = qkv_vec[0]->ne[0] / num_heads;
|
||||
auto q = ggml_reshape_4d(ctx, qkv_vec[0], head_dim, num_heads, qkv_vec[0]->ne[1], qkv_vec[0]->ne[2]); // [N, n_token, n_head, d_head]
|
||||
auto k = ggml_reshape_4d(ctx, qkv_vec[1], head_dim, num_heads, qkv_vec[1]->ne[1], qkv_vec[1]->ne[2]); // [N, n_token, n_head, d_head]
|
||||
auto v = qkv_vec[2]; // [N, n_token, n_head*d_head]
|
||||
|
||||
if (qk_norm == "rms" || qk_norm == "ln") {
|
||||
auto ln_q = std::dynamic_pointer_cast<UnaryBlock>(blocks["ln_q"]);
|
||||
auto ln_k = std::dynamic_pointer_cast<UnaryBlock>(blocks["ln_k"]);
|
||||
q = ln_q->forward(ctx, q);
|
||||
k = ln_k->forward(ctx, k);
|
||||
}
|
||||
|
||||
q = ggml_reshape_3d(ctx, q, q->ne[0] * q->ne[1], q->ne[2], q->ne[3]); // [N, n_token, n_head*d_head]
|
||||
k = ggml_reshape_3d(ctx, k, k->ne[0] * k->ne[1], k->ne[2], k->ne[3]); // [N, n_token, n_head*d_head]
|
||||
|
||||
return {q, k, v};
|
||||
}
|
||||
|
||||
struct ggml_tensor* post_attention(struct ggml_context* ctx, struct ggml_tensor* x) {
|
||||
@ -209,15 +257,15 @@ public:
|
||||
DismantledBlock(int64_t hidden_size,
|
||||
int64_t num_heads,
|
||||
float mlp_ratio = 4.0,
|
||||
std::string qk_norm = "",
|
||||
bool qkv_bias = false,
|
||||
bool pre_only = false)
|
||||
: num_heads(num_heads), pre_only(pre_only) {
|
||||
// rmsnorm is always Flase
|
||||
// scale_mod_only is always Flase
|
||||
// swiglu is always Flase
|
||||
// qk_norm is always Flase
|
||||
blocks["norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-06f, false));
|
||||
blocks["attn"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qkv_bias, pre_only));
|
||||
blocks["attn"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qk_norm, qkv_bias, pre_only));
|
||||
|
||||
if (!pre_only) {
|
||||
blocks["norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-06f, false));
|
||||
@ -397,11 +445,11 @@ public:
|
||||
JointBlock(int64_t hidden_size,
|
||||
int64_t num_heads,
|
||||
float mlp_ratio = 4.0,
|
||||
std::string qk_norm = "",
|
||||
bool qkv_bias = false,
|
||||
bool pre_only = false) {
|
||||
// qk_norm is always Flase
|
||||
blocks["context_block"] = std::shared_ptr<GGMLBlock>(new DismantledBlock(hidden_size, num_heads, mlp_ratio, qkv_bias, pre_only));
|
||||
blocks["x_block"] = std::shared_ptr<GGMLBlock>(new DismantledBlock(hidden_size, num_heads, mlp_ratio, qkv_bias, false));
|
||||
blocks["context_block"] = std::shared_ptr<GGMLBlock>(new DismantledBlock(hidden_size, num_heads, mlp_ratio, qk_norm, qkv_bias, pre_only));
|
||||
blocks["x_block"] = std::shared_ptr<GGMLBlock>(new DismantledBlock(hidden_size, num_heads, mlp_ratio, qk_norm, qkv_bias, false));
|
||||
}
|
||||
|
||||
std::pair<struct ggml_tensor*, struct ggml_tensor*> forward(struct ggml_context* ctx,
|
||||
@ -466,7 +514,9 @@ protected:
|
||||
int64_t pos_embed_max_size = 192;
|
||||
int64_t num_patchs = 36864; // 192 * 192
|
||||
int64_t context_size = 4096;
|
||||
int64_t context_embedder_out_dim = 1536;
|
||||
int64_t hidden_size;
|
||||
std::string qk_norm;
|
||||
|
||||
void init_params(struct ggml_context* ctx, ggml_type wtype) {
|
||||
params["pos_embed"] = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hidden_size, num_patchs, 1);
|
||||
@ -481,7 +531,6 @@ public:
|
||||
// rmsnorm is alwalys False
|
||||
// scale_mod_only is alwalys False
|
||||
// swiglu is alwalys False
|
||||
// qk_norm is always None
|
||||
// qkv_bias is always True
|
||||
// context_processor_layers is always None
|
||||
// pos_embed_scaling_factor is not used
|
||||
@ -498,6 +547,20 @@ public:
|
||||
pos_embed_max_size = 192;
|
||||
num_patchs = 36864; // 192 * 192
|
||||
context_size = 4096;
|
||||
context_embedder_out_dim = 1536;
|
||||
} else if (version == VERSION_SD3_5_8B) {
|
||||
input_size = -1;
|
||||
patch_size = 2;
|
||||
in_channels = 16;
|
||||
depth = 38;
|
||||
mlp_ratio = 4.0f;
|
||||
adm_in_channels = 2048;
|
||||
out_channels = 16;
|
||||
pos_embed_max_size = 192;
|
||||
num_patchs = 36864; // 192 * 192
|
||||
context_size = 4096;
|
||||
context_embedder_out_dim = 2432;
|
||||
qk_norm = "rms";
|
||||
}
|
||||
int64_t default_out_channels = in_channels;
|
||||
hidden_size = 64 * depth;
|
||||
@ -510,12 +573,13 @@ public:
|
||||
blocks["y_embedder"] = std::shared_ptr<GGMLBlock>(new VectorEmbedder(adm_in_channels, hidden_size));
|
||||
}
|
||||
|
||||
blocks["context_embedder"] = std::shared_ptr<GGMLBlock>(new Linear(4096, 1536, true, true));
|
||||
blocks["context_embedder"] = std::shared_ptr<GGMLBlock>(new Linear(4096, context_embedder_out_dim, true, true));
|
||||
|
||||
for (int i = 0; i < depth; i++) {
|
||||
blocks["joint_blocks." + std::to_string(i)] = std::shared_ptr<GGMLBlock>(new JointBlock(hidden_size,
|
||||
num_heads,
|
||||
mlp_ratio,
|
||||
qk_norm,
|
||||
true,
|
||||
i == depth - 1));
|
||||
}
|
||||
|
20
model.cpp
20
model.cpp
@ -430,6 +430,14 @@ std::string convert_tensor_name(std::string name) {
|
||||
if (starts_with(name, "diffusion_model")) {
|
||||
name = "model." + name;
|
||||
}
|
||||
// size_t pos = name.find("lora_A");
|
||||
// if (pos != std::string::npos) {
|
||||
// name.replace(pos, strlen("lora_A"), "lora_up");
|
||||
// }
|
||||
// pos = name.find("lora_B");
|
||||
// if (pos != std::string::npos) {
|
||||
// name.replace(pos, strlen("lora_B"), "lora_down");
|
||||
// }
|
||||
std::string new_name = name;
|
||||
if (starts_with(name, "cond_stage_model.") || starts_with(name, "conditioner.embedders.") || starts_with(name, "text_encoders.") || ends_with(name, ".vision_model.visual_projection.weight")) {
|
||||
new_name = convert_open_clip_to_hf_clip(name);
|
||||
@ -466,6 +474,9 @@ std::string convert_tensor_name(std::string name) {
|
||||
if (pos != std::string::npos) {
|
||||
new_name.replace(pos, strlen(".processor"), "");
|
||||
}
|
||||
// if (starts_with(new_name, "transformer.transformer_blocks") || starts_with(new_name, "transformer.single_transformer_blocks")) {
|
||||
// new_name = "model.diffusion_model." + new_name;
|
||||
// }
|
||||
pos = new_name.rfind("lora");
|
||||
if (pos != std::string::npos) {
|
||||
std::string name_without_network_parts = new_name.substr(0, pos - 1);
|
||||
@ -1354,6 +1365,7 @@ bool ModelLoader::init_from_ckpt_file(const std::string& file_path, const std::s
|
||||
SDVersion ModelLoader::get_sd_version() {
|
||||
TensorStorage token_embedding_weight;
|
||||
bool is_flux = false;
|
||||
bool is_sd3 = false;
|
||||
for (auto& tensor_storage : tensor_storages) {
|
||||
if (tensor_storage.name.find("model.diffusion_model.guidance_in.in_layer.weight") != std::string::npos) {
|
||||
return VERSION_FLUX_DEV;
|
||||
@ -1361,8 +1373,11 @@ SDVersion ModelLoader::get_sd_version() {
|
||||
if (tensor_storage.name.find("model.diffusion_model.double_blocks.") != std::string::npos) {
|
||||
is_flux = true;
|
||||
}
|
||||
if (tensor_storage.name.find("joint_blocks.37.x_block.attn.ln_q.weight") != std::string::npos) {
|
||||
return VERSION_SD3_5_8B;
|
||||
}
|
||||
if (tensor_storage.name.find("model.diffusion_model.joint_blocks.23.") != std::string::npos) {
|
||||
return VERSION_SD3_2B;
|
||||
is_sd3 = true;
|
||||
}
|
||||
if (tensor_storage.name.find("conditioner.embedders.1") != std::string::npos) {
|
||||
return VERSION_SDXL;
|
||||
@ -1387,6 +1402,9 @@ SDVersion ModelLoader::get_sd_version() {
|
||||
if (is_flux) {
|
||||
return VERSION_FLUX_SCHNELL;
|
||||
}
|
||||
if (is_sd3) {
|
||||
return VERSION_SD3_2B;
|
||||
}
|
||||
if (token_embedding_weight.ne[0] == 768) {
|
||||
return VERSION_SD1;
|
||||
} else if (token_embedding_weight.ne[0] == 1024) {
|
||||
|
1
model.h
1
model.h
@ -25,6 +25,7 @@ enum SDVersion {
|
||||
VERSION_SD3_2B,
|
||||
VERSION_FLUX_DEV,
|
||||
VERSION_FLUX_SCHNELL,
|
||||
VERSION_SD3_5_8B,
|
||||
VERSION_COUNT,
|
||||
};
|
||||
|
||||
|
@ -31,7 +31,8 @@ const char* model_version_to_str[] = {
|
||||
"SVD",
|
||||
"SD3 2B",
|
||||
"Flux Dev",
|
||||
"Flux Schnell"};
|
||||
"Flux Schnell",
|
||||
"SD3.5 8B"};
|
||||
|
||||
const char* sampling_methods_str[] = {
|
||||
"Euler A",
|
||||
@ -139,6 +140,7 @@ public:
|
||||
|
||||
bool load_from_file(const std::string& model_path,
|
||||
const std::string& clip_l_path,
|
||||
const std::string& clip_g_path,
|
||||
const std::string& t5xxl_path,
|
||||
const std::string& diffusion_model_path,
|
||||
const std::string& vae_path,
|
||||
@ -167,7 +169,7 @@ public:
|
||||
for (int device = 0; device < ggml_backend_vk_get_device_count(); ++device) {
|
||||
backend = ggml_backend_vk_init(device);
|
||||
}
|
||||
if(!backend) {
|
||||
if (!backend) {
|
||||
LOG_WARN("Failed to initialize Vulkan backend");
|
||||
}
|
||||
#endif
|
||||
@ -181,7 +183,7 @@ public:
|
||||
backend = ggml_backend_cpu_init();
|
||||
}
|
||||
#ifdef SD_USE_FLASH_ATTENTION
|
||||
#if defined(SD_USE_CUBLAS) || defined(SD_USE_METAL) || defined (SD_USE_SYCL) || defined(SD_USE_VULKAN)
|
||||
#if defined(SD_USE_CUBLAS) || defined(SD_USE_METAL) || defined(SD_USE_SYCL) || defined(SD_USE_VULKAN)
|
||||
LOG_WARN("Flash Attention not supported with GPU Backend");
|
||||
#else
|
||||
LOG_INFO("Flash Attention enabled");
|
||||
@ -200,14 +202,21 @@ public:
|
||||
|
||||
if (clip_l_path.size() > 0) {
|
||||
LOG_INFO("loading clip_l from '%s'", clip_l_path.c_str());
|
||||
if (!model_loader.init_from_file(clip_l_path, "text_encoders.clip_l.")) {
|
||||
if (!model_loader.init_from_file(clip_l_path, "text_encoders.clip_l.transformer.")) {
|
||||
LOG_WARN("loading clip_l from '%s' failed", clip_l_path.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
if (clip_g_path.size() > 0) {
|
||||
LOG_INFO("loading clip_g from '%s'", clip_g_path.c_str());
|
||||
if (!model_loader.init_from_file(clip_g_path, "text_encoders.clip_g.transformer.")) {
|
||||
LOG_WARN("loading clip_g from '%s' failed", clip_g_path.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
if (t5xxl_path.size() > 0) {
|
||||
LOG_INFO("loading t5xxl from '%s'", t5xxl_path.c_str());
|
||||
if (!model_loader.init_from_file(t5xxl_path, "text_encoders.t5xxl.")) {
|
||||
if (!model_loader.init_from_file(t5xxl_path, "text_encoders.t5xxl.transformer.")) {
|
||||
LOG_WARN("loading t5xxl from '%s' failed", t5xxl_path.c_str());
|
||||
}
|
||||
}
|
||||
@ -279,7 +288,7 @@ public:
|
||||
"try specifying SDXL VAE FP16 Fix with the --vae parameter. "
|
||||
"You can find it here: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/blob/main/sdxl_vae.safetensors");
|
||||
}
|
||||
} else if (version == VERSION_SD3_2B) {
|
||||
} else if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B) {
|
||||
scale_factor = 1.5305f;
|
||||
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
|
||||
scale_factor = 0.3611;
|
||||
@ -302,7 +311,7 @@ public:
|
||||
} else {
|
||||
clip_backend = backend;
|
||||
bool use_t5xxl = false;
|
||||
if (version == VERSION_SD3_2B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
|
||||
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
|
||||
use_t5xxl = true;
|
||||
}
|
||||
if (!ggml_backend_is_cpu(backend) && use_t5xxl && conditioner_wtype != GGML_TYPE_F32) {
|
||||
@ -313,7 +322,7 @@ public:
|
||||
LOG_INFO("CLIP: Using CPU backend");
|
||||
clip_backend = ggml_backend_cpu_init();
|
||||
}
|
||||
if (version == VERSION_SD3_2B) {
|
||||
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B) {
|
||||
cond_stage_model = std::make_shared<SD3CLIPEmbedder>(clip_backend, conditioner_wtype);
|
||||
diffusion_model = std::make_shared<MMDiTModel>(backend, diffusion_model_wtype, version);
|
||||
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
|
||||
@ -511,7 +520,7 @@ public:
|
||||
is_using_v_parameterization = true;
|
||||
}
|
||||
|
||||
if (version == VERSION_SD3_2B) {
|
||||
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B) {
|
||||
LOG_INFO("running in FLOW mode");
|
||||
denoiser = std::make_shared<DiscreteFlowDenoiser>();
|
||||
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
|
||||
@ -939,7 +948,7 @@ public:
|
||||
if (use_tiny_autoencoder) {
|
||||
C = 4;
|
||||
} else {
|
||||
if (version == VERSION_SD3_2B) {
|
||||
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B) {
|
||||
C = 32;
|
||||
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
|
||||
C = 32;
|
||||
@ -1008,6 +1017,7 @@ struct sd_ctx_t {
|
||||
|
||||
sd_ctx_t* new_sd_ctx(const char* model_path_c_str,
|
||||
const char* clip_l_path_c_str,
|
||||
const char* clip_g_path_c_str,
|
||||
const char* t5xxl_path_c_str,
|
||||
const char* diffusion_model_path_c_str,
|
||||
const char* vae_path_c_str,
|
||||
@ -1032,6 +1042,7 @@ sd_ctx_t* new_sd_ctx(const char* model_path_c_str,
|
||||
}
|
||||
std::string model_path(model_path_c_str);
|
||||
std::string clip_l_path(clip_l_path_c_str);
|
||||
std::string clip_g_path(clip_g_path_c_str);
|
||||
std::string t5xxl_path(t5xxl_path_c_str);
|
||||
std::string diffusion_model_path(diffusion_model_path_c_str);
|
||||
std::string vae_path(vae_path_c_str);
|
||||
@ -1052,6 +1063,7 @@ sd_ctx_t* new_sd_ctx(const char* model_path_c_str,
|
||||
|
||||
if (!sd_ctx->sd->load_from_file(model_path,
|
||||
clip_l_path,
|
||||
clip_g_path,
|
||||
t5xxl_path_c_str,
|
||||
diffusion_model_path,
|
||||
vae_path,
|
||||
@ -1269,7 +1281,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
|
||||
// Sample
|
||||
std::vector<struct ggml_tensor*> final_latents; // collect latents to decode
|
||||
int C = 4;
|
||||
if (sd_ctx->sd->version == VERSION_SD3_2B) {
|
||||
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B) {
|
||||
C = 16;
|
||||
} else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) {
|
||||
C = 16;
|
||||
@ -1382,7 +1394,7 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
|
||||
|
||||
struct ggml_init_params params;
|
||||
params.mem_size = static_cast<size_t>(10 * 1024 * 1024); // 10 MB
|
||||
if (sd_ctx->sd->version == VERSION_SD3_2B) {
|
||||
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B) {
|
||||
params.mem_size *= 3;
|
||||
}
|
||||
if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) {
|
||||
@ -1408,7 +1420,7 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
|
||||
std::vector<float> sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps);
|
||||
|
||||
int C = 4;
|
||||
if (sd_ctx->sd->version == VERSION_SD3_2B) {
|
||||
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B) {
|
||||
C = 16;
|
||||
} else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) {
|
||||
C = 16;
|
||||
@ -1416,7 +1428,7 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
|
||||
int W = width / 8;
|
||||
int H = height / 8;
|
||||
ggml_tensor* init_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, C, 1);
|
||||
if (sd_ctx->sd->version == VERSION_SD3_2B) {
|
||||
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B) {
|
||||
ggml_set_f32(init_latent, 0.0609f);
|
||||
} else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) {
|
||||
ggml_set_f32(init_latent, 0.1159f);
|
||||
@ -1477,7 +1489,7 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
|
||||
|
||||
struct ggml_init_params params;
|
||||
params.mem_size = static_cast<size_t>(10 * 1024 * 1024); // 10 MB
|
||||
if (sd_ctx->sd->version == VERSION_SD3_2B) {
|
||||
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B) {
|
||||
params.mem_size *= 2;
|
||||
}
|
||||
if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) {
|
||||
|
@ -124,6 +124,7 @@ typedef struct sd_ctx_t sd_ctx_t;
|
||||
|
||||
SD_API sd_ctx_t* new_sd_ctx(const char* model_path,
|
||||
const char* clip_l_path,
|
||||
const char* clip_g_path,
|
||||
const char* t5xxl_path,
|
||||
const char* diffusion_model_path,
|
||||
const char* vae_path,
|
||||
|
2
vae.hpp
2
vae.hpp
@ -457,7 +457,7 @@ public:
|
||||
bool use_video_decoder = false,
|
||||
SDVersion version = VERSION_SD1)
|
||||
: decode_only(decode_only), use_video_decoder(use_video_decoder) {
|
||||
if (version == VERSION_SD3_2B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
|
||||
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
|
||||
dd_config.z_channels = 16;
|
||||
use_quant = false;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user