Skip to content

Commit 6cd5d86

Browse files
committed
PR feedback: migrate to pub trait AsDlTensor that takes &self
1 parent e42c0d9 commit 6cd5d86

15 files changed

Lines changed: 241 additions & 214 deletions

File tree

rust/cuvs/examples/cagra.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
//! CAGRA example with a user-provided GPU tensor.
77
//!
88
//! This demonstrates how to feed your own device memory into cuVS by
9-
//! implementing the public [`IntoDlTensor`]/[`IntoDlTensorMut`] traits. The
9+
//! implementing the public [`AsDlTensor`]/[`AsDlTensorMut`] traits. The
1010
//! [`CudaTensor`] type manages device memory directly through the CUDA runtime
1111
//! (`cudaMalloc`/`cudaFree`) and copies to/from host arrays with `cudaMemcpyAsync`
1212
//! on the cuVS stream, reusing the resources handle's `get_cuda_stream`/
@@ -22,8 +22,8 @@ use std::os::raw::c_int;
2222
use cuvs::Resources;
2323
use cuvs::cagra::{Index, IndexParams, SearchParams};
2424
use cuvs::dlpack::{
25-
DLDevice, DLDeviceType, DLPackError, DLTensorView, DLTensorViewMut, DType, IntoDlTensor,
26-
IntoDlTensorMut,
25+
AsDlTensor, AsDlTensorMut, DLDevice, DLDeviceType, DLPackError, DLTensorView, DLTensorViewMut,
26+
DType,
2727
};
2828

2929
use ndarray::s;
@@ -146,8 +146,8 @@ impl<T: DType> Drop for CudaTensor<T> {
146146
}
147147
}
148148

149-
impl<'a, T: DType> IntoDlTensor<'a> for &'a CudaTensor<T> {
150-
fn into_dl_tensor(self) -> std::result::Result<DLTensorView<'a>, DLPackError> {
149+
impl<T: DType> AsDlTensor for CudaTensor<T> {
150+
fn as_dl_tensor(&self) -> std::result::Result<DLTensorView<'_>, DLPackError> {
151151
unsafe {
152152
DLTensorView::from_raw_parts(
153153
self.data,
@@ -160,8 +160,8 @@ impl<'a, T: DType> IntoDlTensor<'a> for &'a CudaTensor<T> {
160160
}
161161
}
162162

163-
impl<'a, T: DType> IntoDlTensorMut<'a> for &'a mut CudaTensor<T> {
164-
fn into_dl_tensor_mut(self) -> std::result::Result<DLTensorViewMut<'a>, DLPackError> {
163+
impl<T: DType> AsDlTensorMut for CudaTensor<T> {
164+
fn as_dl_tensor_mut(&mut self) -> std::result::Result<DLTensorViewMut<'_>, DLPackError> {
165165
unsafe {
166166
DLTensorViewMut::from_raw_parts(
167167
self.data,

rust/cuvs/src/brute_force.rs

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,17 @@
55
//! Brute-force (exact) k-NN.
66
//!
77
//! Build an [`Index`] over a dataset, then [`search`](Index::search) it with
8-
//! device-resident queries and output buffers. Tensors are passed through the
9-
//! [`IntoDlTensor`] /
10-
//! [`IntoDlTensorMut`] traits; see the
8+
//! device-resident queries and output buffers. Tensors are borrowed through the
9+
//! [`AsDlTensor`] /
10+
//! [`AsDlTensorMut`] traits; see the
1111
//! [`dlpack`](crate::dlpack) module for the tensor model and `examples/cagra.rs`
1212
//! for the same build/search workflow.
1313
1414
use std::io::{Write, stderr};
1515
use std::marker::PhantomData;
1616

1717
use crate::distance_type::DistanceType;
18-
use crate::dlpack::{IntoDlTensor, IntoDlTensorMut};
18+
use crate::dlpack::{AsDlTensor, AsDlTensorMut};
1919
use crate::error::{Result, check_cuvs};
2020
use crate::resources::Resources;
2121

@@ -33,16 +33,19 @@ impl<'d> Index<'d> {
3333
///
3434
/// `metric` selects the distance and `metric_arg` is the optional `p` for
3535
/// Minkowski distances (defaults to 2). `dataset` is a row-major matrix on
36-
/// the host or device implementing [`IntoDlTensor`]; the
36+
/// the host or device implementing [`AsDlTensor`]; the
3737
/// C++ index keeps a non-owning view of it, so the returned [`Index`] borrows
3838
/// it for `'d` and cannot outlive it.
39-
pub fn build(
39+
pub fn build<T>(
4040
res: &Resources,
4141
metric: DistanceType,
4242
metric_arg: Option<f32>,
43-
dataset: impl IntoDlTensor<'d>,
44-
) -> Result<Index<'d>> {
45-
let dataset = dataset.into_dl_tensor()?;
43+
dataset: &'d T,
44+
) -> Result<Index<'d>>
45+
where
46+
T: AsDlTensor + ?Sized,
47+
{
48+
let dataset = dataset.as_dl_tensor()?;
4649
let index = Index::new()?;
4750
unsafe {
4851
check_cuvs(ffi::cuvsBruteForceBuild(
@@ -68,20 +71,25 @@ impl<'d> Index<'d> {
6871
/// Searches the index for the `k` nearest neighbors of each query.
6972
///
7073
/// `queries`, `neighbors`, and `distances` must reside in device memory and
71-
/// implement [`IntoDlTensor`] /
72-
/// [`IntoDlTensorMut`]. `neighbors` receives the
74+
/// implement [`AsDlTensor`] /
75+
/// [`AsDlTensorMut`]. `neighbors` receives the
7376
/// neighbor indices and `distances` their distances; both are written in
7477
/// place.
75-
pub fn search<'a>(
78+
pub fn search<Q, N, D>(
7679
&self,
7780
res: &Resources,
78-
queries: impl IntoDlTensor<'a>,
79-
neighbors: impl IntoDlTensorMut<'a>,
80-
distances: impl IntoDlTensorMut<'a>,
81-
) -> Result<()> {
82-
let queries = queries.into_dl_tensor()?;
83-
let neighbors = neighbors.into_dl_tensor_mut()?;
84-
let distances = distances.into_dl_tensor_mut()?;
81+
queries: &Q,
82+
neighbors: &mut N,
83+
distances: &mut D,
84+
) -> Result<()>
85+
where
86+
Q: AsDlTensor + ?Sized,
87+
N: AsDlTensorMut + ?Sized,
88+
D: AsDlTensorMut + ?Sized,
89+
{
90+
let queries = queries.as_dl_tensor()?;
91+
let neighbors = neighbors.as_dl_tensor_mut()?;
92+
let distances = distances.as_dl_tensor_mut()?;
8593
unsafe {
8694
let prefilter = ffi::cuvsFilter { addr: 0, type_: ffi::cuvsFilterType::NO_FILTER };
8795

rust/cuvs/src/cagra/index.rs

Lines changed: 38 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use std::marker::PhantomData;
99
use std::path::Path;
1010

1111
use crate::cagra::{IndexParams, SearchParams};
12-
use crate::dlpack::{IntoDlTensor, IntoDlTensorMut};
12+
use crate::dlpack::{AsDlTensor, AsDlTensorMut};
1313
use crate::error::{Error, Result, check_cuvs};
1414
use crate::resources::Resources;
1515

@@ -41,15 +41,14 @@ impl<'d> Index<'d> {
4141
/// Builds a CAGRA index over `dataset` for efficient search.
4242
///
4343
/// `dataset` is a row-major matrix on the host or device implementing
44-
/// [`IntoDlTensor`](crate::IntoDlTensor). The C++ index keeps a non-owning
44+
/// [`AsDlTensor`]. The C++ index keeps a non-owning
4545
/// view of it, so the returned [`Index`] borrows `dataset` for `'d` and
4646
/// cannot outlive it.
47-
pub fn build(
48-
res: &Resources,
49-
params: &IndexParams,
50-
dataset: impl IntoDlTensor<'d>,
51-
) -> Result<Index<'d>> {
52-
let dataset = dataset.into_dl_tensor()?;
47+
pub fn build<T>(res: &Resources, params: &IndexParams, dataset: &'d T) -> Result<Index<'d>>
48+
where
49+
T: AsDlTensor + ?Sized,
50+
{
51+
let dataset = dataset.as_dl_tensor()?;
5352
let index = Index::new()?;
5453
unsafe {
5554
check_cuvs(ffi::cuvsCagraBuild(
@@ -74,21 +73,26 @@ impl<'d> Index<'d> {
7473
/// Searches the index for the `k` nearest neighbors of each query.
7574
///
7675
/// `queries`, `neighbors`, and `distances` must reside in device memory and
77-
/// implement [`IntoDlTensor`](crate::IntoDlTensor) /
78-
/// [`IntoDlTensorMut`](crate::IntoDlTensorMut). `neighbors` (shape
76+
/// implement [`AsDlTensor`] /
77+
/// [`AsDlTensorMut`]. `neighbors` (shape
7978
/// `n_queries × k`) receives the neighbor indices and `distances` their
8079
/// distances; both are written in place.
81-
pub fn search<'a>(
80+
pub fn search<Q, N, D>(
8281
&self,
8382
res: &Resources,
8483
params: &SearchParams,
85-
queries: impl IntoDlTensor<'a>,
86-
neighbors: impl IntoDlTensorMut<'a>,
87-
distances: impl IntoDlTensorMut<'a>,
88-
) -> Result<()> {
89-
let queries = queries.into_dl_tensor()?;
90-
let neighbors = neighbors.into_dl_tensor_mut()?;
91-
let distances = distances.into_dl_tensor_mut()?;
84+
queries: &Q,
85+
neighbors: &mut N,
86+
distances: &mut D,
87+
) -> Result<()>
88+
where
89+
Q: AsDlTensor + ?Sized,
90+
N: AsDlTensorMut + ?Sized,
91+
D: AsDlTensorMut + ?Sized,
92+
{
93+
let queries = queries.as_dl_tensor()?;
94+
let neighbors = neighbors.as_dl_tensor_mut()?;
95+
let distances = distances.as_dl_tensor_mut()?;
9296
unsafe {
9397
let prefilter = ffi::cuvsFilter { addr: 0, type_: ffi::cuvsFilterType::NO_FILTER };
9498

@@ -113,19 +117,25 @@ impl<'d> Index<'d> {
113117
/// `queries`, `neighbors`, and `distances` are as in [`search`](Self::search).
114118
/// `bitset` is a 1-D `uint32` device tensor of `ceil(n_rows / 32)` elements,
115119
/// where each bit maps to a dataset row (1 = include, 0 = exclude).
116-
pub fn search_with_filter<'a>(
120+
pub fn search_with_filter<Q, N, D, B>(
117121
&self,
118122
res: &Resources,
119123
params: &SearchParams,
120-
queries: impl IntoDlTensor<'a>,
121-
neighbors: impl IntoDlTensorMut<'a>,
122-
distances: impl IntoDlTensorMut<'a>,
123-
bitset: impl IntoDlTensor<'a>,
124-
) -> Result<()> {
125-
let queries = queries.into_dl_tensor()?;
126-
let neighbors = neighbors.into_dl_tensor_mut()?;
127-
let distances = distances.into_dl_tensor_mut()?;
128-
let bitset = bitset.into_dl_tensor()?;
124+
queries: &Q,
125+
neighbors: &mut N,
126+
distances: &mut D,
127+
bitset: &B,
128+
) -> Result<()>
129+
where
130+
Q: AsDlTensor + ?Sized,
131+
N: AsDlTensorMut + ?Sized,
132+
D: AsDlTensorMut + ?Sized,
133+
B: AsDlTensor + ?Sized,
134+
{
135+
let queries = queries.as_dl_tensor()?;
136+
let neighbors = neighbors.as_dl_tensor_mut()?;
137+
let distances = distances.as_dl_tensor_mut()?;
138+
let bitset = bitset.as_dl_tensor()?;
129139
// The bitset pointer is cast to `usize` and stored in `prefilter`, then read
130140
// by the search call, so its `ManagedTensorRef` must outlive both.
131141
// Hence we keep it bound instead of chaining `to_c().as_mut_ptr()`.

rust/cuvs/src/cagra/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
//!
99
//! Build an [`Index`] from a dataset, then [`search`](Index::search) it with
1010
//! device-resident queries and output buffers. Tensors are passed through the
11-
//! [`IntoDlTensor`](crate::IntoDlTensor) /
12-
//! [`IntoDlTensorMut`](crate::IntoDlTensorMut) traits; see the
11+
//! [`AsDlTensor`](crate::AsDlTensor) /
12+
//! [`AsDlTensorMut`](crate::AsDlTensorMut) traits; see the
1313
//! [`dlpack`](crate::dlpack) module for the tensor model and `examples/cagra.rs`
1414
//! for a complete, runnable example.
1515

rust/cuvs/src/cluster/kmeans/mod.rs

Lines changed: 45 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,16 @@
77
//!
88
//! [`fit`] computes cluster centroids for a dataset, [`predict`] assigns points
99
//! to clusters, and [`cluster_cost`] reports the inertia. All inputs and outputs
10-
//! reside in device memory and are passed through the
11-
//! [`IntoDlTensor`] /
12-
//! [`IntoDlTensorMut`] traits; see the
10+
//! reside in device memory and are borrowed through the
11+
//! [`AsDlTensor`] /
12+
//! [`AsDlTensorMut`] traits; see the
1313
//! [`dlpack`](crate::dlpack) module for the tensor model.
1414
1515
mod params;
1616

1717
pub use params::Params;
1818

19-
use crate::dlpack::{IntoDlTensor, IntoDlTensorMut};
19+
use crate::dlpack::{AsDlTensor, AsDlTensorMut};
2020
use crate::error::{Result, check_cuvs};
2121
use crate::resources::Resources;
2222

@@ -25,18 +25,23 @@ use crate::resources::Resources;
2525
/// `x` (shape `m × k`) is the input matrix and `centroids` (shape
2626
/// `n_clusters × k`) receives the fitted centroids; `sample_weight` is an
2727
/// optional per-sample weight. All reside in device memory and implement
28-
/// [`IntoDlTensor`] /
29-
/// [`IntoDlTensorMut`].
30-
pub fn fit<'a>(
28+
/// [`AsDlTensor`] /
29+
/// [`AsDlTensorMut`].
30+
pub fn fit<X, W, C>(
3131
res: &Resources,
3232
params: &Params,
33-
x: impl IntoDlTensor<'a>,
34-
sample_weight: Option<impl IntoDlTensor<'a>>,
35-
centroids: impl IntoDlTensorMut<'a>,
36-
) -> Result<(f64, i32)> {
37-
let x = x.into_dl_tensor()?;
38-
let sample_weight = sample_weight.map(|w| w.into_dl_tensor()).transpose()?;
39-
let centroids = centroids.into_dl_tensor_mut()?;
33+
x: &X,
34+
sample_weight: Option<&W>,
35+
centroids: &mut C,
36+
) -> Result<(f64, i32)>
37+
where
38+
X: AsDlTensor + ?Sized,
39+
W: AsDlTensor + ?Sized,
40+
C: AsDlTensorMut + ?Sized,
41+
{
42+
let x = x.as_dl_tensor()?;
43+
let sample_weight = sample_weight.map(|w| w.as_dl_tensor()).transpose()?;
44+
let centroids = centroids.as_dl_tensor_mut()?;
4045
let mut inertia: f64 = 0.0;
4146
let mut niter: i32 = 0;
4247
let mut sample_weight_c = sample_weight.as_ref().map(|w| w.to_c());
@@ -62,22 +67,28 @@ pub fn fit<'a>(
6267
///
6368
/// `x` (shape `m × k`), `centroids` (shape `n_clusters × k`), the optional
6469
/// `sample_weight`, and `labels` (shape `m × 1`) reside in device memory and
65-
/// implement [`IntoDlTensor`] /
66-
/// [`IntoDlTensorMut`]. `normalize_weight` selects
70+
/// implement [`AsDlTensor`] /
71+
/// [`AsDlTensorMut`]. `normalize_weight` selects
6772
/// whether the sample weights are normalized.
68-
pub fn predict<'a>(
73+
pub fn predict<X, W, C, L>(
6974
res: &Resources,
7075
params: &Params,
71-
x: impl IntoDlTensor<'a>,
72-
sample_weight: Option<impl IntoDlTensor<'a>>,
73-
centroids: impl IntoDlTensor<'a>,
74-
labels: impl IntoDlTensorMut<'a>,
76+
x: &X,
77+
sample_weight: Option<&W>,
78+
centroids: &C,
79+
labels: &mut L,
7580
normalize_weight: bool,
76-
) -> Result<f64> {
77-
let x = x.into_dl_tensor()?;
78-
let sample_weight = sample_weight.map(|w| w.into_dl_tensor()).transpose()?;
79-
let centroids = centroids.into_dl_tensor()?;
80-
let labels = labels.into_dl_tensor_mut()?;
81+
) -> Result<f64>
82+
where
83+
X: AsDlTensor + ?Sized,
84+
W: AsDlTensor + ?Sized,
85+
C: AsDlTensor + ?Sized,
86+
L: AsDlTensorMut + ?Sized,
87+
{
88+
let x = x.as_dl_tensor()?;
89+
let sample_weight = sample_weight.map(|w| w.as_dl_tensor()).transpose()?;
90+
let centroids = centroids.as_dl_tensor()?;
91+
let labels = labels.as_dl_tensor_mut()?;
8192
let mut inertia: f64 = 0.0;
8293
let mut sample_weight_c = sample_weight.as_ref().map(|w| w.to_c());
8394
let sample_weight_ptr =
@@ -101,14 +112,14 @@ pub fn predict<'a>(
101112
/// Computes the k-means cost (inertia) of `x` against existing `centroids`.
102113
///
103114
/// `x` (shape `m × k`) and `centroids` (shape `n_clusters × k`) reside in device
104-
/// memory and implement [`IntoDlTensor`].
105-
pub fn cluster_cost<'a>(
106-
res: &Resources,
107-
x: impl IntoDlTensor<'a>,
108-
centroids: impl IntoDlTensor<'a>,
109-
) -> Result<f64> {
110-
let x = x.into_dl_tensor()?;
111-
let centroids = centroids.into_dl_tensor()?;
115+
/// memory and implement [`AsDlTensor`].
116+
pub fn cluster_cost<X, C>(res: &Resources, x: &X, centroids: &C) -> Result<f64>
117+
where
118+
X: AsDlTensor + ?Sized,
119+
C: AsDlTensor + ?Sized,
120+
{
121+
let x = x.as_dl_tensor()?;
122+
let centroids = centroids.as_dl_tensor()?;
112123
let mut inertia: f64 = 0.0;
113124

114125
unsafe {

0 commit comments

Comments
 (0)