Skip to content

Commit f4bebae

Browse files
danielsuoGoogle-ML-Automation
authored andcommitted
[pmap] Cache computing effective mesh devices.
Clean up / combine logic for computing global axis size / mesh devices. Improving the `jax.jit(jax.shard_map)` implementation of `jax.pmap`. PiperOrigin-RevId: 858891987
1 parent 9f327f6 commit f4bebae

File tree

1 file changed

+79
-55
lines changed

1 file changed

+79
-55
lines changed

jax/_src/pmap.py

Lines changed: 79 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ def pmap(f, axis_name=None, *, in_axes=0, out_axes=0,
6363
wrapped_fun = _pmap_wrap_init(f, static_broadcasted_tuple)
6464

6565
def infer_params(*args, **kwargs):
66+
process_count = xb.process_count(backend)
67+
trace_state_clean = core.trace_state_clean()
6668
fun, dyn_argnums, dyn_args = _get_dyn_args(
6769
wrapped_fun, static_broadcasted_tuple, args)
6870
dyn_args_flat, dyn_args_tree = tree_flatten((dyn_args, kwargs))
@@ -72,20 +74,17 @@ def infer_params(*args, **kwargs):
7274
local_axis_size = _mapped_axis_size(dyn_args_flat, in_axes_flat)
7375
donated_invars = _get_donated_invars(
7476
donate_tuple, dyn_args_tree, len(dyn_args_flat))
77+
mesh_devices = _get_mesh_devices(
78+
devices, backend, local_axis_size, axis_size, trace_state_clean)
7579
fun, out_axes_thunk = flat_out_axes(fun, out_axes)
7680
flat_fun, out_tree = flatten_fun(fun, dyn_args_tree)
77-
global_axis_size = _get_global_axis_size(local_axis_size, devices,
78-
backend, axis_size)
79-
trace_state_clean = core.trace_state_clean()
80-
mesh = Mesh(
81-
_get_devices(devices, local_axis_size, global_axis_size, backend),
82-
(axis_name,))
81+
mesh = Mesh(mesh_devices, (axis_name,))
8382
_pmapped, in_specs, out_specs = _cached_shard_map(
8483
flat_fun, mesh, in_axes_flat, out_axes_thunk, axis_name)
8584
jitted_f = api.jit(
8685
_pmapped,
8786
donate_argnums=[i for i, val in enumerate(donated_invars) if val])
88-
if xb.process_count() > 1:
87+
if process_count > 1:
8988
if trace_state_clean:
9089
flat_global_args = [
9190
host_local_array_to_global_array(arr, global_mesh=mesh, pspec=spec)
@@ -95,15 +94,16 @@ def infer_params(*args, **kwargs):
9594
dyn_args_flat, mesh, list(in_specs))
9695
else:
9796
flat_global_args = dyn_args_flat
98-
return jitted_f, flat_global_args, dyn_args_tree, out_tree, mesh, out_specs, donate_tuple
97+
return (jitted_f, flat_global_args, dyn_args_tree, out_tree, mesh, out_specs,
98+
donate_tuple, process_count, trace_state_clean)
9999

100100
@util.wraps(f)
101101
def wrapped(*args, **kwargs):
102-
jitted_f, flat_global_args, _, out_tree, mesh, out_specs, _ = infer_params(
103-
*args, **kwargs)
102+
(jitted_f, flat_global_args, _, out_tree, mesh, out_specs,
103+
_, process_count, trace_state_clean) = infer_params(*args, **kwargs)
104104
outs = jitted_f(*flat_global_args)
105-
if xb.process_count() > 1:
106-
if core.trace_state_clean():
105+
if process_count > 1:
106+
if trace_state_clean:
107107
outs = [
108108
global_array_to_host_local_array(out, global_mesh=mesh, pspec=spec)
109109
for out, spec in zip(outs, out_specs())
@@ -113,7 +113,7 @@ def wrapped(*args, **kwargs):
113113
return tree_unflatten(out_tree(), outs)
114114

115115
def lower(*args, **kwargs):
116-
jitted_f, flat_global_args, in_tree, out_tree, _, _, donate_tuple = infer_params(
116+
jitted_f, flat_global_args, in_tree, out_tree, _, _, donate_tuple, _, _ = infer_params(
117117
*args, **kwargs)
118118
abstract_args = list(map(core.shaped_abstractify, flat_global_args))
119119
args_info = stages.make_args_info(in_tree, abstract_args, donate_tuple)
@@ -150,22 +150,6 @@ def _handle_reshapes(f, in_axes, out_axes_thunk, *args, **kwargs):
150150
return tree_map(lambda x, ax: x if ax is None else lax.expand_dims(x, [ax]),
151151
list(out), list(out_axes_thunk()))
152152

153-
def _get_devices(devices, local_axis_size, global_axis_size, backend):
154-
if backend is not None and devices is None:
155-
devs = xb.devices(backend=backend)
156-
else:
157-
devs = xb.devices() if devices is None else devices
158-
if xb.process_count() > 1:
159-
return devs[:global_axis_size]
160-
return devs[:local_axis_size]
161-
162-
163-
def _ensure_index_tuple(x) -> tuple[int, ...]:
164-
try:
165-
return (int(x),)
166-
except TypeError:
167-
return tuple(int(i) for i in x)
168-
169153

170154
def _mapped_axis_size(args, in_axes):
171155
"""Infer axis size from the first mapped argument.
@@ -194,32 +178,6 @@ def _mapped_axis_size(args, in_axes):
194178
raise ValueError("pmap requires at least one argument with a mapped axis.")
195179

196180

197-
def _get_global_axis_size(local_axis_size: int, in_devices, backend_name: str,
198-
global_axis_size: int | None):
199-
if (xb.process_count() == 1 and global_axis_size is not None and
200-
global_axis_size != local_axis_size):
201-
raise ValueError(
202-
f"Specified axis_size {global_axis_size} doesn't match received "
203-
f"axis_size {local_axis_size}.")
204-
205-
if in_devices is not None and backend_name is None:
206-
backend = xb.get_device_backend(in_devices[0])
207-
else:
208-
backend = xb.get_backend(backend_name)
209-
210-
if global_axis_size is None:
211-
if xb.process_count(backend) == 1:
212-
global_axis_size = local_axis_size
213-
elif in_devices is not None:
214-
global_axis_size = len(in_devices)
215-
else:
216-
global_axis_size = local_axis_size * xb.process_count(backend)
217-
assert all(
218-
len(xb.local_devices(pi, backend)) == xb.local_device_count(backend)
219-
for pi in range(xb.process_count(backend)))
220-
return global_axis_size
221-
222-
223181
def _pmap_wrap_init(f, static_broadcasted_tuple):
224182
"""Create a wrapped function with DebugInfo for pmap.
225183
@@ -387,6 +345,72 @@ def _get_donated_invars(donate_tuple, in_tree, num_flat_args):
387345
return (False,) * num_flat_args
388346

389347

348+
@lru_cache
349+
def _get_mesh_devices(devices, backend, local_axis_size, axis_size,
350+
trace_state_clean):
351+
"""Compute effective mesh devices based on context.
352+
353+
Args:
354+
devices: The mesh devices tuple.
355+
backend: The backend to use.
356+
local_axis_size: The local axis size (per-process).
357+
axis_size: User-specified global axis size (optional).
358+
trace_state_clean: True if in execution mode (not tracing).
359+
360+
Returns:
361+
Tuple of effective mesh devices sliced appropriately.
362+
363+
Raises:
364+
ValueError: If axis_size doesn't match inferred size in single-process.
365+
"""
366+
process_count = xb.process_count(backend)
367+
368+
# Validate explicit axis_size in single-process mode
369+
if (process_count == 1 and axis_size is not None and
370+
axis_size != local_axis_size):
371+
raise ValueError(
372+
f"Specified axis_size {axis_size} doesn't match received "
373+
f"axis_size {local_axis_size}.")
374+
375+
# Compute global_axis_size
376+
if axis_size is not None:
377+
global_axis_size = axis_size
378+
elif process_count > 1:
379+
global_axis_size = local_axis_size * process_count
380+
# Validate all processes have the same number of local devices
381+
assert all(
382+
len(xb.local_devices(pi, backend)) == xb.local_device_count(backend)
383+
for pi in range(process_count))
384+
else:
385+
global_axis_size = local_axis_size
386+
387+
# Determine mesh devices
388+
if devices is not None:
389+
mesh_devices = devices
390+
elif process_count > 1:
391+
# Multi-process: group devices by process (host) for optimal collective
392+
# performance. This matches the old pmap's device ordering which uses
393+
# local_devices(process_index) in a nested loop, ensuring devices from
394+
# the same host are contiguous in the mesh.
395+
# pylint: disable=g-complex-comprehension
396+
mesh_devices = tuple(
397+
d
398+
for process_index in range(process_count)
399+
for d in xb.local_devices(process_index, backend)
400+
)
401+
# pylint: enable=g-complex-comprehension
402+
elif backend is not None:
403+
mesh_devices = tuple(xb.devices(backend=backend))
404+
else:
405+
mesh_devices = tuple(xb.devices())
406+
407+
if not trace_state_clean and process_count > 1:
408+
# Tracing in multihost: use local devices
409+
return tuple(xb.local_devices(backend=backend)[:local_axis_size])
410+
else:
411+
return mesh_devices[:global_axis_size]
412+
413+
390414
@lru_cache
391415
def _local_to_global_aval(local_aval, mesh, pspec):
392416
pspec = sharding_impls.prepare_axis_resources(pspec, 'pspec to array_mapping')

0 commit comments

Comments
 (0)