fix: detect model format base on file content
This commit is contained in:
parent
8a87b273ad
commit
f99bcd1f76
82
model.cpp
82
model.cpp
@ -14,6 +14,8 @@
|
||||
#include "ggml/ggml-backend.h"
|
||||
#include "ggml/ggml.h"
|
||||
|
||||
#define ST_HEADER_SIZE_LEN 8
|
||||
|
||||
uint64_t read_u64(uint8_t* buffer) {
|
||||
// little endian
|
||||
uint64_t value = 0;
|
||||
@ -533,17 +535,89 @@ std::map<char, int> unicode_to_byte() {
|
||||
return byte_decoder;
|
||||
}
|
||||
|
||||
bool is_zip_file(const std::string& file_path) {
|
||||
struct zip_t* zip = zip_open(file_path.c_str(), 0, 'r');
|
||||
if (zip == NULL) {
|
||||
return false;
|
||||
}
|
||||
zip_close(zip);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool is_gguf_file(const std::string& file_path) {
|
||||
std::ifstream file(file_path, std::ios::binary);
|
||||
if (!file.is_open()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
char magic[4];
|
||||
|
||||
file.read(magic, sizeof(magic));
|
||||
if (!file) {
|
||||
return false;
|
||||
}
|
||||
for (uint32_t i = 0; i < sizeof(magic); i++) {
|
||||
if (magic[i] != GGUF_MAGIC[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool is_safetensors_file(const std::string& file_path) {
|
||||
std::ifstream file(file_path, std::ios::binary);
|
||||
if (!file.is_open()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// get file size
|
||||
file.seekg(0, file.end);
|
||||
size_t file_size_ = file.tellg();
|
||||
file.seekg(0, file.beg);
|
||||
|
||||
// read header size
|
||||
if (file_size_ <= ST_HEADER_SIZE_LEN) {
|
||||
return false;
|
||||
}
|
||||
|
||||
uint8_t header_size_buf[ST_HEADER_SIZE_LEN];
|
||||
file.read((char*)header_size_buf, ST_HEADER_SIZE_LEN);
|
||||
if (!file) {
|
||||
return false;
|
||||
}
|
||||
|
||||
size_t header_size_ = read_u64(header_size_buf);
|
||||
if (header_size_ >= file_size_) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// read header
|
||||
std::vector<char> header_buf;
|
||||
header_buf.resize(header_size_ + 1);
|
||||
header_buf[header_size_] = '\0';
|
||||
file.read(header_buf.data(), header_size_);
|
||||
if (!file) {
|
||||
return false;
|
||||
}
|
||||
nlohmann::json header_ = nlohmann::json::parse(header_buf.data());
|
||||
if (header_.is_discarded()) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ModelLoader::init_from_file(const std::string& file_path, const std::string& prefix) {
|
||||
if (is_directory(file_path)) {
|
||||
LOG_INFO("load %s using diffusers format", file_path.c_str());
|
||||
return init_from_diffusers_file(file_path, prefix);
|
||||
} else if (ends_with(file_path, ".gguf")) {
|
||||
} else if (is_gguf_file(file_path)) {
|
||||
LOG_INFO("load %s using gguf format", file_path.c_str());
|
||||
return init_from_gguf_file(file_path, prefix);
|
||||
} else if (ends_with(file_path, ".safetensors")) {
|
||||
} else if (is_safetensors_file(file_path)) {
|
||||
LOG_INFO("load %s using safetensors format", file_path.c_str());
|
||||
return init_from_safetensors_file(file_path, prefix);
|
||||
} else if (ends_with(file_path, ".ckpt")) {
|
||||
} else if (is_zip_file(file_path)) {
|
||||
LOG_INFO("load %s using checkpoint format", file_path.c_str());
|
||||
return init_from_ckpt_file(file_path, prefix);
|
||||
} else {
|
||||
@ -593,8 +667,6 @@ bool ModelLoader::init_from_gguf_file(const std::string& file_path, const std::s
|
||||
|
||||
/*================================================= SafeTensorsModelLoader ==================================================*/
|
||||
|
||||
#define ST_HEADER_SIZE_LEN 8
|
||||
|
||||
ggml_type str_to_ggml_type(const std::string& dtype) {
|
||||
ggml_type ttype = GGML_TYPE_COUNT;
|
||||
if (dtype == "F16") {
|
||||
|
Loading…
Reference in New Issue
Block a user