fix: seed should be 64 bit

This commit is contained in:
leejet 2023-09-03 20:08:22 +08:00
parent e5a7aec252
commit 45842865ff
5 changed files with 10 additions and 10 deletions

View File

@ -87,7 +87,7 @@ struct Option {
int sample_steps = 20;
float strength = 0.75f;
RNGType rng_type = STD_DEFAULT_RNG;
int seed = 42;
int64_t seed = 42;
bool verbose = false;
void print() {
@ -106,7 +106,7 @@ struct Option {
printf(" sample_steps: %d\n", sample_steps);
printf(" strength: %.2f\n", strength);
printf(" rng: %s\n", rng_type_to_str[rng_type]);
printf(" seed: %d\n", seed);
printf(" seed: %ld\n", seed);
}
};
@ -233,7 +233,7 @@ void parse_args(int argc, const char* argv[], Option* opt) {
invalid_arg = true;
break;
}
opt->seed = std::stoi(argv[i]);
opt->seed = std::stoll(argv[i]);
} else if (arg == "-h" || arg == "--help") {
print_usage(argc, argv);
exit(0);

4
rng.h
View File

@ -6,7 +6,7 @@
class RNG {
public:
virtual void manual_seed(uint32_t seed) = 0;
virtual void manual_seed(uint64_t seed) = 0;
virtual std::vector<float> randn(uint32_t n) = 0;
};
@ -15,7 +15,7 @@ class STDDefaultRNG : public RNG {
std::default_random_engine generator;
public:
void manual_seed(uint32_t seed) {
void manual_seed(uint64_t seed) {
generator.seed(seed);
}

View File

@ -93,7 +93,7 @@ class PhiloxRNG : public RNG {
this->offset = 0;
}
void manual_seed(uint32_t seed) {
void manual_seed(uint64_t seed) {
this->seed = seed;
this->offset = 0;
}

View File

@ -3823,7 +3823,7 @@ std::vector<uint8_t> StableDiffusion::txt2img(const std::string& prompt,
int height,
SampleMethod sample_method,
int sample_steps,
int seed) {
int64_t seed) {
std::vector<uint8_t> result;
struct ggml_init_params params;
params.mem_size = static_cast<size_t>(10 * 1024) * 1024; // 10M
@ -3911,7 +3911,7 @@ std::vector<uint8_t> StableDiffusion::img2img(const std::vector<uint8_t>& init_i
SampleMethod sample_method,
int sample_steps,
float strength,
int seed) {
int64_t seed) {
std::vector<uint8_t> result;
if (init_img_vec.size() != width * height * 3) {
return result;

View File

@ -40,7 +40,7 @@ class StableDiffusion {
int height,
SampleMethod sample_method,
int sample_steps,
int seed);
int64_t seed);
std::vector<uint8_t> img2img(
const std::vector<uint8_t>& init_img,
const std::string& prompt,
@ -51,7 +51,7 @@ class StableDiffusion {
SampleMethod sample_method,
int sample_steps,
float strength,
int seed);
int64_t seed);
};
void set_sd_log_level(SDLogLevel level);