Skip to content

enpasos/jax2onnx

Repository files navigation

jax2onnx 🌟

jax2onnx converts your JAX/Flax functions directly into the ONNX format.

img.png

✨ Key Features

  • Simple API
    Convert any JAX/Flax model to ONNX using to_onnx(...)

  • Model structure preserved
    With @onnx_function, submodules appear as named functions in the ONNX graph (e.g. in Netron). Useful for readability and reuse.

  • Dynamic input support
    Use abstract dimensions like 'B' or pass scalars as runtime inputs. Models stay flexible without retracing.

  • Plugin-based extensibility
    Add support for new primitives by writing small, local plugins.

  • Netron-friendly outputs
    All generated ONNX graphs include shape/type annotations and are structured for clear visualization.


🚀 Quickstart

Convert your JAX callable to ONNX in just a few lines:

import onnx
from flax import nnx
from jax2onnx import to_onnx

# Define a simple MLP (from Flax docs)
class MLP(nnx.Module):
    def __init__(self, din, dmid, dout, *, rngs): 
        self.linear1 = nnx.Linear(din, dmid, rngs=rngs)
        self.dropout = nnx.Dropout(rate=0.1, rngs=rngs)
        self.bn = nnx.BatchNorm(dmid, rngs=rngs)
        self.linear2 = nnx.Linear(dmid, dout, rngs=rngs) 
    def __call__(self, x): 
        x = nnx.gelu(self.dropout(self.bn(self.linear1(x))))
        return self.linear2(x)

# Instantiate model
my_callable = MLP(din=30, dmid=20, dout=10, rngs=nnx.Rngs(0))

# Convert to ONNX
onnx_model = to_onnx(my_callable, [("B", 30)])

# Save the model
onnx.save_model(onnx_model, "my_callable.onnx")

🔎 See it visualized: my_callable.onnx


🧠 ONNX Functions — Minimal Example

ONNX functions help encapsulate reusable subgraphs. Simply use the @onnx_function decorator to make your callable an ONNX function. Just an @onnx_function decorator to make your callable an ONNX function

from onnx import save_model
from flax import nnx
from jax2onnx import onnx_function, to_onnx

# just an @onnx_function decorator to make your callable an ONNX function
@onnx_function
class MLPBlock(nnx.Module):
  def __init__(self, dim, *, rngs):
    self.linear1 = nnx.Linear(dim, dim, rngs=rngs)
    self.linear2 = nnx.Linear(dim, dim, rngs=rngs)
    self.batchnorm = nnx.BatchNorm(dim, rngs=rngs)
  def __call__(self, x):
    return nnx.gelu(self.linear2(self.batchnorm(nnx.gelu(self.linear1(x)))))

# Use it inside another module
class MyModel(nnx.Module):
  def __init__(self, dim, *, rngs):
    self.block1 = MLPBlock(dim, rngs=rngs)
    self.block2 = MLPBlock(dim, rngs=rngs)
  def __call__(self, x):
    return self.block2(self.block1(x))

callable = MyModel(256, rngs=nnx.Rngs(0))
model = to_onnx(callable, [(100, 256)])
save_model(model, "docs/onnx/model_with_function.onnx")

🔎 See it visualized: model_with_function.onnx


📅 Roadmap and Releases

Planned Versions

  • Ongoing: Expanding JAX component coverage.
  • 0.5.3: Some more ONNX function support ... function reuse, make graph optimizer work within functions, allow user friendly var names
  • 0.5.2: Add support for additional primitives: jnp.where, jnp.arange, jnp.linspace.

Current Productive Version

  • 0.5.1 (PyPI): Add support for subgraph using primitives: lax.while_loop, lax.cond, lax.fori_loop, lax.scan, lax.cond.

Past Versions

  • 0.5.0: Improved dynamic batch dimension handling by leveraging shape polymorphism for more robust and flexible model export. Added support for jnp.sign, jnp.abs, jnp.iota primitives.
  • 0.4.4: Added support for lax.cos, lax.cosh, lax.sin, lax.sinh and lax.scatter primitives.
  • 0.4.3: Fixed a bug in the validation of JAX callable outputs against their ONNX counterparts. This fix exposed previously hidden failing tests, which are now fixed.
  • 0.4.2: Cleanup and fixes to the basic ONNX function release.
  • 0.4.1 (ONNX functions): Introducing simple ONNX function support. Making use of ONNX functions is easy for the user: just a @onnx_function decorator making a callable an ONNX function. Each @onnx_function decorator creates a new ONNX function instance on the call graph.
  • 0.3.2: relaxed the minimum Python version to 3.10.
  • 0.3.0: Streamlined the plugin system with automatic registration and simplified integration of custom primitives.
  • 0.2.0 (First PyPI Release): Rebased the implementation on jaxpr, improving usability and adding low-level lax components.
  • 0.1.0 (Initial Approach, Not Released to PyPI): Produced ONNX exports for some nnx components and nnx-based examples, including a VisualTransformer.

❓ Troubleshooting

If conversion doesn't work out of the box, it could be due to:

  • Non-dynamic function references:
    JAXPR-based conversion requires function references to be resolved dynamically at call-time.
    Solution: Wrap your function call inside a lambda to enforce dynamic resolution:

    my_dynamic_callable_function = lambda x: original_function(x)
  • Unsupported primitives:
    The callable may use a primitive not yet or not fully supported by jax2onnx.
    Solution: Write a plugin to handle the unsupported function (this is straightforward!).


🧩 Supported JAX/ONNX Components

JAX Component ONNX Components Testcases Since
core.dim_as_value Cast
Gather
Reshape
Shape
dim_as_value_dynamic
dim_as_value
v0.5.0
jnp.add Add add v0.1.0
jnp.concatenate Concat concatenate
concatenate_abstract_middle_dim_dynamic
concatenate_abstract_middle_dim
concatenate_tile_and_symbolic_dynamic
concatenate_tile_and_symbolic
v0.1.0
jnp.einsum Einsum einsum_vector_dot
einsum_matrix_vector
einsum_matrix_matrix_dynamic
einsum_matrix_matrix
einsum_transpose
einsum_batch_transpose_dynamic
einsum_batch_transpose
einsum_diag
einsum_sum_reduce
einsum_multi_operand
einsum_attention_logits_orig_dynamic
einsum_attention_logits_orig
einsum_attention_output_orig_dynamic
einsum_attention_output_orig
einsum_attention_logits_batched_dynamic
einsum_attention_logits_batched
einsum_attention_output_batched_dynamic
einsum_attention_output_batched
einsum_ellipsis_rank_mismatch
v0.1.0
jnp.matmul MatMul matmul_2d
matmul_1d_2d
matmul_2d_1d
matmul_dynamic_dynamic
matmul_dynamic
matmul_dynamic_a_dynamic
matmul_dynamic_a
matmul_1d
matmul_3d
v0.1.0
jnp.reshape Reshape reshape_1
reshape_2
reshape_3
reshape_4_dynamic
reshape_4
reshape_to_scalar
reshape_from_scalar
reshape_cnn_dynamic
reshape_cnn
v0.1.0
jnp.shape Shape shape_basic
shape_dynamic_dynamic
shape_dynamic
0.4.0
jnp.squeeze Squeeze squeeze_single_dim
squeeze_multiple_dims
squeeze_vit_output
squeeze_dynamic_batch_dynamic
squeeze_dynamic_batch
squeeze_all_dims
squeeze_negative_axis
squeeze_negative_axis_tuple
squeeze_dynamic_and_negative_axis_dynamic
squeeze_dynamic_and_negative_axis
v0.1.0
jnp.tile Tile tile_repeats
tile_a
tile_b
tile_c
tile_d
tile_dynamic_input_static
tile_dynamic_input_dynamic
tile_dynamic_input
tile_pad
tile_with_symbolic_repeats_static
tile_with_symbolic_repeats_dynamic
tile_with_symbolic_repeats
tile_param_symbolic_dynamic
tile_param_symbolic
v0.1.0
jnp.transpose Transpose transpose_basic
transpose_reverse
transpose_4d_dynamic
transpose_4d
transpose_square_matrix
transpose_high_dim
transpose_no_axes
transpose_3d_dynamic
transpose_3d
v0.1.0
lax.abs Abs abs v0.5.0
lax.add Add add v0.2.0
lax.argmax ArgMax argmax_test1
argmax_test2
v0.2.0
lax.argmin ArgMin argmin_test1
argmin_test2
v0.2.0
lax.broadcast_in_dim Expand
Identity
Reshape
broadcast_in_dim
broadcast_in_dim_2d_to_3d
broadcast_in_dim_scalar
broadcast_in_dim_batch_dynamic
broadcast_in_dim_batch
v0.2.0
lax.concatenate Concat concatenate
concatenate_axis1_dynamic
concatenate_axis1
concatenate_axis0
concatenate_3d
v0.2.0
lax.cond If cond_scalar v0.5.1
lax.conv Conv conv
conv2
v0.2.0
lax.convert_element_type Cast convert_element_type v0.2.0
lax.cos Cos cos v0.4.4
lax.cosh Cosh cosh v0.4.4
lax.device_put Identity device_put_array
device_put_scalar
v0.4.0
lax.div Div div v0.2.0
lax.dot_general MatMul dot_general v0.2.0
lax.dynamic_slice Slice dynamic_slice_test1
dynamic_slice_2d
dynamic_slice_3d
dynamic_slice_vit_like
dynamic_slice_vit_like_dynamic_dynamic
dynamic_slice_vit_like_dynamic
v0.1.0
lax.eq Equal eq v0.2.0
lax.exp Exp exp v0.2.0
lax.fori_loop Loop fori_loop_counter
fori_loop_zero
fori_loop_vector
fori_loop_example
v0.5.1
lax.gather GatherND gather_static
gather_dynamic_batch_simple_index_dynamic
gather_dynamic_batch_simple_index
v0.2.0
lax.gt Greater gt v0.2.0
lax.integer_pow Pow integer_pow v0.2.0
lax.iota Range iota_int32
iota_float32
broadcasted_iota
v0.5.0
lax.log Log log v0.2.0
lax.lt Less lt v0.2.0
lax.max Max max v0.2.0
lax.min Min min_test1 v0.1.0
lax.mul Mul mul_test1
mul_test2
v0.1.0
lax.ne Equal
Not
ne v0.2.0
lax.neg Neg neg v0.2.0
lax.reduce_max ReduceMax reduce_max v0.2.0
lax.reduce_min ReduceMin reduce_min v0.2.0
lax.reduce_sum ReduceSum reduce_sum v0.2.0
lax.scan Scan scan_cumsum
scan_carry_only
scan_multiple_sequences
scan_multiple_carry
scan_matrix_carry_multidim_xs
v0.5.1
lax.scatter ScatterElements scatter_set_axis0
scatter_set_middle
v0.4.4
lax.sign Sign sign v0.5.0
lax.sin Sin sin v0.4.4
lax.sinh Sinh sinh v0.4.4
lax.slice Slice slice_test1
slice_3d_none_strides
v0.1.0
lax.sort TopK sort_1d
sort_1d_empty
sort_1d_single
sort_1d_larger
sort_1d_specific_values
v0.2.0
lax.sqrt Sqrt sqrt v0.2.0
lax.square Mul square v0.2.0
lax.squeeze Squeeze squeeze v0.2.0
lax.stop_gradient Identity stop_gradient v0.2.0
lax.sub Sub sub_test1
sub_test2
v0.1.0
lax.tanh Tanh tanh v0.2.0
lax.transpose Transpose transpose_basic v0.2.0
lax.while_loop Loop while_loop_counter
while_loop_vector
v0.5.1
nn.dot_product_attention Cast
Div
Einsum
Gather
Shape
Softmax
Sqrt
dpa_basic
dpa_diff_heads_embed
dpa_batch4_seq16
dpa_float64
dpa_heads1_embed4
dpa_heads8_embed8
dpa_batch1_seq2
dpa_batch8_seq4
dpa_axis1
v0.1.0
nn.softmax Softmax softmax
softmax_2d
softmax_3d
v0.1.0
nnx.avg_pool AveragePool
Transpose
avg_pool_dynamic
avg_pool
avg_pool_same_padding_dynamic
avg_pool_same_padding
avg_pool_default_padding_dynamic
avg_pool_default_padding
avg_pool_stride1_dynamic
avg_pool_stride1
avg_pool_win3x3_stride2_dynamic
avg_pool_win3x3_stride2
avg_pool_stride_none_dynamic
avg_pool_stride_none
avg_pool_count_include_pad_false_dynamic
avg_pool_count_include_pad_false
v0.1.0
nnx.batch_norm BatchNormalization batch_norm_simple_dynamic
batch_norm_simple
batch_norm_2d_dynamic
batch_norm_2d
batch_norm_2d_use_bias_false_dynamic
batch_norm_2d_use_bias_false
batch_norm_2d_use_scale_false_dynamic
batch_norm_2d_use_scale_false
batch_norm_4d_dynamic
batch_norm_4d
batch_norm_4d_use_bias_false_dynamic
batch_norm_4d_use_bias_false
batch_norm_4d_use_scale_false_dynamic
batch_norm_4d_use_scale_false
batch_norm_minimal
v0.1.0
nnx.conv Conv
Transpose
conv_basic_bias_dynamic
conv_basic_bias
conv_basic_bias_2
conv_basic_bias_3
conv_stride2_bias
conv_no_bias_dynamic
conv_no_bias
conv_valid_padding
conv_stride1
conv_stride2
conv_different_kernel
conv_float64
conv_single_batch
conv_large_batch
v0.1.0
nnx.dot_product_attention Cast
Div
Einsum
Gather
Shape
Softmax
Sqrt
dpa_basic
dpa_diff_heads_embed
dpa_batch4_seq16
dpa_float64
dpa_heads1_embed4
dpa_heads8_embed8
dpa_batch1_seq2
dpa_batch8_seq4
dpa_axis1
v0.1.0
nnx.dropout Dropout dropout_init_params_dynamic
dropout_init_params
dropout_call_params_dynamic
dropout_call_params
v0.1.0
nnx.einsum Add
Einsum
einsum_module_with_bias
einsum_module_no_bias
v0.4.2
nnx.elu Elu elu v0.1.0
nnx.gelu Gelu gelu
gelu_1
gelu_2
gelu_3_dynamic
gelu_3
v0.1.0
nnx.group_norm GroupNormalization group_norm
group_norm_2
v0.3.0
nnx.layer_norm LayerNormalization layer_norm_dynamic
layer_norm
layer_norm_multiaxis_dynamic
layer_norm_multiaxis
v0.1.0
nnx.leaky_relu LeakyRelu leaky_relu v0.1.0
nnx.linear Gemm
Reshape
linear_symbolic_batch_dynamic
linear_symbolic_batch
linear_high_rank
v0.1.0
nnx.linear_general Gemm
Reshape
linear_general_dynamic
linear_general
linear_general_2
linear_general_3
linear_general_4
linear_general_abstract_eval_axes
linear_general_abstract_eval_axes_pair
v0.1.0
nnx.log_softmax LogSoftmax log_softmax v0.1.0
nnx.max_pool MaxPool
Transpose
max_pool
max_pool_same_padding
v0.1.0
nnx.relu Relu relu_1d
relu_4d_dynamic
relu_4d
v0.1.0
nnx.rms_norm RMSNormalization rms_norm
rms_norm_2
v0.3.0
nnx.sigmoid Sigmoid sigmoid v0.1.0
nnx.softmax Softmax softmax_dynamic
softmax
v0.1.0
nnx.softplus Softplus softplus v0.1.0
nnx.tanh Tanh tanh v0.1.0

Legend:
✅ = Passed
❌ = Failed
➖ = No testcase yet


🎯 Examples

Component Description Testcases Since
AutoEncoder A simple autoencoder example. simple_autoencoder v0.2.0
CNN A simple convolutional neural network (CNN). simple_cnn_dynamic
simple_cnn
v0.1.0
ClassificationHead Classification head for Vision Transformer classification_head_dynamic
classification_head
v0.4.0
ClassificationHeadFlatten Classification head for Vision Transformer classification_head_flat_dynamic
classification_head_flat
v0.4.0
ConcatClsToken Concatenate CLS token to the input embedding concat_cls_token_dynamic
concat_cls_token
v0.4.0
ConcatClsTokenFlatten Concatenate CLS token to the input embedding concat_cls_token_flat_dynamic
concat_cls_token_flat
v0.4.0
ConvEmbedding Convolutional Token Embedding for MNIST with hierarchical downsampling. mnist_conv_embedding_dynamic
mnist_conv_embedding
v0.1.0
ConvEmbeddingFlatten Convolutional Token Embedding for MNIST with hierarchical downsampling. mnist_conv_embedding_flat_dynamic
mnist_conv_embedding_flat
v0.1.0
FeedForward MLP in Transformer feed_forward_dynamic
feed_forward
v0.1.0
FeedForwardFlatten MLP in Transformer feed_forward_flat_dynamic
feed_forward_flat
v0.1.0
ForiLoop fori_loop example fori_loop_counter v0.5.1
GetToken Get the CLS token from the input embedding get_token_dynamic
get_token
v0.4.0
GetTokenFlatten Get the CLS token from the input embedding get_token_flat_dynamic
get_token_flat
v0.4.0
MLP A simple Multi-Layer Perceptron (MLP) with BatchNorm, Dropout, and GELU activation. simple_mlp_dynamic
simple_mlp
simple_mlp_with_call_params_dynamic
simple_mlp_with_call_params
v0.1.0
MultiHeadAttention This is a multi-head attention module implemented by Flax/nnx that has no ONNX correspondent on the same granularity. multihead_attention_nn_dynamic
multihead_attention_nn
multihead_attention_nnx_dynamic
multihead_attention_nnx
v0.2.0
PatchEmbedding Cutting the image into patches and linearly embedding them. patch_embedding_dynamic
patch_embedding
v0.1.0
PatchEmbeddingFlatten Cutting the image into patches and linearly embedding them. patch_embedding_flat_dynamic
patch_embedding_flat
v0.1.0
PositionalEmbedding Add positional embedding to the input embedding positional_embedding_dynamic
positional_embedding
v0.4.0
PositionalEmbeddingFlatten Add positional embedding to the input embedding positional_embedding_flat_dynamic
positional_embedding_flat
v0.4.0
TransformerBlock Transformer from 'Attention Is All You Need.' transformer_block_dynamic
transformer_block
v0.1.0
TransformerBlockFlatten Transformer from 'Attention Is All You Need.' transformer_block_flat_dynamic
transformer_block_flat
v0.1.0
TransformerStack Stack of Transformer blocks transformer_stack_dynamic
transformer_stack
v0.1.0
TransformerStackFlatten Stack of Transformer blocks transformer_stack_flat_dynamic
transformer_stack_flat
v0.1.0
VisionTransformer A Vision Transformer (ViT) model for MNIST with configurable embedding type. vit_conv_embedding_dynamic
vit_conv_embedding
vit_patch_embedding
v0.2.0
VisionTransformerFlatten A Vision Transformer (ViT) model for MNIST with configurable embedding type. vit_conv_embedding_flat_dynamic
vit_conv_embedding_flat
vit_patch_embedding_flat_dynamic
vit_patch_embedding_flat
v0.2.0
onnx_functions_000 one function on an outer layer. 000_one_function_on_outer_layer_dynamic
000_one_function_on_outer_layer
v0.4.0
onnx_functions_001 one function on an inner layer. 001_one_function_inner_dynamic
001_one_function_inner
v0.4.0
onnx_functions_002 two nested functions. 002_two_nested_functions_dynamic
002_two_nested_functions
v0.4.0
onnx_functions_003 two nested functions. 003_two_simple_nested_functions_dynamic
003_two_simple_nested_functions
v0.4.0
onnx_functions_004 nested function plus component 004_nested_function_plus_component_dynamic
004_nested_function_plus_component
v0.4.0
onnx_functions_005 nested function plus more components 005_nested_function_plus_component_dynamic
005_nested_function_plus_component
v0.4.0
onnx_functions_006 one function on an outer layer. 006_one_function_outer_dynamic
006_one_function_outer
v0.4.0
onnx_functions_007 transformer block with nested mlp block with call parameter 007_transformer_block_dynamic
007_transformer_block
v0.4.0
onnx_functions_008 transformer block with nested mlp block no call parameter 008_transformer_block_dynamic
008_transformer_block
v0.4.0
onnx_functions_009 transformer block using decorator on class and function 009_transformer_block_dynamic
009_transformer_block
v0.4.0
onnx_functions_010 transformer stack 010_transformer_stack_dynamic
010_transformer_stack
v0.4.0
onnx_functions_012 Vision Transformer (ViT) 012_vit_conv_embedding_dynamic
012_vit_conv_embedding
v0.4.0
onnx_functions_013 Vision Transformer (ViT) 013_vit_conv_embedding_with_call_params_dynamic
013_vit_conv_embedding_with_call_params
013_vit_conv_embedding_with_internal_call_params_dynamic
013_vit_conv_embedding_with_internal_call_params
v0.4.0
onnx_functions_014 one function on an outer layer. 014_one_function_with_input_param_with_default_value
014_one_function_without_input_param_with_default_value_dynamic
014_one_function_without_input_param_with_default_value
v0.4.0
onnx_functions_015 one function on an outer layer. 015_one_function_with_input_param_without_default_value_dynamic
015_one_function_with_input_param_without_default_value
v0.4.0
onnx_functions_016 nested function plus more components 016_internal_function_with_input_param_with_default_value_dynamic
016_internal_function_with_input_param_with_default_value
v0.4.0

📌 Dependencies

Versions of Major Dependencies:

Library Versions
JAX 0.6.0
Flax 0.10.6
onnx 1.17.0
onnxruntime 1.21.1

Note: For more details, check pyproject.toml.


⚠️ Limitations

  • Currently not all JAX/Flax components are supported (you can easily help expand this coverage!).
  • Function references need dynamic resolution at call-time.
  • ONNX graph composition is done in-memory before saving to disk, potentially causing memory issues with very large models.

🤝 How to Contribute

We warmly welcome contributions!

How you can help:

  • Add a plugin: Extend jax2onnx by writing a simple Python file in jax2onnx/plugins. a custom primitive or an example.
  • Bug fixes & improvements: PRs and issues are always welcome.

💾 Installation

Install from PyPI:

pip install jax2onnx  

📜 License

This project is licensed under the Apache License, Version 2.0. See LICENSE for details.


🌟 Special Thanks

Special thanks for plugin contributions to @burakksen

Special thanks to @lutzroeder for making shapes internal to ONNX function visible in his great Netron viewer.

Special thanks to the community members involved in:

A huge thanks especially to @limarta, whose elegant jaxpr-to-ONNX demonstration significantly inspired this project.


Happy converting! 🎉

About

export JAX to ONNX - focus on flax nnx models

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Languages