We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
2 parents fa69b4c + d3d70f7 commit ce6b3c8Copy full SHA for ce6b3c8
ferminet/train.py
@@ -883,7 +883,8 @@ def learning_rate_schedule(t_: jnp.ndarray) -> jnp.ndarray:
883
elif isinstance(optimizer, optax.GradientTransformation):
884
# optax/optax-compatible optimizer (ADAM, LAMB, ...)
885
opt_state = jax.pmap(optimizer.init)(params)
886
- opt_state = opt_state_ckpt or opt_state # avoid overwriting ckpted state
+ if opt_state_ckpt is not None:
887
+ opt_state = tuple(opt_state_ckpt)
888
step = make_training_step(
889
mcmc_step=mcmc_step,
890
optimizer_step=make_opt_update_step(evaluate_loss, optimizer),
0 commit comments