@@ -2268,6 +2268,7 @@ def _axis_index_batcher(axis_data, vals_in, dims_in, *, axis_name):
22682268mlir .register_lowering (axis_index_p , _axis_index_lowering )
22692269axis_index_p .def_effectful_abstract_eval (_axis_index_effectful_abstract_eval )
22702270batching .fancy_primitive_batchers [axis_index_p ] = _axis_index_batcher
2271+
22712272######################## psum_invariant_p ####################################
22722273
22732274def bind_psum_invariant (leaf , * , axes , axis_index_groups ):
@@ -2295,15 +2296,13 @@ def _psum_invariant_abstract_eval(name, aval, *, axes):
22952296 " names mentioned in `axes` passed to `psum` must be present in"
22962297 f" `jax.typeof(inp).vma`. Got axes={ axes } and"
22972298 f" jax.typeof(inp).vma={ aval .vma } " )
2299+ if any (isinstance (a , int ) for a in axes ):
2300+ raise ValueError (f'psum_invariant does not accept integer axes. Got { axes } ' )
22982301
22992302 named_axes = tuple (axis for axis in axes if not isinstance (axis , int ))
2300- pos_axes = tuple (axis for axis in axes if isinstance (axis , int ))
23012303 core .check_avals_context_mesh ([aval ], name )
23022304 check_unreduced_args ([aval ], name )
2303- out_aval = core .ShapedArray (
2304- lax ._reduce_op_shape_rule (aval , axes = pos_axes ), aval .dtype ,
2305- sharding = lax ._reduce_op_sharding_rule (aval , axes = pos_axes ),
2306- vma = frozenset (a for a in aval .vma if a not in named_axes ))
2305+ out_aval = aval .update (vma = frozenset (a for a in aval .vma if a not in named_axes ))
23072306 return out_aval , {core .NamedAxisEffect (axis ) for axis in named_axes }
23082307psum_invariant_p .def_effectful_abstract_eval (
23092308 partial (_psum_invariant_abstract_eval , psum_invariant_p .name ))
0 commit comments