From 1a03b193152bc0d4f0c3303b2e8cacaa1d936628 Mon Sep 17 00:00:00 2001 From: jinsolp Date: Fri, 23 Aug 2024 20:34:47 +0000 Subject: [PATCH] putting data on host using mst optimize --- cpp/include/cuml/cluster/hdbscan.hpp | 25 +- cpp/src/hdbscan/detail/mst_opt.cuh | 1245 ++++++++++++++++++ cpp/src/hdbscan/detail/reachability.cuh | 490 ++++++- cpp/src/hdbscan/hdbscan.cu | 6 +- cpp/src/hdbscan/runner.h | 6 +- python/cuml/cuml/cluster/hdbscan/hdbscan.pyx | 111 +- python/cuml/cuml/tests/test_hdbscan.py | 57 +- 7 files changed, 1839 insertions(+), 101 deletions(-) create mode 100644 cpp/src/hdbscan/detail/mst_opt.cuh diff --git a/cpp/include/cuml/cluster/hdbscan.hpp b/cpp/include/cuml/cluster/hdbscan.hpp index eb1223fd88..0b98aeca86 100644 --- a/cpp/include/cuml/cluster/hdbscan.hpp +++ b/cpp/include/cuml/cluster/hdbscan.hpp @@ -18,6 +18,7 @@ #include #include +#include #include @@ -27,6 +28,8 @@ namespace ML { namespace HDBSCAN { namespace Common { +using nn_index_params = raft::neighbors::experimental::nn_descent::index_params; + /** * The Condensed hierarchicy is represented by an edge list with * parents as the source vertices, children as the destination, @@ -134,6 +137,7 @@ class CondensedHierarchy { }; enum CLUSTER_SELECTION_METHOD { EOM = 0, LEAF = 1 }; +enum GRAPH_BUILD_ALGO { BRUTE_FORCE_KNN = 0, NN_DESCENT = 1 }; class RobustSingleLinkageParams { public: @@ -151,6 +155,8 @@ class RobustSingleLinkageParams { class HDBSCANParams : public RobustSingleLinkageParams { public: CLUSTER_SELECTION_METHOD cluster_selection_method = CLUSTER_SELECTION_METHOD::EOM; + GRAPH_BUILD_ALGO build_algo = GRAPH_BUILD_ALGO::BRUTE_FORCE_KNN; + nn_index_params nn_descent_params = {}; }; /** @@ -495,14 +501,19 @@ namespace HDBSCAN::HELPER { * @param n number of columns in X * @param metric distance metric to use * @param min_samples minimum number of samples to use for computing core distances + * @param build_algo build algo for building the knn graph (default: brute_force_knn) + * @param build_params build parameters for build_algo */ -void compute_core_dists(const raft::handle_t& handle, - const float* X, - float* core_dists, - size_t m, - size_t n, - raft::distance::DistanceType metric, - int min_samples); +void compute_core_dists( + const raft::handle_t& handle, + const float* X, + float* core_dists, + size_t m, + size_t n, + raft::distance::DistanceType metric, + int min_samples, + HDBSCAN::Common::GRAPH_BUILD_ALGO build_algo = HDBSCAN::Common::GRAPH_BUILD_ALGO::BRUTE_FORCE_KNN, + HDBSCAN::Common::nn_index_params build_params = Common::nn_index_params{}); /** * @brief Compute the map from final, normalize labels to the labels in the CondensedHierarchy diff --git a/cpp/src/hdbscan/detail/mst_opt.cuh b/cpp/src/hdbscan/detail/mst_opt.cuh new file mode 100644 index 0000000000..b1bdcd0603 --- /dev/null +++ b/cpp/src/hdbscan/detail/mst_opt.cuh @@ -0,0 +1,1245 @@ +/* + * Copyright (c) 2021-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include +#include + +#include + +namespace NNDescent = raft::neighbors::experimental::nn_descent; + +namespace ML { +namespace HDBSCAN { +namespace detail { +namespace Reachability { +// unnamed namespace to avoid multiple definition error +namespace { +inline double cur_time(void) +{ + struct timeval tv; + gettimeofday(&tv, NULL); + return ((double)tv.tv_sec + (double)tv.tv_usec * 1e-6); +} + +template +__device__ inline void swap(T& val1, T& val2) +{ + T val0 = val1; + val1 = val2; + val2 = val0; +} + +template +__device__ inline bool swap_if_needed(K& key1, K& key2, V& val1, V& val2, bool ascending) +{ + if (key1 == key2) { return false; } + if ((key1 > key2) == ascending) { + swap(key1, key2); + swap(val1, val2); + return true; + } + return false; +} + +template +RAFT_KERNEL kern_prune(const IdxT* const knn_graph, // [graph_chunk_size, graph_degree] + const uint32_t graph_size, + const uint32_t graph_degree, + const uint32_t degree, + const uint32_t batch_size, + const uint32_t batch_id, + uint8_t* const detour_count, // [graph_chunk_size, graph_degree] + uint32_t* const num_no_detour_edges, // [graph_size] + uint64_t* const stats) +{ + __shared__ uint32_t smem_num_detour[MAX_DEGREE]; + uint64_t* const num_retain = stats; + uint64_t* const num_full = stats + 1; + + const uint64_t nid = blockIdx.x + (batch_size * batch_id); + if (nid >= graph_size) { return; } + for (uint32_t k = threadIdx.x; k < graph_degree; k += blockDim.x) { + smem_num_detour[k] = 0; + } + __syncthreads(); + + const uint64_t iA = nid; + if (iA >= graph_size) { return; } + + // count number of detours (A->D->B) + for (uint32_t kAD = 0; kAD < graph_degree - 1; kAD++) { + const uint64_t iD = knn_graph[kAD + (graph_degree * iA)]; + for (uint32_t kDB = threadIdx.x; kDB < graph_degree; kDB += blockDim.x) { + const uint64_t iB_candidate = knn_graph[kDB + ((uint64_t)graph_degree * iD)]; + for (uint32_t kAB = kAD + 1; kAB < graph_degree; kAB++) { + // if ( kDB < kAB ) + { + const uint64_t iB = knn_graph[kAB + (graph_degree * iA)]; + if (iB == iB_candidate) { + atomicAdd(smem_num_detour + kAB, 1); + break; + } + } + } + } + __syncthreads(); + } + + uint32_t num_edges_no_detour = 0; + for (uint32_t k = threadIdx.x; k < graph_degree; k += blockDim.x) { + detour_count[k + (graph_degree * iA)] = min(smem_num_detour[k], (uint32_t)255); + if (smem_num_detour[k] == 0) { num_edges_no_detour++; } + } + num_edges_no_detour += __shfl_xor_sync(0xffffffff, num_edges_no_detour, 1); + num_edges_no_detour += __shfl_xor_sync(0xffffffff, num_edges_no_detour, 2); + num_edges_no_detour += __shfl_xor_sync(0xffffffff, num_edges_no_detour, 4); + num_edges_no_detour += __shfl_xor_sync(0xffffffff, num_edges_no_detour, 8); + num_edges_no_detour += __shfl_xor_sync(0xffffffff, num_edges_no_detour, 16); + num_edges_no_detour = min(num_edges_no_detour, degree); + + if (threadIdx.x == 0) { + num_no_detour_edges[iA] = num_edges_no_detour; + atomicAdd((unsigned long long int*)num_retain, (unsigned long long int)num_edges_no_detour); + if (num_edges_no_detour >= degree) { atomicAdd((unsigned long long int*)num_full, 1); } + } +} + +template +RAFT_KERNEL kern_make_rev_graph(const IdxT* const dest_nodes, // [graph_size] + IdxT* const rev_graph, // [size, degree] + uint32_t* const rev_graph_count, // [graph_size] + const uint32_t graph_size, + const uint32_t degree) +{ + const uint32_t tid = threadIdx.x + (blockDim.x * blockIdx.x); + const uint32_t tnum = blockDim.x * gridDim.x; + + for (uint32_t src_id = tid; src_id < graph_size; src_id += tnum) { + const IdxT dest_id = dest_nodes[src_id]; + if (dest_id >= graph_size) continue; + + const uint32_t pos = atomicAdd(rev_graph_count + dest_id, 1); + if (pos < degree) { rev_graph[pos + ((uint64_t)degree * dest_id)] = src_id; } + } +} + +template +__device__ __host__ LabelT get_root_label(IdxT i, const LabelT* label) +{ + LabelT l = label[i]; + while (l != label[l]) { + l = label[l]; + } + return l; +} + +template +RAFT_KERNEL kern_mst_opt_update_graph(IdxT* mst_graph, // [graph_size, graph_degree] + const IdxT* candidate_edges, // [graph_size] + IdxT* outgoing_num_edges, // [graph_size] + IdxT* incoming_num_edges, // [graph_size] + const IdxT* outgoing_max_edges, // [graph_size] + const IdxT* incoming_max_edges, // [graph_size] + const IdxT* label, // [graph_size] + const uint32_t graph_size, + const uint32_t graph_degree, + uint64_t* stats) +{ + const uint64_t i = threadIdx.x + (blockDim.x * blockIdx.x); + if (i >= graph_size) return; + + int ret = 0; // 0: No edge, 1: Direct edge, 2: Alternate edge, 3: Failure + + if (outgoing_num_edges[i] >= outgoing_max_edges[i]) return; + uint64_t j = candidate_edges[i]; + if (j >= graph_size) return; + const uint32_t ri = get_root_label(i, label); + const uint32_t rj = get_root_label(j, label); + if (ri == rj) return; + + // Try to add a direct edge to destination node with different label. + if (incoming_num_edges[j] < incoming_max_edges[j]) { + ret = 1; + // Check to avoid duplication + for (uint64_t kj = 0; kj < graph_degree; kj++) { + uint64_t l = mst_graph[(graph_degree * j) + kj]; + if (l >= graph_size) continue; + const uint32_t rl = get_root_label(l, label); + if (ri == rl) { + ret = 0; + break; + } + } + if (ret == 0) return; + + ret = 0; + auto kj = atomicAdd(incoming_num_edges + j, (IdxT)1); + if (kj < incoming_max_edges[j]) { + auto ki = outgoing_num_edges[i]++; + mst_graph[(graph_degree * (i)) + ki] = j; // outgoing + mst_graph[(graph_degree * (j + 1)) - 1 - kj] = i; // incoming + ret = 1; + } + } + if (ret > 0) { + atomicAdd((unsigned long long int*)stats + ret, 1); + return; + } + + // Try to add an edge to an alternate node instead + ret = 3; + for (uint64_t kj = 0; kj < graph_degree; kj++) { + uint64_t l = mst_graph[(graph_degree * (j + 1)) - 1 - kj]; + if (l >= graph_size) continue; + uint32_t rl = get_root_label(l, label); + if (ri == rl) { + ret = 0; + break; + } + if (incoming_num_edges[l] >= incoming_max_edges[l]) continue; + + // Check to avoid duplication + for (uint64_t kl = 0; kl < graph_degree; kl++) { + uint64_t m = mst_graph[(graph_degree * l) + kl]; + if (m > graph_size) continue; + uint32_t rm = get_root_label(m, label); + if (ri == rm) { + ret = 0; + break; + } + } + if (ret == 0) { break; } + + auto kl = atomicAdd(incoming_num_edges + l, (IdxT)1); + if (kl < incoming_max_edges[l]) { + auto ki = outgoing_num_edges[i]++; + mst_graph[(graph_degree * (i)) + ki] = l; // outgoing + mst_graph[(graph_degree * (l + 1)) - 1 - kl] = i; // incoming + ret = 2; + break; + } + } + if (ret > 0) { atomicAdd((unsigned long long int*)stats + ret, 1); } +} + +template +RAFT_KERNEL kern_mst_opt_labeling(IdxT* label, // [graph_size] + const IdxT* mst_graph, // [graph_size, graph_degree] + const uint32_t graph_size, + const uint32_t graph_degree, + uint64_t* stats) +{ + const uint64_t i = threadIdx.x + (blockDim.x * blockIdx.x); + if (i >= graph_size) return; + + __shared__ uint32_t smem_updated[1]; + if (threadIdx.x == 0) { smem_updated[0] = 0; } + __syncthreads(); + + for (uint64_t ki = 0; ki < graph_degree; ki++) { + uint64_t j = mst_graph[(graph_degree * i) + ki]; + if (j >= graph_size) continue; + + IdxT li = label[i]; + IdxT ri = get_root_label(i, label); + if (ri < li) { atomicMin(label + i, ri); } + IdxT lj = label[j]; + IdxT rj = get_root_label(j, label); + if (rj < lj) { atomicMin(label + j, rj); } + if (ri == rj) continue; + + if (ri > rj) { + atomicCAS(label + i, ri, rj); + } else if (rj > ri) { + atomicCAS(label + j, rj, ri); + } + smem_updated[0] = 1; + } + + __syncthreads(); + if ((threadIdx.x == 0) && (smem_updated[0] > 0)) { stats[0] = 1; } +} + +template +RAFT_KERNEL kern_mst_opt_cluster_size(IdxT* cluster_size, // [graph_size] + const IdxT* label, // [graph_size] + const uint32_t graph_size, + uint64_t* stats) +{ + const uint64_t i = threadIdx.x + (blockDim.x * blockIdx.x); + if (i >= graph_size) return; + + __shared__ uint64_t smem_num_clusters[1]; + if (threadIdx.x == 0) { smem_num_clusters[0] = 0; } + __syncthreads(); + + IdxT ri = get_root_label(i, label); + if (ri == i) { + atomicAdd((unsigned long long int*)smem_num_clusters, 1); + } else { + atomicAdd(cluster_size + ri, cluster_size[i]); + cluster_size[i] = 0; + } + + __syncthreads(); + if ((threadIdx.x == 0) && (smem_num_clusters[0] > 0)) { + atomicAdd((unsigned long long int*)stats, (unsigned long long int)(smem_num_clusters[0])); + } +} + +template +RAFT_KERNEL kern_mst_opt_postprocessing(IdxT* outgoing_num_edges, // [graph_size] + IdxT* incoming_num_edges, // [graph_size] + IdxT* outgoing_max_edges, // [graph_size] + IdxT* incoming_max_edges, // [graph_size] + const IdxT* cluster_size, // [graph_size] + const uint32_t graph_size, + const uint32_t graph_degree, + uint64_t* stats) +{ + const uint64_t i = threadIdx.x + (blockDim.x * blockIdx.x); + if (i >= graph_size) return; + + __shared__ uint64_t smem_cluster_size_min[1]; + __shared__ uint64_t smem_cluster_size_max[1]; + __shared__ uint64_t smem_total_outgoing_edges[1]; + __shared__ uint64_t smem_total_incoming_edges[1]; + if (threadIdx.x == 0) { + smem_cluster_size_min[0] = stats[0]; + smem_cluster_size_max[0] = stats[1]; + smem_total_outgoing_edges[0] = 0; + smem_total_incoming_edges[0] = 0; + } + __syncthreads(); + + // Adjust incoming_num_edges + if (incoming_num_edges[i] > incoming_max_edges[i]) { + incoming_num_edges[i] = incoming_max_edges[i]; + } + + // Calculate min/max of cluster_size + if (cluster_size[i] > 0) { + if (smem_cluster_size_min[0] > cluster_size[i]) { + atomicMin((unsigned long long int*)smem_cluster_size_min, + (unsigned long long int)(cluster_size[i])); + } + if (smem_cluster_size_max[0] < cluster_size[i]) { + atomicMax((unsigned long long int*)smem_cluster_size_max, + (unsigned long long int)(cluster_size[i])); + } + } + + // Calculate total number of outgoing/incoming edges + atomicAdd((unsigned long long int*)smem_total_outgoing_edges, + (unsigned long long int)(outgoing_num_edges[i])); + atomicAdd((unsigned long long int*)smem_total_incoming_edges, + (unsigned long long int)(incoming_num_edges[i])); + + // Adjust incoming/outgoing_max_edges + if (outgoing_num_edges[i] == outgoing_max_edges[i]) { + if (outgoing_num_edges[i] + incoming_num_edges[i] < graph_degree) { + outgoing_max_edges[i] += 1; + incoming_max_edges[i] -= 1; + } + } + + __syncthreads(); + if (threadIdx.x == 0) { + atomicMin((unsigned long long int*)stats + 0, + (unsigned long long int)(smem_cluster_size_min[0])); + atomicMax((unsigned long long int*)stats + 1, + (unsigned long long int)(smem_cluster_size_max[0])); + atomicAdd((unsigned long long int*)stats + 2, + (unsigned long long int)(smem_total_outgoing_edges[0])); + atomicAdd((unsigned long long int*)stats + 3, + (unsigned long long int)(smem_total_incoming_edges[0])); + } +} + +template +uint64_t pos_in_array(T val, const T* array, uint64_t num) +{ + for (uint64_t i = 0; i < num; i++) { + if (val == array[i]) { return i; } + } + return num; +} + +template +void shift_array(T* array, uint64_t num) +{ + for (uint64_t i = num; i > 0; i--) { + array[i] = array[i - 1]; + } +} +} // namespace + +template +void mst_opt_update_graph(IdxT* mst_graph_ptr, + IdxT* candidate_edges_ptr, + IdxT* outgoing_num_edges_ptr, + IdxT* incoming_num_edges_ptr, + IdxT* outgoing_max_edges_ptr, + IdxT* incoming_max_edges_ptr, + IdxT* label_ptr, + IdxT graph_size, + uint32_t mst_graph_degree, + uint64_t k, + int& num_direct, + int& num_alternate, + int& num_failure) +{ +#pragma omp parallel for reduction(+ : num_direct, num_alternate, num_failure) + for (uint64_t ii = 0; ii < graph_size; ii++) { + uint64_t i = ii; + if (k % 2 == 0) { i = graph_size - (ii + 1); } + int ret = 0; // 0: No edge, 1: Direct edge, 2: Alternate edge, 3: Failure + + if (outgoing_num_edges_ptr[i] >= outgoing_max_edges_ptr[i]) continue; + uint64_t j = candidate_edges_ptr[i]; + if (j >= graph_size) continue; + if (label_ptr[i] == label_ptr[j]) continue; + + // Try to add a direct edge to destination node with different label. + if (incoming_num_edges_ptr[j] < incoming_max_edges_ptr[j]) { + ret = 1; + // Check to avoid duplication + for (uint64_t kj = 0; kj < mst_graph_degree; kj++) { + uint64_t l = mst_graph_ptr[(mst_graph_degree * j) + kj]; + if (l >= graph_size) continue; + if (label_ptr[i] == label_ptr[l]) { + ret = 0; + break; + } + } + if (ret == 0) continue; + + // Use atomic to avoid conflicts, since 'incoming_num_edges_ptr[j]' + // can be updated by other threads. + ret = 0; + uint32_t kj; +#pragma omp atomic capture + kj = incoming_num_edges_ptr[j]++; + if (kj < incoming_max_edges_ptr[j]) { + auto ki = outgoing_num_edges_ptr[i]++; + mst_graph_ptr[(mst_graph_degree * (i)) + ki] = j; // OUT + mst_graph_ptr[(mst_graph_degree * (j + 1)) - 1 - kj] = i; // IN + ret = 1; + } + } + if (ret == 1) { + num_direct += 1; + continue; + } + + // Try to add an edge to an alternate node instead + ret = 3; + for (uint64_t kj = 0; kj < mst_graph_degree; kj++) { + uint64_t l = mst_graph_ptr[(mst_graph_degree * (j + 1)) - 1 - kj]; + if (l >= graph_size) continue; + if (label_ptr[i] == label_ptr[l]) { + ret = 0; + break; + } + if (incoming_num_edges_ptr[l] >= incoming_max_edges_ptr[l]) continue; + + // Check to avoid duplication + for (uint64_t kl = 0; kl < mst_graph_degree; kl++) { + uint64_t m = mst_graph_ptr[(mst_graph_degree * l) + kl]; + if (m > graph_size) continue; + if (label_ptr[i] == label_ptr[m]) { + ret = 0; + break; + } + } + if (ret == 0) { break; } + + // Use atomic to avoid conflicts, since 'incoming_num_edges_ptr[l]' + // can be updated by other threads. + uint32_t kl; +#pragma omp atomic capture + kl = incoming_num_edges_ptr[l]++; + if (kl < incoming_max_edges_ptr[l]) { + auto ki = outgoing_num_edges_ptr[i]++; + mst_graph_ptr[(mst_graph_degree * (i)) + ki] = l; // OUT + mst_graph_ptr[(mst_graph_degree * (l + 1)) - 1 - kl] = i; // IN + ret = 2; + break; + } + } + if (ret == 2) { + num_alternate += 1; + } else if (ret == 3) { + num_failure += 1; + } + } +} + +// +// Create approximate MSTs with kNN graphs as input to guarantee connectivity of search graphs +// +// * Since there is an upper limit to the degree of a graph for search, what is created is a +// degree-constraied MST. +// * The number of edges is not a minimum because strict MST is not required. Therefore, it is +// an approximate MST. +// * If the input kNN graph is disconnected, random connection is added to the largest cluster. +// +template +void mst_optimization(raft::resources const& res, + raft::host_matrix_view input_graph, + raft::host_matrix_view output_graph, + raft::host_vector_view mst_graph_num_edges, + bool use_gpu = true) +{ + const double time_mst_opt_start = cur_time(); + + const IdxT graph_size = input_graph.extent(0); + const uint32_t input_graph_degree = input_graph.extent(1); + const uint32_t output_graph_degree = output_graph.extent(1); + auto input_graph_ptr = input_graph.data_handle(); + auto output_graph_ptr = output_graph.data_handle(); + auto mst_graph_num_edges_ptr = mst_graph_num_edges.data_handle(); + + // Allocate temporal arrays + const uint32_t mst_graph_degree = output_graph_degree; + auto mst_graph = raft::make_host_matrix(graph_size, mst_graph_degree); + auto outgoing_max_edges = raft::make_host_vector(graph_size); + auto incoming_max_edges = raft::make_host_vector(graph_size); + auto outgoing_num_edges = raft::make_host_vector(graph_size); + auto incoming_num_edges = raft::make_host_vector(graph_size); + auto label = raft::make_host_vector(graph_size); + auto cluster_size = raft::make_host_vector(graph_size); + auto candidate_edges = raft::make_host_vector(graph_size); + auto mst_graph_ptr = mst_graph.data_handle(); + auto outgoing_max_edges_ptr = outgoing_max_edges.data_handle(); + auto incoming_max_edges_ptr = incoming_max_edges.data_handle(); + auto outgoing_num_edges_ptr = outgoing_num_edges.data_handle(); + auto incoming_num_edges_ptr = incoming_num_edges.data_handle(); + auto label_ptr = label.data_handle(); + auto cluster_size_ptr = cluster_size.data_handle(); + auto candidate_edges_ptr = candidate_edges.data_handle(); + + // Initialize arrays +#pragma omp parallel for + for (uint64_t i = 0; i < graph_size; i++) { + for (uint64_t k = 0; k < mst_graph_degree; k++) { + // mst_graph_ptr[(mst_graph_degree * i) + k] = graph_size; + mst_graph(i, k) = graph_size; + } + outgoing_max_edges_ptr[i] = 2; + incoming_max_edges_ptr[i] = mst_graph_degree - outgoing_max_edges_ptr[i]; + outgoing_num_edges_ptr[i] = 0; + incoming_num_edges_ptr[i] = 0; + label_ptr[i] = i; + cluster_size_ptr[i] = 1; + } + + // Allocate arrays on GPU + uint32_t d_graph_size = graph_size; + if (!use_gpu) { + // (*) If GPU is not used, arrays of size 0 are created. + d_graph_size = 0; + } + auto d_mst_graph_num_edges = raft::make_device_vector(res, d_graph_size); + auto d_mst_graph = raft::make_device_matrix(res, d_graph_size, mst_graph_degree); + auto d_outgoing_max_edges = raft::make_device_vector(res, d_graph_size); + auto d_incoming_max_edges = raft::make_device_vector(res, d_graph_size); + auto d_outgoing_num_edges = raft::make_device_vector(res, d_graph_size); + auto d_incoming_num_edges = raft::make_device_vector(res, d_graph_size); + auto d_label = raft::make_device_vector(res, d_graph_size); + auto d_cluster_size = raft::make_device_vector(res, d_graph_size); + auto d_candidate_edges = raft::make_device_vector(res, d_graph_size); + auto d_mst_graph_num_edges_ptr = d_mst_graph_num_edges.data_handle(); + auto d_mst_graph_ptr = d_mst_graph.data_handle(); + auto d_outgoing_max_edges_ptr = d_outgoing_max_edges.data_handle(); + auto d_incoming_max_edges_ptr = d_incoming_max_edges.data_handle(); + auto d_outgoing_num_edges_ptr = d_outgoing_num_edges.data_handle(); + auto d_incoming_num_edges_ptr = d_incoming_num_edges.data_handle(); + auto d_label_ptr = d_label.data_handle(); + auto d_cluster_size_ptr = d_cluster_size.data_handle(); + auto d_candidate_edges_ptr = d_candidate_edges.data_handle(); + + constexpr int stats_size = 4; + auto stats = raft::make_host_vector(stats_size); + auto d_stats = raft::make_device_vector(res, stats_size); + auto stats_ptr = stats.data_handle(); + auto d_stats_ptr = d_stats.data_handle(); + + if (use_gpu) { + raft::copy(d_mst_graph_ptr, + mst_graph_ptr, + (size_t)graph_size * mst_graph_degree, + raft::resource::get_cuda_stream(res)); + raft::copy(d_outgoing_num_edges_ptr, + outgoing_num_edges_ptr, + (size_t)graph_size, + raft::resource::get_cuda_stream(res)); + raft::copy(d_incoming_num_edges_ptr, + incoming_num_edges_ptr, + (size_t)graph_size, + raft::resource::get_cuda_stream(res)); + raft::copy(d_outgoing_max_edges_ptr, + outgoing_max_edges_ptr, + (size_t)graph_size, + raft::resource::get_cuda_stream(res)); + raft::copy(d_incoming_max_edges_ptr, + incoming_max_edges_ptr, + (size_t)graph_size, + raft::resource::get_cuda_stream(res)); + raft::copy(d_label_ptr, label_ptr, (size_t)graph_size, raft::resource::get_cuda_stream(res)); + raft::copy(d_cluster_size_ptr, + cluster_size_ptr, + (size_t)graph_size, + raft::resource::get_cuda_stream(res)); + } + + IdxT num_clusters = 0; + IdxT num_clusters_pre = graph_size; + IdxT cluster_size_min = graph_size; + IdxT cluster_size_max = 0; + for (uint64_t k = 0; k <= input_graph_degree; k++) { + int num_direct = 0; + int num_alternate = 0; + int num_failure = 0; + + // 1. Prepare candidate edges + if (k == input_graph_degree) { + // If the number of clusters does not converge to 1, then edges are + // made from all nodes not belonging to the main cluster to any node + // in the main cluster. + raft::copy(cluster_size_ptr, + d_cluster_size_ptr, + (size_t)graph_size, + raft::resource::get_cuda_stream(res)); + raft::copy(label_ptr, d_label_ptr, (size_t)graph_size, raft::resource::get_cuda_stream(res)); + raft::resource::sync_stream(res); + uint32_t main_cluster_label = graph_size; +#pragma omp parallel for reduction(min : main_cluster_label) + for (uint64_t i = 0; i < graph_size; i++) { + if ((cluster_size_ptr[i] == cluster_size_max) && (main_cluster_label > i)) { + main_cluster_label = i; + } + } +#pragma omp parallel for + for (uint64_t i = 0; i < graph_size; i++) { + candidate_edges_ptr[i] = graph_size; + if (label_ptr[i] == main_cluster_label) continue; + uint64_t j = i; + while (label_ptr[j] != main_cluster_label) { + constexpr uint32_t ofst = 97; + j = (j + ofst) % graph_size; + } + candidate_edges_ptr[i] = j; + } + } else { + // Copy rank-k edges from the input knn graph to 'candidate_edges' +#pragma omp parallel for + for (uint64_t i = 0; i < graph_size; i++) { + candidate_edges_ptr[i] = input_graph_ptr[k + (input_graph_degree * i)]; + } + } + // 2. Update MST graph + // * Try to add candidate edges to MST graph + if (use_gpu) { + raft::copy(d_candidate_edges_ptr, + candidate_edges_ptr, + graph_size, + raft::resource::get_cuda_stream(res)); + stats_ptr[0] = 0; + stats_ptr[1] = num_direct; + stats_ptr[2] = num_alternate; + stats_ptr[3] = num_failure; + raft::copy(d_stats_ptr, stats_ptr, 4, raft::resource::get_cuda_stream(res)); + constexpr uint64_t n_threads = 256; + const dim3 threads(n_threads, 1, 1); + const dim3 blocks(raft::ceildiv(graph_size, n_threads), 1, 1); + kern_mst_opt_update_graph<<>>( + d_mst_graph_ptr, + d_candidate_edges_ptr, + d_outgoing_num_edges_ptr, + d_incoming_num_edges_ptr, + d_outgoing_max_edges_ptr, + d_incoming_max_edges_ptr, + d_label_ptr, + graph_size, + mst_graph_degree, + d_stats_ptr); + + raft::copy(stats_ptr, d_stats_ptr, 4, raft::resource::get_cuda_stream(res)); + raft::resource::sync_stream(res); + num_direct = stats_ptr[1]; + num_alternate = stats_ptr[2]; + num_failure = stats_ptr[3]; + } else { + mst_opt_update_graph(mst_graph_ptr, + candidate_edges_ptr, + outgoing_num_edges_ptr, + incoming_num_edges_ptr, + outgoing_max_edges_ptr, + incoming_max_edges_ptr, + label_ptr, + graph_size, + mst_graph_degree, + k, + num_direct, + num_alternate, + num_failure); + } + // 3. Labeling + uint32_t flag_update = 1; + while (flag_update) { + flag_update = 0; + if (use_gpu) { + stats_ptr[0] = flag_update; + raft::copy(d_stats_ptr, stats_ptr, 1, raft::resource::get_cuda_stream(res)); + + constexpr uint64_t n_threads = 256; + const dim3 threads(n_threads, 1, 1); + const dim3 blocks((graph_size + n_threads - 1) / n_threads, 1, 1); + kern_mst_opt_labeling<<>>( + d_label_ptr, d_mst_graph_ptr, graph_size, mst_graph_degree, d_stats_ptr); + + raft::copy(stats_ptr, d_stats_ptr, 1, raft::resource::get_cuda_stream(res)); + raft::resource::sync_stream(res); + flag_update = stats_ptr[0]; + } else { +#pragma omp parallel for reduction(+ : flag_update) + for (uint64_t i = 0; i < graph_size; i++) { + for (uint64_t ki = 0; ki < mst_graph_degree; ki++) { + uint64_t j = mst_graph_ptr[(mst_graph_degree * i) + ki]; + if (j >= graph_size) continue; + if (label_ptr[i] > label_ptr[j]) { + flag_update += 1; + label_ptr[i] = label_ptr[j]; + } + } + } + } + } + // 4. Calculate the number of clusters and the size of each cluster + num_clusters = 0; + if (use_gpu) { + stats_ptr[0] = num_clusters; + raft::copy(d_stats_ptr, stats_ptr, 1, raft::resource::get_cuda_stream(res)); + + constexpr uint64_t n_threads = 256; + const dim3 threads(n_threads, 1, 1); + const dim3 blocks(raft::ceildiv(graph_size, n_threads), 1, 1); + kern_mst_opt_cluster_size<<>>( + d_cluster_size_ptr, d_label_ptr, graph_size, d_stats_ptr); + + raft::copy(stats_ptr, d_stats_ptr, 1, raft::resource::get_cuda_stream(res)); + raft::resource::sync_stream(res); + num_clusters = stats_ptr[0]; + } else { +#pragma omp parallel for reduction(+ : num_clusters) + for (uint64_t i = 0; i < graph_size; i++) { + uint64_t ri = get_root_label(i, label_ptr); + if (ri == i) { + num_clusters += 1; + } else { +#pragma omp atomic update + cluster_size_ptr[ri] += cluster_size_ptr[i]; + cluster_size_ptr[i] = 0; + } + } + } + // 5. Postprocessings + // * Adjust incoming_num_edges + // * Calculate the min/max size of clusters. + // * Calculate the total number of outgoing/incoming edges + // * Increase the limit of outgoing edges as needed + cluster_size_min = graph_size; + cluster_size_max = 0; + uint64_t total_outgoing_edges = 0; + uint64_t total_incoming_edges = 0; + if (use_gpu) { + stats_ptr[0] = cluster_size_min; + stats_ptr[1] = cluster_size_max; + stats_ptr[2] = total_outgoing_edges; + stats_ptr[3] = total_incoming_edges; + raft::copy(d_stats_ptr, stats_ptr, 4, raft::resource::get_cuda_stream(res)); + + constexpr uint64_t n_threads = 256; + const dim3 threads(n_threads, 1, 1); + const dim3 blocks((graph_size + n_threads - 1) / n_threads, 1, 1); + kern_mst_opt_postprocessing<<>>( + d_outgoing_num_edges_ptr, + d_incoming_num_edges_ptr, + d_outgoing_max_edges_ptr, + d_incoming_max_edges_ptr, + d_cluster_size_ptr, + graph_size, + mst_graph_degree, + d_stats_ptr); + + raft::copy(stats_ptr, d_stats_ptr, 4, raft::resource::get_cuda_stream(res)); + raft::resource::sync_stream(res); + cluster_size_min = stats_ptr[0]; + cluster_size_max = stats_ptr[1]; + total_outgoing_edges = stats_ptr[2]; + total_incoming_edges = stats_ptr[3]; + } else { +#pragma omp parallel for + for (uint64_t i = 0; i < graph_size; i++) { + if (incoming_num_edges_ptr[i] > incoming_max_edges_ptr[i]) { + incoming_num_edges_ptr[i] = incoming_max_edges_ptr[i]; + } + } + +#pragma omp parallel for reduction(max : cluster_size_max) reduction(min : cluster_size_min) + for (uint64_t i = 0; i < graph_size; i++) { + if (cluster_size_ptr[i] == 0) continue; + cluster_size_min = min(cluster_size_min, cluster_size_ptr[i]); + cluster_size_max = max(cluster_size_max, cluster_size_ptr[i]); + } + +#pragma omp parallel for reduction(+ : total_outgoing_edges, total_incoming_edges) + for (uint64_t i = 0; i < graph_size; i++) { + total_outgoing_edges += outgoing_num_edges_ptr[i]; + total_incoming_edges += incoming_num_edges_ptr[i]; + } + +#pragma omp parallel for + for (uint64_t i = 0; i < graph_size; i++) { + if (outgoing_num_edges_ptr[i] < outgoing_max_edges_ptr[i]) continue; + if (outgoing_num_edges_ptr[i] + incoming_num_edges_ptr[i] == mst_graph_degree) continue; + assert(outgoing_num_edges_ptr[i] + incoming_num_edges_ptr[i] < mst_graph_degree); + outgoing_max_edges_ptr[i] += 1; + incoming_max_edges_ptr[i] = mst_graph_degree - outgoing_max_edges_ptr[i]; + } + } + // 6. Show stats + if (num_clusters != num_clusters_pre) { + std::string msg = "# k: " + std::to_string(k); + msg += ", num_clusters: " + std::to_string(num_clusters); + msg += ", cluster_size: " + std::to_string(cluster_size_min) + " to " + + std::to_string(cluster_size_max); + msg += ", total_num_edges: " + std::to_string(total_outgoing_edges) + ", " + + std::to_string(total_incoming_edges); + if (num_alternate + num_failure > 0) { + msg += ", altenate: " + std::to_string(num_alternate); + if (num_failure > 0) { msg += ", failure: " + std::to_string(num_failure); } + } + } + assert(num_clusters > 0); + assert(total_outgoing_edges == total_incoming_edges); + if (num_clusters == 1) { break; } + num_clusters_pre = num_clusters; + } + // The edges that make up the MST are stored as edges in the output graph. + if (use_gpu) { + raft::copy(mst_graph_ptr, + d_mst_graph_ptr, + (size_t)graph_size * mst_graph_degree, + raft::resource::get_cuda_stream(res)); + raft::resource::sync_stream(res); + } +#pragma omp parallel for + for (uint64_t i = 0; i < graph_size; i++) { + uint64_t k = 0; + for (uint64_t kj = 0; kj < mst_graph_degree; kj++) { + uint64_t j = mst_graph_ptr[(mst_graph_degree * i) + kj]; + if (j >= graph_size) continue; + + // Check to avoid duplication + auto flag_match = false; + for (uint64_t ki = 0; ki < k; ki++) { + if (j == output_graph_ptr[(output_graph_degree * i) + ki]) { + flag_match = true; + break; + } + } + if (flag_match) continue; + + output_graph_ptr[(output_graph_degree * i) + k] = j; + k += 1; + } + mst_graph_num_edges_ptr[i] = k; + } + + const double time_mst_opt_end = cur_time(); +} + +template < + typename IdxT = uint32_t, + typename g_accessor = + raft::host_device_accessor, raft::memory_type::host>> +void optimize(raft::resources const& res, + raft::host_matrix_view knn_graph, + raft::host_matrix_view new_graph, + const bool guarantee_connectivity = true) +{ + auto large_tmp_mr = raft::resource::get_large_workspace_resource(res); + + RAFT_EXPECTS(knn_graph.extent(0) == new_graph.extent(0), + "Each input array is expected to have the same number of rows"); + RAFT_EXPECTS(new_graph.extent(1) <= knn_graph.extent(1), + "output graph cannot have more columns than input graph"); + const uint32_t input_graph_degree = knn_graph.extent(1); + const uint32_t output_graph_degree = new_graph.extent(1); + auto input_graph_ptr = knn_graph.data_handle(); + auto output_graph_ptr = new_graph.data_handle(); + const IdxT graph_size = new_graph.extent(0); + + // MST optimization + auto mst_graph_num_edges = raft::make_host_vector(graph_size); + auto mst_graph_num_edges_ptr = mst_graph_num_edges.data_handle(); +#pragma omp parallel for + for (uint64_t i = 0; i < graph_size; i++) { + mst_graph_num_edges_ptr[i] = 0; + } + if (guarantee_connectivity) { + constexpr bool use_gpu = true; + mst_optimization(res, knn_graph, new_graph, mst_graph_num_edges.view(), use_gpu); + } + + auto pruned_graph = raft::make_host_matrix(graph_size, output_graph_degree); + { + // + // Prune kNN graph + // + auto d_detour_count = raft::make_device_mdarray( + res, large_tmp_mr, raft::make_extents(graph_size, input_graph_degree)); + + RAFT_CUDA_TRY(cudaMemsetAsync(d_detour_count.data_handle(), + 0xff, + graph_size * input_graph_degree * sizeof(uint8_t), + raft::resource::get_cuda_stream(res))); + + auto d_num_no_detour_edges = raft::make_device_mdarray( + res, large_tmp_mr, raft::make_extents(graph_size)); + RAFT_CUDA_TRY(cudaMemsetAsync(d_num_no_detour_edges.data_handle(), + 0x00, + graph_size * sizeof(uint32_t), + raft::resource::get_cuda_stream(res))); + + auto dev_stats = raft::make_device_vector(res, 2); + auto host_stats = raft::make_host_vector(2); + + // + // Prune unimportant edges. + // + // The edge to be retained is determined without explicitly considering + // distance or angle. Suppose the edge is the k-th edge of some node-A to + // node-B (A->B). Among the edges originating at node-A, there are k-1 edges + // shorter than the edge A->B. Each of these k-1 edges are connected to a + // different k-1 nodes. Among these k-1 nodes, count the number of nodes with + // edges to node-B, which is the number of 2-hop detours for the edge A->B. + // Once the number of 2-hop detours has been counted for all edges, the + // specified number of edges are picked up for each node, starting with the + // edge with the lowest number of 2-hop detours. + // + const double time_prune_start = cur_time(); + + // Copy input_graph_ptr over to device if necessary + auto d_input_graph = + raft::make_device_matrix(res, graph_size, input_graph_degree); + raft::copy(d_input_graph.data_handle(), + input_graph_ptr, + graph_size * input_graph_degree, + raft::resource::get_cuda_stream(res)); + // device_matrix_view_from_host d_input_graph( + // res, + // raft::make_host_matrix_view(input_graph_ptr, graph_size, + // input_graph_degree)); + + constexpr int MAX_DEGREE = 1024; + if (input_graph_degree > MAX_DEGREE) { + RAFT_FAIL( + "The degree of input knn graph is too large (%u). " + "It must be equal to or smaller than %d.", + input_graph_degree, + 1024); + } + const uint32_t batch_size = + std::min(static_cast(graph_size), static_cast(256 * 1024)); + const uint32_t num_batch = (graph_size + batch_size - 1) / batch_size; + const dim3 threads_prune(32, 1, 1); + const dim3 blocks_prune(batch_size, 1, 1); + + RAFT_CUDA_TRY(cudaMemsetAsync( + dev_stats.data_handle(), 0, sizeof(uint64_t) * 2, raft::resource::get_cuda_stream(res))); + + for (uint32_t i_batch = 0; i_batch < num_batch; i_batch++) { + kern_prune + <<>>( + d_input_graph.data_handle(), + graph_size, + input_graph_degree, + output_graph_degree, + batch_size, + i_batch, + d_detour_count.data_handle(), + d_num_no_detour_edges.data_handle(), + dev_stats.data_handle()); + raft::resource::sync_stream(res); + } + raft::resource::sync_stream(res); + + // host_matrix_view_from_device detour_count(res, d_detour_count.view()); + auto detour_count = raft::make_host_matrix(graph_size, input_graph_degree); + raft::copy(detour_count.data_handle(), + d_detour_count.data_handle(), + graph_size * input_graph_degree, + raft::resource::get_cuda_stream(res)); + + raft::copy( + host_stats.data_handle(), dev_stats.data_handle(), 2, raft::resource::get_cuda_stream(res)); + const auto num_keep = host_stats.data_handle()[0]; + const auto num_full = host_stats.data_handle()[1]; + + // Create pruned kNN graph +#pragma omp parallel for + for (uint64_t i = 0; i < graph_size; i++) { + // Find the `output_graph_degree` smallest detourable count nodes by checking the detourable + // count of the neighbors while increasing the target detourable count from zero. + uint64_t pk = 0; + uint32_t num_detour = 0; + while (pk < output_graph_degree) { + uint32_t next_num_detour = std::numeric_limits::max(); + for (uint64_t k = 0; k < input_graph_degree; k++) { + const auto num_detour_k = detour_count.data_handle()[k + (input_graph_degree * i)]; + // Find the detourable count to check in the next iteration + if (num_detour_k > num_detour) { + next_num_detour = std::min(static_cast(num_detour_k), next_num_detour); + } + + // Store the neighbor index if its detourable count is equal to `num_detour`. + if (num_detour_k != num_detour) { continue; } + output_graph_ptr[pk + (output_graph_degree * i)] = + input_graph_ptr[k + (input_graph_degree * i)]; + pk += 1; + if (pk >= output_graph_degree) break; + } + if (pk >= output_graph_degree) break; + + assert(next_num_detour != std::numeric_limits::max()); + num_detour = next_num_detour; + } + RAFT_EXPECTS(pk == output_graph_degree, + "Couldn't find the output_graph_degree (%u) smallest detourable count nodes for " + "node %lu in the rank-based node reranking process", + output_graph_degree, + static_cast(i)); + } + + const double time_prune_end = cur_time(); + } + + auto rev_graph = raft::make_host_matrix(graph_size, output_graph_degree); + auto rev_graph_count = raft::make_host_vector(graph_size); + + { + // + // Make reverse graph + // + const double time_make_start = cur_time(); + auto d_rev_graph = + raft::make_device_matrix(res, graph_size, output_graph_degree); + raft::copy(d_rev_graph.data_handle(), + rev_graph.data_handle(), + graph_size * output_graph_degree, + raft::resource::get_cuda_stream(res)); + // device_matrix_view_from_host d_rev_graph(res, rev_graph.view()); + + RAFT_CUDA_TRY(cudaMemsetAsync(d_rev_graph.data_handle(), + 0xff, + graph_size * output_graph_degree * sizeof(IdxT), + raft::resource::get_cuda_stream(res))); + + auto d_rev_graph_count = raft::make_device_mdarray( + res, large_tmp_mr, raft::make_extents(graph_size)); + RAFT_CUDA_TRY(cudaMemsetAsync(d_rev_graph_count.data_handle(), + 0x00, + graph_size * sizeof(uint32_t), + raft::resource::get_cuda_stream(res))); + + auto dest_nodes = raft::make_host_vector(graph_size); + auto d_dest_nodes = + raft::make_device_mdarray(res, large_tmp_mr, raft::make_extents(graph_size)); + + for (uint64_t k = 0; k < output_graph_degree; k++) { +#pragma omp parallel for + for (uint64_t i = 0; i < graph_size; i++) { + // dest_nodes.data_handle()[i] = output_graph_ptr[k + (output_graph_degree * i)]; + dest_nodes(i) = output_graph_ptr[k + (output_graph_degree * i)]; + } + raft::resource::sync_stream(res); + + raft::copy(d_dest_nodes.data_handle(), + dest_nodes.data_handle(), + graph_size, + raft::resource::get_cuda_stream(res)); + + dim3 threads(256, 1, 1); + dim3 blocks(1024, 1, 1); + kern_make_rev_graph<<>>( + d_dest_nodes.data_handle(), + d_rev_graph.data_handle(), + d_rev_graph_count.data_handle(), + graph_size, + output_graph_degree); + } + + raft::resource::sync_stream(res); + + raft::copy(rev_graph.data_handle(), + d_rev_graph.data_handle(), + graph_size * output_graph_degree, + raft::resource::get_cuda_stream(res)); + raft::copy(rev_graph_count.data_handle(), + d_rev_graph_count.data_handle(), + graph_size, + raft::resource::get_cuda_stream(res)); + + const double time_make_end = cur_time(); + } + + { + // + // Create search graphs from MST and pruned and reverse graphs + // + const double time_replace_start = cur_time(); + +#pragma omp parallel for + for (uint64_t i = 0; i < graph_size; i++) { + auto my_fwd_graph = pruned_graph.data_handle() + (output_graph_degree * i); + auto my_rev_graph = rev_graph.data_handle() + (output_graph_degree * i); + auto my_out_graph = output_graph_ptr + (output_graph_degree * i); + uint32_t kf = 0; + uint32_t k = mst_graph_num_edges_ptr[i]; + + const uint64_t num_protected_edges = max(k, output_graph_degree / 2); + assert(num_protected_edges <= output_graph_degree); + if (num_protected_edges == output_graph_degree) continue; + + // Append edges from the pruned graph to output graph + while (k < output_graph_degree && kf < output_graph_degree) { + if (my_fwd_graph[kf] < graph_size) { + auto flag_match = false; + for (uint32_t kk = 0; kk < k; kk++) { + if (my_out_graph[kk] == my_fwd_graph[kf]) { + flag_match = true; + break; + } + } + if (!flag_match) { + my_out_graph[k] = my_fwd_graph[kf]; + k += 1; + } + } + kf += 1; + } + assert(k == output_graph_degree); + assert(kf <= output_graph_degree); + + // Replace some edges of the output graph with edges of the reverse graph. + uint32_t kr = std::min(rev_graph_count.data_handle()[i], output_graph_degree); + while (kr) { + kr -= 1; + if (my_rev_graph[kr] < graph_size) { + uint64_t pos = pos_in_array(my_rev_graph[kr], my_out_graph, output_graph_degree); + if (pos < num_protected_edges) { continue; } + uint64_t num_shift = pos - num_protected_edges; + if (pos >= output_graph_degree) { + num_shift = output_graph_degree - num_protected_edges - 1; + } + shift_array(my_out_graph + num_protected_edges, num_shift); + my_out_graph[num_protected_edges] = my_rev_graph[kr]; + } + } + } + + const double time_replace_end = cur_time(); + + /* stats */ + uint64_t num_replaced_edges = 0; +#pragma omp parallel for reduction(+ : num_replaced_edges) + for (uint64_t i = 0; i < graph_size; i++) { + for (uint64_t k = 0; k < output_graph_degree; k++) { + const uint64_t j = output_graph_ptr[k + (output_graph_degree * i)]; + const uint64_t pos = + pos_in_array(j, output_graph_ptr + (output_graph_degree * i), output_graph_degree); + if (pos == output_graph_degree) { num_replaced_edges += 1; } + } + } + } + + // Check number of incoming edges + { + auto in_edge_count = raft::make_host_vector(graph_size); + auto in_edge_count_ptr = in_edge_count.data_handle(); +#pragma omp parallel for + for (uint64_t i = 0; i < graph_size; i++) { + in_edge_count_ptr[i] = 0; + } +#pragma omp parallel for + for (uint64_t i = 0; i < graph_size; i++) { + for (uint64_t k = 0; k < output_graph_degree; k++) { + const uint64_t j = output_graph_ptr[k + (output_graph_degree * i)]; + if (j >= graph_size) continue; +#pragma omp atomic + in_edge_count_ptr[j] += 1; + } + } + auto hist = raft::make_host_vector(output_graph_degree); + auto hist_ptr = hist.data_handle(); + for (uint64_t k = 0; k < output_graph_degree; k++) { + hist_ptr[k] = 0; + } +#pragma omp parallel for + for (uint64_t i = 0; i < graph_size; i++) { + uint32_t count = in_edge_count_ptr[i]; + if (count >= output_graph_degree) continue; +#pragma omp atomic + hist_ptr[count] += 1; + } + uint32_t sum_hist = 0; + for (uint64_t k = 0; k < output_graph_degree; k++) { + sum_hist += hist_ptr[k]; + } + } +} + +}; // end namespace Reachability +}; // end namespace detail +}; // end namespace HDBSCAN +}; // end namespace ML diff --git a/cpp/src/hdbscan/detail/reachability.cuh b/cpp/src/hdbscan/detail/reachability.cuh index 03a7f7c0ad..4f855e145f 100644 --- a/cpp/src/hdbscan/detail/reachability.cuh +++ b/cpp/src/hdbscan/detail/reachability.cuh @@ -16,11 +16,18 @@ #pragma once +#include "mst_opt.cuh" + +#include +#include #include +#include #include #include #include +#include +#include #include #include #include @@ -34,6 +41,8 @@ #include #include +namespace NNDescent = raft::neighbors::experimental::nn_descent; + namespace ML { namespace HDBSCAN { namespace detail { @@ -68,6 +77,63 @@ void core_distances( }); } +// Functor to post-process distances by sqrt +// For usage with NN Descent which internally supports L2Expanded only +template +struct DistancePostProcessSqrt : NNDescent::DistEpilogue { + DI value_t operator()(value_t value, value_idx row, value_idx col) const + { + return powf(fabsf(value), 0.5); + } +}; + +template +// out and in can be same (can be done in-place) +CUML_KERNEL void copy_first_k_cols_shift_self( + T* out, T* in, size_t out_k, size_t in_k, size_t nrows) +{ + size_t row = blockIdx.x * blockDim.x + threadIdx.x; + if (row < nrows) { + for (size_t i = out_k - 1; i >= 1; i--) { + out[row * out_k + i] = in[row * in_k + i - 1]; + } + out[row * out_k] = row; + } +} + +template +CUML_KERNEL void copy_first_k_cols_shift_zero( + T* out, T* in, size_t out_k, size_t in_k, size_t nrows) +{ + size_t row = blockIdx.x * blockDim.x + threadIdx.x; + if (row < nrows) { + for (size_t i = 1; i < out_k; i++) { + out[row * out_k + i] = in[row * in_k + i - 1]; + } + out[row * out_k] = static_cast(0); + } +} + +template +auto get_graph_nnd(const raft::handle_t& handle, + const value_t* X, + size_t m, + size_t n, + epilogue_op distance_epilogue, + Common::nn_index_params build_params) +{ + cudaPointerAttributes attr; + RAFT_CUDA_TRY(cudaPointerGetAttributes(&attr, X)); + float* ptr = reinterpret_cast(attr.devicePointer); + if (ptr != nullptr) { + auto dataset = raft::make_device_matrix_view(X, m, n); + return NNDescent::build(handle, build_params, dataset, distance_epilogue); + } else { + auto dataset = raft::make_host_matrix_view(X, m, n); + return NNDescent::build(handle, build_params, dataset, distance_epilogue); + } +} + /** * Wraps the brute force knn API, to be used for both training and prediction * @tparam value_idx data type for integrals @@ -93,33 +159,68 @@ void compute_knn(const raft::handle_t& handle, const value_t* search_items, size_t n_search_items, int k, - raft::distance::DistanceType metric) + raft::distance::DistanceType metric, + Common::GRAPH_BUILD_ALGO build_algo = Common::GRAPH_BUILD_ALGO::BRUTE_FORCE_KNN, + Common::nn_index_params build_params = Common::nn_index_params{}) { auto stream = handle.get_stream(); auto exec_policy = handle.get_thrust_policy(); - std::vector inputs; - inputs.push_back(const_cast(X)); - - std::vector sizes; - sizes.push_back(m); - // This is temporary. Once faiss is updated, we should be able to // pass value_idx through to knn. rmm::device_uvector int64_indices(k * n_search_items, stream); - // perform knn - brute_force_knn(handle, - inputs, - sizes, - n, - const_cast(search_items), - n_search_items, - int64_indices.data(), - dists, - k, - true, - true, - metric); + if (build_algo == Common::GRAPH_BUILD_ALGO::BRUTE_FORCE_KNN) { + std::vector inputs; + inputs.push_back(const_cast(X)); + + std::vector sizes; + sizes.push_back(m); + + // perform knn + brute_force_knn(handle, + inputs, + sizes, + n, + const_cast(search_items), + n_search_items, + int64_indices.data(), + dists, + k, + true, + true, + metric); + } else { // NN_DESCENT + RAFT_EXPECTS(static_cast(k) <= build_params.graph_degree, + "n_neighbors should be smaller than the graph degree computed by nn descent"); + build_params.return_distances = true; + + auto epilogue = DistancePostProcessSqrt{}; + build_params.return_distances = true; + auto graph = get_graph_nnd(handle, X, m, n, epilogue, build_params); + + size_t TPB = 256; + size_t num_blocks = static_cast((m + TPB) / TPB); + + auto indices_d = + raft::make_device_matrix(handle, m, build_params.graph_degree); + + raft::copy( + indices_d.data_handle(), graph.graph().data_handle(), m * build_params.graph_degree, stream); + + RAFT_EXPECTS(graph.distances().has_value(), + "return_distances for nn descent should be set to true to be used for HDBSCAN"); + copy_first_k_cols_shift_zero + <<>>(dists, + graph.distances().value().data_handle(), + static_cast(k), + build_params.graph_degree, + m); + copy_first_k_cols_shift_self<<>>(int64_indices.data(), + indices_d.data_handle(), + static_cast(k), + build_params.graph_degree, + m); + } // convert from current knn's 64-bit to 32-bit. thrust::transform(exec_policy, @@ -134,13 +235,16 @@ void compute_knn(const raft::handle_t& handle, to compute core_dists */ template -void _compute_core_dists(const raft::handle_t& handle, - const value_t* X, - value_t* core_dists, - size_t m, - size_t n, - raft::distance::DistanceType metric, - int min_samples) +void _compute_core_dists( + const raft::handle_t& handle, + const value_t* X, + value_t* core_dists, + size_t m, + size_t n, + raft::distance::DistanceType metric, + int min_samples, + Common::GRAPH_BUILD_ALGO build_algo = Common::GRAPH_BUILD_ALGO::BRUTE_FORCE_KNN, + Common::nn_index_params build_params = Common::nn_index_params{}) { RAFT_EXPECTS(metric == raft::distance::DistanceType::L2SqrtExpanded, "Currently only L2 expanded distance is supported"); @@ -151,7 +255,18 @@ void _compute_core_dists(const raft::handle_t& handle, rmm::device_uvector dists(min_samples * m, stream); // perform knn - compute_knn(handle, X, inds.data(), dists.data(), m, n, X, m, min_samples, metric); + compute_knn(handle, + X, + inds.data(), + dists.data(), + m, + n, + X, + m, + min_samples, + metric, + build_algo, + build_params); // Slice core distances (distances to kth nearest neighbor) core_distances(dists.data(), min_samples, min_samples, m, core_dists, stream); @@ -169,6 +284,135 @@ struct ReachabilityPostProcess { value_t alpha; }; +// Functor to post-process distances into reachability space (Sqrt) +// For usage with NN Descent which internally supports L2Expanded only +template +struct ReachabilityPostProcessSqrt : NNDescent::DistEpilogue { + ReachabilityPostProcessSqrt(value_t* core_dists_, value_t alpha_) + : NNDescent::DistEpilogue(), + core_dists(core_dists_), + alpha(alpha_), + value_t_max(std::numeric_limits::max()){}; + + __device__ value_t operator()(value_t value, value_idx row, value_idx col) const + { + if (cluster_indices == nullptr) { + return max(core_dists[col], max(core_dists[row], powf(fabsf(alpha * value), 0.5))); + } else { + if (row < num_data_in_cluster && col < num_data_in_cluster) { + return max(core_dists[cluster_indices[col]], + max(core_dists[cluster_indices[row]], powf(fabsf(alpha * value), 0.5))); + } else { + return value_t_max; + } + } + } + + __host__ void preprocess_for_batch(value_idx* cluster_indices_, size_t num_data_in_cluster_) + { + cluster_indices = cluster_indices_; + num_data_in_cluster = num_data_in_cluster_; + } + + const value_t* core_dists; + value_t alpha; + value_t value_t_max; + value_idx* cluster_indices = nullptr; + size_t num_data_in_cluster = 0; +}; + +template +struct CustomComparator { + __host__ __device__ bool operator()(const thrust::tuple& lhs, + const thrust::tuple& rhs) const + { + if (thrust::get<0>(lhs) != thrust::get<0>(rhs)) { + return thrust::get<0>(lhs) < thrust::get<0>(rhs); + } + return thrust::get<1>(lhs) < thrust::get<1>(rhs); + } +}; + +template +float calculate_mutual_reach_dist( + const raft::handle_t& handle, const value_t* X, int i, int j, float core_dist, size_t dim) +{ + auto x_i = raft::make_host_vector_view(X + i * dim, dim); + auto x_j = raft::make_host_vector_view(X + j * dim, dim); + + float x_i_norm = 0; + float x_j_norm = 0; + float dot = 0; + for (int d = 0; d < dim; d++) { + x_i_norm += x_i(d) * x_i(d); + x_j_norm += x_j(d) * x_j(d); + dot += x_i(d) * x_j(d); + } + + return std::max((float)(std::sqrt(x_i_norm + x_j_norm - 2.0 * dot)), core_dist); +} + +template +struct KeyValuePair { + KeyType key; + ValueType value; +}; + +template +struct CustomKeyComparator { + __device__ bool operator()(const KeyValuePair& a, + const KeyValuePair& b) const + { + if (a.key == b.key) { return a.value < b.value; } + return a.key < b.key; + } +}; + +template +CUML_KERNEL void sort_by_key(float* out_dists, + value_idx* out_inds, + size_t graph_degree, + size_t nrows) +{ + size_t row = blockIdx.x; + typedef cub::BlockMergeSort, BLOCK_SIZE, ITEMS_PER_THREAD> + BlockMergeSortType; + __shared__ typename cub::BlockMergeSort, + BLOCK_SIZE, + ITEMS_PER_THREAD>::TempStorage tmpSmem; + + if (row < nrows) { + KeyValuePair threadKeyValuePair[ITEMS_PER_THREAD]; + + // load key values + size_t arrIdxBase = row * graph_degree; + size_t idxBase = static_cast(threadIdx.x) * static_cast(ITEMS_PER_THREAD); + for (int i = 0; i < ITEMS_PER_THREAD; i++) { + size_t colId = idxBase + static_cast(i); + if (colId < graph_degree) { + threadKeyValuePair[i].key = out_dists[arrIdxBase + colId]; + threadKeyValuePair[i].value = out_inds[arrIdxBase + colId]; + } else { + threadKeyValuePair[i].key = std::numeric_limits::max(); + threadKeyValuePair[i].value = std::numeric_limits::max(); + } + } + + __syncthreads(); + + BlockMergeSortType(tmpSmem).Sort(threadKeyValuePair, CustomKeyComparator{}); + + // load back to global mem + for (int i = 0; i < ITEMS_PER_THREAD; i++) { + size_t colId = idxBase + static_cast(i); + if (colId < graph_degree) { + out_dists[arrIdxBase + colId] = threadKeyValuePair[i].key; + out_inds[arrIdxBase + colId] = threadKeyValuePair[i].value; + } + } + } +} + /** * Given core distances, Fuses computations of L2 distances between all * points, projection into mutual reachability space, and k-selection. @@ -184,38 +428,125 @@ struct ReachabilityPostProcess { * @param[in] core_dists array of core distances (size m) */ template -void mutual_reachability_knn_l2(const raft::handle_t& handle, - value_idx* out_inds, - value_t* out_dists, - const value_t* X, - size_t m, - size_t n, - int k, - value_t* core_dists, - value_t alpha) +void mutual_reachability_knn_l2( + const raft::handle_t& handle, + value_idx* out_inds, + value_t* out_dists, + const value_t* X, + size_t m, + size_t n, + int k, + value_t* core_dists, + value_t alpha, + Common::GRAPH_BUILD_ALGO build_algo = Common::GRAPH_BUILD_ALGO::BRUTE_FORCE_KNN, + Common::nn_index_params build_params = Common::nn_index_params{}) { // Create a functor to postprocess distances into mutual reachability space // Note that we can't use a lambda for this here, since we get errors like: // `A type local to a function cannot be used in the template argument of the // enclosing parent function (and any parent classes) of an extended __device__ // or __host__ __device__ lambda` - auto epilogue = ReachabilityPostProcess{core_dists, alpha}; - - auto X_view = raft::make_device_matrix_view(X, m, n); - std::vector> index = {X_view}; - - raft::neighbors::brute_force::knn( - handle, - index, - X_view, - raft::make_device_matrix_view(out_inds, m, static_cast(k)), - raft::make_device_matrix_view(out_dists, m, static_cast(k)), - // TODO: expand distance metrics to support more than just L2 distance - // https://github.com/rapidsai/cuml/issues/5301 - raft::distance::DistanceType::L2SqrtExpanded, - std::make_optional(2.0f), - std::nullopt, - epilogue); + + if (build_algo == Common::GRAPH_BUILD_ALGO::BRUTE_FORCE_KNN) { + auto epilogue = ReachabilityPostProcess{core_dists, alpha}; + auto X_view = raft::make_device_matrix_view(X, m, n); + std::vector> index = {X_view}; + + raft::neighbors::brute_force::knn( + handle, + index, + X_view, + raft::make_device_matrix_view(out_inds, m, static_cast(k)), + raft::make_device_matrix_view(out_dists, m, static_cast(k)), + // TODO: expand distance metrics to support more than just L2 distance + // https://github.com/rapidsai/cuml/issues/5301 + raft::distance::DistanceType::L2SqrtExpanded, + std::make_optional(2.0f), + std::nullopt, + epilogue); + + } else { + RAFT_EXPECTS(static_cast(k) <= build_params.graph_degree, + "n_neighbors should be smaller than the graph degree computed by nn descent"); + auto epilogue = ReachabilityPostProcessSqrt{core_dists, alpha}; + build_params.return_distances = true; + auto graph = get_graph_nnd(handle, X, m, n, epilogue, build_params); + + RAFT_EXPECTS(graph.distances().has_value(), + "return_distances for nn descent should be set to true to be used for HDBSCAN"); + + cudaPointerAttributes attr; + RAFT_CUDA_TRY(cudaPointerGetAttributes(&attr, X)); + float* ptr = reinterpret_cast(attr.devicePointer); + + if (ptr != nullptr) { // data on device + auto indices_d = + raft::make_device_matrix(handle, m, build_params.graph_degree); + + raft::copy(indices_d.data_handle(), + graph.graph().data_handle(), + m * build_params.graph_degree, + handle.get_stream()); + + raft::matrix::slice_coordinates coords{static_cast(0), + static_cast(0), + static_cast(m), + static_cast(k)}; + + auto out_knn_dists_view = raft::make_device_matrix_view(out_dists, m, (size_t)k); + raft::matrix::slice( + handle, raft::make_const_mdspan(graph.distances().value()), out_knn_dists_view, coords); + auto out_knn_indices_view = + raft::make_device_matrix_view(out_inds, m, (size_t)k); + raft::matrix::slice( + handle, raft::make_const_mdspan(indices_d.view()), out_knn_indices_view, coords); + } else { + auto new_inds = raft::make_host_matrix(m, k); + auto new_dists = raft::make_host_matrix(m, k); + + auto knn_dists = raft::make_host_matrix(m, build_params.graph_degree); + + auto knn_inds = raft::make_host_matrix_view( + graph.graph().data_handle(), m, build_params.graph_degree); // reuse memory + raft::copy(knn_dists.data_handle(), + graph.distances().value().data_handle(), + m * build_params.graph_degree, + handle.get_stream()); + + optimize(handle, knn_inds, new_inds.view(), true); + + auto core_dists_h = raft::make_host_vector(m); + raft::copy(core_dists_h.data_handle(), core_dists, m, handle.get_stream()); + +#pragma omp parallel for + for (size_t i = 0; i < m; i++) { + for (int j = 0; j < k; j++) { + value_idx curr_idx = new_inds.data_handle()[i * k + j]; + bool found = false; + for (int l = 0; l < build_params.graph_degree; l++) { + if (knn_inds.data_handle()[i * build_params.graph_degree + l] == curr_idx) { + new_dists.data_handle()[i * k + j] = + knn_dists.data_handle()[i * build_params.graph_degree + l]; + found = true; + break; + } + } + if (!found) { + new_dists.data_handle()[i * k + j] = + calculate_mutual_reach_dist(handle, X, i, curr_idx, core_dists_h(i), n); + } + } + } + + raft::copy(out_inds, new_inds.data_handle(), m * k, handle.get_stream()); + raft::copy(out_dists, new_dists.data_handle(), m * k, handle.get_stream()); + + if (k <= 128) { + sort_by_key<<>>(out_dists, out_inds, k, m); + } + handle.sync_stream(); + } + } } /** @@ -260,20 +591,22 @@ void mutual_reachability_knn_l2(const raft::handle_t& handle, * neighbors. */ template -void mutual_reachability_graph(const raft::handle_t& handle, - const value_t* X, - size_t m, - size_t n, - raft::distance::DistanceType metric, - int min_samples, - value_t alpha, - value_idx* indptr, - value_t* core_dists, - raft::sparse::COO& out) +void mutual_reachability_graph( + const raft::handle_t& handle, + const value_t* X, + size_t m, + size_t n, + raft::distance::DistanceType metric, + int min_samples, + value_t alpha, + value_idx* indptr, + value_t* core_dists, + raft::sparse::COO& out, + Common::GRAPH_BUILD_ALGO build_algo = Common::GRAPH_BUILD_ALGO::BRUTE_FORCE_KNN, + Common::nn_index_params build_params = Common::nn_index_params{}) { RAFT_EXPECTS(metric == raft::distance::DistanceType::L2SqrtExpanded, "Currently only L2 expanded distance is supported"); - auto stream = handle.get_stream(); auto exec_policy = handle.get_thrust_policy(); @@ -282,16 +615,35 @@ void mutual_reachability_graph(const raft::handle_t& handle, rmm::device_uvector dists(min_samples * m, stream); // perform knn - compute_knn(handle, X, inds.data(), dists.data(), m, n, X, m, min_samples, metric); + compute_knn(handle, + X, + inds.data(), + dists.data(), + m, + n, + X, + m, + min_samples, + metric, + build_algo, + build_params); // Slice core distances (distances to kth nearest neighbor) core_distances(dists.data(), min_samples, min_samples, m, core_dists, stream); - /** * Compute L2 norm */ - mutual_reachability_knn_l2( - handle, inds.data(), dists.data(), X, m, n, min_samples, core_dists, (value_t)1.0 / alpha); + mutual_reachability_knn_l2(handle, + inds.data(), + dists.data(), + X, + m, + n, + min_samples, + core_dists, + (value_t)1.0 / alpha, + build_algo, + build_params); // self-loops get max distance auto coo_rows_counting_itr = thrust::make_counting_iterator(0); diff --git a/cpp/src/hdbscan/hdbscan.cu b/cpp/src/hdbscan/hdbscan.cu index ea64d20f6b..32ef78b470 100644 --- a/cpp/src/hdbscan/hdbscan.cu +++ b/cpp/src/hdbscan/hdbscan.cu @@ -158,10 +158,12 @@ void compute_core_dists(const raft::handle_t& handle, size_t m, size_t n, raft::distance::DistanceType metric, - int min_samples) + int min_samples, + HDBSCAN::Common::GRAPH_BUILD_ALGO build_algo, + HDBSCAN::Common::nn_index_params build_params) { HDBSCAN::detail::Reachability::_compute_core_dists( - handle, X, core_dists, m, n, metric, min_samples); + handle, X, core_dists, m, n, metric, min_samples, build_algo, build_params); } void compute_inverse_label_map(const raft::handle_t& handle, diff --git a/cpp/src/hdbscan/runner.h b/cpp/src/hdbscan/runner.h index c79148eed2..76481acb90 100644 --- a/cpp/src/hdbscan/runner.h +++ b/cpp/src/hdbscan/runner.h @@ -183,8 +183,9 @@ void build_linkage(const raft::handle_t& handle, params.alpha, mutual_reachability_indptr.data(), core_dists, - mutual_reachability_coo); - + mutual_reachability_coo, + params.build_algo, + params.nn_descent_params); /** * Construct MST sorted by weights */ @@ -289,7 +290,6 @@ void _fit_hdbscan(const raft::handle_t& handle, m, out.get_stabilities(), label_map.data()); - /** * Normalize labels so they are drawn from a monotonically increasing set * starting at 0 even in the presence of noise (-1) diff --git a/python/cuml/cuml/cluster/hdbscan/hdbscan.pyx b/python/cuml/cuml/cluster/hdbscan/hdbscan.pyx index f7691c1684..aec107cc49 100644 --- a/python/cuml/cuml/cluster/hdbscan/hdbscan.pyx +++ b/python/cuml/cuml/cluster/hdbscan/hdbscan.pyx @@ -1,4 +1,4 @@ -# Copyright (c) 2021-2023, NVIDIA CORPORATION. +# Copyright (c) 2021-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -23,6 +23,7 @@ from cuml.internals.safe_imports import gpu_only_import cp = gpu_only_import('cupy') from warnings import warn +from cuml.internals import logger from cuml.internals.array import CumlArray from cuml.internals.base import UniversalBase from cuml.common.doc_utils import generate_docstring @@ -31,6 +32,7 @@ from cuml.common import input_to_cuml_array from cuml.common.array_descriptor import CumlArrayDescriptor from cuml.internals.api_decorators import device_interop_preparation from cuml.internals.api_decorators import enable_device_interop +from cuml.internals.mem_type import MemoryType from cuml.internals.mixins import ClusterMixin from cuml.internals.mixins import CMajorInputTagMixin from cuml.internals.import_utils import has_hdbscan @@ -46,12 +48,25 @@ IF GPUBUILD == 1: from pylibraft.common.handle import Handle from pylibraft.common.handle cimport handle_t + cdef extern from "raft/neighbors/nn_descent_types.hpp" namespace "raft::neighbors::experimental::nn_descent": + cdef struct index_params: + size_t graph_degree, + size_t intermediate_graph_degree, + size_t max_iterations, + float termination_threshold, + bool return_distances, + size_t n_clusters, + cdef extern from "cuml/cluster/hdbscan.hpp" namespace "ML::HDBSCAN::Common": ctypedef enum CLUSTER_SELECTION_METHOD: EOM "ML::HDBSCAN::Common::CLUSTER_SELECTION_METHOD::EOM" LEAF "ML::HDBSCAN::Common::CLUSTER_SELECTION_METHOD::LEAF" + ctypedef enum GRAPH_BUILD_ALGO: + BRUTE_FORCE_KNN "ML::HDBSCAN::Common::GRAPH_BUILD_ALGO::BRUTE_FORCE_KNN" + NN_DESCENT "ML::HDBSCAN::Common::GRAPH_BUILD_ALGO::NN_DESCENT" + cdef cppclass CondensedHierarchy[value_idx, value_t]: CondensedHierarchy( const handle_t &handle, size_t n_leaves) @@ -98,6 +113,8 @@ IF GPUBUILD == 1: bool allow_single_cluster, CLUSTER_SELECTION_METHOD cluster_selection_method, + GRAPH_BUILD_ALGO build_algo, + index_params nn_descent_params, cdef cppclass PredictionData[int, float]: PredictionData(const handle_t &handle, @@ -151,7 +168,9 @@ IF GPUBUILD == 1: size_t m, size_t n, DistanceType metric, - int min_samples) + int min_samples, + GRAPH_BUILD_ALGO build_algo, + index_params build_params) void compute_inverse_label_map(const handle_t& handle, CondensedHierarchy[int, float]& @@ -501,7 +520,9 @@ class HDBSCAN(UniversalBase, ClusterMixin, CMajorInputTagMixin): verbose=False, connectivity='knn', output_type=None, - prediction_data=False): + prediction_data=False, + build_algo='auto', + build_kwds=None): super().__init__(handle=handle, verbose=verbose, @@ -532,6 +553,9 @@ class HDBSCAN(UniversalBase, ClusterMixin, CMajorInputTagMixin): self.fit_called_ = False self.prediction_data = prediction_data + self.build_algo = build_algo + self.build_kwds = build_kwds + self.n_clusters_ = None self.n_leaves_ = None @@ -547,6 +571,8 @@ class HDBSCAN(UniversalBase, ClusterMixin, CMajorInputTagMixin): self.prediction_data_ptr = None self._cpu_to_gpu_interop_prepped = False + logger.set_level(verbose) + @property def condensed_tree_(self): @@ -753,17 +779,22 @@ class HDBSCAN(UniversalBase, ClusterMixin, CMajorInputTagMixin): @generate_docstring() @enable_device_interop - def fit(self, X, y=None, convert_dtype=True) -> "HDBSCAN": + def fit(self, X, y=None, convert_dtype=True, data_on_host=False) -> "HDBSCAN": """ Fit HDBSCAN model from features. """ + if data_on_host: + convert_to_mem_type = MemoryType.host + else: + convert_to_mem_type = MemoryType.device X_m, n_rows, n_cols, self.dtype = \ input_to_cuml_array(X, order='C', check_dtype=[np.float32], convert_to_dtype=(np.float32 if convert_dtype - else None)) + else None), + convert_to_mem_type=convert_to_mem_type) self.X_m = X_m self.n_rows = n_rows @@ -831,6 +862,37 @@ class HDBSCAN(UniversalBase, ClusterMixin, CMajorInputTagMixin): raise ValueError("Cluster selection method not supported. " "Must one of {'eom', 'leaf'}") + if self.build_algo == "auto": + if self.n_rows <= 50000: + # brute force is faster for small datasets + logger.warn("Building knn graph using brute force") + self.build_algo = "brute_force_knn" + else: + logger.warn("Building knn graph using nn descent") + self.build_algo = "nn_descent" + + if self.build_algo == 'brute_force_knn': + params.build_algo = GRAPH_BUILD_ALGO.BRUTE_FORCE_KNN + elif self.build_algo == 'nn_descent': + params.build_algo = GRAPH_BUILD_ALGO.NN_DESCENT + if self.build_kwds is None: + params.nn_descent_params.graph_degree = 64 + params.nn_descent_params.intermediate_graph_degree = 128 + params.nn_descent_params.max_iterations = 20 + params.nn_descent_params.termination_threshold = 0.0001 + params.nn_descent_params.return_distances = True + params.nn_descent_params.n_clusters = 1 + else: + params.nn_descent_params.graph_degree = self.build_kwds.get("nnd_graph_degree", 64) + params.nn_descent_params.intermediate_graph_degree = self.build_kwds.get("nnd_intermediate_graph_degree", 128) + params.nn_descent_params.max_iterations = self.build_kwds.get("nnd_max_iterations", 20) + params.nn_descent_params.termination_threshold = self.build_kwds.get("nnd_termination_threshold", 0.0001) + params.nn_descent_params.return_distances = self.build_kwds.get("nnd_return_distances", True) + params.nn_descent_params.n_clusters = self.build_kwds.get("nnd_n_clusters", 1) + else: + raise ValueError("Build algo not supported. " + "Must one of {'brute_force_knn', 'nn_descent'}") + cdef DistanceType metric if self.metric in _metrics_mapping: metric = _metrics_mapping[self.metric] @@ -1071,13 +1133,46 @@ class HDBSCAN(UniversalBase, ClusterMixin, CMajorInputTagMixin): cdef uintptr_t X_ptr = self.X_m.ptr cdef uintptr_t core_dists_ptr = self.core_dists.ptr + cdef GRAPH_BUILD_ALGO build_algo + cdef index_params build_params + + if self.build_algo == "auto": + if self.n_rows <= 50000: + # brute force is faster for small datasets + logger.warn("Building knn graph using brute force") + self.build_algo = "brute_force_knn" + else: + logger.warn("Building knn graph using nn descent") + self.build_algo = "nn_descent" + + if self.build_algo == 'brute_force_knn': + build_algo = GRAPH_BUILD_ALGO.BRUTE_FORCE_KNN + elif self.build_algo == 'nn_descent': + build_algo = GRAPH_BUILD_ALGO.NN_DESCENT + if self.build_kwds is None: + build_params.graph_degree = 64 + build_params.intermediate_graph_degree = 128 + build_params.max_iterations = 20 + build_params.termination_threshold = 0.0001 + build_params.return_distances = True + build_params.n_clusters = 1 + else: + build_params.graph_degree = self.build_kwds.get("nnd_graph_degree", 64) + build_params.intermediate_graph_degree = self.build_kwds.get("nnd_intermediate_graph_degree", 128) + build_params.max_iterations = self.build_kwds.get("nnd_max_iterations", 20) + build_params.termination_threshold = self.build_kwds.get("nnd_termination_threshold", 0.0001) + build_params.return_distances = self.build_kwds.get("nnd_return_distances", True) + build_params.n_clusters = self.build_kwds.get("nnd_n_clusters", 1) + compute_core_dists(handle_[0], X_ptr, core_dists_ptr, self.n_rows, self.n_cols, metric, - self.min_samples) + self.min_samples, + build_algo, + build_params) cdef device_uvector[int] *inverse_label_map = \ new device_uvector[int](0, handle_[0].get_stream()) @@ -1125,7 +1220,9 @@ class HDBSCAN(UniversalBase, ClusterMixin, CMajorInputTagMixin): "connectivity", "alpha", "gen_min_span_tree", - "prediction_data" + "prediction_data", + "build_algo", + "build_kwds" ] def get_attr_names(self): diff --git a/python/cuml/cuml/tests/test_hdbscan.py b/python/cuml/cuml/tests/test_hdbscan.py index 0a9a3a6382..a061f71d56 100644 --- a/python/cuml/cuml/tests/test_hdbscan.py +++ b/python/cuml/cuml/tests/test_hdbscan.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021-2023, NVIDIA CORPORATION. +# Copyright (c) 2021-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -44,6 +44,12 @@ dataset_names = ["noisy_circles", "noisy_moons", "varied"] +def get_graph_degree(n_samples): + graph_degree = max(int((1 + ((n_samples * 1.5) // 32)) * 32), 64) + intermediate_graph_degree = int(1 + ((graph_degree * 1.3) // 32) * 32) + return graph_degree, intermediate_graph_degree + + def assert_cluster_counts(sk_agg, cuml_agg, digits=25): sk_unique, sk_counts = np.unique(sk_agg.labels_, return_counts=True) sk_counts = np.sort(sk_counts) @@ -142,14 +148,19 @@ def assert_membership_vectors(cu_vecs, sk_vecs): cu_labels_sorted = np.argsort(cu_vecs)[::-1] sk_labels_sorted = np.argsort(sk_vecs)[::-1] - k = min(sk_vecs.shape[1], 10) - for i in range(k): + if len(sk_vecs.shape) == 1: assert ( - adjusted_rand_score( - cu_labels_sorted[:, i], sk_labels_sorted[:, i] - ) - >= 0.90 + adjusted_rand_score(cu_labels_sorted, sk_labels_sorted) >= 0.9 ) + else: + k = min(sk_vecs.shape[1], 10) + for i in range(k): + assert ( + adjusted_rand_score( + cu_labels_sorted[:, i], sk_labels_sorted[:, i] + ) + >= 0.9 + ) @pytest.mark.parametrize("nrows", [500]) @@ -308,6 +319,7 @@ def test_hdbscan_sklearn_extract_clusters( allow_single_cluster, ): X = test_datasets.data + cuml_agg = HDBSCAN( verbose=logger.level_info, allow_single_cluster=allow_single_cluster, @@ -349,6 +361,7 @@ def test_hdbscan_sklearn_extract_clusters( @pytest.mark.parametrize("max_cluster_size", [0]) @pytest.mark.parametrize("cluster_selection_method", ["eom"]) @pytest.mark.parametrize("connectivity", ["knn"]) +@pytest.mark.parametrize("build_algo", ["brute_force_knn", "nn_descent"]) def test_hdbscan_cluster_patterns( dataset, nrows, @@ -359,11 +372,11 @@ def test_hdbscan_cluster_patterns( allow_single_cluster, max_cluster_size, min_samples, + build_algo, ): # This also tests duplicate data points X, y = get_pattern(dataset, nrows)[0] - cuml_agg = HDBSCAN( verbose=logger.level_info, allow_single_cluster=allow_single_cluster, @@ -372,6 +385,7 @@ def test_hdbscan_cluster_patterns( min_cluster_size=min_cluster_size, cluster_selection_epsilon=cluster_selection_epsilon, cluster_selection_method=cluster_selection_method, + build_algo=build_algo, ) cuml_agg.fit(X) @@ -412,6 +426,7 @@ def test_hdbscan_cluster_patterns( @pytest.mark.parametrize("max_cluster_size", [0]) @pytest.mark.parametrize("cluster_selection_method", ["eom", "leaf"]) @pytest.mark.parametrize("connectivity", ["knn"]) +@pytest.mark.parametrize("build_algo", ["brute_force_knn", "nn_descent"]) def test_hdbscan_cluster_patterns_extract_clusters( dataset, nrows, @@ -422,11 +437,12 @@ def test_hdbscan_cluster_patterns_extract_clusters( allow_single_cluster, max_cluster_size, min_samples, + build_algo, ): # This also tests duplicate data points X, y = get_pattern(dataset, nrows)[0] - + graph_degree, intermediate_graph_degree = get_graph_degree(min_samples) cuml_agg = HDBSCAN( verbose=logger.level_info, allow_single_cluster=allow_single_cluster, @@ -435,6 +451,11 @@ def test_hdbscan_cluster_patterns_extract_clusters( min_cluster_size=min_cluster_size, cluster_selection_epsilon=cluster_selection_epsilon, cluster_selection_method=cluster_selection_method, + build_algo=build_algo, + build_kwds={ + "nnd_graph_degree": graph_degree, + "nnd_intermediate_graph_degree": intermediate_graph_degree, + }, ) sk_agg = hdbscan.HDBSCAN( @@ -494,7 +515,8 @@ def test_hdbscan_metric_parameter_input(metric, supported): clf.fit(X) -def test_hdbscan_empty_cluster_tree(): +@pytest.mark.parametrize("build_algo", ["brute_force_knn", "nn_descent"]) +def test_hdbscan_empty_cluster_tree(build_algo): raw_tree = np.recarray( shape=(5,), @@ -510,7 +532,9 @@ def test_hdbscan_empty_cluster_tree(): condensed_tree = CondensedTree(raw_tree, 0.0, True) cuml_agg = HDBSCAN( - allow_single_cluster=True, cluster_selection_method="eom" + allow_single_cluster=True, + cluster_selection_method="eom", + build_algo=build_algo, ) cuml_agg._extract_clusters(condensed_tree) @@ -570,7 +594,6 @@ def test_all_points_membership_vectors_blobs( shuffle=True, random_state=42, ) - cuml_agg = HDBSCAN( verbose=logger.level_info, allow_single_cluster=allow_single_cluster, @@ -613,6 +636,7 @@ def test_all_points_membership_vectors_blobs( @pytest.mark.parametrize("cluster_selection_method", ["eom", "leaf"]) @pytest.mark.parametrize("connectivity", ["knn"]) @pytest.mark.parametrize("batch_size", [128, 1000]) +@pytest.mark.parametrize("build_algo", ["brute_force_knn", "nn_descent"]) def test_all_points_membership_vectors_moons( nrows, min_samples, @@ -623,6 +647,7 @@ def test_all_points_membership_vectors_moons( max_cluster_size, connectivity, batch_size, + build_algo, ): X, y = datasets.make_moons(n_samples=nrows, noise=0.05, random_state=42) @@ -636,6 +661,7 @@ def test_all_points_membership_vectors_moons( cluster_selection_epsilon=cluster_selection_epsilon, cluster_selection_method=cluster_selection_method, prediction_data=True, + build_algo=build_algo, ) cuml_agg.fit(X) @@ -934,6 +960,7 @@ def test_approximate_predict_circles( @pytest.mark.parametrize("max_cluster_size", [0]) @pytest.mark.parametrize("cluster_selection_method", ["eom"]) @pytest.mark.parametrize("connectivity", ["knn"]) +@pytest.mark.parametrize("build_algo", ["brute_force_knn", "nn_descent"]) def test_approximate_predict_digits( n_points_to_predict, min_samples, @@ -943,6 +970,7 @@ def test_approximate_predict_digits( max_cluster_size, cluster_selection_method, connectivity, + build_algo, ): digits = datasets.load_digits() X, y = digits.data, digits.target @@ -966,6 +994,7 @@ def test_approximate_predict_digits( cluster_selection_epsilon=cluster_selection_epsilon, cluster_selection_method=cluster_selection_method, prediction_data=True, + build_algo=build_algo, ) cuml_agg.fit(X_train) @@ -1077,6 +1106,7 @@ def test_membership_vector_blobs( @pytest.mark.parametrize("cluster_selection_method", ["eom", "leaf"]) @pytest.mark.parametrize("connectivity", ["knn"]) @pytest.mark.parametrize("batch_size", [16]) +@pytest.mark.parametrize("build_algo", ["brute_force_knn", "nn_descent"]) def test_membership_vector_moons( nrows, n_points_to_predict, @@ -1088,6 +1118,7 @@ def test_membership_vector_moons( max_cluster_size, connectivity, batch_size, + build_algo, ): X, y = datasets.make_moons( @@ -1106,6 +1137,7 @@ def test_membership_vector_moons( cluster_selection_epsilon=cluster_selection_epsilon, cluster_selection_method=cluster_selection_method, prediction_data=True, + build_algo=build_algo, ) cuml_agg.fit(X_train) @@ -1193,5 +1225,4 @@ def test_membership_vector_circles( sk_membership_vectors = hdbscan.membership_vector(sk_agg, X_test).astype( "float32" ) - assert_membership_vectors(cu_membership_vectors, sk_membership_vectors)