Skip to content

Commit 6bdb3de

Browse files
committed
Fix test.
1 parent 257a10b commit 6bdb3de

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

test/test_operations.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -2387,11 +2387,14 @@ def test_cummax_0_sized_dimension(self):
23872387
def test_conj_no_fallback(self):
23882388
met.clear_all()
23892389

2390-
tensor = torch.tensor([1 + 2j])
2391-
expected = torch.conj(tensor)
2392-
actual = torch.conj(tensor.to(xm.xla_device()))
2390+
def run(device):
2391+
tensor = torch.tensor([1 + 2j], device=device)
2392+
return torch.conj(tensor)
23932393

2394-
self.assertEqual(actual, expected)
2394+
actual = run("cpu")
2395+
expected = run(xm.xla_device())
2396+
2397+
self.assertEqual(actual, expected.cpu())
23952398
self.assertEqual(len(met.executed_fallback_ops()), 0)
23962399

23972400

0 commit comments

Comments
 (0)