feat: add option to switch the sigma schedule (#51)
Concretely, this allows switching to the "Karras" schedule from the Karras et al 2022 paper, equivalent to the samplers marked as "Karras" in the AUTOMATIC1111 WebUI. This choice is in principle orthogonal to the sampler choice and can be given independently.
This commit is contained in:
parent
b6899e8fc2
commit
968fbf02aa
@ -80,6 +80,12 @@ const char* sample_method_str[] = {
|
||||
"dpm++2m",
|
||||
"dpm++2mv2"};
|
||||
|
||||
// Names of the sigma schedule overrides, same order as Schedule in stable-diffusion.h
|
||||
const char* schedule_str[] = {
|
||||
"default",
|
||||
"discrete",
|
||||
"karras"};
|
||||
|
||||
struct Option {
|
||||
int n_threads = -1;
|
||||
std::string mode = TXT2IMG;
|
||||
@ -92,6 +98,7 @@ struct Option {
|
||||
int w = 512;
|
||||
int h = 512;
|
||||
SampleMethod sample_method = EULER_A;
|
||||
Schedule schedule = DEFAULT;
|
||||
int sample_steps = 20;
|
||||
float strength = 0.75f;
|
||||
RNGType rng_type = CUDA_RNG;
|
||||
@ -111,6 +118,7 @@ struct Option {
|
||||
printf(" width: %d\n", w);
|
||||
printf(" height: %d\n", h);
|
||||
printf(" sample_method: %s\n", sample_method_str[sample_method]);
|
||||
printf(" schedule: %s\n", schedule_str[schedule]);
|
||||
printf(" sample_steps: %d\n", sample_steps);
|
||||
printf(" strength: %.2f\n", strength);
|
||||
printf(" rng: %s\n", rng_type_to_str[rng_type]);
|
||||
@ -141,6 +149,7 @@ void print_usage(int argc, const char* argv[]) {
|
||||
printf(" --steps STEPS number of sample steps (default: 20)\n");
|
||||
printf(" --rng {std_default, cuda} RNG (default: cuda)\n");
|
||||
printf(" -s SEED, --seed SEED RNG seed (default: 42, use random seed for < 0)\n");
|
||||
printf(" --schedule {discrete, karras} Denoiser sigma schedule (default: discrete)\n");
|
||||
printf(" -v, --verbose print extra info\n");
|
||||
}
|
||||
|
||||
@ -237,6 +246,23 @@ void parse_args(int argc, const char* argv[], Option* opt) {
|
||||
invalid_arg = true;
|
||||
break;
|
||||
}
|
||||
} else if (arg == "--schedule") {
|
||||
if (++i >= argc) {
|
||||
invalid_arg = true;
|
||||
break;
|
||||
}
|
||||
const char* schedule_selected = argv[i];
|
||||
int schedule_found = -1;
|
||||
for (int d = 0; d < N_SCHEDULES; d++) {
|
||||
if (!strcmp(schedule_selected, schedule_str[d])) {
|
||||
schedule_found = d;
|
||||
}
|
||||
}
|
||||
if (schedule_found == -1) {
|
||||
invalid_arg = true;
|
||||
break;
|
||||
}
|
||||
opt->schedule = (Schedule)schedule_found;
|
||||
} else if (arg == "-s" || arg == "--seed") {
|
||||
if (++i >= argc) {
|
||||
invalid_arg = true;
|
||||
@ -377,7 +403,7 @@ int main(int argc, const char* argv[]) {
|
||||
}
|
||||
|
||||
StableDiffusion sd(opt.n_threads, vae_decode_only, true, opt.rng_type);
|
||||
if (!sd.load_from_file(opt.model_path)) {
|
||||
if (!sd.load_from_file(opt.model_path, opt.schedule)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
@ -413,4 +439,4 @@ int main(int argc, const char* argv[]) {
|
||||
printf("save result image to '%s'\n", opt.output_path.c_str());
|
||||
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
@ -2654,32 +2654,12 @@ struct AutoEncoderKL {
|
||||
|
||||
// Ref: https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/external.py
|
||||
|
||||
struct DiscreteSchedule {
|
||||
struct SigmaSchedule {
|
||||
float alphas_cumprod[TIMESTEPS];
|
||||
float sigmas[TIMESTEPS];
|
||||
float log_sigmas[TIMESTEPS];
|
||||
|
||||
std::vector<float> get_sigmas(uint32_t n) {
|
||||
std::vector<float> result;
|
||||
|
||||
int t_max = TIMESTEPS - 1;
|
||||
|
||||
if (n == 0) {
|
||||
return result;
|
||||
} else if (n == 1) {
|
||||
result.push_back(t_to_sigma(t_max));
|
||||
result.push_back(0);
|
||||
return result;
|
||||
}
|
||||
|
||||
float step = static_cast<float>(t_max) / static_cast<float>(n - 1);
|
||||
for (int i = 0; i < n; ++i) {
|
||||
float t = t_max - step * i;
|
||||
result.push_back(t_to_sigma(t));
|
||||
}
|
||||
result.push_back(0);
|
||||
return result;
|
||||
}
|
||||
virtual std::vector<float> get_sigmas(uint32_t n) = 0;
|
||||
|
||||
float sigma_to_t(float sigma) {
|
||||
float log_sigma = std::log(sigma);
|
||||
@ -2714,11 +2694,59 @@ struct DiscreteSchedule {
|
||||
float log_sigma = (1.0f - w) * log_sigmas[low_idx] + w * log_sigmas[high_idx];
|
||||
return std::exp(log_sigma);
|
||||
}
|
||||
};
|
||||
|
||||
struct DiscreteSchedule : SigmaSchedule {
|
||||
std::vector<float> get_sigmas(uint32_t n) {
|
||||
std::vector<float> result;
|
||||
|
||||
int t_max = TIMESTEPS - 1;
|
||||
|
||||
if (n == 0) {
|
||||
return result;
|
||||
} else if (n == 1) {
|
||||
result.push_back(t_to_sigma(t_max));
|
||||
result.push_back(0);
|
||||
return result;
|
||||
}
|
||||
|
||||
float step = static_cast<float>(t_max) / static_cast<float>(n - 1);
|
||||
for (int i = 0; i < n; ++i) {
|
||||
float t = t_max - step * i;
|
||||
result.push_back(t_to_sigma(t));
|
||||
}
|
||||
result.push_back(0);
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
struct KarrasSchedule : SigmaSchedule {
|
||||
std::vector<float> get_sigmas(uint32_t n) {
|
||||
// These *COULD* be function arguments here,
|
||||
// but does anybody ever bother to touch them?
|
||||
float sigma_min = 0.1;
|
||||
float sigma_max = 10.;
|
||||
float rho = 7.;
|
||||
|
||||
std::vector<float> result(n + 1);
|
||||
|
||||
float min_inv_rho = pow(sigma_min, (1. / rho));
|
||||
float max_inv_rho = pow(sigma_max, (1. / rho));
|
||||
for (int i = 0; i < n; i++) {
|
||||
// Eq. (5) from Karras et al 2022
|
||||
result[i] = pow(max_inv_rho + (float)i / ((float)n - 1.) * (min_inv_rho - max_inv_rho), rho);
|
||||
}
|
||||
result[n] = 0.;
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
struct Denoiser {
|
||||
std::shared_ptr<SigmaSchedule> schedule = std::make_shared<DiscreteSchedule>();
|
||||
virtual std::vector<float> get_scalings(float sigma) = 0;
|
||||
};
|
||||
|
||||
struct CompVisDenoiser : public DiscreteSchedule {
|
||||
struct CompVisDenoiser : public Denoiser {
|
||||
float sigma_data = 1.0f;
|
||||
|
||||
std::vector<float> get_scalings(float sigma) {
|
||||
@ -2728,7 +2756,7 @@ struct CompVisDenoiser : public DiscreteSchedule {
|
||||
}
|
||||
};
|
||||
|
||||
struct CompVisVDenoiser : public DiscreteSchedule {
|
||||
struct CompVisVDenoiser : public Denoiser {
|
||||
float sigma_data = 1.0f;
|
||||
|
||||
std::vector<float> get_scalings(float sigma) {
|
||||
@ -2764,7 +2792,7 @@ class StableDiffusionGGML {
|
||||
UNetModel diffusion_model;
|
||||
AutoEncoderKL first_stage_model;
|
||||
|
||||
std::shared_ptr<DiscreteSchedule> denoiser = std::make_shared<CompVisDenoiser>();
|
||||
std::shared_ptr<Denoiser> denoiser = std::make_shared<CompVisDenoiser>();
|
||||
|
||||
StableDiffusionGGML() = default;
|
||||
|
||||
@ -2798,7 +2826,7 @@ class StableDiffusionGGML {
|
||||
}
|
||||
}
|
||||
|
||||
bool load_from_file(const std::string& file_path) {
|
||||
bool load_from_file(const std::string& file_path, Schedule schedule) {
|
||||
LOG_INFO("loading model from '%s'", file_path.c_str());
|
||||
|
||||
std::ifstream file(file_path, std::ios::binary);
|
||||
@ -3093,10 +3121,29 @@ class StableDiffusionGGML {
|
||||
LOG_INFO("running in eps-prediction mode");
|
||||
}
|
||||
|
||||
if (schedule != DEFAULT) {
|
||||
switch (schedule) {
|
||||
case DISCRETE:
|
||||
LOG_INFO("running with discrete schedule");
|
||||
denoiser->schedule = std::make_shared<DiscreteSchedule>();
|
||||
break;
|
||||
case KARRAS:
|
||||
LOG_INFO("running with Karras schedule");
|
||||
denoiser->schedule = std::make_shared<KarrasSchedule>();
|
||||
break;
|
||||
case DEFAULT:
|
||||
// Don't touch anything.
|
||||
break;
|
||||
default:
|
||||
LOG_ERROR("Unknown schedule %i", schedule);
|
||||
abort();
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < TIMESTEPS; i++) {
|
||||
denoiser->alphas_cumprod[i] = alphas_cumprod[i];
|
||||
denoiser->sigmas[i] = std::sqrt((1 - denoiser->alphas_cumprod[i]) / denoiser->alphas_cumprod[i]);
|
||||
denoiser->log_sigmas[i] = std::log(denoiser->sigmas[i]);
|
||||
denoiser->schedule->alphas_cumprod[i] = alphas_cumprod[i];
|
||||
denoiser->schedule->sigmas[i] = std::sqrt((1 - denoiser->schedule->alphas_cumprod[i]) / denoiser->schedule->alphas_cumprod[i]);
|
||||
denoiser->schedule->log_sigmas[i] = std::log(denoiser->schedule->sigmas[i]);
|
||||
}
|
||||
|
||||
return true;
|
||||
@ -3445,7 +3492,7 @@ class StableDiffusionGGML {
|
||||
c_in = scaling[1];
|
||||
}
|
||||
|
||||
float t = denoiser->sigma_to_t(sigma);
|
||||
float t = denoiser->schedule->sigma_to_t(sigma);
|
||||
ggml_set_f32(timesteps, t);
|
||||
set_timestep_embedding(timesteps, t_emb, diffusion_model.model_channels);
|
||||
|
||||
@ -4010,8 +4057,8 @@ StableDiffusion::StableDiffusion(int n_threads,
|
||||
rng_type);
|
||||
}
|
||||
|
||||
bool StableDiffusion::load_from_file(const std::string& file_path) {
|
||||
return sd->load_from_file(file_path);
|
||||
bool StableDiffusion::load_from_file(const std::string& file_path, Schedule s) {
|
||||
return sd->load_from_file(file_path, s);
|
||||
}
|
||||
|
||||
std::vector<uint8_t> StableDiffusion::txt2img(const std::string& prompt,
|
||||
@ -4061,7 +4108,7 @@ std::vector<uint8_t> StableDiffusion::txt2img(const std::string& prompt,
|
||||
struct ggml_tensor* x_t = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, W, H, C, 1);
|
||||
ggml_tensor_set_f32_randn(x_t, sd->rng);
|
||||
|
||||
std::vector<float> sigmas = sd->denoiser->get_sigmas(sample_steps);
|
||||
std::vector<float> sigmas = sd->denoiser->schedule->get_sigmas(sample_steps);
|
||||
|
||||
LOG_INFO("start sampling");
|
||||
struct ggml_tensor* x_0 = sd->sample(ctx, x_t, c, uc, cfg_scale, sample_method, sigmas);
|
||||
@ -4117,7 +4164,7 @@ std::vector<uint8_t> StableDiffusion::img2img(const std::vector<uint8_t>& init_i
|
||||
}
|
||||
LOG_INFO("img2img %dx%d", width, height);
|
||||
|
||||
std::vector<float> sigmas = sd->denoiser->get_sigmas(sample_steps);
|
||||
std::vector<float> sigmas = sd->denoiser->schedule->get_sigmas(sample_steps);
|
||||
size_t t_enc = static_cast<size_t>(sample_steps * strength);
|
||||
LOG_INFO("target t_enc is %zu steps", t_enc);
|
||||
std::vector<float> sigma_sched;
|
||||
|
@ -25,6 +25,13 @@ enum SampleMethod {
|
||||
N_SAMPLE_METHODS
|
||||
};
|
||||
|
||||
enum Schedule {
|
||||
DEFAULT,
|
||||
DISCRETE,
|
||||
KARRAS,
|
||||
N_SCHEDULES
|
||||
};
|
||||
|
||||
class StableDiffusionGGML;
|
||||
|
||||
class StableDiffusion {
|
||||
@ -36,7 +43,7 @@ class StableDiffusion {
|
||||
bool vae_decode_only = false,
|
||||
bool free_params_immediately = false,
|
||||
RNGType rng_type = STD_DEFAULT_RNG);
|
||||
bool load_from_file(const std::string& file_path);
|
||||
bool load_from_file(const std::string& file_path, Schedule d = DEFAULT);
|
||||
std::vector<uint8_t> txt2img(
|
||||
const std::string& prompt,
|
||||
const std::string& negative_prompt,
|
||||
|
Loading…
Reference in New Issue
Block a user