@@ -249,7 +249,7 @@ struct running_stats_printer_t {
249249 timestamp_t time = std::chrono::high_resolution_clock::now ();
250250 std::size_t duration = std::chrono::duration_cast<std::chrono::nanoseconds>(time - start_time).count ();
251251 float vectors_per_second = count * 1e9 / duration;
252- std::printf (" \r\33 [2K100 %% completed, %.0f vectors/s\n " , vectors_per_second);
252+ std::printf (" \r\33 [2K100 %% completed, %.0f vectors/s, %.1f s \n " , vectors_per_second, duration / 1e9 );
253253 }
254254
255255 void refresh (std::size_t step = 1024 * 32 ) {
@@ -286,8 +286,9 @@ struct running_stats_printer_t {
286286 }
287287};
288288
289- template <typename index_at, typename vector_id_at, typename scalar_at>
290- void index_many (index_at& index, std::size_t n, vector_id_at const * ids, scalar_at const * vectors, std::size_t dims) {
289+ template <typename index_at, typename vector_id_at, typename scalar_at, typename add_at>
290+ void index_many (index_at& index, std::size_t n, vector_id_at const * ids, scalar_at const * vectors, std::size_t dims,
291+ add_at&& add) {
291292
292293 running_stats_printer_t printer{n, " Indexing" };
293294
@@ -300,7 +301,7 @@ void index_many(index_at& index, std::size_t n, vector_id_at const* ids, scalar_
300301 config.thread = omp_get_thread_num ();
301302#endif
302303 span_gt<scalar_at const > vector{vectors + dims * i, dims};
303- index. add (ids[i], vector, config. thread );
304+ add (index, ids[i], vector, config);
304305 printer.progress ++;
305306 if (config.thread == 0 )
306307 printer.refresh ();
@@ -310,10 +311,10 @@ void index_many(index_at& index, std::size_t n, vector_id_at const* ids, scalar_
310311 printer.print ();
311312}
312313
313- template <typename index_at, typename vector_id_at, typename scalar_at, typename distance_at>
314+ template <typename index_at, typename vector_id_at, typename scalar_at, typename distance_at, typename search_at >
314315void search_many ( //
315316 index_at& index, std::size_t n, scalar_at const * vectors, std::size_t dims, std::size_t wanted, vector_id_at* ids,
316- distance_at* distances) {
317+ distance_at* distances, search_at&& search ) {
317318
318319 std::string name = " Search " + std::to_string (wanted);
319320 running_stats_printer_t printer{n, name.c_str ()};
@@ -327,7 +328,8 @@ void search_many( //
327328 config.thread = omp_get_thread_num ();
328329#endif
329330 span_gt<scalar_at const > vector{vectors + dims * i, dims};
330- index.search (vector, wanted, config.thread ).dump_to (ids + wanted * i, distances + wanted * i);
331+ typename index_at::search_result_t search_result = search (index, vector, wanted, config);
332+ search_result.dump_to (ids + wanted * i, distances + wanted * i);
331333 printer.progress ++;
332334 if (config.thread == 0 )
333335 printer.refresh ();
@@ -337,8 +339,8 @@ void search_many( //
337339 printer.print ();
338340}
339341
340- template <typename dataset_at, typename index_at> //
341- static void single_shot (dataset_at& dataset, index_at& index, bool construct = true ) {
342+ template <typename dataset_at, typename index_at, typename add_at, typename search_at > //
343+ static void single_shot (dataset_at& dataset, index_at& index, bool construct, add_at&& add, search_at&& search ) {
342344 using distance_t = typename index_at::distance_t ;
343345 constexpr default_key_t missing_key = std::numeric_limits<default_key_t >::max ();
344346
@@ -348,14 +350,14 @@ static void single_shot(dataset_at& dataset, index_at& index, bool construct = t
348350 // Perform insertions, evaluate speed
349351 std::vector<default_key_t > ids (dataset.vectors_count ());
350352 std::iota (ids.begin (), ids.end (), 0 );
351- index_many (index, dataset.vectors_count (), ids.data (), dataset.vector (0 ), dataset.dimensions ());
353+ index_many (index, dataset.vectors_count (), ids.data (), dataset.vector (0 ), dataset.dimensions (), add );
352354 }
353355
354356 // Perform search, evaluate speed
355357 std::vector<default_key_t > found_neighbors (dataset.queries_count () * dataset.neighborhood_size ());
356358 std::vector<distance_t > found_distances (dataset.queries_count () * dataset.neighborhood_size ());
357359 search_many (index, dataset.queries_count (), dataset.query (0 ), dataset.dimensions (), dataset.neighborhood_size (),
358- found_neighbors.data (), found_distances.data ());
360+ found_neighbors.data (), found_distances.data (), search );
359361
360362 // Evaluate search quality
361363 std::size_t recall_at_1 = 0 , recall_full = 0 ;
@@ -369,43 +371,45 @@ static void single_shot(dataset_at& dataset, index_at& index, bool construct = t
369371 std::printf (" Recall@1 %.2f %%\n " , recall_at_1 * 100 .f / dataset.queries_count ());
370372 std::printf (" Recall %.2f %%\n " , recall_full * 100 .f / dataset.queries_count ());
371373
372- // Perform joins
373- std::vector<default_key_t > man_to_woman (dataset.vectors_count ());
374- std::vector<default_key_t > woman_to_man (dataset.vectors_count ());
375- std::size_t join_attempts = 0 ;
376- {
377- index_at& men = index;
378- index_at women = index.copy ();
379- std::fill (man_to_woman.begin (), man_to_woman.end (), missing_key);
380- std::fill (woman_to_man.begin (), woman_to_man.end (), missing_key);
374+ if constexpr (!std::is_same_v<index_at, index_gt<>>) {
375+ // Perform joins
376+ std::vector<default_key_t > man_to_woman (dataset.vectors_count ());
377+ std::vector<default_key_t > woman_to_man (dataset.vectors_count ());
378+ std::size_t join_attempts = 0 ;
381379 {
382- executor_default_t executor (index.limits ().threads ());
383- running_stats_printer_t printer{1 , " Join" };
384- join_result_t result = join ( //
385- men, women, index_join_config_t {executor.size ()}, //
386- man_to_woman.data (), woman_to_man.data (), //
387- executor, [&](std::size_t progress, std::size_t total) {
388- if (progress % 1000 == 0 )
389- printer.print (progress, total);
390- return true ;
391- });
392- // Refresh once again to show 100% completion
393- printer.print ();
394- join_attempts = result.visited_members ;
380+ index_at& men = index;
381+ index_at women = index.copy ();
382+ std::fill (man_to_woman.begin (), man_to_woman.end (), missing_key);
383+ std::fill (woman_to_man.begin (), woman_to_man.end (), missing_key);
384+ {
385+ executor_default_t executor (index.limits ().threads ());
386+ running_stats_printer_t printer{1 , " Join" };
387+ join_result_t result = join ( //
388+ men, women, index_join_config_t {executor.size ()}, //
389+ man_to_woman.data (), woman_to_man.data (), //
390+ executor, [&](std::size_t progress, std::size_t total) {
391+ if (progress % 1000 == 0 )
392+ printer.print (progress, total);
393+ return true ;
394+ });
395+ // Refresh once again to show 100% completion
396+ printer.print ();
397+ join_attempts = result.visited_members ;
398+ }
395399 }
396- }
397- // Evaluate join quality
398- std::size_t recall_join = 0 , unmatched_count = 0 ;
399- for (std::size_t i = 0 ; i != index.size (); ++i) {
400- recall_join += man_to_woman[i] == static_cast <default_key_t >(i);
401- unmatched_count += man_to_woman[i] == missing_key;
402- }
403- std::printf (" Recall Joins %.2f %%\n " , recall_join * 100 .f / index.size ());
404- std::printf (" Unmatched %.2f %% (%zu items)\n " , unmatched_count * 100 .f / index.size (), unmatched_count);
405- std::printf (" Proposals %.2f / man (%zu total)\n " , join_attempts * 1 .f / index.size (), join_attempts);
400+ // Evaluate join quality
401+ std::size_t recall_join = 0 , unmatched_count = 0 ;
402+ for (std::size_t i = 0 ; i != index.size (); ++i) {
403+ recall_join += man_to_woman[i] == static_cast <default_key_t >(i);
404+ unmatched_count += man_to_woman[i] == missing_key;
405+ }
406+ std::printf (" Recall Joins %.2f %%\n " , recall_join * 100 .f / index.size ());
407+ std::printf (" Unmatched %.2f %% (%zu items)\n " , unmatched_count * 100 .f / index.size (), unmatched_count);
408+ std::printf (" Proposals %.2f / man (%zu total)\n " , join_attempts * 1 .f / index.size (), join_attempts);
406409
407- std::printf (" ------------\n " );
408- std::printf (" \n " );
410+ std::printf (" ------------\n " );
411+ std::printf (" \n " );
412+ }
409413}
410414
411415void handler (int sig) {
@@ -468,6 +472,9 @@ struct args_t {
468472
469473 bool big = false ;
470474
475+ bool index_gt_api = false ;
476+ std::size_t chunks_to_merge = 0 ;
477+
471478 bool quantize_bf16 = false ;
472479 bool quantize_f16 = false ;
473480 bool quantize_i8 = false ;
@@ -516,6 +523,8 @@ struct args_t {
516523template <typename index_at, typename dataset_at> //
517524void run_punned (dataset_at& dataset, args_t const & args, index_dense_config_t config, index_limits_t limits) {
518525
526+ using scalar_t = typename dataset_at::scalar_t ;
527+
519528 scalar_kind_t quantization = args.quantization ();
520529 std::printf (" -- Quantization: %s\n " , scalar_kind_name (quantization));
521530
@@ -528,31 +537,141 @@ void run_punned(dataset_at& dataset, args_t const& args, index_dense_config_t co
528537 std::printf (" -- Hardware acceleration: %s\n " , index.metric ().isa_name ());
529538 std::printf (" Will benchmark in-memory\n " );
530539
531- single_shot (dataset, index, true );
540+ auto add = [&](index_at& index, default_key_t id, span_gt<scalar_t const > vector, index_update_config_t config) {
541+ index.add (id, vector, config.thread );
542+ };
543+ auto search = [&](index_at& index, span_gt<scalar_t const > vector, std::size_t wanted,
544+ index_search_config_t config) { return index.search (vector, wanted, config.thread ); };
545+
546+ single_shot (dataset, index, true , add, search);
532547 index.save (args.path_output .c_str ());
533548
534549 std::printf (" Will benchmark an on-disk view\n " );
535550
536551 index_at index_view = index.fork ();
537552 index_view.view (args.path_output .c_str ());
538- single_shot (dataset, index_view, false );
553+ single_shot (dataset, index_view, false , add, search );
539554}
540555
541556template <typename index_at, typename dataset_at> //
542557void run_typed (dataset_at& dataset, args_t const & args, index_config_t config, index_limits_t limits) {
558+ using distance_t = typename index_at::distance_t ;
559+ using member_ref_t = typename index_at::member_ref_t ;
560+ using member_cref_t = typename index_at::member_cref_t ;
561+ using member_citerator_t = typename index_at::member_citerator_t ;
562+
563+ using scalar_t = typename dataset_at::scalar_t ;
564+
565+ scalar_kind_t quantization = args.quantization ();
566+ std::printf (" -- Quantization: %s\n " , scalar_kind_name (quantization));
567+
568+ metric_kind_t kind = args.metric ();
569+ std::printf (" -- Metric: %s\n " , metric_kind_name (kind));
570+
571+ metric_punned_t metric_punned (dataset.dimensions (), kind, quantization);
572+ buffer_gt<byte_t const *> values (limits.members );
573+
574+ std::printf (" -- Hardware acceleration: %s\n " , metric_punned.isa_name ());
575+
576+ class metric_t {
577+ metric_punned_t metric_;
578+ buffer_gt<byte_t const *>& values_;
579+
580+ public:
581+ metric_t (metric_punned_t metric, buffer_gt<byte_t const *>& values) noexcept
582+ : metric_(metric), values_(values) {}
583+
584+ inline distance_t operator ()(byte_t const * a, member_cref_t b) const noexcept { return f (a, v (b)); }
585+ inline distance_t operator ()(member_cref_t a, member_cref_t b) const noexcept { return f (v (a), v (b)); }
586+
587+ inline distance_t operator ()(byte_t const * a, member_citerator_t b) const noexcept { return f (a, v (b)); }
588+ inline distance_t operator ()(member_citerator_t a, member_citerator_t b) const noexcept {
589+ return f (v (a), v (b));
590+ }
591+
592+ inline distance_t operator ()(byte_t const * a, byte_t const * b) const noexcept { return f (a, b); }
593+
594+ inline byte_t const * v (member_cref_t m) const noexcept { return values_[get_slot (m)]; }
595+ inline byte_t const * v (member_citerator_t m) const noexcept { return values_[get_slot (m)]; }
596+ inline distance_t f (byte_t const * a, byte_t const * b) const noexcept { return metric_ (a, b); }
597+ };
598+ metric_t metric{metric_punned, values};
599+
600+ auto add = [&](index_at& index, default_key_t id, span_gt<scalar_t const > vector, index_update_config_t & config) {
601+ byte_t const * vector_data = reinterpret_cast <byte_t const *>(vector.data ());
602+ auto on_success = [&](member_ref_t member) { values[member.slot ] = vector_data; };
603+ index.add (id, vector_data, metric, config, on_success);
604+ };
605+ auto search = [&](index_at& index, span_gt<scalar_t const > vector, std::size_t wanted,
606+ index_search_config_t config) {
607+ byte_t const * vector_data = reinterpret_cast <byte_t const *>(vector.data ());
608+ return index.search (vector_data, wanted, metric, config);
609+ };
543610
544611 index_at index (config);
545- index.reserve (limits);
546- std::printf (" Will benchmark in-memory\n " );
612+ if (args.chunks_to_merge > 0 ) {
613+ index.save (args.path_output .c_str ());
614+ memory_mapped_file_t output{args.path_output .c_str (), true };
615+ index.load (std::move (output));
616+ index.reserve (limits);
617+ std::printf (" Will benchmark merge: %zu\n " , args.chunks_to_merge );
547618
548- single_shot (dataset, index, true );
549- index. save (args. path_output . c_str () );
619+ std::printf ( " \n " );
620+ std::printf ( " ------------ \n " );
550621
551- std::printf (" Will benchmark an on-disk view\n " );
622+ {
623+ // Perform insertions, evaluate speed
624+ std::vector<default_key_t > ids (dataset.vectors_count ());
625+ std::iota (ids.begin (), ids.end (), 0 );
626+ std::vector<index_at> subindexes;
627+ std::vector<buffer_gt<byte_t const *>> subvalues;
628+ std::size_t offset = 0 ;
629+ std::size_t chunk_rows = dataset.vectors_count () / args.chunks_to_merge ;
630+ for (std::size_t i = 0 ; i != args.chunks_to_merge ; ++i) {
631+ subindexes.emplace_back (config);
632+ index_at& subindex = subindexes[i];
633+ std::size_t n = std::min (chunk_rows, dataset.vectors_count () - offset);
634+ subindex.reserve (n);
635+ subvalues.emplace_back (n);
636+ buffer_gt<byte_t const *>& subvs = subvalues[i];
637+ metric_t submetric{metric_punned, subvs};
638+ auto subadd = [&](index_at& index, default_key_t id, span_gt<scalar_t const > vector,
639+ index_update_config_t & config) {
640+ byte_t const * vector_data = reinterpret_cast <byte_t const *>(vector.data ());
641+ auto on_success = [&](member_ref_t member) { subvs[member.slot ] = vector_data; };
642+ index.add (id, vector_data, submetric, config, on_success);
643+ };
644+ index_many (subindex, n, ids.data () + offset, dataset.vector (offset), dataset.dimensions (), subadd);
645+ offset += chunk_rows;
646+ }
647+ {
648+ running_stats_printer_t printer{dataset.vectors_count (), " Merging" };
649+ auto merge_on_success = [&](member_ref_t member, byte_t const * value) { values[member.slot ] = value; };
650+ for (std::size_t i = 0 ; i != args.chunks_to_merge ; ++i) {
651+ buffer_gt<byte_t const *>& subvs = subvalues[i];
652+ auto get_value = [&](member_cref_t member) { return subvs[member.slot ]; };
653+ index.merge (subindexes[i], get_value, metric, {}, merge_on_success);
654+ printer.progress += subindexes[i].size ();
655+ printer.refresh ();
656+ }
657+ printer.print ();
658+ }
659+ }
660+ // Perform searches, evaluate speed
661+ single_shot (dataset, index, false , add, search);
662+ } else {
663+ index.reserve (limits);
664+ std::printf (" Will benchmark in-memory\n " );
552665
553- index_at index_view = index.fork ();
554- index_view.view (args.path_output .c_str ());
555- single_shot (dataset, index_view, false );
666+ single_shot (dataset, index, true , add, search);
667+ index.save (args.path_output .c_str ());
668+
669+ std::printf (" Will benchmark an on-disk view\n " );
670+
671+ index_at index_view = index.fork ();
672+ index_view.view (args.path_output .c_str ());
673+ single_shot (dataset, index_view, false , add, search);
674+ }
556675}
557676
558677template <typename dataset_scalar_at> void bench_with_args (args_t const & args) {
@@ -583,14 +702,18 @@ template <typename dataset_scalar_at> void bench_with_args(args_t const& args) {
583702 std::printf (" -- Expansion @ Add: %zu\n " , config.expansion_add );
584703 std::printf (" -- Expansion @ Search: %zu\n " , config.expansion_search );
585704
586- if (args.big )
705+ if (args.index_gt_api ) {
706+ run_typed<index_gt<>>(dataset, args, config, limits);
707+ } else {
708+ if (args.big )
587709#ifdef USEARCH_64BIT_ENV
588- run_punned<index_dense_gt<default_key_t , uint40_t >>(dataset, args, config, limits);
710+ run_punned<index_dense_gt<default_key_t , uint40_t >>(dataset, args, config, limits);
589711#else
590- std::printf (" Error: Don't use 40 bit identifiers in 32bit environment\n " );
712+ std::printf (" Error: Don't use 40 bit identifiers in 32bit environment\n " );
591713#endif
592- else
593- run_punned<index_dense_gt<default_key_t , std::uint32_t >>(dataset, args, config, limits);
714+ else
715+ run_punned<index_dense_gt<default_key_t , std::uint32_t >>(dataset, args, config, limits);
716+ }
594717}
595718
596719int main (int argc, char ** argv) {
@@ -613,6 +736,9 @@ int main(int argc, char** argv) {
613736 (option (" --expansion-search" ) & value (" integer" , args.expansion_search )).doc (" Affects search depth" ),
614737 (option (" --rows-skip" ) & value (" integer" , args.vectors_to_skip )).doc (" Number of vectors to skip" ),
615738 (option (" --rows-take" ) & value (" integer" , args.vectors_to_take )).doc (" Number of vectors to take" ),
739+ option (" --index-gt-api" ).set (args.index_gt_api ).doc (" Use index_gt<> API not index_dense_gt API" ),
740+ (option (" --chunks-to-merge" ) & value (" integer" , args.chunks_to_merge ))
741+ .doc (" Number of chunks to merge. This requires --index-gt-api" ),
616742 ( //
617743 option (" -bf16" , " --bf16quant" ).set (args.quantize_bf16 ).doc (" Enable `bf16_t` quantization" ) |
618744 option (" -f16" , " --f16quant" ).set (args.quantize_f16 ).doc (" Enable `f16_t` quantization" ) |
0 commit comments