Skip to content

Commit 1bc3436

Browse files
committed
Track tensor and index lifetimes in the Rust DLPack bindings
The previous `ManagedTensor` type was a non-owning wrapper over the C's FFI `DLManagedTensor` type with no lifetime attached. Hence there was no mechanism to tie the tensor data and shape/stride metadata it referenced to their owners. Indexes that keep a non-owning view of their dataset (CAGRA, brute-force) could outlive that data. Here we replace it with lifetime-parameterized `DLTensorView` and `DLTensorViewMut` views. They are produced by the public `IntoDlTensor` and `IntoDlTensorMut` traits. Users are now expected to implement these traits on their tensor types, so that our API can accept them as input/output arguments.
1 parent 6672103 commit 1bc3436

18 files changed

Lines changed: 1337 additions & 734 deletions

File tree

rust/cuvs/Cargo.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,12 @@ doc-only = ["cuvs-sys/doc-only"]
1414

1515
[dependencies]
1616
cuvs-sys = { workspace = true }
17-
ndarray = "0.15"
17+
thiserror = "2"
18+
tinyvec = { version = "1", features = ["alloc", "latest_stable_rust"] }
1819

1920
[dev-dependencies]
20-
ndarray-rand = "0.14"
21+
ndarray = "0.17"
22+
ndarray-rand = "0.16"
2123

2224
[package.metadata.docs.rs]
2325
features = ["doc-only"]

rust/cuvs/examples/cagra.rs

Lines changed: 188 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -3,62 +3,220 @@
33
* SPDX-License-Identifier: Apache-2.0
44
*/
55

6+
//! CAGRA example with a user-provided GPU tensor.
7+
//!
8+
//! This demonstrates how to feed your own device memory into cuVS by
9+
//! implementing the public [`IntoDlTensor`]/[`IntoDlTensorMut`] traits. The
10+
//! [`CudaTensor`] type manages device memory directly through the CUDA runtime
11+
//! (`cudaMalloc`/`cudaFree`) and copies to/from host arrays with `cudaMemcpyAsync`
12+
//! on the cuVS stream, reusing the resources handle's `get_cuda_stream`/
13+
//! `sync_stream` for stream access and synchronization.
14+
//!
15+
//! A real application would likely rely on a helper crate such as `cudarc`
16+
//! and its `CudaSlice`.
17+
18+
use std::ffi::c_void;
19+
use std::marker::PhantomData;
20+
use std::os::raw::c_int;
21+
22+
use cuvs::Resources;
623
use cuvs::cagra::{Index, IndexParams, SearchParams};
7-
use cuvs::{ManagedTensor, Resources, Result};
24+
use cuvs::dlpack::{
25+
DLDevice, DLDeviceType, DLPackError, DLTensorView, DLTensorViewMut, DType, IntoDlTensor,
26+
IntoDlTensorMut,
27+
};
828

929
use ndarray::s;
1030
use ndarray_rand::RandomExt;
1131
use ndarray_rand::rand_distr::Uniform;
1232

13-
/// Example showing how to index and search data with CAGRA
14-
fn cagra_example() -> Result<()> {
33+
type ExampleResult<T> = std::result::Result<T, Box<dyn std::error::Error>>;
34+
35+
// ---------------------------------------------------------------------------
36+
// Minimal CUDA runtime FFI
37+
// ---------------------------------------------------------------------------
38+
39+
#[allow(non_camel_case_types)]
40+
type cudaError_t = c_int;
41+
const CUDA_SUCCESS: cudaError_t = 0;
42+
const CUDA_MEMCPY_HOST_TO_DEVICE: c_int = 1;
43+
const CUDA_MEMCPY_DEVICE_TO_HOST: c_int = 2;
44+
45+
#[link(name = "cudart")]
46+
unsafe extern "C" {
47+
fn cudaMalloc(ptr: *mut *mut c_void, size: usize) -> cudaError_t;
48+
fn cudaFree(ptr: *mut c_void) -> cudaError_t;
49+
fn cudaMemcpyAsync(
50+
dst: *mut c_void,
51+
src: *const c_void,
52+
count: usize,
53+
kind: c_int,
54+
stream: cuvs_sys::cudaStream_t,
55+
) -> cudaError_t;
56+
}
57+
58+
fn check_cuda(status: cudaError_t) -> ExampleResult<()> {
59+
if status == CUDA_SUCCESS {
60+
Ok(())
61+
} else {
62+
Err(format!("CUDA runtime error: {status}").into())
63+
}
64+
}
65+
66+
// ---------------------------------------------------------------------------
67+
// A custom device tensor backed by the CUDA runtime
68+
// ---------------------------------------------------------------------------
69+
70+
struct CudaTensor<T: DType> {
71+
data: *mut c_void,
72+
shape: Vec<i64>,
73+
bytes: usize,
74+
_marker: PhantomData<T>,
75+
}
76+
77+
impl<T: DType> CudaTensor<T> {
78+
/// Allocate an uninitialized device buffer (used for search outputs).
79+
fn alloc(shape: &[usize]) -> ExampleResult<Self> {
80+
let bytes = shape.iter().product::<usize>() * std::mem::size_of::<T>();
81+
let mut data: *mut c_void = std::ptr::null_mut();
82+
check_cuda(unsafe { cudaMalloc(&mut data, bytes) })?;
83+
Ok(Self {
84+
data,
85+
shape: shape.iter().map(|&d| d as i64).collect(),
86+
bytes,
87+
_marker: PhantomData,
88+
})
89+
}
90+
91+
/// Copy a contiguous host array onto the device.
92+
fn from_host<D>(res: &Resources, host: &ndarray::ArrayRef<T, D>) -> ExampleResult<Self>
93+
where
94+
D: ndarray::Dimension,
95+
{
96+
if !host.is_standard_layout() {
97+
return Err("host array must be contiguous (row-major)".into());
98+
}
99+
let tensor = Self::alloc(host.shape())?;
100+
101+
let stream = res.get_cuda_stream()?;
102+
check_cuda(unsafe {
103+
cudaMemcpyAsync(
104+
tensor.data,
105+
host.as_ptr() as *const c_void,
106+
tensor.bytes,
107+
CUDA_MEMCPY_HOST_TO_DEVICE,
108+
stream,
109+
)
110+
})?;
111+
res.sync_stream()?;
112+
113+
Ok(tensor)
114+
}
115+
116+
/// Copy the device buffer back into a contiguous host array.
117+
fn to_host<D>(&self, res: &Resources, host: &mut ndarray::ArrayRef<T, D>) -> ExampleResult<()>
118+
where
119+
D: ndarray::Dimension,
120+
{
121+
if !host.is_standard_layout() {
122+
return Err("host array must be contiguous (row-major)".into());
123+
}
124+
125+
let stream = res.get_cuda_stream()?;
126+
check_cuda(unsafe {
127+
cudaMemcpyAsync(
128+
host.as_mut_ptr() as *mut c_void,
129+
self.data,
130+
self.bytes,
131+
CUDA_MEMCPY_DEVICE_TO_HOST,
132+
stream,
133+
)
134+
})?;
135+
res.sync_stream()?;
136+
137+
Ok(())
138+
}
139+
}
140+
141+
impl<T: DType> Drop for CudaTensor<T> {
142+
fn drop(&mut self) {
143+
if !self.data.is_null() {
144+
unsafe { cudaFree(self.data) };
145+
}
146+
}
147+
}
148+
149+
impl<'a, T: DType> IntoDlTensor<'a> for &'a CudaTensor<T> {
150+
fn into_dl_tensor(self) -> std::result::Result<DLTensorView<'a>, DLPackError> {
151+
unsafe {
152+
DLTensorView::from_raw_parts(
153+
self.data,
154+
DLDevice { device_type: DLDeviceType::kDLCUDA, device_id: 0 },
155+
&self.shape,
156+
None,
157+
T::dl_dtype(),
158+
)
159+
}
160+
}
161+
}
162+
163+
impl<'a, T: DType> IntoDlTensorMut<'a> for &'a mut CudaTensor<T> {
164+
fn into_dl_tensor_mut(self) -> std::result::Result<DLTensorViewMut<'a>, DLPackError> {
165+
unsafe {
166+
DLTensorViewMut::from_raw_parts(
167+
self.data,
168+
DLDevice { device_type: DLDeviceType::kDLCUDA, device_id: 0 },
169+
&self.shape,
170+
None,
171+
T::dl_dtype(),
172+
)
173+
}
174+
}
175+
}
176+
177+
/// Example showing how to index and search data with CAGRA.
178+
fn cagra_example() -> ExampleResult<()> {
15179
let res = Resources::new()?;
16180

17-
// Create a new random dataset to index
181+
// Create a new random dataset to index and copy it to the device.
18182
let n_datapoints = 65536;
19183
let n_features = 512;
20-
let dataset =
21-
ndarray::Array::<f32, _>::random((n_datapoints, n_features), Uniform::new(0., 1.0));
184+
let dataset_host = ndarray::Array::<f32, _>::random(
185+
(n_datapoints, n_features),
186+
Uniform::new(0., 1.0).unwrap(),
187+
);
188+
let dataset = CudaTensor::from_host(&res, &dataset_host)?;
22189

23-
// build the cagra index
190+
// Build the CAGRA index.
24191
let build_params = IndexParams::new()?;
25192
let index = Index::build(&res, &build_params, &dataset)?;
26-
println!("Indexed {}x{} datapoints into cagra index", n_datapoints, n_features);
193+
println!("Indexed {n_datapoints}x{n_features} datapoints into cagra index");
27194

28-
// use the first 4 points from the dataset as queries : will test that we get them back
29-
// as their own nearest neighbor
195+
// Use the first 4 points as queries; each should be its own nearest neighbor.
30196
let n_queries = 4;
31-
let queries = dataset.slice(s![0..n_queries, ..]);
32-
33197
let k = 10;
198+
let queries_host = dataset_host.slice(s![0..n_queries, ..]).to_owned();
199+
let queries = CudaTensor::from_host(&res, &queries_host)?;
34200

35-
// CAGRA search API requires queries and outputs to be on device memory
36-
// copy query data over, and allocate new device memory for the distances/ neighbors
37-
// outputs
38-
let queries = ManagedTensor::from(&queries).to_device(&res)?;
39-
let mut neighbors_host = ndarray::Array::<u32, _>::zeros((n_queries, k));
40-
let neighbors = ManagedTensor::from(&neighbors_host).to_device(&res)?;
41-
42-
let mut distances_host = ndarray::Array::<f32, _>::zeros((n_queries, k));
43-
let distances = ManagedTensor::from(&distances_host).to_device(&res)?;
201+
let mut neighbors = CudaTensor::<u32>::alloc(&[n_queries, k])?;
202+
let mut distances = CudaTensor::<f32>::alloc(&[n_queries, k])?;
44203

45204
let search_params = SearchParams::new()?;
205+
index.search(&res, &search_params, &queries, &mut neighbors, &mut distances)?;
46206

47-
index.search(&res, &search_params, &queries, &neighbors, &distances)?;
48-
49-
// Copy back to host memory
50-
distances.to_host(&res, &mut distances_host)?;
207+
// Copy the results back to the host.
208+
let mut neighbors_host = ndarray::Array::<u32, _>::zeros((n_queries, k));
209+
let mut distances_host = ndarray::Array::<f32, _>::zeros((n_queries, k));
51210
neighbors.to_host(&res, &mut neighbors_host)?;
211+
distances.to_host(&res, &mut distances_host)?;
52212

53-
// nearest neighbors should be themselves, since queries are from the
54-
// dataset
55-
println!("Neighbors {:?}", neighbors_host);
56-
println!("Distances {:?}", distances_host);
213+
println!("Neighbors {neighbors_host:?}");
214+
println!("Distances {distances_host:?}");
57215
Ok(())
58216
}
59217

60218
fn main() {
61219
if let Err(e) = cagra_example() {
62-
println!("Failed to run CAGRA: {:?}", e);
220+
println!("Failed to run CAGRA: {e:?}");
63221
}
64222
}

0 commit comments

Comments
 (0)