Skip to content

Commit e69e125

Browse files
Add support for registring custom op libraries
1 parent 80b68d1 commit e69e125

File tree

9 files changed

+286
-1
lines changed

9 files changed

+286
-1
lines changed

.dockerignore

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
*
2+
!/Cargo.*
3+
!/onnxruntime/Cargo.toml
4+
!/onnxruntime/src
5+
!/onnxruntime/tests
6+
!/onnxruntime-sys/Cargo.toml
7+
!/onnxruntime-sys/build.rs
8+
!/onnxruntime-sys/src
9+
!/test-models/tensorflow/*.onnx

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1010
### Added
1111

1212
- Add `String` datatype ([#58](https://github.com/nbigaouette/onnxruntime-rs/pull/58))
13+
- Support custom operator libraries
1314

1415
## [0.0.11] - 2021-02-22
1516

Dockerfile

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
# onnxruntime requires execinfo.h to build, which only works on glibc-based systems, so alpine is out...
2+
FROM debian:bullseye-slim as base
3+
4+
RUN apt-get update && apt-get -y dist-upgrade
5+
6+
FROM base AS onnxruntime
7+
8+
RUN apt-get install -y \
9+
git \
10+
bash \
11+
python3 \
12+
cmake \
13+
git \
14+
build-essential \
15+
llvm \
16+
locales
17+
18+
# onnxruntime built in tests need en_US.UTF-8 available
19+
# Uncomment en_US.UTF-8, then generate
20+
RUN sed -i 's/^# *\(en_US.UTF-8\)/\1/' /etc/locale.gen && locale-gen
21+
22+
# build onnxruntime
23+
RUN mkdir -p /opt/onnxruntime/tmp
24+
# onnxruntime build relies on being in a git repo, so can't just get a tarball
25+
# it's a big repo, so fetch shallowly
26+
RUN cd /opt/onnxruntime/tmp && \
27+
git clone --recursive --depth 1 --shallow-submodules https://github.com/Microsoft/onnxruntime
28+
29+
# use version that onnxruntime-sys expects
30+
RUN cd /opt/onnxruntime/tmp/onnxruntime && \
31+
git fetch --depth 1 origin tag v1.6.0 && \
32+
git checkout v1.6.0
33+
34+
RUN /opt/onnxruntime/tmp/onnxruntime/build.sh --config RelWithDebInfo --build_shared_lib --parallel
35+
36+
# Build ort-customops, linked against the onnxruntime built above.
37+
# No tags / releases yet - that commit is from 2021-02-16
38+
RUN mkdir -p /opt/ort-customops/tmp && \
39+
cd /opt/ort-customops/tmp && \
40+
git clone --recursive https://github.com/microsoft/ort-customops.git && \
41+
cd ort-customops && \
42+
git checkout 92f6b51106c9e9143c452e537cb5e41d2dcaa266
43+
44+
RUN cd /opt/ort-customops/tmp/ort-customops && \
45+
./build.sh -D ONNXRUNTIME_LIB_DIR=/opt/onnxruntime/tmp/onnxruntime/build/Linux/RelWithDebInfo
46+
47+
48+
# install rust toolchain
49+
FROM base AS rust-toolchain
50+
51+
ARG RUST_VERSION=1.50.0
52+
53+
RUN apt-get install -y \
54+
curl
55+
56+
# install rust toolchain
57+
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --default-toolchain $RUST_VERSION
58+
59+
ENV PATH $PATH:/root/.cargo/bin
60+
61+
62+
# build onnxruntime-rs
63+
FROM rust-toolchain as onnxruntime-rs
64+
# clang & llvm needed by onnxruntime-sys
65+
RUN apt-get install -y \
66+
build-essential \
67+
llvm-dev \
68+
libclang-dev \
69+
clang
70+
71+
RUN mkdir -p \
72+
/onnxruntime-rs/build/onnxruntime-sys/src/ \
73+
/onnxruntime-rs/build/onnxruntime/src/ \
74+
/onnxruntime-rs/build/onnxruntime/tests/ \
75+
/opt/onnxruntime/lib \
76+
/opt/ort-customops/lib
77+
78+
COPY --from=onnxruntime /opt/onnxruntime/tmp/onnxruntime/build/Linux/RelWithDebInfo/libonnxruntime.so /opt/onnxruntime/lib/
79+
COPY --from=onnxruntime /opt/ort-customops/tmp/ort-customops/out/Linux/libortcustomops.so /opt/ort-customops/lib/
80+
81+
WORKDIR /onnxruntime-rs/build
82+
83+
ENV ORT_STRATEGY=system
84+
# this has /lib/ appended to it and is used as a lib search path in onnxruntime-sys's build.rs
85+
ENV ORT_LIB_LOCATION=/opt/onnxruntime/
86+
87+
ENV ONNXRUNTIME_RS_TEST_ORT_CUSTOMOPS_LIB=/opt/ort-customops/lib/libortcustomops.so
88+
89+
# create enough of an empty project that dependencies can build
90+
COPY /Cargo.lock /Cargo.toml /onnxruntime-rs/build/
91+
COPY /onnxruntime/Cargo.toml /onnxruntime-rs/build/onnxruntime/
92+
COPY /onnxruntime-sys/Cargo.toml /onnxruntime-sys/build.rs /onnxruntime-rs/build/onnxruntime-sys/
93+
94+
CMD cargo test
95+
96+
# build dependencies and clean the bogus contents of our two packages
97+
RUN touch \
98+
onnxruntime/src/lib.rs \
99+
onnxruntime/tests/integration_tests.rs \
100+
onnxruntime-sys/src/lib.rs \
101+
&& cargo build --tests \
102+
&& cargo clean --package onnxruntime-sys \
103+
&& cargo clean --package onnxruntime \
104+
&& rm -rf \
105+
onnxruntime/src/ \
106+
onnxruntime/tests/ \
107+
onnxruntime-sys/src/
108+
109+
# now build the actual source
110+
COPY /test-models test-models
111+
COPY /onnxruntime-sys/src onnxruntime-sys/src
112+
COPY /onnxruntime/src onnxruntime/src
113+
COPY /onnxruntime/tests onnxruntime/tests
114+
115+
RUN ln -s /opt/onnxruntime/lib/libonnxruntime.so /opt/onnxruntime/lib/libonnxruntime.so.1.6.0
116+
ENV LD_LIBRARY_PATH=/opt/onnxruntime/lib
117+
118+
RUN cargo build --tests

onnxruntime/Cargo.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,12 @@ ndarray = "0.13"
2626
thiserror = "1.0"
2727
tracing = "0.1"
2828

29+
[target.'cfg(unix)'.dependencies]
30+
libc = "0.2.88"
31+
32+
[target.'cfg(windows)'.dependencies]
33+
winapi = { version = "0.3.9", features = ["std"] }
34+
2935
# Enabled with 'model-fetching' feature
3036
ureq = {version = "1.5.1", optional = true}
3137

onnxruntime/src/session.rs

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
//! Module containing session types
22
3-
use std::{convert::TryInto as _, ffi::CString, fmt::Debug, path::Path};
3+
use std::{convert::TryInto as _, ffi, ffi::CString, fmt::Debug, path::Path};
44

55
#[cfg(not(target_family = "windows"))]
66
use std::os::unix::ffi::OsStrExt;
@@ -64,11 +64,16 @@ pub struct SessionBuilder<'a> {
6464

6565
allocator: AllocatorType,
6666
memory_type: MemType,
67+
custom_runtime_handles: Vec<*mut ::std::os::raw::c_void>,
6768
}
6869

6970
impl<'a> Drop for SessionBuilder<'a> {
7071
#[tracing::instrument]
7172
fn drop(&mut self) {
73+
for &handle in self.custom_runtime_handles.iter() {
74+
close_lib_handle(handle);
75+
}
76+
7277
debug!("Dropping the session options.");
7378
assert_ne!(self.session_options_ptr, std::ptr::null_mut());
7479
unsafe { g_ort().ReleaseSessionOptions.unwrap()(self.session_options_ptr) };
@@ -89,6 +94,7 @@ impl<'a> SessionBuilder<'a> {
8994
session_options_ptr,
9095
allocator: AllocatorType::Arena,
9196
memory_type: MemType::Default,
97+
custom_runtime_handles: Vec::new(),
9298
})
9399
}
94100

@@ -136,6 +142,39 @@ impl<'a> SessionBuilder<'a> {
136142
Ok(self)
137143
}
138144

145+
/// Registers a custom ops library with the given library path in the session.
146+
pub fn with_custom_op_lib(mut self, lib_path: &str) -> Result<SessionBuilder<'a>> {
147+
let path_cstr = ffi::CString::new(lib_path)?;
148+
149+
let mut handle: *mut ::std::os::raw::c_void = std::ptr::null_mut();
150+
151+
let status = unsafe {
152+
g_ort().RegisterCustomOpsLibrary.unwrap()(
153+
self.session_options_ptr,
154+
path_cstr.as_ptr(),
155+
&mut handle,
156+
)
157+
};
158+
159+
// per RegisterCustomOpsLibrary docs, release handle if there was an error and the handle
160+
// is non-null
161+
match status_to_result(status).map_err(OrtError::SessionOptions) {
162+
Ok(_) => {}
163+
Err(e) => {
164+
if handle != std::ptr::null_mut() {
165+
// handle was written to, should release it
166+
close_lib_handle(handle);
167+
}
168+
169+
return Err(e);
170+
}
171+
}
172+
173+
self.custom_runtime_handles.push(handle);
174+
175+
Ok(self)
176+
}
177+
139178
/// Download an ONNX pre-trained model from the [ONNX Model Zoo](https://github.com/onnx/models) and commit the session
140179
#[cfg(feature = "model-fetching")]
141180
pub fn with_model_downloaded<M>(self, model: M) -> Result<Session<'a>>
@@ -619,6 +658,20 @@ where
619658
res
620659
}
621660

661+
#[cfg(unix)]
662+
fn close_lib_handle(handle: *mut ::std::os::raw::c_void) {
663+
unsafe {
664+
libc::dlclose(handle);
665+
}
666+
}
667+
668+
#[cfg(windows)]
669+
fn close_lib_handle(handle: *mut ::std::os::raw::c_void) {
670+
unsafe {
671+
winapi::um::libloaderapi::FreeLibrary(handle as winapi::shared::minwindef::HINSTANCE)
672+
};
673+
}
674+
622675
/// This module contains dangerous functions working on raw pointers.
623676
/// Those functions are only to be used from inside the
624677
/// `SessionBuilder::with_model_from_file()` method.

onnxruntime/tests/custom_ops.rs

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
use std::error::Error;
2+
3+
use ndarray;
4+
use onnxruntime::tensor::{DynOrtTensor, OrtOwnedTensor};
5+
use onnxruntime::{environment::Environment, LoggingLevel};
6+
7+
#[test]
8+
fn run_model_with_ort_customops() -> Result<(), Box<dyn Error>> {
9+
let lib_path = match std::env::var("ONNXRUNTIME_RS_TEST_ORT_CUSTOMOPS_LIB") {
10+
Ok(s) => s,
11+
Err(_e) => {
12+
println!("Skipping ort_customops test -- no lib specified");
13+
return Ok(());
14+
}
15+
};
16+
17+
let environment = Environment::builder()
18+
.with_name("test")
19+
.with_log_level(LoggingLevel::Verbose)
20+
.build()?;
21+
22+
let mut session = environment
23+
.new_session_builder()?
24+
.with_custom_op_lib(&lib_path)?
25+
.with_model_from_file("../test-models/tensorflow/regex_model.onnx")?;
26+
27+
//Inputs:
28+
// 0:
29+
// name = input_1:0
30+
// type = String
31+
// dimensions = [None]
32+
// Outputs:
33+
// 0:
34+
// name = Identity:0
35+
// type = String
36+
// dimensions = [None]
37+
38+
let array = ndarray::Array::from(vec![String::from("Hello world!")]);
39+
let input_tensor_values = vec![array];
40+
41+
let outputs: Vec<DynOrtTensor<_>> = session.run(input_tensor_values)?;
42+
let strings: OrtOwnedTensor<String, _> = outputs[0].try_extract()?;
43+
44+
// ' ' replaced with '_'
45+
assert_eq!(
46+
&[String::from("Hello_world!")],
47+
strings.view().as_slice().unwrap()
48+
);
49+
50+
Ok(())
51+
}

test-models/tensorflow/README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,12 @@ This supports strings, and doesn't require custom operators.
1616
pipenv run python src/unique_model.py
1717
pipenv run python -m tf2onnx.convert --saved-model models/unique_model --output unique_model.onnx --opset 11
1818
```
19+
20+
# Model: Regex (uses `ort_customops`)
21+
22+
A TensorFlow model that applies a regex, which requires the onnxruntime custom ops in `ort-customops`.
23+
24+
```
25+
pipenv run python src/regex_model.py
26+
pipenv run python -m tf2onnx.convert --saved-model models/regex_model --output regex_model.onnx --extra_opset ai.onnx.contrib:1
27+
```
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
tf2onnx1.9.0:�
2+

3+
input_1:0
4+
5+
pattern__7
6+
7+
rewrite__8
8+
Identity:0)PartitionedCall/model1/StaticRegexReplace"StringRegexReplace:ai.onnx.contribtf2onnx*2_B
9+
rewrite__8*2 B
10+
pattern__7R!converted from models/regex_modelZ
11+
input_1:0
12+

13+
14+
unk__9b
15+
16+
Identity:0
17+

18+
unk__10B B
19+
ai.onnx.contrib
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import tensorflow as tf
2+
import numpy as np
3+
import tf2onnx
4+
5+
6+
class RegexModel(tf.keras.Model):
7+
8+
def __init__(self, name='model1', **kwargs):
9+
super(RegexModel, self).__init__(name=name, **kwargs)
10+
11+
def call(self, inputs):
12+
return tf.strings.regex_replace(inputs, " ", "_", replace_global=True)
13+
14+
15+
model1 = RegexModel()
16+
17+
print(model1(tf.constant(["Hello world!"])))
18+
19+
model1.save("models/regex_model")

0 commit comments

Comments
 (0)