Conversation
felixwqp
left a comment
There was a problem hiding this comment.
Can I assume implicitly this will only work for shard_map-based sharding? like how ragged_all_to_all is used?
I'm not sure. I haven't thought that far ahead. When I support these async collectives in JAX, though, I do plan on only supporting shard_map at first. |
89da850 to
3926874
Compare
|
Would this RFC extend to async dynamic-slice/dynamic-update-slice? |
This should naturally extend to these ops. I'm OK to bring them in scope under the umbrella of "known ops that we want to have an async decomposition by a backend" |
See openxla/stablehlo#2897 for context. This CL introduces a `stablehlo.future` type and `stablehlo.async_start` and `stablehlo.async_done`. It does not add any translation for them. That will come in a later change. I also did not make `async_start` and `async_done` variadic for now, even though some collectives are variadic. We can add support for that later if we need it. PiperOrigin-RevId: 874740969
See openxla/stablehlo#2897 for context. This CL introduces a `stablehlo.future` type and `stablehlo.async_start` and `stablehlo.async_done`. It does not add any translation for them. That will come in a later change. I also did not make `async_start` and `async_done` variadic for now, even though some collectives are variadic. We can add support for that later if we need it. PiperOrigin-RevId: 874740969
|
I updated the RFC to include slice ops. I also made things less variadic for now. We can change that later if needed. |
See openxla/stablehlo#2897 for context. This CL introduces a `stablehlo.future` type and `stablehlo.async_start` and `stablehlo.async_done`. It does not add any translation for them. That will come in a later change. I also did not make `async_start` and `async_done` variadic for now, even though some collectives are variadic. We can add support for that later if we need it. PiperOrigin-RevId: 874740969
See openxla/stablehlo#2897 for context. This CL introduces a `stablehlo.future` type and `stablehlo.async_start` and `stablehlo.async_done`. It does not add any translation for them. That will come in a later change. I also did not make `async_start` and `async_done` variadic for now, even though some collectives are variadic. We can add support for that later if we need it. PiperOrigin-RevId: 874740969
See openxla/stablehlo#2897 for context. This CL introduces a `stablehlo.future` type and `stablehlo.async_start` and `stablehlo.async_done`. It does not add any translation for them. That will come in a later change. I also did not make `async_start` and `async_done` variadic for now, even though some collectives are variadic. We can add support for that later if we need it. PiperOrigin-RevId: 874740969
See openxla/stablehlo#2897 for context. This CL introduces a `stablehlo.future` type and `stablehlo.async_start` and `stablehlo.async_done`. It does not add any translation for them. That will come in a later change. I also did not make `async_start` and `async_done` variadic for now, even though some collectives are variadic. We can add support for that later if we need it. PiperOrigin-RevId: 874740969
See openxla/stablehlo#2897 for context. This CL introduces a `stablehlo.future` type and `stablehlo.async_start` and `stablehlo.async_done`. It does not add any translation for them. That will come in a later change. I also did not make `async_start` and `async_done` variadic for now, even though some collectives are variadic. We can add support for that later if we need it. PiperOrigin-RevId: 874740969
See openxla/stablehlo#2897 for context. This CL introduces a `stablehlo.future` type and `stablehlo.async_start` and `stablehlo.async_done`. It does not add any translation for them. That will come in a later change. I also did not make `async_start` and `async_done` variadic for now, even though some collectives are variadic. We can add support for that later if we need it. PiperOrigin-RevId: 874740969
See openxla/stablehlo#2897 for context. This CL introduces a `stablehlo.future` type and `stablehlo.async_start` and `stablehlo.async_done`. It does not add any translation for them. That will come in a later change. I also did not make `async_start` and `async_done` variadic for now, even though some collectives are variadic. We can add support for that later if we need it. PiperOrigin-RevId: 874740969
See openxla/stablehlo#2897 for context. This CL introduces a `stablehlo.future` type and `stablehlo.async_start` and `stablehlo.async_done`. It does not add any translation for them. That will come in a later change. I also did not make `async_start` and `async_done` variadic for now, even though some collectives are variadic. We can add support for that later if we need it. PiperOrigin-RevId: 874740969
See openxla/stablehlo#2897 for context. This CL introduces a `stablehlo.future` type and `stablehlo.async_start` and `stablehlo.async_done`. It does not add any translation for them. That will come in a later change. I also did not make `async_start` and `async_done` variadic for now, even though some collectives are variadic. We can add support for that later if we need it. PiperOrigin-RevId: 890017379
See openxla/stablehlo#2897 for context. This CL introduces a `stablehlo.future` type and `stablehlo.async_start` and `stablehlo.async_done`. It does not add any translation for them. That will come in a later change. I also did not make `async_start` and `async_done` variadic for now, even though some collectives are variadic. We can add support for that later if we need it. PiperOrigin-RevId: 890017379
See openxla/stablehlo#2897 for context. This CL introduces a `stablehlo.future` type and `stablehlo.async_start` and `stablehlo.async_done`. It does not add any translation for them. That will come in a later change. I also did not make `async_start` and `async_done` variadic for now, even though some collectives are variadic. We can add support for that later if we need it. PiperOrigin-RevId: 890017379
No description provided.