feat: support Inpaint models (#511)
This commit is contained in:
@@ -85,6 +85,7 @@ struct SDParams {
|
||||
std::string lora_model_dir;
|
||||
std::string output_path = "output.png";
|
||||
std::string input_path;
|
||||
std::string mask_path;
|
||||
std::string control_image_path;
|
||||
|
||||
std::string prompt;
|
||||
@@ -148,6 +149,7 @@ void print_params(SDParams params) {
|
||||
printf(" normalize input image : %s\n", params.normalize_input ? "true" : "false");
|
||||
printf(" output_path: %s\n", params.output_path.c_str());
|
||||
printf(" init_img: %s\n", params.input_path.c_str());
|
||||
printf(" mask_img: %s\n", params.mask_path.c_str());
|
||||
printf(" control_image: %s\n", params.control_image_path.c_str());
|
||||
printf(" clip on cpu: %s\n", params.clip_on_cpu ? "true" : "false");
|
||||
printf(" controlnet cpu: %s\n", params.control_net_cpu ? "true" : "false");
|
||||
@@ -384,6 +386,12 @@ void parse_args(int argc, const char** argv, SDParams& params) {
|
||||
break;
|
||||
}
|
||||
params.input_path = argv[i];
|
||||
} else if (arg == "--mask") {
|
||||
if (++i >= argc) {
|
||||
invalid_arg = true;
|
||||
break;
|
||||
}
|
||||
params.mask_path = argv[i];
|
||||
} else if (arg == "--control-image") {
|
||||
if (++i >= argc) {
|
||||
invalid_arg = true;
|
||||
@@ -803,6 +811,8 @@ int main(int argc, const char* argv[]) {
|
||||
bool vae_decode_only = true;
|
||||
uint8_t* input_image_buffer = NULL;
|
||||
uint8_t* control_image_buffer = NULL;
|
||||
uint8_t* mask_image_buffer = NULL;
|
||||
|
||||
if (params.mode == IMG2IMG || params.mode == IMG2VID) {
|
||||
vae_decode_only = false;
|
||||
|
||||
@@ -907,6 +917,18 @@ int main(int argc, const char* argv[]) {
|
||||
}
|
||||
}
|
||||
|
||||
if (params.mask_path != "") {
|
||||
int c = 0;
|
||||
mask_image_buffer = stbi_load(params.mask_path.c_str(), ¶ms.width, ¶ms.height, &c, 1);
|
||||
} else {
|
||||
std::vector<uint8_t> arr(params.width * params.height, 255);
|
||||
mask_image_buffer = arr.data();
|
||||
}
|
||||
sd_image_t mask_image = {(uint32_t)params.width,
|
||||
(uint32_t)params.height,
|
||||
1,
|
||||
mask_image_buffer};
|
||||
|
||||
sd_image_t* results;
|
||||
if (params.mode == TXT2IMG) {
|
||||
results = txt2img(sd_ctx,
|
||||
@@ -976,6 +998,7 @@ int main(int argc, const char* argv[]) {
|
||||
} else {
|
||||
results = img2img(sd_ctx,
|
||||
input_image,
|
||||
mask_image,
|
||||
params.prompt.c_str(),
|
||||
params.negative_prompt.c_str(),
|
||||
params.clip_skip,
|
||||
|
||||
Reference in New Issue
Block a user