-
Notifications
You must be signed in to change notification settings - Fork 499
/
Copy path__init__.py
100 lines (89 loc) · 3.67 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import List, Optional, Type
from executorch.backends.xnnpack._passes.channels_last_tagged_reshape_pass import (
ChannelsLastTaggedReshapePass,
)
from executorch.backends.xnnpack._passes.conv1d_unsqueeze_pass import (
Conv1dUnsqueezePass,
)
from executorch.backends.xnnpack._passes.convert_squeeze_to_view_pass import (
ConvertSqueezeToViewPass,
)
from executorch.backends.xnnpack._passes.convert_to_linear import ConvertToLinearPass
from executorch.backends.xnnpack._passes.convert_to_sdpa import ConvertToSDPAPass
from executorch.backends.xnnpack._passes.convert_to_upsample_bilinear2d import (
ConvertToUpsampleBilinear2d,
)
from executorch.backends.xnnpack._passes.decompose_cat import DecomposeConcatenate
from executorch.backends.xnnpack._passes.fuse_activation_pass import FuseActivationPass
from executorch.backends.xnnpack._passes.fuse_batch_norm_with_conv import (
FuseBatchNormWithConvPass,
)
from executorch.backends.xnnpack._passes.prelu_reshape_pass import PReLUReshapePass
from executorch.backends.xnnpack._passes.remove_getitem_op import RemoveGetItemPass
from executorch.backends.xnnpack._passes.tag_implicit_q_dq_pass import (
TagImplicitQDqPass,
)
from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass
from executorch.exir.pass_base import ExportPass
from executorch.exir.passes.const_prop_pass import ConstPropPass
from executorch.exir.passes.memory_format_ops_pass import DimOrderOpsRevertPass
from executorch.exir.program._program import _transform
from torch._export.pass_base import PassType
from torch.export import ExportedProgram
class XNNPACKPassManager:
def __init__(
self,
exported_program: ExportedProgram,
passes: Optional[List[Type[PassType]]] = None,
) -> None:
"""
A helper class to run multiple XNNPACK passes on a program
If passes list is empty, all passes in XNNPACK will be run.
Else only run passes in the list will be run.
"""
self._exported_program = exported_program
if not passes:
# All the XNNPACK passes
self.passes = [
# TODO - remove this pass once we have a better support for dim_order ops lowering
DimOrderOpsRevertPass,
ConvertToUpsampleBilinear2d,
ConvertToLinearPass,
ConvertToSDPAPass,
ConstPropPass,
FuseBatchNormWithConvPass,
FuseActivationPass,
DecomposeConcatenate,
RemoveGetItemPass,
Conv1dUnsqueezePass,
ConvertSqueezeToViewPass,
PReLUReshapePass,
ChannelsLastTaggedReshapePass,
TagImplicitQDqPass,
]
else:
self.passes = passes
@property
def exported_program(self) -> ExportedProgram:
return self._exported_program
def transform(self) -> ExportedProgram:
"""
Returns a transformed ExportedProgram
"""
ep = self.exported_program
for pass_ in self.passes:
if issubclass(pass_, XNNPACKPass):
transform_pass = pass_(ep)
elif issubclass(pass_, ExportPass):
transform_pass = pass_()
else:
raise RuntimeError(
f"Expecting ExportPass or ExportPass(), but got pass: {pass_} with type: {type(pass_)}"
)
ep = _transform(ep, transform_pass)
return ep