Skip to content

Add CPU wheel build and tests in JAX presubmit. #5

Add CPU wheel build and tests in JAX presubmit.

Add CPU wheel build and tests in JAX presubmit. #5

# CI - Bazel build wheels (RBE)

Check failure on line 1 in .github/workflows/bazel_build_wheels.yml

View workflow run for this annotation

GitHub Actions / .github/workflows/bazel_build_wheels.yml

Invalid workflow file

(Line: 53, Col: 18): Unrecognized named-value: 'github'. Located at position 1 within expression: github.workflow
#
# This workflow builds all jax wheels using Bazel.
#
# It consists of the following jobs:
# build-jax-artifact:
# - Builds the jax artifact from source via build.py invocation. Uses build_artifacts.yml.
# build-jaxlib-artifact:
# - Builds the jaxlib artifact from source via build.py invocation. Uses build_artifacts.yml.
# build-cuda-artifacts:
# - Builds the cuda artifacts from source via build.py invocation. Uses build_artifacts.yml.
name: CI - Bazel build wheels (RBE)
on:
workflow_dispatch: # allows triggering the workflow run manually
inputs:
runner:
description: "Which runner should the workflow run on?"
type: choice
default: "linux-x86-n2-16"
options:
- "linux-x86-n2-16"
- "linux-arm64-t2a-48"
- "linux-arm64-c4a-16"
- "windows-x86-n2-16"
artifact:
description: "Which JAX artifact to build?"
type: choice
default: "jaxlib"
options:
- "jax"
- "jaxlib"
- "jax-cuda-plugin"
- "jax-cuda-pjrt"
python:
description: "Which python version should the artifact be built for?"
type: choice
default: "3.12"
options:
- "3.11"
- "3.12"
- "3.13"
- "3.14"
- "3.13-nogil"
- "3.14-nogil"
upload_artifacts_to_gcs:
description: "Should the artifacts be uploaded to a GCS bucket?"
default: true
type: boolean
gcs_upload_uri:
description: "GCS location prefix to where the artifacts should be uploaded"
default: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
type: string
clone_main_xla:
description: "Should latest XLA be used?"
type: choice
default: "0"
options:
- "1"
- "0"
halt-for-connection:
description: 'Should this workflow run wait for a remote connection?'
type: choice
default: 'no'
options:
- 'yes'
- 'no'
workflow_call:
inputs:
runner:
description: "Which runner should the workflow run on?"
type: string
default: "linux-x86-n2-16"
artifact:
description: "Which JAX artifact to build?"
type: string
default: "jaxlib"
python:
description: "Which python version should the artifact be built for?"
type: string
default: "3.12"
upload_artifacts_to_gcs:
description: "Should the artifacts be uploaded to a GCS bucket?"
default: true
type: boolean
gcs_upload_uri:
description: "GCS location prefix to where the artifacts should be uploaded"
default: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
type: string
clone_main_xla:
description: "Should latest XLA be used?"
type: string
default: "0"
outputs:
gcs_upload_uri:
description: "GCS location prefix to where the artifacts were uploaded"
value: ${{ jobs.build-jax-artifact.outputs.gcs_upload_uri }}
permissions: {}
jobs:
build-jax-artifact:
if: ${{ startsWith(inputs.runner, 'linux-x86') && inputs.artifact == 'jax' }}
uses: ./.github/workflows/build_artifacts.yml
name: "Build jax artifact"
with:
# Note that since jax is a pure python package, the runner OS and Python values do not
# matter. In addition, cloning main XLA also has no effect.
runner: ${{ inputs.runner }}
artifact: "jax"
upload_artifacts_to_gcs: ${{ inputs.upload_artifacts_to_gcs }}
gcs_upload_uri: ${{ inputs.gcs_upload_uri }}
build-jaxlib-artifact:
if: ${{ inputs.artifact == 'jaxlib' }}
uses: ./.github/workflows/build_artifacts.yml
# Note: For reasons unknown, Github actions groups jobs with the same top-level name in the
# dashboard only if we use an expression in the "name" field. Otherwise, it appends the matrix
# values to the name and creates a separate entry for each matrix combination.
name: "Build ${{ format('{0}', 'jaxlib') }} artifacts"
with:
runner: ${{ inputs.runner }}
artifact: "jaxlib"
python: ${{ inputs.python }}
clone_main_xla: ${{ inputs.clone_main_xla }}
upload_artifacts_to_gcs: ${{ inputs.upload_artifacts_to_gcs }}
gcs_upload_uri: ${{ inputs.gcs_upload_uri }}
build-cuda-artifacts:
if: ${{ startsWith(inputs.runner, 'linux') && (inputs.artifact == 'jax-cuda-plugin' || inputs.artifact == 'jax-cuda-pjrt') }}
uses: ./.github/workflows/build_artifacts.yml
strategy:
fail-fast: false # don't cancel all jobs on failure
matrix:
# Python values need to match the matrix stategy in the CUDA tests job below
artifact: ["jax-cuda-plugin", "jax-cuda-pjrt"]
name: "Build ${{ format('{0}', 'CUDA') }} artifacts"
with:
runner: ${{ inputs.runner }}
artifact: ${{ matrix.artifact }}
python: ${{ inputs.python }}
clone_main_xla: ${{ inputs.clone_main_xla }}
upload_artifacts_to_gcs: ${{ inputs.upload_artifacts_to_gcs }}
gcs_upload_uri: ${{ inputs.gcs_upload_uri }}