Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 17 additions & 15 deletions test/prototype/mx_formats/test_inference_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from torchao.quantization.utils import compute_error
from torchao.testing.utils import TorchAOIntegrationTestCase, skip_if_rocm
from torchao.utils import (
is_MI350,
is_ROCM,
is_sm_at_least_89,
is_sm_at_least_100,
torch_version_at_least,
Expand Down Expand Up @@ -71,9 +73,6 @@ def cuda_kernel_profiler(kernel_pattern):
@pytest.mark.parametrize("use_inference_mode", [True, False])
@pytest.mark.parametrize("x_rank", [2, 3])
@torch.no_grad()
@skip_if_rocm(
"ROCm float4 gemm require gfx950"
) # TODO(future): deploy gfx950 in ROCM CI
def test_inference_workflow_mx(
elem_dtype,
bias: bool,
Expand All @@ -85,19 +84,22 @@ def test_inference_workflow_mx(
"""
Smoke test for inference compile
"""
# TODO(future): figure out why these CUDA capability conditions are not properly
# applied when inside `pytest.mark.skipif` for this test
if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
if not is_sm_at_least_89():
pytest.skip("CUDA capability >= 8.9 required for float8 in triton")
elif not is_sm_at_least_100() and not emulate:
pytest.skip("CUDA capability >= 10.0 required for mxfp8 gemm")
elif elem_dtype == torch.float4_e2m1fn_x2:
if not is_sm_at_least_100() and not emulate:
pytest.skip("CUDA capability >= 10.0 required for mxfp4 gemm")
elif compile:
# TODO(future PR): investigate and fix this
if is_ROCM():
if not emulate and not is_MI350():
pytest.skip("ROCm native MX gemm requires gfx950 (MI350)")
if elem_dtype == torch.float4_e2m1fn_x2 and compile:
pytest.skip("mxfp4 + compile currently does not work, low SQNR")
else:
if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
if not is_sm_at_least_89():
pytest.skip("CUDA capability >= 8.9 required for float8 in triton")
elif not is_sm_at_least_100() and not emulate:
pytest.skip("CUDA capability >= 10.0 required for mxfp8 gemm")
elif elem_dtype == torch.float4_e2m1fn_x2:
if not is_sm_at_least_100() and not emulate:
pytest.skip("CUDA capability >= 10.0 required for mxfp4 gemm")
elif compile:
pytest.skip("mxfp4 + compile currently does not work, low SQNR")

m = nn.Linear(32, 128, bias=bias, dtype=torch.bfloat16, device="cuda")
m_mx = copy.deepcopy(m)
Expand Down
6 changes: 5 additions & 1 deletion test/prototype/mx_formats/test_mx_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from torchao.quantization import quantize_
from torchao.quantization.quantize_.common import KernelPreference
from torchao.utils import (
is_ROCM,
is_sm_at_least_100,
torch_version_at_least,
)
Expand All @@ -28,13 +29,16 @@


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not is_sm_at_least_100(), reason="needs CUDA capability 10.0+")
@pytest.mark.parametrize("recipe_name", ["mxfp8", "nvfp4"])
def test_serialization(recipe_name):
"""
Ensure that only `import torchao.prototype.mx_formats` is needed to load MX
and NV checkpoints.
"""
if recipe_name == "nvfp4" and not is_sm_at_least_100():
pytest.skip("NVFP4 requires CUDA capability 10.0+")
if recipe_name == "mxfp8" and not is_sm_at_least_100() and not is_ROCM():
pytest.skip("MXFP8 serialization requires CUDA capability 10.0+ or ROCm")

m = nn.Linear(32, 128, bias=False, dtype=torch.bfloat16, device="cuda")
fname = None
Expand Down
19 changes: 11 additions & 8 deletions test/prototype/mx_formats/test_mx_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from torchao.quantization.quantize_.common import KernelPreference
from torchao.quantization.utils import compute_error
from torchao.utils import (
is_ROCM,
is_sm_at_least_89,
is_sm_at_least_90,
torch_version_at_least,
Expand Down Expand Up @@ -467,9 +468,8 @@ def test_to_mx_from_mx_compile_numerics(elem_dtype, hp_dtype, all_zeros):
Verifies that compile does not change numerics of MX casts
"""
if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
if not is_sm_at_least_89():
# separate ifs because flake8 is outsmarting me
pytest.skip("CUDA capability >= 8.9 required for float8 in triton")
if not is_sm_at_least_89() and not is_ROCM():
pytest.skip("CUDA capability >= 8.9 or ROCm required for float8 in triton")

shape = 4, 8
if not all_zeros:
Expand Down Expand Up @@ -510,8 +510,8 @@ def test_to_mx_from_mx_compile_numerics(elem_dtype, hp_dtype, all_zeros):

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(
not is_sm_at_least_89(),
reason="float8 in triton requires CUDA capability 8.9 or greater",
not is_sm_at_least_89() and not is_ROCM(),
reason="float8 in triton requires CUDA capability 8.9 or ROCm",
)
def test_to_mx_inductor_single_kernel():
"""
Expand All @@ -528,7 +528,10 @@ def test_to_mx_inductor_single_kernel():


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not is_sm_at_least_90(), reason="Need sm90+")
@pytest.mark.skipif(
not is_sm_at_least_90() and not is_ROCM(),
reason="Need sm90+ or ROCm",
)
def test_index_select():
"""
test that `x_0 = x[0]` works when `x` is a 3D `MXTensor`. This is
Expand All @@ -549,8 +552,8 @@ def test_index_select():

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(
not is_sm_at_least_89(),
reason="float8 in triton requires CUDA capability 8.9 or greater",
not is_sm_at_least_89() and not is_ROCM(),
reason="float8 in triton requires CUDA capability 8.9 or ROCm",
)
def test_cast_to_float8_e4m3fn_saturation_behavior():
# TODO(#1912): make the saturated cast work in eager mode and remove this
Expand Down
4 changes: 2 additions & 2 deletions test/prototype/mx_formats/test_mxfp8_allgather.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch.distributed as dist

from torchao.prototype.mx_formats.mx_tensor import MXTensor
from torchao.utils import is_sm_at_least_90, torch_version_at_least
from torchao.utils import is_ROCM, is_sm_at_least_90, torch_version_at_least

if not torch_version_at_least("2.7.0"):
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
Expand Down Expand Up @@ -99,7 +99,7 @@ def _test_allgather(local_rank):
if __name__ == "__main__":
local_rank = setup_distributed()

assert is_sm_at_least_90() == True, "SM must be > 9.0"
assert is_sm_at_least_90() or is_ROCM(), "SM must be >= 9.0 or ROCm"

try:
_test_allgather(local_rank)
Expand Down
Loading