1
1
import logging
2
+ import sys
2
3
import unittest
4
+ from absl .testing import parameterized
3
5
4
6
import torch
5
7
from torch import nn as nn
19
21
from jax .experimental import pallas as pl
20
22
21
23
22
- class PallasTest (unittest .TestCase ):
24
+ def with_jax_high_precision (func ):
25
+
26
+ def wrapper (* args , ** kwargs ):
27
+ jax .config .update ('jax_default_matmul_precision' , "highest" )
28
+ try :
29
+ result = func (* args , ** kwargs )
30
+ finally :
31
+ jax .config .update ('jax_default_matmul_precision' , "default" )
32
+ return result
33
+
34
+ return wrapper
35
+
36
+
37
+ class PallasTest (parameterized .TestCase ):
23
38
24
39
# This is to create a diagonal mask where only elements within the same segment
25
40
# can attend to each other. Since the mask is to mask out the unrelevant parts,
@@ -33,12 +48,11 @@ def _make_attention_mask_from_segment_ids(self, q_segment_ids,
33
48
34
49
def _attention (self , q , k , v , * , attn_mask = None , ab = None ):
35
50
attn_weight = q @ k .transpose (- 2 , - 1 )
36
- if attn_mask is not None :
37
- # Masked out the unrelevant parts.
38
- attn_weight = attn_weight .masked_fill (attn_mask ,
39
- torch .finfo (attn_weight .dtype ).min )
40
51
if ab is not None :
41
52
attn_weight = attn_weight + ab
53
+ if attn_mask is not None :
54
+ attn_weight = attn_weight .masked_fill (attn_mask .bool (),
55
+ torch .finfo (attn_weight .dtype ).min )
42
56
attn_weight = nn .functional .softmax (attn_weight , dim = - 1 )
43
57
attn_output = attn_weight @ v
44
58
return attn_output
@@ -216,8 +230,8 @@ def test_tpu_custom_call_pallas_wrap_flash_attention(self):
216
230
217
231
@unittest .skipIf (xr .device_type () != 'TPU' or tpu .version () < 3 ,
218
232
"This test only works on TPUv3+." )
233
+ @with_jax_high_precision
219
234
def test_flash_attention_wrapper (self ):
220
- jax .config .update ("jax_default_matmul_precision" , "highest" )
221
235
from torch_xla .experimental .custom_kernel import flash_attention
222
236
223
237
q = torch .randn (3 , 2 , 128 , 4 ).to ("xla" )
@@ -227,12 +241,11 @@ def test_flash_attention_wrapper(self):
227
241
o = flash_attention (q , k , v )
228
242
expected_o = self ._attention (q , k , v )
229
243
self .assertTrue (torch .allclose (o .cpu (), expected_o .cpu (), atol = 1e-05 ))
230
- jax .config .update ("jax_default_matmul_precision" , "default" )
231
244
232
245
@unittest .skipIf (xr .device_type () != 'TPU' or tpu .version () < 3 ,
233
246
"This test only works on TPUv3+." )
247
+ @with_jax_high_precision
234
248
def test_flash_attention_wrapper_with_dynamo (self ):
235
- jax .config .update ("jax_default_matmul_precision" , "highest" )
236
249
from torch_xla .experimental .custom_kernel import flash_attention
237
250
238
251
def flash_attention_wrapper (q , k , v , causal = False ):
@@ -253,12 +266,11 @@ def flash_attention_wrapper(q, k, v, causal=False):
253
266
# therefore it speeds up the compute but also changes the output.
254
267
self .assertFalse (
255
268
torch .allclose (o_with_causal .cpu (), expected_o .cpu (), atol = 1e-05 ))
256
- jax .config .update ("jax_default_matmul_precision" , "default" )
257
269
258
270
@unittest .skipIf (xr .device_type () != 'TPU' or tpu .version () < 3 ,
259
271
"This test only works on TPUv3+." )
272
+ @with_jax_high_precision
260
273
def test_flash_attention_wrapper_causal (self ):
261
- jax .config .update ("jax_default_matmul_precision" , "highest" )
262
274
from torch_xla .experimental .custom_kernel import flash_attention
263
275
264
276
q = torch .randn (3 , 2 , 128 , 4 ).to ("xla" )
@@ -270,7 +282,6 @@ def test_flash_attention_wrapper_causal(self):
270
282
o = flash_attention (q , k , v , causal = True )
271
283
expected_o = self ._attention (q , k , v )
272
284
self .assertFalse (torch .allclose (o .cpu (), expected_o .cpu ()))
273
- jax .config .update ("jax_default_matmul_precision" , "default" )
274
285
275
286
@unittest .skipIf (xr .device_type () != 'TPU' , "This test only works on TPU." )
276
287
def test_multiple_returns (self ):
@@ -450,8 +461,8 @@ def test__flash_attention_bwd_dkv(self):
450
461
451
462
@unittest .skipIf (xr .device_type () != 'TPU' or tpu .version () < 3 ,
452
463
"This test only works on TPUv3+." )
464
+ @with_jax_high_precision
453
465
def test_flash_attention_backward (self ):
454
- jax .config .update ("jax_default_matmul_precision" , "highest" )
455
466
from torch_xla .experimental .custom_kernel import flash_attention
456
467
457
468
torch .manual_seed (42 )
@@ -486,7 +497,6 @@ def test_flash_attention_backward(self):
486
497
487
498
for i in [(q , q_grad ), (k , k_grad ), (v , v_grad )]:
488
499
self .assertTrue (torch .allclose (i [0 ].grad .cpu (), i [1 ].cpu (), atol = 1e-05 ))
489
- jax .config .update ("jax_default_matmul_precision" , "default" )
490
500
491
501
@unittest .skipIf (xr .device_type () != 'TPU' or tpu .version () < 4 ,
492
502
"This test only works on TPUv4+." )
@@ -1026,8 +1036,8 @@ def test_flash_attention_wrapper_segment_ids_1(self):
1026
1036
1027
1037
@unittest .skipIf (xr .device_type () != 'TPU' or tpu .version () < 3 ,
1028
1038
"This test only works on TPUv3+." )
1039
+ @with_jax_high_precision
1029
1040
def test_flash_attention_wrapper_segment_ids_2 (self ):
1030
- jax .config .update ("jax_default_matmul_precision" , "highest" )
1031
1041
from torch_xla .experimental .custom_kernel import flash_attention
1032
1042
1033
1043
q = torch .randn (3 , 2 , 128 , 4 ).to ("xla" )
@@ -1093,12 +1103,11 @@ def test_flash_attention_backward_segment_ids(self):
1093
1103
1094
1104
for i in [(q , q_grad ), (k , k_grad ), (v , v_grad )]:
1095
1105
self .assertTrue (torch .allclose (i [0 ].grad .cpu (), i [1 ].cpu (), atol = 1e-05 ))
1096
- jax .config .update ("jax_default_matmul_precision" , "default" )
1097
1106
1098
1107
@unittest .skipIf (xr .device_type () != 'TPU' or tpu .version () < 3 ,
1099
1108
"This test only works on TPUv3+." )
1109
+ @with_jax_high_precision
1100
1110
def test_flash_attention_wrapper_sm_scale (self ):
1101
- jax .config .update ("jax_default_matmul_precision" , "highest" )
1102
1111
from torch_xla .experimental .custom_kernel import flash_attention
1103
1112
1104
1113
q = torch .randn (3 , 2 , 128 , 4 ).to ("xla" )
@@ -1109,12 +1118,11 @@ def test_flash_attention_wrapper_sm_scale(self):
1109
1118
1110
1119
expected_o = self ._attention (q * sm_scale , k , v )
1111
1120
self .assertTrue (torch .allclose (o .cpu (), expected_o .cpu (), atol = 1e-05 ))
1112
- jax .config .update ("jax_default_matmul_precision" , "default" )
1113
1121
1114
1122
@unittest .skipIf (xr .device_type () != 'TPU' or tpu .version () < 3 ,
1115
1123
"This test only works on TPUv3+." )
1124
+ @with_jax_high_precision
1116
1125
def test_flash_attention_sm_scale_backward (self ):
1117
- jax .config .update ("jax_default_matmul_precision" , "highest" )
1118
1126
from torch_xla .experimental .custom_kernel import flash_attention
1119
1127
1120
1128
torch .manual_seed (42 )
@@ -1151,12 +1159,11 @@ def test_flash_attention_sm_scale_backward(self):
1151
1159
# Hmm, the gradients are the same even the autograd graph seems different.
1152
1160
for i in [(q , q_grad ), (k , k_grad ), (v , v_grad )]:
1153
1161
self .assertTrue (torch .allclose (i [0 ].grad .cpu (), i [1 ].cpu (), atol = 1e-05 ))
1154
- jax .config .update ("jax_default_matmul_precision" , "default" )
1155
1162
1156
1163
@unittest .skipIf (xr .device_type () != 'TPU' or tpu .version () < 3 ,
1157
1164
"This test only works on TPUv3+." )
1165
+ @with_jax_high_precision
1158
1166
def test_flash_attention_ab (self ):
1159
- jax .config .update ("jax_default_matmul_precision" , "highest" )
1160
1167
from torch_xla .experimental .custom_kernel import flash_attention
1161
1168
1162
1169
q = torch .randn (3 , 2 , 128 , 4 ).to ("xla" )
@@ -1208,12 +1215,11 @@ def test_flash_attention_ab_backward_1(self):
1208
1215
1209
1216
for i in [(q , q_grad ), (k , k_grad ), (v , v_grad )]:
1210
1217
self .assertTrue (torch .allclose (i [0 ].grad .cpu (), i [1 ].cpu (), atol = 1e-05 ))
1211
- jax .config .update ("jax_default_matmul_precision" , "default" )
1212
1218
1213
1219
@unittest .skipIf (xr .device_type () != 'TPU' or tpu .version () < 3 ,
1214
1220
"This test only works on TPUv3+." )
1221
+ @with_jax_high_precision
1215
1222
def test_flash_attention_ab_backward_2 (self ):
1216
- jax .config .update ("jax_default_matmul_precision" , "highest" )
1217
1223
from torch_xla .experimental .custom_kernel import flash_attention
1218
1224
1219
1225
torch .manual_seed (42 )
@@ -1251,7 +1257,145 @@ def test_flash_attention_ab_backward_2(self):
1251
1257
1252
1258
for i in [(q , q_grad ), (k , k_grad ), (v , v_grad ), (ab , ab_grad )]:
1253
1259
self .assertTrue (torch .allclose (i [0 ].grad .cpu (), i [1 ].cpu (), atol = 1e-05 ))
1254
- jax .config .update ("jax_default_matmul_precision" , "default" )
1260
+
1261
+ @parameterized .named_parameters (('off' , False ), ('on' , True ))
1262
+ @unittest .skipIf (xr .device_type () != 'TPU' or tpu .version () < 4 ,
1263
+ "This test only works on TPUv4+." )
1264
+ @with_jax_high_precision
1265
+ def test_flash_attention_forward_aot_autograd_traceable_causal (self , causal ):
1266
+ from functorch .compile import aot_function , make_boxed_func
1267
+ from torch_xla .experimental .custom_kernel import flash_attention
1268
+ import torch_xla .core .xla_model as xm
1269
+
1270
+ def compiler (gm , _ ):
1271
+ return make_boxed_func (gm )
1272
+
1273
+ torch .manual_seed (42 )
1274
+ q = torch .randn (4 , 2 , 128 , 8 , requires_grad = True ).to ("xla" )
1275
+ k = torch .randn (4 , 2 , 128 , 8 , requires_grad = True ).to ("xla" )
1276
+ v = torch .randn (4 , 2 , 128 , 8 , requires_grad = True ).to ("xla" )
1277
+ q .retain_grad ()
1278
+ k .retain_grad ()
1279
+ v .retain_grad ()
1280
+ B , N , SEQ , H = q .size ()
1281
+ q_segment_ids = None
1282
+ kv_segment_ids = None
1283
+ sm_scale = 1.0
1284
+
1285
+ compiled_flash_attention = aot_function (
1286
+ flash_attention , fw_compiler = compiler )
1287
+ o_actual = compiled_flash_attention (q , k , v , causal , q_segment_ids ,
1288
+ kv_segment_ids , sm_scale )
1289
+ xm .mark_step ()
1290
+ if causal :
1291
+ attention_mask = torch .triu (torch .ones (SEQ , SEQ ), diagonal = 1 ).to ("xla" )
1292
+ else :
1293
+ attention_mask = None
1294
+
1295
+ expected_output = self ._attention (q , k , v , attn_mask = attention_mask )
1296
+ xm .mark_step ()
1297
+ self .assertTrue (
1298
+ torch .allclose (o_actual .cpu (), expected_output .cpu (), atol = 1e-5 ))
1299
+
1300
+ @unittest .skipIf (xr .device_type () != 'TPU' or tpu .version () < 4 ,
1301
+ "This test only works on TPUv4+." )
1302
+ @with_jax_high_precision
1303
+ def test_flash_attention_forward_aot_autograd_traceable_ab (self ):
1304
+ from functorch .compile import aot_function , make_boxed_func
1305
+ from torch_xla .experimental .custom_kernel import flash_attention
1306
+ import torch_xla .core .xla_model as xm
1307
+
1308
+ def compiler (gm , _ ):
1309
+ return make_boxed_func (gm )
1310
+
1311
+ torch .manual_seed (42 )
1312
+ q = torch .randn (4 , 2 , 128 , 8 ).to ("xla" )
1313
+ k = torch .randn (4 , 2 , 128 , 8 ).to ("xla" )
1314
+ v = torch .randn (4 , 2 , 128 , 8 ).to ("xla" )
1315
+ B , N , SEQ , H = q .size ()
1316
+ causal = False
1317
+ q_segment_ids = None
1318
+ kv_segment_ids = None
1319
+ sm_scale = 1.0
1320
+ mask = (torch .rand (4 , 2 , 128 , 128 ) > 0.5 ).to ("xla" )
1321
+ ab = torch .ones (4 , 2 , 128 , 128 ).to ("xla" )
1322
+ ab = ab .masked_fill (mask , torch .finfo (ab .dtype ).min )
1323
+
1324
+ compiled_flash_attention = aot_function (
1325
+ flash_attention , fw_compiler = compiler )
1326
+ o_actual = compiled_flash_attention (
1327
+ q , k , v , causal , q_segment_ids , kv_segment_ids , sm_scale , ab = ab )
1328
+ xm .mark_step ()
1329
+
1330
+ expected_output = self ._attention (q , k , v , ab = ab )
1331
+ xm .mark_step ()
1332
+ self .assertTrue (
1333
+ torch .allclose (o_actual .cpu (), expected_output .cpu (), atol = 1e-5 ))
1334
+
1335
+ @unittest .skipIf (xr .device_type () != 'TPU' or tpu .version () < 4 ,
1336
+ "This test only works on TPUv4+." )
1337
+ @with_jax_high_precision
1338
+ def test_flash_attention_backward_aot_autograd_traceable (self ):
1339
+ from functorch .compile import aot_function , make_boxed_func
1340
+ from torch_xla .experimental .custom_kernel import flash_attention
1341
+ import torch_xla .core .xla_model as xm
1342
+
1343
+ def compiler (gm , _ ):
1344
+ return make_boxed_func (gm )
1345
+
1346
+ torch .manual_seed (42 )
1347
+ q = torch .randn (4 , 2 , 128 , 8 , requires_grad = True ).to ("xla" )
1348
+ k = torch .randn (4 , 2 , 128 , 8 , requires_grad = True ).to ("xla" )
1349
+ v = torch .randn (4 , 2 , 128 , 8 , requires_grad = True ).to ("xla" )
1350
+ q .retain_grad ()
1351
+ k .retain_grad ()
1352
+ v .retain_grad ()
1353
+ B , N , SEQ , H = q .size ()
1354
+ mask = (torch .rand (4 , 2 , 128 , 128 ) > 0.5 ).to ("xla" )
1355
+ ab = torch .ones (4 , 2 , 128 , 128 ).to ("xla" )
1356
+ ab = ab .masked_fill (mask , torch .finfo (ab .dtype ).min ).requires_grad_ ()
1357
+ ab .retain_grad ()
1358
+
1359
+ causal = False
1360
+ q_segment_ids = None
1361
+ kv_segment_ids = None
1362
+ sm_scale = 1.0
1363
+ compiled_flash_attention = aot_function (
1364
+ flash_attention , fw_compiler = compiler )
1365
+ o_actual = compiled_flash_attention (
1366
+ q , k , v , causal , q_segment_ids , kv_segment_ids , sm_scale , ab = ab )
1367
+ loss = o_actual .sum ()
1368
+ loss .backward ()
1369
+ xm .mark_step ()
1370
+ q_grad = q .grad
1371
+ k_grad = k .grad
1372
+ v_grad = v .grad
1373
+ ab_grad = ab .grad
1374
+
1375
+ torch .manual_seed (42 )
1376
+ expected_q = torch .randn (4 , 2 , 128 , 8 , requires_grad = True ).to ("xla" )
1377
+ expected_k = torch .randn (4 , 2 , 128 , 8 , requires_grad = True ).to ("xla" )
1378
+ expected_v = torch .randn (4 , 2 , 128 , 8 , requires_grad = True ).to ("xla" )
1379
+ expected_q .retain_grad ()
1380
+ expected_k .retain_grad ()
1381
+ expected_v .retain_grad ()
1382
+ expected_ab = torch .ones (4 , 2 , 128 , 128 ).to ("xla" )
1383
+ expected_ab = expected_ab .masked_fill (mask ,
1384
+ torch .finfo (
1385
+ ab .dtype ).min ).requires_grad_ ()
1386
+ expected_ab .retain_grad ()
1387
+ o = self ._attention (expected_q , expected_k , expected_v , ab = expected_ab )
1388
+ loss = o .sum ()
1389
+ loss .backward ()
1390
+ xm .mark_step ()
1391
+
1392
+ for expected_tensor , actual_tensor_grad in [(expected_q , q_grad ),
1393
+ (expected_k , k_grad ),
1394
+ (expected_v , v_grad ),
1395
+ (expected_ab , ab_grad )]:
1396
+ self .assertTrue (
1397
+ torch .allclose (
1398
+ expected_tensor .grad .cpu (), actual_tensor_grad .cpu (), atol = 1e-02 ))
1255
1399
1256
1400
1257
1401
if __name__ == '__main__' :
0 commit comments