feat: add CUDA RNG
This commit is contained in:
@@ -67,6 +67,11 @@ int32_t get_num_physical_cores() {
|
||||
return n_threads > 0 ? (n_threads <= 4 ? n_threads : n_threads / 2) : 4;
|
||||
}
|
||||
|
||||
const char* rng_type_to_str[] = {
|
||||
"std_default",
|
||||
"cuda",
|
||||
};
|
||||
|
||||
struct Option {
|
||||
int n_threads = -1;
|
||||
std::string mode = TXT2IMG;
|
||||
@@ -81,6 +86,7 @@ struct Option {
|
||||
SampleMethod sample_method = EULAR_A;
|
||||
int sample_steps = 20;
|
||||
float strength = 0.75f;
|
||||
RNGType rng_type = STD_DEFAULT_RNG;
|
||||
int seed = 42;
|
||||
bool verbose = false;
|
||||
|
||||
@@ -99,6 +105,7 @@ struct Option {
|
||||
printf(" sample_method: %s\n", "eular a");
|
||||
printf(" sample_steps: %d\n", sample_steps);
|
||||
printf(" strength: %.2f\n", strength);
|
||||
printf(" rng: %s\n", rng_type_to_str[rng_type]);
|
||||
printf(" seed: %d\n", seed);
|
||||
}
|
||||
};
|
||||
@@ -123,6 +130,7 @@ void print_usage(int argc, const char* argv[]) {
|
||||
printf(" -W, --width W image width, in pixel space (default: 512)\n");
|
||||
printf(" --sample-method SAMPLE_METHOD sample method (default: \"eular a\")\n");
|
||||
printf(" --steps STEPS number of sample steps (default: 20)\n");
|
||||
printf(" --rng {std_default, cuda} RNG (default: std_default)\n");
|
||||
printf(" -s SEED, --seed SEED RNG seed (default: 42, use random seed for < 0)\n");
|
||||
printf(" -v, --verbose print extra info\n");
|
||||
}
|
||||
@@ -206,6 +214,20 @@ void parse_args(int argc, const char* argv[], Option* opt) {
|
||||
break;
|
||||
}
|
||||
opt->sample_steps = std::stoi(argv[i]);
|
||||
} else if (arg == "--rng") {
|
||||
if (++i >= argc) {
|
||||
invalid_arg = true;
|
||||
break;
|
||||
}
|
||||
std::string rng_type_str = argv[i];
|
||||
if (rng_type_str == "std_default") {
|
||||
opt->rng_type = STD_DEFAULT_RNG;
|
||||
} else if (rng_type_str == "cuda") {
|
||||
opt->rng_type = CUDA_RNG;
|
||||
} else {
|
||||
invalid_arg = true;
|
||||
break;
|
||||
}
|
||||
} else if (arg == "-s" || arg == "--seed") {
|
||||
if (++i >= argc) {
|
||||
invalid_arg = true;
|
||||
@@ -328,7 +350,7 @@ int main(int argc, const char* argv[]) {
|
||||
init_img.assign(img_data, img_data + (opt.w * opt.h * c));
|
||||
}
|
||||
|
||||
StableDiffusion sd(opt.n_threads, vae_decode_only, true);
|
||||
StableDiffusion sd(opt.n_threads, vae_decode_only, true, opt.rng_type);
|
||||
if (!sd.load_from_file(opt.model_path)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user