@@ -128,7 +128,8 @@ def shard_map(f=None, /, *, out_specs: Specs, axis_names: Set[AxisName] = set(),
128128
129129def _shard_map (f : Callable , * , mesh : Mesh | AbstractMesh | None ,
130130 in_specs : Specs , out_specs : Specs | Callable [[], Specs ],
131- axis_names : Set [AxisName ], check_vma : bool ):
131+ axis_names : Set [AxisName ], check_vma : bool ,
132+ _skip_mesh_check : bool = False ):
132133 if not callable (f ):
133134 raise TypeError ("shard_map requires a callable for its first argument, "
134135 f"but got { f } of type { type (f )} ." )
@@ -140,6 +141,14 @@ def _shard_map(f: Callable, *, mesh: Mesh | AbstractMesh | None,
140141 "The context mesh cannot be empty. Either use"
141142 " `jax.sharding.use_mesh(mesh)` to enter into a mesh context or pass"
142143 " a mesh to `shard_map` via the `mesh` keyword argument." )
144+ else :
145+ ctx_mesh = get_abstract_mesh ()
146+ if (not _skip_mesh_check and not ctx_mesh .empty and
147+ mesh .abstract_mesh != ctx_mesh ):
148+ raise ValueError (
149+ f"The context mesh { ctx_mesh } should match the mesh passed to"
150+ f" shard_map { mesh } " )
151+
143152 if not isinstance (mesh , (Mesh , AbstractMesh )):
144153 raise TypeError ("shard_map requires a `jax.sharding.Mesh` or a "
145154 "`jax.sharding.AbstractMesh` instance for its "
@@ -540,7 +549,7 @@ def _as_manual_mesh(mesh, manual_axes: frozenset):
540549 if cur_mesh ._name_to_type [a ] == AxisType .Auto :
541550 auto_axes .add (a )
542551 else :
543- assert cur_mesh ._name_to_type [a ] == AxisType .Explicit
552+ assert cur_mesh ._name_to_type [a ] == AxisType .Explicit , cur_mesh . _name_to_type [ a ]
544553 explicit_axes .add (a )
545554
546555 new_axis_types = []
@@ -558,7 +567,7 @@ def _as_manual_mesh(mesh, manual_axes: frozenset):
558567
559568def _extend_axis_env (mesh , manual_axes ):
560569 return core .extend_axis_env_nd ([(k , v ) for k , v in mesh .shape .items ()
561- if k in manual_axes ])
570+ if k in manual_axes ])
562571
563572def _shard_map_staging (
564573 trace : pe .DynamicJaxprTrace , prim : core .Primitive , f : lu .WrappedFun ,
@@ -571,11 +580,11 @@ def _shard_map_staging(
571580 source_info = source_info_util .current ()
572581 to_jaxpr_tracer = partial (trace .to_jaxpr_tracer , source_info = source_info )
573582 in_tracers = map (to_jaxpr_tracer , in_tracers )
583+ inner_mesh = _as_manual_mesh (mesh , manual_axes | set (mesh .manual_axes ))
574584 in_avals = [t .aval for t in in_tracers ]
575585 in_avals_ = map (partial (_shard_aval , mesh , manual_axes , check_vma ), in_names ,
576586 in_avals )
577- manual_mesh = _as_manual_mesh (mesh , manual_axes )
578- with (_extend_axis_env (mesh , manual_axes ), use_abstract_mesh (manual_mesh ),
587+ with (_extend_axis_env (mesh , manual_axes ), use_abstract_mesh (inner_mesh ),
579588 config ._check_vma (check_vma )):
580589 jaxpr , out_avals_ , consts , () = pe .trace_to_jaxpr_dynamic (f , in_avals_ )
581590 _check_names (out_names_thunk (), out_avals_ )
@@ -590,7 +599,7 @@ def _shard_map_staging(
590599 constvars = map (trace .getvar , map (to_jaxpr_tracer , consts ))
591600 outvars = map (trace .makevar , out_tracers )
592601 in_names_staged = ({},) * len (consts ) + tuple (in_names ) # type: ignore
593- with (_extend_axis_env (mesh , manual_axes ), use_abstract_mesh (manual_mesh ),
602+ with (_extend_axis_env (mesh , manual_axes ), use_abstract_mesh (inner_mesh ),
594603 config ._check_vma (check_vma )):
595604 jaxpr = pe .convert_constvars_jaxpr (jaxpr )
596605 params = dict (mesh = mesh , in_names = in_names_staged ,
@@ -629,10 +638,11 @@ def _shard_shaped_array(mesh: Mesh, manual_axes: frozenset, check_vma,
629638 assert isinstance (aval , core .ShapedArray )
630639 new_shape = tuple (sz // prod (mesh .shape [n ] for n in names .get (i , ()))
631640 for i , sz in enumerate (aval .shape ))
632- manual_mesh = _as_manual_mesh (mesh , manual_axes )
641+ manual_mesh = _as_manual_mesh (mesh , manual_axes | set ( mesh . manual_axes ) )
633642 new_sharding = NamedSharding (manual_mesh , aval .sharding .spec )
634643 vma = (frozenset ({n for ns in names .values () for n in ns })
635644 if check_vma else frozenset ())
645+ vma = vma | aval .vma
636646 return aval .update (shape = new_shape , sharding = new_sharding , vma = vma )
637647core .shard_aval_handlers [core .ShapedArray ] = _shard_shaped_array
638648
@@ -695,7 +705,7 @@ def _shard_map_typecheck(_, *in_atoms, jaxpr, mesh, in_names, out_names,
695705
696706
697707def _valid_repeats (mesh : Mesh , vma : Set [AxisName ], names : AxisNames ) -> bool :
698- um = set (_unmentioned (mesh , names ))
708+ um = set (_unmentioned (mesh , names )) - set ( mesh . manual_axes )
699709 if any (u in vma for u in um ):
700710 return False
701711 return True
@@ -808,8 +818,10 @@ def _xla_shard(ctx: mlir.LoweringRuleContext, mesh, manual_axes, names,
808818 if len (manual_axes ) < len (mesh .axis_names ) else set ())
809819 sx = mlir .wrap_with_sharding_op (ctx , x , aval_in , shard_proto ,
810820 unspecified_dims = unspecified )
811- manual_proto = pxla .manual_proto (aval_in , manual_axes , mesh )
812- return mlir .wrap_with_full_to_shard_op (ctx , sx , aval_out , manual_proto , unspecified )
821+ manual_proto = pxla .manual_proto (
822+ aval_in , manual_axes | set (mesh .manual_axes ), mesh )
823+ return mlir .wrap_with_full_to_shard_op (ctx , sx , aval_out , manual_proto ,
824+ unspecified )
813825
814826def _xla_unshard (ctx : mlir .LoweringRuleContext , mesh , manual_axes , names ,
815827 aval_in , aval_out , x ):
@@ -824,8 +836,10 @@ def _xla_unshard(ctx: mlir.LoweringRuleContext, mesh, manual_axes, names,
824836 if len (manual_axes ) < len (mesh .axis_names ) else set ())
825837 if dtypes .issubdtype (aval_in .dtype , dtypes .extended ):
826838 aval_in = core .physical_aval (aval_in )
827- manual_proto = pxla .manual_proto (aval_in , manual_axes , mesh )
828- sx = mlir .wrap_with_sharding_op (ctx , x , aval_in , manual_proto , unspecified_dims = unspecified )
839+ manual_proto = pxla .manual_proto (
840+ aval_in , manual_axes | set (mesh .manual_axes ), mesh )
841+ sx = mlir .wrap_with_sharding_op (ctx , x , aval_in , manual_proto ,
842+ unspecified_dims = unspecified )
829843 shard_proto = ns ._to_xla_hlo_sharding (aval_out .ndim ).to_proto ()
830844 return mlir .wrap_with_shard_to_full_op (ctx , sx , aval_out , shard_proto ,
831845 unspecified )
@@ -894,9 +908,9 @@ def _shard_map_impl(trace, prim, fun, args, *, mesh, in_names, out_names_thunk,
894908def _run_shmap (f , mesh , manual_axes , args , vmas , check_vma , context_mesh ):
895909 trace = ShardMapTrace (mesh , manual_axes , check_vma , context_mesh )
896910 in_tracers = map (partial (ShardMapTracer , trace ), vmas , args )
897- manual_mesh = _as_manual_mesh (mesh , manual_axes )
911+ inner_mesh = _as_manual_mesh (mesh , manual_axes | set ( mesh . manual_axes ) )
898912 with (core .set_current_trace (trace ), _extend_axis_env (mesh , manual_axes ),
899- use_abstract_mesh (manual_mesh ), config ._check_vma (check_vma )):
913+ use_abstract_mesh (inner_mesh ), config ._check_vma (check_vma )):
900914 ans = f .call_wrapped (* in_tracers )
901915 outs , out_vma = unzip2 (map (trace .to_val_vma_pair , ans ))
902916 return outs , out_vma
@@ -1318,7 +1332,7 @@ def fwd_out_names_thunk():
13181332 args_to_promote = [getattr (aval , 'shape' , ()) == () and f1 is None and f2 is None
13191333 for aval , f1 , f2 in zip (res_avals , in_fwd , out_fwd )]
13201334 with (_extend_axis_env (mesh , manual_axes ),
1321- use_abstract_mesh (_as_manual_mesh (mesh , manual_axes )),
1335+ use_abstract_mesh (_as_manual_mesh (mesh , manual_axes | set ( mesh . manual_axes ) )),
13221336 config ._check_vma (check_vma )):
13231337 lin_jaxpr = _promote_scalar_residuals_jaxpr (lin_jaxpr , args_to_promote )
13241338 out_names = out_names_thunk ()
@@ -1483,7 +1497,7 @@ def _partial_eval_jaxpr_custom_rule(
14831497 jaxpr , mesh = eqn .params ['jaxpr' ], eqn .params ['mesh' ]
14841498 check_vma , manual_axes = eqn .params ['check_vma' ], eqn .params ['manual_axes' ]
14851499 with (_extend_axis_env (mesh , manual_axes ), config ._check_vma (check_vma ),
1486- use_abstract_mesh (_as_manual_mesh (mesh , manual_axes ))):
1500+ use_abstract_mesh (_as_manual_mesh (mesh , manual_axes | set ( mesh . manual_axes ) ))):
14871501 jaxpr_known , jaxpr_staged , unks_out , inst_out , num_res = \
14881502 pe .partial_eval_jaxpr_custom (jaxpr , unks_in , inst_in , False , False , saveable )
14891503 num_out_primals = len (jaxpr_known .outvars ) - num_res
@@ -1494,7 +1508,7 @@ def _partial_eval_jaxpr_custom_rule(
14941508 which = [f1 is None and f2 is None for f1 , f2 in zip (in_fwd , out_fwd )]
14951509 mesh = eqn .params ['mesh' ]
14961510 with (_extend_axis_env (mesh , manual_axes ),
1497- use_abstract_mesh (_as_manual_mesh (mesh , manual_axes )),
1511+ use_abstract_mesh (_as_manual_mesh (mesh , manual_axes | set ( mesh . manual_axes ) )),
14981512 config ._check_vma (check_vma )):
14991513 jaxpr_known = pe .prune_jaxpr_outputs (jaxpr_known , [True ] * num_out_primals + which )
15001514 jaxpr_known , jaxpr_staged = _add_reshapes (which , jaxpr_known , jaxpr_staged )
0 commit comments