Skip to content

Commit edc1a88

Browse files
Allow run_tests to run a subset of the tests. (#9190)
Co-authored-by: Zhanyong Wan <[email protected]>
1 parent f39434a commit edc1a88

File tree

1 file changed

+101
-14
lines changed

1 file changed

+101
-14
lines changed

test/run_tests.sh

Lines changed: 101 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,27 @@
11
#!/bin/bash
22
set -exo pipefail
3+
4+
# Absolute path to the directory of this script.
35
CDIR="$(cd "$(dirname "$0")" ; pwd -P)"
6+
7+
# Import utilities.
8+
source "${CDIR}/utils/run_tests_utils.sh"
9+
10+
# Default option values. Can be overridden via commandline flags.
411
LOGFILE=/tmp/pytorch_py_test.log
512
MAX_GRAPH_SIZE=500
613
GRAPH_CHECK_FREQUENCY=100
714
VERBOSITY=2
815

9-
# Utils file
10-
source "${CDIR}/utils/run_tests_utils.sh"
11-
12-
# Note [Keep Going]
13-
#
14-
# Set the `CONTINUE_ON_ERROR` flag to `true` to make the CI tests continue on error.
15-
# This will allow you to see all the failures on your PR, not stopping with the first
16-
# test failure like the default behavior.
17-
CONTINUE_ON_ERROR="${CONTINUE_ON_ERROR:-0}"
18-
if [[ "$CONTINUE_ON_ERROR" == "1" ]]; then
19-
set +e
20-
fi
21-
22-
while getopts 'LM:C:V:' OPTION
16+
# Parse commandline flags:
17+
# -L
18+
# disable writing to the log file at $LOGFILE.
19+
# -M max_graph_size
20+
# -C graph_check_frequency
21+
# -V verbosity
22+
# -h
23+
# print the help string
24+
while getopts 'LM:C:V:h' OPTION
2325
do
2426
case $OPTION in
2527
L)
@@ -34,10 +36,25 @@ do
3436
V)
3537
VERBOSITY=$OPTARG
3638
;;
39+
h)
40+
echo -e "Usage: $0 TEST_FILTER...\nwhere TEST_FILTERs are globs match .py test files. If no test filter is provided, runs all tests."
41+
exit 0
42+
;;
43+
\?) # This catches all invalid options.
44+
echo "ERROR: Invalid commandline flag."
45+
exit 1
3746
esac
3847
done
3948
shift $(($OPTIND - 1))
4049

50+
# Set the `CONTINUE_ON_ERROR` flag to `1` to make the CI tests continue on error.
51+
# This will allow you to see all the failures on your PR, not stopping with the first
52+
# test failure like the default behavior.
53+
CONTINUE_ON_ERROR="${CONTINUE_ON_ERROR:-0}"
54+
if [[ "$CONTINUE_ON_ERROR" == "1" ]]; then
55+
set +e
56+
fi
57+
4158
export TRIM_GRAPH_SIZE=$MAX_GRAPH_SIZE
4259
export TRIM_GRAPH_CHECK_FREQUENCY=$GRAPH_CHECK_FREQUENCY
4360
export TORCH_TEST_DEVICES="$CDIR/pytorch_test_base.py"
@@ -48,7 +65,36 @@ export CPU_NUM_DEVICES=4
4865
TORCH_XLA_DIR=$(cd ~; dirname "$(python -c 'import torch_xla; print(torch_xla.__file__)')")
4966
COVERAGE_FILE="$CDIR/../.coverage"
5067

68+
# Given $1 as a (possibly not normalized) test filepath, returns successfully
69+
# if it matches any of the space-separated globs $_TEST_FILTER. If
70+
# $_TEST_FILTER is empty, returns successfully.
71+
function test_is_selected {
72+
if [[ -z "$_TEST_FILTER" ]]; then
73+
return 0 # success
74+
fi
75+
76+
# _TEST_FILTER is a space-separate list of globs. Loop through the
77+
# list elements.
78+
for _FILTER in $_TEST_FILTER; do
79+
# realpath normalizes the paths (e.g. resolving `..` and relative paths)
80+
# so that they can be compared.
81+
case `realpath $1` in
82+
`realpath $_FILTER`)
83+
return 0 # success
84+
;;
85+
*)
86+
# No match
87+
;;
88+
esac
89+
done
90+
91+
return 1 # failure
92+
}
93+
5194
function run_coverage {
95+
if ! test_is_selected "$1"; then
96+
return
97+
fi
5298
if [ "${USE_COVERAGE:-0}" != "0" ]; then
5399
coverage run --source="$TORCH_XLA_DIR" -p "$@"
54100
else
@@ -57,6 +103,9 @@ function run_coverage {
57103
}
58104

59105
function run_test {
106+
if ! test_is_selected "$1"; then
107+
return
108+
fi
60109
echo "Running in PjRt runtime: $@"
61110
if [ -x "$(command -v nvidia-smi)" ] && [ "$XLA_CUDA" != "0" ]; then
62111
PJRT_DEVICE=CUDA run_coverage "$@"
@@ -67,6 +116,9 @@ function run_test {
67116
}
68117

69118
function run_device_detection_test {
119+
if ! test_is_selected "$1"; then
120+
return
121+
fi
70122
echo "Running in PjRt runtime: $@"
71123
current_device=$PJRT_DEVICE
72124
current_num_gpu_devices=$GPU_NUM_DEVICES
@@ -81,36 +133,57 @@ function run_device_detection_test {
81133
}
82134

83135
function run_test_without_functionalization {
136+
if ! test_is_selected "$1"; then
137+
return
138+
fi
84139
echo "Running with XLA_DISABLE_FUNCTIONALIZATION: $@"
85140
XLA_DISABLE_FUNCTIONALIZATION=1 run_test "$@"
86141
}
87142

88143
function run_use_bf16 {
144+
if ! test_is_selected "$1"; then
145+
return
146+
fi
89147
echo "Running with XLA_USE_BF16: $@"
90148
XLA_USE_BF16=1 run_test "$@"
91149
}
92150

93151
function run_downcast_bf16 {
152+
if ! test_is_selected "$1"; then
153+
return
154+
fi
94155
echo "Running with XLA_DOWNCAST_BF16: $@"
95156
XLA_DOWNCAST_BF16=1 run_test "$@"
96157
}
97158

98159
function run_dynamic {
160+
if ! test_is_selected "$1"; then
161+
return
162+
fi
99163
echo "Running in DynamicShape mode: $@"
100164
XLA_EXPERIMENTAL="nonzero:masked_select:masked_scatter:nms" run_test "$@"
101165
}
102166

103167
function run_eager_debug {
168+
if ! test_is_selected "$1"; then
169+
return
170+
fi
104171
echo "Running in Eager Debug mode: $@"
105172
XLA_USE_EAGER_DEBUG_MODE=1 run_test "$@"
106173
}
107174

108175
function run_pt_xla_debug {
176+
if ! test_is_selected "$1"; then
177+
return
178+
fi
109179
echo "Running in save tensor file mode: $@"
110180
PT_XLA_DEBUG=1 PT_XLA_DEBUG_FILE="/tmp/pt_xla_debug.txt" run_test "$@"
111181
}
112182

113183
function run_pt_xla_debug_level1 {
184+
if ! test_is_selected "$1"; then
185+
return
186+
fi
114187
echo "Running in save tensor file mode: $@"
115188
PT_XLA_DEBUG_LEVEL=1 PT_XLA_DEBUG_FILE="/tmp/pt_xla_debug.txt" run_test "$@"
116189
}
@@ -121,6 +194,9 @@ function run_pt_xla_debug_level2 {
121194
}
122195

123196
function run_torchrun {
197+
if ! test_is_selected "$1"; then
198+
return
199+
fi
124200
if [ -x "$(command -v nvidia-smi)" ] && [ "$XLA_CUDA" != "0" ]; then
125201
echo "Running torchrun test for GPU $@"
126202
num_devices=$(nvidia-smi --list-gpus | wc -l)
@@ -412,6 +488,17 @@ function run_tests {
412488
fi
413489
}
414490

491+
if [[ $# -ge 1 ]]; then
492+
# There are positional arguments - set $_TEST_FILTER to them.
493+
_TEST_FILTER=$@
494+
# Sometimes a test may fail even if it doesn't match _TEST_FILTER. Therefore,
495+
# we need to set this to be able to get to the test(s) we want to run.
496+
CONTINUE_ON_ERROR=1
497+
else
498+
# No positional argument - run all tests.
499+
_TEST_FILTER=""
500+
fi
501+
415502
if [ "$LOGFILE" != "" ]; then
416503
run_tests 2>&1 | tee $LOGFILE
417504
else

0 commit comments

Comments
 (0)