Skip to content

Latest commit

 

History

History
194 lines (151 loc) · 7.1 KB

File metadata and controls

194 lines (151 loc) · 7.1 KB

[RFC] Async Ops

Status: In Review
Initial version: 02/09/2026
Last updated: 02/23/2026
Discussion thread: here

Motivation

StableHLO programs can do two things: perform local computation (e.g., matrix multiplication) and exchange data via collectives (e.g., an all-reduce). To get high performance, it is crucial that these two things are overlapped. Local computation should be executed while collectives are running in the background, whenever possible.

Today, StableHLO doesn't implement any kind of communication-compute overlap, though XLA does. The six StableHLO collective operations---all_gather, all_reduce, all_to_all, collective_broadcast, collective_permute, and reduce_scatter---are lowered to HLO equivalents. Internally, the XLA compiler splits these operations into asynchronous start/done pairs. For example, an all-reduce operation becomes a pair of an all-reduce-start and all-reduce-done. Then, the XLA scheduler---the component responsible for picking the order in which to run ops---can schedule local computation between a start/done pair.

The XLA scheduler is not perfect. Sometimes, it picks bad schedules. That's why we want to allow JAX programmers to manually specify (or at least influence) how their programs are scheduled. This RFC proposes adding asynchronous collectives to StableHLO, which is one step towards this goal.

By exposing async collectives in StableHLO (and also in JAX and other higher-level frameworks), a programmer can write code like the following:

future = all_reduce_start(...)
perform_local_computation(...)
all_reduce_done(future)

Overview

This RFC introduces an async_start op and an async_done op that allow you to run an operation asynchronously. We also introduce a new future type (e.g., future<tensor<2xf32>>) to represent the output of a start operation. In the future, we are likely to consider adding scheduling dependencies between async ops and other ops to enforce an execution orderings, but in the meantime async ops are used to denote that a backend should use an async decomposition for a given op.

Proposed Type Changes

We introduce a new future type as follows.

ValueType ::= TensorType | QuantizedTensorType | TokenType | TupleType | BufferType | FutureType
FutureType ::= 'future' '<' FutureValueType '>'
FutureValueType ::= TensorType | QuantizedTensorType

Proposed Op Changes

We introduce an async_start op that takes a variadic number of tensors as arguments. The op also has a single region that must contain only a call to one of the six collective ops, or a call to one of the slice ops (slice, dynamic_slice, dynamic_update_slice). async_start returns a future. Here's an example:

"stablehlo.async_start"(%x) ({
  %y = "stablehlo.all_gather"(%x) {
    all_gather_dim = 1 : i64,
    replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>
  } : (tensor<8x2xf32>) -> tensor<8x8xf32>
  "stablehlo.return"(%y) : (tensor<8x8xf32>) -> ()
}) : (tensor<8x2xf32>) -> !stablehlo.future<tensor<8x8xf32>>

It is an error if the region contains anything other than a single call to a collective.

We also introduce an async_done op which takes a future and unwraps it. Here's an example.

"stablehlo.async_done"(%f1) : (!stablehlo.future<tensor<4x4xf32>>) -> (tensor<4x4xf32>)

Alternatives

Fully Generic Async Ops

This is a StableHLO RFC that proposes adding generic async_start and async_done ops that can be used to call any function asynchronously, not just collectives. Here's an example from the RFC that performs an asynchronous add:

// %init_i: 2
// %init_sum: 3
%future = "stablehlo.async_start"(
    %init_i as %arg0: tensor<i64>,
    %init_sum as %arg1: tensor<i64>)
{
    %new_sum = stablehlo.add %arg1, %arg0 : tensor<i64>
    stablehlo.return %new_sum : tensor<i64>
} : (tensor<i64>, tensor<i64>) -> async<tensor<i64>>

%result = "stablehlo.async_done"(%future): async<tensor<i64>> -> tensor<i64>
// %result: 5

This RFC proposes something simpler yet less powerful. In the future, we could migrate to fully generic async ops.

Explicit Start/Done Pairs

Rather than introducing async_start and async_done, we could introduce six new start ops:

  • all_gather_start
  • all_reduce_start
  • all_to_all_start
  • collective_broadcast_start
  • collective_permute_start
  • reduce_scatter_start

These ops are identical to their non-asynchronous counterparts. They take the same arguments and have the same constraints. The only difference is that they return futures. Here's an example:

%future = "stablehlo.collective_permute_start"(%operand) {
  source_target_pairs = dense<[[0, 1], [1, 2]]> : tensor<2x2xi64>,
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x2xi64>) -> future<tensor<2x2xi64>>

We could also introduce six done ops.

  • all_gather_done
  • all_reduce_done
  • all_to_all _done
  • collective_broadcast_done
  • collective_permute _done
  • reduce_scatter_done

A done op takes a future<T> as an argument and returns a T. Continuing the example above:

%result = "stablehlo.collective_permute_done"(%future) : (future<tensor<2x2xi64>>) -> tensor<2x2xi64>

Tensors Instead of Futures

Start ops could return regular tensors instead of futures. The value of these tensors, however, would be indeterminate. The tensors should not be used in any way besides as arguments to done ops. Here's an example:

%indeterminate = "stablehlo.collective_permute_start"(%operand) {
  source_target_pairs = dense<[[0, 1], [1, 2]]> : tensor<2x2xi64>,
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x2xi64>) -> tensor<2x2xi64>
%result = "stablehlo.collective_permute_done"(%indeterminate) : (tensor<2x2xi64>) -> tensor<2x2xi64>

This approach mirrors how HLO represents asynchronous ops. It also avoids introducing a new future type. However, it is less type-safe.

Collective in Types

This RFC has every collective return the same future type. Thus, the following code is well-typed but erroneous.

%future = "stablehlo.collective_permute_start"(%operand) {
  source_target_pairs = dense<[[0, 1], [1, 2]]> : tensor<2x2xi64>,
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x2xi64>) -> future<tensor<2x2xi64>>
%result = "stablehlo.all_reduce_done"(%future) : (future<tensor<2x2xi64>>) -> tensor<2x2xi64>

We could instead introduce a separate future type for every collective. For example, collective_permute_start could return a collective_permute_future<...>, and collective_permute_done could take a collective_permute_future<...> as an argument.

This would introduce more type safety.