Open
Description
Describe the issue
This is actually comes from https://github.com/microsoft/onnx-converters-private/issues/203
The ONNX spec for LayerNormalization-17 establishes that it can output from 1 to 3 outputs; Only the the first output is required, whereas the last 2 are optional.
To reproduce
import torch.nn.functional as F
from torch import nn
from onnxruntime.training.ortmodule import ORTModule, DebugOptions, LogLevel
import torch
import os
os.environ["ORTMODULE_ONNX_OPSET_VERSION"] = str(17)
class Net(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(784, 128)
self.fc2 = nn.Linear(128,64)
self.fc3 = nn.Linear(64,10)
# self.norm = nn.BatchNorm1d([784])
self.norm = nn.LayerNorm([784])
def forward(self, x):
x = x.view(x.shape[0], -1)
x = self.norm(x)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
x = F.log_softmax(x, dim=1)
return x
model = Net()
# model = ORTModule(model)
model = ORTModule(model, DebugOptions(save_onnx=True, onnx_prefix='tmp', log_level=LogLevel.INFO))
model.to("cuda")
images = torch.randn(8, 28, 28).to("cuda")
output = model(images)
Urgency
Consult with @prathikr
ONNX Runtime Installation
Built from Source
ONNX Runtime Version or Commit ID
main
PyTorch Version
main
Execution Provider
CUDA
Execution Provider Library Version
No response