Skip to content

Build Portable Linux JAX Wheels #3

Build Portable Linux JAX Wheels

Build Portable Linux JAX Wheels #3

name: Build Portable Linux JAX Wheels
on:
workflow_call:
inputs:
amdgpu_family:
required: true
type: string
python_versions:
required: true
type: string
release_type:
description: The type of release to build ("nightly", or "dev")
required: true
type: string
s3_subdir:
description: S3 subdirectory, not including the GPU-family
required: true
type: string
rocm_version:
description: ROCm version to install
type: string
tar_url:
description: URL to TheRock tarball to build against
type: string
workflow_dispatch:
inputs:
amdgpu_family:
required: true
type: string
python_versions:
required: true
type: string
default:
release_type:
description: The type of release to build ("nightly", or "dev")
required: true
type: string
default: "dev"
s3_subdir:
description: S3 subdirectory, not including the GPU-family
type: string
default: "v2"
rocm_version:
description: ROCm version to install
type: string
tar_url:
description: URL to TheRock tarball to build against
type: string
permissions:
id-token: write
contents: read
jobs:
build_jax_wheels:
strategy:
matrix:
jax_ref: [master]
name: Build Linux JAX Wheels | ${{ inputs.amdgpu_family }} | Python ${{ inputs.python_version }}
runs-on: ${{ github.repository_owner == 'ROCm' && 'azure-linux-scale-rocm' || 'ubuntu-24.04' }}
env:
PACKAGE_DIST_DIR: ${{ github.workspace }}/jax_rocm_plugin/wheelhouse
S3_BUCKET_PY: "therock-${{ inputs.release_type }}-python"
steps:
- name: Checkout TheRock
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Checkout JAX
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
path: jax
repository: rocm/rocm-jax
ref: ${{ inputs.jax_ref }}
- name: Configure Git Identity
run: |
git config --global user.name "therockbot"
git config --global user.email "therockbot@amd.com"
- name: Build JAX Wheels
run: |
pushd rocm-jax
python3 build/ci_build \
--compiler=clang \
--python-versions="${{ inputs.python_versions }}" \
--rocm-version="${{ inputs.rocm_version }}" \
--therock-path="${{ inputs.tar_url }}" \
dist_wheels
- name: Configure AWS Credentials
if: always()
uses: aws-actions/configure-aws-credentials@b47578312673ae6fa5b5096b330d9fbac3d116df # v4.2.1
with:
aws-region: us-east-2
role-to-assume: arn:aws:iam::692859939525:role/therock-${{ inputs.release_type }}-releases
- name: Upload wheels to S3
if: ${{ github.repository_owner == 'ROCm' }}
run: |
aws s3 cp ${{ env.PACKAGE_DIST_DIR }}/ s3://${{ env.S3_BUCKET_PY }}/${{ inputs.s3_subdir }}/${{ inputs.amdgpu_family }}/ \
--recursive --exclude "*" --include "*.whl"
- name: (Re-)Generate Python package release index
if: ${{ github.repository_owner == 'ROCm' }}
run: |
pip install boto3 packaging
python ./build_tools/third_party/s3_management/manage.py ${{ inputs.s3_subdir }}/${{ inputs.amdgpu_family }}