feat: add progress callback (#170)

This commit is contained in:
fszontagh 2024-03-02 10:28:41 +01:00 committed by GitHub
parent d164236b2a
commit 7be65faa7c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 24 additions and 4 deletions

View File

@ -891,6 +891,10 @@ struct FrozenCLIPEmbedderWithCustomWords : public GGMLModule {
LOG_ERROR("embedding '%s' failed", embd_name.c_str()); LOG_ERROR("embedding '%s' failed", embd_name.c_str());
return false; return false;
} }
if (std::find(readed_embeddings.begin(), readed_embeddings.end(), embd_name) != readed_embeddings.end()) {
LOG_DEBUG("embedding already read in: %s", embd_name.c_str());
return true;
}
struct ggml_init_params params; struct ggml_init_params params;
params.mem_size = 10 * 1024 * 1024; // max for custom embeddings 10 MB params.mem_size = 10 * 1024 * 1024; // max for custom embeddings 10 MB
params.mem_buffer = NULL; params.mem_buffer = NULL;

View File

@ -33,6 +33,7 @@ struct LoraModel : public GGMLModule {
return model_loader.get_params_mem_size(NULL); return model_loader.get_params_mem_size(NULL);
} }
bool load_from_file() { bool load_from_file() {
LOG_INFO("loading LoRA from '%s'", file_path.c_str()); LOG_INFO("loading LoRA from '%s'", file_path.c_str());
@ -55,6 +56,7 @@ struct LoraModel : public GGMLModule {
auto real = lora_tensors[name]; auto real = lora_tensors[name];
*dst_tensor = real; *dst_tensor = real;
} }
return true; return true;
}; };
@ -64,6 +66,7 @@ struct LoraModel : public GGMLModule {
dry_run = false; dry_run = false;
model_loader.load_tensors(on_new_tensor_cb, backend); model_loader.load_tensors(on_new_tensor_cb, backend);
LOG_DEBUG("finished loaded lora"); LOG_DEBUG("finished loaded lora");
return true; return true;
} }

View File

@ -92,8 +92,10 @@ enum sd_log_level_t {
}; };
typedef void (*sd_log_cb_t)(enum sd_log_level_t level, const char* text, void* data); typedef void (*sd_log_cb_t)(enum sd_log_level_t level, const char* text, void* data);
typedef void (*sd_progress_cb_t)(int step,int steps,float time, void* data);
SD_API void sd_set_log_callback(sd_log_cb_t sd_log_cb, void* data); SD_API void sd_set_log_callback(sd_log_cb_t sd_log_cb, void* data);
SD_API void sd_set_progress_callback(sd_progress_cb_t cb, void* data);
SD_API int32_t get_num_physical_cores(); SD_API int32_t get_num_physical_cores();
SD_API const char* sd_get_system_info(); SD_API const char* sd_get_system_info();

View File

@ -161,6 +161,9 @@ int32_t get_num_physical_cores() {
return n_threads > 0 ? (n_threads <= 4 ? n_threads : n_threads / 2) : 4; return n_threads > 0 ? (n_threads <= 4 ? n_threads : n_threads / 2) : 4;
} }
static sd_progress_cb_t sd_progress_cb = NULL;
void* sd_progress_cb_data = NULL;
std::u32string utf8_to_utf32(const std::string& utf8_str) { std::u32string utf8_to_utf32(const std::string& utf8_str) {
std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> converter; std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> converter;
return converter.from_bytes(utf8_str); return converter.from_bytes(utf8_str);
@ -205,6 +208,10 @@ std::string path_join(const std::string& p1, const std::string& p2) {
} }
void pretty_progress(int step, int steps, float time) { void pretty_progress(int step, int steps, float time) {
if (sd_progress_cb) {
sd_progress_cb(step,steps,time, sd_progress_cb_data);
return;
}
if (step == 0) { if (step == 0) {
return; return;
} }
@ -248,8 +255,9 @@ std::string trim(const std::string& s) {
return rtrim(ltrim(s)); return rtrim(ltrim(s));
} }
static sd_log_cb_t sd_log_cb = NULL; static sd_log_cb_t sd_log_cb = NULL;
void* sd_log_cb_data = NULL; void* sd_log_cb_data = NULL;
#define LOG_BUFFER_SIZE 1024 #define LOG_BUFFER_SIZE 1024
@ -286,7 +294,10 @@ void sd_set_log_callback(sd_log_cb_t cb, void* data) {
sd_log_cb = cb; sd_log_cb = cb;
sd_log_cb_data = data; sd_log_cb_data = data;
} }
void sd_set_progress_callback(sd_progress_cb_t cb, void* data) {
sd_progress_cb = cb;
sd_progress_cb_data = data;
}
const char* sd_get_system_info() { const char* sd_get_system_info() {
static char buffer[1024]; static char buffer[1024];
std::stringstream ss; std::stringstream ss;

View File

@ -6,7 +6,7 @@
/*================================================== AutoEncoderKL ===================================================*/ /*================================================== AutoEncoderKL ===================================================*/
#define VAE_GRAPH_SIZE 10240 #define VAE_GRAPH_SIZE 20480
class ResnetBlock : public UnaryBlock { class ResnetBlock : public UnaryBlock {
protected: protected: