-
Notifications
You must be signed in to change notification settings - Fork 493
/
Copy path_program.py
1581 lines (1366 loc) · 62.7 KB
/
_program.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-unsafe
import copy
import io
import logging
import os
from typing import Any, Dict, List, Optional, Sequence, Set, TextIO, Tuple, Union
import torch
import torch._export
from executorch.exir._serialize._cord import Cord
from executorch.exir._serialize._serialize import serialize_for_executorch
from executorch.exir._serialize.data_serializer import DataSerializer
from executorch.exir._warnings import experimental
from executorch.exir.backend.backend_api import to_backend
from executorch.exir.backend.partitioner import Partitioner
from executorch.exir.capture._config import EdgeCompileConfig, ExecutorchBackendConfig
from executorch.exir.emit import emit_program, EmitterOutput
from executorch.exir.emit._emitter import _DelegateDebugIdentifierMap
from executorch.exir.error import ExportError
from executorch.exir.graph_module import get_control_flow_submodules
from executorch.exir.pass_base import PassBase
from executorch.exir.pass_manager import PassType
from executorch.exir.passes import (
base_post_op_replace_passes,
base_pre_op_replace_passes,
dead_code_elimination_pass,
EdgeToBackendOpsPass,
MemoryFormatOpsPass,
OpReplacePass,
)
from executorch.exir.passes.external_constants_pass import external_constants_pass
from executorch.exir.passes.insert_write_back_for_buffers_pass import (
insert_write_back_for_buffers_pass,
)
from executorch.exir.passes.normalize_view_copy_base_pass import (
NormalizeViewCopyBasePass,
)
from executorch.exir.passes.remove_graph_asserts_pass import (
RemoveGraphAssertsPass,
RemoveNonCoreAtenOpGraphAssertsPass,
)
from executorch.exir.passes.remove_mixed_type_operators import RemoveMixedTypeOperators
from executorch.exir.passes.replace_aten_with_edge_pass import aten_to_edge
from executorch.exir.passes.replace_view_copy_with_view_pass import (
ReplaceViewCopyWithViewPass,
)
from executorch.exir.passes.spec_prop_pass import SpecPropPass
from executorch.exir.passes.weights_to_outputs_pass import weights_to_outputs_pass
from executorch.exir.print_program import pretty_print, print_program
from executorch.exir.schema import Program
from executorch.exir.tracer import _default_decomposition_table
from executorch.exir.verification.verifier import (
EXIRATenDialectVerifier,
EXIREdgeDialectVerifier,
get_aten_verifier,
)
from executorch.extension.flat_tensor.serialize.serialize import FlatTensorSerializer
from torch._export.passes import ReplaceViewOpsWithViewCopyOpsPass
from torch.export import ExportedProgram
from torch.export._remove_auto_functionalized_pass import (
unsafe_remove_auto_functionalized_pass,
)
from torch.export.exported_program import (
ConstantArgument,
ExportGraphSignature,
InputKind,
InputSpec,
OutputSpec,
TensorArgument,
)
from torch.fx import _pytree as fx_pytree
from torch.fx._compatibility import compatibility
from torch.fx.passes.infra.pass_manager import PassManager
from torch.utils import _pytree as pytree
Val = Any
from typing import Any, Callable
from torch.library import Library
try:
from executorch.exir.program.fb.logger import et_logger
except ImportError:
# Define a stub decorator that does nothing
def et_logger(api_name: str) -> Callable[[Any], Any]:
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
return func(self, *args, **kwargs)
return wrapper
return decorator
# This is the reserved namespace that is used to register ops to that will
# be prevented from being decomposed during to_edge_transform_and_lower.
edge_no_decomp_namespace = "EDGE_DO_NOT_DECOMP"
lib = Library(edge_no_decomp_namespace, "DEF")
# Map from aten ops to the transformed ops registered in the edge_no_decomp_namespace.
aten_op_to_transform_op = {}
# Map from the transformed ops registered in the edge_no_decomp_namespace to aten ops.
transform_op_to_aten_op = {}
def _get_updated_range_constraints(gm):
def get_shape_env(gm):
vals = [
node.meta["val"]
for node in gm.graph.nodes
if node.meta.get("val", None) is not None
]
from torch._guards import detect_fake_mode # type: ignore[21]
fake_mode = detect_fake_mode(vals)
if fake_mode is not None:
return fake_mode.shape_env
for v in vals:
if isinstance(v, torch.SymInt):
return v.node.shape_env
shape_env = get_shape_env(gm)
if shape_env is None:
return {}
range_constraints = {
k: v
for k, v in shape_env.var_to_range.items()
if k not in shape_env.replacements
}
# Only when we have an unbacked symint, and it's used as constructor inputs,
# runtime_var_to_range will make a difference compated to var_to_range.
# e.g. [2, oo) -> [0, oo)
for k, v in shape_env.var_to_range.items():
if k not in shape_env.replacements:
range_constraints[k] = v
return range_constraints
def _get_updated_graph_signature(
old_signature: ExportGraphSignature,
new_gm: torch.fx.GraphModule,
) -> ExportGraphSignature:
"""
Update the graph signature's user_input/user_outputs.
"""
new_input_specs = []
i = 0
for node in new_gm.graph.nodes:
if node.op != "placeholder":
continue
assert i < len(
old_signature.input_specs
), "Number of inputs changed after transformation"
old_input_spec = old_signature.input_specs[i]
arg = (
old_input_spec.arg
if isinstance(old_input_spec.arg, ConstantArgument)
# pyre-fixme[20]: Argument `class_fqn` expected.
else type(old_input_spec.arg)(node.name)
)
new_input_specs.append(
InputSpec(
old_input_spec.kind,
arg,
old_input_spec.target,
persistent=old_input_spec.persistent,
)
)
i += 1
output_node = list(new_gm.graph.nodes)[-1]
assert output_node.op == "output"
new_output_specs = []
for i, node in enumerate(output_node.args[0]):
assert i < len(
old_signature.output_specs
), "Number of outputs changed after transformation"
old_output_spec = old_signature.output_specs[i]
arg = (
old_output_spec.arg
if isinstance(old_output_spec.arg, ConstantArgument)
# pyre-fixme[20]: Argument `class_fqn` expected.
else type(old_output_spec.arg)(node.name)
)
new_output_specs.append(
OutputSpec(old_output_spec.kind, arg, old_output_spec.target)
)
new_signature = ExportGraphSignature(
input_specs=new_input_specs, output_specs=new_output_specs
)
return new_signature
def _transform(self, *passes: PassType) -> "ExportedProgram":
pm = PassManager(list(passes))
res = pm(self.graph_module)
transformed_gm = res.graph_module if res is not None else self.graph_module
assert transformed_gm is not None
if transformed_gm is self.graph_module and not res.modified:
return self
transformed_ep = ExportedProgram(
root=transformed_gm,
graph=transformed_gm.graph,
graph_signature=_get_updated_graph_signature(
self.graph_signature, transformed_gm
),
state_dict=self.state_dict,
range_constraints=_get_updated_range_constraints(transformed_gm),
module_call_graph=copy.deepcopy(self._module_call_graph),
example_inputs=self.example_inputs,
constants=self.constants,
verifiers=[self.verifier],
)
transformed_ep.graph_module.meta.update(self.graph_module.meta)
transformed_ep.graph_module.meta.update(res.graph_module.meta)
return transformed_ep
def _copy_module(new_prog, new_gm):
new_prog.meta.update(new_gm.meta)
new_prog.graph = new_gm.graph
submodules = [name for name, _ in new_prog.named_children()]
for name in submodules:
delattr(new_prog, name)
for name, mod in new_gm.named_children():
setattr(new_prog, name, mod)
for node in new_gm.graph.nodes:
if node.op == "get_attr":
t = getattr(new_gm, node.target, None)
if isinstance(t, torch.Tensor):
setattr(new_prog, node.target, t)
def lift_constant_tensor_pass(ep):
"""
Takes an ExportedProgram and returns the ExportedProgram modified in-place,
with the constant tensors as buffers.
"""
if len([node for node in ep.graph.nodes if node.op == "placeholder"]) == 0:
return ep
graph_signature = ep.graph_signature
buffers = list(graph_signature.buffers)
fake_mode = list(ep.graph.nodes)[0].meta["val"].fake_mode
first_user_input = None
lifted_constants = []
for node in ep.graph.nodes:
if node.op == "placeholder" and node.name in graph_signature.user_inputs:
first_user_input = node
break
for node in ep.graph.nodes:
if node.op == "get_attr":
constant_tensor = getattr(ep.graph_module, node.target)
if not isinstance(constant_tensor, torch.Tensor):
continue
constant_tensor_fqn = f"_lifted_tensor_constant{len(buffers)}"
with ep.graph.inserting_before(first_user_input):
# Insert the constant node before the first user input
const_placeholder_node = ep.graph.placeholder(constant_tensor_fqn)
for k, v in node.meta.items():
const_placeholder_node.meta[k] = v
if fake_mode is not None:
const_placeholder_node.meta["val"] = fake_mode.from_tensor(
constant_tensor, static_shapes=True
)
else:
const_placeholder_node.meta["val"] = constant_tensor
const_placeholder_node.meta["val"].constant = constant_tensor
node.replace_all_uses_with(const_placeholder_node)
ep.graph.erase_node(node)
# Add the constant as a buffer to the graph signature
lifted_constants.append(
InputSpec(
kind=InputKind.BUFFER,
arg=TensorArgument(name=const_placeholder_node.name),
target=constant_tensor_fqn,
persistent=True,
)
)
buffers.append(constant_tensor_fqn)
ep.state_dict[constant_tensor_fqn] = constant_tensor
new_input_specs = []
for s in graph_signature.input_specs:
if s.kind == InputKind.USER_INPUT and len(lifted_constants) > 0:
new_input_specs.extend(lifted_constants)
lifted_constants.clear()
new_input_specs.append(s)
ep.graph_signature.input_specs = new_input_specs
ep.graph_module.recompile()
return ep
# Stub to ease migration from `transform` to private `_transform`
def transform_exported_program(ep, *passes: PassType) -> ExportedProgram:
if hasattr(ep, "_transform"):
return ep._transform(*passes)
else:
return ep.transform(*passes)
class HackedUpExportedProgramDONOTUSE(ExportedProgram):
def __init__(
self,
root,
graph,
graph_signature,
call_spec,
state_dict,
range_constraints,
module_call_graph,
example_inputs,
verifier,
):
super().__init__(
root=root,
graph=graph,
graph_signature=graph_signature,
state_dict=state_dict,
range_constraints=range_constraints,
module_call_graph=module_call_graph,
example_inputs=example_inputs,
verifier=verifier,
)
def __call__(self, *args: Any, **kwargs: Any) -> Any:
import torch._export.error as error
if self.call_spec.in_spec is not None:
user_args = args
try:
args = fx_pytree.tree_flatten_spec(user_args, self.call_spec.in_spec) # type: ignore[assignment]
except Exception:
_, received_spec = pytree.tree_flatten(user_args)
raise error.InternalError(
"Trying to flatten user inputs with exported input tree spec: \n"
f"{self.call_spec.in_spec}\n"
"but actually got inputs with tree spec of: \n"
f"{received_spec}"
)
ordered_params = tuple(
self.state_dict[name] for name in self.graph_signature.parameters
)
ordered_buffers = tuple(
self.state_dict[name] for name in self.graph_signature.buffers
)
with torch.no_grad():
# NOTE: calling convention is first params, then buffers, then args as user supplied them.
# See: torch/_functorch/aot_autograd.py#L1034
res = torch.fx.Interpreter(self.graph_module).run(
*ordered_params, *ordered_buffers, *args, enable_io_processing=False
)
if self.call_spec.out_spec is not None:
mutation = self.graph_signature.buffers_to_mutate
num_mutated = len(mutation)
mutated_buffers = res[:num_mutated]
# Exclude dependency token from final result.
assertion_dep_token = self.graph_signature.assertion_dep_token
if assertion_dep_token is not None:
assertion_dep_token_index = list(assertion_dep_token.keys())[0]
res = res[:assertion_dep_token_index]
res = res[num_mutated:]
try:
res = pytree.tree_unflatten(res, self.call_spec.out_spec)
except Exception:
_, received_spec = pytree.tree_flatten(res)
raise error.InternalError(
"Trying to flatten user outputs with exported output tree spec: \n"
f"{self.call_spec.out_spec}\n"
"but actually got outputs with tree spec of: \n"
f"{received_spec}"
)
finally:
ix = 0
for buffer in self.graph_signature.buffers_to_mutate.values():
self.state_dict[buffer] = mutated_buffers[ix]
ix += 1
return res
@compatibility(is_backward_compatible=False)
class ExirExportedProgram:
def __init__(
self,
exported_program: ExportedProgram,
after_to_edge_passes: bool,
):
self.exported_program = exported_program
# Add a flag to denote whehter to_edge is called on this program
# to detect misusage of directly calling to_executorch without to_edge
self.after_to_edge_passes = after_to_edge_passes
def transform(self, *passes: PassType) -> "ExirExportedProgram":
self.exported_program = _transform(self.exported_program, *passes)
return self
def __call__(self, *args: Any) -> Any:
return self.exported_program.module()(*args)
# TODO(ycao): Change this to a composable function.
def to_edge(
self, config: Optional[EdgeCompileConfig] = None
) -> "ExirExportedProgram":
config = config or EdgeCompileConfig()
assert isinstance(
self.exported_program.graph_module, torch.fx.GraphModule
), f"type is instead: {type(self.exported_program.graph_module).__name__}"
return _to_edge(self, config)
def dump(self) -> None:
print(self.exported_program.graph_module.graph)
def to_executorch(
self,
config: Optional[ExecutorchBackendConfig] = None,
) -> "ExecutorchProgram":
if not self.after_to_edge_passes:
raise RuntimeError("Must run to_edge before to_executorch.")
config = config or ExecutorchBackendConfig()
new_gm = self.exported_program.graph_module
for p in edge_to_executorch_passes(config):
new_gm_res = p(new_gm)
assert new_gm_res is not None
new_gm = new_gm_res.graph_module
# This is tech debt on tech debt. memory planning pass inherits from some pass infra for GMs.
# This isnt enough info now so i cant use call I have to use some new function 'run'.
# Existing user passes dont use run so Im just cheating here because they dont need to work on mutable buffers yet.
# After exir.capture is gone I will clean up the memory planning infra to be consistent.
# Frankly all of exir has big code quality issues because of the migrations that need to be addressed.
new_gm_res = config.memory_planning_pass(new_gm) # pyre-ignore[29]
assert new_gm_res is not None
new_gm = new_gm_res.graph_module
new_prog = ExirExportedProgram(
copy.deepcopy(self.exported_program), self.after_to_edge_passes
)
_copy_module(new_prog.exported_program.graph_module, new_gm)
executorch_prog = ExecutorchProgram(
new_prog,
emit_stacktrace=config.emit_stacktrace,
extract_delegate_segments=config.extract_delegate_segments,
segment_alignment=config.segment_alignment,
constant_tensor_alignment=config.constant_tensor_alignment,
delegate_alignment=config.delegate_alignment,
)
executorch_prog.graph_module.meta.update(new_gm.meta)
executorch_prog.graph_module.meta.update(
self.exported_program.graph_module.meta
)
return executorch_prog
def __deepcopy__(
self, memo: Optional[Dict[int, Any]] = None
) -> "ExirExportedProgram":
new_eep = ExirExportedProgram(
copy.deepcopy(self.exported_program, memo),
self.after_to_edge_passes,
)
return new_eep
@compatibility(is_backward_compatible=False)
class ExecutorchProgram:
def __init__(
self,
exir_exported_program: ExirExportedProgram,
emit_stacktrace: bool,
extract_delegate_segments: bool,
segment_alignment: int,
constant_tensor_alignment: Optional[int] = None,
delegate_alignment: Optional[int] = None,
) -> None:
if not exir_exported_program.after_to_edge_passes:
raise RuntimeError(
"Need to call prog.to_edge prior to constructing ExecutorchProgram."
)
self.exported_program = exir_exported_program.exported_program
self._pte_data: Optional[Cord] = None
self._tensor_data: Optional[Dict[str, Cord]] = None
self._buffer: Optional[bytes] = None
self._emitter_output: Optional[EmitterOutput] = None
self._emit_stacktrace: bool = emit_stacktrace
self._extract_delegate_segments: bool = extract_delegate_segments
self._segment_alignment: int = segment_alignment
self._constant_tensor_alignment: Optional[int] = constant_tensor_alignment
self._delegate_alignment: Optional[int] = delegate_alignment
self._data_serializer: DataSerializer = FlatTensorSerializer()
def _get_emitter_output(self) -> EmitterOutput:
if self._emitter_output is None:
self._emitter_output = emit_program(
self.exported_program, self._emit_stacktrace
)
return self._emitter_output
def _get_pte_data(self) -> Cord:
if self._pte_data is None:
self._pte_data, self._tensor_data = serialize_for_executorch(
self._get_emitter_output(),
ExecutorchBackendConfig(),
self._data_serializer,
)
assert self._pte_data is not None
return self._pte_data
@property
def buffer(self) -> bytes:
"""Returns the serialized ExecuTorch binary as a byte string.
Note that the call to `buffer` may allocate a very large amount of
contiguous memory, depending on the model size. If writing to a file,
use `write_to_file` which won't incur additional copies.
"""
# TODO(T181494963): update pybinding to remove buffer cache, which can consume large
# amounts of memory longer than necessary.
if self._buffer is None:
self._buffer = bytes(self._get_pte_data())
return self._buffer
@property
def program(self) -> Program:
return self._get_emitter_output().program
@property
def debug_handle_map(self) -> Dict[int, Union[int, List[int]]]:
if self._emitter_output:
return self._emitter_output.debug_handle_map
return {}
@property
def delegate_map(
self,
) -> Dict[str, Dict[int, Dict[str, Union[str, _DelegateDebugIdentifierMap]]]]:
if self._emitter_output:
return self._emitter_output.method_to_delegate_debug_id_map
return {}
@property
def graph_module(self) -> torch.fx.GraphModule:
return self.exported_program.graph_module
# TODO (zhxchen17) Change this to property.
def dump_graph_module(self) -> torch.fx.GraphModule:
return self.exported_program.graph_module
def dump_exported_program(self) -> ExportedProgram:
return self.exported_program
def write_to_file(self, open_file: io.BufferedIOBase) -> None:
"""
Writes the serialized ExecuTorch binary to the file at `open_file`. Prefer to use this over
`buffer`, as it writes to file without copying into a contiguous block of memory first,
reducing the peak memory usage.
"""
self._get_pte_data().write_to_file(open_file)
def write_tensor_data_to_file(self, outdir) -> None:
"""
Writes the serialized ExecuTorch data files to the directory at `outdir`.
"""
assert self._tensor_data is not None
# pyre-ignore[16]: `Optional` has no attribute `items`.
for filename, cord in self._tensor_data.items():
with open(os.path.join(outdir, f"{filename}.ptd"), "wb") as f:
logging.info(f"Writing data file to {filename}.ptd")
cord.write_to_file(f)
def _get_aten_to_edge_passes(config: EdgeCompileConfig):
# TODO: the last two passes for aten_to_edge need to be eliminated_dead_code -> debug_handle_generator. After enable
# use_edge_op it can be moved to aten_to_edge_passes before eliminated_dead_code pass. Also ExportPass doesn't play
# well with node.meta, meaning after some passes permuting operators, we may lose some information in node.meta.
# It might be regenerated in SpecPropPass so it may not be visiable. However debug handle will be lost.
pre_op_replace_passes = base_pre_op_replace_passes + (
[] if config._skip_type_promotion else [RemoveMixedTypeOperators()]
)
post_op_replace_passes = base_post_op_replace_passes
return pre_op_replace_passes, post_op_replace_passes
def _to_edge(ep, config: EdgeCompileConfig) -> "ExirExportedProgram":
if config._check_ir_validity:
try:
EXIRATenDialectVerifier()(ep.exported_program.graph_module)
except ExportError:
logging.info(
"If a particular operator failed core ATen IR check, please consider adding it to the exception list. "
"Add the operator to _core_aten_ops_exception_list in EdgeCompileConfig. This is the recommended way "
"to resolve this type of failure, so that the rest of the IR validation check can still be performed.\n"
"If you'd like to disable IR validation checking, please set _check_ir_validity in EdgeCompileConfig, "
"like *.to_edge(exir.EdgeCompileConfig(_check_ir_validity=False))."
)
raise
dialect = ep.exported_program.dialect
if dialect == "ATEN":
ep = ExirExportedProgram(
ExportedProgram(
root=ep.exported_program.graph_module,
graph=ep.exported_program.graph_module.graph,
graph_signature=ep.exported_program.graph_signature,
state_dict=ep.exported_program.state_dict,
range_constraints=ep.exported_program.range_constraints,
module_call_graph=ep.exported_program.module_call_graph,
example_inputs=ep.exported_program.example_inputs,
constants=ep.exported_program.constants,
verifiers=[
get_aten_verifier(
config=config,
)
],
),
False,
)
pre_op_replace_passes, post_op_replace_passes = _get_aten_to_edge_passes(config)
new_ep = copy.deepcopy(ep).transform(*pre_op_replace_passes)
if dialect == "ATEN":
new_ep.exported_program = lift_constant_tensor_pass(new_ep.exported_program)
new_gm = new_ep.exported_program.graph_module
if config._use_edge_ops:
new_gm_res = OpReplacePass()(new_gm)
assert new_gm_res is not None
new_gm = new_gm_res.graph_module
if not config._skip_dim_order:
new_gm_res = MemoryFormatOpsPass()(new_gm)
assert new_gm_res is not None
new_gm = new_gm_res.graph_module
for p in post_op_replace_passes:
new_gm_res = p(new_gm)
assert new_gm_res is not None
new_gm = new_gm_res.graph_module
new_ep.exported_program = ExportedProgram(
root=new_gm,
graph=new_gm.graph,
graph_signature=_get_updated_graph_signature(
new_ep.exported_program.graph_signature, new_gm
),
state_dict=new_ep.exported_program.state_dict,
range_constraints=new_ep.exported_program.range_constraints,
module_call_graph=new_ep.exported_program.module_call_graph,
example_inputs=new_ep.exported_program.example_inputs,
constants=new_ep.exported_program.constants,
verifiers=[
EXIREdgeDialectVerifier(
edge_compile_config=config,
class_only=True,
)
],
)
new_ep.after_to_edge_passes = True
return new_ep
def pre_memory_planning_passes(
config: ExecutorchBackendConfig, name: Optional[str] = None
) -> List[PassType]:
"""
Returns a list of passes to run before memory planning.
Get the sym shape eval pass based on the method name, if the pass is not in the dict, use the default pass.
"""
# Handle symbolic shape eval pass
if isinstance(config.sym_shape_eval_pass, dict):
default_pass = ExecutorchBackendConfig().sym_shape_eval_pass
if not name:
sym_shape_eval_pass = default_pass
# pyre-ignore: Undefined attribute [16]
sym_shape_eval_pass = config.sym_shape_eval_pass.get(name, default_pass)
elif isinstance(config.sym_shape_eval_pass, PassBase):
sym_shape_eval_pass = config.sym_shape_eval_pass
else:
raise RuntimeError(
f"sym_shape_eval_pass must be a dict or a PassBase, got {config.sym_shape_eval_pass}"
)
if config.remove_view_copy:
return [
NormalizeViewCopyBasePass(),
dead_code_elimination_pass,
ReplaceViewCopyWithViewPass(),
sym_shape_eval_pass,
config.to_out_var_pass,
]
else:
return [
sym_shape_eval_pass,
config.to_out_var_pass,
]
def edge_to_executorch_passes(
config: ExecutorchBackendConfig, name: Optional[str] = None
) -> List[PassType]:
"""
Returns a list of passes to lower from edge to executorch.
Get the pre memory planning passes based on the method name, if the pass is not in the dict, use the default pass.
"""
passes: List[PassType] = [
*config.passes,
SpecPropPass(),
# ExecuTorch backend ops are unable to handle unbacked symints. So after
# this pass, passes cannot be Interpreter-based, because it will fail if
# there exists an unbacked symint operation.
EdgeToBackendOpsPass(),
RemoveGraphAssertsPass(),
] + pre_memory_planning_passes(config, name)
return passes
def _generate_edge_program(
name: str,
config: EdgeCompileConfig,
program: ExportedProgram,
ops_set_to_not_decompose: Optional[List[torch._ops.OpOverload]] = None,
) -> ExportedProgram:
# Remove invalid assert ops, such as _assert_tensor_metadata
gm = program.graph_module
gm_res = RemoveNonCoreAtenOpGraphAssertsPass()(gm)
assert gm_res is not None
gm = gm_res.graph_module
if config._check_ir_validity:
try:
EXIRATenDialectVerifier(
edge_compile_config=config,
class_only=False,
exception_list=ops_set_to_not_decompose,
)(gm)
except ExportError as e:
logging.info(f"Input program {name} is not in ATen dialect.")
raise e
pre_op_replace_passes, post_op_replace_passes = _get_aten_to_edge_passes(config)
passes = []
passes.append(
ReplaceViewOpsWithViewCopyOpsPass()
) # TODO move inside aten_to_edge passes after all users are migrated off v1 capture
passes.extend(pre_op_replace_passes)
if config._use_edge_ops:
passes.append(OpReplacePass())
if not config._skip_dim_order:
passes.append(MemoryFormatOpsPass())
for p in passes:
gm_res = p(gm)
assert gm_res is not None
gm = gm_res.graph_module
edge_program = ExportedProgram(
root=gm,
graph=gm.graph,
graph_signature=_get_updated_graph_signature(program.graph_signature, gm),
state_dict=program.state_dict,
range_constraints=program.range_constraints,
module_call_graph=program.module_call_graph,
example_inputs=program.example_inputs,
constants=program.constants,
verifiers=[
EXIREdgeDialectVerifier(
edge_compile_config=config,
class_only=True,
exception_list=ops_set_to_not_decompose,
)
],
)
# Lift the tensor constants created in ScalarToTensorPass
edge_program = lift_constant_tensor_pass(edge_program)
edge_program = _transform(edge_program, *post_op_replace_passes)
return edge_program
def _replace_aten_ops_with_transformed_ops(
name: str,
program: ExportedProgram,
partitioner,
):
ops_to_not_decompose = set()
partitioners = partitioner.get(name)
if partitioners is None:
return
# Iterate through the graph and replace the aten ops with the corresponding
# transformed ops.
for partitioner in partitioners:
ops_set_to_not_decompose, check_op_support = partitioner.ops_to_not_decompose(
program
)
for op_aten in ops_set_to_not_decompose:
_register_no_decomp_op(op_aten)
for node in program.graph.nodes:
is_op_supported = check_op_support(node) if check_op_support else True
if (
node.op == "call_function"
and node.target in ops_set_to_not_decompose
and is_op_supported
):
ops_to_not_decompose.add(node.target)
node.target = aten_op_to_transform_op[node.target]
for _, submod, _ in get_control_flow_submodules(program.graph_module):
for node in submod.graph.nodes:
is_op_supported = check_op_support(node) if check_op_support else True
if (
node.op == "call_function"
and node.target in ops_set_to_not_decompose
and is_op_supported
):
ops_to_not_decompose.add(node.target)
node.target = aten_op_to_transform_op[node.target]
return ops_to_not_decompose
def _restore_transformed_ops_to_aten_ops(program: ExportedProgram):
# Iterate through the graph and replace back the transformed ops with their
# corresponding aten ops.
for node in program.graph.nodes:
if node.op == "call_function" and str(node.target) in transform_op_to_aten_op:
node.target = transform_op_to_aten_op[str(node.target)]
for _, submod, _ in get_control_flow_submodules(program.graph_module):
for node in submod.graph.nodes:
if (
node.op == "call_function"
and str(node.target) in transform_op_to_aten_op
):
node.target = transform_op_to_aten_op[str(node.target)]
# Returns the op in edge_no_decomp_namespace namespace for the aten
# op that is passed in.
def _get_transformed_op(op_aten):
op_name = op_aten._schema.name.split("::")[1]
overload_name = op_aten._schema.overload_name
assert hasattr(
torch.ops, edge_no_decomp_namespace
), f"Couldn't find {edge_no_decomp_namespace} in torch.ops. Please make sure the Library has been registered."
op_namespace = getattr(torch.ops, edge_no_decomp_namespace)
op = getattr(op_namespace, op_name)
return getattr(op, overload_name)
# Registers the op in edge_no_decomp_namespace namespace for the aten
# op that is passed in if it is not already cached in the table.
def _register_no_decomp_op(op_aten):
# Check if the op is already cached in the table. If not, then we need to
# create a new op in the edge_no_decomp_namespace namespace.
if aten_op_to_transform_op.get(op_aten) is None and isinstance(
op_aten, torch._ops.OpOverload
):
# Extract the schema from the aten op.
op_schema = str(op_aten._schema).split("::")[1]
op_name = op_aten._schema.name.split("::")[1]
# Define an op in the edge_no_decomp_namespace namespace with the aten schema.
lib.define(op_schema)
# Define the implementation of the op in the edge_no_decomp_namespace namespace.
# Important to note that the implementation of the op is the same as the aten op.
overload_name = op_aten._schema.overload_name
if overload_name != "":
op_name += "." + overload_name
lib.impl(op_name, op_aten, "CompositeExplicitAutograd")
# Cache the aten op and transformed op in their corresponding tables for future use.
aten_op_to_transform_op[op_aten] = _get_transformed_op(op_aten)
transform_op_to_aten_op[str(aten_op_to_transform_op[op_aten])] = op_aten
def _sanity_check_graph_for_non_decomp_ops(
name: str,
program: ExportedProgram,
ops_set_to_not_decompose,
check_op_support,
generate_error=False,
partitioner_name=None,
):
warning_str_end = ""
if partitioner_name is not None:
warning_str_end += f"This op was registered by the partitioner {partitioner_name} to not be decomposed.\n"
warning_str_end += f"The following ops: {ops_set_to_not_decompose} were specified to not be decomposed in {name}."
# Check that the ops that were registered to not be decomposed are not present in the
# graph anymore as the transform passes and backends should have consumed them by now.
ops_set_to_not_decompose = {
aten_to_edge(op) for op in ops_set_to_not_decompose
}.union(ops_set_to_not_decompose)
for node in program.graph_module.graph.nodes:
is_op_supported = check_op_support(node) if check_op_support else True
if (
node.op == "call_function" and node.target in ops_set_to_not_decompose
) and is_op_supported:
warning_str = (
f"Node {node} with op {node.target} was not decomposed or delegated.\n"
+ warning_str_end
)
if generate_error:
raise RuntimeError(warning_str)
else:
logging.warning(warning_str)
for _, submod, _ in get_control_flow_submodules(program.graph_module):
for node in submod.graph.nodes:
is_op_supported = check_op_support(node) if check_op_support else True
if (
node.op == "call_function" and node.target in ops_set_to_not_decompose
) and is_op_supported:
warning_str = (
f"Node {node} with op {node.target} was not decomposed or delegated.\n"
+ warning_str_end
)
if generate_error:
raise RuntimeError(warning_str)
else:
logging.warning(warning_str)
def _gen_edge_manager_for_partitioners(
partitioner: Dict[str, List[Partitioner]],
aten_programs: Dict[str, ExportedProgram],
config: EdgeCompileConfig,
constant_methods: Optional[Dict[str, Any]],
) -> "EdgeProgramManager":
"""
Generates EdgeProgramManager for subsequent lowering to the
partitioners specified by partitioner. The EdgeProgramManager is generated from
aten_programs.
Partitioners specify what nodes should not be decomposed from the original aten programs.
This is done through two passes of run_decompositions.
- First pass preserves all aten_targets specified by partitioners to preserve
them from nested decompositions
- Second pass uses check_op fn provided by partitioners to perform additional checks
on nodes with preserved aten targets. They are then replaces with transformed ops to
keep them through the second pass of decompositions
"""
ops_set_to_not_decompose_by_program = {}
edge_programs: Dict[str, ExportedProgram] = {}
for name, program in aten_programs.items():
if partitioner is not None:
# preserve all ops listed by all partitioners first
all_ops_no_decomp = set()
for curr_partitioner in partitioner.get(name, []):
curr_ops_no_decomp, _ = curr_partitioner.ops_to_not_decompose(program)
all_ops_no_decomp |= set(curr_ops_no_decomp)
table = _default_decomposition_table()
for op in all_ops_no_decomp:
table.pop(op, None)
program = program.run_decompositions(table)
# Among all the preserved aten ops, use the check_op_fn to do an additional
# check on which ops need to be preserved and which ops need to be decomposed
# Those which are truly preserved will be replaced with transformed ops
ops_set_to_not_decompose_by_program[name] = (
_replace_aten_ops_with_transformed_ops(name, program, partitioner) or []
)
program = program.run_decompositions(_default_decomposition_table())
_restore_transformed_ops_to_aten_ops(program)
edge_programs[name] = program
edge_programs[name] = _generate_edge_program(
name,
config,
program,