Skip to content

enpasos/jax2onnx

Repository files navigation

jax2onnx 🌟

jax2onnx converts your JAX/Flax(nnx) functions directly into the ONNX format.

img.png

✨ Key Features

  • Simple API
    Easily convert JAX callables—including Flax (NNX) models—into ONNX format using to_onnx(...).

  • Model structure preserved
    With @onnx_function, submodules appear as named functions in the ONNX graph (e.g. in Netron). Useful for readability and reuse.

  • Dynamic input support
    Use abstract dimensions like 'B' or pass scalars as runtime inputs. Models stay flexible without retracing.

  • Plugin-based extensibility
    Add support for new primitives by writing small, local plugins.

  • Netron-friendly outputs
    All generated ONNX graphs include shape/type annotations and are structured for clear visualization.


🚀 Quickstart

Convert your JAX callable to ONNX in just a few lines:

import onnx
from flax import nnx
from jax2onnx import to_onnx

# Define a simple MLP (from Flax docs)
class MLP(nnx.Module):
    def __init__(self, din, dmid, dout, *, rngs): 
        self.linear1 = nnx.Linear(din, dmid, rngs=rngs)
        self.dropout = nnx.Dropout(rate=0.1, rngs=rngs)
        self.bn = nnx.BatchNorm(dmid, rngs=rngs)
        self.linear2 = nnx.Linear(dmid, dout, rngs=rngs) 
    def __call__(self, x): 
        x = nnx.gelu(self.dropout(self.bn(self.linear1(x))))
        return self.linear2(x)

# Instantiate model
my_callable = MLP(din=30, dmid=20, dout=10, rngs=nnx.Rngs(0))

# Convert to ONNX
onnx_model = to_onnx(my_callable, [("B", 30)])

# Save the model
onnx.save_model(onnx_model, "my_callable.onnx")

🔎 See it visualized: my_callable.onnx


🧠 ONNX Functions — Minimal Example

ONNX functions help encapsulate reusable subgraphs. Simply use the @onnx_function decorator to make your callable an ONNX function. Just an @onnx_function decorator to make your callable an ONNX function

from onnx import save_model
from flax import nnx
from jax2onnx import onnx_function, to_onnx

# just an @onnx_function decorator to make your callable an ONNX function
@onnx_function
class MLPBlock(nnx.Module):
  def __init__(self, dim, *, rngs):
    self.linear1 = nnx.Linear(dim, dim, rngs=rngs)
    self.linear2 = nnx.Linear(dim, dim, rngs=rngs)
    self.batchnorm = nnx.BatchNorm(dim, rngs=rngs)
  def __call__(self, x):
    return nnx.gelu(self.linear2(self.batchnorm(nnx.gelu(self.linear1(x)))))

# Use it inside another module
class MyModel(nnx.Module):
  def __init__(self, dim, *, rngs):
    self.block1 = MLPBlock(dim, rngs=rngs)
    self.block2 = MLPBlock(dim, rngs=rngs)
  def __call__(self, x):
    return self.block2(self.block1(x))

callable = MyModel(256, rngs=nnx.Rngs(0))
model = to_onnx(callable, [(100, 256)])
save_model(model, "docs/onnx/model_with_function.onnx")

🔎 See it visualized: model_with_function.onnx


📅 Roadmap and Releases

Planned Versions

  • Ongoing

    • Expanding coverage of JAX and Flax (NNX) components
    • Enhancing support for physics-based simulations
  • Under Evaluation

    • Integrating onnx-ir as a backend to improve ONNX model construction, memory efficiency, and performance
  • Upcoming

    • Advanced ONNX function support, including function reuse, optimized internal graph structure, and improved variable naming for clarity and readability
    • Support for equinox models

Current Productive Version

  • 0.7.0 (PyPI):
    • Added a GPT-2 model example based on nanoGPT, featuring ONNX function support and attention masking
    • New support for jnp.concatenate, jnp.take, nnx.Embed
    • ONNX models are now hosted on Hugging Face

Past Versions

  • 0.6.5: Improved support for nnx.batch_norm, nnx.group_norm, nnx.layer_norm, nnx.rms_norm, lax.broadcast_in_dim, lax.cond, lax.fori_loop, lax.integer_pow, lax.scan, lax.scatter, lax.scatter_add, lax.scatter_mul and lax.while_loop; added support for lax.and, lax.rem and lax.remat2.
  • 0.6.4: Improved support for lax.scatter_mul.
  • 0.6.3: Double precision fixes for lax.fori_loop and lax.while_loop. Fixed bugs in lax.scan and jnp.where.
  • 0.6.2: Fixed bugs in nnx.conv and lax.reshape; added new primitive jnp.prod.
  • 0.6.1: Improved support for lax.cond and lax.select_n; added new primitives (lax.reduce_and, lax.reduce_or, lax.reduce_prod, lax.reduce_xor); and introduced new examples for jnp.select and jnp.sort.
  • 0.6.0: Introduced the enable_double_precision parameter (default: False) to support physics simulations, and enhanced handling of lax.scatter.
  • 0.5.2: Add support for additional primitives: jnp.where, jnp.arange, jnp.linspace.
  • 0.5.1: Add support for subgraph using primitives: lax.while_loop, lax.cond, lax.fori_loop, lax.scan.
  • 0.5.0: Improved dynamic batch dimension handling by leveraging shape polymorphism for more robust and flexible model export. Added support for jnp.sign, jnp.abs, jnp.iota primitives.
  • 0.4.4: Added support for lax.cos, lax.cosh, lax.sin, lax.sinh and lax.scatter primitives.
  • 0.4.3: Fixed a bug in the validation of JAX callable outputs against their ONNX counterparts. This fix exposed previously hidden failing tests, which are now fixed.
  • 0.4.2: Cleanup and fixes to the basic ONNX function release.
  • 0.4.1 (ONNX functions): Introducing simple ONNX function support. Making use of ONNX functions is easy for the user: just a @onnx_function decorator making a callable an ONNX function. Each @onnx_function decorator creates a new ONNX function instance on the call graph.
  • 0.3.2: relaxed the minimum Python version to 3.10.
  • 0.3.0: Streamlined the plugin system with automatic registration and simplified integration of custom primitives.
  • 0.2.0 (First PyPI Release): Rebased the implementation on jaxpr, improving usability and adding low-level lax components.
  • 0.1.0 (Initial Approach, Not Released to PyPI): Produced ONNX exports for some nnx components and nnx-based examples, including a VisualTransformer.

❓ Troubleshooting

If conversion doesn't work out of the box, it could be due to:

  • Non-dynamic function references:
    JAXPR-based conversion requires function references to be resolved dynamically at call-time.
    Solution: Wrap your function call inside a lambda to enforce dynamic resolution:

    my_dynamic_callable_function = lambda x: original_function(x)
  • Unsupported primitives:
    The callable may use a primitive not yet or not fully supported by jax2onnx.
    Solution: Write a plugin to handle the unsupported function (this is straightforward!).


🧩 Supported JAX/ONNX Components

JAX Component ONNX Components Testcases Since
core.dim_as_value Cast
Gather
Reshape
Shape
dim_as_value_dynamic
dim_as_value_dynamic_f64
dim_as_value
dim_as_value_f64
v0.5.0
eqx.dropout Dropout
Not
eqx_dropout_inference_mode
eqx_dropout_inference_mode_f64
eqx_dropout_training_mode_dynamic
eqx_dropout_training_mode_dynamic_f64
eqx_dropout_training_mode
eqx_dropout_training_mode_f64
eqx_dropout_dynamic_inference
eqx_dropout_dynamic_inference_f64
v0.7.0
eqx.identity Identity eqx_identity_static
eqx_identity_static_f64
eqx_identity_symbolic_batch_dynamic
eqx_identity_symbolic_batch_dynamic_f64
eqx_identity_symbolic_batch
eqx_identity_symbolic_batch_f64
v0.7.0
eqx.layer_norm LayerNormalization layer_norm
layer_norm_f64
layer_norm_multiaxis
layer_norm_multiaxis_f64
batched_layer_norm_dynamic
batched_layer_norm_dynamic_f64
batched_layer_norm
batched_layer_norm_f64
layer_norm_no_bias_no_scale
layer_norm_no_bias_no_scale_f64
v0.7.0
eqx.linear Gemm
Reshape
eqx_linear_symbolic_batch_dynamic
eqx_linear_symbolic_batch_dynamic_f64
eqx_linear_symbolic_batch
eqx_linear_symbolic_batch_f64
eqx_linear_high_rank
eqx_linear_high_rank_f64
v0.7.0
jnp.add Add add
add_f64
v0.1.0
jnp.arange Range arange_stop_only_concrete_input_val
arange_stop_only_concrete_input_val_f64
arange_start_stop_concrete_input_val
arange_start_stop_concrete_input_val_f64
arange_start_stop_step_concrete_input_val
arange_start_stop_step_concrete_input_val_f64
arange_float_concrete_input_val
arange_float_concrete_input_val_f64
arange_static_stop_only_int
arange_static_stop_only_int_f64
arange_static_stop_only_float
arange_static_stop_only_float_f64
arange_static_start_stop_int
arange_static_start_stop_int_f64
arange_static_start_stop_step_int
arange_static_start_stop_step_int_f64
arange_static_empty_result_pos_step
arange_static_empty_result_pos_step_f64
arange_static_empty_result_neg_step
arange_static_empty_result_neg_step_f64
arange_static_negative_step
arange_static_negative_step_f64
arange_static_float_step_explicit_dtype
arange_static_float_step_explicit_dtype_f64
arange_static_float_step_inferred_dtype
arange_static_float_step_inferred_dtype_f64
arange_static_stop_zero
arange_static_stop_zero_f64
arange_static_start_equals_stop
arange_static_start_equals_stop_f64
arange_static_large_numbers_int
arange_static_large_numbers_int_f64
v0.5.2
jnp.concatenate Concat concatenate
concatenate_f64
concatenate_abstract_middle_dim_dynamic
concatenate_abstract_middle_dim_dynamic_f64
concatenate_abstract_middle_dim
concatenate_abstract_middle_dim_f64
concatenate_tile_and_symbolic_dynamic
concatenate_tile_and_symbolic_dynamic_f64
concatenate_tile_and_symbolic
concatenate_tile_and_symbolic_f64
v0.1.0
jnp.einsum Einsum einsum_vector_dot
einsum_vector_dot_f64
einsum_matrix_vector
einsum_matrix_vector_f64
einsum_matrix_matrix_dynamic
einsum_matrix_matrix_dynamic_f64
einsum_matrix_matrix
einsum_matrix_matrix_f64
einsum_transpose
einsum_transpose_f64
einsum_batch_transpose_dynamic
einsum_batch_transpose_dynamic_f64
einsum_batch_transpose
einsum_batch_transpose_f64
einsum_diag
einsum_diag_f64
einsum_sum_reduce
einsum_sum_reduce_f64
einsum_multi_operand
einsum_multi_operand_f64
einsum_attention_logits_orig_dynamic
einsum_attention_logits_orig_dynamic_f64
einsum_attention_logits_orig
einsum_attention_logits_orig_f64
einsum_attention_output_orig_dynamic
einsum_attention_output_orig_dynamic_f64
einsum_attention_output_orig
einsum_attention_output_orig_f64
einsum_attention_logits_batched_dynamic
einsum_attention_logits_batched_dynamic_f64
einsum_attention_logits_batched
einsum_attention_logits_batched_f64
einsum_attention_output_batched_dynamic
einsum_attention_output_batched_dynamic_f64
einsum_attention_output_batched
einsum_attention_output_batched_f64
einsum_ellipsis_rank_mismatch
einsum_ellipsis_rank_mismatch_f64
einsum_attention_logits_batched_rank_mismatch
einsum_attention_logits_batched_rank_mismatch_f64
v0.1.0
jnp.linspace Constant linspace_static_basic
linspace_static_basic_f64
linspace_static_endpoint_false
linspace_static_endpoint_false_f64
linspace_static_num_1
linspace_static_num_1_f64
linspace_static_num_0
linspace_static_num_0_f64
linspace_static_int_inputs_default_dtype
linspace_static_int_inputs_default_dtype_f64
v0.5.2
jnp.matmul MatMul matmul_2d
matmul_2d_f64
matmul_1d_2d
matmul_1d_2d_f64
matmul_2d_1d
matmul_2d_1d_f64
matmul_dynamic_dynamic
matmul_dynamic_dynamic_f64
matmul_dynamic
matmul_dynamic_f64
matmul_dynamic_a_dynamic
matmul_dynamic_a_dynamic_f64
matmul_dynamic_a
matmul_dynamic_a_f64
matmul_1d
matmul_1d_f64
matmul_3d
matmul_3d_f64
v0.1.0
jnp.prod ReduceProd basic_prod
basic_prod_f64
prod_with_axis
prod_with_axis_f64
prod_with_keepdims
prod_with_keepdims_f64
v0.6.2
jnp.reshape Reshape reshape_1
reshape_1_f64
reshape_2
reshape_2_f64
reshape_3
reshape_3_f64
reshape_4_dynamic
reshape_4_dynamic_f64
reshape_4
reshape_4_f64
reshape_to_scalar
reshape_to_scalar_f64
reshape_from_scalar
reshape_from_scalar_f64
reshape_cnn_dynamic
reshape_cnn_dynamic_f64
reshape_cnn
reshape_cnn_f64
reshape_valid_flatten_trailing
reshape_valid_flatten_trailing_f64
reshape_with_target_shape_from_symbolic_dim_computation
reshape_with_target_shape_from_symbolic_dim_computation_f64
v0.1.0
jnp.shape Shape shape_basic
shape_basic_f64
shape_dynamic_dynamic
shape_dynamic_dynamic_f64
shape_dynamic
shape_dynamic_f64
0.4.0
jnp.sort TopK sort_1d
sort_1d_f64
sort_2d_axis0_dynamic
sort_2d_axis0_dynamic_f64
sort_2d_axis0
sort_2d_axis0_f64
v0.5.2
jnp.squeeze Squeeze squeeze_single_dim
squeeze_single_dim_f64
squeeze_multiple_dims
squeeze_multiple_dims_f64
squeeze_vit_output
squeeze_vit_output_f64
squeeze_dynamic_batch_dynamic
squeeze_dynamic_batch_dynamic_f64
squeeze_dynamic_batch
squeeze_dynamic_batch_f64
squeeze_all_dims
squeeze_all_dims_f64
squeeze_negative_axis
squeeze_negative_axis_f64
squeeze_negative_axis_tuple
squeeze_negative_axis_tuple_f64
squeeze_dynamic_and_negative_axis_dynamic
squeeze_dynamic_and_negative_axis_dynamic_f64
squeeze_dynamic_and_negative_axis
squeeze_dynamic_and_negative_axis_f64
v0.1.0
jnp.take Gather take_data_dependent_indices v0.7.0
jnp.tile Tile tile_repeats
tile_repeats_f64
tile_a
tile_a_f64
tile_b
tile_b_f64
tile_c
tile_c_f64
tile_d
tile_d_f64
tile_dynamic_input_static
tile_dynamic_input_static_f64
tile_dynamic_input_dynamic
tile_dynamic_input_dynamic_f64
tile_dynamic_input
tile_dynamic_input_f64
tile_pad
tile_pad_f64
tile_with_symbolic_repeats_static
tile_with_symbolic_repeats_static_f64
tile_with_symbolic_repeats_dynamic
tile_with_symbolic_repeats_dynamic_f64
tile_with_symbolic_repeats
tile_with_symbolic_repeats_f64
tile_param_symbolic_dynamic
tile_param_symbolic_dynamic_f64
tile_param_symbolic
tile_param_symbolic_f64
v0.1.0
jnp.transpose Transpose transpose_basic
transpose_basic_f64
transpose_reverse
transpose_reverse_f64
transpose_4d_dynamic
transpose_4d_dynamic_f64
transpose_4d
transpose_4d_f64
transpose_square_matrix
transpose_square_matrix_f64
transpose_high_dim
transpose_high_dim_f64
transpose_no_axes
transpose_no_axes_f64
transpose_3d_dynamic
transpose_3d_dynamic_f64
transpose_3d
transpose_3d_f64
v0.1.0
jnp.where Where where_simple
where_simple_f64
where_broadcast
where_broadcast_f64
where_multidim_condition_scalar_branches_broadcast
where_multidim_condition_scalar_branches_broadcast_f64
where_multidim_condition_scalar_branches_broadcast
where_multidim_condition_scalar_branches_broadcast_f64
where_A
where_A_f64
where_B
where_B_f64
where_jax_int_literals_broadcast_f64_mode
where_simple
where_simple_f64
v0.5.2
lax.abs Abs abs
abs_f64
v0.5.0
lax.add Add add
add_f64
v0.2.0
lax.and And
BitwiseAnd
and_bool
and_bool_f64
and_int
and_int_f64
v0.6.5
lax.argmax ArgMax argmax_float_axis0
argmax_float_axis0_f64
argmax_float_axis1
argmax_float_axis1_f64
argmax_boolean_input_axis0_specific_values
argmax_boolean_input_axis0_specific_values_f64
argmax_boolean_input_axis1_specific_values
argmax_boolean_input_axis1_specific_values_f64
argmax_boolean_random_input_axis0
argmax_boolean_random_input_axis0_f64
v0.2.0
lax.argmin ArgMin argmin_test1
argmin_test1_f64
argmin_test2
argmin_test2_f64
v0.2.0
lax.broadcast_in_dim Expand
Identity
Reshape
broadcast_in_dim
broadcast_in_dim_f64
broadcast_in_dim_2d_to_3d
broadcast_in_dim_2d_to_3d_f64
broadcast_in_dim_scalar
broadcast_in_dim_scalar_f64
broadcast_in_dim_batch_dynamic
broadcast_in_dim_batch_dynamic_f64
broadcast_in_dim_batch
broadcast_in_dim_batch_f64
broadcast_in_dim_dynamic_B_dynamic
broadcast_in_dim_dynamic_B_dynamic_f64
broadcast_in_dim_dynamic_B
broadcast_in_dim_dynamic_B_f64
v0.2.0
lax.concatenate Concat concatenate
concatenate_f64
concatenate_axis1_dynamic
concatenate_axis1_dynamic_f64
concatenate_axis1
concatenate_axis1_f64
concatenate_axis0
concatenate_axis0_f64
concatenate_3d
concatenate_3d_f64
v0.2.0
lax.cond If cond_scalar
cond_scalar_f64
cond_multiple_operands_in_tuple
cond_multiple_operands_in_tuple_f64
cond_my_new_complex_scenario
cond_my_new_complex_scenario_f64
cond_nested_conditional
cond_nested_conditional_f64
cond_variables
cond_variables_f64
cond_internal_constant_f64
cond_passthrough_identity
cond_passthrough_identity_f64
cond_with_scatter
cond_with_scatter_f64
v0.5.1
lax.conv Conv conv
conv_f64
conv2
conv2_f64
v0.2.0
lax.convert_element_type Cast convert_element_type
convert_element_type_f64
v0.2.0
[lax.copy](Handles the JAX primitive lax.copy_p. Note: jax.lax.copy API is removed.) Identity copy_float32_array
copy_int64_scalar
<your_current_version>
lax.cos Cos cos
cos_f64
v0.4.4
lax.cosh Cosh cosh
cosh_f64
v0.4.4
lax.device_put Identity device_put_array
device_put_array_f64
device_put_scalar
device_put_scalar_f64
v0.4.0
lax.div Div div
div_f64
v0.2.0
lax.dot_general MatMul dot_general
dot_general_f64
dot_general_lhs1_rhs1
dot_general_lhs1_rhs1_f64
v0.2.0
lax.dynamic_slice Slice dynamic_slice_test1
dynamic_slice_test1_f64
dynamic_slice_2d
dynamic_slice_2d_f64
dynamic_slice_3d
dynamic_slice_3d_f64
dynamic_slice_vit_like
dynamic_slice_vit_like_f64
dynamic_slice_vit_like_dynamic_dynamic
dynamic_slice_vit_like_dynamic_dynamic_f64
dynamic_slice_vit_like_dynamic
dynamic_slice_vit_like_dynamic_f64
v0.1.0
lax.eq Equal eq
eq_f64
v0.2.0
lax.exp Exp exp
exp_f64
v0.2.0
lax.fori_loop Loop fori_loop_counter
fori_loop_counter_f64
fori_loop_zero
fori_loop_zero_f64
fori_loop_vector
fori_loop_vector_f64
fori_loop_example
fori_loop_example_f64
fori_loop_test
fori_loop_test_f64
v0.5.1
lax.gather GatherND gather_static
gather_static_f64
gather_dynamic_batch_simple_index_dynamic
gather_dynamic_batch_simple_index_dynamic_f64
gather_dynamic_batch_simple_index
gather_dynamic_batch_simple_index_f64
v0.2.0
lax.gt Greater gt
gt_f64
v0.2.0
lax.integer_pow Pow integer_pow
integer_pow_f64
v0.2.0
lax.iota Range iota_int32
iota_int32_f64
iota_float32
iota_float32_f64
broadcasted_iota
broadcasted_iota_f64
v0.5.0
lax.log Log log
log_f64
v0.2.0
lax.lt Less lt
lt_f64
v0.2.0
lax.max Max max
max_f64
v0.2.0
lax.min Min min_test1
min_test1_f64
v0.1.0
lax.mul Mul mul_test1
mul_test1_f64
mul_test2
mul_test2_f64
v0.1.0
lax.ne Equal
Not
ne
ne_f64
v0.2.0
lax.neg Neg neg
neg_f64
v0.2.0
lax.reduce_and Cast
ReduceMin
reduce_and_all_true
reduce_and_all_true_f64
reduce_and_one_false
reduce_and_one_false_f64
reduce_and_keepdims
reduce_and_keepdims_f64
v0.6.1
lax.reduce_max ReduceMax reduce_max
reduce_max_f64
reduce_max_allaxes
reduce_max_allaxes_f64
reduce_max_keepdims
reduce_max_keepdims_f64
v0.2.0
lax.reduce_min ReduceMin reduce_min
reduce_min_f64
reduce_min_allaxes
reduce_min_allaxes_f64
reduce_min_keepdims
reduce_min_keepdims_f64
v0.2.0
lax.reduce_or Cast
ReduceMax
reduce_or_all_false
reduce_or_all_false_f64
reduce_or_one_true
reduce_or_one_true_f64
reduce_or_keepdims
reduce_or_keepdims_f64
v0.6.1
lax.reduce_prod ReduceProd reduce_prod
reduce_prod_f64
reduce_prod_allaxes
reduce_prod_allaxes_f64
reduce_prod_keepdims
reduce_prod_keepdims_f64
reduce_prod_dtype_f64
reduce_prod_dtype
v0.6.1
lax.reduce_sum ReduceSum reduce_sum
reduce_sum_f64
reduce_sum_allaxes
reduce_sum_allaxes_f64
reduce_sum_keepdims
reduce_sum_keepdims_f64
reduce_sum_dtype_f64
reduce_sum_dtype
v0.2.0
lax.reduce_xor Cast
Mod
ReduceSum
reduce_xor_all_false
reduce_xor_all_false_f64
reduce_xor_one_true
reduce_xor_one_true_f64
reduce_xor_two_true
reduce_xor_two_true_f64
reduce_xor_keepdims
reduce_xor_keepdims_f64
v0.6.1
lax.rem Div
Mod
rem_int
rem_int_f64
rem_float
rem_float_f64
rem_int_neg
rem_int_neg_f64
rem_float_neg
rem_float_neg_f64
v0.6.5
lax.reshape Reshape reshape
reshape_f64
reshape_valid_squeeze_middle_dim_from_problematic_source
reshape_valid_squeeze_middle_dim_from_problematic_source_f64
reshape_valid_flatten_trailing
reshape_valid_flatten_trailing_f64
reshape_with_target_shape_from_symbolic_dim_computation
reshape_with_target_shape_from_symbolic_dim_computation_f64
reshape_with_inferred_dimension_from_input_dynamic_dynamic
reshape_with_inferred_dimension_from_input_dynamic_dynamic_f64
reshape_with_inferred_dimension_from_input_dynamic
reshape_with_inferred_dimension_from_input_dynamic_f64
reshape_with_inferred_dimension_from_input
reshape_with_inferred_dimension_from_input_f64
v0.2.0
lax.scan Scan scan_cumsum
scan_cumsum_f64
scan_carry_only
scan_carry_only_f64
scan_multiple_sequences
scan_multiple_sequences_f64
scan_multiple_carry
scan_multiple_carry_f64
scan_matrix_carry_multidim_xs
scan_matrix_carry_multidim_xs_f64
scan_no_xs
scan_no_xs_f64
scan_fn
scan_fn_f64
scan_jit_no_xs
scan_jit_no_xs_f64
v0.5.1
lax.scatter ScatterND scatter_set_axis0
scatter_set_axis0_f64
scatter_set_middle
scatter_set_middle_f64
scatter_correct_axis_determination
scatter_correct_axis_determination_f64
scatter_updates_slice_needed_axis0
scatter_updates_slice_needed_axis0_f64
scatter_from_user_warning_shapes_valid_jax
scatter_from_user_warning_shapes_valid_jax_f64
scatter_user_error_scenario_precise
scatter_user_error_scenario_precise_f64
v0.4.4
lax.scatter_add ScatterND scatter_add_simple_1d
scatter_add_simple_1d_f64
scatter_add_window_2d_operand_1d_indices
scatter_add_window_2d_operand_1d_indices_f64
scatter_add_batch_updates_1d_operand
scatter_add_batch_updates_1d_operand_f64
scatter_add_mismatched_window_dims_from_user_report
scatter_add_mismatched_window_dims_from_user_report2
scatter_add_mismatched_window_dims_from_user_report3
scatter_add_fluids_pattern_updates_5_4_1_1
scatter_add_in_cond_float64
v0.5.3
lax.scatter_mul ScatterND scatter_mul_simple_1d
scatter_mul_simple_1d_f64
scatter_mul_window_2d_operand_1d_indices
scatter_mul_window_2d_operand_1d_indices_f64
scatter_mul_batch_updates_1d_operand
scatter_mul_batch_updates_1d_operand_f64
scatter_mul_mismatched_window_dims_from_user_report
scatter_mul_mismatched_window_dims_from_user_report2
scatter_mul_mismatched_window_dims_from_user_report3
scatter_mul_fluids_pattern_updates_5_4_1_1
scatter_mul_in_cond_float64
v0.6.4
lax.select_n Where select_n_bool_predicate_two_cases_float
select_n_bool_predicate_two_cases_float_f64
select_n_bool_predicate_two_cases_int
select_n_bool_predicate_two_cases_int_f64
select_n_bool_predicate_scalar_broadcast
select_n_bool_predicate_scalar_broadcast_f64
select_n_int_indices_three_cases
select_n_int_indices_three_cases_f64
select_n_int_indices_four_cases
select_n_int_indices_four_cases_f64
v0.2.0
lax.sign Sign sign
sign_f64
v0.5.0
lax.sin Sin sin
sin_f64
v0.4.4
lax.sinh Sinh sinh
sinh_f64
v0.4.4
lax.slice Slice slice_test1
slice_test1_f64
slice_3d_none_strides
slice_3d_none_strides_f64
v0.1.0
lax.sort TopK sort_1d
sort_1d_f64
sort_2d
sort_2d_f64
v0.2.0
lax.sqrt Sqrt sqrt
sqrt_f64
v0.2.0
lax.square Mul square
square_f64
v0.2.0
lax.squeeze Squeeze lax_squeeze_specific_axis_0
lax_squeeze_specific_axis_0_f64
lax_squeeze_multiple_axes
lax_squeeze_multiple_axes_f64
lax_squeeze_no_op_empty_dims
lax_squeeze_no_op_empty_dims_f64
lax_squeeze_problem_case_input_squeeze_only_axis_0
lax_squeeze_problem_case_input_squeeze_only_axis_0_f64
lax_squeeze_problem_case_input_squeeze_axes_0_2
lax_squeeze_problem_case_input_squeeze_axes_0_2_f64
lax_squeeze_problem_case_input_squeeze_all_dims_explicitly
lax_squeeze_problem_case_input_squeeze_all_dims_explicitly_f64
v0.2.0
lax.stop_gradient Identity stop_gradient
stop_gradient_f64
v0.2.0
lax.sub Sub sub_test1
sub_test1_f64
sub_test2
sub_test2_f64
v0.1.0
lax.tanh Tanh tanh
tanh_f64
v0.2.0
lax.transpose Transpose transpose_basic
transpose_basic_f64
v0.2.0
lax.while_loop Loop while_loop_counter
while_loop_counter_f64
while_loop_vector
while_loop_vector_f64
while_loop_f64
while_loop_multi_state_f32
while_loop_multi_state_f64
while_loop_with_closure
while_loop_with_closure_f64
while_loop_basic
while_loop_two_state
while_loop_captured_tracer
while_loop_tracer_passthrough
while_loop_no_loop_output_reused_as_input
while_loop_with_closure2_dynamic
while_loop_with_closure2_dynamic_f64
while_loop_with_closure2
while_loop_with_closure2_f64
v0.5.1
nn.celu Celu jaxnn_celu
jaxnn_celu_f64
jaxnn_celu_1
jaxnn_celu_1_f64
v0.7.0
nn.dot_product_attention Add
Cast
MatMul
Mul
Not
Softmax
Transpose
Where
dpa_basic
dpa_basic_f64
dpa_diff_heads_embed
dpa_diff_heads_embed_f64
dpa_batch4_seq16
dpa_batch4_seq16_f64
dpa_float64
dpa_float64_f64
dpa_heads1_embed4
dpa_heads1_embed4_f64
dpa_heads8_embed8
dpa_heads8_embed8_f64
dpa_batch1_seq2
dpa_batch1_seq2_f64
dpa_batch8_seq4
dpa_batch8_seq4_f64
dpa_axis1
dpa_axis1_f64
dpa_with_tensor_mask
dpa_with_tensor_mask_f64
dpa_tiny_mask_all_valid
dpa_tiny_mask_all_valid_f64
dpa_tiny_mask_mixed
dpa_tiny_mask_mixed_f64
dpa_one_false
dpa_one_false_f64
dpa_mostly_false
dpa_mostly_false_f64
dpa_with_causal_mask
dpa_with_causal_mask_f64
dpa_with_padding_mask
dpa_with_padding_mask_f64
dpa_with_local_window_mask
dpa_with_local_window_mask_f64
v0.1.0
nn.elu Elu jaxnn_elu
jaxnn_elu_f64
jaxnn_elu_1
jaxnn_elu_1_f64
v0.7.0
nn.gelu Gelu jaxnn_gelu
jaxnn_gelu_f64
jaxnn_gelu_1
jaxnn_gelu_1_f64
jaxnn_gelu_approx
jaxnn_gelu_approx_f64
v0.7.0
nn.identity Identity jaxnn_identity
jaxnn_identity_f64
jaxnn_identity_1
jaxnn_identity_1_f64
v0.7.0
nn.leaky_relu LeakyRelu jaxnn_leaky_relu
jaxnn_leaky_relu_f64
jaxnn_leaky_relu_1
jaxnn_leaky_relu_1_f64
v0.7.0
nn.mish Mish jaxnn_mish
jaxnn_mish_f64
jaxnn_mish_1
jaxnn_mish_1_f64
v0.7.0
nn.relu Relu jaxnn_relu
jaxnn_relu_f64
jaxnn_relu_1
jaxnn_relu_1_f64
v0.7.0
nn.selu Selu jaxnn_selu
jaxnn_selu_f64
jaxnn_selu_1
jaxnn_selu_1_f64
v0.7.0
nn.sigmoid Sigmoid jaxnn_sigmoid
jaxnn_sigmoid_f64
jaxnn_sigmoid_1
jaxnn_sigmoid_1_f64
v0.7.0
nn.soft_sign Softsign jaxnn_soft_sign
jaxnn_soft_sign_f64
jaxnn_soft_sign_1
jaxnn_soft_sign_1_f64
v0.7.0
nn.softmax Softmax softmax
softmax_f64
softmax_2d
softmax_2d_f64
softmax_3d
softmax_3d_f64
v0.1.0
nn.softplus Softplus jaxnn_softplus
jaxnn_softplus_f64
jaxnn_softplus_1
jaxnn_softplus_1_f64
v0.7.0
nnx.avg_pool AveragePool
Transpose
avg_pool_dynamic
avg_pool_dynamic_f64
avg_pool
avg_pool_f64
avg_pool_same_padding_dynamic
avg_pool_same_padding_dynamic_f64
avg_pool_same_padding
avg_pool_same_padding_f64
avg_pool_default_padding_dynamic
avg_pool_default_padding_dynamic_f64
avg_pool_default_padding
avg_pool_default_padding_f64
avg_pool_stride1_dynamic
avg_pool_stride1_dynamic_f64
avg_pool_stride1
avg_pool_stride1_f64
avg_pool_win3x3_stride2_dynamic
avg_pool_win3x3_stride2_dynamic_f64
avg_pool_win3x3_stride2
avg_pool_win3x3_stride2_f64
avg_pool_stride_none_dynamic
avg_pool_stride_none_dynamic_f64
avg_pool_stride_none
avg_pool_stride_none_f64
avg_pool_count_include_pad_false_dynamic
avg_pool_count_include_pad_false_dynamic_f64
avg_pool_count_include_pad_false
avg_pool_count_include_pad_false_f64
v0.1.0
nnx.batch_norm BatchNormalization batch_norm_no_bias_no_scale_dynamic
batch_norm_no_bias_no_scale_dynamic_f64
batch_norm_no_bias_no_scale
batch_norm_no_bias_no_scale_f64
batch_norm_bias_no_scale_dynamic
batch_norm_bias_no_scale_dynamic_f64
batch_norm_bias_no_scale
batch_norm_bias_no_scale_f64
batch_norm_no_bias_scale_dynamic
batch_norm_no_bias_scale_dynamic_f64
batch_norm_no_bias_scale
batch_norm_no_bias_scale_f64
batch_norm_bias_scale_dynamic
batch_norm_bias_scale_dynamic_f64
batch_norm_bias_scale
batch_norm_bias_scale_f64
batch_norm_4d_dynamic
batch_norm_4d_dynamic_f64
batch_norm_4d
batch_norm_4d_f64
batch_norm_4d_no_bias_no_scale_dynamic
batch_norm_4d_no_bias_no_scale_dynamic_f64
batch_norm_4d_no_bias_no_scale
batch_norm_4d_no_bias_no_scale_f64
batch_norm_training_mode_fallback_dynamic
batch_norm_training_mode_fallback_dynamic_f64
batch_norm_training_mode_fallback
batch_norm_training_mode_fallback_f64
v0.1.0
nnx.conv Conv
Transpose
conv_basic_bias_dynamic
conv_basic_bias_dynamic_f64
conv_basic_bias
conv_basic_bias_f64
conv_basic_bias_2
conv_basic_bias_2_f64
conv_basic_bias_3
conv_basic_bias_3_f64
conv_stride2_bias
conv_stride2_bias_f64
conv_no_bias_dynamic
conv_no_bias_dynamic_f64
conv_no_bias
conv_no_bias_f64
conv_valid_padding
conv_valid_padding_f64
conv_stride1
conv_stride1_f64
conv_stride2
conv_stride2_f64
conv_different_kernel
conv_different_kernel_f64
conv_float64
conv_float64_f64
conv_single_batch
conv_single_batch_f64
conv_large_batch
conv_large_batch_f64
v0.1.0
nnx.dot_product_attention Cast
Div
Einsum
Gather
Shape
Softmax
Sqrt
dpa_basic
dpa_basic_f64
dpa_with_tensor_mask
dpa_with_tensor_mask_f64
dpa_with_bias
dpa_with_bias_f64
dpa_with_causal_mask
dpa_with_causal_mask_f64
dpa_with_mask_and_bias
dpa_with_mask_and_bias_f64
v0.1.0
nnx.dropout Dropout dropout_init_params_dynamic
dropout_init_params_dynamic_f64
dropout_init_params
dropout_init_params_f64
dropout_call_params_dynamic
dropout_call_params_dynamic_f64
dropout_call_params
dropout_call_params_f64
v0.1.0
nnx.einsum Add
Einsum
einsum_module_with_bias
einsum_module_with_bias_f64
einsum_module_no_bias
einsum_module_no_bias_f64
v0.4.2
nnx.elu Elu elu
elu_f64
v0.1.0
nnx.embed Gather token_embedding_dynamic
token_embedding_dynamic_f64
token_embedding
token_embedding_f64
positional_embedding_dynamic
positional_embedding_dynamic_f64
positional_embedding
positional_embedding_f64
v0.7.0
nnx.gelu Gelu gelu
gelu_f64
gelu_1
gelu_1_f64
gelu_2
gelu_2_f64
gelu_3_dynamic
gelu_3_dynamic_f64
gelu_3
gelu_3_f64
v0.1.0
nnx.group_norm GroupNormalization group_norm
group_norm_f64
group_norm_2
group_norm_2_f64
group_norm_no_bias_no_scale_dynamic
group_norm_no_bias_no_scale_dynamic_f64
group_norm_no_bias_no_scale
group_norm_no_bias_no_scale_f64
group_norm_bias_no_scale_dynamic
group_norm_bias_no_scale_dynamic_f64
group_norm_bias_no_scale
group_norm_bias_no_scale_f64
group_norm_no_bias_scale_dynamic
group_norm_no_bias_scale_dynamic_f64
group_norm_no_bias_scale
group_norm_no_bias_scale_f64
group_norm_bias_scale_dynamic
group_norm_bias_scale_dynamic_f64
group_norm_bias_scale
group_norm_bias_scale_f64
v0.3.0
nnx.layer_norm LayerNormalization layer_norm_dynamic
layer_norm_dynamic_f64
layer_norm
layer_norm_f64
layer_norm_no_bias_no_scale_dynamic
layer_norm_no_bias_no_scale_dynamic_f64
layer_norm_no_bias_no_scale
layer_norm_no_bias_no_scale_f64
layer_norm_bias_no_scale_dynamic
layer_norm_bias_no_scale_dynamic_f64
layer_norm_bias_no_scale
layer_norm_bias_no_scale_f64
layer_norm_no_bias_scale_dynamic
layer_norm_no_bias_scale_dynamic_f64
layer_norm_no_bias_scale
layer_norm_no_bias_scale_f64
layer_norm_bias_scale_dynamic
layer_norm_bias_scale_dynamic_f64
layer_norm_bias_scale
layer_norm_bias_scale_f64
layer_norm_multiaxis_dynamic
layer_norm_multiaxis_dynamic_f64
layer_norm_multiaxis
layer_norm_multiaxis_f64
v0.1.0
nnx.leaky_relu LeakyRelu leaky_relu
leaky_relu_f64
v0.1.0
nnx.linear Gemm
Reshape
linear_symbolic_batch_dynamic
linear_symbolic_batch_dynamic_f64
linear_symbolic_batch
linear_symbolic_batch_f64
linear_high_rank_dynamic
linear_high_rank_dynamic_f64
linear_high_rank
linear_high_rank_f64
linear_no_bias_dynamic
linear_no_bias_dynamic_f64
linear_no_bias
linear_no_bias_f64
linear_high_rank_no_bias_dynamic
linear_high_rank_no_bias_dynamic_f64
linear_high_rank_no_bias
linear_high_rank_no_bias_f64
v0.1.0
nnx.linear_general Gemm
Reshape
linear_general_dynamic
linear_general_dynamic_f64
linear_general
linear_general_f64
linear_general_2
linear_general_2_f64
linear_general_3
linear_general_3_f64
linear_general_4
linear_general_4_f64
linear_general_abstract_eval_axes
linear_general_abstract_eval_axes_f64
linear_general_abstract_eval_axes_pair
linear_general_abstract_eval_axes_pair_f64
v0.1.0
nnx.log_softmax LogSoftmax log_softmax
log_softmax_f64
v0.1.0
nnx.max_pool MaxPool
Transpose
max_pool
max_pool_f64
max_pool_same_padding
max_pool_same_padding_f64
v0.1.0
nnx.relu Relu relu_1d
relu_1d_f64
relu_4d_dynamic
relu_4d_dynamic_f64
relu_4d
relu_4d_f64
v0.1.0
nnx.rms_norm RMSNormalization rms_norm_basic
rms_norm_basic_f64
rms_norm_use_scale_false
rms_norm_use_scale_false_f64
rms_norm_4d_dynamic_dynamic
rms_norm_4d_dynamic_dynamic_f64
rms_norm_4d_dynamic
rms_norm_4d_dynamic_f64
rms_norm_4d_dynamic_no_scale_dynamic
rms_norm_4d_dynamic_no_scale_dynamic_f64
rms_norm_4d_dynamic_no_scale
rms_norm_4d_dynamic_no_scale_f64
v0.3.0
nnx.sigmoid Sigmoid sigmoid
sigmoid_f64
v0.1.0
nnx.softmax Softmax softmax_dynamic
softmax_dynamic_f64
softmax
softmax_f64
v0.1.0
nnx.softplus Softplus softplus
softplus_f64
v0.1.0
nnx.tanh Tanh tanh
tanh_f64
v0.1.0

Legend:
✅ = Passed
❌ = Failed
➖ = No testcase yet


🎯 Examples

Component Description Testcases Since
MlpExample A simple MLP example using Equinox. mlp_training_mode
mlp_training_mode_f64
mlp_inference_mode
mlp_inference_mode_f64
mlp_batched_training_mode_dynamic
mlp_batched_training_mode_dynamic_f64
mlp_batched_training_mode
mlp_batched_training_mode_f64
v0.7.0
SimpleLinearExample A simple linear layer example using Equinox. simple_linear_dynamic
simple_linear_dynamic_f64
simple_linear
simple_linear_f64
nn_linear_dynamic
nn_linear_dynamic_f64
nn_linear
nn_linear_f64
v0.7.0
GPT A simple GPT model that reuses nnx.MultiHeadAttention. gpt_dynamic
gpt
v0.7.0
GPT_CausalSelfAttention A causal self-attention module. causal_self_attention_dynamic
causal_self_attention
v0.7.0
GPT_Embeddings Combines token and position embeddings with dropout. gpt_embeddings_dynamic
gpt_embeddings
v0.7.0
GPT_Head The head of the GPT model. gpt_head_dynamic
gpt_head
v0.7.0
GPT_MLP An MLP block with GELU activation from nanoGPT. gpt_mlp_dynamic
gpt_mlp
v0.7.0
GPT_PositionEmbedding A positional embedding layer using nnx.Embed. position_embedding v0.7.0
GPT_TokenEmbedding A token embedding layer using nnx.Embed. token_embedding_dynamic
token_embedding
v0.7.0
GPT_TransformerBlock A transformer block combining attention and MLP. gpt_block_dynamic
gpt_block
v0.7.0
GPT_TransformerStack A stack of transformer blocks. transformer_stack_dynamic
transformer_stack
v0.7.0
cfl_timestep Tests the CFL condition timestep calculation. cfl_timestep_f64 v0.6.5
weno_reconstruction Tests the complex arithmetic pattern found in WENO schemes. weno_reconstruction_f64 v0.6.5
fori_loop_test fori_loop_test: Demonstrates jax.lax.fori_loop with a simple loop. fori_loop_test
fori_loop_test_f64
v0.6.3
issue18_abs Test jnp.abs from issue 18 abs_fn
abs_fn_f64
v0.6.3
issue18_arange Test arange from issue 18 arange_fn
arange_fn_f64
v0.6.3
issue18_fori_loop Test fori_loop from issue 18 fori_loop_fn
fori_loop_fn_f64
v0.6.3
issue18_linspace Test linspace from issue 18 linspace_fn
linspace_fn_f64
v0.6.3
issue18_scan Test scan from issue 18 (no xs) scan_fn
scan_fn_f64
v0.6.3
issue18_sign Test jnp.sign from issue 18 sign_fn
sign_fn_f64
v0.6.3
issue18_where Test where from issue 18 where_fn
where_fn_f64
v0.6.3
issue18_while_loop Test while_loop from issue 18 while_loop_fn
while_loop_fn_f64
v0.6.3
select_test select_test: Demonstrates jnp.select with a dynamic condition based on an input array. select_test_all_options
select_test_scalar_select_option_0
select_test_scalar_select_option_1
select_test_scalar_select_option_2
select_test_default_case
v0.6.1
sort_test sort_test: Demonstrates jnp.sort on slices of an input array. sort_test_basic v0.6.1
cond_scatter_add_mul Tests scatter_add/mul inside jnp.where branches cond_scatter_add_mul_f64 v0.6.4
cond_scatter_repro Reproduces a bug where lax.cond subgraphs do not inherit parent initializers. cond_scatter_repro_f64 v0.6.4
remat2 Tests a simple case of jax.checkpoint (also known as jax.remat2). checkpoint_scalar_f32
checkpoint_scalar_f32_f64
v0.6.5
AutoEncoder A simple autoencoder example. simple_autoencoder
simple_autoencoder_f64
v0.2.0
CNN A simple convolutional neural network (CNN). simple_cnn_explicit_dimensions
simple_cnn_explicit_dimensions_f64
simple_cnn_dynamic
simple_cnn_dynamic_f64
simple_cnn
simple_cnn_f64
v0.1.0
CNN2 A CNN with a while_loop. simple_cnn_inference_dynamic
simple_cnn_inference_dynamic_f64
simple_cnn_inference
simple_cnn_inference_f64
v0.6.5
ForiLoop fori_loop example fori_loop_counter
fori_loop_counter_f64
v0.5.1
MLP A simple Multi-Layer Perceptron (MLP) with BatchNorm, Dropout, and GELU activation. simple_mlp_dynamic
simple_mlp_dynamic_f64
simple_mlp
simple_mlp_f64
simple_mlp_with_call_params_dynamic
simple_mlp_with_call_params_dynamic_f64
simple_mlp_with_call_params
simple_mlp_with_call_params_f64
v0.1.0
MultiHeadAttention This is a multi-head attention module implemented by Flax/nnx that has no ONNX correspondent on the same granularity. multihead_attention_nn_dynamic
multihead_attention_nn_dynamic_f64
multihead_attention_nn
multihead_attention_nn_f64
multihead_attention_nnx_dynamic
multihead_attention_nnx_dynamic_f64
multihead_attention_nnx
multihead_attention_nnx_f64
v0.2.0
onnx_functions_000 one function on an outer layer. 000_one_function_on_outer_layer_dynamic
000_one_function_on_outer_layer_dynamic_f64
000_one_function_on_outer_layer
000_one_function_on_outer_layer_f64
v0.4.0
onnx_functions_001 one function on an inner layer. 001_one_function_inner_dynamic
001_one_function_inner_dynamic_f64
001_one_function_inner
001_one_function_inner_f64
v0.4.0
onnx_functions_002 two nested functions. 002_two_nested_functions_dynamic
002_two_nested_functions_dynamic_f64
002_two_nested_functions
002_two_nested_functions_f64
v0.4.0
onnx_functions_003 two nested functions. 003_two_simple_nested_functions_dynamic
003_two_simple_nested_functions_dynamic_f64
003_two_simple_nested_functions
003_two_simple_nested_functions_f64
v0.4.0
onnx_functions_004 nested function plus component 004_nested_function_plus_component_dynamic
004_nested_function_plus_component_dynamic_f64
004_nested_function_plus_component
004_nested_function_plus_component_f64
v0.4.0
onnx_functions_005 nested function plus more components 005_nested_function_plus_component_dynamic
005_nested_function_plus_component_dynamic_f64
005_nested_function_plus_component
005_nested_function_plus_component_f64
v0.4.0
onnx_functions_006 one function on an outer layer. 006_one_function_outer_dynamic
006_one_function_outer_dynamic_f64
006_one_function_outer
006_one_function_outer_f64
v0.4.0
onnx_functions_007 transformer block with nested mlp block with call parameter 007_transformer_block_dynamic
007_transformer_block_dynamic_f64
007_transformer_block
007_transformer_block_f64
v0.4.0
onnx_functions_008 transformer block with nested mlp block no call parameter 008_transformer_block_dynamic
008_transformer_block_dynamic_f64
008_transformer_block
008_transformer_block_f64
v0.4.0
onnx_functions_009 transformer block using decorator on class and function 009_transformer_block_dynamic
009_transformer_block_dynamic_f64
009_transformer_block
009_transformer_block_f64
v0.4.0
onnx_functions_010 transformer stack 010_transformer_stack_dynamic
010_transformer_stack_dynamic_f64
010_transformer_stack
010_transformer_stack_f64
v0.4.0
onnx_functions_012 Vision Transformer (ViT) 012_vit_conv_embedding_dynamic
012_vit_conv_embedding_dynamic_f64
012_vit_conv_embedding
012_vit_conv_embedding_f64
v0.4.0
onnx_functions_013 Vision Transformer (ViT) 013_vit_conv_embedding_with_call_params_dynamic
013_vit_conv_embedding_with_call_params_dynamic_f64
013_vit_conv_embedding_with_call_params
013_vit_conv_embedding_with_call_params_f64
013_vit_conv_embedding_with_internal_call_params_dynamic
013_vit_conv_embedding_with_internal_call_params_dynamic_f64
013_vit_conv_embedding_with_internal_call_params
013_vit_conv_embedding_with_internal_call_params_f64
v0.4.0
onnx_functions_014 one function on an outer layer. 014_one_function_with_input_param_with_default_value
014_one_function_with_input_param_with_default_value_f64
014_one_function_without_input_param_with_default_value_dynamic
014_one_function_without_input_param_with_default_value_dynamic_f64
014_one_function_without_input_param_with_default_value
014_one_function_without_input_param_with_default_value_f64
v0.4.0
onnx_functions_015 one function on an outer layer. 015_one_function_with_input_param_without_default_value_dynamic
015_one_function_with_input_param_without_default_value_dynamic_f64
015_one_function_with_input_param_without_default_value
015_one_function_with_input_param_without_default_value_f64
v0.4.0
onnx_functions_016 nested function plus more components 016_internal_function_with_input_param_with_default_value_dynamic
016_internal_function_with_input_param_with_default_value_dynamic_f64
016_internal_function_with_input_param_with_default_value
016_internal_function_with_input_param_with_default_value_f64
v0.4.0
ClassificationHead Classification head for Vision Transformer classification_head_dynamic
classification_head_dynamic_f64
classification_head
classification_head_f64
v0.4.0
ClassificationHeadFlatten Classification head for Vision Transformer classification_head_flat_dynamic
classification_head_flat_dynamic_f64
classification_head_flat
classification_head_flat_f64
v0.4.0
ConcatClsToken Concatenate CLS token to the input embedding concat_cls_token_dynamic
concat_cls_token_dynamic_f64
concat_cls_token
concat_cls_token_f64
v0.4.0
ConcatClsTokenFlatten Concatenate CLS token to the input embedding concat_cls_token_flat_dynamic
concat_cls_token_flat_dynamic_f64
concat_cls_token_flat
concat_cls_token_flat_f64
v0.4.0
ConvEmbedding Convolutional Token Embedding for MNIST with hierarchical downsampling. mnist_conv_embedding_dynamic
mnist_conv_embedding_dynamic_f64
mnist_conv_embedding
mnist_conv_embedding_f64
v0.1.0
ConvEmbeddingFlatten Convolutional Token Embedding for MNIST with hierarchical downsampling. mnist_conv_embedding_flat_dynamic
mnist_conv_embedding_flat_dynamic_f64
mnist_conv_embedding_flat
mnist_conv_embedding_flat_f64
v0.1.0
FeedForward MLP in Transformer feed_forward_dynamic
feed_forward_dynamic_f64
feed_forward
feed_forward_f64
v0.1.0
FeedForwardFlatten MLP in Transformer feed_forward_flat_dynamic
feed_forward_flat_dynamic_f64
feed_forward_flat
feed_forward_flat_f64
v0.1.0
GetToken Get the CLS token from the input embedding get_token_dynamic
get_token_dynamic_f64
get_token
get_token_f64
v0.4.0
GetTokenFlatten Get the CLS token from the input embedding get_token_flat_dynamic
get_token_flat_dynamic_f64
get_token_flat
get_token_flat_f64
v0.4.0
PatchEmbedding Cutting the image into patches and linearly embedding them. patch_embedding_dynamic
patch_embedding_dynamic_f64
patch_embedding
patch_embedding_f64
v0.1.0
PatchEmbeddingFlatten Cutting the image into patches and linearly embedding them. patch_embedding_flat_dynamic
patch_embedding_flat_dynamic_f64
patch_embedding_flat
patch_embedding_flat_f64
v0.1.0
PositionalEmbedding Add positional embedding to the input embedding positional_embedding_dynamic
positional_embedding_dynamic_f64
positional_embedding
positional_embedding_f64
v0.4.0
PositionalEmbeddingFlatten Add positional embedding to the input embedding positional_embedding_flat_dynamic
positional_embedding_flat_dynamic_f64
positional_embedding_flat
positional_embedding_flat_f64
v0.4.0
TransformerBlock Transformer from 'Attention Is All You Need.' transformer_block_dynamic
transformer_block_dynamic_f64
transformer_block
transformer_block_f64
v0.1.0
TransformerBlockFlatten Transformer from 'Attention Is All You Need.' transformer_block_flat_dynamic
transformer_block_flat_dynamic_f64
transformer_block_flat
transformer_block_flat_f64
v0.1.0
TransformerStack Stack of Transformer blocks transformer_stack_dynamic
transformer_stack_dynamic_f64
transformer_stack
transformer_stack_f64
v0.1.0
TransformerStackFlatten Stack of Transformer blocks transformer_stack_flat_dynamic
transformer_stack_flat_dynamic_f64
transformer_stack_flat
transformer_stack_flat_f64
v0.1.0
VisionTransformer A Vision Transformer (ViT) model for MNIST with configurable embedding type. vit_conv_embedding_dynamic
vit_conv_embedding_dynamic_f64
vit_conv_embedding
vit_conv_embedding_f64
vit_patch_embedding
vit_patch_embedding_f64
v0.2.0
VisionTransformerFlatten A Vision Transformer (ViT) model for MNIST with configurable embedding type. vit_conv_embedding_flat_dynamic
vit_conv_embedding_flat_dynamic_f64
vit_conv_embedding_flat
vit_conv_embedding_flat_f64
vit_patch_embedding_flat_dynamic
vit_patch_embedding_flat_dynamic_f64
vit_patch_embedding_flat
vit_patch_embedding_flat_f64
v0.2.0

📌 Dependencies

Versions of Major Dependencies:

Library Versions
JAX 0.6.2
Flax 0.10.6
onnx 1.18.0
onnxruntime 1.22.0

Note: For more details, check pyproject.toml.


⚠️ Limitations

  • Currently not all JAX/Flax components are supported (you can easily help expand this coverage!).
  • Function references need dynamic resolution at call-time.
  • ONNX graph composition is done in-memory before saving to disk, potentially causing memory issues with very large models.

🤝 How to Contribute

We warmly welcome contributions!

How you can help:

  • Add a plugin: Extend jax2onnx by writing a simple Python file in jax2onnx/plugins. a custom primitive or an example.
  • Bug fixes & improvements: PRs and issues are always welcome.

💾 Installation

Install from PyPI:

pip install jax2onnx  

📜 License

This project is licensed under the Apache License, Version 2.0. See LICENSE for details.


🌟 Special Thanks

Special thanks for example contributions to @burakssen, @Cadynum and @clementpoiret

Special thanks for plugin contributions to @burakssen

Special thanks to tumaer/JAXFLUIDS for contributing valuable insights rooted in physics simulation use cases.

Special thanks to @lutzroeder for making shapes internal to ONNX function visible in his great Netron viewer.

Special thanks to the community members involved in:

A huge thanks especially to @limarta, whose elegant jaxpr-to-ONNX demonstration significantly inspired this project.


Happy converting! 🎉

About

export JAX to ONNX - focus on flax nnx models

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Contributors 3

  •  
  •  
  •