11use pin_project:: { pin_project, pinned_drop} ;
22
33use core:: slice;
4- use std:: { marker:: PhantomData , os:: raw:: c_void, pin:: Pin } ;
4+ use std:: {
5+ fmt:: Debug ,
6+ marker:: { PhantomData , PhantomPinned } ,
7+ mem:: transmute,
8+ os:: raw:: c_void,
9+ pin:: Pin ,
10+ ptr:: { self , NonNull } ,
11+ } ;
512
613use crate :: {
714 datatype:: DataType ,
@@ -138,84 +145,137 @@ impl<'tensor> Tensor<'tensor> {
138145 }
139146}
140147
141- /// Safe proxy to ffi::DLManagedTensor which is self-referential by design.
142- /// See [DLManagedTensor](https://dmlc.github.io/dlpack/latest/c_api.html#_CPPv415DLManagedTensor)
148+ /// A typed ManagerContext type that is `!Unpin` i.e. pinnable for safety since it holds a pointer to the underlying DLTensor.
143149#[ derive( Debug ) ]
144- #[ pin_project( PinnedDrop ) ]
145- pub struct ManagedTensorProxy < ' ctx > {
146- pub dl_tensor : DLTensor ,
147- #[ pin]
148- /// Holds the underlying DLTensor.
149- pub manager_ctx : * mut c_void ,
150- #[ pin]
151- pub deleter : Option < unsafe extern "C" fn ( * mut DLManagedTensor ) > ,
152- _marker : PhantomData < & ' ctx ( ) > , // covariant wrt 'ctx
150+ #[ repr( C ) ]
151+ pub struct ManagerContext < C > {
152+ pub ptr : Option < NonNull < * mut c_void > > ,
153+ ty : PhantomData < C > ,
154+ _pin : PhantomPinned ,
153155}
154156
155- impl < ' ctx > From < DLManagedTensor > for ManagedTensorProxy < ' ctx > {
156- fn from ( dlmt : DLManagedTensor ) -> Self {
157- ManagedTensorProxy {
158- dl_tensor : dlmt. dl_tensor ,
159- manager_ctx : dlmt. manager_ctx ,
160- deleter : dlmt. deleter ,
161- _marker : PhantomData ,
157+ impl < C > ManagerContext < C > {
158+ pub fn new ( ptr : Option < NonNull < * mut c_void > > ) -> Self {
159+ Self {
160+ ptr,
161+ ty : PhantomData ,
162+ _pin : PhantomPinned ,
162163 }
163164 }
164165}
165166
166- impl < ' ctx > From < ManagedTensorProxy < ' ctx > > for DLManagedTensor {
167- fn from ( pmt : ManagedTensorProxy < ' ctx > ) -> Self {
168- DLManagedTensor {
169- dl_tensor : pmt. dl_tensor ,
170- manager_ctx : pmt. manager_ctx ,
171- deleter : pmt. deleter ,
172- }
173- }
167+ /// Safe proxy to ffi::DLManagedTensor which is self-referential by design.
168+ /// See [DLManagedTensor](https://dmlc.github.io/dlpack/latest/c_api.html#_CPPv415DLManagedTensor)
169+ #[ pin_project( PinnedDrop ) ]
170+ #[ repr( C ) ]
171+ pub struct ManagedTensorProxy < C > {
172+ /// Holds the underlying tensor.
173+ pub dl_tensor : DLTensor ,
174+ /// The context holding the underlying DLTensor.
175+ #[ pin]
176+ pub manager_ctx : ManagerContext < C > , // safe typed wrapper for *mut c_void which is !Unpin i.e. pinnable
177+ /// Deleter function pointer.
178+ // TODO: should this be `#[pin]`?
179+ pub deleter : Option < fn ( & mut ManagedTensor < C > ) > ,
174180}
175181
176- impl < ' ctx > From < Pin < & mut ManagedTensorProxy < ' ctx > > > for DLManagedTensor {
177- fn from ( pmt : Pin < & mut ManagedTensorProxy < ' ctx > > ) -> Self {
178- DLManagedTensor {
179- dl_tensor : pmt. dl_tensor ,
180- manager_ctx : pmt. manager_ctx ,
181- deleter : pmt. deleter ,
182- }
182+ impl < C : Debug > Debug for ManagedTensorProxy < C > {
183+ fn fmt ( & self , f : & mut std:: fmt:: Formatter < ' _ > ) -> std:: fmt:: Result {
184+ f. debug_struct ( "ManagedTensorProxy" )
185+ . field ( "dl_tensor" , & self . dl_tensor )
186+ . field ( "manager_ctx" , & self . manager_ctx )
187+ . finish ( )
183188 }
184189}
185190
186- impl < ' ctx > ManagedTensorProxy < ' ctx > {
191+ impl < C > ManagedTensorProxy < C > {
187192 pub fn dl_tensor ( & self ) -> DLTensor {
188193 self . dl_tensor
189194 }
190195
191- pub fn manager_ctx ( self : Pin < & mut Self > ) -> * mut c_void {
192- let this = self . project ( ) ;
193- * this. manager_ctx . get_mut ( )
196+ pub fn manager_ctx ( self : Pin < & mut Self > ) -> Option < NonNull < * mut c_void > > {
197+ let mut this = self . project ( ) ;
198+ this. manager_ctx . as_mut ( ) . ptr
194199 }
195200
196- pub fn set_manager_ctx ( self : Pin < & mut Self > , manager_ctx : * mut c_void ) {
201+ pub fn set_manager_ctx ( self : Pin < & mut Self > , manager_ctx : NonNull < * mut c_void > ) {
197202 let mut this = self . project ( ) ;
198- * this. manager_ctx = manager_ctx;
203+ let new = ManagerContext :: new ( Some ( manager_ctx) ) ;
204+ this. manager_ctx . set ( new) ;
205+ }
206+ }
207+
208+ impl < C > From < DLManagedTensor > for ManagedTensorProxy < C > {
209+ fn from ( mut dlmt : DLManagedTensor ) -> Self {
210+ let ptr: Option < NonNull < * mut c_void > > = if dlmt. manager_ctx . is_null ( ) {
211+ None
212+ } else {
213+ unsafe { Some ( NonNull :: new_unchecked ( & mut dlmt. manager_ctx as * mut _ ) ) }
214+ } ;
215+ let manager_ctx = ManagerContext :: new ( ptr) ;
216+ let deleter = dlmt. deleter . take ( ) . map ( |del| unsafe {
217+ transmute :: < unsafe extern "C" fn ( * mut DLManagedTensor ) , fn ( & mut ManagedTensor < C > ) > ( del)
218+ } ) ;
219+ ManagedTensorProxy {
220+ dl_tensor : dlmt. dl_tensor ,
221+ manager_ctx,
222+ deleter,
223+ }
224+ }
225+ }
226+
227+ impl < C > From < ManagedTensorProxy < C > > for DLManagedTensor {
228+ fn from ( pmt : ManagedTensorProxy < C > ) -> Self {
229+ let dl_tensor = pmt. dl_tensor ;
230+ let manager_ctx = match pmt. manager_ctx . ptr {
231+ None => ptr:: null_mut ( ) ,
232+ Some ( nnptr) => unsafe { * nnptr. as_ptr ( ) } ,
233+ } ;
234+ let deleter = unsafe {
235+ pmt. deleter . map ( |del_fn| {
236+ transmute :: < fn ( & mut ManagedTensor < C > ) , unsafe extern "C" fn ( * mut DLManagedTensor ) > (
237+ del_fn,
238+ )
239+ } )
240+ } ;
241+ DLManagedTensor {
242+ dl_tensor,
243+ manager_ctx,
244+ deleter,
245+ }
199246 }
247+ }
200248
201- pub fn deleter (
202- self : Pin < & mut Self > ,
203- ) -> Option < unsafe extern "C" fn ( self_ : * mut DLManagedTensor ) > {
204- let this = self . project ( ) ;
205- * this. deleter . get_mut ( )
249+ impl < C > From < Pin < & mut ManagedTensorProxy < C > > > for DLManagedTensor {
250+ fn from ( pmt : Pin < & mut ManagedTensorProxy < C > > ) -> Self {
251+ let dl_tensor = pmt. dl_tensor ;
252+ let manager_ctx = match pmt. manager_ctx . ptr {
253+ None => ptr:: null_mut ( ) ,
254+ Some ( nnptr) => unsafe { * nnptr. as_ptr ( ) } ,
255+ } ;
256+ let deleter = unsafe {
257+ pmt. deleter . map ( |del_fn| {
258+ transmute :: < fn ( & mut ManagedTensor < C > ) , unsafe extern "C" fn ( * mut DLManagedTensor ) > (
259+ del_fn,
260+ )
261+ } )
262+ } ;
263+ DLManagedTensor {
264+ dl_tensor,
265+ manager_ctx,
266+ deleter,
267+ }
206268 }
207269}
208270
209271#[ allow( clippy:: needless_lifetimes) ]
210272#[ pinned_drop]
211- impl < ' ctx > PinnedDrop for ManagedTensorProxy < ' ctx > {
273+ impl < C > PinnedDrop for ManagedTensorProxy < C > {
212274 fn drop ( mut self : Pin < & mut Self > ) {
213275 let mut dlm: DLManagedTensor = self . as_mut ( ) . into ( ) ;
214- if let Some ( fptr) = self . deleter ( ) {
276+ if let Some ( fptr) = self . deleter {
215277 unsafe {
216- let cfptr = std:: mem:: transmute :: < * const ( ) , unsafe fn ( * mut DLManagedTensor ) > (
217- fptr as * const ( ) ,
218- ) ;
278+ let cfptr = transmute :: < fn ( & mut ManagedTensor < C > ) , fn ( * mut DLManagedTensor ) > ( fptr) ;
219279 cfptr ( & mut dlm as * mut _ ) ;
220280 } ;
221281 }
@@ -227,18 +287,35 @@ impl<'ctx> PinnedDrop for ManagedTensorProxy<'ctx> {
227287/// See [DLManagedTensor](https://dmlc.github.io/dlpack/latest/c_api.html#_CPPv415DLManagedTensor)
228288#[ derive( Debug ) ]
229289#[ repr( transparent) ]
230- pub struct ManagedTensor < ' tensor , ' ctx : ' tensor > {
231- pub inner : ManagedTensorProxy < ' ctx > ,
232- _marker : PhantomData < fn ( & ' tensor ( ) ) -> & ' tensor ( ) > ,
290+ pub struct ManagedTensor < ' tensor , C : ' tensor > {
291+ pub inner : ManagedTensorProxy < C > ,
292+ _marker : PhantomData < fn ( & ' tensor ( ) ) -> & ' tensor ( ) > , // invariant wrt 'tensor
293+ }
294+
295+ impl < ' tensor , C > From < DLManagedTensor > for ManagedTensor < ' tensor , C > {
296+ fn from ( dlm : DLManagedTensor ) -> Self {
297+ let proxy: ManagedTensorProxy < C > = dlm. into ( ) ;
298+ ManagedTensor {
299+ inner : proxy,
300+ _marker : PhantomData ,
301+ }
302+ }
233303}
234304
235- impl < ' tensor , ' ctx > ManagedTensor < ' tensor , ' ctx > {
236- pub fn new ( tensor : Tensor < ' tensor > , manager_ctx : * mut c_void ) -> Self {
305+ impl < ' tensor , C > From < ManagedTensor < ' tensor , C > > for DLManagedTensor {
306+ fn from ( mt : ManagedTensor < ' tensor , C > ) -> Self {
307+ mt. inner . into ( )
308+ }
309+ }
310+
311+ impl < ' tensor , C : ' tensor > ManagedTensor < ' tensor , C > {
312+ /// Contructor.
313+ pub fn new ( tensor : Tensor < ' tensor > , manager_ctx : Option < NonNull < * mut c_void > > ) -> Self {
314+ let manager_ctx = ManagerContext :: new ( manager_ctx) ;
237315 let inner = ManagedTensorProxy {
238316 dl_tensor : tensor. into_inner ( ) ,
239317 manager_ctx,
240318 deleter : None ,
241- _marker : PhantomData ,
242319 } ;
243320
244321 ManagedTensor {
@@ -247,8 +324,8 @@ impl<'tensor, 'ctx> ManagedTensor<'tensor, 'ctx> {
247324 }
248325 }
249326
250- /// Sets a deleter function pointer
251- pub fn set_deleter ( & mut self , deleter : unsafe extern "C" fn ( * mut DLManagedTensor ) ) {
327+ /// Sets a deleter function pointer.
328+ pub fn set_deleter ( & mut self , deleter : fn ( & mut ManagedTensor < C > ) ) {
252329 self . inner . deleter = Some ( deleter) ;
253330 }
254331
@@ -261,9 +338,8 @@ impl<'tensor, 'ctx> ManagedTensor<'tensor, 'ctx> {
261338 /// Returns a ManagedTensor instances from a raw pointer to DLManagedTensor.
262339 pub unsafe fn from_raw ( ptr : * mut DLManagedTensor ) -> Self {
263340 debug_assert ! ( !ptr. is_null( ) ) ;
264- let proxy = ( * ptr) . into ( ) ;
265341 ManagedTensor {
266- inner : proxy ,
342+ inner : ( * ptr ) . into ( ) ,
267343 _marker : PhantomData ,
268344 }
269345 }
0 commit comments