Releases: google/flax
Version 0.12.2
What's Changed
- [flax:examples:wmt] Small linter fixes. by @copybara-service[bot] in #5012
- [flax:examples:seq2seq] Create main and default config based on seq2seq.ipynb. by @copybara-service[bot] in #5119
- [flax:examples:vae] Small linter fixes. by @copybara-service[bot] in #5014
- [flax:examples:gemma] Fixing linter errors. by @copybara-service[bot] in #5013
- [flax:examples:sst2] Fix pytype errors. by @copybara-service[bot] in #5118
- Allow substring matching in
nnx.PathContainsby @thijs-vanweezel in #5094 - [flax:examples:sst2] Fix notebook error. by @copybara-service[bot] in #5122
- [flax:examples:ppo] Fix some linter / import issues. #jax-fixit by @copybara-service[bot] in #5120
- Avoid passing
concreteargument tojax.rematby @copybara-service[bot] in #5121 - [flax:examples:lm1b_nnx] Update example to work internally. #jax-fixit. by @copybara-service[bot] in #5125
- [flax:examples:nlp_seq] Create a main.py file to run tests with config files to match other examples. #jax-fixit by @copybara-service[bot] in #5126
- [jax:benchmarks] Add tracing/lowering benchmarks for a few flax examples. by @copybara-service[bot] in #4911
- remove abstracted_axes from nnx.jit by @copybara-service[bot] in #5132
- Pooling operation by @jorisSchaller in #5057
- Added is_causal mask argument to flax.nnx.dot_product_attention by @ibbyml in #5093
- Add out_sharding argument to call methods for layers with jax calls that support it by @samanklesaria in #5102
- Temporary fix for failing CI by @vfdev-5 in #5144
- New release 0.12.2 by @IvyZX in #5149
New Contributors
- @thijs-vanweezel made their first contribution in #5094
- @ibbyml made their first contribution in #5093
Full Changelog: v0.12.1...v0.12.2
v0.12.1
Deprecations
Variable.value
Variable.value is now deprecated. Consider the following example:
import jax.numpy as jnp
import jax
from flax import nnx
my_param = nnx.Param({'a': 0.0})
@nnx.jit
def f(m):
m.value['a'] = 1.0
return mRunning f(my_param) produces Param(value={'a': 0.0}), not Param(value={'a': 1.0}) as before. This is because getting the value parameter new returns a copy of the pytree values (like dict / list). Instead, use the __setitem__ method to update the value:
@nnx.jit
def f(m):
m['a'] = 1.0
return mnnx.Data and nnx.Static
nnx.Data and nnx.Static annotations are now deprecated. To create nnx.Pytree or nnx.Module dataclasses use the new nnx.dataclass with nnx.data and nnx.static as field descriptors.
# old
@dataclasses.dataclass
class Foo(nnx.Pytree):
a: nnx.Data[int]
b: nnx.Static[str]
# new
@nnx.dataclass
class Foo(nnx.Pytree):
a: int = nnx.data()
b: str = nnx.static()Pull Requests
- Clarify
*Normlayer docstrings:axis_index_groupsis unused under SPMD jit. by @copybara-service[bot] in #4940 - Move
ArrayRefcreation to the end ofVariablecreation by @IvyZX in #4980 - clean up jax.Ref-related names by @copybara-service[bot] in #4988
- Add compute_flops and compute_vjp_flops options to
nnx.tabulateby @samanklesaria in #4948 - Fix nnx.tabulate crash with empty dict/None values (fixes #4889) by @mohsinm-dev in #4891
- Future-proof imports of jax.new_ref / jax.Ref. by @copybara-service[bot] in #4986
- Use
jnp.stackinstead ofnp.stackinflax.training.common_utils.stack_forestby @vfdev-5 in #4991 - Fixed broken nnx.statelib.diff by @vfdev-5 in #4992
- Implemented spectral norm in NNX by @mattbahr in #4623
- Improve Variable.{get,set}_metadata by @cgarciae in #4985
- Move iter_children and iter_modules to functions by @samanklesaria in #4961
- Avoid install, import, or tests with tensorflow-text under Python 3.13+. by @jburnim in #5001
- disallow setting metadata through settattr by @cgarciae in #4993
- Use sphinx 6.2+ for docs, which works with Python 3.13. by @jburnim in #5009
- Removed kernel_init/bias_init atttributes from popular layers by @vfdev-5 in #4998
- Migrate from
jax.experimental.enable_x64tojax.enable_x64. by @copybara-service[bot] in #5011 - Add Rngs KeylessInitializers by @cgarciae in #5017
- optimize scan transpositions by @cgarciae in #5015
- Variable refactor by @cgarciae in #5006
- Remove invalid gymnasium dependency in pyproject.toml by @IvyZX in #5016
- Use jax.shard_map in flax by @copybara-service[bot] in #5020
- use jax.shard_map by @copybara-service[bot] in #5018
- Fix formatting in PR template checklist by @rapsealk in #5024
- Fixed attribute visualization in treescope_repr by @vfdev-5 in #5022
- feat: add
nnx.set_metadatato in-place change metadata of the state variables ofnnx.Modules by @pfackeldey in #5007 - Update README to use fully qualified
nnx.Linearin example by @rapsealk in #5023 - Fix nnx tabulate variable hooks by @mohsinm-dev in #5008
- python 3.13 support by @cgarciae in #4987
- Added a note in nnx.jit about arg donation by @vfdev-5 in #5031
- Add flip doc link to eager sharding error message by @IvyZX in #5033
- fix reseed for abstract values by @cgarciae in #5034
- Deduplicate
Variablenodes initer_graphand eliminate recursion. by @copybara-service[bot] in #5035 - Support for python 3.14 by @vfdev-5 in #5032
- [docs] Exposed more helper functions/classes in state.rst by @vfdev-5 in #5037
- Copybara import of the project: by @copybara-service[bot] in #5041
- Internal change by @copybara-service[bot] in #5048
- filter grad state in nnx.Optimizer by @copybara-service[bot] in #5049
- Add NNX WeightNorm (update of #4568) by @samanklesaria in #5043
- Fix shard_map documentation link in compilation.py by @vfdev-5 in #5038
- Fix ValueError when
nnx.jitis used withnnx.custom_vjpby @samanklesaria in #5045 - Recursive map by @chapman20j in #5042
- Convert linen pytorch guide to nnx by @samanklesaria in #4999
- Set Mode with Tests by @chapman20j in #5056
- Fixing Optimizer docstring - fixing #5060 by @Lucas-Fernandes-Martins in #5061
- Update tutorial examples to thread explicit RNGs by @samanklesaria in #4975
- Fix NNX jit static args with in_shardings issue #4989 by @mohsinm-dev in #4996
- support explicit sharding in eager sharding by @cgarciae in #5070
- Added missing LayerNorm test case into TestLayersSameGraph by @vfdev-5 in #5076
- fix main by @cgarciae in #5081
- docs: Document
allow_duplicatesargument ofnnx.to_arrays. by @dan-zheng in #5083 - add promote_dtype to all standard layers by @cgarciae in #5080
- add nnx.dataclass by @cgarciae in #5066
- Expand ConvTranspose padding documentation by @samanklesaria in #4990
- Added kernel_metadata/bias_metadata args to nnx layers by @vfdev-5 in #5074
- Add nnx.use_eager_sharding context manager by @samanklesaria in #5079
- fix main by @cgarciae in #5090
- Adding set_mode_info by @chapman20j in #5071
- Fixed nnx.scan with carry as pytree and sow by @vfdev-5 in #5073
- Fix bound method auto-unbinding for NNX transforms by @mohsinm-dev in #5055
- deprecate Variable.value by @cgarciae in #5052
- Add eq for variables by @samanklesaria in #5084
- Fixed deprecated .value usage failing CI tests by @vfdev-5 in #5097
- update jax minver to 0.8.1 by @cgarciae in #5095
New Contributors
- @samanklesaria made their first contribution in #4948
- @jburnim made their first contribution in #5001
- @rapsealk made their first contribution in #5024
- @pfackeldey made their first contribution in #5007
- @chapman20j made their first contribution in #5042
- @Lucas-Fernandes-Martins made their first contribution in #5061
Full Changelog: v0.12.0...v0.12.1
0.12.0
Flax 0.12.0 includes many updates and some important breaking changes to the NNX API.
Breaking Changes
Pytree Strict Attributes
nnx.Pytree and therefore nnx.Module are now stricter with regards to attributes that contain Arrays and changing the status of attributes. For example, the code below now fails:
from flax import nnx
import jax
import jax.numpy as jnp
class Foo(nnx.Module):
def __init__(self, use_bias, rngs):
self.layers = [ # ERROR
nnx.Linear(3, 3, rngs=rngs) for _ in range(5)
]
self.bias = None # status = static
if use_bias:
self.bias = nnx.Param(rngs.params.uniform(3,)) # ERRORThis happens for two reasons:
- JAX pytree structures that contain Arrays now have to be marked with
nnx.data. Alternatively, if the container pytree is alistor adict, you can usennx.Listornnx.Dict, which additionally allow mixed "data" and "static" elements. - Attributes will no longer automatically change their status—this now has to be done explicitly using
nnx.dataornnx.static. Additionally, assigning Arrays or structures with Arrays to static attributes is now an error, as they will not automatically change to data.
To fix the above you can just create layers as a List Module which is automatically recognized as data, and be explicit about bias being a data attribute on the first assignment by using nnx.data:
class Foo(nnx.Module):
def __init__(self, use_bias, rngs):
self.layers = nnx.List([ # nnx.data also works but List is recommended
nnx.Linear(3, 3, rngs=rngs) for _ in range(5)
])
self.bias = nnx.data(None)
if use_bias:
self.bias = nnx.Param(rngs.params.uniform(3,))For more information check the Module & Pytree guide.
Eager Sharding
Variables will now eagerly shard their values when sharding_names metadata is provided. A mesh is required—it can be provided either via passing a mesh metadata attribute or setting the global mesh context via jax.set_mesh. This simplifies the process of sharding a Variable to construction time:
jax.config.update('jax_num_cpu_devices', 8)
mesh = jax.make_mesh((2, 4), ('data', 'model'))
with jax.set_mesh(mesh):
variable = nnx.Param(jnp.ones((16, 32)), sharding_names=(None, 'model'))
print(variable.value.sharding)Eager sharding will also occur when using the nnx.with_partitioning initializer decorator and will automatically extend to the Optimizer. This means that both model and optimizer will be sharded at construction without the need for the somewhat cumbersome nnx.get_partition_spec + jax.lax.with_sharding_constraint + nnx.update pattern:
with jax.set_mesh(mesh):
linear = nnx.Linear(
in_features=16, out_features=16, use_bias=False,
kernel_init=nnx.with_partitioning(
nnx.initializers.lecun_normal(), (None, 'model')
),
rngs=nnx.Rngs(0),
)
optimizer = nnx.Optimizer(linear, optax.adam(1e-3), wrt=nnx.Param)
print(linear.kernel.value.sharding)
print(optimizer.opt_state[0].mu.kernel.value.sharding)For projects that currently rely on other means for sharding, eager sharding can be turned off by passing eager_sharding=False to the Variable constructor, either directly or through initializer decorators like nnx.with_partitioning:
linear = nnx.Linear(
in_features=16, out_features=16, use_bias=False,
kernel_init=nnx.with_partitioning(
nnx.initializers.lecun_normal(), (None, 'model'), eager_sharding=False
),
rngs=nnx.Rngs(0),
)
optimizer = nnx.Optimizer(linear, optax.adam(1e-3), wrt=nnx.Param)
print(linear.kernel.value.sharding)
print(optimizer.opt_state[0].mu.kernel.value.sharding)Eager sharding can also be turned off globally via the flax_always_shard_variable config flag or the FLAX_ALWAYS_SHARD_VARIABLE environment variable:
import flax
flax.config.update('flax_always_shard_variable', False)For more information, check out the Variable eager sharding FLIP.
In-Place Operators No Longer Allowed
In-place operators will now raise an error. This is done as part of the push for Variables to be compatible with Tracer semantics:
w = nnx.Variable(jnp.array(0))
w += 1 # ERRORThe fix is to simply operate on the .value property instead:
w.value += 1All Changes
- Doc fix: remove dead link to pre-Orbax checkpointing. by @copybara-service[bot] in #4914
- Fix typo in unflatten docs by @copybara-service[bot] in #4918
- fix RNN by @copybara-service[bot] in #4917
- Update optimizer.py to support masked variable from optax. by @ywrt in #4904
- Added missing functions to graph.rst by @vfdev-5 in #4922
- Update flax/docs_nnx/guides/performance.md and .ipynb by @hanrach9 in #4919
- Added preferred_element_type arg to nnx.Linear*, nnx.Conv*, nnx.Einsum by @vfdev-5 in #4920
- Update README badges and remove invalid ones by @IvyZX in #4905
- static + pytree guide by @cgarciae in #4897
- fix mypy by @copybara-service[bot] in #4931
- Avoid passing non-boolean mask to
whereargument ofjax.numpyreductions. Non-boolean mask inputs have been deprecated for several releases, and will result in an error starting in JAX v0.8.0. by @copybara-service[bot] in #4923 - Ported nnx.PReLU from linen by @vfdev-5 in #4934
- Added nnx.scan docs and few minor docs fixes by @vfdev-5 in #4930
- add variables argument to nnx.clone by @cgarciae in #4945
- only copy dicts on State.getitem by @cgarciae in #4946
- always differentiate standalone Variables in nnx.grad by @cgarciae in #4947
- Implement instance norm in NNX by @mattbahr in #4939
- Automatically apply sharding constraints to sharded models by @IvyZX in #4844
- Add reference of flip doc to gspmd guide by @IvyZX in #4949
- Fixed nnx.is_data docstring rendering by @vfdev-5 in #4957
- expose pytree guide by @cgarciae in #4951
- fix toy examples by @cgarciae in #4952
- Explicitly cast attribute names to string before checking for private attributes. by @copybara-service[bot] in #4955
- add flax_hijax_variable flag by @cgarciae in #4953
- mark shard_map as implemented in transforms guide by @cgarciae in #4738
- improve Variable flatten by @cgarciae in #4954
- Minor typo fix in nnx.call docstring by @vfdev-5 in #4959
- allow split tuples in Rngs.fork by @cgarciae in #4958
- Fixed Gemma example using Gemma2 models by @vfdev-5 in #4830
- finish pytree guide by @cgarciae in #4929
- update bridge wrappers from maxtext by @cgarciae in #4937
- fix HashableMapping hash definition for mixed key types by @copybara-service[bot] in #4936
- Flax RNG guide for jax.jit: clarify rng outputs are shared but not inputs. by @copybara-service[bot] in #4956
- fix Variable pytree flatten by @copybara-service[bot] in #4962
- import PathParts from flax.typing by @cgarciae in #4966
- Correctly expose
flax.config.temp_flip_flagby @IvyZX in #4969 - raise on Variable inplace operators by @cgarciae in #4967
- Copybara import of the project: by @copybara-service[bot] in #4976
- update to version 0.12.0 by @cgarciae in #4982
- Minor typo fixes in flax gspmd guide by @vfdev-5 in #4970
- ignore uv.lock by @copybara-service[bot] in #4974
- [nnx] preserve the function's type information in jit by @cgarciae in #4981
- add Variable.set_metadata by @cgarciae in #4968
- propagate eager sharding by @cgarciae in #4983
New Contributors
Full Changelog: v0.11.2...v0.12.0
0.11.2
What's Changed
nnx.merge now doesn't create a copy of the Variables in the incoming states by default, meaning that the new merged structures holds references to the incoming Variables. This enables new patterns, for example its now possible to create models with the same state but with different runtime behavior:
model = SomeModel(...)
# create eval model
eval_model = nnx.merge(*nnx.split(model)) # same Variables, different structure
eval_model.eval()model and eval_model share the same Variables and are therefore kept in sync but have different runtime behavior, this avoids having to constantly mutate a single model back and forth between different runtime modes which can be error prone / cause unwanted recompilation.
To keep the old behavior use nnx.merge(..., copy=True).
PRs
- add Rngs random helpers by @cgarciae in #4876
- Fix re-export and docs for identity by @jlperla in #4850
- Fix ToLinen docstring return description by @mohsinm-dev in #4852
- Update doc build instructions and clean up unused packages by @IvyZX in #4885
- Improve docs related with dataclasses by @IvyZX in #4884
- Fix broken contributing documentation link by @mohsinm-dev in #4855
- Internal change by @copybara-service[bot] in #4886
- Fix string key preservation in replace_by_pure_dict by @mohsinm-dev in #4860
- Remove the need for Conv and ConvTranspose to know the precise batch size. by @copybara-service[bot] in #4877
- call jax's source_info_util.register_exclusion in flax's traceback_util.register_exclusion by @copybara-service[bot] in #4887
- Update typo in nnx.Optimizer by @codinfox in #4880
- Exposed split_rngs docstring in the docs_nnx by @vfdev-5 in #4846
- Pin sentencepiece version to 0.2.0 to fix head by @IvyZX in #4892
- Relax duplicate check to exclude non-string values such as PartitionSpec.UNCONSTRAINED, since those can be repeated. by @copybara-service[bot] in #4881
- add find_duplicates by @cgarciae in #4894
- Sharding API improvements (non breaking) by @IvyZX in #4893
- document jax.random shorthand methods by @cgarciae in #4899
- Optimiser was already instantiated using the model - 05_vae.py by @nenuadrian in #4857
- revert is_leaf logic in _check_carry_same_references by @copybara-service[bot] in #4903
- Doc fix: remove outdated advice on flax v0.6.10; it was released two years ago. by @copybara-service[bot] in #4910
- Fix bug when raising ScopeParamNotFoundError. by @copybara-service[bot] in #4898
- fix mypy on main by @cgarciae in #4909
- merge no copy Variables by @cgarciae in #4912
- update version to 0.11.2 by @copybara-service[bot] in #4915
New Contributors
- @mohsinm-dev made their first contribution in #4852
- @codinfox made their first contribution in #4880
- @nenuadrian made their first contribution in #4857
Full Changelog: v0.11.1...v0.11.2
v0.11.1
What's Changed
- Make
Sequential()be identity by @SobhanMP in #4796 - Add a JAX/Flax key concepts doc by @IvyZX in #4795
- miscellaneous improvements by @cgarciae in #4859
- Replace
jax.sharding.use_meshwithjax.set_mesh.jax.set_meshcan act as a global setter or a context manager. by @copybara-service[bot] in #4862 - Pytree and ArrayRef refactor by @cgarciae in #4863
- Add old property attributes for object->pytree rename. by @copybara-service[bot] in #4864
- Add BatchNorm layers to CNN in MNIST tutorial for improved training stability by @sanepunk in #4773
- Description by @copybara-service[bot] in #4866
- update and pop for dict by @cgarciae in #4869
- simplify nnx_basics by @cgarciae in #4868
- updates to version 0.11.1 by @cgarciae in #4878
New Contributors
Full Changelog: v0.11.0...v0.11.1
v0.11.0
v0.11.0 - Pytrees, MutableArrays, and more!
This version of Flax introduces some changes to improve interop with native JAX and adds support for the new jax.experimental.MutableArray. More on this soon! However, some breaking changes to align with the JAX way of doing things were necessary. Most code should remain intact, however, the following changes deviate from the current behavior:
Rngsin standard layers: all standard layers no longer hold a shared reference to therngsobject given in the constructor, instead they now keep afork-ed copy of theRngsorRngStreamobjects. This impacts Using Rngs in NNX Transforms and Loading Checkpoints with RNGs.- Optimizer Updates: the Optimizer abstraction no longer holds a reference to the
modelto avoid reference sharing, instead themodelmust be provided as the first argument toupdate. - Modules as Pytrees: Modules are now pytrees! This avoid unnecessary use of
splitandmergewhen interacting trivially with raw JAX transforms (state must still be manually propagated if not using MutableArrays, and referential transparency is still an issue). This affects when operating on Pytrees containing NNX Objects withjax.tree.*APIs.
Checkout the full NNX 0.10 to NNX 0.11 migration guide.
In the near future we'll share more information about new ways of using NNX with JAX transforms directly by leveraging the new Pytree and MutableArray support. Stay tuned!
What's Changed
- [nnx] mutable array p3 by @cgarciae in #4755
- [nnx] allow method calls in ToLinen by @cgarciae in #4808
- Internal change by @copybara-service[bot] in #4807
- Preserve sharding information in axes_scan by @copybara-service[bot] in #4806
- Deduplicate contributing and philosophy and move to main site by @IvyZX in #4809
- Fixed nnx.remat docstring rendering by @vfdev-5 in #4790
- Added a note to gemma guide about model's license consent on kaggle by @vfdev-5 in #4776
- [nnx] ToLinen add abtract_init flag by @cgarciae in #4813
- Modify NNX to use id(variable) instead of nnx.Variables as dictionary by @divyashreepathihalli in #4814
- Allow using LazyRngs for flax init/apply. by @copybara-service[bot] in #4818
- [nnx] remove VariableState by @cgarciae in #4800
- Fix failing CI jobs: trailing whitespace, deprecated
.typeusage by @vfdev-5 in #4823 - [nnx] fix Rngs dtype check by @cgarciae in #4820
- refactor: move usages of
.valueto[...]in modules_test.py by @lukeyeh in #4815 - Added training script for Gemma model by @vfdev-5 in #4822
- [nnx] add flax_pytree_module flag by @cgarciae in #4811
- create ModelAndOptimizer symbol by @copybara-service[bot] in #4849
- [nnx] remove Optimizer.model attribute by @cgarciae in #4842
- [nnx] add mutable array support in update by @cgarciae in #4851
- Migrate
transforms_test.pyfrom.valueto[...]by @lukeyeh in #4841 - 0.11.0 migration guide by @cgarciae in #4854
New Contributors
- @divyashreepathihalli made their first contribution in #4814
- @lukeyeh made their first contribution in #4815
Full Changelog: v0.10.7...v0.11.0
0.10.7
What's Changed
- Added identity export from JAX by @jlperla in #4652
- Fixes a bug in type annotations for scope.param (unbox=True should accept callable[..., T | AxisMEtadata[T]] and return T, while unbox=False should always return the same thing as what callable returning. by @copybara-service in #4727
- fix merge by @copybara-service in #4731
- [nnx] make Variable a pytree by @cgarciae in #4728
- [nnx] add JitWrapped API by @cgarciae in #4699
- Update JAX nightly index usage by @copybara-service in #4733
- [nnx] mutable array p1 by @cgarciae in #4715
- add dataclass by @copybara-service in #4739
- [flax] unconditionally register nnx.Variable as a pytree by @copybara-service in #4748
- Updated version of pre-commit-hooks in .pre-commit-config.yaml by @vfdev-5 in #4746
- Fixed docstring visibility for nnx.eval_shape by @vfdev-5 in #4747
- Added keep_rngs arg to MHA to optionally store rngs by @vfdev-5 in #4749
- MultiHeadAttention only keeps rngs if dropout_rate is positive by @copybara-service in #4750
- [nnx] mutable array p2 by @cgarciae in #4741
- Add in_kv_features argument to nnx.MultiHeadAttention, addressing #4756. by @copybara-service in #4757
- Fix broken link for Transforms guide by @nireekshak in #4763
- Minor improvements of lm1b_nnx example by @vfdev-5 in #4745
- Fix head CI tests by @IvyZX in #4764
- Fix typos by @nireekshak in #4725
- Check for leaves of type variablelib.Variable when getting sharding specs. by @copybara-service in #4769
- Fixes #1925 non-str dict keys not suppoted in module state by @muhrin in #4563
- Modified the Functional API link by @nireekshak in #4767
- Fix hardcoded link to filter guide in docs by @hamogu in #4768
- Fix bad doc links by @IvyZX in #4770
- revise axes_scan to flatten argument pytrees only once by @copybara-service in #4772
- Simplify ToNNX access of Linen module methods by @IvyZX in #4766
- Use
.input_formatsand.output_formatsin place of.input_layoutsand.output_layoutsrespectively. by @copybara-service in #4784 - Exposed OptState in nnx module by @vfdev-5 in #4788
- Fixes colab link for nnx docs by @vfdev-5 in #4775
- Internal changes by @copybara-service in #4786
- Fix typo in Flax
nnx_basicsdoc. by @copybara-service in #4781 - update version to 0.10.7 by @cgarciae in #4798
New Contributors
- @nireekshak made their first contribution in #4763
- @muhrin made their first contribution in #4563
- @hamogu made their first contribution in #4768
Full Changelog: v0.10.6...v0.10.7
0.10.6
What's Changed
- Sow top activations based on absolute value. by @copybara-service in #4670
- Add support for layer-specific rope scale factors. by @copybara-service in #4672
- Automatic model selection for Gemma 3 models. by @copybara-service in #4671
- Make LoRA's dtype arg useful by @IvyZX in #4681
- [NVIDIA] Support FP8 Einsum Op by @kaixih in #4686
- [nnx] remove deprecated APIs by @cgarciae in #4627
- Add
attention_biasparameter toMultiHeadDotProductAttention. by @copybara-service in #4694 - Unit tests for
attention_biasparameter toMultiHeadDotProductAttention. Add parameter to all overloads to make pytype happy. by @copybara-service in #4702 - Rollback of attention_bias parameter, because the change overrides the attention bias for injected attention functions. by @copybara-service in #4703
- Add custom einsum op to Einsum() by @IvyZX in #4705
- [nnx] refactor GraphDef by @cgarciae in #4630
- Make fully replicated array before saving checkpoints for examples that use pmap. by @copybara-service in #4707
- Fix CI by @cgarciae in #4716
- remove "nnx" collection in ToLinen by @copybara-service in #4708
- [nnx] flaxlib types by @cgarciae in #4639
- v0.10.6 by @cgarciae in #4724
Full Changelog: v0.10.5...v0.10.6
0.10.5
What's Changed
- [nnx] fix tabulate by @cgarciae in #4580
- Refactor bridge.Module tests from
wrappers_test.pyto another file. by @copybara-service in #4581 - Avoid calls to jnp.shape for non-array inputs. by @jakevdp in #4592
- remove Embed nan casting by @cgarciae in #4600
- Add QK Norm. by @copybara-service in #4594
- Util to let bridge module work with NNX submodules by @IvyZX in #4584
- Add configurable Query Pre Attention scalar. by @copybara-service in #4595
- Make RoPE Base Frequency configurable. by @copybara-service in #4596
- [nnx] pytrees are graph nodes by @cgarciae in #4547
- Add option to load checkpoints with transposed Gating Einsum. by @copybara-service in #4597
- add top_p sampling in gemma example by @copybara-service in #4591
- Fix position and name of Post Attention Norm. by @copybara-service in #4598
- Add Sow Config to from_params constructor. by @copybara-service in #4599
- bridge module with linen submodule by @IvyZX in #4604
- Dramatically speed up sampling compilation time by @copybara-service in #4574
- [nnx] improve grad docs by @cgarciae in #4588
- [nnx] add support for standalone Variables by @cgarciae in #4606
- add promote_dtype as a config option for multiple layers by @cgarciae in #4613
- Copybara import of the project: by @copybara-service in #4616
- Fixed typo in
beam_searchloop. by @copybara-service in #4615 - support swap model params in gemma sampler by @copybara-service in #4614
- Allow bridge module to have 'name' field by @IvyZX in #4619
- fix performance guide by @cgarciae in #4621
- Copybara import of the project: by @copybara-service in #4618
- Add REFLECT padding to convolution layer by @sarlinpe in #4553
- fix trace-level detection by @cgarciae in #4527
- Add attribute path customization to bridge modules by @IvyZX in #4624
- add reprlib max depth flag by @cgarciae in #4632
- Allow custom axis metadata annotation during transforms by @IvyZX in #4637
- [bridge module] Allow name arg to represent actual submodule path by @IvyZX in #4634
- [nnx] improve Variable proxy for binary operations by @cgarciae in #4641
- Fix module stack typing annotation. by @copybara-service in #4633
- Stop passing reduce_axes to jax.grad, jax.vjp, and jax.value_and_grad. by @copybara-service in #4617
- discord release webhook by @cgarciae in #4646
- [nnx] support Array leaves in graph nodes by @cgarciae in #4612
- Roll up package jax version and uv.lock by @IvyZX in #4648
- Use jax.nn.dot_product_attention when possible by @IvyZX in #4649
- Fix flaky vmap test tolerance. by @copybara-service in #4653
- Test runner ubuntu upgrade 24.04 by @IvyZX in #4659
- Fix lazy_init typo by @IvyZX in #4657
- deflake a test by @copybara-service in #4663
- v0.10.5 by @cgarciae in #4656
New Contributors
Full Changelog: v0.10.4...v0.10.5
Release 0.10.4
What's Changed
- update pypi publish by @cgarciae in #4538
- [nnx] register_variable_name refactor by @copybara-service in #4540
- added support to the accuracy metric for binary classification by @mattbahr in #4536
- [nnx] bridge Module by @cgarciae in #4542
- [nnx] copy _var_metadata by @copybara-service in #4548
- [bridge] fix unbox logic by @copybara-service in #4551
- Add
is_initializingAPI by @copybara-service in #4550 - [nnx] Add specific model typing for nnx.Optimizer by @marcelroed in #4470
- Add linen metadata conversion to linx by @IvyZX in #4552
- [bridge] improve Module context by @cgarciae in #4554
- Raise error if user uses 'name' in bridge module setup by @IvyZX in #4555
- Add deprecation warning to all
nnx.Statemethods by @IvyZX in #4561 - [nnx] add shard_map by @cgarciae in #4490
- Fix CI breakages from newest jax by @IvyZX in #4576
- [bridge] Set _initializing correctly and avoid return RNG states by @copybara-service in #4569
- v0.10.4 by @cgarciae in #4579
New Contributors
- @mattbahr made their first contribution in #4536
- @marcelroed made their first contribution in #4470
Full Changelog: v0.10.3...v0.10.4