Skip to content

Commit 68a8353

Browse files
committed
Update GPU/TPU test scripts
1 parent b04da6e commit 68a8353

File tree

3 files changed

+18
-5
lines changed

3 files changed

+18
-5
lines changed

.github/workflows/pytest_tpu.yml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ on:
66
- main
77

88
jobs:
9-
build:
9+
run_tests:
1010
strategy:
1111
matrix:
1212
runner: ["linux-x86-ct5lp-224-8tpu"]
@@ -43,4 +43,7 @@ jobs:
4343
JAXCI_PYTHON: python${{ matrix.python }}
4444
run: $JAXCI_PYTHON -m pip install --pre libtpu-nightly -f "https://storage.googleapis.com/jax-releases/libtpu_releases.html"
4545
- name: Run Pytest TPU tests
46+
env:
47+
JAX_PLATFORMS: tpu,cpu
48+
PY_COLORS: 1
4649
run: ./ci/run_pytest.sh "ci/envs/run_tests/pytest_tpu.env"

ci/run_bazel_test.sh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,13 +68,13 @@ if [[ $JAXCI_RUN_BAZEL_TEST_GPU_RBE == 1 ]]; then
6868
//tests:gpu_tests //tests:backend_independent_tests //tests/pallas:gpu_tests //tests/pallas:backend_independent_tests
6969
fi
7070

71-
# Run Bazel GPU tests locally.
71+
# Run Non-RBE Bazel GPU tests (single accelerator and multiaccelerator tetsts).
7272
if [[ $JAXCI_RUN_BAZEL_TEST_GPU_NON_RBE == 1 ]]; then
7373
export NCCL_DEBUG=WARN
7474
nvidia-smi
75-
echo "Running local GPU tests..."
75+
echo "Running (non-RBE) GPU tests..."
7676

77-
# Runs non-multiaccelerator tests with one GPU apiece.
77+
# Runs single accelerator tests with one GPU apiece.
7878
# It appears --run_under needs an absolute path.
7979
# The product of the `JAX_ACCELERATOR_COUNT`` and `JAX_TESTS_PER_ACCELERATOR`
8080
# should match the VM's CPU core count (set in `--local_test_jobs`).

ci/run_pytest.sh

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@ if [[ $JAXCI_RUN_PYTEST_CPU == 1 ]]; then
2929
fi
3030

3131
if [[ $JAXCI_RUN_PYTEST_GPU == 1 ]]; then
32+
nvidia-smi
33+
export NCCL_DEBUG=WARN
34+
export TF_CPP_MIN_LOG_LEVEL=0
35+
3236
echo "Running GPU tests..."
3337
export XLA_PYTHON_CLIENT_ALLOCATOR=platform
3438
export XLA_FLAGS=--xla_gpu_force_compilation_parallelism=1
@@ -41,11 +45,17 @@ if [[ $JAXCI_RUN_PYTEST_GPU == 1 ]]; then
4145
fi
4246

4347
if [[ $JAXCI_RUN_PYTEST_TPU == 1 ]]; then
48+
49+
"$JAXCI_PYTHON" -c 'import sys; print("python version:", sys.version)'
50+
"$JAXCI_PYTHON" -c 'import jax; print("jax version:", jax.__version__)'
51+
"$JAXCI_PYTHON" -c 'import jaxlib; print("jaxlib version:", jaxlib.__version__)'
52+
strings $HOME/.local/lib/"$JAXCI_PYTHON"/site-packages/libtpu/libtpu.so | grep 'Built on'
53+
"$JAXCI_PYTHON" -c 'import jax; print("libtpu version:",jax.lib.xla_bridge.get_backend().platform_version)'
54+
4455
echo "Running TPU tests..."
4556
# Run single-accelerator tests in parallel
4657
export JAX_ENABLE_TPU_XDIST=true
4758

48-
"$JAXCI_PYTHON" -c 'import jax; print("libtpu version:",jax.lib.xla_bridge.get_backend().platform_version)'
4959
"$JAXCI_PYTHON" -m pytest -n="$JAXCI_TPU_CORES" --tb=short \
5060
--deselect=tests/pallas/tpu_pallas_test.py::PallasCallPrintTest \
5161
--maxfail=20 -m "not multiaccelerator" tests examples

0 commit comments

Comments
 (0)