Skip to content

Commit 4b704af

Browse files
committed
Build cuda12 version with Hopper support.
1 parent 047b437 commit 4b704af

File tree

4 files changed

+61
-38
lines changed

4 files changed

+61
-38
lines changed

.github/workflows/publish.yml

Lines changed: 37 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,11 @@ jobs:
4848
- name: Checkout
4949
uses: actions/checkout@v3
5050

51+
- name: Set up python
52+
uses: actions/setup-python@v4
53+
with:
54+
python-version: '3.10'
55+
5156
- name: Set CUDA and PyTorch versions
5257
run: |
5358
echo "MATRIX_CUDA_MAJOR=$(echo ${{ matrix.cuda-version }} | awk -F \. {'print $1'})" >> $GITHUB_ENV
@@ -76,6 +81,7 @@ jobs:
7681

7782
- name: Log Built Wheels
7883
run: |
84+
python3 set_tag_in_wheels.py "+cu$MATRIX_CUDA_MAJOR" wheelhouse/*.whl
7985
ls wheelhouse
8086
wheel_name=$(basename wheelhouse/*.whl)
8187
echo "wheel_name=${wheel_name}" >> $GITHUB_ENV
@@ -109,36 +115,41 @@ jobs:
109115
name: ${{env.wheel_name}}
110116
path: ./wheelhouse/${{env.wheel_name}}
111117

112-
# publish_package:
113-
# name: Publish package
114-
# needs: [build_wheels]
118+
publish_package:
119+
name: Publish package
120+
needs: [build_wheels]
115121

116-
# runs-on: ubuntu-latest
117-
# permissions:
118-
# id-token: write
122+
runs-on: ubuntu-latest
123+
permissions:
124+
id-token: write
119125

120-
# steps:
121-
# - uses: actions/checkout@v3
126+
steps:
127+
- uses: actions/checkout@v3
122128

123-
# - uses: actions/setup-python@v4
124-
# with:
125-
# python-version: '3.10'
129+
- uses: actions/setup-python@v4
130+
with:
131+
python-version: '3.10'
126132

127-
# - name: Install dependencies
128-
# run: |
129-
# pip install setuptools==68.0.0
130-
# pip install git+https://github.com/nshepperd/setuptools-cuda-cpp
131-
# pip install ninja packaging wheel pybind11
133+
- name: Install dependencies
134+
run: |
135+
pip install setuptools==68.0.0
136+
pip install git+https://github.com/nshepperd/setuptools-cuda-cpp
137+
pip install ninja packaging wheel pybind11
132138
133-
# - name: Build core package
134-
# run: |
135-
# CUDA_HOME=/ python setup.py sdist --dist-dir=dist
139+
- name: Build core package
140+
run: |
141+
CUDA_HOME=/ python setup.py sdist --dist-dir=dist
136142
137-
# - name: Retrieve release distributions
138-
# uses: actions/download-artifact@v4
139-
# with:
140-
# path: dist/
141-
# merge-multiple: true
143+
- name: Retrieve release distributions
144+
uses: actions/download-artifact@v4
145+
with:
146+
path: dist/
147+
merge-multiple: true
148+
pattern: '*+cu12*.whl'
149+
150+
- name: Remove version tag for pypi
151+
run: |
152+
python3 set_tag_in_wheels.py "" dist/*+cu12*.whl
142153
143-
# - name: Publish release distributions to PyPI
144-
# uses: pypa/gh-action-pypi-publish@release/v1
154+
- name: Publish release distributions to PyPI
155+
uses: pypa/gh-action-pypi-publish@release/v1

set_tag_in_wheels.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import os
2+
import sys
3+
4+
cuda_ver = sys.argv[1]
5+
wheels = sys.argv[2:]
6+
for wheel in wheels:
7+
dirname = os.path.dirname(wheel)
8+
basename = os.path.basename(wheel)
9+
parts = basename.split('-')
10+
if len(parts) != 5:
11+
continue
12+
version = parts[1].split('+')[0]
13+
parts[1] = f'{version}{cuda_ver}'
14+
basename = '-'.join(parts)
15+
os.rename(wheel, f'{dirname}/{basename}')

setup.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,7 @@ def get_cuda_version():
7676
output = raw_output.split()
7777
release_idx = output.index("release") + 1
7878
version = output[release_idx].split(",")[0].split('.')[0] # should be 11 or 12
79-
return f'+cu{version}'
80-
79+
return version
8180

8281
def append_nvcc_threads(nvcc_extra_args):
8382
return nvcc_extra_args + ["--threads", "4"]
@@ -92,6 +91,8 @@ def append_nvcc_threads(nvcc_extra_args):
9291

9392
SKIP_CUDA_BUILD = False
9493

94+
cuda12 = get_cuda_version() != "11"
95+
9596
# CUDA_HOME = find_cuda_home()
9697
if not SKIP_CUDA_BUILD:
9798

@@ -100,10 +101,9 @@ def append_nvcc_threads(nvcc_extra_args):
100101
# cc_flag.append("arch=compute_75,code=sm_75")
101102
cc_flag.append("-gencode")
102103
cc_flag.append("arch=compute_80,code=sm_80")
103-
# if CUDA_HOME is not None:
104-
# if bare_metal_version >= Version("11.8"):
105-
# cc_flag.append("-gencode")
106-
# cc_flag.append("arch=compute_90,code=sm_90")
104+
if cuda12:
105+
cc_flag.append("-gencode")
106+
cc_flag.append("arch=compute_90,code=sm_90")
107107
ext_modules.append(
108108
CUDAExtension(
109109
name="flash_attn_jax_lib.flash_api",
@@ -171,6 +171,7 @@ def append_nvcc_threads(nvcc_extra_args):
171171
"-U__CUDA_NO_HALF_CONVERSIONS__",
172172
"-U__CUDA_NO_HALF2_OPERATORS__",
173173
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
174+
"-DFLASHATTENTION_DISABLE_DROPOUT=1",
174175
"--expt-relaxed-constexpr",
175176
"--expt-extended-lambda",
176177
"--use_fast_math",
@@ -194,11 +195,7 @@ def get_package_version():
194195
with open(Path(this_dir) / "src" / "flash_attn_jax" / "__init__.py", "r") as f:
195196
version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE)
196197
public_version = ast.literal_eval(version_match.group(1))
197-
local_version = os.environ.get("FLASH_ATTN_LOCAL_VERSION")
198-
if local_version:
199-
return f"{public_version}+{local_version}{get_cuda_version()}"
200-
else:
201-
return f"{public_version}{get_cuda_version()}"
198+
return public_version
202199

203200

204201
class NinjaBuildExtension(BuildExtension):

src/flash_attn_jax/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
from .flash import flash_mha
2-
__version__ = 'v0.2.1'
2+
__version__ = 'v0.2.2'

0 commit comments

Comments
 (0)