Skip to content

Commit face046

Browse files
james-martensKfacJaxDev
authored andcommitted
Minor bug fix: changing train_step in examples code to take a mean of the stats instead of taking from the first device. Because the optimizer syncs its own stats (like loss), this didn't matter except for stats returned from the kfac_jax optimizer (or Optax optimizers using OptaxWrapper). However, the Polyak averaged loss wasn't actually synced across devices (as its not part of the optimizer anymore), so "loss_polyak" was being reported only for the first device.
PiperOrigin-RevId: 701318626
1 parent 4de99f5 commit face046

File tree

1 file changed

+16
-6
lines changed

1 file changed

+16
-6
lines changed

examples/training.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -621,10 +621,6 @@ def train_step(self, global_step: Array, rng: PRNGKey) -> dict[str, Numeric]:
621621
if "aux" in stats:
622622
stats.update(stats.pop("aux", {}))
623623

624-
self._python_step += 1
625-
626-
stats["progress"] = self.progress(self._python_step)
627-
628624
for name in self.config.get("per_device_stats_to_log", []):
629625
gathered_stat = jnp.reshape(
630626
kfac_jax.utils.host_all_gather(stats[name]), [-1]
@@ -633,7 +629,12 @@ def train_step(self, global_step: Array, rng: PRNGKey) -> dict[str, Numeric]:
633629
for i in range(gathered_stat.shape[0]):
634630
stats[f"{name}_{i}"] = jnp.array([gathered_stat[i]])
635631

636-
return kfac_jax.utils.get_first(stats)
632+
stats = jax.tree_util.tree_map(functools.partial(jnp.mean, axis=0), stats)
633+
634+
self._python_step += 1
635+
stats["progress"] = self.progress(self._python_step)
636+
637+
return stats
637638

638639
# _
639640
# _____ ____ _| |
@@ -774,26 +775,35 @@ def run_evaluation(
774775
global_step, self._params, self._state, self._opt_state, key, batch)
775776

776777
if params_polyak is not None:
778+
777779
stats_no_polyak = stats
780+
778781
stats = self.eval_batch_pmap(
779782
global_step, params_polyak, func_state_polyak, self._opt_state,
780783
key, batch)
784+
781785
stats.update(
782786
{k + "_no_polyak": v for k, v in stats_no_polyak.items()
783787
if k != "data_seen"})
784788

785789
if params_schedule_free is not None:
790+
786791
stats_no_sf = stats
792+
787793
stats = self.eval_batch_pmap(
788794
global_step, params_schedule_free, func_state_schedule_free,
789795
self._opt_state, key, batch)
796+
790797
stats.update(
791798
{k + "_no_sf": v for k, v in stats_no_sf.items()
792799
if k != "data_seen"})
793800

794801
averaged_stats.add(stats, 1)
795802

796-
# Extract all stats
803+
# Extract all stats.
804+
# Note that MultiChunkAccumulator.value will perform a pmean
805+
# automatically, so it's fine to call "get_first" here instead of taking
806+
# the mean.
797807
for k, v in averaged_stats.value.items(): # pytype: disable=attribute-error
798808
all_stats[f"{name}_{k}"] = kfac_jax.utils.get_first(v)
799809

0 commit comments

Comments
 (0)