|
3 | 3 | import unittest
|
4 | 4 |
|
5 | 5 | import torch
|
| 6 | +import numpy as np |
6 | 7 | from torch import nn as nn
|
7 | 8 |
|
8 | 9 | import torch_xla
|
|
22 | 23 |
|
23 | 24 | class PallasTest(unittest.TestCase):
|
24 | 25 |
|
25 |
| - def _attention(self, q, k, v): |
| 26 | + # This is to create a diagonal mask where only elements within the same segment |
| 27 | + # can attend to each other. Since the mask is to mask out the unrelevant parts, |
| 28 | + # therefore we use != instead of ==. |
| 29 | + def _make_attention_mask_from_segment_ids(self, q_segment_ids, |
| 30 | + kv_segment_ids): |
| 31 | + return q_segment_ids.view(q_segment_ids.shape[0], 1, |
| 32 | + q_segment_ids.shape[1], 1) != kv_segment_ids.view( |
| 33 | + kv_segment_ids.shape[0], 1, 1, |
| 34 | + kv_segment_ids.shape[1]) |
| 35 | + |
| 36 | + def _attention(self, q, k, v, *, attn_mask=None, ab=None): |
26 | 37 | attn_weight = q @ k.transpose(-2, -1)
|
| 38 | + if attn_mask is not None: |
| 39 | + # Masked out the unrelevant parts. |
| 40 | + attn_weight = attn_weight.masked_fill(attn_mask, |
| 41 | + torch.finfo(attn_weight.dtype).min) |
| 42 | + if ab is not None: |
| 43 | + attn_weight = attn_weight + ab |
27 | 44 | attn_weight = nn.functional.softmax(attn_weight, dim=-1)
|
28 | 45 | attn_output = attn_weight @ v
|
29 | 46 | return attn_output
|
@@ -98,6 +115,117 @@ def test_flash_attention_backward_spmd_data_parallel(self):
|
98 | 115 | self.assertTrue(torch.allclose(i[0].grad.cpu(), i[1].cpu(), atol=1e-05))
|
99 | 116 | jax.config.update('jax_default_matmul_precision', "default")
|
100 | 117 |
|
| 118 | + @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3, |
| 119 | + "This test only works on TPUv3+.") |
| 120 | + def test_flash_attention_wrapper_segment_ids_spmd(self): |
| 121 | + from torch_xla.experimental.custom_kernel import flash_attention |
| 122 | + from jax.experimental.pallas.ops.tpu.flash_attention import flash_attention as jax_flash_attention, SegmentIds |
| 123 | + xs.set_global_mesh(xs.get_1d_mesh("data")) |
| 124 | + |
| 125 | + q = torch.randn(3, 2, 128, 4) |
| 126 | + k = torch.randn(3, 2, 128, 4) |
| 127 | + v = torch.randn(3, 2, 128, 4) |
| 128 | + zeros = torch.zeros(3, 32) |
| 129 | + segment_ids = torch.cat([zeros, zeros + 1, zeros + 2, zeros + 3], dim=1) |
| 130 | + segment_ids_xla = segment_ids.to("xla") |
| 131 | + # only shard data dimension |
| 132 | + o = flash_attention( |
| 133 | + q.to("xla"), |
| 134 | + k.to("xla"), |
| 135 | + v.to("xla"), |
| 136 | + False, |
| 137 | + segment_ids_xla, |
| 138 | + segment_ids.to("xla"), |
| 139 | + partition_spec=("data", None, None, None)) |
| 140 | + self.assertEqual( |
| 141 | + torch_xla._XLAC._get_xla_sharding_spec(o), |
| 142 | + f"{{devices=[{xr.global_runtime_device_count()},1,1,1]0,1,2,3}}") |
| 143 | + |
| 144 | + jax_q = jnp.array(q.numpy(), dtype=jnp.float32) |
| 145 | + jax_k = jnp.array(k.numpy(), dtype=jnp.float32) |
| 146 | + jax_v = jnp.array(v.numpy(), dtype=jnp.float32) |
| 147 | + jax_segment_ids = jnp.array(segment_ids.numpy(), dtype=jnp.float32) |
| 148 | + expected_o = torch.from_numpy( |
| 149 | + np.array( |
| 150 | + jax_flash_attention( |
| 151 | + jax_q, |
| 152 | + jax_k, |
| 153 | + jax_v, |
| 154 | + segment_ids=SegmentIds(jax_segment_ids, jax_segment_ids), |
| 155 | + ))) |
| 156 | + |
| 157 | + self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu(), atol=1e-05)) |
| 158 | + jax.config.update('jax_default_matmul_precision', "default") |
| 159 | + |
| 160 | + @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3, |
| 161 | + "This test only works on TPUv3+.") |
| 162 | + def test_flash_attention_backward_segment_ids_spmd(self): |
| 163 | + jax.config.update("jax_default_matmul_precision", "highest") |
| 164 | + from torch_xla.experimental.custom_kernel import flash_attention |
| 165 | + n_devices = xr.global_runtime_device_count() |
| 166 | + xs.set_global_mesh(xs.get_1d_mesh("data")) |
| 167 | + |
| 168 | + torch.manual_seed(42) |
| 169 | + q = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") |
| 170 | + k = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") |
| 171 | + v = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") |
| 172 | + zeros = torch.zeros(4, 32).to("xla") |
| 173 | + segment_ids = torch.cat([zeros, zeros + 1, zeros + 2, zeros + 3], dim=1) |
| 174 | + q.retain_grad() |
| 175 | + k.retain_grad() |
| 176 | + v.retain_grad() |
| 177 | + |
| 178 | + o = flash_attention( |
| 179 | + q, |
| 180 | + k, |
| 181 | + v, |
| 182 | + False, |
| 183 | + segment_ids, |
| 184 | + segment_ids, |
| 185 | + partition_spec=("data", None, None, None)) |
| 186 | + loss = o.sum() |
| 187 | + loss.backward() |
| 188 | + q_grad = q.grad |
| 189 | + k_grad = k.grad |
| 190 | + v_grad = v.grad |
| 191 | + self.assertEqual( |
| 192 | + torch_xla._XLAC._get_xla_sharding_spec(o), |
| 193 | + f"{{devices=[{n_devices},1,1,1]0,1,2,3}}") |
| 194 | + self.assertEqual( |
| 195 | + torch_xla._XLAC._get_xla_sharding_spec(q_grad), |
| 196 | + f"{{devices=[{n_devices},1,1,1]0,1,2,3}}") |
| 197 | + self.assertEqual( |
| 198 | + torch_xla._XLAC._get_xla_sharding_spec(k_grad), |
| 199 | + f"{{devices=[{n_devices},1,1,1]0,1,2,3}}") |
| 200 | + self.assertEqual( |
| 201 | + torch_xla._XLAC._get_xla_sharding_spec(v_grad), |
| 202 | + f"{{devices=[{n_devices},1,1,1]0,1,2,3}}") |
| 203 | + torch_xla.sync() |
| 204 | + |
| 205 | + torch.manual_seed(42) |
| 206 | + q = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") |
| 207 | + k = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") |
| 208 | + v = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") |
| 209 | + zeros = torch.zeros(4, 32).to("xla") |
| 210 | + segment_ids = torch.cat([zeros, zeros + 1, zeros + 2, zeros + 3], dim=1) |
| 211 | + q.retain_grad() |
| 212 | + k.retain_grad() |
| 213 | + v.retain_grad() |
| 214 | + |
| 215 | + o = self._attention( |
| 216 | + q, |
| 217 | + k, |
| 218 | + v, |
| 219 | + attn_mask=self._make_attention_mask_from_segment_ids( |
| 220 | + segment_ids, segment_ids)) |
| 221 | + loss = o.sum() |
| 222 | + loss.backward() |
| 223 | + xm.mark_step() |
| 224 | + |
| 225 | + for i in [(q, q_grad), (k, k_grad), (v, v_grad)]: |
| 226 | + self.assertTrue(torch.allclose(i[0].grad.cpu(), i[1].cpu(), atol=1e-05)) |
| 227 | + jax.config.update("jax_default_matmul_precision", "default") |
| 228 | + |
101 | 229 |
|
102 | 230 | if __name__ == '__main__':
|
103 | 231 | logging.getLogger().setLevel(logging.INFO)
|
|
0 commit comments