Skip to content

Commit ce1f3f3

Browse files
authored
Merge branch 'main' into sbosisio/cuda-dl-base
2 parents b2b5bcd + e57ade9 commit ce1f3f3

File tree

12 files changed

+32
-15
lines changed

12 files changed

+32
-15
lines changed

.github/container/Dockerfile.jax

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ RUN <<"EOF" bash -ex
102102
for component in $(ls ${BUILD_PATH_JAXLIB}); do
103103
echo "-e file://${BUILD_PATH_JAXLIB}/${component}" >> /opt/pip-tools.d/requirements-jax.in;
104104
done
105-
echo "-e file://${SRC_PATH_JAX}" >> /opt/pip-tools.d/requirements-jax.in
105+
echo "-e file://${SRC_PATH_JAX}[k8s]" >> /opt/pip-tools.d/requirements-jax.in
106106
echo "numpy<2.0.0" >> /opt/pip-tools.d/requirements-jax.in
107107
EOF
108108

.github/container/test-maxtext.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ fi
233233

234234
export BASE_XLA_FLAGS=${BASE_XLA_FLAGS:---xla_gpu_enable_latency_hiding_scheduler=true
235235
--xla_gpu_enable_triton_gemm=false
236-
--xla_gpu_graph_level=0
236+
--xla_gpu_enable_command_buffer=
237237
--xla_gpu_all_reduce_combine_threshold_bytes=1073741824
238238
--xla_gpu_all_gather_combine_threshold_bytes=1073741824
239239
--xla_gpu_reduce_scatter_combine_threshold_bytes=134217728

rosetta/docs/GPU_performance.md

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,3 +140,24 @@ The following flags were used previously used but no longer required.
140140
- --xla_gpu_enable_highest_priority_async_stream ; Turned on by default
141141
- --xla_gpu_enable_triton_softmax_fusion ; Deprecated, no longer used
142142

143+
## Tips for Good LLM Training Performance on Blackwell (B200)
144+
145+
### **Support for Attention Mask Type**
146+
MaxText uses the `padding_causal` mask type for [cuDNN Flash Attention](https://github.com/AI-Hypercomputer/maxtext/blob/6ec3368af31fff6e6d735ac9d5fb77f91fc0f784/MaxText/layers/attentions.py#L411). However, this mask type is not yet supported on Blackwell systems through TransformerEngine. Using `padding_causal` will default to the `unfused_attention` backend, which may reduce performance. As a temporary workaround, you can use the `causal` mask type for attention to maintain performance.
147+
148+
### **No Need to Set `CUDA_DEVICE_MAX_CONNECTIONS=1`**
149+
Hopper was requiring CUDA_DEVICE_MAX_CONNECTIONS=1 to achieve better communication-compute overlap. This isn't needed for Blackwell and is in fact slower. On Blackwell systems, kernels assigned to higher-priority streams can utilize SM (Streaming Multiprocessor) resources without waiting for lower-priority kernels to release them. Therefore, it is better to leave `CUDA_DEVICE_MAX_CONNECTIONS` at its default value.
150+
151+
### **Additional XLA Flags**
152+
Enabling CUDA Graphs only for Fusions and Custom Calls reduces CPU launch latency overheads on B200, ensure that you set the following XLA flags: `--xla_gpu_enable_command_buffer=FUSION,CUSTOM_CALL`
153+
154+
This configuration improves performance on Blackwell systems by leveraging efficient command buffer execution in all the models tested on B200.
155+
156+
### **Better Utilizing Additional Memory in Blackwell**
157+
Blackwell (B200) GPUs have a memory capacity of 180GB, significantly more than H100 GPUs. To take full advantage of this additional memory and enhance performance:
158+
159+
- Adjust model parallelism configurations: can use less model parallelism to fit the same model in memory.
160+
- Increase batch sizes where possible: larger batch sizes can improve GeMM kernel efficiency.
161+
- Optimize activation checkpointing policies: fewer activation tensors need to be recomputed in the backward pass on B200.
162+
163+
Careful tuning of these parameters is essential when transitioning from H100 to B200 systems to fully utilize the available resources.

rosetta/docs/PGLE.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ In order to get the best performance with PGLE, here is a list of all recommende
6262
```
6363
export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true
6464
--xla_gpu_enable_triton_gemm=false
65-
--xla_gpu_graph_level=0
65+
--xla_gpu_enable_command_buffer=
6666
--xla_gpu_all_reduce_combine_threshold_bytes=1073741824
6767
--xla_gpu_all_gather_combine_threshold_bytes=1073741824
6868
--xla_gpu_reduce_scatter_combine_threshold_bytes=1073741824
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
# These XLA flags are meant to be used with the JAX version in the imagen container
2-
export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=false --xla_gpu_enable_async_all_gather=false --xla_gpu_enable_async_reduce_scatter=false --xla_gpu_enable_triton_gemm=false --xla_gpu_cuda_graph_level=0 --xla_gpu_enable_triton_softmax_fusion=false --xla_gpu_enable_async_all_reduce=false ${XLA_FLAGS}"
2+
export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=false --xla_gpu_enable_async_all_gather=false --xla_gpu_enable_async_reduce_scatter=false --xla_gpu_enable_triton_gemm=false --xla_gpu_cuda_graph_level=0 --xla_gpu_enable_async_all_reduce=false ${XLA_FLAGS}"

rosetta/rosetta/projects/maxtext/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ The [GPU Performance document](../../../docs/GPU_performance.md) provides a deta
6969
```
7070
XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true
7171
--xla_gpu_enable_triton_gemm=false
72-
--xla_gpu_graph_level=0
72+
--xla_gpu_enable_command_buffer=
7373
--xla_gpu_all_reduce_combine_threshold_bytes=1073741824
7474
--xla_gpu_all_gather_combine_threshold_bytes=1073741824
7575
--xla_gpu_reduce_scatter_combine_threshold_bytes=134217728

rosetta/rosetta/projects/maxtext/scripts/example_slurm.sub

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ export NCCL_IB_SL=1
5454
# Set XLA Flags
5555
export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true
5656
--xla_gpu_enable_triton_gemm=false
57-
--xla_gpu_graph_level=0
57+
--xla_gpu_enable_command_buffer=
5858
--xla_gpu_all_reduce_combine_threshold_bytes=1073741824
5959
--xla_gpu_all_gather_combine_threshold_bytes=1073741824
6060
--xla_gpu_reduce_scatter_combine_threshold_bytes=134217728

rosetta/rosetta/projects/maxtext/xla_flags/llama2-7b-1N8G.env

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ THRESHOLD_BYTES=1073741824
55
export XLA_FLAGS="\
66
--xla_gpu_enable_latency_hiding_scheduler=true \
77
--xla_gpu_enable_triton_gemm=false \
8-
--xla_gpu_graph_level=0 \
8+
--xla_gpu_enable_command_buffer= \
99
--xla_gpu_enable_highest_priority_async_stream=true \
1010
--xla_gpu_all_reduce_combine_threshold_bytes=${THRESHOLD_BYTES} \
1111
--xla_gpu_all_gather_combine_threshold_bytes=$((THRESHOLD_BYTES/(NUM_NODES*NUM_GPUS))) \
@@ -14,7 +14,6 @@ export XLA_FLAGS="\
1414
--xla_gpu_enable_pipelined_reduce_scatter=true \
1515
--xla_gpu_enable_pipelined_all_reduce=true \
1616
--xla_gpu_enable_while_loop_double_buffering=true \
17-
--xla_gpu_enable_triton_softmax_fusion=false \
1817
--xla_gpu_enable_all_gather_combine_by_dim=false \
1918
--xla_gpu_enable_reduce_scatter_combine_by_dim=false \
2019
--xla_disable_hlo_passes=rematerialization \

rosetta/rosetta/projects/pax/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ For the the 126M model, we recommend setting `--xla_gpu_all_reduce_combine_thres
141141
BASE_XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true
142142
--xla_gpu_enable_triton_gemm=false
143143
--xla_gpu_all_reduce_combine_threshold_bytes=33554432
144-
--xla_gpu_graph_level=0" bash run_pile_multinode.sh ...
144+
--xla_gpu_enable_command_buffer=" bash run_pile_multinode.sh ...
145145
```
146146

147147
# Configs

rosetta/rosetta/projects/pax/xla_flags/common.env

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,8 @@ export XLA_FLAGS="\
44
--xla_gpu_enable_latency_hiding_scheduler=true \
55
--xla_allow_excess_precision \
66
--xla_gpu_enable_highest_priority_async_stream=true \
7-
--xla_gpu_enable_triton_softmax_fusion=false \
87
--xla_gpu_all_reduce_combine_threshold_bytes=${THRESHOLD_BYTES} \
9-
--xla_gpu_graph_level=0 \
8+
--xla_gpu_enable_command_buffer= \
109
"
1110
export XLA_PYTHON_CLIENT_MEM_FRACTION=0.8
1211
unset THRESHOLD_BYTES

0 commit comments

Comments
 (0)