[ROCm] Enable MXFP8/MXFP4 emulation tests on ROCm (MI300+)#4041
Open
brucechanglongxu wants to merge 1 commit intopytorch:mainfrom
Open
[ROCm] Enable MXFP8/MXFP4 emulation tests on ROCm (MI300+)#4041brucechanglongxu wants to merge 1 commit intopytorch:mainfrom
brucechanglongxu wants to merge 1 commit intopytorch:mainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/4041
Note: Links to docs will display an error until the docs builds have been completed. ❌ 2 New FailuresAs of commit 9c1764c with merge base 4ae435e ( NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
The MX emulation path (`KernelPreference.EMULATED`) performs quantization
and matmul entirely via PyTorch ops -- no native MX hardware kernels are
involved. Despite this, most MX tests were gated behind CUDA SM checks
(`is_sm_at_least_89/90/100`) or blanket `@skip_if_rocm` decorators that
prevented them from running on any ROCm GPU, including MI300X where the
emulation path works correctly.
This patch makes the skip conditions ROCm-aware across four test files so
that emulation-path tests run on ROCm while native-kernel tests remain
correctly gated.
test_inference_workflow.py:
Remove `@skip_if_rocm("ROCm float4 gemm require gfx950")` from
`test_inference_workflow_mx`. The decorator was intended for the native
float4/float8 gemm path but also blocked the `emulate=True` parameter
sweep. Replace with in-body logic: skip native path unless MI350
(gfx950), preserve the mxfp4+compile skip, let everything else run.
test_mx_tensor.py:
Widen the `is_sm_at_least_89/90` skip conditions on four tests to also
pass on ROCm: `test_to_mx_from_mx_compile_numerics` (float8 compile
numerics), `test_to_mx_inductor_single_kernel` (inductor fusion),
`test_index_select` (3D MXTensor indexing, no compile involved), and
`test_cast_to_float8_e4m3fn_saturation_behavior` (triton float8 cast).
test_mx_serialization.py:
The `is_sm_at_least_100` decorator-level skip prevented the mxfp8
recipe from running on ROCm even though it uses `EMULATED` mode and
just tests checkpoint save/load. Move the skip into the test body so
that mxfp8 runs on ROCm while nvfp4 remains gated on SM100.
test_mxfp8_allgather.py:
Widen the SM90 assert to also accept ROCm. The allgather test is pure
tensor data transfer with no compute dependency on SM version.
Verified on MI300X (gfx942): 16 inference workflow tests pass, 90 basic
MX tensor tests pass, serialization and index_select pass, existing
test_mx_linear TORCH-cast tests still pass (123 passed, 0 regressions).
b7874f4 to
9c1764c
Compare
danielvegamyhre
approved these changes
Mar 11, 2026
|
To add the ciflow label This helps ensure we don't trigger CI on this PR until it is actually authorized to do so. Please ping one of the reviewers if you do not have access to approve and run workflows. |
|
Warning: Unknown label
Please add the new label to .github/pytorch-probot.yml |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
The MX emulation path (KernelPreference.EMULATED) performs quantization and matmul entirely via PyTorch ops with no native MX hardware kernels involved. Despite this, most MX tests were gated behind CUDA SM checks (is_sm_at_least_89/90/100) or blanket @skip_if_rocm decorators that prevented them from running on any ROCm GPU, including MI300X where the emulation path works fine.
This patch makes the skip conditions ROCm-aware across four test files so that emulation-path tests run on ROCm while native-kernel tests remain correctly gated.
test_inference_workflow.py: the @skip_if_rocm("ROCm float4 gemm require gfx950") decorator on test_inference_workflow_mx was intended for the native float4/float8 gemm path but also blocked the emulate=True parameter sweep. Replaced with in-body logic that skips the native path unless MI350 (gfx950) is present, preserves the mxfp4+compile skip, and lets all emulation configs through. The CUDA path is unchanged (same is_sm_at_least_89/100 checks as before, just nested under an else branch).
test_mx_tensor.py: widened the is_sm_at_least_89/90 skip conditions on four tests to also pass on ROCm -- test_to_mx_from_mx_compile_numerics (float8 compile numerics), test_to_mx_inductor_single_kernel (inductor fusion), test_index_select (3D MXTensor indexing, no compile involved at all), and test_cast_to_float8_e4m3fn_saturation_behavior (triton float8 saturated cast).
test_mx_serialization.py: the is_sm_at_least_100 decorator-level skip prevented the mxfp8 recipe from running on ROCm even though it uses EMULATED mode and just tests checkpoint save/load. Moved the skip into the test body so that mxfp8 runs on ROCm while nvfp4 stays gated on SM100.
test_mxfp8_allgather.py: widened the SM90 assert to also accept ROCm. The allgather test is pure tensor data transfer with no compute dependency on SM version.
Files not changed (with rationale): test_mx_linear.py already runs TORCH-cast eager tests on ROCm with no skip (the is_sm_at_least_89 gate only applies to non-TORCH cast kernels). test_mx_dtensor.py already runs emulated tests; the dim1 triton/cuda tests correctly require SM100 since they use PTX inline assembly. test_kernels.py triton mxfp8 kernels are CUDA-only (PTX). test_mx_mm.py tests native scaled_mm which requires SM100 hardware.
Tested on MI300X (gfx942): 16 inference workflow emulation tests pass, 90 basic MX tensor tests pass, serialization[mxfp8] passes, index_select passes. Existing test_mx_linear TORCH-cast tests still pass (123 passed, 0 regressions). All native-path tests correctly skip on MI300X.