Status: In Review
Initial version: 02/09/2026
Last updated: 02/23/2026
Discussion thread: here
StableHLO programs can do two things: perform local computation (e.g., matrix multiplication) and exchange data via collectives (e.g., an all-reduce). To get high performance, it is crucial that these two things are overlapped. Local computation should be executed while collectives are running in the background, whenever possible.
Today, StableHLO doesn't implement any kind of communication-compute overlap,
though XLA does. The six StableHLO collective
operations---all_gather, all_reduce, all_to_all, collective_broadcast,
collective_permute, and reduce_scatter---are lowered to HLO equivalents.
Internally, the XLA compiler splits these operations into asynchronous
start/done pairs. For example, an all-reduce
operation becomes a pair of an all-reduce-start and all-reduce-done. Then,
the XLA scheduler---the component responsible for picking the order in which to
run ops---can schedule local computation between a start/done pair.
The XLA scheduler is not perfect. Sometimes, it picks bad schedules. That's why we want to allow JAX programmers to manually specify (or at least influence) how their programs are scheduled. This RFC proposes adding asynchronous collectives to StableHLO, which is one step towards this goal.
By exposing async collectives in StableHLO (and also in JAX and other higher-level frameworks), a programmer can write code like the following:
future = all_reduce_start(...)
perform_local_computation(...)
all_reduce_done(future)This RFC introduces an async_start op and an async_done op that allow you to
run an operation asynchronously. We also introduce a new future type (e.g.,
future<tensor<2xf32>>) to represent the output of a start operation. In the
future, we are likely to consider adding scheduling dependencies between async
ops and other ops to enforce an execution orderings, but in the meantime async
ops are used to denote that a backend should use an async decomposition for a
given op.
We introduce a new future type as follows.
ValueType ::= TensorType | QuantizedTensorType | TokenType | TupleType | BufferType | FutureType
FutureType ::= 'future' '<' FutureValueType '>'
FutureValueType ::= TensorType | QuantizedTensorTypeWe introduce an async_start op that takes a variadic number of tensors as
arguments. The op also has a single region that must contain only a call to one
of the six collective ops, or a call to one of the slice ops (slice,
dynamic_slice, dynamic_update_slice). async_start returns a future.
Here's an example:
"stablehlo.async_start"(%x) ({
%y = "stablehlo.all_gather"(%x) {
all_gather_dim = 1 : i64,
replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>
} : (tensor<8x2xf32>) -> tensor<8x8xf32>
"stablehlo.return"(%y) : (tensor<8x8xf32>) -> ()
}) : (tensor<8x2xf32>) -> !stablehlo.future<tensor<8x8xf32>>
It is an error if the region contains anything other than a single call to a collective.
We also introduce an async_done op which takes a future and unwraps it. Here's
an example.
"stablehlo.async_done"(%f1) : (!stablehlo.future<tensor<4x4xf32>>) -> (tensor<4x4xf32>)
This is a StableHLO RFC that proposes adding generic
async_start and async_done ops that can be used to call any function
asynchronously, not just collectives. Here's an example from the RFC that
performs an asynchronous add:
// %init_i: 2
// %init_sum: 3
%future = "stablehlo.async_start"(
%init_i as %arg0: tensor<i64>,
%init_sum as %arg1: tensor<i64>)
{
%new_sum = stablehlo.add %arg1, %arg0 : tensor<i64>
stablehlo.return %new_sum : tensor<i64>
} : (tensor<i64>, tensor<i64>) -> async<tensor<i64>>
%result = "stablehlo.async_done"(%future): async<tensor<i64>> -> tensor<i64>
// %result: 5
This RFC proposes something simpler yet less powerful. In the future, we could migrate to fully generic async ops.
Rather than introducing async_start and async_done, we could introduce six
new start ops:
all_gather_startall_reduce_startall_to_all_startcollective_broadcast_startcollective_permute_startreduce_scatter_start
These ops are identical to their non-asynchronous counterparts. They take the same arguments and have the same constraints. The only difference is that they return futures. Here's an example:
%future = "stablehlo.collective_permute_start"(%operand) {
source_target_pairs = dense<[[0, 1], [1, 2]]> : tensor<2x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x2xi64>) -> future<tensor<2x2xi64>>
We could also introduce six done ops.
all_gather_doneall_reduce_doneall_to_all _donecollective_broadcast_donecollective_permute _donereduce_scatter_done
A done op takes a future<T> as an argument and returns a T. Continuing the
example above:
%result = "stablehlo.collective_permute_done"(%future) : (future<tensor<2x2xi64>>) -> tensor<2x2xi64>
Start ops could return regular tensors instead of futures. The value of these tensors, however, would be indeterminate. The tensors should not be used in any way besides as arguments to done ops. Here's an example:
%indeterminate = "stablehlo.collective_permute_start"(%operand) {
source_target_pairs = dense<[[0, 1], [1, 2]]> : tensor<2x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x2xi64>) -> tensor<2x2xi64>
%result = "stablehlo.collective_permute_done"(%indeterminate) : (tensor<2x2xi64>) -> tensor<2x2xi64>
This approach mirrors how HLO represents asynchronous ops. It also avoids introducing a new future type. However, it is less type-safe.
This RFC has every collective return the same future type. Thus, the following code is well-typed but erroneous.
%future = "stablehlo.collective_permute_start"(%operand) {
source_target_pairs = dense<[[0, 1], [1, 2]]> : tensor<2x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x2xi64>) -> future<tensor<2x2xi64>>
%result = "stablehlo.all_reduce_done"(%future) : (future<tensor<2x2xi64>>) -> tensor<2x2xi64>
We could instead introduce a separate future type for every collective. For
example, collective_permute_start could return a
collective_permute_future<...>, and collective_permute_done could take a
collective_permute_future<...> as an argument.
This would introduce more type safety.