Skip to content

Feature Request: Support setting inputs and outputs as NCHW to prevent having to transpose data several times through the model #145

@Artoriuz

Description

@Artoriuz

The preferred data format for images on JAX is NHWC, but ONNX uses NCHW instead. When exporting a model directly, this results in many operations getting "wrapped around" by extra operations transposing the data, like this:

Image

This was also the case when exporting from Tensorflow, which also defaults to NHWC, but the TF/Keras exporter (https://github.com/onnx/tensorflow-onnx/blob/4fed7de9534b6a084f7f2326bae775545bd97f9e/tf2onnx/convert.py#L412) supports the following arguments:

inputs_as_nchw: transpose inputs in list from nhwc to nchw
outputs_as_nchw: transpose outputs in list from nhwc to nchw

With these enabled, the graph looked clean and mostly free of the unnecessary transpose operations as everything could be treated as NCHW instead.

It would be nice if jax2onnx had a similar mechanism.

Thanks in advance.

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions