|
4 | 4 | import torch |
5 | 5 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension |
6 | 6 |
|
7 | | -if not torch.cuda.is_available(): |
8 | | - if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None: |
9 | | - os.environ["TORCH_CUDA_ARCH_LIST"] = "9.0" |
| 7 | +if os.environ.get("TORCH_CUDA_ARCH_LIST"): |
| 8 | + # Let PyTorch builder to choose device to target for. |
| 9 | + device_capability = "" |
| 10 | +else: |
| 11 | + device_capability = torch.cuda.get_device_capability() |
| 12 | + device_capability = f"{device_capability[0]}{device_capability[1]}" |
10 | 13 |
|
11 | 14 | cwd = Path(os.path.dirname(os.path.abspath(__file__))) |
12 | | -_dc = torch.cuda.get_device_capability() |
13 | | -_dc = f"{_dc[0]}{_dc[1]}" |
| 15 | + |
| 16 | +nvcc_flags = [ |
| 17 | + "-std=c++17", # NOTE: CUTLASS requires c++17 |
| 18 | +] |
| 19 | + |
| 20 | +if device_capability: |
| 21 | + nvcc_flags.extend([ |
| 22 | + f"--generate-code=arch=compute_{device_capability},code=sm_{device_capability}", |
| 23 | + f"-DGROUPED_GEMM_DEVICE_CAPABILITY={device_capability}", |
| 24 | + ]) |
14 | 25 |
|
15 | 26 | ext_modules = [ |
16 | 27 | CUDAExtension( |
|
24 | 35 | "cxx": [ |
25 | 36 | "-fopenmp", "-fPIC", "-Wno-strict-aliasing" |
26 | 37 | ], |
27 | | - "nvcc": [ |
28 | | - f"--generate-code=arch=compute_{_dc},code=sm_{_dc}", |
29 | | - f"-DGROUPED_GEMM_DEVICE_CAPABILITY={_dc}", |
30 | | - # NOTE: CUTLASS requires c++17. |
31 | | - "-std=c++17", |
32 | | - ], |
| 38 | + "nvcc": nvcc_flags, |
33 | 39 | } |
34 | 40 | ) |
35 | 41 | ] |
|
44 | 50 |
|
45 | 51 | setup( |
46 | 52 | name="grouped_gemm", |
47 | | - version="0.0.1", |
| 53 | + version="0.1.1", |
48 | 54 | author="Trevor Gale", |
49 | 55 | author_email="tgale@stanford.edu", |
50 | 56 | description="Grouped GEMM", |
|
0 commit comments