feat: add CUDA RNG
This commit is contained in:
parent
31e77e1573
commit
e5a7aec252
@ -20,6 +20,7 @@ Inference of [Stable Diffusion](https://github.com/CompVis/stable-diffusion) in
|
||||
- [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) style tokenizer (not all the features, only token weighting for now)
|
||||
- Sampling method
|
||||
- `Euler A`
|
||||
- Cross-platform reproducibility (`--rng cuda`, consistent with the `stable-diffusion-webui GPU RNG`)
|
||||
- Supported platforms
|
||||
- Linux
|
||||
- Mac OS
|
||||
@ -35,8 +36,6 @@ Inference of [Stable Diffusion](https://github.com/CompVis/stable-diffusion) in
|
||||
- [ ] Continuing to reduce memory usage (quantizing the weights of ggml_conv_2d)
|
||||
- [ ] LoRA support
|
||||
- [ ] k-quants support
|
||||
- [ ] Cross-platform reproducibility (perhaps ensuring consistency with the original SD)
|
||||
- [ ] Adapting to more weight formats
|
||||
|
||||
## Usage
|
||||
|
||||
|
@ -67,6 +67,11 @@ int32_t get_num_physical_cores() {
|
||||
return n_threads > 0 ? (n_threads <= 4 ? n_threads : n_threads / 2) : 4;
|
||||
}
|
||||
|
||||
const char* rng_type_to_str[] = {
|
||||
"std_default",
|
||||
"cuda",
|
||||
};
|
||||
|
||||
struct Option {
|
||||
int n_threads = -1;
|
||||
std::string mode = TXT2IMG;
|
||||
@ -81,6 +86,7 @@ struct Option {
|
||||
SampleMethod sample_method = EULAR_A;
|
||||
int sample_steps = 20;
|
||||
float strength = 0.75f;
|
||||
RNGType rng_type = STD_DEFAULT_RNG;
|
||||
int seed = 42;
|
||||
bool verbose = false;
|
||||
|
||||
@ -99,6 +105,7 @@ struct Option {
|
||||
printf(" sample_method: %s\n", "eular a");
|
||||
printf(" sample_steps: %d\n", sample_steps);
|
||||
printf(" strength: %.2f\n", strength);
|
||||
printf(" rng: %s\n", rng_type_to_str[rng_type]);
|
||||
printf(" seed: %d\n", seed);
|
||||
}
|
||||
};
|
||||
@ -123,6 +130,7 @@ void print_usage(int argc, const char* argv[]) {
|
||||
printf(" -W, --width W image width, in pixel space (default: 512)\n");
|
||||
printf(" --sample-method SAMPLE_METHOD sample method (default: \"eular a\")\n");
|
||||
printf(" --steps STEPS number of sample steps (default: 20)\n");
|
||||
printf(" --rng {std_default, cuda} RNG (default: std_default)\n");
|
||||
printf(" -s SEED, --seed SEED RNG seed (default: 42, use random seed for < 0)\n");
|
||||
printf(" -v, --verbose print extra info\n");
|
||||
}
|
||||
@ -206,6 +214,20 @@ void parse_args(int argc, const char* argv[], Option* opt) {
|
||||
break;
|
||||
}
|
||||
opt->sample_steps = std::stoi(argv[i]);
|
||||
} else if (arg == "--rng") {
|
||||
if (++i >= argc) {
|
||||
invalid_arg = true;
|
||||
break;
|
||||
}
|
||||
std::string rng_type_str = argv[i];
|
||||
if (rng_type_str == "std_default") {
|
||||
opt->rng_type = STD_DEFAULT_RNG;
|
||||
} else if (rng_type_str == "cuda") {
|
||||
opt->rng_type = CUDA_RNG;
|
||||
} else {
|
||||
invalid_arg = true;
|
||||
break;
|
||||
}
|
||||
} else if (arg == "-s" || arg == "--seed") {
|
||||
if (++i >= argc) {
|
||||
invalid_arg = true;
|
||||
@ -328,7 +350,7 @@ int main(int argc, const char* argv[]) {
|
||||
init_img.assign(img_data, img_data + (opt.w * opt.h * c));
|
||||
}
|
||||
|
||||
StableDiffusion sd(opt.n_threads, vae_decode_only, true);
|
||||
StableDiffusion sd(opt.n_threads, vae_decode_only, true, opt.rng_type);
|
||||
if (!sd.load_from_file(opt.model_path)) {
|
||||
return 1;
|
||||
}
|
||||
|
35
rng.h
Normal file
35
rng.h
Normal file
@ -0,0 +1,35 @@
|
||||
#ifndef __RNG_H__
|
||||
#define __RNG_H__
|
||||
|
||||
#include <random>
|
||||
#include <vector>
|
||||
|
||||
class RNG {
|
||||
public:
|
||||
virtual void manual_seed(uint32_t seed) = 0;
|
||||
virtual std::vector<float> randn(uint32_t n) = 0;
|
||||
};
|
||||
|
||||
class STDDefaultRNG : public RNG {
|
||||
private:
|
||||
std::default_random_engine generator;
|
||||
|
||||
public:
|
||||
void manual_seed(uint32_t seed) {
|
||||
generator.seed(seed);
|
||||
}
|
||||
|
||||
std::vector<float> randn(uint32_t n) {
|
||||
std::vector<float> result;
|
||||
float mean = 0.0;
|
||||
float stddev = 1.0;
|
||||
std::normal_distribution<float> distribution(mean, stddev);
|
||||
for (int i = 0; i < n; i++) {
|
||||
float random_number = distribution(generator);
|
||||
result.push_back(random_number);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
#endif // __RNG_H__
|
125
rng_philox.h
Normal file
125
rng_philox.h
Normal file
@ -0,0 +1,125 @@
|
||||
#ifndef __RNG_PHILOX_H__
|
||||
#define __RNG_PHILOX_H__
|
||||
|
||||
#include <cmath>
|
||||
#include <vector>
|
||||
|
||||
#include "rng.h"
|
||||
|
||||
// RNG imitiating torch cuda randn on CPU.
|
||||
// Port from: https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/5ef669de080814067961f28357256e8fe27544f4/modules/rng_philox.py
|
||||
class PhiloxRNG : public RNG {
|
||||
private:
|
||||
uint64_t seed;
|
||||
uint32_t offset;
|
||||
|
||||
private:
|
||||
std::vector<uint32_t> philox_m = {0xD2511F53, 0xCD9E8D57};
|
||||
std::vector<uint32_t> philox_w = {0x9E3779B9, 0xBB67AE85};
|
||||
float two_pow32_inv = 2.3283064e-10;
|
||||
float two_pow32_inv_2pi = 2.3283064e-10 * 6.2831855;
|
||||
|
||||
std::vector<uint32_t> uint32(uint64_t x) {
|
||||
std::vector<uint32_t> result(2);
|
||||
result[0] = static_cast<uint32_t>(x & 0xFFFFFFFF);
|
||||
result[1] = static_cast<uint32_t>(x >> 32);
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<std::vector<uint32_t>> uint32(const std::vector<uint64_t>& x) {
|
||||
int N = x.size();
|
||||
std::vector<std::vector<uint32_t>> result(2, std::vector<uint32_t>(N));
|
||||
|
||||
for (int i = 0; i < N; ++i) {
|
||||
result[0][i] = static_cast<uint32_t>(x[i] & 0xFFFFFFFF);
|
||||
result[1][i] = static_cast<uint32_t>(x[i] >> 32);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// A single round of the Philox 4x32 random number generator.
|
||||
void philox4_round(std::vector<std::vector<uint32_t>>& counter,
|
||||
const std::vector<std::vector<uint32_t>>& key) {
|
||||
uint32_t N = counter[0].size();
|
||||
for (uint32_t i = 0; i < N; i++) {
|
||||
std::vector<uint32_t> v1 = uint32(static_cast<uint64_t>(counter[0][i]) * static_cast<uint64_t>(philox_m[0]));
|
||||
std::vector<uint32_t> v2 = uint32(static_cast<uint64_t>(counter[2][i]) * static_cast<uint64_t>(philox_m[1]));
|
||||
|
||||
counter[0][i] = v2[1] ^ counter[1][i] ^ key[0][i];
|
||||
counter[1][i] = v2[0];
|
||||
counter[2][i] = v1[1] ^ counter[3][i] ^ key[1][i];
|
||||
counter[3][i] = v1[0];
|
||||
}
|
||||
}
|
||||
|
||||
// Generates 32-bit random numbers using the Philox 4x32 random number generator.
|
||||
// Parameters:
|
||||
// counter : A 4xN array of 32-bit integers representing the counter values (offset into generation).
|
||||
// key : A 2xN array of 32-bit integers representing the key values (seed).
|
||||
// rounds : The number of rounds to perform.
|
||||
// Returns:
|
||||
// std::vector<std::vector<uint32_t>>: A 4xN array of 32-bit integers containing the generated random numbers.
|
||||
std::vector<std::vector<uint32_t>> philox4_32(std::vector<std::vector<uint32_t>>& counter,
|
||||
std::vector<std::vector<uint32_t>>& key,
|
||||
int rounds = 10) {
|
||||
uint32_t N = counter[0].size();
|
||||
for (int i = 0; i < rounds - 1; ++i) {
|
||||
philox4_round(counter, key);
|
||||
|
||||
for (uint32_t j = 0; j < N; ++j) {
|
||||
key[0][j] += philox_w[0];
|
||||
key[1][j] += philox_w[1];
|
||||
}
|
||||
}
|
||||
|
||||
philox4_round(counter, key);
|
||||
return counter;
|
||||
}
|
||||
|
||||
float box_muller(float x, float y) {
|
||||
float u = x * two_pow32_inv + two_pow32_inv / 2;
|
||||
float v = y * two_pow32_inv_2pi + two_pow32_inv_2pi / 2;
|
||||
|
||||
float s = sqrt(-2.0 * log(u));
|
||||
|
||||
float r1 = s * sin(v);
|
||||
return r1;
|
||||
}
|
||||
|
||||
public:
|
||||
PhiloxRNG(uint64_t seed = 0) {
|
||||
this->seed = seed;
|
||||
this->offset = 0;
|
||||
}
|
||||
|
||||
void manual_seed(uint32_t seed) {
|
||||
this->seed = seed;
|
||||
this->offset = 0;
|
||||
}
|
||||
|
||||
std::vector<float> randn(uint32_t n) {
|
||||
std::vector<std::vector<uint32_t>> counter(4, std::vector<uint32_t>(n, 0));
|
||||
for (uint32_t i = 0; i < n; i++) {
|
||||
counter[0][i] = this->offset;
|
||||
}
|
||||
|
||||
for (uint32_t i = 0; i < n; i++) {
|
||||
counter[2][i] = i;
|
||||
}
|
||||
this->offset += 1;
|
||||
|
||||
std::vector<uint64_t> key(n, this->seed);
|
||||
std::vector<std::vector<uint32_t>> key_uint32 = uint32(key);
|
||||
|
||||
std::vector<std::vector<uint32_t>> g = philox4_32(counter, key_uint32);
|
||||
|
||||
std::vector<float> result;
|
||||
for (int i = 0; i < n; ++i) {
|
||||
result.push_back(box_muller(g[0][i], g[1][i]));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
#endif // __RNG_PHILOX_H__
|
@ -15,6 +15,8 @@
|
||||
|
||||
#include "ggml/ggml.h"
|
||||
#include "stable-diffusion.h"
|
||||
#include "rng.h"
|
||||
#include "rng_philox.h"
|
||||
|
||||
static SDLogLevel log_level = SDLogLevel::INFO;
|
||||
|
||||
@ -117,19 +119,11 @@ ggml_tensor* load_tensor_from_file(ggml_context* ctx, const std::string& file_pa
|
||||
return tensor;
|
||||
}
|
||||
|
||||
static std::default_random_engine generator;
|
||||
|
||||
void set_random_seed(int seed) {
|
||||
generator.seed(seed);
|
||||
}
|
||||
|
||||
void ggml_tensor_set_f32_randn(struct ggml_tensor* tensor) {
|
||||
float mean = 0.0;
|
||||
float stddev = 1.0;
|
||||
std::normal_distribution<float> distribution(mean, stddev);
|
||||
for (int i = 0; i < ggml_nelements(tensor); i++) {
|
||||
float random_number = distribution(generator);
|
||||
ggml_set_f32_1d(tensor, i, random_number);
|
||||
void ggml_tensor_set_f32_randn(struct ggml_tensor* tensor, std::shared_ptr<RNG> rng) {
|
||||
uint32_t n = ggml_nelements(tensor);
|
||||
std::vector<float> random_numbers = rng->randn(n);
|
||||
for (int i = 0; i < n; i++) {
|
||||
ggml_set_f32_1d(tensor, i, random_numbers[i]);
|
||||
}
|
||||
}
|
||||
|
||||
@ -2747,6 +2741,8 @@ class StableDiffusionGGML {
|
||||
bool dynamic = true;
|
||||
bool vae_decode_only = false;
|
||||
bool free_params_immediately = false;
|
||||
|
||||
std::shared_ptr<RNG> rng = std::make_shared<STDDefaultRNG>();
|
||||
int32_t ftype = 1;
|
||||
int n_threads = -1;
|
||||
float scale_factor = 0.18215f;
|
||||
@ -2765,11 +2761,17 @@ class StableDiffusionGGML {
|
||||
|
||||
StableDiffusionGGML(int n_threads,
|
||||
bool vae_decode_only,
|
||||
bool free_params_immediately)
|
||||
bool free_params_immediately,
|
||||
RNGType rng_type)
|
||||
: n_threads(n_threads),
|
||||
vae_decode_only(vae_decode_only),
|
||||
free_params_immediately(free_params_immediately) {
|
||||
first_stage_model.decode_only = vae_decode_only;
|
||||
if (rng_type == STD_DEFAULT_RNG) {
|
||||
rng = std::make_shared<STDDefaultRNG>();
|
||||
} else if (rng_type == CUDA_RNG) {
|
||||
rng = std::make_shared<PhiloxRNG>();
|
||||
}
|
||||
}
|
||||
|
||||
~StableDiffusionGGML() {
|
||||
@ -3539,7 +3541,7 @@ class StableDiffusionGGML {
|
||||
|
||||
if (sigmas[i + 1] > 0) {
|
||||
// x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
||||
ggml_tensor_set_f32_randn(noise);
|
||||
ggml_tensor_set_f32_randn(noise, rng);
|
||||
// noise = load_tensor_from_file(res_ctx, "./rand" + std::to_string(i+1) + ".bin");
|
||||
{
|
||||
float* vec_x = (float*)x->data;
|
||||
@ -3674,7 +3676,7 @@ class StableDiffusionGGML {
|
||||
ggml_tensor* latent = ggml_new_tensor_4d(res_ctx, moments->type, moments->ne[0],
|
||||
moments->ne[1], moments->ne[2] / 2, moments->ne[3]);
|
||||
struct ggml_tensor* noise = ggml_dup_tensor(res_ctx, latent);
|
||||
ggml_tensor_set_f32_randn(noise);
|
||||
ggml_tensor_set_f32_randn(noise, rng);
|
||||
// noise = load_tensor_from_file(res_ctx, "noise.bin");
|
||||
{
|
||||
float mean = 0;
|
||||
@ -3802,10 +3804,12 @@ class StableDiffusionGGML {
|
||||
|
||||
StableDiffusion::StableDiffusion(int n_threads,
|
||||
bool vae_decode_only,
|
||||
bool free_params_immediately) {
|
||||
bool free_params_immediately,
|
||||
RNGType rng_type) {
|
||||
sd = std::make_shared<StableDiffusionGGML>(n_threads,
|
||||
vae_decode_only,
|
||||
free_params_immediately);
|
||||
free_params_immediately,
|
||||
rng_type);
|
||||
}
|
||||
|
||||
bool StableDiffusion::load_from_file(const std::string& file_path) {
|
||||
@ -3835,7 +3839,7 @@ std::vector<uint8_t> StableDiffusion::txt2img(const std::string& prompt,
|
||||
if (seed < 0) {
|
||||
seed = (int)time(NULL);
|
||||
}
|
||||
set_random_seed(seed);
|
||||
sd->rng->manual_seed(seed);
|
||||
|
||||
int64_t t0 = ggml_time_ms();
|
||||
ggml_tensor* c = sd->get_learned_condition(ctx, prompt);
|
||||
@ -3856,7 +3860,7 @@ std::vector<uint8_t> StableDiffusion::txt2img(const std::string& prompt,
|
||||
int W = width / 8;
|
||||
int H = height / 8;
|
||||
struct ggml_tensor* x_t = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, W, H, C, 1);
|
||||
ggml_tensor_set_f32_randn(x_t);
|
||||
ggml_tensor_set_f32_randn(x_t, sd->rng);
|
||||
|
||||
std::vector<float> sigmas = sd->denoiser->get_sigmas(sample_steps);
|
||||
|
||||
@ -3935,7 +3939,7 @@ std::vector<uint8_t> StableDiffusion::img2img(const std::vector<uint8_t>& init_i
|
||||
if (seed < 0) {
|
||||
seed = (int)time(NULL);
|
||||
}
|
||||
set_random_seed(seed);
|
||||
sd->rng->manual_seed(seed);
|
||||
|
||||
ggml_tensor* init_img = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, width, height, 3, 1);
|
||||
image_vec_to_ggml(init_img_vec, init_img);
|
||||
|
@ -4,13 +4,18 @@
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
enum class SDLogLevel {
|
||||
enum SDLogLevel {
|
||||
DEBUG,
|
||||
INFO,
|
||||
WARN,
|
||||
ERROR
|
||||
};
|
||||
|
||||
enum RNGType {
|
||||
STD_DEFAULT_RNG,
|
||||
CUDA_RNG
|
||||
};
|
||||
|
||||
enum SampleMethod {
|
||||
EULAR_A,
|
||||
};
|
||||
@ -24,7 +29,8 @@ class StableDiffusion {
|
||||
public:
|
||||
StableDiffusion(int n_threads = -1,
|
||||
bool vae_decode_only = false,
|
||||
bool free_params_immediately = false);
|
||||
bool free_params_immediately = false,
|
||||
RNGType rng_type = STD_DEFAULT_RNG);
|
||||
bool load_from_file(const std::string& file_path);
|
||||
std::vector<uint8_t> txt2img(
|
||||
const std::string& prompt,
|
||||
|
Loading…
Reference in New Issue
Block a user