feat: add sd3.5 support (#445)

This commit is contained in:
leejet 2024-10-24 21:58:03 +08:00 committed by GitHub
parent 14206fd488
commit ac54e00760
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 250 additions and 127 deletions

View File

@ -10,7 +10,7 @@ Inference of Stable Diffusion and Flux in pure C/C++
- Plain C/C++ implementation based on [ggml](https://github.com/ggerganov/ggml), working in the same way as [llama.cpp](https://github.com/ggerganov/llama.cpp)
- Super lightweight and without external dependencies
- SD1.x, SD2.x, SDXL and SD3 support
- SD1.x, SD2.x, SDXL and [SD3/SD3.5](./docs/sd3.md) support
- !!!The VAE in SDXL encounters NaN issues under FP16, but unfortunately, the ggml_conv_2d only operates under FP16. Hence, a parameter is needed to specify the VAE that has fixed the FP16 NaN issue. You can find it here: [SDXL VAE FP16 Fix](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/blob/main/sdxl_vae.safetensors).
- [Flux-dev/Flux-schnell Support](./docs/flux.md)
@ -197,23 +197,24 @@ usage: ./bin/sd [arguments]
arguments:
-h, --help show this help message and exit
-M, --mode [MODEL] run mode (txt2img or img2img or convert, default: txt2img)
-t, --threads N number of threads to use during computation (default: -1).
-t, --threads N number of threads to use during computation (default: -1)
If threads <= 0, then threads will be set to the number of CPU physical cores
-m, --model [MODEL] path to full model
--diffusion-model path to the standalone diffusion model
--clip_l path to the clip-l text encoder
--t5xxl path to the the t5xxl text encoder.
--clip_g path to the clip-l text encoder
--t5xxl path to the the t5xxl text encoder
--vae [VAE] path to vae
--taesd [TAESD_PATH] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)
--control-net [CONTROL_PATH] path to control net model
--embd-dir [EMBEDDING_PATH] path to embeddings.
--stacked-id-embd-dir [DIR] path to PHOTOMAKER stacked id embeddings.
--input-id-images-dir [DIR] path to PHOTOMAKER input id images dir.
--embd-dir [EMBEDDING_PATH] path to embeddings
--stacked-id-embd-dir [DIR] path to PHOTOMAKER stacked id embeddings
--input-id-images-dir [DIR] path to PHOTOMAKER input id images dir
--normalize-input normalize PHOTOMAKER input id images
--upscale-model [ESRGAN_PATH] path to esrgan model. Upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now.
--upscale-model [ESRGAN_PATH] path to esrgan model. Upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now
--upscale-repeats Run the ESRGAN upscaler this many times (default 1)
--type [TYPE] weight type (f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_k, q3_k, q4_k)
If not specified, the default is the type of the weight file.
If not specified, the default is the type of the weight file
--lora-model-dir [DIR] lora model directory
-i, --init-img [IMAGE] path to the input image, required by img2img
--control-image [IMAGE] path to image condition, control net
@ -232,13 +233,13 @@ arguments:
--steps STEPS number of sample steps (default: 20)
--rng {std_default, cuda} RNG (default: cuda)
-s SEED, --seed SEED RNG seed (default: 42, use random seed for < 0)
-b, --batch-count COUNT number of images to generate.
-b, --batch-count COUNT number of images to generate
--schedule {discrete, karras, exponential, ays, gits} Denoiser sigma schedule (default: discrete)
--clip-skip N ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1)
<= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x
--vae-tiling process vae in tiles to reduce memory usage
--vae-on-cpu keep vae in cpu (for low vram)
--clip-on-cpu keep clip in cpu (for low vram).
--clip-on-cpu keep clip in cpu (for low vram)
--control-net-cpu keep controlnet in cpu (for low vram)
--canny apply canny preprocessor (edge detection)
--color Colors the logging tags according to level
@ -253,6 +254,7 @@ arguments:
# ./bin/sd -m ../models/sd_xl_base_1.0.safetensors --vae ../models/sdxl_vae-fp16-fix.safetensors -H 1024 -W 1024 -p "a lovely cat" -v
# ./bin/sd -m ../models/sd3_medium_incl_clips_t5xxlfp16.safetensors -H 1024 -W 1024 -p 'a lovely cat holding a sign says \"Stable Diffusion CPP\"' --cfg-scale 4.5 --sampling-method euler -v
# ./bin/sd --diffusion-model ../models/flux1-dev-q3_k.gguf --vae ../models/ae.sft --clip_l ../models/clip_l.safetensors --t5xxl ../models/t5xxl_fp16.safetensors -p "a lovely cat holding a sign says 'flux.cpp'" --cfg-scale 1.0 --sampling-method euler -v
# ./bin/sd -m ..\models\sd3.5_large.safetensors --clip_l ..\models\clip_l.safetensors --clip_g ..\models\clip_g.safetensors --t5xxl ..\models\t5xxl_fp16.safetensors -H 1024 -W 1024 -p 'a lovely cat holding a sign says \"Stable diffusion 3.5 Large\"' --cfg-scale 4.5 --sampling-method euler -v
```
Using formats of different precisions will yield results of varying quality.

BIN
assets/sd3.5_large.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.8 MiB

View File

@ -1001,8 +1001,8 @@ struct FluxCLIPEmbedder : public Conditioner {
}
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) {
clip_l->get_param_tensors(tensors, "text_encoders.clip_l.text_model");
t5->get_param_tensors(tensors, "text_encoders.t5xxl");
clip_l->get_param_tensors(tensors, "text_encoders.clip_l.transformer.text_model");
t5->get_param_tensors(tensors, "text_encoders.t5xxl.transformer");
}
void alloc_params_buffer() {

View File

@ -49,7 +49,7 @@ struct ExponentialSchedule : SigmaSchedule {
// Calculate step size
float log_sigma_min = std::log(sigma_min);
float log_sigma_max = std::log(sigma_max);
float step = (log_sigma_max - log_sigma_min) / (n - 1);
float step = (log_sigma_max - log_sigma_min) / (n - 1);
// Fill sigmas with exponential values
for (uint32_t i = 0; i < n; ++i) {
@ -205,7 +205,7 @@ struct AYSSchedule : SigmaSchedule {
/*
* GITS Scheduler: https://github.com/zju-pi/diff-sampler/tree/main/gits-main
*/
*/
struct GITSSchedule : SigmaSchedule {
std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) {
if (sigma_max <= 0.0f) {
@ -221,7 +221,7 @@ struct GITSSchedule : SigmaSchedule {
// Calculate the index based on the coefficient
int index = static_cast<int>((coeff - 0.80f) / 0.05f);
// Ensure the index is within bounds
index = std::max(0, std::min(index, static_cast<int>(GITS_NOISE.size() - 1)));
index = std::max(0, std::min(index, static_cast<int>(GITS_NOISE.size() - 1)));
const std::vector<std::vector<float>>& selected_noise = *GITS_NOISE[index];
if (n <= 20) {
@ -823,24 +823,24 @@ static void sample_k_diffusion(sample_method_t method,
} break;
case IPNDM: // iPNDM sampler from https://github.com/zju-pi/diff-sampler/tree/main/diff-solvers-main
{
int max_order = 4;
int max_order = 4;
ggml_tensor* x_next = x;
std::vector<ggml_tensor*> buffer_model;
for (int i = 0; i < steps; i++) {
float sigma = sigmas[i];
float sigma = sigmas[i];
float sigma_next = sigmas[i + 1];
ggml_tensor* x_cur = x_next;
float* vec_x_cur = (float*)x_cur->data;
float* vec_x_next = (float*)x_next->data;
float* vec_x_cur = (float*)x_cur->data;
float* vec_x_next = (float*)x_next->data;
// Denoising step
ggml_tensor* denoised = model(x_cur, sigma, i + 1);
float* vec_denoised = (float*)denoised->data;
float* vec_denoised = (float*)denoised->data;
// d_cur = (x_cur - denoised) / sigma
struct ggml_tensor* d_cur = ggml_dup_tensor(work_ctx, x_cur);
float* vec_d_cur = (float*)d_cur->data;
float* vec_d_cur = (float*)d_cur->data;
for (int j = 0; j < ggml_nelements(d_cur); j++) {
vec_d_cur[j] = (vec_x_cur[j] - vec_denoised[j]) / sigma;
@ -857,34 +857,31 @@ static void sample_k_diffusion(sample_method_t method,
break;
case 2: // Use one history point
{
float* vec_d_prev1 = (float*)buffer_model.back()->data;
for (int j = 0; j < ggml_nelements(x_next); j++) {
vec_x_next[j] = vec_x_cur[j] + (sigma_next - sigma) * (3 * vec_d_cur[j] - vec_d_prev1[j]) / 2;
}
{
float* vec_d_prev1 = (float*)buffer_model.back()->data;
for (int j = 0; j < ggml_nelements(x_next); j++) {
vec_x_next[j] = vec_x_cur[j] + (sigma_next - sigma) * (3 * vec_d_cur[j] - vec_d_prev1[j]) / 2;
}
break;
} break;
case 3: // Use two history points
{
float* vec_d_prev1 = (float*)buffer_model.back()->data;
float* vec_d_prev2 = (float*)buffer_model[buffer_model.size() - 2]->data;
for (int j = 0; j < ggml_nelements(x_next); j++) {
vec_x_next[j] = vec_x_cur[j] + (sigma_next - sigma) * (23 * vec_d_cur[j] - 16 * vec_d_prev1[j] + 5 * vec_d_prev2[j]) / 12;
}
{
float* vec_d_prev1 = (float*)buffer_model.back()->data;
float* vec_d_prev2 = (float*)buffer_model[buffer_model.size() - 2]->data;
for (int j = 0; j < ggml_nelements(x_next); j++) {
vec_x_next[j] = vec_x_cur[j] + (sigma_next - sigma) * (23 * vec_d_cur[j] - 16 * vec_d_prev1[j] + 5 * vec_d_prev2[j]) / 12;
}
break;
} break;
case 4: // Use three history points
{
float* vec_d_prev1 = (float*)buffer_model.back()->data;
float* vec_d_prev2 = (float*)buffer_model[buffer_model.size() - 2]->data;
float* vec_d_prev3 = (float*)buffer_model[buffer_model.size() - 3]->data;
for (int j = 0; j < ggml_nelements(x_next); j++) {
vec_x_next[j] = vec_x_cur[j] + (sigma_next - sigma) * (55 * vec_d_cur[j] - 59 * vec_d_prev1[j] + 37 * vec_d_prev2[j] - 9 * vec_d_prev3[j]) / 24;
}
{
float* vec_d_prev1 = (float*)buffer_model.back()->data;
float* vec_d_prev2 = (float*)buffer_model[buffer_model.size() - 2]->data;
float* vec_d_prev3 = (float*)buffer_model[buffer_model.size() - 3]->data;
for (int j = 0; j < ggml_nelements(x_next); j++) {
vec_x_next[j] = vec_x_cur[j] + (sigma_next - sigma) * (55 * vec_d_cur[j] - 59 * vec_d_prev1[j] + 37 * vec_d_prev2[j] - 9 * vec_d_prev3[j]) / 24;
}
break;
} break;
}
// Manage buffer_model
@ -906,23 +903,23 @@ static void sample_k_diffusion(sample_method_t method,
ggml_tensor* x_next = x;
for (int i = 0; i < steps; i++) {
float sigma = sigmas[i];
float sigma = sigmas[i];
float t_next = sigmas[i + 1];
// Denoising step
ggml_tensor* denoised = model(x, sigma, i + 1);
float* vec_denoised = (float*)denoised->data;
ggml_tensor* denoised = model(x, sigma, i + 1);
float* vec_denoised = (float*)denoised->data;
struct ggml_tensor* d_cur = ggml_dup_tensor(work_ctx, x);
float* vec_d_cur = (float*)d_cur->data;
float* vec_x = (float*)x->data;
float* vec_d_cur = (float*)d_cur->data;
float* vec_x = (float*)x->data;
// d_cur = (x - denoised) / sigma
for (int j = 0; j < ggml_nelements(d_cur); j++) {
vec_d_cur[j] = (vec_x[j] - vec_denoised[j]) / sigma;
}
int order = std::min(max_order, i + 1);
float h_n = t_next - sigma;
int order = std::min(max_order, i + 1);
float h_n = t_next - sigma;
float h_n_1 = (i > 0) ? (sigma - sigmas[i - 1]) : h_n;
switch (order) {
@ -941,7 +938,7 @@ static void sample_k_diffusion(sample_method_t method,
}
case 3: {
float h_n_2 = (i > 1) ? (sigmas[i - 1] - sigmas[i - 2]) : h_n_1;
float h_n_2 = (i > 1) ? (sigmas[i - 1] - sigmas[i - 2]) : h_n_1;
float* vec_d_prev1 = (float*)buffer_model.back()->data;
float* vec_d_prev2 = (buffer_model.size() > 1) ? (float*)buffer_model[buffer_model.size() - 2]->data : vec_d_prev1;
for (int j = 0; j < ggml_nelements(x_next); j++) {
@ -951,8 +948,8 @@ static void sample_k_diffusion(sample_method_t method,
}
case 4: {
float h_n_2 = (i > 1) ? (sigmas[i - 1] - sigmas[i - 2]) : h_n_1;
float h_n_3 = (i > 2) ? (sigmas[i - 2] - sigmas[i - 3]) : h_n_2;
float h_n_2 = (i > 1) ? (sigmas[i - 1] - sigmas[i - 2]) : h_n_1;
float h_n_3 = (i > 2) ? (sigmas[i - 2] - sigmas[i - 3]) : h_n_2;
float* vec_d_prev1 = (float*)buffer_model.back()->data;
float* vec_d_prev2 = (buffer_model.size() > 1) ? (float*)buffer_model[buffer_model.size() - 2]->data : vec_d_prev1;
float* vec_d_prev3 = (buffer_model.size() > 2) ? (float*)buffer_model[buffer_model.size() - 3]->data : vec_d_prev2;

20
docs/sd3.md Normal file
View File

@ -0,0 +1,20 @@
# How to Use
## Download weights
- Download sd3.5_large from https://huggingface.co/stabilityai/stable-diffusion-3.5-large/blob/main/sd3.5_large.safetensors
- Download clip_g from https://huggingface.co/Comfy-Org/stable-diffusion-3.5-fp8/blob/main/text_encoders/clip_g.safetensors
- Download clip_l from https://huggingface.co/Comfy-Org/stable-diffusion-3.5-fp8/blob/main/text_encoders/clip_l.safetensors
- Download t5xxl from https://huggingface.co/Comfy-Org/stable-diffusion-3.5-fp8/blob/main/text_encoders/t5xxl_fp16.safetensors
## Run
### SD3.5 Large
For example:
```
.\bin\Release\sd.exe -m ..\models\sd3.5_large.safetensors --clip_l ..\models\clip_l.safetensors --clip_g ..\models\clip_g.safetensors --t5xxl ..\models\t5xxl_fp16.safetensors -H 1024 -W 1024 -p 'a lovely cat holding a sign says \"Stable diffusion 3.5 Large\"' --cfg-scale 4.5 --sampling-method euler -v
```
![](../assets/sd3.5_large.png)

View File

@ -69,9 +69,9 @@ enum SDMode {
struct SDParams {
int n_threads = -1;
SDMode mode = TXT2IMG;
std::string model_path;
std::string clip_l_path;
std::string clip_g_path;
std::string t5xxl_path;
std::string diffusion_model_path;
std::string vae_path;
@ -128,6 +128,7 @@ void print_params(SDParams params) {
printf(" model_path: %s\n", params.model_path.c_str());
printf(" wtype: %s\n", params.wtype < SD_TYPE_COUNT ? sd_type_name(params.wtype) : "unspecified");
printf(" clip_l_path: %s\n", params.clip_l_path.c_str());
printf(" clip_g_path: %s\n", params.clip_g_path.c_str());
printf(" t5xxl_path: %s\n", params.t5xxl_path.c_str());
printf(" diffusion_model_path: %s\n", params.diffusion_model_path.c_str());
printf(" vae_path: %s\n", params.vae_path.c_str());
@ -171,23 +172,24 @@ void print_usage(int argc, const char* argv[]) {
printf("arguments:\n");
printf(" -h, --help show this help message and exit\n");
printf(" -M, --mode [MODEL] run mode (txt2img or img2img or convert, default: txt2img)\n");
printf(" -t, --threads N number of threads to use during computation (default: -1).\n");
printf(" -t, --threads N number of threads to use during computation (default: -1)\n");
printf(" If threads <= 0, then threads will be set to the number of CPU physical cores\n");
printf(" -m, --model [MODEL] path to full model\n");
printf(" --diffusion-model path to the standalone diffusion model\n");
printf(" --clip_l path to the clip-l text encoder\n");
printf(" --t5xxl path to the the t5xxl text encoder.\n");
printf(" --clip_g path to the clip-l text encoder\n");
printf(" --t5xxl path to the the t5xxl text encoder\n");
printf(" --vae [VAE] path to vae\n");
printf(" --taesd [TAESD_PATH] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)\n");
printf(" --control-net [CONTROL_PATH] path to control net model\n");
printf(" --embd-dir [EMBEDDING_PATH] path to embeddings.\n");
printf(" --stacked-id-embd-dir [DIR] path to PHOTOMAKER stacked id embeddings.\n");
printf(" --input-id-images-dir [DIR] path to PHOTOMAKER input id images dir.\n");
printf(" --embd-dir [EMBEDDING_PATH] path to embeddings\n");
printf(" --stacked-id-embd-dir [DIR] path to PHOTOMAKER stacked id embeddings\n");
printf(" --input-id-images-dir [DIR] path to PHOTOMAKER input id images dir\n");
printf(" --normalize-input normalize PHOTOMAKER input id images\n");
printf(" --upscale-model [ESRGAN_PATH] path to esrgan model. Upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now.\n");
printf(" --upscale-model [ESRGAN_PATH] path to esrgan model. Upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now\n");
printf(" --upscale-repeats Run the ESRGAN upscaler this many times (default 1)\n");
printf(" --type [TYPE] weight type (f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_k, q3_k, q4_k)\n");
printf(" If not specified, the default is the type of the weight file.\n");
printf(" If not specified, the default is the type of the weight file\n");
printf(" --lora-model-dir [DIR] lora model directory\n");
printf(" -i, --init-img [IMAGE] path to the input image, required by img2img\n");
printf(" --control-image [IMAGE] path to image condition, control net\n");
@ -206,13 +208,13 @@ void print_usage(int argc, const char* argv[]) {
printf(" --steps STEPS number of sample steps (default: 20)\n");
printf(" --rng {std_default, cuda} RNG (default: cuda)\n");
printf(" -s SEED, --seed SEED RNG seed (default: 42, use random seed for < 0)\n");
printf(" -b, --batch-count COUNT number of images to generate.\n");
printf(" -b, --batch-count COUNT number of images to generate\n");
printf(" --schedule {discrete, karras, exponential, ays, gits} Denoiser sigma schedule (default: discrete)\n");
printf(" --clip-skip N ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1)\n");
printf(" <= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x\n");
printf(" --vae-tiling process vae in tiles to reduce memory usage\n");
printf(" --vae-on-cpu keep vae in cpu (for low vram)\n");
printf(" --clip-on-cpu keep clip in cpu (for low vram).\n");
printf(" --clip-on-cpu keep clip in cpu (for low vram)\n");
printf(" --control-net-cpu keep controlnet in cpu (for low vram)\n");
printf(" --canny apply canny preprocessor (edge detection)\n");
printf(" --color Colors the logging tags according to level\n");
@ -262,6 +264,12 @@ void parse_args(int argc, const char** argv, SDParams& params) {
break;
}
params.clip_l_path = argv[i];
} else if (arg == "--clip_g") {
if (++i >= argc) {
invalid_arg = true;
break;
}
params.clip_g_path = argv[i];
} else if (arg == "--t5xxl") {
if (++i >= argc) {
invalid_arg = true;
@ -765,6 +773,7 @@ int main(int argc, const char* argv[]) {
sd_ctx_t* sd_ctx = new_sd_ctx(params.model_path.c_str(),
params.clip_l_path.c_str(),
params.clip_g_path.c_str(),
params.t5xxl_path.c_str(),
params.diffusion_model_path.c_str(),
params.vae_path.c_str(),

View File

@ -368,8 +368,8 @@ __STATIC_INLINE__ void ggml_merge_tensor_2d(struct ggml_tensor* input,
int64_t height = input->ne[1];
int64_t channels = input->ne[2];
int64_t img_width = output->ne[0];
int64_t img_height = output->ne[1];
int64_t img_width = output->ne[0];
int64_t img_height = output->ne[1];
GGML_ASSERT(input->type == GGML_TYPE_F32 && output->type == GGML_TYPE_F32);
for (int iy = 0; iy < height; iy++) {
@ -380,7 +380,7 @@ __STATIC_INLINE__ void ggml_merge_tensor_2d(struct ggml_tensor* input,
float old_value = ggml_tensor_get_f32(output, x + ix, y + iy, k);
const float x_f_0 = (x > 0) ? ix / float(overlap) : 1;
const float x_f_1 = (x < (img_width - width)) ? (width - ix) / float(overlap) : 1 ;
const float x_f_1 = (x < (img_width - width)) ? (width - ix) / float(overlap) : 1;
const float y_f_0 = (y > 0) ? iy / float(overlap) : 1;
const float y_f_1 = (y < (img_height - height)) ? (height - iy) / float(overlap) : 1;
@ -390,8 +390,7 @@ __STATIC_INLINE__ void ggml_merge_tensor_2d(struct ggml_tensor* input,
ggml_tensor_set_f32(
output,
old_value + new_value * ggml_smootherstep_f32(y_f) * ggml_smootherstep_f32(x_f),
x + ix, y + iy, k
);
x + ix, y + iy, k);
} else {
ggml_tensor_set_f32(output, new_value, x + ix, y + iy, k);
}

148
mmdit.hpp
View File

@ -142,29 +142,77 @@ public:
}
};
class RMSNorm : public UnaryBlock {
protected:
int64_t hidden_size;
float eps;
void init_params(struct ggml_context* ctx, ggml_type wtype) {
params["weight"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hidden_size);
}
public:
RMSNorm(int64_t hidden_size,
float eps = 1e-06f)
: hidden_size(hidden_size),
eps(eps) {}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
struct ggml_tensor* w = params["weight"];
x = ggml_rms_norm(ctx, x, eps);
x = ggml_mul(ctx, x, w);
return x;
}
};
class SelfAttention : public GGMLBlock {
public:
int64_t num_heads;
bool pre_only;
std::string qk_norm;
public:
SelfAttention(int64_t dim,
int64_t num_heads = 8,
bool qkv_bias = false,
bool pre_only = false)
: num_heads(num_heads), pre_only(pre_only) {
// qk_norm is always None
blocks["qkv"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim * 3, qkv_bias));
int64_t num_heads = 8,
std::string qk_norm = "",
bool qkv_bias = false,
bool pre_only = false)
: num_heads(num_heads), pre_only(pre_only), qk_norm(qk_norm) {
int64_t d_head = dim / num_heads;
blocks["qkv"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim * 3, qkv_bias));
if (!pre_only) {
blocks["proj"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim));
}
if (qk_norm == "rms") {
blocks["ln_q"] = std::shared_ptr<GGMLBlock>(new RMSNorm(d_head, 1.0e-6));
blocks["ln_k"] = std::shared_ptr<GGMLBlock>(new RMSNorm(d_head, 1.0e-6));
} else if (qk_norm == "ln") {
blocks["ln_q"] = std::shared_ptr<GGMLBlock>(new LayerNorm(d_head, 1.0e-6));
blocks["ln_k"] = std::shared_ptr<GGMLBlock>(new LayerNorm(d_head, 1.0e-6));
}
}
std::vector<struct ggml_tensor*> pre_attention(struct ggml_context* ctx, struct ggml_tensor* x) {
auto qkv_proj = std::dynamic_pointer_cast<Linear>(blocks["qkv"]);
auto qkv = qkv_proj->forward(ctx, x);
return split_qkv(ctx, qkv);
auto qkv = qkv_proj->forward(ctx, x);
auto qkv_vec = split_qkv(ctx, qkv);
int64_t head_dim = qkv_vec[0]->ne[0] / num_heads;
auto q = ggml_reshape_4d(ctx, qkv_vec[0], head_dim, num_heads, qkv_vec[0]->ne[1], qkv_vec[0]->ne[2]); // [N, n_token, n_head, d_head]
auto k = ggml_reshape_4d(ctx, qkv_vec[1], head_dim, num_heads, qkv_vec[1]->ne[1], qkv_vec[1]->ne[2]); // [N, n_token, n_head, d_head]
auto v = qkv_vec[2]; // [N, n_token, n_head*d_head]
if (qk_norm == "rms" || qk_norm == "ln") {
auto ln_q = std::dynamic_pointer_cast<UnaryBlock>(blocks["ln_q"]);
auto ln_k = std::dynamic_pointer_cast<UnaryBlock>(blocks["ln_k"]);
q = ln_q->forward(ctx, q);
k = ln_k->forward(ctx, k);
}
q = ggml_reshape_3d(ctx, q, q->ne[0] * q->ne[1], q->ne[2], q->ne[3]); // [N, n_token, n_head*d_head]
k = ggml_reshape_3d(ctx, k, k->ne[0] * k->ne[1], k->ne[2], k->ne[3]); // [N, n_token, n_head*d_head]
return {q, k, v};
}
struct ggml_tensor* post_attention(struct ggml_context* ctx, struct ggml_tensor* x) {
@ -208,16 +256,16 @@ public:
public:
DismantledBlock(int64_t hidden_size,
int64_t num_heads,
float mlp_ratio = 4.0,
bool qkv_bias = false,
bool pre_only = false)
float mlp_ratio = 4.0,
std::string qk_norm = "",
bool qkv_bias = false,
bool pre_only = false)
: num_heads(num_heads), pre_only(pre_only) {
// rmsnorm is always Flase
// scale_mod_only is always Flase
// swiglu is always Flase
// qk_norm is always Flase
blocks["norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-06f, false));
blocks["attn"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qkv_bias, pre_only));
blocks["attn"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qk_norm, qkv_bias, pre_only));
if (!pre_only) {
blocks["norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-06f, false));
@ -396,12 +444,12 @@ struct JointBlock : public GGMLBlock {
public:
JointBlock(int64_t hidden_size,
int64_t num_heads,
float mlp_ratio = 4.0,
bool qkv_bias = false,
bool pre_only = false) {
// qk_norm is always Flase
blocks["context_block"] = std::shared_ptr<GGMLBlock>(new DismantledBlock(hidden_size, num_heads, mlp_ratio, qkv_bias, pre_only));
blocks["x_block"] = std::shared_ptr<GGMLBlock>(new DismantledBlock(hidden_size, num_heads, mlp_ratio, qkv_bias, false));
float mlp_ratio = 4.0,
std::string qk_norm = "",
bool qkv_bias = false,
bool pre_only = false) {
blocks["context_block"] = std::shared_ptr<GGMLBlock>(new DismantledBlock(hidden_size, num_heads, mlp_ratio, qk_norm, qkv_bias, pre_only));
blocks["x_block"] = std::shared_ptr<GGMLBlock>(new DismantledBlock(hidden_size, num_heads, mlp_ratio, qk_norm, qkv_bias, false));
}
std::pair<struct ggml_tensor*, struct ggml_tensor*> forward(struct ggml_context* ctx,
@ -455,18 +503,20 @@ public:
struct MMDiT : public GGMLBlock {
// Diffusion model with a Transformer backbone.
protected:
SDVersion version = VERSION_SD3_2B;
int64_t input_size = -1;
int64_t patch_size = 2;
int64_t in_channels = 16;
int64_t depth = 24;
float mlp_ratio = 4.0f;
int64_t adm_in_channels = 2048;
int64_t out_channels = 16;
int64_t pos_embed_max_size = 192;
int64_t num_patchs = 36864; // 192 * 192
int64_t context_size = 4096;
SDVersion version = VERSION_SD3_2B;
int64_t input_size = -1;
int64_t patch_size = 2;
int64_t in_channels = 16;
int64_t depth = 24;
float mlp_ratio = 4.0f;
int64_t adm_in_channels = 2048;
int64_t out_channels = 16;
int64_t pos_embed_max_size = 192;
int64_t num_patchs = 36864; // 192 * 192
int64_t context_size = 4096;
int64_t context_embedder_out_dim = 1536;
int64_t hidden_size;
std::string qk_norm;
void init_params(struct ggml_context* ctx, ggml_type wtype) {
params["pos_embed"] = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hidden_size, num_patchs, 1);
@ -481,23 +531,36 @@ public:
// rmsnorm is alwalys False
// scale_mod_only is alwalys False
// swiglu is alwalys False
// qk_norm is always None
// qkv_bias is always True
// context_processor_layers is always None
// pos_embed_scaling_factor is not used
// pos_embed_offset is not used
// context_embedder_config is always {'target': 'torch.nn.Linear', 'params': {'in_features': 4096, 'out_features': 1536}}
if (version == VERSION_SD3_2B) {
input_size = -1;
patch_size = 2;
in_channels = 16;
depth = 24;
mlp_ratio = 4.0f;
adm_in_channels = 2048;
out_channels = 16;
pos_embed_max_size = 192;
num_patchs = 36864; // 192 * 192
context_size = 4096;
input_size = -1;
patch_size = 2;
in_channels = 16;
depth = 24;
mlp_ratio = 4.0f;
adm_in_channels = 2048;
out_channels = 16;
pos_embed_max_size = 192;
num_patchs = 36864; // 192 * 192
context_size = 4096;
context_embedder_out_dim = 1536;
} else if (version == VERSION_SD3_5_8B) {
input_size = -1;
patch_size = 2;
in_channels = 16;
depth = 38;
mlp_ratio = 4.0f;
adm_in_channels = 2048;
out_channels = 16;
pos_embed_max_size = 192;
num_patchs = 36864; // 192 * 192
context_size = 4096;
context_embedder_out_dim = 2432;
qk_norm = "rms";
}
int64_t default_out_channels = in_channels;
hidden_size = 64 * depth;
@ -510,12 +573,13 @@ public:
blocks["y_embedder"] = std::shared_ptr<GGMLBlock>(new VectorEmbedder(adm_in_channels, hidden_size));
}
blocks["context_embedder"] = std::shared_ptr<GGMLBlock>(new Linear(4096, 1536, true, true));
blocks["context_embedder"] = std::shared_ptr<GGMLBlock>(new Linear(4096, context_embedder_out_dim, true, true));
for (int i = 0; i < depth; i++) {
blocks["joint_blocks." + std::to_string(i)] = std::shared_ptr<GGMLBlock>(new JointBlock(hidden_size,
num_heads,
mlp_ratio,
qk_norm,
true,
i == depth - 1));
}

View File

@ -430,6 +430,14 @@ std::string convert_tensor_name(std::string name) {
if (starts_with(name, "diffusion_model")) {
name = "model." + name;
}
// size_t pos = name.find("lora_A");
// if (pos != std::string::npos) {
// name.replace(pos, strlen("lora_A"), "lora_up");
// }
// pos = name.find("lora_B");
// if (pos != std::string::npos) {
// name.replace(pos, strlen("lora_B"), "lora_down");
// }
std::string new_name = name;
if (starts_with(name, "cond_stage_model.") || starts_with(name, "conditioner.embedders.") || starts_with(name, "text_encoders.") || ends_with(name, ".vision_model.visual_projection.weight")) {
new_name = convert_open_clip_to_hf_clip(name);
@ -466,6 +474,9 @@ std::string convert_tensor_name(std::string name) {
if (pos != std::string::npos) {
new_name.replace(pos, strlen(".processor"), "");
}
// if (starts_with(new_name, "transformer.transformer_blocks") || starts_with(new_name, "transformer.single_transformer_blocks")) {
// new_name = "model.diffusion_model." + new_name;
// }
pos = new_name.rfind("lora");
if (pos != std::string::npos) {
std::string name_without_network_parts = new_name.substr(0, pos - 1);
@ -1354,6 +1365,7 @@ bool ModelLoader::init_from_ckpt_file(const std::string& file_path, const std::s
SDVersion ModelLoader::get_sd_version() {
TensorStorage token_embedding_weight;
bool is_flux = false;
bool is_sd3 = false;
for (auto& tensor_storage : tensor_storages) {
if (tensor_storage.name.find("model.diffusion_model.guidance_in.in_layer.weight") != std::string::npos) {
return VERSION_FLUX_DEV;
@ -1361,8 +1373,11 @@ SDVersion ModelLoader::get_sd_version() {
if (tensor_storage.name.find("model.diffusion_model.double_blocks.") != std::string::npos) {
is_flux = true;
}
if (tensor_storage.name.find("joint_blocks.37.x_block.attn.ln_q.weight") != std::string::npos) {
return VERSION_SD3_5_8B;
}
if (tensor_storage.name.find("model.diffusion_model.joint_blocks.23.") != std::string::npos) {
return VERSION_SD3_2B;
is_sd3 = true;
}
if (tensor_storage.name.find("conditioner.embedders.1") != std::string::npos) {
return VERSION_SDXL;
@ -1387,6 +1402,9 @@ SDVersion ModelLoader::get_sd_version() {
if (is_flux) {
return VERSION_FLUX_SCHNELL;
}
if (is_sd3) {
return VERSION_SD3_2B;
}
if (token_embedding_weight.ne[0] == 768) {
return VERSION_SD1;
} else if (token_embedding_weight.ne[0] == 1024) {

View File

@ -25,6 +25,7 @@ enum SDVersion {
VERSION_SD3_2B,
VERSION_FLUX_DEV,
VERSION_FLUX_SCHNELL,
VERSION_SD3_5_8B,
VERSION_COUNT,
};

View File

@ -31,7 +31,8 @@ const char* model_version_to_str[] = {
"SVD",
"SD3 2B",
"Flux Dev",
"Flux Schnell"};
"Flux Schnell",
"SD3.5 8B"};
const char* sampling_methods_str[] = {
"Euler A",
@ -139,6 +140,7 @@ public:
bool load_from_file(const std::string& model_path,
const std::string& clip_l_path,
const std::string& clip_g_path,
const std::string& t5xxl_path,
const std::string& diffusion_model_path,
const std::string& vae_path,
@ -167,7 +169,7 @@ public:
for (int device = 0; device < ggml_backend_vk_get_device_count(); ++device) {
backend = ggml_backend_vk_init(device);
}
if(!backend) {
if (!backend) {
LOG_WARN("Failed to initialize Vulkan backend");
}
#endif
@ -181,7 +183,7 @@ public:
backend = ggml_backend_cpu_init();
}
#ifdef SD_USE_FLASH_ATTENTION
#if defined(SD_USE_CUBLAS) || defined(SD_USE_METAL) || defined (SD_USE_SYCL) || defined(SD_USE_VULKAN)
#if defined(SD_USE_CUBLAS) || defined(SD_USE_METAL) || defined(SD_USE_SYCL) || defined(SD_USE_VULKAN)
LOG_WARN("Flash Attention not supported with GPU Backend");
#else
LOG_INFO("Flash Attention enabled");
@ -200,14 +202,21 @@ public:
if (clip_l_path.size() > 0) {
LOG_INFO("loading clip_l from '%s'", clip_l_path.c_str());
if (!model_loader.init_from_file(clip_l_path, "text_encoders.clip_l.")) {
if (!model_loader.init_from_file(clip_l_path, "text_encoders.clip_l.transformer.")) {
LOG_WARN("loading clip_l from '%s' failed", clip_l_path.c_str());
}
}
if (clip_g_path.size() > 0) {
LOG_INFO("loading clip_g from '%s'", clip_g_path.c_str());
if (!model_loader.init_from_file(clip_g_path, "text_encoders.clip_g.transformer.")) {
LOG_WARN("loading clip_g from '%s' failed", clip_g_path.c_str());
}
}
if (t5xxl_path.size() > 0) {
LOG_INFO("loading t5xxl from '%s'", t5xxl_path.c_str());
if (!model_loader.init_from_file(t5xxl_path, "text_encoders.t5xxl.")) {
if (!model_loader.init_from_file(t5xxl_path, "text_encoders.t5xxl.transformer.")) {
LOG_WARN("loading t5xxl from '%s' failed", t5xxl_path.c_str());
}
}
@ -279,7 +288,7 @@ public:
"try specifying SDXL VAE FP16 Fix with the --vae parameter. "
"You can find it here: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/blob/main/sdxl_vae.safetensors");
}
} else if (version == VERSION_SD3_2B) {
} else if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B) {
scale_factor = 1.5305f;
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
scale_factor = 0.3611;
@ -302,7 +311,7 @@ public:
} else {
clip_backend = backend;
bool use_t5xxl = false;
if (version == VERSION_SD3_2B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
use_t5xxl = true;
}
if (!ggml_backend_is_cpu(backend) && use_t5xxl && conditioner_wtype != GGML_TYPE_F32) {
@ -313,7 +322,7 @@ public:
LOG_INFO("CLIP: Using CPU backend");
clip_backend = ggml_backend_cpu_init();
}
if (version == VERSION_SD3_2B) {
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B) {
cond_stage_model = std::make_shared<SD3CLIPEmbedder>(clip_backend, conditioner_wtype);
diffusion_model = std::make_shared<MMDiTModel>(backend, diffusion_model_wtype, version);
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
@ -511,7 +520,7 @@ public:
is_using_v_parameterization = true;
}
if (version == VERSION_SD3_2B) {
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B) {
LOG_INFO("running in FLOW mode");
denoiser = std::make_shared<DiscreteFlowDenoiser>();
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
@ -939,7 +948,7 @@ public:
if (use_tiny_autoencoder) {
C = 4;
} else {
if (version == VERSION_SD3_2B) {
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B) {
C = 32;
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
C = 32;
@ -1008,6 +1017,7 @@ struct sd_ctx_t {
sd_ctx_t* new_sd_ctx(const char* model_path_c_str,
const char* clip_l_path_c_str,
const char* clip_g_path_c_str,
const char* t5xxl_path_c_str,
const char* diffusion_model_path_c_str,
const char* vae_path_c_str,
@ -1032,6 +1042,7 @@ sd_ctx_t* new_sd_ctx(const char* model_path_c_str,
}
std::string model_path(model_path_c_str);
std::string clip_l_path(clip_l_path_c_str);
std::string clip_g_path(clip_g_path_c_str);
std::string t5xxl_path(t5xxl_path_c_str);
std::string diffusion_model_path(diffusion_model_path_c_str);
std::string vae_path(vae_path_c_str);
@ -1052,6 +1063,7 @@ sd_ctx_t* new_sd_ctx(const char* model_path_c_str,
if (!sd_ctx->sd->load_from_file(model_path,
clip_l_path,
clip_g_path,
t5xxl_path_c_str,
diffusion_model_path,
vae_path,
@ -1269,7 +1281,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
// Sample
std::vector<struct ggml_tensor*> final_latents; // collect latents to decode
int C = 4;
if (sd_ctx->sd->version == VERSION_SD3_2B) {
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B) {
C = 16;
} else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) {
C = 16;
@ -1382,7 +1394,7 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
struct ggml_init_params params;
params.mem_size = static_cast<size_t>(10 * 1024 * 1024); // 10 MB
if (sd_ctx->sd->version == VERSION_SD3_2B) {
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B) {
params.mem_size *= 3;
}
if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) {
@ -1408,7 +1420,7 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
std::vector<float> sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps);
int C = 4;
if (sd_ctx->sd->version == VERSION_SD3_2B) {
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B) {
C = 16;
} else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) {
C = 16;
@ -1416,7 +1428,7 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
int W = width / 8;
int H = height / 8;
ggml_tensor* init_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, C, 1);
if (sd_ctx->sd->version == VERSION_SD3_2B) {
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B) {
ggml_set_f32(init_latent, 0.0609f);
} else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) {
ggml_set_f32(init_latent, 0.1159f);
@ -1477,7 +1489,7 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
struct ggml_init_params params;
params.mem_size = static_cast<size_t>(10 * 1024 * 1024); // 10 MB
if (sd_ctx->sd->version == VERSION_SD3_2B) {
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B) {
params.mem_size *= 2;
}
if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) {

View File

@ -124,6 +124,7 @@ typedef struct sd_ctx_t sd_ctx_t;
SD_API sd_ctx_t* new_sd_ctx(const char* model_path,
const char* clip_l_path,
const char* clip_g_path,
const char* t5xxl_path,
const char* diffusion_model_path,
const char* vae_path,

View File

@ -457,7 +457,7 @@ public:
bool use_video_decoder = false,
SDVersion version = VERSION_SD1)
: decode_only(decode_only), use_video_decoder(use_video_decoder) {
if (version == VERSION_SD3_2B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
dd_config.z_channels = 16;
use_quant = false;
}