1
1
#! /bin/bash
2
2
set -exo pipefail
3
+
4
+ # Absolute path to the directory of this script.
3
5
CDIR=" $( cd " $( dirname " $0 " ) " ; pwd -P) "
6
+
7
+ # Import utilities.
8
+ source " ${CDIR} /utils/run_tests_utils.sh"
9
+
10
+ # Default option values. Can be overridden via commandline flags.
4
11
LOGFILE=/tmp/pytorch_py_test.log
5
12
MAX_GRAPH_SIZE=500
6
13
GRAPH_CHECK_FREQUENCY=100
7
14
VERBOSITY=2
8
15
9
- # Utils file
10
- source " ${CDIR} /utils/run_tests_utils.sh"
11
-
12
- # Note [Keep Going]
13
- #
14
- # Set the `CONTINUE_ON_ERROR` flag to `true` to make the CI tests continue on error.
15
- # This will allow you to see all the failures on your PR, not stopping with the first
16
- # test failure like the default behavior.
17
- CONTINUE_ON_ERROR=" ${CONTINUE_ON_ERROR:- 0} "
18
- if [[ " $CONTINUE_ON_ERROR " == " 1" ]]; then
19
- set +e
20
- fi
21
-
22
- while getopts ' LM:C:V:' OPTION
16
+ # Parse commandline flags:
17
+ # -L
18
+ # disable writing to the log file at $LOGFILE.
19
+ # -M max_graph_size
20
+ # -C graph_check_frequency
21
+ # -V verbosity
22
+ # -h
23
+ # print the help string
24
+ while getopts ' LM:C:V:h' OPTION
23
25
do
24
26
case $OPTION in
25
27
L)
34
36
V)
35
37
VERBOSITY=$OPTARG
36
38
;;
39
+ h)
40
+ echo -e " Usage: $0 TEST_FILTER...\nwhere TEST_FILTERs are globs match .py test files. If no test filter is provided, runs all tests."
41
+ exit 0
42
+ ;;
43
+ \? ) # This catches all invalid options.
44
+ echo " ERROR: Invalid commandline flag."
45
+ exit 1
37
46
esac
38
47
done
39
48
shift $(( $OPTIND - 1 ))
40
49
50
+ # Set the `CONTINUE_ON_ERROR` flag to `1` to make the CI tests continue on error.
51
+ # This will allow you to see all the failures on your PR, not stopping with the first
52
+ # test failure like the default behavior.
53
+ CONTINUE_ON_ERROR=" ${CONTINUE_ON_ERROR:- 0} "
54
+ if [[ " $CONTINUE_ON_ERROR " == " 1" ]]; then
55
+ set +e
56
+ fi
57
+
41
58
export TRIM_GRAPH_SIZE=$MAX_GRAPH_SIZE
42
59
export TRIM_GRAPH_CHECK_FREQUENCY=$GRAPH_CHECK_FREQUENCY
43
60
export TORCH_TEST_DEVICES=" $CDIR /pytorch_test_base.py"
@@ -48,7 +65,36 @@ export CPU_NUM_DEVICES=4
48
65
TORCH_XLA_DIR=$( cd ~ ; dirname " $( python -c ' import torch_xla; print(torch_xla.__file__)' ) " )
49
66
COVERAGE_FILE=" $CDIR /../.coverage"
50
67
68
+ # Given $1 as a (possibly not normalized) test filepath, returns successfully
69
+ # if it matches any of the space-separated globs $_TEST_FILTER. If
70
+ # $_TEST_FILTER is empty, returns successfully.
71
+ function test_is_selected {
72
+ if [[ -z " $_TEST_FILTER " ]]; then
73
+ return 0 # success
74
+ fi
75
+
76
+ # _TEST_FILTER is a space-separate list of globs. Loop through the
77
+ # list elements.
78
+ for _FILTER in $_TEST_FILTER ; do
79
+ # realpath normalizes the paths (e.g. resolving `..` and relative paths)
80
+ # so that they can be compared.
81
+ case ` realpath $1 ` in
82
+ ` realpath $_FILTER ` )
83
+ return 0 # success
84
+ ;;
85
+ * )
86
+ # No match
87
+ ;;
88
+ esac
89
+ done
90
+
91
+ return 1 # failure
92
+ }
93
+
51
94
function run_coverage {
95
+ if ! test_is_selected " $1 " ; then
96
+ return
97
+ fi
52
98
if [ " ${USE_COVERAGE:- 0} " != " 0" ]; then
53
99
coverage run --source=" $TORCH_XLA_DIR " -p " $@ "
54
100
else
@@ -57,6 +103,9 @@ function run_coverage {
57
103
}
58
104
59
105
function run_test {
106
+ if ! test_is_selected " $1 " ; then
107
+ return
108
+ fi
60
109
echo " Running in PjRt runtime: $@ "
61
110
if [ -x " $( command -v nvidia-smi) " ] && [ " $XLA_CUDA " != " 0" ]; then
62
111
PJRT_DEVICE=CUDA run_coverage " $@ "
@@ -67,6 +116,9 @@ function run_test {
67
116
}
68
117
69
118
function run_device_detection_test {
119
+ if ! test_is_selected " $1 " ; then
120
+ return
121
+ fi
70
122
echo " Running in PjRt runtime: $@ "
71
123
current_device=$PJRT_DEVICE
72
124
current_num_gpu_devices=$GPU_NUM_DEVICES
@@ -81,36 +133,57 @@ function run_device_detection_test {
81
133
}
82
134
83
135
function run_test_without_functionalization {
136
+ if ! test_is_selected " $1 " ; then
137
+ return
138
+ fi
84
139
echo " Running with XLA_DISABLE_FUNCTIONALIZATION: $@ "
85
140
XLA_DISABLE_FUNCTIONALIZATION=1 run_test " $@ "
86
141
}
87
142
88
143
function run_use_bf16 {
144
+ if ! test_is_selected " $1 " ; then
145
+ return
146
+ fi
89
147
echo " Running with XLA_USE_BF16: $@ "
90
148
XLA_USE_BF16=1 run_test " $@ "
91
149
}
92
150
93
151
function run_downcast_bf16 {
152
+ if ! test_is_selected " $1 " ; then
153
+ return
154
+ fi
94
155
echo " Running with XLA_DOWNCAST_BF16: $@ "
95
156
XLA_DOWNCAST_BF16=1 run_test " $@ "
96
157
}
97
158
98
159
function run_dynamic {
160
+ if ! test_is_selected " $1 " ; then
161
+ return
162
+ fi
99
163
echo " Running in DynamicShape mode: $@ "
100
164
XLA_EXPERIMENTAL=" nonzero:masked_select:masked_scatter:nms" run_test " $@ "
101
165
}
102
166
103
167
function run_eager_debug {
168
+ if ! test_is_selected " $1 " ; then
169
+ return
170
+ fi
104
171
echo " Running in Eager Debug mode: $@ "
105
172
XLA_USE_EAGER_DEBUG_MODE=1 run_test " $@ "
106
173
}
107
174
108
175
function run_pt_xla_debug {
176
+ if ! test_is_selected " $1 " ; then
177
+ return
178
+ fi
109
179
echo " Running in save tensor file mode: $@ "
110
180
PT_XLA_DEBUG=1 PT_XLA_DEBUG_FILE=" /tmp/pt_xla_debug.txt" run_test " $@ "
111
181
}
112
182
113
183
function run_pt_xla_debug_level1 {
184
+ if ! test_is_selected " $1 " ; then
185
+ return
186
+ fi
114
187
echo " Running in save tensor file mode: $@ "
115
188
PT_XLA_DEBUG_LEVEL=1 PT_XLA_DEBUG_FILE=" /tmp/pt_xla_debug.txt" run_test " $@ "
116
189
}
@@ -121,6 +194,9 @@ function run_pt_xla_debug_level2 {
121
194
}
122
195
123
196
function run_torchrun {
197
+ if ! test_is_selected " $1 " ; then
198
+ return
199
+ fi
124
200
if [ -x " $( command -v nvidia-smi) " ] && [ " $XLA_CUDA " != " 0" ]; then
125
201
echo " Running torchrun test for GPU $@ "
126
202
num_devices=$( nvidia-smi --list-gpus | wc -l)
@@ -412,6 +488,17 @@ function run_tests {
412
488
fi
413
489
}
414
490
491
+ if [[ $# -ge 1 ]]; then
492
+ # There are positional arguments - set $_TEST_FILTER to them.
493
+ _TEST_FILTER=$@
494
+ # Sometimes a test may fail even if it doesn't match _TEST_FILTER. Therefore,
495
+ # we need to set this to be able to get to the test(s) we want to run.
496
+ CONTINUE_ON_ERROR=1
497
+ else
498
+ # No positional argument - run all tests.
499
+ _TEST_FILTER=" "
500
+ fi
501
+
415
502
if [ " $LOGFILE " != " " ]; then
416
503
run_tests 2>&1 | tee $LOGFILE
417
504
else
0 commit comments