Skip to content

UMAP knn build with cuVS all-neighbors #6563

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 23 commits into
base: branch-25.06
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cpp/cmake/thirdparty/get_cuvs.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ function(find_and_configure_cuvs)
BUILD_EXPORT_SET cuml-exports
INSTALL_EXPORT_SET cuml-exports
CPM_ARGS
GIT_REPOSITORY https://github.com/${PKG_FORK}/cuvs.git
GIT_TAG ${PKG_PINNED_TAG}
GIT_REPOSITORY https://github.com/jinsolp/cuvs.git
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs to be reverted once this PR in cuvs is merged

GIT_TAG snmg-batching
SOURCE_SUBDIR cpp
EXCLUDE_FROM_ALL ${PKG_EXCLUDE_FROM_ALL}
OPTIONS
Expand Down
22 changes: 17 additions & 5 deletions cpp/include/cuml/manifold/umapparams.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,30 @@
#include <cuml/common/callback.hpp>
#include <cuml/common/logger.hpp>

#include <raft/neighbors/nn_descent_types.hpp>

#include <cuvs/distance/distance.hpp>

namespace ML {

using nn_index_params = raft::neighbors::experimental::nn_descent::index_params;
namespace graph_build_params {

struct nn_descent_params {
// not directly using cuvs::neighbors::nn_descent::index_params to distinguish UMAP-exposed NN
// Descent parameters
size_t graph_degree = 64;
size_t max_iterations = 20;
};

struct build_params {
size_t n_nearest_clusters = 2;
size_t n_clusters = 4;
nn_descent_params nn_descent_params;
};
} // namespace graph_build_params

class UMAPParams {
public:
enum MetricType { EUCLIDEAN, CATEGORICAL };
enum graph_build_algo { BRUTE_FORCE_KNN, NN_DESCENT };
enum graph_build_algo { BRUTE_FORCE_KNN, NN_DESCENT, IVFPQ };

/**
* The number of neighbors to use to approximate geodesic distance.
Expand Down Expand Up @@ -150,7 +162,7 @@ class UMAPParams {
*/
graph_build_algo build_algo = graph_build_algo::BRUTE_FORCE_KNN;

nn_index_params nn_descent_params = {};
graph_build_params::build_params build_params;

/**
* The number of nearest neighbors to use to construct the target simplicial
Expand Down
133 changes: 60 additions & 73 deletions cpp/src/umap/knn_graph/algo.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,16 @@
#include <raft/core/host_mdspan.hpp>
#include <raft/core/mdspan.hpp>
#include <raft/core/mdspan_types.hpp>
#include <raft/linalg/unary_op.cuh>
#include <raft/matrix/slice.cuh>
#include <raft/neighbors/nn_descent.cuh>
#include <raft/neighbors/nn_descent_types.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/util/cudart_utils.hpp>

#include <cuvs/distance/distance.hpp>
#include <cuvs/neighbors/all_neighbors.hpp>
#include <cuvs/neighbors/brute_force.hpp>
#include <stdint.h>

#include <iostream>

namespace NNDescent = raft::neighbors::experimental::nn_descent;

namespace UMAPAlgo {
namespace kNNGraph {
namespace Algo {
Expand All @@ -58,30 +54,6 @@ void launcher(const raft::handle_t& handle,
const ML::UMAPParams* params,
cudaStream_t stream);

// Functor to post-process distances as L2Sqrt*
template <typename value_idx, typename value_t = float>
struct DistancePostProcessSqrt : NNDescent::DistEpilogue<value_idx, value_t> {
DI value_t operator()(value_t value, value_idx row, value_idx col) const { return sqrtf(value); }
};

auto get_graph_nnd(const raft::handle_t& handle,
const ML::manifold_dense_inputs_t<float>& inputs,
const ML::UMAPParams* params)
{
auto epilogue = DistancePostProcessSqrt<int64_t, float>{};
cudaPointerAttributes attr;
RAFT_CUDA_TRY(cudaPointerGetAttributes(&attr, inputs.X));
float* ptr = reinterpret_cast<float*>(attr.devicePointer);
if (ptr != nullptr) {
auto dataset =
raft::make_device_matrix_view<const float, int64_t>(inputs.X, inputs.n, inputs.d);
return NNDescent::build<float, int64_t>(handle, params->nn_descent_params, dataset, epilogue);
} else {
auto dataset = raft::make_host_matrix_view<const float, int64_t>(inputs.X, inputs.n, inputs.d);
return NNDescent::build<float, int64_t>(handle, params->nn_descent_params, dataset, epilogue);
}
}

// Instantiation for dense inputs, int64_t indices
template <>
inline void launcher(const raft::handle_t& handle,
Expand All @@ -92,12 +64,14 @@ inline void launcher(const raft::handle_t& handle,
const ML::UMAPParams* params,
cudaStream_t stream)
{
cudaPointerAttributes attr;
RAFT_CUDA_TRY(cudaPointerGetAttributes(&attr, inputsA.X));
float* ptr = reinterpret_cast<float*>(attr.devicePointer);
bool data_on_device = ptr != nullptr;

if (params->build_algo == ML::UMAPParams::graph_build_algo::BRUTE_FORCE_KNN) {
cudaPointerAttributes attr;
RAFT_CUDA_TRY(cudaPointerGetAttributes(&attr, inputsA.X));
float* ptr = reinterpret_cast<float*>(attr.devicePointer);
auto idx = [&]() {
if (ptr != nullptr) { // inputsA on device
auto idx = [&]() {
if (data_on_device) { // inputsA on device
return cuvs::neighbors::brute_force::build(
handle,
{params->metric, params->p},
Expand All @@ -115,46 +89,59 @@ inline void launcher(const raft::handle_t& handle,
raft::make_device_matrix_view<const float, int64_t>(inputsB.X, inputsB.n, inputsB.d),
raft::make_device_matrix_view<int64_t, int64_t>(out.knn_indices, inputsB.n, n_neighbors),
raft::make_device_matrix_view<float, int64_t>(out.knn_dists, inputsB.n, n_neighbors));
} else { // nn_descent
// TODO: use nndescent from cuvs
RAFT_EXPECTS(static_cast<size_t>(n_neighbors) <= params->nn_descent_params.graph_degree,
"n_neighbors should be smaller than the graph degree computed by nn descent");
RAFT_EXPECTS(params->nn_descent_params.return_distances,
"return_distances for nn descent should be set to true to be used for UMAP");

auto graph = get_graph_nnd(handle, inputsA, params);

// `graph.graph()` is a host array (n x graph_degree).
// Slice and copy to a temporary host array (n x n_neighbors), then copy
// that to the output device array `out.knn_indices` (n x n_neighbors).
// TODO: force graph_degree = n_neighbors so the temporary host array and
// slice isn't necessary.
auto temp_indices_h = raft::make_host_matrix<int64_t, int64_t>(inputsA.n, n_neighbors);
size_t graph_degree = params->nn_descent_params.graph_degree;
#pragma omp parallel for
for (size_t i = 0; i < static_cast<size_t>(inputsA.n); i++) {
for (int j = 0; j < n_neighbors; j++) {
auto target = temp_indices_h.data_handle();
auto source = graph.graph().data_handle();
target[i * n_neighbors + j] = source[i * graph_degree + j];
}
} else { // use an approximate nearest neighbors algorithm
auto all_neighbors_params_umap = params->build_params;

auto all_neighbors_params = cuvs::neighbors::all_neighbors::index_params{};
all_neighbors_params.n_nearest_clusters = all_neighbors_params_umap.n_nearest_clusters;
all_neighbors_params.n_clusters = all_neighbors_params_umap.n_clusters;
all_neighbors_params.metric = params->metric;

if (params->build_algo == ML::UMAPParams::graph_build_algo::NN_DESCENT) {
RAFT_EXPECTS(all_neighbors_params_umap.nn_descent_params.graph_degree >= n_neighbors,
"NN Descent graph_degree should be larger than or equal to n_neighbors");
auto nn_descent_params =
cuvs::neighbors::all_neighbors::graph_build_params::nn_descent_params{};
nn_descent_params.graph_degree = all_neighbors_params_umap.nn_descent_params.graph_degree;
nn_descent_params.intermediate_graph_degree = nn_descent_params.graph_degree * 1.5;
nn_descent_params.max_iterations = all_neighbors_params_umap.nn_descent_params.max_iterations;
nn_descent_params.metric = params->metric;
all_neighbors_params.graph_build_params = nn_descent_params;
} else { // IVFPQ
auto ivfpq_params = cuvs::neighbors::all_neighbors::graph_build_params::ivf_pq_params{};
// heuristically good ivfpq n_lists parameter
ivfpq_params.build_params.n_lists =
std::max(5u,
static_cast<uint32_t>(inputsA.n * all_neighbors_params.n_nearest_clusters /
(5000 * all_neighbors_params.n_clusters)));
ivfpq_params.build_params.metric = params->metric;
all_neighbors_params.graph_build_params = ivfpq_params;
}
raft::copy(handle,
raft::make_device_matrix_view(out.knn_indices, inputsA.n, n_neighbors),
temp_indices_h.view());

// `graph.distances()` is a device array (n x graph_degree).
// Slice and copy to the output device array `out.knn_dists` (n x n_neighbors).
// TODO: force graph_degree = n_neighbors so this slice isn't necessary.
raft::matrix::slice_coordinates coords{static_cast<int64_t>(0),
static_cast<int64_t>(0),
static_cast<int64_t>(inputsA.n),
static_cast<int64_t>(n_neighbors)};
raft::matrix::slice<float, int64_t, raft::row_major>(

auto tmp_indices_h = raft::make_host_matrix<int64_t, int64_t>(inputsA.n, n_neighbors);
cuvs::neighbors::all_neighbors::index<int64_t, float> idx{
handle,
raft::make_const_mdspan(graph.distances().value()),
raft::make_device_matrix_view(out.knn_dists, inputsA.n, n_neighbors),
coords);
tmp_indices_h.view(),
raft::make_device_matrix_view<float, int64_t>(out.knn_dists, inputsA.n, n_neighbors)};

if (data_on_device) { // inputsA on device
cuvs::neighbors::all_neighbors::build(
handle,
raft::make_device_matrix_view<const float, int64_t>(inputsA.X, inputsA.n, inputsA.d),
all_neighbors_params,
idx);
} else { // inputsA on host
cuvs::neighbors::all_neighbors::build(
handle,
raft::make_host_matrix_view<const float, int64_t>(inputsA.X, inputsA.n, inputsA.d),
all_neighbors_params,
idx);
}

raft::copy(out.knn_indices,
idx.graph().data_handle(),
inputsA.n * n_neighbors,
raft::resource::get_cuda_stream(handle));
}
}

Expand Down
4 changes: 2 additions & 2 deletions cpp/src/umap/umap.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ inline void _transform(const raft::handle_t& handle,
float* transformed)
{
RAFT_EXPECTS(params->build_algo == ML::UMAPParams::graph_build_algo::BRUTE_FORCE_KNN,
"build algo nn_descent not supported for transform()");
"build algo nn_descent and ivfpq not supported for transform()");
manifold_dense_inputs_t<float> inputs(X, nullptr, n, d);
manifold_dense_inputs_t<float> orig_inputs(orig_X, nullptr, orig_n, d);
UMAPAlgo::_transform<knn_indices_dense_t, float, manifold_dense_inputs_t<float>, nnz_t, TPB_X>(
Expand All @@ -253,7 +253,7 @@ inline void _transform_sparse(const raft::handle_t& handle,
float* transformed)
{
RAFT_EXPECTS(params->build_algo == ML::UMAPParams::graph_build_algo::BRUTE_FORCE_KNN,
"build algo nn_descent not supported for transform()");
"build algo nn_descent and ivfpq not supported for transform()");
manifold_sparse_inputs_t<knn_indices_sparse_t, float> inputs(
indptr, indices, data, nullptr, nnz, n, d);
manifold_sparse_inputs_t<knn_indices_sparse_t, float> orig_x_inputs(
Expand Down
63 changes: 38 additions & 25 deletions python/cuml/cuml/manifold/umap.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -288,13 +288,17 @@ class UMAP(UniversalBase,
:ref:`output-data-type-configuration` for more info.
build_algo: string (default='auto')
How to build the knn graph. Supported build algorithms are ['auto', 'brute_force_knn',
'nn_descent']. 'auto' chooses to run with brute force knn if number of data rows is
'nn_descent', 'ivfpq']. 'auto' chooses to run with brute force knn if number of data rows is
smaller than or equal to 50K. Otherwise, runs with nn descent.
build_kwds: dict (optional, default=None)
Build algorithm argument {'nnd_graph_degree': 64, 'nnd_intermediate_graph_degree': 128,
'nnd_max_iterations': 20, 'nnd_termination_threshold': 0.0001, 'nnd_return_distances': True,
'nnd_n_clusters': 1}
Note that nnd_n_clusters > 1 will result in batch-building with NN Descent.
Build algorithm argument. Default values are: {'n_clusters': 1, 'n_nearest_clusters':2, 'nn_descent': {'graph_degree': n_neigbors, 'max_iterations': 20}}.
"n_clusters": int (default=1). Number of clusters to split the data into when building the knn graph. Increasing this will use less device memory at the cost of accuracy. When using n_clusters > 1, put data on host. The default value (n_clusters=1) will place the entire data on device memory.
"n_nearest_clusters": int (default=2). Number of clusters each data is assigned to. Only valid when n_clusters > 1.
"nn_descent": dict (default={"graph_degree": n_neighbors, "max_iterations": 20}). Arguments for when build_algo="nn_descent". graph_degree should be larger than or equal to n_neighbors. Increasing graph_degree and max_iterations may result in better accuracy.
[Hint1]: the ratio of n_nearest_clusters / n_clusters determines device memory usage. Approximately (n_nearest_clusters / n_clusters) * num_rows_in_entire_data number of rows will be put on device memory at once.
E.g. between (n_nearest_clusters / n_clusters) = 2/10 and 2/20, the latter will use less device memory.
[Hint2]: larger n_nearest_clusters results in better accuracy of the final all-neighbors knn graph.
E.g. With the similar device memory usages, (n_nearest_clusters / n_clusters) = 4/20 will have better accuracy than 2/10 at the cost of performance.

Notes
-----
Expand Down Expand Up @@ -433,7 +437,7 @@ class UMAP(UniversalBase,

self.precomputed_knn = extract_knn_infos(precomputed_knn, n_neighbors)

if build_algo == "auto" or build_algo == "brute_force_knn" or build_algo == "nn_descent":
if build_algo in ["auto", "brute_force_knn", "nn_descent", "ivfpq"]:
if self.deterministic and build_algo == "auto":
# TODO: for now, users should be able to see the same results as previous version
# (i.e. running brute force knn) when they explicitly pass random_state
Expand All @@ -444,10 +448,14 @@ class UMAP(UniversalBase,
else:
self.build_algo = build_algo
else:
raise Exception("Invalid build algo: {}. Only support auto, brute_force_knn and nn_descent" % build_algo)
raise Exception("Invalid build algo: {}. Only support auto, brute_force_knn, nn_descent and ivfpq" % build_algo)

self.build_kwds = build_kwds

# for deprecation notice
if self.build_kwds is not None and "nnd" in self.build_kwds.keys():
raise Exception("build_kwds no longer supports nnd_* arguments. Please refer to docs for detailed configurations.")

def validate_hyperparams(self):

if self.min_dist > self.spread:
Expand Down Expand Up @@ -494,20 +502,25 @@ class UMAP(UniversalBase,
else:
umap_params.p = <float>self.metric_kwds.get('p')

if self.build_algo == "brute_force_knn":
umap_params.build_algo = graph_build_algo.BRUTE_FORCE_KNN
else:
umap_params.build_algo = graph_build_algo.NN_DESCENT
build_kwds = self.build_kwds or {}
umap_params.nn_descent_params.graph_degree = <uint64_t> build_kwds.get("nnd_graph_degree", 64)
umap_params.nn_descent_params.intermediate_graph_degree = <uint64_t> build_kwds.get("nnd_intermediate_graph_degree", 128)
umap_params.nn_descent_params.max_iterations = <uint64_t> build_kwds.get("nnd_max_iterations", 20)
umap_params.nn_descent_params.termination_threshold = <float> build_kwds.get("nnd_termination_threshold", 0.0001)
umap_params.nn_descent_params.return_distances = <bool> build_kwds.get("nnd_return_distances", True)
umap_params.nn_descent_params.n_clusters = <uint64_t> build_kwds.get("nnd_n_clusters", 1)
# Forward metric & metric_kwds to nn_descent
umap_params.nn_descent_params.metric = <RaftDistanceType> umap_params.metric
umap_params.nn_descent_params.metric_arg = umap_params.p
if self.build_algo == "brute_force_knn":
umap_params.build_algo = graph_build_algo.BRUTE_FORCE_KNN
else:
build_kwds = self.build_kwds or {}
umap_params.build_params.n_clusters = <uint64_t> build_kwds.get("n_clusters", 1)
umap_params.build_params.n_nearest_clusters = <uint64_t> build_kwds.get("n_nearest_clusters", 2)
if umap_params.build_params.n_clusters > 1 and umap_params.build_params.n_nearest_clusters >= umap_params.build_params.n_clusters:
raise Exception("If n_clusters > 1, then n_nearest_clusters should be strictly smaller than n_clusters.")
if self.build_algo == "nn_descent":
umap_params.build_algo = graph_build_algo.NN_DESCENT
nnd_build_kwds = build_kwds.get("nn_descent", {})
umap_params.build_params.nn_descent_params.graph_degree = <uint64_t> nnd_build_kwds.get("graph_degree", self.n_neighbors)
umap_params.build_params.nn_descent_params.max_iterations = <uint64_t> nnd_build_kwds.get("max_iterations", 20)
if umap_params.build_params.nn_descent_params.graph_degree < self.n_neighbors:
logger.warn("to use nn descent as the build algo, graph_degree should be larger than or equal to n_neigbors. setting graph_degree to n_neighbors.")
umap_params.build_params.nn_descent_params.graph_degree = self.n_neighbors

else: # ivfpq
umap_params.build_algo = graph_build_algo.IVFPQ

cdef uintptr_t callback_ptr = 0
if self.callback:
Expand Down Expand Up @@ -562,8 +575,8 @@ class UMAP(UniversalBase,
self.n_rows, self.n_dims = self._raw_data.shape
self.sparse_fit = True
self._sparse_data = True
if self.build_algo == "nn_descent":
raise ValueError("NN Descent does not support sparse inputs")
if self.build_algo == "nn_descent" or self.build_algo == "ivfpq":
raise ValueError("Building knn graph with NN Descent or IVFPQ does not support sparse inputs")

# Handle dense inputs
else:
Expand Down Expand Up @@ -795,8 +808,8 @@ class UMAP(UniversalBase,

cdef uintptr_t _embed_ptr = self.embedding_.ptr

# NN Descent doesn't support transform yet
if self.build_algo == "nn_descent" or self.build_algo == "auto":
# transform is only supported using brute force
if self.build_algo is not "brute_force_knn":
self.build_algo = "brute_force_knn"
logger.info("Transform can only be run with brute force. Using brute force.")

Expand Down
Loading