@@ -112,11 +112,13 @@ for t in $*; do
112112 BAZEL_TARGET=" ${BAZEL_TARGET} $t "
113113done
114114
115+ TEST_TAG_FILTER_ARRAY=()
116+ TEST_TAG_FILTER_ARRAY+=(' -multiaccelerator' )
117+
115118COMMON_FLAGS=$( cat << EOF
116119--@local_config_cuda//:enable_cuda
117120--cache_test_results=${CACHE_TEST_RESULTS}
118121--test_timeout=600
119- --test_tag_filters=-multiaccelerator
120122--test_env=JAX_SKIP_SLOW_TESTS=1
121123--test_env=JAX_ACCELERATOR_COUNT=${NGPUS}
122124--test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform
@@ -138,7 +140,11 @@ case "${BATTERY}" in
138140 JOBS_PER_GPU=8
139141 JOBS=$(( NGPUS * JOBS_PER_GPU))
140142 EXTRA_FLAGS=" --local_test_jobs=${JOBS} --test_env=JAX_TESTS_PER_ACCELERATOR=${JOBS_PER_GPU} --test_env=JAX_EXCLUDE_TEST_TARGETS=PmapTest.testSizeOverflow"
141- BAZEL_TARGET=" ${BAZEL_TARGET} //tests:gpu_tests"
143+ # collect from all tests subdirectories recursively,
144+ # use jax_test_gpu tag generated by jax_multiplatform_test rule:
145+ # https://github.com/jax-ml/jax/blob/d36afe4f7fe01fe5db16069d796600090db5a3ce/jaxlib/jax.bzl#L265
146+ TEST_TAG_FILTER_ARRAY+=(' jax_test_gpu' )
147+ BAZEL_TARGET=" ${BAZEL_TARGET} //tests/..."
142148 ;;
143149 backend-independent)
144150 JOBS_PER_GPU=4
@@ -157,6 +163,8 @@ case "${BATTERY}" in
157163 ;;
158164esac
159165
166+ TEST_TAG_FILTERS=$( IFS=, ; echo " --test_tag_filters=${TEST_TAG_FILTER_ARRAY[*]} " )
167+
160168print_var NCPUS
161169print_var NGPUS
162170print_var BATTERY
@@ -165,6 +173,7 @@ print_var JOBS_PER_GPU
165173print_var JOBS
166174print_var BUILD_JAXLIB
167175print_var BAZEL_TARGET
176+ print_var TEST_TAG_FILTERS
168177print_var COMMON_FLAGS
169178print_var EXTRA_FLAGS
170179
@@ -182,4 +191,4 @@ pip install matplotlib
182191
183192cd ` jax_source_dir`
184193python build/build.py --configure_only
185- bazel test ${BAZEL_TARGET} ${COMMON_FLAGS} ${EXTRA_FLAGS}
194+ bazel test ${BAZEL_TARGET} ${TEST_TAG_FILTERS} ${ COMMON_FLAGS} ${EXTRA_FLAGS}
0 commit comments