Skip to content

Commit 1a3febb

Browse files
authored
Model XLA Flags (#1052)
Moves XLA flags from model CI into their own files that can be sourced. Each file can be sourced and will print what it sets. Some files source other files, which was intentional to avoid introducing sim-links into the repo, which can sometimes have platform issues (like on windows). --------- Signed-off-by: Terry Kong <[email protected]>
1 parent ccededf commit 1a3febb

File tree

14 files changed

+111
-0
lines changed

14 files changed

+111
-0
lines changed
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
set -x
2+
NUM_NODES=1
3+
NUM_GPUS=8
4+
THRESHOLD_BYTES=1073741824
5+
export XLA_FLAGS="\
6+
--xla_gpu_enable_latency_hiding_scheduler=true \
7+
--xla_gpu_enable_triton_gemm=false \
8+
--xla_gpu_graph_level=0 \
9+
--xla_gpu_enable_highest_priority_async_stream=true \
10+
--xla_gpu_all_reduce_combine_threshold_bytes=${THRESHOLD_BYTES} \
11+
--xla_gpu_all_gather_combine_threshold_bytes=$((THRESHOLD_BYTES/(NUM_NODES*NUM_GPUS))) \
12+
--xla_gpu_reduce_scatter_combine_threshold_bytes=$((THRESHOLD_BYTES/(NUM_NODES*NUM_GPUS*2))) \
13+
--xla_gpu_enable_pipelined_all_gather=true \
14+
--xla_gpu_enable_pipelined_reduce_scatter=true \
15+
--xla_gpu_enable_pipelined_all_reduce=true \
16+
--xla_gpu_enable_while_loop_double_buffering=true \
17+
--xla_gpu_enable_triton_softmax_fusion=false \
18+
--xla_gpu_enable_all_gather_combine_by_dim=false \
19+
--xla_gpu_enable_reduce_scatter_combine_by_dim=false \
20+
--xla_disable_hlo_passes=rematerialization \
21+
"
22+
export XLA_PYTHON_CLIENT_MEM_FRACTION=0.9
23+
unset NUM_NODES NUM_GPUS THRESHOLD_BYTES
24+
set +x
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
set -x
2+
THRESHOLD_BYTES=51200
3+
export XLA_FLAGS="\
4+
--xla_gpu_enable_latency_hiding_scheduler=true \
5+
--xla_allow_excess_precision \
6+
--xla_gpu_enable_highest_priority_async_stream=true \
7+
--xla_gpu_enable_triton_softmax_fusion=false \
8+
--xla_gpu_all_reduce_combine_threshold_bytes=${THRESHOLD_BYTES} \
9+
--xla_gpu_graph_level=0 \
10+
"
11+
export XLA_PYTHON_CLIENT_MEM_FRACTION=0.8
12+
unset THRESHOLD_BYTES
13+
set +x
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
2+
source $SCRIPT_DIR/common.env
3+
unset SCRIPT_DIR
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
2+
source $SCRIPT_DIR/common.env
3+
unset SCRIPT_DIR
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
set -x
2+
THRESHOLD_BYTES=33554432
3+
export XLA_FLAGS="\
4+
--xla_gpu_enable_latency_hiding_scheduler=true \
5+
--xla_allow_excess_precision \
6+
--xla_gpu_enable_highest_priority_async_stream=true \
7+
--xla_gpu_enable_triton_softmax_fusion=false \
8+
--xla_gpu_all_reduce_combine_threshold_bytes=${THRESHOLD_BYTES} \
9+
--xla_gpu_graph_level=0 \
10+
--xla_gpu_enable_cudnn_fmha=false \
11+
"
12+
export XLA_PYTHON_CLIENT_MEM_FRACTION=0.8
13+
unset THRESHOLD_BYTES
14+
set +x
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
2+
source $SCRIPT_DIR/common.env
3+
unset SCRIPT_DIR
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
2+
source $SCRIPT_DIR/common.env
3+
unset SCRIPT_DIR
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
set -x
2+
ALL_REDUCE_THRESHOLD_BYTES=3221225472
3+
ALL_GATHER_THRESHOLD_BYTES=3221225472
4+
REDUCE_SCATTER_THRESHOLD_BYTES=402653184
5+
export XLA_FLAGS="\
6+
--xla_gpu_enable_latency_hiding_scheduler=true \
7+
--xla_allow_excess_precision \
8+
--xla_gpu_enable_highest_priority_async_stream=true \
9+
--xla_gpu_enable_triton_softmax_fusion=false \
10+
--xla_gpu_all_reduce_combine_threshold_bytes=${ALL_REDUCE_THRESHOLD_BYTES} \
11+
--xla_gpu_graph_level=0 \
12+
--xla_gpu_all_gather_combine_threshold_bytes=${ALL_GATHER_THRESHOLD_BYTES} \
13+
--xla_gpu_reduce_scatter_combine_threshold_bytes=${REDUCE_SCATTER_THRESHOLD_BYTES} \
14+
--xla_gpu_enable_pipelined_all_gather=true \
15+
--xla_gpu_enable_pipelined_reduce_scatter=true \
16+
--xla_gpu_enable_pipelined_all_reduce=true \
17+
--xla_gpu_enable_while_loop_double_buffering=true \
18+
--xla_gpu_enable_all_gather_combine_by_dim=false \
19+
--xla_gpu_enable_reduce_scatter_combine_by_dim=false \
20+
--xla_disable_hlo_passes=rematerialization \
21+
--xla_gpu_enable_custom_fusions=true
22+
"
23+
export XLA_PYTHON_CLIENT_MEM_FRACTION=0.9
24+
unset ALL_REDUCE_THRESHOLD_BYTES ALL_GATHER_THRESHOLD_BYTES REDUCE_SCATTER_THRESHOLD_BYTES
25+
set +x
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
2+
source $SCRIPT_DIR/common.env
3+
unset SCRIPT_DIR
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
set -x
2+
echo "$0 uses default XLA_FLAGS='${XLA_FLAGS:-}'"
3+
export XLA_PYTHON_CLIENT_MEM_FRACTION=0.85
4+
set +x

0 commit comments

Comments
 (0)