Skip to content

Commit 5f6fa23

Browse files
pytorchbotlucylq
andauthored
[executorch][flat_tensor] Serialize flat tensor tests
Pull Request resolved: #7269 Introduce _convert_to_flat_tensor, which interprets a flat_tensor blob as a flat_tensor schema. Use this for more comprehensive testing for flat tensor serialization, and later for deserialization. ghstack-source-id: 261976100 @exported-using-ghexport Differential Revision: [D67007821](https://our.internmc.facebook.com/intern/diff/D67007821/) Co-authored-by: lucylq <[email protected]>
1 parent ff7f0c8 commit 5f6fa23

File tree

2 files changed

+136
-10
lines changed

2 files changed

+136
-10
lines changed

extension/flat_tensor/serialize/serialize.py

+31-5
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@
1414

1515
import pkg_resources
1616
from executorch.exir._serialize._cord import Cord
17-
from executorch.exir._serialize._dataclass import _DataclassEncoder
17+
from executorch.exir._serialize._dataclass import _DataclassEncoder, _json_to_dataclass
1818

19-
from executorch.exir._serialize._flatbuffer import _flatc_compile
19+
from executorch.exir._serialize._flatbuffer import _flatc_compile, _flatc_decompile
2020
from executorch.exir._serialize.data_serializer import DataPayload, DataSerializer
2121

2222
from executorch.exir._serialize.padding import aligned_size, pad_to, padding_required
@@ -33,8 +33,8 @@
3333
)
3434

3535

36-
def _convert_to_flatbuffer(flat_tensor: FlatTensor) -> Cord:
37-
"""Converts a FlatTensor to a flatbuffer and returns the serialized data."""
36+
def _serialize_to_flatbuffer(flat_tensor: FlatTensor) -> Cord:
37+
"""Serializes a FlatTensor to a flatbuffer and returns the serialized data."""
3838
flat_tensor_json = json.dumps(flat_tensor, cls=_DataclassEncoder)
3939
with tempfile.TemporaryDirectory() as d:
4040
schema_path = os.path.join(d, "flat_tensor.fbs")
@@ -57,6 +57,32 @@ def _convert_to_flatbuffer(flat_tensor: FlatTensor) -> Cord:
5757
return Cord(output_file.read())
5858

5959

60+
def _deserialize_to_flat_tensor(flatbuffer: bytes) -> FlatTensor:
61+
"""Deserializes a flatbuffer to a FlatTensor and returns the dataclass."""
62+
with tempfile.TemporaryDirectory() as d:
63+
schema_path = os.path.join(d, "flat_tensor.fbs")
64+
with open(schema_path, "wb") as schema_file:
65+
schema_file.write(
66+
pkg_resources.resource_string(__name__, "flat_tensor.fbs")
67+
)
68+
69+
scalar_type_path = os.path.join(d, "scalar_type.fbs")
70+
with open(scalar_type_path, "wb") as scalar_type_file:
71+
scalar_type_file.write(
72+
pkg_resources.resource_string(__name__, "scalar_type.fbs")
73+
)
74+
75+
bin_path = os.path.join(d, "flat_tensor.bin")
76+
with open(bin_path, "wb") as bin_file:
77+
bin_file.write(flatbuffer)
78+
79+
_flatc_decompile(d, schema_path, bin_path, ["--raw-binary"])
80+
81+
json_path = os.path.join(d, "flat_tensor.json")
82+
with open(json_path, "rb") as output_file:
83+
return _json_to_dataclass(json.load(output_file), cls=FlatTensor)
84+
85+
6086
@dataclass
6187
class FlatTensorConfig:
6288
tensor_alignment: int = 16
@@ -244,7 +270,7 @@ def serialize(
244270
segments=[DataSegment(offset=0, size=len(flat_tensor_data))],
245271
)
246272

247-
flatbuffer_payload = _convert_to_flatbuffer(flat_tensor)
273+
flatbuffer_payload = _serialize_to_flatbuffer(flat_tensor)
248274
padded_flatbuffer_length: int = aligned_size(
249275
input_size=len(flatbuffer_payload),
250276
alignment=self.config.tensor_alignment,

extension/flat_tensor/test/test_serialize.py

+105-5
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
import unittest
1010

11+
from typing import List
12+
1113
from executorch.exir._serialize.data_serializer import (
1214
DataPayload,
1315
DataSerializer,
@@ -18,15 +20,17 @@
1820
from executorch.exir._serialize.padding import aligned_size
1921

2022
from executorch.exir.schema import ScalarType
23+
from executorch.extension.flat_tensor.serialize.flat_tensor_schema import TensorMetadata
2124

2225
from executorch.extension.flat_tensor.serialize.serialize import (
26+
_deserialize_to_flat_tensor,
2327
FlatTensorConfig,
2428
FlatTensorHeader,
2529
FlatTensorSerializer,
2630
)
2731

2832
# Test artifacts.
29-
TEST_TENSOR_BUFFER = [b"tensor"]
33+
TEST_TENSOR_BUFFER: List[bytes] = [b"\x11" * 4, b"\x22" * 32]
3034
TEST_TENSOR_MAP = {
3135
"fqn1": TensorEntry(
3236
buffer_index=0,
@@ -44,6 +48,14 @@
4448
dim_order=[0, 1, 2],
4549
),
4650
),
51+
"fqn3": TensorEntry(
52+
buffer_index=1,
53+
layout=TensorLayout(
54+
scalar_type=ScalarType.INT,
55+
sizes=[2, 2, 2],
56+
dim_order=[0, 1],
57+
),
58+
),
4759
}
4860
TEST_DATA_PAYLOAD = DataPayload(
4961
buffers=TEST_TENSOR_BUFFER,
@@ -52,13 +64,24 @@
5264

5365

5466
class TestSerialize(unittest.TestCase):
67+
# TODO(T211851359): improve test coverage.
68+
def check_tensor_metadata(
69+
self, tensor_layout: TensorLayout, tensor_metadata: TensorMetadata
70+
) -> None:
71+
self.assertEqual(tensor_layout.scalar_type, tensor_metadata.scalar_type)
72+
self.assertEqual(tensor_layout.sizes, tensor_metadata.sizes)
73+
self.assertEqual(tensor_layout.dim_order, tensor_metadata.dim_order)
74+
5575
def test_serialize(self) -> None:
5676
config = FlatTensorConfig()
5777
serializer: DataSerializer = FlatTensorSerializer(config)
5878

59-
data = bytes(serializer.serialize(TEST_DATA_PAYLOAD))
79+
serialized_data = bytes(serializer.serialize(TEST_DATA_PAYLOAD))
6080

61-
header = FlatTensorHeader.from_bytes(data[0 : FlatTensorHeader.EXPECTED_LENGTH])
81+
# Check header.
82+
header = FlatTensorHeader.from_bytes(
83+
serialized_data[0 : FlatTensorHeader.EXPECTED_LENGTH]
84+
)
6285
self.assertTrue(header.is_valid())
6386

6487
# Header is aligned to config.segment_alignment, which is where the flatbuffer starts.
@@ -77,9 +100,86 @@ def test_serialize(self) -> None:
77100
self.assertTrue(header.segment_base_offset, expected_segment_base_offset)
78101

79102
# TEST_TENSOR_BUFFER is aligned to config.segment_alignment.
80-
self.assertEqual(header.segment_data_size, config.segment_alignment)
103+
expected_segment_data_size = aligned_size(
104+
sum(len(buffer) for buffer in TEST_TENSOR_BUFFER), config.segment_alignment
105+
)
106+
self.assertEqual(header.segment_data_size, expected_segment_data_size)
81107

82108
# Confirm the flatbuffer magic is present.
83109
self.assertEqual(
84-
data[header.flatbuffer_offset + 4 : header.flatbuffer_offset + 8], b"FT01"
110+
serialized_data[
111+
header.flatbuffer_offset + 4 : header.flatbuffer_offset + 8
112+
],
113+
b"FT01",
114+
)
115+
116+
# Check flat tensor data.
117+
flat_tensor_bytes = serialized_data[
118+
header.flatbuffer_offset : header.flatbuffer_offset + header.flatbuffer_size
119+
]
120+
121+
flat_tensor = _deserialize_to_flat_tensor(flat_tensor_bytes)
122+
123+
self.assertEqual(flat_tensor.version, 0)
124+
self.assertEqual(flat_tensor.tensor_alignment, config.tensor_alignment)
125+
126+
tensors = flat_tensor.tensors
127+
self.assertEqual(len(tensors), 3)
128+
self.assertEqual(tensors[0].fully_qualified_name, "fqn1")
129+
self.check_tensor_metadata(TEST_TENSOR_MAP["fqn1"].layout, tensors[0])
130+
self.assertEqual(tensors[0].segment_index, 0)
131+
self.assertEqual(tensors[0].offset, 0)
132+
133+
self.assertEqual(tensors[1].fully_qualified_name, "fqn2")
134+
self.check_tensor_metadata(TEST_TENSOR_MAP["fqn2"].layout, tensors[1])
135+
self.assertEqual(tensors[1].segment_index, 0)
136+
self.assertEqual(tensors[1].offset, 0)
137+
138+
self.assertEqual(tensors[2].fully_qualified_name, "fqn3")
139+
self.check_tensor_metadata(TEST_TENSOR_MAP["fqn3"].layout, tensors[2])
140+
self.assertEqual(tensors[2].segment_index, 0)
141+
self.assertEqual(tensors[2].offset, config.tensor_alignment)
142+
143+
segments = flat_tensor.segments
144+
self.assertEqual(len(segments), 1)
145+
self.assertEqual(segments[0].offset, 0)
146+
self.assertEqual(segments[0].size, config.tensor_alignment * 3)
147+
148+
# Length of serialized_data matches segment_base_offset + segment_data_size.
149+
self.assertEqual(
150+
header.segment_base_offset + header.segment_data_size, len(serialized_data)
151+
)
152+
self.assertTrue(segments[0].size <= header.segment_data_size)
153+
154+
# Check the contents of the segment. Expecting two tensors from
155+
# TEST_TENSOR_BUFFER = [b"\x11" * 4, b"\x22" * 32]
156+
segment_data = serialized_data[
157+
header.segment_base_offset : header.segment_base_offset + segments[0].size
158+
]
159+
160+
# Tensor: b"\x11" * 4
161+
t0_start = 0
162+
t0_len = len(TEST_TENSOR_BUFFER[0])
163+
t0_end = t0_start + aligned_size(t0_len, config.tensor_alignment)
164+
self.assertEqual(
165+
segment_data[t0_start : t0_start + t0_len], TEST_TENSOR_BUFFER[0]
166+
)
167+
padding = b"\x00" * (t0_end - t0_len)
168+
self.assertEqual(segment_data[t0_start + t0_len : t0_end], padding)
169+
170+
# Tensor: b"\x22" * 32
171+
t1_start = t0_end
172+
t1_len = len(TEST_TENSOR_BUFFER[1])
173+
t1_end = t1_start + aligned_size(t1_len, config.tensor_alignment)
174+
self.assertEqual(
175+
segment_data[t1_start : t1_start + t1_len],
176+
TEST_TENSOR_BUFFER[1],
177+
)
178+
padding = b"\x00" * (t1_end - (t1_len + t1_start))
179+
self.assertEqual(segment_data[t1_start + t1_len : t1_start + t1_end], padding)
180+
181+
# Check length of the segment is expected.
182+
self.assertEqual(
183+
segments[0].size, aligned_size(t1_end, config.segment_alignment)
85184
)
185+
self.assertEqual(segments[0].size, header.segment_data_size)

0 commit comments

Comments
 (0)