jax2onnx
converts your JAX/Flax functions directly into the ONNX format.
-
Simple API
Convert any JAX/Flax model to ONNX usingto_onnx(...)
-
Model structure preserved
With@onnx_function
, submodules appear as named functions in the ONNX graph (e.g. in Netron). Useful for readability and reuse. -
Dynamic input support
Use abstract dimensions like'B'
or pass scalars as runtime inputs. Models stay flexible without retracing. -
Plugin-based extensibility
Add support for new primitives by writing small, local plugins. -
Netron-friendly outputs
All generated ONNX graphs include shape/type annotations and are structured for clear visualization.
Convert your JAX callable to ONNX in just a few lines:
import onnx
from flax import nnx
from jax2onnx import to_onnx
# Define a simple MLP (from Flax docs)
class MLP(nnx.Module):
def __init__(self, din, dmid, dout, *, rngs):
self.linear1 = nnx.Linear(din, dmid, rngs=rngs)
self.dropout = nnx.Dropout(rate=0.1, rngs=rngs)
self.bn = nnx.BatchNorm(dmid, rngs=rngs)
self.linear2 = nnx.Linear(dmid, dout, rngs=rngs)
def __call__(self, x):
x = nnx.gelu(self.dropout(self.bn(self.linear1(x))))
return self.linear2(x)
# Instantiate model
my_callable = MLP(din=30, dmid=20, dout=10, rngs=nnx.Rngs(0))
# Convert to ONNX
onnx_model = to_onnx(my_callable, [("B", 30)])
# Save the model
onnx.save_model(onnx_model, "my_callable.onnx")
🔎 See it visualized: my_callable.onnx
ONNX functions help encapsulate reusable subgraphs. Simply use the @onnx_function
decorator to make your callable an ONNX function.
Just an @onnx_function decorator to make your callable an ONNX function
from onnx import save_model
from flax import nnx
from jax2onnx import onnx_function, to_onnx
# just an @onnx_function decorator to make your callable an ONNX function
@onnx_function
class MLPBlock(nnx.Module):
def __init__(self, dim, *, rngs):
self.linear1 = nnx.Linear(dim, dim, rngs=rngs)
self.linear2 = nnx.Linear(dim, dim, rngs=rngs)
self.batchnorm = nnx.BatchNorm(dim, rngs=rngs)
def __call__(self, x):
return nnx.gelu(self.linear2(self.batchnorm(nnx.gelu(self.linear1(x)))))
# Use it inside another module
class MyModel(nnx.Module):
def __init__(self, dim, *, rngs):
self.block1 = MLPBlock(dim, rngs=rngs)
self.block2 = MLPBlock(dim, rngs=rngs)
def __call__(self, x):
return self.block2(self.block1(x))
callable = MyModel(256, rngs=nnx.Rngs(0))
model = to_onnx(callable, [(100, 256)])
save_model(model, "docs/onnx/model_with_function.onnx")
🔎 See it visualized: model_with_function.onnx
- Ongoing: Expanding JAX component coverage.
- 0.5.3: Some more ONNX function support ... function reuse, make graph optimizer work within functions, allow user friendly var names
- 0.5.2: Add support for additional primitives:
jnp.where
,jnp.arange
,jnp.linspace
.
- 0.5.1 (PyPI): Add support for subgraph using primitives:
lax.while_loop
,lax.cond
,lax.fori_loop
,lax.scan
,lax.cond
.
- 0.5.0: Improved dynamic batch dimension handling by leveraging shape polymorphism for more robust and flexible model export.
Added support for
jnp.sign
,jnp.abs
,jnp.iota
primitives. - 0.4.4: Added support for
lax.cos
,lax.cosh
,lax.sin
,lax.sinh
andlax.scatter
primitives. - 0.4.3: Fixed a bug in the validation of JAX callable outputs against their ONNX counterparts. This fix exposed previously hidden failing tests, which are now fixed.
- 0.4.2: Cleanup and fixes to the basic ONNX function release.
- 0.4.1 (ONNX functions): Introducing simple ONNX function support. Making use of ONNX functions is easy for the user: just a
@onnx_function
decorator making a callable an ONNX function. Each@onnx_function
decorator creates a new ONNX function instance on the call graph. - 0.3.2: relaxed the minimum Python version to 3.10.
- 0.3.0: Streamlined the plugin system with automatic registration and simplified integration of custom primitives.
- 0.2.0 (First PyPI Release): Rebased the implementation on
jaxpr
, improving usability and adding low-levellax
components. - 0.1.0 (Initial Approach, Not Released to PyPI): Produced ONNX exports for some
nnx
components andnnx
-based examples, including a VisualTransformer.
If conversion doesn't work out of the box, it could be due to:
-
Non-dynamic function references:
JAXPR-based conversion requires function references to be resolved dynamically at call-time.
Solution: Wrap your function call inside a lambda to enforce dynamic resolution:my_dynamic_callable_function = lambda x: original_function(x)
-
Unsupported primitives:
The callable may use a primitive not yet or not fully supported byjax2onnx
.
Solution: Write a plugin to handle the unsupported function (this is straightforward!).
Legend:
✅ = Passed
❌ = Failed
➖ = No testcase yet
Component | Description | Testcases | Since |
---|---|---|---|
AutoEncoder | A simple autoencoder example. | simple_autoencoder ✅ |
v0.2.0 |
CNN | A simple convolutional neural network (CNN). | simple_cnn_dynamic ✅simple_cnn ✅ |
v0.1.0 |
ClassificationHead | Classification head for Vision Transformer | classification_head_dynamic ✅classification_head ✅ |
v0.4.0 |
ClassificationHeadFlatten | Classification head for Vision Transformer | classification_head_flat_dynamic ✅classification_head_flat ✅ |
v0.4.0 |
ConcatClsToken | Concatenate CLS token to the input embedding | concat_cls_token_dynamic ✅concat_cls_token ✅ |
v0.4.0 |
ConcatClsTokenFlatten | Concatenate CLS token to the input embedding | concat_cls_token_flat_dynamic ✅concat_cls_token_flat ✅ |
v0.4.0 |
ConvEmbedding | Convolutional Token Embedding for MNIST with hierarchical downsampling. | mnist_conv_embedding_dynamic ✅mnist_conv_embedding ✅ |
v0.1.0 |
ConvEmbeddingFlatten | Convolutional Token Embedding for MNIST with hierarchical downsampling. | mnist_conv_embedding_flat_dynamic ✅mnist_conv_embedding_flat ✅ |
v0.1.0 |
FeedForward | MLP in Transformer | feed_forward_dynamic ✅feed_forward ✅ |
v0.1.0 |
FeedForwardFlatten | MLP in Transformer | feed_forward_flat_dynamic ✅feed_forward_flat ✅ |
v0.1.0 |
ForiLoop | fori_loop example | fori_loop_counter ✅ |
v0.5.1 |
GetToken | Get the CLS token from the input embedding | get_token_dynamic ✅get_token ✅ |
v0.4.0 |
GetTokenFlatten | Get the CLS token from the input embedding | get_token_flat_dynamic ✅get_token_flat ✅ |
v0.4.0 |
MLP | A simple Multi-Layer Perceptron (MLP) with BatchNorm, Dropout, and GELU activation. | simple_mlp_dynamic ✅simple_mlp ✅simple_mlp_with_call_params_dynamic ✅simple_mlp_with_call_params ✅ |
v0.1.0 |
MultiHeadAttention | This is a multi-head attention module implemented by Flax/nnx that has no ONNX correspondent on the same granularity. | multihead_attention_nn_dynamic ✅multihead_attention_nn ✅multihead_attention_nnx_dynamic ✅multihead_attention_nnx ✅ |
v0.2.0 |
PatchEmbedding | Cutting the image into patches and linearly embedding them. | patch_embedding_dynamic ✅patch_embedding ✅ |
v0.1.0 |
PatchEmbeddingFlatten | Cutting the image into patches and linearly embedding them. | patch_embedding_flat_dynamic ✅patch_embedding_flat ✅ |
v0.1.0 |
PositionalEmbedding | Add positional embedding to the input embedding | positional_embedding_dynamic ✅positional_embedding ✅ |
v0.4.0 |
PositionalEmbeddingFlatten | Add positional embedding to the input embedding | positional_embedding_flat_dynamic ✅positional_embedding_flat ✅ |
v0.4.0 |
TransformerBlock | Transformer from 'Attention Is All You Need.' | transformer_block_dynamic ✅transformer_block ✅ |
v0.1.0 |
TransformerBlockFlatten | Transformer from 'Attention Is All You Need.' | transformer_block_flat_dynamic ✅transformer_block_flat ✅ |
v0.1.0 |
TransformerStack | Stack of Transformer blocks | transformer_stack_dynamic ✅transformer_stack ✅ |
v0.1.0 |
TransformerStackFlatten | Stack of Transformer blocks | transformer_stack_flat_dynamic ✅transformer_stack_flat ✅ |
v0.1.0 |
VisionTransformer | A Vision Transformer (ViT) model for MNIST with configurable embedding type. | vit_conv_embedding_dynamic ✅vit_conv_embedding ✅vit_patch_embedding ✅ |
v0.2.0 |
VisionTransformerFlatten | A Vision Transformer (ViT) model for MNIST with configurable embedding type. | vit_conv_embedding_flat_dynamic ✅vit_conv_embedding_flat ✅vit_patch_embedding_flat_dynamic ✅vit_patch_embedding_flat ✅ |
v0.2.0 |
onnx_functions_000 | one function on an outer layer. | 000_one_function_on_outer_layer_dynamic ✅000_one_function_on_outer_layer ✅ |
v0.4.0 |
onnx_functions_001 | one function on an inner layer. | 001_one_function_inner_dynamic ✅001_one_function_inner ✅ |
v0.4.0 |
onnx_functions_002 | two nested functions. | 002_two_nested_functions_dynamic ✅002_two_nested_functions ✅ |
v0.4.0 |
onnx_functions_003 | two nested functions. | 003_two_simple_nested_functions_dynamic ✅003_two_simple_nested_functions ✅ |
v0.4.0 |
onnx_functions_004 | nested function plus component | 004_nested_function_plus_component_dynamic ✅004_nested_function_plus_component ✅ |
v0.4.0 |
onnx_functions_005 | nested function plus more components | 005_nested_function_plus_component_dynamic ✅005_nested_function_plus_component ✅ |
v0.4.0 |
onnx_functions_006 | one function on an outer layer. | 006_one_function_outer_dynamic ✅006_one_function_outer ✅ |
v0.4.0 |
onnx_functions_007 | transformer block with nested mlp block with call parameter | 007_transformer_block_dynamic ✅007_transformer_block ✅ |
v0.4.0 |
onnx_functions_008 | transformer block with nested mlp block no call parameter | 008_transformer_block_dynamic ✅008_transformer_block ✅ |
v0.4.0 |
onnx_functions_009 | transformer block using decorator on class and function | 009_transformer_block_dynamic ✅009_transformer_block ✅ |
v0.4.0 |
onnx_functions_010 | transformer stack | 010_transformer_stack_dynamic ✅010_transformer_stack ✅ |
v0.4.0 |
onnx_functions_012 | Vision Transformer (ViT) | 012_vit_conv_embedding_dynamic ✅012_vit_conv_embedding ✅ |
v0.4.0 |
onnx_functions_013 | Vision Transformer (ViT) | 013_vit_conv_embedding_with_call_params_dynamic ✅013_vit_conv_embedding_with_call_params ✅013_vit_conv_embedding_with_internal_call_params_dynamic ✅013_vit_conv_embedding_with_internal_call_params ✅ |
v0.4.0 |
onnx_functions_014 | one function on an outer layer. | 014_one_function_with_input_param_with_default_value ✅014_one_function_without_input_param_with_default_value_dynamic ✅014_one_function_without_input_param_with_default_value ✅ |
v0.4.0 |
onnx_functions_015 | one function on an outer layer. | 015_one_function_with_input_param_without_default_value_dynamic ✅015_one_function_with_input_param_without_default_value ✅ |
v0.4.0 |
onnx_functions_016 | nested function plus more components | 016_internal_function_with_input_param_with_default_value_dynamic ✅016_internal_function_with_input_param_with_default_value ✅ |
v0.4.0 |
Versions of Major Dependencies:
Library | Versions |
---|---|
JAX |
0.6.0 |
Flax |
0.10.6 |
onnx |
1.17.0 |
onnxruntime |
1.21.1 |
Note: For more details, check pyproject.toml
.
- Currently not all JAX/Flax components are supported (you can easily help expand this coverage!).
- Function references need dynamic resolution at call-time.
- ONNX graph composition is done in-memory before saving to disk, potentially causing memory issues with very large models.
We warmly welcome contributions!
How you can help:
- Add a plugin: Extend
jax2onnx
by writing a simple Python file injax2onnx/plugins
. a custom primitive or an example. - Bug fixes & improvements: PRs and issues are always welcome.
Install from PyPI:
pip install jax2onnx
This project is licensed under the Apache License, Version 2.0. See LICENSE
for details.
Special thanks for plugin contributions to @burakksen
Special thanks to @lutzroeder for making shapes internal to ONNX function visible in his great Netron viewer.
Special thanks to the community members involved in:
A huge thanks especially to @limarta, whose elegant jaxpr-to-ONNX demonstration significantly inspired this project.
Happy converting! 🎉