feat: support Inpaint models (#511)

This commit is contained in:
stduhpf 2024-12-28 06:04:49 +01:00 committed by GitHub
parent cc92a6a1b3
commit 8f4ab9add3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 383 additions and 64 deletions

View File

@ -61,18 +61,18 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
SDVersion version = VERSION_SD1,
PMVersion pv = PM_VERSION_1,
int clip_skip = -1)
: version(version), pm_version(pv), tokenizer(version == VERSION_SD2 ? 0 : 49407), embd_dir(embd_dir) {
: version(version), pm_version(pv), tokenizer(sd_version_is_sd2(version) ? 0 : 49407), embd_dir(embd_dir) {
if (clip_skip <= 0) {
clip_skip = 1;
if (version == VERSION_SD2 || version == VERSION_SDXL) {
if (sd_version_is_sd2(version) || sd_version_is_sdxl(version)) {
clip_skip = 2;
}
}
if (version == VERSION_SD1) {
if (sd_version_is_sd1(version)) {
text_model = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "cond_stage_model.transformer.text_model", OPENAI_CLIP_VIT_L_14, clip_skip);
} else if (version == VERSION_SD2) {
} else if (sd_version_is_sd2(version)) {
text_model = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "cond_stage_model.transformer.text_model", OPEN_CLIP_VIT_H_14, clip_skip);
} else if (version == VERSION_SDXL) {
} else if (sd_version_is_sdxl(version)) {
text_model = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "cond_stage_model.transformer.text_model", OPENAI_CLIP_VIT_L_14, clip_skip, false);
text_model2 = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "cond_stage_model.1.transformer.text_model", OPEN_CLIP_VIT_BIGG_14, clip_skip, false);
}
@ -80,35 +80,35 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
void set_clip_skip(int clip_skip) {
text_model->set_clip_skip(clip_skip);
if (version == VERSION_SDXL) {
if (sd_version_is_sdxl(version)) {
text_model2->set_clip_skip(clip_skip);
}
}
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) {
text_model->get_param_tensors(tensors, "cond_stage_model.transformer.text_model");
if (version == VERSION_SDXL) {
if (sd_version_is_sdxl(version)) {
text_model2->get_param_tensors(tensors, "cond_stage_model.1.transformer.text_model");
}
}
void alloc_params_buffer() {
text_model->alloc_params_buffer();
if (version == VERSION_SDXL) {
if (sd_version_is_sdxl(version)) {
text_model2->alloc_params_buffer();
}
}
void free_params_buffer() {
text_model->free_params_buffer();
if (version == VERSION_SDXL) {
if (sd_version_is_sdxl(version)) {
text_model2->free_params_buffer();
}
}
size_t get_params_buffer_size() {
size_t buffer_size = text_model->get_params_buffer_size();
if (version == VERSION_SDXL) {
if (sd_version_is_sdxl(version)) {
buffer_size += text_model2->get_params_buffer_size();
}
return buffer_size;
@ -402,7 +402,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
auto input_ids = vector_to_ggml_tensor_i32(work_ctx, chunk_tokens);
struct ggml_tensor* input_ids2 = NULL;
size_t max_token_idx = 0;
if (version == VERSION_SDXL) {
if (sd_version_is_sdxl(version)) {
auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), tokenizer.EOS_TOKEN_ID);
if (it != chunk_tokens.end()) {
std::fill(std::next(it), chunk_tokens.end(), 0);
@ -427,7 +427,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
false,
&chunk_hidden_states1,
work_ctx);
if (version == VERSION_SDXL) {
if (sd_version_is_sdxl(version)) {
text_model2->compute(n_threads,
input_ids2,
0,
@ -486,7 +486,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
ggml_nelements(hidden_states) / chunk_hidden_states->ne[0]);
ggml_tensor* vec = NULL;
if (version == VERSION_SDXL) {
if (sd_version_is_sdxl(version)) {
int out_dim = 256;
vec = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, adm_in_channels);
// [0:1280]

View File

@ -34,11 +34,11 @@ public:
ControlNetBlock(SDVersion version = VERSION_SD1)
: version(version) {
if (version == VERSION_SD2) {
if (sd_version_is_sd2(version)) {
context_dim = 1024;
num_head_channels = 64;
num_heads = -1;
} else if (version == VERSION_SDXL) {
} else if (sd_version_is_sdxl(version)) {
context_dim = 2048;
attention_resolutions = {4, 2};
channel_mult = {1, 2, 4};
@ -58,7 +58,7 @@ public:
// time_embed_1 is nn.SiLU()
blocks["time_embed.2"] = std::shared_ptr<GGMLBlock>(new Linear(time_embed_dim, time_embed_dim));
if (version == VERSION_SDXL || version == VERSION_SVD) {
if (sd_version_is_sdxl(version) || version == VERSION_SVD) {
blocks["label_emb.0.0"] = std::shared_ptr<GGMLBlock>(new Linear(adm_in_channels, time_embed_dim));
// label_emb_1 is nn.SiLU()
blocks["label_emb.0.2"] = std::shared_ptr<GGMLBlock>(new Linear(time_embed_dim, time_embed_dim));

View File

@ -133,8 +133,9 @@ struct FluxModel : public DiffusionModel {
FluxModel(ggml_backend_t backend,
std::map<std::string, enum ggml_type>& tensor_types,
bool flash_attn = false)
: flux(backend, tensor_types, "model.diffusion_model", flash_attn) {
SDVersion version = VERSION_FLUX,
bool flash_attn = false)
: flux(backend, tensor_types, "model.diffusion_model", version, flash_attn) {
}
void alloc_params_buffer() {
@ -174,7 +175,7 @@ struct FluxModel : public DiffusionModel {
struct ggml_tensor** output = NULL,
struct ggml_context* output_ctx = NULL,
std::vector<int> skip_layers = std::vector<int>()) {
return flux.compute(n_threads, x, timesteps, context, y, guidance, output, output_ctx, skip_layers);
return flux.compute(n_threads, x, timesteps, context, c_concat, y, guidance, output, output_ctx, skip_layers);
}
};

View File

@ -85,6 +85,7 @@ struct SDParams {
std::string lora_model_dir;
std::string output_path = "output.png";
std::string input_path;
std::string mask_path;
std::string control_image_path;
std::string prompt;
@ -148,6 +149,7 @@ void print_params(SDParams params) {
printf(" normalize input image : %s\n", params.normalize_input ? "true" : "false");
printf(" output_path: %s\n", params.output_path.c_str());
printf(" init_img: %s\n", params.input_path.c_str());
printf(" mask_img: %s\n", params.mask_path.c_str());
printf(" control_image: %s\n", params.control_image_path.c_str());
printf(" clip on cpu: %s\n", params.clip_on_cpu ? "true" : "false");
printf(" controlnet cpu: %s\n", params.control_net_cpu ? "true" : "false");
@ -384,6 +386,12 @@ void parse_args(int argc, const char** argv, SDParams& params) {
break;
}
params.input_path = argv[i];
} else if (arg == "--mask") {
if (++i >= argc) {
invalid_arg = true;
break;
}
params.mask_path = argv[i];
} else if (arg == "--control-image") {
if (++i >= argc) {
invalid_arg = true;
@ -803,6 +811,8 @@ int main(int argc, const char* argv[]) {
bool vae_decode_only = true;
uint8_t* input_image_buffer = NULL;
uint8_t* control_image_buffer = NULL;
uint8_t* mask_image_buffer = NULL;
if (params.mode == IMG2IMG || params.mode == IMG2VID) {
vae_decode_only = false;
@ -907,6 +917,18 @@ int main(int argc, const char* argv[]) {
}
}
if (params.mask_path != "") {
int c = 0;
mask_image_buffer = stbi_load(params.mask_path.c_str(), &params.width, &params.height, &c, 1);
} else {
std::vector<uint8_t> arr(params.width * params.height, 255);
mask_image_buffer = arr.data();
}
sd_image_t mask_image = {(uint32_t)params.width,
(uint32_t)params.height,
1,
mask_image_buffer};
sd_image_t* results;
if (params.mode == TXT2IMG) {
results = txt2img(sd_ctx,
@ -976,6 +998,7 @@ int main(int argc, const char* argv[]) {
} else {
results = img2img(sd_ctx,
input_image,
mask_image,
params.prompt.c_str(),
params.negative_prompt.c_str(),
params.clip_skip,

View File

@ -490,6 +490,7 @@ namespace Flux {
struct FluxParams {
int64_t in_channels = 64;
int64_t out_channels = 64;
int64_t vec_in_dim = 768;
int64_t context_in_dim = 4096;
int64_t hidden_size = 3072;
@ -642,8 +643,7 @@ namespace Flux {
Flux() {}
Flux(FluxParams params)
: params(params) {
int64_t out_channels = params.in_channels;
int64_t pe_dim = params.hidden_size / params.num_heads;
int64_t pe_dim = params.hidden_size / params.num_heads;
blocks["img_in"] = std::shared_ptr<GGMLBlock>(new Linear(params.in_channels, params.hidden_size, true));
blocks["time_in"] = std::shared_ptr<GGMLBlock>(new MLPEmbedder(256, params.hidden_size));
@ -669,7 +669,7 @@ namespace Flux {
params.flash_attn));
}
blocks["final_layer"] = std::shared_ptr<GGMLBlock>(new LastLayer(params.hidden_size, 1, out_channels));
blocks["final_layer"] = std::shared_ptr<GGMLBlock>(new LastLayer(params.hidden_size, 1, params.out_channels));
}
struct ggml_tensor* patchify(struct ggml_context* ctx,
@ -789,6 +789,7 @@ namespace Flux {
struct ggml_tensor* x,
struct ggml_tensor* timestep,
struct ggml_tensor* context,
struct ggml_tensor* c_concat,
struct ggml_tensor* y,
struct ggml_tensor* guidance,
struct ggml_tensor* pe,
@ -797,6 +798,7 @@ namespace Flux {
// x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
// timestep: (N,) tensor of diffusion timesteps
// context: (N, L, D)
// c_concat: NULL, or for (N,C+M, H, W) for Fill
// y: (N, adm_in_channels) tensor of class labels
// guidance: (N,)
// pe: (L, d_head/2, 2, 2)
@ -806,6 +808,7 @@ namespace Flux {
int64_t W = x->ne[0];
int64_t H = x->ne[1];
int64_t C = x->ne[2];
int64_t patch_size = 2;
int pad_h = (patch_size - H % patch_size) % patch_size;
int pad_w = (patch_size - W % patch_size) % patch_size;
@ -814,6 +817,19 @@ namespace Flux {
// img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
auto img = patchify(ctx, x, patch_size); // [N, h*w, C * patch_size * patch_size]
if (c_concat != NULL) {
ggml_tensor* masked = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], 0);
ggml_tensor* mask = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 8 * 8, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C);
masked = ggml_pad(ctx, masked, pad_w, pad_h, 0, 0);
mask = ggml_pad(ctx, mask, pad_w, pad_h, 0, 0);
masked = patchify(ctx, masked, patch_size);
mask = patchify(ctx, mask, patch_size);
img = ggml_concat(ctx, img, ggml_concat(ctx, masked, mask, 0), 0);
}
auto out = forward_orig(ctx, img, context, timestep, y, guidance, pe, skip_layers); // [N, h*w, C * patch_size * patch_size]
// rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)
@ -834,12 +850,16 @@ namespace Flux {
FluxRunner(ggml_backend_t backend,
std::map<std::string, enum ggml_type>& tensor_types = empty_tensor_types,
const std::string prefix = "",
SDVersion version = VERSION_FLUX,
bool flash_attn = false)
: GGMLRunner(backend) {
flux_params.flash_attn = flash_attn;
flux_params.guidance_embed = false;
flux_params.depth = 0;
flux_params.depth_single_blocks = 0;
if (version == VERSION_FLUX_FILL) {
flux_params.in_channels = 384;
}
for (auto pair : tensor_types) {
std::string tensor_name = pair.first;
if (tensor_name.find("model.diffusion_model.") == std::string::npos)
@ -886,14 +906,18 @@ namespace Flux {
struct ggml_cgraph* build_graph(struct ggml_tensor* x,
struct ggml_tensor* timesteps,
struct ggml_tensor* context,
struct ggml_tensor* c_concat,
struct ggml_tensor* y,
struct ggml_tensor* guidance,
std::vector<int> skip_layers = std::vector<int>()) {
GGML_ASSERT(x->ne[3] == 1);
struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, FLUX_GRAPH_SIZE, false);
x = to_backend(x);
context = to_backend(context);
x = to_backend(x);
context = to_backend(context);
if (c_concat != NULL) {
c_concat = to_backend(c_concat);
}
y = to_backend(y);
timesteps = to_backend(timesteps);
if (flux_params.guidance_embed) {
@ -913,6 +937,7 @@ namespace Flux {
x,
timesteps,
context,
c_concat,
y,
guidance,
pe,
@ -927,6 +952,7 @@ namespace Flux {
struct ggml_tensor* x,
struct ggml_tensor* timesteps,
struct ggml_tensor* context,
struct ggml_tensor* c_concat,
struct ggml_tensor* y,
struct ggml_tensor* guidance,
struct ggml_tensor** output = NULL,
@ -938,7 +964,7 @@ namespace Flux {
// y: [N, adm_in_channels] or [1, adm_in_channels]
// guidance: [N, ]
auto get_graph = [&]() -> struct ggml_cgraph* {
return build_graph(x, timesteps, context, y, guidance, skip_layers);
return build_graph(x, timesteps, context, c_concat, y, guidance, skip_layers);
};
GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
@ -978,7 +1004,7 @@ namespace Flux {
struct ggml_tensor* out = NULL;
int t0 = ggml_time_ms();
compute(8, x, timesteps, context, y, guidance, &out, work_ctx);
compute(8, x, timesteps, context, NULL, y, guidance, &out, work_ctx);
int t1 = ggml_time_ms();
print_ggml_tensor(out);

View File

@ -290,6 +290,42 @@ __STATIC_INLINE__ void sd_image_to_tensor(const uint8_t* image_data,
}
}
__STATIC_INLINE__ void sd_mask_to_tensor(const uint8_t* image_data,
struct ggml_tensor* output,
bool scale = true) {
int64_t width = output->ne[0];
int64_t height = output->ne[1];
int64_t channels = output->ne[2];
GGML_ASSERT(channels == 1 && output->type == GGML_TYPE_F32);
for (int iy = 0; iy < height; iy++) {
for (int ix = 0; ix < width; ix++) {
float value = *(image_data + iy * width * channels + ix);
if (scale) {
value /= 255.f;
}
ggml_tensor_set_f32(output, value, ix, iy);
}
}
}
__STATIC_INLINE__ void sd_apply_mask(struct ggml_tensor* image_data,
struct ggml_tensor* mask,
struct ggml_tensor* output) {
int64_t width = output->ne[0];
int64_t height = output->ne[1];
int64_t channels = output->ne[2];
GGML_ASSERT(output->type == GGML_TYPE_F32);
for (int ix = 0; ix < width; ix++) {
for (int iy = 0; iy < height; iy++) {
float m = ggml_tensor_get_f32(mask, ix, iy);
for (int k = 0; k < channels; k++) {
float value = ((float)(m < 254.5/255)) * (ggml_tensor_get_f32(image_data, ix, iy, k) - .5) + .5;
ggml_tensor_set_f32(output, value, ix, iy, k);
}
}
}
}
__STATIC_INLINE__ void sd_mul_images_to_tensor(const uint8_t* image_data,
struct ggml_tensor* output,
int idx,
@ -1144,7 +1180,6 @@ public:
}
#endif
ggml_backend_graph_compute(backend, gf);
#ifdef GGML_PERF
ggml_graph_print(gf);
#endif

View File

@ -1458,24 +1458,49 @@ bool ModelLoader::init_from_ckpt_file(const std::string& file_path, const std::s
}
SDVersion ModelLoader::get_sd_version() {
TensorStorage token_embedding_weight;
for (auto& tensor_storage : tensor_storages) {
if (tensor_storage.name.find("model.diffusion_model.double_blocks.") != std::string::npos) {
return VERSION_FLUX;
}
if (tensor_storage.name.find("model.diffusion_model.joint_blocks.") != std::string::npos) {
return VERSION_SD3;
}
if (tensor_storage.name.find("conditioner.embedders.1") != std::string::npos) {
return VERSION_SDXL;
}
if (tensor_storage.name.find("cond_stage_model.1") != std::string::npos) {
return VERSION_SDXL;
}
if (tensor_storage.name.find("model.diffusion_model.input_blocks.8.0.time_mixer.mix_factor") != std::string::npos) {
return VERSION_SVD;
}
TensorStorage token_embedding_weight, input_block_weight;
bool input_block_checked = false;
bool has_multiple_encoders = false;
bool is_unet = false;
bool is_xl = false;
bool is_flux = false;
#define found_family (is_xl || is_flux)
for (auto& tensor_storage : tensor_storages) {
if (!found_family) {
if (tensor_storage.name.find("model.diffusion_model.double_blocks.") != std::string::npos) {
is_flux = true;
if (input_block_checked) {
break;
}
}
if (tensor_storage.name.find("model.diffusion_model.joint_blocks.") != std::string::npos) {
return VERSION_SD3;
}
if (tensor_storage.name.find("model.diffusion_model.input_blocks.") != std::string::npos) {
is_unet = true;
if(has_multiple_encoders){
is_xl = true;
if (input_block_checked) {
break;
}
}
}
if (tensor_storage.name.find("conditioner.embedders.1") != std::string::npos || tensor_storage.name.find("cond_stage_model.1") != std::string::npos) {
has_multiple_encoders = true;
if(is_unet){
is_xl = true;
if (input_block_checked) {
break;
}
}
}
if (tensor_storage.name.find("model.diffusion_model.input_blocks.8.0.time_mixer.mix_factor") != std::string::npos) {
return VERSION_SVD;
}
}
if (tensor_storage.name == "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight" ||
tensor_storage.name == "cond_stage_model.model.token_embedding.weight" ||
tensor_storage.name == "text_model.embeddings.token_embedding.weight" ||
@ -1485,11 +1510,39 @@ SDVersion ModelLoader::get_sd_version() {
token_embedding_weight = tensor_storage;
// break;
}
if (tensor_storage.name == "model.diffusion_model.input_blocks.0.0.weight" || tensor_storage.name == "model.diffusion_model.img_in.weight") {
input_block_weight = tensor_storage;
input_block_checked = true;
if (found_family) {
break;
}
}
}
bool is_inpaint = input_block_weight.ne[2] == 9;
if (is_xl) {
if (is_inpaint) {
return VERSION_SDXL_INPAINT;
}
return VERSION_SDXL;
}
if (is_flux) {
is_inpaint = input_block_weight.ne[0] == 384;
if (is_inpaint) {
return VERSION_FLUX_FILL;
}
return VERSION_FLUX;
}
if (token_embedding_weight.ne[0] == 768) {
if (is_inpaint) {
return VERSION_SD1_INPAINT;
}
return VERSION_SD1;
} else if (token_embedding_weight.ne[0] == 1024) {
if (is_inpaint) {
return VERSION_SD2_INPAINT;
}
return VERSION_SD2;
}
return VERSION_COUNT;

34
model.h
View File

@ -19,16 +19,20 @@
enum SDVersion {
VERSION_SD1,
VERSION_SD1_INPAINT,
VERSION_SD2,
VERSION_SD2_INPAINT,
VERSION_SDXL,
VERSION_SDXL_INPAINT,
VERSION_SVD,
VERSION_SD3,
VERSION_FLUX,
VERSION_FLUX_FILL,
VERSION_COUNT,
};
static inline bool sd_version_is_flux(SDVersion version) {
if (version == VERSION_FLUX) {
if (version == VERSION_FLUX || version == VERSION_FLUX_FILL) {
return true;
}
return false;
@ -41,6 +45,34 @@ static inline bool sd_version_is_sd3(SDVersion version) {
return false;
}
static inline bool sd_version_is_sd1(SDVersion version) {
if (version == VERSION_SD1 || version == VERSION_SD1_INPAINT) {
return true;
}
return false;
}
static inline bool sd_version_is_sd2(SDVersion version) {
if (version == VERSION_SD2 || version == VERSION_SD2_INPAINT) {
return true;
}
return false;
}
static inline bool sd_version_is_sdxl(SDVersion version) {
if (version == VERSION_SDXL || version == VERSION_SDXL_INPAINT) {
return true;
}
return false;
}
static inline bool sd_version_is_inpaint(SDVersion version) {
if (version == VERSION_SD1_INPAINT || version == VERSION_SD2_INPAINT || version == VERSION_SDXL_INPAINT || version == VERSION_FLUX_FILL) {
return true;
}
return false;
}
static inline bool sd_version_is_dit(SDVersion version) {
if (sd_version_is_flux(version) || sd_version_is_sd3(version)) {
return true;

View File

@ -26,11 +26,15 @@
const char* model_version_to_str[] = {
"SD 1.x",
"SD 1.x Inpaint",
"SD 2.x",
"SD 2.x Inpaint",
"SDXL",
"SDXL Inpaint",
"SVD",
"SD3.x",
"Flux"};
"Flux",
"Flux Fill"};
const char* sampling_methods_str[] = {
"Euler A",
@ -263,7 +267,7 @@ public:
model_loader.set_wtype_override(wtype);
}
if (version == VERSION_SDXL) {
if (sd_version_is_sdxl(version)) {
vae_wtype = GGML_TYPE_F32;
model_loader.set_wtype_override(GGML_TYPE_F32, "vae.");
}
@ -275,7 +279,7 @@ public:
LOG_DEBUG("ggml tensor size = %d bytes", (int)sizeof(ggml_tensor));
if (version == VERSION_SDXL) {
if (sd_version_is_sdxl(version)) {
scale_factor = 0.13025f;
if (vae_path.size() == 0 && taesd_path.size() == 0) {
LOG_WARN(
@ -329,7 +333,7 @@ public:
diffusion_model = std::make_shared<MMDiTModel>(backend, model_loader.tensor_storages_types);
} else if (sd_version_is_flux(version)) {
cond_stage_model = std::make_shared<FluxCLIPEmbedder>(clip_backend, model_loader.tensor_storages_types);
diffusion_model = std::make_shared<FluxModel>(backend, model_loader.tensor_storages_types, diffusion_flash_attn);
diffusion_model = std::make_shared<FluxModel>(backend, model_loader.tensor_storages_types, version, diffusion_flash_attn);
} else {
if (id_embeddings_path.find("v2") != std::string::npos) {
cond_stage_model = std::make_shared<FrozenCLIPEmbedderWithCustomWords>(clip_backend, model_loader.tensor_storages_types, embeddings_path, version, PM_VERSION_2);
@ -517,8 +521,8 @@ public:
// check is_using_v_parameterization_for_sd2
bool is_using_v_parameterization = false;
if (version == VERSION_SD2) {
if (is_using_v_parameterization_for_sd2(ctx)) {
if (sd_version_is_sd2(version)) {
if (is_using_v_parameterization_for_sd2(ctx, sd_version_is_inpaint(version))) {
is_using_v_parameterization = true;
}
} else if (version == VERSION_SVD) {
@ -592,7 +596,7 @@ public:
return true;
}
bool is_using_v_parameterization_for_sd2(ggml_context* work_ctx) {
bool is_using_v_parameterization_for_sd2(ggml_context* work_ctx, bool is_inpaint = false) {
struct ggml_tensor* x_t = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 8, 8, 4, 1);
ggml_set_f32(x_t, 0.5);
struct ggml_tensor* c = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 1024, 2, 1, 1);
@ -600,9 +604,13 @@ public:
struct ggml_tensor* timesteps = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, 1);
ggml_set_f32(timesteps, 999);
struct ggml_tensor* concat = is_inpaint ? ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 8, 8, 5, 1) : NULL;
ggml_set_f32(concat, 0);
int64_t t0 = ggml_time_ms();
struct ggml_tensor* out = ggml_dup_tensor(work_ctx, x_t);
diffusion_model->compute(n_threads, x_t, timesteps, c, NULL, NULL, NULL, -1, {}, 0.f, &out);
diffusion_model->compute(n_threads, x_t, timesteps, c, concat, NULL, NULL, -1, {}, 0.f, &out);
diffusion_model->free_compute_buffer();
double result = 0.f;
@ -785,7 +793,20 @@ public:
std::vector<int> skip_layers = {},
float slg_scale = 0,
float skip_layer_start = 0.01,
float skip_layer_end = 0.2) {
float skip_layer_end = 0.2,
ggml_tensor* noise_mask = nullptr) {
LOG_DEBUG("Sample");
struct ggml_init_params params;
size_t data_size = ggml_row_size(init_latent->type, init_latent->ne[0]);
for (int i = 1; i < 4; i++) {
data_size *= init_latent->ne[i];
}
data_size += 1024;
params.mem_size = data_size * 3;
params.mem_buffer = NULL;
params.no_alloc = false;
ggml_context* tmp_ctx = ggml_init(params);
size_t steps = sigmas.size() - 1;
// noise = load_tensor_from_file(work_ctx, "./rand0.bin");
// print_ggml_tensor(noise);
@ -944,6 +965,19 @@ public:
pretty_progress(step, (int)steps, (t1 - t0) / 1000000.f);
// LOG_INFO("step %d sampling completed taking %.2fs", step, (t1 - t0) * 1.0f / 1000000);
}
if (noise_mask != nullptr) {
for (int64_t x = 0; x < denoised->ne[0]; x++) {
for (int64_t y = 0; y < denoised->ne[1]; y++) {
float mask = ggml_tensor_get_f32(noise_mask, x, y);
for (int64_t k = 0; k < denoised->ne[2]; k++) {
float init = ggml_tensor_get_f32(init_latent, x, y, k);
float den = ggml_tensor_get_f32(denoised, x, y, k);
ggml_tensor_set_f32(denoised, init + mask * (den - init), x, y, k);
}
}
}
}
return denoised;
};
@ -1167,7 +1201,8 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
std::vector<int> skip_layers = {},
float slg_scale = 0,
float skip_layer_start = 0.01,
float skip_layer_end = 0.2) {
float skip_layer_end = 0.2,
ggml_tensor* masked_image = NULL) {
if (seed < 0) {
// Generally, when using the provided command line, the seed is always >0.
// However, to prevent potential issues if 'stable-diffusion.cpp' is invoked as a library
@ -1317,7 +1352,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
SDCondition uncond;
if (cfg_scale != 1.0) {
bool force_zero_embeddings = false;
if (sd_ctx->sd->version == VERSION_SDXL && negative_prompt.size() == 0) {
if (sd_version_is_sdxl(sd_ctx->sd->version) && negative_prompt.size() == 0) {
force_zero_embeddings = true;
}
uncond = sd_ctx->sd->cond_stage_model->get_learned_condition(work_ctx,
@ -1354,6 +1389,39 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
int W = width / 8;
int H = height / 8;
LOG_INFO("sampling using %s method", sampling_methods_str[sample_method]);
ggml_tensor* noise_mask = nullptr;
if (sd_version_is_inpaint(sd_ctx->sd->version)) {
if (masked_image == NULL) {
int64_t mask_channels = 1;
if (sd_ctx->sd->version == VERSION_FLUX_FILL) {
mask_channels = 8 * 8; // flatten the whole mask
}
// no mask, set the whole image as masked
masked_image = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], mask_channels + init_latent->ne[2], 1);
for (int64_t x = 0; x < masked_image->ne[0]; x++) {
for (int64_t y = 0; y < masked_image->ne[1]; y++) {
if (sd_ctx->sd->version == VERSION_FLUX_FILL) {
// TODO: this might be wrong
for (int64_t c = 0; c < init_latent->ne[2]; c++) {
ggml_tensor_set_f32(masked_image, 0, x, y, c);
}
for (int64_t c = init_latent->ne[2]; c < masked_image->ne[2]; c++) {
ggml_tensor_set_f32(masked_image, 1, x, y, c);
}
} else {
ggml_tensor_set_f32(masked_image, 1, x, y, 0);
for (int64_t c = 1; c < masked_image->ne[2]; c++) {
ggml_tensor_set_f32(masked_image, 0, x, y, c);
}
}
}
}
}
cond.c_concat = masked_image;
uncond.c_concat = masked_image;
} else {
noise_mask = masked_image;
}
for (int b = 0; b < batch_count; b++) {
int64_t sampling_start = ggml_time_ms();
int64_t cur_seed = seed + b;
@ -1389,7 +1457,9 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
skip_layers,
slg_scale,
skip_layer_start,
skip_layer_end);
skip_layer_end,
noise_mask);
// struct ggml_tensor* x_0 = load_tensor_from_file(ctx, "samples_ddim.bin");
// print_ggml_tensor(x_0);
int64_t sampling_end = ggml_time_ms();
@ -1511,6 +1581,10 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
ggml_set_f32(init_latent, 0.f);
}
if (sd_version_is_inpaint(sd_ctx->sd->version)) {
LOG_WARN("This is an inpainting model, this should only be used in img2img mode with a mask");
}
sd_image_t* result_images = generate_image(sd_ctx,
work_ctx,
init_latent,
@ -1544,6 +1618,7 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
sd_image_t* img2img(sd_ctx_t* sd_ctx,
sd_image_t init_image,
sd_image_t mask,
const char* prompt_c_str,
const char* negative_prompt_c_str,
int clip_skip,
@ -1583,7 +1658,7 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
if (sd_ctx->sd->stacked_id) {
params.mem_size += static_cast<size_t>(10 * 1024 * 1024); // 10 MB
}
params.mem_size += width * height * 3 * sizeof(float) * 2;
params.mem_size += width * height * 3 * sizeof(float) * 3;
params.mem_size *= batch_count;
params.mem_buffer = NULL;
params.no_alloc = false;
@ -1604,7 +1679,70 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
sd_ctx->sd->rng->manual_seed(seed);
ggml_tensor* init_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1);
ggml_tensor* mask_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 1, 1);
sd_mask_to_tensor(mask.data, mask_img);
sd_image_to_tensor(init_image.data, init_img);
ggml_tensor* masked_image;
if (sd_version_is_inpaint(sd_ctx->sd->version)) {
int64_t mask_channels = 1;
if (sd_ctx->sd->version == VERSION_FLUX_FILL) {
mask_channels = 8 * 8; // flatten the whole mask
}
ggml_tensor* masked_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1);
sd_apply_mask(init_img, mask_img, masked_img);
ggml_tensor* masked_image_0 = NULL;
if (!sd_ctx->sd->use_tiny_autoencoder) {
ggml_tensor* moments = sd_ctx->sd->encode_first_stage(work_ctx, masked_img);
masked_image_0 = sd_ctx->sd->get_first_stage_encoding(work_ctx, moments);
} else {
masked_image_0 = sd_ctx->sd->encode_first_stage(work_ctx, masked_img);
}
masked_image = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, masked_image_0->ne[0], masked_image_0->ne[1], mask_channels + masked_image_0->ne[2], 1);
for (int ix = 0; ix < masked_image_0->ne[0]; ix++) {
for (int iy = 0; iy < masked_image_0->ne[1]; iy++) {
int mx = ix * 8;
int my = iy * 8;
if (sd_ctx->sd->version == VERSION_FLUX_FILL) {
for (int k = 0; k < masked_image_0->ne[2]; k++) {
float v = ggml_tensor_get_f32(masked_image_0, ix, iy, k);
ggml_tensor_set_f32(masked_image, v, ix, iy, k);
}
// "Encode" 8x8 mask chunks into a flattened 1x64 vector, and concatenate to masked image
for (int x = 0; x < 8; x++) {
for (int y = 0; y < 8; y++) {
float m = ggml_tensor_get_f32(mask_img, mx + x, my + y);
// TODO: check if the way the mask is flattened is correct (is it supposed to be x*8+y or x+8*y?)
// python code was using "b (h 8) (w 8) -> b (8 8) h w"
ggml_tensor_set_f32(masked_image, m, ix, iy, masked_image_0->ne[2] + x * 8 + y);
}
}
} else {
float m = ggml_tensor_get_f32(mask_img, mx, my);
ggml_tensor_set_f32(masked_image, m, ix, iy, 0);
for (int k = 0; k < masked_image_0->ne[2]; k++) {
float v = ggml_tensor_get_f32(masked_image_0, ix, iy, k);
ggml_tensor_set_f32(masked_image, v, ix, iy, k + mask_channels);
}
}
}
}
} else {
// LOG_WARN("Inpainting with a base model is not great");
masked_image = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width / 8, height / 8, 1, 1);
for (int ix = 0; ix < masked_image->ne[0]; ix++) {
for (int iy = 0; iy < masked_image->ne[1]; iy++) {
int mx = ix * 8;
int my = iy * 8;
float m = ggml_tensor_get_f32(mask_img, mx, my);
ggml_tensor_set_f32(masked_image, m, ix, iy);
}
}
}
ggml_tensor* init_latent = NULL;
if (!sd_ctx->sd->use_tiny_autoencoder) {
ggml_tensor* moments = sd_ctx->sd->encode_first_stage(work_ctx, init_img);
@ -1612,12 +1750,15 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
} else {
init_latent = sd_ctx->sd->encode_first_stage(work_ctx, init_img);
}
print_ggml_tensor(init_latent, true);
size_t t1 = ggml_time_ms();
LOG_INFO("encode_first_stage completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
std::vector<float> sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps);
size_t t_enc = static_cast<size_t>(sample_steps * strength);
if (t_enc == sample_steps)
t_enc--;
LOG_INFO("target t_enc is %zu steps", t_enc);
std::vector<float> sigma_sched;
sigma_sched.assign(sigmas.begin() + sample_steps - t_enc - 1, sigmas.end());
@ -1644,7 +1785,8 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
skip_layers_vec,
slg_scale,
skip_layer_start,
skip_layer_end);
skip_layer_end,
masked_image);
size_t t2 = ggml_time_ms();

View File

@ -174,6 +174,7 @@ SD_API sd_image_t* txt2img(sd_ctx_t* sd_ctx,
SD_API sd_image_t* img2img(sd_ctx_t* sd_ctx,
sd_image_t init_image,
sd_image_t mask_image,
const char* prompt,
const char* negative_prompt,
int clip_skip,

View File

@ -166,6 +166,7 @@ public:
// ldm.modules.diffusionmodules.openaimodel.UNetModel
class UnetModelBlock : public GGMLBlock {
protected:
static std::map<std::string, enum ggml_type> empty_tensor_types;
SDVersion version = VERSION_SD1;
// network hparams
int in_channels = 4;
@ -183,13 +184,13 @@ public:
int model_channels = 320;
int adm_in_channels = 2816; // only for VERSION_SDXL/SVD
UnetModelBlock(SDVersion version = VERSION_SD1, bool flash_attn = false)
UnetModelBlock(SDVersion version = VERSION_SD1, std::map<std::string, enum ggml_type>& tensor_types = empty_tensor_types, bool flash_attn = false)
: version(version) {
if (version == VERSION_SD2) {
if (sd_version_is_sd2(version)) {
context_dim = 1024;
num_head_channels = 64;
num_heads = -1;
} else if (version == VERSION_SDXL) {
} else if (sd_version_is_sdxl(version)) {
context_dim = 2048;
attention_resolutions = {4, 2};
channel_mult = {1, 2, 4};
@ -204,6 +205,10 @@ public:
num_head_channels = 64;
num_heads = -1;
}
if (sd_version_is_inpaint(version)) {
in_channels = 9;
}
// dims is always 2
// use_temporal_attention is always True for SVD
@ -211,7 +216,7 @@ public:
// time_embed_1 is nn.SiLU()
blocks["time_embed.2"] = std::shared_ptr<GGMLBlock>(new Linear(time_embed_dim, time_embed_dim));
if (version == VERSION_SDXL || version == VERSION_SVD) {
if (sd_version_is_sdxl(version) || version == VERSION_SVD) {
blocks["label_emb.0.0"] = std::shared_ptr<GGMLBlock>(new Linear(adm_in_channels, time_embed_dim));
// label_emb_1 is nn.SiLU()
blocks["label_emb.0.2"] = std::shared_ptr<GGMLBlock>(new Linear(time_embed_dim, time_embed_dim));
@ -536,7 +541,7 @@ struct UNetModelRunner : public GGMLRunner {
const std::string prefix,
SDVersion version = VERSION_SD1,
bool flash_attn = false)
: GGMLRunner(backend), unet(version, flash_attn) {
: GGMLRunner(backend), unet(version, tensor_types, flash_attn) {
unet.init(params_ctx, tensor_types, prefix);
}
@ -566,6 +571,7 @@ struct UNetModelRunner : public GGMLRunner {
context = to_backend(context);
y = to_backend(y);
timesteps = to_backend(timesteps);
c_concat = to_backend(c_concat);
for (int i = 0; i < controls.size(); i++) {
controls[i] = to_backend(controls[i]);