Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 25 additions & 13 deletions .github/workflows/cloud-tpu-ci-nightly.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,16 @@ on:
schedule:
- cron: "0 14 * * *" # daily at 7am PST
workflow_dispatch: # allows triggering the workflow run manually
# # TODO: remove pull request trigger
pull_request:
branches:
- main
# This should also be set to read-only in the project settings, but it's nice to
# document and enforce the permissions here.
# TODO - remove concurrency for normal usage. Its here for presubmit testing
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
permissions:
contents: read
jobs:
Expand All @@ -26,15 +34,19 @@ jobs:
matrix:
jaxlib-version: ["pypi_latest", "nightly", "nightly+oldest_supported_libtpu"]
tpu: [
{type: "v3-8", cores: "4"},
{type: "v4-8", cores: "4"},
# {type: "v3-8", cores: "4"},
# {type: "v4-8", cores: "4"},
{type: "v5e-8", cores: "8"}
]
name: "TPU test (jaxlib=${{ matrix.jaxlib-version }}, ${{ matrix.tpu.type }})"
env:
LIBTPU_OLDEST_VERSION_DATE: 20240228
ENABLE_PJRT_COMPATIBILITY: ${{ matrix.jaxlib-version == 'nightly+oldest_supported_libtpu' }}
runs-on: ["self-hosted", "tpu", "${{ matrix.tpu.type }}"]
runs-on: ["arc-linux-x86-ct5lp-224-8tpu"]
container:
# TODO repin
# We run on a bare python image to best replicate enduser usage
image: python:3.10-bookworm
timeout-minutes: 120
defaults:
run:
Expand Down Expand Up @@ -75,7 +87,6 @@ jobs:
python3 -c 'import sys; print("python version:", sys.version)'
python3 -c 'import jax; print("jax version:", jax.__version__)'
python3 -c 'import jaxlib; print("jaxlib version:", jaxlib.__version__)'
strings $HOME/.local/lib/python3.10/site-packages/libtpu/libtpu.so | grep 'Built on'
python3 -c 'import jax; print("libtpu version:",
jax.lib.xla_bridge.get_backend().platform_version)'
- name: Run tests
Expand All @@ -92,12 +103,13 @@ jobs:
tests/pallas/tpu_pallas_test.py::PallasCallPrintTest
# Run multi-accelerator across all chips
python3 -m pytest --tb=short --maxfail=20 -m "multiaccelerator" tests
- name: Send chat on failure
# Don't notify when testing the workflow from a branch.
if: ${{ (failure() || cancelled()) && github.ref_name == 'main' && matrix.jaxlib-version != 'nightly+oldest_supported_libtpu' }}
run: |
curl --location --request POST '${{ secrets.BUILD_CHAT_WEBHOOK }}' \
--header 'Content-Type: application/json' \
--data-raw "{
'text': '\"$GITHUB_WORKFLOW\", jaxlib/libtpu version \"${{ matrix.jaxlib-version }}\", TPU type ${{ matrix.tpu.type }} job failed, timed out, or was cancelled: $GITHUB_SERVER_URL/$GITHUB_REPOSITORY/actions/runs/$GITHUB_RUN_ID'
}"
# TODO: reenable
# - name: Send chat on failure
# # Don't notify when testing the workflow from a branch.
# if: ${{ (failure() || cancelled()) && github.ref_name == 'main' && matrix.jaxlib-version != 'nightly+oldest_supported_libtpu' }}
# run: |
# curl --location --request POST '${{ secrets.BUILD_CHAT_WEBHOOK }}' \
# --header 'Content-Type: application/json' \
# --data-raw "{
# 'text': '\"$GITHUB_WORKFLOW\", jaxlib/libtpu version \"${{ matrix.jaxlib-version }}\", TPU type ${{ matrix.tpu.type }} job failed, timed out, or was cancelled: $GITHUB_SERVER_URL/$GITHUB_REPOSITORY/actions/runs/$GITHUB_RUN_ID'
# }"
74 changes: 74 additions & 0 deletions .github/workflows/cloud-tpu-presubmit.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Cloud TPU CI
name: Cloud TPU Presubmit
# Run on pull_request that is labeled as "optional_ci_tpu" or workflow dispatch
on:
pull_request:
branches:
- main
types: [labeled, synchronize]
workflow_dispatch:
# Cancel any previous iterations if a new commit is pushed
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
permissions:
contents: read
jobs:
cloud-tpu-test:
# TODO: confirm final naming for optional label
if: contains(github.event.pull_request.labels.*.name, 'optional_ci_tpu')
name: "TPU v5e x 8 Presubmit"
env:
ENABLE_PJRT_COMPATIBILITY: 1
# TODO: Needs final runs-on value
runs-on: arc-linux-x86-ct5lp-224-8tpu
container:
# TODO: Needs newer, light weight image
image: index.docker.io/tensorflow/build@sha256:7fb38f0319bda36393cad7f40670aa22352b44421bb906f5cf34d543acd8e1d2 # ratchet:tensorflow/build:latest-python3.11
timeout-minutes: 120
defaults:
run:
shell: bash -ex {0}
steps:
- uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4
- name: Install JAX test requirements
run: |
pip install -U -r build/test-requirements.txt
# TODO: build jax should be done on a step prior or we should just bazel test
- name: Build JAX
run: |
pip uninstall -y jaxlib
python3 build/build.py --use_clang
pip install -e .
ls -la dist/*.whl
pip install dist/*.whl
# Note the version it installs! Should be today's date
pip install -U --no-index --pre libtpu-nightly -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
python3 -c 'import sys; print("python version:", sys.version)'
python3 -c 'import jax; print("jax version:", jax.__version__)'
python3 -c 'import jaxlib; print("jaxlib version:", jaxlib.__version__)'
python3 -c 'import jax; print("libtpu version:",
jax.lib.xla_bridge.get_backend().platform_version)'
- name: Run tests
env:
JAX_PLATFORMS: tpu,cpu
PY_COLORS: 1
NUM_TESTS: 8
JAX_NUM_GENERATED_CASES: 25
run: |
# Run single-accelerator tests in parallel
mkdir results
JAX_ENABLE_TPU_XDIST=true python3 -m pytest -n=$NUM_TESTS --tb=short \
--junitxml=results/singlejunit.xml --maxfail=20 -m "not multiaccelerator" tests examples
# Run multi-accelerator across all chips
python3 -m pytest --tb=short --junitxml=results/multijunit.xml \
--maxfail=20 -m "multiaccelerator" tests
# - name: 'Upload Artifact'
# if: success() || failure()
# uses: actions/upload-artifact@65462800fd760344b1a7b4382951275a0abb4808 # ratchet:actions/upload-artifact@v4
# with:
# name: junit
# path: |
# results/singlejunit.xml
# results/multijunit.xml
# retention-days: 1