|
22 | 22 |
|
23 | 23 | class PallasTest(unittest.TestCase):
|
24 | 24 |
|
25 |
| - def _attention(self, q, k, v): |
| 25 | + def _attention(self, q, k, v, *, attn_mask=None, ab=None): |
26 | 26 | attn_weight = q @ k.transpose(-2, -1)
|
| 27 | + if attn_mask is not None: |
| 28 | + # Masked out the unrelevant parts. |
| 29 | + attn_weight = attn_weight.masked_fill(attn_mask, |
| 30 | + torch.finfo(attn_weight.dtype).min) |
| 31 | + if ab is not None: |
| 32 | + attn_weight = attn_weight + ab |
27 | 33 | attn_weight = nn.functional.softmax(attn_weight, dim=-1)
|
28 | 34 | attn_output = attn_weight @ v
|
29 | 35 | return attn_output
|
@@ -98,6 +104,36 @@ def test_flash_attention_backward_spmd_data_parallel(self):
|
98 | 104 | self.assertTrue(torch.allclose(i[0].grad.cpu(), i[1].cpu(), atol=1e-05))
|
99 | 105 | jax.config.update('jax_default_matmul_precision', "default")
|
100 | 106 |
|
| 107 | + @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3, |
| 108 | + "This test only works on TPUv3+.") |
| 109 | + def test_flash_attention_spmd_data_parallel_with_segment_ids(self): |
| 110 | + jax.config.update('jax_default_matmul_precision', "highest") |
| 111 | + n_devices = xr.global_runtime_device_count() |
| 112 | + xs.set_global_mesh(xs.Mesh(range(n_devices), (n_devices, 1, 1, 1))) |
| 113 | + |
| 114 | + q = torch.randn(16, 32, 2048, 64).to("xla") |
| 115 | + k = torch.randn(16, 32, 128, 64).to("xla") |
| 116 | + v = torch.randn(16, 32, 128, 64).to("xla") |
| 117 | + q_segment_ids = torch.ones(16, 2048, dtype=torch.float32).to("xla") |
| 118 | + kv_segment_ids = torch.zeros(16, 1, 128, dtype=torch.float32).to("xla") |
| 119 | + # convert mask into a bias that can be added to attention scores: |
| 120 | + # (keep = +0, discard = -10000.0) |
| 121 | + kv_segment_ids[:8, :, 30:] = -10000.0 |
| 122 | + kv_segment_ids[8:, :, 60:] = -10000.0 |
| 123 | + |
| 124 | + o = flash_attention( |
| 125 | + q, k, v, q_segment_ids, kv_segment_ids, partition_spec=range(n_devices)) |
| 126 | + self.assertEqual( |
| 127 | + torch_xla._XLAC._get_xla_sharding_spec(o), |
| 128 | + f"{{devices=[{n_devices},1,1,1]0,1,2,3}}") |
| 129 | + |
| 130 | + attention_mask = kv_segment_ids.repeat_interleave(32, dim=0) |
| 131 | + attention_mask = attention_mask.view(16, 32, 1, 128) |
| 132 | + |
| 133 | + expected_o = self._attention(q, k, v, attn_mask=attention_mask) |
| 134 | + self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu(), atol=1e-05)) |
| 135 | + jax.config.update('jax_default_matmul_precision', "default") |
| 136 | + |
101 | 137 |
|
102 | 138 | if __name__ == '__main__':
|
103 | 139 | logging.getLogger().setLevel(logging.INFO)
|
|
0 commit comments