Skip to content

Commit 1b2cb0f

Browse files
james-martensKfacJaxDev
authored andcommitted
Adding JAX config change jax_pmap_shmap_merge=False to the top of the optimizer.py file. The new pmap implementation currently seems to slow down projects using kfac_jax.
PiperOrigin-RevId: 840716303
1 parent ae4f29d commit 1b2cb0f

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

kfac_jax/_src/optimizer.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,12 @@
2525
from kfac_jax._src import utils
2626
from typing_extensions import Self
2727

28+
# For now we are opting out of JAX's new pmap simulation in favor using the old
29+
# pmap implementation. This is because the new simulation currently leads to
30+
# a compute performance regression in some experiments.
31+
if "jax_pmap_shmap_merge" in jax.config.values:
32+
jax.config.update("jax_pmap_shmap_merge", False)
33+
2834

2935
# Types for annotation
3036
Array = utils.Array

0 commit comments

Comments
 (0)