Skip to content

Commit 80b68d1

Browse files
Support the onnx string type in output tensors
This approach allocates owned Strings for each element, which works, but stresses the allocator, and incurs unnecessary copying. Part of the complication stems from the limitation that in Rust, a field can't be a reference to another field in the same struct. This means that having a Vec<u8> of copied data, referred to by a Vec<&str>, which is then referred to by an ArrayView, requires a sequence of 3 structs to express. Building a Vec<String> gets rid of the references, but also loses the efficiency of 1 allocation with strs pointing into it.
1 parent 555bec7 commit 80b68d1

File tree

7 files changed

+236
-61
lines changed

7 files changed

+236
-61
lines changed

onnxruntime/examples/sample.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,9 @@ fn run() -> Result<(), Error> {
6666
let outputs: Vec<DynOrtTensor<_>> = session.run(input_tensor_values)?;
6767

6868
let output: OrtOwnedTensor<f32, _> = outputs[0].try_extract().unwrap();
69-
assert_eq!(output.shape(), output0_shape.as_slice());
69+
assert_eq!(output.view().shape(), output0_shape.as_slice());
7070
for i in 0..5 {
71-
println!("Score for class [{}] = {}", i, output[[0, i, 0, 0]]);
71+
println!("Score for class [{}] = {}", i, output.view()[[0, i, 0, 0]]);
7272
}
7373

7474
Ok(())

onnxruntime/src/error.rs

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
//! Module containing error definitions.
22
3-
use std::{io, path::PathBuf};
3+
use std::{io, path::PathBuf, string};
44

55
use thiserror::Error;
66

@@ -53,6 +53,12 @@ pub enum OrtError {
5353
/// Error occurred when getting ONNX dimensions
5454
#[error("Failed to get dimensions: {0}")]
5555
GetDimensions(OrtApiError),
56+
/// Error occurred when getting string length
57+
#[error("Failed to get string tensor length: {0}")]
58+
GetStringTensorDataLength(OrtApiError),
59+
/// Error occurred when getting tensor element count
60+
#[error("Failed to get tensor element count: {0}")]
61+
GetTensorShapeElementCount(OrtApiError),
5662
/// Error occurred when creating CPU memory information
5763
#[error("Failed to get dimensions: {0}")]
5864
CreateCpuMemoryInfo(OrtApiError),
@@ -77,6 +83,12 @@ pub enum OrtError {
7783
/// Error occurred when extracting data from an ONNX tensor into an C array to be used as an `ndarray::ArrayView`
7884
#[error("Failed to get tensor data: {0}")]
7985
GetTensorMutableData(OrtApiError),
86+
/// Error occurred when extracting string data from an ONNX tensor
87+
#[error("Failed to get tensor string data: {0}")]
88+
GetStringTensorContent(OrtApiError),
89+
/// Error occurred when converting data to a String
90+
#[error("Data was not UTF-8: {0}")]
91+
StringFromUtf8Error(#[from] string::FromUtf8Error),
8092

8193
/// Error occurred when downloading a pre-trained ONNX model from the [ONNX Model Zoo](https://github.com/onnx/models)
8294
#[error("Failed to download ONNX model: {0}")]

onnxruntime/src/session.rs

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
//! Module containing session types
22
3-
use std::{ffi::CString, fmt::Debug, path::Path};
3+
use std::{convert::TryInto as _, ffi::CString, fmt::Debug, path::Path};
44

55
#[cfg(not(target_family = "windows"))]
66
use std::os::unix::ffi::OsStrExt;
@@ -436,21 +436,40 @@ impl<'a> Session<'a> {
436436
output_tensor_ptrs
437437
.into_iter()
438438
.map(|tensor_ptr| {
439-
let (dims, data_type) = unsafe {
439+
let (dims, data_type, len) = unsafe {
440440
call_with_tensor_info(tensor_ptr, |tensor_info_ptr| {
441441
get_tensor_dimensions(tensor_info_ptr)
442442
.map(|dims| dims.iter().map(|&n| n as usize).collect::<Vec<_>>())
443443
.and_then(|dims| {
444444
extract_data_type(tensor_info_ptr)
445445
.map(|data_type| (dims, data_type))
446446
})
447+
.and_then(|(dims, data_type)| {
448+
let mut len = 0_u64;
449+
450+
call_ort(|ort| {
451+
ort.GetTensorShapeElementCount.unwrap()(
452+
tensor_info_ptr,
453+
&mut len,
454+
)
455+
})
456+
.map_err(OrtError::GetTensorShapeElementCount)?;
457+
458+
Ok((
459+
dims,
460+
data_type,
461+
len.try_into()
462+
.expect("u64 length could not fit into usize"),
463+
))
464+
})
447465
})
448466
}?;
449467

450468
Ok(DynOrtTensor::new(
451469
tensor_ptr,
452470
memory_info_ref,
453471
ndarray::IxDyn(&dims),
472+
len,
454473
data_type,
455474
))
456475
})

onnxruntime/src/tensor.rs

Lines changed: 110 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,10 @@ pub mod ort_tensor;
3030
pub use ort_owned_tensor::{DynOrtTensor, OrtOwnedTensor};
3131
pub use ort_tensor::OrtTensor;
3232

33-
use crate::{OrtError, Result};
33+
use crate::tensor::ort_owned_tensor::TensorPointerHolder;
34+
use crate::{error::call_ort, OrtError, Result};
3435
use onnxruntime_sys::{self as sys, OnnxEnumInt};
35-
use std::{fmt, ptr};
36+
use std::{convert::TryInto as _, ffi, fmt, ptr, rc, result, string};
3637

3738
// FIXME: Use https://docs.rs/bindgen/0.54.1/bindgen/struct.Builder.html#method.rustified_enum
3839
// FIXME: Add tests to cover the commented out types
@@ -188,14 +189,41 @@ pub trait TensorDataToType: Sized + fmt::Debug {
188189
fn tensor_element_data_type() -> TensorElementDataType;
189190

190191
/// Extract an `ArrayView` from the ort-owned tensor.
191-
fn extract_array<'t, D>(
192+
fn extract_data<'t, D>(
192193
shape: D,
193-
tensor: *mut sys::OrtValue,
194-
) -> Result<ndarray::ArrayView<'t, Self, D>>
194+
tensor_element_len: usize,
195+
tensor_ptr: rc::Rc<TensorPointerHolder>,
196+
) -> Result<TensorData<'t, Self, D>>
195197
where
196198
D: ndarray::Dimension;
197199
}
198200

201+
/// Represents the possible ways tensor data can be accessed.
202+
///
203+
/// This should only be used internally.
204+
#[derive(Debug)]
205+
pub enum TensorData<'t, T, D>
206+
where
207+
D: ndarray::Dimension,
208+
{
209+
/// Data resides in ort's tensor, in which case the 't lifetime is what makes this valid.
210+
/// This is used for data types whose in-memory form from ort is compatible with Rust's, like
211+
/// primitive numeric types.
212+
TensorPtr {
213+
/// The pointer ort produced. Kept alive so that `array_view` is valid.
214+
ptr: rc::Rc<TensorPointerHolder>,
215+
/// A view into `ptr`
216+
array_view: ndarray::ArrayView<'t, T, D>,
217+
},
218+
/// String data is output differently by ort, and of course is also variable size, so it cannot
219+
/// use the same simple pointer representation.
220+
// Since 't outlives this struct, the 't lifetime is more than we need, but no harm done.
221+
Strings {
222+
/// Owned Strings copied out of ort's output
223+
strings: ndarray::Array<T, D>,
224+
},
225+
}
226+
199227
/// Implements `OwnedTensorDataToType` for primitives, which can use `GetTensorMutableData`
200228
macro_rules! impl_prim_type_from_ort_trait {
201229
($type_:ty, $variant:ident) => {
@@ -204,14 +232,20 @@ macro_rules! impl_prim_type_from_ort_trait {
204232
TensorElementDataType::$variant
205233
}
206234

207-
fn extract_array<'t, D>(
235+
fn extract_data<'t, D>(
208236
shape: D,
209-
tensor: *mut sys::OrtValue,
210-
) -> Result<ndarray::ArrayView<'t, Self, D>>
237+
_tensor_element_len: usize,
238+
tensor_ptr: rc::Rc<TensorPointerHolder>,
239+
) -> Result<TensorData<'t, Self, D>>
211240
where
212241
D: ndarray::Dimension,
213242
{
214-
extract_primitive_array(shape, tensor)
243+
extract_primitive_array(shape, tensor_ptr.tensor_ptr).map(|v| {
244+
TensorData::TensorPtr {
245+
ptr: tensor_ptr,
246+
array_view: v,
247+
}
248+
})
215249
}
216250
}
217251
};
@@ -255,3 +289,70 @@ impl_prim_type_from_ort_trait!(i64, Int64);
255289
impl_prim_type_from_ort_trait!(f64, Double);
256290
impl_prim_type_from_ort_trait!(u32, Uint32);
257291
impl_prim_type_from_ort_trait!(u64, Uint64);
292+
293+
impl TensorDataToType for String {
294+
fn tensor_element_data_type() -> TensorElementDataType {
295+
TensorElementDataType::String
296+
}
297+
298+
fn extract_data<'t, D: ndarray::Dimension>(
299+
shape: D,
300+
tensor_element_len: usize,
301+
tensor_ptr: rc::Rc<TensorPointerHolder>,
302+
) -> Result<TensorData<'t, Self, D>> {
303+
// Total length of string data, not including \0 suffix
304+
let mut total_length = 0_u64;
305+
unsafe {
306+
call_ort(|ort| {
307+
ort.GetStringTensorDataLength.unwrap()(tensor_ptr.tensor_ptr, &mut total_length)
308+
})
309+
.map_err(OrtError::GetStringTensorDataLength)?
310+
}
311+
312+
// In the JNI impl of this, tensor_element_len was included in addition to total_length,
313+
// but that seems contrary to the docs of GetStringTensorDataLength, and those extra bytes
314+
// don't seem to be written to in practice either.
315+
// If the string data actually did go farther, it would panic below when using the offset
316+
// data to get slices for each string.
317+
let mut string_contents = vec![0_u8; total_length as usize];
318+
// one extra slot so that the total length can go in the last one, making all per-string
319+
// length calculations easy
320+
let mut offsets = vec![0_u64; tensor_element_len as usize + 1];
321+
322+
unsafe {
323+
call_ort(|ort| {
324+
ort.GetStringTensorContent.unwrap()(
325+
tensor_ptr.tensor_ptr,
326+
string_contents.as_mut_ptr() as *mut ffi::c_void,
327+
total_length,
328+
offsets.as_mut_ptr(),
329+
tensor_element_len as u64,
330+
)
331+
})
332+
.map_err(OrtError::GetStringTensorContent)?
333+
}
334+
335+
// final offset = overall length so that per-string length calculations work for the last
336+
// string
337+
debug_assert_eq!(0, offsets[tensor_element_len]);
338+
offsets[tensor_element_len] = total_length;
339+
340+
let strings = offsets
341+
// offsets has 1 extra offset past the end so that all windows work
342+
.windows(2)
343+
.map(|w| {
344+
let start: usize = w[0].try_into().expect("Offset didn't fit into usize");
345+
let next_start: usize = w[1].try_into().expect("Offset didn't fit into usize");
346+
347+
let slice = &string_contents[start..next_start];
348+
String::from_utf8(slice.into())
349+
})
350+
.collect::<result::Result<Vec<String>, string::FromUtf8Error>>()
351+
.map_err(OrtError::StringFromUtf8Error)?;
352+
353+
let array = ndarray::Array::from_shape_vec(shape, strings)
354+
.expect("Shape extracted from tensor didn't match tensor contents");
355+
356+
Ok(TensorData::Strings { strings: array })
357+
}
358+
}

0 commit comments

Comments
 (0)