Skip to content

[Training] ORT Gradient Builder expects optinal outpts for LayerNormalization Op #19427

Open
@thiagocrepaldi

Description

@thiagocrepaldi

Describe the issue

This is actually comes from https://github.com/microsoft/onnx-converters-private/issues/203

The ONNX spec for LayerNormalization-17 establishes that it can output from 1 to 3 outputs; Only the the first output is required, whereas the last 2 are optional.

image

To reproduce

import torch.nn.functional as F
from torch import nn
from onnxruntime.training.ortmodule import ORTModule, DebugOptions, LogLevel
import torch
import os

os.environ["ORTMODULE_ONNX_OPSET_VERSION"] = str(17)

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 128)
        self.fc2 = nn.Linear(128,64)
        self.fc3 = nn.Linear(64,10)
        # self.norm = nn.BatchNorm1d([784])
        self.norm = nn.LayerNorm([784])

    def forward(self, x):
        x = x.view(x.shape[0], -1)
        x = self.norm(x)
        
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        x = F.log_softmax(x, dim=1)

        return x

model = Net()
# model = ORTModule(model)
model = ORTModule(model, DebugOptions(save_onnx=True, onnx_prefix='tmp', log_level=LogLevel.INFO))
model.to("cuda")

images = torch.randn(8, 28, 28).to("cuda")
output = model(images)

image

Urgency

Consult with @prathikr

ONNX Runtime Installation

Built from Source

ONNX Runtime Version or Commit ID

main

PyTorch Version

main

Execution Provider

CUDA

Execution Provider Library Version

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    staleissues that have not been addressed in a while; categorized by a bottrainingissues related to ONNX Runtime training; typically submitted using template

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions