Skip to content

Commit e57ade9

Browse files
Update GPU_performance.md for B200 systems (#1268)
@nouiz this PR has the updated documentation.
1 parent 22c2577 commit e57ade9

File tree

1 file changed

+21
-0
lines changed

1 file changed

+21
-0
lines changed

rosetta/docs/GPU_performance.md

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,3 +140,24 @@ The following flags were used previously used but no longer required.
140140
- --xla_gpu_enable_highest_priority_async_stream ; Turned on by default
141141
- --xla_gpu_enable_triton_softmax_fusion ; Deprecated, no longer used
142142

143+
## Tips for Good LLM Training Performance on Blackwell (B200)
144+
145+
### **Support for Attention Mask Type**
146+
MaxText uses the `padding_causal` mask type for [cuDNN Flash Attention](https://github.com/AI-Hypercomputer/maxtext/blob/6ec3368af31fff6e6d735ac9d5fb77f91fc0f784/MaxText/layers/attentions.py#L411). However, this mask type is not yet supported on Blackwell systems through TransformerEngine. Using `padding_causal` will default to the `unfused_attention` backend, which may reduce performance. As a temporary workaround, you can use the `causal` mask type for attention to maintain performance.
147+
148+
### **No Need to Set `CUDA_DEVICE_MAX_CONNECTIONS=1`**
149+
Hopper was requiring CUDA_DEVICE_MAX_CONNECTIONS=1 to achieve better communication-compute overlap. This isn't needed for Blackwell and is in fact slower. On Blackwell systems, kernels assigned to higher-priority streams can utilize SM (Streaming Multiprocessor) resources without waiting for lower-priority kernels to release them. Therefore, it is better to leave `CUDA_DEVICE_MAX_CONNECTIONS` at its default value.
150+
151+
### **Additional XLA Flags**
152+
Enabling CUDA Graphs only for Fusions and Custom Calls reduces CPU launch latency overheads on B200, ensure that you set the following XLA flags: `--xla_gpu_enable_command_buffer=FUSION,CUSTOM_CALL`
153+
154+
This configuration improves performance on Blackwell systems by leveraging efficient command buffer execution in all the models tested on B200.
155+
156+
### **Better Utilizing Additional Memory in Blackwell**
157+
Blackwell (B200) GPUs have a memory capacity of 180GB, significantly more than H100 GPUs. To take full advantage of this additional memory and enhance performance:
158+
159+
- Adjust model parallelism configurations: can use less model parallelism to fit the same model in memory.
160+
- Increase batch sizes where possible: larger batch sizes can improve GeMM kernel efficiency.
161+
- Optimize activation checkpointing policies: fewer activation tensors need to be recomputed in the backward pass on B200.
162+
163+
Careful tuning of these parameters is essential when transitioning from H100 to B200 systems to fully utilize the available resources.

0 commit comments

Comments
 (0)