feat: add flux 1 lite 8B (freepik) support (#474)

* Flux Lite (Freepik) support

* format code

---------

Co-authored-by: leejet <leejet714@gmail.com>
This commit is contained in:
stduhpf 2024-11-23 04:41:30 +01:00 committed by GitHub
parent 9b1d90bc23
commit 6ea812256e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 38 additions and 24 deletions

View File

@ -712,7 +712,7 @@ public:
auto text_projection = params["text_projection"];
ggml_tensor* pooled = ggml_view_1d(ctx, x, hidden_size, x->nb[1] * max_token_idx);
if (text_projection != NULL) {
pooled = ggml_nn_linear(ctx, pooled, text_projection, NULL);
pooled = ggml_nn_linear(ctx, pooled, text_projection, NULL);
} else {
LOG_DEBUG("Missing text_projection matrix, assuming identity...");
}

View File

@ -798,7 +798,7 @@ struct SD3CLIPEmbedder : public Conditioner {
}
if (chunk_idx == 0) {
auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_l_tokenizer.EOS_TOKEN_ID);
auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_l_tokenizer.EOS_TOKEN_ID);
max_token_idx = std::min<size_t>(std::distance(chunk_tokens.begin(), it), chunk_tokens.size() - 1);
clip_l->compute(n_threads,
input_ids,
@ -808,7 +808,6 @@ struct SD3CLIPEmbedder : public Conditioner {
true,
&pooled_l,
work_ctx);
}
}
@ -848,7 +847,7 @@ struct SD3CLIPEmbedder : public Conditioner {
}
if (chunk_idx == 0) {
auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_g_tokenizer.EOS_TOKEN_ID);
auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_g_tokenizer.EOS_TOKEN_ID);
max_token_idx = std::min<size_t>(std::distance(chunk_tokens.begin(), it), chunk_tokens.size() - 1);
clip_g->compute(n_threads,
input_ids,
@ -858,7 +857,6 @@ struct SD3CLIPEmbedder : public Conditioner {
true,
&pooled_g,
work_ctx);
}
}
@ -1096,9 +1094,9 @@ struct FluxCLIPEmbedder : public Conditioner {
auto input_ids = vector_to_ggml_tensor_i32(work_ctx, chunk_tokens);
size_t max_token_idx = 0;
auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_l_tokenizer.EOS_TOKEN_ID);
auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_l_tokenizer.EOS_TOKEN_ID);
max_token_idx = std::min<size_t>(std::distance(chunk_tokens.begin(), it), chunk_tokens.size() - 1);
clip_l->compute(n_threads,
input_ids,
0,
@ -1107,7 +1105,6 @@ struct FluxCLIPEmbedder : public Conditioner {
true,
&pooled,
work_ctx);
}
// t5

View File

@ -822,6 +822,9 @@ namespace Flux {
if (version == VERSION_FLUX_SCHNELL) {
flux_params.guidance_embed = false;
}
if (version == VERSION_FLUX_LITE) {
flux_params.depth = 8;
}
flux = Flux(flux_params);
flux.init(params_ctx, wtype);
}

View File

@ -1364,15 +1364,20 @@ bool ModelLoader::init_from_ckpt_file(const std::string& file_path, const std::s
SDVersion ModelLoader::get_sd_version() {
TensorStorage token_embedding_weight;
bool is_flux = false;
bool is_sd3 = false;
bool is_flux = false;
bool is_schnell = true;
bool is_lite = true;
bool is_sd3 = false;
for (auto& tensor_storage : tensor_storages) {
if (tensor_storage.name.find("model.diffusion_model.guidance_in.in_layer.weight") != std::string::npos) {
return VERSION_FLUX_DEV;
is_schnell = false;
}
if (tensor_storage.name.find("model.diffusion_model.double_blocks.") != std::string::npos) {
is_flux = true;
}
if (tensor_storage.name.find("model.diffusion_model.double_blocks.8") != std::string::npos) {
is_lite = false;
}
if (tensor_storage.name.find("joint_blocks.0.x_block.attn2.ln_q.weight") != std::string::npos) {
return VERSION_SD3_5_2B;
}
@ -1403,7 +1408,14 @@ SDVersion ModelLoader::get_sd_version() {
}
}
if (is_flux) {
return VERSION_FLUX_SCHNELL;
if (is_schnell) {
GGML_ASSERT(!is_lite);
return VERSION_FLUX_SCHNELL;
} else if (is_lite) {
return VERSION_FLUX_LITE;
} else {
return VERSION_FLUX_DEV;
}
}
if (is_sd3) {
return VERSION_SD3_2B;

View File

@ -27,6 +27,7 @@ enum SDVersion {
VERSION_FLUX_SCHNELL,
VERSION_SD3_5_8B,
VERSION_SD3_5_2B,
VERSION_FLUX_LITE,
VERSION_COUNT,
};

View File

@ -33,7 +33,8 @@ const char* model_version_to_str[] = {
"Flux Dev",
"Flux Schnell",
"SD3.5 8B",
"SD3.5 2B"};
"SD3.5 2B",
"Flux Lite 8B"};
const char* sampling_methods_str[] = {
"Euler A",
@ -291,7 +292,7 @@ public:
}
} else if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B) {
scale_factor = 1.5305f;
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) {
scale_factor = 0.3611;
// TODO: shift_factor
}
@ -312,7 +313,7 @@ public:
} else {
clip_backend = backend;
bool use_t5xxl = false;
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) {
use_t5xxl = true;
}
if (!ggml_backend_is_cpu(backend) && use_t5xxl && conditioner_wtype != GGML_TYPE_F32) {
@ -326,7 +327,7 @@ public:
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B) {
cond_stage_model = std::make_shared<SD3CLIPEmbedder>(clip_backend, conditioner_wtype);
diffusion_model = std::make_shared<MMDiTModel>(backend, diffusion_model_wtype, version);
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) {
cond_stage_model = std::make_shared<FluxCLIPEmbedder>(clip_backend, conditioner_wtype);
diffusion_model = std::make_shared<FluxModel>(backend, diffusion_model_wtype, version);
} else {
@ -524,7 +525,7 @@ public:
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B) {
LOG_INFO("running in FLOW mode");
denoiser = std::make_shared<DiscreteFlowDenoiser>();
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) {
LOG_INFO("running in Flux FLOW mode");
float shift = 1.15f;
if (version == VERSION_FLUX_SCHNELL) {
@ -991,7 +992,7 @@ public:
} else {
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B) {
C = 32;
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) {
C = 32;
}
}
@ -1328,7 +1329,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
int C = 4;
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B || sd_ctx->sd->version == VERSION_SD3_5_2B) {
C = 16;
} else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) {
} else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL || sd_ctx->sd->version == VERSION_FLUX_LITE) {
C = 16;
}
int W = width / 8;
@ -1450,7 +1451,7 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B || sd_ctx->sd->version == VERSION_SD3_5_2B) {
params.mem_size *= 3;
}
if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) {
if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL || sd_ctx->sd->version == VERSION_FLUX_LITE) {
params.mem_size *= 4;
}
if (sd_ctx->sd->stacked_id) {
@ -1475,7 +1476,7 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
int C = 4;
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B || sd_ctx->sd->version == VERSION_SD3_5_2B) {
C = 16;
} else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) {
} else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL || sd_ctx->sd->version == VERSION_FLUX_LITE) {
C = 16;
}
int W = width / 8;
@ -1483,7 +1484,7 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
ggml_tensor* init_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, C, 1);
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B || sd_ctx->sd->version == VERSION_SD3_5_2B) {
ggml_set_f32(init_latent, 0.0609f);
} else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) {
} else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL || sd_ctx->sd->version == VERSION_FLUX_LITE) {
ggml_set_f32(init_latent, 0.1159f);
} else {
ggml_set_f32(init_latent, 0.f);
@ -1553,7 +1554,7 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B || sd_ctx->sd->version == VERSION_SD3_5_2B) {
params.mem_size *= 2;
}
if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) {
if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL || sd_ctx->sd->version == VERSION_FLUX_LITE) {
params.mem_size *= 3;
}
if (sd_ctx->sd->stacked_id) {

View File

@ -457,7 +457,7 @@ public:
bool use_video_decoder = false,
SDVersion version = VERSION_SD1)
: decode_only(decode_only), use_video_decoder(use_video_decoder) {
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) {
dd_config.z_channels = 16;
use_quant = false;
}