Build Portable Linux JAX Wheels #23
Workflow file for this run
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| name: Build Portable Linux JAX Wheels | |
| on: | |
| workflow_call: | |
| inputs: | |
| amdgpu_family: | |
| required: true | |
| type: string | |
| python_versions: | |
| required: true | |
| type: string | |
| release_type: | |
| description: The type of release to build ("nightly", or "dev") | |
| required: true | |
| type: string | |
| s3_subdir: | |
| description: S3 subdirectory, not including the GPU-family | |
| required: true | |
| type: string | |
| rocm_version: | |
| description: ROCm version to install | |
| type: string | |
| tar_url: | |
| description: URL to TheRock tarball to build against | |
| type: string | |
| workflow_dispatch: | |
| inputs: | |
| amdgpu_family: | |
| type: choice | |
| options: | |
| - gfx110X-dgpu | |
| - gfx1151 | |
| - gfx120X-all | |
| - gfx94X-dcgpu | |
| - gfx950-dcgpu | |
| default: gfx94X-dcgpu | |
| python_versions: | |
| required: true | |
| type: string | |
| default: "3.12" | |
| release_type: | |
| description: The type of release to build ("nightly", or "dev") | |
| required: true | |
| type: string | |
| default: "dev" | |
| s3_subdir: | |
| description: S3 subdirectory, not including the GPU-family | |
| type: string | |
| default: "v2" | |
| rocm_version: | |
| description: ROCm version to install | |
| type: string | |
| tar_url: | |
| description: URL to TheRock tarball to build against | |
| type: string | |
| permissions: | |
| id-token: write | |
| contents: read | |
| jobs: | |
| build_jax_wheels: | |
| strategy: | |
| matrix: | |
| jax_ref: [master] | |
| name: Build Linux JAX Wheels | ${{ inputs.amdgpu_family }} | Python ${{ inputs.python_version }} | |
| runs-on: ${{ github.repository_owner == 'ROCm' && 'azure-linux-scale-rocm' || 'ubuntu-24.04' }} | |
| env: | |
| PACKAGE_DIST_DIR: ${{ github.workspace }}/jax/jax_rocm_plugin/wheelhouse | |
| S3_BUCKET_PY: "therock-${{ inputs.release_type }}-python" | |
| steps: | |
| - name: Checkout TheRock | |
| uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 | |
| - name: Checkout JAX | |
| uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 | |
| with: | |
| path: jax | |
| repository: rocm/rocm-jax | |
| ref: ${{ matrix.jax_ref }} | |
| - name: Configure Git Identity | |
| run: | | |
| git config --global user.name "therockbot" | |
| git config --global user.email "therockbot@amd.com" | |
| - name: "Setting up Python" | |
| uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 | |
| with: | |
| python-version: ${{ inputs.python_versions }} | |
| - name: Build JAX Wheels | |
| env: | |
| ROCM_VERSION: ${{ inputs.rocm_version }} | |
| run: | | |
| ls -lah | |
| pushd jax | |
| python3 build/ci_build \ | |
| --compiler=clang \ | |
| --python-versions="${{ inputs.python_versions }}" \ | |
| --rocm-version="${ROCM_VERSION:0:5}" \ | |
| --therock-path="${{ inputs.tar_url }}" \ | |
| dist_wheels | |
| - name: Install AWS CLI | |
| if: always() | |
| run: bash ./dockerfiles/install_awscli.sh | |
| - name: Configure AWS Credentials | |
| if: always() | |
| uses: aws-actions/configure-aws-credentials@7474bc4690e29a8392af63c5b98e7449536d5c3a # v4.3.1 | |
| with: | |
| aws-region: us-east-2 | |
| role-to-assume: arn:aws:iam::692859939525:role/therock-${{ inputs.release_type }}-releases | |
| - name: Upload wheels to S3 | |
| if: ${{ github.repository_owner == 'ROCm' }} | |
| run: | | |
| aws s3 cp ${{ env.PACKAGE_DIST_DIR }}/ s3://${{ env.S3_BUCKET_PY }}/${{ inputs.s3_subdir }}/${{ inputs.amdgpu_family }}/ \ | |
| --recursive --exclude "*" --include "*.whl" | |
| - name: (Re-)Generate Python package release index | |
| if: ${{ github.repository_owner == 'ROCm' }} | |
| run: | | |
| python3 -m venv .venv | |
| source .venv/bin/activate | |
| pip3 install boto3 packaging | |
| python3 ./build_tools/third_party/s3_management/manage.py ${{ inputs.s3_subdir }}/${{ inputs.amdgpu_family }} |