Skip to content

Commit fdb2aae

Browse files
committed
Optimize merge
1 parent e8d3fae commit fdb2aae

File tree

2 files changed

+264
-61
lines changed

2 files changed

+264
-61
lines changed

cpp/test.cpp

Lines changed: 67 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1101,7 +1101,7 @@ template <typename key_at, typename slot_at> void test_replacing_update() {
11011101
/**
11021102
* @brief Tests merging.
11031103
*/
1104-
void test_merge() {
1104+
void test_merge(std::size_t base_n) {
11051105
using index_t = index_gt<>;
11061106
using distance_t = typename index_t::distance_t;
11071107
using key_t = typename index_t::key_t;
@@ -1145,24 +1145,49 @@ void test_merge() {
11451145
expect(result);
11461146
};
11471147

1148+
std::size_t n_nodes1 = base_n;
1149+
std::size_t n_nodes2 = base_n * 2;
1150+
1151+
// Prepare expected index
1152+
auto expected_index = create_index();
1153+
metric_t expected_metric;
1154+
expect(expected_index.reserve(n_nodes1 + n_nodes2));
1155+
11481156
// Prepare index 1
11491157
auto index1 = create_index();
11501158
metric_t metric1;
1151-
expect(index1.reserve(3));
1152-
add(index1, 11, 1.1f, metric1);
1153-
add(index1, 12, 2.1f, metric1);
1154-
add(index1, 13, 3.1f, metric1);
1155-
expect_eq(index1.size(), 3);
1159+
expect(index1.reserve(n_nodes1));
1160+
{
1161+
// Use static seed for easy to reproduce
1162+
std::default_random_engine engine(n_nodes1);
1163+
std::uniform_real_distribution<float> distribution(-1.0, 1.0);
1164+
for (std::size_t i = 0; i < n_nodes1; ++i) {
1165+
std::size_t key = 10000 + i;
1166+
value_t value = distribution(engine);
1167+
add(index1, key, value, metric1);
1168+
add(expected_index, key, value, expected_metric);
1169+
}
1170+
}
1171+
expect_eq(index1.size(), n_nodes1);
1172+
expect_eq(expected_index.size(), n_nodes1);
11561173

11571174
// Prepare index 2
11581175
auto index2 = create_index();
11591176
metric_t metric2;
1160-
expect(index2.reserve(4));
1161-
add(index2, 21, -1.1f, metric2);
1162-
add(index2, 22, -2.1f, metric2);
1163-
add(index2, 23, -3.1f, metric2);
1164-
add(index2, 24, -4.1f, metric2);
1165-
expect_eq(index2.size(), 4);
1177+
expect(index2.reserve(n_nodes2));
1178+
{
1179+
// Use static seed for easy to reproduce
1180+
std::default_random_engine engine(n_nodes2);
1181+
std::uniform_real_distribution<float> distribution(-1.0, 1.0);
1182+
for (std::size_t i = 0; i < n_nodes2; ++i) {
1183+
std::size_t key = 20000 + i;
1184+
value_t value = distribution(engine);
1185+
add(index2, key, value, metric2);
1186+
add(expected_index, key, value, expected_metric);
1187+
}
1188+
}
1189+
expect_eq(index2.size(), n_nodes2);
1190+
expect_eq(expected_index.size(), n_nodes1 + n_nodes2);
11661191

11671192
// Merge indexes
11681193
char const* merge_file_path = "merge.usearch";
@@ -1174,30 +1199,45 @@ void test_merge() {
11741199
auto merge_on_success = [&](member_ref_t member, value_t const& value) {
11751200
merged_metric.values[member.slot] = value;
11761201
};
1202+
1203+
// Merge index1
11771204
auto get_value1 = [&](member_cref_t member) -> value_t& { return metric1.values[member.slot]; };
11781205
expect(merged_index.merge(index1, get_value1, merged_metric, {}, merge_on_success));
1206+
expect_eq(merged_index.size(), n_nodes1);
1207+
// Assert after we merge index1
1208+
auto search = merged_index.search(0.75f, 3, merged_metric);
1209+
auto expected_search = index1.search(0.75f, 3, expected_metric);
1210+
expect_eq(search.size(), 3);
1211+
expect(search[0].distance <= expected_search[0].distance);
1212+
expect(search[1].distance <= expected_search[1].distance);
1213+
expect(search[2].distance <= expected_search[2].distance);
1214+
auto loaded_index = create_index();
1215+
loaded_index.view(merge_file_path);
1216+
search = merged_index.search(0.75f, 3, merged_metric);
1217+
1218+
// Merge index2
11791219
auto get_value2 = [&](member_cref_t member) -> value_t& { return metric2.values[member.slot]; };
11801220
expect(merged_index.merge(index2, get_value2, merged_metric, {}, merge_on_success));
1181-
1182-
// Assert
1183-
expect_eq(merged_index.size(), 7);
1184-
auto search = merged_index.search(0.75f, 3, merged_metric);
1221+
// Assert after we merge index1 and index2
1222+
expect_eq(merged_index.size(), n_nodes1 + n_nodes2);
1223+
search = merged_index.search(0.75f, 3, merged_metric);
1224+
expected_search = expected_index.search(0.75f, 3, expected_metric);
11851225
expect_eq(search.size(), 3);
1186-
expect_eq(static_cast<key_t>(search[0].member.key), 11);
1187-
expect_eq(static_cast<key_t>(search[1].member.key), 12);
1188-
expect_eq(static_cast<key_t>(search[2].member.key), 21);
1226+
expect(search[0].distance <= expected_search[0].distance);
1227+
expect(search[1].distance <= expected_search[1].distance);
1228+
expect(search[2].distance <= expected_search[2].distance);
11891229

1190-
// Re-load merged indexes
1230+
// Re-load the merged index
11911231
merged_index.reset();
11921232
merged_index.load(merge_file_path);
11931233

1194-
// Assert
1195-
expect_eq(merged_index.size(), 7);
1234+
// Assert after we reload the merged index
1235+
expect_eq(merged_index.size(), n_nodes1 + n_nodes2);
11961236
search = merged_index.search(0.75f, 3, merged_metric);
11971237
expect_eq(search.size(), 3);
1198-
expect_eq(static_cast<key_t>(search[0].member.key), 11);
1199-
expect_eq(static_cast<key_t>(search[1].member.key), 12);
1200-
expect_eq(static_cast<key_t>(search[2].member.key), 21);
1238+
expect(search[0].distance <= expected_search[0].distance);
1239+
expect(search[1].distance <= expected_search[1].distance);
1240+
expect(search[2].distance <= expected_search[2].distance);
12011241
}
12021242

12031243
int main(int, char**) {
@@ -1278,7 +1318,8 @@ int main(int, char**) {
12781318

12791319
// Test merge
12801320
std::printf("Testing merge\n");
1281-
test_merge();
1321+
test_merge(10); // Use only the 0-level layer
1322+
test_merge(1000); // Use multiple layers
12821323

12831324
return 0;
12841325
}

0 commit comments

Comments
 (0)