Skip to content

Commit dd8119b

Browse files
authored
[release/2.9] Only skip linalg.eig assertion in test_torch_return_types_returns (#3095)
1 parent 495cb03 commit dd8119b

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

test/functorch/test_vmap.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@
5858
onlyCUDA,
5959
OpDTypes,
6060
ops,
61-
skipCUDAIfNoMagma,
6261
tol,
6362
toleranceOverride,
6463
)
@@ -71,6 +70,7 @@
7170
run_tests,
7271
skipIfTorchDynamo,
7372
subtest,
73+
TEST_WITH_ROCM,
7474
TEST_WITH_TORCHDYNAMO,
7575
TestCase,
7676
unMarkDynamoStrictTest,
@@ -5043,7 +5043,6 @@ def op(inp, running_mean, running_var, weight, bias, training):
50435043

50445044
test(self, op, tuple(inputs), in_dims=tuple(in_dims))
50455045

5046-
@skipCUDAIfNoMagma
50475046
def test_torch_return_types_returns(self, device):
50485047
t = torch.randn(3, 2, 2, device=device)
50495048
self.assertTrue(
@@ -5057,9 +5056,12 @@ def test_torch_return_types_returns(self, device):
50575056
vmap(torch.topk, (0, None, None))(t, 1, 0), torch.return_types.topk
50585057
)
50595058
)
5060-
self.assertTrue(
5061-
isinstance(vmap(torch.linalg.eig, (0))(t), torch.return_types.linalg_eig)
5062-
)
5059+
if not (TEST_WITH_ROCM and not torch.cuda.has_magma):
5060+
self.assertTrue(
5061+
isinstance(
5062+
vmap(torch.linalg.eig, (0))(t), torch.return_types.linalg_eig
5063+
)
5064+
)
50635065

50645066
def test_namedtuple_returns(self, device):
50655067
Point = namedtuple("Point", ["x", "y"])

0 commit comments

Comments
 (0)