-
Notifications
You must be signed in to change notification settings - Fork 191
Add async collectives RFC. #2897
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
mwhittaker
wants to merge
3
commits into
openxla:main
Choose a base branch
from
mwhittaker:async_collectives_rfc
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+194
−0
Open
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,194 @@ | ||
| # [RFC] Async Ops | ||
|
|
||
| Status: In Review<br/> | ||
| Initial version: 02/09/2026<br/> | ||
| Last updated: 02/23/2026<br/> | ||
| Discussion thread: [here][discussion_thread] | ||
|
|
||
| ## Motivation | ||
|
|
||
| StableHLO programs can do two things: perform local computation (e.g., matrix | ||
| multiplication) and exchange data via collectives (e.g., an all-reduce). To get | ||
| high performance, it is crucial that these two things are overlapped. Local | ||
| computation should be executed while collectives are running in the background, | ||
| whenever possible. | ||
|
|
||
| Today, StableHLO doesn't implement any kind of communication-compute overlap, | ||
| though [XLA does][async_hlo]. The six StableHLO collective | ||
| operations---`all_gather`, `all_reduce`, `all_to_all`, `collective_broadcast`, | ||
| `collective_permute`, and `reduce_scatter`---are lowered to HLO equivalents. | ||
|
mwhittaker marked this conversation as resolved.
|
||
| Internally, [the XLA compiler splits these operations into asynchronous | ||
| start/done pairs][async_collective_creator]. For example, an `all-reduce` | ||
| operation becomes a pair of an `all-reduce-start` and `all-reduce-done`. Then, | ||
| the XLA scheduler---the component responsible for picking the order in which to | ||
| run ops---can schedule local computation between a start/done pair. | ||
|
|
||
| The XLA scheduler is not perfect. Sometimes, it picks bad schedules. That's why | ||
| we want to allow JAX programmers to manually specify (or at least influence) how | ||
| their programs are scheduled. This RFC proposes adding asynchronous collectives | ||
| to StableHLO, which is one step towards this goal. | ||
|
|
||
| By exposing async collectives in StableHLO (and also in JAX and other | ||
| higher-level frameworks), a programmer can write code like the following: | ||
|
|
||
| ```python | ||
| future = all_reduce_start(...) | ||
| perform_local_computation(...) | ||
| all_reduce_done(future) | ||
| ``` | ||
|
|
||
| ## Overview | ||
|
|
||
| This RFC introduces an `async_start` op and an `async_done` op that allow you to | ||
| run an operation asynchronously. We also introduce a new future type (e.g., | ||
| `future<tensor<2xf32>>`) to represent the output of a start operation. In the | ||
| future, we are likely to consider adding scheduling dependencies between async | ||
| ops and other ops to enforce an execution orderings, but in the meantime async | ||
| ops are used to denote that a backend should use an async decomposition for a | ||
| given op. | ||
|
|
||
| ## Proposed Type Changes | ||
|
|
||
| We introduce a new future type as follows. | ||
|
|
||
| ```ebnf | ||
| ValueType ::= TensorType | QuantizedTensorType | TokenType | TupleType | BufferType | FutureType | ||
| FutureType ::= 'future' '<' FutureValueType '>' | ||
| FutureValueType ::= TensorType | QuantizedTensorType | ||
| ``` | ||
|
|
||
| ## Proposed Op Changes | ||
|
mwhittaker marked this conversation as resolved.
mwhittaker marked this conversation as resolved.
|
||
|
|
||
| We introduce an `async_start` op that takes a variadic number of tensors as | ||
| arguments. The op also has a single region that must contain only a call to one | ||
| of the six collective ops, or a call to one of the slice ops (`slice`, | ||
| `dynamic_slice`, `dynamic_update_slice`). `async_start` returns a future. | ||
| Here's an example: | ||
|
|
||
| ```text | ||
| "stablehlo.async_start"(%x) ({ | ||
| %y = "stablehlo.all_gather"(%x) { | ||
| all_gather_dim = 1 : i64, | ||
| replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> | ||
| } : (tensor<8x2xf32>) -> tensor<8x8xf32> | ||
| "stablehlo.return"(%y) : (tensor<8x8xf32>) -> () | ||
| }) : (tensor<8x2xf32>) -> !stablehlo.future<tensor<8x8xf32>> | ||
| ``` | ||
|
|
||
| It is an error if the region contains anything other than a single call to a | ||
| collective. | ||
|
|
||
| We also introduce an `async_done` op which takes a future and unwraps it. Here's | ||
| an example. | ||
|
|
||
| ```text | ||
| "stablehlo.async_done"(%f1) : (!stablehlo.future<tensor<4x4xf32>>) -> (tensor<4x4xf32>) | ||
| ``` | ||
|
|
||
| ## Alternatives | ||
|
mwhittaker marked this conversation as resolved.
|
||
|
|
||
| ### Fully Generic Async Ops | ||
|
|
||
| [This][generic_async_rfc] is a StableHLO RFC that proposes adding generic | ||
| `async_start` and `async_done` ops that can be used to call *any* function | ||
| asynchronously, not just collectives. Here's an example from the RFC that | ||
| performs an asynchronous add: | ||
|
|
||
| ```text | ||
| // %init_i: 2 | ||
| // %init_sum: 3 | ||
| %future = "stablehlo.async_start"( | ||
| %init_i as %arg0: tensor<i64>, | ||
| %init_sum as %arg1: tensor<i64>) | ||
| { | ||
| %new_sum = stablehlo.add %arg1, %arg0 : tensor<i64> | ||
| stablehlo.return %new_sum : tensor<i64> | ||
| } : (tensor<i64>, tensor<i64>) -> async<tensor<i64>> | ||
|
|
||
| %result = "stablehlo.async_done"(%future): async<tensor<i64>> -> tensor<i64> | ||
| // %result: 5 | ||
| ``` | ||
|
|
||
| This RFC proposes something simpler yet less powerful. In the future, we | ||
| could migrate to fully generic async ops. | ||
|
|
||
| ### Explicit Start/Done Pairs | ||
|
mwhittaker marked this conversation as resolved.
|
||
|
|
||
| Rather than introducing `async_start` and `async_done`, we could introduce six | ||
| new **start ops**: | ||
|
|
||
| - `all_gather_start` | ||
| - `all_reduce_start` | ||
| - `all_to_all_start` | ||
| - `collective_broadcast_start` | ||
| - `collective_permute_start` | ||
| - `reduce_scatter_start` | ||
|
|
||
| These ops are identical to their non-asynchronous counterparts. They take the | ||
| same arguments and have the same constraints. The only difference is that they | ||
| return futures. Here's an example: | ||
|
|
||
| ```text | ||
| %future = "stablehlo.collective_permute_start"(%operand) { | ||
| source_target_pairs = dense<[[0, 1], [1, 2]]> : tensor<2x2xi64>, | ||
| channel_handle = #stablehlo.channel_handle<handle = 0, type = 0> | ||
| } : (tensor<2x2xi64>) -> future<tensor<2x2xi64>> | ||
| ``` | ||
|
|
||
| We could also introduce six **done ops**. | ||
|
|
||
| - `all_gather_done` | ||
| - `all_reduce_done` | ||
| - `all_to_all _done` | ||
| - `collective_broadcast_done` | ||
| - `collective_permute _done` | ||
| - `reduce_scatter_done` | ||
|
|
||
| A done op takes a `future<T>` as an argument and returns a `T`. Continuing the | ||
| example above: | ||
|
|
||
| ```text | ||
| %result = "stablehlo.collective_permute_done"(%future) : (future<tensor<2x2xi64>>) -> tensor<2x2xi64> | ||
| ``` | ||
|
|
||
| ### Tensors Instead of Futures | ||
|
|
||
| Start ops could return regular tensors instead of futures. The value of these | ||
| tensors, however, would be indeterminate. The tensors should not be used in any | ||
| way besides as arguments to done ops. Here's an example: | ||
|
|
||
| ```text | ||
| %indeterminate = "stablehlo.collective_permute_start"(%operand) { | ||
| source_target_pairs = dense<[[0, 1], [1, 2]]> : tensor<2x2xi64>, | ||
| channel_handle = #stablehlo.channel_handle<handle = 0, type = 0> | ||
| } : (tensor<2x2xi64>) -> tensor<2x2xi64> | ||
| %result = "stablehlo.collective_permute_done"(%indeterminate) : (tensor<2x2xi64>) -> tensor<2x2xi64> | ||
| ``` | ||
|
|
||
| This approach mirrors how HLO represents asynchronous ops. It also avoids | ||
| introducing a new future type. However, it is less type-safe. | ||
|
|
||
| ### Collective in Types | ||
|
|
||
| This RFC has every collective return the same future type. Thus, the following | ||
| code is well-typed but erroneous. | ||
|
|
||
| ```text | ||
| %future = "stablehlo.collective_permute_start"(%operand) { | ||
| source_target_pairs = dense<[[0, 1], [1, 2]]> : tensor<2x2xi64>, | ||
| channel_handle = #stablehlo.channel_handle<handle = 0, type = 0> | ||
| } : (tensor<2x2xi64>) -> future<tensor<2x2xi64>> | ||
| %result = "stablehlo.all_reduce_done"(%future) : (future<tensor<2x2xi64>>) -> tensor<2x2xi64> | ||
| ``` | ||
|
|
||
| We could instead introduce a separate future type for every collective. For | ||
| example, `collective_permute_start` could return a | ||
| `collective_permute_future<...>`, and `collective_permute_done` could take a | ||
| `collective_permute_future<...>` as an argument. | ||
|
|
||
| This would introduce more type safety. | ||
|
|
||
| [async_collective_creator]: https://github.com/openxla/xla/blob/391c1c5fdadde89ee81886495d32dc32f9238af1/xla/hlo/transforms/collectives/async_collective_creator.h#L38 | ||
| [async_hlo]: https://openxla.org/xla/async_ops | ||
| [discussion_thread]: https://github.com/openxla/stablehlo/pull/2897/changes | ||
| [generic_async_rfc]: https://github.com/openxla/stablehlo/pull/2551 | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.