Skip to content

Commit e475f56

Browse files
committed
Try cibuildwheel.
1 parent f4ff626 commit e475f56

File tree

3 files changed

+37
-67
lines changed

3 files changed

+37
-67
lines changed

Diff for: .github/workflows/publish.yml

+12-58
Original file line numberDiff line numberDiff line change
@@ -40,27 +40,17 @@ jobs:
4040
strategy:
4141
fail-fast: false
4242
matrix:
43-
# Using ubuntu-20.04 instead of 22.04 for more compatibility (glibc). Ideally we'd use the
44-
# manylinux docker image, but I haven't figured out how to install CUDA on manylinux.
4543
os: [ubuntu-20.04]
46-
python-version: ['3.9', '3.10', '3.11', '3.12']
47-
jax-version: ['0.4.24']
48-
cuda-version: ['11.8.0', '12.3.1']
44+
python-version: ['cp39', 'cp310', 'cp311', 'cp312']
45+
cuda-version: ['11.8', '12.3']
4946

5047
steps:
5148
- name: Checkout
5249
uses: actions/checkout@v3
5350

54-
- name: Set up Python
55-
uses: actions/setup-python@v4
56-
with:
57-
python-version: ${{ matrix.python-version }}
58-
5951
- name: Set CUDA and PyTorch versions
6052
run: |
6153
echo "MATRIX_CUDA_MAJOR=$(echo ${{ matrix.cuda-version }} | awk -F \. {'print $1'})" >> $GITHUB_ENV
62-
echo "MATRIX_CUDA_VERSION=$(echo ${{ matrix.cuda-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV
63-
echo "MATRIX_JAX_VERSION=$(echo ${{ matrix.jax-version }} | awk -F \. {'print $1 "." $2'})" >> $GITHUB_ENV
6454
6555
- name: Free up disk space
6656
if: ${{ runner.os == 'Linux' }}
@@ -77,53 +67,17 @@ jobs:
7767
with:
7868
swap-size-gb: 10
7969

80-
- name: Install CUDA ${{ matrix.cuda-version }}
81-
if: ${{ matrix.cuda-version != 'cpu' }}
82-
uses: Jimver/[email protected]
83-
id: cuda-toolkit
84-
with:
85-
cuda: ${{ matrix.cuda-version }}
86-
linux-local-args: '["--toolkit"]'
87-
# default method is "local", and we're hitting some error with caching for CUDA 11.8 and 12.1
88-
# method: ${{ (matrix.cuda-version == '11.8.0' || matrix.cuda-version == '12.1.0') && 'network' || 'local' }}
89-
method: 'network'
90-
# We need the cuda libraries (e.g. cuSparse, cuSolver) for compiling PyTorch extensions,
91-
# not just nvcc
92-
# sub-packages: '["nvcc"]'
93-
94-
- name: Install Jax ${{ matrix.torch-version }}+cu${{ matrix.cuda-version }}
95-
run: |
96-
pip install --upgrade pip
97-
pip install --upgrade "jax[cuda${MATRIX_CUDA_MAJOR}_local] == ${{ matrix.jax-version }}" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
98-
shell:
99-
bash
100-
101-
- name: Build wheel
102-
run: |
103-
# We want setuptools >= 49.6.0 otherwise we can't compile the extension if system CUDA version is 11.7 and pytorch cuda version is 11.6
104-
# https://github.com/pytorch/pytorch/blob/664058fa83f1d8eede5d66418abff6e20bd76ca8/torch/utils/cpp_extension.py#L810
105-
# However this still fails so I'm using a newer version of setuptools
106-
pip install setuptools==68.0.0
107-
# setuptools-cuda-cpp on pypi has a bug that breaks ninja
108-
pip install git+https://github.com/nshepperd/setuptools-cuda-cpp
109-
pip install ninja packaging wheel pybind11
110-
export PATH=/usr/local/nvidia/bin:/usr/local/nvidia/lib64:$PATH
111-
export LD_LIBRARY_PATH=/usr/local/nvidia/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH
112-
# Set MAX_JOBS to allocate 8GB per job, which should be enough to build comfortably
113-
free -h
114-
export MAX_JOBS=3
115-
echo "Building with ${MAX_JOBS} jobs"
116-
python setup.py bdist_wheel --dist-dir=dist
117-
tmpname=cu${MATRIX_CUDA_VERSION}jax${{ matrix.jax-version }}
118-
wheel_name=$(ls dist/*whl | xargs -n 1 basename | sed "s/-/+$tmpname-/2")
119-
ls dist/*whl |xargs -I {} mv {} dist/${wheel_name}
120-
echo "wheel_name=${wheel_name}" >> $GITHUB_ENV
121-
shell:
122-
bash
70+
- name: Build wheels
71+
uses: pypa/[email protected]
72+
env:
73+
CIBW_BUILD: ${{ matrix.python-version }}-manylinux_x86_64
74+
CIBW_MANYLINUX_X86_64_IMAGE: manylinux2014_x86_64_cuda_${{ matrix.cuda-version }}
12375

12476
- name: Log Built Wheels
12577
run: |
126-
ls dist
78+
ls wheelhouse
79+
wheel_name=$(basename wheelhouse/*.whl)
80+
echo "wheel_name=${wheel_name}" >> $GITHUB_ENV
12781
12882
- name: Get the tag version
12983
id: extract_branch
@@ -144,15 +98,15 @@ jobs:
14498
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
14599
with:
146100
upload_url: ${{ steps.get_current_release.outputs.upload_url }}
147-
asset_path: ./dist/${{env.wheel_name}}
101+
asset_path: ./wheelhouse/${{env.wheel_name}}
148102
asset_name: ${{env.wheel_name}}
149103
asset_content_type: application/*
150104

151105
- name: Upload Artifact
152106
uses: actions/upload-artifact@v4
153107
with:
154108
name: ${{env.wheel_name}}
155-
path: ./dist/${{env.wheel_name}}
109+
path: ./wheelhouse/${{env.wheel_name}}
156110

157111
publish_package:
158112
name: Publish package

Diff for: pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
[build-system]
2-
requires = ["setuptools", "wheel", "setuptools-cuda-cpp", "packaging", "pybind11"]
2+
requires = ["setuptools", "wheel", "setuptools-cuda-cpp @ git+https://github.com/nshepperd/setuptools-cuda-cpp", "packaging", "pybind11"]

Diff for: setup.py

+24-8
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
from setuptools import setup, find_packages
1212
from setuptools_cuda_cpp import CUDAExtension, BuildExtension, fix_dll
13-
# from setuptools_cuda.inspections import find_cuda_home
1413
import pybind11
1514

1615
import subprocess
@@ -53,14 +52,31 @@ def get_platform():
5352
else:
5453
raise ValueError("Unsupported platform: {}".format(sys.platform))
5554

56-
57-
def get_cuda_bare_metal_version(cuda_dir):
55+
def locate_cuda():
56+
if 'sdist' in sys.argv:
57+
return None
58+
cuda_dir = os.environ.get("CUDA_HOME", None)
59+
if cuda_dir is None:
60+
if os.path.exists("/usr/local/cuda"):
61+
cuda_dir = "/usr/local/cuda"
62+
os.environ["CUDA_HOME"] = cuda_dir
63+
elif os.path.exists("/opt/cuda"):
64+
cuda_dir = "/opt/cuda"
65+
os.environ["CUDA_HOME"] = cuda_dir
66+
else:
67+
raise RuntimeError("CUDA_HOME not set and no CUDA installation found")
68+
return cuda_dir
69+
70+
71+
def get_cuda_version():
72+
cuda_dir = locate_cuda()
73+
if cuda_dir is None:
74+
return ""
5875
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
5976
output = raw_output.split()
6077
release_idx = output.index("release") + 1
61-
bare_metal_version = parse(output[release_idx].split(",")[0])
62-
63-
return raw_output, bare_metal_version
78+
version = output[release_idx].split(",")[0].split('.')[0] # should be 11 or 12
79+
return f'+cu{version}'
6480

6581

6682
def append_nvcc_threads(nvcc_extra_args):
@@ -180,9 +196,9 @@ def get_package_version():
180196
public_version = ast.literal_eval(version_match.group(1))
181197
local_version = os.environ.get("FLASH_ATTN_LOCAL_VERSION")
182198
if local_version:
183-
return f"{public_version}+{local_version}"
199+
return f"{public_version}+{local_version}{get_cuda_version()}"
184200
else:
185-
return str(public_version)
201+
return f"{public_version}{get_cuda_version()}"
186202

187203

188204
class NinjaBuildExtension(BuildExtension):

0 commit comments

Comments
 (0)