Skip to content

Converting Conv3DTranspose with strides of 8 leads to inference differences #1825

Open
@nnpk

Description

@nnpk

Describe the bug
When converting tf.keras.layers.Convolution3DTranspose with strides of (8, 8, 8) the conversion from a keras model to a ONNX model leads to large differences in the inference prediction output to a tolerance of up to 1.0e-1. With strides below, e.g. (2, 2, 2) or (4, 4, 4) no such prediction differences occur. A minimal example is posted below.

System information

  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): macOS 10.15.7
  • Tensorflow Version: 2.7.0
  • Python version: 3.9.7
  • tf2onnx version: 1.9.3
  • onnx version: 1.10.2
  • onnxruntime version: 1.10.0
  • opset: 15

To Reproduce
If you compare the output from predicting the tensorflow saved model of the model below to the same model converted to ONNX, large prediction differences occur.

def model(input_shape=[64, 64, 64], channels=1):
    inputs = tf.keras.Input(
                shape=(
                    input_shape[0],
                    input_shape[1],
                    input_shape[2],
                    channels,
                )
            )
    outputs = tf.keras.layers.Convolution3DTranspose(
            filters=1,
            kernel_size=(3, 3, 3),
            strides=(8, 8, 8),
            padding="same",
            output_padding=None,
        )(inputs)

    model = tf.keras.Model(
            inputs=inputs, outputs=outputs
        )
    model.summary(line_length=150)
    model.compile(optimizer="adam", loss="mean_squared_error", run_eagerly=True)
    return model

Now train the model and save

import numpy as np
model = model(input_shape=[64, 64, 64], channels=1)
# Train the model.
test_input = np.random.random((1, 64, 64, 64, 1))
test_target = np.random.random((1, 512, 512, 512, 2))
model.fit(test_input, test_target)
model.save(saved_model_output_path)

Convert to ONNX

import os
import tf2onnx

tf2onnx.convert.from_keras(
    model,
    opset=15,
    output_path=onnx_output_path,
)

Use ONNX and Tensorflow to load the model

import onnx
import onnxruntime as onr
from google.protobuf.json_format import MessageToDict
def use_tensorflow(save_model_path, vol):
    model = tf.keras.models.load_model(
        save_model_path,
        compile=False,
    )
    model.compile(
        loss="mean_squared_error",
        optimizer=tf.keras.optimizers.Adam()
    )
    pred_mask = model.predict(vol)
    return pred_mask


def use_onnxrt(path, data):
    onnx_model = onnx.load(path)
    for _input in onnx_model.graph.input:
        print(MessageToDict(_input))

    onnx.checker.check_model(onnx_model)
    print("The model is checked!")

    session = onr.InferenceSession(path)
    return session.run(None, {session.get_inputs()[0].name: data})

Use some sample input to compare

example_input=np.ones((64,64,64))
example_input = np.expand_dims(example_input, axis=0)
example_input = np.expand_dims(example_input, axis=-1)
example_input = np.float32(example_input)

onnx_mask = use_onnxrt(onnx_output_path, example_input)
# Use TF
tf_mask = use_tensorflow(saved_model_output_path, example_input)
#print(tf_mask)
if np.allclose(onnx_mask, tf_mask, 1.0e-2, 1.0e-2):
    print("Predictions are equal")
else:
    print("Predictions do NOT match")

If you try out the same with strides of (4, 4, 4) the predictions will be equal to a tolerance of 1.0e-7. However with strides of (8, 8, 8) I encountered inference differences of beyond 1.0e-1.

Thanks for looking into it!

Metadata

Metadata

Assignees

No one assigned

    Labels

    potential bugError in codebase may cause a bug, but no concrete examples observed

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions