Skip to content

Commit 8a69a49

Browse files
committed
merge
2 parents fc48483 + a4b9858 commit 8a69a49

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ metrics = nnx.MultiMetric(
165165
accuracy=nnx.metrics.Accuracy(),
166166
loss=nnx.metrics.Average("loss"),
167167
)
168-
hook = hook_fn(metrics, val_iter, hook_every_n_steps)
168+
hook = hook_fn(metrics, val_iter, hook_every_n_steps)
169169
```
170170

171171
This creates a hook function `hook` that after `eval_every_n_steps` steps iterates over the validation set

examples/mnist_classification/main.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,11 @@
88
from flax import nnx
99
from jax import random as jr
1010
from jax.experimental import mesh_utils
11-
from model import CNN, train_step, val_step
1211

1312
from blaxbird import get_default_checkpointer, train_fn
1413

14+
from model import CNN, train_step, val_step
15+
1516

1617
def get_optimizer(model, lr=1e-4):
1718
tx = optax.adamw(lr)

0 commit comments

Comments
 (0)