diff --git a/cpp/test.cpp b/cpp/test.cpp index bcafbc8c..20f5553e 100644 --- a/cpp/test.cpp +++ b/cpp/test.cpp @@ -1098,6 +1098,51 @@ template void test_replacing_update() { expect_eq(final_search[2].member.key, 44); } +/** + * Tests the filtered search functionality of the index. + */ +void test_filtered_search() { + constexpr std::size_t dataset_count = 2048; + constexpr std::size_t dimensions = 32; + metric_punned_t metric(dimensions, metric_kind_t::cos_k); + + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_real_distribution<> dis(0.0, 1.0); + using vector_of_vectors_t = std::vector>; + + vector_of_vectors_t vector_of_vectors(dataset_count); + for (auto& vector : vector_of_vectors) { + vector.resize(dimensions); + std::generate(vector.begin(), vector.end(), [&] { return dis(gen); }); + } + + index_dense_t index = index_dense_t::make(metric); + index.reserve(dataset_count); + for (std::size_t idx = 0; idx < dataset_count; ++idx) + index.add(idx, vector_of_vectors[idx].data()); + expect_eq(index.size(), dataset_count); + + { + auto predicate = [](index_dense_t::key_t key) { return key != 0; }; + auto results = index.filtered_search(vector_of_vectors[0].data(), 10, predicate); + expect_eq(10, results.size()); // ! Should not contain 0 + for (std::size_t i = 0; i != results.size(); ++i) + expect(0 != results[i].member.key); + } + { + auto predicate = [](index_dense_t::key_t) { return false; }; + auto results = index.filtered_search(vector_of_vectors[0].data(), 10, predicate); + expect_eq(0, results.size()); // ! Should not contain 0 + } + { + auto predicate = [](index_dense_t::key_t key) { return key == 10; }; + auto results = index.filtered_search(vector_of_vectors[0].data(), 10, predicate); + expect_eq(1, results.size()); // ! Should not contain 0 + expect_eq(10, results[0].member.key); + } +} + int main(int, char**) { test_uint40(); test_cosine(10, 10); @@ -1174,5 +1219,6 @@ int main(int, char**) { test_sets(set_size, 20, 30); test_strings(); + test_filtered_search(); return 0; } diff --git a/include/usearch/index.hpp b/include/usearch/index.hpp index 3922ae23..17a089c0 100644 --- a/include/usearch/index.hpp +++ b/include/usearch/index.hpp @@ -4178,9 +4178,10 @@ class index_gt { // This can substantially grow our priority queue: next.insert({-successor_dist, successor_slot}); if (is_dummy() || - predicate(member_cref_t{node_at_(successor_slot).ckey(), successor_slot})) + predicate(member_cref_t{node_at_(successor_slot).ckey(), successor_slot})) { top.insert({successor_dist, successor_slot}, top_limit); - radius = top.top().distance; + radius = top.top().distance; + } } } }