feat: remove type restrictions (#489)

This commit is contained in:
stduhpf 2024-11-30 07:22:15 +01:00 committed by GitHub
parent 7ce63e740c
commit 9148b980be
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -196,7 +196,7 @@ void print_usage(int argc, const char* argv[]) {
printf(" --normalize-input normalize PHOTOMAKER input id images\n"); printf(" --normalize-input normalize PHOTOMAKER input id images\n");
printf(" --upscale-model [ESRGAN_PATH] path to esrgan model. Upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now\n"); printf(" --upscale-model [ESRGAN_PATH] path to esrgan model. Upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now\n");
printf(" --upscale-repeats Run the ESRGAN upscaler this many times (default 1)\n"); printf(" --upscale-repeats Run the ESRGAN upscaler this many times (default 1)\n");
printf(" --type [TYPE] weight type (f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_k, q3_k, q4_k)\n"); printf(" --type [TYPE] weight type (examples: f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_K, q3_K, q4_K)\n");
printf(" If not specified, the default is the type of the weight file\n"); printf(" If not specified, the default is the type of the weight file\n");
printf(" --lora-model-dir [DIR] lora model directory\n"); printf(" --lora-model-dir [DIR] lora model directory\n");
printf(" -i, --init-img [IMAGE] path to the input image, required by img2img\n"); printf(" -i, --init-img [IMAGE] path to the input image, required by img2img\n");
@ -346,30 +346,30 @@ void parse_args(int argc, const char** argv, SDParams& params) {
invalid_arg = true; invalid_arg = true;
break; break;
} }
std::string type = argv[i]; std::string type = argv[i];
if (type == "f32") { bool found = false;
params.wtype = SD_TYPE_F32; std::string valid_types = "";
} else if (type == "f16") { for (size_t i = 0; i < SD_TYPE_COUNT; i++) {
params.wtype = SD_TYPE_F16; auto trait = ggml_get_type_traits((ggml_type)i);
} else if (type == "q4_0") { std::string name(trait->type_name);
params.wtype = SD_TYPE_Q4_0; if (name == "f32" || trait->to_float && trait->type_size) {
} else if (type == "q4_1") { if (i)
params.wtype = SD_TYPE_Q4_1; valid_types += ", ";
} else if (type == "q5_0") { valid_types += name;
params.wtype = SD_TYPE_Q5_0; if (type == name) {
} else if (type == "q5_1") { if (ggml_quantize_requires_imatrix((ggml_type)i)) {
params.wtype = SD_TYPE_Q5_1; printf("\033[35;1m[WARNING]\033[0m: type %s requires imatrix to work properly. A dummy imatrix will be used, expect poor quality.\n", trait->type_name);
} else if (type == "q8_0") { }
params.wtype = SD_TYPE_Q8_0; params.wtype = (enum sd_type_t)i;
} else if (type == "q2_k") { found = true;
params.wtype = SD_TYPE_Q2_K; break;
} else if (type == "q3_k") { }
params.wtype = SD_TYPE_Q3_K; }
} else if (type == "q4_k") { }
params.wtype = SD_TYPE_Q4_K; if (!found) {
} else { fprintf(stderr, "error: invalid weight format %s, must be one of [%s]\n",
fprintf(stderr, "error: invalid weight format %s, must be one of [f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_k, q3_k, q4_k]\n", type.c_str(),
type.c_str()); valid_types.c_str());
exit(1); exit(1);
} }
} else if (arg == "--lora-model-dir") { } else if (arg == "--lora-model-dir") {