feat: add option to switch the sigma schedule (#51)

Concretely, this allows switching to the "Karras" schedule from the
Karras et al 2022 paper, equivalent to the samplers marked as "Karras"
in the AUTOMATIC1111 WebUI. This choice is in principle orthogonal to
the sampler choice and can be given independently.
This commit is contained in:
Urs Ganse 2023-09-08 19:02:07 +03:00 committed by GitHub
parent b6899e8fc2
commit 968fbf02aa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 117 additions and 37 deletions

View File

@ -80,6 +80,12 @@ const char* sample_method_str[] = {
"dpm++2m",
"dpm++2mv2"};
// Names of the sigma schedule overrides, same order as Schedule in stable-diffusion.h
const char* schedule_str[] = {
"default",
"discrete",
"karras"};
struct Option {
int n_threads = -1;
std::string mode = TXT2IMG;
@ -92,6 +98,7 @@ struct Option {
int w = 512;
int h = 512;
SampleMethod sample_method = EULER_A;
Schedule schedule = DEFAULT;
int sample_steps = 20;
float strength = 0.75f;
RNGType rng_type = CUDA_RNG;
@ -111,6 +118,7 @@ struct Option {
printf(" width: %d\n", w);
printf(" height: %d\n", h);
printf(" sample_method: %s\n", sample_method_str[sample_method]);
printf(" schedule: %s\n", schedule_str[schedule]);
printf(" sample_steps: %d\n", sample_steps);
printf(" strength: %.2f\n", strength);
printf(" rng: %s\n", rng_type_to_str[rng_type]);
@ -141,6 +149,7 @@ void print_usage(int argc, const char* argv[]) {
printf(" --steps STEPS number of sample steps (default: 20)\n");
printf(" --rng {std_default, cuda} RNG (default: cuda)\n");
printf(" -s SEED, --seed SEED RNG seed (default: 42, use random seed for < 0)\n");
printf(" --schedule {discrete, karras} Denoiser sigma schedule (default: discrete)\n");
printf(" -v, --verbose print extra info\n");
}
@ -237,6 +246,23 @@ void parse_args(int argc, const char* argv[], Option* opt) {
invalid_arg = true;
break;
}
} else if (arg == "--schedule") {
if (++i >= argc) {
invalid_arg = true;
break;
}
const char* schedule_selected = argv[i];
int schedule_found = -1;
for (int d = 0; d < N_SCHEDULES; d++) {
if (!strcmp(schedule_selected, schedule_str[d])) {
schedule_found = d;
}
}
if (schedule_found == -1) {
invalid_arg = true;
break;
}
opt->schedule = (Schedule)schedule_found;
} else if (arg == "-s" || arg == "--seed") {
if (++i >= argc) {
invalid_arg = true;
@ -377,7 +403,7 @@ int main(int argc, const char* argv[]) {
}
StableDiffusion sd(opt.n_threads, vae_decode_only, true, opt.rng_type);
if (!sd.load_from_file(opt.model_path)) {
if (!sd.load_from_file(opt.model_path, opt.schedule)) {
return 1;
}
@ -413,4 +439,4 @@ int main(int argc, const char* argv[]) {
printf("save result image to '%s'\n", opt.output_path.c_str());
return 0;
}
}

View File

@ -2654,32 +2654,12 @@ struct AutoEncoderKL {
// Ref: https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/external.py
struct DiscreteSchedule {
struct SigmaSchedule {
float alphas_cumprod[TIMESTEPS];
float sigmas[TIMESTEPS];
float log_sigmas[TIMESTEPS];
std::vector<float> get_sigmas(uint32_t n) {
std::vector<float> result;
int t_max = TIMESTEPS - 1;
if (n == 0) {
return result;
} else if (n == 1) {
result.push_back(t_to_sigma(t_max));
result.push_back(0);
return result;
}
float step = static_cast<float>(t_max) / static_cast<float>(n - 1);
for (int i = 0; i < n; ++i) {
float t = t_max - step * i;
result.push_back(t_to_sigma(t));
}
result.push_back(0);
return result;
}
virtual std::vector<float> get_sigmas(uint32_t n) = 0;
float sigma_to_t(float sigma) {
float log_sigma = std::log(sigma);
@ -2714,11 +2694,59 @@ struct DiscreteSchedule {
float log_sigma = (1.0f - w) * log_sigmas[low_idx] + w * log_sigmas[high_idx];
return std::exp(log_sigma);
}
};
struct DiscreteSchedule : SigmaSchedule {
std::vector<float> get_sigmas(uint32_t n) {
std::vector<float> result;
int t_max = TIMESTEPS - 1;
if (n == 0) {
return result;
} else if (n == 1) {
result.push_back(t_to_sigma(t_max));
result.push_back(0);
return result;
}
float step = static_cast<float>(t_max) / static_cast<float>(n - 1);
for (int i = 0; i < n; ++i) {
float t = t_max - step * i;
result.push_back(t_to_sigma(t));
}
result.push_back(0);
return result;
}
};
struct KarrasSchedule : SigmaSchedule {
std::vector<float> get_sigmas(uint32_t n) {
// These *COULD* be function arguments here,
// but does anybody ever bother to touch them?
float sigma_min = 0.1;
float sigma_max = 10.;
float rho = 7.;
std::vector<float> result(n + 1);
float min_inv_rho = pow(sigma_min, (1. / rho));
float max_inv_rho = pow(sigma_max, (1. / rho));
for (int i = 0; i < n; i++) {
// Eq. (5) from Karras et al 2022
result[i] = pow(max_inv_rho + (float)i / ((float)n - 1.) * (min_inv_rho - max_inv_rho), rho);
}
result[n] = 0.;
return result;
}
};
struct Denoiser {
std::shared_ptr<SigmaSchedule> schedule = std::make_shared<DiscreteSchedule>();
virtual std::vector<float> get_scalings(float sigma) = 0;
};
struct CompVisDenoiser : public DiscreteSchedule {
struct CompVisDenoiser : public Denoiser {
float sigma_data = 1.0f;
std::vector<float> get_scalings(float sigma) {
@ -2728,7 +2756,7 @@ struct CompVisDenoiser : public DiscreteSchedule {
}
};
struct CompVisVDenoiser : public DiscreteSchedule {
struct CompVisVDenoiser : public Denoiser {
float sigma_data = 1.0f;
std::vector<float> get_scalings(float sigma) {
@ -2764,7 +2792,7 @@ class StableDiffusionGGML {
UNetModel diffusion_model;
AutoEncoderKL first_stage_model;
std::shared_ptr<DiscreteSchedule> denoiser = std::make_shared<CompVisDenoiser>();
std::shared_ptr<Denoiser> denoiser = std::make_shared<CompVisDenoiser>();
StableDiffusionGGML() = default;
@ -2798,7 +2826,7 @@ class StableDiffusionGGML {
}
}
bool load_from_file(const std::string& file_path) {
bool load_from_file(const std::string& file_path, Schedule schedule) {
LOG_INFO("loading model from '%s'", file_path.c_str());
std::ifstream file(file_path, std::ios::binary);
@ -3093,10 +3121,29 @@ class StableDiffusionGGML {
LOG_INFO("running in eps-prediction mode");
}
if (schedule != DEFAULT) {
switch (schedule) {
case DISCRETE:
LOG_INFO("running with discrete schedule");
denoiser->schedule = std::make_shared<DiscreteSchedule>();
break;
case KARRAS:
LOG_INFO("running with Karras schedule");
denoiser->schedule = std::make_shared<KarrasSchedule>();
break;
case DEFAULT:
// Don't touch anything.
break;
default:
LOG_ERROR("Unknown schedule %i", schedule);
abort();
}
}
for (int i = 0; i < TIMESTEPS; i++) {
denoiser->alphas_cumprod[i] = alphas_cumprod[i];
denoiser->sigmas[i] = std::sqrt((1 - denoiser->alphas_cumprod[i]) / denoiser->alphas_cumprod[i]);
denoiser->log_sigmas[i] = std::log(denoiser->sigmas[i]);
denoiser->schedule->alphas_cumprod[i] = alphas_cumprod[i];
denoiser->schedule->sigmas[i] = std::sqrt((1 - denoiser->schedule->alphas_cumprod[i]) / denoiser->schedule->alphas_cumprod[i]);
denoiser->schedule->log_sigmas[i] = std::log(denoiser->schedule->sigmas[i]);
}
return true;
@ -3445,7 +3492,7 @@ class StableDiffusionGGML {
c_in = scaling[1];
}
float t = denoiser->sigma_to_t(sigma);
float t = denoiser->schedule->sigma_to_t(sigma);
ggml_set_f32(timesteps, t);
set_timestep_embedding(timesteps, t_emb, diffusion_model.model_channels);
@ -4010,8 +4057,8 @@ StableDiffusion::StableDiffusion(int n_threads,
rng_type);
}
bool StableDiffusion::load_from_file(const std::string& file_path) {
return sd->load_from_file(file_path);
bool StableDiffusion::load_from_file(const std::string& file_path, Schedule s) {
return sd->load_from_file(file_path, s);
}
std::vector<uint8_t> StableDiffusion::txt2img(const std::string& prompt,
@ -4061,7 +4108,7 @@ std::vector<uint8_t> StableDiffusion::txt2img(const std::string& prompt,
struct ggml_tensor* x_t = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, W, H, C, 1);
ggml_tensor_set_f32_randn(x_t, sd->rng);
std::vector<float> sigmas = sd->denoiser->get_sigmas(sample_steps);
std::vector<float> sigmas = sd->denoiser->schedule->get_sigmas(sample_steps);
LOG_INFO("start sampling");
struct ggml_tensor* x_0 = sd->sample(ctx, x_t, c, uc, cfg_scale, sample_method, sigmas);
@ -4117,7 +4164,7 @@ std::vector<uint8_t> StableDiffusion::img2img(const std::vector<uint8_t>& init_i
}
LOG_INFO("img2img %dx%d", width, height);
std::vector<float> sigmas = sd->denoiser->get_sigmas(sample_steps);
std::vector<float> sigmas = sd->denoiser->schedule->get_sigmas(sample_steps);
size_t t_enc = static_cast<size_t>(sample_steps * strength);
LOG_INFO("target t_enc is %zu steps", t_enc);
std::vector<float> sigma_sched;

View File

@ -25,6 +25,13 @@ enum SampleMethod {
N_SAMPLE_METHODS
};
enum Schedule {
DEFAULT,
DISCRETE,
KARRAS,
N_SCHEDULES
};
class StableDiffusionGGML;
class StableDiffusion {
@ -36,7 +43,7 @@ class StableDiffusion {
bool vae_decode_only = false,
bool free_params_immediately = false,
RNGType rng_type = STD_DEFAULT_RNG);
bool load_from_file(const std::string& file_path);
bool load_from_file(const std::string& file_path, Schedule d = DEFAULT);
std::vector<uint8_t> txt2img(
const std::string& prompt,
const std::string& negative_prompt,