fix: support more SDXL LoRA names (#216)
* apply pmid lora only once for multiple txt2img calls * add better support for SDXL LoRA * fix for some sdxl lora, like lcm-lora-xl --------- Co-authored-by: bssrdf <bssrdf@gmail.com> Co-authored-by: leejet <leejet714@gmail.com>
This commit is contained in:
parent
646e77638e
commit
afea457eda
@ -688,7 +688,7 @@ int main(int argc, const char* argv[]) {
|
|||||||
int resized_height = params.height + (64 - params.height % 64);
|
int resized_height = params.height + (64 - params.height % 64);
|
||||||
int resized_width = params.width + (64 - params.width % 64);
|
int resized_width = params.width + (64 - params.width % 64);
|
||||||
|
|
||||||
uint8_t *resized_image_buffer = (uint8_t *)malloc(resized_height * resized_width * 3);
|
uint8_t* resized_image_buffer = (uint8_t*)malloc(resized_height * resized_width * 3);
|
||||||
if (resized_image_buffer == NULL) {
|
if (resized_image_buffer == NULL) {
|
||||||
fprintf(stderr, "error: allocate memory for resize input image\n");
|
fprintf(stderr, "error: allocate memory for resize input image\n");
|
||||||
free(input_image_buffer);
|
free(input_image_buffer);
|
||||||
@ -699,8 +699,7 @@ int main(int argc, const char* argv[]) {
|
|||||||
3 /*RGB channel*/, STBIR_ALPHA_CHANNEL_NONE, 0,
|
3 /*RGB channel*/, STBIR_ALPHA_CHANNEL_NONE, 0,
|
||||||
STBIR_EDGE_CLAMP, STBIR_EDGE_CLAMP,
|
STBIR_EDGE_CLAMP, STBIR_EDGE_CLAMP,
|
||||||
STBIR_FILTER_BOX, STBIR_FILTER_BOX,
|
STBIR_FILTER_BOX, STBIR_FILTER_BOX,
|
||||||
STBIR_COLORSPACE_SRGB, nullptr
|
STBIR_COLORSPACE_SRGB, nullptr);
|
||||||
);
|
|
||||||
|
|
||||||
// Save resized result
|
// Save resized result
|
||||||
free(input_image_buffer);
|
free(input_image_buffer);
|
||||||
|
11
lora.hpp
11
lora.hpp
@ -91,10 +91,15 @@ struct LoraModel : public GGMLModule {
|
|||||||
k_tensor = k_tensor.substr(0, k_pos);
|
k_tensor = k_tensor.substr(0, k_pos);
|
||||||
replace_all_chars(k_tensor, '.', '_');
|
replace_all_chars(k_tensor, '.', '_');
|
||||||
// LOG_DEBUG("k_tensor %s", k_tensor.c_str());
|
// LOG_DEBUG("k_tensor %s", k_tensor.c_str());
|
||||||
if (k_tensor == "model_diffusion_model_output_blocks_2_2_conv") { // fix for SDXL
|
|
||||||
k_tensor = "model_diffusion_model_output_blocks_2_1_conv";
|
|
||||||
}
|
|
||||||
std::string lora_up_name = "lora." + k_tensor + ".lora_up.weight";
|
std::string lora_up_name = "lora." + k_tensor + ".lora_up.weight";
|
||||||
|
if (lora_tensors.find(lora_up_name) == lora_tensors.end()) {
|
||||||
|
if (k_tensor == "model_diffusion_model_output_blocks_2_2_conv") {
|
||||||
|
// fix for some sdxl lora, like lcm-lora-xl
|
||||||
|
k_tensor = "model_diffusion_model_output_blocks_2_1_conv";
|
||||||
|
lora_up_name = "lora." + k_tensor + ".lora_up.weight";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
std::string lora_down_name = "lora." + k_tensor + ".lora_down.weight";
|
std::string lora_down_name = "lora." + k_tensor + ".lora_down.weight";
|
||||||
std::string alpha_name = "lora." + k_tensor + ".alpha";
|
std::string alpha_name = "lora." + k_tensor + ".alpha";
|
||||||
std::string scale_name = "lora." + k_tensor + ".scale";
|
std::string scale_name = "lora." + k_tensor + ".scale";
|
||||||
|
17
model.cpp
17
model.cpp
@ -211,6 +211,8 @@ std::string convert_sdxl_lora_name(std::string tensor_name) {
|
|||||||
{"unet", "model_diffusion_model"},
|
{"unet", "model_diffusion_model"},
|
||||||
{"te2", "cond_stage_model_1_transformer"},
|
{"te2", "cond_stage_model_1_transformer"},
|
||||||
{"te1", "cond_stage_model_transformer"},
|
{"te1", "cond_stage_model_transformer"},
|
||||||
|
{"text_encoder_2", "cond_stage_model_1_transformer"},
|
||||||
|
{"text_encoder", "cond_stage_model_transformer"},
|
||||||
};
|
};
|
||||||
for (auto& pair_i : sdxl_lora_name_lookup) {
|
for (auto& pair_i : sdxl_lora_name_lookup) {
|
||||||
if (tensor_name.compare(0, pair_i.first.length(), pair_i.first) == 0) {
|
if (tensor_name.compare(0, pair_i.first.length(), pair_i.first) == 0) {
|
||||||
@ -446,18 +448,25 @@ std::string convert_tensor_name(const std::string& name) {
|
|||||||
} else {
|
} else {
|
||||||
new_name = name;
|
new_name = name;
|
||||||
}
|
}
|
||||||
} else if (contains(name, "lora_up") || contains(name, "lora_down") || contains(name, "lora.up") || contains(name, "lora.down")) {
|
} else if (contains(name, "lora_up") || contains(name, "lora_down") ||
|
||||||
|
contains(name, "lora.up") || contains(name, "lora.down") ||
|
||||||
|
contains(name, "lora_linear")) {
|
||||||
size_t pos = new_name.find(".processor");
|
size_t pos = new_name.find(".processor");
|
||||||
if (pos != std::string::npos) {
|
if (pos != std::string::npos) {
|
||||||
new_name.replace(pos, strlen(".processor"), "");
|
new_name.replace(pos, strlen(".processor"), "");
|
||||||
}
|
}
|
||||||
pos = new_name.find_last_of('_');
|
pos = new_name.rfind("lora");
|
||||||
if (pos != std::string::npos) {
|
if (pos != std::string::npos) {
|
||||||
std::string name_without_network_parts = new_name.substr(0, pos);
|
std::string name_without_network_parts = new_name.substr(0, pos - 1);
|
||||||
std::string network_part = new_name.substr(pos + 1);
|
std::string network_part = new_name.substr(pos);
|
||||||
// LOG_DEBUG("%s %s", name_without_network_parts.c_str(), network_part.c_str());
|
// LOG_DEBUG("%s %s", name_without_network_parts.c_str(), network_part.c_str());
|
||||||
std::string new_key = convert_diffusers_name_to_compvis(name_without_network_parts, '.');
|
std::string new_key = convert_diffusers_name_to_compvis(name_without_network_parts, '.');
|
||||||
|
new_key = convert_sdxl_lora_name(new_key);
|
||||||
replace_all_chars(new_key, '.', '_');
|
replace_all_chars(new_key, '.', '_');
|
||||||
|
size_t npos = network_part.rfind("_linear_layer");
|
||||||
|
if (npos != std::string::npos) {
|
||||||
|
network_part.replace(npos, strlen("_linear_layer"), "");
|
||||||
|
}
|
||||||
if (starts_with(network_part, "lora.")) {
|
if (starts_with(network_part, "lora.")) {
|
||||||
network_part = "lora_" + network_part.substr(5);
|
network_part = "lora_" + network_part.substr(5);
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user