Skip to content

Commit c65c2b2

Browse files
committed
modify scripts to run tpu tests
1 parent 460cb11 commit c65c2b2

File tree

2 files changed

+51
-1
lines changed

2 files changed

+51
-1
lines changed

.github/workflows/pytest_tpu.yml

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
name: Run Pytest TPU tests
2+
3+
on:
4+
pull_request:
5+
branches:
6+
- main
7+
8+
jobs:
9+
build:
10+
strategy:
11+
matrix:
12+
runner: ["linux-x86-ct5lp-224-8tpu"]
13+
python: ["3.10"]
14+
15+
runs-on: ${{ matrix.runner }}
16+
container:
17+
image: "gcr.io/tensorflow-testing/nosla-cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest"
18+
19+
env:
20+
# GitHub actions run in Docker by defaut. Disable running the `setup_docker.sh` script.
21+
JAXCI_RUN_DOCKER_CONTAINER: 0
22+
# Use RBE to build the artifacts.
23+
JAXCI_BUILD_ARTIFACT_WITH_RBE: 1
24+
# Setup the test environment (disable x64 mode and clone XLA at HEAD)
25+
JAXCI_SETUP_TEST_ENVIRONMENT: 1
26+
JAXCI_HERMETIC_PYTHON_VERSION: ${{ matrix.python }}
27+
28+
steps:
29+
- uses: actions/checkout@v3
30+
# Halt for testing
31+
- name: Wait For Connection
32+
uses: ./actions/ci_connection/
33+
- name: Build jaxlib
34+
run: ./ci/build_artifacts.sh "ci/envs/build_artifacts/jaxlib"
35+
- name: Install pytest
36+
env:
37+
JAXCI_PYTHON: python${{ matrix.python }}
38+
run: $JAXCI_PYTHON -m pip install pytest
39+
- name: Install Test requirements
40+
env:
41+
JAXCI_PYTHON: python${{ matrix.python }}
42+
run: |
43+
$JAXCI_PYTHON -m pip install -r build/test-requirements.txt
44+
$JAXCI_PYTHON -m pip install -r build/collect-profile-requirements.txt
45+
- name: Install Libtpu
46+
run: $JAXCI_PYTHON -m pip install --pre libtpu-nightly -f "https://storage.googleapis.com/jax-releases/libtpu_releases.html"
47+
- name: Run Pytest TPU tests
48+
run: ./ci/run_pytest.sh "ci/envs/run_tests/pytest_tpu"

ci/run_pytest.sh

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ if [[ $JAXCI_RUN_PYTEST_TPU == 1 ]]; then
4141
echo "Running TPU tests..."
4242
# Run single-accelerator tests in parallel
4343
export JAX_ENABLE_TPU_XDIST=true
44+
45+
check_if_to_run_in_docker "$JAXCI_PYTHON" -c 'import jax; print("libtpu version:",jax.lib.xla_bridge.get_backend().platform_version)'
4446
check_if_to_run_in_docker "$JAXCI_PYTHON" -m pytest -n="$JAXCI_TPU_CORES" --tb=short \
4547
--deselect=tests/pallas/tpu_pallas_test.py::PallasCallPrintTest \
4648
--maxfail=20 -m "not multiaccelerator" tests examples
@@ -51,4 +53,4 @@ if [[ $JAXCI_RUN_PYTEST_TPU == 1 ]]; then
5153

5254
# Run multi-accelerator across all chips
5355
check_if_to_run_in_docker "$JAXCI_PYTHON" -m pytest --tb=short --maxfail=20 -m "multiaccelerator" tests
54-
fi
56+
fi

0 commit comments

Comments
 (0)