-
Notifications
You must be signed in to change notification settings - Fork 235
Fix jax plugin build #1226
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix jax plugin build #1226
Changes from all commits
8ce2194
57fdf68
133efbf
2cd0b5c
35283b3
35b3b93
3840f6d
5862e25
49853c7
be4584e
7393636
116b748
cbe5af3
097aa8d
556be66
1e7f0fb
f0b81a6
8a633c2
eee39f0
d44c3f2
256a470
bc1911a
d95caf7
6896c59
b1fbd2c
a6e2f3e
8aeb56a
2b13c45
666bfd9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -26,12 +26,18 @@ on: | |
| workflow_dispatch: | ||
| inputs: | ||
| amdgpu_family: | ||
| required: true | ||
| type: string | ||
| type: choice | ||
| options: | ||
| - gfx110X-dgpu | ||
| - gfx1151 | ||
| - gfx120X-all | ||
| - gfx94X-dcgpu | ||
| - gfx950-dcgpu | ||
| default: gfx94X-dcgpu | ||
| python_versions: | ||
| required: true | ||
| type: string | ||
| default: | ||
| default: "3.12" | ||
| release_type: | ||
| description: The type of release to build ("nightly", or "dev") | ||
| required: true | ||
|
|
@@ -60,7 +66,7 @@ jobs: | |
| name: Build Linux JAX Wheels | ${{ inputs.amdgpu_family }} | Python ${{ inputs.python_version }} | ||
| runs-on: ${{ github.repository_owner == 'ROCm' && 'azure-linux-scale-rocm' || 'ubuntu-24.04' }} | ||
| env: | ||
| PACKAGE_DIST_DIR: ${{ github.workspace }}/jax_rocm_plugin/wheelhouse | ||
| PACKAGE_DIST_DIR: ${{ github.workspace }}/jax/jax_rocm_plugin/wheelhouse | ||
| S3_BUCKET_PY: "therock-${{ inputs.release_type }}-python" | ||
| steps: | ||
| - name: Checkout TheRock | ||
|
|
@@ -71,23 +77,35 @@ jobs: | |
| with: | ||
| path: jax | ||
| repository: rocm/rocm-jax | ||
| ref: ${{ inputs.jax_ref }} | ||
| ref: ${{ matrix.jax_ref }} | ||
|
|
||
| - name: Configure Git Identity | ||
| run: | | ||
| git config --global user.name "therockbot" | ||
| git config --global user.email "therockbot@amd.com" | ||
|
|
||
| - name: "Setting up Python" | ||
| uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 | ||
| with: | ||
| python-version: ${{ inputs.python_versions }} | ||
|
|
||
| - name: Build JAX Wheels | ||
| env: | ||
| ROCM_VERSION: ${{ inputs.rocm_version }} | ||
| run: | | ||
| pushd rocm-jax | ||
| ls -lah | ||
| pushd jax | ||
| python3 build/ci_build \ | ||
| --compiler=clang \ | ||
| --python-versions="${{ inputs.python_versions }}" \ | ||
| --rocm-version="${{ inputs.rocm_version }}" \ | ||
| --rocm-version="${ROCM_VERSION:0:5}" \ | ||
|
Comment on lines
-87
to
+101
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why truncate here?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In a lot of places, the build scripts and our ROCm setup scripts assume that the build number is always going to be of the form |
||
| --therock-path="${{ inputs.tar_url }}" \ | ||
| dist_wheels | ||
|
|
||
| - name: Install AWS CLI | ||
| if: always() | ||
| run: bash ./dockerfiles/install_awscli.sh | ||
|
|
||
| - name: Configure AWS Credentials | ||
| if: always() | ||
| uses: aws-actions/configure-aws-credentials@7474bc4690e29a8392af63c5b98e7449536d5c3a # v4.3.1 | ||
|
|
@@ -104,5 +122,7 @@ jobs: | |
| - name: (Re-)Generate Python package release index | ||
| if: ${{ github.repository_owner == 'ROCm' }} | ||
| run: | | ||
| pip install boto3 packaging | ||
| python ./build_tools/third_party/s3_management/manage.py ${{ inputs.s3_subdir }}/${{ inputs.amdgpu_family }} | ||
| python3 -m venv .venv | ||
| source .venv/bin/activate | ||
| pip3 install boto3 packaging | ||
|
geomin12 marked this conversation as resolved.
|
||
| python3 ./build_tools/third_party/s3_management/manage.py ${{ inputs.s3_subdir }}/${{ inputs.amdgpu_family }} | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -248,6 +248,24 @@ jobs: | |
| "rocm_version": "${{ needs.setup_metadata.outputs.version }}" | ||
| } | ||
|
|
||
| - name: URL-encode .tar URL | ||
| id: url-encode-tar | ||
| run: python -c "from urllib.parse import quote; print('tar_url=https://therock-${{ env.RELEASE_TYPE }}-tarball.s3.amazonaws.com/' + quote('therock-dist-linux-${{ matrix.target_bundle.amdgpu_family }}${{ inputs.package_suffix }}-${{ needs.setup_metadata.outputs.version }}.tar.gz'))" >> ${GITHUB_OUTPUT} | ||
|
|
||
| - name: Trigger build JAX wheels | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @marbre I think this is the right place to stick the trigger for the nightly build. Is there a way to get the URL of the latest TheRock .tar release though? I didn't see a simple way to do it through this workflow.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The base URL right now is always https://therock-nightly-tarball.s3.us-east-2.amazonaws.com/ (no CloudFront distribution planned). ROCm version is known by
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See https://github.com/ROCm/TheRock/blob/bb4372b9b502177915ade4d0d9f97397283212fd/.github/workflows/release_portable_linux_packages.yml#L117C53-L117C193 and maybe even better use for the filename to make it more robust.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If I understand, that's the tarball that will sit on the worker's local filesystem, correct? I'm not super familiar with that specific workflow dispatch action, but is it guaranteed to run the wheel build workflow on the same runner as the workflow that called it? If not, we should keep using the URL
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh I was only referring to the where the variable is composed. The tarbal behind is uploaded to S3 here: https://github.com/ROCm/TheRock/blob/bb4372b9b502177915ade4d0d9f97397283212fd/.github/workflows/release_portable_linux_packages.yml#L210 Wanted to point out that the file name can have an the package_suffix in it if set. This does not apply to scheduled builds by default but can be the case for manually triggered builds.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ahh, gotcha. Also, right now this points to Could we also have for dev builds? (Notice the change in the first part of the URL)
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I used to encode this as build/release type variable so that can create either or URL. |
||
| if: ${{ github.repository_owner == 'ROCm' }} | ||
| uses: benc-uk/workflow-dispatch@e2e5e9a103e331dad343f381a29e654aea3cf8fc # v1.2.4 | ||
| with: | ||
| workflow: build_linux_jax_wheels.yml | ||
| inputs: | | ||
| { "amdgpu_family": "${{ matrix.target_bundle.amdgpu_family }}", | ||
| "python_versions": "3.12", | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed Python version here? Should this use a matrix across versions somehow?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The latest release of the JAX plugin supports Python 3.10, 3.11, and 3.12 for the wheel build, but we only use 3.10 and 3.12 in our Ubuntu docker image builds. 3.10 support is going to be dropped next release, so I just left it as 3.12. |
||
| "release_type": "${{ env.RELEASE_TYPE }}", | ||
| "s3_subdir": "${{ env.S3_STAGING_SUBDIR }}", | ||
| "rocm_version": "${{ needs.setup_metadata.outputs.version }}", | ||
| "tar_url": "${{ steps.url-encode-tar.outputs.tar_url }}" | ||
| } | ||
|
|
||
| - name: Save cache | ||
| uses: actions/cache/save@0400d5f644dc74513175e3cd8d07132dd4860809 # v4.2.4 | ||
| if: always() | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.