We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 85535ff commit f9b8b1bCopy full SHA for f9b8b1b
examples/optimizers.py
@@ -124,7 +124,9 @@ def create_optimizer(
124
"Schedule Free is only supported for optax optimizers."
125
)
126
127
- value_and_grad_func = jax.value_and_grad(train_model_func, has_aux=has_aux)
+ value_and_grad_func = jax.value_and_grad(
128
+ train_model_func, has_aux=has_aux or has_func_state
129
+ )
130
131
kwargs = dict(**config[name])
132
0 commit comments