2727from .triton .quant_per_thread import per_thread_int8 as per_thread_int8_triton
2828
2929try :
30- from . import sm80_compile
30+ from . import _qattn_sm80
3131 SM80_ENABLED = True
3232except :
3333 SM80_ENABLED = False
3434
3535try :
36- from . import sm89_compile
36+ from . import _qattn_sm89
3737 SM89_ENABLED = True
3838except :
3939 SM89_ENABLED = False
4040
4141try :
42- from . import sm90_compile
42+ from . import _qattn_sm90
4343 SM90_ENABLED = True
4444except :
4545 SM90_ENABLED = False
5252from typing import Any , List , Literal , Optional , Tuple , Union
5353import warnings
5454
55+
5556import subprocess
5657import re
57-
58-
5958def get_cuda_version ():
6059 try :
6160 output = subprocess .check_output (['nvcc' , '--version' ]).decode ()
@@ -67,15 +66,13 @@ def get_cuda_version():
6766 print ("Failed to get CUDA version:" , e )
6867 return None , None
6968
70-
7169def get_cuda_arch_versions ():
7270 cuda_archs = []
7371 for i in range (torch .cuda .device_count ()):
7472 major , minor = torch .cuda .get_device_capability (i )
7573 cuda_archs .append (f"sm{ major } { minor } " )
7674 return cuda_archs
7775
78-
7976def sageattn (
8077 q : torch .Tensor ,
8178 k : torch .Tensor ,
@@ -154,7 +151,7 @@ def sageattn(
154151 else :
155152 raise ValueError (f"Unsupported CUDA architecture: { arch } " )
156153
157-
154+ @ torch . compiler . disable
158155def sageattn_qk_int8_pv_fp16_triton (
159156 q : torch .Tensor ,
160157 k : torch .Tensor ,
@@ -328,7 +325,7 @@ def sageattn_qk_int8_pv_fp16_triton(
328325 else :
329326 return o
330327
331-
328+ @ torch . compiler . disable
332329def sageattn_varlen (
333330 q : torch .Tensor ,
334331 k : torch .Tensor ,
@@ -445,7 +442,7 @@ def sageattn_varlen(
445442
446443 return o
447444
448-
445+ @ torch . compiler . disable
449446def sageattn_qk_int8_pv_fp16_cuda (
450447 q : torch .Tensor ,
451448 k : torch .Tensor ,
@@ -609,17 +606,17 @@ def sageattn_qk_int8_pv_fp16_cuda(
609606
610607 if pv_accum_dtype == 'fp32' :
611608 v = v .to (torch .float16 )
612- lse = sm80_compile .qk_int8_sv_f16_accum_f32_attn (q_int8 , k_int8 , v , o , q_scale , k_scale , _tensor_layout , _is_caual , _qk_quant_gran , sm_scale , _return_lse )
609+ lse = _qattn_sm80 .qk_int8_sv_f16_accum_f32_attn (q_int8 , k_int8 , v , o , q_scale , k_scale , _tensor_layout , _is_caual , _qk_quant_gran , sm_scale , _return_lse )
613610 elif pv_accum_dtype == "fp16" :
614611 if smooth_v :
615612 smoothed_v , vm = sub_mean (v , tensor_layout = tensor_layout )
616- lse = sm80_compile .qk_int8_sv_f16_accum_f16_fuse_v_mean_attn (q_int8 , k_int8 , smoothed_v , o , q_scale , k_scale , vm , _tensor_layout , _is_caual , _qk_quant_gran , sm_scale , _return_lse )
613+ lse = _qattn_sm80 .qk_int8_sv_f16_accum_f16_fuse_v_mean_attn (q_int8 , k_int8 , smoothed_v , o , q_scale , k_scale , vm , _tensor_layout , _is_caual , _qk_quant_gran , sm_scale , _return_lse )
617614 else :
618615 v = v .to (torch .float16 )
619- lse = sm80_compile .qk_int8_sv_f16_accum_f16_attn (q_int8 , k_int8 , v , o , q_scale , k_scale , _tensor_layout , _is_caual , _qk_quant_gran , sm_scale , _return_lse )
616+ lse = _qattn_sm80 .qk_int8_sv_f16_accum_f16_attn (q_int8 , k_int8 , v , o , q_scale , k_scale , _tensor_layout , _is_caual , _qk_quant_gran , sm_scale , _return_lse )
620617 elif pv_accum_dtype == "fp16+fp32" :
621618 v = v .to (torch .float16 )
622- lse = sm80_compile .qk_int8_sv_f16_accum_f16_attn_inst_buf (q_int8 , k_int8 , v , o , q_scale , k_scale , _tensor_layout , _is_caual , _qk_quant_gran , sm_scale , _return_lse )
619+ lse = _qattn_sm80 .qk_int8_sv_f16_accum_f16_attn_inst_buf (q_int8 , k_int8 , v , o , q_scale , k_scale , _tensor_layout , _is_caual , _qk_quant_gran , sm_scale , _return_lse )
623620 else :
624621 raise ValueError (f"Unsupported pv_accum_dtype: { pv_accum_dtype } " )
625622
@@ -630,7 +627,7 @@ def sageattn_qk_int8_pv_fp16_cuda(
630627 else :
631628 return o
632629
633-
630+ @ torch . compiler . disable
634631def sageattn_qk_int8_pv_fp8_cuda (
635632 q : torch .Tensor ,
636633 k : torch .Tensor ,
@@ -808,13 +805,13 @@ def sageattn_qk_int8_pv_fp8_cuda(
808805
809806 if pv_accum_dtype == "fp32" :
810807 if smooth_v :
811- lse = sm89_compile .qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn (q_int8 , k_int8 , v_fp8 , o , q_scale , k_scale , v_scale , vm , _tensor_layout , _is_caual , _qk_quant_gran , sm_scale , _return_lse )
808+ lse = _qattn_sm89 .qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn (q_int8 , k_int8 , v_fp8 , o , q_scale , k_scale , v_scale , vm , _tensor_layout , _is_caual , _qk_quant_gran , sm_scale , _return_lse )
812809 else :
813- lse = sm89_compile .qk_int8_sv_f8_accum_f32_fuse_v_scale_attn (q_int8 , k_int8 , v_fp8 , o , q_scale , k_scale , v_scale , _tensor_layout , _is_caual , _qk_quant_gran , sm_scale , _return_lse )
810+ lse = _qattn_sm89 .qk_int8_sv_f8_accum_f32_fuse_v_scale_attn (q_int8 , k_int8 , v_fp8 , o , q_scale , k_scale , v_scale , _tensor_layout , _is_caual , _qk_quant_gran , sm_scale , _return_lse )
814811 elif pv_accum_dtype == "fp32+fp32" :
815- lse = sm89_compile .qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf (q_int8 , k_int8 , v_fp8 , o , q_scale , k_scale , v_scale , _tensor_layout , _is_caual , _qk_quant_gran , sm_scale , _return_lse )
812+ lse = _qattn_sm89 .qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf (q_int8 , k_int8 , v_fp8 , o , q_scale , k_scale , v_scale , _tensor_layout , _is_caual , _qk_quant_gran , sm_scale , _return_lse )
816813 elif pv_accum_dtype == "fp32+fp16" :
817- lse = sm89_compile .qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf (q_int8 , k_int8 , v_fp8 , o , q_scale , k_scale , v_scale , _tensor_layout , _is_caual , _qk_quant_gran , sm_scale , _return_lse )
814+ lse = _qattn_sm89 .qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf (q_int8 , k_int8 , v_fp8 , o , q_scale , k_scale , v_scale , _tensor_layout , _is_caual , _qk_quant_gran , sm_scale , _return_lse )
818815
819816 o = o [..., :head_dim_og ]
820817
@@ -823,7 +820,7 @@ def sageattn_qk_int8_pv_fp8_cuda(
823820 else :
824821 return o
825822
826-
823+ @ torch . compiler . disable
827824def sageattn_qk_int8_pv_fp8_cuda_sm90 (
828825 q : torch .Tensor ,
829826 k : torch .Tensor ,
@@ -982,13 +979,13 @@ def sageattn_qk_int8_pv_fp8_cuda_sm90(
982979
983980 if pv_accum_dtype == "fp32" :
984981 raise NotImplementedError ("Please use pv_accum_dtype='fp32+fp32' for sm90." )
985- lse = sm90_compile .qk_int8_sv_f8_accum_f32_fuse_v_scale_attn (q_int8 , k_int8 , v_fp8 , o , q_scale , k_scale , v_scale , _tensor_layout , _is_caual , _qk_quant_gran , sm_scale , _return_lse )
982+ lse = _qattn_sm90 .qk_int8_sv_f8_accum_f32_fuse_v_scale_attn (q_int8 , k_int8 , v_fp8 , o , q_scale , k_scale , v_scale , _tensor_layout , _is_caual , _qk_quant_gran , sm_scale , _return_lse )
986983 elif pv_accum_dtype == "fp32+fp32" :
987- lse = sm90_compile .qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf (q_int8 , k_int8 , v_fp8 , o , q_scale , k_scale , v_scale , _tensor_layout , _is_caual , _qk_quant_gran , sm_scale , _return_lse )
984+ lse = _qattn_sm90 .qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf (q_int8 , k_int8 , v_fp8 , o , q_scale , k_scale , v_scale , _tensor_layout , _is_caual , _qk_quant_gran , sm_scale , _return_lse )
988985
989986 o = o [..., :head_dim_og ]
990987
991988 if return_lse :
992989 return o , lse / 1.44269504 + lse_correction * sm_scale if smooth_k else lse / 1.44269504
993990 else :
994- return o
991+ return o
0 commit comments