Skip to content

Commit 692d820

Browse files
danielsuoGoogle-ML-Automation
authored andcommitted
[pmap] Inline _prepare_pmap and clean up unused structs.
Improving the `jax.jit(jax.shard_map)` implementation of `jax.pmap`. PiperOrigin-RevId: 861777664
1 parent 762f066 commit 692d820

File tree

2 files changed

+39
-66
lines changed

2 files changed

+39
-66
lines changed

jax/_src/pmap.py

Lines changed: 34 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from __future__ import annotations
1515

1616
from functools import lru_cache, partial
17-
from typing import Any, Callable, NamedTuple, Sequence
17+
from typing import Any
1818

1919
from jax._src import api
2020
from jax._src.api_util import (
@@ -60,33 +60,46 @@ def pmap(f, axis_name=None, *, in_axes=0, out_axes=0,
6060
f, axis_name, static_broadcasted_argnums, donate_argnums, in_axes, out_axes)
6161
if isinstance(axis_name, core._TempAxisName): # pylint: disable=protected-access
6262
axis_name = repr(axis_name)
63-
fun = _pmap_wrap_init(f, static_broadcasted_tuple)
63+
wrapped_fun = _pmap_wrap_init(f, static_broadcasted_tuple)
6464

6565
def infer_params(*args, **kwargs):
66-
p = _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple,
67-
donate_tuple, devices, backend, axis_size, args, kwargs)
66+
fun, dyn_argnums, dyn_args = _get_dyn_args(
67+
wrapped_fun, static_broadcasted_tuple, args)
68+
dyn_args_flat, dyn_args_tree = tree_flatten((dyn_args, kwargs))
69+
in_axes_flat = _get_in_axes_flat(
70+
in_axes, dyn_argnums, dyn_args, kwargs, len(dyn_args_flat),
71+
dyn_args_tree)
72+
local_axis_size = _mapped_axis_size(dyn_args_flat, in_axes_flat)
73+
donated_invars = _get_donated_invars(
74+
donate_tuple, dyn_args_tree, len(dyn_args_flat))
75+
fun, out_axes_thunk = flat_out_axes(fun, out_axes)
76+
flat_fun, out_tree = flatten_fun(fun, dyn_args_tree)
77+
global_axis_size = _get_global_axis_size(local_axis_size, devices,
78+
backend, axis_size)
6879
trace_state_clean = core.trace_state_clean()
69-
mesh = Mesh(_get_devices(p, backend), (axis_name,))
80+
mesh = Mesh(
81+
_get_devices(devices, local_axis_size, global_axis_size, backend),
82+
(axis_name,))
7083
_pmapped, in_specs, out_specs = _cached_shard_map(
71-
p.flat_fun, mesh, p.in_axes_flat, p.out_axes_thunk, axis_name)
84+
flat_fun, mesh, in_axes_flat, out_axes_thunk, axis_name)
7285
jitted_f = api.jit(
7386
_pmapped,
74-
donate_argnums=[i for i, val in enumerate(p.donated_invars) if val])
87+
donate_argnums=[i for i, val in enumerate(donated_invars) if val])
7588
if xb.process_count() > 1:
7689
if trace_state_clean:
7790
flat_global_args = [
7891
host_local_array_to_global_array(arr, global_mesh=mesh, pspec=spec)
79-
for arr, spec in zip(p.flat_args, in_specs)]
92+
for arr, spec in zip(dyn_args_flat, in_specs)]
8093
else:
8194
flat_global_args = mhu.host_local_array_to_global_array(
82-
p.flat_args, mesh, list(in_specs))
95+
dyn_args_flat, mesh, list(in_specs))
8396
else:
84-
flat_global_args = p.flat_args
85-
return jitted_f, flat_global_args, p, mesh, out_specs, donate_tuple
97+
flat_global_args = dyn_args_flat
98+
return jitted_f, flat_global_args, dyn_args_tree, out_tree, mesh, out_specs, donate_tuple
8699

87100
@util.wraps(f)
88101
def wrapped(*args, **kwargs):
89-
jitted_f, flat_global_args, p, mesh, out_specs, _ = infer_params(
102+
jitted_f, flat_global_args, _, out_tree, mesh, out_specs, _ = infer_params(
90103
*args, **kwargs)
91104
outs = jitted_f(*flat_global_args)
92105
if xb.process_count() > 1:
@@ -97,15 +110,15 @@ def wrapped(*args, **kwargs):
97110
]
98111
else:
99112
outs = mhu.global_array_to_host_local_array(outs, mesh, out_specs())
100-
return tree_unflatten(p.out_tree(), outs)
113+
return tree_unflatten(out_tree(), outs)
101114

102115
def lower(*args, **kwargs):
103-
jitted_f, flat_global_args, p, _, _, donate_tuple = infer_params(
116+
jitted_f, flat_global_args, in_tree, out_tree, _, _, donate_tuple = infer_params(
104117
*args, **kwargs)
105118
abstract_args = list(map(core.shaped_abstractify, flat_global_args))
106-
args_info = stages.make_args_info(p.in_tree, abstract_args, donate_tuple)
119+
args_info = stages.make_args_info(in_tree, abstract_args, donate_tuple)
107120
lowered = jitted_f.trace(*flat_global_args).lower()
108-
lowered = stages.Lowered(lowered._lowering, args_info, p.out_tree(),
121+
lowered = stages.Lowered(lowered._lowering, args_info, out_tree(),
109122
no_kwargs=lowered._no_kwargs)
110123
return lowered
111124
wrapped.lower = lower
@@ -137,14 +150,14 @@ def _handle_reshapes(f, in_axes, out_axes_thunk, *args, **kwargs):
137150
return tree_map(lambda x, ax: x if ax is None else lax.expand_dims(x, [ax]),
138151
list(out), list(out_axes_thunk()))
139152

140-
def _get_devices(p, backend):
141-
if backend is not None and p.devices is None:
153+
def _get_devices(devices, local_axis_size, global_axis_size, backend):
154+
if backend is not None and devices is None:
142155
devs = xb.devices(backend=backend)
143156
else:
144-
devs = xb.devices() if p.devices is None else p.devices
157+
devs = xb.devices() if devices is None else devices
145158
if xb.process_count() > 1:
146-
return devs[:p.global_axis_size]
147-
return devs[:p.local_axis_size]
159+
return devs[:global_axis_size]
160+
return devs[:local_axis_size]
148161

149162

150163
def _ensure_index_tuple(x) -> tuple[int, ...]:
@@ -181,20 +194,6 @@ def _mapped_axis_size(args, in_axes):
181194
raise ValueError("pmap requires at least one argument with a mapped axis.")
182195

183196

184-
class PmapCallInfo(NamedTuple):
185-
flat_fun: lu.WrappedFun
186-
in_tree: Any
187-
out_tree: Callable
188-
flat_args: Sequence[Any]
189-
donated_invars: Sequence[bool]
190-
in_axes_flat: Sequence[int | None]
191-
local_axis_size: int
192-
out_axes_thunk: Callable
193-
devices: Sequence | None
194-
global_axis_size: int
195-
is_explicit_global_axis_size: bool
196-
197-
198197
def _get_global_axis_size(local_axis_size: int, in_devices, backend_name: str,
199198
global_axis_size: int | None):
200199
if (xb.process_count() == 1 and global_axis_size is not None and
@@ -221,36 +220,6 @@ def _get_global_axis_size(local_axis_size: int, in_devices, backend_name: str,
221220
return global_axis_size
222221

223222

224-
def _prepare_pmap(fun: lu.WrappedFun, in_axes, out_axes,
225-
static_broadcasted_tuple, donate_tuple, in_devices,
226-
backend_name, axis_size, args, kwargs):
227-
fun, dyn_argnums, dyn_args = _get_dyn_args(fun, static_broadcasted_tuple, args)
228-
dyn_args_flat, dyn_args_tree = tree_flatten((dyn_args, kwargs))
229-
in_axes_flat = _get_in_axes_flat(in_axes, dyn_argnums, dyn_args, kwargs,
230-
len(dyn_args_flat), dyn_args_tree)
231-
local_axis_size = _mapped_axis_size(dyn_args_flat, in_axes_flat)
232-
donated_invars = _get_donated_invars(donate_tuple, dyn_args_tree,
233-
len(dyn_args_flat))
234-
235-
fun, out_axes_thunk = flat_out_axes(fun, out_axes)
236-
flat_fun, out_tree = flatten_fun(fun, dyn_args_tree)
237-
238-
is_explicit_global_axis_size = axis_size is not None
239-
global_axis_size = _get_global_axis_size(local_axis_size, in_devices,
240-
backend_name, axis_size)
241-
return PmapCallInfo(flat_fun=flat_fun,
242-
in_tree=dyn_args_tree,
243-
out_tree=out_tree,
244-
flat_args=dyn_args_flat,
245-
donated_invars=donated_invars,
246-
in_axes_flat=in_axes_flat,
247-
local_axis_size=local_axis_size,
248-
out_axes_thunk=out_axes_thunk,
249-
devices=None if in_devices is None else tuple(in_devices),
250-
global_axis_size=global_axis_size,
251-
is_explicit_global_axis_size=is_explicit_global_axis_size)
252-
253-
254223
def _pmap_wrap_init(f, static_broadcasted_tuple):
255224
"""Create a wrapped function with DebugInfo for pmap.
256225

tests/pmap_test.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -689,7 +689,11 @@ def testInAxesPyTreePrefixMismatchErrorKwargs(self):
689689
def testOutAxesPyTreePrefixMismatchError(self):
690690
x = jnp.array([3.14])
691691
f = jax.pmap(lambda x, y: ((x, x), x), out_axes=((0, 0, 0), 0))
692-
with self.assertRaisesRegex(ValueError, re.escape("pmap out_axes[0]")):
692+
if config.pmap_shmap_merge.value:
693+
regex = "pytree structure error: different lengths of tuple at key path.*"
694+
else:
695+
regex = re.escape("pmap out_axes[0]")
696+
with self.assertRaisesRegex(ValueError, regex):
693697
f(x, x)
694698

695699
@parameterized.named_parameters(

0 commit comments

Comments
 (0)