Skip to content

enpasos/jax2onnx

Repository files navigation

jax2onnx 🌟

CI PyPI version

jax2onnx converts your JAX, Flax NNX, Flax Linen, Equinox functions directly into the ONNX format.

jax2onnx.svg

✨ Key Features

  • simple API
    Easily convert JAX callables—including Flax NNX, Flax Linen and Equinox models—into ONNX format using to_onnx(...).

  • model structure preserved
    With @onnx_function, submodules appear as named functions in the ONNX graph (e.g. in Netron). Useful for readability and reuse.

  • dynamic input support
    Use abstract dimensions like 'B' or pass scalars as runtime inputs. Models stay flexible without retracing.

  • plugin-based extensibility
    Add support for new primitives by writing small, local plugins.

  • onnx-ir native pipeline
    Conversion, optimization, and post-processing all run on the typed onnx_ir toolkit—no protobuf juggling—and stay memory-lean before the final ONNX serialization.

  • Netron-friendly outputs
    Generated graphs carry shape/type annotations and a clean hierarchy, so tools like Netron stay easy to read.


🚀 Quickstart

Install and export your first model in minutes:

pip install jax2onnx

Convert your JAX callable to ONNX in just a few lines:

from flax import nnx
from jax2onnx import to_onnx

# Define a simple MLP (from Flax docs)
class MLP(nnx.Module):
    def __init__(self, din, dmid, dout, *, rngs): 
        self.linear1 = nnx.Linear(din, dmid, rngs=rngs)
        self.dropout = nnx.Dropout(rate=0.1, rngs=rngs)
        self.bn = nnx.BatchNorm(dmid, rngs=rngs)
        self.linear2 = nnx.Linear(dmid, dout, rngs=rngs) 
    def __call__(self, x): 
        x = nnx.gelu(self.dropout(self.bn(self.linear1(x))))
        return self.linear2(x)

# Instantiate model
my_callable = MLP(din=30, dmid=20, dout=10, rngs=nnx.Rngs(0))

# Export straight to disk without keeping the proto in memory
to_onnx(
    my_callable,
    [("B", 30)],
    return_mode="file",
    output_path="my_callable.onnx",
)

🔎 See it visualized: my_callable.onnx


🧠 ONNX Functions — Minimal Example

ONNX functions help encapsulate reusable subgraphs. Simply use the @onnx_function decorator to make your callable an ONNX function. Just an @onnx_function decorator to make your callable an ONNX function

from flax import nnx
from jax2onnx import onnx_function, to_onnx

# just an @onnx_function decorator to make your callable an ONNX function
@onnx_function
class MLPBlock(nnx.Module):
  def __init__(self, dim, *, rngs):
    self.linear1 = nnx.Linear(dim, dim, rngs=rngs)
    self.linear2 = nnx.Linear(dim, dim, rngs=rngs)
    self.batchnorm = nnx.BatchNorm(dim, rngs=rngs)
  def __call__(self, x):
    return nnx.gelu(self.linear2(self.batchnorm(nnx.gelu(self.linear1(x)))))

# Use it inside another module
class MyModel(nnx.Module):
  def __init__(self, dim, *, rngs):
    self.block1 = MLPBlock(dim, rngs=rngs)
    self.block2 = MLPBlock(dim, rngs=rngs)
  def __call__(self, x):
    return self.block2(self.block1(x))

callable = MyModel(256, rngs=nnx.Rngs(0))
to_onnx(
    callable,
    [(100, 256)],
    return_mode="file",
    output_path="model_with_function.onnx",
)

🔎 See it visualized: model_with_function.onnx


SotA examples 🚀


🧩 Coverage & Examples (Interactive)

Tip

JAX · Flax · Equinox — explore everything that’s supported and see it in action.

  • Support matrix: status per component
  • 🧪 Exact regression testcase for each entry
  • 🔍 One-click Netron graph to inspect nodes, shapes, attributes
  • 🧩 Examples that compose multiple components (Conv→Norm→Activation→Pool, MLP w/ LayerNorm+Dropout, reshape/transpose/concat, scan/while_loop, gather/scatter, …)

Links: Open support matrix ↗ · Browse examples ↗


📅 Roadmap and Releases

Planned

  • Broaden coverage of JAX, Flax NNX/Linen, and Equinox components.
  • Expand SotA example support for vision and language models.
  • Improve support for physics-based simulations

Current Productive Version

  • 0.11.0:
    • Initial Flax Linen support: core layers (Dense/DenseGeneral, Conv/ConvTranspose/ConvLocal, pooling, BatchNorm/LayerNorm/GroupNorm/RMSNorm/InstanceNorm), Dropout, Einsum/Embed, spectral/weight norm wrappers, activation coverage (GELU plus glu/hard_/log_/relu6/silu-swish/tanh/normalize/one_hot), attention stack (dot_product_attention, dot_product_attention_weights, make_attention_mask/make_causal_mask, SelfAttention, MultiHeadDotProductAttention, MultiHeadAttention), recurrent stack (SimpleCell, GRUCell, MGUCell, LSTMCell, OptimizedLSTMCell, ConvLSTMCell, RNN, Bidirectional), and Linen examples (MLP/CNN/Sequential).
    • Modernized IR optimization pipeline: standard onnx_ir CSE pass adoption, removed legacy helpers/getattr patterns, and simplified tests with direct graph iteration.

Past Versions

See past_versions for the full release archive.


❓ Troubleshooting

If conversion doesn't work out of the box, it could be due to:

  • Non-dynamic function references:
    JAXPR-based conversion requires function references to be resolved dynamically at call-time.
    Solution: Wrap your function call inside a lambda to enforce dynamic resolution:

    my_dynamic_callable_function = lambda x: original_function(x)
  • Unsupported primitives:
    The callable may use a primitive not yet or not fully supported by jax2onnx.
    Solution: Write a plugin to handle the unsupported function (this is straightforward!).

Looking for provenance details while debugging? Check out the new Stacktrace Metadata guide.


🤝 How to Contribute

We warmly welcome contributions!

How you can help:

  • Add a plugin: Extend jax2onnx by writing a simple Python file in jax2onnx/plugins: a primitive or an example. The Plugin Quickstart walks through the process step-by-step.
  • Bug fixes & improvements: PRs and issues are always welcome.

📌 Dependencies

Latest supported version of major dependencies:

Library Versions
JAX 0.8.2
Flax 0.12.2
Equinox 0.13.2
onnx-ir 0.1.13
onnx 1.20.0
onnxruntime 1.23.2

For exact pins and extras, see pyproject.toml.


📜 License

This project is licensed under the Apache License, Version 2.0. See LICENSE for details.


🌟 Special Thanks

✨ Special thanks to @clementpoiret for initiating Equinox support and for Equimo, which brings modern vision models—such as DINOv3—to JAX/Equinox.

✨ Special thanks to @justinchuby for introducing onnx-ir as a scalable and more efficient way to handle ONNX model construction.

✨ Special thanks to @atveit for introducing us to gpt-oss-jax-vs-torch-numerical-comparison.

✨ Special thanks for example contributions to @burakssen, @Cadynum, @clementpoiret and @PVirie

✨ Special thanks for plugin contributions to @burakssen, @clementpoiret, @Clouder0, @rakadam and benmacadam64

✨ Special thanks to @benmacadam64 for championing the complex-number handling initiative.

✨ Special thanks to tumaer/JAXFLUIDS for contributing valuable insights rooted in physics simulation use cases.

✨ Special thanks to @lutzroeder for making shapes internal to ONNX function visible in his great Netron viewer.

✨ Special thanks to the community members involved in:

✨ Special thanks to @limarta, whose elegant jaxpr-to-ONNX demonstration significantly inspired this project.


Happy converting! 🎉

About

export JAX to ONNX

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Contributors 7