Skip to content

Commit d352f4f

Browse files
dougalmGoogle-ML-Automation
authored andcommitted
Put the set of current spmd axis names in the axis env instead of spelunking
through the trace stack to find it. PiperOrigin-RevId: 694710181
1 parent 85dae9e commit d352f4f

File tree

3 files changed

+28
-14
lines changed

3 files changed

+28
-14
lines changed

jax/_src/core.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -955,6 +955,7 @@ def __eq__(self, other):
955955
@dataclass(frozen=True)
956956
class AxisEnv:
957957
axis_sizes : dict[AxisName, int]
958+
spmd_axis_names : set[AxisName]
958959

959960
def axis_size(self, axis_name):
960961
if axis_name not in self.axis_sizes:
@@ -971,20 +972,24 @@ def axis_names(self):
971972
def pop_pure(self, axis_name):
972973
new_sizes = self.axis_sizes.copy()
973974
new_sizes.pop(axis_name)
974-
return AxisEnv(new_sizes)
975+
return AxisEnv(new_sizes, self.spmd_axis_names)
975976

976977
def extend_pure(self, name_size_pairs):
977978
new_sizes = self.axis_sizes.copy()
978979
new_sizes.update((name, size) for name, size in name_size_pairs
979980
if name is not no_axis_name)
980-
return AxisEnv(new_sizes)
981+
return AxisEnv(new_sizes, self.spmd_axis_names)
982+
983+
def add_spmd_axis_names(self, axis_names):
984+
new_spmd_axis_names = self.spmd_axis_names | set(axis_names)
985+
return AxisEnv(self.axis_sizes, new_spmd_axis_names)
981986

982987
def as_hashable_key(self):
983988
return tuple((name, size) for (name, size) in self.axis_sizes.items()
984989
if name is not no_axis_name)
985990

986991
eval_trace = EvalTrace()
987-
top_axis_env = AxisEnv({})
992+
top_axis_env = AxisEnv({}, set())
988993

989994
class TracingContext(threading.local):
990995
trace: Trace | None
@@ -1045,6 +1050,16 @@ def extend_axis_env_nd(name_size_pairs : Iterable[tuple[AxisName, int]]):
10451050
finally:
10461051
trace_ctx.set_axis_env(prev)
10471052

1053+
@contextmanager
1054+
def add_spmd_axis_names(axis_names: AxisName | None):
1055+
prev = trace_ctx.axis_env
1056+
try:
1057+
if axis_names is not None:
1058+
trace_ctx.set_axis_env(prev.add_spmd_axis_names(axis_names))
1059+
yield
1060+
finally:
1061+
trace_ctx.set_axis_env(prev)
1062+
10481063
def get_axis_env():
10491064
return trace_ctx.axis_env
10501065

jax/_src/interpreters/batching.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -596,9 +596,10 @@ def _batch_inner(axis_data, out_dim_dests, tag, in_dims, *in_vals):
596596
idx = memoize(lambda: BatchTracer(trace, make_iota(axis_data.size), 0,
597597
source_info_util.current()))
598598
in_tracers = map(partial(to_elt, trace, idx), in_vals, in_dims)
599-
with core.set_current_trace(trace):
600-
with core.extend_axis_env_nd([(axis_data.name, axis_data.size)]):
601-
outs = yield in_tracers, {}
599+
with (core.set_current_trace(trace),
600+
core.extend_axis_env_nd([(axis_data.name, axis_data.size)]),
601+
core.add_spmd_axis_names(axis_data.spmd_name)):
602+
outs = yield in_tracers, {}
602603

603604
out_dim_dests = out_dim_dests() if callable(out_dim_dests) else out_dim_dests
604605
out_vals = map(partial(from_elt, trace, axis_data.size), range(len(outs)),
@@ -795,9 +796,10 @@ def _batch_jaxpr_inner(axis_data, tag, in_axes, *in_vals):
795796
_, in_axes = resolve_ragged_axes(in_vals, in_axes)
796797
in_tracers = [BatchTracer(trace, val, dim) if dim is not None else val
797798
for val, dim in zip(in_vals, in_axes)]
798-
with core.set_current_trace(trace):
799-
with core.extend_axis_env_nd([(axis_data.name, axis_data.size)]):
800-
outs = yield in_tracers, {}
799+
with (core.set_current_trace(trace),
800+
core.extend_axis_env_nd([(axis_data.name, axis_data.size)]),
801+
core.add_spmd_axis_names(axis_data.spmd_name)):
802+
outs = yield in_tracers, {}
801803
out_vals, out_axes = unzip2(map(trace.to_batch_info, outs))
802804
new_out_axes = indirectify_ragged_axes_against_inputs_outputs(
803805
out_axes, in_vals, out_vals)

jax/experimental/shard_map.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1506,7 +1506,7 @@ def _unmentioned2(mesh: Mesh, names: AxisNames) -> list[AxisName]:
15061506
# We use a filtered-down version of unmentioned to avoid defensive-psum over
15071507
# more chips than required in the transpose-no-check-rep case.
15081508
name_set = {n for ns in names.values() for n in ns}
1509-
return [n for n in mesh.axis_names if n not in name_set]
1509+
return [n for n in _all_mesh_names_except_spmd(mesh) if n not in name_set]
15101510

15111511

15121512
def _shard_map_transpose(out_cts, *args, jaxpr, mesh, in_names, out_names,
@@ -1652,10 +1652,7 @@ def _pe_custom_params(unks_in, inst_in, kept_outs_known, kept_outs_staged,
16521652

16531653
# TODO(mattjj): remove this mechanism when we revise mesh scopes
16541654
def _all_mesh_names_except_spmd(mesh: Mesh, trace=None) -> tuple[AxisName, ...]:
1655-
trace = core.unsafe_get_current_trace() if trace is None else trace
1656-
stack = core.unsafe_get_trace_stack(trace)
1657-
batch_traces = [t for t in stack if isinstance(t, batching.BatchTrace)]
1658-
spmd_names = {n for trace in batch_traces for n in trace.axis_data.spmd_name }
1655+
spmd_names = core.get_axis_env().spmd_axis_names
16591656
return tuple(name for name in mesh.axis_names if name not in spmd_names)
16601657

16611658
# DCE

0 commit comments

Comments
 (0)