refactor: add some sd vesion helper functions
This commit is contained in:
parent
1c168d98a5
commit
b5f4932696
@ -44,7 +44,7 @@ struct Conditioner {
|
|||||||
// Ref: https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/cad87bf4e3e0b0a759afa94e933527c3123d59bc/modules/sd_hijack_clip.py#L283
|
// Ref: https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/cad87bf4e3e0b0a759afa94e933527c3123d59bc/modules/sd_hijack_clip.py#L283
|
||||||
struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
|
struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
|
||||||
SDVersion version = VERSION_SD1;
|
SDVersion version = VERSION_SD1;
|
||||||
PMVersion pm_version = VERSION_1;
|
PMVersion pm_version = PM_VERSION_1;
|
||||||
CLIPTokenizer tokenizer;
|
CLIPTokenizer tokenizer;
|
||||||
ggml_type wtype;
|
ggml_type wtype;
|
||||||
std::shared_ptr<CLIPTextModelRunner> text_model;
|
std::shared_ptr<CLIPTextModelRunner> text_model;
|
||||||
@ -60,7 +60,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
|
|||||||
ggml_type wtype,
|
ggml_type wtype,
|
||||||
const std::string& embd_dir,
|
const std::string& embd_dir,
|
||||||
SDVersion version = VERSION_SD1,
|
SDVersion version = VERSION_SD1,
|
||||||
PMVersion pv = VERSION_1,
|
PMVersion pv = PM_VERSION_1,
|
||||||
int clip_skip = -1)
|
int clip_skip = -1)
|
||||||
: version(version), pm_version(pv), tokenizer(version == VERSION_SD2 ? 0 : 49407), embd_dir(embd_dir), wtype(wtype) {
|
: version(version), pm_version(pv), tokenizer(version == VERSION_SD2 ? 0 : 49407), embd_dir(embd_dir), wtype(wtype) {
|
||||||
if (clip_skip <= 0) {
|
if (clip_skip <= 0) {
|
||||||
@ -270,7 +270,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
|
|||||||
std::vector<int> clean_input_ids_tmp;
|
std::vector<int> clean_input_ids_tmp;
|
||||||
for (uint32_t i = 0; i < class_token_index[0]; i++)
|
for (uint32_t i = 0; i < class_token_index[0]; i++)
|
||||||
clean_input_ids_tmp.push_back(clean_input_ids[i]);
|
clean_input_ids_tmp.push_back(clean_input_ids[i]);
|
||||||
for (uint32_t i = 0; i < (pm_version == VERSION_2 ? 2 * num_input_imgs : num_input_imgs); i++)
|
for (uint32_t i = 0; i < (pm_version == PM_VERSION_2 ? 2 * num_input_imgs : num_input_imgs); i++)
|
||||||
clean_input_ids_tmp.push_back(class_token);
|
clean_input_ids_tmp.push_back(class_token);
|
||||||
for (uint32_t i = class_token_index[0] + 1; i < clean_input_ids.size(); i++)
|
for (uint32_t i = class_token_index[0] + 1; i < clean_input_ids.size(); i++)
|
||||||
clean_input_ids_tmp.push_back(clean_input_ids[i]);
|
clean_input_ids_tmp.push_back(clean_input_ids[i]);
|
||||||
@ -286,7 +286,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
|
|||||||
// weights.insert(weights.begin(), 1.0);
|
// weights.insert(weights.begin(), 1.0);
|
||||||
|
|
||||||
tokenizer.pad_tokens(tokens, weights, max_length, padding);
|
tokenizer.pad_tokens(tokens, weights, max_length, padding);
|
||||||
int offset = pm_version == VERSION_2 ? 2 * num_input_imgs : num_input_imgs;
|
int offset = pm_version == PM_VERSION_2 ? 2 * num_input_imgs : num_input_imgs;
|
||||||
for (uint32_t i = 0; i < tokens.size(); i++) {
|
for (uint32_t i = 0; i < tokens.size(); i++) {
|
||||||
// if (class_idx + 1 <= i && i < class_idx + 1 + 2*num_input_imgs) // photomaker V2 has num_tokens(=2)*num_input_imgs
|
// if (class_idx + 1 <= i && i < class_idx + 1 + 2*num_input_imgs) // photomaker V2 has num_tokens(=2)*num_input_imgs
|
||||||
if (class_idx + 1 <= i && i < class_idx + 1 + offset) // photomaker V2 has num_tokens(=2)*num_input_imgs
|
if (class_idx + 1 <= i && i < class_idx + 1 + offset) // photomaker V2 has num_tokens(=2)*num_input_imgs
|
||||||
|
25
model.h
25
model.h
@ -31,9 +31,30 @@ enum SDVersion {
|
|||||||
VERSION_COUNT,
|
VERSION_COUNT,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
static inline bool sd_version_is_flux(SDVersion version) {
|
||||||
|
if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline bool sd_version_is_sd3(SDVersion version) {
|
||||||
|
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline bool sd_version_is_dit(SDVersion version) {
|
||||||
|
if (sd_version_is_flux(version) || sd_version_is_sd3(version)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
enum PMVersion {
|
enum PMVersion {
|
||||||
VERSION_1,
|
PM_VERSION_1,
|
||||||
VERSION_2,
|
PM_VERSION_2,
|
||||||
};
|
};
|
||||||
|
|
||||||
struct TensorStorage {
|
struct TensorStorage {
|
||||||
|
16
pmid.hpp
16
pmid.hpp
@ -608,7 +608,7 @@ struct PhotoMakerIDEncoder_CLIPInsightfaceExtendtokenBlock : public CLIPVisionMo
|
|||||||
struct PhotoMakerIDEncoder : public GGMLRunner {
|
struct PhotoMakerIDEncoder : public GGMLRunner {
|
||||||
public:
|
public:
|
||||||
SDVersion version = VERSION_SDXL;
|
SDVersion version = VERSION_SDXL;
|
||||||
PMVersion pm_version = VERSION_1;
|
PMVersion pm_version = PM_VERSION_1;
|
||||||
PhotoMakerIDEncoderBlock id_encoder;
|
PhotoMakerIDEncoderBlock id_encoder;
|
||||||
PhotoMakerIDEncoder_CLIPInsightfaceExtendtokenBlock id_encoder2;
|
PhotoMakerIDEncoder_CLIPInsightfaceExtendtokenBlock id_encoder2;
|
||||||
float style_strength;
|
float style_strength;
|
||||||
@ -623,14 +623,14 @@ public:
|
|||||||
std::vector<float> zeros_right;
|
std::vector<float> zeros_right;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
PhotoMakerIDEncoder(ggml_backend_t backend, ggml_type wtype, SDVersion version = VERSION_SDXL, PMVersion pm_v = VERSION_1, float sty = 20.f)
|
PhotoMakerIDEncoder(ggml_backend_t backend, ggml_type wtype, SDVersion version = VERSION_SDXL, PMVersion pm_v = PM_VERSION_1, float sty = 20.f)
|
||||||
: GGMLRunner(backend, wtype),
|
: GGMLRunner(backend, wtype),
|
||||||
version(version),
|
version(version),
|
||||||
pm_version(pm_v),
|
pm_version(pm_v),
|
||||||
style_strength(sty) {
|
style_strength(sty) {
|
||||||
if (pm_version == VERSION_1) {
|
if (pm_version == PM_VERSION_1) {
|
||||||
id_encoder.init(params_ctx, wtype);
|
id_encoder.init(params_ctx, wtype);
|
||||||
} else if (pm_version == VERSION_2) {
|
} else if (pm_version == PM_VERSION_2) {
|
||||||
id_encoder2.init(params_ctx, wtype);
|
id_encoder2.init(params_ctx, wtype);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -644,9 +644,9 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) {
|
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) {
|
||||||
if (pm_version == VERSION_1)
|
if (pm_version == PM_VERSION_1)
|
||||||
id_encoder.get_param_tensors(tensors, prefix);
|
id_encoder.get_param_tensors(tensors, prefix);
|
||||||
else if (pm_version == VERSION_2)
|
else if (pm_version == PM_VERSION_2)
|
||||||
id_encoder2.get_param_tensors(tensors, prefix);
|
id_encoder2.get_param_tensors(tensors, prefix);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -734,14 +734,14 @@ public:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
struct ggml_tensor* updated_prompt_embeds = NULL;
|
struct ggml_tensor* updated_prompt_embeds = NULL;
|
||||||
if (pm_version == VERSION_1)
|
if (pm_version == PM_VERSION_1)
|
||||||
updated_prompt_embeds = id_encoder.forward(ctx0,
|
updated_prompt_embeds = id_encoder.forward(ctx0,
|
||||||
id_pixel_values_d,
|
id_pixel_values_d,
|
||||||
prompt_embeds_d,
|
prompt_embeds_d,
|
||||||
class_tokens_mask_d,
|
class_tokens_mask_d,
|
||||||
class_tokens_mask_pos,
|
class_tokens_mask_pos,
|
||||||
left, right);
|
left, right);
|
||||||
else if (pm_version == VERSION_2)
|
else if (pm_version == PM_VERSION_2)
|
||||||
updated_prompt_embeds = id_encoder2.forward(ctx0,
|
updated_prompt_embeds = id_encoder2.forward(ctx0,
|
||||||
id_pixel_values_d,
|
id_pixel_values_d,
|
||||||
prompt_embeds_d,
|
prompt_embeds_d,
|
||||||
|
@ -286,9 +286,9 @@ public:
|
|||||||
"try specifying SDXL VAE FP16 Fix with the --vae parameter. "
|
"try specifying SDXL VAE FP16 Fix with the --vae parameter. "
|
||||||
"You can find it here: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/blob/main/sdxl_vae.safetensors");
|
"You can find it here: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/blob/main/sdxl_vae.safetensors");
|
||||||
}
|
}
|
||||||
} else if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B) {
|
} else if (sd_version_is_sd3(version)) {
|
||||||
scale_factor = 1.5305f;
|
scale_factor = 1.5305f;
|
||||||
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) {
|
} else if (sd_version_is_flux(version)) {
|
||||||
scale_factor = 0.3611;
|
scale_factor = 0.3611;
|
||||||
// TODO: shift_factor
|
// TODO: shift_factor
|
||||||
}
|
}
|
||||||
@ -309,7 +309,7 @@ public:
|
|||||||
} else {
|
} else {
|
||||||
clip_backend = backend;
|
clip_backend = backend;
|
||||||
bool use_t5xxl = false;
|
bool use_t5xxl = false;
|
||||||
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) {
|
if (sd_version_is_dit(version)) {
|
||||||
use_t5xxl = true;
|
use_t5xxl = true;
|
||||||
}
|
}
|
||||||
if (!ggml_backend_is_cpu(backend) && use_t5xxl && conditioner_wtype != GGML_TYPE_F32) {
|
if (!ggml_backend_is_cpu(backend) && use_t5xxl && conditioner_wtype != GGML_TYPE_F32) {
|
||||||
@ -323,18 +323,18 @@ public:
|
|||||||
if (diffusion_flash_attn) {
|
if (diffusion_flash_attn) {
|
||||||
LOG_INFO("Using flash attention in the diffusion model");
|
LOG_INFO("Using flash attention in the diffusion model");
|
||||||
}
|
}
|
||||||
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B) {
|
if (sd_version_is_sd3(version)) {
|
||||||
if (diffusion_flash_attn) {
|
if (diffusion_flash_attn) {
|
||||||
LOG_WARN("flash attention in this diffusion model is currently unsupported!");
|
LOG_WARN("flash attention in this diffusion model is currently unsupported!");
|
||||||
}
|
}
|
||||||
cond_stage_model = std::make_shared<SD3CLIPEmbedder>(clip_backend, conditioner_wtype);
|
cond_stage_model = std::make_shared<SD3CLIPEmbedder>(clip_backend, conditioner_wtype);
|
||||||
diffusion_model = std::make_shared<MMDiTModel>(backend, diffusion_model_wtype, version);
|
diffusion_model = std::make_shared<MMDiTModel>(backend, diffusion_model_wtype, version);
|
||||||
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) {
|
} else if (sd_version_is_flux(version)) {
|
||||||
cond_stage_model = std::make_shared<FluxCLIPEmbedder>(clip_backend, conditioner_wtype);
|
cond_stage_model = std::make_shared<FluxCLIPEmbedder>(clip_backend, conditioner_wtype);
|
||||||
diffusion_model = std::make_shared<FluxModel>(backend, diffusion_model_wtype, version, diffusion_flash_attn);
|
diffusion_model = std::make_shared<FluxModel>(backend, diffusion_model_wtype, version, diffusion_flash_attn);
|
||||||
} else {
|
} else {
|
||||||
if (id_embeddings_path.find("v2") != std::string::npos) {
|
if (id_embeddings_path.find("v2") != std::string::npos) {
|
||||||
cond_stage_model = std::make_shared<FrozenCLIPEmbedderWithCustomWords>(clip_backend, conditioner_wtype, embeddings_path, version, VERSION_2);
|
cond_stage_model = std::make_shared<FrozenCLIPEmbedderWithCustomWords>(clip_backend, conditioner_wtype, embeddings_path, version, PM_VERSION_2);
|
||||||
} else {
|
} else {
|
||||||
cond_stage_model = std::make_shared<FrozenCLIPEmbedderWithCustomWords>(clip_backend, conditioner_wtype, embeddings_path, version);
|
cond_stage_model = std::make_shared<FrozenCLIPEmbedderWithCustomWords>(clip_backend, conditioner_wtype, embeddings_path, version);
|
||||||
}
|
}
|
||||||
@ -373,7 +373,7 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (id_embeddings_path.find("v2") != std::string::npos) {
|
if (id_embeddings_path.find("v2") != std::string::npos) {
|
||||||
pmid_model = std::make_shared<PhotoMakerIDEncoder>(backend, model_wtype, version, VERSION_2);
|
pmid_model = std::make_shared<PhotoMakerIDEncoder>(backend, model_wtype, version, PM_VERSION_2);
|
||||||
LOG_INFO("using PhotoMaker Version 2");
|
LOG_INFO("using PhotoMaker Version 2");
|
||||||
} else {
|
} else {
|
||||||
pmid_model = std::make_shared<PhotoMakerIDEncoder>(backend, model_wtype, version);
|
pmid_model = std::make_shared<PhotoMakerIDEncoder>(backend, model_wtype, version);
|
||||||
@ -527,10 +527,10 @@ public:
|
|||||||
is_using_v_parameterization = true;
|
is_using_v_parameterization = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B) {
|
if (sd_version_is_sd3(version)) {
|
||||||
LOG_INFO("running in FLOW mode");
|
LOG_INFO("running in FLOW mode");
|
||||||
denoiser = std::make_shared<DiscreteFlowDenoiser>();
|
denoiser = std::make_shared<DiscreteFlowDenoiser>();
|
||||||
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) {
|
} else if (sd_version_is_flux(version)) {
|
||||||
LOG_INFO("running in Flux FLOW mode");
|
LOG_INFO("running in Flux FLOW mode");
|
||||||
float shift = 1.15f;
|
float shift = 1.15f;
|
||||||
if (version == VERSION_FLUX_SCHNELL) {
|
if (version == VERSION_FLUX_SCHNELL) {
|
||||||
@ -804,7 +804,7 @@ public:
|
|||||||
out_uncond = ggml_dup_tensor(work_ctx, x);
|
out_uncond = ggml_dup_tensor(work_ctx, x);
|
||||||
}
|
}
|
||||||
if (has_skiplayer) {
|
if (has_skiplayer) {
|
||||||
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_2B || version == VERSION_SD3_5_8B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
|
if (sd_version_is_dit(version)) {
|
||||||
out_skip = ggml_dup_tensor(work_ctx, x);
|
out_skip = ggml_dup_tensor(work_ctx, x);
|
||||||
} else {
|
} else {
|
||||||
has_skiplayer = false;
|
has_skiplayer = false;
|
||||||
@ -995,9 +995,9 @@ public:
|
|||||||
if (use_tiny_autoencoder) {
|
if (use_tiny_autoencoder) {
|
||||||
C = 4;
|
C = 4;
|
||||||
} else {
|
} else {
|
||||||
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B) {
|
if (sd_version_is_sd3(version)) {
|
||||||
C = 32;
|
C = 32;
|
||||||
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) {
|
} else if (sd_version_is_flux(version)) {
|
||||||
C = 32;
|
C = 32;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1214,7 +1214,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
|
|||||||
}
|
}
|
||||||
// preprocess input id images
|
// preprocess input id images
|
||||||
std::vector<sd_image_t*> input_id_images;
|
std::vector<sd_image_t*> input_id_images;
|
||||||
bool pmv2 = sd_ctx->sd->pmid_model->get_version() == VERSION_2;
|
bool pmv2 = sd_ctx->sd->pmid_model->get_version() == PM_VERSION_2;
|
||||||
if (sd_ctx->sd->pmid_model && input_id_images_path.size() > 0) {
|
if (sd_ctx->sd->pmid_model && input_id_images_path.size() > 0) {
|
||||||
std::vector<std::string> img_files = get_files_from_dir(input_id_images_path);
|
std::vector<std::string> img_files = get_files_from_dir(input_id_images_path);
|
||||||
for (std::string img_file : img_files) {
|
for (std::string img_file : img_files) {
|
||||||
@ -1343,9 +1343,9 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
|
|||||||
// Sample
|
// Sample
|
||||||
std::vector<struct ggml_tensor*> final_latents; // collect latents to decode
|
std::vector<struct ggml_tensor*> final_latents; // collect latents to decode
|
||||||
int C = 4;
|
int C = 4;
|
||||||
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B || sd_ctx->sd->version == VERSION_SD3_5_2B) {
|
if (sd_version_is_sd3(sd_ctx->sd->version)) {
|
||||||
C = 16;
|
C = 16;
|
||||||
} else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL || sd_ctx->sd->version == VERSION_FLUX_LITE) {
|
} else if (sd_version_is_flux(sd_ctx->sd->version)) {
|
||||||
C = 16;
|
C = 16;
|
||||||
}
|
}
|
||||||
int W = width / 8;
|
int W = width / 8;
|
||||||
@ -1464,10 +1464,10 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
|
|||||||
|
|
||||||
struct ggml_init_params params;
|
struct ggml_init_params params;
|
||||||
params.mem_size = static_cast<size_t>(10 * 1024 * 1024); // 10 MB
|
params.mem_size = static_cast<size_t>(10 * 1024 * 1024); // 10 MB
|
||||||
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B || sd_ctx->sd->version == VERSION_SD3_5_2B) {
|
if (sd_version_is_sd3(sd_ctx->sd->version)) {
|
||||||
params.mem_size *= 3;
|
params.mem_size *= 3;
|
||||||
}
|
}
|
||||||
if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL || sd_ctx->sd->version == VERSION_FLUX_LITE) {
|
if (sd_version_is_flux(sd_ctx->sd->version)) {
|
||||||
params.mem_size *= 4;
|
params.mem_size *= 4;
|
||||||
}
|
}
|
||||||
if (sd_ctx->sd->stacked_id) {
|
if (sd_ctx->sd->stacked_id) {
|
||||||
@ -1490,17 +1490,17 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
|
|||||||
std::vector<float> sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps);
|
std::vector<float> sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps);
|
||||||
|
|
||||||
int C = 4;
|
int C = 4;
|
||||||
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B || sd_ctx->sd->version == VERSION_SD3_5_2B) {
|
if (sd_version_is_sd3(sd_ctx->sd->version)) {
|
||||||
C = 16;
|
C = 16;
|
||||||
} else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL || sd_ctx->sd->version == VERSION_FLUX_LITE) {
|
} else if (sd_version_is_flux(sd_ctx->sd->version)) {
|
||||||
C = 16;
|
C = 16;
|
||||||
}
|
}
|
||||||
int W = width / 8;
|
int W = width / 8;
|
||||||
int H = height / 8;
|
int H = height / 8;
|
||||||
ggml_tensor* init_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, C, 1);
|
ggml_tensor* init_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, C, 1);
|
||||||
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B || sd_ctx->sd->version == VERSION_SD3_5_2B) {
|
if (sd_version_is_sd3(sd_ctx->sd->version)) {
|
||||||
ggml_set_f32(init_latent, 0.0609f);
|
ggml_set_f32(init_latent, 0.0609f);
|
||||||
} else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL || sd_ctx->sd->version == VERSION_FLUX_LITE) {
|
} else if (sd_version_is_flux(sd_ctx->sd->version)) {
|
||||||
ggml_set_f32(init_latent, 0.1159f);
|
ggml_set_f32(init_latent, 0.1159f);
|
||||||
} else {
|
} else {
|
||||||
ggml_set_f32(init_latent, 0.f);
|
ggml_set_f32(init_latent, 0.f);
|
||||||
@ -1567,10 +1567,10 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
|
|||||||
|
|
||||||
struct ggml_init_params params;
|
struct ggml_init_params params;
|
||||||
params.mem_size = static_cast<size_t>(10 * 1024 * 1024); // 10 MB
|
params.mem_size = static_cast<size_t>(10 * 1024 * 1024); // 10 MB
|
||||||
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B || sd_ctx->sd->version == VERSION_SD3_5_2B) {
|
if (sd_version_is_sd3(sd_ctx->sd->version)) {
|
||||||
params.mem_size *= 2;
|
params.mem_size *= 2;
|
||||||
}
|
}
|
||||||
if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL || sd_ctx->sd->version == VERSION_FLUX_LITE) {
|
if (sd_version_is_flux(sd_ctx->sd->version)) {
|
||||||
params.mem_size *= 3;
|
params.mem_size *= 3;
|
||||||
}
|
}
|
||||||
if (sd_ctx->sd->stacked_id) {
|
if (sd_ctx->sd->stacked_id) {
|
||||||
|
2
vae.hpp
2
vae.hpp
@ -459,7 +459,7 @@ public:
|
|||||||
bool use_video_decoder = false,
|
bool use_video_decoder = false,
|
||||||
SDVersion version = VERSION_SD1)
|
SDVersion version = VERSION_SD1)
|
||||||
: decode_only(decode_only), use_video_decoder(use_video_decoder) {
|
: decode_only(decode_only), use_video_decoder(use_video_decoder) {
|
||||||
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) {
|
if (sd_version_is_dit(version)) {
|
||||||
dd_config.z_channels = 16;
|
dd_config.z_channels = 16;
|
||||||
use_quant = false;
|
use_quant = false;
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user