@@ -30,9 +30,10 @@ pub mod ort_tensor;
30
30
pub use ort_owned_tensor:: { DynOrtTensor , OrtOwnedTensor } ;
31
31
pub use ort_tensor:: OrtTensor ;
32
32
33
- use crate :: { OrtError , Result } ;
33
+ use crate :: tensor:: ort_owned_tensor:: TensorPointerHolder ;
34
+ use crate :: { error:: call_ort, OrtError , Result } ;
34
35
use onnxruntime_sys:: { self as sys, OnnxEnumInt } ;
35
- use std:: { fmt, ptr} ;
36
+ use std:: { convert :: TryInto as _ , ffi , fmt, ptr, rc , result , string } ;
36
37
37
38
// FIXME: Use https://docs.rs/bindgen/0.54.1/bindgen/struct.Builder.html#method.rustified_enum
38
39
// FIXME: Add tests to cover the commented out types
@@ -188,14 +189,41 @@ pub trait TensorDataToType: Sized + fmt::Debug {
188
189
fn tensor_element_data_type ( ) -> TensorElementDataType ;
189
190
190
191
/// Extract an `ArrayView` from the ort-owned tensor.
191
- fn extract_array < ' t , D > (
192
+ fn extract_data < ' t , D > (
192
193
shape : D ,
193
- tensor : * mut sys:: OrtValue ,
194
- ) -> Result < ndarray:: ArrayView < ' t , Self , D > >
194
+ tensor_element_len : usize ,
195
+ tensor_ptr : rc:: Rc < TensorPointerHolder > ,
196
+ ) -> Result < TensorData < ' t , Self , D > >
195
197
where
196
198
D : ndarray:: Dimension ;
197
199
}
198
200
201
+ /// Represents the possible ways tensor data can be accessed.
202
+ ///
203
+ /// This should only be used internally.
204
+ #[ derive( Debug ) ]
205
+ pub enum TensorData < ' t , T , D >
206
+ where
207
+ D : ndarray:: Dimension ,
208
+ {
209
+ /// Data resides in ort's tensor, in which case the 't lifetime is what makes this valid.
210
+ /// This is used for data types whose in-memory form from ort is compatible with Rust's, like
211
+ /// primitive numeric types.
212
+ TensorPtr {
213
+ /// The pointer ort produced. Kept alive so that `array_view` is valid.
214
+ ptr : rc:: Rc < TensorPointerHolder > ,
215
+ /// A view into `ptr`
216
+ array_view : ndarray:: ArrayView < ' t , T , D > ,
217
+ } ,
218
+ /// String data is output differently by ort, and of course is also variable size, so it cannot
219
+ /// use the same simple pointer representation.
220
+ // Since 't outlives this struct, the 't lifetime is more than we need, but no harm done.
221
+ Strings {
222
+ /// Owned Strings copied out of ort's output
223
+ strings : ndarray:: Array < T , D > ,
224
+ } ,
225
+ }
226
+
199
227
/// Implements `OwnedTensorDataToType` for primitives, which can use `GetTensorMutableData`
200
228
macro_rules! impl_prim_type_from_ort_trait {
201
229
( $type_: ty, $variant: ident) => {
@@ -204,14 +232,20 @@ macro_rules! impl_prim_type_from_ort_trait {
204
232
TensorElementDataType :: $variant
205
233
}
206
234
207
- fn extract_array <' t, D >(
235
+ fn extract_data <' t, D >(
208
236
shape: D ,
209
- tensor: * mut sys:: OrtValue ,
210
- ) -> Result <ndarray:: ArrayView <' t, Self , D >>
237
+ _tensor_element_len: usize ,
238
+ tensor_ptr: rc:: Rc <TensorPointerHolder >,
239
+ ) -> Result <TensorData <' t, Self , D >>
211
240
where
212
241
D : ndarray:: Dimension ,
213
242
{
214
- extract_primitive_array( shape, tensor)
243
+ extract_primitive_array( shape, tensor_ptr. tensor_ptr) . map( |v| {
244
+ TensorData :: TensorPtr {
245
+ ptr: tensor_ptr,
246
+ array_view: v,
247
+ }
248
+ } )
215
249
}
216
250
}
217
251
} ;
@@ -255,3 +289,70 @@ impl_prim_type_from_ort_trait!(i64, Int64);
255
289
impl_prim_type_from_ort_trait ! ( f64 , Double ) ;
256
290
impl_prim_type_from_ort_trait ! ( u32 , Uint32 ) ;
257
291
impl_prim_type_from_ort_trait ! ( u64 , Uint64 ) ;
292
+
293
+ impl TensorDataToType for String {
294
+ fn tensor_element_data_type ( ) -> TensorElementDataType {
295
+ TensorElementDataType :: String
296
+ }
297
+
298
+ fn extract_data < ' t , D : ndarray:: Dimension > (
299
+ shape : D ,
300
+ tensor_element_len : usize ,
301
+ tensor_ptr : rc:: Rc < TensorPointerHolder > ,
302
+ ) -> Result < TensorData < ' t , Self , D > > {
303
+ // Total length of string data, not including \0 suffix
304
+ let mut total_length = 0_u64 ;
305
+ unsafe {
306
+ call_ort ( |ort| {
307
+ ort. GetStringTensorDataLength . unwrap ( ) ( tensor_ptr. tensor_ptr , & mut total_length)
308
+ } )
309
+ . map_err ( OrtError :: GetStringTensorDataLength ) ?
310
+ }
311
+
312
+ // In the JNI impl of this, tensor_element_len was included in addition to total_length,
313
+ // but that seems contrary to the docs of GetStringTensorDataLength, and those extra bytes
314
+ // don't seem to be written to in practice either.
315
+ // If the string data actually did go farther, it would panic below when using the offset
316
+ // data to get slices for each string.
317
+ let mut string_contents = vec ! [ 0_u8 ; total_length as usize ] ;
318
+ // one extra slot so that the total length can go in the last one, making all per-string
319
+ // length calculations easy
320
+ let mut offsets = vec ! [ 0_u64 ; tensor_element_len as usize + 1 ] ;
321
+
322
+ unsafe {
323
+ call_ort ( |ort| {
324
+ ort. GetStringTensorContent . unwrap ( ) (
325
+ tensor_ptr. tensor_ptr ,
326
+ string_contents. as_mut_ptr ( ) as * mut ffi:: c_void ,
327
+ total_length,
328
+ offsets. as_mut_ptr ( ) ,
329
+ tensor_element_len as u64 ,
330
+ )
331
+ } )
332
+ . map_err ( OrtError :: GetStringTensorContent ) ?
333
+ }
334
+
335
+ // final offset = overall length so that per-string length calculations work for the last
336
+ // string
337
+ debug_assert_eq ! ( 0 , offsets[ tensor_element_len] ) ;
338
+ offsets[ tensor_element_len] = total_length;
339
+
340
+ let strings = offsets
341
+ // offsets has 1 extra offset past the end so that all windows work
342
+ . windows ( 2 )
343
+ . map ( |w| {
344
+ let start: usize = w[ 0 ] . try_into ( ) . expect ( "Offset didn't fit into usize" ) ;
345
+ let next_start: usize = w[ 1 ] . try_into ( ) . expect ( "Offset didn't fit into usize" ) ;
346
+
347
+ let slice = & string_contents[ start..next_start] ;
348
+ String :: from_utf8 ( slice. into ( ) )
349
+ } )
350
+ . collect :: < result:: Result < Vec < String > , string:: FromUtf8Error > > ( )
351
+ . map_err ( OrtError :: StringFromUtf8Error ) ?;
352
+
353
+ let array = ndarray:: Array :: from_shape_vec ( shape, strings)
354
+ . expect ( "Shape extracted from tensor didn't match tensor contents" ) ;
355
+
356
+ Ok ( TensorData :: Strings { strings : array } )
357
+ }
358
+ }
0 commit comments