forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_per_overload_api.py
72 lines (54 loc) · 2.46 KB
/
test_per_overload_api.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
# Owner(s): ["module: unknown"]
import torch
import copy
from torch.testing._internal.common_utils import TestCase, run_tests
class TestPerOverloadAPI(TestCase):
def test_basics_opoverloadpacket(self):
# add is ony used as an example here. It is ok to update the test
# if the semantics of add are modified in the future.
add_packet = torch.ops.aten.add
# class attributes
self.assertEqual(add_packet.__name__, 'add')
self.assertEqual(str(add_packet), 'aten.add')
# callable
self.assertEqual(add_packet(torch.tensor(2), torch.tensor(3)), torch.tensor(5))
# correct module
self.assertEqual(add_packet.__module__, add_packet.op.__module__)
# caching
another_add_packet = torch.ops.aten.add
self.assertEqual(id(add_packet), id(another_add_packet))
# deepcopy is a no-op
self.assertEqual(id(add_packet), id(copy.deepcopy(add_packet)))
# pretty print
self.assertEqual(repr(add_packet), "<OpOverloadPacket(op='aten.add')>")
self.assertRaises(AttributeError, lambda: add_packet.foo)
def test_basics_opoverload(self):
add_packet = torch.ops.aten.add
add_tensoroverload = add_packet.Tensor
# class attributes
self.assertEqual(str(add_tensoroverload), 'aten.add.Tensor')
self.assertEqual(add_tensoroverload.__name__, 'add.Tensor')
self.assertEqual(add_tensoroverload.overloadpacket, add_packet)
# deepcopy is a no-op
self.assertEqual(id(add_tensoroverload), id(copy.deepcopy(add_tensoroverload)))
# caching
another_add_tensoroverload = torch.ops.aten.add.Tensor
self.assertEqual(id(add_tensoroverload), id(another_add_tensoroverload))
# pretty print
self.assertEqual(repr(add_tensoroverload), "<OpOverload(op='aten.add', overload='Tensor')>")
# callable
self.assertEqual(add_tensoroverload(torch.tensor(2), torch.tensor(3)), torch.tensor(5))
a = torch.tensor(2)
b = torch.tensor(0)
torch.ops.aten.add.out(a, a, out=b)
self.assertEqual(b, torch.tensor(4))
self.assertRaises(RuntimeError, lambda: add_tensoroverload(a, a, out=b))
def test_decompose(self):
x = torch.randn(2, 3)
y = torch.randn(5, 3)
self.assertEqual(
torch.ops.aten.linear.default.decompose(x, y),
torch.ops.aten.linear.default(x, y)
)
if __name__ == '__main__':
run_tests()