feat: add flux 1 lite 8B (freepik) support (#474)
* Flux Lite (Freepik) support * format code --------- Co-authored-by: leejet <leejet714@gmail.com>
This commit is contained in:
parent
9b1d90bc23
commit
6ea812256e
2
clip.hpp
2
clip.hpp
@ -712,7 +712,7 @@ public:
|
||||
auto text_projection = params["text_projection"];
|
||||
ggml_tensor* pooled = ggml_view_1d(ctx, x, hidden_size, x->nb[1] * max_token_idx);
|
||||
if (text_projection != NULL) {
|
||||
pooled = ggml_nn_linear(ctx, pooled, text_projection, NULL);
|
||||
pooled = ggml_nn_linear(ctx, pooled, text_projection, NULL);
|
||||
} else {
|
||||
LOG_DEBUG("Missing text_projection matrix, assuming identity...");
|
||||
}
|
||||
|
@ -798,7 +798,7 @@ struct SD3CLIPEmbedder : public Conditioner {
|
||||
}
|
||||
|
||||
if (chunk_idx == 0) {
|
||||
auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_l_tokenizer.EOS_TOKEN_ID);
|
||||
auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_l_tokenizer.EOS_TOKEN_ID);
|
||||
max_token_idx = std::min<size_t>(std::distance(chunk_tokens.begin(), it), chunk_tokens.size() - 1);
|
||||
clip_l->compute(n_threads,
|
||||
input_ids,
|
||||
@ -808,7 +808,6 @@ struct SD3CLIPEmbedder : public Conditioner {
|
||||
true,
|
||||
&pooled_l,
|
||||
work_ctx);
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@ -848,7 +847,7 @@ struct SD3CLIPEmbedder : public Conditioner {
|
||||
}
|
||||
|
||||
if (chunk_idx == 0) {
|
||||
auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_g_tokenizer.EOS_TOKEN_ID);
|
||||
auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_g_tokenizer.EOS_TOKEN_ID);
|
||||
max_token_idx = std::min<size_t>(std::distance(chunk_tokens.begin(), it), chunk_tokens.size() - 1);
|
||||
clip_g->compute(n_threads,
|
||||
input_ids,
|
||||
@ -858,7 +857,6 @@ struct SD3CLIPEmbedder : public Conditioner {
|
||||
true,
|
||||
&pooled_g,
|
||||
work_ctx);
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@ -1096,9 +1094,9 @@ struct FluxCLIPEmbedder : public Conditioner {
|
||||
auto input_ids = vector_to_ggml_tensor_i32(work_ctx, chunk_tokens);
|
||||
size_t max_token_idx = 0;
|
||||
|
||||
auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_l_tokenizer.EOS_TOKEN_ID);
|
||||
auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_l_tokenizer.EOS_TOKEN_ID);
|
||||
max_token_idx = std::min<size_t>(std::distance(chunk_tokens.begin(), it), chunk_tokens.size() - 1);
|
||||
|
||||
|
||||
clip_l->compute(n_threads,
|
||||
input_ids,
|
||||
0,
|
||||
@ -1107,7 +1105,6 @@ struct FluxCLIPEmbedder : public Conditioner {
|
||||
true,
|
||||
&pooled,
|
||||
work_ctx);
|
||||
|
||||
}
|
||||
|
||||
// t5
|
||||
|
3
flux.hpp
3
flux.hpp
@ -822,6 +822,9 @@ namespace Flux {
|
||||
if (version == VERSION_FLUX_SCHNELL) {
|
||||
flux_params.guidance_embed = false;
|
||||
}
|
||||
if (version == VERSION_FLUX_LITE) {
|
||||
flux_params.depth = 8;
|
||||
}
|
||||
flux = Flux(flux_params);
|
||||
flux.init(params_ctx, wtype);
|
||||
}
|
||||
|
20
model.cpp
20
model.cpp
@ -1364,15 +1364,20 @@ bool ModelLoader::init_from_ckpt_file(const std::string& file_path, const std::s
|
||||
|
||||
SDVersion ModelLoader::get_sd_version() {
|
||||
TensorStorage token_embedding_weight;
|
||||
bool is_flux = false;
|
||||
bool is_sd3 = false;
|
||||
bool is_flux = false;
|
||||
bool is_schnell = true;
|
||||
bool is_lite = true;
|
||||
bool is_sd3 = false;
|
||||
for (auto& tensor_storage : tensor_storages) {
|
||||
if (tensor_storage.name.find("model.diffusion_model.guidance_in.in_layer.weight") != std::string::npos) {
|
||||
return VERSION_FLUX_DEV;
|
||||
is_schnell = false;
|
||||
}
|
||||
if (tensor_storage.name.find("model.diffusion_model.double_blocks.") != std::string::npos) {
|
||||
is_flux = true;
|
||||
}
|
||||
if (tensor_storage.name.find("model.diffusion_model.double_blocks.8") != std::string::npos) {
|
||||
is_lite = false;
|
||||
}
|
||||
if (tensor_storage.name.find("joint_blocks.0.x_block.attn2.ln_q.weight") != std::string::npos) {
|
||||
return VERSION_SD3_5_2B;
|
||||
}
|
||||
@ -1403,7 +1408,14 @@ SDVersion ModelLoader::get_sd_version() {
|
||||
}
|
||||
}
|
||||
if (is_flux) {
|
||||
return VERSION_FLUX_SCHNELL;
|
||||
if (is_schnell) {
|
||||
GGML_ASSERT(!is_lite);
|
||||
return VERSION_FLUX_SCHNELL;
|
||||
} else if (is_lite) {
|
||||
return VERSION_FLUX_LITE;
|
||||
} else {
|
||||
return VERSION_FLUX_DEV;
|
||||
}
|
||||
}
|
||||
if (is_sd3) {
|
||||
return VERSION_SD3_2B;
|
||||
|
1
model.h
1
model.h
@ -27,6 +27,7 @@ enum SDVersion {
|
||||
VERSION_FLUX_SCHNELL,
|
||||
VERSION_SD3_5_8B,
|
||||
VERSION_SD3_5_2B,
|
||||
VERSION_FLUX_LITE,
|
||||
VERSION_COUNT,
|
||||
};
|
||||
|
||||
|
@ -33,7 +33,8 @@ const char* model_version_to_str[] = {
|
||||
"Flux Dev",
|
||||
"Flux Schnell",
|
||||
"SD3.5 8B",
|
||||
"SD3.5 2B"};
|
||||
"SD3.5 2B",
|
||||
"Flux Lite 8B"};
|
||||
|
||||
const char* sampling_methods_str[] = {
|
||||
"Euler A",
|
||||
@ -291,7 +292,7 @@ public:
|
||||
}
|
||||
} else if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B) {
|
||||
scale_factor = 1.5305f;
|
||||
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
|
||||
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) {
|
||||
scale_factor = 0.3611;
|
||||
// TODO: shift_factor
|
||||
}
|
||||
@ -312,7 +313,7 @@ public:
|
||||
} else {
|
||||
clip_backend = backend;
|
||||
bool use_t5xxl = false;
|
||||
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
|
||||
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) {
|
||||
use_t5xxl = true;
|
||||
}
|
||||
if (!ggml_backend_is_cpu(backend) && use_t5xxl && conditioner_wtype != GGML_TYPE_F32) {
|
||||
@ -326,7 +327,7 @@ public:
|
||||
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B) {
|
||||
cond_stage_model = std::make_shared<SD3CLIPEmbedder>(clip_backend, conditioner_wtype);
|
||||
diffusion_model = std::make_shared<MMDiTModel>(backend, diffusion_model_wtype, version);
|
||||
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
|
||||
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) {
|
||||
cond_stage_model = std::make_shared<FluxCLIPEmbedder>(clip_backend, conditioner_wtype);
|
||||
diffusion_model = std::make_shared<FluxModel>(backend, diffusion_model_wtype, version);
|
||||
} else {
|
||||
@ -524,7 +525,7 @@ public:
|
||||
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B) {
|
||||
LOG_INFO("running in FLOW mode");
|
||||
denoiser = std::make_shared<DiscreteFlowDenoiser>();
|
||||
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
|
||||
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) {
|
||||
LOG_INFO("running in Flux FLOW mode");
|
||||
float shift = 1.15f;
|
||||
if (version == VERSION_FLUX_SCHNELL) {
|
||||
@ -991,7 +992,7 @@ public:
|
||||
} else {
|
||||
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B) {
|
||||
C = 32;
|
||||
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
|
||||
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) {
|
||||
C = 32;
|
||||
}
|
||||
}
|
||||
@ -1328,7 +1329,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
|
||||
int C = 4;
|
||||
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B || sd_ctx->sd->version == VERSION_SD3_5_2B) {
|
||||
C = 16;
|
||||
} else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) {
|
||||
} else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL || sd_ctx->sd->version == VERSION_FLUX_LITE) {
|
||||
C = 16;
|
||||
}
|
||||
int W = width / 8;
|
||||
@ -1450,7 +1451,7 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
|
||||
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B || sd_ctx->sd->version == VERSION_SD3_5_2B) {
|
||||
params.mem_size *= 3;
|
||||
}
|
||||
if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) {
|
||||
if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL || sd_ctx->sd->version == VERSION_FLUX_LITE) {
|
||||
params.mem_size *= 4;
|
||||
}
|
||||
if (sd_ctx->sd->stacked_id) {
|
||||
@ -1475,7 +1476,7 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
|
||||
int C = 4;
|
||||
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B || sd_ctx->sd->version == VERSION_SD3_5_2B) {
|
||||
C = 16;
|
||||
} else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) {
|
||||
} else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL || sd_ctx->sd->version == VERSION_FLUX_LITE) {
|
||||
C = 16;
|
||||
}
|
||||
int W = width / 8;
|
||||
@ -1483,7 +1484,7 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
|
||||
ggml_tensor* init_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, C, 1);
|
||||
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B || sd_ctx->sd->version == VERSION_SD3_5_2B) {
|
||||
ggml_set_f32(init_latent, 0.0609f);
|
||||
} else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) {
|
||||
} else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL || sd_ctx->sd->version == VERSION_FLUX_LITE) {
|
||||
ggml_set_f32(init_latent, 0.1159f);
|
||||
} else {
|
||||
ggml_set_f32(init_latent, 0.f);
|
||||
@ -1553,7 +1554,7 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
|
||||
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B || sd_ctx->sd->version == VERSION_SD3_5_2B) {
|
||||
params.mem_size *= 2;
|
||||
}
|
||||
if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) {
|
||||
if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL || sd_ctx->sd->version == VERSION_FLUX_LITE) {
|
||||
params.mem_size *= 3;
|
||||
}
|
||||
if (sd_ctx->sd->stacked_id) {
|
||||
|
2
vae.hpp
2
vae.hpp
@ -457,7 +457,7 @@ public:
|
||||
bool use_video_decoder = false,
|
||||
SDVersion version = VERSION_SD1)
|
||||
: decode_only(decode_only), use_video_decoder(use_video_decoder) {
|
||||
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
|
||||
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) {
|
||||
dd_config.z_channels = 16;
|
||||
use_quant = false;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user