Add CPU wheel build and tests in JAX presubmit. #5
Workflow file for this run
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # CI - Bazel build wheels (RBE) | ||
| # | ||
| # This workflow builds all jax wheels using Bazel. | ||
| # | ||
| # It consists of the following jobs: | ||
| # build-jax-artifact: | ||
| # - Builds the jax artifact from source via build.py invocation. Uses build_artifacts.yml. | ||
| # build-jaxlib-artifact: | ||
| # - Builds the jaxlib artifact from source via build.py invocation. Uses build_artifacts.yml. | ||
| # build-cuda-artifacts: | ||
| # - Builds the cuda artifacts from source via build.py invocation. Uses build_artifacts.yml. | ||
| name: CI - Bazel build wheels (RBE) | ||
| on: | ||
| workflow_dispatch: # allows triggering the workflow run manually | ||
| inputs: | ||
| runner: | ||
| description: "Which runner should the workflow run on?" | ||
| type: choice | ||
| default: "linux-x86-n2-16" | ||
| options: | ||
| - "linux-x86-n2-16" | ||
| - "linux-arm64-t2a-48" | ||
| - "linux-arm64-c4a-16" | ||
| - "windows-x86-n2-16" | ||
| artifact: | ||
| description: "Which JAX artifact to build?" | ||
| type: choice | ||
| default: "jaxlib" | ||
| options: | ||
| - "jax" | ||
| - "jaxlib" | ||
| - "jax-cuda-plugin" | ||
| - "jax-cuda-pjrt" | ||
| python: | ||
| description: "Which python version should the artifact be built for?" | ||
| type: choice | ||
| default: "3.12" | ||
| options: | ||
| - "3.11" | ||
| - "3.12" | ||
| - "3.13" | ||
| - "3.14" | ||
| - "3.13-nogil" | ||
| - "3.14-nogil" | ||
| upload_artifacts_to_gcs: | ||
| description: "Should the artifacts be uploaded to a GCS bucket?" | ||
| default: true | ||
| type: boolean | ||
| gcs_upload_uri: | ||
| description: "GCS location prefix to where the artifacts should be uploaded" | ||
| default: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' | ||
| type: string | ||
| clone_main_xla: | ||
| description: "Should latest XLA be used?" | ||
| type: choice | ||
| default: "0" | ||
| options: | ||
| - "1" | ||
| - "0" | ||
| halt-for-connection: | ||
| description: 'Should this workflow run wait for a remote connection?' | ||
| type: choice | ||
| default: 'no' | ||
| options: | ||
| - 'yes' | ||
| - 'no' | ||
| workflow_call: | ||
| inputs: | ||
| runner: | ||
| description: "Which runner should the workflow run on?" | ||
| type: string | ||
| default: "linux-x86-n2-16" | ||
| artifact: | ||
| description: "Which JAX artifact to build?" | ||
| type: string | ||
| default: "jaxlib" | ||
| python: | ||
| description: "Which python version should the artifact be built for?" | ||
| type: string | ||
| default: "3.12" | ||
| upload_artifacts_to_gcs: | ||
| description: "Should the artifacts be uploaded to a GCS bucket?" | ||
| default: true | ||
| type: boolean | ||
| gcs_upload_uri: | ||
| description: "GCS location prefix to where the artifacts should be uploaded" | ||
| default: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' | ||
| type: string | ||
| clone_main_xla: | ||
| description: "Should latest XLA be used?" | ||
| type: string | ||
| default: "0" | ||
| outputs: | ||
| gcs_upload_uri: | ||
| description: "GCS location prefix to where the artifacts were uploaded" | ||
| value: ${{ jobs.build-jax-artifact.outputs.gcs_upload_uri }} | ||
| permissions: {} | ||
| jobs: | ||
| build-jax-artifact: | ||
| if: ${{ startsWith(inputs.runner, 'linux-x86') && inputs.artifact == 'jax' }} | ||
| uses: ./.github/workflows/build_artifacts.yml | ||
| name: "Build jax artifact" | ||
| with: | ||
| # Note that since jax is a pure python package, the runner OS and Python values do not | ||
| # matter. In addition, cloning main XLA also has no effect. | ||
| runner: ${{ inputs.runner }} | ||
| artifact: "jax" | ||
| upload_artifacts_to_gcs: ${{ inputs.upload_artifacts_to_gcs }} | ||
| gcs_upload_uri: ${{ inputs.gcs_upload_uri }} | ||
| build-jaxlib-artifact: | ||
| if: ${{ inputs.artifact == 'jaxlib' }} | ||
| uses: ./.github/workflows/build_artifacts.yml | ||
| # Note: For reasons unknown, Github actions groups jobs with the same top-level name in the | ||
| # dashboard only if we use an expression in the "name" field. Otherwise, it appends the matrix | ||
| # values to the name and creates a separate entry for each matrix combination. | ||
| name: "Build ${{ format('{0}', 'jaxlib') }} artifacts" | ||
| with: | ||
| runner: ${{ inputs.runner }} | ||
| artifact: "jaxlib" | ||
| python: ${{ inputs.python }} | ||
| clone_main_xla: ${{ inputs.clone_main_xla }} | ||
| upload_artifacts_to_gcs: ${{ inputs.upload_artifacts_to_gcs }} | ||
| gcs_upload_uri: ${{ inputs.gcs_upload_uri }} | ||
| build-cuda-artifacts: | ||
| if: ${{ startsWith(inputs.runner, 'linux') && (inputs.artifact == 'jax-cuda-plugin' || inputs.artifact == 'jax-cuda-pjrt') }} | ||
| uses: ./.github/workflows/build_artifacts.yml | ||
| strategy: | ||
| fail-fast: false # don't cancel all jobs on failure | ||
| matrix: | ||
| # Python values need to match the matrix stategy in the CUDA tests job below | ||
| artifact: ["jax-cuda-plugin", "jax-cuda-pjrt"] | ||
| name: "Build ${{ format('{0}', 'CUDA') }} artifacts" | ||
| with: | ||
| runner: ${{ inputs.runner }} | ||
| artifact: ${{ matrix.artifact }} | ||
| python: ${{ inputs.python }} | ||
| clone_main_xla: ${{ inputs.clone_main_xla }} | ||
| upload_artifacts_to_gcs: ${{ inputs.upload_artifacts_to_gcs }} | ||
| gcs_upload_uri: ${{ inputs.gcs_upload_uri }} | ||