44 */
55
66#include < cstdint>
7+
78#include < dlpack/dlpack.h>
89
910#include < cuvs/cluster/kmeans.h>
@@ -17,16 +18,18 @@ namespace {
1718
1819cuvs::cluster::kmeans::params convert_params (const cuvsKMeansParams& params)
1920{
20- auto kmeans_params = cuvs::cluster::kmeans::params ();
21- kmeans_params.metric = static_cast <cuvs::distance::DistanceType>(params.metric );
22- kmeans_params.init = static_cast <cuvs::cluster::kmeans::params::InitMethod>(params.init );
23- kmeans_params.n_clusters = params.n_clusters ;
24- kmeans_params.max_iter = params.max_iter ;
25- kmeans_params.tol = params.tol ;
21+ auto kmeans_params = cuvs::cluster::kmeans::params ();
22+ kmeans_params.metric = static_cast <cuvs::distance::DistanceType>(params.metric );
23+ kmeans_params.init = static_cast <cuvs::cluster::kmeans::params::InitMethod>(params.init );
24+ kmeans_params.n_clusters = params.n_clusters ;
25+ kmeans_params.max_iter = params.max_iter ;
26+ kmeans_params.tol = params.tol ;
27+ kmeans_params.n_init = params.n_init ;
2628 kmeans_params.oversampling_factor = params.oversampling_factor ;
2729 kmeans_params.batch_samples = params.batch_samples ;
2830 kmeans_params.batch_centroids = params.batch_centroids ;
2931 kmeans_params.inertia_check = params.inertia_check ;
32+ kmeans_params.streaming_batch_size = params.streaming_batch_size ;
3033 return kmeans_params;
3134}
3235
@@ -38,7 +41,7 @@ cuvs::cluster::kmeans::balanced_params convert_balanced_params(const cuvsKMeansP
3841 return kmeans_params;
3942}
4043
41- template <typename T, typename IdxT = int32_t >
44+ template <typename T, typename IdxT = int64_t >
4245void _fit (cuvsResources_t res,
4346 const cuvsKMeansParams& params,
4447 DLManagedTensor* X_tensor,
@@ -50,7 +53,51 @@ void _fit(cuvsResources_t res,
5053 auto X = X_tensor->dl_tensor ;
5154 auto res_ptr = reinterpret_cast <raft::resources*>(res);
5255
53- if (cuvs::core::is_dlpack_device_compatible (X)) {
56+ if (!cuvs::core::is_dlpack_device_compatible (X)) {
57+ auto n_samples = static_cast <IdxT>(X.shape [0 ]);
58+ auto n_features = static_cast <IdxT>(X.shape [1 ]);
59+
60+ if (params.hierarchical ) {
61+ RAFT_FAIL (" hierarchical kmeans is not supported with host data" );
62+ }
63+
64+ auto centroids_dl = centroids_tensor->dl_tensor ;
65+ if (!cuvs::core::is_dlpack_device_compatible (centroids_dl)) {
66+ RAFT_FAIL (" centroids must be on device memory" );
67+ }
68+
69+ auto X_view = raft::make_host_matrix_view<T const , IdxT>(
70+ reinterpret_cast <T const *>(X.data ), n_samples, n_features);
71+ auto centroids_view =
72+ cuvs::core::from_dlpack<raft::device_matrix_view<T, IdxT, raft::row_major>>(
73+ centroids_tensor);
74+
75+ std::optional<raft::host_vector_view<T const , IdxT>> sample_weight;
76+ if (sample_weight_tensor != NULL ) {
77+ auto sw = sample_weight_tensor->dl_tensor ;
78+ if (!cuvs::core::is_dlpack_host_compatible (sw)) {
79+ RAFT_FAIL (" sample_weight must be host accessible when X is on host" );
80+ }
81+ sample_weight = raft::make_host_vector_view<T const , IdxT>(
82+ reinterpret_cast <T const *>(sw.data ), n_samples);
83+ }
84+
85+ T inertia_temp;
86+ IdxT n_iter_temp;
87+
88+ auto kmeans_params = convert_params (params);
89+ cuvs::cluster::kmeans::fit (*res_ptr,
90+ kmeans_params,
91+ X_view,
92+ sample_weight,
93+ centroids_view,
94+ raft::make_host_scalar_view<T>(&inertia_temp),
95+ raft::make_host_scalar_view<IdxT>(&n_iter_temp));
96+
97+ *inertia = inertia_temp;
98+ *n_iter = n_iter_temp;
99+
100+ } else {
54101 using const_mdspan_type = raft::device_matrix_view<T const , IdxT, raft::row_major>;
55102 using mdspan_type = raft::device_matrix_view<T, IdxT, raft::row_major>;
56103
@@ -85,13 +132,11 @@ void _fit(cuvsResources_t res,
85132 cuvs::core::from_dlpack<const_mdspan_type>(X_tensor),
86133 sample_weight,
87134 cuvs::core::from_dlpack<mdspan_type>(centroids_tensor),
88- raft::make_host_scalar_view<T, IdxT >(&inertia_temp),
89- raft::make_host_scalar_view<IdxT, IdxT >(&n_iter_temp));
135+ raft::make_host_scalar_view<T>(&inertia_temp),
136+ raft::make_host_scalar_view<IdxT>(&n_iter_temp));
90137 *inertia = inertia_temp;
91138 *n_iter = n_iter_temp;
92139 }
93- } else {
94- RAFT_FAIL (" X dataset must be accessible on device memory" );
95140 }
96141}
97142
@@ -143,7 +188,7 @@ void _predict(cuvsResources_t res,
143188 cuvs::core::from_dlpack<const_mdspan_type>(centroids_tensor),
144189 cuvs::core::from_dlpack<labels_mdspan_type>(labels_tensor),
145190 normalize_weight,
146- raft::make_host_scalar_view<T, IdxT >(&inertia_temp));
191+ raft::make_host_scalar_view<T>(&inertia_temp));
147192 *inertia = inertia_temp;
148193 }
149194 } else {
@@ -168,7 +213,7 @@ void _cluster_cost(cuvsResources_t res,
168213 cuvs::cluster::kmeans::cluster_cost (*res_ptr,
169214 cuvs::core::from_dlpack<mdspan_type>(X_tensor),
170215 cuvs::core::from_dlpack<mdspan_type>(centroids_tensor),
171- raft::make_host_scalar_view<T, IdxT >(&cost_temp));
216+ raft::make_host_scalar_view<T>(&cost_temp));
172217 } else {
173218 RAFT_FAIL (" X dataset must be accessible on device memory" );
174219 }
@@ -182,17 +227,20 @@ extern "C" cuvsError_t cuvsKMeansParamsCreate(cuvsKMeansParams_t* params)
182227 return cuvs::core::translate_exceptions ([=] {
183228 cuvs::cluster::kmeans::params cpp_params;
184229 cuvs::cluster::kmeans::balanced_params cpp_balanced_params;
185- *params =
186- new cuvsKMeansParams{.metric = static_cast <cuvsDistanceType>(cpp_params.metric ),
187- .n_clusters = cpp_params.n_clusters ,
188- .init = static_cast <cuvsKMeansInitMethod>(cpp_params.init ),
189- .max_iter = cpp_params.max_iter ,
190- .tol = cpp_params.tol ,
191- .oversampling_factor = cpp_params.oversampling_factor ,
192- .batch_samples = cpp_params.batch_samples ,
193- .inertia_check = cpp_params.inertia_check ,
194- .hierarchical = false ,
195- .hierarchical_n_iters = static_cast <int >(cpp_balanced_params.n_iters )};
230+ *params = new cuvsKMeansParams{
231+ .metric = static_cast <cuvsDistanceType>(cpp_params.metric ),
232+ .n_clusters = cpp_params.n_clusters ,
233+ .init = static_cast <cuvsKMeansInitMethod>(cpp_params.init ),
234+ .max_iter = cpp_params.max_iter ,
235+ .tol = cpp_params.tol ,
236+ .n_init = cpp_params.n_init ,
237+ .oversampling_factor = cpp_params.oversampling_factor ,
238+ .batch_samples = cpp_params.batch_samples ,
239+ .batch_centroids = cpp_params.batch_centroids ,
240+ .inertia_check = cpp_params.inertia_check ,
241+ .hierarchical = false ,
242+ .hierarchical_n_iters = static_cast <int >(cpp_balanced_params.n_iters ),
243+ .streaming_batch_size = cpp_params.streaming_batch_size };
196244 });
197245}
198246
@@ -235,10 +283,9 @@ extern "C" cuvsError_t cuvsKMeansPredict(cuvsResources_t res,
235283 return cuvs::core::translate_exceptions ([=] {
236284 auto dataset = X->dl_tensor ;
237285 if (dataset.dtype .code == kDLFloat && dataset.dtype .bits == 32 ) {
238- _predict<float >(res, *params, X, sample_weight, centroids, labels, normalize_weight, inertia);
286+ _predict<float >(res, *params, X, sample_weight, centroids, labels, normalize_weight, inertia);
239287 } else if (dataset.dtype .code == kDLFloat && dataset.dtype .bits == 64 ) {
240- _predict<double >(
241- res, *params, X, sample_weight, centroids, labels, normalize_weight, inertia);
288+ _predict<double >(res, *params, X, sample_weight, centroids, labels, normalize_weight, inertia);
242289 } else {
243290 RAFT_FAIL (" Unsupported dataset DLtensor dtype: %d and bits: %d" ,
244291 dataset.dtype .code ,
0 commit comments