Skip to content

Commit 77c3c84

Browse files
committed
fix setup.py
1 parent 6e21003 commit 77c3c84

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

setup.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,18 @@
22
import os
33
import os.path as osp
44
import platform
5+
import subprocess
56
import sys
67
from itertools import product
78

89
import torch
910
from setuptools import find_packages, setup
1011
from torch.__config__ import parallel_info
11-
from torch.utils.cpp_extension import (CUDA_HOME, BuildExtension, CppExtension,
12-
CUDAExtension)
1312

1413
__version__ = '2.1.2'
1514
URL = 'https://github.com/rusty1s/pytorch_scatter'
1615

16+
CUDA_HOME = os.environ.get("CUDA_HOME", None)
1717
WITH_CUDA = False
1818
if torch.cuda.is_available():
1919
WITH_CUDA = CUDA_HOME is not None or torch.version.hip
@@ -28,6 +28,10 @@
2828
BUILD_DOCS = os.getenv('BUILD_DOCS', '0') == '1'
2929
WITH_SYMBOLS = os.getenv('WITH_SYMBOLS', '0') == '1'
3030

31+
def get_cuda_bare_metal_version(cuda_home):
32+
output = subprocess.check_output([cuda_home + "/bin/nvcc", "-V"], universal_newlines=True).split()
33+
release_idx = output.index("release")
34+
return output[release_idx+1].split(",")[0].replace('.', '')
3135

3236
def get_extensions():
3337
extensions = []
@@ -108,7 +112,7 @@ def get_extensions():
108112

109113

110114
install_requires = ["torch>=1.8.0"]
111-
extra_index_url = ["https://download.pytorch.org/whl/"]
115+
extra_index_url = ["https://download.pytorch.org/whl/cpu"] if suffices == ["cpu"] else [f"https://download.pytorch.org/whl/cu{get_cuda_bare_metal_version(CUDA_HOME)}"]
112116

113117
test_requires = [
114118
'pytest',

0 commit comments

Comments
 (0)