stable-diffusion.cpp/denoiser.hpp

1385 lines
57 KiB
C++

#ifndef __DENOISER_HPP__
#define __DENOISER_HPP__
#include "ggml_extend.hpp"
#include "gits_noise.inl"
/*================================================= CompVisDenoiser ==================================================*/
// Ref: https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/external.py
#define TIMESTEPS 1000
#define FLUX_TIMESTEPS 1000
struct SigmaSchedule {
int version = 0;
typedef std::function<float(float)> t_to_sigma_t;
virtual std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) = 0;
};
struct DiscreteSchedule : SigmaSchedule {
std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) {
std::vector<float> result;
int t_max = TIMESTEPS - 1;
if (n == 0) {
return result;
} else if (n == 1) {
result.push_back(t_to_sigma((float)t_max));
result.push_back(0);
return result;
}
float step = static_cast<float>(t_max) / static_cast<float>(n - 1);
for (uint32_t i = 0; i < n; ++i) {
float t = t_max - step * i;
result.push_back(t_to_sigma(t));
}
result.push_back(0);
return result;
}
};
struct ExponentialSchedule : SigmaSchedule {
std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) {
std::vector<float> sigmas;
// Calculate step size
float log_sigma_min = std::log(sigma_min);
float log_sigma_max = std::log(sigma_max);
float step = (log_sigma_max - log_sigma_min) / (n - 1);
// Fill sigmas with exponential values
for (uint32_t i = 0; i < n; ++i) {
float sigma = std::exp(log_sigma_max - step * i);
sigmas.push_back(sigma);
}
sigmas.push_back(0.0f);
return sigmas;
}
};
/* interp and linear_interp adapted from dpilger26's NumCpp library:
* https://github.com/dpilger26/NumCpp/tree/5e40aab74d14e257d65d3dc385c9ff9e2120c60e */
constexpr double interp(double left, double right, double perc) noexcept {
return (left * (1. - perc)) + (right * perc);
}
/* This will make the assumption that the reference x and y values are
* already sorted in ascending order because they are being generated as
* such in the calling function */
std::vector<double> linear_interp(std::vector<float> new_x,
const std::vector<float> ref_x,
const std::vector<float> ref_y) {
const size_t len_x = new_x.size();
size_t i = 0;
size_t j = 0;
std::vector<double> new_y(len_x);
if (ref_x.size() != ref_y.size()) {
LOG_ERROR("Linear Interpolation Failed: length mismatch");
return new_y;
}
/* Adjusted bounds checking to ensure new_x is within ref_x range */
if (new_x[0] < ref_x[0]) {
new_x[0] = ref_x[0];
}
if (new_x.back() > ref_x.back()) {
new_x.back() = ref_x.back();
}
while (i < len_x) {
if ((ref_x[j] > new_x[i]) || (new_x[i] > ref_x[j + 1])) {
j++;
continue;
}
const double perc = static_cast<double>(new_x[i] - ref_x[j]) / static_cast<double>(ref_x[j + 1] - ref_x[j]);
new_y[i] = interp(ref_y[j], ref_y[j + 1], perc);
i++;
}
return new_y;
}
std::vector<float> linear_space(const float start, const float end, const size_t num_points) {
std::vector<float> result(num_points);
const float inc = (end - start) / (static_cast<float>(num_points - 1));
if (num_points > 0) {
result[0] = start;
for (size_t i = 1; i < num_points; i++) {
result[i] = result[i - 1] + inc;
}
}
return result;
}
std::vector<float> log_linear_interpolation(std::vector<float> sigma_in,
const size_t new_len) {
const size_t s_len = sigma_in.size();
std::vector<float> x_vals = linear_space(0.f, 1.f, s_len);
std::vector<float> y_vals(s_len);
/* Reverses the input array to be ascending instead of descending,
* also hits it with a log, it is log-linear interpolation after all */
for (size_t i = 0; i < s_len; i++) {
y_vals[i] = std::log(sigma_in[s_len - i - 1]);
}
std::vector<float> new_x_vals = linear_space(0.f, 1.f, new_len);
std::vector<double> new_y_vals = linear_interp(new_x_vals, x_vals, y_vals);
std::vector<float> results(new_len);
for (size_t i = 0; i < new_len; i++) {
results[i] = static_cast<float>(std::exp(new_y_vals[new_len - i - 1]));
}
return results;
}
/*
https://research.nvidia.com/labs/toronto-ai/AlignYourSteps/howto.html
*/
struct AYSSchedule : SigmaSchedule {
std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) {
const std::vector<float> noise_levels[] = {
/* SD1.5 */
{14.6146412293f, 6.4745760956f, 3.8636745985f, 2.6946151520f,
1.8841921177f, 1.3943805092f, 0.9642583904f, 0.6523686016f,
0.3977456272f, 0.1515232662f, 0.0291671582f},
/* SDXL */
{14.6146412293f, 6.3184485287f, 3.7681790315f, 2.1811480769f,
1.3405244945f, 0.8620721141f, 0.5550693289f, 0.3798540708f,
0.2332364134f, 0.1114188177f, 0.0291671582f},
/* SVD */
{700.00f, 54.5f, 15.886f, 7.977f, 4.248f, 1.789f, 0.981f, 0.403f,
0.173f, 0.034f, 0.002f},
};
std::vector<float> inputs;
std::vector<float> results(n + 1);
switch (version) {
case VERSION_SD2: /* fallthrough */
LOG_WARN("AYS not designed for SD2.X models");
case VERSION_SD1:
LOG_INFO("AYS using SD1.5 noise levels");
inputs = noise_levels[0];
break;
case VERSION_SDXL:
LOG_INFO("AYS using SDXL noise levels");
inputs = noise_levels[1];
break;
case VERSION_SVD:
LOG_INFO("AYS using SVD noise levels");
inputs = noise_levels[2];
break;
default:
LOG_ERROR("Version not compatable with AYS scheduler");
return results;
}
/* Stretches those pre-calculated reference levels out to the desired
* size using log-linear interpolation */
if ((n + 1) != inputs.size()) {
results = log_linear_interpolation(inputs, n + 1);
} else {
results = inputs;
}
/* Not sure if this is strictly neccessary */
results[n] = 0.0f;
return results;
}
};
/*
* GITS Scheduler: https://github.com/zju-pi/diff-sampler/tree/main/gits-main
*/
struct GITSSchedule : SigmaSchedule {
std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) {
if (sigma_max <= 0.0f) {
return std::vector<float>{};
}
std::vector<float> sigmas;
// Assume coeff is provided (replace 1.20 with your dynamic coeff)
float coeff = 1.20f; // Default coefficient
// Normalize coeff to the closest value in the array (0.80 to 1.50)
coeff = std::round(coeff * 20.0f) / 20.0f; // Round to the nearest 0.05
// Calculate the index based on the coefficient
int index = static_cast<int>((coeff - 0.80f) / 0.05f);
// Ensure the index is within bounds
index = std::max(0, std::min(index, static_cast<int>(GITS_NOISE.size() - 1)));
const std::vector<std::vector<float>>& selected_noise = *GITS_NOISE[index];
if (n <= 20) {
sigmas = (selected_noise)[n - 2];
} else {
sigmas = log_linear_interpolation(selected_noise.back(), n + 1);
}
sigmas[n] = 0.0f;
return sigmas;
}
};
struct KarrasSchedule : SigmaSchedule {
std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) {
// These *COULD* be function arguments here,
// but does anybody ever bother to touch them?
float rho = 7.f;
std::vector<float> result(n + 1);
float min_inv_rho = pow(sigma_min, (1.f / rho));
float max_inv_rho = pow(sigma_max, (1.f / rho));
for (uint32_t i = 0; i < n; i++) {
// Eq. (5) from Karras et al 2022
result[i] = pow(max_inv_rho + (float)i / ((float)n - 1.f) * (min_inv_rho - max_inv_rho), rho);
}
result[n] = 0.;
return result;
}
};
struct Denoiser {
std::shared_ptr<SigmaSchedule> schedule = std::make_shared<DiscreteSchedule>();
virtual float sigma_min() = 0;
virtual float sigma_max() = 0;
virtual float sigma_to_t(float sigma) = 0;
virtual float t_to_sigma(float t) = 0;
virtual std::vector<float> get_scalings(float sigma) = 0;
virtual ggml_tensor* noise_scaling(float sigma, ggml_tensor* noise, ggml_tensor* latent) = 0;
virtual ggml_tensor* inverse_noise_scaling(float sigma, ggml_tensor* latent) = 0;
virtual std::vector<float> get_sigmas(uint32_t n) {
auto bound_t_to_sigma = std::bind(&Denoiser::t_to_sigma, this, std::placeholders::_1);
return schedule->get_sigmas(n, sigma_min(), sigma_max(), bound_t_to_sigma);
}
};
struct CompVisDenoiser : public Denoiser {
float sigmas[TIMESTEPS];
float log_sigmas[TIMESTEPS];
float sigma_data = 1.0f;
float sigma_min() {
return sigmas[0];
}
float sigma_max() {
return sigmas[TIMESTEPS - 1];
}
float sigma_to_t(float sigma) {
float log_sigma = std::log(sigma);
std::vector<float> dists;
dists.reserve(TIMESTEPS);
for (float log_sigma_val : log_sigmas) {
dists.push_back(log_sigma - log_sigma_val);
}
int low_idx = 0;
for (size_t i = 0; i < TIMESTEPS; i++) {
if (dists[i] >= 0) {
low_idx++;
}
}
low_idx = std::min(std::max(low_idx - 1, 0), TIMESTEPS - 2);
int high_idx = low_idx + 1;
float low = log_sigmas[low_idx];
float high = log_sigmas[high_idx];
float w = (low - log_sigma) / (low - high);
w = std::max(0.f, std::min(1.f, w));
float t = (1.0f - w) * low_idx + w * high_idx;
return t;
}
float t_to_sigma(float t) {
int low_idx = static_cast<int>(std::floor(t));
int high_idx = static_cast<int>(std::ceil(t));
float w = t - static_cast<float>(low_idx);
float log_sigma = (1.0f - w) * log_sigmas[low_idx] + w * log_sigmas[high_idx];
return std::exp(log_sigma);
}
std::vector<float> get_scalings(float sigma) {
float c_skip = 1.0f;
float c_out = -sigma;
float c_in = 1.0f / std::sqrt(sigma * sigma + sigma_data * sigma_data);
return {c_skip, c_out, c_in};
}
// this function will modify noise/latent
ggml_tensor* noise_scaling(float sigma, ggml_tensor* noise, ggml_tensor* latent) {
ggml_tensor_scale(noise, sigma);
ggml_tensor_add(latent, noise);
return latent;
}
ggml_tensor* inverse_noise_scaling(float sigma, ggml_tensor* latent) {
return latent;
}
};
struct CompVisVDenoiser : public CompVisDenoiser {
std::vector<float> get_scalings(float sigma) {
float c_skip = sigma_data * sigma_data / (sigma * sigma + sigma_data * sigma_data);
float c_out = -sigma * sigma_data / std::sqrt(sigma * sigma + sigma_data * sigma_data);
float c_in = 1.0f / std::sqrt(sigma * sigma + sigma_data * sigma_data);
return {c_skip, c_out, c_in};
}
};
float time_snr_shift(float alpha, float t) {
if (alpha == 1.0f) {
return t;
}
return alpha * t / (1 + (alpha - 1) * t);
}
struct DiscreteFlowDenoiser : public Denoiser {
float sigmas[TIMESTEPS];
float shift = 3.0f;
float sigma_data = 1.0f;
DiscreteFlowDenoiser() {
set_parameters();
}
void set_parameters() {
for (int i = 1; i < TIMESTEPS + 1; i++) {
sigmas[i - 1] = t_to_sigma(i);
}
}
float sigma_min() {
return sigmas[0];
}
float sigma_max() {
return sigmas[TIMESTEPS - 1];
}
float sigma_to_t(float sigma) {
return sigma * 1000.f;
}
float t_to_sigma(float t) {
t = t + 1;
return time_snr_shift(shift, t / 1000.f);
}
std::vector<float> get_scalings(float sigma) {
float c_skip = 1.0f;
float c_out = -sigma;
float c_in = 1.0f;
return {c_skip, c_out, c_in};
}
// this function will modify noise/latent
ggml_tensor* noise_scaling(float sigma, ggml_tensor* noise, ggml_tensor* latent) {
ggml_tensor_scale(noise, sigma);
ggml_tensor_scale(latent, 1.0f - sigma);
ggml_tensor_add(latent, noise);
return latent;
}
ggml_tensor* inverse_noise_scaling(float sigma, ggml_tensor* latent) {
ggml_tensor_scale(latent, 1.0f / (1.0f - sigma));
return latent;
}
};
float flux_time_shift(float mu, float sigma, float t) {
return std::exp(mu) / (std::exp(mu) + std::pow((1.0 / t - 1.0), sigma));
}
struct FluxFlowDenoiser : public Denoiser {
float sigmas[TIMESTEPS];
float shift = 1.15f;
float sigma_data = 1.0f;
FluxFlowDenoiser(float shift = 1.15f) {
set_parameters(shift);
}
void set_parameters(float shift = 1.15f) {
this->shift = shift;
for (int i = 1; i < TIMESTEPS + 1; i++) {
sigmas[i - 1] = t_to_sigma(i / TIMESTEPS * TIMESTEPS);
}
}
float sigma_min() {
return sigmas[0];
}
float sigma_max() {
return sigmas[TIMESTEPS - 1];
}
float sigma_to_t(float sigma) {
return sigma;
}
float t_to_sigma(float t) {
t = t + 1;
return flux_time_shift(shift, 1.0f, t / TIMESTEPS);
}
std::vector<float> get_scalings(float sigma) {
float c_skip = 1.0f;
float c_out = -sigma;
float c_in = 1.0f;
return {c_skip, c_out, c_in};
}
// this function will modify noise/latent
ggml_tensor* noise_scaling(float sigma, ggml_tensor* noise, ggml_tensor* latent) {
ggml_tensor_scale(noise, sigma);
ggml_tensor_scale(latent, 1.0f - sigma);
ggml_tensor_add(latent, noise);
return latent;
}
ggml_tensor* inverse_noise_scaling(float sigma, ggml_tensor* latent) {
ggml_tensor_scale(latent, 1.0f / (1.0f - sigma));
return latent;
}
};
typedef std::function<ggml_tensor*(ggml_tensor*, float, int)> denoise_cb_t;
// k diffusion reverse ODE: dx = (x - D(x;\sigma)) / \sigma dt; \sigma(t) = t
static void sample_k_diffusion(sample_method_t method,
denoise_cb_t model,
ggml_context* work_ctx,
ggml_tensor* x,
std::vector<float> sigmas,
std::shared_ptr<RNG> rng,
float eta) {
size_t steps = sigmas.size() - 1;
// sample_euler_ancestral
switch (method) {
case EULER_A: {
struct ggml_tensor* noise = ggml_dup_tensor(work_ctx, x);
struct ggml_tensor* d = ggml_dup_tensor(work_ctx, x);
for (int i = 0; i < steps; i++) {
float sigma = sigmas[i];
// denoise
ggml_tensor* denoised = model(x, sigma, i + 1);
// d = (x - denoised) / sigma
{
float* vec_d = (float*)d->data;
float* vec_x = (float*)x->data;
float* vec_denoised = (float*)denoised->data;
for (int i = 0; i < ggml_nelements(d); i++) {
vec_d[i] = (vec_x[i] - vec_denoised[i]) / sigma;
}
}
// get_ancestral_step
float sigma_up = std::min(sigmas[i + 1],
std::sqrt(sigmas[i + 1] * sigmas[i + 1] * (sigmas[i] * sigmas[i] - sigmas[i + 1] * sigmas[i + 1]) / (sigmas[i] * sigmas[i])));
float sigma_down = std::sqrt(sigmas[i + 1] * sigmas[i + 1] - sigma_up * sigma_up);
// Euler method
float dt = sigma_down - sigmas[i];
// x = x + d * dt
{
float* vec_d = (float*)d->data;
float* vec_x = (float*)x->data;
for (int i = 0; i < ggml_nelements(x); i++) {
vec_x[i] = vec_x[i] + vec_d[i] * dt;
}
}
if (sigmas[i + 1] > 0) {
// x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
ggml_tensor_set_f32_randn(noise, rng);
// noise = load_tensor_from_file(work_ctx, "./rand" + std::to_string(i+1) + ".bin");
{
float* vec_x = (float*)x->data;
float* vec_noise = (float*)noise->data;
for (int i = 0; i < ggml_nelements(x); i++) {
vec_x[i] = vec_x[i] + vec_noise[i] * sigma_up;
}
}
}
}
} break;
case EULER: // Implemented without any sigma churn
{
struct ggml_tensor* d = ggml_dup_tensor(work_ctx, x);
for (int i = 0; i < steps; i++) {
float sigma = sigmas[i];
// denoise
ggml_tensor* denoised = model(x, sigma, i + 1);
// d = (x - denoised) / sigma
{
float* vec_d = (float*)d->data;
float* vec_x = (float*)x->data;
float* vec_denoised = (float*)denoised->data;
for (int j = 0; j < ggml_nelements(d); j++) {
vec_d[j] = (vec_x[j] - vec_denoised[j]) / sigma;
}
}
float dt = sigmas[i + 1] - sigma;
// x = x + d * dt
{
float* vec_d = (float*)d->data;
float* vec_x = (float*)x->data;
for (int j = 0; j < ggml_nelements(x); j++) {
vec_x[j] = vec_x[j] + vec_d[j] * dt;
}
}
}
} break;
case HEUN: {
struct ggml_tensor* d = ggml_dup_tensor(work_ctx, x);
struct ggml_tensor* x2 = ggml_dup_tensor(work_ctx, x);
for (int i = 0; i < steps; i++) {
// denoise
ggml_tensor* denoised = model(x, sigmas[i], -(i + 1));
// d = (x - denoised) / sigma
{
float* vec_d = (float*)d->data;
float* vec_x = (float*)x->data;
float* vec_denoised = (float*)denoised->data;
for (int j = 0; j < ggml_nelements(x); j++) {
vec_d[j] = (vec_x[j] - vec_denoised[j]) / sigmas[i];
}
}
float dt = sigmas[i + 1] - sigmas[i];
if (sigmas[i + 1] == 0) {
// Euler step
// x = x + d * dt
float* vec_d = (float*)d->data;
float* vec_x = (float*)x->data;
for (int j = 0; j < ggml_nelements(x); j++) {
vec_x[j] = vec_x[j] + vec_d[j] * dt;
}
} else {
// Heun step
float* vec_d = (float*)d->data;
float* vec_d2 = (float*)d->data;
float* vec_x = (float*)x->data;
float* vec_x2 = (float*)x2->data;
for (int j = 0; j < ggml_nelements(x); j++) {
vec_x2[j] = vec_x[j] + vec_d[j] * dt;
}
ggml_tensor* denoised = model(x2, sigmas[i + 1], i + 1);
float* vec_denoised = (float*)denoised->data;
for (int j = 0; j < ggml_nelements(x); j++) {
float d2 = (vec_x2[j] - vec_denoised[j]) / sigmas[i + 1];
vec_d[j] = (vec_d[j] + d2) / 2;
vec_x[j] = vec_x[j] + vec_d[j] * dt;
}
}
}
} break;
case DPM2: {
struct ggml_tensor* d = ggml_dup_tensor(work_ctx, x);
struct ggml_tensor* x2 = ggml_dup_tensor(work_ctx, x);
for (int i = 0; i < steps; i++) {
// denoise
ggml_tensor* denoised = model(x, sigmas[i], i + 1);
// d = (x - denoised) / sigma
{
float* vec_d = (float*)d->data;
float* vec_x = (float*)x->data;
float* vec_denoised = (float*)denoised->data;
for (int j = 0; j < ggml_nelements(x); j++) {
vec_d[j] = (vec_x[j] - vec_denoised[j]) / sigmas[i];
}
}
if (sigmas[i + 1] == 0) {
// Euler step
// x = x + d * dt
float dt = sigmas[i + 1] - sigmas[i];
float* vec_d = (float*)d->data;
float* vec_x = (float*)x->data;
for (int j = 0; j < ggml_nelements(x); j++) {
vec_x[j] = vec_x[j] + vec_d[j] * dt;
}
} else {
// DPM-Solver-2
float sigma_mid = exp(0.5f * (log(sigmas[i]) + log(sigmas[i + 1])));
float dt_1 = sigma_mid - sigmas[i];
float dt_2 = sigmas[i + 1] - sigmas[i];
float* vec_d = (float*)d->data;
float* vec_x = (float*)x->data;
float* vec_x2 = (float*)x2->data;
for (int j = 0; j < ggml_nelements(x); j++) {
vec_x2[j] = vec_x[j] + vec_d[j] * dt_1;
}
ggml_tensor* denoised = model(x2, sigma_mid, i + 1);
float* vec_denoised = (float*)denoised->data;
for (int j = 0; j < ggml_nelements(x); j++) {
float d2 = (vec_x2[j] - vec_denoised[j]) / sigma_mid;
vec_x[j] = vec_x[j] + d2 * dt_2;
}
}
}
} break;
case DPMPP2S_A: {
struct ggml_tensor* noise = ggml_dup_tensor(work_ctx, x);
struct ggml_tensor* d = ggml_dup_tensor(work_ctx, x);
struct ggml_tensor* x2 = ggml_dup_tensor(work_ctx, x);
for (int i = 0; i < steps; i++) {
// denoise
ggml_tensor* denoised = model(x, sigmas[i], i + 1);
// get_ancestral_step
float sigma_up = std::min(sigmas[i + 1],
std::sqrt(sigmas[i + 1] * sigmas[i + 1] * (sigmas[i] * sigmas[i] - sigmas[i + 1] * sigmas[i + 1]) / (sigmas[i] * sigmas[i])));
float sigma_down = std::sqrt(sigmas[i + 1] * sigmas[i + 1] - sigma_up * sigma_up);
auto t_fn = [](float sigma) -> float { return -log(sigma); };
auto sigma_fn = [](float t) -> float { return exp(-t); };
if (sigma_down == 0) {
// Euler step
float* vec_d = (float*)d->data;
float* vec_x = (float*)x->data;
float* vec_denoised = (float*)denoised->data;
for (int j = 0; j < ggml_nelements(d); j++) {
vec_d[j] = (vec_x[j] - vec_denoised[j]) / sigmas[i];
}
// TODO: If sigma_down == 0, isn't this wrong?
// But
// https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py#L525
// has this exactly the same way.
float dt = sigma_down - sigmas[i];
for (int j = 0; j < ggml_nelements(d); j++) {
vec_x[j] = vec_x[j] + vec_d[j] * dt;
}
} else {
// DPM-Solver++(2S)
float t = t_fn(sigmas[i]);
float t_next = t_fn(sigma_down);
float h = t_next - t;
float s = t + 0.5f * h;
float* vec_d = (float*)d->data;
float* vec_x = (float*)x->data;
float* vec_x2 = (float*)x2->data;
float* vec_denoised = (float*)denoised->data;
// First half-step
for (int j = 0; j < ggml_nelements(x); j++) {
vec_x2[j] = (sigma_fn(s) / sigma_fn(t)) * vec_x[j] - (exp(-h * 0.5f) - 1) * vec_denoised[j];
}
ggml_tensor* denoised = model(x2, sigmas[i + 1], i + 1);
// Second half-step
for (int j = 0; j < ggml_nelements(x); j++) {
vec_x[j] = (sigma_fn(t_next) / sigma_fn(t)) * vec_x[j] - (exp(-h) - 1) * vec_denoised[j];
}
}
// Noise addition
if (sigmas[i + 1] > 0) {
ggml_tensor_set_f32_randn(noise, rng);
{
float* vec_x = (float*)x->data;
float* vec_noise = (float*)noise->data;
for (int i = 0; i < ggml_nelements(x); i++) {
vec_x[i] = vec_x[i] + vec_noise[i] * sigma_up;
}
}
}
}
} break;
case DPMPP2M: // DPM++ (2M) from Karras et al (2022)
{
struct ggml_tensor* old_denoised = ggml_dup_tensor(work_ctx, x);
auto t_fn = [](float sigma) -> float { return -log(sigma); };
for (int i = 0; i < steps; i++) {
// denoise
ggml_tensor* denoised = model(x, sigmas[i], i + 1);
float t = t_fn(sigmas[i]);
float t_next = t_fn(sigmas[i + 1]);
float h = t_next - t;
float a = sigmas[i + 1] / sigmas[i];
float b = exp(-h) - 1.f;
float* vec_x = (float*)x->data;
float* vec_denoised = (float*)denoised->data;
float* vec_old_denoised = (float*)old_denoised->data;
if (i == 0 || sigmas[i + 1] == 0) {
// Simpler step for the edge cases
for (int j = 0; j < ggml_nelements(x); j++) {
vec_x[j] = a * vec_x[j] - b * vec_denoised[j];
}
} else {
float h_last = t - t_fn(sigmas[i - 1]);
float r = h_last / h;
for (int j = 0; j < ggml_nelements(x); j++) {
float denoised_d = (1.f + 1.f / (2.f * r)) * vec_denoised[j] - (1.f / (2.f * r)) * vec_old_denoised[j];
vec_x[j] = a * vec_x[j] - b * denoised_d;
}
}
// old_denoised = denoised
for (int j = 0; j < ggml_nelements(x); j++) {
vec_old_denoised[j] = vec_denoised[j];
}
}
} break;
case DPMPP2Mv2: // Modified DPM++ (2M) from https://github.com/AUTOMATIC1111/stable-diffusion-webui/discussions/8457
{
struct ggml_tensor* old_denoised = ggml_dup_tensor(work_ctx, x);
auto t_fn = [](float sigma) -> float { return -log(sigma); };
for (int i = 0; i < steps; i++) {
// denoise
ggml_tensor* denoised = model(x, sigmas[i], i + 1);
float t = t_fn(sigmas[i]);
float t_next = t_fn(sigmas[i + 1]);
float h = t_next - t;
float a = sigmas[i + 1] / sigmas[i];
float* vec_x = (float*)x->data;
float* vec_denoised = (float*)denoised->data;
float* vec_old_denoised = (float*)old_denoised->data;
if (i == 0 || sigmas[i + 1] == 0) {
// Simpler step for the edge cases
float b = exp(-h) - 1.f;
for (int j = 0; j < ggml_nelements(x); j++) {
vec_x[j] = a * vec_x[j] - b * vec_denoised[j];
}
} else {
float h_last = t - t_fn(sigmas[i - 1]);
float h_min = std::min(h_last, h);
float h_max = std::max(h_last, h);
float r = h_max / h_min;
float h_d = (h_max + h_min) / 2.f;
float b = exp(-h_d) - 1.f;
for (int j = 0; j < ggml_nelements(x); j++) {
float denoised_d = (1.f + 1.f / (2.f * r)) * vec_denoised[j] - (1.f / (2.f * r)) * vec_old_denoised[j];
vec_x[j] = a * vec_x[j] - b * denoised_d;
}
}
// old_denoised = denoised
for (int j = 0; j < ggml_nelements(x); j++) {
vec_old_denoised[j] = vec_denoised[j];
}
}
} break;
case IPNDM: // iPNDM sampler from https://github.com/zju-pi/diff-sampler/tree/main/diff-solvers-main
{
int max_order = 4;
ggml_tensor* x_next = x;
std::vector<ggml_tensor*> buffer_model;
for (int i = 0; i < steps; i++) {
float sigma = sigmas[i];
float sigma_next = sigmas[i + 1];
ggml_tensor* x_cur = x_next;
float* vec_x_cur = (float*)x_cur->data;
float* vec_x_next = (float*)x_next->data;
// Denoising step
ggml_tensor* denoised = model(x_cur, sigma, i + 1);
float* vec_denoised = (float*)denoised->data;
// d_cur = (x_cur - denoised) / sigma
struct ggml_tensor* d_cur = ggml_dup_tensor(work_ctx, x_cur);
float* vec_d_cur = (float*)d_cur->data;
for (int j = 0; j < ggml_nelements(d_cur); j++) {
vec_d_cur[j] = (vec_x_cur[j] - vec_denoised[j]) / sigma;
}
int order = std::min(max_order, i + 1);
// Calculate vec_x_next based on the order
switch (order) {
case 1: // First Euler step
for (int j = 0; j < ggml_nelements(x_next); j++) {
vec_x_next[j] = vec_x_cur[j] + (sigma_next - sigma) * vec_d_cur[j];
}
break;
case 2: // Use one history point
{
float* vec_d_prev1 = (float*)buffer_model.back()->data;
for (int j = 0; j < ggml_nelements(x_next); j++) {
vec_x_next[j] = vec_x_cur[j] + (sigma_next - sigma) * (3 * vec_d_cur[j] - vec_d_prev1[j]) / 2;
}
} break;
case 3: // Use two history points
{
float* vec_d_prev1 = (float*)buffer_model.back()->data;
float* vec_d_prev2 = (float*)buffer_model[buffer_model.size() - 2]->data;
for (int j = 0; j < ggml_nelements(x_next); j++) {
vec_x_next[j] = vec_x_cur[j] + (sigma_next - sigma) * (23 * vec_d_cur[j] - 16 * vec_d_prev1[j] + 5 * vec_d_prev2[j]) / 12;
}
} break;
case 4: // Use three history points
{
float* vec_d_prev1 = (float*)buffer_model.back()->data;
float* vec_d_prev2 = (float*)buffer_model[buffer_model.size() - 2]->data;
float* vec_d_prev3 = (float*)buffer_model[buffer_model.size() - 3]->data;
for (int j = 0; j < ggml_nelements(x_next); j++) {
vec_x_next[j] = vec_x_cur[j] + (sigma_next - sigma) * (55 * vec_d_cur[j] - 59 * vec_d_prev1[j] + 37 * vec_d_prev2[j] - 9 * vec_d_prev3[j]) / 24;
}
} break;
}
// Manage buffer_model
if (buffer_model.size() == max_order - 1) {
// Shift elements to the left
for (int k = 0; k < max_order - 2; k++) {
buffer_model[k] = buffer_model[k + 1];
}
buffer_model.back() = d_cur; // Replace the last element with d_cur
} else {
buffer_model.push_back(d_cur);
}
}
} break;
case IPNDM_V: // iPNDM_v sampler from https://github.com/zju-pi/diff-sampler/tree/main/diff-solvers-main
{
int max_order = 4;
std::vector<ggml_tensor*> buffer_model;
ggml_tensor* x_next = x;
for (int i = 0; i < steps; i++) {
float sigma = sigmas[i];
float t_next = sigmas[i + 1];
// Denoising step
ggml_tensor* denoised = model(x, sigma, i + 1);
float* vec_denoised = (float*)denoised->data;
struct ggml_tensor* d_cur = ggml_dup_tensor(work_ctx, x);
float* vec_d_cur = (float*)d_cur->data;
float* vec_x = (float*)x->data;
// d_cur = (x - denoised) / sigma
for (int j = 0; j < ggml_nelements(d_cur); j++) {
vec_d_cur[j] = (vec_x[j] - vec_denoised[j]) / sigma;
}
int order = std::min(max_order, i + 1);
float h_n = t_next - sigma;
float h_n_1 = (i > 0) ? (sigma - sigmas[i - 1]) : h_n;
switch (order) {
case 1: // First Euler step
for (int j = 0; j < ggml_nelements(x_next); j++) {
vec_x[j] += vec_d_cur[j] * h_n;
}
break;
case 2: {
float* vec_d_prev1 = (float*)buffer_model.back()->data;
for (int j = 0; j < ggml_nelements(x_next); j++) {
vec_x[j] += h_n * ((2 + (h_n / h_n_1)) * vec_d_cur[j] - (h_n / h_n_1) * vec_d_prev1[j]) / 2;
}
break;
}
case 3: {
float h_n_2 = (i > 1) ? (sigmas[i - 1] - sigmas[i - 2]) : h_n_1;
float* vec_d_prev1 = (float*)buffer_model.back()->data;
float* vec_d_prev2 = (buffer_model.size() > 1) ? (float*)buffer_model[buffer_model.size() - 2]->data : vec_d_prev1;
for (int j = 0; j < ggml_nelements(x_next); j++) {
vec_x[j] += h_n * ((23 * vec_d_cur[j] - 16 * vec_d_prev1[j] + 5 * vec_d_prev2[j]) / 12);
}
break;
}
case 4: {
float h_n_2 = (i > 1) ? (sigmas[i - 1] - sigmas[i - 2]) : h_n_1;
float h_n_3 = (i > 2) ? (sigmas[i - 2] - sigmas[i - 3]) : h_n_2;
float* vec_d_prev1 = (float*)buffer_model.back()->data;
float* vec_d_prev2 = (buffer_model.size() > 1) ? (float*)buffer_model[buffer_model.size() - 2]->data : vec_d_prev1;
float* vec_d_prev3 = (buffer_model.size() > 2) ? (float*)buffer_model[buffer_model.size() - 3]->data : vec_d_prev2;
for (int j = 0; j < ggml_nelements(x_next); j++) {
vec_x[j] += h_n * ((55 * vec_d_cur[j] - 59 * vec_d_prev1[j] + 37 * vec_d_prev2[j] - 9 * vec_d_prev3[j]) / 24);
}
break;
}
}
// Manage buffer_model
if (buffer_model.size() == max_order - 1) {
buffer_model.erase(buffer_model.begin());
}
buffer_model.push_back(d_cur);
// Prepare the next d tensor
d_cur = ggml_dup_tensor(work_ctx, x_next);
}
} break;
case LCM: // Latent Consistency Models
{
struct ggml_tensor* noise = ggml_dup_tensor(work_ctx, x);
struct ggml_tensor* d = ggml_dup_tensor(work_ctx, x);
for (int i = 0; i < steps; i++) {
float sigma = sigmas[i];
// denoise
ggml_tensor* denoised = model(x, sigma, i + 1);
// x = denoised
{
float* vec_x = (float*)x->data;
float* vec_denoised = (float*)denoised->data;
for (int j = 0; j < ggml_nelements(x); j++) {
vec_x[j] = vec_denoised[j];
}
}
if (sigmas[i + 1] > 0) {
// x += sigmas[i + 1] * noise_sampler(sigmas[i], sigmas[i + 1])
ggml_tensor_set_f32_randn(noise, rng);
// noise = load_tensor_from_file(res_ctx, "./rand" + std::to_string(i+1) + ".bin");
{
float* vec_x = (float*)x->data;
float* vec_noise = (float*)noise->data;
for (int j = 0; j < ggml_nelements(x); j++) {
vec_x[j] = vec_x[j] + sigmas[i + 1] * vec_noise[j];
}
}
}
}
} break;
case DDIM_TRAILING: // Denoising Diffusion Implicit Models
// with the "trailing" timestep spacing
{
// See J. Song et al., "Denoising Diffusion Implicit
// Models", arXiv:2010.02502 [cs.LG]
//
// DDIM itself needs alphas_cumprod (DDPM, J. Ho et al.,
// arXiv:2006.11239 [cs.LG] with k-diffusion's start and
// end beta) (which unfortunately k-diffusion's data
// structure hides from the denoiser), and the sigmas are
// also needed to invert the behavior of CompVisDenoiser
// (k-diffusion's LMSDiscreteScheduler)
float beta_start = 0.00085f;
float beta_end = 0.0120f;
std::vector<double> alphas_cumprod;
std::vector<double> compvis_sigmas;
alphas_cumprod.reserve(TIMESTEPS);
compvis_sigmas.reserve(TIMESTEPS);
for (int i = 0; i < TIMESTEPS; i++) {
alphas_cumprod[i] =
(i == 0 ? 1.0f : alphas_cumprod[i - 1]) *
(1.0f -
std::pow(sqrtf(beta_start) +
(sqrtf(beta_end) - sqrtf(beta_start)) *
((float)i / (TIMESTEPS - 1)), 2));
compvis_sigmas[i] =
std::sqrt((1 - alphas_cumprod[i]) /
alphas_cumprod[i]);
}
struct ggml_tensor* pred_original_sample =
ggml_dup_tensor(work_ctx, x);
struct ggml_tensor* variance_noise =
ggml_dup_tensor(work_ctx, x);
for (int i = 0; i < steps; i++) {
// The "trailing" DDIM timestep, see S. Lin et al.,
// "Common Diffusion Noise Schedules and Sample Steps
// are Flawed", arXiv:2305.08891 [cs], p. 4, Table
// 2. Most variables below follow Diffusers naming
//
// Diffuser naming vs. Song et al. (2010), p. 5, (12)
// and p. 16, (16) (<variable name> -> <name in
// paper>):
//
// - pred_noise_t -> epsilon_theta^(t)(x_t)
// - pred_original_sample -> f_theta^(t)(x_t) or x_0
// - std_dev_t -> sigma_t (not the LMS sigma)
// - eta -> eta (set to 0 at the moment)
// - pred_sample_direction -> "direction pointing to
// x_t"
// - pred_prev_sample -> "x_t-1"
int timestep =
roundf(TIMESTEPS -
i * ((float)TIMESTEPS / steps)) - 1;
// 1. get previous step value (=t-1)
int prev_timestep = timestep - TIMESTEPS / steps;
// The sigma here is chosen to cause the
// CompVisDenoiser to produce t = timestep
float sigma = compvis_sigmas[timestep];
if (i == 0) {
// The function add_noise intializes x to
// Diffusers' latents * sigma (as in Diffusers'
// pipeline) or sample * sigma (Diffusers'
// scheduler), where this sigma = init_noise_sigma
// in Diffusers. For DDPM and DDIM however,
// init_noise_sigma = 1. But the k-diffusion
// model() also evaluates F_theta(c_in(sigma) x;
// ...) instead of the bare U-net F_theta, with
// c_in = 1 / sqrt(sigma^2 + 1), as defined in
// T. Karras et al., "Elucidating the Design Space
// of Diffusion-Based Generative Models",
// arXiv:2206.00364 [cs.CV], p. 3, Table 1. Hence
// the first call has to be prescaled as x <- x /
// (c_in * sigma) with the k-diffusion pipeline
// and CompVisDenoiser.
float* vec_x = (float*)x->data;
for (int j = 0; j < ggml_nelements(x); j++) {
vec_x[j] *= std::sqrt(sigma * sigma + 1) /
sigma;
}
}
else {
// For the subsequent steps after the first one,
// at this point x = latents or x = sample, and
// needs to be prescaled with x <- sample / c_in
// to compensate for model() applying the scale
// c_in before the U-net F_theta
float* vec_x = (float*)x->data;
for (int j = 0; j < ggml_nelements(x); j++) {
vec_x[j] *= std::sqrt(sigma * sigma + 1);
}
}
// Note (also noise_pred in Diffuser's pipeline)
// model_output = model() is the D(x, sigma) as
// defined in Karras et al. (2022), p. 3, Table 1 and
// p. 8 (7), compare also p. 38 (226) therein.
struct ggml_tensor* model_output =
model(x, sigma, i + 1);
// Here model_output is still the k-diffusion denoiser
// output, not the U-net output F_theta(c_in(sigma) x;
// ...) in Karras et al. (2022), whereas Diffusers'
// model_output is F_theta(...). Recover the actual
// model_output, which is also referred to as the
// "Karras ODE derivative" d or d_cur in several
// samplers above.
{
float* vec_x = (float*)x->data;
float* vec_model_output =
(float*)model_output->data;
for (int j = 0; j < ggml_nelements(x); j++) {
vec_model_output[j] =
(vec_x[j] - vec_model_output[j]) *
(1 / sigma);
}
}
// 2. compute alphas, betas
float alpha_prod_t = alphas_cumprod[timestep];
// Note final_alpha_cumprod = alphas_cumprod[0] due to
// trailing timestep spacing
float alpha_prod_t_prev = prev_timestep >= 0 ?
alphas_cumprod[prev_timestep] : alphas_cumprod[0];
float beta_prod_t = 1 - alpha_prod_t;
// 3. compute predicted original sample from predicted
// noise also called "predicted x_0" of formula (12)
// from https://arxiv.org/pdf/2010.02502.pdf
{
float* vec_x = (float*)x->data;
float* vec_model_output =
(float*)model_output->data;
float* vec_pred_original_sample =
(float*)pred_original_sample->data;
// Note the substitution of latents or sample = x
// * c_in = x / sqrt(sigma^2 + 1)
for (int j = 0; j < ggml_nelements(x); j++) {
vec_pred_original_sample[j] =
(vec_x[j] / std::sqrt(sigma * sigma + 1) -
std::sqrt(beta_prod_t) *
vec_model_output[j]) *
(1 / std::sqrt(alpha_prod_t));
}
}
// Assuming the "epsilon" prediction type, where below
// pred_epsilon = model_output is inserted, and is not
// defined/copied explicitly.
//
// 5. compute variance: "sigma_t(eta)" -> see formula
// (16)
//
// sigma_t = sqrt((1 - alpha_t-1)/(1 - alpha_t)) *
// sqrt(1 - alpha_t/alpha_t-1)
float beta_prod_t_prev = 1 - alpha_prod_t_prev;
float variance = (beta_prod_t_prev / beta_prod_t) *
(1 - alpha_prod_t / alpha_prod_t_prev);
float std_dev_t = eta * std::sqrt(variance);
// 6. compute "direction pointing to x_t" of formula
// (12) from https://arxiv.org/pdf/2010.02502.pdf
// 7. compute x_t without "random noise" of formula
// (12) from https://arxiv.org/pdf/2010.02502.pdf
{
float* vec_model_output = (float*)model_output->data;
float* vec_pred_original_sample =
(float*)pred_original_sample->data;
float* vec_x = (float*)x->data;
for (int j = 0; j < ggml_nelements(x); j++) {
// Two step inner loop without an explicit
// tensor
float pred_sample_direction =
std::sqrt(1 - alpha_prod_t_prev -
std::pow(std_dev_t, 2)) *
vec_model_output[j];
vec_x[j] = std::sqrt(alpha_prod_t_prev) *
vec_pred_original_sample[j] +
pred_sample_direction;
}
}
if (eta > 0) {
ggml_tensor_set_f32_randn(variance_noise, rng);
float* vec_variance_noise =
(float*)variance_noise->data;
float* vec_x = (float*)x->data;
for (int j = 0; j < ggml_nelements(x); j++) {
vec_x[j] += std_dev_t * vec_variance_noise[j];
}
}
// See the note above: x = latents or sample here, and
// is not scaled by the c_in. For the final output
// this is correct, but for subsequent iterations, x
// needs to be prescaled again, since k-diffusion's
// model() differes from the bare U-net F_theta by the
// factor c_in.
}
} break;
case TCD: // Strategic Stochastic Sampling (Algorithm 4) in
// Trajectory Consistency Distillation
{
// See J. Zheng et al., "Trajectory Consistency
// Distillation: Improved Latent Consistency Distillation
// by Semi-Linear Consistency Function with Trajectory
// Mapping", arXiv:2402.19159 [cs.CV]
float beta_start = 0.00085f;
float beta_end = 0.0120f;
std::vector<double> alphas_cumprod;
std::vector<double> compvis_sigmas;
alphas_cumprod.reserve(TIMESTEPS);
compvis_sigmas.reserve(TIMESTEPS);
for (int i = 0; i < TIMESTEPS; i++) {
alphas_cumprod[i] =
(i == 0 ? 1.0f : alphas_cumprod[i - 1]) *
(1.0f -
std::pow(sqrtf(beta_start) +
(sqrtf(beta_end) - sqrtf(beta_start)) *
((float)i / (TIMESTEPS - 1)), 2));
compvis_sigmas[i] =
std::sqrt((1 - alphas_cumprod[i]) /
alphas_cumprod[i]);
}
int original_steps = 50;
struct ggml_tensor* pred_original_sample =
ggml_dup_tensor(work_ctx, x);
struct ggml_tensor* noise =
ggml_dup_tensor(work_ctx, x);
for (int i = 0; i < steps; i++) {
// Analytic form for TCD timesteps
int timestep = TIMESTEPS - 1 -
(TIMESTEPS / original_steps) *
(int)floor(i * ((float)original_steps / steps));
// 1. get previous step value
int prev_timestep = i >= steps - 1 ? 0 :
TIMESTEPS - 1 - (TIMESTEPS / original_steps) *
(int)floor((i + 1) *
((float)original_steps / steps));
// Here timestep_s is tau_n' in Algorithm 4. The _s
// notation appears to be that from C. Lu,
// "DPM-Solver: A Fast ODE Solver for Diffusion
// Probabilistic Model Sampling in Around 10 Steps",
// arXiv:2206.00927 [cs.LG], but this notation is not
// continued in Algorithm 4, where _n' is used.
int timestep_s =
(int)floor((1 - eta) * prev_timestep);
// Begin k-diffusion specific workaround for
// evaluating F_theta(x; ...) from D(x, sigma), same
// as in DDIM (and see there for detailed comments)
float sigma = compvis_sigmas[timestep];
if (i == 0) {
float* vec_x = (float*)x->data;
for (int j = 0; j < ggml_nelements(x); j++) {
vec_x[j] *= std::sqrt(sigma * sigma + 1) /
sigma;
}
}
else {
float* vec_x = (float*)x->data;
for (int j = 0; j < ggml_nelements(x); j++) {
vec_x[j] *= std::sqrt(sigma * sigma + 1);
}
}
struct ggml_tensor* model_output =
model(x, sigma, i + 1);
{
float* vec_x = (float*)x->data;
float* vec_model_output =
(float*)model_output->data;
for (int j = 0; j < ggml_nelements(x); j++) {
vec_model_output[j] =
(vec_x[j] - vec_model_output[j]) *
(1 / sigma);
}
}
// 2. compute alphas, betas
//
// When comparing TCD with DDPM/DDIM note that Zheng
// et al. (2024) follows the DPM-Solver notation for
// alpha. One can find the following comment in the
// original DPM-Solver code
// (https://github.com/LuChengTHU/dpm-solver/):
// "**Important**: Please pay special attention for
// the args for `alphas_cumprod`: The `alphas_cumprod`
// is the \hat{alpha_n} arrays in the notations of
// DDPM. [...] Therefore, the notation \hat{alpha_n}
// is different from the notation alpha_t in
// DPM-Solver. In fact, we have alpha_{t_n} =
// \sqrt{\hat{alpha_n}}, [...]"
float alpha_prod_t = alphas_cumprod[timestep];
float beta_prod_t = 1 - alpha_prod_t;
// Note final_alpha_cumprod = alphas_cumprod[0] since
// TCD is always "trailing"
float alpha_prod_t_prev = prev_timestep >= 0 ?
alphas_cumprod[prev_timestep] : alphas_cumprod[0];
// The subscript _s are the only portion in this
// section (2) unique to TCD
float alpha_prod_s = alphas_cumprod[timestep_s];
float beta_prod_s = 1 - alpha_prod_s;
// 3. Compute the predicted noised sample x_s based on
// the model parameterization
//
// This section is also exactly the same as DDIM
{
float* vec_x = (float*)x->data;
float* vec_model_output =
(float*)model_output->data;
float* vec_pred_original_sample =
(float*)pred_original_sample->data;
for (int j = 0; j < ggml_nelements(x); j++) {
vec_pred_original_sample[j] =
(vec_x[j] / std::sqrt(sigma * sigma + 1) -
std::sqrt(beta_prod_t) *
vec_model_output[j]) *
(1 / std::sqrt(alpha_prod_t));
}
}
// This consistency function step can be difficult to
// decipher from Algorithm 4, as it is simply stated
// using a consistency function. This step is the
// modified DDIM, i.e. p. 8 (32) in Zheng et
// al. (2024), with eta set to 0 (see the paragraph
// immediately thereafter that states this somewhat
// obliquely).
{
float* vec_pred_original_sample =
(float*)pred_original_sample->data;
float* vec_model_output =
(float*)model_output->data;
float* vec_x = (float*)x->data;
for (int j = 0; j < ggml_nelements(x); j++) {
// Substituting x = pred_noised_sample and
// pred_epsilon = model_output
vec_x[j] =
std::sqrt(alpha_prod_s) *
vec_pred_original_sample[j] +
std::sqrt(beta_prod_s) *
vec_model_output[j];
}
}
// 4. Sample and inject noise z ~ N(0, I) for
// MultiStep Inference Noise is not used on the final
// timestep of the timestep schedule. This also means
// that noise is not used for one-step sampling. Eta
// (referred to as "gamma" in the paper) was
// introduced to control the stochasticity in every
// step. When eta = 0, it represents deterministic
// sampling, whereas eta = 1 indicates full stochastic
// sampling.
if (eta > 0 && i != steps - 1) {
// In this case, x is still pred_noised_sample,
// continue in-place
ggml_tensor_set_f32_randn(noise, rng);
float* vec_x = (float*)x->data;
float* vec_noise = (float*)noise->data;
for (int j = 0; j < ggml_nelements(x); j++) {
// Corresponding to (35) in Zheng et
// al. (2024), substituting x =
// pred_noised_sample
vec_x[j] =
std::sqrt(alpha_prod_t_prev /
alpha_prod_s) *
vec_x[j] +
std::sqrt(1 - alpha_prod_t_prev /
alpha_prod_s) *
vec_noise[j];
}
}
}
} break;
default:
LOG_ERROR("Attempting to sample with nonexisting sample method %i", method);
abort();
}
}
#endif // __DENOISER_HPP__