Skip to content

Commit 482668d

Browse files
committed
NVExecutionProvider
1 parent 4341eac commit 482668d

File tree

7 files changed

+60
-7
lines changed

7 files changed

+60
-7
lines changed

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ cann = [ "ort-sys/cann" ]
8484
qnn = [ "ort-sys/qnn" ]
8585
webgpu = [ "ort-sys/webgpu" ]
8686
azure = [ "ort-sys/azure" ]
87+
nv = [ "ort-sys/nv" ]
8788

8889
[dependencies]
8990
ort-sys = { version = "=2.0.0-rc.9", path = "ort-sys", default-features = false }

ort-sys/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ cann = []
4444
qnn = []
4545
webgpu = [ "dep:glob" ]
4646
azure = []
47+
nv = []
4748

4849
[build-dependencies]
4950
ureq = { version = "3", optional = true, default-features = false, features = [ "native-tls", "socks-proxy" ] }

src/execution_providers/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ pub mod webgpu;
6363
pub use self::webgpu::WebGPUExecutionProvider;
6464
pub mod azure;
6565
pub use self::azure::AzureExecutionProvider;
66+
pub mod nv;
67+
pub use self::nv::NVExecutionProvider;
6668

6769
/// ONNX Runtime works with different hardware acceleration libraries through its extensible **Execution Providers**
6870
/// (EP) framework to optimally execute the ONNX models on the hardware platform. This interface enables flexibility for

src/execution_providers/nv.rs

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
use super::{ExecutionProvider, ExecutionProviderOptions, RegisterError};
2+
use crate::{error::Result, session::builder::SessionBuilder};
3+
4+
#[derive(Debug, Default, Clone)]
5+
pub struct NVExecutionProvider {
6+
options: ExecutionProviderOptions
7+
}
8+
9+
super::impl_ep!(arbitrary; NVExecutionProvider);
10+
11+
impl NVExecutionProvider {
12+
pub fn with_device_id(mut self, device_id: u32) -> Self {
13+
self.options.set("ep.nvtensorrtrtxexecutionprovider.device_id", device_id.to_string());
14+
self
15+
}
16+
17+
pub fn with_cuda_graph(mut self, enable: bool) -> Self {
18+
self.options
19+
.set("ep.nvtensorrtrtxexecutionprovider.nv_cuda_graph_enable", if enable { "1" } else { "0" });
20+
self
21+
}
22+
}
23+
24+
impl ExecutionProvider for NVExecutionProvider {
25+
fn as_str(&self) -> &'static str {
26+
"NvTensorRTRTXExecutionProvider"
27+
}
28+
29+
fn supported_by_platform(&self) -> bool {
30+
cfg!(any(all(target_os = "linux", any(target_arch = "aarch64", target_arch = "x86_64")), all(target_os = "windows", target_arch = "x86_64")))
31+
}
32+
33+
#[allow(unused, unreachable_code)]
34+
fn register(&self, session_builder: &mut SessionBuilder) -> Result<(), RegisterError> {
35+
#[cfg(any(feature = "load-dynamic", feature = "nv"))]
36+
{
37+
use crate::{AsPointer, ortsys};
38+
39+
let ffi_options = self.options.to_ffi();
40+
ortsys![unsafe SessionOptionsAppendExecutionProvider(
41+
session_builder.ptr_mut(),
42+
c"NvTensorRtRtx".as_ptr().cast::<core::ffi::c_char>(),
43+
ffi_options.key_ptrs(),
44+
ffi_options.value_ptrs(),
45+
ffi_options.len(),
46+
)?];
47+
return Ok(());
48+
}
49+
50+
Err(RegisterError::MissingFeature)
51+
}
52+
}

src/memory.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -255,13 +255,12 @@ impl AllocationDevice {
255255
pub const CANN: AllocationDevice = AllocationDevice("Cann\0");
256256
pub const CANN_PINNED: AllocationDevice = AllocationDevice("CannPinned\0");
257257
pub const DIRECTML: AllocationDevice = AllocationDevice("DML\0");
258-
pub const DIRECTML_CPU: AllocationDevice = AllocationDevice("DML CPU\0");
259258
pub const HIP: AllocationDevice = AllocationDevice("Hip\0");
260259
pub const HIP_PINNED: AllocationDevice = AllocationDevice("HipPinned\0");
261260
pub const OPENVINO_CPU: AllocationDevice = AllocationDevice("OpenVINO_CPU\0");
262261
pub const OPENVINO_GPU: AllocationDevice = AllocationDevice("OpenVINO_GPU\0");
263-
pub const XNNPACK: AllocationDevice = AllocationDevice("XnnpackExecutionProvider\0");
264-
pub const TVM: AllocationDevice = AllocationDevice("TVM\0");
262+
pub const QNN_HTP_SHARED: AllocationDevice = AllocationDevice("QnnHtpShared\0");
263+
pub const WEBGPU_BUFFER: AllocationDevice = AllocationDevice("WebGPU_Buffer\0");
265264

266265
pub fn as_str(&self) -> &'static str {
267266
&self.0[..self.0.len() - 1]

src/session/builder/editable.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use alloc::sync::Arc;
1+
use alloc::{boxed::Box, sync::Arc};
22
use core::{
33
ops::Deref,
44
ptr::{self, NonNull}

src/value/impl_tensor/copy.rs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ fn ep_for_device(device: AllocationDevice, device_id: i32) -> Result<ep::Executi
4343
.with_conv_max_workspace(false)
4444
.with_conv_algorithm_search(ep::cuda::CuDNNConvAlgorithmSearch::Default)
4545
.build(),
46-
AllocationDevice::DIRECTML | AllocationDevice::DIRECTML_CPU => ep::DirectMLExecutionProvider::default().with_device_id(device_id).build(),
46+
AllocationDevice::DIRECTML => ep::DirectMLExecutionProvider::default().with_device_id(device_id).build(),
4747
AllocationDevice::CANN | AllocationDevice::CANN_PINNED => ep::CANNExecutionProvider::default()
4848
.with_arena_extend_strategy(ep::ArenaExtendStrategy::SameAsRequested)
4949
.with_cann_graph(false)
@@ -63,8 +63,6 @@ fn ep_for_device(device: AllocationDevice, device_id: i32) -> Result<ep::Executi
6363
.with_exhaustive_conv_search(false)
6464
.with_device_id(device_id)
6565
.build(),
66-
AllocationDevice::TVM => ep::TVMExecutionProvider::default().build(),
67-
AllocationDevice::XNNPACK => ep::XNNPACKExecutionProvider::default().build(),
6866
_ => return Err(crate::Error::new("Unsupported allocation device {device} for tensor copy target"))
6967
})
7068
}

0 commit comments

Comments
 (0)