|
| 1 | +import pytest |
| 2 | +import torch |
| 3 | +from torch.testing._internal.optests import fake_check |
| 4 | + |
| 5 | +from sageattention.core import ( |
| 6 | + SM80_ENABLED, |
| 7 | + SM89_ENABLED, |
| 8 | + SM90_ENABLED, |
| 9 | + sageattn_qk_int8_pv_fp16_cuda, |
| 10 | + sageattn_qk_int8_pv_fp8_cuda, |
| 11 | + sageattn_qk_int8_pv_fp8_cuda_sm90, |
| 12 | +) |
| 13 | + |
| 14 | +def run_fake_check(fn): |
| 15 | + def wrapper(*args, **kwargs): |
| 16 | + fake_check(fn, args, kwargs) |
| 17 | + return wrapper |
| 18 | + |
| 19 | + |
| 20 | +@pytest.mark.skipif(not SM80_ENABLED, reason="SM80 not enabled") |
| 21 | +class TestSM80: |
| 22 | + def get_kernel(self, pv_accum_dtype): |
| 23 | + return sageattn_qk_int8_pv_fp16_cuda |
| 24 | + |
| 25 | + @pytest.mark.parametrize("is_causal", (False, True)) |
| 26 | + @pytest.mark.parametrize("seq_len", (64, 128)) |
| 27 | + @pytest.mark.parametrize("head", (32,)) |
| 28 | + @pytest.mark.parametrize("batch", (4,)) |
| 29 | + @pytest.mark.parametrize("headdim", (32, 64)) |
| 30 | + @pytest.mark.parametrize("quant_gran", ("per_warp", "per_thread")) |
| 31 | + @pytest.mark.parametrize("pv_accum_dtype", ("fp16", "fp16+fp32", "fp32")) |
| 32 | + @pytest.mark.parametrize("tensor_layout", ("NHD", "HND")) |
| 33 | + @pytest.mark.parametrize("smooth_k", (False, True)) |
| 34 | + @pytest.mark.parametrize("smooth_v", (False, True)) |
| 35 | + @pytest.mark.parametrize("return_lse", (False, True)) |
| 36 | + @pytest.mark.parametrize("dtype", (torch.float16, torch.bfloat16)) |
| 37 | + def test_SM80(self, is_causal, seq_len, head, batch, headdim, quant_gran, pv_accum_dtype, tensor_layout, smooth_k, smooth_v, return_lse, dtype): |
| 38 | + q = torch.randint(-95, 95, (batch, seq_len, head, headdim), dtype=dtype, device="cuda") |
| 39 | + k = torch.randint(-95, 95, (batch, seq_len, head, headdim), dtype=dtype, device="cuda") |
| 40 | + |
| 41 | + v = torch.randn(batch, seq_len, head, headdim, dtype=dtype, device="cuda") |
| 42 | + sm_scale = 1 / (headdim ** 0.5) |
| 43 | + |
| 44 | + kernel = self.get_kernel(pv_accum_dtype) |
| 45 | + run_fake_check(kernel)(q, k, v, tensor_layout, is_causal, quant_gran, |
| 46 | + sm_scale, pv_accum_dtype, smooth_k, smooth_v, |
| 47 | + return_lse) |
| 48 | + |
| 49 | + |
| 50 | +@pytest.mark.skipif(not SM89_ENABLED, reason="SM89 not enabled") |
| 51 | +class TestSM89: |
| 52 | + |
| 53 | + def get_kernel(self): |
| 54 | + return sageattn_qk_int8_pv_fp8_cuda |
| 55 | + |
| 56 | + @pytest.mark.parametrize("is_causal", (False, True)) |
| 57 | + @pytest.mark.parametrize("seq_len", (64, 128)) |
| 58 | + @pytest.mark.parametrize("head", (32,)) |
| 59 | + @pytest.mark.parametrize("batch", (4,)) |
| 60 | + @pytest.mark.parametrize("headdim", (32, 64)) |
| 61 | + @pytest.mark.parametrize("quant_gran", ("per_warp", "per_thread")) |
| 62 | + @pytest.mark.parametrize("pv_accum_dtype", ("fp32+fp32", "fp32+fp16", "fp32")) |
| 63 | + @pytest.mark.parametrize("tensor_layout", ("NHD", "HND")) |
| 64 | + @pytest.mark.parametrize("smooth_k", (False, True)) |
| 65 | + @pytest.mark.parametrize("smooth_v", (False, True)) |
| 66 | + @pytest.mark.parametrize("return_lse", (False, True)) |
| 67 | + @pytest.mark.parametrize("dtype", (torch.float16, torch.bfloat16)) |
| 68 | + def test_kernels(self, is_causal, seq_len, head, batch, headdim, quant_gran, pv_accum_dtype, tensor_layout, smooth_k, smooth_v, return_lse, dtype): |
| 69 | + kernel = self.get_kernel() |
| 70 | + |
| 71 | + |
| 72 | + if tensor_layout == "HND": |
| 73 | + q = torch.randint(-128, 127, (batch, head, seq_len, headdim), dtype=dtype, device="cuda") |
| 74 | + k = torch.randint(-128, 127, (batch, head, seq_len, headdim), dtype=dtype, device="cuda") |
| 75 | + v = torch.randn(batch, head, seq_len, headdim, dtype=dtype, device="cuda") |
| 76 | + else: # NHD |
| 77 | + q = torch.randint(-128, 127, (batch, seq_len, head, headdim), dtype=dtype, device="cuda") |
| 78 | + k = torch.randint(-128, 127, (batch, seq_len, head, headdim), dtype=dtype, device="cuda") |
| 79 | + v = torch.randn(batch, seq_len, head, headdim, dtype=dtype, device="cuda") |
| 80 | + |
| 81 | + sm_scale = 1.0 / (headdim ** 0.5) |
| 82 | + |
| 83 | + run_fake_check(kernel)(q, k, v, tensor_layout, is_causal, quant_gran, |
| 84 | + sm_scale, pv_accum_dtype, smooth_k, smooth_v, |
| 85 | + return_lse) |
| 86 | + |
| 87 | + |
| 88 | +@pytest.mark.skipif(not SM90_ENABLED, reason="SM90 not enabled") |
| 89 | +class TestSM90: |
| 90 | + def get_kernel(self): |
| 91 | + return sageattn_qk_int8_pv_fp8_cuda_sm90 |
| 92 | + |
| 93 | + @pytest.mark.parametrize("is_causal", (False, True)) |
| 94 | + @pytest.mark.parametrize("seq_len", (64, 128)) |
| 95 | + @pytest.mark.parametrize("head", (32,)) |
| 96 | + @pytest.mark.parametrize("batch", (4,)) |
| 97 | + @pytest.mark.parametrize("headdim", (32, 64)) |
| 98 | + @pytest.mark.parametrize("quant_gran", ("per_warp", "per_thread")) |
| 99 | + @pytest.mark.parametrize("pv_accum_dtype", ("fp32+fp32",)) |
| 100 | + @pytest.mark.parametrize("tensor_layout", ("NHD", "HND")) |
| 101 | + @pytest.mark.parametrize("smooth_k", (False, True)) |
| 102 | + @pytest.mark.parametrize("return_lse", (False, True)) |
| 103 | + @pytest.mark.parametrize("dtype", (torch.float16, torch.bfloat16)) |
| 104 | + def test_kernels(self, is_causal, seq_len, head, batch, headdim, quant_gran, pv_accum_dtype, tensor_layout, smooth_k, return_lse, dtype): |
| 105 | + kernel = self.get_kernel() |
| 106 | + |
| 107 | + q = torch.randint(-128, 127, (batch, seq_len, head, headdim), dtype=dtype, device="cuda") |
| 108 | + k = torch.randint(-128, 127, (batch, seq_len, head, headdim), dtype=dtype, device="cuda") |
| 109 | + |
| 110 | + if tensor_layout == "HND": |
| 111 | + v = torch.randn(batch, head, seq_len, headdim, dtype=dtype, device="cuda") |
| 112 | + else: # NHD |
| 113 | + v = torch.randn(batch, seq_len, head, headdim, dtype=dtype, device="cuda") |
| 114 | + |
| 115 | + sm_scale = 1.0 / (headdim ** 0.5) |
| 116 | + |
| 117 | + run_fake_check(kernel)(q, k, v, tensor_layout, is_causal, quant_gran, |
| 118 | + sm_scale, pv_accum_dtype, smooth_k, return_lse) |
0 commit comments