Examples build on JAX, Flax, nnx. We use jax2onnx to convert models to onnx format.
mnist_vit- MNIST classification using a vision transformer with convolutional embedding.
Install dependencies, train the model and export it to onnx format:
poetry install
poetry run python jaxamples/mnist_vit.py