feat: add progress callback (#170)
This commit is contained in:
parent
d164236b2a
commit
7be65faa7c
4
clip.hpp
4
clip.hpp
@ -891,6 +891,10 @@ struct FrozenCLIPEmbedderWithCustomWords : public GGMLModule {
|
|||||||
LOG_ERROR("embedding '%s' failed", embd_name.c_str());
|
LOG_ERROR("embedding '%s' failed", embd_name.c_str());
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
if (std::find(readed_embeddings.begin(), readed_embeddings.end(), embd_name) != readed_embeddings.end()) {
|
||||||
|
LOG_DEBUG("embedding already read in: %s", embd_name.c_str());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
struct ggml_init_params params;
|
struct ggml_init_params params;
|
||||||
params.mem_size = 10 * 1024 * 1024; // max for custom embeddings 10 MB
|
params.mem_size = 10 * 1024 * 1024; // max for custom embeddings 10 MB
|
||||||
params.mem_buffer = NULL;
|
params.mem_buffer = NULL;
|
||||||
|
3
lora.hpp
3
lora.hpp
@ -33,6 +33,7 @@ struct LoraModel : public GGMLModule {
|
|||||||
return model_loader.get_params_mem_size(NULL);
|
return model_loader.get_params_mem_size(NULL);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
bool load_from_file() {
|
bool load_from_file() {
|
||||||
LOG_INFO("loading LoRA from '%s'", file_path.c_str());
|
LOG_INFO("loading LoRA from '%s'", file_path.c_str());
|
||||||
|
|
||||||
@ -55,6 +56,7 @@ struct LoraModel : public GGMLModule {
|
|||||||
auto real = lora_tensors[name];
|
auto real = lora_tensors[name];
|
||||||
*dst_tensor = real;
|
*dst_tensor = real;
|
||||||
}
|
}
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -64,6 +66,7 @@ struct LoraModel : public GGMLModule {
|
|||||||
dry_run = false;
|
dry_run = false;
|
||||||
model_loader.load_tensors(on_new_tensor_cb, backend);
|
model_loader.load_tensors(on_new_tensor_cb, backend);
|
||||||
|
|
||||||
|
|
||||||
LOG_DEBUG("finished loaded lora");
|
LOG_DEBUG("finished loaded lora");
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -92,8 +92,10 @@ enum sd_log_level_t {
|
|||||||
};
|
};
|
||||||
|
|
||||||
typedef void (*sd_log_cb_t)(enum sd_log_level_t level, const char* text, void* data);
|
typedef void (*sd_log_cb_t)(enum sd_log_level_t level, const char* text, void* data);
|
||||||
|
typedef void (*sd_progress_cb_t)(int step,int steps,float time, void* data);
|
||||||
|
|
||||||
SD_API void sd_set_log_callback(sd_log_cb_t sd_log_cb, void* data);
|
SD_API void sd_set_log_callback(sd_log_cb_t sd_log_cb, void* data);
|
||||||
|
SD_API void sd_set_progress_callback(sd_progress_cb_t cb, void* data);
|
||||||
SD_API int32_t get_num_physical_cores();
|
SD_API int32_t get_num_physical_cores();
|
||||||
SD_API const char* sd_get_system_info();
|
SD_API const char* sd_get_system_info();
|
||||||
|
|
||||||
|
17
util.cpp
17
util.cpp
@ -161,6 +161,9 @@ int32_t get_num_physical_cores() {
|
|||||||
return n_threads > 0 ? (n_threads <= 4 ? n_threads : n_threads / 2) : 4;
|
return n_threads > 0 ? (n_threads <= 4 ? n_threads : n_threads / 2) : 4;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static sd_progress_cb_t sd_progress_cb = NULL;
|
||||||
|
void* sd_progress_cb_data = NULL;
|
||||||
|
|
||||||
std::u32string utf8_to_utf32(const std::string& utf8_str) {
|
std::u32string utf8_to_utf32(const std::string& utf8_str) {
|
||||||
std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> converter;
|
std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> converter;
|
||||||
return converter.from_bytes(utf8_str);
|
return converter.from_bytes(utf8_str);
|
||||||
@ -205,6 +208,10 @@ std::string path_join(const std::string& p1, const std::string& p2) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void pretty_progress(int step, int steps, float time) {
|
void pretty_progress(int step, int steps, float time) {
|
||||||
|
if (sd_progress_cb) {
|
||||||
|
sd_progress_cb(step,steps,time, sd_progress_cb_data);
|
||||||
|
return;
|
||||||
|
}
|
||||||
if (step == 0) {
|
if (step == 0) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -248,8 +255,9 @@ std::string trim(const std::string& s) {
|
|||||||
return rtrim(ltrim(s));
|
return rtrim(ltrim(s));
|
||||||
}
|
}
|
||||||
|
|
||||||
static sd_log_cb_t sd_log_cb = NULL;
|
static sd_log_cb_t sd_log_cb = NULL;
|
||||||
void* sd_log_cb_data = NULL;
|
void* sd_log_cb_data = NULL;
|
||||||
|
|
||||||
|
|
||||||
#define LOG_BUFFER_SIZE 1024
|
#define LOG_BUFFER_SIZE 1024
|
||||||
|
|
||||||
@ -286,7 +294,10 @@ void sd_set_log_callback(sd_log_cb_t cb, void* data) {
|
|||||||
sd_log_cb = cb;
|
sd_log_cb = cb;
|
||||||
sd_log_cb_data = data;
|
sd_log_cb_data = data;
|
||||||
}
|
}
|
||||||
|
void sd_set_progress_callback(sd_progress_cb_t cb, void* data) {
|
||||||
|
sd_progress_cb = cb;
|
||||||
|
sd_progress_cb_data = data;
|
||||||
|
}
|
||||||
const char* sd_get_system_info() {
|
const char* sd_get_system_info() {
|
||||||
static char buffer[1024];
|
static char buffer[1024];
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
|
2
vae.hpp
2
vae.hpp
@ -6,7 +6,7 @@
|
|||||||
|
|
||||||
/*================================================== AutoEncoderKL ===================================================*/
|
/*================================================== AutoEncoderKL ===================================================*/
|
||||||
|
|
||||||
#define VAE_GRAPH_SIZE 10240
|
#define VAE_GRAPH_SIZE 20480
|
||||||
|
|
||||||
class ResnetBlock : public UnaryBlock {
|
class ResnetBlock : public UnaryBlock {
|
||||||
protected:
|
protected:
|
||||||
|
Loading…
Reference in New Issue
Block a user