Skip to content

Commit f894a0d

Browse files
committed
Use rngs instead of nnx_keys
1 parent b939636 commit f894a0d

File tree

1 file changed

+11
-11
lines changed

1 file changed

+11
-11
lines changed

docs_nnx/guides/opt_cookbook.rst

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ The following sections will all be training the same toy model. We will allow ex
3636

3737
param_init = jax.nn.initializers.lecun_normal()
3838

39-
nnx_keys = nnx.Rngs(0)
39+
rngs = nnx.Rngs(0)
4040

4141
def nnx_model(rngs, **kwargs):
4242
return nnx.Sequential(
@@ -76,8 +76,8 @@ We'll operate on the following fake data:
7676
:title: NNX, Jax
7777
:sync:
7878

79-
x = nnx_keys.normal((32, 2))
80-
y = nnx_keys.normal((32, 8))
79+
x = rngs.normal((32, 2))
80+
y = rngs.normal((32, 8))
8181

8282
---
8383

@@ -88,13 +88,13 @@ We'll operate on the following fake data:
8888
Exponential Moving Average
8989
===========================
9090

91-
Neural network see increased robustness when, rather than using only the weights available at the end of training, we use an exponential moving average of the weights produced throughout training. It is easy to modify the standard Jax training loop to accomodate calculating exponential moving averages.
91+
Neural networks see increased robustness when, rather than using only the weights available at the end of training, we use an exponential moving average of the weights produced throughout training. It is easy to modify the standard Jax training loop to accomodate calculating exponential moving averages.
9292

9393
.. codediff::
9494
:title: NNX, Jax
9595
:sync:
9696

97-
model = nnx_model(nnx_keys)
97+
model = nnx_model(rngs)
9898

9999
nnx_optimizer = nnx.Optimizer(
100100
model,
@@ -163,10 +163,10 @@ The pattern for adding low rank adaptation to an optimization loop is very simil
163163

164164
def add_rank2_lora(path, node):
165165
if isinstance(node, nnx.Linear):
166-
return nnx.LoRA(node.in_features, 2, node.out_features, base_module=node, rngs=nnx_keys)
166+
return nnx.LoRA(node.in_features, 2, node.out_features, base_module=node, rngs=rngs)
167167
return node
168168

169-
base_model = nnx_model(nnx_keys)
169+
base_model = nnx_model(rngs)
170170
model = nnx.recursive_map(add_rank2_lora, base_model)
171171

172172
@nnx.jit
@@ -244,7 +244,7 @@ So far, we've been using optax optimizers with the interface ``optimizer.update(
244244
value_fn=loss_fn_state)
245245
return loss
246246

247-
model = nnx_model(nnx_keys)
247+
model = nnx_model(rngs)
248248

249249
nnx_optimizer = nnx.Optimizer(
250250
model,
@@ -294,7 +294,7 @@ In Flax, we will also initialize a partitioned optax optimizer. But unlike the J
294294
:title: NNX, Jax
295295
:sync:
296296

297-
model = nnx_model(nnx_keys)
297+
model = nnx_model(rngs)
298298
state = nnx.state(model, nnx.Param)
299299
rates = {'kernel': optax.adam(1e-3), 'bias': optax.adam(1e-2)}
300300
param_tys = nnx.map_state(lambda p, v: list(p)[-1], state)
@@ -342,7 +342,7 @@ In Flax, we can just wrap wrap the ``MultiSteps`` optimizer with the ``nnx.Optim
342342
:title: NNX, Jax
343343
:sync:
344344

345-
model = nnx_model(nnx_keys)
345+
model = nnx_model(rngs)
346346
nnx_optimizer = nnx.Optimizer(model, tx=optax.MultiSteps(optax.adam(1e-3), every_k_schedule=3), wrt=nnx.Param)
347347

348348
@nnx.jit
@@ -383,7 +383,7 @@ Sharding Optimization State Differently from Parameters
383383
Say we're doing data parallelism. We want to replicate our parameters across all GPUs so we can do the forward and backward passes without communication latency.
384384

385385

386-
But we don't need to replicate the optimizer state, as it's not invovled in SPMD computations. One copy is enough, and we can shard this copy across our mesh to reduce memory usage. This means that we need the optimier state to be sharded differently from the parameters themselves.
386+
But we don't need to replicate the optimizer state, as it's not invovled in SPMD computations. One copy is enough, and we can shard this copy across our mesh to reduce memory usage. This means that we need the optimizer state to be sharded differently from the parameters themselves.
387387

388388

389389
To do this, we can pass the params initializer given the the optimizer a `sharding` argument. This will shard the optimization state the same way. But when we initialize the model parameters themselves, we won't provide a sharding, allowing for data parallelism.

0 commit comments

Comments
 (0)