The preferred data format for images on JAX is NHWC, but ONNX uses NCHW instead. When exporting a model directly, this results in many operations getting "wrapped around" by extra operations transposing the data, like this:
This was also the case when exporting from Tensorflow, which also defaults to NHWC, but the TF/Keras exporter (https://github.com/onnx/tensorflow-onnx/blob/4fed7de9534b6a084f7f2326bae775545bd97f9e/tf2onnx/convert.py#L412) supports the following arguments:
inputs_as_nchw: transpose inputs in list from nhwc to nchw
outputs_as_nchw: transpose outputs in list from nhwc to nchw
With these enabled, the graph looked clean and mostly free of the unnecessary transpose operations as everything could be treated as NCHW instead.
It would be nice if jax2onnx had a similar mechanism.
Thanks in advance.