Skip to content

Commit fae5cb7

Browse files
seanprime7Angelogeb
andcommitted
Update 26.02.13
Updates include * KimiK2 schedule * bug fixes - 00df43e65ba442c6dc8d6ec6378303bd9b15e644 by Anxhelo Xhebraj <axhebraj@nvidia.com> - a218842766a8d2967619aed225fa6d910ebbbe80 by Sean Lee <selee@nvidia.com> - d50fd8b5681c29b138a6df530fe3e451ec17b6b6 by Anxhelo Xhebraj <axhebraj@nvidia.com> - dc85730f57d2444021a7ed799c4daa3840cb2472 by Anxhelo Xhebraj <axhebraj@nvidia.com> - ce305ed9972778ff8d39b1a3223a886fa51ea89b by Sean Lee <selee@nvidia.com> - b5a509e339010808f47a1a0c12452f0dcecf7b60 by Sean Lee <selee@nvidia.com> - 03898159ec8c9732b9731154762b72e6712986c6 by Sean Lee <selee@nvidia.com> Co-authored-by: Anxhelo Xhebraj <axhebraj@nvidia.com> Signed-off-by: Sean Lee <selee@nvidia.com> GitOrigin-RevId: 00df43e65ba442c6dc8d6ec6378303bd9b15e644
1 parent 842c336 commit fae5cb7

File tree

8 files changed

+136
-298
lines changed

8 files changed

+136
-298
lines changed

scripts/test_jax_versions.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#!/bin/bash
22
set -e
33

4-
JAX_VERSIONS=("0.6.2" "0.7.0" "0.7.1" "0.7.2" "0.8.0" "0.8.1" "0.8.2")
4+
JAX_VERSIONS=("0.6.2" "0.7.0" "0.7.1" "0.7.2" "0.8.0" "0.8.1" "0.8.2", "0.9.0")
55

66
for version in "${JAX_VERSIONS[@]}"
77
do

src/jaxpp/array.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -379,12 +379,23 @@ def _spmd_to_mpmd_reshard(
379379

380380

381381
def _get_working_memory_threshold() -> int:
382-
"""Get the minimum available working memory across local devices."""
382+
"""Get the minimum available working memory across all devices globally."""
383383
min_available = float("inf")
384384
for d in jax.local_devices():
385385
stats = d.memory_stats()
386386
available = stats["bytes_limit"] - stats["peak_bytes_in_use"]
387387
min_available = min(min_available, available)
388+
389+
# Ensure all processes use the same threshold by computing global minimum
390+
if jax.process_count() > 1:
391+
# Use process_allgather to collect all local minimums, then compute global min
392+
# Note: process_allgather requires an array, not a scalar
393+
local_min_array = jax.numpy.array(min_available, dtype=jax.numpy.float32)
394+
all_mins = jax.experimental.multihost_utils.process_allgather(
395+
local_min_array, tiled=True
396+
)
397+
min_available = float(jax.numpy.min(all_mins))
398+
388399
return int(min_available // 3)
389400

390401

src/jaxpp/core.py

Lines changed: 23 additions & 178 deletions
Original file line numberDiff line numberDiff line change
@@ -103,31 +103,6 @@
103103
logger = logging.getLogger(__name__)
104104

105105

106-
@contextmanager
107-
def stable_names_ctx(anno: Callable[[jcore.Var], str | None] = lambda v: None):
108-
prev_repr = jax._src.core.Var.__repr__
109-
prev_pp_var = jax._src.core.pp_var
110-
111-
ctx = jcore.JaxprPpContext()
112-
113-
def __repr__(v):
114-
if isinstance(v, jcore.Literal):
115-
return f"{v}"
116-
if (s := anno(v)) is not None:
117-
return f"{prev_pp_var(v, ctx)}{{{s}}}"
118-
119-
return f"{prev_pp_var(v, ctx)}"
120-
121-
jax._src.core.pp_var = lambda v, _: __repr__(v)
122-
jax._src.core.Var.__repr__ = __repr__
123-
124-
try:
125-
yield
126-
finally:
127-
jax._src.core.Var.__repr__ = prev_repr
128-
jax._src.core.pp_var = prev_pp_var
129-
130-
131106
CJaxpr = TypeVar("CJaxpr", jcore.ClosedJaxpr, jcore.Jaxpr)
132107
Res = TypeVar("Res")
133108
P = ParamSpec("P")
@@ -893,9 +868,7 @@ def first_pipeline_yield_eqn_idx(eqns: Iterable[jcore.JaxprEqn]) -> int | None:
893868
def infer_cluster_idx_for_eqns(
894869
clusters: list[Cluster],
895870
eqns: list[jcore.JaxprEqn],
896-
bias: dict[jcore.Var, set[MpmdIdx]] | None = None,
897871
) -> list[int | None]:
898-
bias = bias or {}
899872
cluster_info = get_cluster_information(clusters)
900873
var_def_cluster_idx = cluster_info.var_def_cluster_idx
901874
var_ref_cluster_idx = cluster_info.var_ref_cluster_idx
@@ -1048,11 +1021,9 @@ def cluster_by_yield_eqns(
10481021
def cluster_eqns(
10491022
eqns: list[jcore.JaxprEqn],
10501023
get_mpmd_idx: Callable[[int], MpmdIdx],
1051-
bias: dict[jcore.Var, set[MpmdIdx]] | None = None,
10521024
) -> tuple[list[Cluster], list[jcore.JaxprEqn]]:
1053-
bias = bias or {}
10541025
clusters, rest = cluster_by_yield_eqns(eqns, get_mpmd_idx)
1055-
eqns_cluster_idxs = infer_cluster_idx_for_eqns(clusters, rest, bias)
1026+
eqns_cluster_idxs = infer_cluster_idx_for_eqns(clusters, rest)
10561027
unclustered_eqns = list[jcore.JaxprEqn]()
10571028
for cluster_idx, eqn in zip(eqns_cluster_idxs, rest, strict=True):
10581029
if cluster_idx is not None:
@@ -1124,19 +1095,10 @@ def cluster_jaxpr(
11241095
target_num_stages: int,
11251096
is_partial_bwd: bool,
11261097
get_mpmd_idx: Callable[[int], MpmdIdx],
1127-
bias: list[set[MpmdIdx] | None] | None = None,
11281098
is_loop: bool = True,
11291099
):
11301100
# TODO: remove is_loop parameter and make the caller perform the checks
1131-
bias_map = None
1132-
if bias is not None:
1133-
bias_map = {
1134-
invar: p
1135-
for invar, p in zip(jaxpr.invars, bias, strict=True)
1136-
if p is not None
1137-
}
1138-
1139-
clusters, unclustered_eqns = cluster_eqns(jaxpr.eqns, get_mpmd_idx, bias_map)
1101+
clusters, unclustered_eqns = cluster_eqns(jaxpr.eqns, get_mpmd_idx)
11401102
if (
11411103
is_loop
11421104
and len(unclustered_eqns) != 0
@@ -1190,10 +1152,7 @@ def cluster_jaxpr(
11901152
return clustered_jaxpr
11911153

11921154

1193-
def wrap_into_tasks_inside_loop(
1194-
loop_eqn: jcore.JaxprEqn,
1195-
bias: list[set[MpmdIdx] | None] | None = None,
1196-
) -> jcore.JaxprEqn:
1155+
def wrap_into_tasks_inside_loop(loop_eqn: jcore.JaxprEqn) -> jcore.JaxprEqn:
11971156
jaxpr: jcore.Jaxpr = loop_eqn.params["jaxpr"].jaxpr
11981157
# TODO: let bind literals
11991158
assert len(jaxpr.outvars) == len(
@@ -1205,7 +1164,6 @@ def wrap_into_tasks_inside_loop(
12051164
target_num_stages=loop_eqn.params["schedule"].num_stages,
12061165
is_partial_bwd=loop_eqn.params["schedule"].is_partial_bwd,
12071166
get_mpmd_idx=loop_eqn.params["schedule"].get_mpmd_idx,
1208-
bias=bias,
12091167
)
12101168

12111169
# Infer where loop inputs are used (refs) and where loop outputs
@@ -1401,43 +1359,6 @@ def make_replicated_jaxpr(
14011359
return res, invar_mpmd_refs
14021360

14031361

1404-
def infer_outvar_placement_rev(
1405-
jaxpr: jcore.Jaxpr, partial_outvar_placement: Iterable[set[MpmdIdx] | None]
1406-
) -> tuple[list[set[MpmdIdx]], list[set[MpmdIdx] | None]]:
1407-
partial_outvar_placement = tuple(partial_outvar_placement)
1408-
outvars = cast(list[jcore.Var], jaxpr.outvars)
1409-
placement = {
1410-
outvar: maybe_p
1411-
for outvar, maybe_p in jc.safe_zip(outvars, partial_outvar_placement)
1412-
if isinstance(outvar, jcore.Var) and maybe_p is not None
1413-
}
1414-
1415-
# Infer from outvars to invars
1416-
for eqn in reversed(jaxpr.eqns):
1417-
eqn_p = set.union(
1418-
set(), *(placement.get(outvar, set()) for outvar in eqn.outvars)
1419-
)
1420-
if len(eqn_p) > 0:
1421-
for invar in nonlit(eqn.invars):
1422-
placement[invar] = placement.get(invar, set()) | eqn_p
1423-
1424-
# Infer from invars to outvars
1425-
for eqn in jaxpr.eqns:
1426-
eqn_p = set.union(
1427-
set(), *(placement.get(invar, set()) for invar in nonlit(eqn.invars))
1428-
)
1429-
if len(eqn_p) > 0:
1430-
for outvar in eqn.outvars:
1431-
placement[outvar] = placement.get(outvar, set()) | eqn_p
1432-
1433-
return [placement.get(invar) for invar in jaxpr.invars], [
1434-
placement.get(outvar)
1435-
if isinstance(outvar, jcore.Var)
1436-
else partial_outvar_placement[outvar_idx]
1437-
for outvar_idx, outvar in enumerate(outvars)
1438-
]
1439-
1440-
14411362
def get_one_loop_eqn_idx(
14421363
eqns_or_jaxpr: jcore.ClosedJaxpr | jcore.Jaxpr | Iterable[jcore.JaxprEqn],
14431364
) -> int:
@@ -1716,25 +1637,18 @@ def loop_placement_by_clusters(
17161637
if (idx := outvar_idx.get(outvar)) is not None:
17171638
out_mpmd_defs[idx] = clusters[def_cluster_idx].mpmd_idx
17181639

1719-
with stable_names_ctx(
1720-
lambda v: {clusters[idx].mpmd_idx for idx in idxs}
1721-
if (idxs := cluster_info.var_ref_cluster_idx.get(v)) is not None
1722-
else {clusters[idx].mpmd_idx}
1723-
if (idx := cluster_info.var_def_cluster_idx.get(v)) is not None
1724-
else None
1640+
for in_idx, (mpmd_refs, mpmd_def) in enumerate(
1641+
jc.safe_zip(in_mpmd_refs[n_consts:], out_mpmd_defs), start=n_consts
17251642
):
1726-
for in_idx, (mpmd_refs, mpmd_def) in enumerate(
1727-
jc.safe_zip(in_mpmd_refs[n_consts:], out_mpmd_defs), start=n_consts
1728-
):
1729-
# Check that the mpmd_index that produces an outvar
1730-
# is a subset of the ones that refer to it.
1731-
if mpmd_refs is not None:
1732-
if mpmd_def not in mpmd_refs:
1733-
raise AssertionError(
1734-
f"Loop state is not stable across iterations {in_idx=} {in_idx - n_consts=}"
1735-
)
1736-
elif mpmd_def is not None:
1737-
in_mpmd_refs[in_idx] = {mpmd_def}
1643+
# Check that the mpmd_index that produces an outvar
1644+
# is a subset of the ones that refer to it.
1645+
if mpmd_refs is not None:
1646+
if mpmd_def not in mpmd_refs:
1647+
raise AssertionError(
1648+
f"Loop state is not stable across iterations {in_idx=} {in_idx - n_consts=}"
1649+
)
1650+
elif mpmd_def is not None:
1651+
in_mpmd_refs[in_idx] = {mpmd_def}
17381652

17391653
return in_mpmd_refs, out_mpmd_defs
17401654

@@ -1766,77 +1680,6 @@ def join_argument_refs(
17661680
return loop_args_mpmd_refs_map
17671681

17681682

1769-
def _compute_bias(jaxpr: jcore.Jaxpr, loop_eqn_idx: int):
1770-
loop_eqn = jaxpr.eqns[loop_eqn_idx]
1771-
1772-
# Infer partial placement from loop body
1773-
loop_in_mpmd_refs, loop_out_mpmd_defs = loop_placement_by_clusters(
1774-
loop_eqn, loop_eqn.params["schedule"].get_mpmd_idx
1775-
)
1776-
1777-
# Use partial placement from loop body to infer
1778-
# before loop partial placement
1779-
before_loop_jaxpr = jaxpr_from_eqns(
1780-
jaxpr.eqns[:loop_eqn_idx], eqns_free_vars(jaxpr.eqns[loop_eqn_idx:])[0]
1781-
)
1782-
1783-
loop_args_mpmd_refs = join_argument_refs(
1784-
cast(list[jcore.Var], loop_eqn.invars), loop_in_mpmd_refs
1785-
)
1786-
1787-
before_loop_invar_placement, before_loop_outvar_mpmd_defs = (
1788-
infer_outvar_placement_rev(
1789-
before_loop_jaxpr,
1790-
partial_outvar_placement=tuple(
1791-
loop_args_mpmd_refs.get(outvar)
1792-
# NOTE: before loop outvars are just vars as it comes from
1793-
# `jaxpr_from_eqns`
1794-
for outvar in cast(list[jcore.Var], before_loop_jaxpr.outvars)
1795-
),
1796-
)
1797-
)
1798-
1799-
# make_replicated_jaxpr(
1800-
# before_loop_jaxpr,
1801-
# tuple(
1802-
# map(
1803-
# join_argument_refs(loop_eqn.invars, loop_in_mpmd_refs).get,
1804-
# before_loop_jaxpr.outvars,
1805-
# )
1806-
# ),
1807-
# list(range(mpmd_dim)),
1808-
# )
1809-
1810-
# Merge all partial placement known so far
1811-
placement = {}
1812-
for invar, p in zip(
1813-
before_loop_jaxpr.invars, before_loop_invar_placement, strict=True
1814-
):
1815-
assert invar not in placement
1816-
placement[invar] = p
1817-
1818-
for outvar, p in zip(
1819-
before_loop_jaxpr.outvars, before_loop_outvar_mpmd_defs, strict=True
1820-
):
1821-
assert outvar not in placement
1822-
if p is not None:
1823-
placement[outvar] = p
1824-
1825-
bias: list[set[MpmdIdx] | None] = [None] * len(loop_eqn.invars)
1826-
for invar_idx, invar in enumerate(loop_eqn.invars):
1827-
p = placement.get(invar)
1828-
loop_parameter_p = loop_in_mpmd_refs[invar_idx]
1829-
if loop_parameter_p is not None and p is not None:
1830-
if not loop_parameter_p.issubset(p):
1831-
raise AssertionError()
1832-
1833-
if loop_parameter_p is None and p is not None:
1834-
bias[invar_idx] = p
1835-
loop_placement_changed = any(b is not None for b in bias) # noqa: F841
1836-
1837-
return before_loop_jaxpr, bias
1838-
1839-
18401683
@jc.weakref_lru_cache
18411684
def _wrap_into_tasks(
18421685
cjaxpr: jcore.ClosedJaxpr, used_invars: Sequence[bool], mpmd_dim: int
@@ -1854,11 +1697,11 @@ def _wrap_into_tasks(
18541697
loop_eqn_idx = len(before_loop_eqns)
18551698
loop_eqn = jaxpr.eqns[loop_eqn_idx]
18561699

1857-
# TODO: use partial placement to obtain partial placement from
1858-
# after loop part
1859-
before_loop_jaxpr, bias = _compute_bias(jaxpr, loop_eqn_idx)
1700+
before_loop_jaxpr = jaxpr_from_eqns(
1701+
jaxpr.eqns[:loop_eqn_idx], eqns_free_vars(jaxpr.eqns[loop_eqn_idx:])[0]
1702+
)
18601703
# Use current placement to taskify loop body
1861-
tasked_loop_eqn = wrap_into_tasks_inside_loop(loop_eqn, bias)
1704+
tasked_loop_eqn = wrap_into_tasks_inside_loop(loop_eqn)
18621705

18631706
# Use current placement to taskify before loop
18641707
loop_args_mpmd_refs = join_argument_refs(
@@ -3640,8 +3483,9 @@ def __call__(self, *args, **kwargs):
36403483
jc.equality_errors_pytreedef(self.in_info.in_tree, in_tree)
36413484
)
36423485

3486+
n_consts = len(self.consts)
36433487
for i, arg in enumerate(flat_args):
3644-
expected_mpmd_idx = set(self.in_info.in_mpmd_defs[len(self.consts) + i])
3488+
expected_mpmd_idx = set(self.in_info.in_mpmd_defs[n_consts + i])
36453489
if len(expected_mpmd_idx) == 0:
36463490
continue
36473491

@@ -3663,9 +3507,10 @@ def __call__(self, *args, **kwargs):
36633507
except KeyError:
36643508
pass
36653509

3510+
in_sharding = self.in_info.in_shardings[n_consts + i]
36663511
for mpmd_idx in expected_mpmd_idx:
36673512
(sh,) = updated_named_sharding_mesh(
3668-
(self.in_info.in_shardings[i],),
3513+
(in_sharding,),
36693514
self.mpmd_mesh.unstack[mpmd_idx],
36703515
)
36713516
values[mpmd_idx] = jax.device_put(arg, sh)
@@ -3677,7 +3522,7 @@ def __call__(self, *args, **kwargs):
36773522
mpmd_sharding=MpmdSharding(
36783523
self.mpmd_mesh,
36793524
expected_mpmd_idx,
3680-
self.in_info.in_shardings[i].spec,
3525+
in_sharding.spec,
36813526
),
36823527
)
36833528

src/jaxpp/jax_compat/__init__.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,12 +84,23 @@
8484
from jax._src.pjit import jit_p
8585

8686
# _infer_params was renamed to _trace_for_jit in JAX 0.8.3
87-
if jax.__version_info__ < (0, 8, 3):
87+
if jax.__version_info__ < (0, 8, 3) or jax.__version_info__ > (0, 9, 0):
8888
from jax._src.pjit import _infer_params
8989
else:
9090
from jax._src.pjit import _trace_for_jit as _infer_params
9191

9292

93+
def set_mesh(mesh: jax.sharding.Mesh):
94+
"""Return a context manager that sets the mesh.
95+
96+
JAX >= 0.8 requires ``jax.set_mesh`` for the mesh to be visible to
97+
``jax.jit``; older versions only support ``with mesh:``.
98+
"""
99+
if jax.__version_info__ >= (0, 8):
100+
return jax.set_mesh(mesh)
101+
return mesh
102+
103+
93104
def map_dynamic_args(args, kwargs, static_argnums, static_argnames, fn):
94105
static_argnums = static_argnums or ()
95106
static_argnames = static_argnames or ()
@@ -165,4 +176,5 @@ def map_dynamic_args(args, kwargs, static_argnums, static_argnames, fn):
165176
"convert_constvars_jaxpr",
166177
# utilities
167178
"map_dynamic_args",
179+
"set_mesh",
168180
]

0 commit comments

Comments
 (0)