|
| 1 | +use std::ffi::CString; |
| 2 | +use std::fmt::Debug; |
| 3 | +use std::ops::{Index, IndexMut}; |
| 4 | + |
| 5 | +use ndarray::{Array, Dimension, IxDyn}; |
| 6 | + |
| 7 | +use onnxruntime_sys as sys; |
| 8 | + |
| 9 | +use crate::error::{status_to_result, OrtError}; |
| 10 | +use crate::memory::MemoryInfo; |
| 11 | +use crate::session::{Output, Session}; |
| 12 | +use crate::tensor::OrtTensor; |
| 13 | +use crate::{g_ort, Result, TypeToTensorElementDataType}; |
| 14 | + |
| 15 | +pub trait Element: 'static + Clone + Debug + TypeToTensorElementDataType + Default {} |
| 16 | + |
| 17 | +impl<T: 'static + Clone + Debug + TypeToTensorElementDataType + Default> Element for T {} |
| 18 | + |
| 19 | +fn names_to_ptrs(names: impl Iterator<Item = String>) -> Vec<*const i8> { |
| 20 | + names |
| 21 | + .map(|name| CString::new(name.clone()).unwrap().into_raw() as *const _) |
| 22 | + .collect() |
| 23 | +} |
| 24 | + |
| 25 | +fn compute_output_shapes<TIn, DIn: Dimension>( |
| 26 | + input_arrays: &[Array<TIn, DIn>], |
| 27 | + outputs: &[Output], |
| 28 | +) -> Vec<Vec<usize>> { |
| 29 | + outputs |
| 30 | + .iter() |
| 31 | + .enumerate() |
| 32 | + .map(|(idx, output)| { |
| 33 | + output |
| 34 | + .dimensions |
| 35 | + .iter() |
| 36 | + .enumerate() |
| 37 | + .map(|(jdx, dim)| match dim { |
| 38 | + None => input_arrays[idx].shape()[jdx], |
| 39 | + Some(d) => *d as usize, |
| 40 | + }) |
| 41 | + .collect() |
| 42 | + }) |
| 43 | + .collect() |
| 44 | +} |
| 45 | + |
| 46 | +fn arrays_to_tensors<T: Element, D: Dimension>( |
| 47 | + memory_info: &MemoryInfo, |
| 48 | + arrays: impl IntoIterator<Item = Array<T, D>>, |
| 49 | +) -> Result<Vec<OrtTensor<T, D>>> { |
| 50 | + Ok(arrays |
| 51 | + .into_iter() |
| 52 | + .map(|arr| OrtTensor::from_array(memory_info, arr)) |
| 53 | + .collect::<Result<Vec<_>>>()?) |
| 54 | +} |
| 55 | + |
| 56 | +fn tensors_to_ptr<'a, 's: 'a, T: Element, D: Dimension + 'a>( |
| 57 | + tensors: impl IntoIterator<Item = &'a OrtTensor<'s, T, D>>, |
| 58 | +) -> Vec<*const sys::OrtValue> { |
| 59 | + tensors |
| 60 | + .into_iter() |
| 61 | + .map(|tensor| tensor.c_ptr as *const _) |
| 62 | + .collect() |
| 63 | +} |
| 64 | + |
| 65 | +fn tensors_to_mut_ptr<'a, 's: 'a, T: Element, D: Dimension + 'a>( |
| 66 | + tensors: impl IntoIterator<Item = &'a mut OrtTensor<'s, T, D>>, |
| 67 | +) -> Vec<*mut sys::OrtValue> { |
| 68 | + tensors |
| 69 | + .into_iter() |
| 70 | + .map(|tensor| tensor.c_ptr as *mut _) |
| 71 | + .collect() |
| 72 | +} |
| 73 | + |
| 74 | +fn arrays_to_ort<T: Element, D: Dimension>( |
| 75 | + memory_info: &MemoryInfo, |
| 76 | + arrays: impl IntoIterator<Item = Array<T, D>>, |
| 77 | +) -> Result<(Vec<OrtTensor<T, D>>, Vec<*const sys::OrtValue>)> { |
| 78 | + let ort_tensors = arrays |
| 79 | + .into_iter() |
| 80 | + .map(|arr| OrtTensor::from_array(memory_info, arr)) |
| 81 | + .collect::<Result<Vec<_>>>()?; |
| 82 | + let ort_values = ort_tensors |
| 83 | + .iter() |
| 84 | + .map(|tensor| tensor.c_ptr as *const _) |
| 85 | + .collect(); |
| 86 | + Ok((ort_tensors, ort_values)) |
| 87 | +} |
| 88 | + |
| 89 | +fn arrays_with_shapes<T: Element, D: Dimension>(shapes: &[Vec<usize>]) -> Result<Vec<Array<T, D>>> { |
| 90 | + Ok(shapes |
| 91 | + .into_iter() |
| 92 | + .map(|shape| Array::<_, IxDyn>::default(shape.clone()).into_dimensionality()) |
| 93 | + .collect::<std::result::Result<Vec<Array<T, D>>, _>>()?) |
| 94 | +} |
| 95 | + |
| 96 | +pub struct Inputs<'r, 'a, T: Element, D: Dimension> { |
| 97 | + tensors: &'a mut [OrtTensor<'r, T, D>], |
| 98 | +} |
| 99 | + |
| 100 | +impl<T: Element, D: Dimension> Inputs<'_, '_, T, D> {} |
| 101 | + |
| 102 | +impl<T: Element, D: Dimension> Index<usize> for Inputs<'_, '_, T, D> { |
| 103 | + type Output = Array<T, D>; |
| 104 | + |
| 105 | + #[inline] |
| 106 | + fn index(&self, index: usize) -> &Self::Output { |
| 107 | + &(*self.tensors[index]) |
| 108 | + } |
| 109 | +} |
| 110 | + |
| 111 | +impl<T: Element, D: Dimension> IndexMut<usize> for Inputs<'_, '_, T, D> { |
| 112 | + #[inline] |
| 113 | + fn index_mut(&mut self, index: usize) -> &mut Self::Output { |
| 114 | + &mut (*self.tensors[index]) |
| 115 | + } |
| 116 | +} |
| 117 | + |
| 118 | +pub struct Outputs<'r, 'a, T: Element, D: Dimension> { |
| 119 | + tensors: &'a [OrtTensor<'r, T, D>], |
| 120 | +} |
| 121 | + |
| 122 | +impl<T: Element, D: Dimension> Outputs<'_, '_, T, D> {} |
| 123 | + |
| 124 | +impl<T: Element, D: Dimension> Index<usize> for Outputs<'_, '_, T, D> { |
| 125 | + type Output = Array<T, D>; |
| 126 | + |
| 127 | + #[inline] |
| 128 | + fn index(&self, index: usize) -> &Self::Output { |
| 129 | + &(*self.tensors[index]) |
| 130 | + } |
| 131 | +} |
| 132 | + |
| 133 | +pub struct RunnerBuilder<'s, TIn: Element, DIn: Dimension> { |
| 134 | + session: &'s Session, |
| 135 | + input_arrays: Vec<Array<TIn, DIn>>, |
| 136 | +} |
| 137 | + |
| 138 | +impl<'s, TIn: Element, DIn: Dimension> RunnerBuilder<'s, TIn, DIn> { |
| 139 | + #[inline] |
| 140 | + pub fn new( |
| 141 | + session: &'s Session, |
| 142 | + input_arrays: impl IntoIterator<Item = Array<TIn, DIn>>, |
| 143 | + ) -> Self { |
| 144 | + Self { |
| 145 | + session, |
| 146 | + input_arrays: input_arrays.into_iter().collect(), |
| 147 | + } |
| 148 | + } |
| 149 | + |
| 150 | + #[inline] |
| 151 | + pub fn with_output<TOut: Element, DOut: Dimension>( |
| 152 | + self, |
| 153 | + ) -> Result<Runner<'s, TIn, DIn, TOut, DOut>> { |
| 154 | + Runner::new(self.session, self.input_arrays) |
| 155 | + } |
| 156 | + |
| 157 | + #[inline] |
| 158 | + pub fn with_output_dyn<TOut: Element>(self) -> Result<Runner<'s, TIn, DIn, TOut, IxDyn>> { |
| 159 | + Runner::new(self.session, self.input_arrays) |
| 160 | + } |
| 161 | +} |
| 162 | + |
| 163 | +pub struct Runner<'s, TIn: Element, DIn: Dimension, TOut: Element, DOut: Dimension> { |
| 164 | + session: &'s Session, |
| 165 | + input_names_ptr: Vec<*const i8>, |
| 166 | + output_names_ptr: Vec<*const i8>, |
| 167 | + input_ort_tensors: Vec<OrtTensor<'s, TIn, DIn>>, |
| 168 | + input_ort_values_ptr: Vec<*const sys::OrtValue>, |
| 169 | + output_ort_tensors: Vec<OrtTensor<'s, TOut, DOut>>, |
| 170 | + output_ort_values_ptr: Vec<*mut sys::OrtValue>, |
| 171 | +} |
| 172 | + |
| 173 | +impl<'s, TIn: Element, DIn: Dimension, TOut: Element, DOut: Dimension> |
| 174 | + Runner<'s, TIn, DIn, TOut, DOut> |
| 175 | +{ |
| 176 | + pub fn new( |
| 177 | + session: &'s Session, |
| 178 | + input_arrays: impl IntoIterator<Item = Array<TIn, DIn>>, |
| 179 | + ) -> Result<Self> { |
| 180 | + let input_names_ptr = names_to_ptrs(session.inputs.iter().map(|i| i.name.clone())); |
| 181 | + let output_names_ptr = names_to_ptrs(session.outputs.iter().map(|o| o.name.clone())); |
| 182 | + let input_arrays = input_arrays.into_iter().collect::<Vec<_>>(); |
| 183 | + session.validate_input_shapes(&input_arrays)?; |
| 184 | + let output_shapes = compute_output_shapes(&input_arrays, &session.outputs); |
| 185 | + let output_arrays = arrays_with_shapes::<_, DOut>(&output_shapes)?; |
| 186 | + let input_ort_tensors = arrays_to_tensors(&session.memory_info, input_arrays)?; |
| 187 | + let input_ort_values_ptr = tensors_to_ptr(&input_ort_tensors); |
| 188 | + let mut output_ort_tensors = arrays_to_tensors(&session.memory_info, output_arrays)?; |
| 189 | + let output_ort_values_ptr = tensors_to_mut_ptr(&mut output_ort_tensors); |
| 190 | + Ok(Self { |
| 191 | + session, |
| 192 | + input_names_ptr, |
| 193 | + output_names_ptr, |
| 194 | + input_ort_tensors, |
| 195 | + input_ort_values_ptr, |
| 196 | + output_ort_tensors, |
| 197 | + output_ort_values_ptr, |
| 198 | + }) |
| 199 | + } |
| 200 | + |
| 201 | + #[inline] |
| 202 | + pub fn inputs(&mut self) -> Inputs<'s, '_, TIn, DIn> { |
| 203 | + Inputs { |
| 204 | + tensors: self.input_ort_tensors.as_mut_slice(), |
| 205 | + } |
| 206 | + } |
| 207 | + |
| 208 | + #[inline] |
| 209 | + pub fn outputs(&'s self) -> Outputs<'s, '_, TOut, DOut> { |
| 210 | + Outputs { |
| 211 | + tensors: self.output_ort_tensors.as_slice(), |
| 212 | + } |
| 213 | + } |
| 214 | + |
| 215 | + #[inline] |
| 216 | + pub fn execute(&mut self) -> Result<()> { |
| 217 | + Ok(status_to_result(unsafe { |
| 218 | + g_ort().Run.unwrap()( |
| 219 | + self.session.session_ptr, |
| 220 | + std::ptr::null() as _, |
| 221 | + self.input_names_ptr.as_ptr(), |
| 222 | + self.input_ort_values_ptr.as_ptr(), |
| 223 | + self.input_ort_values_ptr.len() as _, |
| 224 | + self.output_names_ptr.as_ptr(), |
| 225 | + self.output_names_ptr.len() as _, |
| 226 | + self.output_ort_values_ptr.as_mut_ptr(), |
| 227 | + ) |
| 228 | + }) |
| 229 | + .map_err(OrtError::Run)?) |
| 230 | + } |
| 231 | +} |
0 commit comments