@@ -66,15 +66,22 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version:
6666 nvcc_cuda_version = parse (output [release_idx ].split ("," )[0 ])
6767 return nvcc_cuda_version
6868
69- # Iterate over all GPUs on the current machine. Also you can modify this part to specify the architecture if you want to build for specific GPU architectures.
7069compute_capabilities = set ()
71- device_count = torch .cuda .device_count ()
72- for i in range (device_count ):
73- major , minor = torch .cuda .get_device_capability (i )
74- if major < 8 :
75- warnings .warn (f"skipping GPU { i } with compute capability { major } .{ minor } " )
76- continue
77- compute_capabilities .add (f"{ major } .{ minor } " )
70+ cuda_architectures = os .environ .get ("CUDA_ARCHITECTURES" )
71+ if cuda_architectures is not None :
72+ for arch in cuda_architectures .split ("," ):
73+ arch = arch .strip ()
74+ if arch :
75+ compute_capabilities .add (arch )
76+ else :
77+ #Iterate over all GPUs on the current machine. Also you can modify this part to specify the architecture if you want to build for specific GPU architectures.
78+ device_count = torch .cuda .device_count ()
79+ for i in range (device_count ):
80+ major , minor = torch .cuda .get_device_capability (i )
81+ if major < 8 :
82+ warnings .warn (f"skipping GPU { i } with compute capability { major } .{ minor } " )
83+ continue
84+ compute_capabilities .add (f"{ major } .{ minor } " )
7885
7986nvcc_cuda_version = get_nvcc_cuda_version (CUDA_HOME )
8087if not compute_capabilities :
@@ -119,54 +126,96 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version:
119126ext_modules = []
120127
121128if HAS_SM80 or HAS_SM86 or HAS_SM89 or HAS_SM90 or HAS_SM120 :
122- qattn_extension = CUDAExtension (
129+ sm80_sources = [
130+ "csrc/qattn/pybind_sm80.cpp" ,
131+ "csrc/qattn/qk_int_sv_f16_cuda_sm80.cu" ,
132+ ]
133+
134+ qattn_extension_sm80 = CUDAExtension (
123135 name = "sageattention._qattn_sm80" ,
124- sources = [
125- "csrc/qattn/pybind_sm80.cpp" ,
126- "csrc/qattn/qk_int_sv_f16_cuda_sm80.cu" ,
127- ],
136+ sources = sm80_sources ,
128137 extra_compile_args = {
129138 "cxx" : CXX_FLAGS ,
130139 "nvcc" : NVCC_FLAGS ,
131140 },
132141 )
133- ext_modules .append (qattn_extension )
142+ ext_modules .append (qattn_extension_sm80 )
134143
135144if HAS_SM89 or HAS_SM120 :
136- qattn_extension = CUDAExtension (
145+ sm89_sources = [
146+ "csrc/qattn/pybind_sm89.cpp" ,
147+ "csrc/qattn/sm89_qk_int8_sv_f8_accum_f32_attn_inst_buf.cu" ,
148+ "csrc/qattn/sm89_qk_int8_sv_f8_accum_f16_attn_inst_buf.cu" ,
149+ "csrc/qattn/sm89_qk_int8_sv_f8_accum_f32_attn.cu" ,
150+ "csrc/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn.cu" ,
151+ "csrc/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn.cu" ,
152+ "csrc/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf.cu" ,
153+ "csrc/qattn/sm89_qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf.cu"
154+ #"csrc/qattn/qk_int_sv_f8_cuda_sm89.cu",
155+ ]
156+
157+ sm89_nvcc_flags = [flag for flag in NVCC_FLAGS ]
158+
159+ filtered_flags = []
160+ skip_next = False
161+ for i , flag in enumerate (sm89_nvcc_flags ):
162+ if skip_next :
163+ skip_next = False
164+ continue
165+ if flag == "-gencode" :
166+ if i + 1 < len (sm89_nvcc_flags ):
167+ arch_flag = sm89_nvcc_flags [i + 1 ]
168+ if "compute_89" in arch_flag or "compute_90" in arch_flag or "compute_120" in arch_flag :
169+ filtered_flags .append (flag )
170+ filtered_flags .append (arch_flag )
171+ skip_next = True
172+ elif flag not in ["-gencode" ]:
173+ filtered_flags .append (flag )
174+
175+ qattn_extension_sm89 = CUDAExtension (
137176 name = "sageattention._qattn_sm89" ,
138- sources = [
139- "csrc/qattn/pybind_sm89.cpp" ,
140- "csrc/qattn/sm89_qk_int8_sv_f8_accum_f32_attn_inst_buf.cu" ,
141- "csrc/qattn/sm89_qk_int8_sv_f8_accum_f16_attn_inst_buf.cu" ,
142- "csrc/qattn/sm89_qk_int8_sv_f8_accum_f32_attn.cu" ,
143- "csrc/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn.cu" ,
144- "csrc/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn.cu" ,
145- "csrc/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf.cu" ,
146- "csrc/qattn/sm89_qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf.cu"
147- #"csrc/qattn/qk_int_sv_f8_cuda_sm89.cu",
148- ],
177+ sources = sm89_sources ,
149178 extra_compile_args = {
150179 "cxx" : CXX_FLAGS ,
151- "nvcc" : NVCC_FLAGS ,
180+ "nvcc" : filtered_flags if filtered_flags else NVCC_FLAGS ,
152181 },
153182 )
154- ext_modules .append (qattn_extension )
183+ ext_modules .append (qattn_extension_sm89 )
155184
156185if HAS_SM90 :
157- qattn_extension = CUDAExtension (
186+ sm90_sources = [
187+ "csrc/qattn/pybind_sm90.cpp" ,
188+ "csrc/qattn/qk_int_sv_f8_cuda_sm90.cu" ,
189+ ]
190+
191+ sm90_nvcc_flags = [flag for flag in NVCC_FLAGS ]
192+
193+ filtered_flags = []
194+ skip_next = False
195+ for i , flag in enumerate (sm90_nvcc_flags ):
196+ if skip_next :
197+ skip_next = False
198+ continue
199+ if flag == "-gencode" :
200+ if i + 1 < len (sm90_nvcc_flags ):
201+ arch_flag = sm90_nvcc_flags [i + 1 ]
202+ if "compute_90" in arch_flag or "compute_120" in arch_flag :
203+ filtered_flags .append (flag )
204+ filtered_flags .append (arch_flag )
205+ skip_next = True
206+ elif flag not in ["-gencode" ]:
207+ filtered_flags .append (flag )
208+
209+ qattn_extension_sm90 = CUDAExtension (
158210 name = "sageattention._qattn_sm90" ,
159- sources = [
160- "csrc/qattn/pybind_sm90.cpp" ,
161- "csrc/qattn/qk_int_sv_f8_cuda_sm90.cu" ,
162- ],
211+ sources = sm90_sources ,
163212 extra_compile_args = {
164213 "cxx" : CXX_FLAGS ,
165- "nvcc" : NVCC_FLAGS ,
214+ "nvcc" : filtered_flags if filtered_flags else NVCC_FLAGS ,
166215 },
167216 extra_link_args = ['-lcuda' ],
168217 )
169- ext_modules .append (qattn_extension )
218+ ext_modules .append (qattn_extension_sm90 )
170219
171220# Fused kernels.
172221fused_extension = CUDAExtension (
@@ -208,15 +257,23 @@ def compile_new(*args, **kwargs):
208257 ** kwargs ,
209258 "output_dir" : os .path .join (
210259 kwargs ["output_dir" ],
211- self .thread_ext_name_map [ threading .current_thread ().ident ] ),
260+ self .thread_ext_name_map . get ( threading .current_thread ().ident , "default" ) ),
212261 })
213262 self .compiler .compile = compile_new
214263 self .compiler ._compile_separate_output_dir = True
215264 self .thread_ext_name_map [threading .current_thread ().ident ] = ext .name
216- objects = super ().build_extension (ext )
265+
266+ original_build_temp = self .build_temp
267+ self .build_temp = os .path .join (original_build_temp , ext .name .replace ("." , "_" ))
268+ os .makedirs (self .build_temp , exist_ok = True )
269+
270+ try :
271+ objects = super ().build_extension (ext )
272+ finally :
273+ self .build_temp = original_build_temp
274+
217275 return objects
218276
219-
220277setup (
221278 name = 'sageattention' ,
222279 version = '2.2.0' ,
0 commit comments