Skip to content

Commit aff6dd4

Browse files
seanprime7Angelogeb
andcommitted
Update 26.01.18
- 9e20ec9dc2b556db2faba1990bf6579da80bf6ea by Anxhelo Xhebraj <axhebraj@nvidia.com> - de11bd8aa863acbab840e33368ffdb199893e2e4 by Anxhelo Xhebraj <axhebraj@nvidia.com> - 2d398e9802ab2ad6016e385c9a4f18c86f6381d0 by Sean Lee <selee@nvidia.com> - 960973450466b0a29b8c40e56799bc8f34e76025 by Anxhelo Xhebraj <axhebraj@nvidia.com> - 1d19d8f7a05de81c2efd3214d4f050421dbaec2c by Anxhelo Xhebraj <axhebraj@nvidia.com> - aaec562a093047ad2847b576ed1113587ae77f17 by Anxhelo Xhebraj <axhebraj@nvidia.com> - 7164423d1ea990ea482612631042f88569ca857e by Anxhelo Xhebraj <axhebraj@nvidia.com> - 00f246f52cbfb8d3086f995b8d3dfb7b4c1d861e by Anxhelo Xhebraj <axhebraj@nvidia.com> - c360203d29269207ced92451746b5f3fe08d5250 by Anxhelo Xhebraj <axhebraj@nvidia.com> Co-authored-by: Anxhelo Xhebraj <axhebraj@nvidia.com> Signed-off-by: Sean Lee <selee@nvidia.com> GitOrigin-RevId: 9e20ec9dc2b556db2faba1990bf6579da80bf6ea
1 parent 298036e commit aff6dd4

File tree

17 files changed

+785
-370
lines changed

17 files changed

+785
-370
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ classifiers = [
9999
"Operating System :: OS Independent",
100100
]
101101
dependencies = [
102-
"jax[cuda12]>=0.5.1,<=0.8.1",
102+
"jax[cuda12]>=0.6.2,<=0.8.2",
103103
"cupy-cuda12x",
104104
"portpicker==1.6.0",
105105
]

scripts/local_mc.sh

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,13 @@
1515

1616
set -u
1717
set -o pipefail
18+
set -m
1819

1920
if [ -z "${N_PROCS:-}" ] || [ -z "${N_GPUS:-}" ] || [ -z "${COMMAND:-}" ]; then
2021
echo "N_PROCS, N_GPUS, and COMMAND must be set"
2122
exit 1
2223
fi
2324

24-
25-
2625
# Default coordinator setup
2726
export JAX_COORDINATOR_IP="${JAX_COORDINATOR_IP:-localhost}"
2827
export JAX_COORDINATOR_PORT="${JAX_COORDINATOR_PORT:-1234}"
@@ -35,9 +34,7 @@ PIDS=()
3534
cleanup() {
3635
echo "Cleaning up..."
3736
for pid in "${PIDS[@]}"; do
38-
if kill -0 "$pid" 2>/dev/null; then
39-
kill "$pid" 2>/dev/null
40-
fi
37+
kill -9 -- -"$pid" 2>/dev/null || true
4138
done
4239
}
4340

scripts/test_jax_versions.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#!/bin/bash
22
set -e
33

4-
JAX_VERSIONS=("0.6.1" "0.6.2" "0.7.0" "0.7.1" "0.7.2" "0.8.0" "0.8.1")
4+
JAX_VERSIONS=("0.6.2" "0.7.0" "0.7.1" "0.7.2" "0.8.0" "0.8.1" "0.8.2")
55

66
for version in "${JAX_VERSIONS[@]}"
77
do
@@ -13,6 +13,6 @@ done
1313

1414
# Test nightly
1515
pip install -U --pre jax jaxlib "jax-cuda12-plugin[with-cuda]" jax-cuda12-pjrt -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/
16-
# JAXPP_DEBUG_FORCE_MPMDIFY=True JAXPP_ENABLE_LICM=True python examples/basic.py --train_steps=10
16+
JAXPP_DEBUG_FORCE_MPMDIFY=True JAXPP_ENABLE_LICM=True python examples/basic.py --train_steps=10
1717

1818
exit 0

src/jaxpp/api.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323
mpmd_to_spmd_reshard,
2424
spmd_to_mpmd_reshard,
2525
)
26-
from jaxpp.core import mpmd_jit_by_yield, mpmd_jit_rev, mpmd_jit_with_loop
27-
from jaxpp.jax_primitives import add_multi_p
26+
from jaxpp.core import mpmd_jit_by_yield, mpmd_jit_with_loop
27+
from jaxpp.jax_primitives import add_multi_p, gather_multi_p
2828
from jaxpp.mesh import MpmdMesh
2929
from jaxpp.pipelining import pipeline_enter_stage
3030
from jaxpp.schedules import (
@@ -45,3 +45,10 @@ def cross_mpmd_all_reduce(*args):
4545
f"All arguments must have the same dtype, got {[a.dtype for a in args]}"
4646
)
4747
return add_multi_p.bind(*args)
48+
49+
50+
def cross_mpmd_stack(arrays, axis: int = 0):
51+
import jax
52+
53+
expanded = [jax.numpy.expand_dims(a, axis=axis) for a in arrays]
54+
return gather_multi_p.bind(*expanded, axis=axis)

src/jaxpp/array.py

Lines changed: 44 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,15 @@
1414
# limitations under the License.
1515

1616
from collections import OrderedDict, defaultdict
17-
from typing import Any, cast, overload
17+
from typing import Any, cast
1818

1919
import jax
20-
import jax._src.core as jcore
2120
import numpy as np
2221

22+
from jaxpp.jax_compat import core as jcore
2323
from jaxpp.mesh import MpmdMesh
2424
from jaxpp.types import MpmdSharding
25-
from jaxpp.utils import get_named_sharding, update_named_sharding
25+
from jaxpp.utils import filter_axes, get_named_sharding, update_named_sharding
2626

2727

2828
class MpmdArray:
@@ -83,8 +83,8 @@ def __init__(
8383
mesh = get_named_sharding(arr).mesh
8484
if (mpmd_idx := mpmd_mesh.mpmd_idx_for_mesh.get(mesh)) is None:
8585
raise ValueError(
86-
f"Argument array {idx} {arr.shape} is not on a mesh that is part "
87-
f"mpmd_mesh={mpmd_mesh.jax_mesh}"
86+
f"Argument array {idx} {arr.shape} is not on a mesh that is part"
87+
f" of mpmd_mesh={mpmd_mesh.jax_mesh}"
8888
)
8989

9090
if mpmd_idx not in mpmd_idxs:
@@ -118,8 +118,9 @@ def __init__(
118118
assert all(_ == shape for _ in shapes), (shape, shapes)
119119
dtypes = [a.dtype for a in self._partially_addressable_arrays.values()]
120120
assert all(_ == dtype for _ in dtypes), (dtype, dtypes)
121+
mpmd_axis = mpmd_sharding.mpmd_mesh.mpmd_axis_name
121122
specs = [
122-
get_named_sharding(a).spec
123+
filter_axes(get_named_sharding(a).spec, {mpmd_axis})
123124
for a in self._partially_addressable_arrays.values()
124125
]
125126
assert all(_ == mpmd_sharding.spec for _ in specs), (
@@ -281,56 +282,6 @@ def _to_global_jax_array(mpmd_array: MpmdArray) -> jax.Array | None:
281282
)
282283

283284

284-
@overload
285-
def filter_axes(
286-
sharding_or_pspec: jax.sharding.NamedSharding, axes: set[str]
287-
) -> jax.sharding.NamedSharding: ...
288-
289-
290-
@overload
291-
def filter_axes(
292-
sharding_or_pspec: jax.sharding.PartitionSpec, axes: set[str]
293-
) -> jax.sharding.PartitionSpec: ...
294-
295-
296-
def filter_axes(
297-
sharding_or_pspec: jax.sharding.NamedSharding | jax.sharding.PartitionSpec,
298-
axes: set[str],
299-
) -> jax.sharding.NamedSharding | jax.sharding.PartitionSpec:
300-
"""Filter out specified axes from a sharding or partition spec.
301-
302-
Args:
303-
sharding_or_pspec: Either a NamedSharding or PartitionSpec to filter.
304-
axes: Set of axis names to remove from the spec.
305-
306-
Returns:
307-
Same type as input with specified axes filtered out.
308-
"""
309-
if isinstance(sharding_or_pspec, jax.sharding.NamedSharding):
310-
return jax.sharding.NamedSharding(
311-
sharding_or_pspec.mesh, filter_axes(sharding_or_pspec.spec, axes)
312-
)
313-
314-
assert isinstance(sharding_or_pspec, jax.sharding.PartitionSpec)
315-
316-
new_spec = []
317-
for axis in sharding_or_pspec:
318-
if axis is None:
319-
new_spec.append(None)
320-
elif isinstance(axis, str):
321-
if axis not in axes:
322-
new_spec.append(axis)
323-
else:
324-
new_spec.append(None)
325-
elif isinstance(axis, (list, tuple)):
326-
new_axis = [a for a in axis if a not in axes]
327-
new_spec.append(type(axis)(new_axis))
328-
else:
329-
raise ValueError(f"Unsupported_axis_type: {type(axis)}")
330-
331-
return jax.sharding.PartitionSpec(*new_spec)
332-
333-
334285
def _id(*xs):
335286
return xs
336287

@@ -346,8 +297,9 @@ def _spmd_to_mpmd_reshard(
346297

347298
for spmd_value, dist_sharding in zip(spmd_values, dist_shardings):
348299
assert isinstance(
349-
spmd_value.sharding, jax.sharding.NamedSharding
350-
), spmd_value.sharding
300+
spmd_value.sharding,
301+
(jax.sharding.NamedSharding, jax.sharding.SingleDeviceSharding),
302+
), f"Unsupported sharding type: {spmd_value.sharding}"
351303

352304
# NOTE: We filter out the mpmd axis from the sharding so that
353305
# the output is replicated across all mpmd ranks.
@@ -362,11 +314,7 @@ def _spmd_to_mpmd_reshard(
362314
for dist_sharding in dist_shardings
363315
)
364316

365-
res: list[jax.Array] = jax.jit(
366-
_id,
367-
in_shardings=tuple(_.sharding for _ in spmd_values),
368-
out_shardings=_actual_shardings,
369-
)(*spmd_values)
317+
res: list[jax.Array] = jax.jit(_id, out_shardings=_actual_shardings)(*spmd_values)
370318

371319
for spmd_value, donated in zip(spmd_values, donate, strict=True):
372320
if donated:
@@ -398,17 +346,19 @@ def _spmd_to_mpmd_reshard(
398346
_res = []
399347
for arr, dsh in zip(res, dist_shardings, strict=True):
400348
mesh_ids = dsh.mesh_ids
401-
filtered_sharding = MpmdSharding(
349+
# MpmdSharding.__post_init__ canonicalizes the spec by filtering out
350+
# the mpmd axis, so we can just use dsh.spec directly
351+
mpmd_sharding = MpmdSharding(
402352
mpmd_mesh=dsh.mpmd_mesh,
403353
mesh_ids=dsh.mesh_ids,
404-
spec=filter_axes(dsh.sharding.spec, {mpmd_mesh.mpmd_axis_name}),
354+
spec=dsh.spec,
405355
)
406356

407357
if mpmd_mesh.my_mpmd_axis_index not in mesh_ids:
408358
_res.append(
409359
MpmdArray(
410360
partially_addressable_arrays=[],
411-
mpmd_sharding=filtered_sharding,
361+
mpmd_sharding=mpmd_sharding,
412362
shape=arr.shape,
413363
dtype=arr.dtype,
414364
)
@@ -425,7 +375,7 @@ def _spmd_to_mpmd_reshard(
425375
_res.append(
426376
MpmdArray(
427377
partially_addressable_arrays=[new_arr],
428-
mpmd_sharding=filtered_sharding,
378+
mpmd_sharding=mpmd_sharding,
429379
)
430380
)
431381
return _res
@@ -481,6 +431,9 @@ def spmd_to_mpmd_reshard(
481431
The specs of the returned arrays will _not_ have `mpmd_mesh.mpmd_axis_name` in
482432
them.
483433
434+
Limitations: same constraints as jax.jit apply (e.g. _device_assignment must be the
435+
same for all arrays)
436+
484437
Args:
485438
mpmd_mesh: The MPMD mesh definition.
486439
spmd_arrays: A pytree of source SPMD arrays.
@@ -506,28 +459,6 @@ def spmd_to_mpmd_reshard(
506459

507460
assert spmd_tree_def == mpmd_tree_def
508461

509-
# Verify all arrays are on the same mesh
510-
first_path, first_leaf = spmd_arrays_with_path[0]
511-
first_sharding = first_leaf.sharding
512-
assert isinstance(first_sharding, jax.sharding.NamedSharding), first_sharding
513-
mesh = first_sharding.mesh
514-
515-
# This check is the same as the one performed by jax.jit
516-
assert mesh._flat_devices_tuple == mpmd_mesh.jax_mesh._flat_devices_tuple, (
517-
mesh,
518-
mpmd_mesh.jax_mesh,
519-
)
520-
for path, leaf in spmd_arrays_with_path:
521-
assert isinstance(leaf.sharding, jax.sharding.NamedSharding), (
522-
path,
523-
leaf.sharding,
524-
)
525-
assert leaf.sharding.mesh._flat_devices_tuple == mesh._flat_devices_tuple, (
526-
path,
527-
mesh,
528-
leaf.sharding.mesh,
529-
)
530-
531462
_, spmd_arrays_flat = jax._src.util.unzip2(spmd_arrays_with_path)
532463
spmd_arrays_flat_list = list(spmd_arrays_flat)
533464

@@ -605,7 +536,11 @@ def _axis_name_in_spec(axis_name: str, spec) -> bool:
605536

606537

607538
def logically_stacked(
608-
array: jax.Array, comm_mesh: jax.sharding.Mesh, axis_name: str, strict: bool = False
539+
array: jax.Array,
540+
comm_mesh: jax.sharding.Mesh,
541+
mesh_axis_name: str,
542+
array_axis: int = 0,
543+
strict: bool = False,
609544
):
610545
"""
611546
Logically stacks an array along a new axis corresponding to the MPMD dimension.
@@ -619,18 +554,25 @@ def logically_stacked(
619554
if strict:
620555
spec = array.sharding.spec
621556
assert not _axis_name_in_spec(
622-
axis_name, spec
623-
), f"axis_name {axis_name!r} already exists in spec {spec}"
557+
mesh_axis_name, spec
558+
), f"axis_name {mesh_axis_name!r} already exists in spec {spec}"
624559
else:
625-
spec = filter_axes(array.sharding.spec, {axis_name})
560+
spec = filter_axes(array.sharding.spec, {mesh_axis_name})
626561

627-
expanded_array = jax.numpy.expand_dims(array, 0)
562+
expanded_array = jax.numpy.expand_dims(array, array_axis)
628563
in_sharding = jax.sharding.NamedSharding(
629-
comm_mesh, jax.sharding.PartitionSpec(axis_name, *spec)
564+
comm_mesh,
565+
jax.sharding.PartitionSpec(
566+
*spec[:array_axis], mesh_axis_name, *spec[array_axis:]
567+
),
630568
)
631569

632570
global_array = jax.make_array_from_single_device_arrays(
633-
(comm_mesh.shape[axis_name], *array.shape),
571+
(
572+
*array.shape[:array_axis],
573+
comm_mesh.shape[mesh_axis_name],
574+
*array.shape[array_axis:],
575+
),
634576
in_sharding,
635577
[s.data for s in expanded_array.addressable_shards],
636578
)
@@ -664,6 +606,12 @@ def mpmd_to_spmd_reshard(
664606
Returns:
665607
A pytree of JAX arrays with the same structure as mpmd_arrays.
666608
"""
609+
610+
if not mpmd_mesh.jax_mesh.is_multi_process:
611+
return jax.device_put(
612+
jax.tree.map(lambda _: _.first_mpmd_replica, mpmd_arrays), spmd_shardings
613+
)
614+
667615
mpmd_arrays_with_path, mpmd_tree_def = jax.tree.flatten_with_path(mpmd_arrays)
668616
mpmd_arrays_with_path: list[tuple[Any, MpmdArray]]
669617
spmd_shardings_flat, spmd_tree_def = jax.tree.flatten(spmd_shardings)

0 commit comments

Comments
 (0)