feat: enable controlnet and photo maker for img2img mode

This commit is contained in:
leejet 2024-04-14 16:36:08 +08:00
parent ec82d5279a
commit 036ba9e6d8
5 changed files with 255 additions and 256 deletions

View File

@ -656,13 +656,16 @@ int main(int argc, const char* argv[]) {
return 1; return 1;
} }
bool vae_decode_only = true; bool vae_decode_only = true;
uint8_t* input_image_buffer = NULL; uint8_t* input_image_buffer = NULL;
uint8_t* control_image_buffer = NULL;
if (params.mode == IMG2IMG || params.mode == IMG2VID) { if (params.mode == IMG2IMG || params.mode == IMG2VID) {
vae_decode_only = false; vae_decode_only = false;
int c = 0; int c = 0;
input_image_buffer = stbi_load(params.input_path.c_str(), &params.width, &params.height, &c, 3); int width = 0;
int height = 0;
input_image_buffer = stbi_load(params.input_path.c_str(), &width, &height, &c, 3);
if (input_image_buffer == NULL) { if (input_image_buffer == NULL) {
fprintf(stderr, "load image from '%s' failed\n", params.input_path.c_str()); fprintf(stderr, "load image from '%s' failed\n", params.input_path.c_str());
return 1; return 1;
@ -672,21 +675,22 @@ int main(int argc, const char* argv[]) {
free(input_image_buffer); free(input_image_buffer);
return 1; return 1;
} }
if (params.width <= 0) { if (width <= 0) {
fprintf(stderr, "error: the width of image must be greater than 0\n"); fprintf(stderr, "error: the width of image must be greater than 0\n");
free(input_image_buffer); free(input_image_buffer);
return 1; return 1;
} }
if (params.height <= 0) { if (height <= 0) {
fprintf(stderr, "error: the height of image must be greater than 0\n"); fprintf(stderr, "error: the height of image must be greater than 0\n");
free(input_image_buffer); free(input_image_buffer);
return 1; return 1;
} }
// Resize input image ... // Resize input image ...
if (params.height % 64 != 0 || params.width % 64 != 0) { if (params.height != height || params.width != width) {
int resized_height = params.height + (64 - params.height % 64); printf("resize input image from %dx%d to %dx%d\n", width, height, params.width, params.height);
int resized_width = params.width + (64 - params.width % 64); int resized_height = params.height;
int resized_width = params.width;
uint8_t* resized_image_buffer = (uint8_t*)malloc(resized_height * resized_width * 3); uint8_t* resized_image_buffer = (uint8_t*)malloc(resized_height * resized_width * 3);
if (resized_image_buffer == NULL) { if (resized_image_buffer == NULL) {
@ -694,7 +698,7 @@ int main(int argc, const char* argv[]) {
free(input_image_buffer); free(input_image_buffer);
return 1; return 1;
} }
stbir_resize(input_image_buffer, params.width, params.height, 0, stbir_resize(input_image_buffer, width, height, 0,
resized_image_buffer, resized_width, resized_height, 0, STBIR_TYPE_UINT8, resized_image_buffer, resized_width, resized_height, 0, STBIR_TYPE_UINT8,
3 /*RGB channel*/, STBIR_ALPHA_CHANNEL_NONE, 0, 3 /*RGB channel*/, STBIR_ALPHA_CHANNEL_NONE, 0,
STBIR_EDGE_CLAMP, STBIR_EDGE_CLAMP, STBIR_EDGE_CLAMP, STBIR_EDGE_CLAMP,
@ -704,8 +708,6 @@ int main(int argc, const char* argv[]) {
// Save resized result // Save resized result
free(input_image_buffer); free(input_image_buffer);
input_image_buffer = resized_image_buffer; input_image_buffer = resized_image_buffer;
params.height = resized_height;
params.width = resized_width;
} }
} }
@ -732,31 +734,32 @@ int main(int argc, const char* argv[]) {
return 1; return 1;
} }
sd_image_t* control_image = NULL;
if (params.controlnet_path.size() > 0 && params.control_image_path.size() > 0) {
int c = 0;
control_image_buffer = stbi_load(params.control_image_path.c_str(), &params.width, &params.height, &c, 3);
if (control_image_buffer == NULL) {
fprintf(stderr, "load image from '%s' failed\n", params.control_image_path.c_str());
return 1;
}
control_image = new sd_image_t{(uint32_t)params.width,
(uint32_t)params.height,
3,
control_image_buffer};
if (params.canny_preprocess) { // apply preprocessor
control_image->data = preprocess_canny(control_image->data,
control_image->width,
control_image->height,
0.08f,
0.08f,
0.8f,
1.0f,
false);
}
}
sd_image_t* results; sd_image_t* results;
if (params.mode == TXT2IMG) { if (params.mode == TXT2IMG) {
sd_image_t* control_image = NULL;
if (params.controlnet_path.size() > 0 && params.control_image_path.size() > 0) {
int c = 0;
input_image_buffer = stbi_load(params.control_image_path.c_str(), &params.width, &params.height, &c, 3);
if (input_image_buffer == NULL) {
fprintf(stderr, "load image from '%s' failed\n", params.control_image_path.c_str());
return 1;
}
control_image = new sd_image_t{(uint32_t)params.width,
(uint32_t)params.height,
3,
input_image_buffer};
if (params.canny_preprocess) { // apply preprocessor
control_image->data = preprocess_canny(control_image->data,
control_image->width,
control_image->height,
0.08f,
0.08f,
0.8f,
1.0f,
false);
}
}
results = txt2img(sd_ctx, results = txt2img(sd_ctx,
params.prompt.c_str(), params.prompt.c_str(),
params.negative_prompt.c_str(), params.negative_prompt.c_str(),
@ -828,7 +831,12 @@ int main(int argc, const char* argv[]) {
params.sample_steps, params.sample_steps,
params.strength, params.strength,
params.seed, params.seed,
params.batch_count); params.batch_count,
control_image,
params.control_strength,
params.style_ratio,
params.normalize_input,
params.input_id_images_path.c_str());
} }
} }
@ -881,6 +889,8 @@ int main(int argc, const char* argv[]) {
} }
free(results); free(results);
free_sd_ctx(sd_ctx); free_sd_ctx(sd_ctx);
free(control_image_buffer);
free(input_image_buffer);
return 0; return 0;
} }

View File

@ -752,10 +752,9 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_timestep_embedding(
return ggml_timestep_embedding(ctx, timesteps, dim, max_period); return ggml_timestep_embedding(ctx, timesteps, dim, max_period);
} }
__STATIC_INLINE__ size_t ggml_tensor_num(ggml_context* ctx) {
__STATIC_INLINE__ size_t ggml_tensor_num(ggml_context * ctx) {
size_t num = 0; size_t num = 0;
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) { for (ggml_tensor* t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) {
num++; num++;
} }
return num; return num;
@ -851,7 +850,7 @@ protected:
} }
public: public:
virtual std::string get_desc() = 0; virtual std::string get_desc() = 0;
GGMLModule(ggml_backend_t backend, ggml_type wtype = GGML_TYPE_F32) GGMLModule(ggml_backend_t backend, ggml_type wtype = GGML_TYPE_F32)
: backend(backend), wtype(wtype) { : backend(backend), wtype(wtype) {

View File

@ -852,7 +852,6 @@ public:
copy_ggml_tensor(x, x_t); copy_ggml_tensor(x, x_t);
struct ggml_tensor* noised_input = ggml_dup_tensor(work_ctx, x_t); struct ggml_tensor* noised_input = ggml_dup_tensor(work_ctx, x_t);
struct ggml_tensor* guided_hint = NULL;
bool has_unconditioned = cfg_scale != 1.0 && uc != NULL; bool has_unconditioned = cfg_scale != 1.0 && uc != NULL;
@ -1536,60 +1535,35 @@ void free_sd_ctx(sd_ctx_t* sd_ctx) {
free(sd_ctx); free(sd_ctx);
} }
sd_image_t* txt2img(sd_ctx_t* sd_ctx, sd_image_t* generate_image(sd_ctx_t* sd_ctx,
const char* prompt_c_str, struct ggml_context* work_ctx,
const char* negative_prompt_c_str, ggml_tensor* init_latent,
int clip_skip, std::string prompt,
float cfg_scale, std::string negative_prompt,
int width, int clip_skip,
int height, float cfg_scale,
enum sample_method_t sample_method, int width,
int sample_steps, int height,
int64_t seed, enum sample_method_t sample_method,
int batch_count, const std::vector<float>& sigmas,
const sd_image_t* control_cond, int64_t seed,
float control_strength, int batch_count,
float style_ratio, const sd_image_t* control_cond,
bool normalize_input, float control_strength,
const char* input_id_images_path_c_str) { float style_ratio,
LOG_DEBUG("txt2img %dx%d", width, height); bool normalize_input,
if (sd_ctx == NULL) { std::string input_id_images_path) {
return NULL; if (seed < 0) {
} // Generally, when using the provided command line, the seed is always >0.
// LOG_DEBUG("%s %s %f %d %d %d", prompt_c_str, negative_prompt_c_str, cfg_scale, sample_steps, seed, batch_count); // However, to prevent potential issues if 'stable-diffusion.cpp' is invoked as a library
std::string prompt(prompt_c_str); // by a third party with a seed <0, let's incorporate randomization here.
std::string negative_prompt(negative_prompt_c_str); srand((int)time(NULL));
std::string input_id_images_path(input_id_images_path_c_str); seed = rand();
// preprocess input id images
std::vector<sd_image_t*> input_id_images;
if (sd_ctx->sd->pmid_model && input_id_images_path.size() > 0) {
std::vector<std::string> img_files = get_files_from_dir(input_id_images_path);
for (std::string img_file : img_files) {
int c = 0;
int width, height;
uint8_t* input_image_buffer = stbi_load(img_file.c_str(), &width, &height, &c, 3);
if (input_image_buffer == NULL) {
LOG_ERROR("PhotoMaker load image from '%s' failed", img_file.c_str());
continue;
} else {
LOG_INFO("PhotoMaker loaded image from '%s'", img_file.c_str());
}
sd_image_t* input_image = NULL;
input_image = new sd_image_t{(uint32_t)width,
(uint32_t)height,
3,
input_image_buffer};
input_image = preprocess_id_image(input_image);
if (input_image == NULL) {
LOG_ERROR("preprocess input id image from '%s' failed", img_file.c_str());
continue;
}
input_id_images.push_back(input_image);
}
} }
// extract and remove lora int sample_steps = sigmas.size() - 1;
// Apply lora
auto result_pair = extract_and_remove_lora(prompt); auto result_pair = extract_and_remove_lora(prompt);
std::unordered_map<std::string, float> lora_f2m = result_pair.first; // lora_name -> multiplier std::unordered_map<std::string, float> lora_f2m = result_pair.first; // lora_name -> multiplier
@ -1605,49 +1579,50 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
int64_t t1 = ggml_time_ms(); int64_t t1 = ggml_time_ms();
LOG_INFO("apply_loras completed, taking %.2fs", (t1 - t0) * 1.0f / 1000); LOG_INFO("apply_loras completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
if (sd_ctx->sd->stacked_id && !sd_ctx->sd->pmid_lora->applied) { // Photo Maker
t0 = ggml_time_ms();
sd_ctx->sd->pmid_lora->apply(sd_ctx->sd->tensors, sd_ctx->sd->n_threads);
t1 = ggml_time_ms();
sd_ctx->sd->pmid_lora->applied = true;
LOG_INFO("pmid_lora apply completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
if (sd_ctx->sd->free_params_immediately) {
sd_ctx->sd->pmid_lora->free_params_buffer();
}
}
struct ggml_init_params params;
params.mem_size = static_cast<size_t>(10 * 1024 * 1024); // 10 MB
if (sd_ctx->sd->stacked_id) {
params.mem_size += static_cast<size_t>(10 * 1024 * 1024); // 10 MB
}
params.mem_size += width * height * 3 * sizeof(float);
params.mem_size *= batch_count;
params.mem_buffer = NULL;
params.no_alloc = false;
// LOG_DEBUG("mem_size %u ", params.mem_size);
struct ggml_context* work_ctx = ggml_init(params);
if (!work_ctx) {
LOG_ERROR("ggml_init() failed");
return NULL;
}
if (seed < 0) {
// Generally, when using the provided command line, the seed is always >0.
// However, to prevent potential issues if 'stable-diffusion.cpp' is invoked as a library
// by a third party with a seed <0, let's incorporate randomization here.
srand((int)time(NULL));
seed = rand();
}
std::string prompt_text_only; std::string prompt_text_only;
ggml_tensor* init_img = NULL; ggml_tensor* init_img = NULL;
ggml_tensor* prompts_embeds = NULL; ggml_tensor* prompts_embeds = NULL;
ggml_tensor* pooled_prompts_embeds = NULL; ggml_tensor* pooled_prompts_embeds = NULL;
// ggml_tensor* class_tokens_mask = NULL;
std::vector<bool> class_tokens_mask; std::vector<bool> class_tokens_mask;
if (sd_ctx->sd->stacked_id) { if (sd_ctx->sd->stacked_id) {
if (!sd_ctx->sd->pmid_lora->applied) {
t0 = ggml_time_ms();
sd_ctx->sd->pmid_lora->apply(sd_ctx->sd->tensors, sd_ctx->sd->n_threads);
t1 = ggml_time_ms();
sd_ctx->sd->pmid_lora->applied = true;
LOG_INFO("pmid_lora apply completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
if (sd_ctx->sd->free_params_immediately) {
sd_ctx->sd->pmid_lora->free_params_buffer();
}
}
// preprocess input id images
std::vector<sd_image_t*> input_id_images;
if (sd_ctx->sd->pmid_model && input_id_images_path.size() > 0) {
std::vector<std::string> img_files = get_files_from_dir(input_id_images_path);
for (std::string img_file : img_files) {
int c = 0;
int width, height;
uint8_t* input_image_buffer = stbi_load(img_file.c_str(), &width, &height, &c, 3);
if (input_image_buffer == NULL) {
LOG_ERROR("PhotoMaker load image from '%s' failed", img_file.c_str());
continue;
} else {
LOG_INFO("PhotoMaker loaded image from '%s'", img_file.c_str());
}
sd_image_t* input_image = NULL;
input_image = new sd_image_t{(uint32_t)width,
(uint32_t)height,
3,
input_image_buffer};
input_image = preprocess_id_image(input_image);
if (input_image == NULL) {
LOG_ERROR("preprocess input id image from '%s' failed", img_file.c_str());
continue;
}
input_id_images.push_back(input_image);
}
}
if (input_id_images.size() > 0) { if (input_id_images.size() > 0) {
sd_ctx->sd->pmid_model->style_strength = style_ratio; sd_ctx->sd->pmid_model->style_strength = style_ratio;
int32_t w = input_id_images[0]->width; int32_t w = input_id_images[0]->width;
@ -1682,21 +1657,22 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
prompt_text_only = sd_ctx->sd->remove_trigger_from_prompt(work_ctx, prompt); prompt_text_only = sd_ctx->sd->remove_trigger_from_prompt(work_ctx, prompt);
// printf("%s || %s \n", prompt.c_str(), prompt_text_only.c_str()); // printf("%s || %s \n", prompt.c_str(), prompt_text_only.c_str());
prompt = prompt_text_only; // prompt = prompt_text_only; //
if (sample_steps < 50) { // if (sample_steps < 50) {
LOG_INFO("sampling steps increases from %d to 50 for PHOTOMAKER", sample_steps); // LOG_INFO("sampling steps increases from %d to 50 for PHOTOMAKER", sample_steps);
sample_steps = 50; // sample_steps = 50;
} // }
} else { } else {
LOG_WARN("Provided PhotoMaker model file, but NO input ID images"); LOG_WARN("Provided PhotoMaker model file, but NO input ID images");
LOG_WARN("Turn off PhotoMaker"); LOG_WARN("Turn off PhotoMaker");
sd_ctx->sd->stacked_id = false; sd_ctx->sd->stacked_id = false;
} }
for (sd_image_t* img : input_id_images) {
free(img->data);
}
input_id_images.clear();
} }
for (sd_image_t* img : input_id_images) {
free(img->data);
}
input_id_images.clear();
// Get learned condition
t0 = ggml_time_ms(); t0 = ggml_time_ms();
auto cond_pair = sd_ctx->sd->get_learned_condition(work_ctx, prompt, clip_skip, width, height); auto cond_pair = sd_ctx->sd->get_learned_condition(work_ctx, prompt, clip_skip, width, height);
ggml_tensor* c = cond_pair.first; ggml_tensor* c = cond_pair.first;
@ -1720,12 +1696,14 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
sd_ctx->sd->cond_stage_model->free_params_buffer(); sd_ctx->sd->cond_stage_model->free_params_buffer();
} }
// Control net hint
struct ggml_tensor* image_hint = NULL; struct ggml_tensor* image_hint = NULL;
if (control_cond != NULL) { if (control_cond != NULL) {
image_hint = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1); image_hint = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1);
sd_image_to_tensor(control_cond->data, image_hint); sd_image_to_tensor(control_cond->data, image_hint);
} }
// Sample
std::vector<struct ggml_tensor*> final_latents; // collect latents to decode std::vector<struct ggml_tensor*> final_latents; // collect latents to decode
int C = 4; int C = 4;
int W = width / 8; int W = width / 8;
@ -1737,22 +1715,28 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
LOG_INFO("generating image: %i/%i - seed %i", b + 1, batch_count, cur_seed); LOG_INFO("generating image: %i/%i - seed %i", b + 1, batch_count, cur_seed);
sd_ctx->sd->rng->manual_seed(cur_seed); sd_ctx->sd->rng->manual_seed(cur_seed);
struct ggml_tensor* x_t = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, C, 1); struct ggml_tensor* x_t = NULL;
ggml_tensor_set_f32_randn(x_t, sd_ctx->sd->rng); struct ggml_tensor* noise = NULL;
if (init_latent == NULL) {
std::vector<float> sigmas = sd_ctx->sd->denoiser->schedule->get_sigmas(sample_steps); x_t = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, C, 1);
ggml_tensor_set_f32_randn(x_t, sd_ctx->sd->rng);
} else {
x_t = init_latent;
noise = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, C, 1);
ggml_tensor_set_f32_randn(noise, sd_ctx->sd->rng);
}
int start_merge_step = -1; int start_merge_step = -1;
if (sd_ctx->sd->stacked_id) { if (sd_ctx->sd->stacked_id) {
start_merge_step = int(sd_ctx->sd->pmid_model->style_strength / 100.f * sample_steps); start_merge_step = int(sd_ctx->sd->pmid_model->style_strength / 100.f * sample_steps);
if (start_merge_step > 30) // if (start_merge_step > 30)
start_merge_step = 30; // start_merge_step = 30;
LOG_INFO("PHOTOMAKER: start_merge_step: %d", start_merge_step); LOG_INFO("PHOTOMAKER: start_merge_step: %d", start_merge_step);
} }
struct ggml_tensor* x_0 = sd_ctx->sd->sample(work_ctx, struct ggml_tensor* x_0 = sd_ctx->sd->sample(work_ctx,
x_t, x_t,
NULL, noise,
c, c,
NULL, NULL,
c_vector, c_vector,
@ -1781,6 +1765,7 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
int64_t t3 = ggml_time_ms(); int64_t t3 = ggml_time_ms();
LOG_INFO("generating %" PRId64 " latent images completed, taking %.2fs", final_latents.size(), (t3 - t1) * 1.0f / 1000); LOG_INFO("generating %" PRId64 " latent images completed, taking %.2fs", final_latents.size(), (t3 - t1) * 1.0f / 1000);
// Decode to image
LOG_INFO("decoding %zu latents", final_latents.size()); LOG_INFO("decoding %zu latents", final_latents.size());
std::vector<struct ggml_tensor*> decoded_images; // collect decoded images std::vector<struct ggml_tensor*> decoded_images; // collect decoded images
for (size_t i = 0; i < final_latents.size(); i++) { for (size_t i = 0; i < final_latents.size(); i++) {
@ -1812,9 +1797,74 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
result_images[i].data = sd_tensor_to_image(decoded_images[i]); result_images[i].data = sd_tensor_to_image(decoded_images[i]);
} }
ggml_free(work_ctx); ggml_free(work_ctx);
LOG_INFO(
"txt2img completed in %.2fs", return result_images;
(t4 - t0) * 1.0f / 1000); }
sd_image_t* txt2img(sd_ctx_t* sd_ctx,
const char* prompt_c_str,
const char* negative_prompt_c_str,
int clip_skip,
float cfg_scale,
int width,
int height,
enum sample_method_t sample_method,
int sample_steps,
int64_t seed,
int batch_count,
const sd_image_t* control_cond,
float control_strength,
float style_ratio,
bool normalize_input,
const char* input_id_images_path_c_str) {
LOG_DEBUG("txt2img %dx%d", width, height);
if (sd_ctx == NULL) {
return NULL;
}
struct ggml_init_params params;
params.mem_size = static_cast<size_t>(10 * 1024 * 1024); // 10 MB
if (sd_ctx->sd->stacked_id) {
params.mem_size += static_cast<size_t>(10 * 1024 * 1024); // 10 MB
}
params.mem_size += width * height * 3 * sizeof(float);
params.mem_size *= batch_count;
params.mem_buffer = NULL;
params.no_alloc = false;
// LOG_DEBUG("mem_size %u ", params.mem_size);
struct ggml_context* work_ctx = ggml_init(params);
if (!work_ctx) {
LOG_ERROR("ggml_init() failed");
return NULL;
}
size_t t0 = ggml_time_ms();
std::vector<float> sigmas = sd_ctx->sd->denoiser->schedule->get_sigmas(sample_steps);
sd_image_t* result_images = generate_image(sd_ctx,
work_ctx,
NULL,
prompt_c_str,
negative_prompt_c_str,
clip_skip,
cfg_scale,
width,
height,
sample_method,
sigmas,
seed,
batch_count,
control_cond,
control_strength,
style_ratio,
normalize_input,
input_id_images_path_c_str);
size_t t1 = ggml_time_ms();
LOG_INFO("txt2img completed in %.2fs", (t1 - t0) * 1.0f / 1000);
return result_images; return result_images;
} }
@ -1831,59 +1881,44 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
int sample_steps, int sample_steps,
float strength, float strength,
int64_t seed, int64_t seed,
int batch_count) { int batch_count,
const sd_image_t* control_cond,
float control_strength,
float style_ratio,
bool normalize_input,
const char* input_id_images_path_c_str) {
LOG_DEBUG("img2img %dx%d", width, height);
if (sd_ctx == NULL) { if (sd_ctx == NULL) {
return NULL; return NULL;
} }
std::string prompt(prompt_c_str);
std::string negative_prompt(negative_prompt_c_str);
LOG_INFO("img2img %dx%d", width, height);
std::vector<float> sigmas = sd_ctx->sd->denoiser->schedule->get_sigmas(sample_steps);
size_t t_enc = static_cast<size_t>(sample_steps * strength);
LOG_INFO("target t_enc is %zu steps", t_enc);
std::vector<float> sigma_sched;
sigma_sched.assign(sigmas.begin() + sample_steps - t_enc - 1, sigmas.end());
struct ggml_init_params params; struct ggml_init_params params;
params.mem_size = static_cast<size_t>(10 * 1024) * 1024; // 10 MB params.mem_size = static_cast<size_t>(10 * 1024 * 1024); // 10 MB
if (sd_ctx->sd->stacked_id) {
params.mem_size += static_cast<size_t>(10 * 1024 * 1024); // 10 MB
}
params.mem_size += width * height * 3 * sizeof(float) * 2; params.mem_size += width * height * 3 * sizeof(float) * 2;
params.mem_size *= batch_count;
params.mem_buffer = NULL; params.mem_buffer = NULL;
params.no_alloc = false; params.no_alloc = false;
// LOG_DEBUG("mem_size %u ", params.mem_size); // LOG_DEBUG("mem_size %u ", params.mem_size);
// draft context
struct ggml_context* work_ctx = ggml_init(params); struct ggml_context* work_ctx = ggml_init(params);
if (!work_ctx) { if (!work_ctx) {
LOG_ERROR("ggml_init() failed"); LOG_ERROR("ggml_init() failed");
return NULL; return NULL;
} }
size_t t0 = ggml_time_ms();
if (seed < 0) { if (seed < 0) {
seed = (int)time(NULL); srand((int)time(NULL));
seed = rand();
} }
sd_ctx->sd->rng->manual_seed(seed); sd_ctx->sd->rng->manual_seed(seed);
// extract and remove lora
auto result_pair = extract_and_remove_lora(prompt);
std::unordered_map<std::string, float> lora_f2m = result_pair.first; // lora_name -> multiplier
for (auto& kv : lora_f2m) {
LOG_DEBUG("lora %s:%.2f", kv.first.c_str(), kv.second);
}
prompt = result_pair.second;
LOG_DEBUG("prompt after extract and remove lora: \"%s\"", prompt.c_str());
// load lora from file
int64_t t0 = ggml_time_ms();
sd_ctx->sd->apply_loras(lora_f2m);
int64_t t1 = ggml_time_ms();
LOG_INFO("apply_loras completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
ggml_tensor* init_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1); ggml_tensor* init_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1);
sd_image_to_tensor(init_image.data, init_img); sd_image_to_tensor(init_image.data, init_img);
t0 = ggml_time_ms();
ggml_tensor* init_latent = NULL; ggml_tensor* init_latent = NULL;
if (!sd_ctx->sd->use_tiny_autoencoder) { if (!sd_ctx->sd->use_tiny_autoencoder) {
ggml_tensor* moments = sd_ctx->sd->encode_first_stage(work_ctx, init_img); ggml_tensor* moments = sd_ctx->sd->encode_first_stage(work_ctx, init_img);
@ -1892,87 +1927,37 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
init_latent = sd_ctx->sd->encode_first_stage(work_ctx, init_img); init_latent = sd_ctx->sd->encode_first_stage(work_ctx, init_img);
} }
// print_ggml_tensor(init_latent); // print_ggml_tensor(init_latent);
t1 = ggml_time_ms(); size_t t1 = ggml_time_ms();
LOG_INFO("encode_first_stage completed, taking %.2fs", (t1 - t0) * 1.0f / 1000); LOG_INFO("encode_first_stage completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
auto cond_pair = sd_ctx->sd->get_learned_condition(work_ctx, prompt, clip_skip, width, height); std::vector<float> sigmas = sd_ctx->sd->denoiser->schedule->get_sigmas(sample_steps);
ggml_tensor* c = cond_pair.first; size_t t_enc = static_cast<size_t>(sample_steps * strength);
ggml_tensor* c_vector = cond_pair.second; // [adm_in_channels, ] LOG_INFO("target t_enc is %zu steps", t_enc);
struct ggml_tensor* uc = NULL; std::vector<float> sigma_sched;
struct ggml_tensor* uc_vector = NULL; sigma_sched.assign(sigmas.begin() + sample_steps - t_enc - 1, sigmas.end());
if (cfg_scale != 1.0) {
bool force_zero_embeddings = false;
if (sd_ctx->sd->version == VERSION_XL && negative_prompt.size() == 0) {
force_zero_embeddings = true;
}
auto uncond_pair = sd_ctx->sd->get_learned_condition(work_ctx, negative_prompt, clip_skip, width, height, force_zero_embeddings);
uc = uncond_pair.first;
uc_vector = uncond_pair.second; // [adm_in_channels, ]
}
int64_t t2 = ggml_time_ms();
LOG_INFO("get_learned_condition completed, taking %" PRId64 " ms", t2 - t1);
if (sd_ctx->sd->free_params_immediately) {
sd_ctx->sd->cond_stage_model->free_params_buffer();
}
sd_ctx->sd->rng->manual_seed(seed); sd_image_t* result_images = generate_image(sd_ctx,
struct ggml_tensor* noise = ggml_dup_tensor(work_ctx, init_latent); work_ctx,
ggml_tensor_set_f32_randn(noise, sd_ctx->sd->rng); init_latent,
prompt_c_str,
negative_prompt_c_str,
clip_skip,
cfg_scale,
width,
height,
sample_method,
sigma_sched,
seed,
batch_count,
control_cond,
control_strength,
style_ratio,
normalize_input,
input_id_images_path_c_str);
LOG_INFO("sampling using %s method", sampling_methods_str[sample_method]); size_t t2 = ggml_time_ms();
struct ggml_tensor* x_0 = sd_ctx->sd->sample(work_ctx,
init_latent,
noise,
c,
NULL,
c_vector,
uc,
NULL,
uc_vector,
{},
0.f,
cfg_scale,
cfg_scale,
sample_method,
sigma_sched,
-1,
NULL,
NULL);
// struct ggml_tensor *x_0 = load_tensor_from_file(ctx, "samples_ddim.bin");
// print_ggml_tensor(x_0);
int64_t t3 = ggml_time_ms();
LOG_INFO("sampling completed, taking %.2fs", (t3 - t2) * 1.0f / 1000);
if (sd_ctx->sd->free_params_immediately) {
sd_ctx->sd->diffusion_model->free_params_buffer();
}
struct ggml_tensor* img = sd_ctx->sd->decode_first_stage(work_ctx, x_0); LOG_INFO("img2img completed in %.2fs", (t1 - t0) * 1.0f / 1000);
if (sd_ctx->sd->free_params_immediately && !sd_ctx->sd->use_tiny_autoencoder) {
sd_ctx->sd->first_stage_model->free_params_buffer();
}
if (img == NULL) {
ggml_free(work_ctx);
return NULL;
}
sd_image_t* result_images = (sd_image_t*)calloc(1, sizeof(sd_image_t));
if (result_images == NULL) {
ggml_free(work_ctx);
return NULL;
}
for (size_t i = 0; i < 1; i++) {
result_images[i].width = width;
result_images[i].height = height;
result_images[i].channel = 3;
result_images[i].data = sd_tensor_to_image(img);
}
ggml_free(work_ctx);
int64_t t4 = ggml_time_ms();
LOG_INFO("decode_first_stage completed, taking %.2fs", (t4 - t3) * 1.0f / 1000);
LOG_INFO("img2img completed in %.2fs", (t4 - t0) * 1.0f / 1000);
return result_images; return result_images;
} }

View File

@ -160,7 +160,12 @@ SD_API sd_image_t* img2img(sd_ctx_t* sd_ctx,
int sample_steps, int sample_steps,
float strength, float strength,
int64_t seed, int64_t seed,
int batch_count); int batch_count,
const sd_image_t* control_cond,
float control_strength,
float style_strength,
bool normalize_input,
const char* input_id_images_path);
SD_API sd_image_t* img2vid(sd_ctx_t* sd_ctx, SD_API sd_image_t* img2vid(sd_ctx_t* sd_ctx,
sd_image_t init_image, sd_image_t init_image,

View File

@ -201,7 +201,7 @@ struct TinyAutoEncoder : public GGMLModule {
} }
bool load_from_file(const std::string& file_path) { bool load_from_file(const std::string& file_path) {
LOG_INFO("loading taesd from '%s'", file_path.c_str()); LOG_INFO("loading taesd from '%s', decode_only = %s", file_path.c_str(), decode_only ? "true" : "false");
alloc_params_buffer(); alloc_params_buffer();
std::map<std::string, ggml_tensor*> taesd_tensors; std::map<std::string, ggml_tensor*> taesd_tensors;
taesd.get_param_tensors(taesd_tensors); taesd.get_param_tensors(taesd_tensors);