Skip to content

Commit 691420f

Browse files
committed
add torch as torch_scatter requirements
1 parent 96aa2e3 commit 691420f

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

setup.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import os
33
import os.path as osp
44
import platform
5+
import subprocess
56
import sys
67
from itertools import product
78

@@ -14,6 +15,7 @@
1415
__version__ = '2.1.2'
1516
URL = 'https://github.com/rusty1s/pytorch_scatter'
1617

18+
CUDA_HOME = os.environ.get("CUDA_HOME", None)
1719
WITH_CUDA = False
1820
if torch.cuda.is_available():
1921
WITH_CUDA = CUDA_HOME is not None or torch.version.hip
@@ -28,6 +30,10 @@
2830
BUILD_DOCS = os.getenv('BUILD_DOCS', '0') == '1'
2931
WITH_SYMBOLS = os.getenv('WITH_SYMBOLS', '0') == '1'
3032

33+
def get_cuda_bare_metal_version(cuda_home):
34+
output = subprocess.check_output([cuda_home + "/bin/nvcc", "-V"], universal_newlines=True).split()
35+
release_idx = output.index("release")
36+
return output[release_idx+1].split(",")[0].replace('.', '')
3137

3238
def get_extensions():
3339
extensions = []
@@ -107,7 +113,8 @@ def get_extensions():
107113
return extensions
108114

109115

110-
install_requires = []
116+
install_requires = ["torch>=1.8.0"]
117+
extra_index_url = ["https://download.pytorch.org/whl/cpu"] if suffices == ["cpu"] else [f"https://download.pytorch.org/whl/cu{get_cuda_bare_metal_version(CUDA_HOME)}"]
111118

112119
test_requires = [
113120
'pytest',
@@ -130,6 +137,7 @@ def get_extensions():
130137
keywords=['pytorch', 'scatter', 'segment', 'gather'],
131138
python_requires='>=3.8',
132139
install_requires=install_requires,
140+
extra_index_url=extra_index_url,
133141
extras_require={
134142
'test': test_requires,
135143
},

0 commit comments

Comments
 (0)