Skip to content

Commit 948fba6

Browse files
pytorchbotlucylq
andauthored
[executorch][serialization] Serialize PTD files.
Pull Request resolved: #7270 Introduce top-level serialization file that calls: - serialize_pte_binary for PTE file - FlatTensor.serialize_tensors for PTD files. ghstack-source-id: 262004271 @exported-using-ghexport Differential Revision: [D66523267](https://our.internmc.facebook.com/intern/diff/D66523267/) --------- Co-authored-by: lucylq <[email protected]>
1 parent 5f6fa23 commit 948fba6

File tree

5 files changed

+139
-19
lines changed

5 files changed

+139
-19
lines changed

exir/_serialize/TARGETS

+1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ runtime.python_library(
3333
"_dataclass.py",
3434
"_flatbuffer.py",
3535
"_program.py",
36+
"_serialize.py",
3637
"data_serializer.py",
3738
"padding.py",
3839
],

exir/_serialize/_serialize.py

+91
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
10+
from typing import Dict, Tuple
11+
12+
from executorch.exir._serialize import _serialize_pte_binary
13+
14+
from executorch.exir._serialize._cord import Cord
15+
from executorch.exir._serialize.data_serializer import (
16+
DataPayload,
17+
DataSerializer,
18+
TensorEntry,
19+
TensorLayout,
20+
)
21+
22+
from executorch.exir.capture._config import ExecutorchBackendConfig
23+
from executorch.exir.emit import EmitterOutput
24+
from executorch.exir.schema import Tensor, TensorDataLocation
25+
26+
27+
def serialize_for_executorch(
28+
emitter_output: EmitterOutput,
29+
config: ExecutorchBackendConfig,
30+
data_serializer: DataSerializer,
31+
) -> Tuple[Cord, Dict[str, Cord]]:
32+
"""Serialize the output from Emitter into ExecuTorch artifacts; PTE and PTD files."""
33+
34+
# Serialize PTE file.
35+
pte: Cord = _serialize_pte_binary(
36+
program=emitter_output.program,
37+
mutable_data=emitter_output.mutable_data,
38+
extract_delegate_segments=config.extract_delegate_segments,
39+
segment_alignment=config.segment_alignment,
40+
constant_tensor_alignment=config.constant_tensor_alignment,
41+
delegate_alignment=config.delegate_alignment,
42+
)
43+
44+
# Serialize PTD files.
45+
ptd_files: Dict[str, Cord] = {}
46+
47+
# Find all external tensors and organize into {fqn: TensorLayout}.
48+
fqn_to_tensor_layout: Dict[str, TensorLayout] = {}
49+
for plan in emitter_output.program.execution_plan:
50+
for evalue in plan.values:
51+
if isinstance(evalue.val, Tensor):
52+
tensor = evalue.val
53+
if (
54+
tensor.extra_tensor_info is not None
55+
and tensor.extra_tensor_info.fully_qualified_name is not None
56+
and tensor.extra_tensor_info.location is TensorDataLocation.EXTERNAL
57+
):
58+
fqn_to_tensor_layout[
59+
tensor.extra_tensor_info.fully_qualified_name
60+
] = TensorLayout(tensor.scalar_type, tensor.sizes, tensor.dim_order)
61+
62+
if len(fqn_to_tensor_layout) > 0:
63+
# emitter_output.external_constant_map contains the mapping from
64+
# {file: {fqn: index into external_constant_buffer}}
65+
# Contains the locations of the tensor buffers, and must be non-empty
66+
# if there are external tensors to serialize.
67+
assert emitter_output.external_constant_map is not None
68+
for (
69+
filename,
70+
fqn_to_index,
71+
) in (
72+
# pyre-ignore Undefined attribute [16]: Optional type has no attribute `items`.
73+
emitter_output.external_constant_map.items()
74+
):
75+
# Create a TensorEntry for each external tensor.
76+
fqn_to_tensor_entry: Dict[str, TensorEntry] = {}
77+
for fqn, index in fqn_to_index.items():
78+
assert fqn in fqn_to_tensor_layout
79+
fqn_to_tensor_entry[fqn] = TensorEntry(
80+
buffer_index=index,
81+
layout=fqn_to_tensor_layout[fqn],
82+
)
83+
84+
ptd_files[filename] = data_serializer.serialize(
85+
DataPayload(
86+
buffers=emitter_output.external_constant_buffer,
87+
fqn_to_tensor=fqn_to_tensor_entry,
88+
)
89+
)
90+
91+
return pte, ptd_files

exir/program/TARGETS

+1
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ python_library(
4444
"//executorch/exir/passes:spec_prop_pass",
4545
"//executorch/exir/passes:weights_to_outputs_pass",
4646
"//executorch/exir/verification:verifier",
47+
"//executorch/extension/flat_tensor/serialize:serialize",
4748
] + (["//executorch/exir/program/fb:logger"] if not runtime.is_oss else [])
4849
)
4950

exir/program/_program.py

+43-19
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,14 @@
99
import copy
1010
import io
1111
import logging
12+
import os
1213
from typing import Any, Dict, List, Optional, Sequence, Set, TextIO, Tuple, Union
1314

1415
import torch
1516
import torch._export
16-
from executorch.exir._serialize import _serialize_pte_binary
1717
from executorch.exir._serialize._cord import Cord
18+
from executorch.exir._serialize._serialize import serialize_for_executorch
19+
from executorch.exir._serialize.data_serializer import DataSerializer
1820
from executorch.exir._warnings import experimental
1921
from executorch.exir.backend.backend_api import to_backend
2022
from executorch.exir.backend.partitioner import Partitioner
@@ -59,6 +61,7 @@
5961
EXIREdgeDialectVerifier,
6062
get_aten_verifier,
6163
)
64+
from executorch.extension.flat_tensor.serialize.serialize import FlatTensorSerializer
6265
from torch._export.passes import ReplaceViewOpsWithViewCopyOpsPass
6366
from torch.export import ExportedProgram
6467
from torch.export._remove_auto_functionalized_pass import (
@@ -497,23 +500,31 @@ def __init__(
497500
)
498501
self.exported_program = exir_exported_program.exported_program
499502
self._pte_data: Optional[Cord] = None
503+
self._tensor_data: Optional[Dict[str, Cord]] = None
500504
self._buffer: Optional[bytes] = None
501505
self._emitter_output: Optional[EmitterOutput] = None
502506
self._emit_stacktrace: bool = emit_stacktrace
503507
self._extract_delegate_segments: bool = extract_delegate_segments
504508
self._segment_alignment: int = segment_alignment
505509
self._constant_tensor_alignment: Optional[int] = constant_tensor_alignment
506510
self._delegate_alignment: Optional[int] = delegate_alignment
511+
self._data_serializer: DataSerializer = FlatTensorSerializer()
512+
513+
def _get_emitter_output(self) -> EmitterOutput:
514+
if self._emitter_output is None:
515+
self._emitter_output = emit_program(
516+
self.exported_program, self._emit_stacktrace
517+
)
518+
return self._emitter_output
507519

508520
def _get_pte_data(self) -> Cord:
509521
if self._pte_data is None:
510-
self._pte_data = _serialize_pte_binary(
511-
program=self.program,
512-
extract_delegate_segments=self._extract_delegate_segments,
513-
segment_alignment=self._segment_alignment,
514-
constant_tensor_alignment=self._constant_tensor_alignment,
515-
delegate_alignment=self._delegate_alignment,
522+
self._pte_data, self._tensor_data = serialize_for_executorch(
523+
self._get_emitter_output(),
524+
ExecutorchBackendConfig(),
525+
self._data_serializer,
516526
)
527+
assert self._pte_data is not None
517528
return self._pte_data
518529

519530
@property
@@ -532,11 +543,7 @@ def buffer(self) -> bytes:
532543

533544
@property
534545
def program(self) -> Program:
535-
if self._emitter_output is None:
536-
self._emitter_output = emit_program(
537-
self.exported_program, self._emit_stacktrace
538-
)
539-
return self._emitter_output.program
546+
return self._get_emitter_output().program
540547

541548
@property
542549
def debug_handle_map(self) -> Dict[int, Union[int, List[int]]]:
@@ -571,6 +578,17 @@ def write_to_file(self, open_file: io.BufferedIOBase) -> None:
571578
"""
572579
self._get_pte_data().write_to_file(open_file)
573580

581+
def write_tensor_data_to_file(self, outdir) -> None:
582+
"""
583+
Writes the serialized ExecuTorch data files to the directory at `outdir`.
584+
"""
585+
assert self._tensor_data is not None
586+
# pyre-ignore[16]: `Optional` has no attribute `items`.
587+
for filename, cord in self._tensor_data.items():
588+
with open(os.path.join(outdir, f"{filename}.ptd"), "wb") as f:
589+
logging.info(f"Writing data file to {filename}.ptd")
590+
cord.write_to_file(f)
591+
574592

575593
def _get_aten_to_edge_passes(config: EdgeCompileConfig):
576594
# TODO: the last two passes for aten_to_edge need to be eliminated_dead_code -> debug_handle_generator. After enable
@@ -1453,13 +1471,9 @@ def __init__(
14531471
)
14541472

14551473
# Serialize emitter output, ready to be written to a file.
1456-
self._pte_data: Cord = _serialize_pte_binary(
1457-
program=self._emitter_output.program,
1458-
mutable_data=self._emitter_output.mutable_data,
1459-
extract_delegate_segments=backend_config.extract_delegate_segments,
1460-
segment_alignment=backend_config.segment_alignment,
1461-
constant_tensor_alignment=backend_config.constant_tensor_alignment,
1462-
delegate_alignment=backend_config.delegate_alignment,
1474+
self._data_serializer = FlatTensorSerializer()
1475+
self._pte_data, self._tensor_data = serialize_for_executorch(
1476+
self._emitter_output, ExecutorchBackendConfig(), self._data_serializer
14631477
)
14641478
self._buffer: Optional[bytes] = None
14651479

@@ -1542,6 +1556,16 @@ def write_to_file(self, open_file: io.BufferedIOBase) -> None:
15421556
"""
15431557
self._pte_data.write_to_file(open_file)
15441558

1559+
def write_tensor_data_to_file(self, outdir) -> None:
1560+
"""
1561+
Writes the serialized ExecuTorch data files to the directory at `outdir`.
1562+
"""
1563+
assert self._tensor_data is not None
1564+
for filename, cord in self._tensor_data.items():
1565+
with open(os.path.join(outdir, f"{filename}.ptd"), "wb") as f:
1566+
logging.info(f"Writing data file to {filename}")
1567+
cord.write_to_file(f)
1568+
15451569
def save(self, path: str) -> None:
15461570
"""
15471571
Saves the serialized ExecuTorch binary to the file at `path`.

extension/export_util/utils.py

+3
Original file line numberDiff line numberDiff line change
@@ -135,9 +135,12 @@ def save_pte_program(
135135
filename = os.path.join(output_dir, f"{model_name}.pte")
136136

137137
try:
138+
# Write program to file.
138139
with open(filename, "wb") as file:
139140
prog.write_to_file(file)
140141
logging.info(f"Saved exported program to {filename}")
142+
# Write data to file/s.
143+
prog.write_tensor_data_to_file(outdir=output_dir)
141144
except Exception as e:
142145
logging.error(f"Error while saving to {filename}: {e}")
143146

0 commit comments

Comments
 (0)