File tree Expand file tree Collapse file tree 2 files changed +51
-1
lines changed Expand file tree Collapse file tree 2 files changed +51
-1
lines changed Original file line number Diff line number Diff line change 1+ name : Run Pytest TPU tests
2+
3+ on :
4+ pull_request :
5+ branches :
6+ - main
7+
8+ jobs :
9+ build :
10+ strategy :
11+ matrix :
12+ runner : ["linux-x86-ct5lp-224-8tpu"]
13+ python : ["3.10"]
14+
15+ runs-on : ${{ matrix.runner }}
16+ container :
17+ image : " gcr.io/tensorflow-testing/nosla-cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest"
18+
19+ env :
20+ # GitHub actions run in Docker by defaut. Disable running the `setup_docker.sh` script.
21+ JAXCI_RUN_DOCKER_CONTAINER : 0
22+ # Use RBE to build the artifacts.
23+ JAXCI_BUILD_ARTIFACT_WITH_RBE : 1
24+ # Setup the test environment (disable x64 mode and clone XLA at HEAD)
25+ JAXCI_SETUP_TEST_ENVIRONMENT : 1
26+ JAXCI_HERMETIC_PYTHON_VERSION : ${{ matrix.python }}
27+
28+ steps :
29+ - uses : actions/checkout@v3
30+ # Halt for testing
31+ - name : Wait For Connection
32+ uses : ./actions/ci_connection/
33+ - name : Build jaxlib
34+ run : ./ci/build_artifacts.sh "ci/envs/build_artifacts/jaxlib"
35+ - name : Install pytest
36+ env :
37+ JAXCI_PYTHON : python${{ matrix.python }}
38+ run : $JAXCI_PYTHON -m pip install pytest
39+ - name : Install Test requirements
40+ env :
41+ JAXCI_PYTHON : python${{ matrix.python }}
42+ run : |
43+ $JAXCI_PYTHON -m pip install -r build/test-requirements.txt
44+ $JAXCI_PYTHON -m pip install -r build/collect-profile-requirements.txt
45+ - name : Install Libtpu
46+ run : $JAXCI_PYTHON -m pip install --pre libtpu-nightly -f "https://storage.googleapis.com/jax-releases/libtpu_releases.html"
47+ - name : Run Pytest TPU tests
48+ run : ./ci/run_pytest.sh "ci/envs/run_tests/pytest_tpu"
Original file line number Diff line number Diff line change @@ -41,6 +41,8 @@ if [[ $JAXCI_RUN_PYTEST_TPU == 1 ]]; then
4141 echo " Running TPU tests..."
4242 # Run single-accelerator tests in parallel
4343 export JAX_ENABLE_TPU_XDIST=true
44+
45+ check_if_to_run_in_docker " $JAXCI_PYTHON " -c ' import jax; print("libtpu version:",jax.lib.xla_bridge.get_backend().platform_version)'
4446 check_if_to_run_in_docker " $JAXCI_PYTHON " -m pytest -n=" $JAXCI_TPU_CORES " --tb=short \
4547 --deselect=tests/pallas/tpu_pallas_test.py::PallasCallPrintTest \
4648 --maxfail=20 -m " not multiaccelerator" tests examples
@@ -51,4 +53,4 @@ if [[ $JAXCI_RUN_PYTEST_TPU == 1 ]]; then
5153
5254 # Run multi-accelerator across all chips
5355 check_if_to_run_in_docker " $JAXCI_PYTHON " -m pytest --tb=short --maxfail=20 -m " multiaccelerator" tests
54- fi
56+ fi
You can’t perform that action at this time.
0 commit comments