fix: preprocess tensor names in tensor types map (#607)

Thank you for your contribution
This commit is contained in:
stduhpf 2025-03-01 04:48:04 +01:00 committed by GitHub
parent fbd42b6fc1
commit 85e9a12988
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -558,6 +558,26 @@ std::string convert_tensor_name(std::string name) {
return new_name; return new_name;
} }
void add_preprocess_tensor_storage_types(std::map<std::string, enum ggml_type>& tensor_storages_types, std::string name, enum ggml_type type) {
std::string new_name = convert_tensor_name(name);
if (new_name.find("cond_stage_model") != std::string::npos && ends_with(new_name, "attn.in_proj_weight")) {
size_t prefix_size = new_name.find("attn.in_proj_weight");
std::string prefix = new_name.substr(0, prefix_size);
tensor_storages_types[prefix + "self_attn.q_proj.weight"] = type;
tensor_storages_types[prefix + "self_attn.k_proj.weight"] = type;
tensor_storages_types[prefix + "self_attn.v_proj.weight"] = type;
} else if (new_name.find("cond_stage_model") != std::string::npos && ends_with(new_name, "attn.in_proj_bias")) {
size_t prefix_size = new_name.find("attn.in_proj_bias");
std::string prefix = new_name.substr(0, prefix_size);
tensor_storages_types[prefix + "self_attn.q_proj.bias"] = type;
tensor_storages_types[prefix + "self_attn.k_proj.bias"] = type;
tensor_storages_types[prefix + "self_attn.v_proj.bias"] = type;
} else {
tensor_storages_types[new_name] = type;
}
}
void preprocess_tensor(TensorStorage tensor_storage, void preprocess_tensor(TensorStorage tensor_storage,
std::vector<TensorStorage>& processed_tensor_storages) { std::vector<TensorStorage>& processed_tensor_storages) {
std::vector<TensorStorage> result; std::vector<TensorStorage> result;
@ -927,7 +947,7 @@ bool ModelLoader::init_from_gguf_file(const std::string& file_path, const std::s
GGML_ASSERT(ggml_nbytes(dummy) == tensor_storage.nbytes()); GGML_ASSERT(ggml_nbytes(dummy) == tensor_storage.nbytes());
tensor_storages.push_back(tensor_storage); tensor_storages.push_back(tensor_storage);
tensor_storages_types[tensor_storage.name] = tensor_storage.type; add_preprocess_tensor_storage_types(tensor_storages_types, tensor_storage.name, tensor_storage.type);
} }
gguf_free(ctx_gguf_); gguf_free(ctx_gguf_);
@ -1072,7 +1092,7 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
} }
tensor_storages.push_back(tensor_storage); tensor_storages.push_back(tensor_storage);
tensor_storages_types[tensor_storage.name] = tensor_storage.type; add_preprocess_tensor_storage_types(tensor_storages_types, tensor_storage.name, tensor_storage.type);
// LOG_DEBUG("%s %s", tensor_storage.to_string().c_str(), dtype.c_str()); // LOG_DEBUG("%s %s", tensor_storage.to_string().c_str(), dtype.c_str());
} }
@ -1403,7 +1423,7 @@ bool ModelLoader::parse_data_pkl(uint8_t* buffer,
// printf(" ZIP got tensor %s \n ", reader.tensor_storage.name.c_str()); // printf(" ZIP got tensor %s \n ", reader.tensor_storage.name.c_str());
reader.tensor_storage.name = prefix + reader.tensor_storage.name; reader.tensor_storage.name = prefix + reader.tensor_storage.name;
tensor_storages.push_back(reader.tensor_storage); tensor_storages.push_back(reader.tensor_storage);
tensor_storages_types[reader.tensor_storage.name] = reader.tensor_storage.type; add_preprocess_tensor_storage_types(tensor_storages_types, reader.tensor_storage.name, reader.tensor_storage.type);
// LOG_DEBUG("%s", reader.tensor_storage.name.c_str()); // LOG_DEBUG("%s", reader.tensor_storage.name.c_str());
// reset // reset
@ -1461,10 +1481,10 @@ SDVersion ModelLoader::get_sd_version() {
TensorStorage token_embedding_weight, input_block_weight; TensorStorage token_embedding_weight, input_block_weight;
bool input_block_checked = false; bool input_block_checked = false;
bool has_multiple_encoders = false; bool has_multiple_encoders = false;
bool is_unet = false; bool is_unet = false;
bool is_xl = false; bool is_xl = false;
bool is_flux = false; bool is_flux = false;
#define found_family (is_xl || is_flux) #define found_family (is_xl || is_flux)
@ -1481,7 +1501,7 @@ SDVersion ModelLoader::get_sd_version() {
} }
if (tensor_storage.name.find("model.diffusion_model.input_blocks.") != std::string::npos) { if (tensor_storage.name.find("model.diffusion_model.input_blocks.") != std::string::npos) {
is_unet = true; is_unet = true;
if(has_multiple_encoders){ if (has_multiple_encoders) {
is_xl = true; is_xl = true;
if (input_block_checked) { if (input_block_checked) {
break; break;
@ -1490,7 +1510,7 @@ SDVersion ModelLoader::get_sd_version() {
} }
if (tensor_storage.name.find("conditioner.embedders.1") != std::string::npos || tensor_storage.name.find("cond_stage_model.1") != std::string::npos) { if (tensor_storage.name.find("conditioner.embedders.1") != std::string::npos || tensor_storage.name.find("cond_stage_model.1") != std::string::npos) {
has_multiple_encoders = true; has_multiple_encoders = true;
if(is_unet){ if (is_unet) {
is_xl = true; is_xl = true;
if (input_block_checked) { if (input_block_checked) {
break; break;
@ -1635,11 +1655,20 @@ ggml_type ModelLoader::get_vae_wtype() {
void ModelLoader::set_wtype_override(ggml_type wtype, std::string prefix) { void ModelLoader::set_wtype_override(ggml_type wtype, std::string prefix) {
for (auto& pair : tensor_storages_types) { for (auto& pair : tensor_storages_types) {
if (prefix.size() < 1 || pair.first.substr(0, prefix.size()) == prefix) { if (prefix.size() < 1 || pair.first.substr(0, prefix.size()) == prefix) {
bool found = false;
for (auto& tensor_storage : tensor_storages) { for (auto& tensor_storage : tensor_storages) {
if (tensor_storage.name == pair.first) { std::map<std::string, ggml_type> temp;
if (tensor_should_be_converted(tensor_storage, wtype)) { add_preprocess_tensor_storage_types(temp, tensor_storage.name, tensor_storage.type);
pair.second = wtype; for (auto& preprocessed_name : temp) {
if (preprocessed_name.first == pair.first) {
if (tensor_should_be_converted(tensor_storage, wtype)) {
pair.second = wtype;
}
found = true;
break;
} }
}
if (found) {
break; break;
} }
} }