Skip to content

Commit 341d943

Browse files
Fix jax plugin build (#1226)
## Motivation Fixes problems with the JAX CI workflow that was created in #1033 ## Technical Details Fixes minor problems that make the workflow crash, and ensures some command-line tools that we need are installed. ## Test Plan Make sure the workflow works when run with Actions: https://github.com/ROCm/TheRock/actions/workflows/build_linux_jax_wheels.yml --------- Co-authored-by: Scott Todd <scott.todd0@gmail.com>
1 parent bb36a58 commit 341d943

2 files changed

Lines changed: 47 additions & 9 deletions

File tree

.github/workflows/build_linux_jax_wheels.yml

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,18 @@ on:
2626
workflow_dispatch:
2727
inputs:
2828
amdgpu_family:
29-
required: true
30-
type: string
29+
type: choice
30+
options:
31+
- gfx110X-dgpu
32+
- gfx1151
33+
- gfx120X-all
34+
- gfx94X-dcgpu
35+
- gfx950-dcgpu
36+
default: gfx94X-dcgpu
3137
python_versions:
3238
required: true
3339
type: string
34-
default:
40+
default: "3.12"
3541
release_type:
3642
description: The type of release to build ("nightly", or "dev")
3743
required: true
@@ -60,7 +66,7 @@ jobs:
6066
name: Build Linux JAX Wheels | ${{ inputs.amdgpu_family }} | Python ${{ inputs.python_version }}
6167
runs-on: ${{ github.repository_owner == 'ROCm' && 'azure-linux-scale-rocm' || 'ubuntu-24.04' }}
6268
env:
63-
PACKAGE_DIST_DIR: ${{ github.workspace }}/jax_rocm_plugin/wheelhouse
69+
PACKAGE_DIST_DIR: ${{ github.workspace }}/jax/jax_rocm_plugin/wheelhouse
6470
S3_BUCKET_PY: "therock-${{ inputs.release_type }}-python"
6571
steps:
6672
- name: Checkout TheRock
@@ -71,23 +77,35 @@ jobs:
7177
with:
7278
path: jax
7379
repository: rocm/rocm-jax
74-
ref: ${{ inputs.jax_ref }}
80+
ref: ${{ matrix.jax_ref }}
7581

7682
- name: Configure Git Identity
7783
run: |
7884
git config --global user.name "therockbot"
7985
git config --global user.email "therockbot@amd.com"
8086
87+
- name: "Setting up Python"
88+
uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
89+
with:
90+
python-version: ${{ inputs.python_versions }}
91+
8192
- name: Build JAX Wheels
93+
env:
94+
ROCM_VERSION: ${{ inputs.rocm_version }}
8295
run: |
83-
pushd rocm-jax
96+
ls -lah
97+
pushd jax
8498
python3 build/ci_build \
8599
--compiler=clang \
86100
--python-versions="${{ inputs.python_versions }}" \
87-
--rocm-version="${{ inputs.rocm_version }}" \
101+
--rocm-version="${ROCM_VERSION:0:5}" \
88102
--therock-path="${{ inputs.tar_url }}" \
89103
dist_wheels
90104
105+
- name: Install AWS CLI
106+
if: always()
107+
run: bash ./dockerfiles/install_awscli.sh
108+
91109
- name: Configure AWS Credentials
92110
if: always()
93111
uses: aws-actions/configure-aws-credentials@7474bc4690e29a8392af63c5b98e7449536d5c3a # v4.3.1
@@ -104,5 +122,7 @@ jobs:
104122
- name: (Re-)Generate Python package release index
105123
if: ${{ github.repository_owner == 'ROCm' }}
106124
run: |
107-
pip install boto3 packaging
108-
python ./build_tools/third_party/s3_management/manage.py ${{ inputs.s3_subdir }}/${{ inputs.amdgpu_family }}
125+
python3 -m venv .venv
126+
source .venv/bin/activate
127+
pip3 install boto3 packaging
128+
python3 ./build_tools/third_party/s3_management/manage.py ${{ inputs.s3_subdir }}/${{ inputs.amdgpu_family }}

.github/workflows/release_portable_linux_packages.yml

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,24 @@ jobs:
252252
"ref": "${{ inputs.ref || '' }}"
253253
}
254254
255+
- name: URL-encode .tar URL
256+
id: url-encode-tar
257+
run: python -c "from urllib.parse import quote; print('tar_url=https://therock-${{ env.RELEASE_TYPE }}-tarball.s3.amazonaws.com/' + quote('therock-dist-linux-${{ matrix.target_bundle.amdgpu_family }}${{ inputs.package_suffix }}-${{ needs.setup_metadata.outputs.version }}.tar.gz'))" >> ${GITHUB_OUTPUT}
258+
259+
- name: Trigger build JAX wheels
260+
if: ${{ github.repository_owner == 'ROCm' }}
261+
uses: benc-uk/workflow-dispatch@e2e5e9a103e331dad343f381a29e654aea3cf8fc # v1.2.4
262+
with:
263+
workflow: build_linux_jax_wheels.yml
264+
inputs: |
265+
{ "amdgpu_family": "${{ matrix.target_bundle.amdgpu_family }}",
266+
"python_versions": "3.12",
267+
"release_type": "${{ env.RELEASE_TYPE }}",
268+
"s3_subdir": "${{ env.S3_STAGING_SUBDIR }}",
269+
"rocm_version": "${{ needs.setup_metadata.outputs.version }}",
270+
"tar_url": "${{ steps.url-encode-tar.outputs.tar_url }}"
271+
}
272+
255273
- name: Save cache
256274
uses: actions/cache/save@0400d5f644dc74513175e3cd8d07132dd4860809 # v4.2.4
257275
if: ${{ !cancelled() }}

0 commit comments

Comments
 (0)