Skip to content

Commit 8f27938

Browse files
[training] skip rocm and distributed tests pending solution
1 parent 605a22e commit 8f27938

File tree

5 files changed

+25
-6
lines changed

5 files changed

+25
-6
lines changed

test/prototype/blockwise_fp8_training/test_blockwise_kernels.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
import pytest
88
import torch
99

10+
from torchao.utils import is_ROCM
11+
1012
triton = pytest.importorskip("triton", reason="Triton required to run this test")
1113

1214
from packaging import version
@@ -37,6 +39,11 @@
3739
(67, 6656, 1408),
3840
]
3941

42+
if is_ROCM():
43+
pytest.skip(
44+
"ROCM not yet supported, tests failing",
45+
allow_module_level=True,
46+
)
4047

4148
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
4249
@pytest.mark.skipif(

test/prototype/moe_training/test_tensor.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,18 @@
88
import torch
99
import torch.nn.functional as F
1010

11-
from torchao.utils import torch_version_at_least
11+
from torchao.utils import is_ROCM, torch_version_at_least
1212

1313
# Skip module if basic requirements aren't met
1414
if not (torch_version_at_least("2.7.0") and torch.cuda.is_available()):
1515
pytest.skip("CUDA and PyTorch 2.7.0+ required", allow_module_level=True)
1616

17+
if is_ROCM():
18+
pytest.skip(
19+
"ROCM not yet supported, tests failing",
20+
allow_module_level=True,
21+
)
22+
1723
from torchao.prototype.moe_training.config import (
1824
MXFP8TrainingOpConfig,
1925
MXFP8TrainingRecipe,

test/prototype/moe_training/test_training.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from torch import nn
66
from torch.nn import functional as F
77

8+
from torchao.testing.utils import skip_if_rocm
9+
810
# this feature requires CUDA and SM89+
911
if not torch.cuda.is_available() or torch.cuda.get_device_capability() < (8, 9):
1012
pytest.skip(
@@ -30,6 +32,7 @@
3032
torch._dynamo.config.cache_size_limit = 1000
3133

3234

35+
@skip_if_rocm
3336
@pytest.mark.parametrize(
3437
"target_fqns", [["experts"], ["shared_experts"], ["experts", "shared_experts"]]
3538
)

test/prototype/mx_formats/test_inference_workflow.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,15 @@
2222
from torchao.quantization.utils import compute_error
2323
from torchao.testing.utils import TorchAOIntegrationTestCase, skip_if_rocm
2424
from torchao.utils import (
25+
is_ROCM,
2526
is_sm_at_least_89,
2627
is_sm_at_least_100,
2728
torch_version_at_least,
2829
)
2930

3031
torch.manual_seed(2)
3132

32-
if not torch_version_at_least("2.8.0"):
33+
if not torch_version_at_least("2.8.0") or is_ROCM():
3334
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
3435

3536

test/prototype/mx_formats/test_mx_dtensor.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,16 @@
1111
"""
1212

1313
import os
14+
import sys
1415

15-
import pytest
1616
import torch
1717

18-
from torchao.utils import torch_version_at_least
18+
# TODO: re-enable once mx training refactor is complete
19+
_SKIP_MSG = "DTensor support incomplete, MXFP8 training refactor is not yet complete, see: https://github.com/pytorch/ao/pull/3985"
1920

20-
if not torch_version_at_least("2.7.0"):
21-
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
21+
if __name__ == "__main__":
22+
print(f"SKIPPED: {_SKIP_MSG}")
23+
sys.exit(0)
2224

2325
from torch.distributed._tensor import DTensor, Shard, distribute_tensor
2426
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh

0 commit comments

Comments
 (0)