Make flash_attention
Dynamo/AOTAutograd traceable #8633
Description
One-line pitch: make torch_xla.experimental.custom_kernel.flash_attention
traceable by AOTAutograd/Dynamo, and that both forward/backward pass is supported.
Motivation
The torch_xla.experimental.custom_kernel.flash_attention
function calls into Jax to build a pallas kernel. If we naively trace this function using AOTAutograd, we'll end up tracing Jax operations with PyTorch fake tensors, which is clearly unsupported. Instead, we need to teach AOTAutograd to place a custom torch op that encapsulates the forward pass of the flash attention kernel as it's building the forward graph, and another corresponding custom op in the backward graph.
PyTorch/XLA has some initial support for AOTAutograd-friendly flash attention: if you trace the flash_attention_xla
custom op, the AOTAutograd graph will contain a single torch.ops.xla.flash_attention
node. However, it has some limitations:
No support for autograd/backward pass. The custom op returns tensors detached from the autograd graph.
Cannot pass in partition_spec and mesh which is required to use flash attention in SPMD modes.
I've prototyped 1 an AOTAutograd-friendly FlashAttention as a torch.autograd.Function
where:
- The forward of the Function calls into a custom op (fa_custom_forward) that returns both the attention output and the residuals.
- The backward of the Function calls into a custom op (fa_custom_backward) that takes in gradient w.r.t. attention output and the residuals, and returns gradient w.r.t q/k/v tensors.
- The partition_spec and mesh are hardcoded to what's used in Llama 3 2D sharding.
When AOTAutograd traces this function, instead of tracing into Jax, it will place a torch.ops.xla.fa_custom_forward
op in the forward graph, and a torch.ops.xla.fa_custom_backward
op in the backward graph.