11name : Run Pytest TPU tests
22
33on :
4- # pull_request:
5- # branches:
6- # - main
4+ pull_request :
5+ branches :
6+ - main
77 workflow_dispatch :
88 inputs :
99 halt-for-connection :
1616 - ' no'
1717
1818jobs :
19- run_tests :
19+ run_tpu_tests :
2020 strategy :
21+ fail-fast : false
2122 matrix :
22- runner : ["linux-x86-ct5lp-224-8tpu"]
23- tpu_cores : ["8"]
23+ jaxlib-version : ["pypi_latest", "nightly", "nightly+oldest_supported_libtpu"]
24+ tpu : [
25+ # {type: "v3-8", cores: "4"}, # Enable when we have the v3 type available
26+ {type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"},
27+ {type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"}
28+ ]
2429 python : ["3.10"]
2530
26- runs-on : ${{ matrix.runner }}
27- container :
28- image : " gcr.io/tensorflow-testing/nosla-cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest"
31+ runs-on : ${{ matrix.tpu.runner }}
32+ container : " us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest"
33+
34+ name : " TPU test (jaxlib=${{ matrix.jaxlib-version }}, ${{ matrix.tpu.type }})"
2935
3036 env :
31- JAXCI_CLONE_MAIN_XLA : 1
37+ LIBTPU_OLDEST_VERSION_DATE : 20240722
38+ ENABLE_PJRT_COMPATIBILITY : ${{ matrix.jaxlib-version == 'nightly+oldest_supported_libtpu' }}
3239 JAXCI_HERMETIC_PYTHON_VERSION : ${{ matrix.python }}
3340
3441 steps :
@@ -38,23 +45,33 @@ jobs:
3845 uses : google-ml-infra/actions/ci_connection@main
3946 with :
4047 halt-dispatch-input : ${{ inputs.halt-for-connection }}
41- - name : Build jaxlib
42- run : ./ci/build_artifacts.sh "jaxlib"
43- - name : Install pytest
44- env :
45- JAXCI_PYTHON : python${{ matrix.python }}
46- run : $JAXCI_PYTHON -m pip install pytest
47- - name : Install Test requirements
48- env :
49- JAXCI_PYTHON : python${{ matrix.python }}
50- run : |
51- $JAXCI_PYTHON -m pip install -r build/test-requirements.txt
52- $JAXCI_PYTHON -m pip install -r build/collect-profile-requirements.txt
53- - name : Install Libtpu
48+ - name : Install JAX test requirements
5449 env :
5550 JAXCI_PYTHON : python${{ matrix.python }}
56- run : $JAXCI_PYTHON -m pip install --pre libtpu-nightly -f "https://storage.googleapis.com/jax-releases/libtpu_releases.html"
51+ run : |
52+ pip install -U -r build/test-requirements.txt
53+ pip install -U -r build/collect-profile-requirements.txt
54+ - name : Install JAX
55+ run : |
56+ if [ "${{ matrix.jaxlib-version }}" == "pypi_latest" ]; then
57+ pip install .[tpu] \
58+ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
59+ elif [ "${{ matrix.jaxlib-version }}" == "nightly" ]; then
60+ pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
61+ pip install --pre libtpu \
62+ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
63+ pip install requests
64+ elif [ "${{ matrix.jaxlib-version }}" == "nightly+oldest_supported_libtpu" ]; then
65+ pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
66+ # TODO(phawkins): switch to libtpu, when the oldest release we support is a libtpu release.
67+ pip install --pre libtpu-nightly==0.1.dev${{ env.LIBTPU_OLDEST_VERSION_DATE }} \
68+ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
69+ pip install requests
70+ else
71+ echo "Unknown jaxlib-version: ${{ matrix.jaxlib-version }}"
72+ exit 1
73+ fi
5774 - name : Run Pytest TPU tests
5875 env :
59- JAXCI_TPU_CORES : ${{ matrix.tpu_cores }}
76+ JAXCI_TPU_CORES : ${{ matrix.tpu.cores }}
6077 run : ./ci/run_pytest_tpu.sh
0 commit comments