Skip to content

onnx result error #92

@ryujaehun

Description

@ryujaehun

Hello ,
There are some issues using onnx file created in Pytorch.

below code make simple NN and export onnx file using PyTorch.

import torch
import torch.nn as nn
import onnx
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3,16,3)
        self.layers = nn.Sequential(
            nn.Conv2d(3,16,3),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Conv2d(16,16,3),
        )

    def forward(self, x):
        x = self.layers(x)
        return x
m = Net()
m.eval()
batch = 1
dummy_input = torch.randn(1,3, 32,32, device="cpu")
input_names = [ "input" ] 
output_names = [ "output" ]
torch.onnx.export(m, (dummy_input), f"test.onnx",\
                  do_constant_folding=True,verbose=True,\
                  input_names=input_names, output_names=output_names)

And, I made the optimized computation graph using TASO.

python3 example/test_onnx.py -f test.onnx

I made inference code for verification using onnxruntime. Unexpectedly, the results of the optimized computing graph were different from the previous results.

import onnx
import numpy as np
import onnxruntime
inp = np.random.randn(1,3,32,32).astype(np.float32)

ort_session = onnxruntime.InferenceSession("test.onnx")
ort_inputs = {ort_session.get_inputs()[0].name: inp }
ort_outs = ort_session.run(None, ort_inputs)

ort_session2 = onnxruntime.InferenceSession("test.onnx.taso.onnx")
ort_inputs2 = {ort_session2.get_inputs()[0].name: inp }
ort_outs2 = ort_session2.run(None, ort_inputs2)

I made several neural network and experimented with them, but I could see the same result.

Anybody help ?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions