fix: allow model and vae using different format
This commit is contained in:
parent
d7af2c2ba9
commit
8a87b273ad
486
model.cpp
486
model.cpp
@ -534,232 +534,29 @@ std::map<char, int> unicode_to_byte() {
|
||||
}
|
||||
|
||||
bool ModelLoader::init_from_file(const std::string& file_path, const std::string& prefix) {
|
||||
file_paths_.push_back(file_path);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ModelLoader::init_from_files(const std::vector<std::string>& file_paths) {
|
||||
for (auto& file_path : file_paths) {
|
||||
if (!init_from_file(file_path)) {
|
||||
if (is_directory(file_path)) {
|
||||
LOG_INFO("load %s using diffusers format", file_path.c_str());
|
||||
return init_from_diffusers_file(file_path, prefix);
|
||||
} else if (ends_with(file_path, ".gguf")) {
|
||||
LOG_INFO("load %s using gguf format", file_path.c_str());
|
||||
return init_from_gguf_file(file_path, prefix);
|
||||
} else if (ends_with(file_path, ".safetensors")) {
|
||||
LOG_INFO("load %s using safetensors format", file_path.c_str());
|
||||
return init_from_safetensors_file(file_path, prefix);
|
||||
} else if (ends_with(file_path, ".ckpt")) {
|
||||
LOG_INFO("load %s using checkpoint format", file_path.c_str());
|
||||
return init_from_ckpt_file(file_path, prefix);
|
||||
} else {
|
||||
LOG_WARN("unknown format %s", file_path.c_str());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
SDVersion ModelLoader::get_sd_version() {
|
||||
TensorStorage token_embedding_weight;
|
||||
for (auto& tensor_storage : tensor_storages) {
|
||||
if (tensor_storage.name == "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight" ||
|
||||
tensor_storage.name == "cond_stage_model.model.token_embedding.weight" ||
|
||||
tensor_storage.name == "text_model.embeddings.token_embedding.weight" ||
|
||||
tensor_storage.name == "te.text_model.embeddings.token_embedding.weight") {
|
||||
token_embedding_weight = tensor_storage;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (token_embedding_weight.ne[0] == 768) {
|
||||
return VERSION_1_x;
|
||||
} else if (token_embedding_weight.ne[0] == 1024) {
|
||||
return VERSION_2_x;
|
||||
}
|
||||
return VERSION_COUNT;
|
||||
}
|
||||
|
||||
ggml_type ModelLoader::get_sd_wtype() {
|
||||
for (auto& tensor_storage : tensor_storages) {
|
||||
if (is_unused_tensor(tensor_storage.name)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (tensor_storage.name.find(".weight") != std::string::npos &&
|
||||
tensor_storage.name.find("time_embed") != std::string::npos) {
|
||||
return tensor_storage.type;
|
||||
}
|
||||
}
|
||||
return GGML_TYPE_COUNT;
|
||||
}
|
||||
|
||||
bool ModelLoader::load_vocab(on_new_token_cb_t on_new_token_cb) {
|
||||
char* vocab_buffer = reinterpret_cast<char*>(vocab_json);
|
||||
nlohmann::json vocab = nlohmann::json::parse(vocab_buffer);
|
||||
std::map<char, int> decoder = unicode_to_byte();
|
||||
for (auto& it : vocab.items()) {
|
||||
int token_id = it.value();
|
||||
std::string token_str = it.key();
|
||||
std::string token = "";
|
||||
for (char c : token_str) {
|
||||
token += decoder[c];
|
||||
}
|
||||
on_new_token_cb(token, token_id);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb) {
|
||||
bool success = true;
|
||||
for (size_t file_index = 0; file_index < file_paths_.size(); file_index++) {
|
||||
std::string file_path = file_paths_[file_index];
|
||||
LOG_DEBUG("loading tensors from %s", file_path.c_str());
|
||||
|
||||
std::ifstream file(file_path, std::ios::binary);
|
||||
if (!file.is_open()) {
|
||||
LOG_ERROR("failed to open '%s'", file_path.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
bool is_zip = false;
|
||||
for (auto& tensor_storage : tensor_storages) {
|
||||
if (tensor_storage.index_in_zip >= 0) {
|
||||
is_zip = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
struct zip_t* zip = NULL;
|
||||
if (is_zip) {
|
||||
zip = zip_open(file_path.c_str(), 0, 'r');
|
||||
if (zip == NULL) {
|
||||
LOG_ERROR("failed to open zip '%s'", file_path.c_str());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<uint8_t> read_buffer;
|
||||
std::vector<uint8_t> convert_buffer;
|
||||
|
||||
auto read_data = [&](const TensorStorage& tensor_storage, char* buf, size_t n) {
|
||||
if (zip != NULL) {
|
||||
zip_entry_openbyindex(zip, tensor_storage.index_in_zip);
|
||||
size_t entry_size = zip_entry_size(zip);
|
||||
if (entry_size != n) {
|
||||
read_buffer.resize(entry_size);
|
||||
zip_entry_noallocread(zip, (void*)read_buffer.data(), entry_size);
|
||||
memcpy((void*)buf, (void*)(read_buffer.data() + tensor_storage.offset), n);
|
||||
} else {
|
||||
zip_entry_noallocread(zip, (void*)buf, n);
|
||||
}
|
||||
zip_entry_close(zip);
|
||||
} else {
|
||||
file.seekg(tensor_storage.offset);
|
||||
file.read(buf, n);
|
||||
if (!file) {
|
||||
LOG_ERROR("read tensor data failed: '%s'", file_path.c_str());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
std::vector<TensorStorage> processed_tensor_storages;
|
||||
for (auto& tensor_storage : tensor_storages) {
|
||||
if (tensor_storage.file_index != file_index) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// LOG_DEBUG("%s", name.c_str());
|
||||
|
||||
if (is_unused_tensor(tensor_storage.name)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
preprocess_tensor(tensor_storage, processed_tensor_storages);
|
||||
}
|
||||
|
||||
for (auto& tensor_storage : processed_tensor_storages) {
|
||||
// LOG_DEBUG("%s", name.c_str());
|
||||
|
||||
ggml_tensor* dst_tensor = NULL;
|
||||
|
||||
success = on_new_tensor_cb(tensor_storage, &dst_tensor);
|
||||
if (!success) {
|
||||
LOG_WARN("process tensor failed: '%s'", tensor_storage.name.c_str());
|
||||
break;
|
||||
}
|
||||
|
||||
if (dst_tensor == NULL) {
|
||||
continue;
|
||||
}
|
||||
|
||||
ggml_backend_t backend = ggml_get_backend(dst_tensor);
|
||||
|
||||
size_t nbytes_to_read = tensor_storage.nbytes_to_read();
|
||||
|
||||
if (backend == NULL || ggml_backend_is_cpu(backend)) {
|
||||
// for the CPU and Metal backend, we can copy directly into the tensor
|
||||
if (tensor_storage.type == dst_tensor->type) {
|
||||
GGML_ASSERT(ggml_nbytes(dst_tensor) == nbytes_to_read);
|
||||
read_data(tensor_storage, (char*)dst_tensor->data, nbytes_to_read);
|
||||
|
||||
if (tensor_storage.is_bf16) {
|
||||
// inplace op
|
||||
bf16_to_f32_vec((uint16_t*)dst_tensor->data, (float*)dst_tensor->data, tensor_storage.nelements());
|
||||
}
|
||||
} else {
|
||||
read_buffer.resize(tensor_storage.nbytes());
|
||||
read_data(tensor_storage, (char*)read_buffer.data(), nbytes_to_read);
|
||||
|
||||
if (tensor_storage.is_bf16) {
|
||||
// inplace op
|
||||
bf16_to_f32_vec((uint16_t*)read_buffer.data(), (float*)read_buffer.data(), tensor_storage.nelements());
|
||||
}
|
||||
|
||||
convert_tensor((void*)read_buffer.data(), tensor_storage.type, dst_tensor->data,
|
||||
dst_tensor->type, (int)tensor_storage.nelements());
|
||||
}
|
||||
} else {
|
||||
read_buffer.resize(tensor_storage.nbytes());
|
||||
read_data(tensor_storage, (char*)read_buffer.data(), nbytes_to_read);
|
||||
|
||||
if (tensor_storage.is_bf16) {
|
||||
// inplace op
|
||||
bf16_to_f32_vec((uint16_t*)read_buffer.data(), (float*)read_buffer.data(), tensor_storage.nelements());
|
||||
}
|
||||
|
||||
if (tensor_storage.type == dst_tensor->type) {
|
||||
// copy to device memory
|
||||
ggml_backend_tensor_set(dst_tensor, read_buffer.data(), 0, ggml_nbytes(dst_tensor));
|
||||
} else {
|
||||
// convert first, then copy to device memory
|
||||
convert_buffer.resize(ggml_nbytes(dst_tensor));
|
||||
convert_tensor((void*)read_buffer.data(), tensor_storage.type,
|
||||
(void*)convert_buffer.data(), dst_tensor->type,
|
||||
(int)tensor_storage.nelements());
|
||||
ggml_backend_tensor_set(dst_tensor, convert_buffer.data(), 0, ggml_nbytes(dst_tensor));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (zip != NULL) {
|
||||
zip_close(zip);
|
||||
}
|
||||
|
||||
if (!success) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
return success;
|
||||
}
|
||||
|
||||
int64_t ModelLoader::cal_mem_size() {
|
||||
int64_t mem_size = 0;
|
||||
for (auto& tensor_storage : tensor_storages) {
|
||||
if (is_unused_tensor(tensor_storage.name)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
mem_size += tensor_storage.nbytes();
|
||||
mem_size += GGML_MEM_ALIGN * 2; // for lora alphas
|
||||
}
|
||||
|
||||
return mem_size + 10 * 1024 * 1024;
|
||||
}
|
||||
|
||||
/*================================================= GGUFModelLoader ==================================================*/
|
||||
|
||||
bool GGUFModelLoader::init_from_file(const std::string& file_path, const std::string& prefix) {
|
||||
LOG_INFO("loading model from '%s'", file_path.c_str());
|
||||
ModelLoader::init_from_file(file_path, prefix);
|
||||
bool ModelLoader::init_from_gguf_file(const std::string& file_path, const std::string& prefix) {
|
||||
LOG_DEBUG("init from '%s'", file_path.c_str());
|
||||
file_paths_.push_back(file_path);
|
||||
size_t file_index = file_paths_.size() - 1;
|
||||
|
||||
gguf_context* ctx_gguf_ = NULL;
|
||||
@ -811,8 +608,9 @@ ggml_type str_to_ggml_type(const std::string& dtype) {
|
||||
}
|
||||
|
||||
// https://huggingface.co/docs/safetensors/index
|
||||
bool SafeTensorsModelLoader::init_from_file(const std::string& file_path, const std::string& prefix) {
|
||||
ModelLoader::init_from_file(file_path, prefix);
|
||||
bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const std::string& prefix) {
|
||||
LOG_DEBUG("init from '%s'", file_path.c_str());
|
||||
file_paths_.push_back(file_path);
|
||||
size_t file_index = file_paths_.size() - 1;
|
||||
std::ifstream file(file_path, std::ios::binary);
|
||||
if (!file.is_open()) {
|
||||
@ -913,21 +711,18 @@ bool SafeTensorsModelLoader::init_from_file(const std::string& file_path, const
|
||||
|
||||
/*================================================= DiffusersModelLoader ==================================================*/
|
||||
|
||||
bool DiffusersModelLoader::init_from_file(const std::string& file_path, const std::string& prefix) {
|
||||
if (!is_directory(file_path)) {
|
||||
return SafeTensorsModelLoader::init_from_file(file_path, prefix);
|
||||
}
|
||||
bool ModelLoader::init_from_diffusers_file(const std::string& file_path, const std::string& prefix) {
|
||||
std::string unet_path = path_join(file_path, "unet/diffusion_pytorch_model.safetensors");
|
||||
std::string vae_path = path_join(file_path, "vae/diffusion_pytorch_model.safetensors");
|
||||
std::string clip_path = path_join(file_path, "text_encoder/model.safetensors");
|
||||
|
||||
if (!SafeTensorsModelLoader::init_from_file(unet_path, "unet.")) {
|
||||
if (!init_from_safetensors_file(unet_path, "unet.")) {
|
||||
return false;
|
||||
}
|
||||
if (!SafeTensorsModelLoader::init_from_file(vae_path, "vae.")) {
|
||||
if (!init_from_safetensors_file(vae_path, "vae.")) {
|
||||
return false;
|
||||
}
|
||||
if (!SafeTensorsModelLoader::init_from_file(clip_path, "te.")) {
|
||||
if (!init_from_safetensors_file(clip_path, "te.")) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
@ -1127,7 +922,7 @@ int find_char(uint8_t* buffer, int len, char c) {
|
||||
|
||||
#define MAX_STRING_BUFFER 512
|
||||
|
||||
bool CkptModelLoader::parse_data_pkl(uint8_t* buffer,
|
||||
bool ModelLoader::parse_data_pkl(uint8_t* buffer,
|
||||
size_t buffer_size,
|
||||
zip_t* zip,
|
||||
std::string dir,
|
||||
@ -1250,8 +1045,9 @@ bool CkptModelLoader::parse_data_pkl(uint8_t* buffer,
|
||||
return true;
|
||||
}
|
||||
|
||||
bool CkptModelLoader::init_from_file(const std::string& file_path, const std::string& prefix) {
|
||||
ModelLoader::init_from_file(file_path, prefix);
|
||||
bool ModelLoader::init_from_ckpt_file(const std::string& file_path, const std::string& prefix) {
|
||||
LOG_DEBUG("init from '%s'", file_path.c_str());
|
||||
file_paths_.push_back(file_path);
|
||||
size_t file_index = file_paths_.size() - 1;
|
||||
|
||||
struct zip_t* zip = zip_open(file_path.c_str(), 0, 'r');
|
||||
@ -1284,29 +1080,213 @@ bool CkptModelLoader::init_from_file(const std::string& file_path, const std::st
|
||||
return true;
|
||||
}
|
||||
|
||||
/*================================================= init_model_loader_from_file ==================================================*/
|
||||
SDVersion ModelLoader::get_sd_version() {
|
||||
TensorStorage token_embedding_weight;
|
||||
for (auto& tensor_storage : tensor_storages) {
|
||||
if (tensor_storage.name == "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight" ||
|
||||
tensor_storage.name == "cond_stage_model.model.token_embedding.weight" ||
|
||||
tensor_storage.name == "text_model.embeddings.token_embedding.weight" ||
|
||||
tensor_storage.name == "te.text_model.embeddings.token_embedding.weight") {
|
||||
token_embedding_weight = tensor_storage;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (token_embedding_weight.ne[0] == 768) {
|
||||
return VERSION_1_x;
|
||||
} else if (token_embedding_weight.ne[0] == 1024) {
|
||||
return VERSION_2_x;
|
||||
}
|
||||
return VERSION_COUNT;
|
||||
}
|
||||
|
||||
ModelLoader* init_model_loader_from_file(const std::string& file_path) {
|
||||
ModelLoader* model_loader = NULL;
|
||||
if (is_directory(file_path)) {
|
||||
LOG_DEBUG("load %s using diffusers format", file_path.c_str());
|
||||
model_loader = new DiffusersModelLoader();
|
||||
} else if (ends_with(file_path, ".gguf")) {
|
||||
LOG_DEBUG("load %s using gguf format", file_path.c_str());
|
||||
model_loader = new GGUFModelLoader();
|
||||
} else if (ends_with(file_path, ".safetensors")) {
|
||||
LOG_DEBUG("load %s using safetensors format", file_path.c_str());
|
||||
model_loader = new SafeTensorsModelLoader();
|
||||
} else if (ends_with(file_path, ".ckpt")) {
|
||||
LOG_DEBUG("load %s using checkpoint format", file_path.c_str());
|
||||
model_loader = new CkptModelLoader();
|
||||
ggml_type ModelLoader::get_sd_wtype() {
|
||||
for (auto& tensor_storage : tensor_storages) {
|
||||
if (is_unused_tensor(tensor_storage.name)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (tensor_storage.name.find(".weight") != std::string::npos &&
|
||||
tensor_storage.name.find("time_embed") != std::string::npos) {
|
||||
return tensor_storage.type;
|
||||
}
|
||||
}
|
||||
return GGML_TYPE_COUNT;
|
||||
}
|
||||
|
||||
bool ModelLoader::load_vocab(on_new_token_cb_t on_new_token_cb) {
|
||||
char* vocab_buffer = reinterpret_cast<char*>(vocab_json);
|
||||
nlohmann::json vocab = nlohmann::json::parse(vocab_buffer);
|
||||
std::map<char, int> decoder = unicode_to_byte();
|
||||
for (auto& it : vocab.items()) {
|
||||
int token_id = it.value();
|
||||
std::string token_str = it.key();
|
||||
std::string token = "";
|
||||
for (char c : token_str) {
|
||||
token += decoder[c];
|
||||
}
|
||||
on_new_token_cb(token, token_id);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb) {
|
||||
bool success = true;
|
||||
for (size_t file_index = 0; file_index < file_paths_.size(); file_index++) {
|
||||
std::string file_path = file_paths_[file_index];
|
||||
LOG_DEBUG("loading tensors from %s", file_path.c_str());
|
||||
|
||||
std::ifstream file(file_path, std::ios::binary);
|
||||
if (!file.is_open()) {
|
||||
LOG_ERROR("failed to open '%s'", file_path.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
bool is_zip = false;
|
||||
for (auto& tensor_storage : tensor_storages) {
|
||||
if (tensor_storage.file_index != file_index) {
|
||||
continue;
|
||||
}
|
||||
if (tensor_storage.index_in_zip >= 0) {
|
||||
is_zip = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
struct zip_t* zip = NULL;
|
||||
if (is_zip) {
|
||||
zip = zip_open(file_path.c_str(), 0, 'r');
|
||||
if (zip == NULL) {
|
||||
LOG_ERROR("failed to open zip '%s'", file_path.c_str());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<uint8_t> read_buffer;
|
||||
std::vector<uint8_t> convert_buffer;
|
||||
|
||||
auto read_data = [&](const TensorStorage& tensor_storage, char* buf, size_t n) {
|
||||
if (zip != NULL) {
|
||||
zip_entry_openbyindex(zip, tensor_storage.index_in_zip);
|
||||
size_t entry_size = zip_entry_size(zip);
|
||||
if (entry_size != n) {
|
||||
read_buffer.resize(entry_size);
|
||||
zip_entry_noallocread(zip, (void*)read_buffer.data(), entry_size);
|
||||
memcpy((void*)buf, (void*)(read_buffer.data() + tensor_storage.offset), n);
|
||||
} else {
|
||||
LOG_DEBUG("unknown format %s", file_path.c_str());
|
||||
return NULL;
|
||||
zip_entry_noallocread(zip, (void*)buf, n);
|
||||
}
|
||||
if (!model_loader->init_from_file(file_path)) {
|
||||
delete model_loader;
|
||||
model_loader = NULL;
|
||||
zip_entry_close(zip);
|
||||
} else {
|
||||
file.seekg(tensor_storage.offset);
|
||||
file.read(buf, n);
|
||||
if (!file) {
|
||||
LOG_ERROR("read tensor data failed: '%s'", file_path.c_str());
|
||||
return false;
|
||||
}
|
||||
return model_loader;
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
std::vector<TensorStorage> processed_tensor_storages;
|
||||
for (auto& tensor_storage : tensor_storages) {
|
||||
if (tensor_storage.file_index != file_index) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// LOG_DEBUG("%s", name.c_str());
|
||||
|
||||
if (is_unused_tensor(tensor_storage.name)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
preprocess_tensor(tensor_storage, processed_tensor_storages);
|
||||
}
|
||||
|
||||
for (auto& tensor_storage : processed_tensor_storages) {
|
||||
// LOG_DEBUG("%s", name.c_str());
|
||||
|
||||
ggml_tensor* dst_tensor = NULL;
|
||||
|
||||
success = on_new_tensor_cb(tensor_storage, &dst_tensor);
|
||||
if (!success) {
|
||||
LOG_WARN("process tensor failed: '%s'", tensor_storage.name.c_str());
|
||||
break;
|
||||
}
|
||||
|
||||
if (dst_tensor == NULL) {
|
||||
continue;
|
||||
}
|
||||
|
||||
ggml_backend_t backend = ggml_get_backend(dst_tensor);
|
||||
|
||||
size_t nbytes_to_read = tensor_storage.nbytes_to_read();
|
||||
|
||||
if (backend == NULL || ggml_backend_is_cpu(backend)) {
|
||||
// for the CPU and Metal backend, we can copy directly into the tensor
|
||||
if (tensor_storage.type == dst_tensor->type) {
|
||||
GGML_ASSERT(ggml_nbytes(dst_tensor) == nbytes_to_read);
|
||||
read_data(tensor_storage, (char*)dst_tensor->data, nbytes_to_read);
|
||||
|
||||
if (tensor_storage.is_bf16) {
|
||||
// inplace op
|
||||
bf16_to_f32_vec((uint16_t*)dst_tensor->data, (float*)dst_tensor->data, tensor_storage.nelements());
|
||||
}
|
||||
} else {
|
||||
read_buffer.resize(tensor_storage.nbytes());
|
||||
read_data(tensor_storage, (char*)read_buffer.data(), nbytes_to_read);
|
||||
|
||||
if (tensor_storage.is_bf16) {
|
||||
// inplace op
|
||||
bf16_to_f32_vec((uint16_t*)read_buffer.data(), (float*)read_buffer.data(), tensor_storage.nelements());
|
||||
}
|
||||
|
||||
convert_tensor((void*)read_buffer.data(), tensor_storage.type, dst_tensor->data,
|
||||
dst_tensor->type, (int)tensor_storage.nelements());
|
||||
}
|
||||
} else {
|
||||
read_buffer.resize(tensor_storage.nbytes());
|
||||
read_data(tensor_storage, (char*)read_buffer.data(), nbytes_to_read);
|
||||
|
||||
if (tensor_storage.is_bf16) {
|
||||
// inplace op
|
||||
bf16_to_f32_vec((uint16_t*)read_buffer.data(), (float*)read_buffer.data(), tensor_storage.nelements());
|
||||
}
|
||||
|
||||
if (tensor_storage.type == dst_tensor->type) {
|
||||
// copy to device memory
|
||||
ggml_backend_tensor_set(dst_tensor, read_buffer.data(), 0, ggml_nbytes(dst_tensor));
|
||||
} else {
|
||||
// convert first, then copy to device memory
|
||||
convert_buffer.resize(ggml_nbytes(dst_tensor));
|
||||
convert_tensor((void*)read_buffer.data(), tensor_storage.type,
|
||||
(void*)convert_buffer.data(), dst_tensor->type,
|
||||
(int)tensor_storage.nelements());
|
||||
ggml_backend_tensor_set(dst_tensor, convert_buffer.data(), 0, ggml_nbytes(dst_tensor));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (zip != NULL) {
|
||||
zip_close(zip);
|
||||
}
|
||||
|
||||
if (!success) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
return success;
|
||||
}
|
||||
|
||||
int64_t ModelLoader::cal_mem_size() {
|
||||
int64_t mem_size = 0;
|
||||
for (auto& tensor_storage : tensor_storages) {
|
||||
if (is_unused_tensor(tensor_storage.name)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
mem_size += tensor_storage.nbytes();
|
||||
mem_size += GGML_MEM_ALIGN * 2; // for lora alphas
|
||||
}
|
||||
|
||||
return mem_size + 10 * 1024 * 1024;
|
||||
}
|
42
model.h
42
model.h
@ -98,29 +98,6 @@ protected:
|
||||
std::vector<std::string> file_paths_;
|
||||
std::vector<TensorStorage> tensor_storages;
|
||||
|
||||
public:
|
||||
virtual bool init_from_file(const std::string& file_path, const std::string& prefix = "");
|
||||
virtual bool init_from_files(const std::vector<std::string>& file_paths);
|
||||
virtual SDVersion get_sd_version();
|
||||
virtual ggml_type get_sd_wtype();
|
||||
virtual bool load_vocab(on_new_token_cb_t on_new_token_cb);
|
||||
virtual bool load_tensors(on_new_tensor_cb_t on_new_tensor_cb);
|
||||
virtual int64_t cal_mem_size();
|
||||
virtual ~ModelLoader() = default;
|
||||
};
|
||||
|
||||
class GGUFModelLoader : public ModelLoader {
|
||||
public:
|
||||
bool init_from_file(const std::string& file_path, const std::string& prefix = "");
|
||||
};
|
||||
|
||||
class SafeTensorsModelLoader : public ModelLoader {
|
||||
public:
|
||||
bool init_from_file(const std::string& file_path, const std::string& prefix = "");
|
||||
};
|
||||
|
||||
class CkptModelLoader : public ModelLoader {
|
||||
private:
|
||||
bool parse_data_pkl(uint8_t* buffer,
|
||||
size_t buffer_size,
|
||||
zip_t* zip,
|
||||
@ -128,15 +105,18 @@ private:
|
||||
size_t file_index,
|
||||
const std::string& prefix);
|
||||
|
||||
bool init_from_gguf_file(const std::string& file_path, const std::string& prefix = "");
|
||||
bool init_from_safetensors_file(const std::string& file_path, const std::string& prefix = "");
|
||||
bool init_from_ckpt_file(const std::string& file_path, const std::string& prefix = "");
|
||||
bool init_from_diffusers_file(const std::string& file_path, const std::string& prefix = "");
|
||||
|
||||
public:
|
||||
bool init_from_file(const std::string& file_path, const std::string& prefix = "");
|
||||
SDVersion get_sd_version();
|
||||
ggml_type get_sd_wtype();
|
||||
bool load_vocab(on_new_token_cb_t on_new_token_cb);
|
||||
bool load_tensors(on_new_tensor_cb_t on_new_tensor_cb);
|
||||
int64_t cal_mem_size();
|
||||
~ModelLoader() = default;
|
||||
};
|
||||
|
||||
class DiffusersModelLoader : public SafeTensorsModelLoader {
|
||||
public:
|
||||
bool init_from_file(const std::string& file_path, const std::string& prefix = "");
|
||||
};
|
||||
|
||||
ModelLoader* init_model_loader_from_file(const std::string& file_path);
|
||||
|
||||
#endif // __MODEL_H__
|
@ -3281,9 +3281,10 @@ struct LoraModel {
|
||||
bool load(ggml_backend_t backend_, std::string file_path) {
|
||||
backend = backend_;
|
||||
LOG_INFO("loading LoRA from '%s'", file_path.c_str());
|
||||
std::shared_ptr<ModelLoader> model_loader = std::shared_ptr<ModelLoader>(init_model_loader_from_file(file_path));
|
||||
ModelLoader model_loader;
|
||||
;
|
||||
|
||||
if (!model_loader) {
|
||||
if (!model_loader.init_from_file(file_path)) {
|
||||
LOG_ERROR("init lora model loader from file failed: '%s'", file_path.c_str());
|
||||
return false;
|
||||
}
|
||||
@ -3299,10 +3300,10 @@ struct LoraModel {
|
||||
return false;
|
||||
}
|
||||
|
||||
ggml_type wtype = model_loader->get_sd_wtype();
|
||||
ggml_type wtype = model_loader.get_sd_wtype();
|
||||
|
||||
LOG_DEBUG("calculating buffer size");
|
||||
int64_t memory_buffer_size = model_loader->cal_mem_size();
|
||||
int64_t memory_buffer_size = model_loader.cal_mem_size();
|
||||
LOG_DEBUG("lora params backend buffer size = % 6.2f MB", memory_buffer_size / (1024.0 * 1024.0));
|
||||
|
||||
params_buffer_lora = ggml_backend_alloc_buffer(backend, memory_buffer_size);
|
||||
@ -3320,7 +3321,7 @@ struct LoraModel {
|
||||
return true;
|
||||
};
|
||||
|
||||
model_loader->load_tensors(on_new_tensor_cb);
|
||||
model_loader.load_tensors(on_new_tensor_cb);
|
||||
|
||||
LOG_DEBUG("finished loaded lora");
|
||||
ggml_allocr_free(alloc);
|
||||
@ -3664,21 +3665,21 @@ public:
|
||||
#endif
|
||||
#endif
|
||||
LOG_INFO("loading model from '%s'", model_path.c_str());
|
||||
std::shared_ptr<ModelLoader> model_loader = std::shared_ptr<ModelLoader>(init_model_loader_from_file(model_path));
|
||||
ModelLoader model_loader;
|
||||
|
||||
if (!model_loader) {
|
||||
if (!model_loader.init_from_file(model_path)) {
|
||||
LOG_ERROR("init model loader from file failed: '%s'", model_path.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
if (vae_path.size() > 0) {
|
||||
LOG_INFO("loading vae from '%s'", vae_path.c_str());
|
||||
if (!model_loader->init_from_file(vae_path, "vae.")) {
|
||||
if (!model_loader.init_from_file(vae_path, "vae.")) {
|
||||
LOG_WARN("loading vae from '%s' failed", vae_path.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
SDVersion version = model_loader->get_sd_version();
|
||||
SDVersion version = model_loader.get_sd_version();
|
||||
if (version == VERSION_COUNT) {
|
||||
LOG_ERROR("get sd version from file failed: '%s'", model_path.c_str());
|
||||
return false;
|
||||
@ -3687,7 +3688,7 @@ public:
|
||||
diffusion_model = UNetModel(version);
|
||||
LOG_INFO("Stable Diffusion %s ", model_version_to_str[version]);
|
||||
if (wtype == GGML_TYPE_COUNT) {
|
||||
model_data_type = model_loader->get_sd_wtype();
|
||||
model_data_type = model_loader.get_sd_wtype();
|
||||
} else {
|
||||
model_data_type = wtype;
|
||||
}
|
||||
@ -3697,7 +3698,7 @@ public:
|
||||
auto add_token = [&](const std::string& token, int32_t token_id) {
|
||||
cond_stage_model.tokenizer.add_token(token, token_id);
|
||||
};
|
||||
bool success = model_loader->load_vocab(add_token);
|
||||
bool success = model_loader.load_vocab(add_token);
|
||||
if (!success) {
|
||||
LOG_ERROR("get vocab from file failed: '%s'", model_path.c_str());
|
||||
return false;
|
||||
@ -3794,7 +3795,7 @@ public:
|
||||
|
||||
// print_ggml_tensor(alphas_cumprod_tensor);
|
||||
|
||||
success = model_loader->load_tensors(on_new_tensor_cb);
|
||||
success = model_loader.load_tensors(on_new_tensor_cb);
|
||||
if (!success) {
|
||||
LOG_ERROR("load tensors from file failed");
|
||||
ggml_free(ctx);
|
||||
|
@ -1,8 +1,8 @@
|
||||
#ifndef __STABLE_DIFFUSION_H__
|
||||
#define __STABLE_DIFFUSION_H__
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
enum RNGType {
|
||||
|
Loading…
Reference in New Issue
Block a user