22
22
#define STB_IMAGE_RESIZE_STATIC
23
23
#include " stb_image_resize.h"
24
24
25
+ #define IMATRIX_IMPL
26
+ #include " imatrix.hpp"
27
+ static IMatrixCollector g_collector;
28
+
25
29
const char * rng_type_to_str[] = {
26
30
" std_default" ,
27
31
" cuda" ,
@@ -129,6 +133,12 @@ struct SDParams {
129
133
float slg_scale = 0 .f;
130
134
float skip_layer_start = 0 .01f ;
131
135
float skip_layer_end = 0 .2f ;
136
+
137
+ /* Imatrix params */
138
+
139
+ std::string imatrix_out = " " ;
140
+
141
+ std::vector<std::string> imatrix_in = {};
132
142
};
133
143
134
144
void print_params (SDParams params) {
@@ -204,6 +214,8 @@ void print_usage(int argc, const char* argv[]) {
204
214
printf (" --upscale-repeats Run the ESRGAN upscaler this many times (default 1)\n " );
205
215
printf (" --type [TYPE] weight type (examples: f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_K, q3_K, q4_K)\n " );
206
216
printf (" If not specified, the default is the type of the weight file\n " );
217
+ printf (" --imat-out [PATH] If set, compute the imatrix for this run and save it to the provided path" );
218
+ printf (" --imat-in [PATH] Use imatrix for quantization." );
207
219
printf (" --lora-model-dir [DIR] lora model directory\n " );
208
220
printf (" -i, --init-img [IMAGE] path to the input image, required by img2img\n " );
209
221
printf (" --mask [MASK] path to the mask image, required by img2img with mask\n " );
@@ -629,6 +641,18 @@ void parse_args(int argc, const char** argv, SDParams& params) {
629
641
break ;
630
642
}
631
643
params.skip_layer_end = std::stof (argv[i]);
644
+ } else if (arg == " --imat-out" ) {
645
+ if (++i >= argc) {
646
+ invalid_arg = true ;
647
+ break ;
648
+ }
649
+ params.imatrix_out = argv[i];
650
+ } else if (arg == " --imat-in" ) {
651
+ if (++i >= argc) {
652
+ invalid_arg = true ;
653
+ break ;
654
+ }
655
+ params.imatrix_in .push_back (std::string (argv[i]));
632
656
} else {
633
657
fprintf (stderr, " error: unknown argument: %s\n " , arg.c_str ());
634
658
print_usage (argc, argv);
@@ -787,6 +811,10 @@ void sd_log_cb(enum sd_log_level_t level, const char* log, void* data) {
787
811
fflush (out_stream);
788
812
}
789
813
814
+ static bool collect_imatrix (struct ggml_tensor * t, bool ask, void * user_data) {
815
+ return g_collector.collect_imatrix (t, ask, user_data);
816
+ }
817
+
790
818
int main (int argc, const char * argv[]) {
791
819
SDParams params;
792
820
@@ -799,8 +827,21 @@ int main(int argc, const char* argv[]) {
799
827
printf (" %s" , sd_get_system_info ());
800
828
}
801
829
830
+ if (params.imatrix_out != " " ) {
831
+ sd_set_backend_eval_callback ((sd_graph_eval_callback_t )collect_imatrix, ¶ms);
832
+ }
833
+ if (params.imatrix_out != " " || params.mode == CONVERT || params.wtype != SD_TYPE_COUNT) {
834
+ setConvertImatrixCollector ((void *)&g_collector);
835
+ for (const auto & in_file : params.imatrix_in ) {
836
+ printf (" loading imatrix from '%s'\n " , in_file.c_str ());
837
+ if (!g_collector.load_imatrix (in_file.c_str ())) {
838
+ printf (" Failed to load %s\n " , in_file.c_str ());
839
+ }
840
+ }
841
+ }
842
+
802
843
if (params.mode == CONVERT) {
803
- bool success = convert (params.model_path .c_str (), params.vae_path .c_str (), params.output_path .c_str (), params.wtype , NULL );
844
+ bool success = convert (params.model_path .c_str (), params.vae_path .c_str (), params.output_path .c_str (), params.wtype );
804
845
if (!success) {
805
846
fprintf (stderr,
806
847
" convert '%s'/'%s' to '%s' failed\n " ,
@@ -1075,19 +1116,19 @@ int main(int argc, const char* argv[]) {
1075
1116
1076
1117
std::string dummy_name, ext, lc_ext;
1077
1118
bool is_jpg;
1078
- size_t last = params.output_path .find_last_of (" ." );
1119
+ size_t last = params.output_path .find_last_of (" ." );
1079
1120
size_t last_path = std::min (params.output_path .find_last_of (" /" ),
1080
1121
params.output_path .find_last_of (" \\ " ));
1081
- if (last != std::string::npos // filename has extension
1082
- && (last_path == std::string::npos || last > last_path)) {
1122
+ if (last != std::string::npos // filename has extension
1123
+ && (last_path == std::string::npos || last > last_path)) {
1083
1124
dummy_name = params.output_path .substr (0 , last);
1084
1125
ext = lc_ext = params.output_path .substr (last);
1085
1126
std::transform (ext.begin (), ext.end (), lc_ext.begin (), ::tolower);
1086
1127
is_jpg = lc_ext == " .jpg" || lc_ext == " .jpeg" || lc_ext == " .jpe" ;
1087
1128
} else {
1088
1129
dummy_name = params.output_path ;
1089
1130
ext = lc_ext = " " ;
1090
- is_jpg = false ;
1131
+ is_jpg = false ;
1091
1132
}
1092
1133
// appending ".png" to absent or unknown extension
1093
1134
if (!is_jpg && lc_ext != " .png" ) {
@@ -1099,7 +1140,7 @@ int main(int argc, const char* argv[]) {
1099
1140
continue ;
1100
1141
}
1101
1142
std::string final_image_path = i > 0 ? dummy_name + " _" + std::to_string (i + 1 ) + ext : dummy_name + ext;
1102
- if (is_jpg) {
1143
+ if (is_jpg) {
1103
1144
stbi_write_jpg (final_image_path.c_str (), results[i].width , results[i].height , results[i].channel ,
1104
1145
results[i].data , 90 , get_image_params (params, params.seed + i).c_str ());
1105
1146
printf (" save result JPEG image to '%s'\n " , final_image_path.c_str ());
@@ -1111,6 +1152,9 @@ int main(int argc, const char* argv[]) {
1111
1152
free (results[i].data );
1112
1153
results[i].data = NULL ;
1113
1154
}
1155
+ if (params.imatrix_out != " " ) {
1156
+ g_collector.save_imatrix (params.imatrix_out );
1157
+ }
1114
1158
free (results);
1115
1159
free_sd_ctx (sd_ctx);
1116
1160
free (control_image_buffer);
0 commit comments