103103logger = logging .getLogger (__name__ )
104104
105105
106- @contextmanager
107- def stable_names_ctx (anno : Callable [[jcore .Var ], str | None ] = lambda v : None ):
108- prev_repr = jax ._src .core .Var .__repr__
109- prev_pp_var = jax ._src .core .pp_var
110-
111- ctx = jcore .JaxprPpContext ()
112-
113- def __repr__ (v ):
114- if isinstance (v , jcore .Literal ):
115- return f"{ v } "
116- if (s := anno (v )) is not None :
117- return f"{ prev_pp_var (v , ctx )} {{{ s } }}"
118-
119- return f"{ prev_pp_var (v , ctx )} "
120-
121- jax ._src .core .pp_var = lambda v , _ : __repr__ (v )
122- jax ._src .core .Var .__repr__ = __repr__
123-
124- try :
125- yield
126- finally :
127- jax ._src .core .Var .__repr__ = prev_repr
128- jax ._src .core .pp_var = prev_pp_var
129-
130-
131106CJaxpr = TypeVar ("CJaxpr" , jcore .ClosedJaxpr , jcore .Jaxpr )
132107Res = TypeVar ("Res" )
133108P = ParamSpec ("P" )
@@ -893,9 +868,7 @@ def first_pipeline_yield_eqn_idx(eqns: Iterable[jcore.JaxprEqn]) -> int | None:
893868def infer_cluster_idx_for_eqns (
894869 clusters : list [Cluster ],
895870 eqns : list [jcore .JaxprEqn ],
896- bias : dict [jcore .Var , set [MpmdIdx ]] | None = None ,
897871) -> list [int | None ]:
898- bias = bias or {}
899872 cluster_info = get_cluster_information (clusters )
900873 var_def_cluster_idx = cluster_info .var_def_cluster_idx
901874 var_ref_cluster_idx = cluster_info .var_ref_cluster_idx
@@ -1048,11 +1021,9 @@ def cluster_by_yield_eqns(
10481021def cluster_eqns (
10491022 eqns : list [jcore .JaxprEqn ],
10501023 get_mpmd_idx : Callable [[int ], MpmdIdx ],
1051- bias : dict [jcore .Var , set [MpmdIdx ]] | None = None ,
10521024) -> tuple [list [Cluster ], list [jcore .JaxprEqn ]]:
1053- bias = bias or {}
10541025 clusters , rest = cluster_by_yield_eqns (eqns , get_mpmd_idx )
1055- eqns_cluster_idxs = infer_cluster_idx_for_eqns (clusters , rest , bias )
1026+ eqns_cluster_idxs = infer_cluster_idx_for_eqns (clusters , rest )
10561027 unclustered_eqns = list [jcore .JaxprEqn ]()
10571028 for cluster_idx , eqn in zip (eqns_cluster_idxs , rest , strict = True ):
10581029 if cluster_idx is not None :
@@ -1124,19 +1095,10 @@ def cluster_jaxpr(
11241095 target_num_stages : int ,
11251096 is_partial_bwd : bool ,
11261097 get_mpmd_idx : Callable [[int ], MpmdIdx ],
1127- bias : list [set [MpmdIdx ] | None ] | None = None ,
11281098 is_loop : bool = True ,
11291099):
11301100 # TODO: remove is_loop parameter and make the caller perform the checks
1131- bias_map = None
1132- if bias is not None :
1133- bias_map = {
1134- invar : p
1135- for invar , p in zip (jaxpr .invars , bias , strict = True )
1136- if p is not None
1137- }
1138-
1139- clusters , unclustered_eqns = cluster_eqns (jaxpr .eqns , get_mpmd_idx , bias_map )
1101+ clusters , unclustered_eqns = cluster_eqns (jaxpr .eqns , get_mpmd_idx )
11401102 if (
11411103 is_loop
11421104 and len (unclustered_eqns ) != 0
@@ -1190,10 +1152,7 @@ def cluster_jaxpr(
11901152 return clustered_jaxpr
11911153
11921154
1193- def wrap_into_tasks_inside_loop (
1194- loop_eqn : jcore .JaxprEqn ,
1195- bias : list [set [MpmdIdx ] | None ] | None = None ,
1196- ) -> jcore .JaxprEqn :
1155+ def wrap_into_tasks_inside_loop (loop_eqn : jcore .JaxprEqn ) -> jcore .JaxprEqn :
11971156 jaxpr : jcore .Jaxpr = loop_eqn .params ["jaxpr" ].jaxpr
11981157 # TODO: let bind literals
11991158 assert len (jaxpr .outvars ) == len (
@@ -1205,7 +1164,6 @@ def wrap_into_tasks_inside_loop(
12051164 target_num_stages = loop_eqn .params ["schedule" ].num_stages ,
12061165 is_partial_bwd = loop_eqn .params ["schedule" ].is_partial_bwd ,
12071166 get_mpmd_idx = loop_eqn .params ["schedule" ].get_mpmd_idx ,
1208- bias = bias ,
12091167 )
12101168
12111169 # Infer where loop inputs are used (refs) and where loop outputs
@@ -1401,43 +1359,6 @@ def make_replicated_jaxpr(
14011359 return res , invar_mpmd_refs
14021360
14031361
1404- def infer_outvar_placement_rev (
1405- jaxpr : jcore .Jaxpr , partial_outvar_placement : Iterable [set [MpmdIdx ] | None ]
1406- ) -> tuple [list [set [MpmdIdx ]], list [set [MpmdIdx ] | None ]]:
1407- partial_outvar_placement = tuple (partial_outvar_placement )
1408- outvars = cast (list [jcore .Var ], jaxpr .outvars )
1409- placement = {
1410- outvar : maybe_p
1411- for outvar , maybe_p in jc .safe_zip (outvars , partial_outvar_placement )
1412- if isinstance (outvar , jcore .Var ) and maybe_p is not None
1413- }
1414-
1415- # Infer from outvars to invars
1416- for eqn in reversed (jaxpr .eqns ):
1417- eqn_p = set .union (
1418- set (), * (placement .get (outvar , set ()) for outvar in eqn .outvars )
1419- )
1420- if len (eqn_p ) > 0 :
1421- for invar in nonlit (eqn .invars ):
1422- placement [invar ] = placement .get (invar , set ()) | eqn_p
1423-
1424- # Infer from invars to outvars
1425- for eqn in jaxpr .eqns :
1426- eqn_p = set .union (
1427- set (), * (placement .get (invar , set ()) for invar in nonlit (eqn .invars ))
1428- )
1429- if len (eqn_p ) > 0 :
1430- for outvar in eqn .outvars :
1431- placement [outvar ] = placement .get (outvar , set ()) | eqn_p
1432-
1433- return [placement .get (invar ) for invar in jaxpr .invars ], [
1434- placement .get (outvar )
1435- if isinstance (outvar , jcore .Var )
1436- else partial_outvar_placement [outvar_idx ]
1437- for outvar_idx , outvar in enumerate (outvars )
1438- ]
1439-
1440-
14411362def get_one_loop_eqn_idx (
14421363 eqns_or_jaxpr : jcore .ClosedJaxpr | jcore .Jaxpr | Iterable [jcore .JaxprEqn ],
14431364) -> int :
@@ -1716,25 +1637,18 @@ def loop_placement_by_clusters(
17161637 if (idx := outvar_idx .get (outvar )) is not None :
17171638 out_mpmd_defs [idx ] = clusters [def_cluster_idx ].mpmd_idx
17181639
1719- with stable_names_ctx (
1720- lambda v : {clusters [idx ].mpmd_idx for idx in idxs }
1721- if (idxs := cluster_info .var_ref_cluster_idx .get (v )) is not None
1722- else {clusters [idx ].mpmd_idx }
1723- if (idx := cluster_info .var_def_cluster_idx .get (v )) is not None
1724- else None
1640+ for in_idx , (mpmd_refs , mpmd_def ) in enumerate (
1641+ jc .safe_zip (in_mpmd_refs [n_consts :], out_mpmd_defs ), start = n_consts
17251642 ):
1726- for in_idx , (mpmd_refs , mpmd_def ) in enumerate (
1727- jc .safe_zip (in_mpmd_refs [n_consts :], out_mpmd_defs ), start = n_consts
1728- ):
1729- # Check that the mpmd_index that produces an outvar
1730- # is a subset of the ones that refer to it.
1731- if mpmd_refs is not None :
1732- if mpmd_def not in mpmd_refs :
1733- raise AssertionError (
1734- f"Loop state is not stable across iterations { in_idx = } { in_idx - n_consts = } "
1735- )
1736- elif mpmd_def is not None :
1737- in_mpmd_refs [in_idx ] = {mpmd_def }
1643+ # Check that the mpmd_index that produces an outvar
1644+ # is a subset of the ones that refer to it.
1645+ if mpmd_refs is not None :
1646+ if mpmd_def not in mpmd_refs :
1647+ raise AssertionError (
1648+ f"Loop state is not stable across iterations { in_idx = } { in_idx - n_consts = } "
1649+ )
1650+ elif mpmd_def is not None :
1651+ in_mpmd_refs [in_idx ] = {mpmd_def }
17381652
17391653 return in_mpmd_refs , out_mpmd_defs
17401654
@@ -1766,77 +1680,6 @@ def join_argument_refs(
17661680 return loop_args_mpmd_refs_map
17671681
17681682
1769- def _compute_bias (jaxpr : jcore .Jaxpr , loop_eqn_idx : int ):
1770- loop_eqn = jaxpr .eqns [loop_eqn_idx ]
1771-
1772- # Infer partial placement from loop body
1773- loop_in_mpmd_refs , loop_out_mpmd_defs = loop_placement_by_clusters (
1774- loop_eqn , loop_eqn .params ["schedule" ].get_mpmd_idx
1775- )
1776-
1777- # Use partial placement from loop body to infer
1778- # before loop partial placement
1779- before_loop_jaxpr = jaxpr_from_eqns (
1780- jaxpr .eqns [:loop_eqn_idx ], eqns_free_vars (jaxpr .eqns [loop_eqn_idx :])[0 ]
1781- )
1782-
1783- loop_args_mpmd_refs = join_argument_refs (
1784- cast (list [jcore .Var ], loop_eqn .invars ), loop_in_mpmd_refs
1785- )
1786-
1787- before_loop_invar_placement , before_loop_outvar_mpmd_defs = (
1788- infer_outvar_placement_rev (
1789- before_loop_jaxpr ,
1790- partial_outvar_placement = tuple (
1791- loop_args_mpmd_refs .get (outvar )
1792- # NOTE: before loop outvars are just vars as it comes from
1793- # `jaxpr_from_eqns`
1794- for outvar in cast (list [jcore .Var ], before_loop_jaxpr .outvars )
1795- ),
1796- )
1797- )
1798-
1799- # make_replicated_jaxpr(
1800- # before_loop_jaxpr,
1801- # tuple(
1802- # map(
1803- # join_argument_refs(loop_eqn.invars, loop_in_mpmd_refs).get,
1804- # before_loop_jaxpr.outvars,
1805- # )
1806- # ),
1807- # list(range(mpmd_dim)),
1808- # )
1809-
1810- # Merge all partial placement known so far
1811- placement = {}
1812- for invar , p in zip (
1813- before_loop_jaxpr .invars , before_loop_invar_placement , strict = True
1814- ):
1815- assert invar not in placement
1816- placement [invar ] = p
1817-
1818- for outvar , p in zip (
1819- before_loop_jaxpr .outvars , before_loop_outvar_mpmd_defs , strict = True
1820- ):
1821- assert outvar not in placement
1822- if p is not None :
1823- placement [outvar ] = p
1824-
1825- bias : list [set [MpmdIdx ] | None ] = [None ] * len (loop_eqn .invars )
1826- for invar_idx , invar in enumerate (loop_eqn .invars ):
1827- p = placement .get (invar )
1828- loop_parameter_p = loop_in_mpmd_refs [invar_idx ]
1829- if loop_parameter_p is not None and p is not None :
1830- if not loop_parameter_p .issubset (p ):
1831- raise AssertionError ()
1832-
1833- if loop_parameter_p is None and p is not None :
1834- bias [invar_idx ] = p
1835- loop_placement_changed = any (b is not None for b in bias ) # noqa: F841
1836-
1837- return before_loop_jaxpr , bias
1838-
1839-
18401683@jc .weakref_lru_cache
18411684def _wrap_into_tasks (
18421685 cjaxpr : jcore .ClosedJaxpr , used_invars : Sequence [bool ], mpmd_dim : int
@@ -1854,11 +1697,11 @@ def _wrap_into_tasks(
18541697 loop_eqn_idx = len (before_loop_eqns )
18551698 loop_eqn = jaxpr .eqns [loop_eqn_idx ]
18561699
1857- # TODO: use partial placement to obtain partial placement from
1858- # after loop part
1859- before_loop_jaxpr , bias = _compute_bias ( jaxpr , loop_eqn_idx )
1700+ before_loop_jaxpr = jaxpr_from_eqns (
1701+ jaxpr . eqns [: loop_eqn_idx ], eqns_free_vars ( jaxpr . eqns [ loop_eqn_idx :])[ 0 ]
1702+ )
18601703 # Use current placement to taskify loop body
1861- tasked_loop_eqn = wrap_into_tasks_inside_loop (loop_eqn , bias )
1704+ tasked_loop_eqn = wrap_into_tasks_inside_loop (loop_eqn )
18621705
18631706 # Use current placement to taskify before loop
18641707 loop_args_mpmd_refs = join_argument_refs (
@@ -3640,8 +3483,9 @@ def __call__(self, *args, **kwargs):
36403483 jc .equality_errors_pytreedef (self .in_info .in_tree , in_tree )
36413484 )
36423485
3486+ n_consts = len (self .consts )
36433487 for i , arg in enumerate (flat_args ):
3644- expected_mpmd_idx = set (self .in_info .in_mpmd_defs [len ( self . consts ) + i ])
3488+ expected_mpmd_idx = set (self .in_info .in_mpmd_defs [n_consts + i ])
36453489 if len (expected_mpmd_idx ) == 0 :
36463490 continue
36473491
@@ -3663,9 +3507,10 @@ def __call__(self, *args, **kwargs):
36633507 except KeyError :
36643508 pass
36653509
3510+ in_sharding = self .in_info .in_shardings [n_consts + i ]
36663511 for mpmd_idx in expected_mpmd_idx :
36673512 (sh ,) = updated_named_sharding_mesh (
3668- (self . in_info . in_shardings [ i ] ,),
3513+ (in_sharding ,),
36693514 self .mpmd_mesh .unstack [mpmd_idx ],
36703515 )
36713516 values [mpmd_idx ] = jax .device_put (arg , sh )
@@ -3677,7 +3522,7 @@ def __call__(self, *args, **kwargs):
36773522 mpmd_sharding = MpmdSharding (
36783523 self .mpmd_mesh ,
36793524 expected_mpmd_idx ,
3680- self . in_info . in_shardings [ i ] .spec ,
3525+ in_sharding .spec ,
36813526 ),
36823527 )
36833528
0 commit comments