Skip to content

Commit aea2c8e

Browse files
committed
feat: Add Mixed Array Support
1 parent cd5a6eb commit aea2c8e

File tree

2 files changed

+325
-13
lines changed

2 files changed

+325
-13
lines changed

onnxruntime/src/lib.rs

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,8 @@ use sys::OnnxEnumInt;
154154

155155
// Re-export ndarray as it's part of the public API anyway
156156
pub use ndarray;
157+
use ndarray::Array;
158+
use crate::tensor::OrtTensor;
157159

158160
lazy_static! {
159161
// static ref G_ORT: Arc<Mutex<AtomicPtr<sys::OrtApi>>> =
@@ -459,6 +461,34 @@ impl_type_trait!(u64, Uint64);
459461
// impl_type_trait!(, Complex128);
460462
// impl_type_trait!(, Bfloat16);
461463

464+
#[derive(Debug)]
465+
pub enum TypedArray<D: ndarray::Dimension> {
466+
F32(Array<f32, D>),
467+
U8(Array<u8, D>),
468+
I8(Array<i8, D>),
469+
U16(Array<u16, D>),
470+
I16(Array<i16, D>),
471+
I32(Array<i32, D>),
472+
I64(Array<i64, D>),
473+
F64(Array<f64, D>),
474+
U32(Array<u32, D>),
475+
U64(Array<u64, D>),
476+
}
477+
478+
#[derive(Debug)]
479+
pub enum TypedOrtTensor<'t, D: ndarray::Dimension> {
480+
F32(OrtTensor<'t, f32, D>),
481+
U8(OrtTensor<'t, u8, D>),
482+
I8(OrtTensor<'t, i8, D>),
483+
U16(OrtTensor<'t, u16, D>),
484+
I16(OrtTensor<'t, i16, D>),
485+
I32(OrtTensor<'t, i32, D>),
486+
I64(OrtTensor<'t, i64, D>),
487+
F64(OrtTensor<'t, f64, D>),
488+
U32(OrtTensor<'t, u32, D>),
489+
U64(OrtTensor<'t, u64, D>),
490+
}
491+
462492
/// Adapter for common Rust string types to Onnx strings.
463493
///
464494
/// It should be easy to use both `String` and `&str` as [TensorElementDataType::String] data, but

0 commit comments

Comments
 (0)