@@ -66,22 +66,56 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version:
6666 nvcc_cuda_version = parse (output [release_idx ].split ("," )[0 ])
6767 return nvcc_cuda_version
6868
69- # Iterate over all GPUs on the current machine. Also you can modify this part to specify the architecture if you want to build for specific GPU architectures.
69+ def filter_nvcc_flags_for_arch (nvcc_flags , arch_substrings ):
70+ """Filter NVCC flags, only keep gencode flags for specified architectures"""
71+ filtered_flags = []
72+ skip_next = False
73+ for i , flag in enumerate (nvcc_flags ):
74+ if skip_next :
75+ skip_next = False
76+ continue
77+ if flag == "-gencode" :
78+ if i + 1 < len (nvcc_flags ):
79+ arch_flag = nvcc_flags [i + 1 ]
80+ if any (sub in arch_flag for sub in arch_substrings ):
81+ filtered_flags .append (flag )
82+ filtered_flags .append (arch_flag )
83+ skip_next = True
84+ elif flag not in ["-gencode" ]:
85+ filtered_flags .append (flag )
86+ return filtered_flags
87+
7088compute_capabilities = set ()
71- device_count = torch .cuda .device_count ()
72- for i in range (device_count ):
73- major , minor = torch .cuda .get_device_capability (i )
74- if major < 8 :
75- warnings .warn (f"skipping GPU { i } with compute capability { major } .{ minor } " )
76- continue
77- compute_capabilities .add (f"{ major } .{ minor } " )
89+ cuda_architectures = os .environ .get ("CUDA_ARCHITECTURES" )
90+ if cuda_architectures is not None :
91+ for arch in cuda_architectures .split ("," ):
92+ arch = arch .strip ()
93+ if arch :
94+ compute_capabilities .add (arch )
95+ else :
96+ #Iterate over all GPUs on the current machine. Also you can modify this part to specify the architecture if you want to build for specific GPU architectures.
97+ device_count = torch .cuda .device_count ()
98+ for i in range (device_count ):
99+ major , minor = torch .cuda .get_device_capability (i )
100+ if major < 8 :
101+ warnings .warn (f"skipping GPU { i } with compute capability { major } .{ minor } " )
102+ continue
103+ compute_capabilities .add (f"{ major } .{ minor } " )
78104
79- nvcc_cuda_version = get_nvcc_cuda_version (CUDA_HOME )
80105if not compute_capabilities :
81106 raise RuntimeError ("No GPUs found. Please specify the target GPU architectures or build on a machine with GPUs." )
82107else :
108+
109+ unsupported_archs = compute_capabilities - SUPPORTED_ARCHS
110+ if unsupported_archs :
111+ warnings .warn (f"Unsupported GPU architectures detected: { unsupported_archs } . Supported architectures: { SUPPORTED_ARCHS } " )
112+ compute_capabilities = compute_capabilities & SUPPORTED_ARCHS
113+ if not compute_capabilities :
114+ raise RuntimeError (f"No supported GPU architectures found. Detected: { compute_capabilities | unsupported_archs } , Supported: { SUPPORTED_ARCHS } " )
115+
83116 print (f"Detect GPUs with compute capabilities: { compute_capabilities } " )
84117
118+ nvcc_cuda_version = get_nvcc_cuda_version (CUDA_HOME )
85119# Validate the NVCC CUDA version.
86120if nvcc_cuda_version < Version ("12.0" ):
87121 raise RuntimeError ("CUDA 12.0 or higher is required to build the package." )
@@ -119,54 +153,66 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version:
119153ext_modules = []
120154
121155if HAS_SM80 or HAS_SM86 or HAS_SM89 or HAS_SM90 or HAS_SM120 :
122- qattn_extension = CUDAExtension (
156+ sm80_sources = [
157+ "csrc/qattn/pybind_sm80.cpp" ,
158+ "csrc/qattn/qk_int_sv_f16_cuda_sm80.cu" ,
159+ ]
160+
161+ qattn_extension_sm80 = CUDAExtension (
123162 name = "sageattention._qattn_sm80" ,
124- sources = [
125- "csrc/qattn/pybind_sm80.cpp" ,
126- "csrc/qattn/qk_int_sv_f16_cuda_sm80.cu" ,
127- ],
163+ sources = sm80_sources ,
128164 extra_compile_args = {
129165 "cxx" : CXX_FLAGS ,
130166 "nvcc" : NVCC_FLAGS ,
131167 },
132168 )
133- ext_modules .append (qattn_extension )
169+ ext_modules .append (qattn_extension_sm80 )
134170
135171if HAS_SM89 or HAS_SM120 :
136- qattn_extension = CUDAExtension (
172+ sm89_sources = [
173+ "csrc/qattn/pybind_sm89.cpp" ,
174+ "csrc/qattn/sm89_qk_int8_sv_f8_accum_f32_attn_inst_buf.cu" ,
175+ "csrc/qattn/sm89_qk_int8_sv_f8_accum_f16_attn_inst_buf.cu" ,
176+ "csrc/qattn/sm89_qk_int8_sv_f8_accum_f32_attn.cu" ,
177+ "csrc/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn.cu" ,
178+ "csrc/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn.cu" ,
179+ "csrc/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf.cu" ,
180+ "csrc/qattn/sm89_qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf.cu"
181+ #"csrc/qattn/qk_int_sv_f8_cuda_sm89.cu",
182+ ]
183+
184+ arch_substrings = ["sm_89" , "compute_89" , "sm_90a" , "compute_90a" , "sm_120" , "compute_120" ]
185+ filtered_flags = filter_nvcc_flags_for_arch (NVCC_FLAGS , arch_substrings )
186+
187+ qattn_extension_sm89 = CUDAExtension (
137188 name = "sageattention._qattn_sm89" ,
138- sources = [
139- "csrc/qattn/pybind_sm89.cpp" ,
140- "csrc/qattn/sm89_qk_int8_sv_f8_accum_f32_attn_inst_buf.cu" ,
141- "csrc/qattn/sm89_qk_int8_sv_f8_accum_f16_attn_inst_buf.cu" ,
142- "csrc/qattn/sm89_qk_int8_sv_f8_accum_f32_attn.cu" ,
143- "csrc/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn.cu" ,
144- "csrc/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn.cu" ,
145- "csrc/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf.cu" ,
146- "csrc/qattn/sm89_qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf.cu"
147- #"csrc/qattn/qk_int_sv_f8_cuda_sm89.cu",
148- ],
189+ sources = sm89_sources ,
149190 extra_compile_args = {
150191 "cxx" : CXX_FLAGS ,
151- "nvcc" : NVCC_FLAGS ,
192+ "nvcc" : filtered_flags if filtered_flags else NVCC_FLAGS ,
152193 },
153194 )
154- ext_modules .append (qattn_extension )
195+ ext_modules .append (qattn_extension_sm89 )
155196
156197if HAS_SM90 :
157- qattn_extension = CUDAExtension (
198+ sm90_sources = [
199+ "csrc/qattn/pybind_sm90.cpp" ,
200+ "csrc/qattn/qk_int_sv_f8_cuda_sm90.cu" ,
201+ ]
202+
203+ arch_substrings = ["sm_90a" , "compute_90a" ]
204+ filtered_flags = filter_nvcc_flags_for_arch (NVCC_FLAGS , arch_substrings )
205+
206+ qattn_extension_sm90 = CUDAExtension (
158207 name = "sageattention._qattn_sm90" ,
159- sources = [
160- "csrc/qattn/pybind_sm90.cpp" ,
161- "csrc/qattn/qk_int_sv_f8_cuda_sm90.cu" ,
162- ],
208+ sources = sm90_sources ,
163209 extra_compile_args = {
164210 "cxx" : CXX_FLAGS ,
165- "nvcc" : NVCC_FLAGS ,
211+ "nvcc" : filtered_flags if filtered_flags else NVCC_FLAGS ,
166212 },
167213 extra_link_args = ['-lcuda' ],
168214 )
169- ext_modules .append (qattn_extension )
215+ ext_modules .append (qattn_extension_sm90 )
170216
171217# Fused kernels.
172218fused_extension = CUDAExtension (
@@ -208,24 +254,32 @@ def compile_new(*args, **kwargs):
208254 ** kwargs ,
209255 "output_dir" : os .path .join (
210256 kwargs ["output_dir" ],
211- self .thread_ext_name_map [ threading .current_thread ().ident ] ),
257+ self .thread_ext_name_map . get ( threading .current_thread ().ident , f"thread_ { threading . current_thread (). ident } " ) ),
212258 })
213259 self .compiler .compile = compile_new
214260 self .compiler ._compile_separate_output_dir = True
215261 self .thread_ext_name_map [threading .current_thread ().ident ] = ext .name
216- objects = super ().build_extension (ext )
217- return objects
218262
263+ original_build_temp = self .build_temp
264+ self .build_temp = os .path .join (original_build_temp , ext .name .replace ("." , "_" ))
265+ os .makedirs (self .build_temp , exist_ok = True )
266+
267+ try :
268+ objects = super ().build_extension (ext )
269+ finally :
270+ self .build_temp = original_build_temp
271+
272+ return objects
219273
220274setup (
221- name = 'sageattention' ,
222- version = '2.2.0' ,
275+ name = 'sageattention' ,
276+ version = '2.2.0' ,
223277 author = 'SageAttention team' ,
224- license = 'Apache 2.0 License' ,
225- description = 'Accurate and efficient plug-and-play low-bit attention.' ,
226- long_description = open ('README.md' , encoding = 'utf-8' ).read (),
227- long_description_content_type = 'text/markdown' ,
228- url = 'https://github.com/thu-ml/SageAttention' ,
278+ license = 'Apache 2.0 License' ,
279+ description = 'Accurate and efficient plug-and-play low-bit attention.' ,
280+ long_description = open ('README.md' , encoding = 'utf-8' ).read (),
281+ long_description_content_type = 'text/markdown' ,
282+ url = 'https://github.com/thu-ml/SageAttention' ,
229283 packages = find_packages (),
230284 python_requires = '>=3.9' ,
231285 ext_modules = ext_modules ,
0 commit comments