Enable blockwise FP8 dense training kernels on ROCm#4036
Open
brucechanglongxu wants to merge 2 commits intopytorch:mainfrom
Open
Enable blockwise FP8 dense training kernels on ROCm#4036brucechanglongxu wants to merge 2 commits intopytorch:mainfrom
brucechanglongxu wants to merge 2 commits intopytorch:mainfrom
Conversation
Introduce rocm_device_capability() and is_rocm_gpu_at_least() to provide structured (major, minor) capability tags for AMD GPUs, analogous to is_sm_at_least_*() for NVIDIA. Parses GCN arch strings (gfx942 -> (9,4), gfx950 -> (9,5), gfx1200 -> (12,0)) following the convention established by vLLM (vllm/platforms/rocm.py). Also adds is_MI325X() and is_MI350() to __all__, fixes the is_Navi4() operator precedence bug, and adds type annotations to all ROCm helpers.
The blockwise FP8 quantization kernels and GEMM ops in torchao/prototype/blockwise_fp8_training had hardcoded FP8 max values (448.0) and dtype references to float8_e4m3fn, which prevented them from running on AMD GPUs that use float8_e4m3fnuz. This commit parameterizes the FP8_MAX value in all 5 Triton quantization kernels, updates the 5 wrapper functions and 3 reference implementations to derive limits from torch.finfo(dtype), and defaults the dtype to e4m3_dtype (platform-aware: e4m3fn on NVIDIA, e4m3fnuz on ROCm). Test skip guards are updated to use a device-capability check that works on both NVIDIA SM >= 9.0 and ROCm, and the @skip_if_rocm decorators on the quantization kernel tests are removed. Addresses the blockwise_fp8_training entries in pytorch#3666.
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/4036
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New FailureAs of commit a95b462 with merge base f0d0deb ( NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Contributor
Author
Contributor
|
@brucechanglongxu some merge conflicts to resolve |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
The blockwise FP8 quantization kernels in
torchao/prototype/blockwise_fp8_training/kernels.pyhad hardcoded FP8 max values (448.0) and dtype references tofloat8_e4m3fn, which prevented them from running on AMD GPUs that usefloat8_e4m3fnuz. Same class of issue fixed for the MoE path in #3996.This PR parameterizes
FP8_MAXas atl.constexprin the 5 Triton quantization kernels, updates the wrapper functions and reference implementations to derive limits fromtorch.finfo(dtype), and defaults the dtype toe4m3_dtype(platform-aware). The dtype assertions are widened to accept bothfloat8_e4m3fnandfloat8_e4m3fnuz. Test skip guards are updated to useget_device_capability()[0] >= 9which covers both NVIDIA SM >= 9.0 and ROCm GFX9xx+, and the@skip_if_rocmdecorators on the quantization kernel tests are removed. TheFloat8BlockwiseLinearinit guard is also updated.Addresses
blockwise_fp8_trainingentries in #3666.Tested on MI300X (gfx942) with
float8_e4m3fnuz(FP8 max = 240.0):Quantization kernel correctness (Triton vs PyTorch reference, block_size=128):
GEMM correctness (FP8 Triton GEMM vs BF16 matmul, SQNR):
GEMM kernel performance (FP8 Triton vs BF16 hipBLAS):
The FP8 Triton GEMM doesn't outperform hipBLAS BF16 on raw throughput — hipBLAS is heavily tuned for MI300X. The value here is that the kernels now work correctly on ROCm, and the FP8 path provides memory/bandwidth savings during training.