diff --git a/test/test_pallas.py b/test/test_pallas.py index 8c96bae0301..4c291de3acc 100644 --- a/test/test_pallas.py +++ b/test/test_pallas.py @@ -242,6 +242,21 @@ def test_flash_attention_wrapper(self): expected_o = self._attention(q, k, v) self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu(), atol=1e-05)) + @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3, + "This test only works on TPUv3+.") + @with_jax_high_precision + def test_flash_attention_wrapper_kv_and_ab_padding(self): + from torch_xla.experimental.custom_kernel import flash_attention + + q = torch.randn(1, 2, 513, 4).to("xla") + k = torch.randn(1, 2, 513, 4).to("xla") + v = torch.randn(1, 2, 513, 4).to("xla") + ab = torch.randn(1, 2, 513, 513).to("xla") + + o = flash_attention(q, k, v, ab=ab) + expected_o = self._attention(q, k, v, ab=ab) + self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu(), atol=1e-05)) + @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3, "This test only works on TPUv3+.") @with_jax_high_precision diff --git a/test/test_pallas_spmd.py b/test/test_pallas_spmd.py index ba9b819c653..f9a48aab7fb 100644 --- a/test/test_pallas_spmd.py +++ b/test/test_pallas_spmd.py @@ -77,6 +77,26 @@ def test_flash_attention_spmd_data_parallel(self): expected_o = self._attention(q, k, v) self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu(), atol=1e-05)) + @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3, + "This test only works on TPUv3+.") + @with_jax_high_precision + def test_flash_attention_spmd_data_parallel_kv_and_ab_padding(self): + n_devices = xr.global_runtime_device_count() + xs.set_global_mesh(xs.Mesh(range(n_devices), (n_devices, 1, 1, 1))) + + q = torch.randn(4, 2, 513, 4).to("xla") + k = torch.randn(4, 2, 513, 4).to("xla") + v = torch.randn(4, 2, 513, 4).to("xla") + ab = torch.randn(4, 2, 513, 513).to("xla") + + o = flash_attention(q, k, v, ab=ab, partition_spec=range(n_devices)) + self.assertEqual( + torch_xla._XLAC._get_xla_sharding_spec(o), + f"{{devices=[{n_devices},1,1,1]0,1,2,3}}") + + expected_o = self._attention(q, k, v, ab=ab) + self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu(), atol=1e-05)) + @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3, "This test only works on TPUv3+.") @with_jax_high_precision diff --git a/torch_xla/experimental/custom_kernel.py b/torch_xla/experimental/custom_kernel.py index 45c7a4a6de0..e5299683951 100644 --- a/torch_xla/experimental/custom_kernel.py +++ b/torch_xla/experimental/custom_kernel.py @@ -248,6 +248,18 @@ def fa_custom_forward( full_ab = ab.clone() else: full_ab = None + + block_k_major = min(FlashAttention.DEFAULT_BLOCK_SIZES["block_k_major"], + k.shape[2]) + block_k = min(FlashAttention.DEFAULT_BLOCK_SIZES["block_k"], k.shape[2]) + k, k_pad_size = _pad_to_block_size(k, max(block_k_major, block_k), 2) + if k_pad_size > 0: + v, _ = _pad_to_block_size(v, max(block_k_major, block_k), 2) + if ab is None: + ab = torch.zeros((q.shape[0], q.shape[1], q.shape[2], q.shape[2])) + ab, _ = _pad_to_block_size( + ab, max(block_k_major, block_k), 3, padding_minus_inf=True) + if partition_spec is not None: q_full_shape = q.shape q = xs.enable_manual_sharding(q, partition_spec, mesh=mesh).global_tensor @@ -295,8 +307,8 @@ def fa_custom_forward( sm_scale, min(FlashAttention.DEFAULT_BLOCK_SIZES["block_b"], q.shape[0]), min(FlashAttention.DEFAULT_BLOCK_SIZES["block_q"], q.shape[2]), - min(FlashAttention.DEFAULT_BLOCK_SIZES["block_k_major"], k.shape[2]), - min(FlashAttention.DEFAULT_BLOCK_SIZES["block_k"], k.shape[2]), + block_k_major, + block_k, False, static_argnums=range(5, 13), use_cache=True, @@ -337,6 +349,27 @@ def fa_custom_forward( return tuple(outs) +def _pad_to_block_size( + tensor: torch.Tensor, + block_size: int, + dim: int, + padding_minus_inf: bool = False) -> Tuple[torch.Tensor, int]: + size = tensor.shape[dim] + if size % block_size == 0: + return tensor, 0 + + pad_size = block_size - (size % block_size) + pad_shape = list(tensor.shape) + pad_shape[dim] = pad_size + padding = torch.full( + pad_shape, + torch.finfo(tensor.dtype).min if padding_minus_inf else 0, + dtype=tensor.dtype, + device=tensor.device) + padded = torch.cat([tensor, padding], dim=dim) + return padded, pad_size + + @custom_op("xla::fa_custom_backward", mutates_args=()) def fa_custom_backward( grad_output: torch.Tensor, q: torch.Tensor, k: torch.Tensor,