diff --git a/conditioner.hpp b/conditioner.hpp index 47fd3eb..9a63009 100644 --- a/conditioner.hpp +++ b/conditioner.hpp @@ -44,7 +44,7 @@ struct Conditioner { // Ref: https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/cad87bf4e3e0b0a759afa94e933527c3123d59bc/modules/sd_hijack_clip.py#L283 struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { SDVersion version = VERSION_SD1; - PMVersion pm_version = VERSION_1; + PMVersion pm_version = PM_VERSION_1; CLIPTokenizer tokenizer; ggml_type wtype; std::shared_ptr text_model; @@ -60,7 +60,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { ggml_type wtype, const std::string& embd_dir, SDVersion version = VERSION_SD1, - PMVersion pv = VERSION_1, + PMVersion pv = PM_VERSION_1, int clip_skip = -1) : version(version), pm_version(pv), tokenizer(version == VERSION_SD2 ? 0 : 49407), embd_dir(embd_dir), wtype(wtype) { if (clip_skip <= 0) { @@ -270,7 +270,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { std::vector clean_input_ids_tmp; for (uint32_t i = 0; i < class_token_index[0]; i++) clean_input_ids_tmp.push_back(clean_input_ids[i]); - for (uint32_t i = 0; i < (pm_version == VERSION_2 ? 2 * num_input_imgs : num_input_imgs); i++) + for (uint32_t i = 0; i < (pm_version == PM_VERSION_2 ? 2 * num_input_imgs : num_input_imgs); i++) clean_input_ids_tmp.push_back(class_token); for (uint32_t i = class_token_index[0] + 1; i < clean_input_ids.size(); i++) clean_input_ids_tmp.push_back(clean_input_ids[i]); @@ -286,7 +286,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { // weights.insert(weights.begin(), 1.0); tokenizer.pad_tokens(tokens, weights, max_length, padding); - int offset = pm_version == VERSION_2 ? 2 * num_input_imgs : num_input_imgs; + int offset = pm_version == PM_VERSION_2 ? 2 * num_input_imgs : num_input_imgs; for (uint32_t i = 0; i < tokens.size(); i++) { // if (class_idx + 1 <= i && i < class_idx + 1 + 2*num_input_imgs) // photomaker V2 has num_tokens(=2)*num_input_imgs if (class_idx + 1 <= i && i < class_idx + 1 + offset) // photomaker V2 has num_tokens(=2)*num_input_imgs diff --git a/model.h b/model.h index 552a2cc..b7e3b3a 100644 --- a/model.h +++ b/model.h @@ -31,9 +31,30 @@ enum SDVersion { VERSION_COUNT, }; +static inline bool sd_version_is_flux(SDVersion version) { + if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) { + return true; + } + return false; +} + +static inline bool sd_version_is_sd3(SDVersion version) { + if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B) { + return true; + } + return false; +} + +static inline bool sd_version_is_dit(SDVersion version) { + if (sd_version_is_flux(version) || sd_version_is_sd3(version)) { + return true; + } + return false; +} + enum PMVersion { - VERSION_1, - VERSION_2, + PM_VERSION_1, + PM_VERSION_2, }; struct TensorStorage { diff --git a/pmid.hpp b/pmid.hpp index b8555eb..defb4f0 100644 --- a/pmid.hpp +++ b/pmid.hpp @@ -608,7 +608,7 @@ struct PhotoMakerIDEncoder_CLIPInsightfaceExtendtokenBlock : public CLIPVisionMo struct PhotoMakerIDEncoder : public GGMLRunner { public: SDVersion version = VERSION_SDXL; - PMVersion pm_version = VERSION_1; + PMVersion pm_version = PM_VERSION_1; PhotoMakerIDEncoderBlock id_encoder; PhotoMakerIDEncoder_CLIPInsightfaceExtendtokenBlock id_encoder2; float style_strength; @@ -623,14 +623,14 @@ public: std::vector zeros_right; public: - PhotoMakerIDEncoder(ggml_backend_t backend, ggml_type wtype, SDVersion version = VERSION_SDXL, PMVersion pm_v = VERSION_1, float sty = 20.f) + PhotoMakerIDEncoder(ggml_backend_t backend, ggml_type wtype, SDVersion version = VERSION_SDXL, PMVersion pm_v = PM_VERSION_1, float sty = 20.f) : GGMLRunner(backend, wtype), version(version), pm_version(pm_v), style_strength(sty) { - if (pm_version == VERSION_1) { + if (pm_version == PM_VERSION_1) { id_encoder.init(params_ctx, wtype); - } else if (pm_version == VERSION_2) { + } else if (pm_version == PM_VERSION_2) { id_encoder2.init(params_ctx, wtype); } } @@ -644,9 +644,9 @@ public: } void get_param_tensors(std::map& tensors, const std::string prefix) { - if (pm_version == VERSION_1) + if (pm_version == PM_VERSION_1) id_encoder.get_param_tensors(tensors, prefix); - else if (pm_version == VERSION_2) + else if (pm_version == PM_VERSION_2) id_encoder2.get_param_tensors(tensors, prefix); } @@ -734,14 +734,14 @@ public: } } struct ggml_tensor* updated_prompt_embeds = NULL; - if (pm_version == VERSION_1) + if (pm_version == PM_VERSION_1) updated_prompt_embeds = id_encoder.forward(ctx0, id_pixel_values_d, prompt_embeds_d, class_tokens_mask_d, class_tokens_mask_pos, left, right); - else if (pm_version == VERSION_2) + else if (pm_version == PM_VERSION_2) updated_prompt_embeds = id_encoder2.forward(ctx0, id_pixel_values_d, prompt_embeds_d, diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index c722b65..5024b5f 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -286,9 +286,9 @@ public: "try specifying SDXL VAE FP16 Fix with the --vae parameter. " "You can find it here: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/blob/main/sdxl_vae.safetensors"); } - } else if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B) { + } else if (sd_version_is_sd3(version)) { scale_factor = 1.5305f; - } else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) { + } else if (sd_version_is_flux(version)) { scale_factor = 0.3611; // TODO: shift_factor } @@ -309,7 +309,7 @@ public: } else { clip_backend = backend; bool use_t5xxl = false; - if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) { + if (sd_version_is_dit(version)) { use_t5xxl = true; } if (!ggml_backend_is_cpu(backend) && use_t5xxl && conditioner_wtype != GGML_TYPE_F32) { @@ -323,18 +323,18 @@ public: if (diffusion_flash_attn) { LOG_INFO("Using flash attention in the diffusion model"); } - if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B) { + if (sd_version_is_sd3(version)) { if (diffusion_flash_attn) { LOG_WARN("flash attention in this diffusion model is currently unsupported!"); } cond_stage_model = std::make_shared(clip_backend, conditioner_wtype); diffusion_model = std::make_shared(backend, diffusion_model_wtype, version); - } else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) { + } else if (sd_version_is_flux(version)) { cond_stage_model = std::make_shared(clip_backend, conditioner_wtype); diffusion_model = std::make_shared(backend, diffusion_model_wtype, version, diffusion_flash_attn); } else { if (id_embeddings_path.find("v2") != std::string::npos) { - cond_stage_model = std::make_shared(clip_backend, conditioner_wtype, embeddings_path, version, VERSION_2); + cond_stage_model = std::make_shared(clip_backend, conditioner_wtype, embeddings_path, version, PM_VERSION_2); } else { cond_stage_model = std::make_shared(clip_backend, conditioner_wtype, embeddings_path, version); } @@ -373,7 +373,7 @@ public: } if (id_embeddings_path.find("v2") != std::string::npos) { - pmid_model = std::make_shared(backend, model_wtype, version, VERSION_2); + pmid_model = std::make_shared(backend, model_wtype, version, PM_VERSION_2); LOG_INFO("using PhotoMaker Version 2"); } else { pmid_model = std::make_shared(backend, model_wtype, version); @@ -527,10 +527,10 @@ public: is_using_v_parameterization = true; } - if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B) { + if (sd_version_is_sd3(version)) { LOG_INFO("running in FLOW mode"); denoiser = std::make_shared(); - } else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) { + } else if (sd_version_is_flux(version)) { LOG_INFO("running in Flux FLOW mode"); float shift = 1.15f; if (version == VERSION_FLUX_SCHNELL) { @@ -804,7 +804,7 @@ public: out_uncond = ggml_dup_tensor(work_ctx, x); } if (has_skiplayer) { - if (version == VERSION_SD3_2B || version == VERSION_SD3_5_2B || version == VERSION_SD3_5_8B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) { + if (sd_version_is_dit(version)) { out_skip = ggml_dup_tensor(work_ctx, x); } else { has_skiplayer = false; @@ -995,9 +995,9 @@ public: if (use_tiny_autoencoder) { C = 4; } else { - if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B) { + if (sd_version_is_sd3(version)) { C = 32; - } else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) { + } else if (sd_version_is_flux(version)) { C = 32; } } @@ -1214,7 +1214,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, } // preprocess input id images std::vector input_id_images; - bool pmv2 = sd_ctx->sd->pmid_model->get_version() == VERSION_2; + bool pmv2 = sd_ctx->sd->pmid_model->get_version() == PM_VERSION_2; if (sd_ctx->sd->pmid_model && input_id_images_path.size() > 0) { std::vector img_files = get_files_from_dir(input_id_images_path); for (std::string img_file : img_files) { @@ -1343,9 +1343,9 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, // Sample std::vector final_latents; // collect latents to decode int C = 4; - if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B || sd_ctx->sd->version == VERSION_SD3_5_2B) { + if (sd_version_is_sd3(sd_ctx->sd->version)) { C = 16; - } else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL || sd_ctx->sd->version == VERSION_FLUX_LITE) { + } else if (sd_version_is_flux(sd_ctx->sd->version)) { C = 16; } int W = width / 8; @@ -1464,10 +1464,10 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx, struct ggml_init_params params; params.mem_size = static_cast(10 * 1024 * 1024); // 10 MB - if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B || sd_ctx->sd->version == VERSION_SD3_5_2B) { + if (sd_version_is_sd3(sd_ctx->sd->version)) { params.mem_size *= 3; } - if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL || sd_ctx->sd->version == VERSION_FLUX_LITE) { + if (sd_version_is_flux(sd_ctx->sd->version)) { params.mem_size *= 4; } if (sd_ctx->sd->stacked_id) { @@ -1490,17 +1490,17 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx, std::vector sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps); int C = 4; - if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B || sd_ctx->sd->version == VERSION_SD3_5_2B) { + if (sd_version_is_sd3(sd_ctx->sd->version)) { C = 16; - } else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL || sd_ctx->sd->version == VERSION_FLUX_LITE) { + } else if (sd_version_is_flux(sd_ctx->sd->version)) { C = 16; } int W = width / 8; int H = height / 8; ggml_tensor* init_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, C, 1); - if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B || sd_ctx->sd->version == VERSION_SD3_5_2B) { + if (sd_version_is_sd3(sd_ctx->sd->version)) { ggml_set_f32(init_latent, 0.0609f); - } else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL || sd_ctx->sd->version == VERSION_FLUX_LITE) { + } else if (sd_version_is_flux(sd_ctx->sd->version)) { ggml_set_f32(init_latent, 0.1159f); } else { ggml_set_f32(init_latent, 0.f); @@ -1567,10 +1567,10 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx, struct ggml_init_params params; params.mem_size = static_cast(10 * 1024 * 1024); // 10 MB - if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B || sd_ctx->sd->version == VERSION_SD3_5_2B) { + if (sd_version_is_sd3(sd_ctx->sd->version)) { params.mem_size *= 2; } - if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL || sd_ctx->sd->version == VERSION_FLUX_LITE) { + if (sd_version_is_flux(sd_ctx->sd->version)) { params.mem_size *= 3; } if (sd_ctx->sd->stacked_id) { diff --git a/vae.hpp b/vae.hpp index c32846a..0c7d84f 100644 --- a/vae.hpp +++ b/vae.hpp @@ -459,7 +459,7 @@ public: bool use_video_decoder = false, SDVersion version = VERSION_SD1) : decode_only(decode_only), use_video_decoder(use_video_decoder) { - if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) { + if (sd_version_is_dit(version)) { dd_config.z_channels = 16; use_quant = false; }