Releases: enpasos/jax2onnx
0.12.0
jax2onnx 0.12.0 - Layout controls, opset 23 defaults, and regression hardening
-
NCHW boundary layout support:
Addedinputs_as_nchw/outputs_as_nchwforto_onnx(...)andallclose(...), with layout-optimization docs/tests and transpose-cleanup improvements for Conv-heavy graphs. -
Depth-to-space and residual-stack coverage:
Addeddm_pix.depth_to_spacelowering to ONNXDepthToSpaceand expanded NNX regression examples/tests for depth-to-space and nested residual groups. -
Primitive and IR improvements:
Addedjax.numpy.meanlowering toReduceMean; fixed symbolicdim_as_valuehandling; and stabilized dynamic reshape folding used by CLIP/MaxText exports. -
ONNX opset 23 path for attention models:
Added opset >= 23 RotaryEmbedding/Attention support and made opset 23 the default into_onnx(...). -
Gather/scatter regression fixes:
Fixed scatter-add broadcast window handling and issue #52 lowering edge cases; fixed gather indexing andvmap(dynamic_slice_in_dim)gather lowering regressions. -
Compatibility refresh:
Expanded tested Python versions to 3.11-3.14 and updated runtime dependency floors (onnx,onnxruntime,dm-pix) for the new paths.
0.11.2
jax2onnx 0.11.2 – Native onnx-ir cloning, unified passes & sturdier MaxDiffusion exports
-
Native
onnx-irgraph cloning:
Replaced the customir_cloneimplementation withonnx-ir’s nativeGraph.clone()method, improving maintainability and leveraging upstream validation (PR #162, #163). -
Unified IR pass infrastructure:
Streamlined the optimization pipeline by adopting standardonnx-irpasses and removing redundant custom pass logic (PR #151). -
MaxDiffusion robustness:
Fixed environment-dependent crashes (UnboundLocalError) and corrected type annotations in the MaxDiffusion plugin stack.
0.11.1
jax2onnx 0.11.1 – MaxText model family coverage & cleaner exported graphs
-
Comprehensive MaxText example stack:
Added a fully comprehensive MaxText example + test suite covering exports for DeepSeek, Gemma, GPT-3, Kimi, Llama, Mistral, and Qwen model families. -
MaxText stubs & new primitive coverage:
Introduced MaxText dependency stubs and implemented new primitive support required to enable those exports end-to-end. -
Cleaner ONNX graphs via stricter subgraph cleanup:
Tightened subgraph cleanup to produce cleaner, more minimal ONNX graphs (less leftover/unused substructure after export).
0.11.0
jax2onnx 0.11.0 – Flax Linen support & modern IR optimization
-
Initial Flax Linen support:
Added first-class export support for Flax Linen, including core layers (Dense,DenseGeneral,Conv*, pooling), normalization stacks (BatchNorm, LayerNorm, GroupNorm, RMSNorm, InstanceNorm), Dropout, Einsum/Embed, spectral/weight-norm wrappers, extended activation coverage (GELU, GLU family, hard_, log_, relu6, silu/swish, tanh, normalize, one_hot), the full Linen attention stack (dot_product_attention, mask helpers,SelfAttention,MultiHeadDotProductAttention,MultiHeadAttention), and recurrent modules (SimpleCell,GRUCell,MGUCell,LSTMCell,OptimizedLSTMCell,ConvLSTMCell,RNN,Bidirectional). -
Linen examples:
Added end-to-end Linen reference examples (MLP, CNN, Sequential) to validate export coverage and serve as usage templates. -
IR optimizer modernization:
Modernized the IR optimization pipeline by adopting the standard onnx-ir CSE pass, removing legacy helpers andgetattr-based patterns, and simplifying tests via direct graph iteration.
0.10.4
jax2onnx 0.10.4 – vmap fixes, IR optimizations & CI
-
vmapbatching fixes:
Corrected batching forjax.numpy.reshape,transpose, and several other primitives. -
IR optimizer refactor:
Switched to onnx-ir public APIs (value.consumers(),graph.remove()), removing internal helpers. -
New IR passes:
Added Common Subexpression Elimination (CSE) and Constant Lifting inir_optimizations.py. -
CI:
Added GitHub Actions for automated testing.
0.10.3
jax2onnx 0.10.3 – DINOv3 NNX stack & example registry hygiene
-
Flax/NNX DINOv3 VisionTransformer stack:
- Added a DINOv3 VisionTransformer example (
plugins/examples/nnx/dinov3.py) with deterministic RNG helpers, rotary-cache capture, andexpect_graphcoverage across multiple ViT variants. - Documented NNX DINO exports for static and dynamic batch sizes and kept generated ONNX artifacts out of git via a dedicated
.gitignore.
- Added a DINOv3 VisionTransformer example (
-
Equinox→NNX DINO parity:
- Introduced Equinox→NNX DINOv3 parity testing (weight copy + forward-output check) to keep the Equinox and Flax/NNX export paths aligned.
-
Example registry hygiene:
- Made example registry keys context-aware (
context::component) and emit override warnings to avoid collisions between example stacks.
- Made example registry keys context-aware (
0.10.2
jax2onnx 0.10.2 – GPT-OSS export, symbolic dims, typing overhaul
-
GPT-OSS export stack: Equinox + Flax/NNX reference modules, parity harnesses, exporter scripts, and docs so the open MoE transformer can be converted end-to-end and validated numerically.
-
New ops & masking semantics:
- New primitive coverage:
lax.top_k,lax.rsqrt, and EquinoxRMSNormlowerings, all landing with tests. - Masked softmax now lowers
where-masked calls to ONNXSoftmax+Where, explicitly zeroing masked positions.
- New primitive coverage:
-
Control-flow & scatter correctness:
- Scatter operations embedded in
cond/scannow preserve ONNX-compliant initializers and dtypes via the renewed index helpers, fixing the Issue #139 regression suite.
- Scatter operations embedded in
-
Symbolic dimensions & IR determinism:
- Symbolic-dimension support strengthened via
DimExprlowering and shape-polynomial helpers to stabilize broadcast / loop / gather shapes. - IR return-mode /
input_parammaterialization fixed (and legacyserde_onnxremoved) so IR-only output stays deterministic.
- Symbolic-dimension support strengthened via
-
Typing & tooling:
- Typing overhaul with shared
typing_supportprotocols, stricter mypy coverage, and helper scripts (check_typing.sh,report_rng_traces.py). - Flax NNX compatibility tweaks for the newer
Linear/Einsumparameter access patterns.
- Typing overhaul with shared
-
Dependencies:
- Dependency stack bumped to JAX 0.8.1 / Flax 0.12.1 with corresponding NNX plugin updates.
0.10.1
jax2onnx 0.10.1 – Complex numbers, stacktrace metadata, op polish
-
Complex numbers: unified packed layout (
[..., 2]), helper stack for packing/unpacking & dtype reconciliation, and the first wave of plugin coverage:- Elementwise:
lax/jnpadd/sub/mul/div - Conjugation:
lax.conj,jnp.conj - Bilinear ops:
lax.dot_general,jnp.matmul,lax.conv_general_dilated - Transforms: FFT via an ONNX-compliant DFT lowering
- Docs: see
docs/dev_guides/complex_numbers.md
- Elementwise:
-
Observability & debugging: stacktrace metadata toggles (
pkg.jax2onnx.callsite/pkg.jax2onnx.plugin) with optional full Python/JAX traces. -
Ops & lowerings:
lax.dot_general: addEinsumfallback.lax.broadcast_in_dim: keep constant folding on handler infrastructure, preserve loop-extent metadata, and always emitExpandfor deterministic IR.lax.reduce_window_sum: new Conv-based lowering that handles strides, window dilation, integer operands (via cast wrappers), and static base-dilation expansion.
0.10.0
jax2onnx 0.10.0 – Equinox/DINOv3 plugins, unique function reuse, JAX 0.8.0
-
Expanded Equinox coverage for a DINOv3 exporter:
- New plugins:
equinox/eqx/nn/conv.py,multihead_attention.py,rotary_positional_embedding.py - Example:
plugins/examples/eqx/dino.py
- New plugins:
-
Added lowering helpers and plugins:
_axis0_utils.py,_loop_extent_meta.py,jax/lax/gather_compile.py,jax/lax/gather_helpers.py,jax/image/resize.py,jax/numpy/outer.py
-
Rewrote and extended existing plugins—especially
jax.laxcontrol-flow & scatter/gather (incl.while_loop),jax.numpybatching ops (arange,reshape,split,stack,tile,where), andjax.nnactivations/initializers—improving metadata, axis handling, and ONNX parity. -
@onnx_function: declare once and reuse by passing the optionalunique=True. -
IR builder refactor: live graph proxies and a reusable
clone_graphkeep function/loop subgraphs detached and eliminate cross-graph ownership errors. -
Dependency updates: JAX 0.8.0, onnx-ir 0.1.11.
0.9.0
jax2onnx 0.9.0 – ONNX-IR core, return_mode, JAX 0.7.2
-
Migrated from a proto-based ONNX representation to an IR-based one, significantly reducing peak memory during conversion - especially on large models.
-
Added
return_modetoto_onnx:"proto"(default) → returns anonnx.ModelProto"ir"→ returns the intermediateonnx_ir.Model"file"→ serializes directly to disk (faster than “proto” → external save).
-
Dependency updates: JAX 0.7.2, Flax 0.12.0 (requires Python ≥3.11), Equinox 0.13.2, onnx-ir 0.1.10, onnx 1.19.1.