jax2onnx converts your JAX, Flax NNX, Flax Linen, Equinox functions directly into the ONNX format.
-
simple API
Easily convert JAX callables—including Flax NNX, Flax Linen and Equinox models—into ONNX format usingto_onnx(...). -
model structure preserved
With@onnx_function, submodules appear as named functions in the ONNX graph (e.g. in Netron). Useful for readability and reuse. -
dynamic input support
Use abstract dimensions like'B'or pass scalars as runtime inputs. Models stay flexible without retracing. -
plugin-based extensibility
Add support for new primitives by writing small, local plugins. -
onnx-ir native pipeline
Conversion, optimization, and post-processing all run on the typedonnx_irtoolkit—no protobuf juggling—and stay memory-lean before the final ONNX serialization. -
Netron-friendly outputs
Generated graphs carry shape/type annotations and a clean hierarchy, so tools like Netron stay easy to read.
Install and export your first model in minutes:
pip install jax2onnxConvert your JAX callable to ONNX in just a few lines:
from flax import nnx
from jax2onnx import to_onnx
# Define a simple MLP (from Flax docs)
class MLP(nnx.Module):
def __init__(self, din, dmid, dout, *, rngs):
self.linear1 = nnx.Linear(din, dmid, rngs=rngs)
self.dropout = nnx.Dropout(rate=0.1, rngs=rngs)
self.bn = nnx.BatchNorm(dmid, rngs=rngs)
self.linear2 = nnx.Linear(dmid, dout, rngs=rngs)
def __call__(self, x):
x = nnx.gelu(self.dropout(self.bn(self.linear1(x))))
return self.linear2(x)
# Instantiate model
my_callable = MLP(din=30, dmid=20, dout=10, rngs=nnx.Rngs(0))
# Export straight to disk without keeping the proto in memory
to_onnx(
my_callable,
[("B", 30)],
return_mode="file",
output_path="my_callable.onnx",
)🔎 See it visualized: my_callable.onnx
ONNX functions help encapsulate reusable subgraphs. Simply use the @onnx_function decorator to make your callable an ONNX function.
Just an @onnx_function decorator to make your callable an ONNX function
from flax import nnx
from jax2onnx import onnx_function, to_onnx
# just an @onnx_function decorator to make your callable an ONNX function
@onnx_function
class MLPBlock(nnx.Module):
def __init__(self, dim, *, rngs):
self.linear1 = nnx.Linear(dim, dim, rngs=rngs)
self.linear2 = nnx.Linear(dim, dim, rngs=rngs)
self.batchnorm = nnx.BatchNorm(dim, rngs=rngs)
def __call__(self, x):
return nnx.gelu(self.linear2(self.batchnorm(nnx.gelu(self.linear1(x)))))
# Use it inside another module
class MyModel(nnx.Module):
def __init__(self, dim, *, rngs):
self.block1 = MLPBlock(dim, rngs=rngs)
self.block2 = MLPBlock(dim, rngs=rngs)
def __call__(self, x):
return self.block2(self.block1(x))
callable = MyModel(256, rngs=nnx.Rngs(0))
to_onnx(
callable,
[(100, 256)],
return_mode="file",
output_path="model_with_function.onnx",
)🔎 See it visualized: model_with_function.onnx
-
Language: GPT-OSS (open-source MoE Transformer)
- Architecture: Flax/NNX + Equinox reference stacks with gating/routing capture, MoE MLP rebuilds, and deterministic ONNX exporters (see
jax2onnx/plugins/examples/nnx/gpt_oss_flax.pyandjax2onnx/plugins/examples/eqx/gpt_oss.py). - Structural graph:
- How-to: Getting GPT-OSS weights into jax2onnx
- Equivalence check: Routing parity harness · Flax parity tests · Equinox parity tests
- Optional pretrained weights: openai/gpt-oss-20b · openai/gpt-oss-120b (weights and model cards list
license: apache-2.0)
- Architecture: Flax/NNX + Equinox reference stacks with gating/routing capture, MoE MLP rebuilds, and deterministic ONNX exporters (see
-
Vision: DINOv3
- Architecture: Equimo’s clean-room Equinox/JAX implementation, following Meta AI’s DINOv3 paper. Flax/NNX parity modules now live under
jax2onnx/plugins/examples/nnx/dinov3.py(randomly initialised example stack for IR-only exports). - Structural graphs (selected examples):
- How-to: Getting Meta weights into jax2onnx
- Equivalence check: Comparing Meta vs jax2onnx ONNX
- Optional pretrained weights (Meta AI): facebook/dinov3-vitb16-pretrain-lvd1689m (other variants live under the same namespace) — DINOv3 license applies; review before downloading or redistributing.
- Architecture: Equimo’s clean-room Equinox/JAX implementation, following Meta AI’s DINOv3 paper. Flax/NNX parity modules now live under
Tip
JAX · Flax · Equinox — explore everything that’s supported and see it in action.
- ✅ Support matrix: status per component
- 🧪 Exact regression testcase for each entry
- 🔍 One-click Netron graph to inspect nodes, shapes, attributes
- 🧩 Examples that compose multiple components (Conv→Norm→Activation→Pool, MLP w/ LayerNorm+Dropout,
reshape/transpose/concat,scan/while_loop,gather/scatter, …)
Links: Open support matrix ↗ · Browse examples ↗
- Broaden coverage of JAX, Flax NNX/Linen, and Equinox components.
- Expand SotA example support for vision and language models.
- Improve support for physics-based simulations
- 0.11.0:
- Initial Flax Linen support: core layers (Dense/DenseGeneral, Conv/ConvTranspose/ConvLocal, pooling, BatchNorm/LayerNorm/GroupNorm/RMSNorm/InstanceNorm), Dropout, Einsum/Embed, spectral/weight norm wrappers, activation coverage (GELU plus glu/hard_/log_/relu6/silu-swish/tanh/normalize/one_hot), attention stack (dot_product_attention, dot_product_attention_weights, make_attention_mask/make_causal_mask, SelfAttention, MultiHeadDotProductAttention, MultiHeadAttention), recurrent stack (SimpleCell, GRUCell, MGUCell, LSTMCell, OptimizedLSTMCell, ConvLSTMCell, RNN, Bidirectional), and Linen examples (MLP/CNN/Sequential).
- Modernized IR optimization pipeline: standard onnx_ir CSE pass adoption, removed legacy helpers/getattr patterns, and simplified tests with direct graph iteration.
See past_versions for the full release archive.
If conversion doesn't work out of the box, it could be due to:
-
Non-dynamic function references:
JAXPR-based conversion requires function references to be resolved dynamically at call-time.
Solution: Wrap your function call inside a lambda to enforce dynamic resolution:my_dynamic_callable_function = lambda x: original_function(x)
-
Unsupported primitives:
The callable may use a primitive not yet or not fully supported byjax2onnx.
Solution: Write a plugin to handle the unsupported function (this is straightforward!).
Looking for provenance details while debugging? Check out the new Stacktrace Metadata guide.
We warmly welcome contributions!
How you can help:
- Add a plugin: Extend
jax2onnxby writing a simple Python file injax2onnx/plugins: a primitive or an example. The Plugin Quickstart walks through the process step-by-step. - Bug fixes & improvements: PRs and issues are always welcome.
Latest supported version of major dependencies:
| Library | Versions |
|---|---|
JAX |
0.8.2 |
Flax |
0.12.2 |
Equinox |
0.13.2 |
onnx-ir |
0.1.13 |
onnx |
1.20.0 |
onnxruntime |
1.23.2 |
For exact pins and extras, see pyproject.toml.
This project is licensed under the Apache License, Version 2.0. See LICENSE for details.
✨ Special thanks to @clementpoiret for initiating Equinox support and for Equimo, which brings modern vision models—such as DINOv3—to JAX/Equinox.
✨ Special thanks to @justinchuby for introducing onnx-ir as a scalable and more efficient way to handle ONNX model construction.
✨ Special thanks to @atveit for introducing us to gpt-oss-jax-vs-torch-numerical-comparison.
✨ Special thanks for example contributions to @burakssen, @Cadynum, @clementpoiret and @PVirie
✨ Special thanks for plugin contributions to @burakssen, @clementpoiret, @Clouder0, @rakadam and benmacadam64
✨ Special thanks to @benmacadam64 for championing the complex-number handling initiative.
✨ Special thanks to tumaer/JAXFLUIDS for contributing valuable insights rooted in physics simulation use cases.
✨ Special thanks to @lutzroeder for making shapes internal to ONNX function visible in his great Netron viewer.
✨ Special thanks to the community members involved in:
✨ Special thanks to @limarta, whose elegant jaxpr-to-ONNX demonstration significantly inspired this project.
Happy converting! 🎉