Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions .github/container/test-jax.sh
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,13 @@ for t in $*; do
BAZEL_TARGET="${BAZEL_TARGET} $t"
done

TEST_TAG_FILTER_ARRAY=()
TEST_TAG_FILTER_ARRAY+=('-multiaccelerator')

COMMON_FLAGS=$(cat << EOF
--@local_config_cuda//:enable_cuda
--cache_test_results=${CACHE_TEST_RESULTS}
--test_timeout=600
--test_tag_filters=-multiaccelerator
--test_env=JAX_SKIP_SLOW_TESTS=1
--test_env=JAX_ACCELERATOR_COUNT=${NGPUS}
--test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform
Expand All @@ -138,7 +140,11 @@ case "${BATTERY}" in
JOBS_PER_GPU=8
JOBS=$((NGPUS * JOBS_PER_GPU))
EXTRA_FLAGS="--local_test_jobs=${JOBS} --test_env=JAX_TESTS_PER_ACCELERATOR=${JOBS_PER_GPU} --test_env=JAX_EXCLUDE_TEST_TARGETS=PmapTest.testSizeOverflow"
BAZEL_TARGET="${BAZEL_TARGET} //tests:gpu_tests"
# collect from all tests subdirectories recursively,
# use jax_test_gpu tag generated by jax_multiplatform_test rule:
# https://github.com/jax-ml/jax/blob/d36afe4f7fe01fe5db16069d796600090db5a3ce/jaxlib/jax.bzl#L265
TEST_TAG_FILTER_ARRAY+=('jax_test_gpu')
BAZEL_TARGET="${BAZEL_TARGET} //tests/..."
;;
backend-independent)
JOBS_PER_GPU=4
Expand All @@ -157,6 +163,8 @@ case "${BATTERY}" in
;;
esac

TEST_TAG_FILTERS=$(IFS=, ; echo "--test_tag_filters=${TEST_TAG_FILTER_ARRAY[*]}")

print_var NCPUS
print_var NGPUS
print_var BATTERY
Expand All @@ -165,6 +173,7 @@ print_var JOBS_PER_GPU
print_var JOBS
print_var BUILD_JAXLIB
print_var BAZEL_TARGET
print_var TEST_TAG_FILTERS
print_var COMMON_FLAGS
print_var EXTRA_FLAGS

Expand All @@ -182,4 +191,4 @@ pip install matplotlib

cd `jax_source_dir`
python build/build.py --configure_only
bazel test ${BAZEL_TARGET} ${COMMON_FLAGS} ${EXTRA_FLAGS}
bazel test ${BAZEL_TARGET} ${TEST_TAG_FILTERS} ${COMMON_FLAGS} ${EXTRA_FLAGS}
Loading