You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs_nnx/guides/opt_cookbook.rst
+11-11Lines changed: 11 additions & 11 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -36,7 +36,7 @@ The following sections will all be training the same toy model. We will allow ex
36
36
37
37
param_init = jax.nn.initializers.lecun_normal()
38
38
39
-
nnx_keys = nnx.Rngs(0)
39
+
rngs = nnx.Rngs(0)
40
40
41
41
def nnx_model(rngs, **kwargs):
42
42
return nnx.Sequential(
@@ -76,8 +76,8 @@ We'll operate on the following fake data:
76
76
:title: NNX, Jax
77
77
:sync:
78
78
79
-
x = nnx_keys.normal((32, 2))
80
-
y = nnx_keys.normal((32, 8))
79
+
x = rngs.normal((32, 2))
80
+
y = rngs.normal((32, 8))
81
81
82
82
---
83
83
@@ -88,13 +88,13 @@ We'll operate on the following fake data:
88
88
Exponential Moving Average
89
89
===========================
90
90
91
-
Neural network see increased robustness when, rather than using only the weights available at the end of training, we use an exponential moving average of the weights produced throughout training. It is easy to modify the standard Jax training loop to accomodate calculating exponential moving averages.
91
+
Neural networks see increased robustness when, rather than using only the weights available at the end of training, we use an exponential moving average of the weights produced throughout training. It is easy to modify the standard Jax training loop to accomodate calculating exponential moving averages.
92
92
93
93
.. codediff::
94
94
:title: NNX, Jax
95
95
:sync:
96
96
97
-
model = nnx_model(nnx_keys)
97
+
model = nnx_model(rngs)
98
98
99
99
nnx_optimizer = nnx.Optimizer(
100
100
model,
@@ -163,10 +163,10 @@ The pattern for adding low rank adaptation to an optimization loop is very simil
@@ -383,7 +383,7 @@ Sharding Optimization State Differently from Parameters
383
383
Say we're doing data parallelism. We want to replicate our parameters across all GPUs so we can do the forward and backward passes without communication latency.
384
384
385
385
386
-
But we don't need to replicate the optimizer state, as it's not invovled in SPMD computations. One copy is enough, and we can shard this copy across our mesh to reduce memory usage. This means that we need the optimier state to be sharded differently from the parameters themselves.
386
+
But we don't need to replicate the optimizer state, as it's not invovled in SPMD computations. One copy is enough, and we can shard this copy across our mesh to reduce memory usage. This means that we need the optimizer state to be sharded differently from the parameters themselves.
387
387
388
388
389
389
To do this, we can pass the params initializer given the the optimizer a `sharding` argument. This will shard the optimization state the same way. But when we initialize the model parameters themselves, we won't provide a sharding, allowing for data parallelism.
0 commit comments