Skip to content

Commit 2cf9b68

Browse files
[training] skip rocm and distributed tests pending solution
1 parent 605a22e commit 2cf9b68

File tree

3 files changed

+10
-4
lines changed

3 files changed

+10
-4
lines changed

test/prototype/moe_training/test_tensor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import torch
99
import torch.nn.functional as F
1010

11+
from torchao.testing.utils import skip_if_rocm
1112
from torchao.utils import torch_version_at_least
1213

1314
# Skip module if basic requirements aren't met
@@ -22,6 +23,7 @@
2223
from torchao.quantization.utils import compute_error
2324

2425

26+
@skip_if_rocm
2527
@pytest.mark.parametrize("op_name", ["mm", "matmul", "linear"])
2628
@pytest.mark.parametrize("batch_size", [None, 2, 4])
2729
def test_mxfp8_training_tensor_ops_fwd_bwd(op_name, batch_size):

test/prototype/moe_training/test_training.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from torch import nn
66
from torch.nn import functional as F
77

8+
from torchao.testing.utils import skip_if_rocm
9+
810
# this feature requires CUDA and SM89+
911
if not torch.cuda.is_available() or torch.cuda.get_device_capability() < (8, 9):
1012
pytest.skip(
@@ -30,6 +32,7 @@
3032
torch._dynamo.config.cache_size_limit = 1000
3133

3234

35+
@skip_if_rocm
3336
@pytest.mark.parametrize(
3437
"target_fqns", [["experts"], ["shared_experts"], ["experts", "shared_experts"]]
3538
)

test/prototype/mx_formats/test_mx_dtensor.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,11 @@
1515
import pytest
1616
import torch
1717

18-
from torchao.utils import torch_version_at_least
19-
20-
if not torch_version_at_least("2.7.0"):
21-
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
18+
# TODO: re-enable once mx training refactor is complete
19+
pytest.skip(
20+
"DTensor support incomplete, MXFP8 training refactor is not yet complete, see: https://github.com/pytorch/ao/pull/3985",
21+
allow_module_level=True,
22+
)
2223

2324
from torch.distributed._tensor import DTensor, Shard, distribute_tensor
2425
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh

0 commit comments

Comments
 (0)