Skip to content

Enable blockwise FP8 dense training kernels on ROCm#4036

Open
brucechanglongxu wants to merge 2 commits intopytorch:mainfrom
brucechanglongxu:rocm-blockwise-fp8-dense-enablement
Open

Enable blockwise FP8 dense training kernels on ROCm#4036
brucechanglongxu wants to merge 2 commits intopytorch:mainfrom
brucechanglongxu:rocm-blockwise-fp8-dense-enablement

Conversation

@brucechanglongxu
Copy link
Contributor

@brucechanglongxu brucechanglongxu commented Mar 10, 2026

The blockwise FP8 quantization kernels in torchao/prototype/blockwise_fp8_training/kernels.py had hardcoded FP8 max values (448.0) and dtype references to float8_e4m3fn, which prevented them from running on AMD GPUs that use float8_e4m3fnuz. Same class of issue fixed for the MoE path in #3996.

This PR parameterizes FP8_MAX as a tl.constexpr in the 5 Triton quantization kernels, updates the wrapper functions and reference implementations to derive limits from torch.finfo(dtype), and defaults the dtype to e4m3_dtype (platform-aware). The dtype assertions are widened to accept both float8_e4m3fn and float8_e4m3fnuz. Test skip guards are updated to use get_device_capability()[0] >= 9 which covers both NVIDIA SM >= 9.0 and ROCm GFX9xx+, and the @skip_if_rocm decorators on the quantization kernel tests are removed. The Float8BlockwiseLinear init guard is also updated.

Addresses blockwise_fp8_training entries in #3666.

Tested on MI300X (gfx942) with float8_e4m3fnuz (FP8 max = 240.0):

Quantization kernel correctness (Triton vs PyTorch reference, block_size=128):

  M= 4096 K= 1024: PASS  (act_quant_lhs=ok  act_quant_rhs=ok  weight_quant_rhs=ok)
  M= 4096 K= 4096: PASS  (act_quant_lhs=ok  act_quant_rhs=ok  weight_quant_rhs=ok)
  M= 2048 K= 8192: PASS  (act_quant_lhs=ok  act_quant_rhs=ok  weight_quant_rhs=ok)

GEMM correctness (FP8 Triton GEMM vs BF16 matmul, SQNR):

  M=    2 N=  512 K=  128: PASS  SQNR=28.8 dB
  M=    2 N= 5120 K= 1280: PASS  SQNR=28.6 dB
  M=    4 N= 3584 K=  640: PASS  SQNR=28.6 dB
  M=  128 N= 4096 K= 4096: PASS  SQNR=28.7 dB
  M=  512 N= 4096 K= 4096: PASS  SQNR=28.7 dB
  M= 2048 N= 4096 K= 4096: PASS  SQNR=28.7 dB

GEMM kernel performance (FP8 Triton vs BF16 hipBLAS):

     M      N      K    FP8 (ms)  BF16 (ms)  FP8 TFLOPS  BF16 TFLOPS
   128   4096   4096       0.079      0.021       54.6       199.8
   512   4096   4096       0.078      0.043      220.3       396.2
  2048   4096   4096       0.234      0.134      294.1       513.7
  4096   4096   4096       0.459      0.243      299.3       565.8
  8192   4096   4096       0.898      0.460      306.2       597.7
  2048   8192   4096       0.303      0.260      453.4       529.5
  4096   8192   4096       0.584      0.463      470.5       594.0
  4096   5120   5120       0.409      0.361      525.0       594.8

The FP8 Triton GEMM doesn't outperform hipBLAS BF16 on raw throughput — hipBLAS is heavily tuned for MI300X. The value here is that the kernels now work correctly on ROCm, and the FP8 path provides memory/bandwidth savings during training.

Introduce rocm_device_capability() and is_rocm_gpu_at_least() to provide
structured (major, minor) capability tags for AMD GPUs, analogous to
is_sm_at_least_*() for NVIDIA. Parses GCN arch strings (gfx942 -> (9,4),
gfx950 -> (9,5), gfx1200 -> (12,0)) following the convention established
by vLLM (vllm/platforms/rocm.py).

Also adds is_MI325X() and is_MI350() to __all__, fixes the is_Navi4()
operator precedence bug, and adds type annotations to all ROCm helpers.
The blockwise FP8 quantization kernels and GEMM ops in
torchao/prototype/blockwise_fp8_training had hardcoded FP8 max values
(448.0) and dtype references to float8_e4m3fn, which prevented them
from running on AMD GPUs that use float8_e4m3fnuz.

This commit parameterizes the FP8_MAX value in all 5 Triton
quantization kernels, updates the 5 wrapper functions and 3 reference
implementations to derive limits from torch.finfo(dtype), and defaults
the dtype to e4m3_dtype (platform-aware: e4m3fn on NVIDIA, e4m3fnuz on
ROCm). Test skip guards are updated to use a device-capability check
that works on both NVIDIA SM >= 9.0 and ROCm, and the @skip_if_rocm
decorators on the quantization kernel tests are removed.

Addresses the blockwise_fp8_training entries in pytorch#3666.
@pytorch-bot
Copy link

pytorch-bot bot commented Mar 10, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/4036

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure

As of commit a95b462 with merge base f0d0deb (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 10, 2026
@brucechanglongxu
Copy link
Contributor Author

cc @danielvegamyhre @BowenBao

@danielvegamyhre
Copy link
Contributor

@brucechanglongxu some merge conflicts to resolve

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. float8 module: rocm

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants