Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -566,9 +566,7 @@ if(NOT BUILD_CPU_ONLY)
src/neighbors/iface/iface_pq_uint8_t_int64_t.cu
src/neighbors/detail/cagra/topk_for_cagra/topk.cu
src/neighbors/dynamic_batching.cu
src/neighbors/cagra_index_wrapper.cu
src/neighbors/composite/index.cu
src/neighbors/composite/merge.cpp
$<$<BOOL:${BUILD_CAGRA_HNSWLIB}>:src/neighbors/cagra.cpp>
$<$<BOOL:${BUILD_CAGRA_HNSWLIB}>:src/neighbors/hnsw.cpp>
src/neighbors/ivf_common.cu
Expand Down
30 changes: 9 additions & 21 deletions cpp/bench/ann/src/cuvs/cuvs_cagra_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
#include <cuvs/distance/distance.hpp>
#include <cuvs/neighbors/cagra.hpp>
#include <cuvs/neighbors/common.hpp>
#include <cuvs/neighbors/composite/merge.hpp>
#include <cuvs/neighbors/composite/index.hpp>
#include <cuvs/neighbors/dynamic_batching.hpp>
#include <cuvs/neighbors/ivf_pq.hpp>
#include <cuvs/neighbors/nn_descent.hpp>
Expand Down Expand Up @@ -453,33 +453,21 @@ void cuvs_cagra<T, IdxT>::search_base(
} else {
if (index_params_.merge_type == CagraMergeType::kLogical) {
// TODO: index merge must happen outside of search, otherwise what are we benchmarking?
cuvs::neighbors::cagra::merge_params merge_params{cuvs::neighbors::cagra::index_params{}};
merge_params.merge_strategy = cuvs::neighbors::MergeStrategy::MERGE_STRATEGY_LOGICAL;

// Create wrapped indices for composite merge
std::vector<std::shared_ptr<cuvs::neighbors::IndexWrapper<T, IdxT, algo_base::index_type>>>
wrapped_indices;
wrapped_indices.reserve(sub_indices_.size());
std::vector<cuvs::neighbors::cagra::index<T, IdxT>*> cagra_indices;
cagra_indices.reserve(sub_indices_.size());
for (auto& ptr : sub_indices_) {
auto index_wrapper =
cuvs::neighbors::cagra::make_index_wrapper<T, IdxT, algo_base::index_type>(ptr.get());
wrapped_indices.push_back(index_wrapper);
cagra_indices.push_back(ptr.get());
}

raft::resources composite_handle(handle_);
size_t n_streams = wrapped_indices.size();
size_t n_streams = cagra_indices.size();
raft::resource::set_cuda_stream_pool(composite_handle,
std::make_shared<rmm::cuda_stream_pool>(n_streams));

auto merged_index =
cuvs::neighbors::composite::merge(composite_handle, merge_params, wrapped_indices);
cuvs::neighbors::filtering::none_sample_filter empty_filter;
merged_index->search(composite_handle,
search_params_,
queries_view,
neighbors_view,
distances_view,
empty_filter);
cuvs::neighbors::composite::CompositeIndex<T, IdxT, algo_base::index_type> composite(
cagra_indices);
composite.search(
composite_handle, search_params_, queries_view, neighbors_view, distances_view);
}
}
}
Expand Down
2 changes: 0 additions & 2 deletions cpp/include/cuvs/neighbors/cagra.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3173,5 +3173,3 @@ void optimize(raft::resources const& handle,
raft::host_matrix_view<uint32_t, int64_t, raft::row_major> new_graph);

} // namespace cuvs::neighbors::cagra::helpers

#include <cuvs/neighbors/cagra_index_wrapper.hpp>
163 changes: 0 additions & 163 deletions cpp/include/cuvs/neighbors/cagra_index_wrapper.hpp

This file was deleted.

7 changes: 0 additions & 7 deletions cpp/include/cuvs/neighbors/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,13 +126,6 @@ enum class MergeStrategy {
MERGE_STRATEGY_LOGICAL = 1
};

/** Base merge parameters with polymorphic interface. */
struct merge_params {
virtual ~merge_params() = default;

virtual MergeStrategy strategy() const = 0;
};

/** @} */ // end group neighbors_index

/** Two-dimensional dataset; maybe owning, maybe compressed, maybe strided. */
Expand Down
80 changes: 41 additions & 39 deletions cpp/include/cuvs/neighbors/composite/index.hpp
Original file line number Diff line number Diff line change
@@ -1,76 +1,78 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.
* SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/

#pragma once

#include <cuvs/distance/distance.hpp>
#include <cuvs/neighbors/index_base.hpp>
#include <cuvs/neighbors/cagra.hpp>
#include <raft/core/device_mdspan.hpp>

#include <memory>
#include <vector>

namespace cuvs::neighbors::composite {

/**
* @brief Composite index made of other IndexBase implementations.
* @brief Composite index that searches multiple CAGRA sub-indices and merges results.
*
* When the composite index contains multiple sub-indices, the user can set a
* stream pool in the input raft::resource to enable parallel search across
* sub-indices for improved performance.
*
* Usage example:
* @code{.cpp}
* using namespace cuvs::neighbors;
*
* auto index0 = cagra::build(res, params, dataset0);
* auto index1 = cagra::build(res, params, dataset1);
*
* composite::CompositeIndex<float, uint32_t> composite({&index0, &index1});
*
* // optional: create a stream pool to enable parallel search across sub-indices
* size_t n_streams = 2;
* raft::resource::set_cuda_stream_pool(handle,
* std::make_shared<rmm::cuda_stream_pool>(n_streams));
*
* composite.search(handle, search_params, queries, neighbors, distances);
* @endcode
*/
template <typename T, typename IdxT, typename OutputIdxT = IdxT>
class CompositeIndex : public IndexBase<T, IdxT, OutputIdxT> {
class CompositeIndex {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We use snake case (for eg. composite_index) for all other class names. Let's continue to use the same style standards here for consistency.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in 2052e39

public:
using value_type = typename IndexBase<T, IdxT, OutputIdxT>::value_type;
using index_type = typename IndexBase<T, IdxT, OutputIdxT>::index_type;
using out_index_type = typename IndexBase<T, IdxT, OutputIdxT>::out_index_type;
using matrix_index_type = typename IndexBase<T, IdxT, OutputIdxT>::matrix_index_type;
using value_type = T;
using index_type = IdxT;
using out_index_type = OutputIdxT;
using matrix_index_type = int64_t;

using index_ptr = std::shared_ptr<IndexBase<value_type, index_type, out_index_type>>;

explicit CompositeIndex(std::vector<index_ptr> children) : children_(std::move(children)) {}
explicit CompositeIndex(std::vector<cuvs::neighbors::cagra::index<T, IdxT>*> children)
: children_(std::move(children))
{
}

/**
* @brief Search the composite index for the k nearest neighbors.
*
* When the composite index contains multiple sub-indices, the user can set a
* stream pool in the input raft::resource to enable parallel search across
* sub-indices for improved performance.
*
* Usage example:
* @code{.cpp}
* using namespace cuvs::neighbors;
* // create a composite index with multiple sub-indices
* std::vector<CompositeIndex<T, IdxT>::index_ptr> sub_indices;
* // ... populate sub_indices ...
* auto composite_index = CompositeIndex<T, IdxT>(std::move(sub_indices));
*
* // optional: create a stream pool to enable parallel search across sub-indices
* // recommended stream count: min(number_of_sub_indices, 8)
* size_t n_streams = std::min(sub_indices.size(), size_t(8));
* raft::resource::set_cuda_stream_pool(handle,
* std::make_shared<rmm::cuda_stream_pool>(n_streams));
*
* // perform search with parallel sub-index execution
* composite_index.search(handle, search_params, queries, neighbors, distances);
* @endcode
* Searches each sub-index independently (optionally in parallel via stream pool),
* then selects the top-k results across all sub-indices.
*
* @param[in] handle raft resource handle
* @param[in] params search parameters
* @param[in] params CAGRA search parameters
* @param[in] queries device matrix view of query vectors [n_queries, dim]
* @param[out] neighbors device matrix view for neighbor indices [n_queries, k]
* @param[out] distances device matrix view for distances [n_queries, k]
* @param[in] filter optional filter for search results
*/
void search(
const raft::resources& handle,
const cuvs::neighbors::search_params& params,
const cuvs::neighbors::cagra::search_params& params,
raft::device_matrix_view<const value_type, matrix_index_type, raft::row_major> queries,
raft::device_matrix_view<out_index_type, matrix_index_type, raft::row_major> neighbors,
raft::device_matrix_view<float, matrix_index_type, raft::row_major> distances,
const cuvs::neighbors::filtering::base_filter& filter =
cuvs::neighbors::filtering::none_sample_filter{}) const override;
cuvs::neighbors::filtering::none_sample_filter{}) const;

index_type size() const noexcept override
index_type size() const noexcept
{
index_type total = 0;
for (const auto& c : children_) {
Expand All @@ -79,14 +81,14 @@ class CompositeIndex : public IndexBase<T, IdxT, OutputIdxT> {
return total;
}

cuvs::distance::DistanceType metric() const noexcept override
cuvs::distance::DistanceType metric() const noexcept
{
return children_.empty() ? cuvs::distance::DistanceType::L2Expanded
: children_.front()->metric();
}

private:
std::vector<index_ptr> children_;
std::vector<cuvs::neighbors::cagra::index<T, IdxT>*> children_;
};

} // namespace cuvs::neighbors::composite
Loading