@@ -193,8 +193,44 @@ def reduce_step(reduced: M, metric: M) -> tuple[M, None]:
193193 # pylint: disable-next=protected-access
194194 return reduced ._reduce_merge (metric ), None
195195
196- first = jax .tree_util .tree_map (lambda x : x [0 ], self )
197- remainder = jax .tree_util .tree_map (lambda x : x [1 :], self )
196+ # Avoid degraded performance under the new jax.pmap. See
197+ # https://docs.jax.dev/en/latest/migrate_pmap.html#int-indexing-into-sharded-arrays.
198+ # Only use the sharding path for concrete sharded arrays, not tracers.
199+ def _is_concrete_sharded (x ):
200+ if isinstance (x , jax .core .Tracer ):
201+ return False
202+ if not hasattr (x , "addressable_shards" ):
203+ return False
204+ shards = x .addressable_shards
205+ if not shards :
206+ return False
207+ # Only use sharding path when shards have shape (1, ...) from pmap
208+ return shards [0 ].data .ndim > 0 and shards [0 ].data .shape [0 ] == 1
209+
210+ leaves = jax .tree_util .tree_leaves (self )
211+ use_sharding_path = (
212+ jax .config .jax_pmap_shmap_merge
213+ and leaves
214+ and _is_concrete_sharded (leaves [0 ])
215+ )
216+
217+ if use_sharding_path :
218+
219+ def get_first (x ):
220+ return x .addressable_shards [0 ].data .squeeze (0 )
221+
222+ def get_remainder (x ):
223+ shards = x .addressable_shards
224+ if len (shards ) <= 1 :
225+ shape = shards [0 ].data .squeeze (0 ).shape
226+ return jnp .empty ((0 ,) + shape , dtype = shards [0 ].data .dtype )
227+ return jnp .stack ([s .data .squeeze (0 ) for s in shards [1 :]], axis = 0 )
228+
229+ first = jax .tree_util .tree_map (get_first , self )
230+ remainder = jax .tree_util .tree_map (get_remainder , self )
231+ else :
232+ first = jax .tree_util .tree_map (lambda x : x [0 ], self )
233+ remainder = jax .tree_util .tree_map (lambda x : x [1 :], self )
198234 # According to b/160868467#comment4, usage of `jax.lax.scan` does not add a
199235 # significant computational cost for simple metrics where e.g. `jnp.sum`
200236 # could be used instead.
0 commit comments