|
8 | 8 | import sys
|
9 | 9 | import torch
|
10 | 10 | from glob import glob
|
11 |
| -from torch.utils.cpp_extension import BuildExtension, CUDAExtension |
| 11 | +from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME |
12 | 12 |
|
13 | 13 | SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
|
14 | 14 | ROOT_DIR = os.path.dirname(os.path.dirname(SCRIPT_DIR))
|
@@ -80,26 +80,64 @@ def find_cl_path():
|
80 | 80 | cpp_standard = 14
|
81 | 81 |
|
82 | 82 | # Get CUDA version and make sure the targeted compute capability is compatible
|
83 |
| -if os.system("nvcc --version") == 0: |
84 |
| - nvcc_out = subprocess.check_output(["nvcc", "--version"]).decode() |
85 |
| - cuda_version = re.search(r"release (\S+),", nvcc_out) |
86 |
| - |
87 |
| - if cuda_version: |
88 |
| - cuda_version = parse_version(cuda_version.group(1)) |
89 |
| - print(f"Detected CUDA version {cuda_version}") |
90 |
| - if cuda_version >= parse_version("11.0"): |
91 |
| - cpp_standard = 17 |
92 |
| - |
93 |
| - supported_compute_capabilities = [ |
94 |
| - cc for cc in compute_capabilities if cc >= min_supported_compute_capability(cuda_version) and cc <= max_supported_compute_capability(cuda_version) |
95 |
| - ] |
96 |
| - |
97 |
| - if not supported_compute_capabilities: |
98 |
| - supported_compute_capabilities = [max_supported_compute_capability(cuda_version)] |
99 |
| - |
100 |
| - if supported_compute_capabilities != compute_capabilities: |
101 |
| - print(f"WARNING: Compute capabilities {compute_capabilities} are not all supported by the installed CUDA version {cuda_version}. Targeting {supported_compute_capabilities} instead.") |
102 |
| - compute_capabilities = supported_compute_capabilities |
| 83 | +def _maybe_find_nvcc(): |
| 84 | + # Try PATH first |
| 85 | + maybe_nvcc = shutil.which("nvcc") |
| 86 | + |
| 87 | + if maybe_nvcc is not None: |
| 88 | + return maybe_nvcc |
| 89 | + |
| 90 | + # Then try CUDA_HOME from torch (cpp_extension.CUDA_HOME is undocumented, which is why we only use |
| 91 | + # it as a fallback) |
| 92 | + try: |
| 93 | + from torch.utils.cpp_extension import CUDA_HOME |
| 94 | + except ImportError: |
| 95 | + return None |
| 96 | + |
| 97 | + if not CUDA_HOME: |
| 98 | + return None |
| 99 | + |
| 100 | + return os.path.join(CUDA_HOME, "bin", "nvcc") |
| 101 | + |
| 102 | +def _maybe_nvcc_version(): |
| 103 | + maybe_nvcc = _maybe_find_nvcc() |
| 104 | + |
| 105 | + if maybe_nvcc is None: |
| 106 | + return None |
| 107 | + |
| 108 | + nvcc_version_result = subprocess.run( |
| 109 | + [maybe_nvcc, "--version"], |
| 110 | + text=True, |
| 111 | + check=False, |
| 112 | + stdout=subprocess.PIPE, |
| 113 | + ) |
| 114 | + |
| 115 | + if nvcc_version_result.returncode != 0: |
| 116 | + return None |
| 117 | + |
| 118 | + cuda_version = re.search(r"release (\S+),", nvcc_version_result.stdout) |
| 119 | + |
| 120 | + if not cuda_version: |
| 121 | + return None |
| 122 | + |
| 123 | + return parse_version(cuda_version.group(1)) |
| 124 | + |
| 125 | +cuda_version = _maybe_nvcc_version() |
| 126 | +if cuda_version is not None: |
| 127 | + print(f"Detected CUDA version {cuda_version}") |
| 128 | + if cuda_version >= parse_version("11.0"): |
| 129 | + cpp_standard = 17 |
| 130 | + |
| 131 | + supported_compute_capabilities = [ |
| 132 | + cc for cc in compute_capabilities if cc >= min_supported_compute_capability(cuda_version) and cc <= max_supported_compute_capability(cuda_version) |
| 133 | + ] |
| 134 | + |
| 135 | + if not supported_compute_capabilities: |
| 136 | + supported_compute_capabilities = [max_supported_compute_capability(cuda_version)] |
| 137 | + |
| 138 | + if supported_compute_capabilities != compute_capabilities: |
| 139 | + print(f"WARNING: Compute capabilities {compute_capabilities} are not all supported by the installed CUDA version {cuda_version}. Targeting {supported_compute_capabilities} instead.") |
| 140 | + compute_capabilities = supported_compute_capabilities |
103 | 141 |
|
104 | 142 | min_compute_capability = min(compute_capabilities)
|
105 | 143 |
|
|
0 commit comments