Skip to content

XlaConvV2 Op not registered #1984

Open
Open
@JeffH-1QBit

Description

@JeffH-1QBit

Describe the bug
Converting a convolution from JAX -> TF -> ONNX with XLA enabled has missing Ops.

System information

  • MacOS 11
  • Python version: 3.9.13
  • Relevant package versions:
    flax==0.5.2
    jax==0.3.13
    jaxlib==0.3.10
    onnx==1.12.0
    onnxruntime==1.11.1
    tensorflow==2.8.0
    tf2onnx==1.11.1

To Reproduce

import tensorflow as tf

import jax
import jax.numpy as jnp
from jax.experimental import jax2tf
from flax import linen as nn

import onnx
import tf2onnx
import onnxruntime


class ConvBlock(nn.Module):

    @nn.compact
    def __call__(self, x):

        x = nn.Conv(
            features=10,
            kernel_size=(7, 7),
            strides=(2, 2),
            padding=[(3, 3), (3, 3)],
        )(x)
        x = nn.relu(x)
        return x

model = ConvBlock()

rng = jax.random.PRNGKey(42)
batch_shape = (1, 32, 32, 3)
init_array = jnp.ones(batch_shape)
variables = model.init(rng, init_array)

def inference(x):
    return model.apply(variables, x).sum()

inference_tf = jax2tf.convert(inference, enable_xla=True)
inference_tf = tf.function(inference_tf, autograph=False)

inference_onnx = tf2onnx.convert.from_function(inference_tf, input_signature=[tf.TensorSpec(batch_shape)], opset=16)
model_proto, external_tensor_storage = inference_onnx

session = onnxruntime.InferenceSession(model_proto.SerializeToString())

rng, subkey = jax.random.split(rng)
images = jax.random.normal(subkey, shape=batch_shape)

input_name = session.get_inputs()[0].name  
onnx_output = session.run([], {input_name: images})[0]

jax_output = inference(images)

print(f"onnx_output: {onnx_output}")
print(f"jax_output: {jax_output}")

Stack Trace


InvalidGraph Traceback (most recent call last)
Input In [6], in
40 inference_onnx = tf2onnx.convert.from_function(inference_tf, input_signature=[tf.TensorSpec(batch_shape)], opset=16)
41 model_proto, external_tensor_storage = inference_onnx
---> 43 session = onnxruntime.InferenceSession(model_proto.SerializeToString())
45 rng, subkey = jax.random.split(rng)
46 images = jax.random.normal(subkey, shape=batch_shape)

File ~/.virtualenvs/py39/lib/python3.9/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py:335, in InferenceSession.init(self, path_or_bytes, sess_options, providers, provider_options, **kwargs)
332 disabled_optimizers = kwargs['disabled_optimizers'] if 'disabled_optimizers' in kwargs else None
334 try:
--> 335 self._create_inference_session(providers, provider_options, disabled_optimizers)
336 except ValueError:
337 if self._enable_fallback:

File ~/.virtualenvs/py39/lib/python3.9/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py:372, in InferenceSession._create_inference_session(self, providers, provider_options, disabled_optimizers)
370 sess = C.InferenceSession(session_options, self._model_path, True, self._read_config_from_model)
371 else:
--> 372 sess = C.InferenceSession(session_options, self._model_bytes, False, self._read_config_from_model)
374 if disabled_optimizers is None:
375 disabled_optimizers = set()

InvalidGraph: [ONNXRuntimeError] : 10 : INVALID_GRAPH : This is an invalid model. In Node, ("XlaConvV2", XlaConvV2, "", -1) : ("args_0": tensor(float),"Const:0": tensor(float),"XlaConvV2/window_strides:0": tensor(int32),"XlaConvV2/padding:0": tensor(int32),"XlaConvV2/lhs_dilation:0": tensor(int32),"XlaConvV2/lhs_dilation:0": tensor(int32),"XlaConvV2/feature_group_count:0": tensor(int32),) -> ("XlaConvV2:0",) , Error No Op registered for XlaConvV2 with domain_version of 16

Metadata

Metadata

Assignees

No one assigned

    Labels

    unsupported opsIssues related to unsupported operators

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions