Skip to content

Conv3D performance degradation after ONNX conversion #2303

Open
@jm2201

Description

@jm2201

Describe the bug
A simple tensorflow model with Conv3D and pooling is 3.6x slower on CPU after converting to ONNX.
The same model with Conv3D replaced by Conv2D is 10x faster on CPU after converting to ONNX.

Urgency
If not resolved in the next 4-6 months, this bug will block the planned release of a TF to ONNX-converted model.

System information

  • OS Platform and Distribution (e.g., Linux Ubuntu 18.04*): Windows 10 Pro 22H2
  • TensorFlow Version: 2.10.0
  • Python version: 3.9.16
  • ONNX version (if applicable, e.g. 1.11*): 1.15.0
  • ONNXRuntime version (if applicable, e.g. 1.11*): 1.16.3

To Reproduce

import os
import numpy as np
import tensorflow as tf
from tensorflow.keras import models, layers, optimizers, initializers
import onnxruntime as rt

def get_model(_3d=True):
    tf.keras.backend.set_image_data_format('channels_first')
    if _3d:
        input_shape = (1, 80, 128, 128)  # channels first
        conv_layer = layers.Conv3D
        pool_layer = layers.MaxPooling3D        
    else:
        input_shape = (1, 128, 128)
        conv_layer = layers.Conv2D
        pool_layer = layers.MaxPooling2D
        
    input = layers.Input(input_shape, name='input_0')
    x = conv_layer(32, 3, activation='relu', padding='same')(input)
    x = conv_layer(32, 3, activation='relu', padding='same')(x)
    x = pool_layer(2)(x)
    
    x = conv_layer(64, 3, activation='relu', padding='same')(x)
    x = conv_layer(64, 3, activation='relu', padding='same')(x)
    x = pool_layer(2)(x)
    
    x = conv_layer(128, 3, activation='relu', padding='same')(x)
    x = conv_layer(128, 3, activation='relu', padding='same')(x)
    x = pool_layer(2)(x)
    
    out = conv_layer(1, 1, activation='sigmoid')(x)
    
    model = models.Model(inputs=input, outputs=out)
    model.compile(optimizer=optimizers.Adam(lr=1e-3, decay=1e-5), loss='binary_crossentropy')
    return model

Test the 3D tensorflow model:

MODEL_EX = 'model_ex/saved_model'
_3D = True
tf_model = get_model(_3D)
input_ = np.random.random((1, 1, 80, 128, 128)).astype(np.float32)
%timeit out = tf_model.predict(input_)
1/1 [==============================] - 0s 89ms/step
1/1 [==============================] - 0s 20ms/step
1/1 [==============================] - 0s 21ms/step
1/1 [==============================] - 0s 20ms/step
1/1 [==============================] - 0s 21ms/step
1/1 [==============================] - 0s 23ms/step
1/1 [==============================] - 0s 20ms/step
1/1 [==============================] - 0s 21ms/step
195 ms ± 2.08 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Convert the 3D model to ONNX:

tf_model.save(MODEL_EX)
!python -m tf2onnx.convert --opset 18 --saved-model model_ex/saved_model --output model_ex/tmp.onnx

Test the 3D ONNX model:

TMP_MODEL = os.path.join('model_ex', 'tmp.onnx')
sess = rt.InferenceSession(TMP_MODEL)
%timeit result = sess.run(None, {'input_0': input_})
737 ms ± 23.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Test the 2D tensorflow model:

MODEL_EX = 'model_ex/saved_model'
_3D = False
tf_model = get_model(_3D)
input_ = np.random.random((1, 1, 128, 128)).astype(np.float32)
%timeit out = tf_model.predict(input_)
1/1 [==============================] - 0s 19ms/step
1/1 [==============================] - 0s 19ms/step
1/1 [==============================] - 0s 20ms/step
1/1 [==============================] - 0s 20ms/step
...
1/1 [==============================] - 0s 19ms/step
1/1 [==============================] - 0s 19ms/step
1/1 [==============================] - 0s 18ms/step
59.7 ms ± 1.31 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

Convert the 2D model to ONNX:

tf_model.save(MODEL_EX)
!python -m tf2onnx.convert --opset 18 --saved-model model_ex/saved_model --output model_ex/tmp.onnx

Test the 2D ONNX model:

TMP_MODEL = os.path.join('model_ex', 'tmp.onnx')
sess = rt.InferenceSession(TMP_MODEL)
%timeit result = sess.run(None, {'input_0': input_})
4.55 ms ± 304 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugAn unexpected problem or unintended behaviorpending on user responseWaiting for more information or validation from user

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions