Skip to content

Commit 21e5f20

Browse files
committed
bloom filter
1 parent 8e9a78c commit 21e5f20

14 files changed

Lines changed: 272 additions & 4 deletions

cpp/CMakeLists.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,9 @@ if(NOT BUILD_CPU_ONLY)
367367
"$<$<COMPILE_LANGUAGE:CUDA>:${CUVS_CUDA_FLAGS}>"
368368
)
369369
target_compile_features(jit_lto_kernel_usage_requirements INTERFACE cuda_std_20)
370-
target_link_libraries(jit_lto_kernel_usage_requirements INTERFACE rmm::rmm raft::raft CCCL::CCCL)
370+
target_link_libraries(
371+
jit_lto_kernel_usage_requirements INTERFACE rmm::rmm raft::raft CCCL::CCCL cuco::cuco
372+
)
371373

372374
block(PROPAGATE jit_lto_files)
373375
set(jit_lto_files)

cpp/include/cuvs/detail/jit_lto/common_fragments.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ struct tag_i8 {};
1414
struct tag_u8 {};
1515
struct tag_filter_none {};
1616
struct tag_filter_bitset {};
17+
struct tag_filter_bloom_filter {};
1718
struct tag_filter_udf {};
1819

1920
struct tag_bitset_u32 {};

cpp/include/cuvs/neighbors/common.hpp

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -497,7 +497,7 @@ namespace filtering {
497497
* @{
498498
*/
499499

500-
enum class FilterType { None, Bitmap, Bitset, UDF };
500+
enum class FilterType { None, Bitmap, Bitset, Bloom, UDF };
501501

502502
struct base_filter {
503503
~base_filter() = default;
@@ -617,6 +617,32 @@ struct bitset_filter : public base_filter {
617617
void to_csr(raft::resources const& handle, csr_matrix_t& csr);
618618
};
619619

620+
/**
621+
* @brief Filter CAGRA candidates with a global @c cuco bloom filter over the index.
622+
*
623+
* Build the filter once on the host with bulk @c add() over the allowed dataset row ids, obtain a
624+
* @c ref() from the owning @c cuco::bloom_filter, copy that ref to device memory, and pass the
625+
* device pointer as @c filter_data. The linked JIT-LTO fragment probes the same filter for every
626+
* query and candidate, similar to @ref bitset_filter but with probabilistic membership tests.
627+
*
628+
* Bloom filters have no false negatives: if a row was inserted, @c contains returns @c true. False
629+
* positives are possible, so highly selective predicates may still need a bitset or UDF for exact
630+
* filtering.
631+
*/
632+
struct bloom_filter : public base_filter {
633+
void* filter_data{nullptr};
634+
float filtering_rate{-1.0f};
635+
636+
bloom_filter() = default;
637+
638+
explicit bloom_filter(void* filter_data, float filtering_rate = -1.0f)
639+
: filter_data(filter_data), filtering_rate(filtering_rate)
640+
{
641+
}
642+
643+
FilterType get_filter_type() const override { return FilterType::Bloom; }
644+
};
645+
620646
/**
621647
* @brief JIT-LTO user-defined filter predicate.
622648
*

cpp/src/neighbors/cagra.cuh

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,25 @@ void search(raft::resources const& res,
385385
} catch (const std::bad_cast&) {
386386
}
387387

388+
try {
389+
auto& sample_filter =
390+
dynamic_cast<const cuvs::neighbors::filtering::bloom_filter&>(sample_filter_ref);
391+
search_params params_copy = params;
392+
if (params.filtering_rate < 0.0) {
393+
const float min_filtering_rate = 0.0f;
394+
const float max_filtering_rate = 0.999f;
395+
params_copy.filtering_rate =
396+
sample_filter.filtering_rate < 0.0f
397+
? 0.0f
398+
: std::min(std::max(sample_filter.filtering_rate, min_filtering_rate),
399+
max_filtering_rate);
400+
}
401+
auto sample_filter_copy = sample_filter;
402+
return search_with_filtering<T, IdxT, decltype(sample_filter_copy), OutputIdxT>(
403+
res, params_copy, idx, queries, neighbors, distances, sample_filter_copy);
404+
} catch (const std::bad_cast&) {
405+
}
406+
388407
try {
389408
auto& sample_filter =
390409
dynamic_cast<const cuvs::neighbors::filtering::udf_filter&>(sample_filter_ref);

cpp/src/neighbors/detail/cagra/cagra_filter_payload.hpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,12 @@ template <typename bitset_t, typename index_t>
139139
struct is_bitset_filter<::cuvs::neighbors::filtering::bitset_filter<bitset_t, index_t>>
140140
: std::true_type {};
141141

142+
template <typename T>
143+
struct is_bloom_filter : std::false_type {};
144+
145+
template <>
146+
struct is_bloom_filter<::cuvs::neighbors::filtering::bloom_filter> : std::true_type {};
147+
142148
template <typename T>
143149
struct is_udf_filter : std::false_type {};
144150

@@ -177,6 +183,8 @@ void fill_cagra_sample_filter(cagra_sample_filter<SourceIndexT>& out,
177183
using DecayedFilter = std::decay_t<FilterT>;
178184
if constexpr (is_bitset_filter<DecayedFilter>::value) {
179185
out.filter_data = make_cagra_bitset_filter_payload<SourceIndexT>(filter, stream);
186+
} else if constexpr (is_bloom_filter<DecayedFilter>::value) {
187+
out.filter_data = filter.filter_data;
180188
} else if constexpr (is_udf_filter<DecayedFilter>::value) {
181189
out.filter_data = filter.filter_data;
182190
}
@@ -199,7 +207,7 @@ template <typename FilterT>
199207
void* cagra_filter_data_ptr(const FilterT& filter)
200208
{
201209
using DecayedFilter = std::decay_t<FilterT>;
202-
if constexpr (is_udf_filter<DecayedFilter>::value) {
210+
if constexpr (is_bloom_filter<DecayedFilter>::value || is_udf_filter<DecayedFilter>::value) {
203211
return filter.filter_data;
204212
} else if constexpr (requires { filter.filter; }) {
205213
return cagra_filter_data_ptr(filter.filter);

cpp/src/neighbors/detail/cagra/jit_lto_kernels/sample_filter_impl.cuh

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99

1010
#include "../../sample_filter_data.cuh"
1111

12+
#include <cuco/bloom_filter_ref.cuh>
13+
1214
#include <raft/core/bitset.cuh>
1315

1416
#include <cstdint>
@@ -38,4 +40,15 @@ __device__ bool sample_filter_bitset_impl(uint32_t /*query_id*/,
3840
return view.test(node_id);
3941
}
4042

43+
template <typename SourceIndexT, typename Key = std::uint32_t>
44+
__device__ bool sample_filter_bloom_filter_impl(uint32_t /*query_id*/,
45+
SourceIndexT node_id,
46+
void* filter_data)
47+
{
48+
if (filter_data == nullptr) { return true; }
49+
50+
auto* data = static_cast<bloom_filter_data_t<Key>*>(filter_data);
51+
return data->filter.contains(static_cast<Key>(node_id));
52+
}
53+
4154
} // namespace cuvs::neighbors::detail

cpp/src/neighbors/detail/cagra/jit_lto_kernels/sample_filter_matrix.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
{
2-
"filter_name": ["none", "bitset"],
2+
"filter_name": ["none", "bitset", "bloom_filter"],
33
"_bitset": [
44
{
55
"bitset_type": "uint32_t",

cpp/src/neighbors/detail/cagra/search_multi_cta_inst.cu.in

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ namespace {
1111
using data_t = @data_type@;
1212
using bitset_filter_t = cuvs::neighbors::cagra::detail::CagraSampleFilterWithQueryIdOffset<
1313
cuvs::neighbors::filtering::bitset_filter<uint32_t, int64_t>>;
14+
using bloom_filter_t = cuvs::neighbors::cagra::detail::CagraSampleFilterWithQueryIdOffset<
15+
cuvs::neighbors::filtering::bloom_filter>;
1416
using udf_filter_t = cuvs::neighbors::cagra::detail::CagraSampleFilterWithQueryIdOffset<
1517
cuvs::neighbors::filtering::udf_filter>;
1618

@@ -22,6 +24,7 @@ instantiate_kernel_selection(data_t,
2224
float,
2325
cuvs::neighbors::filtering::none_sample_filter);
2426
instantiate_kernel_selection(data_t, uint32_t, float, bitset_filter_t);
27+
instantiate_kernel_selection(data_t, uint32_t, float, bloom_filter_t);
2528
instantiate_kernel_selection(data_t, uint32_t, float, udf_filter_t);
2629

2730
} // namespace cuvs::neighbors::cagra::detail::multi_cta_search

cpp/src/neighbors/detail/cagra/search_single_cta_inst.cu.in

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ namespace {
1111
using data_t = @data_type@;
1212
using bitset_filter_t = cuvs::neighbors::cagra::detail::CagraSampleFilterWithQueryIdOffset<
1313
cuvs::neighbors::filtering::bitset_filter<uint32_t, int64_t>>;
14+
using bloom_filter_t = cuvs::neighbors::cagra::detail::CagraSampleFilterWithQueryIdOffset<
15+
cuvs::neighbors::filtering::bloom_filter>;
1416
using udf_filter_t = cuvs::neighbors::cagra::detail::CagraSampleFilterWithQueryIdOffset<
1517
cuvs::neighbors::filtering::udf_filter>;
1618

@@ -22,6 +24,7 @@ instantiate_kernel_selection(data_t,
2224
float,
2325
cuvs::neighbors::filtering::none_sample_filter);
2426
instantiate_kernel_selection(data_t, uint32_t, float, bitset_filter_t);
27+
instantiate_kernel_selection(data_t, uint32_t, float, bloom_filter_t);
2528
instantiate_kernel_selection(data_t, uint32_t, float, udf_filter_t);
2629

2730
} // namespace cuvs::neighbors::cagra::detail::single_cta_search

cpp/src/neighbors/detail/cagra/search_single_cta_kernel_launcher_jit.cuh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@ std::uint64_t cagra_sample_filter_type_id(const SampleFilterT& sample_filter)
8080
{
8181
using DecayedFilter = std::decay_t<SampleFilterT>;
8282
if constexpr (is_udf_filter<DecayedFilter>::value) {
83+
return 3;
84+
} else if constexpr (is_bloom_filter<DecayedFilter>::value) {
8385
return 2;
8486
} else if constexpr (is_bitset_filter<DecayedFilter>::value) {
8587
return 1;

0 commit comments

Comments
 (0)