jax2onnx converts your JAX, Flax NNX, Flax Linen, Equinox functions directly into the ONNX format.
Read the full documentation here
pip install jax2onnxfrom jax2onnx import to_onnx
from flax import nnx
model = MyFlaxModel(...)
to_onnx(model, [("B", 32)], return_mode="file", output_path="model.onnx")We warmly welcome contributions! Please check our Developer Guide for plugin tutorials and architecture details.
Apache License, Version 2.0. See LICENSE.
A huge thank you to all our contributors and the community for their help and inspiration!
Happy converting! 🎉