|
12 | 12 | import numpy as np
|
13 | 13 |
|
14 | 14 | if xr.device_type() == 'TPU':
|
15 |
| - from torch_xla.experimental.custom_kernel import jax_import_guard |
16 |
| - jax_import_guard() |
| 15 | + # from torch_xla.experimental.custom_kernel import jax_import_guard |
| 16 | + # jax_import_guard() |
| 17 | + torch_xla._XLAC._init_computation_client() |
17 | 18 | import jax
|
18 | 19 | import jax.numpy as jnp
|
19 | 20 | from jax.experimental import pallas as pl
|
@@ -488,6 +489,79 @@ def test_flash_attention_backward(self):
|
488 | 489 | self.assertTrue(torch.allclose(i[0].grad.cpu(), i[1].cpu(), atol=1e-05))
|
489 | 490 | jax.config.update("jax_default_matmul_precision", "default")
|
490 | 491 |
|
| 492 | + |
| 493 | + @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4, |
| 494 | + "This test only works on TPUv4+.") |
| 495 | + def test_flash_attention_backward_aot_autograd_traceable(self): |
| 496 | + from functorch.compile import aot_function, make_boxed_func |
| 497 | + from torch_xla.experimental.custom_kernel import flash_attention, FlashAttention, flash_attention_compilable |
| 498 | + import torch_xla.core.xla_model as xm |
| 499 | + jax.config.update("jax_default_matmul_precision", "highest") |
| 500 | + def compiler(gm, _): |
| 501 | + print("Got graph:") |
| 502 | + print(gm.code) |
| 503 | + return make_boxed_func(gm) |
| 504 | + |
| 505 | + torch.manual_seed(42) |
| 506 | + q = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") |
| 507 | + k = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") |
| 508 | + v = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") |
| 509 | + q.retain_grad() |
| 510 | + k.retain_grad() |
| 511 | + v.retain_grad() |
| 512 | + B, N, SEQ, H = q.size() |
| 513 | + causal = True |
| 514 | + q_segment_ids = None |
| 515 | + kv_segment_ids = None |
| 516 | + sm_scale = 1.0 |
| 517 | + mask = (torch.rand(4, 2, 128, 128) > 0.5).to("xla") |
| 518 | + # ab = torch.ones(4, 2, 128, 128).to("xla") |
| 519 | + # ab = ab.masked_fill(mask, torch.finfo(ab.dtype).min).requires_grad_(True) |
| 520 | + # ab.retain_grad() |
| 521 | + ab = None |
| 522 | + partition_spec = ('fsdp', 'tensor', None, None) |
| 523 | + # partition_spec = None |
| 524 | + import torch_xla.runtime as xr |
| 525 | + from torch_xla.distributed.spmd import Mesh |
| 526 | + xr.use_spmd() |
| 527 | + num_devices = xr.global_runtime_device_count() |
| 528 | + mesh_shape = (num_devices // 2, 2) |
| 529 | + device_ids = np.array(range(num_devices)) |
| 530 | + mesh = Mesh(device_ids, mesh_shape, ('fsdp', 'tensor')) |
| 531 | + |
| 532 | + def flash_attention_wrapper(q, k, v, casual, q_segment_ids, kv_segment_ids, sm_scale, ab, partition_spec, mesh): |
| 533 | + return flash_attention_compilable(q, k, v, casual, q_segment_ids, kv_segment_ids, sm_scale, ab=ab, partition_spec=partition_spec, mesh=mesh) |
| 534 | + |
| 535 | + |
| 536 | + # AOT compatiable funtion only accepts argument types listed https://github.com/pytorch/pytorch/blob/82859f61857ef39898b34a5cdf0ae56ec25704d9/torch/_functorch/_aot_autograd/utils.py#L23-L34, so we serliaze partition_spec and mesh into string. |
| 537 | + # compiled_flash_attention = aot_function( |
| 538 | + # flash_attention_wrapper, fw_compiler=compiler) |
| 539 | + # o_actual = compiled_flash_attention(q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab, str(partition_spec), str(mesh)) |
| 540 | + o_actual = flash_attention(q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab=ab, partition_spec=partition_spec, mesh=mesh) |
| 541 | + |
| 542 | + print(o_actual.sum()) |
| 543 | + o_actual.sum().backward() |
| 544 | + print(q.grad) |
| 545 | + |
| 546 | + # if causal: |
| 547 | + # attention_mask = torch.triu(torch.ones(SEQ, SEQ), diagonal=1).to("xla") |
| 548 | + # # attention_mask = attention_mask.view(1, 1, SEQ, SEQ) |
| 549 | + # # attention_mask = attention_mask.expand(q.size(0), q.size(1), -1, -1) |
| 550 | + # else: |
| 551 | + # attention_mask = None |
| 552 | + # print(attention_mask) |
| 553 | + # assert False |
| 554 | + # import torch_xla.distributed.spmd as xs |
| 555 | + # expected_output = self._attention(q, k, v, attn_mask = attention_mask) |
| 556 | + # print(expected_output) |
| 557 | + # self.assertTrue( |
| 558 | + # torch.allclose( |
| 559 | + # expected_output.cpu(), |
| 560 | + # o_actual.cpu(), |
| 561 | + # atol=1e-1, |
| 562 | + # rtol=1e-1)) |
| 563 | + |
| 564 | + |
491 | 565 | @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4,
|
492 | 566 | "This test only works on TPUv4+.")
|
493 | 567 | def test_paged_attention_wrapper(self):
|
|
0 commit comments