Skip to content

Commit a1f8373

Browse files
authored
Migrate pt2e qualcomm
Differential Revision: D75166700 Pull Request resolved: pytorch#11049
1 parent 6fafe7c commit a1f8373

File tree

17 files changed

+87
-74
lines changed

17 files changed

+87
-74
lines changed

.lintrunner.toml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -388,15 +388,12 @@ exclude_patterns = [
388388
# backends
389389
"backends/vulkan/quantizer/**",
390390
"backends/vulkan/test/**",
391-
"backends/qualcomm/quantizer/**",
392-
"examples/qualcomm/**",
393391
"backends/xnnpack/quantizer/**",
394392
"backends/xnnpack/test/**",
395393
"exir/tests/test_passes.py",
396394
"extension/llm/export/builder.py",
397395
"extension/llm/export/quantizer_lib.py",
398396
"exir/tests/test_memory_planning.py",
399-
"backends/transforms/duplicate_dynamic_quant_chain.py",
400397
"exir/backend/test/demos/test_xnnpack_qnnpack.py",
401398
]
402399

backends/qualcomm/_passes/qnn_pass_manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,8 +135,8 @@ def get_to_edge_transform_passes(
135135
from executorch.backends.qualcomm.builders import node_visitor
136136
from executorch.exir.dialects._ops import ops as exir_ops
137137

138-
node_visitor.q_ops.add(exir_ops.edge.pt2e_quant.quantize_affine.default)
139-
node_visitor.dq_ops.add(exir_ops.edge.pt2e_quant.dequantize_affine.default)
138+
node_visitor.q_ops.add(exir_ops.edge.torchao.quantize_affine.default)
139+
node_visitor.dq_ops.add(exir_ops.edge.torchao.dequantize_affine.default)
140140

141141
passes_job = (
142142
passes_job if passes_job is not None else get_capture_program_passes()

backends/qualcomm/builders/node_visitor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -265,8 +265,8 @@ def get_quant_encoding_conf(
265265
)
266266
# TODO: refactor this when target could be correctly detected
267267
per_block_encoding = {
268-
exir_ops.edge.pt2e_quant.quantize_affine.default,
269-
exir_ops.edge.pt2e_quant.dequantize_affine.default,
268+
exir_ops.edge.torchao.quantize_affine.default,
269+
exir_ops.edge.torchao.dequantize_affine.default,
270270
}
271271
if quant_attrs[QCOM_ENCODING] in per_block_encoding:
272272
return self.make_qnn_per_block_config(node, quant_attrs)

backends/qualcomm/partition/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def get_skip_decomp_table() -> List[torch._ops.OperatorBase]:
5757
torch.ops.aten.upsample_bicubic2d.vec,
5858
# This request is ignored because it is in a blocklist. Refer to exir/program/_program.py
5959
torch.ops.aten.unbind.int,
60-
torch.ops.pt2e_quant.quantize_affine.default,
61-
torch.ops.pt2e_quant.dequantize_affine.default,
60+
torch.ops.torchao.quantize_affine.default,
61+
torch.ops.torchao.dequantize_affine.default,
6262
]
6363
return do_not_decompose

backends/qualcomm/quantizer/annotators.py

Lines changed: 26 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -12,20 +12,17 @@
1212
from torch._ops import OpOverload
1313

1414
from torch._subclasses import FakeTensor
15-
from torch.ao.quantization.fake_quantize import FixedQParamsFakeQuantize
15+
from torch.fx import Node
1616

17-
from torch.ao.quantization.observer import FixedQParamsObserver
18-
from torch.ao.quantization.quantizer import (
17+
from torchao.quantization.pt2e import FixedQParamsFakeQuantize, FixedQParamsObserver
18+
from torchao.quantization.pt2e.quantizer import (
19+
annotate_input_qspec_map,
20+
annotate_output_qspec,
1921
DerivedQuantizationSpec,
2022
QuantizationAnnotation,
2123
QuantizationSpec,
2224
SharedQuantizationSpec,
2325
)
24-
from torch.ao.quantization.quantizer.utils import (
25-
_annotate_input_qspec_map,
26-
_annotate_output_qspec,
27-
)
28-
from torch.fx import Node
2926

3027
from .qconfig import (
3128
get_16a16w_qnn_ptq_config,
@@ -643,19 +640,19 @@ def annotate_rms_norm(node: Node, quantization_config: QuantizationConfig) -> No
643640
return
644641

645642
# TODO current only support 16a16w
646-
_annotate_input_qspec_map(
643+
annotate_input_qspec_map(
647644
node,
648645
act_node,
649646
quantization_config.input_activation,
650647
)
651648

652-
_annotate_input_qspec_map(
649+
annotate_input_qspec_map(
653650
node,
654651
weight_node,
655652
quantization_config.input_activation,
656653
)
657654
nodes_to_mark_annotated = [node]
658-
_annotate_output_qspec(node, quantization_config.output_activation)
655+
annotate_output_qspec(node, quantization_config.output_activation)
659656
_mark_nodes_as_annotated(nodes_to_mark_annotated)
660657

661658

@@ -844,25 +841,25 @@ def annotate_group_norm(node: Node, quantization_config: QuantizationConfig) ->
844841
if _is_annotated([node]):
845842
return
846843

847-
_annotate_input_qspec_map(
844+
annotate_input_qspec_map(
848845
node,
849846
act_node,
850847
quantization_config.input_activation,
851848
)
852-
_annotate_input_qspec_map(
849+
annotate_input_qspec_map(
853850
node,
854851
weight_node,
855852
quantization_config.weight,
856853
)
857854
nodes_to_mark_annotated = [node, weight_node]
858855
if bias_node:
859-
_annotate_input_qspec_map(
856+
annotate_input_qspec_map(
860857
node,
861858
bias_node,
862859
quantization_config.bias,
863860
)
864861
nodes_to_mark_annotated.append(bias_node)
865-
_annotate_output_qspec(node, quantization_config.output_activation)
862+
annotate_output_qspec(node, quantization_config.output_activation)
866863
_mark_nodes_as_annotated(nodes_to_mark_annotated)
867864

868865

@@ -1027,12 +1024,12 @@ def annotate_linear(node: Node, quantization_config: QuantizationConfig) -> None
10271024
if _is_annotated([node]):
10281025
return
10291026

1030-
_annotate_input_qspec_map(
1027+
annotate_input_qspec_map(
10311028
node,
10321029
act_node,
10331030
quantization_config.input_activation,
10341031
)
1035-
_annotate_input_qspec_map(
1032+
annotate_input_qspec_map(
10361033
node,
10371034
weight_node,
10381035
quantization_config.weight,
@@ -1043,9 +1040,9 @@ def annotate_linear(node: Node, quantization_config: QuantizationConfig) -> None
10431040
bias_config = quantization_config.bias(node)
10441041
else:
10451042
bias_config = quantization_config.bias
1046-
_annotate_input_qspec_map(node, bias_node, bias_config)
1043+
annotate_input_qspec_map(node, bias_node, bias_config)
10471044
nodes_to_mark_annotated.append(bias_node)
1048-
_annotate_output_qspec(node, quantization_config.output_activation)
1045+
annotate_output_qspec(node, quantization_config.output_activation)
10491046
_mark_nodes_as_annotated(nodes_to_mark_annotated)
10501047

10511048
# We use get_source_partition in pass, but it is the same source for MultiheadAttention, so we need to change its source_fn_stack.
@@ -1063,29 +1060,29 @@ def annotate_batch_and_instance_norm(
10631060
return
10641061

10651062
annotated_args = [act]
1066-
_annotate_input_qspec_map(
1063+
annotate_input_qspec_map(
10671064
node,
10681065
act,
10691066
quantization_config.input_activation,
10701067
)
10711068
# QNN requires uint8 instead of int8 in 'weight' config
10721069
if weight is not None:
1073-
_annotate_input_qspec_map(
1070+
annotate_input_qspec_map(
10741071
node,
10751072
weight,
10761073
quantization_config.input_activation,
10771074
)
10781075
annotated_args.append(weight)
10791076

10801077
if bias is not None:
1081-
_annotate_input_qspec_map(
1078+
annotate_input_qspec_map(
10821079
node,
10831080
bias,
10841081
quantization_config.bias,
10851082
)
10861083
annotated_args.append(bias)
10871084

1088-
_annotate_output_qspec(node, quantization_config.output_activation)
1085+
annotate_output_qspec(node, quantization_config.output_activation)
10891086
_mark_nodes_as_annotated([node, *annotated_args])
10901087

10911088

@@ -1095,7 +1092,7 @@ def annotate_getitem(node: Node, quantization_config: QuantizationConfig) -> Non
10951092
return
10961093

10971094
if _is_float_tensor(node):
1098-
_annotate_output_qspec(node, quantization_config.output_activation)
1095+
annotate_output_qspec(node, quantization_config.output_activation)
10991096
_mark_nodes_as_annotated([node])
11001097

11011098

@@ -1111,32 +1108,32 @@ def annotate_layer_norm(node: Node, quantization_config: QuantizationConfig) ->
11111108
return
11121109
input_act_qspec = quantization_config.input_activation
11131110

1114-
_annotate_input_qspec_map(
1111+
annotate_input_qspec_map(
11151112
node,
11161113
act_node,
11171114
input_act_qspec,
11181115
)
11191116
if input_act_qspec.dtype == torch.int32:
1120-
_annotate_input_qspec_map(
1117+
annotate_input_qspec_map(
11211118
node,
11221119
weight_node,
11231120
get_16a16w_qnn_ptq_config().weight,
11241121
)
11251122
else:
1126-
_annotate_input_qspec_map(
1123+
annotate_input_qspec_map(
11271124
node,
11281125
weight_node,
11291126
input_act_qspec,
11301127
)
11311128
nodes_to_mark_annotated = [node, weight_node]
11321129
if bias_node:
1133-
_annotate_input_qspec_map(
1130+
annotate_input_qspec_map(
11341131
node,
11351132
bias_node,
11361133
quantization_config.bias,
11371134
)
11381135
nodes_to_mark_annotated.append(bias_node)
1139-
_annotate_output_qspec(node, quantization_config.output_activation)
1136+
annotate_output_qspec(node, quantization_config.output_activation)
11401137
_mark_nodes_as_annotated(nodes_to_mark_annotated)
11411138

11421139

backends/qualcomm/quantizer/custom_annotation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,13 @@
1717
QuantizationConfig,
1818
)
1919
from executorch.exir.dialects._ops import ops as exir_ops
20-
from torch.ao.quantization.observer import FixedQParamsObserver, MinMaxObserver
21-
from torch.ao.quantization.quantizer import (
20+
from torch.fx import Node
21+
from torchao.quantization.pt2e import FixedQParamsObserver, MinMaxObserver
22+
from torchao.quantization.pt2e.quantizer import (
2223
QuantizationAnnotation,
2324
QuantizationSpec,
2425
SharedQuantizationSpec,
2526
)
26-
from torch.fx import Node
2727

2828

2929
def annotate_mimi_decoder(gm: torch.fx.GraphModule):

backends/qualcomm/quantizer/observers/per_block_param_observer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
from typing import Tuple
88

99
import torch
10-
from torch.ao.quantization.observer import MappingType, PerBlock
11-
from torch.ao.quantization.pt2e._affine_quantization import (
10+
from torchao.quantization.pt2e import MappingType, PerBlock
11+
from torchao.quantization.pt2e._affine_quantization import (
1212
_get_reduction_params,
1313
AffineQuantizedMinMaxObserver,
1414
choose_qparams_affine_with_min_max,

backends/qualcomm/quantizer/observers/per_channel_param_observer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import torch
8-
from torch.ao.quantization.observer import UniformQuantizationObserverBase
8+
from torchao.quantization.pt2e import UniformQuantizationObserverBase
99

1010

1111
# TODO move to torch/ao/quantization/observer.py.

backends/qualcomm/quantizer/qconfig.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,19 @@
77
PerBlockParamObserver,
88
)
99
from torch import Tensor
10-
from torch.ao.quantization.fake_quantize import (
10+
from torch.fx import Node
11+
from torchao.quantization.pt2e import (
1112
FakeQuantize,
1213
FusedMovingAvgObsFakeQuantize,
13-
)
14-
from torch.ao.quantization.observer import (
1514
MinMaxObserver,
1615
MovingAverageMinMaxObserver,
1716
MovingAveragePerChannelMinMaxObserver,
1817
PerChannelMinMaxObserver,
1918
)
20-
from torch.ao.quantization.quantizer import DerivedQuantizationSpec, QuantizationSpec
21-
from torch.fx import Node
19+
from torchao.quantization.pt2e.quantizer import (
20+
DerivedQuantizationSpec,
21+
QuantizationSpec,
22+
)
2223

2324

2425
@dataclass(eq=True)

backends/qualcomm/quantizer/quantizer.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@
1212
from executorch.backends.qualcomm._passes.qnn_pass_manager import QnnPassManager
1313

1414
from torch._ops import OpOverload
15-
from torch.ao.quantization.quantizer import Quantizer
1615
from torch.fx import GraphModule
16+
from torchao.quantization.pt2e import UniformQuantizationObserverBase
17+
from torchao.quantization.pt2e.quantizer import Quantizer
1718

1819
from .annotators import OP_ANNOTATOR
1920

@@ -130,9 +131,7 @@ class ModuleQConfig:
130131
is_qat: bool = False
131132
is_conv_per_channel: bool = False
132133
is_linear_per_channel: bool = False
133-
act_observer: Optional[
134-
torch.ao.quantization.observer.UniformQuantizationObserverBase
135-
] = None
134+
act_observer: Optional[UniformQuantizationObserverBase] = None
136135

137136
def __post_init__(self):
138137
if (self.quant_dtype, self.is_qat) not in QUANT_CONFIG_DICT:

backends/qualcomm/tests/utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import numpy as np
1616
import torch
17+
import torchao
1718
from executorch import exir
1819
from executorch.backends.qualcomm.builders.node_visitor import dq_ops
1920
from executorch.backends.qualcomm.qnn_preprocess import QnnBackend
@@ -537,8 +538,8 @@ def get_qdq_module(
537538
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
538539
torch.ops.quantized_decomposed.quantize_per_channel.default,
539540
torch.ops.quantized_decomposed.dequantize_per_channel.default,
540-
torch.ops.pt2e_quant.quantize_affine.default,
541-
torch.ops.pt2e_quant.dequantize_affine.default,
541+
torch.ops.torchao.quantize_affine.default,
542+
torch.ops.torchao.dequantize_affine.default,
542543
}
543544
if not bypass_check:
544545
self.assertTrue(nodes.intersection(q_and_dq))
@@ -569,7 +570,7 @@ def get_prepared_qat_module(
569570
quantizer.set_submodule_qconfig_list(submodule_qconfig_list)
570571

571572
prepared = prepare_qat_pt2e(m, quantizer)
572-
return torch.ao.quantization.move_exported_model_to_train(prepared)
573+
return torchao.quantization.pt2e.move_exported_model_to_train(prepared)
573574

574575
def get_converted_sgd_trained_module(
575576
self,

backends/transforms/duplicate_dynamic_quant_chain.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,12 @@
99

1010
import torch
1111

12-
from torch.ao.quantization.pt2e.utils import (
13-
_filter_sym_size_users,
14-
_is_valid_annotation,
15-
)
16-
1712
from torch.fx.node import map_arg
1813
from torch.fx.passes.infra.pass_base import PassBase, PassResult
1914

15+
from torchao.quantization.pt2e.quantizer import is_valid_annotation
16+
from torchao.quantization.pt2e.utils import _filter_sym_size_users
17+
2018

2119
logger = logging.getLogger(__name__)
2220
logger.setLevel(logging.WARNING)
@@ -129,7 +127,7 @@ def _maybe_duplicate_dynamic_quantize_chain(
129127
dq_node_users = list(dq_node.users.copy())
130128
for user in dq_node_users:
131129
annotation = user.meta.get("quantization_annotation", None)
132-
if not _is_valid_annotation(annotation):
130+
if not is_valid_annotation(annotation):
133131
return
134132
with gm.graph.inserting_after(dq_node):
135133
new_node = gm.graph.node_copy(dq_node)

examples/qualcomm/oss_scripts/llama/llama.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@
8181
from pytorch_tokenizers import get_tokenizer, TiktokenTokenizer
8282
from pytorch_tokenizers.llama2c import Llama2cTokenizer as SentencePieceTokenizer
8383

84-
from torch.ao.quantization.observer import MinMaxObserver
84+
from torchao.quantization.pt2e import MinMaxObserver
8585
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
8686

8787
sys.setrecursionlimit(4096)

examples/qualcomm/oss_scripts/moshi/mimi.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
from huggingface_hub import hf_hub_download
3838
from moshi.models import loaders
3939

40-
from torch.ao.quantization.observer import MinMaxObserver
40+
from torchao.quantization.pt2e import MinMaxObserver
4141

4242

4343
def seed_all(seed):

0 commit comments

Comments
 (0)