Skip to content

Commit 5d2ef2f

Browse files
committed
Enhance CUDA architecture support in setup.py by allowing user-defined architectures via environment variable. Refactor GPU capability checks and streamline NVCC flags for SM89 and SM90 extensions. Improve build process by creating separate output directories for extensions.
1 parent 798c791 commit 5d2ef2f

File tree

1 file changed

+100
-46
lines changed

1 file changed

+100
-46
lines changed

setup.py

Lines changed: 100 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -66,22 +66,56 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version:
6666
nvcc_cuda_version = parse(output[release_idx].split(",")[0])
6767
return nvcc_cuda_version
6868

69-
# Iterate over all GPUs on the current machine. Also you can modify this part to specify the architecture if you want to build for specific GPU architectures.
69+
def filter_nvcc_flags_for_arch(nvcc_flags, arch_substrings):
70+
"""Filter NVCC flags, only keep gencode flags for specified architectures"""
71+
filtered_flags = []
72+
skip_next = False
73+
for i, flag in enumerate(nvcc_flags):
74+
if skip_next:
75+
skip_next = False
76+
continue
77+
if flag == "-gencode":
78+
if i + 1 < len(nvcc_flags):
79+
arch_flag = nvcc_flags[i + 1]
80+
if any(sub in arch_flag for sub in arch_substrings):
81+
filtered_flags.append(flag)
82+
filtered_flags.append(arch_flag)
83+
skip_next = True
84+
elif flag not in ["-gencode"]:
85+
filtered_flags.append(flag)
86+
return filtered_flags
87+
7088
compute_capabilities = set()
71-
device_count = torch.cuda.device_count()
72-
for i in range(device_count):
73-
major, minor = torch.cuda.get_device_capability(i)
74-
if major < 8:
75-
warnings.warn(f"skipping GPU {i} with compute capability {major}.{minor}")
76-
continue
77-
compute_capabilities.add(f"{major}.{minor}")
89+
cuda_architectures = os.environ.get("CUDA_ARCHITECTURES")
90+
if cuda_architectures is not None:
91+
for arch in cuda_architectures.split(","):
92+
arch = arch.strip()
93+
if arch:
94+
compute_capabilities.add(arch)
95+
else:
96+
#Iterate over all GPUs on the current machine. Also you can modify this part to specify the architecture if you want to build for specific GPU architectures.
97+
device_count = torch.cuda.device_count()
98+
for i in range(device_count):
99+
major, minor = torch.cuda.get_device_capability(i)
100+
if major < 8:
101+
warnings.warn(f"skipping GPU {i} with compute capability {major}.{minor}")
102+
continue
103+
compute_capabilities.add(f"{major}.{minor}")
78104

79-
nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME)
80105
if not compute_capabilities:
81106
raise RuntimeError("No GPUs found. Please specify the target GPU architectures or build on a machine with GPUs.")
82107
else:
108+
109+
unsupported_archs = compute_capabilities - SUPPORTED_ARCHS
110+
if unsupported_archs:
111+
warnings.warn(f"Unsupported GPU architectures detected: {unsupported_archs}. Supported architectures: {SUPPORTED_ARCHS}")
112+
compute_capabilities = compute_capabilities & SUPPORTED_ARCHS
113+
if not compute_capabilities:
114+
raise RuntimeError(f"No supported GPU architectures found. Detected: {compute_capabilities | unsupported_archs}, Supported: {SUPPORTED_ARCHS}")
115+
83116
print(f"Detect GPUs with compute capabilities: {compute_capabilities}")
84117

118+
nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME)
85119
# Validate the NVCC CUDA version.
86120
if nvcc_cuda_version < Version("12.0"):
87121
raise RuntimeError("CUDA 12.0 or higher is required to build the package.")
@@ -119,54 +153,66 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version:
119153
ext_modules = []
120154

121155
if HAS_SM80 or HAS_SM86 or HAS_SM89 or HAS_SM90 or HAS_SM120:
122-
qattn_extension = CUDAExtension(
156+
sm80_sources = [
157+
"csrc/qattn/pybind_sm80.cpp",
158+
"csrc/qattn/qk_int_sv_f16_cuda_sm80.cu",
159+
]
160+
161+
qattn_extension_sm80 = CUDAExtension(
123162
name="sageattention._qattn_sm80",
124-
sources=[
125-
"csrc/qattn/pybind_sm80.cpp",
126-
"csrc/qattn/qk_int_sv_f16_cuda_sm80.cu",
127-
],
163+
sources=sm80_sources,
128164
extra_compile_args={
129165
"cxx": CXX_FLAGS,
130166
"nvcc": NVCC_FLAGS,
131167
},
132168
)
133-
ext_modules.append(qattn_extension)
169+
ext_modules.append(qattn_extension_sm80)
134170

135171
if HAS_SM89 or HAS_SM120:
136-
qattn_extension = CUDAExtension(
172+
sm89_sources = [
173+
"csrc/qattn/pybind_sm89.cpp",
174+
"csrc/qattn/sm89_qk_int8_sv_f8_accum_f32_attn_inst_buf.cu",
175+
"csrc/qattn/sm89_qk_int8_sv_f8_accum_f16_attn_inst_buf.cu",
176+
"csrc/qattn/sm89_qk_int8_sv_f8_accum_f32_attn.cu",
177+
"csrc/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn.cu",
178+
"csrc/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn.cu",
179+
"csrc/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf.cu",
180+
"csrc/qattn/sm89_qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf.cu"
181+
#"csrc/qattn/qk_int_sv_f8_cuda_sm89.cu",
182+
]
183+
184+
arch_substrings = ["sm_89", "compute_89", "sm_90a", "compute_90a", "sm_120", "compute_120"]
185+
filtered_flags = filter_nvcc_flags_for_arch(NVCC_FLAGS, arch_substrings)
186+
187+
qattn_extension_sm89 = CUDAExtension(
137188
name="sageattention._qattn_sm89",
138-
sources=[
139-
"csrc/qattn/pybind_sm89.cpp",
140-
"csrc/qattn/sm89_qk_int8_sv_f8_accum_f32_attn_inst_buf.cu",
141-
"csrc/qattn/sm89_qk_int8_sv_f8_accum_f16_attn_inst_buf.cu",
142-
"csrc/qattn/sm89_qk_int8_sv_f8_accum_f32_attn.cu",
143-
"csrc/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn.cu",
144-
"csrc/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn.cu",
145-
"csrc/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf.cu",
146-
"csrc/qattn/sm89_qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf.cu"
147-
#"csrc/qattn/qk_int_sv_f8_cuda_sm89.cu",
148-
],
189+
sources=sm89_sources,
149190
extra_compile_args={
150191
"cxx": CXX_FLAGS,
151-
"nvcc": NVCC_FLAGS,
192+
"nvcc": filtered_flags if filtered_flags else NVCC_FLAGS,
152193
},
153194
)
154-
ext_modules.append(qattn_extension)
195+
ext_modules.append(qattn_extension_sm89)
155196

156197
if HAS_SM90:
157-
qattn_extension = CUDAExtension(
198+
sm90_sources = [
199+
"csrc/qattn/pybind_sm90.cpp",
200+
"csrc/qattn/qk_int_sv_f8_cuda_sm90.cu",
201+
]
202+
203+
arch_substrings = ["sm_90a", "compute_90a"]
204+
filtered_flags = filter_nvcc_flags_for_arch(NVCC_FLAGS, arch_substrings)
205+
206+
qattn_extension_sm90 = CUDAExtension(
158207
name="sageattention._qattn_sm90",
159-
sources=[
160-
"csrc/qattn/pybind_sm90.cpp",
161-
"csrc/qattn/qk_int_sv_f8_cuda_sm90.cu",
162-
],
208+
sources=sm90_sources,
163209
extra_compile_args={
164210
"cxx": CXX_FLAGS,
165-
"nvcc": NVCC_FLAGS,
211+
"nvcc": filtered_flags if filtered_flags else NVCC_FLAGS,
166212
},
167213
extra_link_args=['-lcuda'],
168214
)
169-
ext_modules.append(qattn_extension)
215+
ext_modules.append(qattn_extension_sm90)
170216

171217
# Fused kernels.
172218
fused_extension = CUDAExtension(
@@ -208,24 +254,32 @@ def compile_new(*args, **kwargs):
208254
**kwargs,
209255
"output_dir": os.path.join(
210256
kwargs["output_dir"],
211-
self.thread_ext_name_map[threading.current_thread().ident]),
257+
self.thread_ext_name_map.get(threading.current_thread().ident, f"thread_{threading.current_thread().ident}")),
212258
})
213259
self.compiler.compile = compile_new
214260
self.compiler._compile_separate_output_dir = True
215261
self.thread_ext_name_map[threading.current_thread().ident] = ext.name
216-
objects = super().build_extension(ext)
217-
return objects
218262

263+
original_build_temp = self.build_temp
264+
self.build_temp = os.path.join(original_build_temp, ext.name.replace(".", "_"))
265+
os.makedirs(self.build_temp, exist_ok=True)
266+
267+
try:
268+
objects = super().build_extension(ext)
269+
finally:
270+
self.build_temp = original_build_temp
271+
272+
return objects
219273

220274
setup(
221-
name='sageattention',
222-
version='2.2.0',
275+
name='sageattention',
276+
version='2.2.0',
223277
author='SageAttention team',
224-
license='Apache 2.0 License',
225-
description='Accurate and efficient plug-and-play low-bit attention.',
226-
long_description=open('README.md', encoding='utf-8').read(),
227-
long_description_content_type='text/markdown',
228-
url='https://github.com/thu-ml/SageAttention',
278+
license='Apache 2.0 License',
279+
description='Accurate and efficient plug-and-play low-bit attention.',
280+
long_description=open('README.md', encoding='utf-8').read(),
281+
long_description_content_type='text/markdown',
282+
url='https://github.com/thu-ml/SageAttention',
229283
packages=find_packages(),
230284
python_requires='>=3.9',
231285
ext_modules=ext_modules,

0 commit comments

Comments
 (0)