fix: allow model and vae using different format

This commit is contained in:
leejet 2023-12-03 17:12:04 +08:00
parent d7af2c2ba9
commit 8a87b273ad
4 changed files with 266 additions and 305 deletions

502
model.cpp
View File

@ -534,232 +534,29 @@ std::map<char, int> unicode_to_byte() {
} }
bool ModelLoader::init_from_file(const std::string& file_path, const std::string& prefix) { bool ModelLoader::init_from_file(const std::string& file_path, const std::string& prefix) {
file_paths_.push_back(file_path); if (is_directory(file_path)) {
return true; LOG_INFO("load %s using diffusers format", file_path.c_str());
} return init_from_diffusers_file(file_path, prefix);
} else if (ends_with(file_path, ".gguf")) {
bool ModelLoader::init_from_files(const std::vector<std::string>& file_paths) { LOG_INFO("load %s using gguf format", file_path.c_str());
for (auto& file_path : file_paths) { return init_from_gguf_file(file_path, prefix);
if (!init_from_file(file_path)) { } else if (ends_with(file_path, ".safetensors")) {
return false; LOG_INFO("load %s using safetensors format", file_path.c_str());
} return init_from_safetensors_file(file_path, prefix);
} else if (ends_with(file_path, ".ckpt")) {
LOG_INFO("load %s using checkpoint format", file_path.c_str());
return init_from_ckpt_file(file_path, prefix);
} else {
LOG_WARN("unknown format %s", file_path.c_str());
return false;
} }
return true;
}
SDVersion ModelLoader::get_sd_version() {
TensorStorage token_embedding_weight;
for (auto& tensor_storage : tensor_storages) {
if (tensor_storage.name == "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight" ||
tensor_storage.name == "cond_stage_model.model.token_embedding.weight" ||
tensor_storage.name == "text_model.embeddings.token_embedding.weight" ||
tensor_storage.name == "te.text_model.embeddings.token_embedding.weight") {
token_embedding_weight = tensor_storage;
break;
}
}
if (token_embedding_weight.ne[0] == 768) {
return VERSION_1_x;
} else if (token_embedding_weight.ne[0] == 1024) {
return VERSION_2_x;
}
return VERSION_COUNT;
}
ggml_type ModelLoader::get_sd_wtype() {
for (auto& tensor_storage : tensor_storages) {
if (is_unused_tensor(tensor_storage.name)) {
continue;
}
if (tensor_storage.name.find(".weight") != std::string::npos &&
tensor_storage.name.find("time_embed") != std::string::npos) {
return tensor_storage.type;
}
}
return GGML_TYPE_COUNT;
}
bool ModelLoader::load_vocab(on_new_token_cb_t on_new_token_cb) {
char* vocab_buffer = reinterpret_cast<char*>(vocab_json);
nlohmann::json vocab = nlohmann::json::parse(vocab_buffer);
std::map<char, int> decoder = unicode_to_byte();
for (auto& it : vocab.items()) {
int token_id = it.value();
std::string token_str = it.key();
std::string token = "";
for (char c : token_str) {
token += decoder[c];
}
on_new_token_cb(token, token_id);
}
return true;
}
bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb) {
bool success = true;
for (size_t file_index = 0; file_index < file_paths_.size(); file_index++) {
std::string file_path = file_paths_[file_index];
LOG_DEBUG("loading tensors from %s", file_path.c_str());
std::ifstream file(file_path, std::ios::binary);
if (!file.is_open()) {
LOG_ERROR("failed to open '%s'", file_path.c_str());
return false;
}
bool is_zip = false;
for (auto& tensor_storage : tensor_storages) {
if (tensor_storage.index_in_zip >= 0) {
is_zip = true;
break;
}
}
struct zip_t* zip = NULL;
if (is_zip) {
zip = zip_open(file_path.c_str(), 0, 'r');
if (zip == NULL) {
LOG_ERROR("failed to open zip '%s'", file_path.c_str());
return false;
}
}
std::vector<uint8_t> read_buffer;
std::vector<uint8_t> convert_buffer;
auto read_data = [&](const TensorStorage& tensor_storage, char* buf, size_t n) {
if (zip != NULL) {
zip_entry_openbyindex(zip, tensor_storage.index_in_zip);
size_t entry_size = zip_entry_size(zip);
if (entry_size != n) {
read_buffer.resize(entry_size);
zip_entry_noallocread(zip, (void*)read_buffer.data(), entry_size);
memcpy((void*)buf, (void*)(read_buffer.data() + tensor_storage.offset), n);
} else {
zip_entry_noallocread(zip, (void*)buf, n);
}
zip_entry_close(zip);
} else {
file.seekg(tensor_storage.offset);
file.read(buf, n);
if (!file) {
LOG_ERROR("read tensor data failed: '%s'", file_path.c_str());
return false;
}
}
return true;
};
std::vector<TensorStorage> processed_tensor_storages;
for (auto& tensor_storage : tensor_storages) {
if (tensor_storage.file_index != file_index) {
continue;
}
// LOG_DEBUG("%s", name.c_str());
if (is_unused_tensor(tensor_storage.name)) {
continue;
}
preprocess_tensor(tensor_storage, processed_tensor_storages);
}
for (auto& tensor_storage : processed_tensor_storages) {
// LOG_DEBUG("%s", name.c_str());
ggml_tensor* dst_tensor = NULL;
success = on_new_tensor_cb(tensor_storage, &dst_tensor);
if (!success) {
LOG_WARN("process tensor failed: '%s'", tensor_storage.name.c_str());
break;
}
if (dst_tensor == NULL) {
continue;
}
ggml_backend_t backend = ggml_get_backend(dst_tensor);
size_t nbytes_to_read = tensor_storage.nbytes_to_read();
if (backend == NULL || ggml_backend_is_cpu(backend)) {
// for the CPU and Metal backend, we can copy directly into the tensor
if (tensor_storage.type == dst_tensor->type) {
GGML_ASSERT(ggml_nbytes(dst_tensor) == nbytes_to_read);
read_data(tensor_storage, (char*)dst_tensor->data, nbytes_to_read);
if (tensor_storage.is_bf16) {
// inplace op
bf16_to_f32_vec((uint16_t*)dst_tensor->data, (float*)dst_tensor->data, tensor_storage.nelements());
}
} else {
read_buffer.resize(tensor_storage.nbytes());
read_data(tensor_storage, (char*)read_buffer.data(), nbytes_to_read);
if (tensor_storage.is_bf16) {
// inplace op
bf16_to_f32_vec((uint16_t*)read_buffer.data(), (float*)read_buffer.data(), tensor_storage.nelements());
}
convert_tensor((void*)read_buffer.data(), tensor_storage.type, dst_tensor->data,
dst_tensor->type, (int)tensor_storage.nelements());
}
} else {
read_buffer.resize(tensor_storage.nbytes());
read_data(tensor_storage, (char*)read_buffer.data(), nbytes_to_read);
if (tensor_storage.is_bf16) {
// inplace op
bf16_to_f32_vec((uint16_t*)read_buffer.data(), (float*)read_buffer.data(), tensor_storage.nelements());
}
if (tensor_storage.type == dst_tensor->type) {
// copy to device memory
ggml_backend_tensor_set(dst_tensor, read_buffer.data(), 0, ggml_nbytes(dst_tensor));
} else {
// convert first, then copy to device memory
convert_buffer.resize(ggml_nbytes(dst_tensor));
convert_tensor((void*)read_buffer.data(), tensor_storage.type,
(void*)convert_buffer.data(), dst_tensor->type,
(int)tensor_storage.nelements());
ggml_backend_tensor_set(dst_tensor, convert_buffer.data(), 0, ggml_nbytes(dst_tensor));
}
}
}
if (zip != NULL) {
zip_close(zip);
}
if (!success) {
break;
}
}
return success;
}
int64_t ModelLoader::cal_mem_size() {
int64_t mem_size = 0;
for (auto& tensor_storage : tensor_storages) {
if (is_unused_tensor(tensor_storage.name)) {
continue;
}
mem_size += tensor_storage.nbytes();
mem_size += GGML_MEM_ALIGN * 2; // for lora alphas
}
return mem_size + 10 * 1024 * 1024;
} }
/*================================================= GGUFModelLoader ==================================================*/ /*================================================= GGUFModelLoader ==================================================*/
bool GGUFModelLoader::init_from_file(const std::string& file_path, const std::string& prefix) { bool ModelLoader::init_from_gguf_file(const std::string& file_path, const std::string& prefix) {
LOG_INFO("loading model from '%s'", file_path.c_str()); LOG_DEBUG("init from '%s'", file_path.c_str());
ModelLoader::init_from_file(file_path, prefix); file_paths_.push_back(file_path);
size_t file_index = file_paths_.size() - 1; size_t file_index = file_paths_.size() - 1;
gguf_context* ctx_gguf_ = NULL; gguf_context* ctx_gguf_ = NULL;
@ -811,8 +608,9 @@ ggml_type str_to_ggml_type(const std::string& dtype) {
} }
// https://huggingface.co/docs/safetensors/index // https://huggingface.co/docs/safetensors/index
bool SafeTensorsModelLoader::init_from_file(const std::string& file_path, const std::string& prefix) { bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const std::string& prefix) {
ModelLoader::init_from_file(file_path, prefix); LOG_DEBUG("init from '%s'", file_path.c_str());
file_paths_.push_back(file_path);
size_t file_index = file_paths_.size() - 1; size_t file_index = file_paths_.size() - 1;
std::ifstream file(file_path, std::ios::binary); std::ifstream file(file_path, std::ios::binary);
if (!file.is_open()) { if (!file.is_open()) {
@ -913,21 +711,18 @@ bool SafeTensorsModelLoader::init_from_file(const std::string& file_path, const
/*================================================= DiffusersModelLoader ==================================================*/ /*================================================= DiffusersModelLoader ==================================================*/
bool DiffusersModelLoader::init_from_file(const std::string& file_path, const std::string& prefix) { bool ModelLoader::init_from_diffusers_file(const std::string& file_path, const std::string& prefix) {
if (!is_directory(file_path)) {
return SafeTensorsModelLoader::init_from_file(file_path, prefix);
}
std::string unet_path = path_join(file_path, "unet/diffusion_pytorch_model.safetensors"); std::string unet_path = path_join(file_path, "unet/diffusion_pytorch_model.safetensors");
std::string vae_path = path_join(file_path, "vae/diffusion_pytorch_model.safetensors"); std::string vae_path = path_join(file_path, "vae/diffusion_pytorch_model.safetensors");
std::string clip_path = path_join(file_path, "text_encoder/model.safetensors"); std::string clip_path = path_join(file_path, "text_encoder/model.safetensors");
if (!SafeTensorsModelLoader::init_from_file(unet_path, "unet.")) { if (!init_from_safetensors_file(unet_path, "unet.")) {
return false; return false;
} }
if (!SafeTensorsModelLoader::init_from_file(vae_path, "vae.")) { if (!init_from_safetensors_file(vae_path, "vae.")) {
return false; return false;
} }
if (!SafeTensorsModelLoader::init_from_file(clip_path, "te.")) { if (!init_from_safetensors_file(clip_path, "te.")) {
return false; return false;
} }
return true; return true;
@ -1127,12 +922,12 @@ int find_char(uint8_t* buffer, int len, char c) {
#define MAX_STRING_BUFFER 512 #define MAX_STRING_BUFFER 512
bool CkptModelLoader::parse_data_pkl(uint8_t* buffer, bool ModelLoader::parse_data_pkl(uint8_t* buffer,
size_t buffer_size, size_t buffer_size,
zip_t* zip, zip_t* zip,
std::string dir, std::string dir,
size_t file_index, size_t file_index,
const std::string& prefix) { const std::string& prefix) {
uint8_t* buffer_end = buffer + buffer_size; uint8_t* buffer_end = buffer + buffer_size;
if (buffer[0] == 0x80) { // proto if (buffer[0] == 0x80) { // proto
if (buffer[1] != 2) { if (buffer[1] != 2) {
@ -1250,8 +1045,9 @@ bool CkptModelLoader::parse_data_pkl(uint8_t* buffer,
return true; return true;
} }
bool CkptModelLoader::init_from_file(const std::string& file_path, const std::string& prefix) { bool ModelLoader::init_from_ckpt_file(const std::string& file_path, const std::string& prefix) {
ModelLoader::init_from_file(file_path, prefix); LOG_DEBUG("init from '%s'", file_path.c_str());
file_paths_.push_back(file_path);
size_t file_index = file_paths_.size() - 1; size_t file_index = file_paths_.size() - 1;
struct zip_t* zip = zip_open(file_path.c_str(), 0, 'r'); struct zip_t* zip = zip_open(file_path.c_str(), 0, 'r');
@ -1284,29 +1080,213 @@ bool CkptModelLoader::init_from_file(const std::string& file_path, const std::st
return true; return true;
} }
/*================================================= init_model_loader_from_file ==================================================*/ SDVersion ModelLoader::get_sd_version() {
TensorStorage token_embedding_weight;
for (auto& tensor_storage : tensor_storages) {
if (tensor_storage.name == "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight" ||
tensor_storage.name == "cond_stage_model.model.token_embedding.weight" ||
tensor_storage.name == "text_model.embeddings.token_embedding.weight" ||
tensor_storage.name == "te.text_model.embeddings.token_embedding.weight") {
token_embedding_weight = tensor_storage;
break;
}
}
if (token_embedding_weight.ne[0] == 768) {
return VERSION_1_x;
} else if (token_embedding_weight.ne[0] == 1024) {
return VERSION_2_x;
}
return VERSION_COUNT;
}
ModelLoader* init_model_loader_from_file(const std::string& file_path) { ggml_type ModelLoader::get_sd_wtype() {
ModelLoader* model_loader = NULL; for (auto& tensor_storage : tensor_storages) {
if (is_directory(file_path)) { if (is_unused_tensor(tensor_storage.name)) {
LOG_DEBUG("load %s using diffusers format", file_path.c_str()); continue;
model_loader = new DiffusersModelLoader(); }
} else if (ends_with(file_path, ".gguf")) {
LOG_DEBUG("load %s using gguf format", file_path.c_str()); if (tensor_storage.name.find(".weight") != std::string::npos &&
model_loader = new GGUFModelLoader(); tensor_storage.name.find("time_embed") != std::string::npos) {
} else if (ends_with(file_path, ".safetensors")) { return tensor_storage.type;
LOG_DEBUG("load %s using safetensors format", file_path.c_str()); }
model_loader = new SafeTensorsModelLoader();
} else if (ends_with(file_path, ".ckpt")) {
LOG_DEBUG("load %s using checkpoint format", file_path.c_str());
model_loader = new CkptModelLoader();
} else {
LOG_DEBUG("unknown format %s", file_path.c_str());
return NULL;
} }
if (!model_loader->init_from_file(file_path)) { return GGML_TYPE_COUNT;
delete model_loader; }
model_loader = NULL;
bool ModelLoader::load_vocab(on_new_token_cb_t on_new_token_cb) {
char* vocab_buffer = reinterpret_cast<char*>(vocab_json);
nlohmann::json vocab = nlohmann::json::parse(vocab_buffer);
std::map<char, int> decoder = unicode_to_byte();
for (auto& it : vocab.items()) {
int token_id = it.value();
std::string token_str = it.key();
std::string token = "";
for (char c : token_str) {
token += decoder[c];
}
on_new_token_cb(token, token_id);
} }
return model_loader; return true;
} }
bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb) {
bool success = true;
for (size_t file_index = 0; file_index < file_paths_.size(); file_index++) {
std::string file_path = file_paths_[file_index];
LOG_DEBUG("loading tensors from %s", file_path.c_str());
std::ifstream file(file_path, std::ios::binary);
if (!file.is_open()) {
LOG_ERROR("failed to open '%s'", file_path.c_str());
return false;
}
bool is_zip = false;
for (auto& tensor_storage : tensor_storages) {
if (tensor_storage.file_index != file_index) {
continue;
}
if (tensor_storage.index_in_zip >= 0) {
is_zip = true;
break;
}
}
struct zip_t* zip = NULL;
if (is_zip) {
zip = zip_open(file_path.c_str(), 0, 'r');
if (zip == NULL) {
LOG_ERROR("failed to open zip '%s'", file_path.c_str());
return false;
}
}
std::vector<uint8_t> read_buffer;
std::vector<uint8_t> convert_buffer;
auto read_data = [&](const TensorStorage& tensor_storage, char* buf, size_t n) {
if (zip != NULL) {
zip_entry_openbyindex(zip, tensor_storage.index_in_zip);
size_t entry_size = zip_entry_size(zip);
if (entry_size != n) {
read_buffer.resize(entry_size);
zip_entry_noallocread(zip, (void*)read_buffer.data(), entry_size);
memcpy((void*)buf, (void*)(read_buffer.data() + tensor_storage.offset), n);
} else {
zip_entry_noallocread(zip, (void*)buf, n);
}
zip_entry_close(zip);
} else {
file.seekg(tensor_storage.offset);
file.read(buf, n);
if (!file) {
LOG_ERROR("read tensor data failed: '%s'", file_path.c_str());
return false;
}
}
return true;
};
std::vector<TensorStorage> processed_tensor_storages;
for (auto& tensor_storage : tensor_storages) {
if (tensor_storage.file_index != file_index) {
continue;
}
// LOG_DEBUG("%s", name.c_str());
if (is_unused_tensor(tensor_storage.name)) {
continue;
}
preprocess_tensor(tensor_storage, processed_tensor_storages);
}
for (auto& tensor_storage : processed_tensor_storages) {
// LOG_DEBUG("%s", name.c_str());
ggml_tensor* dst_tensor = NULL;
success = on_new_tensor_cb(tensor_storage, &dst_tensor);
if (!success) {
LOG_WARN("process tensor failed: '%s'", tensor_storage.name.c_str());
break;
}
if (dst_tensor == NULL) {
continue;
}
ggml_backend_t backend = ggml_get_backend(dst_tensor);
size_t nbytes_to_read = tensor_storage.nbytes_to_read();
if (backend == NULL || ggml_backend_is_cpu(backend)) {
// for the CPU and Metal backend, we can copy directly into the tensor
if (tensor_storage.type == dst_tensor->type) {
GGML_ASSERT(ggml_nbytes(dst_tensor) == nbytes_to_read);
read_data(tensor_storage, (char*)dst_tensor->data, nbytes_to_read);
if (tensor_storage.is_bf16) {
// inplace op
bf16_to_f32_vec((uint16_t*)dst_tensor->data, (float*)dst_tensor->data, tensor_storage.nelements());
}
} else {
read_buffer.resize(tensor_storage.nbytes());
read_data(tensor_storage, (char*)read_buffer.data(), nbytes_to_read);
if (tensor_storage.is_bf16) {
// inplace op
bf16_to_f32_vec((uint16_t*)read_buffer.data(), (float*)read_buffer.data(), tensor_storage.nelements());
}
convert_tensor((void*)read_buffer.data(), tensor_storage.type, dst_tensor->data,
dst_tensor->type, (int)tensor_storage.nelements());
}
} else {
read_buffer.resize(tensor_storage.nbytes());
read_data(tensor_storage, (char*)read_buffer.data(), nbytes_to_read);
if (tensor_storage.is_bf16) {
// inplace op
bf16_to_f32_vec((uint16_t*)read_buffer.data(), (float*)read_buffer.data(), tensor_storage.nelements());
}
if (tensor_storage.type == dst_tensor->type) {
// copy to device memory
ggml_backend_tensor_set(dst_tensor, read_buffer.data(), 0, ggml_nbytes(dst_tensor));
} else {
// convert first, then copy to device memory
convert_buffer.resize(ggml_nbytes(dst_tensor));
convert_tensor((void*)read_buffer.data(), tensor_storage.type,
(void*)convert_buffer.data(), dst_tensor->type,
(int)tensor_storage.nelements());
ggml_backend_tensor_set(dst_tensor, convert_buffer.data(), 0, ggml_nbytes(dst_tensor));
}
}
}
if (zip != NULL) {
zip_close(zip);
}
if (!success) {
break;
}
}
return success;
}
int64_t ModelLoader::cal_mem_size() {
int64_t mem_size = 0;
for (auto& tensor_storage : tensor_storages) {
if (is_unused_tensor(tensor_storage.name)) {
continue;
}
mem_size += tensor_storage.nbytes();
mem_size += GGML_MEM_ALIGN * 2; // for lora alphas
}
return mem_size + 10 * 1024 * 1024;
}

42
model.h
View File

@ -98,29 +98,6 @@ protected:
std::vector<std::string> file_paths_; std::vector<std::string> file_paths_;
std::vector<TensorStorage> tensor_storages; std::vector<TensorStorage> tensor_storages;
public:
virtual bool init_from_file(const std::string& file_path, const std::string& prefix = "");
virtual bool init_from_files(const std::vector<std::string>& file_paths);
virtual SDVersion get_sd_version();
virtual ggml_type get_sd_wtype();
virtual bool load_vocab(on_new_token_cb_t on_new_token_cb);
virtual bool load_tensors(on_new_tensor_cb_t on_new_tensor_cb);
virtual int64_t cal_mem_size();
virtual ~ModelLoader() = default;
};
class GGUFModelLoader : public ModelLoader {
public:
bool init_from_file(const std::string& file_path, const std::string& prefix = "");
};
class SafeTensorsModelLoader : public ModelLoader {
public:
bool init_from_file(const std::string& file_path, const std::string& prefix = "");
};
class CkptModelLoader : public ModelLoader {
private:
bool parse_data_pkl(uint8_t* buffer, bool parse_data_pkl(uint8_t* buffer,
size_t buffer_size, size_t buffer_size,
zip_t* zip, zip_t* zip,
@ -128,15 +105,18 @@ private:
size_t file_index, size_t file_index,
const std::string& prefix); const std::string& prefix);
bool init_from_gguf_file(const std::string& file_path, const std::string& prefix = "");
bool init_from_safetensors_file(const std::string& file_path, const std::string& prefix = "");
bool init_from_ckpt_file(const std::string& file_path, const std::string& prefix = "");
bool init_from_diffusers_file(const std::string& file_path, const std::string& prefix = "");
public: public:
bool init_from_file(const std::string& file_path, const std::string& prefix = ""); bool init_from_file(const std::string& file_path, const std::string& prefix = "");
SDVersion get_sd_version();
ggml_type get_sd_wtype();
bool load_vocab(on_new_token_cb_t on_new_token_cb);
bool load_tensors(on_new_tensor_cb_t on_new_tensor_cb);
int64_t cal_mem_size();
~ModelLoader() = default;
}; };
class DiffusersModelLoader : public SafeTensorsModelLoader {
public:
bool init_from_file(const std::string& file_path, const std::string& prefix = "");
};
ModelLoader* init_model_loader_from_file(const std::string& file_path);
#endif // __MODEL_H__ #endif // __MODEL_H__

View File

@ -3281,9 +3281,10 @@ struct LoraModel {
bool load(ggml_backend_t backend_, std::string file_path) { bool load(ggml_backend_t backend_, std::string file_path) {
backend = backend_; backend = backend_;
LOG_INFO("loading LoRA from '%s'", file_path.c_str()); LOG_INFO("loading LoRA from '%s'", file_path.c_str());
std::shared_ptr<ModelLoader> model_loader = std::shared_ptr<ModelLoader>(init_model_loader_from_file(file_path)); ModelLoader model_loader;
;
if (!model_loader) { if (!model_loader.init_from_file(file_path)) {
LOG_ERROR("init lora model loader from file failed: '%s'", file_path.c_str()); LOG_ERROR("init lora model loader from file failed: '%s'", file_path.c_str());
return false; return false;
} }
@ -3299,10 +3300,10 @@ struct LoraModel {
return false; return false;
} }
ggml_type wtype = model_loader->get_sd_wtype(); ggml_type wtype = model_loader.get_sd_wtype();
LOG_DEBUG("calculating buffer size"); LOG_DEBUG("calculating buffer size");
int64_t memory_buffer_size = model_loader->cal_mem_size(); int64_t memory_buffer_size = model_loader.cal_mem_size();
LOG_DEBUG("lora params backend buffer size = % 6.2f MB", memory_buffer_size / (1024.0 * 1024.0)); LOG_DEBUG("lora params backend buffer size = % 6.2f MB", memory_buffer_size / (1024.0 * 1024.0));
params_buffer_lora = ggml_backend_alloc_buffer(backend, memory_buffer_size); params_buffer_lora = ggml_backend_alloc_buffer(backend, memory_buffer_size);
@ -3320,7 +3321,7 @@ struct LoraModel {
return true; return true;
}; };
model_loader->load_tensors(on_new_tensor_cb); model_loader.load_tensors(on_new_tensor_cb);
LOG_DEBUG("finished loaded lora"); LOG_DEBUG("finished loaded lora");
ggml_allocr_free(alloc); ggml_allocr_free(alloc);
@ -3664,21 +3665,21 @@ public:
#endif #endif
#endif #endif
LOG_INFO("loading model from '%s'", model_path.c_str()); LOG_INFO("loading model from '%s'", model_path.c_str());
std::shared_ptr<ModelLoader> model_loader = std::shared_ptr<ModelLoader>(init_model_loader_from_file(model_path)); ModelLoader model_loader;
if (!model_loader) { if (!model_loader.init_from_file(model_path)) {
LOG_ERROR("init model loader from file failed: '%s'", model_path.c_str()); LOG_ERROR("init model loader from file failed: '%s'", model_path.c_str());
return false; return false;
} }
if (vae_path.size() > 0) { if (vae_path.size() > 0) {
LOG_INFO("loading vae from '%s'", vae_path.c_str()); LOG_INFO("loading vae from '%s'", vae_path.c_str());
if (!model_loader->init_from_file(vae_path, "vae.")) { if (!model_loader.init_from_file(vae_path, "vae.")) {
LOG_WARN("loading vae from '%s' failed", vae_path.c_str()); LOG_WARN("loading vae from '%s' failed", vae_path.c_str());
} }
} }
SDVersion version = model_loader->get_sd_version(); SDVersion version = model_loader.get_sd_version();
if (version == VERSION_COUNT) { if (version == VERSION_COUNT) {
LOG_ERROR("get sd version from file failed: '%s'", model_path.c_str()); LOG_ERROR("get sd version from file failed: '%s'", model_path.c_str());
return false; return false;
@ -3687,7 +3688,7 @@ public:
diffusion_model = UNetModel(version); diffusion_model = UNetModel(version);
LOG_INFO("Stable Diffusion %s ", model_version_to_str[version]); LOG_INFO("Stable Diffusion %s ", model_version_to_str[version]);
if (wtype == GGML_TYPE_COUNT) { if (wtype == GGML_TYPE_COUNT) {
model_data_type = model_loader->get_sd_wtype(); model_data_type = model_loader.get_sd_wtype();
} else { } else {
model_data_type = wtype; model_data_type = wtype;
} }
@ -3697,7 +3698,7 @@ public:
auto add_token = [&](const std::string& token, int32_t token_id) { auto add_token = [&](const std::string& token, int32_t token_id) {
cond_stage_model.tokenizer.add_token(token, token_id); cond_stage_model.tokenizer.add_token(token, token_id);
}; };
bool success = model_loader->load_vocab(add_token); bool success = model_loader.load_vocab(add_token);
if (!success) { if (!success) {
LOG_ERROR("get vocab from file failed: '%s'", model_path.c_str()); LOG_ERROR("get vocab from file failed: '%s'", model_path.c_str());
return false; return false;
@ -3794,7 +3795,7 @@ public:
// print_ggml_tensor(alphas_cumprod_tensor); // print_ggml_tensor(alphas_cumprod_tensor);
success = model_loader->load_tensors(on_new_tensor_cb); success = model_loader.load_tensors(on_new_tensor_cb);
if (!success) { if (!success) {
LOG_ERROR("load tensors from file failed"); LOG_ERROR("load tensors from file failed");
ggml_free(ctx); ggml_free(ctx);

View File

@ -1,8 +1,8 @@
#ifndef __STABLE_DIFFUSION_H__ #ifndef __STABLE_DIFFUSION_H__
#define __STABLE_DIFFUSION_H__ #define __STABLE_DIFFUSION_H__
#include <string>
#include <memory> #include <memory>
#include <string>
#include <vector> #include <vector>
enum RNGType { enum RNGType {