@@ -1547,10 +1547,11 @@ def fun(*res_and_args):
15471547 return jaxpr
15481548
15491549
1550- def _unmentioned2 (mesh : Mesh , names : AxisNames ) -> list [AxisName ]:
1550+ def _unmentioned2 (mesh : Mesh , names : AxisNames ,
1551+ auto : frozenset [AxisName ]) -> list [AxisName ]:
15511552 # We use a filtered-down version of unmentioned to avoid defensive-psum over
15521553 # more chips than required in the transpose-no-check-rep case.
1553- name_set = {n for ns in names .values () for n in ns }
1554+ name_set = {n for ns in names .values () for n in ns } | auto
15541555 return [n for n in _all_mesh_names_except_spmd (mesh ) if n not in name_set ]
15551556
15561557
@@ -1559,7 +1560,7 @@ def _shard_map_transpose(out_cts, *args, jaxpr, mesh, in_names, out_names,
15591560 mb_div = lambda x , y : x / y if y != 1 else x
15601561 out_cts = [ad .Zero (_shard_aval (mesh , ns , x .aval )) if type (x ) is ad .Zero
15611562 else x if rewrite or dtypes .dtype (x ) == dtypes .float0
1562- else mb_div (x , prod (map (mesh .shape .get , _unmentioned2 (mesh , ns ))))
1563+ else mb_div (x , prod (map (mesh .shape .get , _unmentioned2 (mesh , ns , auto ))))
15631564 for ns , x in zip (out_names , out_cts )]
15641565 args = [x if type (x ) is not ad .UndefinedPrimal else
15651566 ad .UndefinedPrimal (_shard_aval (mesh , ns , x .aval ))
@@ -1577,7 +1578,7 @@ def fun_trans(out_cts, args):
15771578 )
15781579 out = [ad .Zero (_unshard_aval (mesh , ns , x .aval )) if type (x ) is ad .Zero
15791580 else x if rewrite
1580- else jax .lax .psum (x , tuple (_unmentioned2 (mesh , ns )))
1581+ else jax .lax .psum (x , tuple (_unmentioned2 (mesh , ns , auto )))
15811582 for ns , x in zip (in_names , out )]
15821583 return out
15831584
0 commit comments