feat: add sd3.5 medium and skip layer guidance support (#451)

* mmdit-x

* add support for sd3.5 medium

* add skip layer guidance support (mmdit only)

* ignore slg if slg_scale is zero (optimization)

* init out_skip once

* slg support for flux (expermiental)

* warn if version doesn't support slg

* refactor slg cli args

* set default slg_scale to 0 (oops)

* format code

---------

Co-authored-by: leejet <leejet714@gmail.com>
This commit is contained in:
stduhpf
2024-11-23 04:15:31 +01:00
committed by GitHub
parent ac54e00760
commit 65fa646684
9 changed files with 414 additions and 79 deletions

View File

@@ -119,6 +119,11 @@ struct SDParams {
bool canny_preprocess = false;
bool color = false;
int upscale_repeats = 1;
std::vector<int> skip_layers = {7, 8, 9};
float slg_scale = 0.;
float skip_layer_start = 0.01;
float skip_layer_end = 0.2;
};
void print_params(SDParams params) {
@@ -151,6 +156,7 @@ void print_params(SDParams params) {
printf(" negative_prompt: %s\n", params.negative_prompt.c_str());
printf(" min_cfg: %.2f\n", params.min_cfg);
printf(" cfg_scale: %.2f\n", params.cfg_scale);
printf(" slg_scale: %.2f\n", params.slg_scale);
printf(" guidance: %.2f\n", params.guidance);
printf(" clip_skip: %d\n", params.clip_skip);
printf(" width: %d\n", params.width);
@@ -197,6 +203,12 @@ void print_usage(int argc, const char* argv[]) {
printf(" -p, --prompt [PROMPT] the prompt to render\n");
printf(" -n, --negative-prompt PROMPT the negative prompt (default: \"\")\n");
printf(" --cfg-scale SCALE unconditional guidance scale: (default: 7.0)\n");
printf(" --slg-scale SCALE skip layer guidance (SLG) scale, only for DiT models: (default: 0)\n");
printf(" 0 means disabled, a value of 2.5 is nice for sd3.5 medium\n");
printf(" --skip_layers LAYERS Layers to skip for SLG steps: (default: [7,8,9])\n");
printf(" --skip_layer_start START SLG enabling point: (default: 0.01)\n");
printf(" --skip_layer_end END SLG disabling point: (default: 0.2)\n");
printf(" SLG will be enabled at step int([STEPS]*[START]) and disabled at int([STEPS]*[END])\n");
printf(" --strength STRENGTH strength for noising/unnoising (default: 0.75)\n");
printf(" --style-ratio STYLE-RATIO strength for keeping input identity (default: 20%%)\n");
printf(" --control-strength STRENGTH strength to apply Control Net (default: 0.9)\n");
@@ -534,6 +546,61 @@ void parse_args(int argc, const char** argv, SDParams& params) {
params.verbose = true;
} else if (arg == "--color") {
params.color = true;
} else if (arg == "--slg-scale") {
if (++i >= argc) {
invalid_arg = true;
break;
}
params.slg_scale = std::stof(argv[i]);
} else if (arg == "--skip-layers") {
if (++i >= argc) {
invalid_arg = true;
break;
}
if (argv[i][0] != '[') {
invalid_arg = true;
break;
}
std::string layers_str = argv[i];
while (layers_str.back() != ']') {
if (++i >= argc) {
invalid_arg = true;
break;
}
layers_str += " " + std::string(argv[i]);
}
layers_str = layers_str.substr(1, layers_str.size() - 2);
std::regex regex("[, ]+");
std::sregex_token_iterator iter(layers_str.begin(), layers_str.end(), regex, -1);
std::sregex_token_iterator end;
std::vector<std::string> tokens(iter, end);
std::vector<int> layers;
for (const auto& token : tokens) {
try {
layers.push_back(std::stoi(token));
} catch (const std::invalid_argument& e) {
invalid_arg = true;
break;
}
}
params.skip_layers = layers;
if (invalid_arg) {
break;
}
} else if (arg == "--skip-layer-start") {
if (++i >= argc) {
invalid_arg = true;
break;
}
params.skip_layer_start = std::stof(argv[i]);
} else if (arg == "--skip-layer-end") {
if (++i >= argc) {
invalid_arg = true;
break;
}
params.skip_layer_end = std::stof(argv[i]);
} else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
print_usage(argc, argv);
@@ -624,6 +691,16 @@ std::string get_image_params(SDParams params, int64_t seed) {
}
parameter_string += "Steps: " + std::to_string(params.sample_steps) + ", ";
parameter_string += "CFG scale: " + std::to_string(params.cfg_scale) + ", ";
if (params.slg_scale != 0 && params.skip_layers.size() != 0) {
parameter_string += "SLG scale: " + std::to_string(params.cfg_scale) + ", ";
parameter_string += "Skip layers: [";
for (const auto& layer : params.skip_layers) {
parameter_string += std::to_string(layer) + ", ";
}
parameter_string += "], ";
parameter_string += "Skip layer start: " + std::to_string(params.skip_layer_start) + ", ";
parameter_string += "Skip layer end: " + std::to_string(params.skip_layer_end) + ", ";
}
parameter_string += "Guidance: " + std::to_string(params.guidance) + ", ";
parameter_string += "Seed: " + std::to_string(seed) + ", ";
parameter_string += "Size: " + std::to_string(params.width) + "x" + std::to_string(params.height) + ", ";
@@ -840,7 +917,11 @@ int main(int argc, const char* argv[]) {
params.control_strength,
params.style_ratio,
params.normalize_input,
params.input_id_images_path.c_str());
params.input_id_images_path.c_str(),
params.skip_layers,
params.slg_scale,
params.skip_layer_start,
params.skip_layer_end);
} else {
sd_image_t input_image = {(uint32_t)params.width,
(uint32_t)params.height,