feat: support Inpaint models (#511)

This commit is contained in:
stduhpf
2024-12-28 06:04:49 +01:00
committed by GitHub
parent cc92a6a1b3
commit 8f4ab9add3
11 changed files with 383 additions and 64 deletions

View File

@@ -85,6 +85,7 @@ struct SDParams {
std::string lora_model_dir;
std::string output_path = "output.png";
std::string input_path;
std::string mask_path;
std::string control_image_path;
std::string prompt;
@@ -148,6 +149,7 @@ void print_params(SDParams params) {
printf(" normalize input image : %s\n", params.normalize_input ? "true" : "false");
printf(" output_path: %s\n", params.output_path.c_str());
printf(" init_img: %s\n", params.input_path.c_str());
printf(" mask_img: %s\n", params.mask_path.c_str());
printf(" control_image: %s\n", params.control_image_path.c_str());
printf(" clip on cpu: %s\n", params.clip_on_cpu ? "true" : "false");
printf(" controlnet cpu: %s\n", params.control_net_cpu ? "true" : "false");
@@ -384,6 +386,12 @@ void parse_args(int argc, const char** argv, SDParams& params) {
break;
}
params.input_path = argv[i];
} else if (arg == "--mask") {
if (++i >= argc) {
invalid_arg = true;
break;
}
params.mask_path = argv[i];
} else if (arg == "--control-image") {
if (++i >= argc) {
invalid_arg = true;
@@ -803,6 +811,8 @@ int main(int argc, const char* argv[]) {
bool vae_decode_only = true;
uint8_t* input_image_buffer = NULL;
uint8_t* control_image_buffer = NULL;
uint8_t* mask_image_buffer = NULL;
if (params.mode == IMG2IMG || params.mode == IMG2VID) {
vae_decode_only = false;
@@ -907,6 +917,18 @@ int main(int argc, const char* argv[]) {
}
}
if (params.mask_path != "") {
int c = 0;
mask_image_buffer = stbi_load(params.mask_path.c_str(), &params.width, &params.height, &c, 1);
} else {
std::vector<uint8_t> arr(params.width * params.height, 255);
mask_image_buffer = arr.data();
}
sd_image_t mask_image = {(uint32_t)params.width,
(uint32_t)params.height,
1,
mask_image_buffer};
sd_image_t* results;
if (params.mode == TXT2IMG) {
results = txt2img(sd_ctx,
@@ -976,6 +998,7 @@ int main(int argc, const char* argv[]) {
} else {
results = img2img(sd_ctx,
input_image,
mask_image,
params.prompt.c_str(),
params.negative_prompt.c_str(),
params.clip_skip,