8
8
9
9
import unittest
10
10
11
+ from typing import List
12
+
11
13
from executorch .exir ._serialize .data_serializer import (
12
14
DataPayload ,
13
15
DataSerializer ,
18
20
from executorch .exir ._serialize .padding import aligned_size
19
21
20
22
from executorch .exir .schema import ScalarType
23
+ from executorch .extension .flat_tensor .serialize .flat_tensor_schema import TensorMetadata
21
24
22
25
from executorch .extension .flat_tensor .serialize .serialize import (
26
+ _deserialize_to_flat_tensor ,
23
27
FlatTensorConfig ,
24
28
FlatTensorHeader ,
25
29
FlatTensorSerializer ,
26
30
)
27
31
28
32
# Test artifacts.
29
- TEST_TENSOR_BUFFER = [b"tensor" ]
33
+ TEST_TENSOR_BUFFER : List [ bytes ] = [b"\x11 " * 4 , b" \x22 " * 32 ]
30
34
TEST_TENSOR_MAP = {
31
35
"fqn1" : TensorEntry (
32
36
buffer_index = 0 ,
44
48
dim_order = [0 , 1 , 2 ],
45
49
),
46
50
),
51
+ "fqn3" : TensorEntry (
52
+ buffer_index = 1 ,
53
+ layout = TensorLayout (
54
+ scalar_type = ScalarType .INT ,
55
+ sizes = [2 , 2 , 2 ],
56
+ dim_order = [0 , 1 ],
57
+ ),
58
+ ),
47
59
}
48
60
TEST_DATA_PAYLOAD = DataPayload (
49
61
buffers = TEST_TENSOR_BUFFER ,
52
64
53
65
54
66
class TestSerialize (unittest .TestCase ):
67
+ # TODO(T211851359): improve test coverage.
68
+ def check_tensor_metadata (
69
+ self , tensor_layout : TensorLayout , tensor_metadata : TensorMetadata
70
+ ) -> None :
71
+ self .assertEqual (tensor_layout .scalar_type , tensor_metadata .scalar_type )
72
+ self .assertEqual (tensor_layout .sizes , tensor_metadata .sizes )
73
+ self .assertEqual (tensor_layout .dim_order , tensor_metadata .dim_order )
74
+
55
75
def test_serialize (self ) -> None :
56
76
config = FlatTensorConfig ()
57
77
serializer : DataSerializer = FlatTensorSerializer (config )
58
78
59
- data = bytes (serializer .serialize (TEST_DATA_PAYLOAD ))
79
+ serialized_data = bytes (serializer .serialize (TEST_DATA_PAYLOAD ))
60
80
61
- header = FlatTensorHeader .from_bytes (data [0 : FlatTensorHeader .EXPECTED_LENGTH ])
81
+ # Check header.
82
+ header = FlatTensorHeader .from_bytes (
83
+ serialized_data [0 : FlatTensorHeader .EXPECTED_LENGTH ]
84
+ )
62
85
self .assertTrue (header .is_valid ())
63
86
64
87
# Header is aligned to config.segment_alignment, which is where the flatbuffer starts.
@@ -77,9 +100,86 @@ def test_serialize(self) -> None:
77
100
self .assertTrue (header .segment_base_offset , expected_segment_base_offset )
78
101
79
102
# TEST_TENSOR_BUFFER is aligned to config.segment_alignment.
80
- self .assertEqual (header .segment_data_size , config .segment_alignment )
103
+ expected_segment_data_size = aligned_size (
104
+ sum (len (buffer ) for buffer in TEST_TENSOR_BUFFER ), config .segment_alignment
105
+ )
106
+ self .assertEqual (header .segment_data_size , expected_segment_data_size )
81
107
82
108
# Confirm the flatbuffer magic is present.
83
109
self .assertEqual (
84
- data [header .flatbuffer_offset + 4 : header .flatbuffer_offset + 8 ], b"FT01"
110
+ serialized_data [
111
+ header .flatbuffer_offset + 4 : header .flatbuffer_offset + 8
112
+ ],
113
+ b"FT01" ,
114
+ )
115
+
116
+ # Check flat tensor data.
117
+ flat_tensor_bytes = serialized_data [
118
+ header .flatbuffer_offset : header .flatbuffer_offset + header .flatbuffer_size
119
+ ]
120
+
121
+ flat_tensor = _deserialize_to_flat_tensor (flat_tensor_bytes )
122
+
123
+ self .assertEqual (flat_tensor .version , 0 )
124
+ self .assertEqual (flat_tensor .tensor_alignment , config .tensor_alignment )
125
+
126
+ tensors = flat_tensor .tensors
127
+ self .assertEqual (len (tensors ), 3 )
128
+ self .assertEqual (tensors [0 ].fully_qualified_name , "fqn1" )
129
+ self .check_tensor_metadata (TEST_TENSOR_MAP ["fqn1" ].layout , tensors [0 ])
130
+ self .assertEqual (tensors [0 ].segment_index , 0 )
131
+ self .assertEqual (tensors [0 ].offset , 0 )
132
+
133
+ self .assertEqual (tensors [1 ].fully_qualified_name , "fqn2" )
134
+ self .check_tensor_metadata (TEST_TENSOR_MAP ["fqn2" ].layout , tensors [1 ])
135
+ self .assertEqual (tensors [1 ].segment_index , 0 )
136
+ self .assertEqual (tensors [1 ].offset , 0 )
137
+
138
+ self .assertEqual (tensors [2 ].fully_qualified_name , "fqn3" )
139
+ self .check_tensor_metadata (TEST_TENSOR_MAP ["fqn3" ].layout , tensors [2 ])
140
+ self .assertEqual (tensors [2 ].segment_index , 0 )
141
+ self .assertEqual (tensors [2 ].offset , config .tensor_alignment )
142
+
143
+ segments = flat_tensor .segments
144
+ self .assertEqual (len (segments ), 1 )
145
+ self .assertEqual (segments [0 ].offset , 0 )
146
+ self .assertEqual (segments [0 ].size , config .tensor_alignment * 3 )
147
+
148
+ # Length of serialized_data matches segment_base_offset + segment_data_size.
149
+ self .assertEqual (
150
+ header .segment_base_offset + header .segment_data_size , len (serialized_data )
151
+ )
152
+ self .assertTrue (segments [0 ].size <= header .segment_data_size )
153
+
154
+ # Check the contents of the segment. Expecting two tensors from
155
+ # TEST_TENSOR_BUFFER = [b"\x11" * 4, b"\x22" * 32]
156
+ segment_data = serialized_data [
157
+ header .segment_base_offset : header .segment_base_offset + segments [0 ].size
158
+ ]
159
+
160
+ # Tensor: b"\x11" * 4
161
+ t0_start = 0
162
+ t0_len = len (TEST_TENSOR_BUFFER [0 ])
163
+ t0_end = t0_start + aligned_size (t0_len , config .tensor_alignment )
164
+ self .assertEqual (
165
+ segment_data [t0_start : t0_start + t0_len ], TEST_TENSOR_BUFFER [0 ]
166
+ )
167
+ padding = b"\x00 " * (t0_end - t0_len )
168
+ self .assertEqual (segment_data [t0_start + t0_len : t0_end ], padding )
169
+
170
+ # Tensor: b"\x22" * 32
171
+ t1_start = t0_end
172
+ t1_len = len (TEST_TENSOR_BUFFER [1 ])
173
+ t1_end = t1_start + aligned_size (t1_len , config .tensor_alignment )
174
+ self .assertEqual (
175
+ segment_data [t1_start : t1_start + t1_len ],
176
+ TEST_TENSOR_BUFFER [1 ],
177
+ )
178
+ padding = b"\x00 " * (t1_end - (t1_len + t1_start ))
179
+ self .assertEqual (segment_data [t1_start + t1_len : t1_start + t1_end ], padding )
180
+
181
+ # Check length of the segment is expected.
182
+ self .assertEqual (
183
+ segments [0 ].size , aligned_size (t1_end , config .segment_alignment )
85
184
)
185
+ self .assertEqual (segments [0 ].size , header .segment_data_size )
0 commit comments