refactor: add some sd vesion helper functions

This commit is contained in:
leejet 2024-11-23 13:02:44 +08:00
parent 1c168d98a5
commit b5f4932696
5 changed files with 59 additions and 38 deletions

View File

@ -44,7 +44,7 @@ struct Conditioner {
// Ref: https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/cad87bf4e3e0b0a759afa94e933527c3123d59bc/modules/sd_hijack_clip.py#L283 // Ref: https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/cad87bf4e3e0b0a759afa94e933527c3123d59bc/modules/sd_hijack_clip.py#L283
struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
SDVersion version = VERSION_SD1; SDVersion version = VERSION_SD1;
PMVersion pm_version = VERSION_1; PMVersion pm_version = PM_VERSION_1;
CLIPTokenizer tokenizer; CLIPTokenizer tokenizer;
ggml_type wtype; ggml_type wtype;
std::shared_ptr<CLIPTextModelRunner> text_model; std::shared_ptr<CLIPTextModelRunner> text_model;
@ -60,7 +60,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
ggml_type wtype, ggml_type wtype,
const std::string& embd_dir, const std::string& embd_dir,
SDVersion version = VERSION_SD1, SDVersion version = VERSION_SD1,
PMVersion pv = VERSION_1, PMVersion pv = PM_VERSION_1,
int clip_skip = -1) int clip_skip = -1)
: version(version), pm_version(pv), tokenizer(version == VERSION_SD2 ? 0 : 49407), embd_dir(embd_dir), wtype(wtype) { : version(version), pm_version(pv), tokenizer(version == VERSION_SD2 ? 0 : 49407), embd_dir(embd_dir), wtype(wtype) {
if (clip_skip <= 0) { if (clip_skip <= 0) {
@ -270,7 +270,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
std::vector<int> clean_input_ids_tmp; std::vector<int> clean_input_ids_tmp;
for (uint32_t i = 0; i < class_token_index[0]; i++) for (uint32_t i = 0; i < class_token_index[0]; i++)
clean_input_ids_tmp.push_back(clean_input_ids[i]); clean_input_ids_tmp.push_back(clean_input_ids[i]);
for (uint32_t i = 0; i < (pm_version == VERSION_2 ? 2 * num_input_imgs : num_input_imgs); i++) for (uint32_t i = 0; i < (pm_version == PM_VERSION_2 ? 2 * num_input_imgs : num_input_imgs); i++)
clean_input_ids_tmp.push_back(class_token); clean_input_ids_tmp.push_back(class_token);
for (uint32_t i = class_token_index[0] + 1; i < clean_input_ids.size(); i++) for (uint32_t i = class_token_index[0] + 1; i < clean_input_ids.size(); i++)
clean_input_ids_tmp.push_back(clean_input_ids[i]); clean_input_ids_tmp.push_back(clean_input_ids[i]);
@ -286,7 +286,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
// weights.insert(weights.begin(), 1.0); // weights.insert(weights.begin(), 1.0);
tokenizer.pad_tokens(tokens, weights, max_length, padding); tokenizer.pad_tokens(tokens, weights, max_length, padding);
int offset = pm_version == VERSION_2 ? 2 * num_input_imgs : num_input_imgs; int offset = pm_version == PM_VERSION_2 ? 2 * num_input_imgs : num_input_imgs;
for (uint32_t i = 0; i < tokens.size(); i++) { for (uint32_t i = 0; i < tokens.size(); i++) {
// if (class_idx + 1 <= i && i < class_idx + 1 + 2*num_input_imgs) // photomaker V2 has num_tokens(=2)*num_input_imgs // if (class_idx + 1 <= i && i < class_idx + 1 + 2*num_input_imgs) // photomaker V2 has num_tokens(=2)*num_input_imgs
if (class_idx + 1 <= i && i < class_idx + 1 + offset) // photomaker V2 has num_tokens(=2)*num_input_imgs if (class_idx + 1 <= i && i < class_idx + 1 + offset) // photomaker V2 has num_tokens(=2)*num_input_imgs

25
model.h
View File

@ -31,9 +31,30 @@ enum SDVersion {
VERSION_COUNT, VERSION_COUNT,
}; };
static inline bool sd_version_is_flux(SDVersion version) {
if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) {
return true;
}
return false;
}
static inline bool sd_version_is_sd3(SDVersion version) {
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B) {
return true;
}
return false;
}
static inline bool sd_version_is_dit(SDVersion version) {
if (sd_version_is_flux(version) || sd_version_is_sd3(version)) {
return true;
}
return false;
}
enum PMVersion { enum PMVersion {
VERSION_1, PM_VERSION_1,
VERSION_2, PM_VERSION_2,
}; };
struct TensorStorage { struct TensorStorage {

View File

@ -608,7 +608,7 @@ struct PhotoMakerIDEncoder_CLIPInsightfaceExtendtokenBlock : public CLIPVisionMo
struct PhotoMakerIDEncoder : public GGMLRunner { struct PhotoMakerIDEncoder : public GGMLRunner {
public: public:
SDVersion version = VERSION_SDXL; SDVersion version = VERSION_SDXL;
PMVersion pm_version = VERSION_1; PMVersion pm_version = PM_VERSION_1;
PhotoMakerIDEncoderBlock id_encoder; PhotoMakerIDEncoderBlock id_encoder;
PhotoMakerIDEncoder_CLIPInsightfaceExtendtokenBlock id_encoder2; PhotoMakerIDEncoder_CLIPInsightfaceExtendtokenBlock id_encoder2;
float style_strength; float style_strength;
@ -623,14 +623,14 @@ public:
std::vector<float> zeros_right; std::vector<float> zeros_right;
public: public:
PhotoMakerIDEncoder(ggml_backend_t backend, ggml_type wtype, SDVersion version = VERSION_SDXL, PMVersion pm_v = VERSION_1, float sty = 20.f) PhotoMakerIDEncoder(ggml_backend_t backend, ggml_type wtype, SDVersion version = VERSION_SDXL, PMVersion pm_v = PM_VERSION_1, float sty = 20.f)
: GGMLRunner(backend, wtype), : GGMLRunner(backend, wtype),
version(version), version(version),
pm_version(pm_v), pm_version(pm_v),
style_strength(sty) { style_strength(sty) {
if (pm_version == VERSION_1) { if (pm_version == PM_VERSION_1) {
id_encoder.init(params_ctx, wtype); id_encoder.init(params_ctx, wtype);
} else if (pm_version == VERSION_2) { } else if (pm_version == PM_VERSION_2) {
id_encoder2.init(params_ctx, wtype); id_encoder2.init(params_ctx, wtype);
} }
} }
@ -644,9 +644,9 @@ public:
} }
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) { void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) {
if (pm_version == VERSION_1) if (pm_version == PM_VERSION_1)
id_encoder.get_param_tensors(tensors, prefix); id_encoder.get_param_tensors(tensors, prefix);
else if (pm_version == VERSION_2) else if (pm_version == PM_VERSION_2)
id_encoder2.get_param_tensors(tensors, prefix); id_encoder2.get_param_tensors(tensors, prefix);
} }
@ -734,14 +734,14 @@ public:
} }
} }
struct ggml_tensor* updated_prompt_embeds = NULL; struct ggml_tensor* updated_prompt_embeds = NULL;
if (pm_version == VERSION_1) if (pm_version == PM_VERSION_1)
updated_prompt_embeds = id_encoder.forward(ctx0, updated_prompt_embeds = id_encoder.forward(ctx0,
id_pixel_values_d, id_pixel_values_d,
prompt_embeds_d, prompt_embeds_d,
class_tokens_mask_d, class_tokens_mask_d,
class_tokens_mask_pos, class_tokens_mask_pos,
left, right); left, right);
else if (pm_version == VERSION_2) else if (pm_version == PM_VERSION_2)
updated_prompt_embeds = id_encoder2.forward(ctx0, updated_prompt_embeds = id_encoder2.forward(ctx0,
id_pixel_values_d, id_pixel_values_d,
prompt_embeds_d, prompt_embeds_d,

View File

@ -286,9 +286,9 @@ public:
"try specifying SDXL VAE FP16 Fix with the --vae parameter. " "try specifying SDXL VAE FP16 Fix with the --vae parameter. "
"You can find it here: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/blob/main/sdxl_vae.safetensors"); "You can find it here: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/blob/main/sdxl_vae.safetensors");
} }
} else if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B) { } else if (sd_version_is_sd3(version)) {
scale_factor = 1.5305f; scale_factor = 1.5305f;
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) { } else if (sd_version_is_flux(version)) {
scale_factor = 0.3611; scale_factor = 0.3611;
// TODO: shift_factor // TODO: shift_factor
} }
@ -309,7 +309,7 @@ public:
} else { } else {
clip_backend = backend; clip_backend = backend;
bool use_t5xxl = false; bool use_t5xxl = false;
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) { if (sd_version_is_dit(version)) {
use_t5xxl = true; use_t5xxl = true;
} }
if (!ggml_backend_is_cpu(backend) && use_t5xxl && conditioner_wtype != GGML_TYPE_F32) { if (!ggml_backend_is_cpu(backend) && use_t5xxl && conditioner_wtype != GGML_TYPE_F32) {
@ -323,18 +323,18 @@ public:
if (diffusion_flash_attn) { if (diffusion_flash_attn) {
LOG_INFO("Using flash attention in the diffusion model"); LOG_INFO("Using flash attention in the diffusion model");
} }
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B) { if (sd_version_is_sd3(version)) {
if (diffusion_flash_attn) { if (diffusion_flash_attn) {
LOG_WARN("flash attention in this diffusion model is currently unsupported!"); LOG_WARN("flash attention in this diffusion model is currently unsupported!");
} }
cond_stage_model = std::make_shared<SD3CLIPEmbedder>(clip_backend, conditioner_wtype); cond_stage_model = std::make_shared<SD3CLIPEmbedder>(clip_backend, conditioner_wtype);
diffusion_model = std::make_shared<MMDiTModel>(backend, diffusion_model_wtype, version); diffusion_model = std::make_shared<MMDiTModel>(backend, diffusion_model_wtype, version);
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) { } else if (sd_version_is_flux(version)) {
cond_stage_model = std::make_shared<FluxCLIPEmbedder>(clip_backend, conditioner_wtype); cond_stage_model = std::make_shared<FluxCLIPEmbedder>(clip_backend, conditioner_wtype);
diffusion_model = std::make_shared<FluxModel>(backend, diffusion_model_wtype, version, diffusion_flash_attn); diffusion_model = std::make_shared<FluxModel>(backend, diffusion_model_wtype, version, diffusion_flash_attn);
} else { } else {
if (id_embeddings_path.find("v2") != std::string::npos) { if (id_embeddings_path.find("v2") != std::string::npos) {
cond_stage_model = std::make_shared<FrozenCLIPEmbedderWithCustomWords>(clip_backend, conditioner_wtype, embeddings_path, version, VERSION_2); cond_stage_model = std::make_shared<FrozenCLIPEmbedderWithCustomWords>(clip_backend, conditioner_wtype, embeddings_path, version, PM_VERSION_2);
} else { } else {
cond_stage_model = std::make_shared<FrozenCLIPEmbedderWithCustomWords>(clip_backend, conditioner_wtype, embeddings_path, version); cond_stage_model = std::make_shared<FrozenCLIPEmbedderWithCustomWords>(clip_backend, conditioner_wtype, embeddings_path, version);
} }
@ -373,7 +373,7 @@ public:
} }
if (id_embeddings_path.find("v2") != std::string::npos) { if (id_embeddings_path.find("v2") != std::string::npos) {
pmid_model = std::make_shared<PhotoMakerIDEncoder>(backend, model_wtype, version, VERSION_2); pmid_model = std::make_shared<PhotoMakerIDEncoder>(backend, model_wtype, version, PM_VERSION_2);
LOG_INFO("using PhotoMaker Version 2"); LOG_INFO("using PhotoMaker Version 2");
} else { } else {
pmid_model = std::make_shared<PhotoMakerIDEncoder>(backend, model_wtype, version); pmid_model = std::make_shared<PhotoMakerIDEncoder>(backend, model_wtype, version);
@ -527,10 +527,10 @@ public:
is_using_v_parameterization = true; is_using_v_parameterization = true;
} }
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B) { if (sd_version_is_sd3(version)) {
LOG_INFO("running in FLOW mode"); LOG_INFO("running in FLOW mode");
denoiser = std::make_shared<DiscreteFlowDenoiser>(); denoiser = std::make_shared<DiscreteFlowDenoiser>();
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) { } else if (sd_version_is_flux(version)) {
LOG_INFO("running in Flux FLOW mode"); LOG_INFO("running in Flux FLOW mode");
float shift = 1.15f; float shift = 1.15f;
if (version == VERSION_FLUX_SCHNELL) { if (version == VERSION_FLUX_SCHNELL) {
@ -804,7 +804,7 @@ public:
out_uncond = ggml_dup_tensor(work_ctx, x); out_uncond = ggml_dup_tensor(work_ctx, x);
} }
if (has_skiplayer) { if (has_skiplayer) {
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_2B || version == VERSION_SD3_5_8B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) { if (sd_version_is_dit(version)) {
out_skip = ggml_dup_tensor(work_ctx, x); out_skip = ggml_dup_tensor(work_ctx, x);
} else { } else {
has_skiplayer = false; has_skiplayer = false;
@ -995,9 +995,9 @@ public:
if (use_tiny_autoencoder) { if (use_tiny_autoencoder) {
C = 4; C = 4;
} else { } else {
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B) { if (sd_version_is_sd3(version)) {
C = 32; C = 32;
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) { } else if (sd_version_is_flux(version)) {
C = 32; C = 32;
} }
} }
@ -1214,7 +1214,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
} }
// preprocess input id images // preprocess input id images
std::vector<sd_image_t*> input_id_images; std::vector<sd_image_t*> input_id_images;
bool pmv2 = sd_ctx->sd->pmid_model->get_version() == VERSION_2; bool pmv2 = sd_ctx->sd->pmid_model->get_version() == PM_VERSION_2;
if (sd_ctx->sd->pmid_model && input_id_images_path.size() > 0) { if (sd_ctx->sd->pmid_model && input_id_images_path.size() > 0) {
std::vector<std::string> img_files = get_files_from_dir(input_id_images_path); std::vector<std::string> img_files = get_files_from_dir(input_id_images_path);
for (std::string img_file : img_files) { for (std::string img_file : img_files) {
@ -1343,9 +1343,9 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
// Sample // Sample
std::vector<struct ggml_tensor*> final_latents; // collect latents to decode std::vector<struct ggml_tensor*> final_latents; // collect latents to decode
int C = 4; int C = 4;
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B || sd_ctx->sd->version == VERSION_SD3_5_2B) { if (sd_version_is_sd3(sd_ctx->sd->version)) {
C = 16; C = 16;
} else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL || sd_ctx->sd->version == VERSION_FLUX_LITE) { } else if (sd_version_is_flux(sd_ctx->sd->version)) {
C = 16; C = 16;
} }
int W = width / 8; int W = width / 8;
@ -1464,10 +1464,10 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
struct ggml_init_params params; struct ggml_init_params params;
params.mem_size = static_cast<size_t>(10 * 1024 * 1024); // 10 MB params.mem_size = static_cast<size_t>(10 * 1024 * 1024); // 10 MB
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B || sd_ctx->sd->version == VERSION_SD3_5_2B) { if (sd_version_is_sd3(sd_ctx->sd->version)) {
params.mem_size *= 3; params.mem_size *= 3;
} }
if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL || sd_ctx->sd->version == VERSION_FLUX_LITE) { if (sd_version_is_flux(sd_ctx->sd->version)) {
params.mem_size *= 4; params.mem_size *= 4;
} }
if (sd_ctx->sd->stacked_id) { if (sd_ctx->sd->stacked_id) {
@ -1490,17 +1490,17 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
std::vector<float> sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps); std::vector<float> sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps);
int C = 4; int C = 4;
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B || sd_ctx->sd->version == VERSION_SD3_5_2B) { if (sd_version_is_sd3(sd_ctx->sd->version)) {
C = 16; C = 16;
} else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL || sd_ctx->sd->version == VERSION_FLUX_LITE) { } else if (sd_version_is_flux(sd_ctx->sd->version)) {
C = 16; C = 16;
} }
int W = width / 8; int W = width / 8;
int H = height / 8; int H = height / 8;
ggml_tensor* init_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, C, 1); ggml_tensor* init_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, C, 1);
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B || sd_ctx->sd->version == VERSION_SD3_5_2B) { if (sd_version_is_sd3(sd_ctx->sd->version)) {
ggml_set_f32(init_latent, 0.0609f); ggml_set_f32(init_latent, 0.0609f);
} else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL || sd_ctx->sd->version == VERSION_FLUX_LITE) { } else if (sd_version_is_flux(sd_ctx->sd->version)) {
ggml_set_f32(init_latent, 0.1159f); ggml_set_f32(init_latent, 0.1159f);
} else { } else {
ggml_set_f32(init_latent, 0.f); ggml_set_f32(init_latent, 0.f);
@ -1567,10 +1567,10 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
struct ggml_init_params params; struct ggml_init_params params;
params.mem_size = static_cast<size_t>(10 * 1024 * 1024); // 10 MB params.mem_size = static_cast<size_t>(10 * 1024 * 1024); // 10 MB
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B || sd_ctx->sd->version == VERSION_SD3_5_2B) { if (sd_version_is_sd3(sd_ctx->sd->version)) {
params.mem_size *= 2; params.mem_size *= 2;
} }
if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL || sd_ctx->sd->version == VERSION_FLUX_LITE) { if (sd_version_is_flux(sd_ctx->sd->version)) {
params.mem_size *= 3; params.mem_size *= 3;
} }
if (sd_ctx->sd->stacked_id) { if (sd_ctx->sd->stacked_id) {

View File

@ -459,7 +459,7 @@ public:
bool use_video_decoder = false, bool use_video_decoder = false,
SDVersion version = VERSION_SD1) SDVersion version = VERSION_SD1)
: decode_only(decode_only), use_video_decoder(use_video_decoder) { : decode_only(decode_only), use_video_decoder(use_video_decoder) {
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) { if (sd_version_is_dit(version)) {
dd_config.z_channels = 16; dd_config.z_channels = 16;
use_quant = false; use_quant = false;
} }