Skip to content

Commit 2c9b917

Browse files
pschuhGoogle-ML-Automation
authored andcommitted
Don't psum over auto mesh dims in _unmentioned2.
PiperOrigin-RevId: 698440525
1 parent eab9026 commit 2c9b917

File tree

2 files changed

+28
-4
lines changed

2 files changed

+28
-4
lines changed

jax/experimental/shard_map.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1547,10 +1547,11 @@ def fun(*res_and_args):
15471547
return jaxpr
15481548

15491549

1550-
def _unmentioned2(mesh: Mesh, names: AxisNames) -> list[AxisName]:
1550+
def _unmentioned2(mesh: Mesh, names: AxisNames,
1551+
auto: frozenset[AxisName]) -> list[AxisName]:
15511552
# We use a filtered-down version of unmentioned to avoid defensive-psum over
15521553
# more chips than required in the transpose-no-check-rep case.
1553-
name_set = {n for ns in names.values() for n in ns}
1554+
name_set = {n for ns in names.values() for n in ns} | auto
15541555
return [n for n in _all_mesh_names_except_spmd(mesh) if n not in name_set]
15551556

15561557

@@ -1559,7 +1560,7 @@ def _shard_map_transpose(out_cts, *args, jaxpr, mesh, in_names, out_names,
15591560
mb_div = lambda x, y: x / y if y != 1 else x
15601561
out_cts = [ad.Zero(_shard_aval(mesh, ns, x.aval)) if type(x) is ad.Zero
15611562
else x if rewrite or dtypes.dtype(x) == dtypes.float0
1562-
else mb_div(x, prod(map(mesh.shape.get, _unmentioned2(mesh, ns))))
1563+
else mb_div(x, prod(map(mesh.shape.get, _unmentioned2(mesh, ns, auto))))
15631564
for ns, x in zip(out_names, out_cts)]
15641565
args = [x if type(x) is not ad.UndefinedPrimal else
15651566
ad.UndefinedPrimal(_shard_aval(mesh, ns, x.aval))
@@ -1577,7 +1578,7 @@ def fun_trans(out_cts, args):
15771578
)
15781579
out = [ad.Zero(_unshard_aval(mesh, ns, x.aval)) if type(x) is ad.Zero
15791580
else x if rewrite
1580-
else jax.lax.psum(x, tuple(_unmentioned2(mesh, ns)))
1581+
else jax.lax.psum(x, tuple(_unmentioned2(mesh, ns, auto)))
15811582
for ns, x in zip(in_names, out)]
15821583
return out
15831584

tests/shard_map_test.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2046,6 +2046,29 @@ def f(x):
20462046
v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j')))
20472047
self.assertAllClose(v*v, f(v), check_dtypes=False)
20482048

2049+
def test_grad_nested_partial_auto(self):
2050+
mesh = jtu.create_mesh((2, 2), ('i', 'j'))
2051+
2052+
def g(x):
2053+
return x * x
2054+
2055+
def h(x):
2056+
return shard_map(g, mesh,
2057+
in_specs=P(None, 'j'),
2058+
out_specs=P(None, 'j'))(x)
2059+
2060+
@jax.jit
2061+
def f(x):
2062+
return shard_map(h, mesh,
2063+
in_specs=P('i', None),
2064+
out_specs=P('i', None),
2065+
check_rep=False,
2066+
auto=frozenset({'j'}))(x).sum()
2067+
2068+
v = jnp.arange(32.).reshape(4, 8)
2069+
v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j')))
2070+
self.assertAllClose(v*2, jax.grad(f)(v), check_dtypes=False)
2071+
20492072
def test_axis_size_1_partial_auto(self):
20502073
mesh = jtu.create_mesh((1, 2, 2), ('i', 'j', 'k'))
20512074

0 commit comments

Comments
 (0)