File tree Expand file tree Collapse file tree 2 files changed +3
-2
lines changed
examples/mnist_classification Expand file tree Collapse file tree 2 files changed +3
-2
lines changed Original file line number Diff line number Diff line change @@ -165,7 +165,7 @@ metrics = nnx.MultiMetric(
165165 accuracy = nnx.metrics.Accuracy(),
166166 loss = nnx.metrics.Average(" loss" ),
167167)
168- hook = hook_fn(metrics, val_iter, hook_every_n_steps)
168+ hook = hook_fn(metrics, val_iter, hook_every_n_steps)
169169```
170170
171171This creates a hook function ` hook ` that after ` eval_every_n_steps ` steps iterates over the validation set
Original file line number Diff line number Diff line change 88from flax import nnx
99from jax import random as jr
1010from jax .experimental import mesh_utils
11- from model import CNN , train_step , val_step
1211
1312from blaxbird import get_default_checkpointer , train_fn
1413
14+ from model import CNN , train_step , val_step
15+
1516
1617def get_optimizer (model , lr = 1e-4 ):
1718 tx = optax .adamw (lr )
You can’t perform that action at this time.
0 commit comments