Skip to content

Commit 49561f6

Browse files
danielsuoGoogle-ML-Automation
authored andcommitted
[pmap] Remove any wrappings that have auxiliary values.
We want to avoid StoreException or StoreEmpty errors. Previous code added another transformation to reset stores. New code does away with this and also gives a chance to remove some unnecessary flatten/unflattens. Improving the `jax.jit(jax.shard_map)` implementation of `jax.pmap`. PiperOrigin-RevId: 861864367
1 parent 1cf6690 commit 49561f6

File tree

1 file changed

+45
-41
lines changed

1 file changed

+45
-41
lines changed

jax/_src/pmap.py

Lines changed: 45 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,7 @@
1818

1919
from jax._src import api
2020
from jax._src.api_util import (
21-
argnums_partial, donation_vector, flatten_fun,
22-
flat_out_axes, fun_signature, fun_sourceinfo)
21+
argnums_partial, donation_vector, fun_signature, fun_sourceinfo)
2322
from jax._src import array
2423
from jax._src import config
2524
from jax._src import core
@@ -37,7 +36,8 @@
3736
from jax._src.mesh import Mesh
3837
from jax._src.shard_map import _axes_to_pspec, _shard_map
3938
from jax._src.tree_util import (
40-
broadcast_flattened_prefix_with_treedef, prefix_errors, tree_flatten, tree_map, tree_unflatten)
39+
broadcast_flattened_prefix_with_treedef, broadcast_prefix,
40+
prefix_errors, tree_flatten, tree_map, tree_unflatten)
4141
import numpy as np
4242

4343
map, unsafe_map = util.safe_map, map
@@ -61,11 +61,13 @@ def pmap(f, axis_name=None, *, in_axes=0, out_axes=0,
6161
if isinstance(axis_name, core._TempAxisName): # pylint: disable=protected-access
6262
axis_name = repr(axis_name)
6363
wrapped_fun = _pmap_wrap_init(f, static_broadcasted_tuple)
64+
out_axes_flat, out_axes_tree = tree_flatten(out_axes)
65+
out_axes_flat = tuple(out_axes_flat)
6466

6567
def infer_params(*args, **kwargs):
6668
process_count = xb.process_count(backend)
6769
trace_state_clean = core.trace_state_clean()
68-
fun, dyn_argnums, dyn_args = _get_dyn_args(
70+
dyn_f, dyn_argnums, dyn_args = _get_dyn_args(
6971
wrapped_fun, static_broadcasted_tuple, args)
7072
dyn_args_flat, dyn_args_tree = tree_flatten((dyn_args, kwargs))
7173
in_axes_flat = _get_in_axes_flat(
@@ -76,11 +78,9 @@ def infer_params(*args, **kwargs):
7678
donate_tuple, dyn_args_tree, len(dyn_args_flat))
7779
mesh_devices = _get_mesh_devices(
7880
devices, backend, local_axis_size, axis_size, trace_state_clean)
79-
fun, out_axes_thunk = flat_out_axes(fun, out_axes)
80-
flat_fun, out_tree = flatten_fun(fun, dyn_args_tree)
81-
mesh = Mesh(mesh_devices, (axis_name,))
82-
_pmapped, in_specs, out_specs = _cached_shard_map(
83-
flat_fun, mesh, in_axes_flat, out_axes_thunk, axis_name)
81+
_pmapped, in_specs, out_specs, mesh = _cached_shard_map(
82+
dyn_f, dyn_args_tree, in_axes_flat, out_axes_flat, out_axes_tree,
83+
mesh_devices, axis_name)
8484
jitted_f = api.jit(
8585
_pmapped,
8686
donate_argnums=[i for i, val in enumerate(donated_invars) if val])
@@ -94,61 +94,65 @@ def infer_params(*args, **kwargs):
9494
dyn_args_flat, mesh, list(in_specs))
9595
else:
9696
flat_global_args = dyn_args_flat
97-
return (jitted_f, flat_global_args, dyn_args_tree, out_tree, mesh, out_specs,
97+
return (jitted_f, flat_global_args, dyn_args_tree, mesh, out_specs,
9898
donate_tuple, process_count, trace_state_clean)
9999

100100
@util.wraps(f)
101101
def wrapped(*args, **kwargs):
102-
(jitted_f, flat_global_args, _, out_tree, mesh, out_specs,
102+
(jitted_f, flat_global_args, _, mesh, out_specs,
103103
_, process_count, trace_state_clean) = infer_params(*args, **kwargs)
104104
outs = jitted_f(*flat_global_args)
105105
if process_count > 1:
106106
if trace_state_clean:
107-
outs = [
108-
global_array_to_host_local_array(out, global_mesh=mesh, pspec=spec)
109-
for out, spec in zip(outs, out_specs())
110-
]
107+
outs = tree_map(
108+
lambda out, spec: global_array_to_host_local_array(
109+
out, global_mesh=mesh, pspec=spec),
110+
outs, out_specs, is_leaf=lambda x: x is None)
111111
else:
112-
outs = mhu.global_array_to_host_local_array(outs, mesh, out_specs())
113-
return tree_unflatten(out_tree(), outs)
112+
outs = mhu.global_array_to_host_local_array(outs, mesh, out_specs)
113+
return outs
114114

115115
def lower(*args, **kwargs):
116-
jitted_f, flat_global_args, in_tree, out_tree, _, _, donate_tuple, _, _ = infer_params(
116+
jitted_f, flat_global_args, in_tree, _, _, donate_tuple, _, _ = infer_params(
117117
*args, **kwargs)
118118
abstract_args = list(map(core.shaped_abstractify, flat_global_args))
119119
args_info = stages.make_args_info(in_tree, abstract_args, donate_tuple)
120120
lowered = jitted_f.trace(*flat_global_args).lower()
121-
lowered = stages.Lowered(lowered._lowering, args_info, out_tree(),
121+
lowered = stages.Lowered(lowered._lowering, args_info, lowered.out_tree,
122122
no_kwargs=lowered._no_kwargs)
123123
return lowered
124124
wrapped.lower = lower
125125
return wrapped
126126

127127

128128
@lu.cache
129-
def _cached_shard_map(flat_fun, mesh, in_axes_flat, out_axes_thunk, axis_name):
130-
f_transformed = flat_fun.f_transformed
131-
def reset_stores_f_transformed(*args, **kwargs):
132-
for store in flat_fun.stores:
133-
if store is not None:
134-
store.reset()
135-
return f_transformed(*args, **kwargs)
136-
flat_fun.f_transformed = reset_stores_f_transformed
129+
def _cached_shard_map(fun, in_tree, in_axes_flat,
130+
out_axes_flat, out_axes_tree, mesh_devices, axis_name):
131+
mesh = Mesh(mesh_devices, (axis_name,))
132+
out_axes = tree_unflatten(out_axes_tree, list(out_axes_flat))
137133
in_specs = tuple(map(partial(_axes_to_pspec, axis_name), in_axes_flat))
138-
out_specs = lambda: map(partial(_axes_to_pspec, axis_name), out_axes_thunk())
139-
fun = _handle_reshapes(flat_fun, in_axes_flat, out_axes_thunk)
140-
return (_shard_map(fun.call_wrapped, mesh=mesh, in_specs=in_specs,
141-
out_specs=out_specs, check_vma=False,
142-
axis_names=set(mesh.axis_names)),
143-
in_specs, out_specs)
144-
145-
@lu.transformation2
146-
def _handle_reshapes(f, in_axes, out_axes_thunk, *args, **kwargs):
147-
args = tree_map(lambda x, ax: x if ax is None else lax.squeeze(x, [ax]),
148-
list(args), list(in_axes))
149-
out = f(*args)
150-
return tree_map(lambda x, ax: x if ax is None else lax.expand_dims(x, [ax]),
151-
list(out), list(out_axes_thunk()))
134+
out_specs = tree_map(
135+
partial(_axes_to_pspec, axis_name), out_axes, is_leaf=lambda x: x is None
136+
)
137+
def _fun(*flat_args):
138+
args = tree_map(
139+
lambda x, ax: x if ax is None else lax.squeeze(x, [ax]),
140+
flat_args,
141+
in_axes_flat,
142+
)
143+
args, kwargs = tree_unflatten(in_tree, args)
144+
out = fun.call_wrapped(*args, **kwargs)
145+
out_flat, out_tree = tree_flatten(out)
146+
out_axes_flat = broadcast_prefix(out_axes, out, is_leaf=lambda x: x is None)
147+
out_flat = tree_map(
148+
lambda x, ax: x if ax is None else lax.expand_dims(x, [ax]),
149+
out_flat,
150+
out_axes_flat,
151+
)
152+
return tree_unflatten(out_tree, out_flat)
153+
_pmapped = _shard_map(_fun, mesh=mesh, in_specs=in_specs, out_specs=out_specs,
154+
check_vma=False, axis_names=set(mesh.axis_names))
155+
return (_pmapped, in_specs, out_specs, mesh)
152156

153157

154158
def _mapped_axis_size(args, in_axes):

0 commit comments

Comments
 (0)