Skip to content

Commit 84fa316

Browse files
authored
Merge branch 'main' into upgrade-haiku-linen-to-nnx
2 parents 4ed2ba8 + ab122af commit 84fa316

File tree

7 files changed

+91
-32
lines changed

7 files changed

+91
-32
lines changed
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
Dtypes
2+
------------------------
3+
4+
.. automodule:: flax.nnx.nn.dtypes
5+
.. currentmodule:: flax.nnx.nn.dtypes
6+
7+
.. autofunction:: canonicalize_dtype
8+
.. autofunction:: promote_dtype

docs_nnx/api_reference/flax.nnx/nn/index.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,11 @@ See the `NNX page <https://flax.readthedocs.io/en/latest/nnx/index.html>`__ for
99

1010
activations
1111
attention
12+
dtypes
1213
initializers
1314
linear
15+
lora
1416
normalization
17+
recurrent
1518
stochastic
1619

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
LoRA
2+
------------------------
3+
4+
NNX LoRA classes.
5+
6+
.. automodule:: flax.nnx
7+
.. currentmodule:: flax.nnx
8+
9+
.. flax_module::
10+
:module: flax.nnx
11+
:class: LoRA
12+
13+
.. flax_module::
14+
:module: flax.nnx
15+
:class: LoRALinear
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
Recurrent
2+
------------------------
3+
4+
.. automodule:: flax.nnx.nn.recurrent
5+
.. currentmodule:: flax.nnx.nn.recurrent
6+
7+
.. flax_module::
8+
:module: flax.nnx.nn.recurrent
9+
:class: LSTMCell
10+
11+
.. flax_module::
12+
:module: flax.nnx.nn.recurrent
13+
:class: OptimizedLSTMCell
14+
15+
.. flax_module::
16+
:module: flax.nnx.nn.recurrent
17+
:class: SimpleCell
18+
19+
.. flax_module::
20+
:module: flax.nnx.nn.recurrent
21+
:class: GRUCell
22+
23+
.. flax_module::
24+
:module: flax.nnx.nn.recurrent
25+
:class: RNN
26+
27+
.. flax_module::
28+
:module: flax.nnx.nn.recurrent
29+
:class: Bidirectional
30+
31+
32+
.. autofunction:: flip_sequences

docs_nnx/guides/haiku_to_flax.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,7 @@ To call those custom methods:
410410

411411

412412
Transformations
413-
===============
413+
=======================
414414

415415
Both Haiku and `Flax transformations <https://flax.readthedocs.io/en/latest/guides/jax_and_nnx_transforms.html>`__ provide their own set of transforms that wrap `JAX transforms <https://jax.readthedocs.io/en/latest/key-concepts.html#transformations>`__ in a way that they can be used with ``Module`` objects.
416416

@@ -497,7 +497,7 @@ The only difference is that Flax ``nnx.scan`` allows you to specify which axis t
497497

498498

499499
Scan over layers
500-
================
500+
=======================
501501

502502
Most Haiku transforms should look similar with Flax, since they all wraps their JAX counterparts, but the scan-over-layers use case is an exception.
503503

@@ -645,7 +645,7 @@ Now inspect the variable pytree on both sides:
645645
646646
647647
Top-level Haiku functions vs top-level Flax modules
648-
================
648+
=======================
649649

650650
In Haiku, it is possible to write the entire model as a single function by using
651651
the raw ``hk.{get,set}_{parameter,state}`` to define/access model parameters and

docs_nnx/index.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ Python objects. Flax NNX is an evolution of the previous `Flax Linen <https://fl
1717
API, and it took years of experience to bring a simpler and more user-friendly API.
1818

1919
.. note::
20-
Flax Linen API is not going to be deprecated in the near future as most of Flax users still
21-
rely on this API. However, new users are encouraged to use Flax NNX.
22-
For existing Flax Linen users planning to move to Flax NNX, check out the `evolution guide <guides/linen_to_nnx.html>`_ and `Why Flax NNX <why.html>`_.
20+
Flax Linen API is not going to be deprecated in the near future as most of Flax users still rely on this API. However, new users are encouraged to use Flax NNX. Check out `Why Flax NNX <why.html>`_ for a comparison between Flax NNX and Linen, and our reasoning to make the new API.
21+
22+
To move your Flax Linen codebase to Flax NNX, get familiarized with the API in `NNX Basics <https://flax.readthedocs.io/en/latest/nnx_basics.html>`_ and then start your move following the `evolution guide <guides/linen_to_nnx.html>`_.
2323

2424
Features
2525
^^^^^^^^^

flax/nnx/nn/recurrent.py

Lines changed: 27 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -742,16 +742,17 @@ def flip_sequences(
742742
values for those sequences that were padded. This function keeps the padding
743743
at the end, while flipping the rest of the elements.
744744
745-
Example:
746-
```python
747-
inputs = [[1, 0, 0],
748-
[2, 3, 0]
749-
[4, 5, 6]]
750-
lengths = [1, 2, 3]
751-
flip_sequences(inputs, lengths) = [[1, 0, 0],
752-
[3, 2, 0],
753-
[6, 5, 4]]
754-
```
745+
Example::
746+
747+
>>> from flax.nnx.nn.recurrent import flip_sequences
748+
>>> from jax import numpy as jnp
749+
>>> inputs = jnp.array([[1, 0, 0], [2, 3, 0], [4, 5, 6]])
750+
>>> lengths = jnp.array([1, 2, 3])
751+
>>> flip_sequences(inputs, lengths, 1, False)
752+
Array([[1, 0, 0],
753+
[3, 2, 0],
754+
[6, 5, 4]], dtype=int32)
755+
755756
756757
Args:
757758
inputs: An array of input IDs <int>[batch_size, seq_length].
@@ -810,27 +811,27 @@ def __call__(
810811
class Bidirectional(Module):
811812
"""Processes the input in both directions and merges the results.
812813
813-
Example usage:
814+
Example usage::
815+
816+
>>> from flax import nnx
817+
>>> import jax
818+
>>> import jax.numpy as jnp
814819
815-
```python
816-
import nnx
817-
import jax
818-
import jax.numpy as jnp
820+
>>> # Define forward and backward RNNs
821+
>>> forward_rnn = RNN(GRUCell(in_features=3, hidden_features=4, rngs=nnx.Rngs(0)))
822+
>>> backward_rnn = RNN(GRUCell(in_features=3, hidden_features=4, rngs=nnx.Rngs(0)))
819823
820-
# Define forward and backward RNNs
821-
forward_rnn = RNN(GRUCell(in_features=3, hidden_features=4, rngs=nnx.Rngs(0)))
822-
backward_rnn = RNN(GRUCell(in_features=3, hidden_features=4, rngs=nnx.Rngs(0)))
824+
>>> # Create Bidirectional layer
825+
>>> layer = Bidirectional(forward_rnn=forward_rnn, backward_rnn=backward_rnn)
823826
824-
# Create Bidirectional layer
825-
layer = Bidirectional(forward_rnn=forward_rnn, backward_rnn=backward_rnn)
827+
>>> # Input data
828+
>>> x = jnp.ones((2, 3, 3))
826829
827-
# Input data
828-
x = jnp.ones((2, 3, 3))
830+
>>> # Apply the layer
831+
>>> out = layer(x)
832+
>>> print(out.shape)
833+
(2, 3, 8)
829834
830-
# Apply the layer
831-
out = layer(x)
832-
print(out.shape)
833-
```
834835
"""
835836

836837
forward_rnn: RNNBase

0 commit comments

Comments
 (0)