File tree Expand file tree Collapse file tree 3 files changed +18
-5
lines changed Expand file tree Collapse file tree 3 files changed +18
-5
lines changed Original file line number Diff line number Diff line change 66 - main
77
88jobs :
9- build :
9+ run_tests :
1010 strategy :
1111 matrix :
1212 runner : ["linux-x86-ct5lp-224-8tpu"]
4343 JAXCI_PYTHON : python${{ matrix.python }}
4444 run : $JAXCI_PYTHON -m pip install --pre libtpu-nightly -f "https://storage.googleapis.com/jax-releases/libtpu_releases.html"
4545 - name : Run Pytest TPU tests
46+ env :
47+ JAX_PLATFORMS : tpu,cpu
48+ PY_COLORS : 1
4649 run : ./ci/run_pytest.sh "ci/envs/run_tests/pytest_tpu.env"
Original file line number Diff line number Diff line change @@ -68,13 +68,13 @@ if [[ $JAXCI_RUN_BAZEL_TEST_GPU_RBE == 1 ]]; then
6868 //tests:gpu_tests //tests:backend_independent_tests //tests/pallas:gpu_tests //tests/pallas:backend_independent_tests
6969fi
7070
71- # Run Bazel GPU tests locally .
71+ # Run Non-RBE Bazel GPU tests (single accelerator and multiaccelerator tetsts) .
7272if [[ $JAXCI_RUN_BAZEL_TEST_GPU_NON_RBE == 1 ]]; then
7373 export NCCL_DEBUG=WARN
7474 nvidia-smi
75- echo " Running local GPU tests..."
75+ echo " Running (non-RBE) GPU tests..."
7676
77- # Runs non-multiaccelerator tests with one GPU apiece.
77+ # Runs single accelerator tests with one GPU apiece.
7878 # It appears --run_under needs an absolute path.
7979 # The product of the `JAX_ACCELERATOR_COUNT`` and `JAX_TESTS_PER_ACCELERATOR`
8080 # should match the VM's CPU core count (set in `--local_test_jobs`).
Original file line number Diff line number Diff line change @@ -29,6 +29,10 @@ if [[ $JAXCI_RUN_PYTEST_CPU == 1 ]]; then
2929fi
3030
3131if [[ $JAXCI_RUN_PYTEST_GPU == 1 ]]; then
32+ nvidia-smi
33+ export NCCL_DEBUG=WARN
34+ export TF_CPP_MIN_LOG_LEVEL=0
35+
3236 echo " Running GPU tests..."
3337 export XLA_PYTHON_CLIENT_ALLOCATOR=platform
3438 export XLA_FLAGS=--xla_gpu_force_compilation_parallelism=1
@@ -41,11 +45,17 @@ if [[ $JAXCI_RUN_PYTEST_GPU == 1 ]]; then
4145fi
4246
4347if [[ $JAXCI_RUN_PYTEST_TPU == 1 ]]; then
48+
49+ " $JAXCI_PYTHON " -c ' import sys; print("python version:", sys.version)'
50+ " $JAXCI_PYTHON " -c ' import jax; print("jax version:", jax.__version__)'
51+ " $JAXCI_PYTHON " -c ' import jaxlib; print("jaxlib version:", jaxlib.__version__)'
52+ strings $HOME /.local/lib/" $JAXCI_PYTHON " /site-packages/libtpu/libtpu.so | grep ' Built on'
53+ " $JAXCI_PYTHON " -c ' import jax; print("libtpu version:",jax.lib.xla_bridge.get_backend().platform_version)'
54+
4455 echo " Running TPU tests..."
4556 # Run single-accelerator tests in parallel
4657 export JAX_ENABLE_TPU_XDIST=true
4758
48- " $JAXCI_PYTHON " -c ' import jax; print("libtpu version:",jax.lib.xla_bridge.get_backend().platform_version)'
4959 " $JAXCI_PYTHON " -m pytest -n=" $JAXCI_TPU_CORES " --tb=short \
5060 --deselect=tests/pallas/tpu_pallas_test.py::PallasCallPrintTest \
5161 --maxfail=20 -m " not multiaccelerator" tests examples
You can’t perform that action at this time.
0 commit comments