Skip to content

Commit 8e9a78c

Browse files
authored
Add JIT-LTO based filter UDF support for CAGRA (#2132)
This PR adds a first version of low-level JIT-LTO filter UDF support for CAGRA search in cuVS C++. Users can provide CUDA device source for a predicate with this ABI: ```cpp __device__ bool cuvs_filter_udf(uint32_t query_id, source_index_t source_id, void* filter_data); ``` The predicate returns `true` to allow a candidate and `false` to reject it. `filter_data` is not result data; it is an optional opaque device-accessible context pointer that lets the predicate read user metadata such as tenant ids, timestamps, language masks, ACL bitmaps, or query-specific thresholds. Results are still written to the normal CAGRA `neighbors` and `distances` outputs. This gives CAGRA a path to support runtime metadata predicates without ahead-of-time template explosion, while keeping existing `none` and `bitset` filters on their static JIT-LTO fragment paths. ## What Changed - Added `cuvs::neighbors::filtering::udf_filter`. - Added `FilterType::UDF`. - Added CAGRA dispatch for `udf_filter`, including `filtering_rate` fallback behavior. - Added a dynamic JIT-LTO sample-filter fragment generator for UDF source. - Reworked CAGRA JIT sample-filter payload plumbing from bitset-only payloads to a generic payload that supports: - no filter - bitset filter - UDF filter context pointer - query id offsets for batched/multi-kernel paths - Updated single-CTA, multi-CTA, and multi-kernel CAGRA JIT paths to call the linked sample filter uniformly. - Added docs for UDF behavior, caveats, and non-goals. - Added a standalone C++ example showing three metadata-style UDFs. ## About `filter_data` `filter_data` is the mechanism for passing runtime metadata into the device predicate. It is optional: simple predicates can ignore it or use `nullptr`. For example, this UDF needs no context: ```cpp __device__ bool cuvs_filter_udf(uint32_t, source_index_t source_id, void*) { return source_id >= 704; } ``` For metadata filters, callers pass a device pointer to a user-defined context struct: ```cpp struct tenant_filter_context { const uint32_t* row_tenant_ids; const uint32_t* query_tenant_ids; }; ``` Then the UDF casts `filter_data` back to that type: ```cpp __device__ bool tenant_filter(uint32_t query_id, source_index_t source_id, void* filter_data) { auto* ctx = static_cast<const tenant_filter_context*>(filter_data); return ctx->row_tenant_ids[source_id] == ctx->query_tenant_ids[query_id]; } ``` The pointer and anything it points to must be device-accessible and remain valid for the duration of the search. Future typed wrappers or expression builders could hide this cast, but the internal ABI can stay stable as: ```cpp bool predicate(uint32_t query_id, source_index_t source_id, void* filter_data); ``` ## Example The standalone example builds one CAGRA index and runs three different UDF predicates over the same metadata context. First, the caller defines a host-side context type. This same layout is repeated in the UDF source string so device code knows how to interpret `filter_data`: ```cpp struct metadata_filter_context { const uint32_t* row_tenant_ids; const int64_t* row_timestamps; const uint32_t* row_language_ids; const uint64_t* row_acl_masks; const uint32_t* query_tenant_ids; const int64_t* query_min_timestamps; const uint64_t* query_allowed_language_masks; const uint64_t* query_permission_masks; }; ``` The UDF source contains the device predicates. Each predicate receives the same `filter_data` pointer and casts it to `metadata_filter_context`: ```cpp std::string source = R"cpp( struct metadata_filter_context { const uint32_t* row_tenant_ids; const int64_t* row_timestamps; const uint32_t* row_language_ids; const uint64_t* row_acl_masks; const uint32_t* query_tenant_ids; const int64_t* query_min_timestamps; const uint64_t* query_allowed_language_masks; const uint64_t* query_permission_masks; }; __device__ bool tenant_filter(uint32_t query_id, source_index_t source_id, void* filter_data) { auto* ctx = static_cast<const metadata_filter_context*>(filter_data); return ctx->row_tenant_ids[source_id] == ctx->query_tenant_ids[query_id]; } __device__ bool timestamp_filter(uint32_t query_id, source_index_t source_id, void* filter_data) { auto* ctx = static_cast<const metadata_filter_context*>(filter_data); return ctx->row_timestamps[source_id] >= ctx->query_min_timestamps[query_id]; } __device__ bool language_acl_filter(uint32_t query_id, source_index_t source_id, void* filter_data) { auto* ctx = static_cast<const metadata_filter_context*>(filter_data); const auto language_bit = uint64_t{1} << ctx->row_language_ids[source_id]; const bool language_ok = (ctx->query_allowed_language_masks[query_id] & language_bit) != 0; const bool acl_ok = (ctx->row_acl_masks[source_id] & ctx->query_permission_masks[query_id]) != 0; return language_ok && acl_ok; } )cpp"; ``` The caller owns the metadata arrays. They must be copied to device memory before search: ```cpp auto row_tenant_ids_device = raft::make_device_vector<uint32_t, int64_t>(res, n_rows); auto row_timestamps_device = raft::make_device_vector<int64_t, int64_t>(res, n_rows); auto row_language_ids_device = raft::make_device_vector<uint32_t, int64_t>(res, n_rows); auto row_acl_masks_device = raft::make_device_vector<uint64_t, int64_t>(res, n_rows); auto query_tenant_ids_device = raft::make_device_vector<uint32_t, int64_t>(res, n_queries); auto query_min_timestamps_device = raft::make_device_vector<int64_t, int64_t>(res, n_queries); auto query_allowed_language_masks_device = raft::make_device_vector<uint64_t, int64_t>(res, n_queries); auto query_permission_masks_device = raft::make_device_vector<uint64_t, int64_t>(res, n_queries); // Copy host metadata into these device arrays before launching search. ``` Then the caller builds one device-resident context struct whose fields point at those device arrays: ```cpp auto context_device = raft::make_device_vector<metadata_filter_context, int64_t>(res, 1); metadata_filter_context host_context{ row_tenant_ids_device.data_handle(), row_timestamps_device.data_handle(), row_language_ids_device.data_handle(), row_acl_masks_device.data_handle(), query_tenant_ids_device.data_handle(), query_min_timestamps_device.data_handle(), query_allowed_language_masks_device.data_handle(), query_permission_masks_device.data_handle() }; // Copy the struct itself to device memory. The pointer to this device struct // is what we pass as filter_data. raft::copy(context_device.data_handle(), &host_context, 1, raft::resource::get_cuda_stream(res)); raft::resource::sync_stream(res); ``` Finally, each `udf_filter` selects which device function to link by name. All three filters reuse the same source and the same `filter_data` context: ```cpp auto tenant_filter = cuvs::neighbors::filtering::udf_filter( source, context_device.data_handle(), // filter_data 0.75f, // estimated rejected fraction "tenant-filter-v1", // cache key "tenant_filter"); // device function name auto timestamp_filter = cuvs::neighbors::filtering::udf_filter( source, context_device.data_handle(), 0.50f, "timestamp-filter-v1", "timestamp_filter"); auto language_acl_filter = cuvs::neighbors::filtering::udf_filter( source, context_device.data_handle(), 0.875f, "language-acl-filter-v1", "language_acl_filter"); ``` Example output: ```text Building CAGRA index tenant_filter first query neighbors: 3364 3488 260 772 3292 3344 2052 660 timestamp_filter first query neighbors: 4070 3942 3364 3629 3797 3488 3292 3495 language_acl_filter first query neighbors: 3488 3344 2480 2520 2304 2176 1360 400 All CAGRA filter UDF examples produced valid filtered neighbors. ``` ## Validation Focused CAGRA UDF test passes across all CAGRA search algorithms: ```text NEIGHBORS_ANN_CAGRA_FILTER_UDF_TEST ``` Coverage includes: - accept-all UDF matches no-filter results - reject-all UDF returns no valid dataset row ids - high `filtering_rate` UDF returns only accepted rows - invalid source throws during compile/search setup - repeated same-cache-key searches match - UDF threshold predicate matches equivalent bitset filter - query-specific tenant metadata works across SINGLE_CTA, MULTI_CTA, and MULTI_KERNEL Broader CAGRA regression set also passed: ```text NEIGHBORS_ANN_CAGRA_(FILTER_UDF_TEST|FLOAT_UINT32|INT8_UINT32|UINT8_UINT32|HALF_UINT32|TEST_BUGS) ``` ## Notes This PR intentionally keeps the v1 API low-level: users provide CUDA source plus an optional `void*` device context. Typed wrappers, expression builders, or write-once host/device ergonomics can be layered on later without changing the internal ABI. Filter UDFs are candidate-validity predicates only. They do not control CAGRA graph traversal, distance computation, PQ/VPQ internals, or result selection. ## Benchmarks Performance benchmarking is still TODO for this PR. The expected regression risk for existing functionality is low because `none` and `bitset` filters still use static JIT-LTO fragments, and the UDF dynamic fragment path is only used for `udf_filter`. However, the CAGRA JIT kernel payload/signature plumbing changed from bitset-only payloads to a generic sample-filter payload, so we should still validate no-filter and bitset search performance against `main`. Proposed planned benchmark coverage: - no-filter CAGRA search vs `main` - bitset-filtered CAGRA search vs `main` - equivalent bitset predicate vs UDF predicate - first-call UDF compile/link latency - warm-cache repeated UDF search latency - representative CAGRA algorithms/configs, especially SINGLE_CTA, MULTI_CTA, and MULTI_KERNEL where applicable This should give us both existing-functionality regression coverage and a baseline for the new UDF path. Authors: - Dante Gama Dessavre (https://github.com/dantegd) Approvers: - Divye Gala (https://github.com/divyegala) URL: #2132
1 parent 78135be commit 8e9a78c

32 files changed

Lines changed: 1374 additions & 184 deletions

cpp/include/cuvs/detail/jit_lto/common_fragments.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ struct tag_i8 {};
1414
struct tag_u8 {};
1515
struct tag_filter_none {};
1616
struct tag_filter_bitset {};
17+
struct tag_filter_udf {};
1718

1819
struct tag_bitset_u32 {};
1920

cpp/include/cuvs/neighbors/cagra.hpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,10 @@ struct search_params : cuvs::neighbors::search_params {
346346
/**
347347
* A parameter indicating the rate of nodes to be filtered-out, when filtering is used.
348348
* The value must be equal to or greater than 0.0 and less than 1.0. Default value is
349-
* negative, in which case the filtering rate is automatically calculated.
349+
* negative, in which case the filtering rate is automatically calculated when possible.
350+
* For `filtering::udf_filter`, CAGRA uses `udf_filter::filtering_rate` when this value is
351+
* negative. If both values are negative, CAGRA assumes 0.0 because a UDF's selectivity cannot be
352+
* inferred from the source string.
350353
*/
351354
float filtering_rate = -1.0;
352355
};

cpp/include/cuvs/neighbors/common.hpp

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@
2424

2525
#include <memory>
2626
#include <numeric>
27+
#include <string>
2728
#include <type_traits>
29+
#include <utility>
2830

2931
#ifdef __cpp_lib_bitops
3032
#include <bit>
@@ -495,7 +497,7 @@ namespace filtering {
495497
* @{
496498
*/
497499

498-
enum class FilterType { None, Bitmap, Bitset };
500+
enum class FilterType { None, Bitmap, Bitset, UDF };
499501

500502
struct base_filter {
501503
~base_filter() = default;
@@ -615,6 +617,46 @@ struct bitset_filter : public base_filter {
615617
void to_csr(raft::resources const& handle, csr_matrix_t& csr);
616618
};
617619

620+
/**
621+
* @brief JIT-LTO user-defined filter predicate.
622+
*
623+
* The source must define a device function named by @c function_name with signature:
624+
*
625+
* @code{.cpp}
626+
* __device__ bool cuvs_filter_udf(uint32_t query_id, source_index_t source_id, void* filter_data);
627+
* @endcode
628+
*
629+
* Return @c true to allow a source vector to appear in the results and @c false to reject it.
630+
* @c filter_data is passed through unchanged and must point to device-accessible memory when the
631+
* UDF dereferences it. CAGRA currently provides @c source_index_t as @c uint32_t in the generated
632+
* JIT fragment.
633+
*/
634+
struct udf_filter : public base_filter {
635+
/** CUDA C++ source containing the device predicate. */
636+
std::string source;
637+
/** Opaque device-accessible pointer passed to the predicate. */
638+
void* filter_data = nullptr;
639+
/** Estimated fraction of rows rejected by the predicate, or negative if unknown. */
640+
float filtering_rate = -1.0f;
641+
/** Device function name to call from the generated CAGRA sample filter. */
642+
std::string function_name = "cuvs_filter_udf";
643+
644+
udf_filter() = default;
645+
646+
explicit udf_filter(std::string source,
647+
void* filter_data = nullptr,
648+
float filtering_rate = -1.0f,
649+
std::string function_name = "cuvs_filter_udf")
650+
: source(std::move(source)),
651+
filter_data(filter_data),
652+
filtering_rate(filtering_rate),
653+
function_name(std::move(function_name))
654+
{
655+
}
656+
657+
FilterType get_filter_type() const override { return FilterType::UDF; }
658+
};
659+
618660
/** @} */ // end group neighbors_filtering
619661

620662
/**

cpp/src/neighbors/cagra.cuh

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626

2727
#include <rmm/cuda_stream_view.hpp>
2828

29+
#include <algorithm>
30+
2931
namespace cuvs::neighbors::cagra {
3032

3133
// Member function implementations for cagra::index
@@ -380,6 +382,25 @@ void search(raft::resources const& res,
380382
auto sample_filter_copy = sample_filter;
381383
return search_with_filtering<T, IdxT, decltype(sample_filter_copy), OutputIdxT>(
382384
res, params_copy, idx, queries, neighbors, distances, sample_filter_copy);
385+
} catch (const std::bad_cast&) {
386+
}
387+
388+
try {
389+
auto& sample_filter =
390+
dynamic_cast<const cuvs::neighbors::filtering::udf_filter&>(sample_filter_ref);
391+
search_params params_copy = params;
392+
if (params.filtering_rate < 0.0) {
393+
const float min_filtering_rate = 0.0f;
394+
const float max_filtering_rate = 0.999f;
395+
params_copy.filtering_rate =
396+
sample_filter.filtering_rate < 0.0f
397+
? 0.0f
398+
: std::min(std::max(sample_filter.filtering_rate, min_filtering_rate),
399+
max_filtering_rate);
400+
}
401+
auto sample_filter_copy = sample_filter;
402+
return search_with_filtering<T, IdxT, decltype(sample_filter_copy), OutputIdxT>(
403+
res, params_copy, idx, queries, neighbors, distances, sample_filter_copy);
383404
} catch (const std::bad_cast&) {
384405
RAFT_FAIL("Unsupported sample filter type");
385406
}
Lines changed: 258 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,258 @@
1+
/*
2+
* SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights
3+
* reserved. SPDX-License-Identifier: Apache-2.0
4+
*/
5+
#pragma once
6+
7+
#include "../../sample_filter.cuh" // public filter types
8+
#include "../sample_filter_data.cuh"
9+
#include "jit_lto_kernels/cagra_filter_payload.cuh"
10+
11+
#include <raft/core/error.hpp>
12+
13+
#include <cuda_runtime_api.h>
14+
15+
#include <cstddef>
16+
#include <cstdint>
17+
#include <cstring>
18+
#include <list>
19+
#include <mutex>
20+
#include <type_traits>
21+
#include <unordered_map>
22+
23+
namespace cuvs::neighbors::cagra::detail {
24+
25+
template <typename PayloadT>
26+
std::uint64_t cagra_payload_hash(PayloadT const& payload)
27+
{
28+
static_assert(std::is_trivially_copyable_v<PayloadT>);
29+
constexpr std::uint64_t kOffset = 1469598103934665603ull;
30+
constexpr std::uint64_t kPrime = 1099511628211ull;
31+
auto const* bytes = reinterpret_cast<unsigned char const*>(&payload);
32+
std::uint64_t hash = kOffset;
33+
for (std::size_t i = 0; i < sizeof(PayloadT); ++i) {
34+
hash ^= bytes[i];
35+
hash *= kPrime;
36+
}
37+
return hash;
38+
}
39+
40+
template <typename PayloadT>
41+
struct cagra_device_payload_owner {
42+
struct state {
43+
PayloadT host_payload{};
44+
PayloadT* device_payload{nullptr};
45+
cudaStream_t stream{};
46+
cudaEvent_t ready_event{};
47+
int device{-1};
48+
std::mutex mutex;
49+
50+
explicit state(PayloadT payload) : host_payload(payload) {}
51+
52+
~state() noexcept
53+
{
54+
if (device_payload != nullptr) {
55+
RAFT_CUDA_TRY_NO_THROW(cudaFreeAsync(device_payload, stream));
56+
}
57+
if (ready_event != nullptr) { RAFT_CUDA_TRY_NO_THROW(cudaEventDestroy(ready_event)); }
58+
}
59+
60+
PayloadT* dev_ptr(cudaStream_t cuda_stream)
61+
{
62+
std::lock_guard<std::mutex> lock(mutex);
63+
if (device_payload == nullptr) {
64+
RAFT_CUDA_TRY(cudaGetDevice(&device));
65+
RAFT_CUDA_TRY(cudaMallocAsync(
66+
reinterpret_cast<void**>(&device_payload), sizeof(PayloadT), cuda_stream));
67+
RAFT_CUDA_TRY(cudaMemcpyAsync(
68+
device_payload, &host_payload, sizeof(PayloadT), cudaMemcpyHostToDevice, cuda_stream));
69+
RAFT_CUDA_TRY(cudaEventCreateWithFlags(&ready_event, cudaEventDisableTiming));
70+
RAFT_CUDA_TRY(cudaEventRecord(ready_event, cuda_stream));
71+
stream = cuda_stream;
72+
} else {
73+
RAFT_CUDA_TRY(cudaStreamWaitEvent(cuda_stream, ready_event, 0));
74+
}
75+
return device_payload;
76+
}
77+
};
78+
79+
// PayloadT is copied to device by value. Pointer fields inside PayloadT are shallow-copied and
80+
// must already point to device-addressable memory that remains valid while the cached payload is
81+
// usable.
82+
struct cache_key {
83+
std::uint64_t payload_hash{};
84+
int device{};
85+
86+
bool operator==(cache_key const& other) const
87+
{
88+
return payload_hash == other.payload_hash && device == other.device;
89+
}
90+
};
91+
92+
struct cache_key_hash {
93+
std::size_t operator()(cache_key const& key) const
94+
{
95+
auto seed = static_cast<std::size_t>(key.payload_hash);
96+
seed ^= static_cast<std::size_t>(key.device) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
97+
return seed;
98+
}
99+
};
100+
101+
cagra_device_payload_owner() = default;
102+
103+
void* dev_ptr(PayloadT payload, cudaStream_t stream) const
104+
{
105+
int device{};
106+
RAFT_CUDA_TRY(cudaGetDevice(&device));
107+
108+
// Keep cached payload copies for process lifetime to avoid per-search allocation/copy churn.
109+
// Cross-stream reuse is ordered by each state's ready_event before kernels consume the pointer.
110+
const auto key = cache_key{cagra_payload_hash(payload), device};
111+
state* selected_state{};
112+
{
113+
std::lock_guard<std::mutex> lock(cache_mutex_);
114+
auto& entries = cache_[key];
115+
for (auto& cached : entries) {
116+
if (std::memcmp(&cached.host_payload, &payload, sizeof(PayloadT)) == 0) {
117+
selected_state = &cached;
118+
break;
119+
}
120+
}
121+
if (selected_state == nullptr) {
122+
entries.emplace_back(payload);
123+
selected_state = &entries.back();
124+
}
125+
}
126+
127+
return selected_state->dev_ptr(stream);
128+
}
129+
130+
private:
131+
mutable std::mutex cache_mutex_;
132+
mutable std::unordered_map<cache_key, std::list<state>, cache_key_hash> cache_;
133+
};
134+
135+
template <typename T>
136+
struct is_bitset_filter : std::false_type {};
137+
138+
template <typename bitset_t, typename index_t>
139+
struct is_bitset_filter<::cuvs::neighbors::filtering::bitset_filter<bitset_t, index_t>>
140+
: std::true_type {};
141+
142+
template <typename T>
143+
struct is_udf_filter : std::false_type {};
144+
145+
template <>
146+
struct is_udf_filter<::cuvs::neighbors::filtering::udf_filter> : std::true_type {};
147+
148+
template <typename SourceIndexT, typename FilterT>
149+
::cuvs::neighbors::detail::bitset_filter_data_t<SourceIndexT> make_cagra_bitset_filter_storage(
150+
const FilterT& filter)
151+
{
152+
const auto bitset_view = filter.view();
153+
return ::cuvs::neighbors::detail::bitset_filter_data_t<SourceIndexT>{
154+
const_cast<std::uint32_t*>(bitset_view.data()),
155+
static_cast<SourceIndexT>(bitset_view.size()),
156+
static_cast<SourceIndexT>(bitset_view.get_original_nbits())};
157+
}
158+
159+
template <typename PayloadT>
160+
void* get_cagra_device_payload(PayloadT payload, cudaStream_t stream)
161+
{
162+
static cagra_device_payload_owner<PayloadT> owner;
163+
return owner.dev_ptr(payload, stream);
164+
}
165+
166+
template <typename SourceIndexT, typename FilterT>
167+
void* make_cagra_bitset_filter_payload(const FilterT& filter, cudaStream_t stream)
168+
{
169+
return get_cagra_device_payload(make_cagra_bitset_filter_storage<SourceIndexT>(filter), stream);
170+
}
171+
172+
template <typename SourceIndexT, typename FilterT>
173+
void fill_cagra_sample_filter(cagra_sample_filter<SourceIndexT>& out,
174+
const FilterT& filter,
175+
cudaStream_t stream)
176+
{
177+
using DecayedFilter = std::decay_t<FilterT>;
178+
if constexpr (is_bitset_filter<DecayedFilter>::value) {
179+
out.filter_data = make_cagra_bitset_filter_payload<SourceIndexT>(filter, stream);
180+
} else if constexpr (is_udf_filter<DecayedFilter>::value) {
181+
out.filter_data = filter.filter_data;
182+
}
183+
}
184+
185+
template <typename SourceIndexT, typename FilterT>
186+
std::uint64_t cagra_filter_payload_hash(const FilterT& filter)
187+
{
188+
using DecayedFilter = std::decay_t<FilterT>;
189+
if constexpr (is_bitset_filter<DecayedFilter>::value) {
190+
return cagra_payload_hash(make_cagra_bitset_filter_storage<SourceIndexT>(filter));
191+
} else if constexpr (requires { filter.filter; }) {
192+
return cagra_filter_payload_hash<SourceIndexT>(filter.filter);
193+
} else {
194+
return 0;
195+
}
196+
}
197+
198+
template <typename FilterT>
199+
void* cagra_filter_data_ptr(const FilterT& filter)
200+
{
201+
using DecayedFilter = std::decay_t<FilterT>;
202+
if constexpr (is_udf_filter<DecayedFilter>::value) {
203+
return filter.filter_data;
204+
} else if constexpr (requires { filter.filter; }) {
205+
return cagra_filter_data_ptr(filter.filter);
206+
} else {
207+
return nullptr;
208+
}
209+
}
210+
211+
template <typename SampleFilterT>
212+
std::uint32_t cagra_filter_query_id_offset(const SampleFilterT& sample_filter)
213+
{
214+
if constexpr (requires {
215+
sample_filter.filter;
216+
sample_filter.offset;
217+
}) {
218+
return sample_filter.offset;
219+
} else {
220+
return 0;
221+
}
222+
}
223+
224+
/// Host: fill @ref cagra_sample_filter from a CAGRA filter object.
225+
template <typename SourceIndexT, typename SampleFilterT>
226+
cagra_sample_filter<SourceIndexT> extract_cagra_sample_filter(const SampleFilterT& sample_filter,
227+
cudaStream_t stream)
228+
{
229+
cagra_sample_filter<SourceIndexT> out;
230+
if constexpr (requires {
231+
sample_filter.filter;
232+
sample_filter.offset;
233+
}) {
234+
out.query_id_offset = sample_filter.offset;
235+
fill_cagra_sample_filter(out, sample_filter.filter, stream);
236+
} else {
237+
fill_cagra_sample_filter(out, sample_filter, stream);
238+
}
239+
return out;
240+
}
241+
242+
/// Host: find UDF compile/link metadata only. Query offsets stay in the runtime payload produced
243+
/// by @ref extract_cagra_sample_filter and are applied before calling the linked sample_filter.
244+
template <typename SampleFilterT>
245+
const ::cuvs::neighbors::filtering::udf_filter* get_cagra_udf_filter(
246+
const SampleFilterT& sample_filter)
247+
{
248+
using DecayedFilter = std::decay_t<SampleFilterT>;
249+
if constexpr (is_udf_filter<DecayedFilter>::value) {
250+
return &sample_filter;
251+
} else if constexpr (requires { sample_filter.filter; }) {
252+
return get_cagra_udf_filter(sample_filter.filter);
253+
} else {
254+
return nullptr;
255+
}
256+
}
257+
258+
} // namespace cuvs::neighbors::cagra::detail

0 commit comments

Comments
 (0)