Skip to content

reduce is under-specified for non-commutative operations over multiple axes #2919

@jakevdp

Description

@jakevdp

What happened?

Reported to JAX at jax-ml/jax#36011 (click through for a JAX reproduction).

TL;DR: reduce requires a monoid, which guarantees associativity but not necessarily commutativity. For non-commutative reductions when multiple dimensions are specified, the order of the reduction over dimensions can affect the output, but the StableHLO spec does not specify which order should be used.

As I see it, there are three options here for StableHLO:

  1. support non-commutative reductions & specify a particular reduction order in the multi-axis case (possibly with a user-defined order).
  2. support non-commutative reductions but explicitly state that reduction order is platform-dependent.
  3. explicitly disallow non-commutative reductions in the spec.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions