Open
Description
Describe the issue
I see ORTModule is easy to use for pytorch model training. When it comes to inference, it seems that we have to write so much more code to do inference with torch tensors. An example code from official docs:
# X is a PyTorch tensor on device
session = onnxruntime.InferenceSession('model.onnx', providers=['CUDAExecutionProvider', 'CPUExecutionProvider']))
binding = session.io_binding()
X_tensor = X.contiguous()
binding.bind_input(
name='X',
device_type='cuda',
device_id=0,
element_type=np.float32,
shape=tuple(x_tensor.shape),
buffer_ptr=x_tensor.data_ptr(),
)
## Allocate the PyTorch tensor for the model output
Y_shape = ... # You need to specify the output PyTorch tensor shape
Y_tensor = torch.empty(Y_shape, dtype=torch.float32, device='cuda:0').contiguous()
binding.bind_output(
name='Y',
device_type='cuda',
device_id=0,
element_type=np.float32,
shape=tuple(Y_tensor.shape),
buffer_ptr=Y_tensor.data_ptr(),
)
session.run_with_iobinding(binding)
So can ORTModule support inference with providers?
To reproduce
:)
Urgency
No response
ONNX Runtime Installation
Released Package
ONNX Runtime Version or Commit ID
1.17.1
PyTorch Version
2.2.0
Execution Provider
CUDA
Execution Provider Library Version
No response