Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 194 additions & 0 deletions rfcs/20260209-async-collectives.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
# [RFC] Async Ops

Status: In Review<br/>
Initial version: 02/09/2026<br/>
Last updated: 02/23/2026<br/>
Discussion thread: [here][discussion_thread]

## Motivation

StableHLO programs can do two things: perform local computation (e.g., matrix
multiplication) and exchange data via collectives (e.g., an all-reduce). To get
Comment thread
mwhittaker marked this conversation as resolved.
high performance, it is crucial that these two things are overlapped. Local
computation should be executed while collectives are running in the background,
whenever possible.

Today, StableHLO doesn't implement any kind of communication-compute overlap,
though [XLA does][async_hlo]. The six StableHLO collective
operations---`all_gather`, `all_reduce`, `all_to_all`, `collective_broadcast`,
`collective_permute`, and `reduce_scatter`---are lowered to HLO equivalents.
Comment thread
mwhittaker marked this conversation as resolved.
Internally, [the XLA compiler splits these operations into asynchronous
start/done pairs][async_collective_creator]. For example, an `all-reduce`
operation becomes a pair of an `all-reduce-start` and `all-reduce-done`. Then,
the XLA scheduler---the component responsible for picking the order in which to
run ops---can schedule local computation between a start/done pair.

The XLA scheduler is not perfect. Sometimes, it picks bad schedules. That's why
we want to allow JAX programmers to manually specify (or at least influence) how
their programs are scheduled. This RFC proposes adding asynchronous collectives
to StableHLO, which is one step towards this goal.

By exposing async collectives in StableHLO (and also in JAX and other
higher-level frameworks), a programmer can write code like the following:

```python
future = all_reduce_start(...)
perform_local_computation(...)
all_reduce_done(future)
```

## Overview

This RFC introduces an `async_start` op and an `async_done` op that allow you to
run an operation asynchronously. We also introduce a new future type (e.g.,
`future<tensor<2xf32>>`) to represent the output of a start operation. In the
future, we are likely to consider adding scheduling dependencies between async
ops and other ops to enforce an execution orderings, but in the meantime async
ops are used to denote that a backend should use an async decomposition for a
given op.

## Proposed Type Changes

We introduce a new future type as follows.

```ebnf
ValueType ::= TensorType | QuantizedTensorType | TokenType | TupleType | BufferType | FutureType
FutureType ::= 'future' '<' FutureValueType '>'
FutureValueType ::= TensorType | QuantizedTensorType
```

## Proposed Op Changes
Comment thread
mwhittaker marked this conversation as resolved.
Comment thread
mwhittaker marked this conversation as resolved.

We introduce an `async_start` op that takes a variadic number of tensors as
arguments. The op also has a single region that must contain only a call to one
of the six collective ops, or a call to one of the slice ops (`slice`,
`dynamic_slice`, `dynamic_update_slice`). `async_start` returns a future.
Here's an example:

```text
"stablehlo.async_start"(%x) ({
%y = "stablehlo.all_gather"(%x) {
all_gather_dim = 1 : i64,
replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>
} : (tensor<8x2xf32>) -> tensor<8x8xf32>
"stablehlo.return"(%y) : (tensor<8x8xf32>) -> ()
}) : (tensor<8x2xf32>) -> !stablehlo.future<tensor<8x8xf32>>
```

It is an error if the region contains anything other than a single call to a
collective.

We also introduce an `async_done` op which takes a future and unwraps it. Here's
an example.

```text
"stablehlo.async_done"(%f1) : (!stablehlo.future<tensor<4x4xf32>>) -> (tensor<4x4xf32>)
```

## Alternatives
Comment thread
mwhittaker marked this conversation as resolved.

### Fully Generic Async Ops

[This][generic_async_rfc] is a StableHLO RFC that proposes adding generic
`async_start` and `async_done` ops that can be used to call *any* function
asynchronously, not just collectives. Here's an example from the RFC that
performs an asynchronous add:

```text
// %init_i: 2
// %init_sum: 3
%future = "stablehlo.async_start"(
%init_i as %arg0: tensor<i64>,
%init_sum as %arg1: tensor<i64>)
{
%new_sum = stablehlo.add %arg1, %arg0 : tensor<i64>
stablehlo.return %new_sum : tensor<i64>
} : (tensor<i64>, tensor<i64>) -> async<tensor<i64>>

%result = "stablehlo.async_done"(%future): async<tensor<i64>> -> tensor<i64>
// %result: 5
```

This RFC proposes something simpler yet less powerful. In the future, we
could migrate to fully generic async ops.

### Explicit Start/Done Pairs
Comment thread
mwhittaker marked this conversation as resolved.

Rather than introducing `async_start` and `async_done`, we could introduce six
new **start ops**:

- `all_gather_start`
- `all_reduce_start`
- `all_to_all_start`
- `collective_broadcast_start`
- `collective_permute_start`
- `reduce_scatter_start`

These ops are identical to their non-asynchronous counterparts. They take the
same arguments and have the same constraints. The only difference is that they
return futures. Here's an example:

```text
%future = "stablehlo.collective_permute_start"(%operand) {
source_target_pairs = dense<[[0, 1], [1, 2]]> : tensor<2x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x2xi64>) -> future<tensor<2x2xi64>>
```

We could also introduce six **done ops**.

- `all_gather_done`
- `all_reduce_done`
- `all_to_all _done`
- `collective_broadcast_done`
- `collective_permute _done`
- `reduce_scatter_done`

A done op takes a `future<T>` as an argument and returns a `T`. Continuing the
example above:

```text
%result = "stablehlo.collective_permute_done"(%future) : (future<tensor<2x2xi64>>) -> tensor<2x2xi64>
```

### Tensors Instead of Futures

Start ops could return regular tensors instead of futures. The value of these
tensors, however, would be indeterminate. The tensors should not be used in any
way besides as arguments to done ops. Here's an example:

```text
%indeterminate = "stablehlo.collective_permute_start"(%operand) {
source_target_pairs = dense<[[0, 1], [1, 2]]> : tensor<2x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x2xi64>) -> tensor<2x2xi64>
%result = "stablehlo.collective_permute_done"(%indeterminate) : (tensor<2x2xi64>) -> tensor<2x2xi64>
```

This approach mirrors how HLO represents asynchronous ops. It also avoids
introducing a new future type. However, it is less type-safe.

### Collective in Types

This RFC has every collective return the same future type. Thus, the following
code is well-typed but erroneous.

```text
%future = "stablehlo.collective_permute_start"(%operand) {
source_target_pairs = dense<[[0, 1], [1, 2]]> : tensor<2x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x2xi64>) -> future<tensor<2x2xi64>>
%result = "stablehlo.all_reduce_done"(%future) : (future<tensor<2x2xi64>>) -> tensor<2x2xi64>
```

We could instead introduce a separate future type for every collective. For
example, `collective_permute_start` could return a
`collective_permute_future<...>`, and `collective_permute_done` could take a
`collective_permute_future<...>` as an argument.

This would introduce more type safety.

[async_collective_creator]: https://github.com/openxla/xla/blob/391c1c5fdadde89ee81886495d32dc32f9238af1/xla/hlo/transforms/collectives/async_collective_creator.h#L38
[async_hlo]: https://openxla.org/xla/async_ops
[discussion_thread]: https://github.com/openxla/stablehlo/pull/2897/changes
[generic_async_rfc]: https://github.com/openxla/stablehlo/pull/2551
Loading