Skip to content

[Training] Can we use ORTModule for inference? #20281

Open
@LSC527

Description

@LSC527

Describe the issue

I see ORTModule is easy to use for pytorch model training. When it comes to inference, it seems that we have to write so much more code to do inference with torch tensors. An example code from official docs:

# X is a PyTorch tensor on device
session = onnxruntime.InferenceSession('model.onnx', providers=['CUDAExecutionProvider', 'CPUExecutionProvider']))
binding = session.io_binding()

X_tensor = X.contiguous()

binding.bind_input(
    name='X',
    device_type='cuda',
    device_id=0,
    element_type=np.float32,
    shape=tuple(x_tensor.shape),
    buffer_ptr=x_tensor.data_ptr(),
    )

## Allocate the PyTorch tensor for the model output
Y_shape = ... # You need to specify the output PyTorch tensor shape
Y_tensor = torch.empty(Y_shape, dtype=torch.float32, device='cuda:0').contiguous()
binding.bind_output(
    name='Y',
    device_type='cuda',
    device_id=0,
    element_type=np.float32,
    shape=tuple(Y_tensor.shape),
    buffer_ptr=Y_tensor.data_ptr(),
)

session.run_with_iobinding(binding)

So can ORTModule support inference with providers?

To reproduce

:)

Urgency

No response

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

1.17.1

PyTorch Version

2.2.0

Execution Provider

CUDA

Execution Provider Library Version

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    trainingissues related to ONNX Runtime training; typically submitted using template

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions