Skip to content

Commit 90c889e

Browse files
leeeizhangawgu
andauthored
[MRG] relax the FP8 CUDA arch limitation to SM89 (#549)
closes: #548 > Nvidia Ada Lovelace GPUs (e.g., RTX 4090, L20, L40) with SM89 version are also support FP8 MMA, and hence, it is recommended to relax the CUDA architecture limitations to enable FP8 training on a broader range of devices. > > and the [CUDA 12.0 announcement](https://developer.nvidia.com/blog/cuda-toolkit-12-0-released-for-general-availability/) says that it supports Lovelace architecture: > '*CUDA 12.0 exposes programmable functionality for many features of the NVIDIA Hopper and NVIDIA Ada Lovelace architectures: ...32x Ultra xMMA (including FP8 and FP16)*' > > - https://developer.nvidia.com/blog/cuda-toolkit-12-0-released-for-general-availability/ > - https://nvidia.github.io/TensorRT-LLM/reference/support-matrix.html > - https://github.com/NVIDIA/cutlass/blob/c4e3e122e266644c61b4af33d0cc09f4c391a64b/include/cutlass/arch/mma_sm89.h#L57 > > ![image](https://github.com/user-attachments/assets/3c11736c-2e84-4bd6-a49c-5af8b0e3e6ac) After relaxing the CUDA architecture limitations for FP8, my environment with **4 x L40 GPUs (SM89)** can still successfully train llama under float8 precision. ![image](https://github.com/user-attachments/assets/1337e041-0d0d-49b5-8c11-00e67f4df41f) --------- Co-authored-by: Andrew Gu <[email protected]>
1 parent 40210ea commit 90c889e

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

torchtitan/float8.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@
2323
from torchtitan.parallelisms import ParallelDims
2424

2525

26-
def _is_sm90_or_later():
27-
# Float8 is only supported on H100+ GPUs
28-
return torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)
26+
def _is_sm89_or_later():
27+
# Float8 is only supported on SM89 or later (H100+ GPUs)
28+
return torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)
2929

3030

3131
class Float8Handler:
@@ -35,9 +35,9 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
3535
float8_config = job_config.float8
3636
if not float8_config.enable_float8_linear:
3737
return
38-
if not _is_sm90_or_later():
38+
if not _is_sm89_or_later():
3939
logger.warning(
40-
"Failed to swap to Float8Linear because SM90 or later is not available",
40+
"Failed to swap to Float8Linear because float8 is only supported on SM89 or later",
4141
)
4242
return
4343
try:

0 commit comments

Comments
 (0)