Skip to content

Commit 7cfd33d

Browse files
committed
Manually specify flags if no arch set
stack-info: PR: #2219, branch: drisspg/stack/55
1 parent 5549da8 commit 7cfd33d

File tree

2 files changed

+31
-0
lines changed

2 files changed

+31
-0
lines changed

CUDA_ARCH_NOTES.md

Whitespace-only changes.

setup.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,37 @@ def get_extensions():
291291
use_cuda = torch.version.cuda and (CUDA_HOME is not None or ROCM_HOME is not None)
292292
extension = CUDAExtension if use_cuda else CppExtension
293293

294+
# =====================================================================================
295+
# CUDA Architecture Settings
296+
# =====================================================================================
297+
# If TORCH_CUDA_ARCH_LIST is not set during compilation, PyTorch tries to automatically
298+
# detect architectures from available GPUs. This can fail when:
299+
# 1. No GPU is visible to PyTorch
300+
# 2. CUDA is available but no device is detected
301+
#
302+
# To resolve this, you can manually set CUDA architecture targets:
303+
# export TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0;8.6+PTX"
304+
#
305+
# Adding "+PTX" to the last architecture enables JIT compilation for future GPUs.
306+
# =====================================================================================
307+
if use_cuda and "TORCH_CUDA_ARCH_LIST" not in os.environ and torch.version.cuda:
308+
# Set to common architectures for CUDA 12.x compatibility
309+
cuda_arch_list = "7.0;7.5;8.0;8.6;8.9;9.0"
310+
311+
# Only add SM10.0 (Blackwell) flags when using CUDA 12.8 or newer
312+
cuda_version = torch.version.cuda
313+
if cuda_version and cuda_version.startswith("12.8"):
314+
print("Detected CUDA 12.8 - adding SM10.0 architectures to build list")
315+
cuda_arch_list += ";10.0"
316+
317+
# Add PTX to the last architecture for future compatibility
318+
cuda_arch_list += "+PTX"
319+
320+
os.environ["TORCH_CUDA_ARCH_LIST"] = cuda_arch_list
321+
print(
322+
f"Setting default TORCH_CUDA_ARCH_LIST={os.environ['TORCH_CUDA_ARCH_LIST']}"
323+
)
324+
294325
extra_link_args = []
295326
extra_compile_args = {
296327
"cxx": [f"-DPy_LIMITED_API={PY3_9_HEXCODE}"],

0 commit comments

Comments
 (0)