Description
Onnx is not correctly inferring dynamic shape for a tf.reshape
operation if it is done after a time-distributed Dense
layer.
For example, let's consider the following model:
class CustomLayer(tf.keras.layers.Layer):
def __init__(self, H, N, *args, **kwargs):
super().__init__(*args, **kwargs)
self.dense = tf.keras.layers.Dense(512)
self.N = N
assert not H%N
self.D = H//N
def call(self, x):
x = self.dense(x)
x_shape = tf.shape(x)
x = tf.reshape(x, [x_shape[0], x_shape[1], self.N, self.D])
return x
def get_model():
inputs = tf.keras.Input((None,512))
x = CustomLayer(H=512, N=8)(inputs)
return tf.keras.Model(inputs=inputs, outputs=x)
model = get_model()
@tf.function(input_signature=[tf.TensorSpec(shape=[None, None, 512], dtype=tf.float32, name="input")])
def infer(x):
return model(x)
concrete_func = infer.get_concrete_function()
tf.saved_model.save(model, 'model', signatures = {'serving_default': concrete_func})
If I run inference with tf everything works fine and the tensor is reshaped from None x None x 512
to None x None x 8 x 64
:
a = np.random.rand(2,100,512).astype('float32')
concrete_func(a).shape
Out[1]: TensorShape([2, 100, 8, 64])
However after converting with Onnx opset 17 and trying to infer with onxxruntime
the following error is shown:
import onnxruntime as rt
sess = rt.InferenceSession('model.onnx')
result = sess.run(['output_0'], {"input": a})
---------------------------------------------------------------------------
RuntimeException Traceback (most recent call last)
/work/255973546.1.grid/ipykernel_46657/489356439.py in <module>
----> 1 result = sess.run(['output_0'], {"input": a})
2 result
/gpfs/ess2_fs1/amr_tools/data/amr/tools/francesco_salvetti/exp/onnx/env_onnx/lib/python3.7/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py in run(self, output_names, input_feed, run_options)
198 output_names = [output.name for output in self._outputs_meta]
199 try:
--> 200 return self._sess.run(output_names, input_feed, run_options)
201 except C.EPFail as err:
202 if self._enable_fallback:
RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zero status code returned while running Reshape node. Name:'StatefulPartitionedCall/model_22/custom_layer_29/Reshape' Status Message: /onnxruntime_src/onnxruntime/core/providers/cpu/tensor/reshape_helper.h:40 onnxruntime::ReshapeHelper::ReshapeHelper(const onnxruntime::TensorShape&, onnxruntime::TensorShapeVector&, bool) gsl::narrow_cast<int64_t>(input_shape.Size()) == size was false. The input tensor cannot be reshaped to the requested shape. Input shape:{200,512}, requested shape:{200,512,8,64}
As shown in the error, after the time-distributed Dense
, the tensor appears to have a shape 200 x 512
instead of 2 x 100 x 512
, as if the time-distributed Dense merged the first two axes. The reshape then, instead of targeting 2 x 100 x 8 x 64
, it gives 200 x 512 x 8 x 64
, resulting in the consequent error. With no Dense
everything goes correctly. I also tried to explicitly use tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(...))
but it gives the error also in that case.
System information
- TensorFlow Version: 2.11.0
- Python version: 3.7.5
- ONNX version: 1.12.0 (opset 17)
- ONNXRuntime version: 1.13.1
Edit
Apparently the problem is solved if the tensor is explicitly reshaped to (2 x 100) x 512
before applying the Dense
, avoiding the time-distributed behavior. However I think it should be supported by tf2onnx since time-distributed Dense
layers are very popular (for example in Transformers).
Edit 2
After further investigation, I found out that the problem lies in when the shape of x
is read in order to get the first two dimensions used in the following reshape. If tf.shape
is read before the Dense
, everything works. If it is read after the Dense
, even if explicitly reshaped, it no longer works, since it reads the "wrong" shape with merged axes.