Skip to content

[ROCm] Enable MXFP8/MXFP4 emulation tests on ROCm (MI300+)#4041

Open
brucechanglongxu wants to merge 1 commit intopytorch:mainfrom
brucechanglongxu:rocm-mxfp-emulation-tests
Open

[ROCm] Enable MXFP8/MXFP4 emulation tests on ROCm (MI300+)#4041
brucechanglongxu wants to merge 1 commit intopytorch:mainfrom
brucechanglongxu:rocm-mxfp-emulation-tests

Conversation

@brucechanglongxu
Copy link
Contributor

The MX emulation path (KernelPreference.EMULATED) performs quantization and matmul entirely via PyTorch ops with no native MX hardware kernels involved. Despite this, most MX tests were gated behind CUDA SM checks (is_sm_at_least_89/90/100) or blanket @skip_if_rocm decorators that prevented them from running on any ROCm GPU, including MI300X where the emulation path works fine.

This patch makes the skip conditions ROCm-aware across four test files so that emulation-path tests run on ROCm while native-kernel tests remain correctly gated.

test_inference_workflow.py: the @skip_if_rocm("ROCm float4 gemm require gfx950") decorator on test_inference_workflow_mx was intended for the native float4/float8 gemm path but also blocked the emulate=True parameter sweep. Replaced with in-body logic that skips the native path unless MI350 (gfx950) is present, preserves the mxfp4+compile skip, and lets all emulation configs through. The CUDA path is unchanged (same is_sm_at_least_89/100 checks as before, just nested under an else branch).

test_mx_tensor.py: widened the is_sm_at_least_89/90 skip conditions on four tests to also pass on ROCm -- test_to_mx_from_mx_compile_numerics (float8 compile numerics), test_to_mx_inductor_single_kernel (inductor fusion), test_index_select (3D MXTensor indexing, no compile involved at all), and test_cast_to_float8_e4m3fn_saturation_behavior (triton float8 saturated cast).

test_mx_serialization.py: the is_sm_at_least_100 decorator-level skip prevented the mxfp8 recipe from running on ROCm even though it uses EMULATED mode and just tests checkpoint save/load. Moved the skip into the test body so that mxfp8 runs on ROCm while nvfp4 stays gated on SM100.

test_mxfp8_allgather.py: widened the SM90 assert to also accept ROCm. The allgather test is pure tensor data transfer with no compute dependency on SM version.

Files not changed (with rationale): test_mx_linear.py already runs TORCH-cast eager tests on ROCm with no skip (the is_sm_at_least_89 gate only applies to non-TORCH cast kernels). test_mx_dtensor.py already runs emulated tests; the dim1 triton/cuda tests correctly require SM100 since they use PTX inline assembly. test_kernels.py triton mxfp8 kernels are CUDA-only (PTX). test_mx_mm.py tests native scaled_mm which requires SM100 hardware.

Tested on MI300X (gfx942): 16 inference workflow emulation tests pass, 90 basic MX tensor tests pass, serialization[mxfp8] passes, index_select passes. Existing test_mx_linear TORCH-cast tests still pass (123 passed, 0 regressions). All native-path tests correctly skip on MI300X.

@pytorch-bot
Copy link

pytorch-bot bot commented Mar 10, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/4041

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 New Failures

As of commit 9c1764c with merge base 4ae435e (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 10, 2026
The MX emulation path (`KernelPreference.EMULATED`) performs quantization
and matmul entirely via PyTorch ops -- no native MX hardware kernels are
involved. Despite this, most MX tests were gated behind CUDA SM checks
(`is_sm_at_least_89/90/100`) or blanket `@skip_if_rocm` decorators that
prevented them from running on any ROCm GPU, including MI300X where the
emulation path works correctly.

This patch makes the skip conditions ROCm-aware across four test files so
that emulation-path tests run on ROCm while native-kernel tests remain
correctly gated.

test_inference_workflow.py:
  Remove `@skip_if_rocm("ROCm float4 gemm require gfx950")` from
  `test_inference_workflow_mx`. The decorator was intended for the native
  float4/float8 gemm path but also blocked the `emulate=True` parameter
  sweep. Replace with in-body logic: skip native path unless MI350
  (gfx950), preserve the mxfp4+compile skip, let everything else run.

test_mx_tensor.py:
  Widen the `is_sm_at_least_89/90` skip conditions on four tests to also
  pass on ROCm: `test_to_mx_from_mx_compile_numerics` (float8 compile
  numerics), `test_to_mx_inductor_single_kernel` (inductor fusion),
  `test_index_select` (3D MXTensor indexing, no compile involved), and
  `test_cast_to_float8_e4m3fn_saturation_behavior` (triton float8 cast).

test_mx_serialization.py:
  The `is_sm_at_least_100` decorator-level skip prevented the mxfp8
  recipe from running on ROCm even though it uses `EMULATED` mode and
  just tests checkpoint save/load. Move the skip into the test body so
  that mxfp8 runs on ROCm while nvfp4 remains gated on SM100.

test_mxfp8_allgather.py:
  Widen the SM90 assert to also accept ROCm. The allgather test is pure
  tensor data transfer with no compute dependency on SM version.

Verified on MI300X (gfx942): 16 inference workflow tests pass, 90 basic
MX tensor tests pass, serialization and index_select pass, existing
test_mx_linear TORCH-cast tests still pass (123 passed, 0 regressions).
@brucechanglongxu brucechanglongxu force-pushed the rocm-mxfp-emulation-tests branch from b7874f4 to 9c1764c Compare March 10, 2026 23:37
@danielvegamyhre danielvegamyhre added module: training quantize_ api training flow ciflow/rocm labels Mar 11, 2026
@pytorch-bot
Copy link

pytorch-bot bot commented Mar 11, 2026

To add the ciflow label ciflow/rocm please first approve the workflows that are awaiting approval (scroll to the bottom of this page).

This helps ensure we don't trigger CI on this PR until it is actually authorized to do so. Please ping one of the reviewers if you do not have access to approve and run workflows.

@pytorch-bot
Copy link

pytorch-bot bot commented Mar 11, 2026

Warning: Unknown label ciflow/rocm-mi300.
Currently recognized labels are

  • ciflow/benchmark
  • ciflow/tutorials
  • ciflow/rocm
  • ciflow/4xh100
  • ciflow/xpu

Please add the new label to .github/pytorch-probot.yml

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/rocm ciflow/rocm-mi300 CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. module: rocm module: training quantize_ api training flow

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants