Skip to content

CI - Wheel Tests (Continuous) #1164

CI - Wheel Tests (Continuous)

CI - Wheel Tests (Continuous) #1164

# CI - Wheel Tests (Continuous)
#
# This workflow builds JAX artifacts and runs CPU/TPU/CUDA tests.
#
# It orchestrates the following:
# 1. build-jaxlib-artifact: Calls the `build_artifacts.yml` workflow to build jaxlib and
# uploads it to a GCS bucket.
# 2. run-pytest-cpu: Calls the `pytest_cpu.yml` workflow which downloads the jaxlib wheel
# that was built in the previous step and runs CPU tests.
# 3. build-cuda-artifacts: Calls the `build_artifacts.yml` workflow to build CUDA artifacts and
# uploads them to a GCS bucket.
# 4. run-bazel-test-cpu-py-import: Calls the `bazel_cpu_rbe.yml` workflow which
# runs Bazel CPU tests with py_import on RBE.
# 5. run-bazel-test-cuda-py-import: Calls the `bazel_cuda.yml` workflow which
# runs Bazel CUDA tests with py_import on non-RBE.
# 6. run-pytest-cuda: Calls the `pytest_cuda.yml` workflow which downloads the jaxlib and CUDA
# artifacts that were built in the previous steps and runs the CUDA tests.
# 7. run-bazel-test-cuda: Calls the `bazel_cuda.yml` workflow which downloads the jaxlib
# and CUDA artifacts that were built in the previous steps and runs the
# CUDA tests using Bazel.
# 8. run-pytest-tpu: Calls the `pytest_tpu.yml` workflow which downloads the jaxlib wheel
# that was built in the previous step and runs TPU tests.
# 9. run-bazel-test-tpu: Calls the `bazel_test_tpu.yml` workflow which
# runs Bazel TPU tests with py_import.
name: CI - Wheel Tests (Continuous)
permissions:
contents: read
on:
schedule:
- cron: "0 0 * * *" # Run once per day
workflow_dispatch: # allows triggering the workflow run manually
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
cancel-in-progress: ${{ !contains(github.ref, 'release/') && github.ref != 'main' }}
jobs:
build-jax-artifact:
uses: ./.github/workflows/build_artifacts.yml
name: "Build jax artifact"
with:
# Note that since jax is a pure python package, the runner OS and Python values do not
# matter. In addition, cloning main XLA also has no effect.
runner: "linux-x86-n4-16"
artifact: "jax"
upload_artifacts_to_gcs: true
gcs_upload_uri: 'gs://general-ml-ci-transient/jax-github-actions/jax-fork/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
build-jaxlib-artifact:
uses: ./.github/workflows/build_artifacts.yml
strategy:
fail-fast: false # don't cancel all jobs on failure
matrix:
# Runner OS and Python values need to match the matrix stategy in the CPU tests job
runner: ["linux-x86-n4-16", "linux-arm64-t2a-48", "windows-x86-n2-16"]
artifact: ["jaxlib"]
python: ["3.11"]
# Note: For reasons unknown, Github actions groups jobs with the same top-level name in the
# dashboard only if we use an expression in the "name" field. Otherwise, it appends the matrix
# values to the name and creates a separate entry for each matrix combination.
name: "Build ${{ format('{0}', 'jaxlib') }} artifacts"
with:
runner: ${{ matrix.runner }}
artifact: ${{ matrix.artifact }}
python: ${{ matrix.python }}
clone_main_xla: 1
upload_artifacts_to_gcs: true
gcs_upload_uri: 'gs://general-ml-ci-transient/jax-github-actions/jax-fork/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
build-cuda-artifacts:
uses: ./.github/workflows/build_artifacts.yml
strategy:
fail-fast: false # don't cancel all jobs on failure
matrix:
# Python values need to match the matrix stategy in the CUDA tests job below
runner: ["linux-x86-n4-16"]
artifact: ["jax-cuda-plugin", "jax-cuda-pjrt"]
python: ["3.11",]
cuda-version: ["12", "13"]
name: "Build ${{ format('{0}', 'CUDA') }} artifacts"
with:
runner: ${{ matrix.runner }}
artifact: ${{ matrix.artifact }}
python: ${{ matrix.python }}
cuda-version: ${{ matrix.cuda-version }}
clone_main_xla: 1
upload_artifacts_to_gcs: true
gcs_upload_uri: 'gs://general-ml-ci-transient/jax-github-actions/jax-fork/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
run-pytest-cpu:
# Run test jobs even if the build job fails. Avoids losing test coverage if a single unrelated
# build job fails. E.g Windows build job fails but everything else succeeds. In this case, we
# still want to run the tests for other platforms.
if: ${{ !cancelled() }}
needs: [build-jax-artifact, build-jaxlib-artifact]
uses: ./.github/workflows/pytest_cpu.yml
strategy:
fail-fast: false # don't cancel all jobs on failure
matrix:
# Runner OS and Python values need to match the matrix stategy in the
# build_jaxlib_artifact job above
runner: ["linux-x86-n4-64", "linux-arm64-t2a-48", "windows-x86-n2-64"]
python: ["3.11",]
enable-x64: [1, 0]
name: "Pytest CPU (JAX artifacts version = ${{ format('{0}', 'head') }})"
with:
runner: ${{ matrix.runner }}
python: ${{ matrix.python }}
enable-x64: ${{ matrix.enable-x64 }}
gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }}
run-pytest-cuda:
# Run test jobs even if the build job fails. Avoids losing test coverage if a single unrelated
# build job fails. E.g Windows build job fails but everything else succeeds. In this case, we
# still want to run the tests for other platforms.
if: ${{ !cancelled() }}
needs: [build-jax-artifact, build-jaxlib-artifact, build-cuda-artifacts]
uses: ./.github/workflows/pytest_cuda.yml
strategy:
fail-fast: false # don't cancel all jobs on failure
matrix:
# Python values need to match the matrix stategy in the artifact build jobs above
# See exlusions for what is fully tested
# Disable h100 and b200 on jax-fork to save resources
runner: ["linux-x86-g2-48-l4-4gpu",] # "linux-x86-a3-8g-h100-8gpu", "linux-x86-a4-224-b200-1gpu"]
python: ["3.11",]
cuda: [
{version: "12.1", use-nvidia-pip-wheels: false},
{version: "12.9", use-nvidia-pip-wheels: true},
{version: "13", use-nvidia-pip-wheels: true},
]
enable-x64: [1, 0]
# exclude:
# # H100 runs only a single config, CUDA 12.9 Enable x64 1
# - runner: "linux-x86-a3-8g-h100-8gpu"
# cuda:
# version: "12.1"
# - runner: "linux-x86-a3-8g-h100-8gpu"
# enable-x64: "0"
# # B200 runs only a single config, CUDA 12.9 Enable x64 1
# - runner: "linux-x86-a4-224-b200-1gpu"
# cuda:
# version: "12.1"
# - runner: "linux-x86-a4-224-b200-1gpu"
# enable-x64: "0"
name: "Pytest CUDA (JAX artifacts version = ${{ format('{0}', 'head') }}, CUDA Pip packages = ${{ matrix.cuda.use-nvidia-pip-wheels }})"
with:
runner: ${{ matrix.runner }}
python: ${{ matrix.python }}
cuda-version: ${{ matrix.cuda.version }}
use-nvidia-pip-wheels: ${{ matrix.cuda.use-nvidia-pip-wheels }}
enable-x64: ${{ matrix.enable-x64 }}
# GCS upload URI is the same for both artifact build jobs
gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }}
run-bazel-test-cpu-py-import:
uses: ./.github/workflows/bazel_cpu.yml
strategy:
fail-fast: false # don't cancel all jobs on failure
matrix:
runner: ["linux-x86-n4-16", "linux-arm64-t2a-48", "windows-x86-n2-16"]
python: ["3.11",]
enable-x64: [1, 0]
name: "Bazel CPU tests with ${{ format('{0}', 'build_jaxlib=wheel') }}"
with:
runner: ${{ matrix.runner }}
python: ${{ matrix.python }}
enable-x64: ${{ matrix.enable-x64 }}
build_jaxlib: "wheel"
build_jax: "wheel"
run-bazel-test-cuda:
# Run test jobs even if the build job fails. Avoids losing test coverage if a single unrelated
# build job fails. E.g Windows build job fails but everything else succeeds. In this case, we
# still want to run the tests for other platforms.
if: ${{ !cancelled() }}
needs: [build-jax-artifact, build-jaxlib-artifact, build-cuda-artifacts]
uses: ./.github/workflows/bazel_cuda.yml
strategy:
fail-fast: false # don't cancel all jobs on failure
matrix:
# Python values need to match the matrix stategy in the build artifacts job above
runner: ["linux-x86-g2-48-l4-4gpu",]
python: ["3.11",]
cuda-version: ["12", "13"]
jaxlib-version: ["head", "pypi_latest"]
enable-x64: [1, 0]
name: "Bazel CUDA Non-RBE with build_jaxlib=false, (jax version = ${{ format('{0}', 'head') }})"
with:
runner: ${{ matrix.runner }}
python: ${{ matrix.python }}
cuda-version: ${{ matrix.cuda-version }}
enable-x64: ${{ matrix.enable-x64 }}
jaxlib-version: ${{ matrix.jaxlib-version }}
# GCS upload URI is the same for both artifact build jobs
gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }}
build_jaxlib: "false"
build_jax: "false"
write_to_bazel_remote_cache: 1
run_multiaccelerator_tests: "true"
run-bazel-test-cuda-py-import:
# Run test jobs even if the build job fails. Avoids losing test coverage if a single unrelated
# build job fails. E.g Windows build job fails but everything else succeeds. In this case, we
# still want to run the tests for other platforms.
if: ${{ !cancelled() }}
uses: ./.github/workflows/bazel_cuda.yml
strategy:
fail-fast: false # don't cancel all jobs on failure
matrix:
# Python values need to match the matrix stategy in the build artifacts job above
runner: ["linux-x86-g2-48-l4-4gpu",]
python: ["3.11"]
cuda-version: ["12", "13"]
enable-x64: [1]
name: "Bazel CUDA Non-RBE with ${{ format('{0}', 'build_jaxlib=wheel') }}"
with:
runner: ${{ matrix.runner }}
python: ${{ matrix.python }}
cuda-version: ${{ matrix.cuda-version }}
enable-x64: ${{ matrix.enable-x64 }}
build_jaxlib: "wheel"
build_jax: "wheel"
jaxlib-version: "head"
write_to_bazel_remote_cache: 1
run_multiaccelerator_tests: "true"
run-pytest-tpu:
# Run test jobs even if the build job fails. Avoids losing test coverage if a single unrelated
# build job fails. E.g Windows build job fails but everything else succeeds. In this case, we
# still want to run the tests for other platforms.
if: ${{ !cancelled() }}
needs: [build-jax-artifact, build-jaxlib-artifact]
uses: ./.github/workflows/pytest_tpu.yml
strategy:
fail-fast: false # don't cancel all jobs on failure
matrix:
python: ["3.11"]
tpu-specs: [
# {type: "v3-8", cores: "4"}, # Enable when we have the v3 type available
{type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"},
{type: "v6e-8", cores: "8", runner: "linux-x86-ct6e-180-8tpu"}
]
libtpu-version-type: ["nightly"]
name: "Pytest TPU (JAX artifacts version = ${{ format('{0}', 'head') }})"
with:
runner: ${{ matrix.tpu-specs.runner }}
cores: ${{ matrix.tpu-specs.cores }}
tpu-type: ${{ matrix.tpu-specs.type }}
python: ${{ matrix.python }}
run-full-tpu-test-suite: "1"
libtpu-version-type: ${{ matrix.libtpu-version-type }}
gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }}
run-bazel-test-tpu:
# Run test jobs even if the build job fails. Avoids losing test coverage if a single unrelated
# build job fails. E.g Windows build job fails but everything else succeeds. In this case, we
# still want to run the tests for other platforms.
if: ${{ !cancelled() }}
uses: ./.github/workflows/bazel_test_tpu.yml
strategy:
fail-fast: false # don't cancel all jobs on failure
matrix:
python: ["3.11"]
tpu-specs: [
{type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"},
{type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"},
]
libtpu-version-type: ["nightly"]
name: "Bazel tests TPU (JAX artifacts version = ${{ format('{0}', 'head') }})"
with:
runner: ${{ matrix.tpu-specs.runner }}
cores: ${{ matrix.tpu-specs.cores }}
tpu-type: ${{ matrix.tpu-specs.type }}
python: ${{ matrix.python }}
run-full-tpu-test-suite: "1"
libtpu-version-type: ${{ matrix.libtpu-version-type }}
gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }}
build_jaxlib: "wheel"
build_jax: "wheel"
clone_main_xla: 1