Skip to content

Commit bdc89f5

Browse files
authored
[release/2.10] Only skip linalg.eig assertion in test_torch_return_types_returns (#3096)
1 parent 1f8cea4 commit bdc89f5

File tree

1 file changed

+6
-7
lines changed

1 file changed

+6
-7
lines changed

test/functorch/test_vmap.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5052,10 +5052,6 @@ def op(inp, running_mean, running_var, weight, bias, training):
50525052

50535053
test(self, op, tuple(inputs), in_dims=tuple(in_dims))
50545054

5055-
@unittest.skipIf(
5056-
TEST_WITH_ROCM and not torch.cuda.has_magma,
5057-
"ROCm hipsolver backend does not currently support eig",
5058-
)
50595055
def test_torch_return_types_returns(self, device):
50605056
t = torch.randn(3, 2, 2, device=device)
50615057
self.assertTrue(
@@ -5069,9 +5065,12 @@ def test_torch_return_types_returns(self, device):
50695065
vmap(torch.topk, (0, None, None))(t, 1, 0), torch.return_types.topk
50705066
)
50715067
)
5072-
self.assertTrue(
5073-
isinstance(vmap(torch.linalg.eig, (0))(t), torch.return_types.linalg_eig)
5074-
)
5068+
if not (TEST_WITH_ROCM and not torch.cuda.has_magma):
5069+
self.assertTrue(
5070+
isinstance(
5071+
vmap(torch.linalg.eig, (0))(t), torch.return_types.linalg_eig
5072+
)
5073+
)
50755074

50765075
def test_namedtuple_returns(self, device):
50775076
Point = namedtuple("Point", ["x", "y"])

0 commit comments

Comments
 (0)