Skip to content

Commit b1c414c

Browse files
author
Ehsan M. Kermani
committed
Make manager context pinnable with safer and more ergonomic bi-directional conversions
1 parent b088453 commit b1c414c

File tree

3 files changed

+160
-88
lines changed

3 files changed

+160
-88
lines changed

examples/sample/src/main.rs

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
#![allow(clippy::drop_copy)]
1+
#![allow(clippy::drop_ref)]
22

3-
use std::{ffi::c_void, mem, ptr};
3+
use std::ffi::c_void;
44

5-
use dlpackrs::{ffi, DataType, Device, ManagedTensor as DLManagedTensor, Tensor as DLTensor};
5+
use dlpackrs::{DataType, Device, ManagedTensor as DLManagedTensor, Tensor as DLTensor};
66
use ndarray::{Array, ArrayD, RawArrayViewMut};
77

88
#[derive(Debug)]
@@ -32,19 +32,20 @@ impl<'tensor> From<&'tensor mut Tensor<'tensor>> for ArrayD<f32> {
3232
}
3333
}
3434

35+
// The context holds DLManagedTensor
3536
#[derive(Debug)]
36-
pub struct ManagedTensor<'tensor, 'ctx>(DLManagedTensor<'tensor, 'ctx>);
37+
pub struct ManagedContext<'tensor, C>(DLManagedTensor<'tensor, C>);
3738

38-
impl<'tensor, 'ctx> From<&'tensor mut ArrayD<f32>> for ManagedTensor<'tensor, 'ctx> {
39+
impl<'tensor, C> From<&'tensor mut ArrayD<f32>> for ManagedContext<'tensor, C> {
3940
fn from(t: &'tensor mut ArrayD<f32>) -> Self {
4041
let dlt: Tensor<'tensor> = Tensor::from(t);
41-
let inner = DLManagedTensor::new(dlt.0, ptr::null_mut());
42-
ManagedTensor(inner)
42+
let inner = DLManagedTensor::new(dlt.0, None);
43+
ManagedContext(inner)
4344
}
4445
}
4546

46-
impl<'tensor, 'ctx> From<&mut ManagedTensor<'tensor, 'ctx>> for ArrayD<f32> {
47-
fn from(mt: &mut ManagedTensor<'tensor, 'ctx>) -> Self {
47+
impl<'tensor, C> From<&mut ManagedContext<'tensor, C>> for ArrayD<f32> {
48+
fn from(mt: &mut ManagedContext<'tensor, C>) -> Self {
4849
let dlt: DLTensor = mt.0.inner.dl_tensor.into();
4950
unsafe {
5051
let arr = RawArrayViewMut::from_shape_ptr(dlt.shape().unwrap(), dlt.data() as *mut f32);
@@ -70,16 +71,11 @@ fn main() {
7071
let pong = ArrayD::from(&mut tensor);
7172
println!("pong {:?}", pong);
7273
assert!(pong.into_dyn().abs_diff_eq(&ping, 1e-8f32));
73-
// TODO: use dummy ctx holding the managed_tensor
74-
let mut managed_tensor = ManagedTensor::from(&mut ping);
74+
let mut managed_tensor: ManagedContext<f32> = (&mut ping).into();
7575
println!("managed tensor {:?}", managed_tensor);
76-
let deleter = unsafe {
77-
mem::transmute::<fn(&mut ManagedTensor), unsafe extern "C" fn(*mut ffi::DLManagedTensor)>(
78-
|mt| {
79-
println!("deleter is called!");
80-
drop(mt as *mut _);
81-
},
82-
)
76+
let deleter: fn(&mut DLManagedTensor<f32>) = |managed_tensor| {
77+
println!("manager tensor deleter is called");
78+
drop(managed_tensor);
8379
};
8480
managed_tensor.0.set_deleter(deleter);
8581
println!("managed tensor with deleter {:?}", managed_tensor);

src/lib.rs

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
//!
2626
//! In this case, `ManagedTensor` is built from `ManagedTensorProxy` which is a safe proxy for the unsafe `ffi::DLManagedTensor`.
2727
//!
28-
//! ### Non-Memory Manged Tensor (View)
28+
//! ### Plain Not-Memory-Managed Tensor
2929
//!
3030
//! In this case, the (invariant) Rust wrapper `Tensor` can be used or if needed the unsafe `ffi::DLTensor`.
3131
//!
@@ -35,21 +35,21 @@
3535
//!
3636
//! When ownership is concerned, one can use the `ManagedTensor`. Here is an example on how the bi-directional conversion
3737
//!
38-
//! <div align="center">ndarray::ArrayD <--> ManagedTensor</div>
38+
//! <div align="center">ndarray::ArrayD <---> ManagedTensor</div>
3939
//!
4040
//! is done at zero-cost.
4141
//!
4242
//! ```no_run
43-
//! impl<'tensor, 'ctx> From<&'tensor mut ArrayD<f32>> for ManagedTensor<'tensor, 'ctx> {
43+
//! impl<'tensor, C> From<&'tensor mut ArrayD<f32>> for ManagedContext<'tensor, C> {
4444
//! fn from(t: &'tensor mut ArrayD<f32>) -> Self {
4545
//! let dlt: Tensor<'tensor> = Tensor::from(t);
46-
//! let inner = DLManagedTensor::new(dlt.0, ptr::null_mut());
47-
//! ManagedTensor(inner)
46+
//! let inner = DLManagedTensor::new(dlt.0, None);
47+
//! ManagedContext(inner)
4848
//! }
4949
//! }
5050
//!
51-
//! impl<'tensor, 'ctx> From<&mut ManagedTensor<'tensor, 'ctx>> for ArrayD<f32> {
52-
//! fn from(mt: &mut ManagedTensor<'tensor, 'ctx>) -> Self {
51+
//! impl<'tensor, C> From<&mut ManagedContext<'tensor, C>> for ArrayD<f32> {
52+
//! fn from(mt: &mut ManagedContext<'tensor, C>) -> Self {
5353
//! let dlt: DLTensor = mt.0.inner.dl_tensor.into();
5454
//! unsafe {
5555
//! let arr = RawArrayViewMut::from_shape_ptr(dlt.shape().unwrap(), dlt.data() as *mut f32);
@@ -63,7 +63,7 @@
6363
//!
6464
//! And when ownership is not concerned, one can use `Tensor` as a view. Here is an example on how the bi-directional converion
6565
//!
66-
//! <div align="center">ndarray::ArrayD <--> Tensor</div>
66+
//! <div align="center">ndarray::ArrayD <---> Tensor</div>
6767
//!
6868
//! is done at zero-cost.
6969
//!
@@ -104,12 +104,12 @@ pub mod ffi {
104104

105105
pub mod datatype;
106106
pub mod device;
107-
mod errors;
107+
pub mod errors;
108108
pub mod tensor;
109109

110110
pub use datatype::{DataType, DataTypeCode};
111111
pub use device::{Device, DeviceType};
112-
pub use tensor::{ManagedTensor, ManagedTensorProxy, Tensor};
112+
pub use tensor::{ManagedTensor, ManagedTensorProxy, ManagerContext, Tensor};
113113

114114
pub fn version() -> u32 {
115115
ffi::DLPACK_VERSION

src/tensor.rs

Lines changed: 136 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,14 @@
11
use pin_project::{pin_project, pinned_drop};
22

33
use core::slice;
4-
use std::{marker::PhantomData, os::raw::c_void, pin::Pin};
4+
use std::{
5+
fmt::Debug,
6+
marker::{PhantomData, PhantomPinned},
7+
mem::transmute,
8+
os::raw::c_void,
9+
pin::Pin,
10+
ptr::{self, NonNull},
11+
};
512

613
use crate::{
714
datatype::DataType,
@@ -138,84 +145,137 @@ impl<'tensor> Tensor<'tensor> {
138145
}
139146
}
140147

141-
/// Safe proxy to ffi::DLManagedTensor which is self-referential by design.
142-
/// See [DLManagedTensor](https://dmlc.github.io/dlpack/latest/c_api.html#_CPPv415DLManagedTensor)
148+
/// A typed ManagerContext type that is `!Unpin` i.e. pinnable for safety since it holds a pointer to the underlying DLTensor.
143149
#[derive(Debug)]
144-
#[pin_project(PinnedDrop)]
145-
pub struct ManagedTensorProxy<'ctx> {
146-
pub dl_tensor: DLTensor,
147-
#[pin]
148-
/// Holds the underlying DLTensor.
149-
pub manager_ctx: *mut c_void,
150-
#[pin]
151-
pub deleter: Option<unsafe extern "C" fn(*mut DLManagedTensor)>,
152-
_marker: PhantomData<&'ctx ()>, // covariant wrt 'ctx
150+
#[repr(C)]
151+
pub struct ManagerContext<C> {
152+
pub ptr: Option<NonNull<*mut c_void>>,
153+
ty: PhantomData<C>,
154+
_pin: PhantomPinned,
153155
}
154156

155-
impl<'ctx> From<DLManagedTensor> for ManagedTensorProxy<'ctx> {
156-
fn from(dlmt: DLManagedTensor) -> Self {
157-
ManagedTensorProxy {
158-
dl_tensor: dlmt.dl_tensor,
159-
manager_ctx: dlmt.manager_ctx,
160-
deleter: dlmt.deleter,
161-
_marker: PhantomData,
157+
impl<C> ManagerContext<C> {
158+
pub fn new(ptr: Option<NonNull<*mut c_void>>) -> Self {
159+
Self {
160+
ptr,
161+
ty: PhantomData,
162+
_pin: PhantomPinned,
162163
}
163164
}
164165
}
165166

166-
impl<'ctx> From<ManagedTensorProxy<'ctx>> for DLManagedTensor {
167-
fn from(pmt: ManagedTensorProxy<'ctx>) -> Self {
168-
DLManagedTensor {
169-
dl_tensor: pmt.dl_tensor,
170-
manager_ctx: pmt.manager_ctx,
171-
deleter: pmt.deleter,
172-
}
173-
}
167+
/// Safe proxy to ffi::DLManagedTensor which is self-referential by design.
168+
/// See [DLManagedTensor](https://dmlc.github.io/dlpack/latest/c_api.html#_CPPv415DLManagedTensor)
169+
#[pin_project(PinnedDrop)]
170+
#[repr(C)]
171+
pub struct ManagedTensorProxy<C> {
172+
/// Holds the underlying tensor.
173+
pub dl_tensor: DLTensor,
174+
/// The context holding the underlying DLTensor.
175+
#[pin]
176+
pub manager_ctx: ManagerContext<C>, // safe typed wrapper for *mut c_void which is !Unpin i.e. pinnable
177+
/// Deleter function pointer.
178+
// TODO: should this be `#[pin]`?
179+
pub deleter: Option<fn(&mut ManagedTensor<C>)>,
174180
}
175181

176-
impl<'ctx> From<Pin<&mut ManagedTensorProxy<'ctx>>> for DLManagedTensor {
177-
fn from(pmt: Pin<&mut ManagedTensorProxy<'ctx>>) -> Self {
178-
DLManagedTensor {
179-
dl_tensor: pmt.dl_tensor,
180-
manager_ctx: pmt.manager_ctx,
181-
deleter: pmt.deleter,
182-
}
182+
impl<C: Debug> Debug for ManagedTensorProxy<C> {
183+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
184+
f.debug_struct("ManagedTensorProxy")
185+
.field("dl_tensor", &self.dl_tensor)
186+
.field("manager_ctx", &self.manager_ctx)
187+
.finish()
183188
}
184189
}
185190

186-
impl<'ctx> ManagedTensorProxy<'ctx> {
191+
impl<C> ManagedTensorProxy<C> {
187192
pub fn dl_tensor(&self) -> DLTensor {
188193
self.dl_tensor
189194
}
190195

191-
pub fn manager_ctx(self: Pin<&mut Self>) -> *mut c_void {
192-
let this = self.project();
193-
*this.manager_ctx.get_mut()
196+
pub fn manager_ctx(self: Pin<&mut Self>) -> Option<NonNull<*mut c_void>> {
197+
let mut this = self.project();
198+
this.manager_ctx.as_mut().ptr
194199
}
195200

196-
pub fn set_manager_ctx(self: Pin<&mut Self>, manager_ctx: *mut c_void) {
201+
pub fn set_manager_ctx(self: Pin<&mut Self>, manager_ctx: NonNull<*mut c_void>) {
197202
let mut this = self.project();
198-
*this.manager_ctx = manager_ctx;
203+
let new = ManagerContext::new(Some(manager_ctx));
204+
this.manager_ctx.set(new);
205+
}
206+
}
207+
208+
impl<C> From<DLManagedTensor> for ManagedTensorProxy<C> {
209+
fn from(mut dlmt: DLManagedTensor) -> Self {
210+
let ptr: Option<NonNull<*mut c_void>> = if dlmt.manager_ctx.is_null() {
211+
None
212+
} else {
213+
unsafe { Some(NonNull::new_unchecked(&mut dlmt.manager_ctx as *mut _)) }
214+
};
215+
let manager_ctx = ManagerContext::new(ptr);
216+
let deleter = dlmt.deleter.take().map(|del| unsafe {
217+
transmute::<unsafe extern "C" fn(*mut DLManagedTensor), fn(&mut ManagedTensor<C>)>(del)
218+
});
219+
ManagedTensorProxy {
220+
dl_tensor: dlmt.dl_tensor,
221+
manager_ctx,
222+
deleter,
223+
}
224+
}
225+
}
226+
227+
impl<C> From<ManagedTensorProxy<C>> for DLManagedTensor {
228+
fn from(pmt: ManagedTensorProxy<C>) -> Self {
229+
let dl_tensor = pmt.dl_tensor;
230+
let manager_ctx = match pmt.manager_ctx.ptr {
231+
None => ptr::null_mut(),
232+
Some(nnptr) => unsafe { *nnptr.as_ptr() },
233+
};
234+
let deleter = unsafe {
235+
pmt.deleter.map(|del_fn| {
236+
transmute::<fn(&mut ManagedTensor<C>), unsafe extern "C" fn(*mut DLManagedTensor)>(
237+
del_fn,
238+
)
239+
})
240+
};
241+
DLManagedTensor {
242+
dl_tensor,
243+
manager_ctx,
244+
deleter,
245+
}
199246
}
247+
}
200248

201-
pub fn deleter(
202-
self: Pin<&mut Self>,
203-
) -> Option<unsafe extern "C" fn(self_: *mut DLManagedTensor)> {
204-
let this = self.project();
205-
*this.deleter.get_mut()
249+
impl<C> From<Pin<&mut ManagedTensorProxy<C>>> for DLManagedTensor {
250+
fn from(pmt: Pin<&mut ManagedTensorProxy<C>>) -> Self {
251+
let dl_tensor = pmt.dl_tensor;
252+
let manager_ctx = match pmt.manager_ctx.ptr {
253+
None => ptr::null_mut(),
254+
Some(nnptr) => unsafe { *nnptr.as_ptr() },
255+
};
256+
let deleter = unsafe {
257+
pmt.deleter.map(|del_fn| {
258+
transmute::<fn(&mut ManagedTensor<C>), unsafe extern "C" fn(*mut DLManagedTensor)>(
259+
del_fn,
260+
)
261+
})
262+
};
263+
DLManagedTensor {
264+
dl_tensor,
265+
manager_ctx,
266+
deleter,
267+
}
206268
}
207269
}
208270

209271
#[allow(clippy::needless_lifetimes)]
210272
#[pinned_drop]
211-
impl<'ctx> PinnedDrop for ManagedTensorProxy<'ctx> {
273+
impl<C> PinnedDrop for ManagedTensorProxy<C> {
212274
fn drop(mut self: Pin<&mut Self>) {
213275
let mut dlm: DLManagedTensor = self.as_mut().into();
214-
if let Some(fptr) = self.deleter() {
276+
if let Some(fptr) = self.deleter {
215277
unsafe {
216-
let cfptr = std::mem::transmute::<*const (), unsafe fn(*mut DLManagedTensor)>(
217-
fptr as *const (),
218-
);
278+
let cfptr = transmute::<fn(&mut ManagedTensor<C>), fn(*mut DLManagedTensor)>(fptr);
219279
cfptr(&mut dlm as *mut _);
220280
};
221281
}
@@ -227,18 +287,35 @@ impl<'ctx> PinnedDrop for ManagedTensorProxy<'ctx> {
227287
/// See [DLManagedTensor](https://dmlc.github.io/dlpack/latest/c_api.html#_CPPv415DLManagedTensor)
228288
#[derive(Debug)]
229289
#[repr(transparent)]
230-
pub struct ManagedTensor<'tensor, 'ctx: 'tensor> {
231-
pub inner: ManagedTensorProxy<'ctx>,
232-
_marker: PhantomData<fn(&'tensor ()) -> &'tensor ()>,
290+
pub struct ManagedTensor<'tensor, C: 'tensor> {
291+
pub inner: ManagedTensorProxy<C>,
292+
_marker: PhantomData<fn(&'tensor ()) -> &'tensor ()>, // invariant wrt 'tensor
293+
}
294+
295+
impl<'tensor, C> From<DLManagedTensor> for ManagedTensor<'tensor, C> {
296+
fn from(dlm: DLManagedTensor) -> Self {
297+
let proxy: ManagedTensorProxy<C> = dlm.into();
298+
ManagedTensor {
299+
inner: proxy,
300+
_marker: PhantomData,
301+
}
302+
}
233303
}
234304

235-
impl<'tensor, 'ctx> ManagedTensor<'tensor, 'ctx> {
236-
pub fn new(tensor: Tensor<'tensor>, manager_ctx: *mut c_void) -> Self {
305+
impl<'tensor, C> From<ManagedTensor<'tensor, C>> for DLManagedTensor {
306+
fn from(mt: ManagedTensor<'tensor, C>) -> Self {
307+
mt.inner.into()
308+
}
309+
}
310+
311+
impl<'tensor, C: 'tensor> ManagedTensor<'tensor, C> {
312+
/// Contructor.
313+
pub fn new(tensor: Tensor<'tensor>, manager_ctx: Option<NonNull<*mut c_void>>) -> Self {
314+
let manager_ctx = ManagerContext::new(manager_ctx);
237315
let inner = ManagedTensorProxy {
238316
dl_tensor: tensor.into_inner(),
239317
manager_ctx,
240318
deleter: None,
241-
_marker: PhantomData,
242319
};
243320

244321
ManagedTensor {
@@ -247,8 +324,8 @@ impl<'tensor, 'ctx> ManagedTensor<'tensor, 'ctx> {
247324
}
248325
}
249326

250-
/// Sets a deleter function pointer
251-
pub fn set_deleter(&mut self, deleter: unsafe extern "C" fn(*mut DLManagedTensor)) {
327+
/// Sets a deleter function pointer.
328+
pub fn set_deleter(&mut self, deleter: fn(&mut ManagedTensor<C>)) {
252329
self.inner.deleter = Some(deleter);
253330
}
254331

@@ -261,9 +338,8 @@ impl<'tensor, 'ctx> ManagedTensor<'tensor, 'ctx> {
261338
/// Returns a ManagedTensor instances from a raw pointer to DLManagedTensor.
262339
pub unsafe fn from_raw(ptr: *mut DLManagedTensor) -> Self {
263340
debug_assert!(!ptr.is_null());
264-
let proxy = (*ptr).into();
265341
ManagedTensor {
266-
inner: proxy,
342+
inner: (*ptr).into(),
267343
_marker: PhantomData,
268344
}
269345
}

0 commit comments

Comments
 (0)