feat: enable controlnet and photo maker for img2img mode

This commit is contained in:
leejet 2024-04-14 16:36:08 +08:00
parent ec82d5279a
commit 036ba9e6d8
5 changed files with 255 additions and 256 deletions

View File

@ -656,13 +656,16 @@ int main(int argc, const char* argv[]) {
return 1;
}
bool vae_decode_only = true;
uint8_t* input_image_buffer = NULL;
bool vae_decode_only = true;
uint8_t* input_image_buffer = NULL;
uint8_t* control_image_buffer = NULL;
if (params.mode == IMG2IMG || params.mode == IMG2VID) {
vae_decode_only = false;
int c = 0;
input_image_buffer = stbi_load(params.input_path.c_str(), &params.width, &params.height, &c, 3);
int width = 0;
int height = 0;
input_image_buffer = stbi_load(params.input_path.c_str(), &width, &height, &c, 3);
if (input_image_buffer == NULL) {
fprintf(stderr, "load image from '%s' failed\n", params.input_path.c_str());
return 1;
@ -672,21 +675,22 @@ int main(int argc, const char* argv[]) {
free(input_image_buffer);
return 1;
}
if (params.width <= 0) {
if (width <= 0) {
fprintf(stderr, "error: the width of image must be greater than 0\n");
free(input_image_buffer);
return 1;
}
if (params.height <= 0) {
if (height <= 0) {
fprintf(stderr, "error: the height of image must be greater than 0\n");
free(input_image_buffer);
return 1;
}
// Resize input image ...
if (params.height % 64 != 0 || params.width % 64 != 0) {
int resized_height = params.height + (64 - params.height % 64);
int resized_width = params.width + (64 - params.width % 64);
if (params.height != height || params.width != width) {
printf("resize input image from %dx%d to %dx%d\n", width, height, params.width, params.height);
int resized_height = params.height;
int resized_width = params.width;
uint8_t* resized_image_buffer = (uint8_t*)malloc(resized_height * resized_width * 3);
if (resized_image_buffer == NULL) {
@ -694,7 +698,7 @@ int main(int argc, const char* argv[]) {
free(input_image_buffer);
return 1;
}
stbir_resize(input_image_buffer, params.width, params.height, 0,
stbir_resize(input_image_buffer, width, height, 0,
resized_image_buffer, resized_width, resized_height, 0, STBIR_TYPE_UINT8,
3 /*RGB channel*/, STBIR_ALPHA_CHANNEL_NONE, 0,
STBIR_EDGE_CLAMP, STBIR_EDGE_CLAMP,
@ -704,8 +708,6 @@ int main(int argc, const char* argv[]) {
// Save resized result
free(input_image_buffer);
input_image_buffer = resized_image_buffer;
params.height = resized_height;
params.width = resized_width;
}
}
@ -732,31 +734,32 @@ int main(int argc, const char* argv[]) {
return 1;
}
sd_image_t* control_image = NULL;
if (params.controlnet_path.size() > 0 && params.control_image_path.size() > 0) {
int c = 0;
control_image_buffer = stbi_load(params.control_image_path.c_str(), &params.width, &params.height, &c, 3);
if (control_image_buffer == NULL) {
fprintf(stderr, "load image from '%s' failed\n", params.control_image_path.c_str());
return 1;
}
control_image = new sd_image_t{(uint32_t)params.width,
(uint32_t)params.height,
3,
control_image_buffer};
if (params.canny_preprocess) { // apply preprocessor
control_image->data = preprocess_canny(control_image->data,
control_image->width,
control_image->height,
0.08f,
0.08f,
0.8f,
1.0f,
false);
}
}
sd_image_t* results;
if (params.mode == TXT2IMG) {
sd_image_t* control_image = NULL;
if (params.controlnet_path.size() > 0 && params.control_image_path.size() > 0) {
int c = 0;
input_image_buffer = stbi_load(params.control_image_path.c_str(), &params.width, &params.height, &c, 3);
if (input_image_buffer == NULL) {
fprintf(stderr, "load image from '%s' failed\n", params.control_image_path.c_str());
return 1;
}
control_image = new sd_image_t{(uint32_t)params.width,
(uint32_t)params.height,
3,
input_image_buffer};
if (params.canny_preprocess) { // apply preprocessor
control_image->data = preprocess_canny(control_image->data,
control_image->width,
control_image->height,
0.08f,
0.08f,
0.8f,
1.0f,
false);
}
}
results = txt2img(sd_ctx,
params.prompt.c_str(),
params.negative_prompt.c_str(),
@ -828,7 +831,12 @@ int main(int argc, const char* argv[]) {
params.sample_steps,
params.strength,
params.seed,
params.batch_count);
params.batch_count,
control_image,
params.control_strength,
params.style_ratio,
params.normalize_input,
params.input_id_images_path.c_str());
}
}
@ -881,6 +889,8 @@ int main(int argc, const char* argv[]) {
}
free(results);
free_sd_ctx(sd_ctx);
free(control_image_buffer);
free(input_image_buffer);
return 0;
}

View File

@ -752,10 +752,9 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_timestep_embedding(
return ggml_timestep_embedding(ctx, timesteps, dim, max_period);
}
__STATIC_INLINE__ size_t ggml_tensor_num(ggml_context * ctx) {
__STATIC_INLINE__ size_t ggml_tensor_num(ggml_context* ctx) {
size_t num = 0;
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) {
for (ggml_tensor* t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) {
num++;
}
return num;
@ -851,7 +850,7 @@ protected:
}
public:
virtual std::string get_desc() = 0;
virtual std::string get_desc() = 0;
GGMLModule(ggml_backend_t backend, ggml_type wtype = GGML_TYPE_F32)
: backend(backend), wtype(wtype) {

View File

@ -852,7 +852,6 @@ public:
copy_ggml_tensor(x, x_t);
struct ggml_tensor* noised_input = ggml_dup_tensor(work_ctx, x_t);
struct ggml_tensor* guided_hint = NULL;
bool has_unconditioned = cfg_scale != 1.0 && uc != NULL;
@ -1536,60 +1535,35 @@ void free_sd_ctx(sd_ctx_t* sd_ctx) {
free(sd_ctx);
}
sd_image_t* txt2img(sd_ctx_t* sd_ctx,
const char* prompt_c_str,
const char* negative_prompt_c_str,
int clip_skip,
float cfg_scale,
int width,
int height,
enum sample_method_t sample_method,
int sample_steps,
int64_t seed,
int batch_count,
const sd_image_t* control_cond,
float control_strength,
float style_ratio,
bool normalize_input,
const char* input_id_images_path_c_str) {
LOG_DEBUG("txt2img %dx%d", width, height);
if (sd_ctx == NULL) {
return NULL;
}
// LOG_DEBUG("%s %s %f %d %d %d", prompt_c_str, negative_prompt_c_str, cfg_scale, sample_steps, seed, batch_count);
std::string prompt(prompt_c_str);
std::string negative_prompt(negative_prompt_c_str);
std::string input_id_images_path(input_id_images_path_c_str);
// preprocess input id images
std::vector<sd_image_t*> input_id_images;
if (sd_ctx->sd->pmid_model && input_id_images_path.size() > 0) {
std::vector<std::string> img_files = get_files_from_dir(input_id_images_path);
for (std::string img_file : img_files) {
int c = 0;
int width, height;
uint8_t* input_image_buffer = stbi_load(img_file.c_str(), &width, &height, &c, 3);
if (input_image_buffer == NULL) {
LOG_ERROR("PhotoMaker load image from '%s' failed", img_file.c_str());
continue;
} else {
LOG_INFO("PhotoMaker loaded image from '%s'", img_file.c_str());
}
sd_image_t* input_image = NULL;
input_image = new sd_image_t{(uint32_t)width,
(uint32_t)height,
3,
input_image_buffer};
input_image = preprocess_id_image(input_image);
if (input_image == NULL) {
LOG_ERROR("preprocess input id image from '%s' failed", img_file.c_str());
continue;
}
input_id_images.push_back(input_image);
}
sd_image_t* generate_image(sd_ctx_t* sd_ctx,
struct ggml_context* work_ctx,
ggml_tensor* init_latent,
std::string prompt,
std::string negative_prompt,
int clip_skip,
float cfg_scale,
int width,
int height,
enum sample_method_t sample_method,
const std::vector<float>& sigmas,
int64_t seed,
int batch_count,
const sd_image_t* control_cond,
float control_strength,
float style_ratio,
bool normalize_input,
std::string input_id_images_path) {
if (seed < 0) {
// Generally, when using the provided command line, the seed is always >0.
// However, to prevent potential issues if 'stable-diffusion.cpp' is invoked as a library
// by a third party with a seed <0, let's incorporate randomization here.
srand((int)time(NULL));
seed = rand();
}
// extract and remove lora
int sample_steps = sigmas.size() - 1;
// Apply lora
auto result_pair = extract_and_remove_lora(prompt);
std::unordered_map<std::string, float> lora_f2m = result_pair.first; // lora_name -> multiplier
@ -1605,49 +1579,50 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
int64_t t1 = ggml_time_ms();
LOG_INFO("apply_loras completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
if (sd_ctx->sd->stacked_id && !sd_ctx->sd->pmid_lora->applied) {
t0 = ggml_time_ms();
sd_ctx->sd->pmid_lora->apply(sd_ctx->sd->tensors, sd_ctx->sd->n_threads);
t1 = ggml_time_ms();
sd_ctx->sd->pmid_lora->applied = true;
LOG_INFO("pmid_lora apply completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
if (sd_ctx->sd->free_params_immediately) {
sd_ctx->sd->pmid_lora->free_params_buffer();
}
}
struct ggml_init_params params;
params.mem_size = static_cast<size_t>(10 * 1024 * 1024); // 10 MB
if (sd_ctx->sd->stacked_id) {
params.mem_size += static_cast<size_t>(10 * 1024 * 1024); // 10 MB
}
params.mem_size += width * height * 3 * sizeof(float);
params.mem_size *= batch_count;
params.mem_buffer = NULL;
params.no_alloc = false;
// LOG_DEBUG("mem_size %u ", params.mem_size);
struct ggml_context* work_ctx = ggml_init(params);
if (!work_ctx) {
LOG_ERROR("ggml_init() failed");
return NULL;
}
if (seed < 0) {
// Generally, when using the provided command line, the seed is always >0.
// However, to prevent potential issues if 'stable-diffusion.cpp' is invoked as a library
// by a third party with a seed <0, let's incorporate randomization here.
srand((int)time(NULL));
seed = rand();
}
// Photo Maker
std::string prompt_text_only;
ggml_tensor* init_img = NULL;
ggml_tensor* prompts_embeds = NULL;
ggml_tensor* pooled_prompts_embeds = NULL;
// ggml_tensor* class_tokens_mask = NULL;
std::vector<bool> class_tokens_mask;
if (sd_ctx->sd->stacked_id) {
if (!sd_ctx->sd->pmid_lora->applied) {
t0 = ggml_time_ms();
sd_ctx->sd->pmid_lora->apply(sd_ctx->sd->tensors, sd_ctx->sd->n_threads);
t1 = ggml_time_ms();
sd_ctx->sd->pmid_lora->applied = true;
LOG_INFO("pmid_lora apply completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
if (sd_ctx->sd->free_params_immediately) {
sd_ctx->sd->pmid_lora->free_params_buffer();
}
}
// preprocess input id images
std::vector<sd_image_t*> input_id_images;
if (sd_ctx->sd->pmid_model && input_id_images_path.size() > 0) {
std::vector<std::string> img_files = get_files_from_dir(input_id_images_path);
for (std::string img_file : img_files) {
int c = 0;
int width, height;
uint8_t* input_image_buffer = stbi_load(img_file.c_str(), &width, &height, &c, 3);
if (input_image_buffer == NULL) {
LOG_ERROR("PhotoMaker load image from '%s' failed", img_file.c_str());
continue;
} else {
LOG_INFO("PhotoMaker loaded image from '%s'", img_file.c_str());
}
sd_image_t* input_image = NULL;
input_image = new sd_image_t{(uint32_t)width,
(uint32_t)height,
3,
input_image_buffer};
input_image = preprocess_id_image(input_image);
if (input_image == NULL) {
LOG_ERROR("preprocess input id image from '%s' failed", img_file.c_str());
continue;
}
input_id_images.push_back(input_image);
}
}
if (input_id_images.size() > 0) {
sd_ctx->sd->pmid_model->style_strength = style_ratio;
int32_t w = input_id_images[0]->width;
@ -1682,21 +1657,22 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
prompt_text_only = sd_ctx->sd->remove_trigger_from_prompt(work_ctx, prompt);
// printf("%s || %s \n", prompt.c_str(), prompt_text_only.c_str());
prompt = prompt_text_only; //
if (sample_steps < 50) {
LOG_INFO("sampling steps increases from %d to 50 for PHOTOMAKER", sample_steps);
sample_steps = 50;
}
// if (sample_steps < 50) {
// LOG_INFO("sampling steps increases from %d to 50 for PHOTOMAKER", sample_steps);
// sample_steps = 50;
// }
} else {
LOG_WARN("Provided PhotoMaker model file, but NO input ID images");
LOG_WARN("Turn off PhotoMaker");
sd_ctx->sd->stacked_id = false;
}
for (sd_image_t* img : input_id_images) {
free(img->data);
}
input_id_images.clear();
}
for (sd_image_t* img : input_id_images) {
free(img->data);
}
input_id_images.clear();
// Get learned condition
t0 = ggml_time_ms();
auto cond_pair = sd_ctx->sd->get_learned_condition(work_ctx, prompt, clip_skip, width, height);
ggml_tensor* c = cond_pair.first;
@ -1720,12 +1696,14 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
sd_ctx->sd->cond_stage_model->free_params_buffer();
}
// Control net hint
struct ggml_tensor* image_hint = NULL;
if (control_cond != NULL) {
image_hint = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1);
sd_image_to_tensor(control_cond->data, image_hint);
}
// Sample
std::vector<struct ggml_tensor*> final_latents; // collect latents to decode
int C = 4;
int W = width / 8;
@ -1737,22 +1715,28 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
LOG_INFO("generating image: %i/%i - seed %i", b + 1, batch_count, cur_seed);
sd_ctx->sd->rng->manual_seed(cur_seed);
struct ggml_tensor* x_t = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, C, 1);
ggml_tensor_set_f32_randn(x_t, sd_ctx->sd->rng);
std::vector<float> sigmas = sd_ctx->sd->denoiser->schedule->get_sigmas(sample_steps);
struct ggml_tensor* x_t = NULL;
struct ggml_tensor* noise = NULL;
if (init_latent == NULL) {
x_t = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, C, 1);
ggml_tensor_set_f32_randn(x_t, sd_ctx->sd->rng);
} else {
x_t = init_latent;
noise = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, C, 1);
ggml_tensor_set_f32_randn(noise, sd_ctx->sd->rng);
}
int start_merge_step = -1;
if (sd_ctx->sd->stacked_id) {
start_merge_step = int(sd_ctx->sd->pmid_model->style_strength / 100.f * sample_steps);
if (start_merge_step > 30)
start_merge_step = 30;
// if (start_merge_step > 30)
// start_merge_step = 30;
LOG_INFO("PHOTOMAKER: start_merge_step: %d", start_merge_step);
}
struct ggml_tensor* x_0 = sd_ctx->sd->sample(work_ctx,
x_t,
NULL,
noise,
c,
NULL,
c_vector,
@ -1781,6 +1765,7 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
int64_t t3 = ggml_time_ms();
LOG_INFO("generating %" PRId64 " latent images completed, taking %.2fs", final_latents.size(), (t3 - t1) * 1.0f / 1000);
// Decode to image
LOG_INFO("decoding %zu latents", final_latents.size());
std::vector<struct ggml_tensor*> decoded_images; // collect decoded images
for (size_t i = 0; i < final_latents.size(); i++) {
@ -1812,9 +1797,74 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
result_images[i].data = sd_tensor_to_image(decoded_images[i]);
}
ggml_free(work_ctx);
LOG_INFO(
"txt2img completed in %.2fs",
(t4 - t0) * 1.0f / 1000);
return result_images;
}
sd_image_t* txt2img(sd_ctx_t* sd_ctx,
const char* prompt_c_str,
const char* negative_prompt_c_str,
int clip_skip,
float cfg_scale,
int width,
int height,
enum sample_method_t sample_method,
int sample_steps,
int64_t seed,
int batch_count,
const sd_image_t* control_cond,
float control_strength,
float style_ratio,
bool normalize_input,
const char* input_id_images_path_c_str) {
LOG_DEBUG("txt2img %dx%d", width, height);
if (sd_ctx == NULL) {
return NULL;
}
struct ggml_init_params params;
params.mem_size = static_cast<size_t>(10 * 1024 * 1024); // 10 MB
if (sd_ctx->sd->stacked_id) {
params.mem_size += static_cast<size_t>(10 * 1024 * 1024); // 10 MB
}
params.mem_size += width * height * 3 * sizeof(float);
params.mem_size *= batch_count;
params.mem_buffer = NULL;
params.no_alloc = false;
// LOG_DEBUG("mem_size %u ", params.mem_size);
struct ggml_context* work_ctx = ggml_init(params);
if (!work_ctx) {
LOG_ERROR("ggml_init() failed");
return NULL;
}
size_t t0 = ggml_time_ms();
std::vector<float> sigmas = sd_ctx->sd->denoiser->schedule->get_sigmas(sample_steps);
sd_image_t* result_images = generate_image(sd_ctx,
work_ctx,
NULL,
prompt_c_str,
negative_prompt_c_str,
clip_skip,
cfg_scale,
width,
height,
sample_method,
sigmas,
seed,
batch_count,
control_cond,
control_strength,
style_ratio,
normalize_input,
input_id_images_path_c_str);
size_t t1 = ggml_time_ms();
LOG_INFO("txt2img completed in %.2fs", (t1 - t0) * 1.0f / 1000);
return result_images;
}
@ -1831,59 +1881,44 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
int sample_steps,
float strength,
int64_t seed,
int batch_count) {
int batch_count,
const sd_image_t* control_cond,
float control_strength,
float style_ratio,
bool normalize_input,
const char* input_id_images_path_c_str) {
LOG_DEBUG("img2img %dx%d", width, height);
if (sd_ctx == NULL) {
return NULL;
}
std::string prompt(prompt_c_str);
std::string negative_prompt(negative_prompt_c_str);
LOG_INFO("img2img %dx%d", width, height);
std::vector<float> sigmas = sd_ctx->sd->denoiser->schedule->get_sigmas(sample_steps);
size_t t_enc = static_cast<size_t>(sample_steps * strength);
LOG_INFO("target t_enc is %zu steps", t_enc);
std::vector<float> sigma_sched;
sigma_sched.assign(sigmas.begin() + sample_steps - t_enc - 1, sigmas.end());
struct ggml_init_params params;
params.mem_size = static_cast<size_t>(10 * 1024) * 1024; // 10 MB
params.mem_size = static_cast<size_t>(10 * 1024 * 1024); // 10 MB
if (sd_ctx->sd->stacked_id) {
params.mem_size += static_cast<size_t>(10 * 1024 * 1024); // 10 MB
}
params.mem_size += width * height * 3 * sizeof(float) * 2;
params.mem_size *= batch_count;
params.mem_buffer = NULL;
params.no_alloc = false;
// LOG_DEBUG("mem_size %u ", params.mem_size);
// draft context
struct ggml_context* work_ctx = ggml_init(params);
if (!work_ctx) {
LOG_ERROR("ggml_init() failed");
return NULL;
}
size_t t0 = ggml_time_ms();
if (seed < 0) {
seed = (int)time(NULL);
srand((int)time(NULL));
seed = rand();
}
sd_ctx->sd->rng->manual_seed(seed);
// extract and remove lora
auto result_pair = extract_and_remove_lora(prompt);
std::unordered_map<std::string, float> lora_f2m = result_pair.first; // lora_name -> multiplier
for (auto& kv : lora_f2m) {
LOG_DEBUG("lora %s:%.2f", kv.first.c_str(), kv.second);
}
prompt = result_pair.second;
LOG_DEBUG("prompt after extract and remove lora: \"%s\"", prompt.c_str());
// load lora from file
int64_t t0 = ggml_time_ms();
sd_ctx->sd->apply_loras(lora_f2m);
int64_t t1 = ggml_time_ms();
LOG_INFO("apply_loras completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
ggml_tensor* init_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1);
sd_image_to_tensor(init_image.data, init_img);
t0 = ggml_time_ms();
ggml_tensor* init_latent = NULL;
if (!sd_ctx->sd->use_tiny_autoencoder) {
ggml_tensor* moments = sd_ctx->sd->encode_first_stage(work_ctx, init_img);
@ -1892,87 +1927,37 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
init_latent = sd_ctx->sd->encode_first_stage(work_ctx, init_img);
}
// print_ggml_tensor(init_latent);
t1 = ggml_time_ms();
size_t t1 = ggml_time_ms();
LOG_INFO("encode_first_stage completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
auto cond_pair = sd_ctx->sd->get_learned_condition(work_ctx, prompt, clip_skip, width, height);
ggml_tensor* c = cond_pair.first;
ggml_tensor* c_vector = cond_pair.second; // [adm_in_channels, ]
struct ggml_tensor* uc = NULL;
struct ggml_tensor* uc_vector = NULL;
if (cfg_scale != 1.0) {
bool force_zero_embeddings = false;
if (sd_ctx->sd->version == VERSION_XL && negative_prompt.size() == 0) {
force_zero_embeddings = true;
}
auto uncond_pair = sd_ctx->sd->get_learned_condition(work_ctx, negative_prompt, clip_skip, width, height, force_zero_embeddings);
uc = uncond_pair.first;
uc_vector = uncond_pair.second; // [adm_in_channels, ]
}
int64_t t2 = ggml_time_ms();
LOG_INFO("get_learned_condition completed, taking %" PRId64 " ms", t2 - t1);
if (sd_ctx->sd->free_params_immediately) {
sd_ctx->sd->cond_stage_model->free_params_buffer();
}
std::vector<float> sigmas = sd_ctx->sd->denoiser->schedule->get_sigmas(sample_steps);
size_t t_enc = static_cast<size_t>(sample_steps * strength);
LOG_INFO("target t_enc is %zu steps", t_enc);
std::vector<float> sigma_sched;
sigma_sched.assign(sigmas.begin() + sample_steps - t_enc - 1, sigmas.end());
sd_ctx->sd->rng->manual_seed(seed);
struct ggml_tensor* noise = ggml_dup_tensor(work_ctx, init_latent);
ggml_tensor_set_f32_randn(noise, sd_ctx->sd->rng);
sd_image_t* result_images = generate_image(sd_ctx,
work_ctx,
init_latent,
prompt_c_str,
negative_prompt_c_str,
clip_skip,
cfg_scale,
width,
height,
sample_method,
sigma_sched,
seed,
batch_count,
control_cond,
control_strength,
style_ratio,
normalize_input,
input_id_images_path_c_str);
LOG_INFO("sampling using %s method", sampling_methods_str[sample_method]);
struct ggml_tensor* x_0 = sd_ctx->sd->sample(work_ctx,
init_latent,
noise,
c,
NULL,
c_vector,
uc,
NULL,
uc_vector,
{},
0.f,
cfg_scale,
cfg_scale,
sample_method,
sigma_sched,
-1,
NULL,
NULL);
// struct ggml_tensor *x_0 = load_tensor_from_file(ctx, "samples_ddim.bin");
// print_ggml_tensor(x_0);
int64_t t3 = ggml_time_ms();
LOG_INFO("sampling completed, taking %.2fs", (t3 - t2) * 1.0f / 1000);
if (sd_ctx->sd->free_params_immediately) {
sd_ctx->sd->diffusion_model->free_params_buffer();
}
size_t t2 = ggml_time_ms();
struct ggml_tensor* img = sd_ctx->sd->decode_first_stage(work_ctx, x_0);
if (sd_ctx->sd->free_params_immediately && !sd_ctx->sd->use_tiny_autoencoder) {
sd_ctx->sd->first_stage_model->free_params_buffer();
}
if (img == NULL) {
ggml_free(work_ctx);
return NULL;
}
sd_image_t* result_images = (sd_image_t*)calloc(1, sizeof(sd_image_t));
if (result_images == NULL) {
ggml_free(work_ctx);
return NULL;
}
for (size_t i = 0; i < 1; i++) {
result_images[i].width = width;
result_images[i].height = height;
result_images[i].channel = 3;
result_images[i].data = sd_tensor_to_image(img);
}
ggml_free(work_ctx);
int64_t t4 = ggml_time_ms();
LOG_INFO("decode_first_stage completed, taking %.2fs", (t4 - t3) * 1.0f / 1000);
LOG_INFO("img2img completed in %.2fs", (t4 - t0) * 1.0f / 1000);
LOG_INFO("img2img completed in %.2fs", (t1 - t0) * 1.0f / 1000);
return result_images;
}

View File

@ -160,7 +160,12 @@ SD_API sd_image_t* img2img(sd_ctx_t* sd_ctx,
int sample_steps,
float strength,
int64_t seed,
int batch_count);
int batch_count,
const sd_image_t* control_cond,
float control_strength,
float style_strength,
bool normalize_input,
const char* input_id_images_path);
SD_API sd_image_t* img2vid(sd_ctx_t* sd_ctx,
sd_image_t init_image,

View File

@ -201,7 +201,7 @@ struct TinyAutoEncoder : public GGMLModule {
}
bool load_from_file(const std::string& file_path) {
LOG_INFO("loading taesd from '%s'", file_path.c_str());
LOG_INFO("loading taesd from '%s', decode_only = %s", file_path.c_str(), decode_only ? "true" : "false");
alloc_params_buffer();
std::map<std::string, ggml_tensor*> taesd_tensors;
taesd.get_param_tensors(taesd_tensors);