Skip to content

IVF-PQ: low-precision coarse search #715

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 21 commits into
base: branch-25.06
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
8c3b0aa
IVF-PQ: low-precision coarse search
achirkin Feb 21, 2025
7e239a4
Merge branch 'branch-25.04' into fea-ivf-pq-low-precision-coarse-search
achirkin Feb 25, 2025
bcc1aae
Merge branch 'branch-25.04' into fea-ivf-pq-low-precision-coarse-search
achirkin Feb 28, 2025
6271db3
Merge branch 'branch-25.04' into fea-ivf-pq-low-precision-coarse-search
achirkin Feb 28, 2025
f4c3d70
Merge branch 'branch-25.04' into fea-ivf-pq-low-precision-coarse-search
achirkin Mar 5, 2025
5a7f497
Merge branch 'branch-25.04' into fea-ivf-pq-low-precision-coarse-search
achirkin Mar 21, 2025
020e6c3
Add a few test cases to cover new parameters
achirkin Mar 21, 2025
27abf81
Relax the tests for 8-bit coarse search
achirkin Mar 21, 2025
4cb5836
Relax the tests for 8-bit coarse search
achirkin Mar 21, 2025
9ece563
Merge branch 'branch-25.04' into fea-ivf-pq-low-precision-coarse-search
achirkin Mar 25, 2025
0cf7630
Merge branch 'branch-25.04' into fea-ivf-pq-low-precision-coarse-search
achirkin Mar 31, 2025
d8cef4a
Add notes to int8_t implementation
achirkin Mar 31, 2025
47be16a
Add the new search parameters to the C API
achirkin Mar 31, 2025
3b76ae7
Merge branch 'branch-25.04' into fea-ivf-pq-low-precision-coarse-search
achirkin Apr 1, 2025
f2da0b5
Merge branch 'branch-25.04' into fea-ivf-pq-low-precision-coarse-search
achirkin Apr 3, 2025
94b5f4c
Remove redundant 'coarse' from the error message.
achirkin Apr 3, 2025
a2f41ef
Merge branch 'branch-25.06' into fea-ivf-pq-low-precision-coarse-search
achirkin Apr 16, 2025
9ef4657
Merge branch 'branch-25.06' into fea-ivf-pq-low-precision-coarse-search
achirkin Apr 16, 2025
b46f36e
Merge branch 'branch-25.06' into fea-ivf-pq-low-precision-coarse-search
achirkin Apr 17, 2025
3630820
Merge branch 'branch-25.06' into fea-ivf-pq-low-precision-coarse-search
achirkin Apr 24, 2025
1997932
Merge branch 'branch-25.06' into fea-ivf-pq-low-precision-coarse-search
achirkin Apr 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ if(BUILD_SHARED_LIBS)
src/neighbors/ivf_flat/ivf_flat_serialize_float_int64_t.cu
src/neighbors/ivf_flat/ivf_flat_serialize_int8_t_int64_t.cu
src/neighbors/ivf_flat/ivf_flat_serialize_uint8_t_int64_t.cu
src/neighbors/ivf_pq_index.cpp
src/neighbors/ivf_pq_index.cu
src/neighbors/ivf_pq/ivf_pq_build_common.cu
src/neighbors/ivf_pq/ivf_pq_serialize.cu
src/neighbors/ivf_pq/ivf_pq_deserialize.cu
Expand Down
19 changes: 19 additions & 0 deletions cpp/bench/ann/src/cuvs/cuvs_ann_bench_param_parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,25 @@ void parse_search_param(const nlohmann::json& conf,
// set half as default
param.pq_param.lut_dtype = CUDA_R_16F;
}

if (conf.contains("coarse_search_dtype")) {
std::string type = conf.at("coarse_search_dtype");
if (type == "float") {
param.pq_param.coarse_search_dtype = CUDA_R_32F;
} else if (type == "half") {
param.pq_param.coarse_search_dtype = CUDA_R_16F;
} else if (type == "int8") {
param.pq_param.coarse_search_dtype = CUDA_R_8I;
} else {
throw std::runtime_error("coarse_search_dtype: '" + type +
"', should be either 'float', 'half' or 'int8'");
}
}

if (conf.contains("max_internal_batch_size")) {
param.pq_param.max_internal_batch_size = conf.at("max_internal_batch_size");
}

if (conf.contains("refine_ratio")) {
param.refine_ratio = conf.at("refine_ratio");
if (param.refine_ratio < 1.0f) { throw std::runtime_error("refine_ratio should be >= 1.0"); }
Expand Down
15 changes: 15 additions & 0 deletions cpp/include/cuvs/neighbors/ivf_pq.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,21 @@ struct cuvsIvfPqSearchParams {
* performance slightly.
*/
cudaDataType_t internal_distance_dtype;
/**
* The data type to use as the GEMM element type when searching the clusters to probe.
*
* Possible values: [CUDA_R_8I, CUDA_R_16F, CUDA_R_32F].
*
* - Legacy default: CUDA_R_32F (float)
* - Recommended for performance: CUDA_R_16F (half)
* - Experimental/low-precision: CUDA_R_8I (int8_t)
* (WARNING: int8_t variant degrades recall unless data is normalized and low-dimensional)
*/
cudaDataType_t coarse_search_dtype;
/**
* Set the internal batch size to improve GPU utilization at the cost of larger memory footprint.
*/
uint32_t max_internal_batch_size;
/**
* Preferred fraction of SM's unified memory / L1 cache to be used as shared memory.
*
Expand Down
39 changes: 39 additions & 0 deletions cpp/include/cuvs/neighbors/ivf_pq.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@
#include <raft/core/resources.hpp>
#include <raft/util/integer_utils.hpp>

#include <optional>
#include <tuple>
#include <variant>
#include <vector>

namespace cuvs::neighbors::ivf_pq {

/**
Expand Down Expand Up @@ -181,6 +186,22 @@ struct search_params : cuvs::neighbors::search_params {
* performance if tweaked incorrectly.
*/
double preferred_shmem_carveout = 1.0;
/**
* [Experimental] The data type to use as the GEMM element type when searching the clusters to
* probe.
*
* Possible values: [CUDA_R_8I, CUDA_R_16F, CUDA_R_32F].
*
* - Legacy default: CUDA_R_32F (float)
* - Recommended for performance: CUDA_R_16F (half)
* - Experimental/low-precision: CUDA_R_8I (int8_t)
* (WARNING: int8_t variant degrades recall unless data is normalized and low-dimensional)
*/
cudaDataType_t coarse_search_dtype = CUDA_R_32F;
/**
* Set the internal batch size to improve GPU utilization at the cost of larger memory footprint.
*/
uint32_t max_internal_batch_size = 4096;
};
/**
* @}
Expand Down Expand Up @@ -427,6 +448,11 @@ struct index : cuvs::neighbors::index {
raft::device_matrix_view<float, uint32_t, raft::row_major> rotation_matrix() noexcept;
raft::device_matrix_view<const float, uint32_t, raft::row_major> rotation_matrix() const noexcept;

raft::device_matrix_view<const int8_t, uint32_t, raft::row_major> rotation_matrix_int8(
const raft::resources& res) const;
raft::device_matrix_view<const half, uint32_t, raft::row_major> rotation_matrix_half(
const raft::resources& res) const;

/**
* Accumulated list sizes, sorted in descending order [n_lists + 1].
* The last value contains the total length of the index.
Expand All @@ -447,6 +473,11 @@ struct index : cuvs::neighbors::index {
raft::device_matrix_view<float, uint32_t, raft::row_major> centers() noexcept;
raft::device_matrix_view<const float, uint32_t, raft::row_major> centers() const noexcept;

raft::device_matrix_view<const int8_t, uint32_t, raft::row_major> centers_int8(
const raft::resources& res) const;
raft::device_matrix_view<const half, uint32_t, raft::row_major> centers_half(
const raft::resources& res) const;

/** Cluster centers corresponding to the lists in the rotated space [n_lists, rot_dim] */
raft::device_matrix_view<float, uint32_t, raft::row_major> centers_rot() noexcept;
raft::device_matrix_view<const float, uint32_t, raft::row_major> centers_rot() const noexcept;
Expand Down Expand Up @@ -485,6 +516,14 @@ struct index : cuvs::neighbors::index {
raft::device_matrix<float, uint32_t, raft::row_major> centers_rot_;
raft::device_matrix<float, uint32_t, raft::row_major> rotation_matrix_;

// Lazy-initialized low-precision variants of index members - for low-precision coarse search.
// These are never serialized and not touched during build/extend.
mutable std::optional<raft::device_matrix<int8_t, uint32_t, raft::row_major>> centers_int8_;
mutable std::optional<raft::device_matrix<half, uint32_t, raft::row_major>> centers_half_;
mutable std::optional<raft::device_matrix<int8_t, uint32_t, raft::row_major>>
rotation_matrix_int8_;
mutable std::optional<raft::device_matrix<half, uint32_t, raft::row_major>> rotation_matrix_half_;

// Computed members for accelerating search.
raft::device_vector<uint8_t*, uint32_t, raft::row_major> data_ptrs_;
raft::device_vector<IdxT*, uint32_t, raft::row_major> inds_ptrs_;
Expand Down
16 changes: 16 additions & 0 deletions cpp/src/neighbors/detail/ann_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,22 @@ struct mapping {
/** @} */
};

template <>
template <>
HDI constexpr auto mapping<int8_t>::operator()(const uint8_t& x) const -> int8_t
{
// Avoid overflows when converting uint8_t -> int_8
return static_cast<int8_t>(x >> 1);
}

template <>
template <>
HDI constexpr auto mapping<int8_t>::operator()(const float& x) const -> int8_t
{
// Carefully clamp floats if out-of-bounds.
return static_cast<int8_t>(std::clamp<float>(x * 128.0f, -128.0f, 127.0f));
}

/**
* @brief Sets the first num bytes of the block of memory pointed by ptr to the specified value.
*
Expand Down
4 changes: 3 additions & 1 deletion cpp/src/neighbors/detail/cagra/cagra_build.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ ivf_pq_params::ivf_pq_params(raft::matrix_extent<int64_t> dataset_extents,
search_params.n_probes = std::max<uint32_t>(10, build_params.n_lists * 0.01);
search_params.lut_dtype = CUDA_R_16F;
search_params.internal_distance_dtype = CUDA_R_16F;
search_params.coarse_search_dtype = CUDA_R_16F;
search_params.max_internal_batch_size = 128 * 1024;

refinement_rate = 2;
refinement_rate = 1;
}
} // namespace cuvs::neighbors::cagra::graph_build_params
24 changes: 12 additions & 12 deletions cpp/src/neighbors/detail/cagra/cagra_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -164,13 +164,13 @@ void build_knn_graph(
const auto num_queries = dataset.extent(0);

// Use the same maximum batch size as the ivf_pq::search to avoid allocating more than needed.
constexpr uint32_t kMaxQueries = 4096;
const uint32_t max_queries = pq.search_params.max_internal_batch_size;

// Heuristic: the build_knn_graph code should use only a fraction of the workspace memory; the
// rest should be used by the ivf_pq::search. Here we say that the workspace size should be a good
// multiple of what is required for the I/O batching below.
constexpr size_t kMinWorkspaceRatio = 5;
auto desired_workspace_size = kMaxQueries * kMinWorkspaceRatio *
auto desired_workspace_size = max_queries * kMinWorkspaceRatio *
(sizeof(DataT) * dataset.extent(1) // queries (dataset batch)
+ sizeof(float) * gpu_top_k // distances
+ sizeof(int64_t) * gpu_top_k // neighbors
Expand All @@ -189,21 +189,21 @@ void build_knn_graph(
node_degree,
top_k,
gpu_top_k,
kMaxQueries,
max_queries,
pq.search_params.n_probes);

auto distances = raft::make_device_mdarray<float>(
res, workspace_mr, raft::make_extents<int64_t>(kMaxQueries, gpu_top_k));
res, workspace_mr, raft::make_extents<int64_t>(max_queries, gpu_top_k));
auto neighbors = raft::make_device_mdarray<int64_t>(
res, workspace_mr, raft::make_extents<int64_t>(kMaxQueries, gpu_top_k));
res, workspace_mr, raft::make_extents<int64_t>(max_queries, gpu_top_k));
auto refined_distances = raft::make_device_mdarray<float>(
res, workspace_mr, raft::make_extents<int64_t>(kMaxQueries, top_k));
res, workspace_mr, raft::make_extents<int64_t>(max_queries, top_k));
auto refined_neighbors = raft::make_device_mdarray<int64_t>(
res, workspace_mr, raft::make_extents<int64_t>(kMaxQueries, top_k));
auto neighbors_host = raft::make_host_matrix<int64_t, int64_t>(kMaxQueries, gpu_top_k);
auto queries_host = raft::make_host_matrix<DataT, int64_t>(kMaxQueries, dataset.extent(1));
auto refined_neighbors_host = raft::make_host_matrix<int64_t, int64_t>(kMaxQueries, top_k);
auto refined_distances_host = raft::make_host_matrix<float, int64_t>(kMaxQueries, top_k);
res, workspace_mr, raft::make_extents<int64_t>(max_queries, top_k));
auto neighbors_host = raft::make_host_matrix<int64_t, int64_t>(max_queries, gpu_top_k);
auto queries_host = raft::make_host_matrix<DataT, int64_t>(max_queries, dataset.extent(1));
auto refined_neighbors_host = raft::make_host_matrix<int64_t, int64_t>(max_queries, top_k);
auto refined_distances_host = raft::make_host_matrix<float, int64_t>(max_queries, top_k);

// TODO(tfeher): batched search with multiple GPUs
std::size_t num_self_included = 0;
Expand All @@ -214,7 +214,7 @@ void build_knn_graph(
dataset.data_handle(),
dataset.extent(0),
dataset.extent(1),
static_cast<int64_t>(kMaxQueries),
static_cast<int64_t>(max_queries),
raft::resource::get_cuda_stream(res),
workspace_mr);

Expand Down
4 changes: 4 additions & 0 deletions cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,10 @@ void set_centers(raft::resources const& handle, index<IdxT>* index, const float*
auto stream = raft::resource::get_cuda_stream(handle);
auto* device_memory = raft::resource::get_workspace_resource(handle);

// Make sure to have trailing zeroes between dim and dim_ext;
// We rely on this to enable padded tensor gemm kernels during coarse search.
cuvs::spatial::knn::detail::utils::memzero(
index->centers().data_handle(), index->centers().size(), stream);
// combine cluster_centers and their norms
RAFT_CUDA_TRY(cudaMemcpy2DAsync(index->centers().data_handle(),
sizeof(float) * index->dim_ext(),
Expand Down
Loading