diff --git a/test/prototype/mx_formats/test_inference_workflow.py b/test/prototype/mx_formats/test_inference_workflow.py index de41616451..2c59c3e3ad 100644 --- a/test/prototype/mx_formats/test_inference_workflow.py +++ b/test/prototype/mx_formats/test_inference_workflow.py @@ -22,6 +22,8 @@ from torchao.quantization.utils import compute_error from torchao.testing.utils import TorchAOIntegrationTestCase, skip_if_rocm from torchao.utils import ( + is_MI350, + is_ROCM, is_sm_at_least_89, is_sm_at_least_100, torch_version_at_least, @@ -71,9 +73,6 @@ def cuda_kernel_profiler(kernel_pattern): @pytest.mark.parametrize("use_inference_mode", [True, False]) @pytest.mark.parametrize("x_rank", [2, 3]) @torch.no_grad() -@skip_if_rocm( - "ROCm float4 gemm require gfx950" -) # TODO(future): deploy gfx950 in ROCM CI def test_inference_workflow_mx( elem_dtype, bias: bool, @@ -85,19 +84,22 @@ def test_inference_workflow_mx( """ Smoke test for inference compile """ - # TODO(future): figure out why these CUDA capability conditions are not properly - # applied when inside `pytest.mark.skipif` for this test - if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2): - if not is_sm_at_least_89(): - pytest.skip("CUDA capability >= 8.9 required for float8 in triton") - elif not is_sm_at_least_100() and not emulate: - pytest.skip("CUDA capability >= 10.0 required for mxfp8 gemm") - elif elem_dtype == torch.float4_e2m1fn_x2: - if not is_sm_at_least_100() and not emulate: - pytest.skip("CUDA capability >= 10.0 required for mxfp4 gemm") - elif compile: - # TODO(future PR): investigate and fix this + if is_ROCM(): + if not emulate and not is_MI350(): + pytest.skip("ROCm native MX gemm requires gfx950 (MI350)") + if elem_dtype == torch.float4_e2m1fn_x2 and compile: pytest.skip("mxfp4 + compile currently does not work, low SQNR") + else: + if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2): + if not is_sm_at_least_89(): + pytest.skip("CUDA capability >= 8.9 required for float8 in triton") + elif not is_sm_at_least_100() and not emulate: + pytest.skip("CUDA capability >= 10.0 required for mxfp8 gemm") + elif elem_dtype == torch.float4_e2m1fn_x2: + if not is_sm_at_least_100() and not emulate: + pytest.skip("CUDA capability >= 10.0 required for mxfp4 gemm") + elif compile: + pytest.skip("mxfp4 + compile currently does not work, low SQNR") m = nn.Linear(32, 128, bias=bias, dtype=torch.bfloat16, device="cuda") m_mx = copy.deepcopy(m) diff --git a/test/prototype/mx_formats/test_mx_serialization.py b/test/prototype/mx_formats/test_mx_serialization.py index a109b63aef..275e66de4a 100644 --- a/test/prototype/mx_formats/test_mx_serialization.py +++ b/test/prototype/mx_formats/test_mx_serialization.py @@ -19,6 +19,7 @@ from torchao.quantization import quantize_ from torchao.quantization.quantize_.common import KernelPreference from torchao.utils import ( + is_ROCM, is_sm_at_least_100, torch_version_at_least, ) @@ -28,13 +29,16 @@ @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.skipif(not is_sm_at_least_100(), reason="needs CUDA capability 10.0+") @pytest.mark.parametrize("recipe_name", ["mxfp8", "nvfp4"]) def test_serialization(recipe_name): """ Ensure that only `import torchao.prototype.mx_formats` is needed to load MX and NV checkpoints. """ + if recipe_name == "nvfp4" and not is_sm_at_least_100(): + pytest.skip("NVFP4 requires CUDA capability 10.0+") + if recipe_name == "mxfp8" and not is_sm_at_least_100() and not is_ROCM(): + pytest.skip("MXFP8 serialization requires CUDA capability 10.0+ or ROCm") m = nn.Linear(32, 128, bias=False, dtype=torch.bfloat16, device="cuda") fname = None diff --git a/test/prototype/mx_formats/test_mx_tensor.py b/test/prototype/mx_formats/test_mx_tensor.py index cf85a984a6..7cf7f7da73 100644 --- a/test/prototype/mx_formats/test_mx_tensor.py +++ b/test/prototype/mx_formats/test_mx_tensor.py @@ -27,6 +27,7 @@ from torchao.quantization.quantize_.common import KernelPreference from torchao.quantization.utils import compute_error from torchao.utils import ( + is_ROCM, is_sm_at_least_89, is_sm_at_least_90, torch_version_at_least, @@ -467,9 +468,8 @@ def test_to_mx_from_mx_compile_numerics(elem_dtype, hp_dtype, all_zeros): Verifies that compile does not change numerics of MX casts """ if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2): - if not is_sm_at_least_89(): - # separate ifs because flake8 is outsmarting me - pytest.skip("CUDA capability >= 8.9 required for float8 in triton") + if not is_sm_at_least_89() and not is_ROCM(): + pytest.skip("CUDA capability >= 8.9 or ROCm required for float8 in triton") shape = 4, 8 if not all_zeros: @@ -510,8 +510,8 @@ def test_to_mx_from_mx_compile_numerics(elem_dtype, hp_dtype, all_zeros): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif( - not is_sm_at_least_89(), - reason="float8 in triton requires CUDA capability 8.9 or greater", + not is_sm_at_least_89() and not is_ROCM(), + reason="float8 in triton requires CUDA capability 8.9 or ROCm", ) def test_to_mx_inductor_single_kernel(): """ @@ -528,7 +528,10 @@ def test_to_mx_inductor_single_kernel(): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.skipif(not is_sm_at_least_90(), reason="Need sm90+") +@pytest.mark.skipif( + not is_sm_at_least_90() and not is_ROCM(), + reason="Need sm90+ or ROCm", +) def test_index_select(): """ test that `x_0 = x[0]` works when `x` is a 3D `MXTensor`. This is @@ -549,8 +552,8 @@ def test_index_select(): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif( - not is_sm_at_least_89(), - reason="float8 in triton requires CUDA capability 8.9 or greater", + not is_sm_at_least_89() and not is_ROCM(), + reason="float8 in triton requires CUDA capability 8.9 or ROCm", ) def test_cast_to_float8_e4m3fn_saturation_behavior(): # TODO(#1912): make the saturated cast work in eager mode and remove this diff --git a/test/prototype/mx_formats/test_mxfp8_allgather.py b/test/prototype/mx_formats/test_mxfp8_allgather.py index d68d2e7f43..4f9cfc66c8 100644 --- a/test/prototype/mx_formats/test_mxfp8_allgather.py +++ b/test/prototype/mx_formats/test_mxfp8_allgather.py @@ -3,7 +3,7 @@ import torch.distributed as dist from torchao.prototype.mx_formats.mx_tensor import MXTensor -from torchao.utils import is_sm_at_least_90, torch_version_at_least +from torchao.utils import is_ROCM, is_sm_at_least_90, torch_version_at_least if not torch_version_at_least("2.7.0"): pytest.skip("Unsupported PyTorch version", allow_module_level=True) @@ -99,7 +99,7 @@ def _test_allgather(local_rank): if __name__ == "__main__": local_rank = setup_distributed() - assert is_sm_at_least_90() == True, "SM must be > 9.0" + assert is_sm_at_least_90() or is_ROCM(), "SM must be >= 9.0 or ROCm" try: _test_allgather(local_rank)