Skip to content

Commit f9b8b1b

Browse files
timothyn617KfacJaxDev
authored andcommitted
has_aux fix
PiperOrigin-RevId: 762317513
1 parent 85535ff commit f9b8b1b

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

examples/optimizers.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,9 @@ def create_optimizer(
124124
"Schedule Free is only supported for optax optimizers."
125125
)
126126

127-
value_and_grad_func = jax.value_and_grad(train_model_func, has_aux=has_aux)
127+
value_and_grad_func = jax.value_and_grad(
128+
train_model_func, has_aux=has_aux or has_func_state
129+
)
128130

129131
kwargs = dict(**config[name])
130132

0 commit comments

Comments
 (0)