9
9
import copy
10
10
import io
11
11
import logging
12
+ import os
12
13
from typing import Any , Dict , List , Optional , Sequence , Set , TextIO , Tuple , Union
13
14
14
15
import torch
15
16
import torch ._export
16
- from executorch .exir ._serialize import _serialize_pte_binary
17
17
from executorch .exir ._serialize ._cord import Cord
18
+ from executorch .exir ._serialize ._serialize import serialize_for_executorch
19
+ from executorch .exir ._serialize .data_serializer import DataSerializer
18
20
from executorch .exir ._warnings import experimental
19
21
from executorch .exir .backend .backend_api import to_backend
20
22
from executorch .exir .backend .partitioner import Partitioner
59
61
EXIREdgeDialectVerifier ,
60
62
get_aten_verifier ,
61
63
)
64
+ from executorch .extension .flat_tensor .serialize .serialize import FlatTensorSerializer
62
65
from torch ._export .passes import ReplaceViewOpsWithViewCopyOpsPass
63
66
from torch .export import ExportedProgram
64
67
from torch .export ._remove_auto_functionalized_pass import (
@@ -497,23 +500,31 @@ def __init__(
497
500
)
498
501
self .exported_program = exir_exported_program .exported_program
499
502
self ._pte_data : Optional [Cord ] = None
503
+ self ._tensor_data : Optional [Dict [str , Cord ]] = None
500
504
self ._buffer : Optional [bytes ] = None
501
505
self ._emitter_output : Optional [EmitterOutput ] = None
502
506
self ._emit_stacktrace : bool = emit_stacktrace
503
507
self ._extract_delegate_segments : bool = extract_delegate_segments
504
508
self ._segment_alignment : int = segment_alignment
505
509
self ._constant_tensor_alignment : Optional [int ] = constant_tensor_alignment
506
510
self ._delegate_alignment : Optional [int ] = delegate_alignment
511
+ self ._data_serializer : DataSerializer = FlatTensorSerializer ()
512
+
513
+ def _get_emitter_output (self ) -> EmitterOutput :
514
+ if self ._emitter_output is None :
515
+ self ._emitter_output = emit_program (
516
+ self .exported_program , self ._emit_stacktrace
517
+ )
518
+ return self ._emitter_output
507
519
508
520
def _get_pte_data (self ) -> Cord :
509
521
if self ._pte_data is None :
510
- self ._pte_data = _serialize_pte_binary (
511
- program = self .program ,
512
- extract_delegate_segments = self ._extract_delegate_segments ,
513
- segment_alignment = self ._segment_alignment ,
514
- constant_tensor_alignment = self ._constant_tensor_alignment ,
515
- delegate_alignment = self ._delegate_alignment ,
522
+ self ._pte_data , self ._tensor_data = serialize_for_executorch (
523
+ self ._get_emitter_output (),
524
+ ExecutorchBackendConfig (),
525
+ self ._data_serializer ,
516
526
)
527
+ assert self ._pte_data is not None
517
528
return self ._pte_data
518
529
519
530
@property
@@ -532,11 +543,7 @@ def buffer(self) -> bytes:
532
543
533
544
@property
534
545
def program (self ) -> Program :
535
- if self ._emitter_output is None :
536
- self ._emitter_output = emit_program (
537
- self .exported_program , self ._emit_stacktrace
538
- )
539
- return self ._emitter_output .program
546
+ return self ._get_emitter_output ().program
540
547
541
548
@property
542
549
def debug_handle_map (self ) -> Dict [int , Union [int , List [int ]]]:
@@ -571,6 +578,17 @@ def write_to_file(self, open_file: io.BufferedIOBase) -> None:
571
578
"""
572
579
self ._get_pte_data ().write_to_file (open_file )
573
580
581
+ def write_tensor_data_to_file (self , outdir ) -> None :
582
+ """
583
+ Writes the serialized ExecuTorch data files to the directory at `outdir`.
584
+ """
585
+ assert self ._tensor_data is not None
586
+ # pyre-ignore[16]: `Optional` has no attribute `items`.
587
+ for filename , cord in self ._tensor_data .items ():
588
+ with open (os .path .join (outdir , f"{ filename } .ptd" ), "wb" ) as f :
589
+ logging .info (f"Writing data file to { filename } .ptd" )
590
+ cord .write_to_file (f )
591
+
574
592
575
593
def _get_aten_to_edge_passes (config : EdgeCompileConfig ):
576
594
# TODO: the last two passes for aten_to_edge need to be eliminated_dead_code -> debug_handle_generator. After enable
@@ -1453,13 +1471,9 @@ def __init__(
1453
1471
)
1454
1472
1455
1473
# Serialize emitter output, ready to be written to a file.
1456
- self ._pte_data : Cord = _serialize_pte_binary (
1457
- program = self ._emitter_output .program ,
1458
- mutable_data = self ._emitter_output .mutable_data ,
1459
- extract_delegate_segments = backend_config .extract_delegate_segments ,
1460
- segment_alignment = backend_config .segment_alignment ,
1461
- constant_tensor_alignment = backend_config .constant_tensor_alignment ,
1462
- delegate_alignment = backend_config .delegate_alignment ,
1474
+ self ._data_serializer = FlatTensorSerializer ()
1475
+ self ._pte_data , self ._tensor_data = serialize_for_executorch (
1476
+ self ._emitter_output , ExecutorchBackendConfig (), self ._data_serializer
1463
1477
)
1464
1478
self ._buffer : Optional [bytes ] = None
1465
1479
@@ -1542,6 +1556,16 @@ def write_to_file(self, open_file: io.BufferedIOBase) -> None:
1542
1556
"""
1543
1557
self ._pte_data .write_to_file (open_file )
1544
1558
1559
+ def write_tensor_data_to_file (self , outdir ) -> None :
1560
+ """
1561
+ Writes the serialized ExecuTorch data files to the directory at `outdir`.
1562
+ """
1563
+ assert self ._tensor_data is not None
1564
+ for filename , cord in self ._tensor_data .items ():
1565
+ with open (os .path .join (outdir , f"{ filename } .ptd" ), "wb" ) as f :
1566
+ logging .info (f"Writing data file to { filename } " )
1567
+ cord .write_to_file (f )
1568
+
1545
1569
def save (self , path : str ) -> None :
1546
1570
"""
1547
1571
Saves the serialized ExecuTorch binary to the file at `path`.
0 commit comments