Skip to content

Commit 48931c6

Browse files
committed
Fix multi-device NaN check in KFAC training step
Replace direct boolean check with jnp.any() to handle loss arrays from multiple devices when detecting NaN values.
1 parent fa69b4c commit 48931c6

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

ferminet/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -378,7 +378,7 @@ def step(
378378
damping=shared_damping,
379379
)
380380

381-
if reset_if_nan and jnp.isnan(stats['loss']):
381+
if reset_if_nan and jnp.any(jnp.isnan(stats['loss'])):
382382
new_params = old_params
383383
new_state = old_state
384384
return data, new_params, new_state, stats['loss'], stats['aux'], pmove

0 commit comments

Comments
 (0)