Commit face046
Minor bug fix: changing train_step in examples code to take a mean of the stats instead of taking from the first device. Because the optimizer syncs its own stats (like loss), this didn't matter except for stats returned from the kfac_jax optimizer (or Optax optimizers using OptaxWrapper). However, the Polyak averaged loss wasn't actually synced across devices (as its not part of the optimizer anymore), so "loss_polyak" was being reported only for the first device.
PiperOrigin-RevId: 7013186261 parent 4de99f5 commit face046
1 file changed
+16
-6
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
621 | 621 | | |
622 | 622 | | |
623 | 623 | | |
624 | | - | |
625 | | - | |
626 | | - | |
627 | | - | |
628 | 624 | | |
629 | 625 | | |
630 | 626 | | |
| |||
633 | 629 | | |
634 | 630 | | |
635 | 631 | | |
636 | | - | |
| 632 | + | |
| 633 | + | |
| 634 | + | |
| 635 | + | |
| 636 | + | |
| 637 | + | |
637 | 638 | | |
638 | 639 | | |
639 | 640 | | |
| |||
774 | 775 | | |
775 | 776 | | |
776 | 777 | | |
| 778 | + | |
777 | 779 | | |
| 780 | + | |
778 | 781 | | |
779 | 782 | | |
780 | 783 | | |
| 784 | + | |
781 | 785 | | |
782 | 786 | | |
783 | 787 | | |
784 | 788 | | |
785 | 789 | | |
| 790 | + | |
786 | 791 | | |
| 792 | + | |
787 | 793 | | |
788 | 794 | | |
789 | 795 | | |
| 796 | + | |
790 | 797 | | |
791 | 798 | | |
792 | 799 | | |
793 | 800 | | |
794 | 801 | | |
795 | 802 | | |
796 | | - | |
| 803 | + | |
| 804 | + | |
| 805 | + | |
| 806 | + | |
797 | 807 | | |
798 | 808 | | |
799 | 809 | | |
| |||
0 commit comments