Skip to content

Commit 61dc6a1

Browse files
committed
update
1 parent 333fbdb commit 61dc6a1

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3538,6 +3538,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView
35383538

35393539
// Create compute function
35403540
compute_info.compute_func = [this](FunctionState state, const OrtApi* api, OrtKernelContext* context) {
3541+
cudaSetDevice(device_id_);
35413542
Ort::KernelContext ctx(context);
35423543

35433544
TensorrtFuncState* trt_state = reinterpret_cast<TensorrtFuncState*>(state);
@@ -4212,6 +4213,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(con
42124213

42134214
// Create compute function
42144215
compute_info.compute_func = [this](FunctionState state, const OrtApi* api, OrtKernelContext* context) {
4216+
cudaSetDevice(device_id_);
42154217
Ort::KernelContext ctx(context);
42164218

42174219
TensorrtShortFuncState* trt_state = reinterpret_cast<TensorrtShortFuncState*>(state);

0 commit comments

Comments
 (0)