forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_appending_byte_serializer.py
85 lines (68 loc) · 2.29 KB
/
test_appending_byte_serializer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
# Owner(s): ["module: inductor"]
import dataclasses
from torch.testing._internal.common_utils import TestCase
from torch.utils._appending_byte_serializer import (
AppendingByteSerializer,
BytesReader,
BytesWriter,
)
class TestAppendingByteSerializer(TestCase):
def test_write_and_read_int(self) -> None:
def int_serializer(writer: BytesWriter, i: int) -> None:
writer.write_uint64(i)
def int_deserializer(reader: BytesReader) -> int:
return reader.read_uint64()
s = AppendingByteSerializer(serialize_fn=int_serializer)
data = [1, 2, 3, 4]
s.extend(data)
self.assertListEqual(
data,
AppendingByteSerializer.to_list(
s.to_bytes(), deserialize_fn=int_deserializer
),
)
data2 = [8, 9, 10, 11]
s.extend(data2)
self.assertListEqual(
data + data2,
AppendingByteSerializer.to_list(
s.to_bytes(), deserialize_fn=int_deserializer
),
)
def test_write_and_read_class(self) -> None:
@dataclasses.dataclass(frozen=True, eq=True)
class Foo:
x: int
y: str
z: bytes
@staticmethod
def serialize(writer: BytesWriter, cls: "Foo") -> None:
writer.write_uint64(cls.x)
writer.write_str(cls.y)
writer.write_bytes(cls.z)
@staticmethod
def deserialize(reader: BytesReader) -> "Foo":
x = reader.read_uint64()
y = reader.read_str()
z = reader.read_bytes()
return Foo(x, y, z)
a = Foo(5, "ok", bytes([15]))
b = Foo(10, "lol", bytes([25]))
s = AppendingByteSerializer(serialize_fn=Foo.serialize)
s.append(a)
self.assertListEqual(
[a],
AppendingByteSerializer.to_list(
s.to_bytes(), deserialize_fn=Foo.deserialize
),
)
s.append(b)
self.assertListEqual(
[a, b],
AppendingByteSerializer.to_list(
s.to_bytes(), deserialize_fn=Foo.deserialize
),
)
if __name__ == "__main__":
from torch._inductor.test_case import run_tests
run_tests()