Skip to content

Commit aa01d86

Browse files
committed
rework tpu job and scripts to match upstream
1 parent d14a236 commit aa01d86

File tree

2 files changed

+50
-40
lines changed

2 files changed

+50
-40
lines changed

.github/workflows/pytest_tpu.yml

Lines changed: 42 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
name: Run Pytest TPU tests
22

33
on:
4-
# pull_request:
5-
# branches:
6-
# - main
4+
pull_request:
5+
branches:
6+
- main
77
workflow_dispatch:
88
inputs:
99
halt-for-connection:
@@ -16,19 +16,26 @@ on:
1616
- 'no'
1717

1818
jobs:
19-
run_tests:
19+
run_tpu_tests:
2020
strategy:
21+
fail-fast: false
2122
matrix:
22-
runner: ["linux-x86-ct5lp-224-8tpu"]
23-
tpu_cores: ["8"]
23+
jaxlib-version: ["pypi_latest", "nightly", "nightly+oldest_supported_libtpu"]
24+
tpu: [
25+
# {type: "v3-8", cores: "4"}, # Enable when we have the v3 type available
26+
{type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"},
27+
{type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"}
28+
]
2429
python: ["3.10"]
2530

26-
runs-on: ${{ matrix.runner }}
27-
container:
28-
image: "gcr.io/tensorflow-testing/nosla-cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest"
31+
runs-on: ${{ matrix.tpu.runner }}
32+
container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest"
33+
34+
name: "TPU test (jaxlib=${{ matrix.jaxlib-version }}, ${{ matrix.tpu.type }})"
2935

3036
env:
31-
JAXCI_CLONE_MAIN_XLA: 1
37+
LIBTPU_OLDEST_VERSION_DATE: 20240722
38+
ENABLE_PJRT_COMPATIBILITY: ${{ matrix.jaxlib-version == 'nightly+oldest_supported_libtpu' }}
3239
JAXCI_HERMETIC_PYTHON_VERSION: ${{ matrix.python }}
3340

3441
steps:
@@ -38,23 +45,33 @@ jobs:
3845
uses: google-ml-infra/actions/ci_connection@main
3946
with:
4047
halt-dispatch-input: ${{ inputs.halt-for-connection }}
41-
- name: Build jaxlib
42-
run: ./ci/build_artifacts.sh "jaxlib"
43-
- name: Install pytest
44-
env:
45-
JAXCI_PYTHON: python${{ matrix.python }}
46-
run: $JAXCI_PYTHON -m pip install pytest
47-
- name: Install Test requirements
48-
env:
49-
JAXCI_PYTHON: python${{ matrix.python }}
50-
run: |
51-
$JAXCI_PYTHON -m pip install -r build/test-requirements.txt
52-
$JAXCI_PYTHON -m pip install -r build/collect-profile-requirements.txt
53-
- name: Install Libtpu
48+
- name: Install JAX test requirements
5449
env:
5550
JAXCI_PYTHON: python${{ matrix.python }}
56-
run: $JAXCI_PYTHON -m pip install --pre libtpu-nightly -f "https://storage.googleapis.com/jax-releases/libtpu_releases.html"
51+
run: |
52+
pip install -U -r build/test-requirements.txt
53+
pip install -U -r build/collect-profile-requirements.txt
54+
- name: Install JAX
55+
run: |
56+
if [ "${{ matrix.jaxlib-version }}" == "pypi_latest" ]; then
57+
pip install .[tpu] \
58+
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
59+
elif [ "${{ matrix.jaxlib-version }}" == "nightly" ]; then
60+
pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
61+
pip install --pre libtpu \
62+
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
63+
pip install requests
64+
elif [ "${{ matrix.jaxlib-version }}" == "nightly+oldest_supported_libtpu" ]; then
65+
pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
66+
# TODO(phawkins): switch to libtpu, when the oldest release we support is a libtpu release.
67+
pip install --pre libtpu-nightly==0.1.dev${{ env.LIBTPU_OLDEST_VERSION_DATE }} \
68+
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
69+
pip install requests
70+
else
71+
echo "Unknown jaxlib-version: ${{ matrix.jaxlib-version }}"
72+
exit 1
73+
fi
5774
- name: Run Pytest TPU tests
5875
env:
59-
JAXCI_TPU_CORES: ${{ matrix.tpu_cores }}
76+
JAXCI_TPU_CORES: ${{ matrix.tpu.cores }}
6077
run: ./ci/run_pytest_tpu.sh

ci/run_pytest_tpu.sh

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -26,37 +26,30 @@ set -exu -o history -o allexport
2626
# Inherit default JAXCI environment variables.
2727
source ci/envs/default.env
2828

29-
# Install jaxlib wheel on the system. Requires a jaxlib wheel to be present
30-
# inside $JAXCI_OUTPUT_DIR (../dist)
31-
echo "Installing wheels locally..."
32-
source ./ci/utilities/install_wheels_locally.sh
33-
3429
# Set up the build environment.
3530
source "ci/utilities/setup_build_environment.sh"
3631

37-
export PY_COLORS=1
38-
export JAX_SKIP_SLOW_TESTS=true
39-
4032
"$JAXCI_PYTHON" -c "import jax; print(jax.default_backend()); print(jax.devices()); print(len(jax.devices()))"
41-
4233
"$JAXCI_PYTHON" -c 'import sys; print("python version:", sys.version)'
4334
"$JAXCI_PYTHON" -c 'import jax; print("jax version:", jax.__version__)'
4435
"$JAXCI_PYTHON" -c 'import jaxlib; print("jaxlib version:", jaxlib.__version__)'
4536
strings /usr/local/lib/"$JAXCI_PYTHON"/site-packages/libtpu/libtpu.so | grep 'Built on'
4637
"$JAXCI_PYTHON" -c 'import jax; print("libtpu version:",jax.lib.xla_bridge.get_backend().platform_version)'
4738

48-
echo "Running TPU tests..."
39+
# Set up common test environment variables
40+
export PY_COLORS=1
41+
export JAX_SKIP_SLOW_TESTS=true
4942
export JAX_PLATFORMS=tpu,cpu
50-
# Run single-accelerator tests in parallel
51-
export JAX_ENABLE_TPU_XDIST=true
43+
# End of common test environment variable setup
5244

53-
"$JAXCI_PYTHON" -m pytest -n="$JAXCI_TPU_CORES" --tb=short \
45+
echo "Running TPU tests..."
46+
# Run single-accelerator tests in parallel
47+
JAX_ENABLE_TPU_XDIST=true "$JAXCI_PYTHON" -m pytest -n="$JAXCI_TPU_CORES" --tb=short \
5448
--deselect=tests/pallas/tpu_pallas_test.py::PallasCallPrintTest \
5549
--maxfail=20 -m "not multiaccelerator" tests examples
5650

5751
# Run Pallas printing tests, which need to run with I/O capturing disabled.
58-
export TPU_STDERR_LOG_LEVEL=0
59-
"$JAXCI_PYTHON" -m pytest -s tests/pallas/tpu_pallas_test.py::PallasCallPrintTest
52+
TPU_STDERR_LOG_LEVEL=0 "$JAXCI_PYTHON" -m pytest -s tests/pallas/tpu_pallas_test.py::PallasCallPrintTest
6053

6154
# Run multi-accelerator across all chips
6255
"$JAXCI_PYTHON" -m pytest --tb=short --maxfail=20 -m "multiaccelerator" tests

0 commit comments

Comments
 (0)