Skip to content

Commit ae2cbce

Browse files
committed
Refactor imatrix implementation into main example
1 parent 769d0ab commit ae2cbce

File tree

8 files changed

+88
-1197
lines changed

8 files changed

+88
-1197
lines changed

examples/CMakeLists.txt

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
11
include_directories(${CMAKE_CURRENT_SOURCE_DIR})
22

3-
add_subdirectory(cli)
4-
add_subdirectory(imatrix)
3+
add_subdirectory(cli)

examples/cli/main.cpp

+50-6
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@
2222
#define STB_IMAGE_RESIZE_STATIC
2323
#include "stb_image_resize.h"
2424

25+
#define IMATRIX_IMPL
26+
#include "imatrix.hpp"
27+
static IMatrixCollector g_collector;
28+
2529
const char* rng_type_to_str[] = {
2630
"std_default",
2731
"cuda",
@@ -129,6 +133,12 @@ struct SDParams {
129133
float slg_scale = 0.f;
130134
float skip_layer_start = 0.01f;
131135
float skip_layer_end = 0.2f;
136+
137+
/* Imatrix params */
138+
139+
std::string imatrix_out = "";
140+
141+
std::vector<std::string> imatrix_in = {};
132142
};
133143

134144
void print_params(SDParams params) {
@@ -204,6 +214,8 @@ void print_usage(int argc, const char* argv[]) {
204214
printf(" --upscale-repeats Run the ESRGAN upscaler this many times (default 1)\n");
205215
printf(" --type [TYPE] weight type (examples: f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_K, q3_K, q4_K)\n");
206216
printf(" If not specified, the default is the type of the weight file\n");
217+
printf(" --imat-out [PATH] If set, compute the imatrix for this run and save it to the provided path");
218+
printf(" --imat-in [PATH] Use imatrix for quantization.");
207219
printf(" --lora-model-dir [DIR] lora model directory\n");
208220
printf(" -i, --init-img [IMAGE] path to the input image, required by img2img\n");
209221
printf(" --mask [MASK] path to the mask image, required by img2img with mask\n");
@@ -629,6 +641,18 @@ void parse_args(int argc, const char** argv, SDParams& params) {
629641
break;
630642
}
631643
params.skip_layer_end = std::stof(argv[i]);
644+
} else if (arg == "--imat-out") {
645+
if (++i >= argc) {
646+
invalid_arg = true;
647+
break;
648+
}
649+
params.imatrix_out = argv[i];
650+
} else if (arg == "--imat-in") {
651+
if (++i >= argc) {
652+
invalid_arg = true;
653+
break;
654+
}
655+
params.imatrix_in.push_back(std::string(argv[i]));
632656
} else {
633657
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
634658
print_usage(argc, argv);
@@ -787,6 +811,10 @@ void sd_log_cb(enum sd_log_level_t level, const char* log, void* data) {
787811
fflush(out_stream);
788812
}
789813

814+
static bool collect_imatrix(struct ggml_tensor* t, bool ask, void* user_data) {
815+
return g_collector.collect_imatrix(t, ask, user_data);
816+
}
817+
790818
int main(int argc, const char* argv[]) {
791819
SDParams params;
792820

@@ -799,8 +827,21 @@ int main(int argc, const char* argv[]) {
799827
printf("%s", sd_get_system_info());
800828
}
801829

830+
if (params.imatrix_out != "") {
831+
sd_set_backend_eval_callback((sd_graph_eval_callback_t)collect_imatrix, &params);
832+
}
833+
if (params.imatrix_out != "" || params.mode == CONVERT || params.wtype != SD_TYPE_COUNT) {
834+
setConvertImatrixCollector((void*)&g_collector);
835+
for (const auto& in_file : params.imatrix_in) {
836+
printf("loading imatrix from '%s'\n", in_file.c_str());
837+
if (!g_collector.load_imatrix(in_file.c_str())) {
838+
printf("Failed to load %s\n", in_file.c_str());
839+
}
840+
}
841+
}
842+
802843
if (params.mode == CONVERT) {
803-
bool success = convert(params.model_path.c_str(), params.vae_path.c_str(), params.output_path.c_str(), params.wtype,NULL);
844+
bool success = convert(params.model_path.c_str(), params.vae_path.c_str(), params.output_path.c_str(), params.wtype);
804845
if (!success) {
805846
fprintf(stderr,
806847
"convert '%s'/'%s' to '%s' failed\n",
@@ -1075,19 +1116,19 @@ int main(int argc, const char* argv[]) {
10751116

10761117
std::string dummy_name, ext, lc_ext;
10771118
bool is_jpg;
1078-
size_t last = params.output_path.find_last_of(".");
1119+
size_t last = params.output_path.find_last_of(".");
10791120
size_t last_path = std::min(params.output_path.find_last_of("/"),
10801121
params.output_path.find_last_of("\\"));
1081-
if (last != std::string::npos // filename has extension
1082-
&& (last_path == std::string::npos || last > last_path)) {
1122+
if (last != std::string::npos // filename has extension
1123+
&& (last_path == std::string::npos || last > last_path)) {
10831124
dummy_name = params.output_path.substr(0, last);
10841125
ext = lc_ext = params.output_path.substr(last);
10851126
std::transform(ext.begin(), ext.end(), lc_ext.begin(), ::tolower);
10861127
is_jpg = lc_ext == ".jpg" || lc_ext == ".jpeg" || lc_ext == ".jpe";
10871128
} else {
10881129
dummy_name = params.output_path;
10891130
ext = lc_ext = "";
1090-
is_jpg = false;
1131+
is_jpg = false;
10911132
}
10921133
// appending ".png" to absent or unknown extension
10931134
if (!is_jpg && lc_ext != ".png") {
@@ -1099,7 +1140,7 @@ int main(int argc, const char* argv[]) {
10991140
continue;
11001141
}
11011142
std::string final_image_path = i > 0 ? dummy_name + "_" + std::to_string(i + 1) + ext : dummy_name + ext;
1102-
if(is_jpg) {
1143+
if (is_jpg) {
11031144
stbi_write_jpg(final_image_path.c_str(), results[i].width, results[i].height, results[i].channel,
11041145
results[i].data, 90, get_image_params(params, params.seed + i).c_str());
11051146
printf("save result JPEG image to '%s'\n", final_image_path.c_str());
@@ -1111,6 +1152,9 @@ int main(int argc, const char* argv[]) {
11111152
free(results[i].data);
11121153
results[i].data = NULL;
11131154
}
1155+
if (params.imatrix_out != "") {
1156+
g_collector.save_imatrix(params.imatrix_out);
1157+
}
11141158
free(results);
11151159
free_sd_ctx(sd_ctx);
11161160
free(control_image_buffer);

examples/imatrix/CMakeLists.txt

-7
This file was deleted.

0 commit comments

Comments
 (0)