Skip to content

Commit dcef513

Browse files
committed
Look for nvcc in CUDA_HOME
1 parent daf9628 commit dcef513

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

bindings/torch/setup.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22

3+
from pathlib import Path
34
import re
45
from setuptools import setup
56
from pkg_resources import parse_version
@@ -8,7 +9,7 @@
89
import sys
910
import torch
1011
from glob import glob
11-
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
12+
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME
1213

1314
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
1415
ROOT_DIR = os.path.dirname(os.path.dirname(SCRIPT_DIR))
@@ -80,9 +81,14 @@ def find_cl_path():
8081
cpp_standard = 14
8182

8283
# Get CUDA version and make sure the targeted compute capability is compatible
83-
if os.system("nvcc --version") == 0:
84-
nvcc_out = subprocess.check_output(["nvcc", "--version"]).decode()
85-
cuda_version = re.search(r"release (\S+),", nvcc_out)
84+
nvcc_version_result = subprocess.run(
85+
[str(Path(CUDA_HOME) / "bin" / "nvcc"), "--version"],
86+
text=True,
87+
check=False,
88+
stdout=subprocess.PIPE,
89+
)
90+
if nvcc_version_result.returncode == 0:
91+
cuda_version = re.search(r"release (\S+),", nvcc_version_result.stdout)
8692

8793
if cuda_version:
8894
cuda_version = parse_version(cuda_version.group(1))

0 commit comments

Comments
 (0)