Skip to content

Commit 9b5bd2a

Browse files
committed
Change output mdspan
Signed-off-by: Mickael Ide <mide@nvidia.com>
1 parent e091067 commit 9b5bd2a

1 file changed

Lines changed: 4 additions & 2 deletions

File tree

cpp/src/neighbors/detail/vpq_dataset.cuh

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,14 +183,16 @@ auto train_vq(const raft::resources& res, const vpq_params& params, const Datase
183183
auto vq_centers =
184184
raft::make_device_matrix<MathT, uint32_t, raft::row_major>(res, vq_n_centers, dim);
185185

186-
auto vq_centers_view =
187-
raft::make_device_matrix_view<MathT, ix_t>(vq_centers.data_handle(), vq_n_centers, dim);
188186
auto vq_trainset_view = raft::make_device_matrix_view<const kmeans_in_type, ix_t>(
189187
vq_trainset.data_handle(), n_rows_train, dim);
190188

191189
if (vq_n_centers == 1) {
190+
auto vq_centers_view =
191+
raft::make_device_vector_view<MathT, ix_t>(vq_centers.data_handle(), dim);
192192
raft::stats::mean(res, vq_trainset_view, vq_centers_view);
193193
} else {
194+
auto vq_centers_view =
195+
raft::make_device_matrix_view<MathT, ix_t>(vq_centers.data_handle(), vq_n_centers, dim);
194196
cuvs::cluster::kmeans::balanced_params kmeans_params;
195197
kmeans_params.n_iters = params.kmeans_n_iters;
196198
kmeans_params.metric = cuvs::distance::DistanceType::L2Expanded;

0 commit comments

Comments
 (0)