feat: add CUDA RNG

This commit is contained in:
leejet
2023-09-03 19:24:07 +08:00
parent 31e77e1573
commit e5a7aec252
6 changed files with 217 additions and 26 deletions

View File

@@ -67,6 +67,11 @@ int32_t get_num_physical_cores() {
return n_threads > 0 ? (n_threads <= 4 ? n_threads : n_threads / 2) : 4;
}
const char* rng_type_to_str[] = {
"std_default",
"cuda",
};
struct Option {
int n_threads = -1;
std::string mode = TXT2IMG;
@@ -81,6 +86,7 @@ struct Option {
SampleMethod sample_method = EULAR_A;
int sample_steps = 20;
float strength = 0.75f;
RNGType rng_type = STD_DEFAULT_RNG;
int seed = 42;
bool verbose = false;
@@ -99,6 +105,7 @@ struct Option {
printf(" sample_method: %s\n", "eular a");
printf(" sample_steps: %d\n", sample_steps);
printf(" strength: %.2f\n", strength);
printf(" rng: %s\n", rng_type_to_str[rng_type]);
printf(" seed: %d\n", seed);
}
};
@@ -123,6 +130,7 @@ void print_usage(int argc, const char* argv[]) {
printf(" -W, --width W image width, in pixel space (default: 512)\n");
printf(" --sample-method SAMPLE_METHOD sample method (default: \"eular a\")\n");
printf(" --steps STEPS number of sample steps (default: 20)\n");
printf(" --rng {std_default, cuda} RNG (default: std_default)\n");
printf(" -s SEED, --seed SEED RNG seed (default: 42, use random seed for < 0)\n");
printf(" -v, --verbose print extra info\n");
}
@@ -206,6 +214,20 @@ void parse_args(int argc, const char* argv[], Option* opt) {
break;
}
opt->sample_steps = std::stoi(argv[i]);
} else if (arg == "--rng") {
if (++i >= argc) {
invalid_arg = true;
break;
}
std::string rng_type_str = argv[i];
if (rng_type_str == "std_default") {
opt->rng_type = STD_DEFAULT_RNG;
} else if (rng_type_str == "cuda") {
opt->rng_type = CUDA_RNG;
} else {
invalid_arg = true;
break;
}
} else if (arg == "-s" || arg == "--seed") {
if (++i >= argc) {
invalid_arg = true;
@@ -328,7 +350,7 @@ int main(int argc, const char* argv[]) {
init_img.assign(img_data, img_data + (opt.w * opt.h * c));
}
StableDiffusion sd(opt.n_threads, vae_decode_only, true);
StableDiffusion sd(opt.n_threads, vae_decode_only, true, opt.rng_type);
if (!sd.load_from_file(opt.model_path)) {
return 1;
}