Skip to content

Add CPU wheel build and tests in JAX presubmit. #6

Add CPU wheel build and tests in JAX presubmit.

Add CPU wheel build and tests in JAX presubmit. #6

# CI - Bazel build wheels (RBE)
#
# This workflow builds all jax wheels using Bazel.
#
# It consists of the following jobs:
# build-jax-artifact:
# - Builds the jax artifact from source via build.py invocation. Uses build_artifacts.yml.
# build-jaxlib-artifact:
# - Builds the jaxlib artifact from source via build.py invocation. Uses build_artifacts.yml.
# build-cuda-artifacts:
# - Builds the cuda artifacts from source via build.py invocation. Uses build_artifacts.yml.
name: CI - Bazel build wheels (RBE)
on:
workflow_dispatch: # allows triggering the workflow run manually
inputs:
runners:

Check failure on line 18 in .github/workflows/bazel_build_wheels.yml

View workflow run for this annotation

GitHub Actions / .github/workflows/bazel_build_wheels.yml

Invalid workflow file

You have an error in your yaml syntax on line 18
description: "Which runners should the workflow run on? Format example: '[\"linux-x86-n4-16\", \"linux-arm64-c4a-16\"]'"
type: string
default: '["linux-x86-n4-16"]''
artifacts:
description: "Which JAX artifacts to build? Format example: '[\"jaxlib\", \"jax\"]'"
type: string
default: '["jaxlib"]''
python-versions:
description: "Which python versions should the artifact be built for? Format example: '[\"3.12\", \"3.13\"]'"
type: string
default: '["3.12"]'
cuda-versions:
description: "Which cuda versions should the artifact be built for? Format example: '[\"12\", \"13\"]'"
type: string
default: '["12"]'
clone_main_xla:
description: "Should latest XLA be used?"
type: choice
default: "0"
options:
- "1"
- "0"
halt-for-connection:
description: 'Should this workflow run wait for a remote connection?'
type: choice
default: 'no'
options:
- 'yes'
- 'no'
workflow_call:
inputs:
runners:
description: "Which runners should the workflow run on?"
type: string
default: "linux-x86-n2-16"
artifacts:
description: "Which JAX artifacts to build?"
type: string
default: "jaxlib"
python-versions:
description: "Which python versions should the artifact be built for?"
type: string
default: "3.12"
cuda-versions:
description: "Which cuda versions should the artifact be built for?"
type: string
default: "12"
upload_artifacts_to_gcs:
description: "Should the artifacts be uploaded to a GCS bucket?"
default: true
type: boolean
gcs_upload_uri:
description: "GCS location prefix to where the artifacts should be uploaded"
default: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
type: string
clone_main_xla:
description: "Should latest XLA be used?"
type: string
default: "0"
outputs:
gcs_upload_uri:
description: "GCS location prefix to where the artifacts were uploaded"
value: ${{ jobs.build-jax-artifact.outputs.gcs_upload_uri }}
permissions: {}
jobs:
build-jax-artifact:
if: ${{ startsWith(inputs.runner, 'linux-x86') && 'jax' in inputs.artifacts }}
uses: ./.github/workflows/build_artifacts.yml
name: "Build jax artifact"
with:
# Note that since jax is a pure python package, the runner OS and Python values do not
# matter. In addition, cloning main XLA also has no effect.
runner: "linux-x86-n4-16"
artifact: "jax"
upload_artifacts_to_gcs: ${{ inputs.upload_artifacts_to_gcs }}
gcs_upload_uri: ${{ inputs.gcs_upload_uri }}
build-jaxlib-artifact:
if: ${{ 'jaxlib' in inputs.artifacts }}
uses: ./.github/workflows/build_artifacts.yml
# Note: For reasons unknown, Github actions groups jobs with the same top-level name in the
# dashboard only if we use an expression in the "name" field. Otherwise, it appends the matrix
# values to the name and creates a separate entry for each matrix combination.
name: "Build ${{ format('{0}', 'jaxlib') }} artifacts"
strategy:
fail-fast: false # don't cancel all jobs on failure
matrix:
runner: ${{ inputs.runners }}
python: ${{ inputs.python-versions }}
with:
runner: ${{ matrix.runner }}
artifact: "jaxlib"
python: ${{ matrix.python }}
clone_main_xla: ${{ inputs.clone_main_xla }}
upload_artifacts_to_gcs: ${{ inputs.upload_artifacts_to_gcs }}
gcs_upload_uri: ${{ inputs.gcs_upload_uri }}
build-cuda-artifacts:
if: ${{ startsWith(inputs.runner, 'linux') && ('jax-cuda-plugin' in inputs.artifacts || 'jax-cuda-pjrt' in inputs.artifacts) }}
uses: ./.github/workflows/build_artifacts.yml
strategy:
fail-fast: false # don't cancel all jobs on failure
matrix:
runner: ${{ inputs.runners }}
python: ${{ inputs.python-versions }}
artifact: ["jax-cuda-plugin", "jax-cuda-pjrt"]
cuda-version: ${{ inputs.cuda-versions }}
name: "Build ${{ format('{0}', 'CUDA') }} artifacts"
with:
runner: ${{ matrix.runner }}
artifact: ${{ matrix.artifact }}
python: ${{ matrix.python }}
cuda-version: ${{ matrix.cuda-version }}
clone_main_xla: ${{ inputs.clone_main_xla }}
upload_artifacts_to_gcs: ${{ inputs.upload_artifacts_to_gcs }}
gcs_upload_uri: ${{ inputs.gcs_upload_uri }}