@@ -1098,6 +1098,51 @@ template <typename key_at, typename slot_at> void test_replacing_update() {
1098
1098
expect_eq (final_search[2 ].member .key , 44 );
1099
1099
}
1100
1100
1101
+ /* *
1102
+ * Tests the filtered search functionality of the index.
1103
+ */
1104
+ void test_filtered_search () {
1105
+ constexpr std::size_t dataset_count = 2048 ;
1106
+ constexpr std::size_t dimensions = 32 ;
1107
+ metric_punned_t metric (dimensions, metric_kind_t ::cos_k);
1108
+
1109
+ std::random_device rd;
1110
+ std::mt19937 gen (rd ());
1111
+ std::uniform_real_distribution<> dis (0.0 , 1.0 );
1112
+ using vector_of_vectors_t = std::vector<std::vector<float >>;
1113
+
1114
+ vector_of_vectors_t vector_of_vectors (dataset_count);
1115
+ for (auto & vector : vector_of_vectors) {
1116
+ vector.resize (dimensions);
1117
+ std::generate (vector.begin (), vector.end (), [&] { return dis (gen); });
1118
+ }
1119
+
1120
+ index_dense_t index = index_dense_t::make (metric);
1121
+ index .reserve (dataset_count);
1122
+ for (std::size_t idx = 0 ; idx < dataset_count; ++idx)
1123
+ index .add (idx, vector_of_vectors[idx].data ());
1124
+ expect_eq (index .size (), dataset_count);
1125
+
1126
+ {
1127
+ auto predicate = [](index_dense_t ::key_t key) { return key != 0 ; };
1128
+ auto results = index .filtered_search (vector_of_vectors[0 ].data (), 10 , predicate);
1129
+ expect_eq (10 , results.size ()); // ! Should not contain 0
1130
+ for (std::size_t i = 0 ; i != results.size (); ++i)
1131
+ expect (0 != results[i].member .key );
1132
+ }
1133
+ {
1134
+ auto predicate = [](index_dense_t ::key_t ) { return false ; };
1135
+ auto results = index .filtered_search (vector_of_vectors[0 ].data (), 10 , predicate);
1136
+ expect_eq (0 , results.size ()); // ! Should not contain 0
1137
+ }
1138
+ {
1139
+ auto predicate = [](index_dense_t ::key_t key) { return key == 10 ; };
1140
+ auto results = index .filtered_search (vector_of_vectors[0 ].data (), 10 , predicate);
1141
+ expect_eq (1 , results.size ()); // ! Should not contain 0
1142
+ expect_eq (10 , results[0 ].member .key );
1143
+ }
1144
+ }
1145
+
1101
1146
int main (int , char **) {
1102
1147
test_uint40 ();
1103
1148
test_cosine<float , std::int64_t , uint40_t >(10 , 10 );
@@ -1174,5 +1219,6 @@ int main(int, char**) {
1174
1219
test_sets<std::int64_t , slot32_t >(set_size, 20 , 30 );
1175
1220
test_strings<std::int64_t , slot32_t >();
1176
1221
1222
+ test_filtered_search ();
1177
1223
return 0 ;
1178
1224
}
0 commit comments