What happened?
Reported to JAX at jax-ml/jax#36011 (click through for a JAX reproduction).
TL;DR: reduce requires a monoid, which guarantees associativity but not necessarily commutativity. For non-commutative reductions when multiple dimensions are specified, the order of the reduction over dimensions can affect the output, but the StableHLO spec does not specify which order should be used.
As I see it, there are three options here for StableHLO:
- support non-commutative reductions & specify a particular reduction order in the multi-axis case (possibly with a user-defined order).
- support non-commutative reductions but explicitly state that reduction order is platform-dependent.
- explicitly disallow non-commutative reductions in the spec.
What happened?
Reported to JAX at jax-ml/jax#36011 (click through for a JAX reproduction).
TL;DR:
reducerequires a monoid, which guarantees associativity but not necessarily commutativity. For non-commutative reductions when multiple dimensions are specified, the order of the reduction over dimensions can affect the output, but the StableHLO spec does not specify which order should be used.As I see it, there are three options here for StableHLO: