1313name : CI - Cloud TPU (nightly)
1414on :
1515 schedule :
16- - cron : " 0 14 * * *" # daily at 7am PST
16+ - cron : " 0 */2 * * *" # Run every 2 hours
1717 workflow_dispatch : # allows triggering the workflow run manually
1818# This should also be set to read-only in the project settings, but it's nice to
1919# document and enforce the permissions here.
@@ -24,17 +24,20 @@ jobs:
2424 strategy :
2525 fail-fast : false # don't cancel all jobs on failure
2626 matrix :
27- jaxlib-version : ["pypi_latest", "nightly", "nightly+oldest_supported_libtpu"]
27+ jaxlib-version : ["head", " pypi_latest", "nightly", "nightly+oldest_supported_libtpu"]
2828 tpu : [
29- {type: "v3-8", cores: "4"},
30- {type: "v4-8", cores: "4"},
31- {type: "v5e-8", cores: "8"}
29+ # {type: "v3-8", cores: "4"}, # Enable when we have the v3 type available
30+ {type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu" },
31+ {type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu" }
3232 ]
33+ python-version : ["3.10"]
3334 name : " TPU test (jaxlib=${{ matrix.jaxlib-version }}, ${{ matrix.tpu.type }})"
3435 env :
3536 LIBTPU_OLDEST_VERSION_DATE : 20240722
3637 ENABLE_PJRT_COMPATIBILITY : ${{ matrix.jaxlib-version == 'nightly+oldest_supported_libtpu' }}
37- runs-on : ["self-hosted", "tpu", "${{ matrix.tpu.type }}"]
38+ PYTHON : python${{ matrix.python-version }}
39+ runs-on : ${{ matrix.tpu.runner }}
40+ container : " us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest"
3841 timeout-minutes : 120
3942 defaults :
4043 run :
@@ -44,54 +47,74 @@ jobs:
4447 # mandates using a specific commit for non-Google actions. We use
4548 # https://github.com/sethvargo/ratchet to pin specific versions.
4649 - uses : actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
50+ # Checkout XLA at head, if we're building jaxlib at head.
51+ - name : Checkout XLA at head
52+ uses : actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
53+ if : ${{ matrix.jaxlib-version == 'head' }}
54+ with :
55+ repository : openxla/xla
56+ path : xla
4757 - name : Install JAX test requirements
4858 run : |
49- pip install -U -r build/test-requirements.txt
50- pip install -U -r build/collect-profile-requirements.txt
59+ $PYTHON -m pip install -U -r build/test-requirements.txt
60+ $PYTHON -m pip install -U -r build/collect-profile-requirements.txt
5161 - name : Install JAX
5262 run : |
53- pip uninstall -y jax jaxlib libtpu
54- if [ "${{ matrix.jaxlib-version }}" == "pypi_latest" ]; then
55- pip install .[tpu] \
63+ $PYTHON -m pip uninstall -y jax jaxlib libtpu
64+ if [ "${{ matrix.jaxlib-version }}" == "head" ]; then
65+ # Build and install jaxlib at head
66+ $PYTHON build/build.py --bazel_options=--config=rbe_linux_x86_64 \
67+ --bazel_options="--override_repository=xla=$(pwd)/xla" \
68+ --bazel_options=--color=yes
69+ $PYTHON -m pip install dist/*.whl
70+
71+ # Install "jax" at head
72+ $PYTHON -m pip install -U -e .
73+
74+ # Install libtpu
75+ $PYTHON -m pip install --pre libtpu \
76+ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
77+ elif [ "${{ matrix.jaxlib-version }}" == "pypi_latest" ]; then
78+ $PYTHON -m pip install .[tpu] \
5679 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
5780
5881 elif [ "${{ matrix.jaxlib-version }}" == "nightly" ]; then
59- pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
60- pip install --pre libtpu \
82+ $PYTHON -m pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
83+ $PYTHON -m pip install --pre libtpu \
6184 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
62- pip install requests
85+ $PYTHON -m pip install requests
6386
6487 elif [ "${{ matrix.jaxlib-version }}" == "nightly+oldest_supported_libtpu" ]; then
65- pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
88+ $PYTHON -m pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
6689 # TODO(phawkins): switch to libtpu, when the oldest release we support is a libtpu release.
67- pip install --pre libtpu-nightly==0.1.dev${{ env.LIBTPU_OLDEST_VERSION_DATE }} \
90+ $PYTHON -m pip install --pre libtpu-nightly==0.1.dev${{ env.LIBTPU_OLDEST_VERSION_DATE }} \
6891 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
69- pip install requests
92+ $PYTHON -m pip install requests
7093 else
7194 echo "Unknown jaxlib-version: ${{ matrix.jaxlib-version }}"
7295 exit 1
7396 fi
7497
75- python3 -c 'import sys; print("python version:", sys.version)'
76- python3 -c 'import jax; print("jax version:", jax.__version__)'
77- python3 -c 'import jaxlib; print("jaxlib version:", jaxlib.__version__)'
78- strings $HOME/. local/lib/python3.10/site -packages/libtpu/libtpu.so | grep 'Built on'
79- python3 -c 'import jax; print("libtpu version:",
98+ $PYTHON -c 'import sys; print("python version:", sys.version)'
99+ $PYTHON -c 'import jax; print("jax version:", jax.__version__)'
100+ $PYTHON -c 'import jaxlib; print("jaxlib version:", jaxlib.__version__)'
101+ strings /usr/ local/lib/"$PYTHON"/dist -packages/libtpu/libtpu.so | grep 'Built on'
102+ $PYTHON -c 'import jax; print("libtpu version:",
80103 jax.lib.xla_bridge.get_backend().platform_version)'
81104 - name : Run tests
82105 env :
83106 JAX_PLATFORMS : tpu,cpu
84107 PY_COLORS : 1
85108 run : |
86109 # Run single-accelerator tests in parallel
87- JAX_ENABLE_TPU_XDIST=true python3 -m pytest -n=${{ matrix.tpu.cores }} --tb=short \
110+ JAX_ENABLE_TPU_XDIST=true $PYTHON -m pytest -n=${{ matrix.tpu.cores }} --tb=short \
88111 --deselect=tests/pallas/tpu_pallas_test.py::PallasCallPrintTest \
89112 --maxfail=20 -m "not multiaccelerator" tests examples
90113 # Run Pallas printing tests, which need to run with I/O capturing disabled.
91- TPU_STDERR_LOG_LEVEL=0 python3 -m pytest -s \
114+ TPU_STDERR_LOG_LEVEL=0 $PYTHON -m pytest -s \
92115 tests/pallas/tpu_pallas_test.py::PallasCallPrintTest
93116 # Run multi-accelerator across all chips
94- python3 -m pytest --tb=short --maxfail=20 -m "multiaccelerator" tests
117+ $PYTHON -m pytest --tb=short --maxfail=20 -m "multiaccelerator" tests
95118 - name : Send chat on failure
96119 # Don't notify when testing the workflow from a branch.
97120 if : ${{ (failure() || cancelled()) && github.ref_name == 'main' && matrix.jaxlib-version != 'nightly+oldest_supported_libtpu' }}
0 commit comments