jax2onnx
converts your JAX/Flax(nnx) functions directly into the ONNX format.
-
Simple API
Easily convert JAX callables—including Flax (NNX) models—into ONNX format usingto_onnx(...)
. -
Model structure preserved
With@onnx_function
, submodules appear as named functions in the ONNX graph (e.g. in Netron). Useful for readability and reuse. -
Dynamic input support
Use abstract dimensions like'B'
or pass scalars as runtime inputs. Models stay flexible without retracing. -
Plugin-based extensibility
Add support for new primitives by writing small, local plugins. -
Netron-friendly outputs
All generated ONNX graphs include shape/type annotations and are structured for clear visualization.
Convert your JAX callable to ONNX in just a few lines:
import onnx
from flax import nnx
from jax2onnx import to_onnx
# Define a simple MLP (from Flax docs)
class MLP(nnx.Module):
def __init__(self, din, dmid, dout, *, rngs):
self.linear1 = nnx.Linear(din, dmid, rngs=rngs)
self.dropout = nnx.Dropout(rate=0.1, rngs=rngs)
self.bn = nnx.BatchNorm(dmid, rngs=rngs)
self.linear2 = nnx.Linear(dmid, dout, rngs=rngs)
def __call__(self, x):
x = nnx.gelu(self.dropout(self.bn(self.linear1(x))))
return self.linear2(x)
# Instantiate model
my_callable = MLP(din=30, dmid=20, dout=10, rngs=nnx.Rngs(0))
# Convert to ONNX
onnx_model = to_onnx(my_callable, [("B", 30)])
# Save the model
onnx.save_model(onnx_model, "my_callable.onnx")
🔎 See it visualized: my_callable.onnx
ONNX functions help encapsulate reusable subgraphs. Simply use the @onnx_function
decorator to make your callable an ONNX function.
Just an @onnx_function decorator to make your callable an ONNX function
from onnx import save_model
from flax import nnx
from jax2onnx import onnx_function, to_onnx
# just an @onnx_function decorator to make your callable an ONNX function
@onnx_function
class MLPBlock(nnx.Module):
def __init__(self, dim, *, rngs):
self.linear1 = nnx.Linear(dim, dim, rngs=rngs)
self.linear2 = nnx.Linear(dim, dim, rngs=rngs)
self.batchnorm = nnx.BatchNorm(dim, rngs=rngs)
def __call__(self, x):
return nnx.gelu(self.linear2(self.batchnorm(nnx.gelu(self.linear1(x)))))
# Use it inside another module
class MyModel(nnx.Module):
def __init__(self, dim, *, rngs):
self.block1 = MLPBlock(dim, rngs=rngs)
self.block2 = MLPBlock(dim, rngs=rngs)
def __call__(self, x):
return self.block2(self.block1(x))
callable = MyModel(256, rngs=nnx.Rngs(0))
model = to_onnx(callable, [(100, 256)])
save_model(model, "docs/onnx/model_with_function.onnx")
🔎 See it visualized: model_with_function.onnx
-
Ongoing
- Expanding coverage of JAX and Flax (NNX) components
- Enhancing support for physics-based simulations
-
Under Evaluation
- Integrating
onnx-ir
as a backend to improve ONNX model construction, memory efficiency, and performance
- Integrating
-
Upcoming
- Advanced ONNX function support, including function reuse, optimized internal graph structure, and improved variable naming for clarity and readability
- Support for
equinox
models
- 0.7.0 (PyPI):
- Added a GPT-2 model example based on nanoGPT, featuring ONNX function support and attention masking
- New support for
jnp.concatenate
,jnp.take
,nnx.Embed
- ONNX models are now hosted on Hugging Face
- 0.6.5: Improved support for
nnx.batch_norm
,nnx.group_norm
,nnx.layer_norm
,nnx.rms_norm
,lax.broadcast_in_dim
,lax.cond
,lax.fori_loop
,lax.integer_pow
,lax.scan
,lax.scatter
,lax.scatter_add
,lax.scatter_mul
andlax.while_loop
; added support forlax.and
,lax.rem
andlax.remat2
. - 0.6.4: Improved support for
lax.scatter_mul
. - 0.6.3: Double precision fixes for
lax.fori_loop
andlax.while_loop
. Fixed bugs inlax.scan
andjnp.where
. - 0.6.2: Fixed bugs in
nnx.conv
andlax.reshape
; added new primitivejnp.prod
. - 0.6.1: Improved support for
lax.cond
andlax.select_n
; added new primitives (lax.reduce_and
,lax.reduce_or
,lax.reduce_prod
,lax.reduce_xor
); and introduced new examples forjnp.select
andjnp.sort
. - 0.6.0: Introduced the
enable_double_precision
parameter (default:False
) to support physics simulations, and enhanced handling oflax.scatter
. - 0.5.2: Add support for additional primitives:
jnp.where
,jnp.arange
,jnp.linspace
. - 0.5.1: Add support for subgraph using primitives:
lax.while_loop
,lax.cond
,lax.fori_loop
,lax.scan
. - 0.5.0: Improved dynamic batch dimension handling by leveraging shape polymorphism for more robust and flexible model export.
Added support for
jnp.sign
,jnp.abs
,jnp.iota
primitives. - 0.4.4: Added support for
lax.cos
,lax.cosh
,lax.sin
,lax.sinh
andlax.scatter
primitives. - 0.4.3: Fixed a bug in the validation of JAX callable outputs against their ONNX counterparts. This fix exposed previously hidden failing tests, which are now fixed.
- 0.4.2: Cleanup and fixes to the basic ONNX function release.
- 0.4.1 (ONNX functions): Introducing simple ONNX function support. Making use of ONNX functions is easy for the user: just a
@onnx_function
decorator making a callable an ONNX function. Each@onnx_function
decorator creates a new ONNX function instance on the call graph. - 0.3.2: relaxed the minimum Python version to 3.10.
- 0.3.0: Streamlined the plugin system with automatic registration and simplified integration of custom primitives.
- 0.2.0 (First PyPI Release): Rebased the implementation on
jaxpr
, improving usability and adding low-levellax
components. - 0.1.0 (Initial Approach, Not Released to PyPI): Produced ONNX exports for some
nnx
components andnnx
-based examples, including a VisualTransformer.
If conversion doesn't work out of the box, it could be due to:
-
Non-dynamic function references:
JAXPR-based conversion requires function references to be resolved dynamically at call-time.
Solution: Wrap your function call inside a lambda to enforce dynamic resolution:my_dynamic_callable_function = lambda x: original_function(x)
-
Unsupported primitives:
The callable may use a primitive not yet or not fully supported byjax2onnx
.
Solution: Write a plugin to handle the unsupported function (this is straightforward!).
Legend:
✅ = Passed
❌ = Failed
➖ = No testcase yet
Versions of Major Dependencies:
Library | Versions |
---|---|
JAX |
0.6.2 |
Flax |
0.10.6 |
onnx |
1.18.0 |
onnxruntime |
1.22.0 |
Note: For more details, check pyproject.toml
.
- Currently not all JAX/Flax components are supported (you can easily help expand this coverage!).
- Function references need dynamic resolution at call-time.
- ONNX graph composition is done in-memory before saving to disk, potentially causing memory issues with very large models.
We warmly welcome contributions!
How you can help:
- Add a plugin: Extend
jax2onnx
by writing a simple Python file injax2onnx/plugins
. a custom primitive or an example. - Bug fixes & improvements: PRs and issues are always welcome.
Install from PyPI:
pip install jax2onnx
This project is licensed under the Apache License, Version 2.0. See LICENSE
for details.
Special thanks for example contributions to @burakssen, @Cadynum and @clementpoiret
Special thanks for plugin contributions to @burakssen
Special thanks to tumaer/JAXFLUIDS for contributing valuable insights rooted in physics simulation use cases.
Special thanks to @lutzroeder for making shapes internal to ONNX function visible in his great Netron viewer.
Special thanks to the community members involved in:
A huge thanks especially to @limarta, whose elegant jaxpr-to-ONNX demonstration significantly inspired this project.
Happy converting! 🎉