We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 48931c6 commit a13b4deCopy full SHA for a13b4de
ferminet/train.py
@@ -979,8 +979,9 @@ def learning_rate_schedule(t_: jnp.ndarray) -> jnp.ndarray:
979
loss = loss[0]
980
# per batch variance isn't informative. Use weighted mean and variance
981
# instead.
982
- weighted_stats = statistics.exponentialy_weighted_stats(
983
- alpha=0.1, observation=loss, previous_stats=weighted_stats)
+ if not jnp.isnan(loss):
+ weighted_stats = statistics.exponentialy_weighted_stats(
984
+ alpha=0.1, observation=loss, previous_stats=weighted_stats)
985
pmove = pmove[0]
986
987
# Update observables
0 commit comments