Skip to content

Commit 83ac17a

Browse files
danielsuocopybara-github
authored andcommitted
No public description
PiperOrigin-RevId: 853251958
1 parent 52dff62 commit 83ac17a

File tree

1 file changed

+38
-2
lines changed

1 file changed

+38
-2
lines changed

clu/metrics.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,8 +193,44 @@ def reduce_step(reduced: M, metric: M) -> tuple[M, None]:
193193
# pylint: disable-next=protected-access
194194
return reduced._reduce_merge(metric), None
195195

196-
first = jax.tree_util.tree_map(lambda x: x[0], self)
197-
remainder = jax.tree_util.tree_map(lambda x: x[1:], self)
196+
# Avoid degraded performance under the new jax.pmap. See
197+
# https://docs.jax.dev/en/latest/migrate_pmap.html#int-indexing-into-sharded-arrays.
198+
# Only use the sharding path for concrete sharded arrays, not tracers.
199+
def _is_concrete_sharded(x):
200+
if isinstance(x, jax.core.Tracer):
201+
return False
202+
if not hasattr(x, "addressable_shards"):
203+
return False
204+
shards = x.addressable_shards
205+
if not shards:
206+
return False
207+
# Only use sharding path when shards have shape (1, ...) from pmap
208+
return shards[0].data.ndim > 0 and shards[0].data.shape[0] == 1
209+
210+
leaves = jax.tree_util.tree_leaves(self)
211+
use_sharding_path = (
212+
jax.config.jax_pmap_shmap_merge
213+
and leaves
214+
and _is_concrete_sharded(leaves[0])
215+
)
216+
217+
if use_sharding_path:
218+
219+
def get_first(x):
220+
return x.addressable_shards[0].data.squeeze(0)
221+
222+
def get_remainder(x):
223+
shards = x.addressable_shards
224+
if len(shards) <= 1:
225+
shape = shards[0].data.squeeze(0).shape
226+
return jnp.empty((0,) + shape, dtype=shards[0].data.dtype)
227+
return jnp.stack([s.data.squeeze(0) for s in shards[1:]], axis=0)
228+
229+
first = jax.tree_util.tree_map(get_first, self)
230+
remainder = jax.tree_util.tree_map(get_remainder, self)
231+
else:
232+
first = jax.tree_util.tree_map(lambda x: x[0], self)
233+
remainder = jax.tree_util.tree_map(lambda x: x[1:], self)
198234
# According to b/160868467#comment4, usage of `jax.lax.scan` does not add a
199235
# significant computational cost for simple metrics where e.g. `jnp.sum`
200236
# could be used instead.

0 commit comments

Comments
 (0)