Commit dc9c815
[kfac_jax] Prepare for
Condition on `jax_pmap_shmap_merge` to grab explicitly grab the first shard rather than rely on `x[0]`. See https://docs.jax.dev/en/latest/migrate_pmap.html#int-indexing-into-sharded-arrays for more information.
NOTE: `kfac_jax.utils.index_if_not_scalar` has been removed and inlined into `kfac_jax.utils.get_first`.
- If you need to grab the first shard of a semantically replicated array (i.e., each shard is really a replica), use `kfac_jax.utils.get_first`.
- If you really are trying to index into an array, use the usual array indexing operators (i.e., `x[0]`).
PiperOrigin-RevId: 846329906jax_pmap_shmap_merge=True.1 parent c0d25ad commit dc9c815
File tree
3 files changed
+25
-18
lines changed- examples
- kfac_jax/_src/utils
3 files changed
+25
-18
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
520 | 520 | | |
521 | 521 | | |
522 | 522 | | |
523 | | - | |
| 523 | + | |
| 524 | + | |
524 | 525 | | |
525 | 526 | | |
526 | 527 | | |
| |||
541 | 542 | | |
542 | 543 | | |
543 | 544 | | |
544 | | - | |
| 545 | + | |
| 546 | + | |
545 | 547 | | |
546 | 548 | | |
547 | 549 | | |
| |||
671 | 673 | | |
672 | 674 | | |
673 | 675 | | |
674 | | - | |
| 676 | + | |
| 677 | + | |
| 678 | + | |
| 679 | + | |
675 | 680 | | |
676 | 681 | | |
677 | 682 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
73 | 73 | | |
74 | 74 | | |
75 | 75 | | |
76 | | - | |
77 | 76 | | |
78 | 77 | | |
79 | 78 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
82 | 82 | | |
83 | 83 | | |
84 | 84 | | |
85 | | - | |
86 | | - | |
87 | | - | |
88 | | - | |
| 85 | + | |
| 86 | + | |
89 | 87 | | |
90 | | - | |
91 | | - | |
92 | | - | |
93 | | - | |
| 88 | + | |
| 89 | + | |
94 | 90 | | |
95 | | - | |
96 | | - | |
97 | | - | |
| 91 | + | |
| 92 | + | |
| 93 | + | |
| 94 | + | |
| 95 | + | |
| 96 | + | |
| 97 | + | |
| 98 | + | |
| 99 | + | |
98 | 100 | | |
| 101 | + | |
| 102 | + | |
| 103 | + | |
99 | 104 | | |
100 | | - | |
101 | | - | |
102 | | - | |
| 105 | + | |
103 | 106 | | |
104 | 107 | | |
105 | 108 | | |
| |||
0 commit comments