Skip to content

Commit fb8ea90

Browse files
Update keras_nnx_guide.md
1 parent 5ec12d4 commit fb8ea90

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

guides/md/keras_nnx_guide.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,8 @@ def train_step(model, optimizer, batch):
223223
y_pred = model_(x)
224224
return jnp.mean((y - y_pred) ** 2)
225225

226-
grads = nnx.grad(loss_fn, wrt=trainable_var)(model)
226+
diff_state = nnx.DiffState(0, trainable_var)
227+
grads = nnx.grad(loss_fn, argnums=diff_state)(model)
227228
optimizer.update(model, grads)
228229

229230

0 commit comments

Comments
 (0)