How to access the gradient value when using flax.nnx.value_and_grad #26863
Unanswered
adamhadani
asked this question in
Q&A
Replies: 1 comment 2 replies
-
Hi @adamhadani, from flax import nnx
import jax.numpy as jnp
class Foo(nnx.Module):
def __init__(self):
self.a = nnx.Param(jnp.ones((2, 2)))
def loss_fn(model):
return jnp.sum(model.a.value)
model = Foo()
grads = nnx.grad(loss_fn)(model)
print(grads) Output:
|
Beta Was this translation helpful? Give feedback.
2 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hello,
I have recently started using the jax/flax/optax stack (stax?) for some deep reinforcement learning projects.
I ran into the following issue pretty early on. Notice I'm using the
flax.nnx.value_and_grad
"lifted" version ofjax.value_and_grad
that provides some boilerplate for state management and working withnnx.Module
s. A single training step looks like this for example:I would like to track the gradient norms (e.g. for tracking in tensorboard). how do I access the actual numeric value for the computed gradients? the
grads
value returned here seems to be a function/callable that is used by theoptimizer.update
rather than the actual computed gradient?Sorry if this is a silly question or belongs in some different flax forum!
Cheers,
Adam
Beta Was this translation helpful? Give feedback.
All reactions