Skip to content

Commit a13b4de

Browse files
committed
fix: prevent NaN loss from updating weighted statistics
1 parent 48931c6 commit a13b4de

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

ferminet/train.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -979,8 +979,9 @@ def learning_rate_schedule(t_: jnp.ndarray) -> jnp.ndarray:
979979
loss = loss[0]
980980
# per batch variance isn't informative. Use weighted mean and variance
981981
# instead.
982-
weighted_stats = statistics.exponentialy_weighted_stats(
983-
alpha=0.1, observation=loss, previous_stats=weighted_stats)
982+
if not jnp.isnan(loss):
983+
weighted_stats = statistics.exponentialy_weighted_stats(
984+
alpha=0.1, observation=loss, previous_stats=weighted_stats)
984985
pmove = pmove[0]
985986

986987
# Update observables

0 commit comments

Comments
 (0)