Skip to content

Commit acdc209

Browse files
committed
Add initial working implementation of Runner
1 parent 67a525e commit acdc209

File tree

3 files changed

+236
-0
lines changed

3 files changed

+236
-0
lines changed

onnxruntime/src/error.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use std::{io, path::PathBuf};
44

55
use thiserror::Error;
66

7+
use ndarray::ShapeError;
78
use onnxruntime_sys as sys;
89

910
use crate::{char_p_to_string, g_ort};
@@ -91,6 +92,9 @@ pub enum OrtError {
9192
/// Attempt to build a Rust `CString` from a null pointer
9293
#[error("Failed to build CString when original contains null: {0}")]
9394
CStringNulError(#[from] std::ffi::NulError),
95+
/// Output dimensionality mismatch
96+
#[error("Output dimensionality mismatch: {0}")]
97+
OutputDimensionalityMismatch(#[from] ShapeError),
9498
}
9599

96100
/// Error used when dimensions of input (from model and from inference call)

onnxruntime/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ pub mod download;
125125
pub mod environment;
126126
pub mod error;
127127
mod memory;
128+
pub mod runner;
128129
pub mod session;
129130
pub mod tensor;
130131

onnxruntime/src/runner.rs

Lines changed: 231 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,231 @@
1+
use std::ffi::CString;
2+
use std::fmt::Debug;
3+
use std::ops::{Index, IndexMut};
4+
5+
use ndarray::{Array, Dimension, IxDyn};
6+
7+
use onnxruntime_sys as sys;
8+
9+
use crate::error::{status_to_result, OrtError};
10+
use crate::memory::MemoryInfo;
11+
use crate::session::{Output, Session};
12+
use crate::tensor::OrtTensor;
13+
use crate::{g_ort, Result, TypeToTensorElementDataType};
14+
15+
pub trait Element: 'static + Clone + Debug + TypeToTensorElementDataType + Default {}
16+
17+
impl<T: 'static + Clone + Debug + TypeToTensorElementDataType + Default> Element for T {}
18+
19+
fn names_to_ptrs(names: impl Iterator<Item = String>) -> Vec<*const i8> {
20+
names
21+
.map(|name| CString::new(name.clone()).unwrap().into_raw() as *const _)
22+
.collect()
23+
}
24+
25+
fn compute_output_shapes<TIn, DIn: Dimension>(
26+
input_arrays: &[Array<TIn, DIn>],
27+
outputs: &[Output],
28+
) -> Vec<Vec<usize>> {
29+
outputs
30+
.iter()
31+
.enumerate()
32+
.map(|(idx, output)| {
33+
output
34+
.dimensions
35+
.iter()
36+
.enumerate()
37+
.map(|(jdx, dim)| match dim {
38+
None => input_arrays[idx].shape()[jdx],
39+
Some(d) => *d as usize,
40+
})
41+
.collect()
42+
})
43+
.collect()
44+
}
45+
46+
fn arrays_to_tensors<T: Element, D: Dimension>(
47+
memory_info: &MemoryInfo,
48+
arrays: impl IntoIterator<Item = Array<T, D>>,
49+
) -> Result<Vec<OrtTensor<T, D>>> {
50+
Ok(arrays
51+
.into_iter()
52+
.map(|arr| OrtTensor::from_array(memory_info, arr))
53+
.collect::<Result<Vec<_>>>()?)
54+
}
55+
56+
fn tensors_to_ptr<'a, 's: 'a, T: Element, D: Dimension + 'a>(
57+
tensors: impl IntoIterator<Item = &'a OrtTensor<'s, T, D>>,
58+
) -> Vec<*const sys::OrtValue> {
59+
tensors
60+
.into_iter()
61+
.map(|tensor| tensor.c_ptr as *const _)
62+
.collect()
63+
}
64+
65+
fn tensors_to_mut_ptr<'a, 's: 'a, T: Element, D: Dimension + 'a>(
66+
tensors: impl IntoIterator<Item = &'a mut OrtTensor<'s, T, D>>,
67+
) -> Vec<*mut sys::OrtValue> {
68+
tensors
69+
.into_iter()
70+
.map(|tensor| tensor.c_ptr as *mut _)
71+
.collect()
72+
}
73+
74+
fn arrays_to_ort<T: Element, D: Dimension>(
75+
memory_info: &MemoryInfo,
76+
arrays: impl IntoIterator<Item = Array<T, D>>,
77+
) -> Result<(Vec<OrtTensor<T, D>>, Vec<*const sys::OrtValue>)> {
78+
let ort_tensors = arrays
79+
.into_iter()
80+
.map(|arr| OrtTensor::from_array(memory_info, arr))
81+
.collect::<Result<Vec<_>>>()?;
82+
let ort_values = ort_tensors
83+
.iter()
84+
.map(|tensor| tensor.c_ptr as *const _)
85+
.collect();
86+
Ok((ort_tensors, ort_values))
87+
}
88+
89+
fn arrays_with_shapes<T: Element, D: Dimension>(shapes: &[Vec<usize>]) -> Result<Vec<Array<T, D>>> {
90+
Ok(shapes
91+
.into_iter()
92+
.map(|shape| Array::<_, IxDyn>::default(shape.clone()).into_dimensionality())
93+
.collect::<std::result::Result<Vec<Array<T, D>>, _>>()?)
94+
}
95+
96+
pub struct Inputs<'r, 'a, T: Element, D: Dimension> {
97+
tensors: &'a mut [OrtTensor<'r, T, D>],
98+
}
99+
100+
impl<T: Element, D: Dimension> Inputs<'_, '_, T, D> {}
101+
102+
impl<T: Element, D: Dimension> Index<usize> for Inputs<'_, '_, T, D> {
103+
type Output = Array<T, D>;
104+
105+
#[inline]
106+
fn index(&self, index: usize) -> &Self::Output {
107+
&(*self.tensors[index])
108+
}
109+
}
110+
111+
impl<T: Element, D: Dimension> IndexMut<usize> for Inputs<'_, '_, T, D> {
112+
#[inline]
113+
fn index_mut(&mut self, index: usize) -> &mut Self::Output {
114+
&mut (*self.tensors[index])
115+
}
116+
}
117+
118+
pub struct Outputs<'r, 'a, T: Element, D: Dimension> {
119+
tensors: &'a [OrtTensor<'r, T, D>],
120+
}
121+
122+
impl<T: Element, D: Dimension> Outputs<'_, '_, T, D> {}
123+
124+
impl<T: Element, D: Dimension> Index<usize> for Outputs<'_, '_, T, D> {
125+
type Output = Array<T, D>;
126+
127+
#[inline]
128+
fn index(&self, index: usize) -> &Self::Output {
129+
&(*self.tensors[index])
130+
}
131+
}
132+
133+
pub struct RunnerBuilder<'s, TIn: Element, DIn: Dimension> {
134+
session: &'s Session,
135+
input_arrays: Vec<Array<TIn, DIn>>,
136+
}
137+
138+
impl<'s, TIn: Element, DIn: Dimension> RunnerBuilder<'s, TIn, DIn> {
139+
#[inline]
140+
pub fn new(
141+
session: &'s Session,
142+
input_arrays: impl IntoIterator<Item = Array<TIn, DIn>>,
143+
) -> Self {
144+
Self {
145+
session,
146+
input_arrays: input_arrays.into_iter().collect(),
147+
}
148+
}
149+
150+
#[inline]
151+
pub fn with_output<TOut: Element, DOut: Dimension>(
152+
self,
153+
) -> Result<Runner<'s, TIn, DIn, TOut, DOut>> {
154+
Runner::new(self.session, self.input_arrays)
155+
}
156+
157+
#[inline]
158+
pub fn with_output_dyn<TOut: Element>(self) -> Result<Runner<'s, TIn, DIn, TOut, IxDyn>> {
159+
Runner::new(self.session, self.input_arrays)
160+
}
161+
}
162+
163+
pub struct Runner<'s, TIn: Element, DIn: Dimension, TOut: Element, DOut: Dimension> {
164+
session: &'s Session,
165+
input_names_ptr: Vec<*const i8>,
166+
output_names_ptr: Vec<*const i8>,
167+
input_ort_tensors: Vec<OrtTensor<'s, TIn, DIn>>,
168+
input_ort_values_ptr: Vec<*const sys::OrtValue>,
169+
output_ort_tensors: Vec<OrtTensor<'s, TOut, DOut>>,
170+
output_ort_values_ptr: Vec<*mut sys::OrtValue>,
171+
}
172+
173+
impl<'s, TIn: Element, DIn: Dimension, TOut: Element, DOut: Dimension>
174+
Runner<'s, TIn, DIn, TOut, DOut>
175+
{
176+
pub fn new(
177+
session: &'s Session,
178+
input_arrays: impl IntoIterator<Item = Array<TIn, DIn>>,
179+
) -> Result<Self> {
180+
let input_names_ptr = names_to_ptrs(session.inputs.iter().map(|i| i.name.clone()));
181+
let output_names_ptr = names_to_ptrs(session.outputs.iter().map(|o| o.name.clone()));
182+
let input_arrays = input_arrays.into_iter().collect::<Vec<_>>();
183+
session.validate_input_shapes(&input_arrays)?;
184+
let output_shapes = compute_output_shapes(&input_arrays, &session.outputs);
185+
let output_arrays = arrays_with_shapes::<_, DOut>(&output_shapes)?;
186+
let input_ort_tensors = arrays_to_tensors(&session.memory_info, input_arrays)?;
187+
let input_ort_values_ptr = tensors_to_ptr(&input_ort_tensors);
188+
let mut output_ort_tensors = arrays_to_tensors(&session.memory_info, output_arrays)?;
189+
let output_ort_values_ptr = tensors_to_mut_ptr(&mut output_ort_tensors);
190+
Ok(Self {
191+
session,
192+
input_names_ptr,
193+
output_names_ptr,
194+
input_ort_tensors,
195+
input_ort_values_ptr,
196+
output_ort_tensors,
197+
output_ort_values_ptr,
198+
})
199+
}
200+
201+
#[inline]
202+
pub fn inputs(&mut self) -> Inputs<'s, '_, TIn, DIn> {
203+
Inputs {
204+
tensors: self.input_ort_tensors.as_mut_slice(),
205+
}
206+
}
207+
208+
#[inline]
209+
pub fn outputs(&'s self) -> Outputs<'s, '_, TOut, DOut> {
210+
Outputs {
211+
tensors: self.output_ort_tensors.as_slice(),
212+
}
213+
}
214+
215+
#[inline]
216+
pub fn execute(&mut self) -> Result<()> {
217+
Ok(status_to_result(unsafe {
218+
g_ort().Run.unwrap()(
219+
self.session.session_ptr,
220+
std::ptr::null() as _,
221+
self.input_names_ptr.as_ptr(),
222+
self.input_ort_values_ptr.as_ptr(),
223+
self.input_ort_values_ptr.len() as _,
224+
self.output_names_ptr.as_ptr(),
225+
self.output_names_ptr.len() as _,
226+
self.output_ort_values_ptr.as_mut_ptr(),
227+
)
228+
})
229+
.map_err(OrtError::Run)?)
230+
}
231+
}

0 commit comments

Comments
 (0)