Commit 1b2cb0f
Adding JAX config change
PiperOrigin-RevId: 840716303jax_pmap_shmap_merge=False to the top of the optimizer.py file. The new pmap implementation currently seems to slow down projects using kfac_jax.1 parent ae4f29d commit 1b2cb0f
1 file changed
+6
-0
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
25 | 25 | | |
26 | 26 | | |
27 | 27 | | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
28 | 34 | | |
29 | 35 | | |
30 | 36 | | |
| |||
0 commit comments