Skip to content

Commit 56e77ae

Browse files
committed
Add support for benchmarking merge
1 parent fdb2aae commit 56e77ae

File tree

2 files changed

+189
-63
lines changed

2 files changed

+189
-63
lines changed

cpp/bench.cpp

Lines changed: 186 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ struct running_stats_printer_t {
249249
timestamp_t time = std::chrono::high_resolution_clock::now();
250250
std::size_t duration = std::chrono::duration_cast<std::chrono::nanoseconds>(time - start_time).count();
251251
float vectors_per_second = count * 1e9 / duration;
252-
std::printf("\r\33[2K100 %% completed, %.0f vectors/s\n", vectors_per_second);
252+
std::printf("\r\33[2K100 %% completed, %.0f vectors/s, %.1f s\n", vectors_per_second, duration / 1e9);
253253
}
254254

255255
void refresh(std::size_t step = 1024 * 32) {
@@ -286,8 +286,9 @@ struct running_stats_printer_t {
286286
}
287287
};
288288

289-
template <typename index_at, typename vector_id_at, typename scalar_at>
290-
void index_many(index_at& index, std::size_t n, vector_id_at const* ids, scalar_at const* vectors, std::size_t dims) {
289+
template <typename index_at, typename vector_id_at, typename scalar_at, typename add_at>
290+
void index_many(index_at& index, std::size_t n, vector_id_at const* ids, scalar_at const* vectors, std::size_t dims,
291+
add_at&& add) {
291292

292293
running_stats_printer_t printer{n, "Indexing"};
293294

@@ -300,7 +301,7 @@ void index_many(index_at& index, std::size_t n, vector_id_at const* ids, scalar_
300301
config.thread = omp_get_thread_num();
301302
#endif
302303
span_gt<scalar_at const> vector{vectors + dims * i, dims};
303-
index.add(ids[i], vector, config.thread);
304+
add(index, ids[i], vector, config);
304305
printer.progress++;
305306
if (config.thread == 0)
306307
printer.refresh();
@@ -310,10 +311,10 @@ void index_many(index_at& index, std::size_t n, vector_id_at const* ids, scalar_
310311
printer.print();
311312
}
312313

313-
template <typename index_at, typename vector_id_at, typename scalar_at, typename distance_at>
314+
template <typename index_at, typename vector_id_at, typename scalar_at, typename distance_at, typename search_at>
314315
void search_many( //
315316
index_at& index, std::size_t n, scalar_at const* vectors, std::size_t dims, std::size_t wanted, vector_id_at* ids,
316-
distance_at* distances) {
317+
distance_at* distances, search_at&& search) {
317318

318319
std::string name = "Search " + std::to_string(wanted);
319320
running_stats_printer_t printer{n, name.c_str()};
@@ -327,7 +328,8 @@ void search_many( //
327328
config.thread = omp_get_thread_num();
328329
#endif
329330
span_gt<scalar_at const> vector{vectors + dims * i, dims};
330-
index.search(vector, wanted, config.thread).dump_to(ids + wanted * i, distances + wanted * i);
331+
typename index_at::search_result_t search_result = search(index, vector, wanted, config);
332+
search_result.dump_to(ids + wanted * i, distances + wanted * i);
331333
printer.progress++;
332334
if (config.thread == 0)
333335
printer.refresh();
@@ -337,8 +339,8 @@ void search_many( //
337339
printer.print();
338340
}
339341

340-
template <typename dataset_at, typename index_at> //
341-
static void single_shot(dataset_at& dataset, index_at& index, bool construct = true) {
342+
template <typename dataset_at, typename index_at, typename add_at, typename search_at> //
343+
static void single_shot(dataset_at& dataset, index_at& index, bool construct, add_at&& add, search_at&& search) {
342344
using distance_t = typename index_at::distance_t;
343345
constexpr default_key_t missing_key = std::numeric_limits<default_key_t>::max();
344346

@@ -348,14 +350,14 @@ static void single_shot(dataset_at& dataset, index_at& index, bool construct = t
348350
// Perform insertions, evaluate speed
349351
std::vector<default_key_t> ids(dataset.vectors_count());
350352
std::iota(ids.begin(), ids.end(), 0);
351-
index_many(index, dataset.vectors_count(), ids.data(), dataset.vector(0), dataset.dimensions());
353+
index_many(index, dataset.vectors_count(), ids.data(), dataset.vector(0), dataset.dimensions(), add);
352354
}
353355

354356
// Perform search, evaluate speed
355357
std::vector<default_key_t> found_neighbors(dataset.queries_count() * dataset.neighborhood_size());
356358
std::vector<distance_t> found_distances(dataset.queries_count() * dataset.neighborhood_size());
357359
search_many(index, dataset.queries_count(), dataset.query(0), dataset.dimensions(), dataset.neighborhood_size(),
358-
found_neighbors.data(), found_distances.data());
360+
found_neighbors.data(), found_distances.data(), search);
359361

360362
// Evaluate search quality
361363
std::size_t recall_at_1 = 0, recall_full = 0;
@@ -369,43 +371,45 @@ static void single_shot(dataset_at& dataset, index_at& index, bool construct = t
369371
std::printf("Recall@1 %.2f %%\n", recall_at_1 * 100.f / dataset.queries_count());
370372
std::printf("Recall %.2f %%\n", recall_full * 100.f / dataset.queries_count());
371373

372-
// Perform joins
373-
std::vector<default_key_t> man_to_woman(dataset.vectors_count());
374-
std::vector<default_key_t> woman_to_man(dataset.vectors_count());
375-
std::size_t join_attempts = 0;
376-
{
377-
index_at& men = index;
378-
index_at women = index.copy();
379-
std::fill(man_to_woman.begin(), man_to_woman.end(), missing_key);
380-
std::fill(woman_to_man.begin(), woman_to_man.end(), missing_key);
374+
if constexpr (!std::is_same_v<index_at, index_gt<>>) {
375+
// Perform joins
376+
std::vector<default_key_t> man_to_woman(dataset.vectors_count());
377+
std::vector<default_key_t> woman_to_man(dataset.vectors_count());
378+
std::size_t join_attempts = 0;
381379
{
382-
executor_default_t executor(index.limits().threads());
383-
running_stats_printer_t printer{1, "Join"};
384-
join_result_t result = join( //
385-
men, women, index_join_config_t{executor.size()}, //
386-
man_to_woman.data(), woman_to_man.data(), //
387-
executor, [&](std::size_t progress, std::size_t total) {
388-
if (progress % 1000 == 0)
389-
printer.print(progress, total);
390-
return true;
391-
});
392-
// Refresh once again to show 100% completion
393-
printer.print();
394-
join_attempts = result.visited_members;
380+
index_at& men = index;
381+
index_at women = index.copy();
382+
std::fill(man_to_woman.begin(), man_to_woman.end(), missing_key);
383+
std::fill(woman_to_man.begin(), woman_to_man.end(), missing_key);
384+
{
385+
executor_default_t executor(index.limits().threads());
386+
running_stats_printer_t printer{1, "Join"};
387+
join_result_t result = join( //
388+
men, women, index_join_config_t{executor.size()}, //
389+
man_to_woman.data(), woman_to_man.data(), //
390+
executor, [&](std::size_t progress, std::size_t total) {
391+
if (progress % 1000 == 0)
392+
printer.print(progress, total);
393+
return true;
394+
});
395+
// Refresh once again to show 100% completion
396+
printer.print();
397+
join_attempts = result.visited_members;
398+
}
395399
}
396-
}
397-
// Evaluate join quality
398-
std::size_t recall_join = 0, unmatched_count = 0;
399-
for (std::size_t i = 0; i != index.size(); ++i) {
400-
recall_join += man_to_woman[i] == static_cast<default_key_t>(i);
401-
unmatched_count += man_to_woman[i] == missing_key;
402-
}
403-
std::printf("Recall Joins %.2f %%\n", recall_join * 100.f / index.size());
404-
std::printf("Unmatched %.2f %% (%zu items)\n", unmatched_count * 100.f / index.size(), unmatched_count);
405-
std::printf("Proposals %.2f / man (%zu total)\n", join_attempts * 1.f / index.size(), join_attempts);
400+
// Evaluate join quality
401+
std::size_t recall_join = 0, unmatched_count = 0;
402+
for (std::size_t i = 0; i != index.size(); ++i) {
403+
recall_join += man_to_woman[i] == static_cast<default_key_t>(i);
404+
unmatched_count += man_to_woman[i] == missing_key;
405+
}
406+
std::printf("Recall Joins %.2f %%\n", recall_join * 100.f / index.size());
407+
std::printf("Unmatched %.2f %% (%zu items)\n", unmatched_count * 100.f / index.size(), unmatched_count);
408+
std::printf("Proposals %.2f / man (%zu total)\n", join_attempts * 1.f / index.size(), join_attempts);
406409

407-
std::printf("------------\n");
408-
std::printf("\n");
410+
std::printf("------------\n");
411+
std::printf("\n");
412+
}
409413
}
410414

411415
void handler(int sig) {
@@ -468,6 +472,9 @@ struct args_t {
468472

469473
bool big = false;
470474

475+
bool index_gt_api = false;
476+
std::size_t chunks_to_merge = 0;
477+
471478
bool quantize_bf16 = false;
472479
bool quantize_f16 = false;
473480
bool quantize_i8 = false;
@@ -516,6 +523,8 @@ struct args_t {
516523
template <typename index_at, typename dataset_at> //
517524
void run_punned(dataset_at& dataset, args_t const& args, index_dense_config_t config, index_limits_t limits) {
518525

526+
using scalar_t = typename dataset_at::scalar_t;
527+
519528
scalar_kind_t quantization = args.quantization();
520529
std::printf("-- Quantization: %s\n", scalar_kind_name(quantization));
521530

@@ -528,31 +537,141 @@ void run_punned(dataset_at& dataset, args_t const& args, index_dense_config_t co
528537
std::printf("-- Hardware acceleration: %s\n", index.metric().isa_name());
529538
std::printf("Will benchmark in-memory\n");
530539

531-
single_shot(dataset, index, true);
540+
auto add = [&](index_at& index, default_key_t id, span_gt<scalar_t const> vector, index_update_config_t config) {
541+
index.add(id, vector, config.thread);
542+
};
543+
auto search = [&](index_at& index, span_gt<scalar_t const> vector, std::size_t wanted,
544+
index_search_config_t config) { return index.search(vector, wanted, config.thread); };
545+
546+
single_shot(dataset, index, true, add, search);
532547
index.save(args.path_output.c_str());
533548

534549
std::printf("Will benchmark an on-disk view\n");
535550

536551
index_at index_view = index.fork();
537552
index_view.view(args.path_output.c_str());
538-
single_shot(dataset, index_view, false);
553+
single_shot(dataset, index_view, false, add, search);
539554
}
540555

541556
template <typename index_at, typename dataset_at> //
542557
void run_typed(dataset_at& dataset, args_t const& args, index_config_t config, index_limits_t limits) {
558+
using distance_t = typename index_at::distance_t;
559+
using member_ref_t = typename index_at::member_ref_t;
560+
using member_cref_t = typename index_at::member_cref_t;
561+
using member_citerator_t = typename index_at::member_citerator_t;
562+
563+
using scalar_t = typename dataset_at::scalar_t;
564+
565+
scalar_kind_t quantization = args.quantization();
566+
std::printf("-- Quantization: %s\n", scalar_kind_name(quantization));
567+
568+
metric_kind_t kind = args.metric();
569+
std::printf("-- Metric: %s\n", metric_kind_name(kind));
570+
571+
metric_punned_t metric_punned(dataset.dimensions(), kind, quantization);
572+
buffer_gt<byte_t const*> values(limits.members);
573+
574+
std::printf("-- Hardware acceleration: %s\n", metric_punned.isa_name());
575+
576+
class metric_t {
577+
metric_punned_t metric_;
578+
buffer_gt<byte_t const*>& values_;
579+
580+
public:
581+
metric_t(metric_punned_t metric, buffer_gt<byte_t const*>& values) noexcept
582+
: metric_(metric), values_(values) {}
583+
584+
inline distance_t operator()(byte_t const* a, member_cref_t b) const noexcept { return f(a, v(b)); }
585+
inline distance_t operator()(member_cref_t a, member_cref_t b) const noexcept { return f(v(a), v(b)); }
586+
587+
inline distance_t operator()(byte_t const* a, member_citerator_t b) const noexcept { return f(a, v(b)); }
588+
inline distance_t operator()(member_citerator_t a, member_citerator_t b) const noexcept {
589+
return f(v(a), v(b));
590+
}
591+
592+
inline distance_t operator()(byte_t const* a, byte_t const* b) const noexcept { return f(a, b); }
593+
594+
inline byte_t const* v(member_cref_t m) const noexcept { return values_[get_slot(m)]; }
595+
inline byte_t const* v(member_citerator_t m) const noexcept { return values_[get_slot(m)]; }
596+
inline distance_t f(byte_t const* a, byte_t const* b) const noexcept { return metric_(a, b); }
597+
};
598+
metric_t metric{metric_punned, values};
599+
600+
auto add = [&](index_at& index, default_key_t id, span_gt<scalar_t const> vector, index_update_config_t& config) {
601+
byte_t const* vector_data = reinterpret_cast<byte_t const*>(vector.data());
602+
auto on_success = [&](member_ref_t member) { values[member.slot] = vector_data; };
603+
index.add(id, vector_data, metric, config, on_success);
604+
};
605+
auto search = [&](index_at& index, span_gt<scalar_t const> vector, std::size_t wanted,
606+
index_search_config_t config) {
607+
byte_t const* vector_data = reinterpret_cast<byte_t const*>(vector.data());
608+
return index.search(vector_data, wanted, metric, config);
609+
};
543610

544611
index_at index(config);
545-
index.reserve(limits);
546-
std::printf("Will benchmark in-memory\n");
612+
if (args.chunks_to_merge > 0) {
613+
index.save(args.path_output.c_str());
614+
memory_mapped_file_t output{args.path_output.c_str(), true};
615+
index.load(std::move(output));
616+
index.reserve(limits);
617+
std::printf("Will benchmark merge: %zu\n", args.chunks_to_merge);
547618

548-
single_shot(dataset, index, true);
549-
index.save(args.path_output.c_str());
619+
std::printf("\n");
620+
std::printf("------------\n");
550621

551-
std::printf("Will benchmark an on-disk view\n");
622+
{
623+
// Perform insertions, evaluate speed
624+
std::vector<default_key_t> ids(dataset.vectors_count());
625+
std::iota(ids.begin(), ids.end(), 0);
626+
std::vector<index_at> subindexes;
627+
std::vector<buffer_gt<byte_t const*>> subvalues;
628+
std::size_t offset = 0;
629+
std::size_t chunk_rows = dataset.vectors_count() / args.chunks_to_merge;
630+
for (std::size_t i = 0; i != args.chunks_to_merge; ++i) {
631+
subindexes.emplace_back(config);
632+
index_at& subindex = subindexes[i];
633+
std::size_t n = std::min(chunk_rows, dataset.vectors_count() - offset);
634+
subindex.reserve(n);
635+
subvalues.emplace_back(n);
636+
buffer_gt<byte_t const*>& subvs = subvalues[i];
637+
metric_t submetric{metric_punned, subvs};
638+
auto subadd = [&](index_at& index, default_key_t id, span_gt<scalar_t const> vector,
639+
index_update_config_t& config) {
640+
byte_t const* vector_data = reinterpret_cast<byte_t const*>(vector.data());
641+
auto on_success = [&](member_ref_t member) { subvs[member.slot] = vector_data; };
642+
index.add(id, vector_data, submetric, config, on_success);
643+
};
644+
index_many(subindex, n, ids.data() + offset, dataset.vector(offset), dataset.dimensions(), subadd);
645+
offset += chunk_rows;
646+
}
647+
{
648+
running_stats_printer_t printer{dataset.vectors_count(), "Merging"};
649+
auto merge_on_success = [&](member_ref_t member, byte_t const* value) { values[member.slot] = value; };
650+
for (std::size_t i = 0; i != args.chunks_to_merge; ++i) {
651+
buffer_gt<byte_t const*>& subvs = subvalues[i];
652+
auto get_value = [&](member_cref_t member) { return subvs[member.slot]; };
653+
index.merge(subindexes[i], get_value, metric, {}, merge_on_success);
654+
printer.progress += subindexes[i].size();
655+
printer.refresh();
656+
}
657+
printer.print();
658+
}
659+
}
660+
// Perform searches, evaluate speed
661+
single_shot(dataset, index, false, add, search);
662+
} else {
663+
index.reserve(limits);
664+
std::printf("Will benchmark in-memory\n");
552665

553-
index_at index_view = index.fork();
554-
index_view.view(args.path_output.c_str());
555-
single_shot(dataset, index_view, false);
666+
single_shot(dataset, index, true, add, search);
667+
index.save(args.path_output.c_str());
668+
669+
std::printf("Will benchmark an on-disk view\n");
670+
671+
index_at index_view = index.fork();
672+
index_view.view(args.path_output.c_str());
673+
single_shot(dataset, index_view, false, add, search);
674+
}
556675
}
557676

558677
template <typename dataset_scalar_at> void bench_with_args(args_t const& args) {
@@ -583,14 +702,18 @@ template <typename dataset_scalar_at> void bench_with_args(args_t const& args) {
583702
std::printf("-- Expansion @ Add: %zu\n", config.expansion_add);
584703
std::printf("-- Expansion @ Search: %zu\n", config.expansion_search);
585704

586-
if (args.big)
705+
if (args.index_gt_api) {
706+
run_typed<index_gt<>>(dataset, args, config, limits);
707+
} else {
708+
if (args.big)
587709
#ifdef USEARCH_64BIT_ENV
588-
run_punned<index_dense_gt<default_key_t, uint40_t>>(dataset, args, config, limits);
710+
run_punned<index_dense_gt<default_key_t, uint40_t>>(dataset, args, config, limits);
589711
#else
590-
std::printf("Error: Don't use 40 bit identifiers in 32bit environment\n");
712+
std::printf("Error: Don't use 40 bit identifiers in 32bit environment\n");
591713
#endif
592-
else
593-
run_punned<index_dense_gt<default_key_t, std::uint32_t>>(dataset, args, config, limits);
714+
else
715+
run_punned<index_dense_gt<default_key_t, std::uint32_t>>(dataset, args, config, limits);
716+
}
594717
}
595718

596719
int main(int argc, char** argv) {
@@ -613,6 +736,9 @@ int main(int argc, char** argv) {
613736
(option("--expansion-search") & value("integer", args.expansion_search)).doc("Affects search depth"),
614737
(option("--rows-skip") & value("integer", args.vectors_to_skip)).doc("Number of vectors to skip"),
615738
(option("--rows-take") & value("integer", args.vectors_to_take)).doc("Number of vectors to take"),
739+
option("--index-gt-api").set(args.index_gt_api).doc("Use index_gt<> API not index_dense_gt API"),
740+
(option("--chunks-to-merge") & value("integer", args.chunks_to_merge))
741+
.doc("Number of chunks to merge. This requires --index-gt-api"),
616742
( //
617743
option("-bf16", "--bf16quant").set(args.quantize_bf16).doc("Enable `bf16_t` quantization") |
618744
option("-f16", "--f16quant").set(args.quantize_f16).doc("Enable `f16_t` quantization") |

include/usearch/index.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3133,7 +3133,7 @@ class index_gt {
31333133
if (size() == 0 && same_connectivity) {
31343134
// If the base index is empty, we just copy nodes in the target index.
31353135
for (member_cref_t target_member : index) {
3136-
auto& value = get_value(target_member);
3136+
auto value = get_value(target_member);
31373137
const std::size_t target_slot = get_slot(target_member);
31383138
const std::size_t base_slot = target_slot;
31393139
node_t target_node = index.node_at_(target_slot);
@@ -3178,7 +3178,7 @@ class index_gt {
31783178
// index map for the 2nd path
31793179
buffer_gt<std::size_t> target_slot_to_base(index.size());
31803180
for (member_cref_t target_member : index) {
3181-
auto& value = get_value(target_member);
3181+
auto value = get_value(target_member);
31823182
const std::size_t target_slot = get_slot(target_member);
31833183
node_t target_node = index.node_at_(target_slot);
31843184
level_t level = target_node.level();
@@ -3247,7 +3247,7 @@ class index_gt {
32473247
} else {
32483248
// Add all values in the target index to the base index.
32493249
for (const auto& member : index) {
3250-
auto& value = get_value(member);
3250+
auto value = get_value(member);
32513251
auto merge_callback = [&](member_ref_t m) { callback(m, value); };
32523252
add_result_t result = add(get_key(member), value, metric, update_config, merge_callback, prefetch);
32533253
if (!result) {

0 commit comments

Comments
 (0)