Skip to content

Commit e5bf6ee

Browse files
committed
Revert "Merge pull request #218 from guilhermeleobas/guilhermeleobas/torch-compile"
This reverts commit 15c0e22, reversing changes made to 68de379.
1 parent 2fff52e commit e5bf6ee

File tree

5 files changed

+21
-469
lines changed

5 files changed

+21
-469
lines changed

sageattention/core.py

Lines changed: 20 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -27,19 +27,19 @@
2727
from .triton.quant_per_thread import per_thread_int8 as per_thread_int8_triton
2828

2929
try:
30-
from . import sm80_compile
30+
from . import _qattn_sm80
3131
SM80_ENABLED = True
3232
except:
3333
SM80_ENABLED = False
3434

3535
try:
36-
from . import sm89_compile
36+
from . import _qattn_sm89
3737
SM89_ENABLED = True
3838
except:
3939
SM89_ENABLED = False
4040

4141
try:
42-
from . import sm90_compile
42+
from . import _qattn_sm90
4343
SM90_ENABLED = True
4444
except:
4545
SM90_ENABLED = False
@@ -52,10 +52,9 @@
5252
from typing import Any, List, Literal, Optional, Tuple, Union
5353
import warnings
5454

55+
5556
import subprocess
5657
import re
57-
58-
5958
def get_cuda_version():
6059
try:
6160
output = subprocess.check_output(['nvcc', '--version']).decode()
@@ -67,15 +66,13 @@ def get_cuda_version():
6766
print("Failed to get CUDA version:", e)
6867
return None, None
6968

70-
7169
def get_cuda_arch_versions():
7270
cuda_archs = []
7371
for i in range(torch.cuda.device_count()):
7472
major, minor = torch.cuda.get_device_capability(i)
7573
cuda_archs.append(f"sm{major}{minor}")
7674
return cuda_archs
7775

78-
7976
def sageattn(
8077
q: torch.Tensor,
8178
k: torch.Tensor,
@@ -154,7 +151,7 @@ def sageattn(
154151
else:
155152
raise ValueError(f"Unsupported CUDA architecture: {arch}")
156153

157-
154+
@torch.compiler.disable
158155
def sageattn_qk_int8_pv_fp16_triton(
159156
q: torch.Tensor,
160157
k: torch.Tensor,
@@ -328,7 +325,7 @@ def sageattn_qk_int8_pv_fp16_triton(
328325
else:
329326
return o
330327

331-
328+
@torch.compiler.disable
332329
def sageattn_varlen(
333330
q: torch.Tensor,
334331
k: torch.Tensor,
@@ -445,7 +442,7 @@ def sageattn_varlen(
445442

446443
return o
447444

448-
445+
@torch.compiler.disable
449446
def sageattn_qk_int8_pv_fp16_cuda(
450447
q: torch.Tensor,
451448
k: torch.Tensor,
@@ -609,17 +606,17 @@ def sageattn_qk_int8_pv_fp16_cuda(
609606

610607
if pv_accum_dtype == 'fp32':
611608
v = v.to(torch.float16)
612-
lse = sm80_compile.qk_int8_sv_f16_accum_f32_attn(q_int8, k_int8, v, o, q_scale, k_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse)
609+
lse = _qattn_sm80.qk_int8_sv_f16_accum_f32_attn(q_int8, k_int8, v, o, q_scale, k_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse)
613610
elif pv_accum_dtype == "fp16":
614611
if smooth_v:
615612
smoothed_v, vm = sub_mean(v, tensor_layout=tensor_layout)
616-
lse = sm80_compile.qk_int8_sv_f16_accum_f16_fuse_v_mean_attn(q_int8, k_int8, smoothed_v, o, q_scale, k_scale, vm, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse)
613+
lse = _qattn_sm80.qk_int8_sv_f16_accum_f16_fuse_v_mean_attn(q_int8, k_int8, smoothed_v, o, q_scale, k_scale, vm, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse)
617614
else:
618615
v = v.to(torch.float16)
619-
lse = sm80_compile.qk_int8_sv_f16_accum_f16_attn(q_int8, k_int8, v, o, q_scale, k_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse)
616+
lse = _qattn_sm80.qk_int8_sv_f16_accum_f16_attn(q_int8, k_int8, v, o, q_scale, k_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse)
620617
elif pv_accum_dtype == "fp16+fp32":
621618
v = v.to(torch.float16)
622-
lse = sm80_compile.qk_int8_sv_f16_accum_f16_attn_inst_buf(q_int8, k_int8, v, o, q_scale, k_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse)
619+
lse = _qattn_sm80.qk_int8_sv_f16_accum_f16_attn_inst_buf(q_int8, k_int8, v, o, q_scale, k_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse)
623620
else:
624621
raise ValueError(f"Unsupported pv_accum_dtype: {pv_accum_dtype}")
625622

@@ -630,7 +627,7 @@ def sageattn_qk_int8_pv_fp16_cuda(
630627
else:
631628
return o
632629

633-
630+
@torch.compiler.disable
634631
def sageattn_qk_int8_pv_fp8_cuda(
635632
q: torch.Tensor,
636633
k: torch.Tensor,
@@ -808,13 +805,13 @@ def sageattn_qk_int8_pv_fp8_cuda(
808805

809806
if pv_accum_dtype == "fp32":
810807
if smooth_v:
811-
lse = sm89_compile.qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, vm, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse)
808+
lse = _qattn_sm89.qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, vm, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse)
812809
else:
813-
lse = sm89_compile.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse)
810+
lse = _qattn_sm89.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse)
814811
elif pv_accum_dtype == "fp32+fp32":
815-
lse = sm89_compile.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse)
812+
lse = _qattn_sm89.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse)
816813
elif pv_accum_dtype == "fp32+fp16":
817-
lse = sm89_compile.qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse)
814+
lse = _qattn_sm89.qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse)
818815

819816
o = o[..., :head_dim_og]
820817

@@ -823,7 +820,7 @@ def sageattn_qk_int8_pv_fp8_cuda(
823820
else:
824821
return o
825822

826-
823+
@torch.compiler.disable
827824
def sageattn_qk_int8_pv_fp8_cuda_sm90(
828825
q: torch.Tensor,
829826
k: torch.Tensor,
@@ -982,13 +979,13 @@ def sageattn_qk_int8_pv_fp8_cuda_sm90(
982979

983980
if pv_accum_dtype == "fp32":
984981
raise NotImplementedError("Please use pv_accum_dtype='fp32+fp32' for sm90.")
985-
lse = sm90_compile.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse)
982+
lse = _qattn_sm90.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse)
986983
elif pv_accum_dtype == "fp32+fp32":
987-
lse = sm90_compile.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse)
984+
lse = _qattn_sm90.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse)
988985

989986
o = o[..., :head_dim_og]
990987

991988
if return_lse:
992989
return o, lse / 1.44269504 + lse_correction * sm_scale if smooth_k else lse / 1.44269504
993990
else:
994-
return o
991+
return o

sageattention/sm80_compile.py

Lines changed: 0 additions & 173 deletions
This file was deleted.

0 commit comments

Comments
 (0)