@@ -154,6 +154,8 @@ use sys::OnnxEnumInt;
154
154
155
155
// Re-export ndarray as it's part of the public API anyway
156
156
pub use ndarray;
157
+ use ndarray:: Array ;
158
+ use crate :: tensor:: OrtTensor ;
157
159
158
160
lazy_static ! {
159
161
// static ref G_ORT: Arc<Mutex<AtomicPtr<sys::OrtApi>>> =
@@ -459,6 +461,34 @@ impl_type_trait!(u64, Uint64);
459
461
// impl_type_trait!(, Complex128);
460
462
// impl_type_trait!(, Bfloat16);
461
463
464
+ #[ derive( Debug ) ]
465
+ pub enum TypedArray < D : ndarray:: Dimension > {
466
+ F32 ( Array < f32 , D > ) ,
467
+ U8 ( Array < u8 , D > ) ,
468
+ I8 ( Array < i8 , D > ) ,
469
+ U16 ( Array < u16 , D > ) ,
470
+ I16 ( Array < i16 , D > ) ,
471
+ I32 ( Array < i32 , D > ) ,
472
+ I64 ( Array < i64 , D > ) ,
473
+ F64 ( Array < f64 , D > ) ,
474
+ U32 ( Array < u32 , D > ) ,
475
+ U64 ( Array < u64 , D > ) ,
476
+ }
477
+
478
+ #[ derive( Debug ) ]
479
+ pub enum TypedOrtTensor < ' t , D : ndarray:: Dimension > {
480
+ F32 ( OrtTensor < ' t , f32 , D > ) ,
481
+ U8 ( OrtTensor < ' t , u8 , D > ) ,
482
+ I8 ( OrtTensor < ' t , i8 , D > ) ,
483
+ U16 ( OrtTensor < ' t , u16 , D > ) ,
484
+ I16 ( OrtTensor < ' t , i16 , D > ) ,
485
+ I32 ( OrtTensor < ' t , i32 , D > ) ,
486
+ I64 ( OrtTensor < ' t , i64 , D > ) ,
487
+ F64 ( OrtTensor < ' t , f64 , D > ) ,
488
+ U32 ( OrtTensor < ' t , u32 , D > ) ,
489
+ U64 ( OrtTensor < ' t , u64 , D > ) ,
490
+ }
491
+
462
492
/// Adapter for common Rust string types to Onnx strings.
463
493
///
464
494
/// It should be easy to use both `String` and `&str` as [TensorElementDataType::String] data, but
0 commit comments