1414# limitations under the License.
1515
1616from collections import OrderedDict , defaultdict
17- from typing import Any , cast , overload
17+ from typing import Any , cast
1818
1919import jax
20- import jax ._src .core as jcore
2120import numpy as np
2221
22+ from jaxpp .jax_compat import core as jcore
2323from jaxpp .mesh import MpmdMesh
2424from jaxpp .types import MpmdSharding
25- from jaxpp .utils import get_named_sharding , update_named_sharding
25+ from jaxpp .utils import filter_axes , get_named_sharding , update_named_sharding
2626
2727
2828class MpmdArray :
@@ -83,8 +83,8 @@ def __init__(
8383 mesh = get_named_sharding (arr ).mesh
8484 if (mpmd_idx := mpmd_mesh .mpmd_idx_for_mesh .get (mesh )) is None :
8585 raise ValueError (
86- f"Argument array { idx } { arr .shape } is not on a mesh that is part "
87- f"mpmd_mesh={ mpmd_mesh .jax_mesh } "
86+ f"Argument array { idx } { arr .shape } is not on a mesh that is part"
87+ f" of mpmd_mesh={ mpmd_mesh .jax_mesh } "
8888 )
8989
9090 if mpmd_idx not in mpmd_idxs :
@@ -118,8 +118,9 @@ def __init__(
118118 assert all (_ == shape for _ in shapes ), (shape , shapes )
119119 dtypes = [a .dtype for a in self ._partially_addressable_arrays .values ()]
120120 assert all (_ == dtype for _ in dtypes ), (dtype , dtypes )
121+ mpmd_axis = mpmd_sharding .mpmd_mesh .mpmd_axis_name
121122 specs = [
122- get_named_sharding (a ).spec
123+ filter_axes ( get_named_sharding (a ).spec , { mpmd_axis })
123124 for a in self ._partially_addressable_arrays .values ()
124125 ]
125126 assert all (_ == mpmd_sharding .spec for _ in specs ), (
@@ -281,56 +282,6 @@ def _to_global_jax_array(mpmd_array: MpmdArray) -> jax.Array | None:
281282 )
282283
283284
284- @overload
285- def filter_axes (
286- sharding_or_pspec : jax .sharding .NamedSharding , axes : set [str ]
287- ) -> jax .sharding .NamedSharding : ...
288-
289-
290- @overload
291- def filter_axes (
292- sharding_or_pspec : jax .sharding .PartitionSpec , axes : set [str ]
293- ) -> jax .sharding .PartitionSpec : ...
294-
295-
296- def filter_axes (
297- sharding_or_pspec : jax .sharding .NamedSharding | jax .sharding .PartitionSpec ,
298- axes : set [str ],
299- ) -> jax .sharding .NamedSharding | jax .sharding .PartitionSpec :
300- """Filter out specified axes from a sharding or partition spec.
301-
302- Args:
303- sharding_or_pspec: Either a NamedSharding or PartitionSpec to filter.
304- axes: Set of axis names to remove from the spec.
305-
306- Returns:
307- Same type as input with specified axes filtered out.
308- """
309- if isinstance (sharding_or_pspec , jax .sharding .NamedSharding ):
310- return jax .sharding .NamedSharding (
311- sharding_or_pspec .mesh , filter_axes (sharding_or_pspec .spec , axes )
312- )
313-
314- assert isinstance (sharding_or_pspec , jax .sharding .PartitionSpec )
315-
316- new_spec = []
317- for axis in sharding_or_pspec :
318- if axis is None :
319- new_spec .append (None )
320- elif isinstance (axis , str ):
321- if axis not in axes :
322- new_spec .append (axis )
323- else :
324- new_spec .append (None )
325- elif isinstance (axis , (list , tuple )):
326- new_axis = [a for a in axis if a not in axes ]
327- new_spec .append (type (axis )(new_axis ))
328- else :
329- raise ValueError (f"Unsupported_axis_type: { type (axis )} " )
330-
331- return jax .sharding .PartitionSpec (* new_spec )
332-
333-
334285def _id (* xs ):
335286 return xs
336287
@@ -346,8 +297,9 @@ def _spmd_to_mpmd_reshard(
346297
347298 for spmd_value , dist_sharding in zip (spmd_values , dist_shardings ):
348299 assert isinstance (
349- spmd_value .sharding , jax .sharding .NamedSharding
350- ), spmd_value .sharding
300+ spmd_value .sharding ,
301+ (jax .sharding .NamedSharding , jax .sharding .SingleDeviceSharding ),
302+ ), f"Unsupported sharding type: { spmd_value .sharding } "
351303
352304 # NOTE: We filter out the mpmd axis from the sharding so that
353305 # the output is replicated across all mpmd ranks.
@@ -362,11 +314,7 @@ def _spmd_to_mpmd_reshard(
362314 for dist_sharding in dist_shardings
363315 )
364316
365- res : list [jax .Array ] = jax .jit (
366- _id ,
367- in_shardings = tuple (_ .sharding for _ in spmd_values ),
368- out_shardings = _actual_shardings ,
369- )(* spmd_values )
317+ res : list [jax .Array ] = jax .jit (_id , out_shardings = _actual_shardings )(* spmd_values )
370318
371319 for spmd_value , donated in zip (spmd_values , donate , strict = True ):
372320 if donated :
@@ -398,17 +346,19 @@ def _spmd_to_mpmd_reshard(
398346 _res = []
399347 for arr , dsh in zip (res , dist_shardings , strict = True ):
400348 mesh_ids = dsh .mesh_ids
401- filtered_sharding = MpmdSharding (
349+ # MpmdSharding.__post_init__ canonicalizes the spec by filtering out
350+ # the mpmd axis, so we can just use dsh.spec directly
351+ mpmd_sharding = MpmdSharding (
402352 mpmd_mesh = dsh .mpmd_mesh ,
403353 mesh_ids = dsh .mesh_ids ,
404- spec = filter_axes ( dsh .sharding . spec , { mpmd_mesh . mpmd_axis_name }) ,
354+ spec = dsh .spec ,
405355 )
406356
407357 if mpmd_mesh .my_mpmd_axis_index not in mesh_ids :
408358 _res .append (
409359 MpmdArray (
410360 partially_addressable_arrays = [],
411- mpmd_sharding = filtered_sharding ,
361+ mpmd_sharding = mpmd_sharding ,
412362 shape = arr .shape ,
413363 dtype = arr .dtype ,
414364 )
@@ -425,7 +375,7 @@ def _spmd_to_mpmd_reshard(
425375 _res .append (
426376 MpmdArray (
427377 partially_addressable_arrays = [new_arr ],
428- mpmd_sharding = filtered_sharding ,
378+ mpmd_sharding = mpmd_sharding ,
429379 )
430380 )
431381 return _res
@@ -481,6 +431,9 @@ def spmd_to_mpmd_reshard(
481431 The specs of the returned arrays will _not_ have `mpmd_mesh.mpmd_axis_name` in
482432 them.
483433
434+ Limitations: same constraints as jax.jit apply (e.g. _device_assignment must be the
435+ same for all arrays)
436+
484437 Args:
485438 mpmd_mesh: The MPMD mesh definition.
486439 spmd_arrays: A pytree of source SPMD arrays.
@@ -506,28 +459,6 @@ def spmd_to_mpmd_reshard(
506459
507460 assert spmd_tree_def == mpmd_tree_def
508461
509- # Verify all arrays are on the same mesh
510- first_path , first_leaf = spmd_arrays_with_path [0 ]
511- first_sharding = first_leaf .sharding
512- assert isinstance (first_sharding , jax .sharding .NamedSharding ), first_sharding
513- mesh = first_sharding .mesh
514-
515- # This check is the same as the one performed by jax.jit
516- assert mesh ._flat_devices_tuple == mpmd_mesh .jax_mesh ._flat_devices_tuple , (
517- mesh ,
518- mpmd_mesh .jax_mesh ,
519- )
520- for path , leaf in spmd_arrays_with_path :
521- assert isinstance (leaf .sharding , jax .sharding .NamedSharding ), (
522- path ,
523- leaf .sharding ,
524- )
525- assert leaf .sharding .mesh ._flat_devices_tuple == mesh ._flat_devices_tuple , (
526- path ,
527- mesh ,
528- leaf .sharding .mesh ,
529- )
530-
531462 _ , spmd_arrays_flat = jax ._src .util .unzip2 (spmd_arrays_with_path )
532463 spmd_arrays_flat_list = list (spmd_arrays_flat )
533464
@@ -605,7 +536,11 @@ def _axis_name_in_spec(axis_name: str, spec) -> bool:
605536
606537
607538def logically_stacked (
608- array : jax .Array , comm_mesh : jax .sharding .Mesh , axis_name : str , strict : bool = False
539+ array : jax .Array ,
540+ comm_mesh : jax .sharding .Mesh ,
541+ mesh_axis_name : str ,
542+ array_axis : int = 0 ,
543+ strict : bool = False ,
609544):
610545 """
611546 Logically stacks an array along a new axis corresponding to the MPMD dimension.
@@ -619,18 +554,25 @@ def logically_stacked(
619554 if strict :
620555 spec = array .sharding .spec
621556 assert not _axis_name_in_spec (
622- axis_name , spec
623- ), f"axis_name { axis_name !r} already exists in spec { spec } "
557+ mesh_axis_name , spec
558+ ), f"axis_name { mesh_axis_name !r} already exists in spec { spec } "
624559 else :
625- spec = filter_axes (array .sharding .spec , {axis_name })
560+ spec = filter_axes (array .sharding .spec , {mesh_axis_name })
626561
627- expanded_array = jax .numpy .expand_dims (array , 0 )
562+ expanded_array = jax .numpy .expand_dims (array , array_axis )
628563 in_sharding = jax .sharding .NamedSharding (
629- comm_mesh , jax .sharding .PartitionSpec (axis_name , * spec )
564+ comm_mesh ,
565+ jax .sharding .PartitionSpec (
566+ * spec [:array_axis ], mesh_axis_name , * spec [array_axis :]
567+ ),
630568 )
631569
632570 global_array = jax .make_array_from_single_device_arrays (
633- (comm_mesh .shape [axis_name ], * array .shape ),
571+ (
572+ * array .shape [:array_axis ],
573+ comm_mesh .shape [mesh_axis_name ],
574+ * array .shape [array_axis :],
575+ ),
634576 in_sharding ,
635577 [s .data for s in expanded_array .addressable_shards ],
636578 )
@@ -664,6 +606,12 @@ def mpmd_to_spmd_reshard(
664606 Returns:
665607 A pytree of JAX arrays with the same structure as mpmd_arrays.
666608 """
609+
610+ if not mpmd_mesh .jax_mesh .is_multi_process :
611+ return jax .device_put (
612+ jax .tree .map (lambda _ : _ .first_mpmd_replica , mpmd_arrays ), spmd_shardings
613+ )
614+
667615 mpmd_arrays_with_path , mpmd_tree_def = jax .tree .flatten_with_path (mpmd_arrays )
668616 mpmd_arrays_with_path : list [tuple [Any , MpmdArray ]]
669617 spmd_shardings_flat , spmd_tree_def = jax .tree .flatten (spmd_shardings )
0 commit comments