|
3 | 3 | * SPDX-License-Identifier: Apache-2.0 |
4 | 4 | */ |
5 | 5 |
|
| 6 | +//! CAGRA example with a user-provided GPU tensor. |
| 7 | +//! |
| 8 | +//! This demonstrates how to feed your own device memory into cuVS by |
| 9 | +//! implementing the public [`IntoDlTensor`]/[`IntoDlTensorMut`] traits. The |
| 10 | +//! [`CudaTensor`] type manages device memory directly through the CUDA runtime |
| 11 | +//! (`cudaMalloc`/`cudaFree`) and copies to/from host arrays with `cudaMemcpyAsync` |
| 12 | +//! on the cuVS stream, reusing the resources handle's `get_cuda_stream`/ |
| 13 | +//! `sync_stream` for stream access and synchronization. |
| 14 | +//! |
| 15 | +//! A real application would likely rely on a helper crate such as `cudarc` |
| 16 | +//! and its `CudaSlice`. |
| 17 | +
|
| 18 | +use std::ffi::c_void; |
| 19 | +use std::marker::PhantomData; |
| 20 | +use std::os::raw::c_int; |
| 21 | + |
| 22 | +use cuvs::Resources; |
6 | 23 | use cuvs::cagra::{Index, IndexParams, SearchParams}; |
7 | | -use cuvs::{ManagedTensor, Resources, Result}; |
| 24 | +use cuvs::dlpack::{ |
| 25 | + DLDevice, DLDeviceType, DLPackError, DLTensorView, DLTensorViewMut, DType, IntoDlTensor, |
| 26 | + IntoDlTensorMut, |
| 27 | +}; |
8 | 28 |
|
9 | 29 | use ndarray::s; |
10 | 30 | use ndarray_rand::RandomExt; |
11 | 31 | use ndarray_rand::rand_distr::Uniform; |
12 | 32 |
|
13 | | -/// Example showing how to index and search data with CAGRA |
14 | | -fn cagra_example() -> Result<()> { |
| 33 | +type ExampleResult<T> = std::result::Result<T, Box<dyn std::error::Error>>; |
| 34 | + |
| 35 | +// --------------------------------------------------------------------------- |
| 36 | +// Minimal CUDA runtime FFI |
| 37 | +// --------------------------------------------------------------------------- |
| 38 | + |
| 39 | +#[allow(non_camel_case_types)] |
| 40 | +type cudaError_t = c_int; |
| 41 | +const CUDA_SUCCESS: cudaError_t = 0; |
| 42 | +const CUDA_MEMCPY_HOST_TO_DEVICE: c_int = 1; |
| 43 | +const CUDA_MEMCPY_DEVICE_TO_HOST: c_int = 2; |
| 44 | + |
| 45 | +#[link(name = "cudart")] |
| 46 | +unsafe extern "C" { |
| 47 | + fn cudaMalloc(ptr: *mut *mut c_void, size: usize) -> cudaError_t; |
| 48 | + fn cudaFree(ptr: *mut c_void) -> cudaError_t; |
| 49 | + fn cudaMemcpyAsync( |
| 50 | + dst: *mut c_void, |
| 51 | + src: *const c_void, |
| 52 | + count: usize, |
| 53 | + kind: c_int, |
| 54 | + stream: cuvs_sys::cudaStream_t, |
| 55 | + ) -> cudaError_t; |
| 56 | +} |
| 57 | + |
| 58 | +fn check_cuda(status: cudaError_t) -> ExampleResult<()> { |
| 59 | + if status == CUDA_SUCCESS { |
| 60 | + Ok(()) |
| 61 | + } else { |
| 62 | + Err(format!("CUDA runtime error: {status}").into()) |
| 63 | + } |
| 64 | +} |
| 65 | + |
| 66 | +// --------------------------------------------------------------------------- |
| 67 | +// A custom device tensor backed by the CUDA runtime |
| 68 | +// --------------------------------------------------------------------------- |
| 69 | + |
| 70 | +struct CudaTensor<T: DType> { |
| 71 | + data: *mut c_void, |
| 72 | + shape: Vec<i64>, |
| 73 | + bytes: usize, |
| 74 | + _marker: PhantomData<T>, |
| 75 | +} |
| 76 | + |
| 77 | +impl<T: DType> CudaTensor<T> { |
| 78 | + /// Allocate an uninitialized device buffer (used for search outputs). |
| 79 | + fn alloc(shape: &[usize]) -> ExampleResult<Self> { |
| 80 | + let bytes = shape.iter().product::<usize>() * std::mem::size_of::<T>(); |
| 81 | + let mut data: *mut c_void = std::ptr::null_mut(); |
| 82 | + check_cuda(unsafe { cudaMalloc(&mut data, bytes) })?; |
| 83 | + Ok(Self { |
| 84 | + data, |
| 85 | + shape: shape.iter().map(|&d| d as i64).collect(), |
| 86 | + bytes, |
| 87 | + _marker: PhantomData, |
| 88 | + }) |
| 89 | + } |
| 90 | + |
| 91 | + /// Copy a contiguous host array onto the device. |
| 92 | + fn from_host<D>(res: &Resources, host: &ndarray::ArrayRef<T, D>) -> ExampleResult<Self> |
| 93 | + where |
| 94 | + D: ndarray::Dimension, |
| 95 | + { |
| 96 | + if !host.is_standard_layout() { |
| 97 | + return Err("host array must be contiguous (row-major)".into()); |
| 98 | + } |
| 99 | + let tensor = Self::alloc(host.shape())?; |
| 100 | + |
| 101 | + let stream = res.get_cuda_stream()?; |
| 102 | + check_cuda(unsafe { |
| 103 | + cudaMemcpyAsync( |
| 104 | + tensor.data, |
| 105 | + host.as_ptr() as *const c_void, |
| 106 | + tensor.bytes, |
| 107 | + CUDA_MEMCPY_HOST_TO_DEVICE, |
| 108 | + stream, |
| 109 | + ) |
| 110 | + })?; |
| 111 | + res.sync_stream()?; |
| 112 | + |
| 113 | + Ok(tensor) |
| 114 | + } |
| 115 | + |
| 116 | + /// Copy the device buffer back into a contiguous host array. |
| 117 | + fn to_host<D>(&self, res: &Resources, host: &mut ndarray::ArrayRef<T, D>) -> ExampleResult<()> |
| 118 | + where |
| 119 | + D: ndarray::Dimension, |
| 120 | + { |
| 121 | + if !host.is_standard_layout() { |
| 122 | + return Err("host array must be contiguous (row-major)".into()); |
| 123 | + } |
| 124 | + |
| 125 | + let stream = res.get_cuda_stream()?; |
| 126 | + check_cuda(unsafe { |
| 127 | + cudaMemcpyAsync( |
| 128 | + host.as_mut_ptr() as *mut c_void, |
| 129 | + self.data, |
| 130 | + self.bytes, |
| 131 | + CUDA_MEMCPY_DEVICE_TO_HOST, |
| 132 | + stream, |
| 133 | + ) |
| 134 | + })?; |
| 135 | + res.sync_stream()?; |
| 136 | + |
| 137 | + Ok(()) |
| 138 | + } |
| 139 | +} |
| 140 | + |
| 141 | +impl<T: DType> Drop for CudaTensor<T> { |
| 142 | + fn drop(&mut self) { |
| 143 | + if !self.data.is_null() { |
| 144 | + unsafe { cudaFree(self.data) }; |
| 145 | + } |
| 146 | + } |
| 147 | +} |
| 148 | + |
| 149 | +impl<'a, T: DType> IntoDlTensor<'a> for &'a CudaTensor<T> { |
| 150 | + fn into_dl_tensor(self) -> std::result::Result<DLTensorView<'a>, DLPackError> { |
| 151 | + unsafe { |
| 152 | + DLTensorView::from_raw_parts( |
| 153 | + self.data, |
| 154 | + DLDevice { device_type: DLDeviceType::kDLCUDA, device_id: 0 }, |
| 155 | + &self.shape, |
| 156 | + None, |
| 157 | + T::dl_dtype(), |
| 158 | + ) |
| 159 | + } |
| 160 | + } |
| 161 | +} |
| 162 | + |
| 163 | +impl<'a, T: DType> IntoDlTensorMut<'a> for &'a mut CudaTensor<T> { |
| 164 | + fn into_dl_tensor_mut(self) -> std::result::Result<DLTensorViewMut<'a>, DLPackError> { |
| 165 | + unsafe { |
| 166 | + DLTensorViewMut::from_raw_parts( |
| 167 | + self.data, |
| 168 | + DLDevice { device_type: DLDeviceType::kDLCUDA, device_id: 0 }, |
| 169 | + &self.shape, |
| 170 | + None, |
| 171 | + T::dl_dtype(), |
| 172 | + ) |
| 173 | + } |
| 174 | + } |
| 175 | +} |
| 176 | + |
| 177 | +/// Example showing how to index and search data with CAGRA. |
| 178 | +fn cagra_example() -> ExampleResult<()> { |
15 | 179 | let res = Resources::new()?; |
16 | 180 |
|
17 | | - // Create a new random dataset to index |
| 181 | + // Create a new random dataset to index and copy it to the device. |
18 | 182 | let n_datapoints = 65536; |
19 | 183 | let n_features = 512; |
20 | | - let dataset = |
21 | | - ndarray::Array::<f32, _>::random((n_datapoints, n_features), Uniform::new(0., 1.0)); |
| 184 | + let dataset_host = ndarray::Array::<f32, _>::random( |
| 185 | + (n_datapoints, n_features), |
| 186 | + Uniform::new(0., 1.0).unwrap(), |
| 187 | + ); |
| 188 | + let dataset = CudaTensor::from_host(&res, &dataset_host)?; |
22 | 189 |
|
23 | | - // build the cagra index |
| 190 | + // Build the CAGRA index. |
24 | 191 | let build_params = IndexParams::new()?; |
25 | 192 | let index = Index::build(&res, &build_params, &dataset)?; |
26 | | - println!("Indexed {}x{} datapoints into cagra index", n_datapoints, n_features); |
| 193 | + println!("Indexed {n_datapoints}x{n_features} datapoints into cagra index"); |
27 | 194 |
|
28 | | - // use the first 4 points from the dataset as queries : will test that we get them back |
29 | | - // as their own nearest neighbor |
| 195 | + // Use the first 4 points as queries; each should be its own nearest neighbor. |
30 | 196 | let n_queries = 4; |
31 | | - let queries = dataset.slice(s![0..n_queries, ..]); |
32 | | - |
33 | 197 | let k = 10; |
| 198 | + let queries_host = dataset_host.slice(s![0..n_queries, ..]).to_owned(); |
| 199 | + let queries = CudaTensor::from_host(&res, &queries_host)?; |
34 | 200 |
|
35 | | - // CAGRA search API requires queries and outputs to be on device memory |
36 | | - // copy query data over, and allocate new device memory for the distances/ neighbors |
37 | | - // outputs |
38 | | - let queries = ManagedTensor::from(&queries).to_device(&res)?; |
39 | | - let mut neighbors_host = ndarray::Array::<u32, _>::zeros((n_queries, k)); |
40 | | - let neighbors = ManagedTensor::from(&neighbors_host).to_device(&res)?; |
41 | | - |
42 | | - let mut distances_host = ndarray::Array::<f32, _>::zeros((n_queries, k)); |
43 | | - let distances = ManagedTensor::from(&distances_host).to_device(&res)?; |
| 201 | + let mut neighbors = CudaTensor::<u32>::alloc(&[n_queries, k])?; |
| 202 | + let mut distances = CudaTensor::<f32>::alloc(&[n_queries, k])?; |
44 | 203 |
|
45 | 204 | let search_params = SearchParams::new()?; |
| 205 | + index.search(&res, &search_params, &queries, &mut neighbors, &mut distances)?; |
46 | 206 |
|
47 | | - index.search(&res, &search_params, &queries, &neighbors, &distances)?; |
48 | | - |
49 | | - // Copy back to host memory |
50 | | - distances.to_host(&res, &mut distances_host)?; |
| 207 | + // Copy the results back to the host. |
| 208 | + let mut neighbors_host = ndarray::Array::<u32, _>::zeros((n_queries, k)); |
| 209 | + let mut distances_host = ndarray::Array::<f32, _>::zeros((n_queries, k)); |
51 | 210 | neighbors.to_host(&res, &mut neighbors_host)?; |
| 211 | + distances.to_host(&res, &mut distances_host)?; |
52 | 212 |
|
53 | | - // nearest neighbors should be themselves, since queries are from the |
54 | | - // dataset |
55 | | - println!("Neighbors {:?}", neighbors_host); |
56 | | - println!("Distances {:?}", distances_host); |
| 213 | + println!("Neighbors {neighbors_host:?}"); |
| 214 | + println!("Distances {distances_host:?}"); |
57 | 215 | Ok(()) |
58 | 216 | } |
59 | 217 |
|
60 | 218 | fn main() { |
61 | 219 | if let Err(e) = cagra_example() { |
62 | | - println!("Failed to run CAGRA: {:?}", e); |
| 220 | + println!("Failed to run CAGRA: {e:?}"); |
63 | 221 | } |
64 | 222 | } |
0 commit comments