diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 60cab78b2..e919be49d 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -425,7 +425,7 @@ if(BUILD_SHARED_LIBS) src/neighbors/ivf_flat/ivf_flat_serialize_float_int64_t.cu src/neighbors/ivf_flat/ivf_flat_serialize_int8_t_int64_t.cu src/neighbors/ivf_flat/ivf_flat_serialize_uint8_t_int64_t.cu - src/neighbors/ivf_pq_index.cpp + src/neighbors/ivf_pq_index.cu src/neighbors/ivf_pq/ivf_pq_build_common.cu src/neighbors/ivf_pq/ivf_pq_serialize.cu src/neighbors/ivf_pq/ivf_pq_deserialize.cu diff --git a/cpp/bench/ann/src/cuvs/cuvs_ann_bench_param_parser.h b/cpp/bench/ann/src/cuvs/cuvs_ann_bench_param_parser.h index 7617bfa66..48a0f39ab 100644 --- a/cpp/bench/ann/src/cuvs/cuvs_ann_bench_param_parser.h +++ b/cpp/bench/ann/src/cuvs/cuvs_ann_bench_param_parser.h @@ -154,6 +154,25 @@ void parse_search_param(const nlohmann::json& conf, // set half as default param.pq_param.lut_dtype = CUDA_R_16F; } + + if (conf.contains("coarse_search_dtype")) { + std::string type = conf.at("coarse_search_dtype"); + if (type == "float") { + param.pq_param.coarse_search_dtype = CUDA_R_32F; + } else if (type == "half") { + param.pq_param.coarse_search_dtype = CUDA_R_16F; + } else if (type == "int8") { + param.pq_param.coarse_search_dtype = CUDA_R_8I; + } else { + throw std::runtime_error("coarse_search_dtype: '" + type + + "', should be either 'float', 'half' or 'int8'"); + } + } + + if (conf.contains("max_internal_batch_size")) { + param.pq_param.max_internal_batch_size = conf.at("max_internal_batch_size"); + } + if (conf.contains("refine_ratio")) { param.refine_ratio = conf.at("refine_ratio"); if (param.refine_ratio < 1.0f) { throw std::runtime_error("refine_ratio should be >= 1.0"); } diff --git a/cpp/include/cuvs/neighbors/ivf_pq.h b/cpp/include/cuvs/neighbors/ivf_pq.h index 08fe600ae..66f4a86fa 100644 --- a/cpp/include/cuvs/neighbors/ivf_pq.h +++ b/cpp/include/cuvs/neighbors/ivf_pq.h @@ -178,6 +178,21 @@ struct cuvsIvfPqSearchParams { * performance slightly. */ cudaDataType_t internal_distance_dtype; + /** + * The data type to use as the GEMM element type when searching the clusters to probe. + * + * Possible values: [CUDA_R_8I, CUDA_R_16F, CUDA_R_32F]. + * + * - Legacy default: CUDA_R_32F (float) + * - Recommended for performance: CUDA_R_16F (half) + * - Experimental/low-precision: CUDA_R_8I (int8_t) + * (WARNING: int8_t variant degrades recall unless data is normalized and low-dimensional) + */ + cudaDataType_t coarse_search_dtype; + /** + * Set the internal batch size to improve GPU utilization at the cost of larger memory footprint. + */ + uint32_t max_internal_batch_size; /** * Preferred fraction of SM's unified memory / L1 cache to be used as shared memory. * diff --git a/cpp/include/cuvs/neighbors/ivf_pq.hpp b/cpp/include/cuvs/neighbors/ivf_pq.hpp index 73f9a831b..5efe3775f 100644 --- a/cpp/include/cuvs/neighbors/ivf_pq.hpp +++ b/cpp/include/cuvs/neighbors/ivf_pq.hpp @@ -27,6 +27,11 @@ #include #include +#include +#include +#include +#include + namespace cuvs::neighbors::ivf_pq { /** @@ -181,6 +186,22 @@ struct search_params : cuvs::neighbors::search_params { * performance if tweaked incorrectly. */ double preferred_shmem_carveout = 1.0; + /** + * [Experimental] The data type to use as the GEMM element type when searching the clusters to + * probe. + * + * Possible values: [CUDA_R_8I, CUDA_R_16F, CUDA_R_32F]. + * + * - Legacy default: CUDA_R_32F (float) + * - Recommended for performance: CUDA_R_16F (half) + * - Experimental/low-precision: CUDA_R_8I (int8_t) + * (WARNING: int8_t variant degrades recall unless data is normalized and low-dimensional) + */ + cudaDataType_t coarse_search_dtype = CUDA_R_32F; + /** + * Set the internal batch size to improve GPU utilization at the cost of larger memory footprint. + */ + uint32_t max_internal_batch_size = 4096; }; /** * @} @@ -427,6 +448,11 @@ struct index : cuvs::neighbors::index { raft::device_matrix_view rotation_matrix() noexcept; raft::device_matrix_view rotation_matrix() const noexcept; + raft::device_matrix_view rotation_matrix_int8( + const raft::resources& res) const; + raft::device_matrix_view rotation_matrix_half( + const raft::resources& res) const; + /** * Accumulated list sizes, sorted in descending order [n_lists + 1]. * The last value contains the total length of the index. @@ -447,6 +473,11 @@ struct index : cuvs::neighbors::index { raft::device_matrix_view centers() noexcept; raft::device_matrix_view centers() const noexcept; + raft::device_matrix_view centers_int8( + const raft::resources& res) const; + raft::device_matrix_view centers_half( + const raft::resources& res) const; + /** Cluster centers corresponding to the lists in the rotated space [n_lists, rot_dim] */ raft::device_matrix_view centers_rot() noexcept; raft::device_matrix_view centers_rot() const noexcept; @@ -485,6 +516,14 @@ struct index : cuvs::neighbors::index { raft::device_matrix centers_rot_; raft::device_matrix rotation_matrix_; + // Lazy-initialized low-precision variants of index members - for low-precision coarse search. + // These are never serialized and not touched during build/extend. + mutable std::optional> centers_int8_; + mutable std::optional> centers_half_; + mutable std::optional> + rotation_matrix_int8_; + mutable std::optional> rotation_matrix_half_; + // Computed members for accelerating search. raft::device_vector data_ptrs_; raft::device_vector inds_ptrs_; diff --git a/cpp/src/neighbors/detail/ann_utils.cuh b/cpp/src/neighbors/detail/ann_utils.cuh index 149eea3f1..4c62d5bac 100644 --- a/cpp/src/neighbors/detail/ann_utils.cuh +++ b/cpp/src/neighbors/detail/ann_utils.cuh @@ -195,6 +195,22 @@ struct mapping { /** @} */ }; +template <> +template <> +HDI constexpr auto mapping::operator()(const uint8_t& x) const -> int8_t +{ + // Avoid overflows when converting uint8_t -> int_8 + return static_cast(x >> 1); +} + +template <> +template <> +HDI constexpr auto mapping::operator()(const float& x) const -> int8_t +{ + // Carefully clamp floats if out-of-bounds. + return static_cast(std::clamp(x * 128.0f, -128.0f, 127.0f)); +} + /** * @brief Sets the first num bytes of the block of memory pointed by ptr to the specified value. * diff --git a/cpp/src/neighbors/detail/cagra/cagra_build.cpp b/cpp/src/neighbors/detail/cagra/cagra_build.cpp index 490dc0f30..0fdfd1bcf 100644 --- a/cpp/src/neighbors/detail/cagra/cagra_build.cpp +++ b/cpp/src/neighbors/detail/cagra/cagra_build.cpp @@ -29,7 +29,9 @@ ivf_pq_params::ivf_pq_params(raft::matrix_extent dataset_extents, search_params.n_probes = std::max(10, build_params.n_lists * 0.01); search_params.lut_dtype = CUDA_R_16F; search_params.internal_distance_dtype = CUDA_R_16F; + search_params.coarse_search_dtype = CUDA_R_16F; + search_params.max_internal_batch_size = 128 * 1024; - refinement_rate = 2; + refinement_rate = 1; } } // namespace cuvs::neighbors::cagra::graph_build_params diff --git a/cpp/src/neighbors/detail/cagra/cagra_build.cuh b/cpp/src/neighbors/detail/cagra/cagra_build.cuh index 7c749cf6e..861c02561 100644 --- a/cpp/src/neighbors/detail/cagra/cagra_build.cuh +++ b/cpp/src/neighbors/detail/cagra/cagra_build.cuh @@ -164,13 +164,13 @@ void build_knn_graph( const auto num_queries = dataset.extent(0); // Use the same maximum batch size as the ivf_pq::search to avoid allocating more than needed. - constexpr uint32_t kMaxQueries = 4096; + const uint32_t max_queries = pq.search_params.max_internal_batch_size; // Heuristic: the build_knn_graph code should use only a fraction of the workspace memory; the // rest should be used by the ivf_pq::search. Here we say that the workspace size should be a good // multiple of what is required for the I/O batching below. constexpr size_t kMinWorkspaceRatio = 5; - auto desired_workspace_size = kMaxQueries * kMinWorkspaceRatio * + auto desired_workspace_size = max_queries * kMinWorkspaceRatio * (sizeof(DataT) * dataset.extent(1) // queries (dataset batch) + sizeof(float) * gpu_top_k // distances + sizeof(int64_t) * gpu_top_k // neighbors @@ -189,21 +189,21 @@ void build_knn_graph( node_degree, top_k, gpu_top_k, - kMaxQueries, + max_queries, pq.search_params.n_probes); auto distances = raft::make_device_mdarray( - res, workspace_mr, raft::make_extents(kMaxQueries, gpu_top_k)); + res, workspace_mr, raft::make_extents(max_queries, gpu_top_k)); auto neighbors = raft::make_device_mdarray( - res, workspace_mr, raft::make_extents(kMaxQueries, gpu_top_k)); + res, workspace_mr, raft::make_extents(max_queries, gpu_top_k)); auto refined_distances = raft::make_device_mdarray( - res, workspace_mr, raft::make_extents(kMaxQueries, top_k)); + res, workspace_mr, raft::make_extents(max_queries, top_k)); auto refined_neighbors = raft::make_device_mdarray( - res, workspace_mr, raft::make_extents(kMaxQueries, top_k)); - auto neighbors_host = raft::make_host_matrix(kMaxQueries, gpu_top_k); - auto queries_host = raft::make_host_matrix(kMaxQueries, dataset.extent(1)); - auto refined_neighbors_host = raft::make_host_matrix(kMaxQueries, top_k); - auto refined_distances_host = raft::make_host_matrix(kMaxQueries, top_k); + res, workspace_mr, raft::make_extents(max_queries, top_k)); + auto neighbors_host = raft::make_host_matrix(max_queries, gpu_top_k); + auto queries_host = raft::make_host_matrix(max_queries, dataset.extent(1)); + auto refined_neighbors_host = raft::make_host_matrix(max_queries, top_k); + auto refined_distances_host = raft::make_host_matrix(max_queries, top_k); // TODO(tfeher): batched search with multiple GPUs std::size_t num_self_included = 0; @@ -214,7 +214,7 @@ void build_knn_graph( dataset.data_handle(), dataset.extent(0), dataset.extent(1), - static_cast(kMaxQueries), + static_cast(max_queries), raft::resource::get_cuda_stream(res), workspace_mr); diff --git a/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh b/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh index 4e2410405..2ee0335f6 100644 --- a/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh +++ b/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh @@ -239,6 +239,10 @@ void set_centers(raft::resources const& handle, index* index, const float* auto stream = raft::resource::get_cuda_stream(handle); auto* device_memory = raft::resource::get_workspace_resource(handle); + // Make sure to have trailing zeroes between dim and dim_ext; + // We rely on this to enable padded tensor gemm kernels during coarse search. + cuvs::spatial::knn::detail::utils::memzero( + index->centers().data_handle(), index->centers().size(), stream); // combine cluster_centers and their norms RAFT_CUDA_TRY(cudaMemcpy2DAsync(index->centers().data_handle(), sizeof(float) * index->dim_ext(), diff --git a/cpp/src/neighbors/ivf_pq/ivf_pq_search.cuh b/cpp/src/neighbors/ivf_pq/ivf_pq_search.cuh index 05bb99353..feba4e7ae 100644 --- a/cpp/src/neighbors/ivf_pq/ivf_pq_search.cuh +++ b/cpp/src/neighbors/ivf_pq/ivf_pq_search.cuh @@ -42,6 +42,7 @@ #include #include #include +#include #include #include #include @@ -131,19 +132,17 @@ void select_clusters(raft::resources const& handle, handle, float_queries_view, [queries, dim, dim_ext, norm_factor] __device__(uint32_t ix) { uint32_t col = ix % dim_ext; uint32_t row = ix / dim_ext; - return col < dim ? utils::mapping{}(queries[col + dim * row]) : norm_factor; + if (col < dim) { return utils::mapping{}(queries[col + dim * row]); } + return col == dim ? norm_factor : 0.0f; }); float alpha; float beta; - uint32_t gemm_k = dim; switch (metric) { case cuvs::distance::DistanceType::L2SqrtExpanded: case cuvs::distance::DistanceType::L2Expanded: { - alpha = -2.0; - beta = 0.0; - gemm_k = dim + 1; - RAFT_EXPECTS(gemm_k <= dim_ext, "unexpected gemm_k or dim_ext"); + alpha = -2.0; + beta = 0.0; } break; case cuvs::distance::DistanceType::CosineExpanded: case cuvs::distance::DistanceType::InnerProduct: { @@ -158,7 +157,7 @@ void select_clusters(raft::resources const& handle, false, n_lists, n_queries, - gemm_k, + dim_ext, &alpha, cluster_centers, dim_ext, @@ -180,6 +179,178 @@ void select_clusters(raft::resources const& handle, true); } +template +void select_clusters(raft::resources const& handle, + uint32_t* clusters_to_probe, // [n_queries, n_probes] + int8_t* float_queries, // [n_queries, dim_ext] + uint32_t n_queries, + uint32_t n_probes, + uint32_t n_lists, + uint32_t dim, + uint32_t dim_ext, + cuvs::distance::DistanceType metric, + const T* queries, // [n_queries, dim] + const int8_t* cluster_centers, // [n_lists, dim_ext] + rmm::mr::device_memory_resource* mr) +{ + raft::common::nvtx::range fun_scope( + "ivf_pq::search::select_clusters(n_probes = %u, n_queries = %u, n_lists = %u, dim = %u)", + n_probes, + n_queries, + n_lists, + dim); + auto stream = raft::resource::get_cuda_stream(handle); + int8_t norm_factor; + switch (metric) { + case cuvs::distance::DistanceType::L2SqrtExpanded: + case cuvs::distance::DistanceType::L2Expanded: norm_factor = -128; break; + case cuvs::distance::DistanceType::CosineExpanded: + case cuvs::distance::DistanceType::InnerProduct: norm_factor = 0; break; + default: RAFT_FAIL("Unsupported distance type %d.", int(metric)); + } + auto float_queries_view = + raft::make_device_vector_view(float_queries, dim_ext * n_queries); + raft::linalg::map_offset( + handle, float_queries_view, [queries, dim, dim_ext, norm_factor] __device__(uint32_t ix) { + uint32_t col = ix % dim_ext; + uint32_t row = ix / dim_ext; + if (col < dim) { return utils::mapping{}(queries[col + dim * row]); } + auto m = dim_ext - dim; + // see 'NOTE: maximizing the range and the precision of int8_t GEMM' in ivf_pq_index.cu + if (m == 1 || col > dim) { return norm_factor; } // times `y` (higher bits) + return static_cast(1 - m); // times `z` (lower bits) + }); + + using dist_type = int32_t; + dist_type alpha; + dist_type beta; + switch (metric) { + case cuvs::distance::DistanceType::L2SqrtExpanded: + case cuvs::distance::DistanceType::L2Expanded: { + alpha = -2; + beta = 0; + } break; + case cuvs::distance::DistanceType::CosineExpanded: + case cuvs::distance::DistanceType::InnerProduct: { + alpha = -1; + beta = 0; + } break; + default: RAFT_FAIL("Unsupported distance type %d.", int(metric)); + } + rmm::device_uvector qc_distances(n_queries * n_lists, stream, mr); + raft::linalg::gemm(handle, + true, + false, + n_lists, + n_queries, + dim_ext, + &alpha, + cluster_centers, + dim_ext, + float_queries, + dim_ext, + &beta, + qc_distances.data(), + n_lists, + stream); + + // Select neighbor clusters for each query. + rmm::device_uvector cluster_dists(n_queries * n_probes, stream, mr); + // cuvs::selection::select_k lacks uint32_t-as-a-value support at the moment + raft::matrix::select_k( + handle, + raft::make_device_matrix_view( + qc_distances.data(), n_queries, n_lists), + std::nullopt, + raft::make_device_matrix_view(cluster_dists.data(), n_queries, n_probes), + raft::make_device_matrix_view(clusters_to_probe, n_queries, n_probes), + true); +} + +template +void select_clusters(raft::resources const& handle, + uint32_t* clusters_to_probe, // [n_queries, n_probes] + half* float_queries, // [n_queries, dim_ext] + uint32_t n_queries, + uint32_t n_probes, + uint32_t n_lists, + uint32_t dim, + uint32_t dim_ext, + cuvs::distance::DistanceType metric, + const T* queries, // [n_queries, dim] + const half* cluster_centers, // [n_lists, dim_ext] + rmm::mr::device_memory_resource* mr) +{ + raft::common::nvtx::range fun_scope( + "ivf_pq::search::select_clusters(n_probes = %u, n_queries = %u, n_lists = %u, dim = %u)", + n_probes, + n_queries, + n_lists, + dim); + auto stream = raft::resource::get_cuda_stream(handle); + half norm_factor; + switch (metric) { + case cuvs::distance::DistanceType::L2SqrtExpanded: + case cuvs::distance::DistanceType::L2Expanded: norm_factor = 1.0 / -2.0; break; + case cuvs::distance::DistanceType::CosineExpanded: + case cuvs::distance::DistanceType::InnerProduct: norm_factor = 0; break; + default: RAFT_FAIL("Unsupported distance type %d.", int(metric)); + } + auto float_queries_view = + raft::make_device_vector_view(float_queries, dim_ext * n_queries); + raft::linalg::map_offset( + handle, float_queries_view, [queries, dim, dim_ext, norm_factor] __device__(uint32_t ix) { + uint32_t col = ix % dim_ext; + uint32_t row = ix / dim_ext; + if (col < dim) { return utils::mapping{}(queries[col + dim * row]); } + return col == dim ? norm_factor : half(0); + }); + + using dist_type = half; + dist_type alpha; + dist_type beta; + switch (metric) { + case cuvs::distance::DistanceType::L2SqrtExpanded: + case cuvs::distance::DistanceType::L2Expanded: { + alpha = -2.0; + beta = 0.0; + } break; + case cuvs::distance::DistanceType::CosineExpanded: + case cuvs::distance::DistanceType::InnerProduct: { + alpha = -1.0; + beta = 0.0; + } break; + default: RAFT_FAIL("Unsupported distance type %d.", int(metric)); + } + rmm::device_uvector qc_distances(n_queries * n_lists, stream, mr); + raft::linalg::gemm(handle, + true, + false, + n_lists, + n_queries, + dim_ext, + &alpha, + cluster_centers, + dim_ext, + float_queries, + dim_ext, + &beta, + qc_distances.data(), + n_lists, + stream); + + // Select neighbor clusters for each query. + rmm::device_uvector cluster_dists(n_queries * n_probes, stream, mr); + cuvs::selection::select_k( + handle, + raft::make_device_matrix_view( + qc_distances.data(), n_queries, n_lists), + std::nullopt, + raft::make_device_matrix_view(cluster_dists.data(), n_queries, n_probes), + raft::make_device_matrix_view(clusters_to_probe, n_queries, n_probes), + true); +} + /** * An approximation to the number of times each cluster appears in a batched sample. * @@ -607,8 +778,23 @@ inline auto get_max_batch_size(raft::resources const& res, return max_batch_size; } -/** Maximum number of queries ivf_pq::search can process in one batch. */ -constexpr uint32_t kMaxQueries = 4096; +template +inline auto get_rotation_matrix(const raft::resources& res, const index& index) + -> raft::device_matrix_view +{ + if constexpr (std::is_same_v) { return index.rotation_matrix(); } + if constexpr (std::is_same_v) { return index.rotation_matrix_half(res); } + if constexpr (std::is_same_v) { return index.rotation_matrix_int8(res); } +} + +template +inline auto get_centers(const raft::resources& res, const index& index) + -> raft::device_matrix_view +{ + if constexpr (std::is_same_v) { return index.centers(); } + if constexpr (std::is_same_v) { return index.centers_half(res); } + if constexpr (std::is_same_v) { return index.centers_int8(res); } +} /** See raft::spatial::knn::ivf_pq::search docs */ template (handle, index).extent(1) + : index.dim_ext(); auto n_probes = std::min(params.n_probes, index.n_lists()); uint32_t max_samples = 0; @@ -678,10 +867,24 @@ inline void search(raft::resources const& handle, auto mr = raft::resource::get_workspace_resource(handle); // Maximum number of query vectors to search at the same time. - const auto max_queries = std::min(std::max(n_queries, 1), kMaxQueries); - auto max_batch_size = get_max_batch_size(handle, k, n_probes, max_queries, max_samples); - - rmm::device_uvector float_queries(max_queries * dim_ext, stream, mr); + const auto max_queries = + std::min(std::max(n_queries, 1), params.max_internal_batch_size); + auto max_batch_size = get_max_batch_size(handle, k, n_probes, max_queries, max_samples); + + using some_query_t = std:: + variant, rmm::device_uvector, rmm::device_uvector>; + some_query_t gemm_queries( + params.coarse_search_dtype == CUDA_R_32F + ? std::move(some_query_t{ + std::in_place_type_t>{}, max_queries * dim_ext, stream, mr}) + : params.coarse_search_dtype == CUDA_R_16F + ? std::move(some_query_t{ + std::in_place_type_t>{}, max_queries * dim_ext, stream, mr}) + : params.coarse_search_dtype == CUDA_R_8I + ? std::move(some_query_t{ + std::in_place_type_t>{}, max_queries * dim_ext, stream, mr}) + : throw raft::logic_error("Unsupported coarse_search_dtype (only CUDA_R_32F, " + "CUDA_R_16F, and CUDA_R_8I are supported)")); rmm::device_uvector rot_queries(max_queries * index.rot_dim(), stream, mr); rmm::device_uvector clusters_to_probe(max_queries * n_probes, stream, mr); @@ -694,37 +897,49 @@ inline void search(raft::resources const& handle, raft::common::nvtx::range batch_scope( "ivf_pq::search-batch(queries: %u - %u)", offset_q, offset_q + queries_batch); - select_clusters(handle, - clusters_to_probe.data(), - float_queries.data(), - queries_batch, - n_probes, - index.n_lists(), - dim, - dim_ext, - index.metric(), - queries + static_cast(dim) * offset_q, - index.centers().data_handle(), - mr); + std::visit( + [&](auto&& gemm_qs) { + using gemm_type = std::remove_reference_t; + using value_type = std::remove_cv_t; + return select_clusters(handle, + clusters_to_probe.data(), + gemm_qs.data(), + queries_batch, + n_probes, + index.n_lists(), + dim, + dim_ext, + index.metric(), + queries + static_cast(dim) * offset_q, + get_centers(handle, index).data_handle(), + mr); + }, + gemm_queries); // Rotate queries - float alpha = 1.0; - float beta = 0.0; - raft::linalg::gemm(handle, - true, - false, - index.rot_dim(), - queries_batch, - dim, - &alpha, - index.rotation_matrix().data_handle(), - dim, - float_queries.data(), - dim_ext, - &beta, - rot_queries.data(), - index.rot_dim(), - stream); + std::visit( + [&](auto&& gemm_qs) { + using gemm_type = std::remove_reference_t; + using value_type = std::remove_cv_t; + float alpha = std::is_same_v ? 1.0 / 128.0 / 128.0 : 1.0; + float beta = 0.0; + raft::linalg::gemm(handle, + true, + false, + index.rot_dim(), + queries_batch, + dim, + &alpha, + get_rotation_matrix(handle, index).data_handle(), + dim, + gemm_qs.data(), + dim_ext, + &beta, + rot_queries.data(), + index.rot_dim(), + stream); + }, + gemm_queries); if (index.metric() == distance::DistanceType::CosineExpanded) { auto rot_queries_view = raft::make_device_matrix_view( rot_queries.data(), max_queries, index.rot_dim()); diff --git a/cpp/src/neighbors/ivf_pq_c.cpp b/cpp/src/neighbors/ivf_pq_c.cpp index 00ede91c0..6548155ef 100755 --- a/cpp/src/neighbors/ivf_pq_c.cpp +++ b/cpp/src/neighbors/ivf_pq_c.cpp @@ -79,6 +79,8 @@ void _search(cuvsResources_t res, search_params.lut_dtype = params.lut_dtype; search_params.internal_distance_dtype = params.internal_distance_dtype; search_params.preferred_shmem_carveout = params.preferred_shmem_carveout; + search_params.coarse_search_dtype = params.coarse_search_dtype; + search_params.max_internal_batch_size = params.max_internal_batch_size; using queries_mdspan_type = raft::device_matrix_view; using neighbors_mdspan_type = raft::device_matrix_view; @@ -246,6 +248,8 @@ extern "C" cuvsError_t cuvsIvfPqSearchParamsCreate(cuvsIvfPqSearchParams_t* para *params = new cuvsIvfPqSearchParams{.n_probes = 20, .lut_dtype = CUDA_R_32F, .internal_distance_dtype = CUDA_R_32F, + .coarse_search_dtype = CUDA_R_32F, + .max_internal_batch_size = 4096, .preferred_shmem_carveout = 1.0}; }); } diff --git a/cpp/src/neighbors/ivf_pq_index.cpp b/cpp/src/neighbors/ivf_pq_index.cu similarity index 68% rename from cpp/src/neighbors/ivf_pq_index.cpp rename to cpp/src/neighbors/ivf_pq_index.cu index 8f4e5b331..1bd6b6291 100644 --- a/cpp/src/neighbors/ivf_pq_index.cpp +++ b/cpp/src/neighbors/ivf_pq_index.cu @@ -16,6 +16,14 @@ #include +#include "detail/ann_utils.cuh" + +#include +#include +#include + +#include + namespace cuvs::neighbors::ivf_pq { index_params index_params::from_dataset(raft::matrix_extent dataset, cuvs::distance::DistanceType metric) @@ -339,6 +347,109 @@ uint32_t index::calculate_pq_dim(uint32_t dim) return r; } +template +raft::device_matrix_view index::rotation_matrix_int8( + const raft::resources& res) const +{ + if (!rotation_matrix_int8_.has_value()) { + rotation_matrix_int8_.emplace( + raft::make_device_mdarray(res, rotation_matrix().extents())); + raft::linalg::map(res, + rotation_matrix_int8_->view(), + cuvs::spatial::knn::detail::utils::mapping{}, + rotation_matrix()); + } + return rotation_matrix_int8_->view(); +} + +template +raft::device_matrix_view index::centers_int8( + const raft::resources& res) const +{ + if (!centers_int8_.has_value()) { + uint32_t n_lists = this->n_lists(); + uint32_t dim = this->dim(); + uint32_t dim_ext = this->dim_ext(); + uint32_t dim_ext_int8 = raft::round_up_safe(dim + 2, 16u); + centers_int8_.emplace(raft::make_device_matrix(res, n_lists, dim_ext_int8)); + auto* inputs = centers().data_handle(); + /* NOTE: maximizing the range and the precision of int8_t GEMM + + int8_t has a very limited range [-128, 127], which is problematic when storing both vectors and + their squared norms in one place. + + We map all dimensions by multiplying by 128. But that means we need to multiply the squared norm + component by `128^2`, which we cannot afford, since it most likely overflows. + So, a naive mapping would be: + ``` + [c_1 * 128, c_2, * 128, ...., c_(dim-1) * 128, n2 * 128 * 128, 0 ... 0] + • [q_1 * 128, q_2 * 128, ..., q_(dim-1)*128, -0.5, 0, ... 0] + ``` + + Which is at first can be improved by moving one 128 to the query side: + ``` + [c_1 * 128, c_2, * 128, ...., c_(dim-1) * 128, n2 * 128, 0 ... 0] + • [q_1 * 128, q_2 * 128, ..., q_(dim-1)*128, -64, 0, ... 0] + ``` + + Yet this still only works for vectors with L2 norms not bigger than one and has a rather awful + granularity of 64. To improve both the range and the precision, we count the number of available + slots `m > 2` and decompose the squared norm, such that: + ``` + 0.5 * 128 * n2 = 64 * n2 = 128 * z + (m - 1) * y + ``` + where `y` maximizes the available range while `z` encodes the rounding error. + Then we get following dot product during the coarse search: + ``` + [c_1 * 128, c_2, * 128, ...., c_(dim-1) * 128, z, y, ... y] + • [q_1 * 128, q_2 * 128, ..., q_(dim-1)*128, 1 - m, -128, ... -128] + ``` + `m` is maximum 16, so we get the coefficient much lower than the naive 64 on the query side; and + it is limited by the range we can cover (the squared norm must be within `m * 2` before + normalization). + */ + raft::linalg::map_offset( + res, centers_int8_->view(), [dim, dim_ext, dim_ext_int8, inputs] __device__(uint32_t ix) { + uint32_t col = ix % dim_ext_int8; + uint32_t row = ix / dim_ext_int8; + if (col < dim) { + return static_cast( + std::clamp(inputs[col + row * dim_ext] * 128.0f, -128.0f, 127.f)); + } + auto x = inputs[row * dim_ext + dim]; + auto c = 64.0f / static_cast(dim_ext_int8 - dim - 1); + auto y = std::clamp(x * c, -128.0f, 127.f); + auto z = std::clamp((y - std::round(y)) * 128.0f, -128.0f, 127.f); + if (col > dim) { return static_cast(std::round(y)); } + return static_cast(z); + }); + } + return centers_int8_->view(); +} + +template +raft::device_matrix_view index::rotation_matrix_half( + const raft::resources& res) const +{ + if (!rotation_matrix_half_.has_value()) { + rotation_matrix_half_.emplace( + raft::make_device_mdarray(res, rotation_matrix().extents())); + raft::linalg::map(res, rotation_matrix_half_->view(), raft::cast_op{}, rotation_matrix()); + } + return rotation_matrix_half_->view(); +} + +template +raft::device_matrix_view index::centers_half( + const raft::resources& res) const +{ + if (!centers_half_.has_value()) { + centers_half_.emplace(raft::make_device_mdarray(res, centers().extents())); + raft::linalg::map(res, centers_half_->view(), raft::cast_op{}, centers()); + } + return centers_half_->view(); +} + template struct index; } // namespace cuvs::neighbors::ivf_pq diff --git a/cpp/tests/neighbors/ann_ivf_pq.cuh b/cpp/tests/neighbors/ann_ivf_pq.cuh index 0ebe604f8..97f252bff 100644 --- a/cpp/tests/neighbors/ann_ivf_pq.cuh +++ b/cpp/tests/neighbors/ann_ivf_pq.cuh @@ -99,6 +99,9 @@ inline auto operator<<(std::ostream& os, const ivf_pq_inputs& p) -> std::ostream PRINT_DIFF_V(.search_params.lut_dtype, cuvs::neighbors::print_dtype{p.search_params.lut_dtype}); PRINT_DIFF_V(.search_params.internal_distance_dtype, cuvs::neighbors::print_dtype{p.search_params.internal_distance_dtype}); + PRINT_DIFF_V(.search_params.coarse_search_dtype, + cuvs::neighbors::print_dtype{p.search_params.coarse_search_dtype}); + PRINT_DIFF(.search_params.max_internal_batch_size); os << "}"; return os; } @@ -849,6 +852,17 @@ inline auto enum_variety() -> test_cases_t x.search_params.lut_dtype = CUDA_R_8U; x.min_recall = 0.84; }); + ADD_CASE({ + x.search_params.coarse_search_dtype = CUDA_R_16F; + x.min_recall = 0.86; + }); + ADD_CASE({ + x.search_params.coarse_search_dtype = CUDA_R_8I; + // 8-bit coarse search is experimental and there's no go guarantee of any recall + // if the data is not normalized. Especially for L2, because we store vector norms alongside the + // cluster centers. + x.min_recall = 0.1; + }); ADD_CASE({ x.search_params.internal_distance_dtype = CUDA_R_32F; @@ -859,6 +873,12 @@ inline auto enum_variety() -> test_cases_t x.search_params.lut_dtype = CUDA_R_16F; x.min_recall = 0.86; }); + ADD_CASE({ + x.search_params.internal_distance_dtype = CUDA_R_16F; + x.search_params.lut_dtype = CUDA_R_16F; + x.search_params.coarse_search_dtype = CUDA_R_16F; + x.min_recall = 0.86; + }); return xs; } @@ -994,6 +1014,34 @@ inline auto special_cases() -> test_cases_t x.search_params.n_probes = 50; }); + // Test large max_internal_batch_size + ADD_CASE({ + x.num_db_vecs = 500000; + x.dim = 100; + x.num_queries = 128 * 1024 * 1024; + x.k = 10; + x.index_params.codebook_kind = ivf_pq::codebook_gen::PER_SUBSPACE; + x.index_params.pq_dim = 10; + x.index_params.pq_bits = 8; + x.index_params.n_lists = 1024; + x.search_params.n_probes = 50; + x.search_params.max_internal_batch_size = 64 * 1024 * 1024; + }); + + // Test small max_internal_batch_size + ADD_CASE({ + x.num_db_vecs = 500000; + x.dim = 100; + x.num_queries = 128 * 1024 * 1024; + x.k = 10; + x.index_params.codebook_kind = ivf_pq::codebook_gen::PER_SUBSPACE; + x.index_params.pq_dim = 10; + x.index_params.pq_bits = 8; + x.index_params.n_lists = 1024; + x.search_params.n_probes = 50; + x.search_params.max_internal_batch_size = 1024 * 1024; + }); + ADD_CASE({ x.num_db_vecs = 10000; x.dim = 16;