@@ -112,11 +112,13 @@ for t in $*; do
112112 BAZEL_TARGET=" ${BAZEL_TARGET} $t "
113113done
114114
115+ TEST_TAG_FILTER_ARRAY=()
116+ TEST_TAG_FILTER_ARRAY+=(' -multiaccelerator' )
117+
115118COMMON_FLAGS=$( cat << EOF
116119--@local_config_cuda//:enable_cuda
117120--cache_test_results=${CACHE_TEST_RESULTS}
118121--test_timeout=600
119- --test_tag_filters=-multiaccelerator
120122--test_env=JAX_SKIP_SLOW_TESTS=1
121123--test_env=JAX_ACCELERATOR_COUNT=${NGPUS}
122124--test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform
@@ -141,7 +143,8 @@ case "${BATTERY}" in
141143 # collect from all tests subdirectories recursively,
142144 # use jax_test_gpu tag generated by jax_multiplatform_test rule:
143145 # https://github.com/jax-ml/jax/blob/d36afe4f7fe01fe5db16069d796600090db5a3ce/jaxlib/jax.bzl#L265
144- BAZEL_TARGET=" ${BAZEL_TARGET} //tests/... --test_tag_filters=jax_test_gpu"
146+ TEST_TAG_FILTER_ARRAY+=(' jax_test_gpu' )
147+ BAZEL_TARGET=" ${BAZEL_TARGET} //tests/...
145148 ;;
146149 backend-independent)
147150 JOBS_PER_GPU=4
@@ -160,6 +163,8 @@ case "${BATTERY}" in
160163 ;;
161164esac
162165
166+ TEST_TAG_FILTERS=$( IFS=, echo " --test_tag_filters=${TEST_TAG_FILTER_ARRAY[*]} " )
167+
163168print_var NCPUS
164169print_var NGPUS
165170print_var BATTERY
@@ -168,6 +173,7 @@ print_var JOBS_PER_GPU
168173print_var JOBS
169174print_var BUILD_JAXLIB
170175print_var BAZEL_TARGET
176+ print_var TEST_TAG_FILTERS
171177print_var COMMON_FLAGS
172178print_var EXTRA_FLAGS
173179
@@ -185,4 +191,4 @@ pip install matplotlib
185191
186192cd ` jax_source_dir`
187193python build/build.py --configure_only
188- bazel test ${BAZEL_TARGET} ${COMMON_FLAGS} ${EXTRA_FLAGS}
194+ bazel test ${BAZEL_TARGET} ${TEST_TAG_FILTERS} ${ COMMON_FLAGS} ${EXTRA_FLAGS}
0 commit comments