Skip to content

Commit cff9f4e

Browse files
authored
Fix a bug in flash attention where kv_seq_len should divide block_k_major. (#8671)
1 parent c0afda3 commit cff9f4e

File tree

3 files changed

+70
-2
lines changed

3 files changed

+70
-2
lines changed

test/test_pallas.py

+15
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,21 @@ def test_flash_attention_wrapper(self):
242242
expected_o = self._attention(q, k, v)
243243
self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu(), atol=1e-05))
244244

245+
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
246+
"This test only works on TPUv3+.")
247+
@with_jax_high_precision
248+
def test_flash_attention_wrapper_kv_and_ab_padding(self):
249+
from torch_xla.experimental.custom_kernel import flash_attention
250+
251+
q = torch.randn(1, 2, 513, 4).to("xla")
252+
k = torch.randn(1, 2, 513, 4).to("xla")
253+
v = torch.randn(1, 2, 513, 4).to("xla")
254+
ab = torch.randn(1, 2, 513, 513).to("xla")
255+
256+
o = flash_attention(q, k, v, ab=ab)
257+
expected_o = self._attention(q, k, v, ab=ab)
258+
self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu(), atol=1e-05))
259+
245260
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
246261
"This test only works on TPUv3+.")
247262
@with_jax_high_precision

test/test_pallas_spmd.py

+20
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,26 @@ def test_flash_attention_spmd_data_parallel(self):
7777
expected_o = self._attention(q, k, v)
7878
self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu(), atol=1e-05))
7979

80+
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
81+
"This test only works on TPUv3+.")
82+
@with_jax_high_precision
83+
def test_flash_attention_spmd_data_parallel_kv_and_ab_padding(self):
84+
n_devices = xr.global_runtime_device_count()
85+
xs.set_global_mesh(xs.Mesh(range(n_devices), (n_devices, 1, 1, 1)))
86+
87+
q = torch.randn(4, 2, 513, 4).to("xla")
88+
k = torch.randn(4, 2, 513, 4).to("xla")
89+
v = torch.randn(4, 2, 513, 4).to("xla")
90+
ab = torch.randn(4, 2, 513, 513).to("xla")
91+
92+
o = flash_attention(q, k, v, ab=ab, partition_spec=range(n_devices))
93+
self.assertEqual(
94+
torch_xla._XLAC._get_xla_sharding_spec(o),
95+
f"{{devices=[{n_devices},1,1,1]0,1,2,3}}")
96+
97+
expected_o = self._attention(q, k, v, ab=ab)
98+
self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu(), atol=1e-05))
99+
80100
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
81101
"This test only works on TPUv3+.")
82102
@with_jax_high_precision

torch_xla/experimental/custom_kernel.py

+35-2
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,18 @@ def fa_custom_forward(
248248
full_ab = ab.clone()
249249
else:
250250
full_ab = None
251+
252+
block_k_major = min(FlashAttention.DEFAULT_BLOCK_SIZES["block_k_major"],
253+
k.shape[2])
254+
block_k = min(FlashAttention.DEFAULT_BLOCK_SIZES["block_k"], k.shape[2])
255+
k, k_pad_size = _pad_to_block_size(k, max(block_k_major, block_k), 2)
256+
if k_pad_size > 0:
257+
v, _ = _pad_to_block_size(v, max(block_k_major, block_k), 2)
258+
if ab is None:
259+
ab = torch.zeros((q.shape[0], q.shape[1], q.shape[2], q.shape[2]))
260+
ab, _ = _pad_to_block_size(
261+
ab, max(block_k_major, block_k), 3, padding_minus_inf=True)
262+
251263
if partition_spec is not None:
252264
q_full_shape = q.shape
253265
q = xs.enable_manual_sharding(q, partition_spec, mesh=mesh).global_tensor
@@ -295,8 +307,8 @@ def fa_custom_forward(
295307
sm_scale,
296308
min(FlashAttention.DEFAULT_BLOCK_SIZES["block_b"], q.shape[0]),
297309
min(FlashAttention.DEFAULT_BLOCK_SIZES["block_q"], q.shape[2]),
298-
min(FlashAttention.DEFAULT_BLOCK_SIZES["block_k_major"], k.shape[2]),
299-
min(FlashAttention.DEFAULT_BLOCK_SIZES["block_k"], k.shape[2]),
310+
block_k_major,
311+
block_k,
300312
False,
301313
static_argnums=range(5, 13),
302314
use_cache=True,
@@ -337,6 +349,27 @@ def fa_custom_forward(
337349
return tuple(outs)
338350

339351

352+
def _pad_to_block_size(
353+
tensor: torch.Tensor,
354+
block_size: int,
355+
dim: int,
356+
padding_minus_inf: bool = False) -> Tuple[torch.Tensor, int]:
357+
size = tensor.shape[dim]
358+
if size % block_size == 0:
359+
return tensor, 0
360+
361+
pad_size = block_size - (size % block_size)
362+
pad_shape = list(tensor.shape)
363+
pad_shape[dim] = pad_size
364+
padding = torch.full(
365+
pad_shape,
366+
torch.finfo(tensor.dtype).min if padding_minus_inf else 0,
367+
dtype=tensor.dtype,
368+
device=tensor.device)
369+
padded = torch.cat([tensor, padding], dim=dim)
370+
return padded, pad_size
371+
372+
340373
@custom_op("xla::fa_custom_backward", mutates_args=())
341374
def fa_custom_backward(
342375
grad_output: torch.Tensor, q: torch.Tensor, k: torch.Tensor,

0 commit comments

Comments
 (0)