Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dynamo/AOTAutograd traceable flash attention #8654

Merged
merged 4 commits into from
Feb 1, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 161 additions & 23 deletions test/test_pallas.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import logging
import sys
import unittest
from absl.testing import parameterized

import torch
from torch import nn as nn
Expand All @@ -19,7 +21,20 @@
from jax.experimental import pallas as pl


class PallasTest(unittest.TestCase):
def with_jax_high_precision(func):
zpcore marked this conversation as resolved.
Show resolved Hide resolved

def wrapper(*args, **kwargs):
jax.config.update('jax_default_matmul_precision', "highest")
try:
result = func(*args, **kwargs)
finally:
jax.config.update('jax_default_matmul_precision', "default")
return result

return wrapper


class PallasTest(parameterized.TestCase):

# This is to create a diagonal mask where only elements within the same segment
# can attend to each other. Since the mask is to mask out the unrelevant parts,
Expand All @@ -33,12 +48,11 @@ def _make_attention_mask_from_segment_ids(self, q_segment_ids,

def _attention(self, q, k, v, *, attn_mask=None, ab=None):
attn_weight = q @ k.transpose(-2, -1)
if attn_mask is not None:
# Masked out the unrelevant parts.
attn_weight = attn_weight.masked_fill(attn_mask,
torch.finfo(attn_weight.dtype).min)
if ab is not None:
attn_weight = attn_weight + ab
if attn_mask is not None:
attn_weight = attn_weight.masked_fill(attn_mask.bool(),
torch.finfo(attn_weight.dtype).min)
attn_weight = nn.functional.softmax(attn_weight, dim=-1)
attn_output = attn_weight @ v
return attn_output
Expand Down Expand Up @@ -216,8 +230,8 @@ def test_tpu_custom_call_pallas_wrap_flash_attention(self):

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
"This test only works on TPUv3+.")
@with_jax_high_precision
def test_flash_attention_wrapper(self):
jax.config.update("jax_default_matmul_precision", "highest")
from torch_xla.experimental.custom_kernel import flash_attention

q = torch.randn(3, 2, 128, 4).to("xla")
Expand All @@ -227,12 +241,11 @@ def test_flash_attention_wrapper(self):
o = flash_attention(q, k, v)
expected_o = self._attention(q, k, v)
self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu(), atol=1e-05))
jax.config.update("jax_default_matmul_precision", "default")

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
"This test only works on TPUv3+.")
@with_jax_high_precision
def test_flash_attention_wrapper_with_dynamo(self):
jax.config.update("jax_default_matmul_precision", "highest")
from torch_xla.experimental.custom_kernel import flash_attention

def flash_attention_wrapper(q, k, v, causal=False):
Expand All @@ -253,12 +266,11 @@ def flash_attention_wrapper(q, k, v, causal=False):
# therefore it speeds up the compute but also changes the output.
self.assertFalse(
torch.allclose(o_with_causal.cpu(), expected_o.cpu(), atol=1e-05))
jax.config.update("jax_default_matmul_precision", "default")

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
"This test only works on TPUv3+.")
@with_jax_high_precision
def test_flash_attention_wrapper_causal(self):
jax.config.update("jax_default_matmul_precision", "highest")
from torch_xla.experimental.custom_kernel import flash_attention

q = torch.randn(3, 2, 128, 4).to("xla")
Expand All @@ -270,7 +282,6 @@ def test_flash_attention_wrapper_causal(self):
o = flash_attention(q, k, v, causal=True)
expected_o = self._attention(q, k, v)
self.assertFalse(torch.allclose(o.cpu(), expected_o.cpu()))
jax.config.update("jax_default_matmul_precision", "default")

@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
def test_multiple_returns(self):
Expand Down Expand Up @@ -450,8 +461,8 @@ def test__flash_attention_bwd_dkv(self):

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
"This test only works on TPUv3+.")
@with_jax_high_precision
def test_flash_attention_backward(self):
jax.config.update("jax_default_matmul_precision", "highest")
from torch_xla.experimental.custom_kernel import flash_attention

torch.manual_seed(42)
Expand Down Expand Up @@ -486,7 +497,6 @@ def test_flash_attention_backward(self):

for i in [(q, q_grad), (k, k_grad), (v, v_grad)]:
self.assertTrue(torch.allclose(i[0].grad.cpu(), i[1].cpu(), atol=1e-05))
jax.config.update("jax_default_matmul_precision", "default")

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4,
"This test only works on TPUv4+.")
Expand Down Expand Up @@ -1026,8 +1036,8 @@ def test_flash_attention_wrapper_segment_ids_1(self):

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
"This test only works on TPUv3+.")
@with_jax_high_precision
def test_flash_attention_wrapper_segment_ids_2(self):
jax.config.update("jax_default_matmul_precision", "highest")
from torch_xla.experimental.custom_kernel import flash_attention

q = torch.randn(3, 2, 128, 4).to("xla")
Expand Down Expand Up @@ -1093,12 +1103,11 @@ def test_flash_attention_backward_segment_ids(self):

for i in [(q, q_grad), (k, k_grad), (v, v_grad)]:
self.assertTrue(torch.allclose(i[0].grad.cpu(), i[1].cpu(), atol=1e-05))
jax.config.update("jax_default_matmul_precision", "default")

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
"This test only works on TPUv3+.")
@with_jax_high_precision
def test_flash_attention_wrapper_sm_scale(self):
jax.config.update("jax_default_matmul_precision", "highest")
from torch_xla.experimental.custom_kernel import flash_attention

q = torch.randn(3, 2, 128, 4).to("xla")
Expand All @@ -1109,12 +1118,11 @@ def test_flash_attention_wrapper_sm_scale(self):

expected_o = self._attention(q * sm_scale, k, v)
self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu(), atol=1e-05))
jax.config.update("jax_default_matmul_precision", "default")

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
"This test only works on TPUv3+.")
@with_jax_high_precision
def test_flash_attention_sm_scale_backward(self):
jax.config.update("jax_default_matmul_precision", "highest")
from torch_xla.experimental.custom_kernel import flash_attention

torch.manual_seed(42)
Expand Down Expand Up @@ -1151,12 +1159,11 @@ def test_flash_attention_sm_scale_backward(self):
# Hmm, the gradients are the same even the autograd graph seems different.
for i in [(q, q_grad), (k, k_grad), (v, v_grad)]:
self.assertTrue(torch.allclose(i[0].grad.cpu(), i[1].cpu(), atol=1e-05))
jax.config.update("jax_default_matmul_precision", "default")

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
"This test only works on TPUv3+.")
@with_jax_high_precision
def test_flash_attention_ab(self):
jax.config.update("jax_default_matmul_precision", "highest")
from torch_xla.experimental.custom_kernel import flash_attention

q = torch.randn(3, 2, 128, 4).to("xla")
Expand Down Expand Up @@ -1208,12 +1215,11 @@ def test_flash_attention_ab_backward_1(self):

for i in [(q, q_grad), (k, k_grad), (v, v_grad)]:
self.assertTrue(torch.allclose(i[0].grad.cpu(), i[1].cpu(), atol=1e-05))
jax.config.update("jax_default_matmul_precision", "default")

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
"This test only works on TPUv3+.")
@with_jax_high_precision
def test_flash_attention_ab_backward_2(self):
jax.config.update("jax_default_matmul_precision", "highest")
from torch_xla.experimental.custom_kernel import flash_attention

torch.manual_seed(42)
Expand Down Expand Up @@ -1251,7 +1257,139 @@ def test_flash_attention_ab_backward_2(self):

for i in [(q, q_grad), (k, k_grad), (v, v_grad), (ab, ab_grad)]:
self.assertTrue(torch.allclose(i[0].grad.cpu(), i[1].cpu(), atol=1e-05))
jax.config.update("jax_default_matmul_precision", "default")

@parameterized.named_parameters(('off', False), ('on', True))
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4,
"This test only works on TPUv4+.")
@with_jax_high_precision
def test_flash_attention_forward_aot_autograd_traceable_causal(self, causal):
from functorch.compile import aot_function, make_boxed_func
from torch_xla.experimental.custom_kernel import flash_attention
import torch_xla.core.xla_model as xm

def compiler(gm, _):
return make_boxed_func(gm)

torch.manual_seed(42)
q = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
k = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
v = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
q.retain_grad()
k.retain_grad()
v.retain_grad()
B, N, SEQ, H = q.size()
q_segment_ids = None
kv_segment_ids = None
sm_scale = 1.0

compiled_flash_attention = aot_function(
flash_attention, fw_compiler=compiler)
o_actual = compiled_flash_attention(q, k, v, causal, q_segment_ids,
kv_segment_ids, sm_scale)
xm.mark_step()
if causal:
attention_mask = torch.triu(torch.ones(SEQ, SEQ), diagonal=1).to("xla")
else:
attention_mask = None

expected_output = self._attention(q, k, v, attn_mask=attention_mask)
xm.mark_step()
self.assertTrue(
torch.allclose(o_actual.cpu(), expected_output.cpu(), atol=1e-5))

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4,
"This test only works on TPUv4+.")
@with_jax_high_precision
def test_flash_attention_forward_aot_autograd_traceable_ab(self):
from functorch.compile import aot_function, make_boxed_func
from torch_xla.experimental.custom_kernel import flash_attention
import torch_xla.core.xla_model as xm

def compiler(gm, _):
return make_boxed_func(gm)

torch.manual_seed(42)
q = torch.randn(4, 2, 128, 8).to("xla")
k = torch.randn(4, 2, 128, 8).to("xla")
v = torch.randn(4, 2, 128, 8).to("xla")
B, N, SEQ, H = q.size()
causal = False
q_segment_ids = None
kv_segment_ids = None
sm_scale = 1.0
mask = (torch.rand(4, 2, 128, 128) > 0.5).to("xla")
ab = torch.ones(4, 2, 128, 128).to("xla")
ab = ab.masked_fill(mask, torch.finfo(ab.dtype).min)

compiled_flash_attention = aot_function(
flash_attention, fw_compiler=compiler)
o_actual = compiled_flash_attention(
q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab=ab)
xm.mark_step()

expected_output = self._attention(q, k, v, ab=ab)
xm.mark_step()
self.assertTrue(
torch.allclose(o_actual.cpu(), expected_output.cpu(), atol=1e-5))

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4,
"This test only works on TPUv4+.")
@with_jax_high_precision
def test_flash_attention_backward_aot_autograd_traceable(self):
from functorch.compile import aot_function, make_boxed_func
from torch_xla.experimental.custom_kernel import flash_attention
import torch_xla.core.xla_model as xm

def compiler(gm, _):
return make_boxed_func(gm)

torch.manual_seed(42)
q = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
k = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
v = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
q.retain_grad()
k.retain_grad()
v.retain_grad()
B, N, SEQ, H = q.size()
mask = (torch.rand(4, 2, 128, 128) > 0.5).to("xla")
ab = torch.ones(4, 2, 128, 128).to("xla")
ab = ab.masked_fill(mask, torch.finfo(ab.dtype).min).requires_grad_()
ab.retain_grad()

causal = False
q_segment_ids = None
kv_segment_ids = None
sm_scale = 1.0
compiled_flash_attention = aot_function(
flash_attention, fw_compiler=compiler)
o_actual = compiled_flash_attention(
q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab=ab)
loss = o_actual.sum()
loss.backward()
xm.mark_step()
q_grad = q.grad
k_grad = k.grad
v_grad = v.grad
ab_grad = ab.grad

torch.manual_seed(42)
q = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
zpcore marked this conversation as resolved.
Show resolved Hide resolved
k = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
v = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
q.retain_grad()
k.retain_grad()
v.retain_grad()
ab = torch.ones(4, 2, 128, 128).to("xla")
ab = ab.masked_fill(mask, torch.finfo(ab.dtype).min).requires_grad_()
ab.retain_grad()

o = self._attention(q, k, v, ab=ab)
loss = o.sum()
loss.backward()
xm.mark_step()

for i in [(q, q_grad), (k, k_grad), (v, v_grad), (ab, ab_grad)]:
zpcore marked this conversation as resolved.
Show resolved Hide resolved
self.assertTrue(torch.allclose(i[0].grad.cpu(), i[1].cpu(), atol=1e-02))


if __name__ == '__main__':
Expand Down
Loading