Skip to content

Commit a5e5311

Browse files
authored
Merge pull request #7 from mvpatel2000/mvpatel2000/update-seutp
Update setup to be more flexible in cuda builds
2 parents 108009a + 35034ac commit a5e5311

1 file changed

Lines changed: 18 additions & 12 deletions

File tree

setup.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,24 @@
44
import torch
55
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
66

7-
if not torch.cuda.is_available():
8-
if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None:
9-
os.environ["TORCH_CUDA_ARCH_LIST"] = "9.0"
7+
if os.environ.get("TORCH_CUDA_ARCH_LIST"):
8+
# Let PyTorch builder to choose device to target for.
9+
device_capability = ""
10+
else:
11+
device_capability = torch.cuda.get_device_capability()
12+
device_capability = f"{device_capability[0]}{device_capability[1]}"
1013

1114
cwd = Path(os.path.dirname(os.path.abspath(__file__)))
12-
_dc = torch.cuda.get_device_capability()
13-
_dc = f"{_dc[0]}{_dc[1]}"
15+
16+
nvcc_flags = [
17+
"-std=c++17", # NOTE: CUTLASS requires c++17
18+
]
19+
20+
if device_capability:
21+
nvcc_flags.extend([
22+
f"--generate-code=arch=compute_{device_capability},code=sm_{device_capability}",
23+
f"-DGROUPED_GEMM_DEVICE_CAPABILITY={device_capability}",
24+
])
1425

1526
ext_modules = [
1627
CUDAExtension(
@@ -24,12 +35,7 @@
2435
"cxx": [
2536
"-fopenmp", "-fPIC", "-Wno-strict-aliasing"
2637
],
27-
"nvcc": [
28-
f"--generate-code=arch=compute_{_dc},code=sm_{_dc}",
29-
f"-DGROUPED_GEMM_DEVICE_CAPABILITY={_dc}",
30-
# NOTE: CUTLASS requires c++17.
31-
"-std=c++17",
32-
],
38+
"nvcc": nvcc_flags,
3339
}
3440
)
3541
]
@@ -44,7 +50,7 @@
4450

4551
setup(
4652
name="grouped_gemm",
47-
version="0.0.1",
53+
version="0.1.1",
4854
author="Trevor Gale",
4955
author_email="tgale@stanford.edu",
5056
description="Grouped GEMM",

0 commit comments

Comments
 (0)