What's Changed
- Set numpy<2.4 to fix DeprecationWarning in the CI doctest by @vfdev-5 in #5163
- ignore optax deprecation warning by @cgarciae in #5165
- fix general guides landing page by @ivrolan in #5139
- Remove nnx.split_rngs call wrapping nnx.scan in linen to nnx tutorial by @samanklesaria in #5160
- Add out_sharding arguments to linear layers where supported by @jackopenn in #5156
- fix kw_only_dataclasses for python 3.14 (part 2) by @copybara-service[bot] in #5135
- No public description by @copybara-service[bot] in #5164
- Have _graph_flatten respect nnx.data declarations (extension of #5140) by @samanklesaria in #5159
- [pmap] Avoid degraded performance under the new
jax.pmap. by @copybara-service[bot] in #5152 - Support multiple None and UNCONSTRAINED when resolving logical rules by @copybara-service[bot] in #5129
- improve hijax guide by @cgarciae in #5115
- docs: fix typo 'paramater' -> 'parameter' by @ayulockedin in #5166
- Make
nnx.popremove sown attributes. by @samanklesaria in #5133 - Rename sharding_names to sharding_metadata by @samanklesaria in #5089
- Fix bug in graph overhead benchmark by @samanklesaria in #5183
- Fixed typos in the docstrings using antigravity by @vfdev-5 in #5145
- docs(nnx): add missing functional args to Conv and LinearGeneral by @ayulockedin in #5174
- empty change by @copybara-service[bot] in #5184
- Use nnx split during tabulate (clone of #5069) by @samanklesaria in #5186
- Handle pure bodies in nnx.fori_loop by @samanklesaria in #5141
- Typo fix in _cached_partial method by @vfdev-5 in #5142
- Update mnist example to use NNX (clone of #5064) by @samanklesaria in #5188
- Docs: Fix typo and clarify introduction in Functional API section by @Moriyuki-S in #5157
- Shard split rngs in lifted vmap if spmd_axis_name is given and applied to vmapped axis. Without this, jax.vmap gives an error:
ValueError: Mapped away dimension of inputs passed to vmap should be sharded the same. Got inconsistent axis specs: None vs batchdue to split_rngs being replicated. by @copybara-service[bot] in #5189 - Remove ref cycles introduced by self-calling nested functions. by @copybara-service[bot] in #5193
- Add HijaxTransformCoverageTest by @cgarciae in #5190
- allow nnx standalone import p1 by @copybara-service[bot] in #5196
- Add _graph_node_set_key method for List class by @samanklesaria in #5171
- _apply_sharding disallow mixed Explicit/Auto mesh by @copybara-service[bot] in #5199
- update flax to version 0.12.3 by @copybara-service[bot] in #5206
New Contributors
- @ivrolan made their first contribution in #5139
- @jackopenn made their first contribution in #5156
- @ayulockedin made their first contribution in #5166
- @Moriyuki-S made their first contribution in #5157
Full Changelog: v0.12.2...v0.12.3