forked from pytorch/executorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcompiler.py
330 lines (278 loc) · 10.8 KB
/
compiler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
import logging
from pathlib import Path
from typing import Callable, cast, Optional
import executorch.backends.cadence.aot.ops_registrations # noqa
import torch
from executorch.backends.cadence.aot.memory_planning import (
CadenceMemoryPlanning,
print_memory_planning_info,
)
from executorch.backends.cadence.aot.quantizer.fusion_pass import QuantFusion
from executorch.backends.cadence.aot.quantizer.quantizer import (
CadenceDefaultQuantizer,
CadenceQuantizer,
)
from executorch.backends.cadence.aot.utils import (
get_default_memory_config,
MemoryConfig,
)
from executorch.devtools import generate_etrecord
from executorch.exir import (
EdgeCompileConfig,
EdgeProgramManager,
ExecutorchBackendConfig,
ExecutorchProgramManager,
to_edge,
)
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import PassResult
from executorch.exir.passes import ToOutVarPass
from executorch.exir.passes.sym_shape_eval_pass import HintBasedSymShapeEvalPass
from torch._inductor.decomposition import remove_decompositions
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
from torch.export import export
from torch.export.exported_program import ExportedProgram
from .passes import get_cadence_passes
from .utils import print_ops_info
# Note: this is not meant as a primary API since it can create inconsistencies
# if the quantizer here is different from the quantizer used to convert. It is
# however useful for unit tests to separate the converted model from the fused
# model, to be able to get reference numerics.
# If this does not apply, please use quantize_and_fuse_pt2 instead.
def convert_pt2(
model: torch.nn.Module,
inputs: tuple[object, ...],
quantizer: CadenceQuantizer,
) -> torch.fx.GraphModule:
"""
Prepare and convert a model using the given quantizer.
The quantizer must be supplied and be the same as the one used to
fuse the model later, if applicable. If you do not expect that behavior,
please use quantize_and_fuse_pt2 instead, which will instantiate a
default quantizer for you if needed.
Returns a GraphModule with the converted model.
"""
# Get default decompositions
decomp_table = torch.export.default_decompositions()
# Select ops to keep
ops_to_keep = [
torch.ops.aten.conv1d.default,
torch.ops.aten.conv2d.default,
torch.ops.aten.layer_norm.default,
torch.ops.aten.linear.default,
torch.ops.aten.matmul.default,
]
# Remove decompositions for the ops we want to keep
# pyre-fixme[6]: For 1st argument expected `Dict[typing.Callable[..., typing.Any
remove_decompositions(decomp_table, ops_to_keep)
# Export with dynamo
model_gm = (
torch.export.export_for_training(model, inputs)
.run_decompositions(decomp_table)
.module()
)
# Prepare
prepared_model = prepare_pt2e(model_gm, quantizer)
# Calibrate
prepared_model(*inputs)
# Convert
converted_model = convert_pt2e(prepared_model)
return converted_model
# Note: this is not meant as a primary API since it can create inconsistencies
# if the quantizer here is different from the quantizer used to convert. It is
# however useful for unit tests to separate the converted model from the fused
# model, to be able to get reference numerics.
# If this does not apply, please use quantize_and_fuse_pt2 instead.
def fuse_pt2(
converted_graph_module: torch.fx.GraphModule,
quantizer: CadenceQuantizer,
) -> torch.fx.GraphModule:
"""
Fuse a converted graph module using the given quantizer.
The quantizer must be the same as the one used to convert the model.
If you do not expect that behavior, please use quantize_and_fuse_pt2 instead,
which will instantiate a default quantizer for you if needed.
Returns a GraphModule with the fused model.
"""
# Get patterns and apply fusion of dq -> op -> q to qop
# pyre-ignore[16]: no attribute
patterns = [q.pattern for q in quantizer.quantizers]
QuantFusion(patterns)(converted_graph_module)
return converted_graph_module
# Note: this is the one-liner API to quantize and fuse a model.
def quantize_pt2(
model: torch.nn.Module,
inputs: tuple[object, ...],
quantizer: Optional[CadenceQuantizer] = None,
) -> torch.fx.GraphModule:
"""
Prepare, convert and fuse the model using the given quantizer.
Returns a GraphModule with the quantized model.
"""
# Make the model inference mode by calling model.eval()
model.eval()
# Instantiate the quantizer to CadenceQuantizer if not supplied
if not quantizer:
quantizer = CadenceDefaultQuantizer()
# Get converted graph module
converted_gm = convert_pt2(model, inputs, quantizer)
# Get fused model
fused_gm = fuse_pt2(converted_gm, quantizer)
return fused_gm
# Export the model and lower it to an ExportedProgram (in aten IR)
def export_program(
model: torch.nn.Module,
inputs: tuple[object, ...],
dump_graphs: bool = False,
) -> ExportedProgram:
assert isinstance(model, torch.nn.Module), "model should be an nn.Module"
# Prevent mkldnn decompositions
torch._C._set_mkldnn_enabled(False)
# Export the model and return it.
expo_program = export(model, inputs, strict=True)
if dump_graphs:
logging.info("Exported graph:")
expo_program.graph_module.graph.print_tabular()
return expo_program
# Export the model and lower it to an EdgeProgramManager (in edge IR).
def export_to_edge(
model: torch.nn.Module,
inputs: tuple[object, ...],
dump_graphs: bool = False,
constant_methods: Optional[dict[str, object]] = None,
) -> EdgeProgramManager:
assert isinstance(model, torch.nn.Module), "model should be an nn.Module"
# Export the model into an ExportedProgram.
expo_program = export_program(model, inputs, dump_graphs=dump_graphs)
# Call to_edge to convert the graph to edge IR.
# Note: dim_order is skipped (https://github.com/pytorch/executorch/issues/3704)
edge_prog_manager = to_edge(
expo_program,
compile_config=EdgeCompileConfig(
# Allow specific non-core aten ops in the IR.
_core_aten_ops_exception_list=[
torch.ops.aten._native_batch_norm_legit_functional.default,
torch.ops.aten.linear.default,
torch.ops.aten.linalg_vector_norm.default,
torch.ops.aten.unfold.default,
torch.ops.aten.angle.default,
# cadence replaced to_dim_order_copy with _to_copy for performance
# skip _to_copy op to get around of dim order check
# We should remove this op once cadence can support dim order
exir_ops.edge.aten._to_copy.default,
],
),
constant_methods=constant_methods,
)
if dump_graphs:
logging.info("Edge graph:")
edge_prog_manager.exported_program().graph_module.graph.print_tabular()
return edge_prog_manager
def export_to_cadence(
model: torch.nn.Module,
inputs: tuple[object, ...],
dump_graphs: bool = False,
opt_level: int = 1,
) -> EdgeProgramManager:
edge_prog_manager = export_to_edge(model, inputs, dump_graphs=dump_graphs)
cadence_passes = get_cadence_passes(opt_level)
# Run a couple required passes for quant/dequant ops
cadence_prog_manager = edge_prog_manager.transform(
cast(
list[Callable[[torch.fx.GraphModule], Optional[PassResult]]], cadence_passes
)
)
return cadence_prog_manager
def quantize_and_export_to_cadence(
model: torch.nn.Module,
inputs: tuple[object, ...],
dump_graphs: bool = False,
opt_level: int = 1,
) -> EdgeProgramManager:
quantized_model = quantize_pt2(model, inputs)
return export_to_cadence(
quantized_model,
inputs,
opt_level=opt_level,
dump_graphs=dump_graphs,
)
# Export the model and lower it to an EdgeProgramManager (in edge IR), and
# apply passes specific to Cadence DSP execution. Return both to print the
# differences.
def export_to_executorch_gen_etrecord(
model: torch.nn.Module,
inputs: tuple[object, ...],
output_dir: Optional[str] = None,
opt_level: int = 1,
mem_algo: int = 0,
alloc_graph_input: bool = True,
alloc_graph_output: bool = True,
memory_config: Optional[MemoryConfig] = None,
dump_graphs: bool = False,
) -> ExecutorchProgramManager:
cadence_passes = get_cadence_passes(opt_level)
edge_prog_manager = export_to_edge(model, inputs, dump_graphs)
# Run a couple required passes for quant/dequant ops
cadence_prog_manager = edge_prog_manager.transform(
cast(
list[Callable[[torch.fx.GraphModule], Optional[PassResult]]], cadence_passes
)
)
# Print some information to terminal
print_ops_info(
edge_prog_manager.exported_program().graph_module,
cadence_prog_manager.exported_program().graph_module,
)
if memory_config is None:
memory_config = get_default_memory_config()
memory_planning_pass = CadenceMemoryPlanning(
memory_config,
opt_level=opt_level,
mem_algo=mem_algo,
alloc_graph_input=alloc_graph_input,
alloc_graph_output=alloc_graph_output,
)
# Get executorch program after Cadence specific passes
exec_prog: ExecutorchProgramManager = cadence_prog_manager.to_executorch(
ExecutorchBackendConfig(
memory_planning_pass=memory_planning_pass,
emit_stacktrace=False,
to_out_var_pass=ToOutVarPass(),
extract_delegate_segments=False,
sym_shape_eval_pass=HintBasedSymShapeEvalPass(),
),
)
print_memory_planning_info(
exec_prog,
memory_config,
opt_level,
alloc_graph_input,
alloc_graph_output,
)
if output_dir:
_gen_etrecord(edge_prog_manager, exec_prog, Path(output_dir))
else:
logging.warning("No output directory provided, skipping ETRecord generation")
return exec_prog
def _gen_etrecord(
edge_program: EdgeProgramManager,
et_program: ExecutorchProgramManager,
output_dir: Path,
) -> None:
etrec_path = output_dir / "etrecord.bin"
try:
generate_etrecord(
et_record=etrec_path,
edge_dialect_program=edge_program,
executorch_program=et_program,
)
logging.info(f"Generated ETRecord at {etrec_path}")
except Exception:
# Any errors here shouldn't block the rest of the flow
logging.exception("Encountered exception while generating ETRecord")