Skip to content

Commit 06418c2

Browse files
committed
Revert "Remove torch.compile related tests"
This reverts commit 32ed32f.
1 parent e5bf6ee commit 06418c2

File tree

3 files changed

+131
-0
lines changed

3 files changed

+131
-0
lines changed

bench/bench_qk_int8_pv_fp16_cuda.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import torch
2+
from torch.testing._internal.optests import fake_check
23
from flash_attn.utils.benchmark import benchmark_forward
34

45
import sageattention._qattn_sm80 as qattn

script.sh

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
#!bin/bash
2+
3+
set -e
4+
5+
(
6+
export PYTHONBREAKPOINT="pdbp.set_trace"
7+
python setup.py install
8+
(
9+
cd tests
10+
python -m pytest --tb=short -rs -sv -x -k SM89
11+
)
12+
)

tests/test_torch_compile.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
import pytest
2+
import torch
3+
from torch.testing._internal.optests import fake_check
4+
5+
from sageattention.core import (
6+
SM80_ENABLED,
7+
SM89_ENABLED,
8+
SM90_ENABLED,
9+
sageattn_qk_int8_pv_fp16_cuda,
10+
sageattn_qk_int8_pv_fp8_cuda,
11+
sageattn_qk_int8_pv_fp8_cuda_sm90,
12+
)
13+
14+
def run_fake_check(fn):
15+
def wrapper(*args, **kwargs):
16+
fake_check(fn, args, kwargs)
17+
return wrapper
18+
19+
20+
@pytest.mark.skipif(not SM80_ENABLED, reason="SM80 not enabled")
21+
class TestSM80:
22+
def get_kernel(self, pv_accum_dtype):
23+
return sageattn_qk_int8_pv_fp16_cuda
24+
25+
@pytest.mark.parametrize("is_causal", (False, True))
26+
@pytest.mark.parametrize("seq_len", (64, 128))
27+
@pytest.mark.parametrize("head", (32,))
28+
@pytest.mark.parametrize("batch", (4,))
29+
@pytest.mark.parametrize("headdim", (32, 64))
30+
@pytest.mark.parametrize("quant_gran", ("per_warp", "per_thread"))
31+
@pytest.mark.parametrize("pv_accum_dtype", ("fp16", "fp16+fp32", "fp32"))
32+
@pytest.mark.parametrize("tensor_layout", ("NHD", "HND"))
33+
@pytest.mark.parametrize("smooth_k", (False, True))
34+
@pytest.mark.parametrize("smooth_v", (False, True))
35+
@pytest.mark.parametrize("return_lse", (False, True))
36+
@pytest.mark.parametrize("dtype", (torch.float16, torch.bfloat16))
37+
def test_SM80(self, is_causal, seq_len, head, batch, headdim, quant_gran, pv_accum_dtype, tensor_layout, smooth_k, smooth_v, return_lse, dtype):
38+
q = torch.randint(-95, 95, (batch, seq_len, head, headdim), dtype=dtype, device="cuda")
39+
k = torch.randint(-95, 95, (batch, seq_len, head, headdim), dtype=dtype, device="cuda")
40+
41+
v = torch.randn(batch, seq_len, head, headdim, dtype=dtype, device="cuda")
42+
sm_scale = 1 / (headdim ** 0.5)
43+
44+
kernel = self.get_kernel(pv_accum_dtype)
45+
run_fake_check(kernel)(q, k, v, tensor_layout, is_causal, quant_gran,
46+
sm_scale, pv_accum_dtype, smooth_k, smooth_v,
47+
return_lse)
48+
49+
50+
@pytest.mark.skipif(not SM89_ENABLED, reason="SM89 not enabled")
51+
class TestSM89:
52+
53+
def get_kernel(self):
54+
return sageattn_qk_int8_pv_fp8_cuda
55+
56+
@pytest.mark.parametrize("is_causal", (False, True))
57+
@pytest.mark.parametrize("seq_len", (64, 128))
58+
@pytest.mark.parametrize("head", (32,))
59+
@pytest.mark.parametrize("batch", (4,))
60+
@pytest.mark.parametrize("headdim", (32, 64))
61+
@pytest.mark.parametrize("quant_gran", ("per_warp", "per_thread"))
62+
@pytest.mark.parametrize("pv_accum_dtype", ("fp32+fp32", "fp32+fp16", "fp32"))
63+
@pytest.mark.parametrize("tensor_layout", ("NHD", "HND"))
64+
@pytest.mark.parametrize("smooth_k", (False, True))
65+
@pytest.mark.parametrize("smooth_v", (False, True))
66+
@pytest.mark.parametrize("return_lse", (False, True))
67+
@pytest.mark.parametrize("dtype", (torch.float16, torch.bfloat16))
68+
def test_kernels(self, is_causal, seq_len, head, batch, headdim, quant_gran, pv_accum_dtype, tensor_layout, smooth_k, smooth_v, return_lse, dtype):
69+
kernel = self.get_kernel()
70+
71+
72+
if tensor_layout == "HND":
73+
q = torch.randint(-128, 127, (batch, head, seq_len, headdim), dtype=dtype, device="cuda")
74+
k = torch.randint(-128, 127, (batch, head, seq_len, headdim), dtype=dtype, device="cuda")
75+
v = torch.randn(batch, head, seq_len, headdim, dtype=dtype, device="cuda")
76+
else: # NHD
77+
q = torch.randint(-128, 127, (batch, seq_len, head, headdim), dtype=dtype, device="cuda")
78+
k = torch.randint(-128, 127, (batch, seq_len, head, headdim), dtype=dtype, device="cuda")
79+
v = torch.randn(batch, seq_len, head, headdim, dtype=dtype, device="cuda")
80+
81+
sm_scale = 1.0 / (headdim ** 0.5)
82+
83+
run_fake_check(kernel)(q, k, v, tensor_layout, is_causal, quant_gran,
84+
sm_scale, pv_accum_dtype, smooth_k, smooth_v,
85+
return_lse)
86+
87+
88+
@pytest.mark.skipif(not SM90_ENABLED, reason="SM90 not enabled")
89+
class TestSM90:
90+
def get_kernel(self):
91+
return sageattn_qk_int8_pv_fp8_cuda_sm90
92+
93+
@pytest.mark.parametrize("is_causal", (False, True))
94+
@pytest.mark.parametrize("seq_len", (64, 128))
95+
@pytest.mark.parametrize("head", (32,))
96+
@pytest.mark.parametrize("batch", (4,))
97+
@pytest.mark.parametrize("headdim", (32, 64))
98+
@pytest.mark.parametrize("quant_gran", ("per_warp", "per_thread"))
99+
@pytest.mark.parametrize("pv_accum_dtype", ("fp32+fp32",))
100+
@pytest.mark.parametrize("tensor_layout", ("NHD", "HND"))
101+
@pytest.mark.parametrize("smooth_k", (False, True))
102+
@pytest.mark.parametrize("return_lse", (False, True))
103+
@pytest.mark.parametrize("dtype", (torch.float16, torch.bfloat16))
104+
def test_kernels(self, is_causal, seq_len, head, batch, headdim, quant_gran, pv_accum_dtype, tensor_layout, smooth_k, return_lse, dtype):
105+
kernel = self.get_kernel()
106+
107+
q = torch.randint(-128, 127, (batch, seq_len, head, headdim), dtype=dtype, device="cuda")
108+
k = torch.randint(-128, 127, (batch, seq_len, head, headdim), dtype=dtype, device="cuda")
109+
110+
if tensor_layout == "HND":
111+
v = torch.randn(batch, head, seq_len, headdim, dtype=dtype, device="cuda")
112+
else: # NHD
113+
v = torch.randn(batch, seq_len, head, headdim, dtype=dtype, device="cuda")
114+
115+
sm_scale = 1.0 / (headdim ** 0.5)
116+
117+
run_fake_check(kernel)(q, k, v, tensor_layout, is_causal, quant_gran,
118+
sm_scale, pv_accum_dtype, smooth_k, return_lse)

0 commit comments

Comments
 (0)