Skip to content

Inconsistent Results After ONNX Runtime Optimization #23142

@Thrsu

Description

@Thrsu

Describe the issue

I encountered an issue where the outputs of a model optimized using ONNX Runtime (opt_level=0/1/2/99) are inconsistent with the original unoptimized model. This inconsistency occurs specifically for the output v4_0, and the error is intermittent (flaky test), not occurring every time the model is run.

The following error message is seen when comparing the results of the optimized model with the original model:

AssertionError: 
Not equal to tolerance rtol=0.001, atol=0.001

Mismatched elements: 2 / 22 (9.09%)
Max absolute difference: 1841117444
Max relative difference: 4.60279361e+08
 x: array([[-1841117440],
       [      32646],
       [          3],...
 y: array([[4],
       [3],
       [3],...

I suspect this could be related to precision loss or non-deterministic operations introduced during the optimization process. Could the team assist in analyzing the root cause of this discrepancy?

To reproduce

  1. Download the model
  2. Run the below script:
import onnx
import onnxruntime as ort
import numpy as np
from onnxruntime.transformers import optimizer

model_path = "inconsis2.onnx"
optimized_model_path = f"./opt.onnx"
sess_options = ort.SessionOptions()
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL
this_provider_list = ort.get_available_providers()

original_session = ort.InferenceSession(model_path, sess_options, providers=this_provider_list)
input_data = {"v7_0": np.random.rand(1, 1, 1).astype(np.int32), "v11_0": np.random.rand(1, 1).astype(np.int32)}
output_names = [output.name for output in original_session.get_outputs()]
original_result = original_session.run(output_names, input_data)

optimized_model = optimizer.optimize_model(model_path, opt_level=1, use_gpu=True)
optimized_model.save_model_to_file(optimized_model_path)
optimized_session = ort.InferenceSession(optimized_model_path, providers=this_provider_list)
optimized_model = onnx.load(optimized_model_path)
optimized_result = optimized_session.run(output_names, input_data)
for r1, r2 in zip(original_result, optimized_result):
    np.testing.assert_allclose(r1, r2, atol=1e-3, rtol=1e-3)

Urgency

No response

Platform

Linux

OS Version

Ubuntu 20.04

ONNX Runtime Installation

Built from Source

ONNX Runtime Version or Commit ID

5c1b7cc

ONNX Runtime API

Python

Architecture

X64

Execution Provider

CUDA

Execution Provider Library Version

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    model:transformerissues related to a transformer model: BERT, GPT2, Hugging Face, Longformer, T5, etc.staleissues that have not been addressed in a while; categorized by a bot

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions