@@ -291,6 +291,37 @@ def get_extensions():
291
291
use_cuda = torch .version .cuda and (CUDA_HOME is not None or ROCM_HOME is not None )
292
292
extension = CUDAExtension if use_cuda else CppExtension
293
293
294
+ # =====================================================================================
295
+ # CUDA Architecture Settings
296
+ # =====================================================================================
297
+ # If TORCH_CUDA_ARCH_LIST is not set during compilation, PyTorch tries to automatically
298
+ # detect architectures from available GPUs. This can fail when:
299
+ # 1. No GPU is visible to PyTorch
300
+ # 2. CUDA is available but no device is detected
301
+ #
302
+ # To resolve this, you can manually set CUDA architecture targets:
303
+ # export TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0;8.6+PTX"
304
+ #
305
+ # Adding "+PTX" to the last architecture enables JIT compilation for future GPUs.
306
+ # =====================================================================================
307
+ if use_cuda and "TORCH_CUDA_ARCH_LIST" not in os .environ and torch .version .cuda :
308
+ # Set to common architectures for CUDA 12.x compatibility
309
+ cuda_arch_list = "7.0;7.5;8.0;8.6;8.9;9.0"
310
+
311
+ # Only add SM10.0 (Blackwell) flags when using CUDA 12.8 or newer
312
+ cuda_version = torch .version .cuda
313
+ if cuda_version and cuda_version .startswith ("12.8" ):
314
+ print ("Detected CUDA 12.8 - adding SM10.0 architectures to build list" )
315
+ cuda_arch_list += ";10.0"
316
+
317
+ # Add PTX to the last architecture for future compatibility
318
+ cuda_arch_list += "+PTX"
319
+
320
+ os .environ ["TORCH_CUDA_ARCH_LIST" ] = cuda_arch_list
321
+ print (
322
+ f"Setting default TORCH_CUDA_ARCH_LIST={ os .environ ['TORCH_CUDA_ARCH_LIST' ]} "
323
+ )
324
+
294
325
extra_link_args = []
295
326
extra_compile_args = {
296
327
"cxx" : [f"-DPy_LIMITED_API={ PY3_9_HEXCODE } " ],
0 commit comments