Description
Describe the bug
Converting a convolution from JAX -> TF -> ONNX with XLA enabled has missing Ops.
System information
- MacOS 11
- Python version: 3.9.13
- Relevant package versions:
flax==0.5.2
jax==0.3.13
jaxlib==0.3.10
onnx==1.12.0
onnxruntime==1.11.1
tensorflow==2.8.0
tf2onnx==1.11.1
To Reproduce
import tensorflow as tf
import jax
import jax.numpy as jnp
from jax.experimental import jax2tf
from flax import linen as nn
import onnx
import tf2onnx
import onnxruntime
class ConvBlock(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Conv(
features=10,
kernel_size=(7, 7),
strides=(2, 2),
padding=[(3, 3), (3, 3)],
)(x)
x = nn.relu(x)
return x
model = ConvBlock()
rng = jax.random.PRNGKey(42)
batch_shape = (1, 32, 32, 3)
init_array = jnp.ones(batch_shape)
variables = model.init(rng, init_array)
def inference(x):
return model.apply(variables, x).sum()
inference_tf = jax2tf.convert(inference, enable_xla=True)
inference_tf = tf.function(inference_tf, autograph=False)
inference_onnx = tf2onnx.convert.from_function(inference_tf, input_signature=[tf.TensorSpec(batch_shape)], opset=16)
model_proto, external_tensor_storage = inference_onnx
session = onnxruntime.InferenceSession(model_proto.SerializeToString())
rng, subkey = jax.random.split(rng)
images = jax.random.normal(subkey, shape=batch_shape)
input_name = session.get_inputs()[0].name
onnx_output = session.run([], {input_name: images})[0]
jax_output = inference(images)
print(f"onnx_output: {onnx_output}")
print(f"jax_output: {jax_output}")
Stack Trace
InvalidGraph Traceback (most recent call last)
Input In [6], in
40 inference_onnx = tf2onnx.convert.from_function(inference_tf, input_signature=[tf.TensorSpec(batch_shape)], opset=16)
41 model_proto, external_tensor_storage = inference_onnx
---> 43 session = onnxruntime.InferenceSession(model_proto.SerializeToString())
45 rng, subkey = jax.random.split(rng)
46 images = jax.random.normal(subkey, shape=batch_shape)
File ~/.virtualenvs/py39/lib/python3.9/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py:335, in InferenceSession.init(self, path_or_bytes, sess_options, providers, provider_options, **kwargs)
332 disabled_optimizers = kwargs['disabled_optimizers'] if 'disabled_optimizers' in kwargs else None
334 try:
--> 335 self._create_inference_session(providers, provider_options, disabled_optimizers)
336 except ValueError:
337 if self._enable_fallback:
File ~/.virtualenvs/py39/lib/python3.9/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py:372, in InferenceSession._create_inference_session(self, providers, provider_options, disabled_optimizers)
370 sess = C.InferenceSession(session_options, self._model_path, True, self._read_config_from_model)
371 else:
--> 372 sess = C.InferenceSession(session_options, self._model_bytes, False, self._read_config_from_model)
374 if disabled_optimizers is None:
375 disabled_optimizers = set()
InvalidGraph: [ONNXRuntimeError] : 10 : INVALID_GRAPH : This is an invalid model. In Node, ("XlaConvV2", XlaConvV2, "", -1) : ("args_0": tensor(float),"Const:0": tensor(float),"XlaConvV2/window_strides:0": tensor(int32),"XlaConvV2/padding:0": tensor(int32),"XlaConvV2/lhs_dilation:0": tensor(int32),"XlaConvV2/lhs_dilation:0": tensor(int32),"XlaConvV2/feature_group_count:0": tensor(int32),) -> ("XlaConvV2:0",) , Error No Op registered for XlaConvV2 with domain_version of 16