Skip to content

Commit fb6e2cb

Browse files
yashk2810Google-ML-Automation
authored andcommitted
Simplify psum_invariant's abstract_eval rule by disallowing positional axes.
PiperOrigin-RevId: 860130171
1 parent 039e00d commit fb6e2cb

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

jax/_src/lax/parallel.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2268,6 +2268,7 @@ def _axis_index_batcher(axis_data, vals_in, dims_in, *, axis_name):
22682268
mlir.register_lowering(axis_index_p, _axis_index_lowering)
22692269
axis_index_p.def_effectful_abstract_eval(_axis_index_effectful_abstract_eval)
22702270
batching.fancy_primitive_batchers[axis_index_p] = _axis_index_batcher
2271+
22712272
######################## psum_invariant_p ####################################
22722273

22732274
def bind_psum_invariant(leaf, *, axes, axis_index_groups):
@@ -2295,15 +2296,13 @@ def _psum_invariant_abstract_eval(name, aval, *, axes):
22952296
" names mentioned in `axes` passed to `psum` must be present in"
22962297
f" `jax.typeof(inp).vma`. Got axes={axes} and"
22972298
f" jax.typeof(inp).vma={aval.vma}")
2299+
if any(isinstance(a, int) for a in axes):
2300+
raise ValueError(f'psum_invariant does not accept integer axes. Got {axes}')
22982301

22992302
named_axes = tuple(axis for axis in axes if not isinstance(axis, int))
2300-
pos_axes = tuple(axis for axis in axes if isinstance(axis, int))
23012303
core.check_avals_context_mesh([aval], name)
23022304
check_unreduced_args([aval], name)
2303-
out_aval = core.ShapedArray(
2304-
lax._reduce_op_shape_rule(aval, axes=pos_axes), aval.dtype,
2305-
sharding=lax._reduce_op_sharding_rule(aval, axes=pos_axes),
2306-
vma=frozenset(a for a in aval.vma if a not in named_axes))
2305+
out_aval = aval.update(vma=frozenset(a for a in aval.vma if a not in named_axes))
23072306
return out_aval, {core.NamedAxisEffect(axis) for axis in named_axes}
23082307
psum_invariant_p.def_effectful_abstract_eval(
23092308
partial(_psum_invariant_abstract_eval, psum_invariant_p.name))

0 commit comments

Comments
 (0)