Skip to content

Commit 0b3f785

Browse files
yashk2810Google-ML-Automation
authored andcommitted
Fix bugs wrt partial-auto when there are multiple levels of nesting.
The changes are: * If the mesh passed to shard_map doesn't match the context mesh (if present), error out * Whenever we trace a jaxpr in shard_map: * the avals passed via `_shard_aval` should union the current manual axes on mesh with the newly manual axes specified on shard_map's `axis_name` argument * The mesh we enter into when in `use_abstract_mesh` should also union the current manual axes with the newly manual axes PiperOrigin-RevId: 750774376
1 parent 614e975 commit 0b3f785

File tree

5 files changed

+100
-55
lines changed

5 files changed

+100
-55
lines changed

jax/_src/debugging.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def f():
141141
return jax.lax.cond(idx == 0,
142142
lambda: debug_callback_p.bind(*args, **params),
143143
lambda: [])
144-
return jax.shard_map(f, mesh=axis_context.mesh, in_specs=(), out_specs=[])()
144+
return jax.shard_map(f, in_specs=(), out_specs=[])()
145145

146146
def debug_callback_lowering(ctx, *args, effect, partitioned, callback, **params):
147147
axis_context = ctx.module_context.axis_context

jax/_src/lax/parallel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1925,8 +1925,8 @@ def _build_axis_index_lowering_hlo(ctx, axis_name, axis_env):
19251925
def f():
19261926
return axis_index_p.bind(axis_name=axis_name)
19271927
return mlir.lower_fun(
1928-
lambda: [jax.shard_map(f, mesh=axis_context.mesh, check_vma=False,
1929-
in_specs=(), out_specs=P())()])(ctx)[0]
1928+
lambda: [jax.shard_map(f, check_vma=False, in_specs=(),
1929+
out_specs=P())()])(ctx)[0]
19301930

19311931
nreplicas = axis_env.nreps // math.prod(axis_env.sizes)
19321932
div = mlir.ir_constant(

jax/_src/shard_map.py

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,8 @@ def shard_map(f=None, /, *, out_specs: Specs, axis_names: Set[AxisName] = set(),
128128

129129
def _shard_map(f: Callable, *, mesh: Mesh | AbstractMesh | None,
130130
in_specs: Specs, out_specs: Specs | Callable[[], Specs],
131-
axis_names: Set[AxisName], check_vma: bool):
131+
axis_names: Set[AxisName], check_vma: bool,
132+
_skip_mesh_check: bool = False):
132133
if not callable(f):
133134
raise TypeError("shard_map requires a callable for its first argument, "
134135
f"but got {f} of type {type(f)}.")
@@ -140,6 +141,14 @@ def _shard_map(f: Callable, *, mesh: Mesh | AbstractMesh | None,
140141
"The context mesh cannot be empty. Either use"
141142
" `jax.sharding.use_mesh(mesh)` to enter into a mesh context or pass"
142143
" a mesh to `shard_map` via the `mesh` keyword argument.")
144+
else:
145+
ctx_mesh = get_abstract_mesh()
146+
if (not _skip_mesh_check and not ctx_mesh.empty and
147+
mesh.abstract_mesh != ctx_mesh):
148+
raise ValueError(
149+
f"The context mesh {ctx_mesh} should match the mesh passed to"
150+
f" shard_map {mesh}")
151+
143152
if not isinstance(mesh, (Mesh, AbstractMesh)):
144153
raise TypeError("shard_map requires a `jax.sharding.Mesh` or a "
145154
"`jax.sharding.AbstractMesh` instance for its "
@@ -540,7 +549,7 @@ def _as_manual_mesh(mesh, manual_axes: frozenset):
540549
if cur_mesh._name_to_type[a] == AxisType.Auto:
541550
auto_axes.add(a)
542551
else:
543-
assert cur_mesh._name_to_type[a] == AxisType.Explicit
552+
assert cur_mesh._name_to_type[a] == AxisType.Explicit, cur_mesh._name_to_type[a]
544553
explicit_axes.add(a)
545554

546555
new_axis_types = []
@@ -558,7 +567,7 @@ def _as_manual_mesh(mesh, manual_axes: frozenset):
558567

559568
def _extend_axis_env(mesh, manual_axes):
560569
return core.extend_axis_env_nd([(k, v) for k, v in mesh.shape.items()
561-
if k in manual_axes])
570+
if k in manual_axes])
562571

563572
def _shard_map_staging(
564573
trace: pe.DynamicJaxprTrace, prim: core.Primitive, f: lu.WrappedFun,
@@ -571,11 +580,11 @@ def _shard_map_staging(
571580
source_info = source_info_util.current()
572581
to_jaxpr_tracer = partial(trace.to_jaxpr_tracer, source_info=source_info)
573582
in_tracers = map(to_jaxpr_tracer, in_tracers)
583+
inner_mesh = _as_manual_mesh(mesh, manual_axes | set(mesh.manual_axes))
574584
in_avals = [t.aval for t in in_tracers]
575585
in_avals_ = map(partial(_shard_aval, mesh, manual_axes, check_vma), in_names,
576586
in_avals)
577-
manual_mesh = _as_manual_mesh(mesh, manual_axes)
578-
with (_extend_axis_env(mesh, manual_axes), use_abstract_mesh(manual_mesh),
587+
with (_extend_axis_env(mesh, manual_axes), use_abstract_mesh(inner_mesh),
579588
config._check_vma(check_vma)):
580589
jaxpr, out_avals_, consts, () = pe.trace_to_jaxpr_dynamic(f, in_avals_)
581590
_check_names(out_names_thunk(), out_avals_)
@@ -590,7 +599,7 @@ def _shard_map_staging(
590599
constvars = map(trace.getvar, map(to_jaxpr_tracer, consts))
591600
outvars = map(trace.makevar, out_tracers)
592601
in_names_staged = ({},) * len(consts) + tuple(in_names) # type: ignore
593-
with (_extend_axis_env(mesh, manual_axes), use_abstract_mesh(manual_mesh),
602+
with (_extend_axis_env(mesh, manual_axes), use_abstract_mesh(inner_mesh),
594603
config._check_vma(check_vma)):
595604
jaxpr = pe.convert_constvars_jaxpr(jaxpr)
596605
params = dict(mesh=mesh, in_names=in_names_staged,
@@ -629,10 +638,11 @@ def _shard_shaped_array(mesh: Mesh, manual_axes: frozenset, check_vma,
629638
assert isinstance(aval, core.ShapedArray)
630639
new_shape = tuple(sz // prod(mesh.shape[n] for n in names.get(i, ()))
631640
for i, sz in enumerate(aval.shape))
632-
manual_mesh = _as_manual_mesh(mesh, manual_axes)
641+
manual_mesh = _as_manual_mesh(mesh, manual_axes | set(mesh.manual_axes))
633642
new_sharding = NamedSharding(manual_mesh, aval.sharding.spec)
634643
vma = (frozenset({n for ns in names.values() for n in ns})
635644
if check_vma else frozenset())
645+
vma = vma | aval.vma
636646
return aval.update(shape=new_shape, sharding=new_sharding, vma=vma)
637647
core.shard_aval_handlers[core.ShapedArray] = _shard_shaped_array
638648

@@ -695,7 +705,7 @@ def _shard_map_typecheck(_, *in_atoms, jaxpr, mesh, in_names, out_names,
695705

696706

697707
def _valid_repeats(mesh: Mesh, vma: Set[AxisName], names: AxisNames) -> bool:
698-
um = set(_unmentioned(mesh, names))
708+
um = set(_unmentioned(mesh, names)) - set(mesh.manual_axes)
699709
if any(u in vma for u in um):
700710
return False
701711
return True
@@ -808,8 +818,10 @@ def _xla_shard(ctx: mlir.LoweringRuleContext, mesh, manual_axes, names,
808818
if len(manual_axes) < len(mesh.axis_names) else set())
809819
sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, shard_proto,
810820
unspecified_dims=unspecified)
811-
manual_proto = pxla.manual_proto(aval_in, manual_axes, mesh)
812-
return mlir.wrap_with_full_to_shard_op(ctx, sx, aval_out, manual_proto, unspecified)
821+
manual_proto = pxla.manual_proto(
822+
aval_in, manual_axes | set(mesh.manual_axes), mesh)
823+
return mlir.wrap_with_full_to_shard_op(ctx, sx, aval_out, manual_proto,
824+
unspecified)
813825

814826
def _xla_unshard(ctx: mlir.LoweringRuleContext, mesh, manual_axes, names,
815827
aval_in, aval_out, x):
@@ -824,8 +836,10 @@ def _xla_unshard(ctx: mlir.LoweringRuleContext, mesh, manual_axes, names,
824836
if len(manual_axes) < len(mesh.axis_names) else set())
825837
if dtypes.issubdtype(aval_in.dtype, dtypes.extended):
826838
aval_in = core.physical_aval(aval_in)
827-
manual_proto = pxla.manual_proto(aval_in, manual_axes, mesh)
828-
sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, manual_proto, unspecified_dims=unspecified)
839+
manual_proto = pxla.manual_proto(
840+
aval_in, manual_axes | set(mesh.manual_axes), mesh)
841+
sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, manual_proto,
842+
unspecified_dims=unspecified)
829843
shard_proto = ns._to_xla_hlo_sharding(aval_out.ndim).to_proto()
830844
return mlir.wrap_with_shard_to_full_op(ctx, sx, aval_out, shard_proto,
831845
unspecified)
@@ -894,9 +908,9 @@ def _shard_map_impl(trace, prim, fun, args, *, mesh, in_names, out_names_thunk,
894908
def _run_shmap(f, mesh, manual_axes, args, vmas, check_vma, context_mesh):
895909
trace = ShardMapTrace(mesh, manual_axes, check_vma, context_mesh)
896910
in_tracers = map(partial(ShardMapTracer, trace), vmas, args)
897-
manual_mesh = _as_manual_mesh(mesh, manual_axes)
911+
inner_mesh = _as_manual_mesh(mesh, manual_axes | set(mesh.manual_axes))
898912
with (core.set_current_trace(trace), _extend_axis_env(mesh, manual_axes),
899-
use_abstract_mesh(manual_mesh), config._check_vma(check_vma)):
913+
use_abstract_mesh(inner_mesh), config._check_vma(check_vma)):
900914
ans = f.call_wrapped(*in_tracers)
901915
outs, out_vma = unzip2(map(trace.to_val_vma_pair, ans))
902916
return outs, out_vma
@@ -1318,7 +1332,7 @@ def fwd_out_names_thunk():
13181332
args_to_promote = [getattr(aval, 'shape', ()) == () and f1 is None and f2 is None
13191333
for aval, f1, f2 in zip(res_avals, in_fwd, out_fwd)]
13201334
with (_extend_axis_env(mesh, manual_axes),
1321-
use_abstract_mesh(_as_manual_mesh(mesh, manual_axes)),
1335+
use_abstract_mesh(_as_manual_mesh(mesh, manual_axes | set(mesh.manual_axes))),
13221336
config._check_vma(check_vma)):
13231337
lin_jaxpr = _promote_scalar_residuals_jaxpr(lin_jaxpr, args_to_promote)
13241338
out_names = out_names_thunk()
@@ -1483,7 +1497,7 @@ def _partial_eval_jaxpr_custom_rule(
14831497
jaxpr, mesh = eqn.params['jaxpr'], eqn.params['mesh']
14841498
check_vma, manual_axes = eqn.params['check_vma'], eqn.params['manual_axes']
14851499
with (_extend_axis_env(mesh, manual_axes), config._check_vma(check_vma),
1486-
use_abstract_mesh(_as_manual_mesh(mesh, manual_axes))):
1500+
use_abstract_mesh(_as_manual_mesh(mesh, manual_axes | set(mesh.manual_axes)))):
14871501
jaxpr_known, jaxpr_staged, unks_out, inst_out, num_res = \
14881502
pe.partial_eval_jaxpr_custom(jaxpr, unks_in, inst_in, False, False, saveable)
14891503
num_out_primals = len(jaxpr_known.outvars) - num_res
@@ -1494,7 +1508,7 @@ def _partial_eval_jaxpr_custom_rule(
14941508
which = [f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd)]
14951509
mesh = eqn.params['mesh']
14961510
with (_extend_axis_env(mesh, manual_axes),
1497-
use_abstract_mesh(_as_manual_mesh(mesh, manual_axes)),
1511+
use_abstract_mesh(_as_manual_mesh(mesh, manual_axes | set(mesh.manual_axes))),
14981512
config._check_vma(check_vma)):
14991513
jaxpr_known = pe.prune_jaxpr_outputs(jaxpr_known, [True] * num_out_primals + which)
15001514
jaxpr_known, jaxpr_staged = _add_reshapes(which, jaxpr_known, jaxpr_staged)

jax/experimental/shard_map.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,6 @@ def shard_map(
7777
.. _SPMD multi-device parallelism with shard_map: https://docs.jax.dev/en/latest/notebooks/shard_map.html
7878
"""
7979
axis_names = frozenset(mesh.axis_names) - auto
80-
return jshmap.shard_map(
80+
return jshmap._shard_map(
8181
f, mesh=mesh, in_specs=in_specs, out_specs=out_specs,
82-
check_vma=check_rep, axis_names=axis_names)
82+
check_vma=check_rep, axis_names=axis_names, _skip_mesh_check=True)

tests/shard_map_test.py

Lines changed: 64 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -438,19 +438,17 @@ def test_replication_checker_jit(self):
438438
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
439439
x = np.arange(8 * 8.).reshape(8, 8)
440440

441-
def f(x):
442-
return 2 * x
443441
def g(x):
444-
return shard_map(f, mesh=mesh, in_specs=(P('x', 'y'),), out_specs=P(None, 'y'))(x)
442+
return shard_map(lambda x: x * 2, mesh=mesh, in_specs=P('x', 'y'),
443+
out_specs=P(None, 'y'))(x)
445444

446445
with self.assertRaisesRegex(ValueError, 'statically inferred'):
447446
jax.jit(g)(x)
448447

449-
def f2(x):
450-
return jax.lax.psum(x, 'x')
451448
def g2(x):
452-
return shard_map(f2, mesh=mesh, in_specs=(P('x', 'y'),), out_specs=P(None, 'y'))(x)
453-
_ = jax.jit(g2)(x) # doesn't crash
449+
return shard_map(lambda x: jax.lax.psum(x, 'x'), mesh=mesh,
450+
in_specs=P('x', 'y'), out_specs=P(None, 'y'))(x)
451+
jax.jit(g2)(x) # doesn't crash
454452

455453
def test_process_env_traces(self):
456454
mesh = Mesh(np.array(jax.devices()[:4]), ('x',))
@@ -2242,16 +2240,54 @@ def g(x):
22422240
return x * x
22432241

22442242
def h(x):
2245-
return shard_map(g, mesh=mesh, in_specs=P(None, 'j'), out_specs=P(None, 'j'))(x)
2243+
return shard_map(g, in_specs=P(None, 'j'), out_specs=P(None, 'j'))(x)
22462244

22472245
@jax.jit
22482246
def f(x):
2249-
return shard_map(h, mesh=mesh, in_specs=P('i', None), out_specs=P('i', None),
2247+
return shard_map(h, in_specs=P('i', None), out_specs=P('i', None),
22502248
check_vma=False, axis_names=frozenset({'i'}))(x)
22512249

22522250
v = jnp.arange(32.).reshape(4, 8)
22532251
v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j')))
2254-
self.assertAllClose(v*v, f(v), check_dtypes=False)
2252+
with jax.sharding.use_mesh(mesh):
2253+
self.assertAllClose(v*v, f(v), check_dtypes=False)
2254+
2255+
@parameterized.named_parameters(
2256+
('0', 'x', 'y', {'x'}, {'x', 'y'}),
2257+
('1', None, 'y', frozenset(), {'y'}),
2258+
('2', 'x', None, {'x'}, {'x'}),
2259+
('3', None, None, frozenset(), frozenset()),
2260+
)
2261+
def test_nested_partial_auto_1d(self, dim1, dim2, outer_vma, inner_vma):
2262+
mesh = jtu.create_mesh((2, 2, 2), ('x', 'y', 'z'))
2263+
np_inp = np.arange(32.).reshape(4, 8)
2264+
arr = jax.device_put(np_inp, NamedSharding(mesh, P(dim1, dim2)))
2265+
2266+
def g(x):
2267+
self.assertEqual(get_abstract_mesh().manual_axes, ('x', 'y'))
2268+
self.assertEqual(get_abstract_mesh().auto_axes, ('z',))
2269+
self.assertEqual(x.aval.vma, inner_vma)
2270+
out = x * x
2271+
self.assertEqual(out.aval.vma, inner_vma)
2272+
return out
2273+
2274+
def h(x):
2275+
self.assertEqual(get_abstract_mesh().manual_axes, ('x',))
2276+
self.assertEqual(get_abstract_mesh().auto_axes, ('y', 'z'))
2277+
self.assertEqual(x.aval.vma, outer_vma)
2278+
out = shard_map(g, in_specs=P(None, dim2),
2279+
out_specs=P(None, dim2), axis_names={'y'})(x)
2280+
self.assertEqual(out.aval.vma, outer_vma)
2281+
return out
2282+
2283+
@jax.jit
2284+
def f(x):
2285+
return shard_map(h, in_specs=P(dim1, None),
2286+
out_specs=P(dim1, None), axis_names={'x'})(x)
2287+
2288+
with jax.sharding.use_mesh(mesh):
2289+
out = f(arr)
2290+
self.assertArraysEqual(out, np_inp * np_inp)
22552291

22562292
def test_grad_nested_partial_auto(self):
22572293
mesh = jtu.create_mesh((2, 2), ('i', 'j'))
@@ -2262,22 +2298,19 @@ def g(x):
22622298

22632299
def h(x):
22642300
# auto: 'j', manual: 'i'
2265-
return shard_map(g, mesh=mesh,
2266-
in_specs=P(None, 'j'),
2267-
out_specs=P(None, 'j'))(x)
2301+
return shard_map(g, in_specs=P(None, 'j'), out_specs=P(None, 'j'))(x)
22682302

22692303
@jax.jit
22702304
def f(x):
22712305
# auto: 'i', 'j'
2272-
return shard_map(h, mesh=mesh,
2273-
in_specs=P('i', None),
2274-
out_specs=P('i', None),
2275-
check_vma=False,
2276-
axis_names=frozenset({'i'}))(x).sum()
2306+
return shard_map(h, in_specs=P('i', None), out_specs=P('i', None),
2307+
check_vma=False, axis_names=frozenset({'i'}))(x).sum()
22772308

22782309
v = jnp.arange(32.).reshape(4, 8)
22792310
v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j')))
2280-
self.assertAllClose(v*2, jax.grad(f)(v), check_dtypes=False)
2311+
with jax.sharding.use_mesh(mesh):
2312+
out = jax.grad(f)(v)
2313+
self.assertAllClose(out, v * 2, check_dtypes=False)
22812314

22822315
def test_grad_nested_partial_auto_with_residuals(self):
22832316
mesh = jtu.create_mesh((2, 2), ('i', 'j'))
@@ -2286,21 +2319,18 @@ def g(x):
22862319
return x * x * x
22872320

22882321
def h(x):
2289-
return shard_map(g, mesh=mesh,
2290-
in_specs=P(None, 'j'),
2291-
out_specs=P(None, 'j'))(x)
2322+
return shard_map(g, in_specs=P(None, 'j'), out_specs=P(None, 'j'))(x)
22922323

22932324
@jax.jit
22942325
def f(x):
2295-
return shard_map(h, mesh=mesh,
2296-
in_specs=P('i', None),
2297-
out_specs=P('i', None),
2298-
check_vma=False,
2299-
axis_names=frozenset({'i'}))(x).sum()
2326+
return shard_map(h, in_specs=P('i', None), out_specs=P('i', None),
2327+
check_vma=False, axis_names=frozenset({'i'}))(x).sum()
23002328

23012329
v = jnp.arange(32.).reshape(4, 8)
23022330
v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j')))
2303-
self.assertAllClose(v*v*3, jax.grad(f)(v), check_dtypes=False)
2331+
with jax.sharding.use_mesh(mesh):
2332+
out = jax.grad(f)(v)
2333+
self.assertAllClose(out, v * v * 3, check_dtypes=False)
23042334

23052335
def test_axis_size_1_partial_auto(self):
23062336
mesh = jtu.create_mesh((1, 2, 2), ('i', 'j', 'k'))
@@ -2367,10 +2397,11 @@ def test_partial_auto_axis_index(self):
23672397
@partial(jax.jit, out_shardings=out_sharding)
23682398
def f():
23692399
return shard_map(lambda: jax.lax.axis_index('i').reshape(1,1),
2370-
mesh=mesh, in_specs=P('i', None), out_specs=P('i', None),
2400+
in_specs=P('i', None), out_specs=P('i', None),
23712401
check_vma=False, axis_names=frozenset({'i'}))()
23722402

2373-
self.assertAllClose(f(), np.arange(4, dtype=np.int32).reshape(-1, 1))
2403+
with jax.sharding.use_mesh(mesh):
2404+
self.assertAllClose(f(), np.arange(4, dtype=np.int32).reshape(-1, 1))
23742405

23752406
def test_partial_auto_axis_index_degenerated_axis(self):
23762407
mesh = jtu.create_mesh((1, 2), ('i', 'j'))
@@ -2432,11 +2463,11 @@ def g(x):
24322463

24332464
@jax.jit
24342465
def f(x):
2435-
return shard_map(g,
2436-
mesh=mesh, in_specs=P('i'), out_specs=None,
2466+
return shard_map(g, mesh=mesh, in_specs=P('i'), out_specs=None,
24372467
check_vma=False, axis_names=frozenset({'i'}))(x)
24382468

2439-
y = f(x) # don't crash
2469+
with jax.sharding.use_mesh(mesh):
2470+
f(x) # don't crash
24402471

24412472
def test_partial_auto_of_random_keys(self):
24422473
mesh = jtu.create_mesh((4, 2), ('i', 'j'))

0 commit comments

Comments
 (0)