Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions guides/ipynb/keras_nnx_guide.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,8 @@
" y_pred = model_(x)\n",
" return jnp.mean((y - y_pred) ** 2)\n",
"\n",
" grads = nnx.grad(loss_fn, wrt=trainable_var)(model)\n",
" diff_state = nnx.DiffState(0, trainable_var)\n",
" grads = nnx.grad(loss_fn, argnums=diff_state)(model)\n",
Comment on lines +345 to +346
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For conciseness, you can inline the creation of nnx.DiffState directly into the nnx.grad call. This removes the need for a temporary variable and makes the code a single, expressive line.

Suggested change
" diff_state = nnx.DiffState(0, trainable_var)\n",
" grads = nnx.grad(loss_fn, argnums=diff_state)(model)\n",
" grads = nnx.grad(loss_fn, argnums=nnx.DiffState(0, trainable_var))(model)\n",

" optimizer.update(model, grads)\n",
"\n",
"\n",
Expand Down Expand Up @@ -506,4 +507,4 @@
},
"nbformat": 4,
"nbformat_minor": 0
}
}
3 changes: 2 additions & 1 deletion guides/keras_nnx_guide.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,8 @@ def loss_fn(model_):
y_pred = model_(x)
return jnp.mean((y - y_pred) ** 2)

grads = nnx.grad(loss_fn, wrt=trainable_var)(model)
diff_state = nnx.DiffState(0, trainable_var)
grads = nnx.grad(loss_fn, argnums=diff_state)(model)
Comment on lines +211 to +212
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For conciseness, you can inline the creation of nnx.DiffState directly into the nnx.grad call. This removes the need for a temporary variable and makes the code a single, expressive line.

Suggested change
diff_state = nnx.DiffState(0, trainable_var)
grads = nnx.grad(loss_fn, argnums=diff_state)(model)
grads = nnx.grad(loss_fn, argnums=nnx.DiffState(0, trainable_var))(model)

optimizer.update(model, grads)


Expand Down
3 changes: 2 additions & 1 deletion guides/md/keras_nnx_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,8 @@ def train_step(model, optimizer, batch):
y_pred = model_(x)
return jnp.mean((y - y_pred) ** 2)

grads = nnx.grad(loss_fn, wrt=trainable_var)(model)
diff_state = nnx.DiffState(0, trainable_var)
grads = nnx.grad(loss_fn, argnums=diff_state)(model)
Comment on lines +226 to +227
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For conciseness, you can inline the creation of nnx.DiffState directly into the nnx.grad call. This removes the need for a temporary variable and makes the code a single, expressive line. This can help keep code examples in the guide clean and focused.

Suggested change
diff_state = nnx.DiffState(0, trainable_var)
grads = nnx.grad(loss_fn, argnums=diff_state)(model)
grads = nnx.grad(loss_fn, argnums=nnx.DiffState(0, trainable_var))(model)

optimizer.update(model, grads)


Expand Down