Skip to content

Releases: enpasos/jax2onnx

0.12.0

09 Feb 19:20

Choose a tag to compare

jax2onnx 0.12.0 - Layout controls, opset 23 defaults, and regression hardening

  • NCHW boundary layout support:
    Added inputs_as_nchw / outputs_as_nchw for to_onnx(...) and allclose(...), with layout-optimization docs/tests and transpose-cleanup improvements for Conv-heavy graphs.

  • Depth-to-space and residual-stack coverage:
    Added dm_pix.depth_to_space lowering to ONNX DepthToSpace and expanded NNX regression examples/tests for depth-to-space and nested residual groups.

  • Primitive and IR improvements:
    Added jax.numpy.mean lowering to ReduceMean; fixed symbolic dim_as_value handling; and stabilized dynamic reshape folding used by CLIP/MaxText exports.

  • ONNX opset 23 path for attention models:
    Added opset >= 23 RotaryEmbedding/Attention support and made opset 23 the default in to_onnx(...).

  • Gather/scatter regression fixes:
    Fixed scatter-add broadcast window handling and issue #52 lowering edge cases; fixed gather indexing and vmap(dynamic_slice_in_dim) gather lowering regressions.

  • Compatibility refresh:
    Expanded tested Python versions to 3.11-3.14 and updated runtime dependency floors (onnx, onnxruntime, dm-pix) for the new paths.

0.11.2

26 Jan 09:15

Choose a tag to compare

jax2onnx 0.11.2 – Native onnx-ir cloning, unified passes & sturdier MaxDiffusion exports

  • Native onnx-ir graph cloning:
    Replaced the custom ir_clone implementation with onnx-ir’s native Graph.clone() method, improving maintainability and leveraging upstream validation (PR #162, #163).

  • Unified IR pass infrastructure:
    Streamlined the optimization pipeline by adopting standard onnx-ir passes and removing redundant custom pass logic (PR #151).

  • MaxDiffusion robustness:
    Fixed environment-dependent crashes (UnboundLocalError) and corrected type annotations in the MaxDiffusion plugin stack.

0.11.1

13 Jan 05:57

Choose a tag to compare

jax2onnx 0.11.1 – MaxText model family coverage & cleaner exported graphs

  • Comprehensive MaxText example stack:
    Added a fully comprehensive MaxText example + test suite covering exports for DeepSeek, Gemma, GPT-3, Kimi, Llama, Mistral, and Qwen model families.

  • MaxText stubs & new primitive coverage:
    Introduced MaxText dependency stubs and implemented new primitive support required to enable those exports end-to-end.

  • Cleaner ONNX graphs via stricter subgraph cleanup:
    Tightened subgraph cleanup to produce cleaner, more minimal ONNX graphs (less leftover/unused substructure after export).

0.11.0

29 Dec 15:03

Choose a tag to compare

jax2onnx 0.11.0 – Flax Linen support & modern IR optimization

  • Initial Flax Linen support:
    Added first-class export support for Flax Linen, including core layers (Dense, DenseGeneral, Conv*, pooling), normalization stacks (BatchNorm, LayerNorm, GroupNorm, RMSNorm, InstanceNorm), Dropout, Einsum/Embed, spectral/weight-norm wrappers, extended activation coverage (GELU, GLU family, hard_, log_, relu6, silu/swish, tanh, normalize, one_hot), the full Linen attention stack (dot_product_attention, mask helpers, SelfAttention, MultiHeadDotProductAttention, MultiHeadAttention), and recurrent modules (SimpleCell, GRUCell, MGUCell, LSTMCell, OptimizedLSTMCell, ConvLSTMCell, RNN, Bidirectional).

  • Linen examples:
    Added end-to-end Linen reference examples (MLP, CNN, Sequential) to validate export coverage and serve as usage templates.

  • IR optimizer modernization:
    Modernized the IR optimization pipeline by adopting the standard onnx-ir CSE pass, removing legacy helpers and getattr-based patterns, and simplifying tests via direct graph iteration.

0.10.4

22 Dec 19:08

Choose a tag to compare

jax2onnx 0.10.4 – vmap fixes, IR optimizations & CI

  • vmap batching fixes:
    Corrected batching for jax.numpy.reshape, transpose, and several other primitives.

  • IR optimizer refactor:
    Switched to onnx-ir public APIs (value.consumers(), graph.remove()), removing internal helpers.

  • New IR passes:
    Added Common Subexpression Elimination (CSE) and Constant Lifting in ir_optimizations.py.

  • CI:
    Added GitHub Actions for automated testing.

0.10.3

09 Dec 19:18

Choose a tag to compare

jax2onnx 0.10.3 – DINOv3 NNX stack & example registry hygiene

  • Flax/NNX DINOv3 VisionTransformer stack:

    • Added a DINOv3 VisionTransformer example (plugins/examples/nnx/dinov3.py) with deterministic RNG helpers, rotary-cache capture, and expect_graph coverage across multiple ViT variants.
    • Documented NNX DINO exports for static and dynamic batch sizes and kept generated ONNX artifacts out of git via a dedicated .gitignore.
  • Equinox→NNX DINO parity:

    • Introduced Equinox→NNX DINOv3 parity testing (weight copy + forward-output check) to keep the Equinox and Flax/NNX export paths aligned.
  • Example registry hygiene:

    • Made example registry keys context-aware (context::component) and emit override warnings to avoid collisions between example stacks.

0.10.2

01 Dec 14:32

Choose a tag to compare

jax2onnx 0.10.2 – GPT-OSS export, symbolic dims, typing overhaul

  • GPT-OSS export stack: Equinox + Flax/NNX reference modules, parity harnesses, exporter scripts, and docs so the open MoE transformer can be converted end-to-end and validated numerically.

  • New ops & masking semantics:

    • New primitive coverage: lax.top_k, lax.rsqrt, and Equinox RMSNorm lowerings, all landing with tests.
    • Masked softmax now lowers where-masked calls to ONNX Softmax + Where, explicitly zeroing masked positions.
  • Control-flow & scatter correctness:

    • Scatter operations embedded in cond / scan now preserve ONNX-compliant initializers and dtypes via the renewed index helpers, fixing the Issue #139 regression suite.
  • Symbolic dimensions & IR determinism:

    • Symbolic-dimension support strengthened via DimExpr lowering and shape-polynomial helpers to stabilize broadcast / loop / gather shapes.
    • IR return-mode / input_param materialization fixed (and legacy serde_onnx removed) so IR-only output stays deterministic.
  • Typing & tooling:

    • Typing overhaul with shared typing_support protocols, stricter mypy coverage, and helper scripts (check_typing.sh, report_rng_traces.py).
    • Flax NNX compatibility tweaks for the newer Linear / Einsum parameter access patterns.
  • Dependencies:

    • Dependency stack bumped to JAX 0.8.1 / Flax 0.12.1 with corresponding NNX plugin updates.

0.10.1

06 Nov 13:45

Choose a tag to compare

jax2onnx 0.10.1 – Complex numbers, stacktrace metadata, op polish

  • Complex numbers: unified packed layout ([..., 2]), helper stack for packing/unpacking & dtype reconciliation, and the first wave of plugin coverage:

    • Elementwise: lax/jnp add / sub / mul / div
    • Conjugation: lax.conj, jnp.conj
    • Bilinear ops: lax.dot_general, jnp.matmul, lax.conv_general_dilated
    • Transforms: FFT via an ONNX-compliant DFT lowering
    • Docs: see docs/dev_guides/complex_numbers.md
  • Observability & debugging: stacktrace metadata toggles (pkg.jax2onnx.callsite / pkg.jax2onnx.plugin) with optional full Python/JAX traces.

  • Ops & lowerings:

    • lax.dot_general: add Einsum fallback.
    • lax.broadcast_in_dim: keep constant folding on handler infrastructure, preserve loop-extent metadata, and always emit Expand for deterministic IR.
    • lax.reduce_window_sum: new Conv-based lowering that handles strides, window dilation, integer operands (via cast wrappers), and static base-dilation expansion.

0.10.0

24 Oct 13:09

Choose a tag to compare

jax2onnx 0.10.0 – Equinox/DINOv3 plugins, unique function reuse, JAX 0.8.0

  • Expanded Equinox coverage for a DINOv3 exporter:

    • New plugins: equinox/eqx/nn/conv.py, multihead_attention.py, rotary_positional_embedding.py
    • Example: plugins/examples/eqx/dino.py
  • Added lowering helpers and plugins:

    • _axis0_utils.py, _loop_extent_meta.py, jax/lax/gather_compile.py, jax/lax/gather_helpers.py, jax/image/resize.py, jax/numpy/outer.py
  • Rewrote and extended existing plugins—especially jax.lax control-flow & scatter/gather (incl. while_loop), jax.numpy batching ops (arange, reshape, split, stack, tile, where), and jax.nn activations/initializers—improving metadata, axis handling, and ONNX parity.

  • @onnx_function: declare once and reuse by passing the optional unique=True.

  • IR builder refactor: live graph proxies and a reusable clone_graph keep function/loop subgraphs detached and eliminate cross-graph ownership errors.

  • Dependency updates: JAX 0.8.0, onnx-ir 0.1.11.

0.9.0

12 Oct 17:37

Choose a tag to compare

jax2onnx 0.9.0 – ONNX-IR core, return_mode, JAX 0.7.2

  • Migrated from a proto-based ONNX representation to an IR-based one, significantly reducing peak memory during conversion - especially on large models.

  • Added return_mode to to_onnx:

    • "proto" (default) → returns an onnx.ModelProto
    • "ir" → returns the intermediate onnx_ir.Model
    • "file" → serializes directly to disk (faster than “proto” → external save).
  • Dependency updates: JAX 0.7.2, Flax 0.12.0 (requires Python ≥3.11), Equinox 0.13.2, onnx-ir 0.1.10, onnx 1.19.1.