Skip to content

Commit 555bec7

Browse files
Use DynOrtTensor for model output tensors
Outputs aren't all the same type for a single model, so this allows extracting different types per tensor.
1 parent 0d34d23 commit 555bec7

File tree

13 files changed

+295
-79
lines changed

13 files changed

+295
-79
lines changed

onnxruntime/examples/issue22.rs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,12 @@ fn main() {
3434
let input_ids = Array2::<i64>::from_shape_vec((1, 3), vec![1, 2, 3]).unwrap();
3535
let attention_mask = Array2::<i64>::from_shape_vec((1, 3), vec![1, 1, 1]).unwrap();
3636

37-
let outputs: Vec<OrtOwnedTensor<f32, _>> =
38-
session.run(vec![input_ids, attention_mask]).unwrap();
37+
let outputs: Vec<OrtOwnedTensor<f32, _>> = session
38+
.run(vec![input_ids, attention_mask])
39+
.unwrap()
40+
.into_iter()
41+
.map(|dyn_tensor| dyn_tensor.try_extract())
42+
.collect::<Result<_, _>>()
43+
.unwrap();
3944
print!("outputs: {:#?}", outputs);
4045
}

onnxruntime/examples/sample.rs

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
#![forbid(unsafe_code)]
22

33
use onnxruntime::{
4-
environment::Environment, ndarray::Array, tensor::OrtOwnedTensor, GraphOptimizationLevel,
5-
LoggingLevel,
4+
environment::Environment,
5+
ndarray::Array,
6+
tensor::{DynOrtTensor, OrtOwnedTensor},
7+
GraphOptimizationLevel, LoggingLevel,
68
};
79
use tracing::Level;
810
use tracing_subscriber::FmtSubscriber;
@@ -61,11 +63,12 @@ fn run() -> Result<(), Error> {
6163
.unwrap();
6264
let input_tensor_values = vec![array];
6365

64-
let outputs: Vec<OrtOwnedTensor<f32, _>> = session.run(input_tensor_values)?;
66+
let outputs: Vec<DynOrtTensor<_>> = session.run(input_tensor_values)?;
6567

66-
assert_eq!(outputs[0].shape(), output0_shape.as_slice());
68+
let output: OrtOwnedTensor<f32, _> = outputs[0].try_extract().unwrap();
69+
assert_eq!(output.shape(), output0_shape.as_slice());
6770
for i in 0..5 {
68-
println!("Score for class [{}] = {}", i, outputs[0][[0, i, 0, 0]]);
71+
println!("Score for class [{}] = {}", i, output[[0, i, 0, 0]]);
6972
}
7073

7174
Ok(())

onnxruntime/src/lib.rs

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,10 @@ to download.
104104
//! let array = ndarray::Array::linspace(0.0_f32, 1.0, 100);
105105
//! // Multiple inputs and outputs are possible
106106
//! let input_tensor = vec![array];
107-
//! let outputs: Vec<OrtOwnedTensor<f32,_>> = session.run(input_tensor)?;
107+
//! let outputs: Vec<OrtOwnedTensor<f32, _>> = session.run(input_tensor)?
108+
//! .into_iter()
109+
//! .map(|dyn_tensor| dyn_tensor.try_extract())
110+
//! .collect::<Result<_, _>>()?;
108111
//! # Ok(())
109112
//! # }
110113
//! ```
@@ -115,7 +118,10 @@ to download.
115118
//! See the [`sample.rs`](https://github.com/nbigaouette/onnxruntime-rs/blob/master/onnxruntime/examples/sample.rs)
116119
//! example for more details.
117120
118-
use std::sync::{atomic::AtomicPtr, Arc, Mutex};
121+
use std::{
122+
ffi, ptr,
123+
sync::{atomic::AtomicPtr, Arc, Mutex},
124+
};
119125

120126
use lazy_static::lazy_static;
121127

@@ -142,7 +148,7 @@ lazy_static! {
142148
// } as *mut sys::OrtApi)));
143149
static ref G_ORT_API: Arc<Mutex<AtomicPtr<sys::OrtApi>>> = {
144150
let base: *const sys::OrtApiBase = unsafe { sys::OrtGetApiBase() };
145-
assert_ne!(base, std::ptr::null());
151+
assert_ne!(base, ptr::null());
146152
let get_api: unsafe extern "C" fn(u32) -> *const onnxruntime_sys::OrtApi =
147153
unsafe { (*base).GetApi.unwrap() };
148154
let api: *const sys::OrtApi = unsafe { get_api(sys::ORT_API_VERSION) };
@@ -157,13 +163,13 @@ fn g_ort() -> sys::OrtApi {
157163
let api_ref_mut: &mut *mut sys::OrtApi = api_ref.get_mut();
158164
let api_ptr_mut: *mut sys::OrtApi = *api_ref_mut;
159165

160-
assert_ne!(api_ptr_mut, std::ptr::null_mut());
166+
assert_ne!(api_ptr_mut, ptr::null_mut());
161167

162168
unsafe { *api_ptr_mut }
163169
}
164170

165171
fn char_p_to_string(raw: *const i8) -> Result<String> {
166-
let c_string = unsafe { std::ffi::CStr::from_ptr(raw as *mut i8).to_owned() };
172+
let c_string = unsafe { ffi::CStr::from_ptr(raw as *mut i8).to_owned() };
167173

168174
match c_string.into_string() {
169175
Ok(string) => Ok(string),
@@ -176,7 +182,7 @@ mod onnxruntime {
176182
//! Module containing a custom logger, used to catch the runtime's own logging and send it
177183
//! to Rust's tracing logging instead.
178184
179-
use std::ffi::CStr;
185+
use std::{ffi, ffi::CStr, ptr};
180186
use tracing::{debug, error, info, span, trace, warn, Level};
181187

182188
use onnxruntime_sys as sys;
@@ -212,7 +218,7 @@ mod onnxruntime {
212218

213219
/// Callback from C that will handle the logging, forwarding the runtime's logs to the tracing crate.
214220
pub(crate) extern "C" fn custom_logger(
215-
_params: *mut std::ffi::c_void,
221+
_params: *mut ffi::c_void,
216222
severity: sys::OrtLoggingLevel,
217223
category: *const i8,
218224
logid: *const i8,
@@ -227,16 +233,16 @@ mod onnxruntime {
227233
sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_FATAL => Level::ERROR,
228234
};
229235

230-
assert_ne!(category, std::ptr::null());
236+
assert_ne!(category, ptr::null());
231237
let category = unsafe { CStr::from_ptr(category) };
232-
assert_ne!(code_location, std::ptr::null());
238+
assert_ne!(code_location, ptr::null());
233239
let code_location = unsafe { CStr::from_ptr(code_location) }
234240
.to_str()
235241
.unwrap_or("unknown");
236-
assert_ne!(message, std::ptr::null());
242+
assert_ne!(message, ptr::null());
237243
let message = unsafe { CStr::from_ptr(message) };
238244

239-
assert_ne!(logid, std::ptr::null());
245+
assert_ne!(logid, ptr::null());
240246
let logid = unsafe { CStr::from_ptr(logid) };
241247

242248
// Parse the code location
@@ -376,7 +382,7 @@ mod test {
376382

377383
#[test]
378384
fn test_char_p_to_string() {
379-
let s = std::ffi::CString::new("foo").unwrap();
385+
let s = ffi::CString::new("foo").unwrap();
380386
let ptr = s.as_c_str().as_ptr();
381387
assert_eq!("foo", char_p_to_string(ptr).unwrap());
382388
}

onnxruntime/src/session.rs

Lines changed: 16 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,7 @@ use crate::{
2121
error::{call_ort, status_to_result, NonMatchingDimensionsError, OrtError, Result},
2222
g_ort,
2323
memory::MemoryInfo,
24-
tensor::{
25-
ort_owned_tensor::OrtOwnedTensor, OrtTensor, TensorDataToType, TensorElementDataType,
26-
TypeToTensorElementDataType,
27-
},
24+
tensor::{DynOrtTensor, OrtTensor, TensorElementDataType, TypeToTensorElementDataType},
2825
AllocatorType, GraphOptimizationLevel, MemType,
2926
};
3027

@@ -364,13 +361,12 @@ impl<'a> Session<'a> {
364361
///
365362
/// Note that ONNX models can have multiple inputs; a `Vec<_>` is thus
366363
/// used for the input data here.
367-
pub fn run<'s, 't, 'm, TIn, TOut, D>(
364+
pub fn run<'s, 't, 'm, TIn, D>(
368365
&'s mut self,
369366
input_arrays: Vec<Array<TIn, D>>,
370-
) -> Result<Vec<OrtOwnedTensor<'t, 'm, TOut, ndarray::IxDyn>>>
367+
) -> Result<Vec<DynOrtTensor<'m, ndarray::IxDyn>>>
371368
where
372369
TIn: TypeToTensorElementDataType + Debug + Clone,
373-
TOut: TensorDataToType,
374370
D: ndarray::Dimension,
375371
'm: 't, // 'm outlives 't (memory info outlives tensor)
376372
's: 'm, // 's outlives 'm (session outlives memory info)
@@ -404,7 +400,7 @@ impl<'a> Session<'a> {
404400
.map(|n| n.as_ptr() as *const i8)
405401
.collect();
406402

407-
let mut output_tensor_extractors_ptrs: Vec<*mut sys::OrtValue> =
403+
let mut output_tensor_ptrs: Vec<*mut sys::OrtValue> =
408404
vec![std::ptr::null_mut(); self.outputs.len()];
409405

410406
// The C API expects pointers for the arrays (pointers to C-arrays)
@@ -430,38 +426,32 @@ impl<'a> Session<'a> {
430426
input_ort_values.len() as u64, // C API expects a u64, not isize
431427
output_names_ptr.as_ptr(),
432428
output_names_ptr.len() as u64, // C API expects a u64, not isize
433-
output_tensor_extractors_ptrs.as_mut_ptr(),
429+
output_tensor_ptrs.as_mut_ptr(),
434430
)
435431
};
436432
status_to_result(status).map_err(OrtError::Run)?;
437433

438434
let memory_info_ref = &self.memory_info;
439-
let outputs: Result<Vec<OrtOwnedTensor<TOut, ndarray::Dim<ndarray::IxDynImpl>>>> =
440-
output_tensor_extractors_ptrs
435+
let outputs: Result<Vec<DynOrtTensor<ndarray::Dim<ndarray::IxDynImpl>>>> =
436+
output_tensor_ptrs
441437
.into_iter()
442438
.map(|tensor_ptr| {
443-
let dims = unsafe {
439+
let (dims, data_type) = unsafe {
444440
call_with_tensor_info(tensor_ptr, |tensor_info_ptr| {
445441
get_tensor_dimensions(tensor_info_ptr)
446442
.map(|dims| dims.iter().map(|&n| n as usize).collect::<Vec<_>>())
443+
.and_then(|dims| {
444+
extract_data_type(tensor_info_ptr)
445+
.map(|data_type| (dims, data_type))
446+
})
447447
})
448448
}?;
449449

450-
// Note: Both tensor and array will point to the same data, nothing is copied.
451-
// As such, there is no need to free the pointer used to create the ArrayView.
452-
assert_ne!(tensor_ptr, std::ptr::null_mut());
453-
454-
let mut is_tensor = 0;
455-
unsafe { call_ort(|ort| ort.IsTensor.unwrap()(tensor_ptr, &mut is_tensor)) }
456-
.map_err(OrtError::IsTensor)?;
457-
assert_eq!(is_tensor, 1);
458-
459-
let array_view = TOut::extract_array(ndarray::IxDyn(&dims), tensor_ptr)?;
460-
461-
Ok(OrtOwnedTensor::new(
450+
Ok(DynOrtTensor::new(
462451
tensor_ptr,
463-
array_view,
464-
&memory_info_ref,
452+
memory_info_ref,
453+
ndarray::IxDyn(&dims),
454+
data_type,
465455
))
466456
})
467457
.collect();

onnxruntime/src/tensor.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ pub mod ndarray_tensor;
2727
pub mod ort_owned_tensor;
2828
pub mod ort_tensor;
2929

30-
pub use ort_owned_tensor::OrtOwnedTensor;
30+
pub use ort_owned_tensor::{DynOrtTensor, OrtOwnedTensor};
3131
pub use ort_tensor::OrtTensor;
3232

3333
use crate::{OrtError, Result};

0 commit comments

Comments
 (0)