Skip to content

Commit

Permalink
Fix a bug in flash attention where kv_seq_len should divide block_k_m…
Browse files Browse the repository at this point in the history
…ajor. (#8671)
  • Loading branch information
zhangp365 authored Feb 10, 2025
1 parent c0afda3 commit cff9f4e
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 2 deletions.
15 changes: 15 additions & 0 deletions test/test_pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,21 @@ def test_flash_attention_wrapper(self):
expected_o = self._attention(q, k, v)
self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu(), atol=1e-05))

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
"This test only works on TPUv3+.")
@with_jax_high_precision
def test_flash_attention_wrapper_kv_and_ab_padding(self):
from torch_xla.experimental.custom_kernel import flash_attention

q = torch.randn(1, 2, 513, 4).to("xla")
k = torch.randn(1, 2, 513, 4).to("xla")
v = torch.randn(1, 2, 513, 4).to("xla")
ab = torch.randn(1, 2, 513, 513).to("xla")

o = flash_attention(q, k, v, ab=ab)
expected_o = self._attention(q, k, v, ab=ab)
self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu(), atol=1e-05))

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
"This test only works on TPUv3+.")
@with_jax_high_precision
Expand Down
20 changes: 20 additions & 0 deletions test/test_pallas_spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,26 @@ def test_flash_attention_spmd_data_parallel(self):
expected_o = self._attention(q, k, v)
self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu(), atol=1e-05))

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
"This test only works on TPUv3+.")
@with_jax_high_precision
def test_flash_attention_spmd_data_parallel_kv_and_ab_padding(self):
n_devices = xr.global_runtime_device_count()
xs.set_global_mesh(xs.Mesh(range(n_devices), (n_devices, 1, 1, 1)))

q = torch.randn(4, 2, 513, 4).to("xla")
k = torch.randn(4, 2, 513, 4).to("xla")
v = torch.randn(4, 2, 513, 4).to("xla")
ab = torch.randn(4, 2, 513, 513).to("xla")

o = flash_attention(q, k, v, ab=ab, partition_spec=range(n_devices))
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(o),
f"{{devices=[{n_devices},1,1,1]0,1,2,3}}")

expected_o = self._attention(q, k, v, ab=ab)
self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu(), atol=1e-05))

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
"This test only works on TPUv3+.")
@with_jax_high_precision
Expand Down
37 changes: 35 additions & 2 deletions torch_xla/experimental/custom_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,18 @@ def fa_custom_forward(
full_ab = ab.clone()
else:
full_ab = None

block_k_major = min(FlashAttention.DEFAULT_BLOCK_SIZES["block_k_major"],
k.shape[2])
block_k = min(FlashAttention.DEFAULT_BLOCK_SIZES["block_k"], k.shape[2])
k, k_pad_size = _pad_to_block_size(k, max(block_k_major, block_k), 2)
if k_pad_size > 0:
v, _ = _pad_to_block_size(v, max(block_k_major, block_k), 2)
if ab is None:
ab = torch.zeros((q.shape[0], q.shape[1], q.shape[2], q.shape[2]))
ab, _ = _pad_to_block_size(
ab, max(block_k_major, block_k), 3, padding_minus_inf=True)

if partition_spec is not None:
q_full_shape = q.shape
q = xs.enable_manual_sharding(q, partition_spec, mesh=mesh).global_tensor
Expand Down Expand Up @@ -295,8 +307,8 @@ def fa_custom_forward(
sm_scale,
min(FlashAttention.DEFAULT_BLOCK_SIZES["block_b"], q.shape[0]),
min(FlashAttention.DEFAULT_BLOCK_SIZES["block_q"], q.shape[2]),
min(FlashAttention.DEFAULT_BLOCK_SIZES["block_k_major"], k.shape[2]),
min(FlashAttention.DEFAULT_BLOCK_SIZES["block_k"], k.shape[2]),
block_k_major,
block_k,
False,
static_argnums=range(5, 13),
use_cache=True,
Expand Down Expand Up @@ -337,6 +349,27 @@ def fa_custom_forward(
return tuple(outs)


def _pad_to_block_size(
tensor: torch.Tensor,
block_size: int,
dim: int,
padding_minus_inf: bool = False) -> Tuple[torch.Tensor, int]:
size = tensor.shape[dim]
if size % block_size == 0:
return tensor, 0

pad_size = block_size - (size % block_size)
pad_shape = list(tensor.shape)
pad_shape[dim] = pad_size
padding = torch.full(
pad_shape,
torch.finfo(tensor.dtype).min if padding_minus_inf else 0,
dtype=tensor.dtype,
device=tensor.device)
padded = torch.cat([tensor, padding], dim=dim)
return padded, pad_size


@custom_op("xla::fa_custom_backward", mutates_args=())
def fa_custom_backward(
grad_output: torch.Tensor, q: torch.Tensor, k: torch.Tensor,
Expand Down

0 comments on commit cff9f4e

Please sign in to comment.