feat: add option to switch the sigma schedule (#51)
Concretely, this allows switching to the "Karras" schedule from the Karras et al 2022 paper, equivalent to the samplers marked as "Karras" in the AUTOMATIC1111 WebUI. This choice is in principle orthogonal to the sampler choice and can be given independently.
This commit is contained in:
parent
b6899e8fc2
commit
968fbf02aa
@ -80,6 +80,12 @@ const char* sample_method_str[] = {
|
|||||||
"dpm++2m",
|
"dpm++2m",
|
||||||
"dpm++2mv2"};
|
"dpm++2mv2"};
|
||||||
|
|
||||||
|
// Names of the sigma schedule overrides, same order as Schedule in stable-diffusion.h
|
||||||
|
const char* schedule_str[] = {
|
||||||
|
"default",
|
||||||
|
"discrete",
|
||||||
|
"karras"};
|
||||||
|
|
||||||
struct Option {
|
struct Option {
|
||||||
int n_threads = -1;
|
int n_threads = -1;
|
||||||
std::string mode = TXT2IMG;
|
std::string mode = TXT2IMG;
|
||||||
@ -92,6 +98,7 @@ struct Option {
|
|||||||
int w = 512;
|
int w = 512;
|
||||||
int h = 512;
|
int h = 512;
|
||||||
SampleMethod sample_method = EULER_A;
|
SampleMethod sample_method = EULER_A;
|
||||||
|
Schedule schedule = DEFAULT;
|
||||||
int sample_steps = 20;
|
int sample_steps = 20;
|
||||||
float strength = 0.75f;
|
float strength = 0.75f;
|
||||||
RNGType rng_type = CUDA_RNG;
|
RNGType rng_type = CUDA_RNG;
|
||||||
@ -111,6 +118,7 @@ struct Option {
|
|||||||
printf(" width: %d\n", w);
|
printf(" width: %d\n", w);
|
||||||
printf(" height: %d\n", h);
|
printf(" height: %d\n", h);
|
||||||
printf(" sample_method: %s\n", sample_method_str[sample_method]);
|
printf(" sample_method: %s\n", sample_method_str[sample_method]);
|
||||||
|
printf(" schedule: %s\n", schedule_str[schedule]);
|
||||||
printf(" sample_steps: %d\n", sample_steps);
|
printf(" sample_steps: %d\n", sample_steps);
|
||||||
printf(" strength: %.2f\n", strength);
|
printf(" strength: %.2f\n", strength);
|
||||||
printf(" rng: %s\n", rng_type_to_str[rng_type]);
|
printf(" rng: %s\n", rng_type_to_str[rng_type]);
|
||||||
@ -141,6 +149,7 @@ void print_usage(int argc, const char* argv[]) {
|
|||||||
printf(" --steps STEPS number of sample steps (default: 20)\n");
|
printf(" --steps STEPS number of sample steps (default: 20)\n");
|
||||||
printf(" --rng {std_default, cuda} RNG (default: cuda)\n");
|
printf(" --rng {std_default, cuda} RNG (default: cuda)\n");
|
||||||
printf(" -s SEED, --seed SEED RNG seed (default: 42, use random seed for < 0)\n");
|
printf(" -s SEED, --seed SEED RNG seed (default: 42, use random seed for < 0)\n");
|
||||||
|
printf(" --schedule {discrete, karras} Denoiser sigma schedule (default: discrete)\n");
|
||||||
printf(" -v, --verbose print extra info\n");
|
printf(" -v, --verbose print extra info\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -237,6 +246,23 @@ void parse_args(int argc, const char* argv[], Option* opt) {
|
|||||||
invalid_arg = true;
|
invalid_arg = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
} else if (arg == "--schedule") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
invalid_arg = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
const char* schedule_selected = argv[i];
|
||||||
|
int schedule_found = -1;
|
||||||
|
for (int d = 0; d < N_SCHEDULES; d++) {
|
||||||
|
if (!strcmp(schedule_selected, schedule_str[d])) {
|
||||||
|
schedule_found = d;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (schedule_found == -1) {
|
||||||
|
invalid_arg = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
opt->schedule = (Schedule)schedule_found;
|
||||||
} else if (arg == "-s" || arg == "--seed") {
|
} else if (arg == "-s" || arg == "--seed") {
|
||||||
if (++i >= argc) {
|
if (++i >= argc) {
|
||||||
invalid_arg = true;
|
invalid_arg = true;
|
||||||
@ -377,7 +403,7 @@ int main(int argc, const char* argv[]) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
StableDiffusion sd(opt.n_threads, vae_decode_only, true, opt.rng_type);
|
StableDiffusion sd(opt.n_threads, vae_decode_only, true, opt.rng_type);
|
||||||
if (!sd.load_from_file(opt.model_path)) {
|
if (!sd.load_from_file(opt.model_path, opt.schedule)) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2654,32 +2654,12 @@ struct AutoEncoderKL {
|
|||||||
|
|
||||||
// Ref: https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/external.py
|
// Ref: https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/external.py
|
||||||
|
|
||||||
struct DiscreteSchedule {
|
struct SigmaSchedule {
|
||||||
float alphas_cumprod[TIMESTEPS];
|
float alphas_cumprod[TIMESTEPS];
|
||||||
float sigmas[TIMESTEPS];
|
float sigmas[TIMESTEPS];
|
||||||
float log_sigmas[TIMESTEPS];
|
float log_sigmas[TIMESTEPS];
|
||||||
|
|
||||||
std::vector<float> get_sigmas(uint32_t n) {
|
virtual std::vector<float> get_sigmas(uint32_t n) = 0;
|
||||||
std::vector<float> result;
|
|
||||||
|
|
||||||
int t_max = TIMESTEPS - 1;
|
|
||||||
|
|
||||||
if (n == 0) {
|
|
||||||
return result;
|
|
||||||
} else if (n == 1) {
|
|
||||||
result.push_back(t_to_sigma(t_max));
|
|
||||||
result.push_back(0);
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
float step = static_cast<float>(t_max) / static_cast<float>(n - 1);
|
|
||||||
for (int i = 0; i < n; ++i) {
|
|
||||||
float t = t_max - step * i;
|
|
||||||
result.push_back(t_to_sigma(t));
|
|
||||||
}
|
|
||||||
result.push_back(0);
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
float sigma_to_t(float sigma) {
|
float sigma_to_t(float sigma) {
|
||||||
float log_sigma = std::log(sigma);
|
float log_sigma = std::log(sigma);
|
||||||
@ -2714,11 +2694,59 @@ struct DiscreteSchedule {
|
|||||||
float log_sigma = (1.0f - w) * log_sigmas[low_idx] + w * log_sigmas[high_idx];
|
float log_sigma = (1.0f - w) * log_sigmas[low_idx] + w * log_sigmas[high_idx];
|
||||||
return std::exp(log_sigma);
|
return std::exp(log_sigma);
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct DiscreteSchedule : SigmaSchedule {
|
||||||
|
std::vector<float> get_sigmas(uint32_t n) {
|
||||||
|
std::vector<float> result;
|
||||||
|
|
||||||
|
int t_max = TIMESTEPS - 1;
|
||||||
|
|
||||||
|
if (n == 0) {
|
||||||
|
return result;
|
||||||
|
} else if (n == 1) {
|
||||||
|
result.push_back(t_to_sigma(t_max));
|
||||||
|
result.push_back(0);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
float step = static_cast<float>(t_max) / static_cast<float>(n - 1);
|
||||||
|
for (int i = 0; i < n; ++i) {
|
||||||
|
float t = t_max - step * i;
|
||||||
|
result.push_back(t_to_sigma(t));
|
||||||
|
}
|
||||||
|
result.push_back(0);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct KarrasSchedule : SigmaSchedule {
|
||||||
|
std::vector<float> get_sigmas(uint32_t n) {
|
||||||
|
// These *COULD* be function arguments here,
|
||||||
|
// but does anybody ever bother to touch them?
|
||||||
|
float sigma_min = 0.1;
|
||||||
|
float sigma_max = 10.;
|
||||||
|
float rho = 7.;
|
||||||
|
|
||||||
|
std::vector<float> result(n + 1);
|
||||||
|
|
||||||
|
float min_inv_rho = pow(sigma_min, (1. / rho));
|
||||||
|
float max_inv_rho = pow(sigma_max, (1. / rho));
|
||||||
|
for (int i = 0; i < n; i++) {
|
||||||
|
// Eq. (5) from Karras et al 2022
|
||||||
|
result[i] = pow(max_inv_rho + (float)i / ((float)n - 1.) * (min_inv_rho - max_inv_rho), rho);
|
||||||
|
}
|
||||||
|
result[n] = 0.;
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Denoiser {
|
||||||
|
std::shared_ptr<SigmaSchedule> schedule = std::make_shared<DiscreteSchedule>();
|
||||||
virtual std::vector<float> get_scalings(float sigma) = 0;
|
virtual std::vector<float> get_scalings(float sigma) = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct CompVisDenoiser : public DiscreteSchedule {
|
struct CompVisDenoiser : public Denoiser {
|
||||||
float sigma_data = 1.0f;
|
float sigma_data = 1.0f;
|
||||||
|
|
||||||
std::vector<float> get_scalings(float sigma) {
|
std::vector<float> get_scalings(float sigma) {
|
||||||
@ -2728,7 +2756,7 @@ struct CompVisDenoiser : public DiscreteSchedule {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct CompVisVDenoiser : public DiscreteSchedule {
|
struct CompVisVDenoiser : public Denoiser {
|
||||||
float sigma_data = 1.0f;
|
float sigma_data = 1.0f;
|
||||||
|
|
||||||
std::vector<float> get_scalings(float sigma) {
|
std::vector<float> get_scalings(float sigma) {
|
||||||
@ -2764,7 +2792,7 @@ class StableDiffusionGGML {
|
|||||||
UNetModel diffusion_model;
|
UNetModel diffusion_model;
|
||||||
AutoEncoderKL first_stage_model;
|
AutoEncoderKL first_stage_model;
|
||||||
|
|
||||||
std::shared_ptr<DiscreteSchedule> denoiser = std::make_shared<CompVisDenoiser>();
|
std::shared_ptr<Denoiser> denoiser = std::make_shared<CompVisDenoiser>();
|
||||||
|
|
||||||
StableDiffusionGGML() = default;
|
StableDiffusionGGML() = default;
|
||||||
|
|
||||||
@ -2798,7 +2826,7 @@ class StableDiffusionGGML {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool load_from_file(const std::string& file_path) {
|
bool load_from_file(const std::string& file_path, Schedule schedule) {
|
||||||
LOG_INFO("loading model from '%s'", file_path.c_str());
|
LOG_INFO("loading model from '%s'", file_path.c_str());
|
||||||
|
|
||||||
std::ifstream file(file_path, std::ios::binary);
|
std::ifstream file(file_path, std::ios::binary);
|
||||||
@ -3093,10 +3121,29 @@ class StableDiffusionGGML {
|
|||||||
LOG_INFO("running in eps-prediction mode");
|
LOG_INFO("running in eps-prediction mode");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (schedule != DEFAULT) {
|
||||||
|
switch (schedule) {
|
||||||
|
case DISCRETE:
|
||||||
|
LOG_INFO("running with discrete schedule");
|
||||||
|
denoiser->schedule = std::make_shared<DiscreteSchedule>();
|
||||||
|
break;
|
||||||
|
case KARRAS:
|
||||||
|
LOG_INFO("running with Karras schedule");
|
||||||
|
denoiser->schedule = std::make_shared<KarrasSchedule>();
|
||||||
|
break;
|
||||||
|
case DEFAULT:
|
||||||
|
// Don't touch anything.
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
LOG_ERROR("Unknown schedule %i", schedule);
|
||||||
|
abort();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for (int i = 0; i < TIMESTEPS; i++) {
|
for (int i = 0; i < TIMESTEPS; i++) {
|
||||||
denoiser->alphas_cumprod[i] = alphas_cumprod[i];
|
denoiser->schedule->alphas_cumprod[i] = alphas_cumprod[i];
|
||||||
denoiser->sigmas[i] = std::sqrt((1 - denoiser->alphas_cumprod[i]) / denoiser->alphas_cumprod[i]);
|
denoiser->schedule->sigmas[i] = std::sqrt((1 - denoiser->schedule->alphas_cumprod[i]) / denoiser->schedule->alphas_cumprod[i]);
|
||||||
denoiser->log_sigmas[i] = std::log(denoiser->sigmas[i]);
|
denoiser->schedule->log_sigmas[i] = std::log(denoiser->schedule->sigmas[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
@ -3445,7 +3492,7 @@ class StableDiffusionGGML {
|
|||||||
c_in = scaling[1];
|
c_in = scaling[1];
|
||||||
}
|
}
|
||||||
|
|
||||||
float t = denoiser->sigma_to_t(sigma);
|
float t = denoiser->schedule->sigma_to_t(sigma);
|
||||||
ggml_set_f32(timesteps, t);
|
ggml_set_f32(timesteps, t);
|
||||||
set_timestep_embedding(timesteps, t_emb, diffusion_model.model_channels);
|
set_timestep_embedding(timesteps, t_emb, diffusion_model.model_channels);
|
||||||
|
|
||||||
@ -4010,8 +4057,8 @@ StableDiffusion::StableDiffusion(int n_threads,
|
|||||||
rng_type);
|
rng_type);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool StableDiffusion::load_from_file(const std::string& file_path) {
|
bool StableDiffusion::load_from_file(const std::string& file_path, Schedule s) {
|
||||||
return sd->load_from_file(file_path);
|
return sd->load_from_file(file_path, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<uint8_t> StableDiffusion::txt2img(const std::string& prompt,
|
std::vector<uint8_t> StableDiffusion::txt2img(const std::string& prompt,
|
||||||
@ -4061,7 +4108,7 @@ std::vector<uint8_t> StableDiffusion::txt2img(const std::string& prompt,
|
|||||||
struct ggml_tensor* x_t = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, W, H, C, 1);
|
struct ggml_tensor* x_t = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, W, H, C, 1);
|
||||||
ggml_tensor_set_f32_randn(x_t, sd->rng);
|
ggml_tensor_set_f32_randn(x_t, sd->rng);
|
||||||
|
|
||||||
std::vector<float> sigmas = sd->denoiser->get_sigmas(sample_steps);
|
std::vector<float> sigmas = sd->denoiser->schedule->get_sigmas(sample_steps);
|
||||||
|
|
||||||
LOG_INFO("start sampling");
|
LOG_INFO("start sampling");
|
||||||
struct ggml_tensor* x_0 = sd->sample(ctx, x_t, c, uc, cfg_scale, sample_method, sigmas);
|
struct ggml_tensor* x_0 = sd->sample(ctx, x_t, c, uc, cfg_scale, sample_method, sigmas);
|
||||||
@ -4117,7 +4164,7 @@ std::vector<uint8_t> StableDiffusion::img2img(const std::vector<uint8_t>& init_i
|
|||||||
}
|
}
|
||||||
LOG_INFO("img2img %dx%d", width, height);
|
LOG_INFO("img2img %dx%d", width, height);
|
||||||
|
|
||||||
std::vector<float> sigmas = sd->denoiser->get_sigmas(sample_steps);
|
std::vector<float> sigmas = sd->denoiser->schedule->get_sigmas(sample_steps);
|
||||||
size_t t_enc = static_cast<size_t>(sample_steps * strength);
|
size_t t_enc = static_cast<size_t>(sample_steps * strength);
|
||||||
LOG_INFO("target t_enc is %zu steps", t_enc);
|
LOG_INFO("target t_enc is %zu steps", t_enc);
|
||||||
std::vector<float> sigma_sched;
|
std::vector<float> sigma_sched;
|
||||||
|
@ -25,6 +25,13 @@ enum SampleMethod {
|
|||||||
N_SAMPLE_METHODS
|
N_SAMPLE_METHODS
|
||||||
};
|
};
|
||||||
|
|
||||||
|
enum Schedule {
|
||||||
|
DEFAULT,
|
||||||
|
DISCRETE,
|
||||||
|
KARRAS,
|
||||||
|
N_SCHEDULES
|
||||||
|
};
|
||||||
|
|
||||||
class StableDiffusionGGML;
|
class StableDiffusionGGML;
|
||||||
|
|
||||||
class StableDiffusion {
|
class StableDiffusion {
|
||||||
@ -36,7 +43,7 @@ class StableDiffusion {
|
|||||||
bool vae_decode_only = false,
|
bool vae_decode_only = false,
|
||||||
bool free_params_immediately = false,
|
bool free_params_immediately = false,
|
||||||
RNGType rng_type = STD_DEFAULT_RNG);
|
RNGType rng_type = STD_DEFAULT_RNG);
|
||||||
bool load_from_file(const std::string& file_path);
|
bool load_from_file(const std::string& file_path, Schedule d = DEFAULT);
|
||||||
std::vector<uint8_t> txt2img(
|
std::vector<uint8_t> txt2img(
|
||||||
const std::string& prompt,
|
const std::string& prompt,
|
||||||
const std::string& negative_prompt,
|
const std::string& negative_prompt,
|
||||||
|
Loading…
Reference in New Issue
Block a user