File tree Expand file tree Collapse file tree 1 file changed +7
-3
lines changed Expand file tree Collapse file tree 1 file changed +7
-3
lines changed Original file line number Diff line number Diff line change 2
2
import os
3
3
import os .path as osp
4
4
import platform
5
+ import subprocess
5
6
import sys
6
7
from itertools import product
7
8
8
9
import torch
9
10
from setuptools import find_packages , setup
10
11
from torch .__config__ import parallel_info
11
- from torch .utils .cpp_extension import (CUDA_HOME , BuildExtension , CppExtension ,
12
- CUDAExtension )
13
12
14
13
__version__ = '2.1.2'
15
14
URL = 'https://github.com/rusty1s/pytorch_scatter'
16
15
16
+ CUDA_HOME = os .environ .get ("CUDA_HOME" , None )
17
17
WITH_CUDA = False
18
18
if torch .cuda .is_available ():
19
19
WITH_CUDA = CUDA_HOME is not None or torch .version .hip
28
28
BUILD_DOCS = os .getenv ('BUILD_DOCS' , '0' ) == '1'
29
29
WITH_SYMBOLS = os .getenv ('WITH_SYMBOLS' , '0' ) == '1'
30
30
31
+ def get_cuda_bare_metal_version (cuda_home ):
32
+ output = subprocess .check_output ([cuda_home + "/bin/nvcc" , "-V" ], universal_newlines = True ).split ()
33
+ release_idx = output .index ("release" )
34
+ return output [release_idx + 1 ].split ("," )[0 ].replace ('.' , '' )
31
35
32
36
def get_extensions ():
33
37
extensions = []
@@ -108,7 +112,7 @@ def get_extensions():
108
112
109
113
110
114
install_requires = ["torch>=1.8.0" ]
111
- extra_index_url = ["https://download.pytorch.org/whl/" ]
115
+ extra_index_url = ["https://download.pytorch.org/whl/cpu" ] if suffices == [ "cpu" ] else [ f"https://download.pytorch.org/whl/cu { get_cuda_bare_metal_version ( CUDA_HOME ) } " ]
112
116
113
117
test_requires = [
114
118
'pytest' ,
You can’t perform that action at this time.
0 commit comments