Skip to content

Build Portable Linux JAX Wheels #25

Build Portable Linux JAX Wheels

Build Portable Linux JAX Wheels #25

name: Build Portable Linux JAX Wheels
on:
workflow_call:
inputs:
amdgpu_family:
required: true
type: string
python_versions:
required: true
type: string
release_type:
description: The type of release to build ("nightly", or "dev")
required: true
type: string
s3_subdir:
description: S3 subdirectory, not including the GPU-family
required: true
type: string
rocm_version:
description: ROCm version to install
type: string
tar_url:
description: URL to TheRock tarball to build against
type: string
workflow_dispatch:
inputs:
amdgpu_family:
type: choice
options:
- gfx110X-dgpu
- gfx1151
- gfx120X-all
- gfx94X-dcgpu
- gfx950-dcgpu
default: gfx94X-dcgpu
python_versions:
required: true
type: string
default: "3.12"
release_type:
description: The type of release to build ("nightly", or "dev")
required: true
type: string
default: "dev"
s3_subdir:
description: S3 subdirectory, not including the GPU-family
type: string
default: "v2"
rocm_version:
description: ROCm version to install
type: string
tar_url:
description: URL to TheRock tarball to build against
type: string
permissions:
id-token: write
contents: read
jobs:
build_jax_wheels:
strategy:
matrix:
jax_ref: [master]
name: Build Linux JAX Wheels | ${{ inputs.amdgpu_family }} | Python ${{ inputs.python_version }}
runs-on: ${{ github.repository_owner == 'ROCm' && 'azure-linux-scale-rocm' || 'ubuntu-24.04' }}
env:
PACKAGE_DIST_DIR: ${{ github.workspace }}/jax/jax_rocm_plugin/wheelhouse
S3_BUCKET_PY: "therock-${{ inputs.release_type }}-python"
steps:
- name: Checkout TheRock
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
- name: Checkout JAX
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
with:
path: jax
repository: rocm/rocm-jax
ref: ${{ matrix.jax_ref }}
- name: Configure Git Identity
run: |
git config --global user.name "therockbot"
git config --global user.email "therockbot@amd.com"
- name: "Setting up Python"
uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
with:
python-version: ${{ inputs.python_versions }}
- name: Build JAX Wheels
env:
ROCM_VERSION: ${{ inputs.rocm_version }}
run: |
ls -lah
pushd jax
python3 build/ci_build \
--compiler=clang \
--python-versions="${{ inputs.python_versions }}" \
--rocm-version="${ROCM_VERSION:0:5}" \
--therock-path="${{ inputs.tar_url }}" \
dist_wheels
- name: Install AWS CLI
if: always()
run: bash ./dockerfiles/install_awscli.sh
- name: Configure AWS Credentials
if: always()
uses: aws-actions/configure-aws-credentials@7474bc4690e29a8392af63c5b98e7449536d5c3a # v4.3.1
with:
aws-region: us-east-2
role-to-assume: arn:aws:iam::692859939525:role/therock-${{ inputs.release_type }}-releases
- name: Upload wheels to S3
if: ${{ github.repository_owner == 'ROCm' }}
run: |
aws s3 cp ${{ env.PACKAGE_DIST_DIR }}/ s3://${{ env.S3_BUCKET_PY }}/${{ inputs.s3_subdir }}/${{ inputs.amdgpu_family }}/ \
--recursive --exclude "*" --include "*.whl"
- name: (Re-)Generate Python package release index
if: ${{ github.repository_owner == 'ROCm' }}
run: |
python3 -m venv .venv
source .venv/bin/activate
pip3 install boto3 packaging
python3 ./build_tools/third_party/s3_management/manage.py ${{ inputs.s3_subdir }}/${{ inputs.amdgpu_family }}