Skip to content

Commit 369a553

Browse files
Fix: filtered_search shouldn't update radius (#599)
Co-authored-by: Ash Vardanian <[email protected]>
1 parent 9ab9e1f commit 369a553

File tree

2 files changed

+49
-2
lines changed

2 files changed

+49
-2
lines changed

cpp/test.cpp

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1098,6 +1098,51 @@ template <typename key_at, typename slot_at> void test_replacing_update() {
10981098
expect_eq(final_search[2].member.key, 44);
10991099
}
11001100

1101+
/**
1102+
* Tests the filtered search functionality of the index.
1103+
*/
1104+
void test_filtered_search() {
1105+
constexpr std::size_t dataset_count = 2048;
1106+
constexpr std::size_t dimensions = 32;
1107+
metric_punned_t metric(dimensions, metric_kind_t::cos_k);
1108+
1109+
std::random_device rd;
1110+
std::mt19937 gen(rd());
1111+
std::uniform_real_distribution<> dis(0.0, 1.0);
1112+
using vector_of_vectors_t = std::vector<std::vector<float>>;
1113+
1114+
vector_of_vectors_t vector_of_vectors(dataset_count);
1115+
for (auto& vector : vector_of_vectors) {
1116+
vector.resize(dimensions);
1117+
std::generate(vector.begin(), vector.end(), [&] { return dis(gen); });
1118+
}
1119+
1120+
index_dense_t index = index_dense_t::make(metric);
1121+
index.reserve(dataset_count);
1122+
for (std::size_t idx = 0; idx < dataset_count; ++idx)
1123+
index.add(idx, vector_of_vectors[idx].data());
1124+
expect_eq(index.size(), dataset_count);
1125+
1126+
{
1127+
auto predicate = [](index_dense_t::key_t key) { return key != 0; };
1128+
auto results = index.filtered_search(vector_of_vectors[0].data(), 10, predicate);
1129+
expect_eq(10, results.size()); // ! Should not contain 0
1130+
for (std::size_t i = 0; i != results.size(); ++i)
1131+
expect(0 != results[i].member.key);
1132+
}
1133+
{
1134+
auto predicate = [](index_dense_t::key_t) { return false; };
1135+
auto results = index.filtered_search(vector_of_vectors[0].data(), 10, predicate);
1136+
expect_eq(0, results.size()); // ! Should not contain 0
1137+
}
1138+
{
1139+
auto predicate = [](index_dense_t::key_t key) { return key == 10; };
1140+
auto results = index.filtered_search(vector_of_vectors[0].data(), 10, predicate);
1141+
expect_eq(1, results.size()); // ! Should not contain 0
1142+
expect_eq(10, results[0].member.key);
1143+
}
1144+
}
1145+
11011146
int main(int, char**) {
11021147
test_uint40();
11031148
test_cosine<float, std::int64_t, uint40_t>(10, 10);
@@ -1174,5 +1219,6 @@ int main(int, char**) {
11741219
test_sets<std::int64_t, slot32_t>(set_size, 20, 30);
11751220
test_strings<std::int64_t, slot32_t>();
11761221

1222+
test_filtered_search();
11771223
return 0;
11781224
}

include/usearch/index.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4178,9 +4178,10 @@ class index_gt {
41784178
// This can substantially grow our priority queue:
41794179
next.insert({-successor_dist, successor_slot});
41804180
if (is_dummy<predicate_at>() ||
4181-
predicate(member_cref_t{node_at_(successor_slot).ckey(), successor_slot}))
4181+
predicate(member_cref_t{node_at_(successor_slot).ckey(), successor_slot})) {
41824182
top.insert({successor_dist, successor_slot}, top_limit);
4183-
radius = top.top().distance;
4183+
radius = top.top().distance;
4184+
}
41844185
}
41854186
}
41864187
}

0 commit comments

Comments
 (0)