diff --git a/extension/flat_tensor/serialize/serialize.py b/extension/flat_tensor/serialize/serialize.py index 9e3df6aafce..6a07892eb5d 100644 --- a/extension/flat_tensor/serialize/serialize.py +++ b/extension/flat_tensor/serialize/serialize.py @@ -14,9 +14,9 @@ import pkg_resources from executorch.exir._serialize._cord import Cord -from executorch.exir._serialize._dataclass import _DataclassEncoder +from executorch.exir._serialize._dataclass import _DataclassEncoder, _json_to_dataclass -from executorch.exir._serialize._flatbuffer import _flatc_compile +from executorch.exir._serialize._flatbuffer import _flatc_compile, _flatc_decompile from executorch.exir._serialize.data_serializer import DataPayload, DataSerializer from executorch.exir._serialize.padding import aligned_size, pad_to, padding_required @@ -33,8 +33,8 @@ ) -def _convert_to_flatbuffer(flat_tensor: FlatTensor) -> Cord: - """Converts a FlatTensor to a flatbuffer and returns the serialized data.""" +def _serialize_to_flatbuffer(flat_tensor: FlatTensor) -> Cord: + """Serializes a FlatTensor to a flatbuffer and returns the serialized data.""" flat_tensor_json = json.dumps(flat_tensor, cls=_DataclassEncoder) with tempfile.TemporaryDirectory() as d: schema_path = os.path.join(d, "flat_tensor.fbs") @@ -57,6 +57,32 @@ def _convert_to_flatbuffer(flat_tensor: FlatTensor) -> Cord: return Cord(output_file.read()) +def _deserialize_to_flat_tensor(flatbuffer: bytes) -> FlatTensor: + """Deserializes a flatbuffer to a FlatTensor and returns the dataclass.""" + with tempfile.TemporaryDirectory() as d: + schema_path = os.path.join(d, "flat_tensor.fbs") + with open(schema_path, "wb") as schema_file: + schema_file.write( + pkg_resources.resource_string(__name__, "flat_tensor.fbs") + ) + + scalar_type_path = os.path.join(d, "scalar_type.fbs") + with open(scalar_type_path, "wb") as scalar_type_file: + scalar_type_file.write( + pkg_resources.resource_string(__name__, "scalar_type.fbs") + ) + + bin_path = os.path.join(d, "flat_tensor.bin") + with open(bin_path, "wb") as bin_file: + bin_file.write(flatbuffer) + + _flatc_decompile(d, schema_path, bin_path, ["--raw-binary"]) + + json_path = os.path.join(d, "flat_tensor.json") + with open(json_path, "rb") as output_file: + return _json_to_dataclass(json.load(output_file), cls=FlatTensor) + + @dataclass class FlatTensorConfig: tensor_alignment: int = 16 @@ -244,7 +270,7 @@ def serialize( segments=[DataSegment(offset=0, size=len(flat_tensor_data))], ) - flatbuffer_payload = _convert_to_flatbuffer(flat_tensor) + flatbuffer_payload = _serialize_to_flatbuffer(flat_tensor) padded_flatbuffer_length: int = aligned_size( input_size=len(flatbuffer_payload), alignment=self.config.tensor_alignment, diff --git a/extension/flat_tensor/test/test_serialize.py b/extension/flat_tensor/test/test_serialize.py index d0235672748..57dbdb8c192 100644 --- a/extension/flat_tensor/test/test_serialize.py +++ b/extension/flat_tensor/test/test_serialize.py @@ -8,6 +8,8 @@ import unittest +from typing import List + from executorch.exir._serialize.data_serializer import ( DataPayload, DataSerializer, @@ -18,15 +20,17 @@ from executorch.exir._serialize.padding import aligned_size from executorch.exir.schema import ScalarType +from executorch.extension.flat_tensor.serialize.flat_tensor_schema import TensorMetadata from executorch.extension.flat_tensor.serialize.serialize import ( + _deserialize_to_flat_tensor, FlatTensorConfig, FlatTensorHeader, FlatTensorSerializer, ) # Test artifacts. -TEST_TENSOR_BUFFER = [b"tensor"] +TEST_TENSOR_BUFFER: List[bytes] = [b"\x11" * 4, b"\x22" * 32] TEST_TENSOR_MAP = { "fqn1": TensorEntry( buffer_index=0, @@ -44,6 +48,14 @@ dim_order=[0, 1, 2], ), ), + "fqn3": TensorEntry( + buffer_index=1, + layout=TensorLayout( + scalar_type=ScalarType.INT, + sizes=[2, 2, 2], + dim_order=[0, 1], + ), + ), } TEST_DATA_PAYLOAD = DataPayload( buffers=TEST_TENSOR_BUFFER, @@ -52,13 +64,24 @@ class TestSerialize(unittest.TestCase): + # TODO(T211851359): improve test coverage. + def check_tensor_metadata( + self, tensor_layout: TensorLayout, tensor_metadata: TensorMetadata + ) -> None: + self.assertEqual(tensor_layout.scalar_type, tensor_metadata.scalar_type) + self.assertEqual(tensor_layout.sizes, tensor_metadata.sizes) + self.assertEqual(tensor_layout.dim_order, tensor_metadata.dim_order) + def test_serialize(self) -> None: config = FlatTensorConfig() serializer: DataSerializer = FlatTensorSerializer(config) - data = bytes(serializer.serialize(TEST_DATA_PAYLOAD)) + serialized_data = bytes(serializer.serialize(TEST_DATA_PAYLOAD)) - header = FlatTensorHeader.from_bytes(data[0 : FlatTensorHeader.EXPECTED_LENGTH]) + # Check header. + header = FlatTensorHeader.from_bytes( + serialized_data[0 : FlatTensorHeader.EXPECTED_LENGTH] + ) self.assertTrue(header.is_valid()) # Header is aligned to config.segment_alignment, which is where the flatbuffer starts. @@ -77,9 +100,86 @@ def test_serialize(self) -> None: self.assertTrue(header.segment_base_offset, expected_segment_base_offset) # TEST_TENSOR_BUFFER is aligned to config.segment_alignment. - self.assertEqual(header.segment_data_size, config.segment_alignment) + expected_segment_data_size = aligned_size( + sum(len(buffer) for buffer in TEST_TENSOR_BUFFER), config.segment_alignment + ) + self.assertEqual(header.segment_data_size, expected_segment_data_size) # Confirm the flatbuffer magic is present. self.assertEqual( - data[header.flatbuffer_offset + 4 : header.flatbuffer_offset + 8], b"FT01" + serialized_data[ + header.flatbuffer_offset + 4 : header.flatbuffer_offset + 8 + ], + b"FT01", + ) + + # Check flat tensor data. + flat_tensor_bytes = serialized_data[ + header.flatbuffer_offset : header.flatbuffer_offset + header.flatbuffer_size + ] + + flat_tensor = _deserialize_to_flat_tensor(flat_tensor_bytes) + + self.assertEqual(flat_tensor.version, 0) + self.assertEqual(flat_tensor.tensor_alignment, config.tensor_alignment) + + tensors = flat_tensor.tensors + self.assertEqual(len(tensors), 3) + self.assertEqual(tensors[0].fully_qualified_name, "fqn1") + self.check_tensor_metadata(TEST_TENSOR_MAP["fqn1"].layout, tensors[0]) + self.assertEqual(tensors[0].segment_index, 0) + self.assertEqual(tensors[0].offset, 0) + + self.assertEqual(tensors[1].fully_qualified_name, "fqn2") + self.check_tensor_metadata(TEST_TENSOR_MAP["fqn2"].layout, tensors[1]) + self.assertEqual(tensors[1].segment_index, 0) + self.assertEqual(tensors[1].offset, 0) + + self.assertEqual(tensors[2].fully_qualified_name, "fqn3") + self.check_tensor_metadata(TEST_TENSOR_MAP["fqn3"].layout, tensors[2]) + self.assertEqual(tensors[2].segment_index, 0) + self.assertEqual(tensors[2].offset, config.tensor_alignment) + + segments = flat_tensor.segments + self.assertEqual(len(segments), 1) + self.assertEqual(segments[0].offset, 0) + self.assertEqual(segments[0].size, config.tensor_alignment * 3) + + # Length of serialized_data matches segment_base_offset + segment_data_size. + self.assertEqual( + header.segment_base_offset + header.segment_data_size, len(serialized_data) + ) + self.assertTrue(segments[0].size <= header.segment_data_size) + + # Check the contents of the segment. Expecting two tensors from + # TEST_TENSOR_BUFFER = [b"\x11" * 4, b"\x22" * 32] + segment_data = serialized_data[ + header.segment_base_offset : header.segment_base_offset + segments[0].size + ] + + # Tensor: b"\x11" * 4 + t0_start = 0 + t0_len = len(TEST_TENSOR_BUFFER[0]) + t0_end = t0_start + aligned_size(t0_len, config.tensor_alignment) + self.assertEqual( + segment_data[t0_start : t0_start + t0_len], TEST_TENSOR_BUFFER[0] + ) + padding = b"\x00" * (t0_end - t0_len) + self.assertEqual(segment_data[t0_start + t0_len : t0_end], padding) + + # Tensor: b"\x22" * 32 + t1_start = t0_end + t1_len = len(TEST_TENSOR_BUFFER[1]) + t1_end = t1_start + aligned_size(t1_len, config.tensor_alignment) + self.assertEqual( + segment_data[t1_start : t1_start + t1_len], + TEST_TENSOR_BUFFER[1], + ) + padding = b"\x00" * (t1_end - (t1_len + t1_start)) + self.assertEqual(segment_data[t1_start + t1_len : t1_start + t1_end], padding) + + # Check length of the segment is expected. + self.assertEqual( + segments[0].size, aligned_size(t1_end, config.segment_alignment) ) + self.assertEqual(segments[0].size, header.segment_data_size)