-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy patharm_backend_monkey_patch.py
More file actions
689 lines (607 loc) · 28.6 KB
/
arm_backend_monkey_patch.py
File metadata and controls
689 lines (607 loc) · 28.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
# SPDX-FileCopyrightText: Copyright 2026 Arm Limited and/or its affiliates <open-source-office@arm.com>
#
# This source code is licensed under the BSD-style license found in the
# LicenseRef-BSD-ExecuTorch.txt file in the top-level directory.
# SPDX-FileCopyrightText: <text>Copyright 2025-2026 Arm Limited and/or
# its affiliates <open-source-office@arm.com></text>
# SPDX-License-Identifier: Apache-2.0 AND LicenseRef-BSD-ExecuTorch
import torch
def apply_arm_backend_monkey_patch() -> None:
# Arm's FoldAndAnnotateQParamsPass stores input_qparams by top-level FX
# argument index. That is fine for normal tensor arguments, but
# FuseConstantArgsPass later interprets those keys as if they were indexed
# by flattened `node.all_input_nodes`. For list-valued tensor arguments such
# as `aten.cat([tensor0, tensor1], dim=...)`, only key 0 exists, so the
# first tensor is dequantized during constant folding while later tensors in
# the same list are treated as raw int8 codes cast to float. That produces a
# bad fused constant before TOSA serialization. Resolve qparams using the
# parent argument index instead so every tensor inside the same list inherits
# the same dequantization domain, matching FoldAndAnnotateQParamsPass.
#
# Upstream issue:
# https://github.com/pytorch/executorch/issues/18971
#
# Repro script:
# investigations/practical_rife/analyze_practical_rife_quantized_cat_constant_fold_repro.py
from executorch.backends.arm._passes import fuse_constant_ops_pass
from executorch.backends.arm._passes.arm_pass_utils import (
get_constant_placeholder_kind,
get_first_fake_tensor,
get_param_tensor,
is_persistent_buffer,
)
from executorch.backends.transforms.utils import create_constant_placeholder
fuse_constant_pass = fuse_constant_ops_pass.FuseConstantArgsPass
if not getattr(
fuse_constant_pass, "_list_arg_input_qparams_patch_applied", False
):
def patched_fuse_nodes(self, node) -> bool:
if (
node.meta.get(fuse_constant_ops_pass.TosaSpecialDtype.meta_key(), None)
== fuse_constant_ops_pass.TosaSpecialDtype.SHAPE
):
return False
input_nodes = list(node.all_input_nodes)
qparams = node.meta.get("input_qparams", None)
def resolve_arg(arg, *, arg_index):
if isinstance(arg, torch.fx.Node) and arg in input_nodes:
tensor = get_param_tensor(self.exported_program, arg)
if qparams and arg_index is not None and arg_index in qparams.keys():
tensor = qparams[arg_index].dequantize_value(tensor)
return tensor
if isinstance(arg, tuple):
return tuple(
resolve_arg(item, arg_index=arg_index) for item in arg
)
if isinstance(arg, list):
return [resolve_arg(item, arg_index=arg_index) for item in arg]
return arg
new_args = tuple(
resolve_arg(arg, arg_index=index)
for index, arg in enumerate(node.args)
)
new_kwargs = {
key: resolve_arg(value, arg_index=None)
for key, value in node.kwargs.items()
}
data = node.target(*new_args, **new_kwargs)
if data.numel() > get_first_fake_tensor(node).numel():
return False
if "output_qparams" in node.meta and len(node.meta["output_qparams"]) > 0:
q_params = node.meta["output_qparams"][0]
data = q_params.quantize_value(data)
insert_pos = list(node.all_input_nodes)[0]
input_kind = get_constant_placeholder_kind(self.exported_program, insert_pos)
persistent_buffer = is_persistent_buffer(self.exported_program, insert_pos)
with node.graph.inserting_before(insert_pos):
const_node = create_constant_placeholder(
exp_program=self.exported_program,
graph=node.graph,
kind=input_kind,
name=node.name + "_fused_const",
data=data,
persistent_buffer=persistent_buffer,
)
self._propagate_special_dtype(input_nodes, const_node, data)
node.replace_all_uses_with(const_node)
return True
fuse_constant_pass._fuse_nodes = patched_fuse_nodes
fuse_constant_pass._list_arg_input_qparams_patch_applied = True
# Arm's generic quantization annotator marks constant/factory producers like
# aten.full.default as quantized outputs. PT2E prepare then processes those
# nodes directly and asserts because exported factory ops still carry kwargs
# such as dtype/device/pin_memory. XNNPACK avoids this by leaving the
# factory node unannotated and letting annotated consumers insert observers
# at the boundary instead.
#
# Upstream issue:
# https://github.com/pytorch/executorch/issues/18322
from executorch.backends.arm.quantizer import quantization_annotator
if getattr(
quantization_annotator, "_skip_factory_output_annotation_patch_applied", False
):
return
original_get_quant_properties = quantization_annotator.get_quant_properties
skipped_factory_targets = {
torch.ops.aten.full.default,
torch.ops.aten.full,
torch.ops.aten.zeros.default,
torch.ops.aten.ones.default,
torch.ops.aten.fill_.Scalar,
torch.ops.aten.scalar_tensor.default,
}
def patched_get_quant_properties(node, gm, quantization_config):
if node.target in skipped_factory_targets:
return None
return original_get_quant_properties(node, gm, quantization_config)
quantization_annotator.get_quant_properties = patched_get_quant_properties
quantization_annotator._skip_factory_output_annotation_patch_applied = True
# Arm's RewriteConvPass inserts a TOSA rescale using qparams read from the
# rewritten TOSA conv node itself. For non-fuseable conv -> clamp branches,
# FoldAndAnnotateQParamsPass can leave the output qparams on the clamp
# instead of the conv. The rewritten TOSA conv then appears to have no
# output qparams and lowering fails. Use the original conv as the qparam
# source and fall back to its immediate clamp user when needed.
#
# Upstream issue:
# https://github.com/pytorch/executorch/issues/18491
from executorch.backends.arm._passes import rewrite_conv_pass
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
get_input_qparams,
get_output_qparams,
)
from executorch.exir.dialects._ops import ops as exir_ops
rewrite_pass = rewrite_conv_pass.RewriteConvPass
if getattr(rewrite_pass, "_conv_clamp_output_q_patch_applied", False):
return
original_insert_output_rescale = rewrite_pass.insert_output_rescale
def patched_insert_output_rescale(self, graph_module, node, source_node=None):
qparam_source = source_node if source_node is not None else node
input_qparams = get_input_qparams(qparam_source)
try:
output_qparams = get_output_qparams(qparam_source)[0]
except ValueError:
users = list(qparam_source.users)
if (
len(users) == 1
and users[0].target == exir_ops.edge.aten.clamp.default
and "output_qparams" in users[0].meta
and len(users[0].meta["output_qparams"]) > 0
):
output_qparams = get_output_qparams(users[0])[0]
else:
raise
if source_node is None:
return original_insert_output_rescale(self, graph_module, node)
weight_qparams = input_qparams[1]
input_qparams = input_qparams[0]
is_per_channel = weight_qparams.per_channel
if is_per_channel:
weight_scale = weight_qparams.get_scale_per_channel()
else:
weight_scale = [weight_qparams.get_scale_per_tensor()]
input_scale = input_qparams.get_scale_per_tensor()
post_conv2d_scale = [
(inp * w) / out
for inp, w, out in zip(
rewrite_conv_pass.itertools.cycle([input_scale]),
weight_scale,
rewrite_conv_pass.itertools.cycle(
[output_qparams.get_scale_per_tensor()]
),
)
]
with graph_module.graph.inserting_after(node):
rescale_node = rewrite_conv_pass.create_node(
graph=graph_module.graph,
op_target=exir_ops.backend.tosa.RESCALE.default,
args=(
node,
output_qparams.dtype,
post_conv2d_scale,
0,
output_qparams.get_zp_per_tensor(),
),
from_node=node,
)
return rescale_node
def patched_call(self, graph_module):
modified = False
for node in graph_module.graph.nodes:
if (
node.op != "call_function"
or node.target != exir_ops.edge.aten.convolution.default
):
continue
modified = True
(
x,
weight,
bias,
stride,
pad,
dilation,
transposed,
output_padding,
group,
) = node.args
input_fake_tensor = rewrite_conv_pass.get_first_fake_tensor(x)
weight_fake_tensor = rewrite_conv_pass.get_first_fake_tensor(weight)
input_shape = input_fake_tensor.shape
weight_shape = weight_fake_tensor.shape
spatial_rank = len(input_shape) - 2
stride_list = rewrite_conv_pass.expand_around_channel(stride, spatial_rank)
dilation_list = rewrite_conv_pass.expand_around_channel(
dilation, spatial_rank
)
pad_list = rewrite_conv_pass.expand_around_channel(pad, spatial_rank)
stride = tuple(stride_list)
has_bias = bias is not None
if not has_bias:
bias = self._add_bias(graph_module, node, weight)
conv_args: tuple[object, ...]
if transposed:
if spatial_rank != 2:
raise RuntimeError(
"Only 2D transpose convolutions are supported in the Arm backend."
)
if group != 1:
raise RuntimeError(
"Grouped transpose convolutions are not supported in the Arm backend."
)
if any(d != 1 for d in dilation_list):
raise RuntimeError(
"Transpose convolutions with dilation are not supported in the Arm backend."
)
output_padding_list = rewrite_conv_pass.expand_around_channel(
output_padding, spatial_rank
)
out_pad = [
-pad_list[0],
-pad_list[0] + output_padding_list[0],
-pad_list[1],
-pad_list[1] + output_padding_list[1],
]
target_op = exir_ops.backend.tosa.TRANSPOSE_CONV2D.default
conv_args = (
x,
weight,
bias,
out_pad,
stride,
)
else:
pad_attr: list[int] = []
for value in pad_list:
pad_attr.extend([value, value])
for axis_index in range(spatial_rank):
pad_index = axis_index * 2 + 1
pad_attr[pad_index] = self._adjust_pad_if_needed(
input_shape[axis_index + 2],
weight_shape[axis_index + 2],
stride_list[axis_index],
pad_attr[pad_index],
dilation_list[axis_index],
)
dilation = tuple(dilation_list)
pad = pad_attr
if self._is_conv3d(len(input_shape), group):
target_op = exir_ops.backend.tosa.CONV3D.default
elif self._is_depthwise_conv2d(node):
target_op = exir_ops.backend.tosa.DEPTHWISE_CONV2D.default
if all(user.target != target_op for user in weight.users):
self._reshape_weights(weight, input_fake_tensor.shape[1])
weight_fake_tensor = rewrite_conv_pass.get_first_fake_tensor(weight)
else:
target_op = exir_ops.backend.tosa.CONV2D.default
conv_args = (
x,
weight,
bias,
stride,
pad,
dilation,
)
with graph_module.graph.inserting_after(node):
tosa_op = rewrite_conv_pass.create_node(
graph=graph_module.graph,
op_target=target_op,
args=conv_args,
from_node=node,
inherit_qparams=True,
)
bias_fake_tensor = rewrite_conv_pass.get_first_fake_tensor(bias) if bias else None
tosa_node_fake_tensor = target_op(
input_fake_tensor,
weight_fake_tensor,
bias_fake_tensor,
*conv_args[3:],
)
if (
tosa_node_fake_tensor.dtype == torch.int32
and input_fake_tensor.dtype == torch.int8
):
output_rescale = self.insert_output_rescale(
graph_module, tosa_op, source_node=node
)
node.replace_all_uses_with(output_rescale)
elif (
tosa_node_fake_tensor.dtype == torch.int32
and input_fake_tensor.dtype == torch.int16
):
has_bias = len(node.meta["input_qparams"]) > 2
if not has_bias:
output_rescale = self.insert_output_rescale(
graph_module, tosa_op, source_node=node
)
node.replace_all_uses_with(output_rescale)
else:
node.replace_all_uses_with(tosa_op)
tosa_op.meta[rewrite_conv_pass.TosaSpecialDtype.meta_key()] = (
rewrite_conv_pass.TosaSpecialDtype.INT48
)
else:
node.replace_all_uses_with(tosa_op)
graph_module.graph.erase_node(node)
if modified:
graph_module.recompile()
graph_module = rewrite_conv_pass.ArmPass.call(
self, graph_module
).graph_module
return rewrite_conv_pass.PassResult(graph_module, modified)
rewrite_pass.insert_output_rescale = patched_insert_output_rescale
rewrite_pass.call = patched_call
rewrite_pass._conv_clamp_output_q_patch_applied = True
# ToTosaMemoryFormatPass rewrites graph outputs back to the original memory
# format by calling node.replace_input_with(...) on the FX output node. If
# the output tuple contains the same FX node multiple times, such as after
# FuseEqualPlaceholdersPass merges equal constant placeholders, that helper
# rewrites every matching slot at once. The result is that distinct logical
# outputs collapse onto the same transpose node and lose per-slot identity.
# Patch output-node rewrites to update only the first matching tuple slot so
# duplicate logical outputs each keep their own transpose.
#
# Upstream issue:
# https://github.com/pytorch/executorch/issues/18320
from executorch.backends.arm._passes import to_tosa_memory_format_pass
tosa_memory_format_module = to_tosa_memory_format_pass
tosa_memory_format = tosa_memory_format_module.ToTosaMemoryFormatPass
if getattr(
tosa_memory_format, "_duplicate_output_transpose_patch_applied", False
):
return
original_insert_input_transpose = tosa_memory_format.insert_input_transpose
def _replace_first_output_slot(
output_node, original_input_node, replacement_node
) -> None:
outputs = output_node.args[0]
if not isinstance(outputs, (list, tuple)):
raise TypeError(
f"Expected output node args to be a list or tuple, got {type(outputs)}"
)
rewritten_outputs = list(outputs)
for output_index, existing_output in enumerate(rewritten_outputs):
if existing_output is original_input_node:
rewritten_outputs[output_index] = replacement_node
break
else:
raise RuntimeError(
"Could not find the original output node while rewriting the output tuple."
)
replacement = (
rewritten_outputs if isinstance(outputs, list) else tuple(rewritten_outputs)
)
output_node.args = (replacement,)
def patched_insert_input_transpose(node, input_node, graph_module):
if node.op != "output":
return original_insert_input_transpose(node, input_node, graph_module)
if (
input_node.op == "call_function"
and input_node.target
== tosa_memory_format_module.exir_ops.backend.tosa.TRANSPOSE.default
):
pre_permute_node = input_node.all_input_nodes[0]
_replace_first_output_slot(node, input_node, pre_permute_node)
return
rank = len(tosa_memory_format_module.get_first_fake_tensor(input_node).size())
spatial_rank = input_node.meta["tosa_spatial_rank"]
mem_format = tosa_memory_format._channels_last_inverse_order(
rank, spatial_rank
)
assert sorted(mem_format) == list(
range(rank)
), f"bad perm {mem_format} for rank {rank} in insert_input_transpose"
with graph_module.graph.inserting_before(node):
permute_node = tosa_memory_format_module.create_node(
graph_module.graph,
tosa_memory_format_module.exir_ops.backend.tosa.TRANSPOSE.default,
args=(
input_node,
list(mem_format),
),
from_node=node,
)
permute_node.meta["tosa_dim_order"] = tuple(
range(len(input_node.meta["val"].size()))
)
permute_node.meta["tosa_spatial_rank"] = spatial_rank
_replace_first_output_slot(node, input_node, permute_node)
tosa_memory_format.insert_input_transpose = staticmethod(
patched_insert_input_transpose
)
tosa_memory_format._duplicate_output_transpose_patch_applied = True
# Arm's quantized TABLE lowering builds the 256-entry int8 input domain with
# torch.linspace(..., dtype=torch.int8). For symmetric int8 qparams such as
# qmin=-127, qmax=127 that produces only 255 unique codes, duplicating 0 and
# shifting half the LUT by one entry. Build the TOSA int8 table domain from
# exact integer codes instead so quantized sigmoid/tanh/etc. tables line up
# with the PT2E q/dq reference and the backend's raw int8 indexing.
#
# Upstream issue:
# https://github.com/pytorch/executorch/issues/18496
from executorch.backends.arm._passes import insert_table_ops
insert_table_pass = insert_table_ops.InsertTableOpsPass
if getattr(insert_table_pass, "_exact_int8_table_domain_patch_applied", False):
return
original_generate_8bit_table_values = insert_table_pass.generate_8bit_table_values
def patched_generate_8bit_table_values(
self,
torch_op,
in_quantargs,
out_quantargs,
):
if in_quantargs.dtype != torch.int8:
return original_generate_8bit_table_values(
self, torch_op, in_quantargs, out_quantargs
)
def f(x: torch.Tensor) -> torch.Tensor:
x = in_quantargs.dequantize_value(x)
x = torch_op(x)
return out_quantargs.quantize_value(x)
table_codes = torch.arange(-128, 128, dtype=torch.int16).to(torch.int8)
return (f(table_codes).to(dtype=torch.int8), 0)
insert_table_pass.generate_8bit_table_values = (
patched_generate_8bit_table_values
)
insert_table_pass._exact_int8_table_domain_patch_applied = True
# Arm's TOSA partitioner accumulates partition tags in a Python `set` and
# then materializes `partition_tags` from that set. ExecuTorch backend
# lowering later iterates `partition_tags.items()` in insertion order when
# duplicating constants, creating submodules, and lowering delegates. That
# makes the exported graph/resource order depend on Python hash
# randomization. Sort tags here so Arm/VGF lowering is deterministic
# without requiring PYTHONHASHSEED in the environment.
#
# Upstream issue:
# https://github.com/pytorch/executorch/issues/19045
#
# Repro script:
# investigations/practical_rife/executorch_partition_tag_order_repro.py
from executorch.backends.arm.tosa import partitioner as tosa_partitioner_module
tosa_partitioner = tosa_partitioner_module.TOSAPartitioner
if not getattr(
tosa_partitioner, "_deterministic_partition_tag_order_patch_applied", False
):
original_tosa_partition = tosa_partitioner.partition
def patched_tosa_partition(self, exported_program):
result = original_tosa_partition(self, exported_program)
if len(result.partition_tags) > 1:
result.partition_tags = {
tag: result.partition_tags[tag]
for tag in sorted(result.partition_tags)
}
return result
tosa_partitioner.partition = patched_tosa_partition
tosa_partitioner._deterministic_partition_tag_order_patch_applied = True
# Arm's RewriteUpsamplePass inserts a trailing TOSA RESCALE for quantized
# bilinear resize, but the inserted node does not inherit the original
# resize node's fake-tensor metadata. A later chained resize then crashes
# on get_first_fake_tensor(x) with `KeyError: 'RewriteUpsamplePass: val'`.
# Preserve the original node metadata on the inserted RESCALE so the chained
# case behaves like the unsplit single-resize case.
#
# Upstream issue:
# https://github.com/pytorch/executorch/issues/19068
from executorch.backends.arm._passes import rewrite_upsample
rewrite_upsample_pass = rewrite_upsample.RewriteUpsamplePass
if not getattr(
rewrite_upsample_pass, "_rewrite_upsample_rescale_meta_patch_applied", False
):
def patched_rewrite_upsample_call(self, graph_module):
modified = False
for node in graph_module.graph.nodes:
if (
node.op != "call_function"
or node.target not in self.targeted_ops
):
continue
modified = True
if node.target == rewrite_upsample.exir_ops.edge.aten.upsample_bilinear2d.vec:
x, output_size, align_corners, scale_factors = node.args
resize_mode = "bilinear"
else:
x, output_size, scale_factors = node.args
align_corners = False
resize_mode = "nearest"
input_fake = rewrite_upsample.get_first_fake_tensor(x)
output_fake = rewrite_upsample.get_first_fake_tensor(node)
with graph_module.graph.inserting_before(node):
tosa_resize_node = rewrite_upsample.create_node(
graph_module.graph,
op_target=rewrite_upsample.exir_ops.backend.tosa.RESIZE.default,
args=(x, output_size, align_corners, scale_factors),
kwargs={"resize_mode": resize_mode},
from_node=node,
inherit_qparams=True,
)
node.replace_all_uses_with(tosa_resize_node)
graph_module.graph.erase_node(node)
input_dtype = input_fake.dtype
if (
input_dtype == torch.int8 or input_dtype == torch.int16
) and resize_mode == "bilinear":
input_size_xy = input_fake.shape[2:]
output_size_xy = output_fake.shape[2:]
scale_n_yx, _, _, _ = rewrite_upsample.get_resize_parameters(
input_size_xy=input_size_xy,
output_size_xy=output_size_xy,
resize_mode=1,
align_corners=align_corners,
)
output_dtype = output_fake.dtype
output_scale = float(1 / (scale_n_yx[0] * scale_n_yx[1]))
with graph_module.graph.inserting_after(tosa_resize_node):
rescale_node = rewrite_upsample.create_node(
graph_module.graph,
rewrite_upsample.exir_ops.backend.tosa.RESCALE.default,
)
rescale_node.meta = dict(node.meta)
tosa_resize_node.replace_all_uses_with(rescale_node)
if input_dtype == torch.int16:
tosa_resize_node.meta[
rewrite_upsample.TosaSpecialDtype.meta_key()
] = rewrite_upsample.TosaSpecialDtype.INT48
rescale_node.args = (
tosa_resize_node,
output_dtype,
[output_scale],
0,
0,
)
if modified:
graph_module = rewrite_upsample.ArmPass.call(
self, graph_module
).graph_module
return rewrite_upsample.PassResult(graph_module, modified)
rewrite_upsample_pass.call = patched_rewrite_upsample_call
rewrite_upsample_pass._rewrite_upsample_rescale_meta_patch_applied = True
# Arm's VGF partitioner can leave a final graph-output dequantize at top
# level when the only post-backend tail is that lone DQ node. In our
# scenario-export flow we cannot fall back to an undelegated compute op, so
# patch the VGF partitioner to claim that exact output tail as its own tiny
# partition. This is a narrow workaround for our VGF/scenario-export setup,
# not a general upstream ExecuTorch bug report candidate. Keep the scope
# narrow: only a terminal graph-output DQ is force-tagged, and only for VGF
# partitioning.
from executorch.backends.arm.constants import DQ_OPS
from executorch.backends.arm.vgf import partitioner as vgf_partitioner_module
vgf_partitioner = vgf_partitioner_module.VgfPartitioner
if getattr(vgf_partitioner, "_output_dq_tail_patch_applied", False):
return
original_vgf_partition = vgf_partitioner.partition
def patched_vgf_partition(self, exported_program):
result = original_vgf_partition(self, exported_program)
gm = result.tagged_exported_program.graph_module
output_node = next(
(node for node in gm.graph.nodes if node.op == "output"),
None,
)
if output_node is None:
return result
existing_tags = set(result.partition_tags)
tag_index = 0
def next_tag() -> str:
nonlocal tag_index
while True:
tag = f"arm_vgf_output_dq_tail_{tag_index}"
tag_index += 1
if tag not in existing_tags:
existing_tags.add(tag)
return tag
for output_parent in output_node.all_input_nodes:
if (
output_parent.op != "call_function"
or output_parent.target not in DQ_OPS
or "delegation_tag" in output_parent.meta
):
continue
if not output_parent.users or any(
user.op != "output" for user in output_parent.users
):
continue
tag = next_tag()
output_parent.meta["delegation_tag"] = tag
result.partition_tags[tag] = self.delegation_spec
return result
vgf_partitioner.partition = patched_vgf_partition
vgf_partitioner._output_dq_tail_patch_applied = True