fix: seed should be 64 bit

This commit is contained in:
leejet 2023-09-03 20:08:22 +08:00
parent e5a7aec252
commit 45842865ff
5 changed files with 10 additions and 10 deletions

View File

@ -87,7 +87,7 @@ struct Option {
int sample_steps = 20; int sample_steps = 20;
float strength = 0.75f; float strength = 0.75f;
RNGType rng_type = STD_DEFAULT_RNG; RNGType rng_type = STD_DEFAULT_RNG;
int seed = 42; int64_t seed = 42;
bool verbose = false; bool verbose = false;
void print() { void print() {
@ -106,7 +106,7 @@ struct Option {
printf(" sample_steps: %d\n", sample_steps); printf(" sample_steps: %d\n", sample_steps);
printf(" strength: %.2f\n", strength); printf(" strength: %.2f\n", strength);
printf(" rng: %s\n", rng_type_to_str[rng_type]); printf(" rng: %s\n", rng_type_to_str[rng_type]);
printf(" seed: %d\n", seed); printf(" seed: %ld\n", seed);
} }
}; };
@ -233,7 +233,7 @@ void parse_args(int argc, const char* argv[], Option* opt) {
invalid_arg = true; invalid_arg = true;
break; break;
} }
opt->seed = std::stoi(argv[i]); opt->seed = std::stoll(argv[i]);
} else if (arg == "-h" || arg == "--help") { } else if (arg == "-h" || arg == "--help") {
print_usage(argc, argv); print_usage(argc, argv);
exit(0); exit(0);

4
rng.h
View File

@ -6,7 +6,7 @@
class RNG { class RNG {
public: public:
virtual void manual_seed(uint32_t seed) = 0; virtual void manual_seed(uint64_t seed) = 0;
virtual std::vector<float> randn(uint32_t n) = 0; virtual std::vector<float> randn(uint32_t n) = 0;
}; };
@ -15,7 +15,7 @@ class STDDefaultRNG : public RNG {
std::default_random_engine generator; std::default_random_engine generator;
public: public:
void manual_seed(uint32_t seed) { void manual_seed(uint64_t seed) {
generator.seed(seed); generator.seed(seed);
} }

View File

@ -93,7 +93,7 @@ class PhiloxRNG : public RNG {
this->offset = 0; this->offset = 0;
} }
void manual_seed(uint32_t seed) { void manual_seed(uint64_t seed) {
this->seed = seed; this->seed = seed;
this->offset = 0; this->offset = 0;
} }

View File

@ -3823,7 +3823,7 @@ std::vector<uint8_t> StableDiffusion::txt2img(const std::string& prompt,
int height, int height,
SampleMethod sample_method, SampleMethod sample_method,
int sample_steps, int sample_steps,
int seed) { int64_t seed) {
std::vector<uint8_t> result; std::vector<uint8_t> result;
struct ggml_init_params params; struct ggml_init_params params;
params.mem_size = static_cast<size_t>(10 * 1024) * 1024; // 10M params.mem_size = static_cast<size_t>(10 * 1024) * 1024; // 10M
@ -3911,7 +3911,7 @@ std::vector<uint8_t> StableDiffusion::img2img(const std::vector<uint8_t>& init_i
SampleMethod sample_method, SampleMethod sample_method,
int sample_steps, int sample_steps,
float strength, float strength,
int seed) { int64_t seed) {
std::vector<uint8_t> result; std::vector<uint8_t> result;
if (init_img_vec.size() != width * height * 3) { if (init_img_vec.size() != width * height * 3) {
return result; return result;

View File

@ -40,7 +40,7 @@ class StableDiffusion {
int height, int height,
SampleMethod sample_method, SampleMethod sample_method,
int sample_steps, int sample_steps,
int seed); int64_t seed);
std::vector<uint8_t> img2img( std::vector<uint8_t> img2img(
const std::vector<uint8_t>& init_img, const std::vector<uint8_t>& init_img,
const std::string& prompt, const std::string& prompt,
@ -51,7 +51,7 @@ class StableDiffusion {
SampleMethod sample_method, SampleMethod sample_method,
int sample_steps, int sample_steps,
float strength, float strength,
int seed); int64_t seed);
}; };
void set_sd_log_level(SDLogLevel level); void set_sd_log_level(SDLogLevel level);