Skip to content

Commit 9ae017e

Browse files
authored
Dynamo/AOTAutograd traceable flash attention (#8654)
1 parent 82d3504 commit 9ae017e

File tree

4 files changed

+707
-285
lines changed

4 files changed

+707
-285
lines changed

test/test_pallas.py

+167-23
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import logging
2+
import sys
23
import unittest
4+
from absl.testing import parameterized
35

46
import torch
57
from torch import nn as nn
@@ -19,7 +21,20 @@
1921
from jax.experimental import pallas as pl
2022

2123

22-
class PallasTest(unittest.TestCase):
24+
def with_jax_high_precision(func):
25+
26+
def wrapper(*args, **kwargs):
27+
jax.config.update('jax_default_matmul_precision', "highest")
28+
try:
29+
result = func(*args, **kwargs)
30+
finally:
31+
jax.config.update('jax_default_matmul_precision', "default")
32+
return result
33+
34+
return wrapper
35+
36+
37+
class PallasTest(parameterized.TestCase):
2338

2439
# This is to create a diagonal mask where only elements within the same segment
2540
# can attend to each other. Since the mask is to mask out the unrelevant parts,
@@ -33,12 +48,11 @@ def _make_attention_mask_from_segment_ids(self, q_segment_ids,
3348

3449
def _attention(self, q, k, v, *, attn_mask=None, ab=None):
3550
attn_weight = q @ k.transpose(-2, -1)
36-
if attn_mask is not None:
37-
# Masked out the unrelevant parts.
38-
attn_weight = attn_weight.masked_fill(attn_mask,
39-
torch.finfo(attn_weight.dtype).min)
4051
if ab is not None:
4152
attn_weight = attn_weight + ab
53+
if attn_mask is not None:
54+
attn_weight = attn_weight.masked_fill(attn_mask.bool(),
55+
torch.finfo(attn_weight.dtype).min)
4256
attn_weight = nn.functional.softmax(attn_weight, dim=-1)
4357
attn_output = attn_weight @ v
4458
return attn_output
@@ -216,8 +230,8 @@ def test_tpu_custom_call_pallas_wrap_flash_attention(self):
216230

217231
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
218232
"This test only works on TPUv3+.")
233+
@with_jax_high_precision
219234
def test_flash_attention_wrapper(self):
220-
jax.config.update("jax_default_matmul_precision", "highest")
221235
from torch_xla.experimental.custom_kernel import flash_attention
222236

223237
q = torch.randn(3, 2, 128, 4).to("xla")
@@ -227,12 +241,11 @@ def test_flash_attention_wrapper(self):
227241
o = flash_attention(q, k, v)
228242
expected_o = self._attention(q, k, v)
229243
self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu(), atol=1e-05))
230-
jax.config.update("jax_default_matmul_precision", "default")
231244

232245
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
233246
"This test only works on TPUv3+.")
247+
@with_jax_high_precision
234248
def test_flash_attention_wrapper_with_dynamo(self):
235-
jax.config.update("jax_default_matmul_precision", "highest")
236249
from torch_xla.experimental.custom_kernel import flash_attention
237250

238251
def flash_attention_wrapper(q, k, v, causal=False):
@@ -253,12 +266,11 @@ def flash_attention_wrapper(q, k, v, causal=False):
253266
# therefore it speeds up the compute but also changes the output.
254267
self.assertFalse(
255268
torch.allclose(o_with_causal.cpu(), expected_o.cpu(), atol=1e-05))
256-
jax.config.update("jax_default_matmul_precision", "default")
257269

258270
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
259271
"This test only works on TPUv3+.")
272+
@with_jax_high_precision
260273
def test_flash_attention_wrapper_causal(self):
261-
jax.config.update("jax_default_matmul_precision", "highest")
262274
from torch_xla.experimental.custom_kernel import flash_attention
263275

264276
q = torch.randn(3, 2, 128, 4).to("xla")
@@ -270,7 +282,6 @@ def test_flash_attention_wrapper_causal(self):
270282
o = flash_attention(q, k, v, causal=True)
271283
expected_o = self._attention(q, k, v)
272284
self.assertFalse(torch.allclose(o.cpu(), expected_o.cpu()))
273-
jax.config.update("jax_default_matmul_precision", "default")
274285

275286
@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
276287
def test_multiple_returns(self):
@@ -450,8 +461,8 @@ def test__flash_attention_bwd_dkv(self):
450461

451462
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
452463
"This test only works on TPUv3+.")
464+
@with_jax_high_precision
453465
def test_flash_attention_backward(self):
454-
jax.config.update("jax_default_matmul_precision", "highest")
455466
from torch_xla.experimental.custom_kernel import flash_attention
456467

457468
torch.manual_seed(42)
@@ -486,7 +497,6 @@ def test_flash_attention_backward(self):
486497

487498
for i in [(q, q_grad), (k, k_grad), (v, v_grad)]:
488499
self.assertTrue(torch.allclose(i[0].grad.cpu(), i[1].cpu(), atol=1e-05))
489-
jax.config.update("jax_default_matmul_precision", "default")
490500

491501
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4,
492502
"This test only works on TPUv4+.")
@@ -1026,8 +1036,8 @@ def test_flash_attention_wrapper_segment_ids_1(self):
10261036

10271037
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
10281038
"This test only works on TPUv3+.")
1039+
@with_jax_high_precision
10291040
def test_flash_attention_wrapper_segment_ids_2(self):
1030-
jax.config.update("jax_default_matmul_precision", "highest")
10311041
from torch_xla.experimental.custom_kernel import flash_attention
10321042

10331043
q = torch.randn(3, 2, 128, 4).to("xla")
@@ -1093,12 +1103,11 @@ def test_flash_attention_backward_segment_ids(self):
10931103

10941104
for i in [(q, q_grad), (k, k_grad), (v, v_grad)]:
10951105
self.assertTrue(torch.allclose(i[0].grad.cpu(), i[1].cpu(), atol=1e-05))
1096-
jax.config.update("jax_default_matmul_precision", "default")
10971106

10981107
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
10991108
"This test only works on TPUv3+.")
1109+
@with_jax_high_precision
11001110
def test_flash_attention_wrapper_sm_scale(self):
1101-
jax.config.update("jax_default_matmul_precision", "highest")
11021111
from torch_xla.experimental.custom_kernel import flash_attention
11031112

11041113
q = torch.randn(3, 2, 128, 4).to("xla")
@@ -1109,12 +1118,11 @@ def test_flash_attention_wrapper_sm_scale(self):
11091118

11101119
expected_o = self._attention(q * sm_scale, k, v)
11111120
self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu(), atol=1e-05))
1112-
jax.config.update("jax_default_matmul_precision", "default")
11131121

11141122
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
11151123
"This test only works on TPUv3+.")
1124+
@with_jax_high_precision
11161125
def test_flash_attention_sm_scale_backward(self):
1117-
jax.config.update("jax_default_matmul_precision", "highest")
11181126
from torch_xla.experimental.custom_kernel import flash_attention
11191127

11201128
torch.manual_seed(42)
@@ -1151,12 +1159,11 @@ def test_flash_attention_sm_scale_backward(self):
11511159
# Hmm, the gradients are the same even the autograd graph seems different.
11521160
for i in [(q, q_grad), (k, k_grad), (v, v_grad)]:
11531161
self.assertTrue(torch.allclose(i[0].grad.cpu(), i[1].cpu(), atol=1e-05))
1154-
jax.config.update("jax_default_matmul_precision", "default")
11551162

11561163
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
11571164
"This test only works on TPUv3+.")
1165+
@with_jax_high_precision
11581166
def test_flash_attention_ab(self):
1159-
jax.config.update("jax_default_matmul_precision", "highest")
11601167
from torch_xla.experimental.custom_kernel import flash_attention
11611168

11621169
q = torch.randn(3, 2, 128, 4).to("xla")
@@ -1208,12 +1215,11 @@ def test_flash_attention_ab_backward_1(self):
12081215

12091216
for i in [(q, q_grad), (k, k_grad), (v, v_grad)]:
12101217
self.assertTrue(torch.allclose(i[0].grad.cpu(), i[1].cpu(), atol=1e-05))
1211-
jax.config.update("jax_default_matmul_precision", "default")
12121218

12131219
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
12141220
"This test only works on TPUv3+.")
1221+
@with_jax_high_precision
12151222
def test_flash_attention_ab_backward_2(self):
1216-
jax.config.update("jax_default_matmul_precision", "highest")
12171223
from torch_xla.experimental.custom_kernel import flash_attention
12181224

12191225
torch.manual_seed(42)
@@ -1251,7 +1257,145 @@ def test_flash_attention_ab_backward_2(self):
12511257

12521258
for i in [(q, q_grad), (k, k_grad), (v, v_grad), (ab, ab_grad)]:
12531259
self.assertTrue(torch.allclose(i[0].grad.cpu(), i[1].cpu(), atol=1e-05))
1254-
jax.config.update("jax_default_matmul_precision", "default")
1260+
1261+
@parameterized.named_parameters(('off', False), ('on', True))
1262+
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4,
1263+
"This test only works on TPUv4+.")
1264+
@with_jax_high_precision
1265+
def test_flash_attention_forward_aot_autograd_traceable_causal(self, causal):
1266+
from functorch.compile import aot_function, make_boxed_func
1267+
from torch_xla.experimental.custom_kernel import flash_attention
1268+
import torch_xla.core.xla_model as xm
1269+
1270+
def compiler(gm, _):
1271+
return make_boxed_func(gm)
1272+
1273+
torch.manual_seed(42)
1274+
q = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
1275+
k = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
1276+
v = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
1277+
q.retain_grad()
1278+
k.retain_grad()
1279+
v.retain_grad()
1280+
B, N, SEQ, H = q.size()
1281+
q_segment_ids = None
1282+
kv_segment_ids = None
1283+
sm_scale = 1.0
1284+
1285+
compiled_flash_attention = aot_function(
1286+
flash_attention, fw_compiler=compiler)
1287+
o_actual = compiled_flash_attention(q, k, v, causal, q_segment_ids,
1288+
kv_segment_ids, sm_scale)
1289+
xm.mark_step()
1290+
if causal:
1291+
attention_mask = torch.triu(torch.ones(SEQ, SEQ), diagonal=1).to("xla")
1292+
else:
1293+
attention_mask = None
1294+
1295+
expected_output = self._attention(q, k, v, attn_mask=attention_mask)
1296+
xm.mark_step()
1297+
self.assertTrue(
1298+
torch.allclose(o_actual.cpu(), expected_output.cpu(), atol=1e-5))
1299+
1300+
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4,
1301+
"This test only works on TPUv4+.")
1302+
@with_jax_high_precision
1303+
def test_flash_attention_forward_aot_autograd_traceable_ab(self):
1304+
from functorch.compile import aot_function, make_boxed_func
1305+
from torch_xla.experimental.custom_kernel import flash_attention
1306+
import torch_xla.core.xla_model as xm
1307+
1308+
def compiler(gm, _):
1309+
return make_boxed_func(gm)
1310+
1311+
torch.manual_seed(42)
1312+
q = torch.randn(4, 2, 128, 8).to("xla")
1313+
k = torch.randn(4, 2, 128, 8).to("xla")
1314+
v = torch.randn(4, 2, 128, 8).to("xla")
1315+
B, N, SEQ, H = q.size()
1316+
causal = False
1317+
q_segment_ids = None
1318+
kv_segment_ids = None
1319+
sm_scale = 1.0
1320+
mask = (torch.rand(4, 2, 128, 128) > 0.5).to("xla")
1321+
ab = torch.ones(4, 2, 128, 128).to("xla")
1322+
ab = ab.masked_fill(mask, torch.finfo(ab.dtype).min)
1323+
1324+
compiled_flash_attention = aot_function(
1325+
flash_attention, fw_compiler=compiler)
1326+
o_actual = compiled_flash_attention(
1327+
q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab=ab)
1328+
xm.mark_step()
1329+
1330+
expected_output = self._attention(q, k, v, ab=ab)
1331+
xm.mark_step()
1332+
self.assertTrue(
1333+
torch.allclose(o_actual.cpu(), expected_output.cpu(), atol=1e-5))
1334+
1335+
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4,
1336+
"This test only works on TPUv4+.")
1337+
@with_jax_high_precision
1338+
def test_flash_attention_backward_aot_autograd_traceable(self):
1339+
from functorch.compile import aot_function, make_boxed_func
1340+
from torch_xla.experimental.custom_kernel import flash_attention
1341+
import torch_xla.core.xla_model as xm
1342+
1343+
def compiler(gm, _):
1344+
return make_boxed_func(gm)
1345+
1346+
torch.manual_seed(42)
1347+
q = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
1348+
k = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
1349+
v = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
1350+
q.retain_grad()
1351+
k.retain_grad()
1352+
v.retain_grad()
1353+
B, N, SEQ, H = q.size()
1354+
mask = (torch.rand(4, 2, 128, 128) > 0.5).to("xla")
1355+
ab = torch.ones(4, 2, 128, 128).to("xla")
1356+
ab = ab.masked_fill(mask, torch.finfo(ab.dtype).min).requires_grad_()
1357+
ab.retain_grad()
1358+
1359+
causal = False
1360+
q_segment_ids = None
1361+
kv_segment_ids = None
1362+
sm_scale = 1.0
1363+
compiled_flash_attention = aot_function(
1364+
flash_attention, fw_compiler=compiler)
1365+
o_actual = compiled_flash_attention(
1366+
q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab=ab)
1367+
loss = o_actual.sum()
1368+
loss.backward()
1369+
xm.mark_step()
1370+
q_grad = q.grad
1371+
k_grad = k.grad
1372+
v_grad = v.grad
1373+
ab_grad = ab.grad
1374+
1375+
torch.manual_seed(42)
1376+
expected_q = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
1377+
expected_k = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
1378+
expected_v = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
1379+
expected_q.retain_grad()
1380+
expected_k.retain_grad()
1381+
expected_v.retain_grad()
1382+
expected_ab = torch.ones(4, 2, 128, 128).to("xla")
1383+
expected_ab = expected_ab.masked_fill(mask,
1384+
torch.finfo(
1385+
ab.dtype).min).requires_grad_()
1386+
expected_ab.retain_grad()
1387+
o = self._attention(expected_q, expected_k, expected_v, ab=expected_ab)
1388+
loss = o.sum()
1389+
loss.backward()
1390+
xm.mark_step()
1391+
1392+
for expected_tensor, actual_tensor_grad in [(expected_q, q_grad),
1393+
(expected_k, k_grad),
1394+
(expected_v, v_grad),
1395+
(expected_ab, ab_grad)]:
1396+
self.assertTrue(
1397+
torch.allclose(
1398+
expected_tensor.grad.cpu(), actual_tensor_grad.cpu(), atol=1e-02))
12551399

12561400

12571401
if __name__ == '__main__':

0 commit comments

Comments
 (0)