From 5d43ae2bb8eca566d5d029e46cb373adf44dec25 Mon Sep 17 00:00:00 2001 From: lucylq Date: Fri, 17 Jan 2025 14:12:52 -0800 Subject: [PATCH 1/3] [executorch][flat_tensor] Serialize flat tensor tests Pull Request resolved: https://github.com/pytorch/executorch/pull/7269 Introduce _convert_to_flat_tensor, which interprets a flat_tensor blob as a flat_tensor schema. Use this for more comprehensive testing for flat tensor serialization, and later for deserialization. ghstack-source-id: 261976100 @exported-using-ghexport Differential Revision: [D67007821](https://our.internmc.facebook.com/intern/diff/D67007821/) --- extension/flat_tensor/serialize/serialize.py | 36 +++++- extension/flat_tensor/test/test_serialize.py | 110 ++++++++++++++++++- 2 files changed, 136 insertions(+), 10 deletions(-) diff --git a/extension/flat_tensor/serialize/serialize.py b/extension/flat_tensor/serialize/serialize.py index 9e3df6aafce..6a07892eb5d 100644 --- a/extension/flat_tensor/serialize/serialize.py +++ b/extension/flat_tensor/serialize/serialize.py @@ -14,9 +14,9 @@ import pkg_resources from executorch.exir._serialize._cord import Cord -from executorch.exir._serialize._dataclass import _DataclassEncoder +from executorch.exir._serialize._dataclass import _DataclassEncoder, _json_to_dataclass -from executorch.exir._serialize._flatbuffer import _flatc_compile +from executorch.exir._serialize._flatbuffer import _flatc_compile, _flatc_decompile from executorch.exir._serialize.data_serializer import DataPayload, DataSerializer from executorch.exir._serialize.padding import aligned_size, pad_to, padding_required @@ -33,8 +33,8 @@ ) -def _convert_to_flatbuffer(flat_tensor: FlatTensor) -> Cord: - """Converts a FlatTensor to a flatbuffer and returns the serialized data.""" +def _serialize_to_flatbuffer(flat_tensor: FlatTensor) -> Cord: + """Serializes a FlatTensor to a flatbuffer and returns the serialized data.""" flat_tensor_json = json.dumps(flat_tensor, cls=_DataclassEncoder) with tempfile.TemporaryDirectory() as d: schema_path = os.path.join(d, "flat_tensor.fbs") @@ -57,6 +57,32 @@ def _convert_to_flatbuffer(flat_tensor: FlatTensor) -> Cord: return Cord(output_file.read()) +def _deserialize_to_flat_tensor(flatbuffer: bytes) -> FlatTensor: + """Deserializes a flatbuffer to a FlatTensor and returns the dataclass.""" + with tempfile.TemporaryDirectory() as d: + schema_path = os.path.join(d, "flat_tensor.fbs") + with open(schema_path, "wb") as schema_file: + schema_file.write( + pkg_resources.resource_string(__name__, "flat_tensor.fbs") + ) + + scalar_type_path = os.path.join(d, "scalar_type.fbs") + with open(scalar_type_path, "wb") as scalar_type_file: + scalar_type_file.write( + pkg_resources.resource_string(__name__, "scalar_type.fbs") + ) + + bin_path = os.path.join(d, "flat_tensor.bin") + with open(bin_path, "wb") as bin_file: + bin_file.write(flatbuffer) + + _flatc_decompile(d, schema_path, bin_path, ["--raw-binary"]) + + json_path = os.path.join(d, "flat_tensor.json") + with open(json_path, "rb") as output_file: + return _json_to_dataclass(json.load(output_file), cls=FlatTensor) + + @dataclass class FlatTensorConfig: tensor_alignment: int = 16 @@ -244,7 +270,7 @@ def serialize( segments=[DataSegment(offset=0, size=len(flat_tensor_data))], ) - flatbuffer_payload = _convert_to_flatbuffer(flat_tensor) + flatbuffer_payload = _serialize_to_flatbuffer(flat_tensor) padded_flatbuffer_length: int = aligned_size( input_size=len(flatbuffer_payload), alignment=self.config.tensor_alignment, diff --git a/extension/flat_tensor/test/test_serialize.py b/extension/flat_tensor/test/test_serialize.py index d0235672748..57dbdb8c192 100644 --- a/extension/flat_tensor/test/test_serialize.py +++ b/extension/flat_tensor/test/test_serialize.py @@ -8,6 +8,8 @@ import unittest +from typing import List + from executorch.exir._serialize.data_serializer import ( DataPayload, DataSerializer, @@ -18,15 +20,17 @@ from executorch.exir._serialize.padding import aligned_size from executorch.exir.schema import ScalarType +from executorch.extension.flat_tensor.serialize.flat_tensor_schema import TensorMetadata from executorch.extension.flat_tensor.serialize.serialize import ( + _deserialize_to_flat_tensor, FlatTensorConfig, FlatTensorHeader, FlatTensorSerializer, ) # Test artifacts. -TEST_TENSOR_BUFFER = [b"tensor"] +TEST_TENSOR_BUFFER: List[bytes] = [b"\x11" * 4, b"\x22" * 32] TEST_TENSOR_MAP = { "fqn1": TensorEntry( buffer_index=0, @@ -44,6 +48,14 @@ dim_order=[0, 1, 2], ), ), + "fqn3": TensorEntry( + buffer_index=1, + layout=TensorLayout( + scalar_type=ScalarType.INT, + sizes=[2, 2, 2], + dim_order=[0, 1], + ), + ), } TEST_DATA_PAYLOAD = DataPayload( buffers=TEST_TENSOR_BUFFER, @@ -52,13 +64,24 @@ class TestSerialize(unittest.TestCase): + # TODO(T211851359): improve test coverage. + def check_tensor_metadata( + self, tensor_layout: TensorLayout, tensor_metadata: TensorMetadata + ) -> None: + self.assertEqual(tensor_layout.scalar_type, tensor_metadata.scalar_type) + self.assertEqual(tensor_layout.sizes, tensor_metadata.sizes) + self.assertEqual(tensor_layout.dim_order, tensor_metadata.dim_order) + def test_serialize(self) -> None: config = FlatTensorConfig() serializer: DataSerializer = FlatTensorSerializer(config) - data = bytes(serializer.serialize(TEST_DATA_PAYLOAD)) + serialized_data = bytes(serializer.serialize(TEST_DATA_PAYLOAD)) - header = FlatTensorHeader.from_bytes(data[0 : FlatTensorHeader.EXPECTED_LENGTH]) + # Check header. + header = FlatTensorHeader.from_bytes( + serialized_data[0 : FlatTensorHeader.EXPECTED_LENGTH] + ) self.assertTrue(header.is_valid()) # Header is aligned to config.segment_alignment, which is where the flatbuffer starts. @@ -77,9 +100,86 @@ def test_serialize(self) -> None: self.assertTrue(header.segment_base_offset, expected_segment_base_offset) # TEST_TENSOR_BUFFER is aligned to config.segment_alignment. - self.assertEqual(header.segment_data_size, config.segment_alignment) + expected_segment_data_size = aligned_size( + sum(len(buffer) for buffer in TEST_TENSOR_BUFFER), config.segment_alignment + ) + self.assertEqual(header.segment_data_size, expected_segment_data_size) # Confirm the flatbuffer magic is present. self.assertEqual( - data[header.flatbuffer_offset + 4 : header.flatbuffer_offset + 8], b"FT01" + serialized_data[ + header.flatbuffer_offset + 4 : header.flatbuffer_offset + 8 + ], + b"FT01", + ) + + # Check flat tensor data. + flat_tensor_bytes = serialized_data[ + header.flatbuffer_offset : header.flatbuffer_offset + header.flatbuffer_size + ] + + flat_tensor = _deserialize_to_flat_tensor(flat_tensor_bytes) + + self.assertEqual(flat_tensor.version, 0) + self.assertEqual(flat_tensor.tensor_alignment, config.tensor_alignment) + + tensors = flat_tensor.tensors + self.assertEqual(len(tensors), 3) + self.assertEqual(tensors[0].fully_qualified_name, "fqn1") + self.check_tensor_metadata(TEST_TENSOR_MAP["fqn1"].layout, tensors[0]) + self.assertEqual(tensors[0].segment_index, 0) + self.assertEqual(tensors[0].offset, 0) + + self.assertEqual(tensors[1].fully_qualified_name, "fqn2") + self.check_tensor_metadata(TEST_TENSOR_MAP["fqn2"].layout, tensors[1]) + self.assertEqual(tensors[1].segment_index, 0) + self.assertEqual(tensors[1].offset, 0) + + self.assertEqual(tensors[2].fully_qualified_name, "fqn3") + self.check_tensor_metadata(TEST_TENSOR_MAP["fqn3"].layout, tensors[2]) + self.assertEqual(tensors[2].segment_index, 0) + self.assertEqual(tensors[2].offset, config.tensor_alignment) + + segments = flat_tensor.segments + self.assertEqual(len(segments), 1) + self.assertEqual(segments[0].offset, 0) + self.assertEqual(segments[0].size, config.tensor_alignment * 3) + + # Length of serialized_data matches segment_base_offset + segment_data_size. + self.assertEqual( + header.segment_base_offset + header.segment_data_size, len(serialized_data) + ) + self.assertTrue(segments[0].size <= header.segment_data_size) + + # Check the contents of the segment. Expecting two tensors from + # TEST_TENSOR_BUFFER = [b"\x11" * 4, b"\x22" * 32] + segment_data = serialized_data[ + header.segment_base_offset : header.segment_base_offset + segments[0].size + ] + + # Tensor: b"\x11" * 4 + t0_start = 0 + t0_len = len(TEST_TENSOR_BUFFER[0]) + t0_end = t0_start + aligned_size(t0_len, config.tensor_alignment) + self.assertEqual( + segment_data[t0_start : t0_start + t0_len], TEST_TENSOR_BUFFER[0] + ) + padding = b"\x00" * (t0_end - t0_len) + self.assertEqual(segment_data[t0_start + t0_len : t0_end], padding) + + # Tensor: b"\x22" * 32 + t1_start = t0_end + t1_len = len(TEST_TENSOR_BUFFER[1]) + t1_end = t1_start + aligned_size(t1_len, config.tensor_alignment) + self.assertEqual( + segment_data[t1_start : t1_start + t1_len], + TEST_TENSOR_BUFFER[1], + ) + padding = b"\x00" * (t1_end - (t1_len + t1_start)) + self.assertEqual(segment_data[t1_start + t1_len : t1_start + t1_end], padding) + + # Check length of the segment is expected. + self.assertEqual( + segments[0].size, aligned_size(t1_end, config.segment_alignment) ) + self.assertEqual(segments[0].size, header.segment_data_size) From 3523843a58367bde8ddcd3842f2400fb21ac1712 Mon Sep 17 00:00:00 2001 From: lucylq Date: Fri, 17 Jan 2025 16:42:18 -0800 Subject: [PATCH 2/3] [executorch][serialization] Serialize PTD files. Pull Request resolved: https://github.com/pytorch/executorch/pull/7270 Introduce top-level serialization file that calls: - serialize_pte_binary for PTE file - FlatTensor.serialize_tensors for PTD files. ghstack-source-id: 262004271 @exported-using-ghexport Differential Revision: [D66523267](https://our.internmc.facebook.com/intern/diff/D66523267/) --- exir/_serialize/TARGETS | 1 + exir/_serialize/_serialize.py | 91 ++++++++++++++++++++++++++++++++++ exir/program/TARGETS | 1 + exir/program/_program.py | 62 ++++++++++++++++------- extension/export_util/utils.py | 3 ++ 5 files changed, 139 insertions(+), 19 deletions(-) create mode 100644 exir/_serialize/_serialize.py diff --git a/exir/_serialize/TARGETS b/exir/_serialize/TARGETS index cd6a4bc5a2f..cc6f16d78d8 100644 --- a/exir/_serialize/TARGETS +++ b/exir/_serialize/TARGETS @@ -33,6 +33,7 @@ runtime.python_library( "_dataclass.py", "_flatbuffer.py", "_program.py", + "_serialize.py", "data_serializer.py", "padding.py", ], diff --git a/exir/_serialize/_serialize.py b/exir/_serialize/_serialize.py new file mode 100644 index 00000000000..c311274922f --- /dev/null +++ b/exir/_serialize/_serialize.py @@ -0,0 +1,91 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + + +from typing import Dict, Tuple + +from executorch.exir._serialize import _serialize_pte_binary + +from executorch.exir._serialize._cord import Cord +from executorch.exir._serialize.data_serializer import ( + DataPayload, + DataSerializer, + TensorEntry, + TensorLayout, +) + +from executorch.exir.capture._config import ExecutorchBackendConfig +from executorch.exir.emit import EmitterOutput +from executorch.exir.schema import Tensor, TensorDataLocation + + +def serialize_for_executorch( + emitter_output: EmitterOutput, + config: ExecutorchBackendConfig, + data_serializer: DataSerializer, +) -> Tuple[Cord, Dict[str, Cord]]: + """Serialize the output from Emitter into ExecuTorch artifacts; PTE and PTD files.""" + + # Serialize PTE file. + pte: Cord = _serialize_pte_binary( + program=emitter_output.program, + mutable_data=emitter_output.mutable_data, + extract_delegate_segments=config.extract_delegate_segments, + segment_alignment=config.segment_alignment, + constant_tensor_alignment=config.constant_tensor_alignment, + delegate_alignment=config.delegate_alignment, + ) + + # Serialize PTD files. + ptd_files: Dict[str, Cord] = {} + + # Find all external tensors and organize into {fqn: TensorLayout}. + fqn_to_tensor_layout: Dict[str, TensorLayout] = {} + for plan in emitter_output.program.execution_plan: + for evalue in plan.values: + if isinstance(evalue.val, Tensor): + tensor = evalue.val + if ( + tensor.extra_tensor_info is not None + and tensor.extra_tensor_info.fully_qualified_name is not None + and tensor.extra_tensor_info.location is TensorDataLocation.EXTERNAL + ): + fqn_to_tensor_layout[ + tensor.extra_tensor_info.fully_qualified_name + ] = TensorLayout(tensor.scalar_type, tensor.sizes, tensor.dim_order) + + if len(fqn_to_tensor_layout) > 0: + # emitter_output.external_constant_map contains the mapping from + # {file: {fqn: index into external_constant_buffer}} + # Contains the locations of the tensor buffers, and must be non-empty + # if there are external tensors to serialize. + assert emitter_output.external_constant_map is not None + for ( + filename, + fqn_to_index, + ) in ( + # pyre-ignore Undefined attribute [16]: Optional type has no attribute `items`. + emitter_output.external_constant_map.items() + ): + # Create a TensorEntry for each external tensor. + fqn_to_tensor_entry: Dict[str, TensorEntry] = {} + for fqn, index in fqn_to_index.items(): + assert fqn in fqn_to_tensor_layout + fqn_to_tensor_entry[fqn] = TensorEntry( + buffer_index=index, + layout=fqn_to_tensor_layout[fqn], + ) + + ptd_files[filename] = data_serializer.serialize( + DataPayload( + buffers=emitter_output.external_constant_buffer, + fqn_to_tensor=fqn_to_tensor_entry, + ) + ) + + return pte, ptd_files diff --git a/exir/program/TARGETS b/exir/program/TARGETS index 674d7baa35e..33e417e7326 100644 --- a/exir/program/TARGETS +++ b/exir/program/TARGETS @@ -44,6 +44,7 @@ python_library( "//executorch/exir/passes:spec_prop_pass", "//executorch/exir/passes:weights_to_outputs_pass", "//executorch/exir/verification:verifier", + "//executorch/extension/flat_tensor/serialize:serialize", ] + (["//executorch/exir/program/fb:logger"] if not runtime.is_oss else []) ) diff --git a/exir/program/_program.py b/exir/program/_program.py index 7dbf97a047b..e8cee0b5da8 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -9,12 +9,14 @@ import copy import io import logging +import os from typing import Any, Dict, List, Optional, Sequence, Set, TextIO, Tuple, Union import torch import torch._export -from executorch.exir._serialize import _serialize_pte_binary from executorch.exir._serialize._cord import Cord +from executorch.exir._serialize._serialize import serialize_for_executorch +from executorch.exir._serialize.data_serializer import DataSerializer from executorch.exir._warnings import experimental from executorch.exir.backend.backend_api import to_backend from executorch.exir.backend.partitioner import Partitioner @@ -59,6 +61,7 @@ EXIREdgeDialectVerifier, get_aten_verifier, ) +from executorch.extension.flat_tensor.serialize.serialize import FlatTensorSerializer from torch._export.passes import ReplaceViewOpsWithViewCopyOpsPass from torch.export import ExportedProgram from torch.export._remove_auto_functionalized_pass import ( @@ -497,6 +500,7 @@ def __init__( ) self.exported_program = exir_exported_program.exported_program self._pte_data: Optional[Cord] = None + self._tensor_data: Optional[Dict[str, Cord]] = None self._buffer: Optional[bytes] = None self._emitter_output: Optional[EmitterOutput] = None self._emit_stacktrace: bool = emit_stacktrace @@ -504,16 +508,23 @@ def __init__( self._segment_alignment: int = segment_alignment self._constant_tensor_alignment: Optional[int] = constant_tensor_alignment self._delegate_alignment: Optional[int] = delegate_alignment + self._data_serializer: DataSerializer = FlatTensorSerializer() + + def _get_emitter_output(self) -> EmitterOutput: + if self._emitter_output is None: + self._emitter_output = emit_program( + self.exported_program, self._emit_stacktrace + ) + return self._emitter_output def _get_pte_data(self) -> Cord: if self._pte_data is None: - self._pte_data = _serialize_pte_binary( - program=self.program, - extract_delegate_segments=self._extract_delegate_segments, - segment_alignment=self._segment_alignment, - constant_tensor_alignment=self._constant_tensor_alignment, - delegate_alignment=self._delegate_alignment, + self._pte_data, self._tensor_data = serialize_for_executorch( + self._get_emitter_output(), + ExecutorchBackendConfig(), + self._data_serializer, ) + assert self._pte_data is not None return self._pte_data @property @@ -532,11 +543,7 @@ def buffer(self) -> bytes: @property def program(self) -> Program: - if self._emitter_output is None: - self._emitter_output = emit_program( - self.exported_program, self._emit_stacktrace - ) - return self._emitter_output.program + return self._get_emitter_output().program @property def debug_handle_map(self) -> Dict[int, Union[int, List[int]]]: @@ -571,6 +578,17 @@ def write_to_file(self, open_file: io.BufferedIOBase) -> None: """ self._get_pte_data().write_to_file(open_file) + def write_tensor_data_to_file(self, outdir) -> None: + """ + Writes the serialized ExecuTorch data files to the directory at `outdir`. + """ + assert self._tensor_data is not None + # pyre-ignore[16]: `Optional` has no attribute `items`. + for filename, cord in self._tensor_data.items(): + with open(os.path.join(outdir, f"{filename}.ptd"), "wb") as f: + logging.info(f"Writing data file to {filename}.ptd") + cord.write_to_file(f) + def _get_aten_to_edge_passes(config: EdgeCompileConfig): # TODO: the last two passes for aten_to_edge need to be eliminated_dead_code -> debug_handle_generator. After enable @@ -1453,13 +1471,9 @@ def __init__( ) # Serialize emitter output, ready to be written to a file. - self._pte_data: Cord = _serialize_pte_binary( - program=self._emitter_output.program, - mutable_data=self._emitter_output.mutable_data, - extract_delegate_segments=backend_config.extract_delegate_segments, - segment_alignment=backend_config.segment_alignment, - constant_tensor_alignment=backend_config.constant_tensor_alignment, - delegate_alignment=backend_config.delegate_alignment, + self._data_serializer = FlatTensorSerializer() + self._pte_data, self._tensor_data = serialize_for_executorch( + self._emitter_output, ExecutorchBackendConfig(), self._data_serializer ) self._buffer: Optional[bytes] = None @@ -1542,6 +1556,16 @@ def write_to_file(self, open_file: io.BufferedIOBase) -> None: """ self._pte_data.write_to_file(open_file) + def write_tensor_data_to_file(self, outdir) -> None: + """ + Writes the serialized ExecuTorch data files to the directory at `outdir`. + """ + assert self._tensor_data is not None + for filename, cord in self._tensor_data.items(): + with open(os.path.join(outdir, f"{filename}.ptd"), "wb") as f: + logging.info(f"Writing data file to {filename}") + cord.write_to_file(f) + def save(self, path: str) -> None: """ Saves the serialized ExecuTorch binary to the file at `path`. diff --git a/extension/export_util/utils.py b/extension/export_util/utils.py index a289355919e..2679930178a 100644 --- a/extension/export_util/utils.py +++ b/extension/export_util/utils.py @@ -135,9 +135,12 @@ def save_pte_program( filename = os.path.join(output_dir, f"{model_name}.pte") try: + # Write program to file. with open(filename, "wb") as file: prog.write_to_file(file) logging.info(f"Saved exported program to {filename}") + # Write data to file/s. + prog.write_tensor_data_to_file(outdir=output_dir) except Exception as e: logging.error(f"Error while saving to {filename}: {e}") From de1e754f6fe1739ade4690bb7fe751d046c7c29e Mon Sep 17 00:00:00 2001 From: lucylq Date: Fri, 17 Jan 2025 16:44:55 -0800 Subject: [PATCH 3/3] [executorch][core] Add TensorLayout to core Introduce TensorLayout class, used to describe external tensors. Currently contains: - scalar_type - sizes - dim_order Differential Revision: [D67048723](https://our.internmc.facebook.com/intern/diff/D67048723/) ghstack-source-id: 262004745 Pull Request resolved: https://github.com/pytorch/executorch/pull/7761 --- runtime/core/targets.bzl | 1 + runtime/core/tensor_layout.h | 91 ++++++++++++++++++++++++ runtime/core/test/targets.bzl | 9 +++ runtime/core/test/tensor_layout_test.cpp | 40 +++++++++++ 4 files changed, 141 insertions(+) create mode 100644 runtime/core/tensor_layout.h create mode 100644 runtime/core/test/tensor_layout_test.cpp diff --git a/runtime/core/targets.bzl b/runtime/core/targets.bzl index 7e0aeb5d28c..4b8a7869afe 100644 --- a/runtime/core/targets.bzl +++ b/runtime/core/targets.bzl @@ -34,6 +34,7 @@ def define_common_targets(): "freeable_buffer.h", "result.h", "span.h", + "tensor_layout.h", ], visibility = [ "//executorch/...", diff --git a/runtime/core/tensor_layout.h b/runtime/core/tensor_layout.h new file mode 100644 index 00000000000..84238561412 --- /dev/null +++ b/runtime/core/tensor_layout.h @@ -0,0 +1,91 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include + +namespace executorch { +namespace runtime { + +namespace { +size_t calculate_nbytes( + const Span& sizes, + const exec_aten::ScalarType& scalar_type) { + ssize_t n = 1; + for (ssize_t i = 0; i < sizes.size(); i++) { + ET_CHECK(sizes[i] >= 0); + n *= sizes[i]; + } + // Use the full namespace to disambiguate from c10::elementSize. + return n * executorch::runtime::elementSize(scalar_type); +} +} // namespace + +/** + * Metadata describing the layout of external tensors (tensors that are not + stored in the PTE file). + * + * The NamedDataMap used to create the TensorLayout must outlive the + TensorLayout. + */ +class TensorLayout { + public: + TensorLayout( + executorch::aten::ScalarType scalar_type, + Span sizes, + Span dim_order) + : sizes_(sizes), + dim_order_(dim_order), + scalar_type_(scalar_type), + nbytes_(calculate_nbytes(sizes_, scalar_type_)) {} + + TensorLayout(const TensorLayout&) = default; + TensorLayout(TensorLayout&&) = default; + TensorLayout& operator=(const TensorLayout&) = default; + TensorLayout& operator=(TensorLayout&& other) = default; + ~TensorLayout() = default; + + /// Returns the sizes of the tensor. + Span sizes() const { + return sizes_; + } + + /// Returns the dim order of the tensor. + Span dim_order() const { + return dim_order_; + } + + /// Returns the scalar type of the tensor. + executorch::aten::ScalarType scalar_type() const { + return scalar_type_; + } + + /// Returns the size of the tensor in bytes. + size_t nbytes() const { + return nbytes_; + } + + private: + /// The sizes of the tensor. + Span sizes_; + + /// The dim order of the tensor. + Span dim_order_; + + /// The scalar type of the tensor. + executorch::aten::ScalarType scalar_type_; + + /// The size in bytes of the tensor. + size_t nbytes_; +}; + +} // namespace runtime +} // namespace executorch diff --git a/runtime/core/test/targets.bzl b/runtime/core/test/targets.bzl index 2857de308bd..4cc1290c5e0 100644 --- a/runtime/core/test/targets.bzl +++ b/runtime/core/test/targets.bzl @@ -15,6 +15,15 @@ def define_common_targets(): ], ) + runtime.cxx_test( + name = "tensor_layout_test", + srcs = ["tensor_layout_test.cpp"], + deps = [ + "//executorch/runtime/core:core", + "//executorch/runtime/core/exec_aten:lib", + ], + ) + runtime.cxx_test( name = "error_handling_test", srcs = [ diff --git a/runtime/core/test/tensor_layout_test.cpp b/runtime/core/test/tensor_layout_test.cpp new file mode 100644 index 00000000000..e9065f93cce --- /dev/null +++ b/runtime/core/test/tensor_layout_test.cpp @@ -0,0 +1,40 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include + +using namespace ::testing; +using executorch::aten::ScalarType; +using executorch::runtime::Span; +using executorch::runtime::TensorLayout; + +TEST(TestTensorLayout, Ctor) { + int32_t sizes[2] = {1, 2}; + uint8_t dim_order[2] = {0, 1}; + + Span sizes_span = {sizes, sizes + 2}; + Span dim_order_span = {dim_order, dim_order + 2}; + + TensorLayout layout = + TensorLayout(ScalarType::Float, sizes_span, dim_order_span); + + EXPECT_EQ(layout.scalar_type(), ScalarType::Float); + + EXPECT_EQ(layout.sizes().size(), sizes_span.size()); + EXPECT_EQ(layout.sizes()[0], sizes_span[0]); + EXPECT_EQ(layout.sizes()[1], sizes_span[1]); + + EXPECT_EQ(layout.dim_order().size(), dim_order_span.size()); + EXPECT_EQ(layout.dim_order()[0], dim_order_span[0]); + EXPECT_EQ(layout.dim_order()[1], dim_order_span[1]); + + EXPECT_EQ(layout.nbytes(), 8); +}