Skip to content

Commit ce6b3c8

Browse files
committed
Merge pull request #105 from Qianruipku:hotfix2
PiperOrigin-RevId: 807836181 Change-Id: Id83ca6c48188051d92c88d3983ae8712e2186864
2 parents fa69b4c + d3d70f7 commit ce6b3c8

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

ferminet/train.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -883,7 +883,8 @@ def learning_rate_schedule(t_: jnp.ndarray) -> jnp.ndarray:
883883
elif isinstance(optimizer, optax.GradientTransformation):
884884
# optax/optax-compatible optimizer (ADAM, LAMB, ...)
885885
opt_state = jax.pmap(optimizer.init)(params)
886-
opt_state = opt_state_ckpt or opt_state # avoid overwriting ckpted state
886+
if opt_state_ckpt is not None:
887+
opt_state = tuple(opt_state_ckpt)
887888
step = make_training_step(
888889
mcmc_step=mcmc_step,
889890
optimizer_step=make_opt_update_step(evaluate_loss, optimizer),

0 commit comments

Comments
 (0)