@@ -40,27 +40,17 @@ jobs:
40
40
strategy :
41
41
fail-fast : false
42
42
matrix :
43
- # Using ubuntu-20.04 instead of 22.04 for more compatibility (glibc). Ideally we'd use the
44
- # manylinux docker image, but I haven't figured out how to install CUDA on manylinux.
45
43
os : [ubuntu-20.04]
46
- python-version : ['3.9', '3.10', '3.11', '3.12']
47
- jax-version : ['0.4.24']
48
- cuda-version : ['11.8.0', '12.3.1']
44
+ python-version : ['cp39', 'cp310', 'cp311', 'cp312']
45
+ cuda-version : ['11.8', '12.3']
49
46
50
47
steps :
51
48
- name : Checkout
52
49
uses : actions/checkout@v3
53
50
54
- - name : Set up Python
55
- uses : actions/setup-python@v4
56
- with :
57
- python-version : ${{ matrix.python-version }}
58
-
59
51
- name : Set CUDA and PyTorch versions
60
52
run : |
61
53
echo "MATRIX_CUDA_MAJOR=$(echo ${{ matrix.cuda-version }} | awk -F \. {'print $1'})" >> $GITHUB_ENV
62
- echo "MATRIX_CUDA_VERSION=$(echo ${{ matrix.cuda-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV
63
- echo "MATRIX_JAX_VERSION=$(echo ${{ matrix.jax-version }} | awk -F \. {'print $1 "." $2'})" >> $GITHUB_ENV
64
54
65
55
- name : Free up disk space
66
56
if : ${{ runner.os == 'Linux' }}
@@ -77,53 +67,17 @@ jobs:
77
67
with :
78
68
swap-size-gb : 10
79
69
80
- - name : Install CUDA ${{ matrix.cuda-version }}
81
- if : ${{ matrix.cuda-version != 'cpu' }}
82
-
83
- id : cuda-toolkit
84
- with :
85
- cuda : ${{ matrix.cuda-version }}
86
- linux-local-args : ' ["--toolkit"]'
87
- # default method is "local", and we're hitting some error with caching for CUDA 11.8 and 12.1
88
- # method: ${{ (matrix.cuda-version == '11.8.0' || matrix.cuda-version == '12.1.0') && 'network' || 'local' }}
89
- method : ' network'
90
- # We need the cuda libraries (e.g. cuSparse, cuSolver) for compiling PyTorch extensions,
91
- # not just nvcc
92
- # sub-packages: '["nvcc"]'
93
-
94
- - name : Install Jax ${{ matrix.torch-version }}+cu${{ matrix.cuda-version }}
95
- run : |
96
- pip install --upgrade pip
97
- pip install --upgrade "jax[cuda${MATRIX_CUDA_MAJOR}_local] == ${{ matrix.jax-version }}" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
98
- shell :
99
- bash
100
-
101
- - name : Build wheel
102
- run : |
103
- # We want setuptools >= 49.6.0 otherwise we can't compile the extension if system CUDA version is 11.7 and pytorch cuda version is 11.6
104
- # https://github.com/pytorch/pytorch/blob/664058fa83f1d8eede5d66418abff6e20bd76ca8/torch/utils/cpp_extension.py#L810
105
- # However this still fails so I'm using a newer version of setuptools
106
- pip install setuptools==68.0.0
107
- # setuptools-cuda-cpp on pypi has a bug that breaks ninja
108
- pip install git+https://github.com/nshepperd/setuptools-cuda-cpp
109
- pip install ninja packaging wheel pybind11
110
- export PATH=/usr/local/nvidia/bin:/usr/local/nvidia/lib64:$PATH
111
- export LD_LIBRARY_PATH=/usr/local/nvidia/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH
112
- # Set MAX_JOBS to allocate 8GB per job, which should be enough to build comfortably
113
- free -h
114
- export MAX_JOBS=3
115
- echo "Building with ${MAX_JOBS} jobs"
116
- python setup.py bdist_wheel --dist-dir=dist
117
- tmpname=cu${MATRIX_CUDA_VERSION}jax${{ matrix.jax-version }}
118
- wheel_name=$(ls dist/*whl | xargs -n 1 basename | sed "s/-/+$tmpname-/2")
119
- ls dist/*whl |xargs -I {} mv {} dist/${wheel_name}
120
- echo "wheel_name=${wheel_name}" >> $GITHUB_ENV
121
- shell :
122
- bash
70
+ - name : Build wheels
71
+
72
+ env :
73
+ CIBW_BUILD : ${{ matrix.python-version }}-manylinux_x86_64
74
+ CIBW_MANYLINUX_X86_64_IMAGE : manylinux2014_x86_64_cuda_${{ matrix.cuda-version }}
123
75
124
76
- name : Log Built Wheels
125
77
run : |
126
- ls dist
78
+ ls wheelhouse
79
+ wheel_name=$(basename wheelhouse/*.whl)
80
+ echo "wheel_name=${wheel_name}" >> $GITHUB_ENV
127
81
128
82
- name : Get the tag version
129
83
id : extract_branch
@@ -144,15 +98,15 @@ jobs:
144
98
GITHUB_TOKEN : ${{ secrets.GITHUB_TOKEN }}
145
99
with :
146
100
upload_url : ${{ steps.get_current_release.outputs.upload_url }}
147
- asset_path : ./dist /${{env.wheel_name}}
101
+ asset_path : ./wheelhouse /${{env.wheel_name}}
148
102
asset_name : ${{env.wheel_name}}
149
103
asset_content_type : application/*
150
104
151
105
- name : Upload Artifact
152
106
uses : actions/upload-artifact@v4
153
107
with :
154
108
name : ${{env.wheel_name}}
155
- path : ./dist /${{env.wheel_name}}
109
+ path : ./wheelhouse /${{env.wheel_name}}
156
110
157
111
publish_package :
158
112
name : Publish package
0 commit comments