Skip to content

Fix jax plugin build#1226

Merged
charleshofer merged 29 commits intomainfrom
fix-jax-plugin-build
Sep 29, 2025
Merged

Fix jax plugin build#1226
charleshofer merged 29 commits intomainfrom
fix-jax-plugin-build

Conversation

@charleshofer
Copy link
Copy Markdown
Contributor

@charleshofer charleshofer commented Aug 11, 2025

Motivation

Fixes problems with the JAX CI workflow that was created in #1033

Technical Details

Fixes minor problems that make the workflow crash, and ensures some command-line tools that we need are installed.

Test Plan

Make sure the workflow works when run with Actions: https://github.com/ROCm/TheRock/actions/workflows/build_linux_jax_wheels.yml

Copy link
Copy Markdown
Member

@marbre marbre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Drive-by: I noticed jax_rocm7_pjrt-0.6.0-py3-none-manylinux_2_28_x86_64.whl and jax_rocm7_plugin-0.6.0-cp310-cp310-manylinux_2_28_x86_64.whl got uploaded to v2/gfx942 in the dev bucket but they must go into v2/gfx94X-dcgpu instead.

@charleshofer
Copy link
Copy Markdown
Contributor Author

Drive-by: I noticed jax_rocm7_pjrt-0.6.0-py3-none-manylinux_2_28_x86_64.whl and jax_rocm7_plugin-0.6.0-cp310-cp310-manylinux_2_28_x86_64.whl got uploaded to v2/gfx942 in the dev bucket but they must go into v2/gfx94X-dcgpu instead.

Gotcha. Looks like the amdgpu_family gets used in two different ways in the workflow. Once as an input to the JAX plugin build, and another time as an input to the upload script. I assume it'd be bad to hardcode the upload to v2/gfx94X-dcgpu, right? Because we might want to upload to gfx95X. The JAX build requires specific GFX numbers. It can't take something like gfx94X. Maybe we just make two inputs to the workflow? One for the build and one for the bucket?

@marbre
Copy link
Copy Markdown
Member

marbre commented Aug 19, 2025

Drive-by: I noticed jax_rocm7_pjrt-0.6.0-py3-none-manylinux_2_28_x86_64.whl and jax_rocm7_plugin-0.6.0-cp310-cp310-manylinux_2_28_x86_64.whl got uploaded to v2/gfx942 in the dev bucket but they must go into v2/gfx94X-dcgpu instead.

Gotcha. Looks like the amdgpu_family gets used in two different ways in the workflow. Once as an input to the JAX plugin build, and another time as an input to the upload script. I assume it'd be bad to hardcode the upload to v2/gfx94X-dcgpu, right? Because we might want to upload to gfx95X. The JAX build requires specific GFX numbers. It can't take something like gfx94X. Maybe we just make two inputs to the workflow? One for the build and one for the bucket?

Other workflows incorporate scripts that use https://github.com/ROCm/TheRock/blob/main/build_tools/github_actions/amdgpu_family_matrix.py. @geomin12 has more context here but normally we avoid to have to inputs.

@geomin12
Copy link
Copy Markdown
Contributor

Drive-by: I noticed jax_rocm7_pjrt-0.6.0-py3-none-manylinux_2_28_x86_64.whl and jax_rocm7_plugin-0.6.0-cp310-cp310-manylinux_2_28_x86_64.whl got uploaded to v2/gfx942 in the dev bucket but they must go into v2/gfx94X-dcgpu instead.

Gotcha. Looks like the amdgpu_family gets used in two different ways in the workflow. Once as an input to the JAX plugin build, and another time as an input to the upload script. I assume it'd be bad to hardcode the upload to v2/gfx94X-dcgpu, right? Because we might want to upload to gfx95X. The JAX build requires specific GFX numbers. It can't take something like gfx94X. Maybe we just make two inputs to the workflow? One for the build and one for the bucket?

hi! yes, for inputs, we figured that the abbreviation would be easier for workflow_dispatch triggers. in general, we typically have a Setup step/job that determines which targets to run, where we get these from amdgpu_family_matrix.py. We use the gfx94X as a key, then get the corresponding full family name (gfx94X -> gfx94X-dcgpu)

I would recommend creating a setup job/step, use GitHub outputs (setup.yml as an example) to pass this information to the jax plugin segment, so it can be a bit more dynamic

@amd-chrissosa
Copy link
Copy Markdown
Contributor

Any progress needed here? I understand Jax was critical to enable a specific set of users

@charleshofer charleshofer marked this pull request as ready for review August 28, 2025 21:40
@charleshofer
Copy link
Copy Markdown
Contributor Author

charleshofer commented Aug 28, 2025

Do you guys want me to set this up to run on a schedule? Nightly? Weekly? @gabeweisz

@gabeweisz
Copy link
Copy Markdown

Do you guys want me to set this up to run on a schedule? Nightly? Weekly? @gabeweisz

Nightly please

@@ -247,6 +247,20 @@ jobs:
"rocm_version": "${{ needs.setup_metadata.outputs.version }}"
}

- name: Trigger build JAX wheels
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@marbre I think this is the right place to stick the trigger for the nightly build. Is there a way to get the URL of the latest TheRock .tar release though? I didn't see a simple way to do it through this workflow.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The base URL right now is always https://therock-nightly-tarball.s3.us-east-2.amazonaws.com/ (no CloudFront distribution planned). ROCm version is known by ${{ needs.setup_metadata.outputs.version }} and GPU family via ${{ matrix.target_bundle.amdgpu_family }}. Thus the URL should be

https://therock-nightly-tarball.s3.us-east-2.amazonaws.com/therock-dist-linux-${{ matrix.target_bundle.amdgpu_family }}-${{ needs.setup_metadata.outputs.version }}.tar.gz

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See https://github.com/ROCm/TheRock/blob/bb4372b9b502177915ade4d0d9f97397283212fd/.github/workflows/release_portable_linux_packages.yml#L117C53-L117C193 and maybe even better use

therock-dist-linux-${{ matrix.target_bundle.amdgpu_family }}${{ inputs.package_suffix }}-${{ needs.setup_metadata.outputs.version }}.tar.gz

for the filename to make it more robust.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand, that's the tarball that will sit on the worker's local filesystem, correct? I'm not super familiar with that specific workflow dispatch action, but is it guaranteed to run the wheel build workflow on the same runner as the workflow that called it? If not, we should keep using the URL

Copy link
Copy Markdown
Member

@marbre marbre Sep 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I was only referring to the where the variable is composed. The tarbal behind is uploaded to S3 here: https://github.com/ROCm/TheRock/blob/bb4372b9b502177915ade4d0d9f97397283212fd/.github/workflows/release_portable_linux_packages.yml#L210

Wanted to point out that the file name can have an the package_suffix in it if set. This does not apply to scheduled builds by default but can be the case for manually triggered builds.

Copy link
Copy Markdown
Contributor Author

@charleshofer charleshofer Sep 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh, gotcha. Also, right now this points to

https://therock-nightly-tarball.s3.us-east-2.amazonaws.com/therock-dist-linux-${{ matrix.target_bundle.amdgpu_family }}${{ inputs.package_suffix }}-${{ needs.setup_metadata.outputs.version }}.tar.gz

Could we also have

https://therock-dev-tarball.s3.us-east-2.amazonaws.com/therock-dist-linux-${{ matrix.target_bundle.amdgpu_family }}${{ inputs.package_suffix }}-${{ needs.setup_metadata.outputs.version }}.tar.gz

for dev builds? (Notice the change in the first part of the URL)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used to encode this as build/release type variable so that can create either or URL.

Comment on lines +98 to +105
- name: Install AWS CLI
if: always()
run: |
curl "https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip" -o "awscliv2.zip"
unzip awscliv2.zip
./aws/install -i ./aws-cli-files -b ./aws-bin
sudo ./aws/install

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you use this container (

image: ghcr.io/rocm/therock_build_manylinux_x86_64@sha256:543ba2609de3571d2c64f3872e5f1af42fdfa90d074a7baccb1db120c9514be2
), it comes already installed with AWS! I would recommend replacing this and just using container :)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We cannot use the container that has AWS pre-installed. The build/ci_build script that does the JAX wheel build creates a docker image and does the actual build inside of that so that we're manylinux compliant.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in that case, we do have this script (https://github.com/ROCm/TheRock/blob/main/dockerfiles/install_awscli.sh) that does the same thing here! less duplication!

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll give that a shot

Comment thread .github/workflows/build_linux_jax_wheels.yml
@@ -104,5 +119,6 @@ jobs:
- name: (Re-)Generate Python package release index
if: ${{ github.repository_owner == 'ROCm' }}
run: |
pip install boto3 packaging
python ./build_tools/third_party/s3_management/manage.py ${{ inputs.s3_subdir }}/${{ inputs.amdgpu_family }}
sudo apt install python3-venv python3-pip -y
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can be removed if container is used!

you can also use:

- name: "Setting up Python"
uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
with:
python-version: 3.11

to help with python versioning!

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't use the container, but I'll switch it to using the actions step

Comment thread .github/workflows/build_linux_jax_wheels.yml
@@ -247,6 +247,20 @@ jobs:
"rocm_version": "${{ needs.setup_metadata.outputs.version }}"
}

- name: Trigger build JAX wheels
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The base URL right now is always https://therock-nightly-tarball.s3.us-east-2.amazonaws.com/ (no CloudFront distribution planned). ROCm version is known by ${{ needs.setup_metadata.outputs.version }} and GPU family via ${{ matrix.target_bundle.amdgpu_family }}. Thus the URL should be

https://therock-nightly-tarball.s3.us-east-2.amazonaws.com/therock-dist-linux-${{ matrix.target_bundle.amdgpu_family }}-${{ needs.setup_metadata.outputs.version }}.tar.gz

"release_type": "${{ env.RELEASE_TYPE }}",
"s3_subdir": "${{ env.S3_SUBDIR }}",
"rocm_version": "${{ needs.setup_metadata.outputs.version }}",
"tag_url": "${{ ??? }}"
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CI fails with invalid workflow on this.

- name: "Setting up Python"
uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
with:
python-version: 3.11
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

want to use the input here? inputs.python_versions

"release_type": "${{ env.RELEASE_TYPE }}",
"s3_subdir": "${{ env.S3_SUBDIR }}",
"rocm_version": "${{ needs.setup_metadata.outputs.version }}",
"tag_url": "https://therock-nightly-tarball.s3.us-east-2.amazonaws.com/therock-dist-linux-${{ matrix.target_bundle.amdgpu_family }}-${{ needs.setup_metadata.outputs.version }}.tar.gz"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: pretty sure you can remove the us-east-2 part

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, should be possible as well.

Copy link
Copy Markdown
Contributor

@geomin12 geomin12 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm! although, can we test this by triggering release_portable_linux_packages using workflow_dispatch? just to make sure everything works?

@charleshofer charleshofer force-pushed the fix-jax-plugin-build branch 2 times, most recently from beec084 to 107c6bf Compare September 4, 2025 21:13
@charleshofer
Copy link
Copy Markdown
Contributor Author

I kicked off a nightly build with this branch: https://github.com/ROCm/TheRock/actions/runs/17479247009

@charleshofer
Copy link
Copy Markdown
Contributor Author

charleshofer commented Sep 26, 2025

Comment on lines 116 to 120
- name: Upload wheels to S3
if: ${{ github.repository_owner == 'ROCm' }}
run: |
aws s3 cp ${{ env.PACKAGE_DIST_DIR }}/ s3://${{ env.S3_BUCKET_PY }}/${{ inputs.s3_subdir }}/${{ inputs.amdgpu_family }}/ \
--recursive --exclude "*" --include "*.whl"
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How close are these wheels ready to be advertised? If not ready yet, should we upload them to s3_staging_subdir instead of s3_subdir? We started uploading pytorch wheels first to "staging" then we copy them out once tests complete.

Once ready, we should advertise JAX support here:

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are not ready. I don't have the automated testing in place yet to make sure that they're good. Will switch to using the staging directory.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. You might want to put a TODO in there to

  1. Plumb through both the staging and "prod" URL
  2. First upload to staging, then run tests, then copy to "prod" if tests passed (matching what we do for torch)

@charleshofer charleshofer merged commit 341d943 into main Sep 29, 2025
29 of 32 checks passed
@charleshofer charleshofer deleted the fix-jax-plugin-build branch September 29, 2025 08:15
@github-project-automation github-project-automation Bot moved this from TODO to Done in TheRock Triage Sep 29, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

6 participants