diff --git a/torchtitan/float8.py b/torchtitan/float8.py index 043b1832..244fba47 100644 --- a/torchtitan/float8.py +++ b/torchtitan/float8.py @@ -23,9 +23,9 @@ from torchtitan.parallelisms import ParallelDims -def _is_sm90_or_later(): - # Float8 is only supported on H100+ GPUs - return torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0) +def _is_sm89_or_later(): + # Float8 is only supported on SM89 or later (H100+ GPUs) + return torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) class Float8Handler: @@ -35,9 +35,9 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): float8_config = job_config.float8 if not float8_config.enable_float8_linear: return - if not _is_sm90_or_later(): + if not _is_sm89_or_later(): logger.warning( - "Failed to swap to Float8Linear because SM90 or later is not available", + "Failed to swap to Float8Linear because float8 is only supported on SM89 or later", ) return try: