Skip to content

Commit

Permalink
[MRG] relax the FP8 CUDA arch limitation to SM89 (#549)
Browse files Browse the repository at this point in the history
closes: #548

> Nvidia Ada Lovelace GPUs (e.g., RTX 4090, L20, L40) with SM89 version
are also support FP8 MMA, and hence, it is recommended to relax the CUDA
architecture limitations to enable FP8 training on a broader range of
devices.
>
> and the [CUDA 12.0
announcement](https://developer.nvidia.com/blog/cuda-toolkit-12-0-released-for-general-availability/)
says that it supports Lovelace architecture:
> '*CUDA 12.0 exposes programmable functionality for many features of
the NVIDIA Hopper and NVIDIA Ada Lovelace architectures: ...32x Ultra
xMMA (including FP8 and FP16)*'
> 
> -
https://developer.nvidia.com/blog/cuda-toolkit-12-0-released-for-general-availability/
> - https://nvidia.github.io/TensorRT-LLM/reference/support-matrix.html
> -
https://github.com/NVIDIA/cutlass/blob/c4e3e122e266644c61b4af33d0cc09f4c391a64b/include/cutlass/arch/mma_sm89.h#L57
> 
>
![image](https://github.com/user-attachments/assets/3c11736c-2e84-4bd6-a49c-5af8b0e3e6ac)

After relaxing the CUDA architecture limitations for FP8, my environment
with **4 x L40 GPUs (SM89)** can still successfully train llama under
float8 precision.


![image](https://github.com/user-attachments/assets/1337e041-0d0d-49b5-8c11-00e67f4df41f)

---------

Co-authored-by: Andrew Gu <[email protected]>
  • Loading branch information
leeeizhang and awgu authored Aug 21, 2024
1 parent 40210ea commit 90c889e
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions torchtitan/float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@
from torchtitan.parallelisms import ParallelDims


def _is_sm90_or_later():
# Float8 is only supported on H100+ GPUs
return torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)
def _is_sm89_or_later():
# Float8 is only supported on SM89 or later (H100+ GPUs)
return torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)


class Float8Handler:
Expand All @@ -35,9 +35,9 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
float8_config = job_config.float8
if not float8_config.enable_float8_linear:
return
if not _is_sm90_or_later():
if not _is_sm89_or_later():
logger.warning(
"Failed to swap to Float8Linear because SM90 or later is not available",
"Failed to swap to Float8Linear because float8 is only supported on SM89 or later",
)
return
try:
Expand Down

0 comments on commit 90c889e

Please sign in to comment.