From de73e73f1310f8e44149266089f3e8442be30bd8 Mon Sep 17 00:00:00 2001 From: daniil-lyakhov Date: Fri, 21 Feb 2025 17:21:56 +0100 Subject: [PATCH 1/6] WIP test quantization --- .../tests/quantizer/test_pt2e_quantization.py | 748 +++++++++++ .../tests/quantizer/test_representation.py | 311 +++++ .../tests/quantizer/test_xnnpack_quantizer.py | 1090 +++++++++++++++++ 3 files changed, 2149 insertions(+) create mode 100644 backends/openvino/tests/quantizer/test_pt2e_quantization.py create mode 100644 backends/openvino/tests/quantizer/test_representation.py create mode 100644 backends/openvino/tests/quantizer/test_xnnpack_quantizer.py diff --git a/backends/openvino/tests/quantizer/test_pt2e_quantization.py b/backends/openvino/tests/quantizer/test_pt2e_quantization.py new file mode 100644 index 00000000000..8f15ada00bd --- /dev/null +++ b/backends/openvino/tests/quantizer/test_pt2e_quantization.py @@ -0,0 +1,748 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree + +# pyre-unsafe + +from collections import Counter +from typing import Dict, Tuple, Optional + +import torch +from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import ( + get_symmetric_quantization_config, + XNNPACKQuantizer, +) +from executorch.backends.openvino.quantizer.quantizer import OpenVINOQuantizer +from executorch.backends.openvino.quantizer.quantizer import QuantizationMode + +from torch.ao.quantization import ( + compare_results, + CUSTOM_KEY, + default_per_channel_symmetric_qnnpack_qconfig, + extract_results_from_loggers, + generate_numeric_debug_handle, + NUMERIC_DEBUG_HANDLE_KEY, + observer, + prepare_for_propagation_comparison, +) +from torch.ao.quantization.pt2e.graph_utils import bfs_trace_with_node_process +from torch.ao.quantization.qconfig import ( + float_qparams_weight_only_qconfig, + per_channel_weight_observer_range_neg_127_to_127, + QConfig, + weight_observer_range_neg_127_to_127, +) +from torch.ao.quantization.qconfig_mapping import QConfigMapping +from torch.ao.quantization.quantize_pt2e import ( + convert_pt2e, + prepare_pt2e, + prepare_qat_pt2e, +) +from torch.ao.quantization.quantizer import Quantizer +from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer +from torch.ao.quantization.quantizer.embedding_quantizer import EmbeddingQuantizer +from torch.export import export_for_training +from torch.testing._internal.common_quantization import ( + NodeSpec as ns, + PT2EQuantizationTestCase, + TestHelperModules, +) +from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + TemporaryFileName, + TestCase, +) +from nncf.torch import disable_patching + + +class TestQuantizePT2E(PT2EQuantizationTestCase): + + def run(self, result=None): + """ + Disable NNCF pathing for each test call + """ + with disable_patching(): + super().run(result) + + + def _get_pt2e_quantized_linear( + self, mode: Optional[QuantizationMode] = None + ) -> torch.fx.GraphModule: + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(2, 2) + + def forward(self, x): + return self.linear(x) + + if mode is None: + quantizer = OpenVINOQuantizer() + else: + quantizer = OpenVINOQuantizer(mode=mode) + example_inputs = (torch.randn(2, 2),) + m = M().eval() + return self._quantize(m, quantizer, example_inputs) + + def test_fold_all_ops_before_quantize(self) -> None: + """Test folding all ops that's before quantized operator: + Before: + get_attr(weight) -> transpose -> quantize -> dequantize + After: + get_attr(folded_weight) -> dequantize + """ + + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.weight = torch.randn(2, 2) + + def forward(self, x): + t = self.weight.t() + return torch.nn.functional.linear(x, t) + + quantizer = OpenVINOQuantizer() + example_inputs = (torch.randn(2, 2),) + m = M().eval() + m = self._quantize(m, quantizer, example_inputs) + node_occurrence = { + # quantize op for weight node is folded + ns.call_function( + torch.ops.quantized_decomposed.quantize_per_tensor.default + ): 1, + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_tensor.default + ): 1, + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_channel.default + ): 1, + } + self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) + + def test_composable_quantizer_throw(self) -> None: + class BadQuantizer(Quantizer): + def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + for n in gm.graph.nodes: + n.meta["quantization_annotation"] = None + + def validate(self, model: torch.fx.GraphModule) -> None: + pass + + quantizer = OpenVINOQuantizer() + bad_quantizer = BadQuantizer() + composable_quantizer = ComposableQuantizer([quantizer, bad_quantizer]) + m_eager = TestHelperModules.ConvLinearWPermute().eval() + example_inputs = (torch.randn(2, 3, 4, 4),) + self.assertRaises( + RuntimeError, + lambda: self._test_quantizer( + m_eager, example_inputs, composable_quantizer, {} + ), + ) + + def test_composable_quantizer_linear_conv(self) -> None: + #TODO: impelment whe dynamic quantization will be supported by OpenVINOQuantizer + pass + + def test_embedding_conv_linear_quantization(self) -> None: + # Mark + m_eager = TestHelperModules.EmbeddingConvLinearModule().eval() + indices = torch.tensor( + [ + 9, + 6, + 5, + 7, + 8, + 8, + 9, + 2, + 8, + 6, + 6, + 9, + 1, + 6, + 8, + 8, + 3, + 2, + 3, + 6, + 3, + 6, + 5, + 7, + 0, + 8, + 4, + 6, + 5, + 8, + 2, + 3, + ] + ) + indices = torch.unsqueeze(indices, 0) + example_inputs = (indices,) + + embedding_quantizer = EmbeddingQuantizer() + dynamic_quantizer = XNNPACKQuantizer() + quantization_config_dynamic = get_symmetric_quantization_config( + is_per_channel=True, is_dynamic=True + ) + dynamic_quantizer.set_global(quantization_config_dynamic) + static_quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config(is_per_channel=True) + static_quantizer.set_global(quantization_config) + composed_quantizer = ComposableQuantizer( + [embedding_quantizer, dynamic_quantizer, static_quantizer] + ) + + act_affine_quant_obs = observer.PlaceholderObserver.with_args( + dtype=torch.qint8, + qscheme=torch.per_tensor_affine, + quant_min=-128, + quant_max=127, + eps=2**-12, + is_dynamic=True, + ) + dynamic_qconfig = QConfig( + activation=act_affine_quant_obs, + weight=per_channel_weight_observer_range_neg_127_to_127, + ) + qconfig = default_per_channel_symmetric_qnnpack_qconfig + qconfig_mapping = QConfigMapping().set_global(qconfig) + qconfig_mapping.set_object_type(torch.nn.Linear, dynamic_qconfig) + qconfig_mapping = qconfig_mapping.set_object_type( + torch.nn.Embedding, float_qparams_weight_only_qconfig + ) + + node_occurrence = { + torch.ops.quantized_decomposed.quantize_per_tensor.default: 4, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 4, + torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 1, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 1, + # note: quantize op for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 3, + } + self._test_quantizer( + m_eager, + example_inputs, + composed_quantizer, + node_occurrence, + [], + True, + qconfig_mapping, + ) + + def test_disallow_eval_train(self) -> None: + m = TestHelperModules.ConvWithBNRelu(relu=True) + example_inputs = (torch.rand(3, 3, 5, 5),) + + # Before export: this is OK + m.eval() + m.train() + + # After export: this is not OK + m = export_for_training(m, example_inputs).module() + with self.assertRaises(NotImplementedError): + m.eval() + with self.assertRaises(NotImplementedError): + m.train() + + # After prepare: still not OK + quantizer = OpenVINOQuantizer() + m = prepare_qat_pt2e(m, quantizer) # pyre-ignore[6] + with self.assertRaises(NotImplementedError): + m.eval() + with self.assertRaises(NotImplementedError): + m.train() + + # After convert: still not OK + m = convert_pt2e(m) + with self.assertRaises(NotImplementedError): + m.eval() + with self.assertRaises(NotImplementedError): + m.train() + + def _get_bn_train_eval_ops(self) -> Tuple[torch._ops.OpOverload]: + return ( + torch.ops.aten.batch_norm.default, + torch.ops.aten.batch_norm.default, + ) + + def _get_node( + self, m: torch.fx.GraphModule, target: torch._ops.OpOverload + ) -> torch.fx.Node: + """ + Return the first node matching the specified target, throwing an exception + if no such batch norm node is found. + """ + for n in m.graph.nodes: + if n.target == target: + return n + raise ValueError("Did not find node with target ", target) + + def test_allow_exported_model_train_eval(self) -> None: + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.bn = torch.nn.BatchNorm2d(3) + self.dropout = torch.nn.Dropout(0.5) + + def forward(self, x): + x = self.bn(x) + x = self.dropout(x) + return x + + m = M().train() + example_inputs = (torch.randn(1, 3, 3, 3),) + bn_train_op, bn_eval_op = self._get_bn_train_eval_ops() # pyre-ignore[23] + m = export_for_training(m, example_inputs).module() + + def _assert_ops_are_correct(m: torch.fx.GraphModule, train: bool) -> None: + bn_op = bn_train_op if train else bn_eval_op + bn_node = self._get_node(m, bn_op) + self.assertTrue(bn_node is not None) + dropout_node = self._get_node(m, torch.ops.aten.dropout.default) + self.assertEqual(dropout_node.args[2], train) + + # Before wrapping: this is not OK + with self.assertRaises(NotImplementedError): + m.eval() + with self.assertRaises(NotImplementedError): + m.train() + + # After wrapping: does not error and swaps the ops accordingly + torch.ao.quantization.allow_exported_model_train_eval(m) # pyre-ignore[6] + m.eval() + _assert_ops_are_correct(m, train=False) # pyre-ignore[6] + m.train() + _assert_ops_are_correct(m, train=True) # pyre-ignore[6] + + # After prepare but before wrapping: this is not OK + quantizer = XNNPACKQuantizer() + m = prepare_qat_pt2e(m, quantizer) # pyre-ignore[6] + with self.assertRaises(NotImplementedError): + m.eval() + with self.assertRaises(NotImplementedError): + m.train() + + # After prepare and after wrapping: does not error and swaps the ops accordingly + torch.ao.quantization.allow_exported_model_train_eval(m) + m.eval() + _assert_ops_are_correct(m, train=False) + m.train() + _assert_ops_are_correct(m, train=True) + + # After convert but before wrapping: this is not OK + m = convert_pt2e(m, fold_quantize=True) + with self.assertRaises(NotImplementedError): + m.eval() + with self.assertRaises(NotImplementedError): + m.train() + + # After convert and after wrapping: does not error and swaps the ops accordingly + torch.ao.quantization.allow_exported_model_train_eval(m) + m.eval() + _assert_ops_are_correct(m, train=False) + m.train() + _assert_ops_are_correct(m, train=True) + + def test_constant_prop_preserve_metadata(self) -> None: + """Test to make sure the get_attr node for const propagated weight Tensor gets the correct + metadata (from original get_attr node from weight) + """ + + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(2, 2) + + def forward(self, x): + return self.linear(x) + + quantizer = XNNPACKQuantizer() + operator_config = get_symmetric_quantization_config() + quantizer.set_global(operator_config) + example_inputs = (torch.randn(2, 2),) + m = M().eval() + m = export_for_training( + m, + example_inputs, + ).module() + weight_meta = None + for n in m.graph.nodes: # pyre-ignore[16] + if ( + n.op == "get_attr" + and next(iter(n.users)).target == torch.ops.aten.linear.default + ): + weight_meta = n.meta + break + assert weight_meta is not None, "Expect to find metadata for weight node" + + m = prepare_pt2e(m, quantizer) # pyre-ignore[6] + m(*example_inputs) + m = convert_pt2e(m) + + for n in m.graph.nodes: + if n.op == "get_attr" and "frozen_param" in n.target: + for key in n.meta: + self.assertEqual(n.meta[key], weight_meta[key]) + + def test_reentrant(self) -> None: + """Test we can safely call quantization apis multiple times""" + m = TestHelperModules.ConvBnReLU2dAndLinearReLU() + example_inputs = (torch.randn(3, 3, 10, 10),) + + quantizer = XNNPACKQuantizer().set_global( + get_symmetric_quantization_config(is_per_channel=True, is_qat=True) + ) + m.conv_bn_relu = export_for_training( # pyre-ignore[8] + m.conv_bn_relu, example_inputs + ).module() + m.conv_bn_relu = prepare_qat_pt2e(m.conv_bn_relu, quantizer) # pyre-ignore[6,8] + m(*example_inputs) + m.conv_bn_relu = convert_pt2e(m.conv_bn_relu) # pyre-ignore[6, 8] + + quantizer = XNNPACKQuantizer().set_module_type( + torch.nn.Linear, get_symmetric_quantization_config(is_per_channel=False) + ) + m = export_for_training(m, example_inputs).module() + m = prepare_pt2e(m, quantizer) # pyre-ignore[6] + m = convert_pt2e(m) + + node_occurrence = { + ns.call_function( + torch.ops.quantized_decomposed.quantize_per_tensor.default + ): 4, + # one for weight + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_tensor.default + ): 5, + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_channel.default + ): 1, + } + node_list = [ + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_tensor.default + ), + ns.call_function(torch.ops.aten.conv2d.default), + ns.call_function(torch.ops.aten.relu.default), + ns.call_function( + torch.ops.quantized_decomposed.quantize_per_tensor.default + ), + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_tensor.default + ), + ns.call_function(torch.ops.aten.linear.default), + ns.call_function( + torch.ops.quantized_decomposed.quantize_per_tensor.default + ), + ] + self.checkGraphModuleNodes( + m, expected_node_occurrence=node_occurrence, expected_node_list=node_list + ) + + def test_groupwise_per_channel_quant(self) -> None: + m = TestHelperModules.GroupwiseConv2d() + quantizer = XNNPACKQuantizer() + operator_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_global(operator_config) + example_inputs = m.example_inputs() + m = self._quantize(m, quantizer, example_inputs) + # make sure it runs + m(*example_inputs) + + def test_preserve_nn_module_stack(self) -> None: + """Test we can preserve nn_module_stack on replaced pattern's nodes""" + m = TestHelperModules.ConvBnReLU2dAndLinearReLU() + example_inputs = (torch.randn(3, 3, 10, 10),) + + quantizer = XNNPACKQuantizer().set_global( + get_symmetric_quantization_config(is_per_channel=True, is_qat=True) + ) + + def check_nn_module(node: torch.fx.Node) -> None: + self.assertTrue("nn_module_stack" in node.meta) + self.assertTrue( + "ConvWithBNRelu" in node.meta["nn_module_stack"]["L__self__"][1] + ) + + m.conv_bn_relu = export_for_training( # pyre-ignore[8] + m.conv_bn_relu, example_inputs + ).module() + for node in m.conv_bn_relu.graph.nodes: # pyre-ignore[16] + if node.op not in ["placeholder", "output", "get_attr"]: + check_nn_module(node) + m.conv_bn_relu = prepare_qat_pt2e(m.conv_bn_relu, quantizer) # pyre-ignore[6,8] + for node in m.conv_bn_relu.graph.nodes: # pyre-ignore[16] + if node.name == "mul": + check_nn_module(node) + + def test_speed(self) -> None: + import time # noqa: F401 + + def dynamic_quantize_pt2e(model, example_inputs) -> torch.fx.GraphModule: + torch._dynamo.reset() + model = export_for_training(model, example_inputs).module() + # Per channel quantization for weight + # Dynamic quantization for activation + # Please read a detail: https://fburl.com/code/30zds51q + embedding_quantizer = EmbeddingQuantizer() + dynamic_quantizer = XNNPACKQuantizer() + operator_config_dynamic = get_symmetric_quantization_config( + is_per_channel=True, is_dynamic=True + ) + dynamic_quantizer.set_global(operator_config_dynamic) + composed_quantizer = ComposableQuantizer( + [embedding_quantizer, dynamic_quantizer] + ) + # prev = time.time() + model = prepare_qat_pt2e(model, composed_quantizer) # pyre-ignore[6] + # cur = time.time() + # print("prepare time:", cur - prev) + # Without Calibraiton, scale/zero value will have an initialized value of 1.0 + # Per channel quantization needs a proper scale/zero shape/value to work properly. + # So we need to run calibration before converting to quantized model. + model(*example_inputs) + # prev = time.time() + model = convert_pt2e(model) + # cur = time.time() + # uncomment to see the time + # print("convert time:", cur - prev) + return model + + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(5, 5) + + def forward(self, x): + return self.linear(x) + + m = M().eval() + example_inputs = (torch.randn(5, 5),) + _ = dynamic_quantize_pt2e(m, example_inputs) + + def test_multi_users_without_output_observer(self) -> None: + """ + Test the case in which a node is used by multiple users, + and had its output observer removed. + """ + + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv = torch.nn.Conv2d(3, 3, 3) + + def forward(self, x): + x = self.conv(x) + return x, x + 1 + + example_inputs = (torch.randn(1, 3, 5, 5),) + m = M() + m = export_for_training(m, example_inputs).module() + quantizer = XNNPACKQuantizer().set_global( + get_symmetric_quantization_config(), + ) + m = prepare_pt2e(m, quantizer) # pyre-ignore[6] + m(*example_inputs) + + # Remove output observer + observer_to_remove = None + for n in m.graph.nodes: + if n.op == "output": + observer_to_remove = n.args[0][0] + assert observer_to_remove.op == "call_module" + assert observer_to_remove.target.startswith("activation_post_process_") + break + assert observer_to_remove is not None + observer_to_remove.replace_all_uses_with(observer_to_remove.args[0]) + m.graph.erase_node(observer_to_remove) + m.recompile() + + # Convert should succeed + m = convert_pt2e(m) + m(*example_inputs) + + def test_fold_quantize_sym(self) -> None: + """Test to make sure the quantized model gets quantized weight (quantize_per_tensor op is folded)""" + m = self._get_pt2e_quantized_linear() + + node_occurrence = { + # quantize op for weight node is folded + ns.call_function( + torch.ops.quantized_decomposed.quantize_per_tensor.default + ): 1, + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_channel.default + ): 1, + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_tensor.default + ): 1 + } + + self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) + + def test_fold_quantize_mixed(self) -> None: + """Test to make sure the quantized model gets quantized weight (quantize_per_channel op is folded)""" + m = self._get_pt2e_quantized_linear(mode=QuantizationMode.INT8_MIXED) + node_occurrence = { + # quantize op for weight node is folded + ns.call_function( + torch.ops.quantized_decomposed.quantize_per_tensor.default + ): 1, + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_channel.default + ): 1, + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_tensor.default + ): 1 + } + self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) + + + def test_save_load(self) -> None: + """Test save/load a quantized model""" + m = self._get_pt2e_quantized_linear() + example_inputs = (torch.randn(2, 2),) + ref_res = m(*example_inputs) + + with TemporaryFileName() as fname: + # serialization + quantized_ep = torch.export.export(m, example_inputs, strict=True) + torch.export.save(quantized_ep, fname) + # deserialization + loaded_ep = torch.export.load(fname) + loaded_quantized_model = loaded_ep.module() + res = loaded_quantized_model(*example_inputs) + self.assertEqual(ref_res, res) + + +instantiate_parametrized_tests(TestQuantizePT2E) + + +class TestNumericDebugger(TestCase): + + def _extract_debug_handles(self, model) -> Dict[str, int]: + debug_handle_map: Dict[str, int] = {} + + def _extract_debug_handles_from_node(node: torch.fx.Node) -> None: + nonlocal debug_handle_map + if ( + CUSTOM_KEY in node.meta + and NUMERIC_DEBUG_HANDLE_KEY in node.meta[CUSTOM_KEY] + ): + debug_handle_map[str(node)] = node.meta[CUSTOM_KEY][ + NUMERIC_DEBUG_HANDLE_KEY + ] + + bfs_trace_with_node_process(model, _extract_debug_handles_from_node) + return debug_handle_map + + def _assert_each_node_has_debug_handle(self, model) -> None: + def _assert_node_has_debug_handle(node: torch.fx.Node) -> None: + self.assertTrue( + CUSTOM_KEY in node.meta + and NUMERIC_DEBUG_HANDLE_KEY in node.meta[CUSTOM_KEY], + f"Node {node} doesn't have debug handle", + ) + + bfs_trace_with_node_process(model, _assert_node_has_debug_handle) + + def test_quantize_pt2e_preserve_handle(self) -> None: + m = TestHelperModules.Conv2dThenConv1d() + example_inputs = m.example_inputs() + ep = export_for_training(m, example_inputs) + generate_numeric_debug_handle(ep) + m = ep.module() + + quantizer = XNNPACKQuantizer().set_global( + get_symmetric_quantization_config(is_per_channel=False) + ) + m = prepare_pt2e(m, quantizer) # pyre-ignore[6] + debug_handle_map = self._extract_debug_handles(m) + res_counter = Counter(debug_handle_map.values()) + repeated_debug_handle_ids = [1, 2, 3] + # 3 ids were repeated because we copy over the id from node to its output observer + # torch.ops.aten.conv2d.default, torch.ops.aten.squeeze.dim and torch.ops.aten.conv1d.default + for dh_id in repeated_debug_handle_ids: + self.assertEqual(res_counter[dh_id], 2) + + m(*example_inputs) + m = convert_pt2e(m) + self._assert_each_node_has_debug_handle(ep) + debug_handle_map = self._extract_debug_handles(m) + res_counter = Counter(debug_handle_map.values()) + # same set of ids where repeated, because we copy over the id from observer/fake_quant to + # dequantize node + repeated_debug_handle_ids = [1, 2, 3] + for dh_id in repeated_debug_handle_ids: + self.assertEqual(res_counter[dh_id], 2) + + def test_extract_results_from_loggers(self) -> None: + m = TestHelperModules.Conv2dThenConv1d() + example_inputs = m.example_inputs() + ep = export_for_training(m, example_inputs) + generate_numeric_debug_handle(ep) + m = ep.module() + m_ref_logger = prepare_for_propagation_comparison(m) # pyre-ignore[6] + + quantizer = XNNPACKQuantizer().set_global( + get_symmetric_quantization_config(is_per_channel=False) + ) + m = prepare_pt2e(m, quantizer) # pyre-ignore[6] + m(*example_inputs) + m = convert_pt2e(m) + m_quant_logger = prepare_for_propagation_comparison(m) + + m_ref_logger(*example_inputs) + m_quant_logger(*example_inputs) + ref_results = extract_results_from_loggers(m_ref_logger) + quant_results = extract_results_from_loggers(m_quant_logger) + comparison_results = compare_results( + ref_results, quant_results # pyre-ignore[6] + ) + for node_summary in comparison_results.values(): + if len(node_summary.results) > 0: + self.assertGreaterEqual( + node_summary.results[0].sqnr, 35 # pyre-ignore[6] + ) + + def test_extract_results_from_loggers_list_output(self) -> None: + m = TestHelperModules.Conv2dWithSplit() + example_inputs = m.example_inputs() + ep = export_for_training(m, example_inputs) + generate_numeric_debug_handle(ep) + m = ep.module() + m_ref_logger = prepare_for_propagation_comparison(m) # pyre-ignore[6] + + quantizer = XNNPACKQuantizer().set_global( + get_symmetric_quantization_config(is_per_channel=False) + ) + m = prepare_pt2e(m, quantizer) # pyre-ignore[6] + m(*example_inputs) + m = convert_pt2e(m) + m_quant_logger = prepare_for_propagation_comparison(m) + + m_ref_logger(*example_inputs) + m_quant_logger(*example_inputs) + ref_results = extract_results_from_loggers(m_ref_logger) + quant_results = extract_results_from_loggers(m_quant_logger) + comparison_results = compare_results( + ref_results, quant_results # pyre-ignore[6] + ) + for node_summary in comparison_results.values(): + if len(node_summary.results) > 0: + sqnr = node_summary.results[0].sqnr + if isinstance(sqnr, list): + for sqnr_i in sqnr: + self.assertGreaterEqual(sqnr_i, 35) + else: + self.assertGreaterEqual(sqnr, 35) # pyre-ignore[6] diff --git a/backends/openvino/tests/quantizer/test_representation.py b/backends/openvino/tests/quantizer/test_representation.py new file mode 100644 index 00000000000..83cecaec5ad --- /dev/null +++ b/backends/openvino/tests/quantizer/test_representation.py @@ -0,0 +1,311 @@ +# Owner(s): ["oncall: quantization"] +import copy +from typing import Any, Optional + +import torch +from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import ( + get_symmetric_quantization_config, + XNNPACKQuantizer, +) +from torch._higher_order_ops.out_dtype import out_dtype # noqa: F401 +from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e +from torch.ao.quantization.quantizer import Quantizer +from torch.export import export_for_training +from torch.testing._internal.common_quantization import ( + NodeSpec as ns, + QuantizationTestCase, + skipIfNoQNNPACK, + TestHelperModules, +) + + +@skipIfNoQNNPACK +class TestPT2ERepresentation(QuantizationTestCase): + def _test_representation( + self, + model: torch.nn.Module, + example_inputs: tuple[Any, ...], + quantizer: Quantizer, + ref_node_occurrence: dict[ns, int], + non_ref_node_occurrence: dict[ns, int], + fixed_output_tol: Optional[float] = None, + output_scale_idx: int = 2, + ) -> None: + # resetting dynamo cache + torch._dynamo.reset() + model = export_for_training( + model, + example_inputs, + ).module() + model_copy = copy.deepcopy(model) + + model = prepare_pt2e(model, quantizer) # pyre-ignore[6] + # Calibrate + model(*example_inputs) + model = convert_pt2e(model, use_reference_representation=True) + self.checkGraphModuleNodes(model, expected_node_occurrence=ref_node_occurrence) + # make sure it runs + pt2e_quant_output = model(*example_inputs) + + # TODO: torchdynamo times out when we do this, we can enable numerical checking + # after that is fixed + model_copy = prepare_pt2e(model_copy, quantizer) # pyre-ignore[6] + # Calibrate + model_copy(*example_inputs) + model_copy = convert_pt2e(model_copy, use_reference_representation=False) + self.checkGraphModuleNodes( + model_copy, expected_node_occurrence=non_ref_node_occurrence + ) + pt2e_quant_output_copy = model_copy(*example_inputs) + + output_tol = None + if fixed_output_tol is not None: + output_tol = fixed_output_tol + else: + idx = 0 + for n in model_copy.graph.nodes: + if ( + n.target + == torch.ops.quantized_decomposed.quantize_per_tensor.default + ): + idx += 1 + if idx == output_scale_idx: + output_tol = n.args[1] + assert output_tol is not None + + # make sure the result is off by one at most in the quantized integer representation + self.assertTrue( + torch.max(torch.abs(pt2e_quant_output_copy - pt2e_quant_output)) + <= (2 * output_tol + 1e-5) + ) + + def test_static_linear(self): + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(5, 5) + + def forward(self, x): + return self.linear(x) + + quantizer = XNNPACKQuantizer() + operator_config = get_symmetric_quantization_config(is_per_channel=False) + quantizer.set_global(operator_config) + example_inputs = (torch.randn(2, 5),) + + self._test_representation( + M().eval(), + example_inputs, + quantizer, + ref_node_occurrence={}, + non_ref_node_occurrence={}, + ) + + def test_dynamic_linear(self): + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(5, 5) + + def forward(self, x): + return self.linear(x) + + quantizer = XNNPACKQuantizer() + operator_config = get_symmetric_quantization_config( + is_per_channel=False, is_dynamic=True + ) + quantizer.set_global(operator_config) + example_inputs = (torch.randn(2, 5),) + + self._test_representation( + M().eval(), + example_inputs, + quantizer, + ref_node_occurrence={}, + non_ref_node_occurrence={}, + fixed_output_tol=1e-4, + ) + + def test_conv2d(self): + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv2d = torch.nn.Conv2d(3, 3, 3) + + def forward(self, x): + return self.conv2d(x) + + quantizer = XNNPACKQuantizer() + operator_config = get_symmetric_quantization_config(is_per_channel=False) + quantizer.set_global(operator_config) + example_inputs = (torch.randn(1, 3, 3, 3),) + + self._test_representation( + M().eval(), + example_inputs, + quantizer, + ref_node_occurrence={}, + non_ref_node_occurrence={}, + ) + + def test_add(self): + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, x, y): + return x + y + + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_global(quantization_config) + M().eval() + + example_inputs = ( + torch.randn(1, 3, 3, 3), + torch.randn(1, 3, 3, 3), + ) + + self._test_representation( + M().eval(), + example_inputs, + quantizer, + ref_node_occurrence={}, + non_ref_node_occurrence={}, + ) + + def test_add_relu(self): + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, x, y): + out = x + y + out = torch.nn.functional.relu(out) + return out + + quantizer = XNNPACKQuantizer() + operator_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_global(operator_config) + + example_inputs = ( + torch.randn(1, 3, 3, 3), + torch.randn(1, 3, 3, 3), + ) + ref_node_occurrence = { + ns.call_function(out_dtype): 2, + } + + self._test_representation( + M().eval(), + example_inputs, + quantizer, + ref_node_occurrence=ref_node_occurrence, + non_ref_node_occurrence={}, + ) + + def test_maxpool2d(self): + quantizer = XNNPACKQuantizer() + operator_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_global(operator_config) + m_eager = TestHelperModules.ConvMaxPool2d().eval() + + example_inputs = (torch.randn(1, 2, 2, 2),) + + self._test_representation( + m_eager, + example_inputs, + quantizer, + ref_node_occurrence={}, + non_ref_node_occurrence={}, + ) + + def test_qdq_per_channel(self): + """Test representation for quantize_per_channel and dequantize_per_channel op""" + + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(5, 5) + + def forward(self, x): + return self.linear(x) + + quantizer = XNNPACKQuantizer() + # use per channel quantization for weight + operator_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_global(operator_config) + M().eval() + + inputs = [ + (torch.randn(1, 5),), + (torch.randn(1, 3, 5),), + (torch.randn(1, 3, 3, 5),), + (torch.randn(1, 3, 3, 3, 5),), + ] + for example_inputs in inputs: + ref_node_occurrence = { + ns.call_function( + torch.ops.quantized_decomposed.quantize_per_channel.default + ): 0, + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_channel.default + ): 0, + } + non_ref_node_occurrence = { + # quantize_per_channel is folded + ns.call_function( + torch.ops.quantized_decomposed.quantize_per_channel.default + ): 0, + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_channel.default + ): 1, + } + + self._test_representation( + M().eval(), + example_inputs, + quantizer, + ref_node_occurrence, + non_ref_node_occurrence, + output_scale_idx=2, + ) + + def test_qdq(self): + """Test representation for quantize and dequantize op""" + + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, x, y): + return x + y + + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_global(quantization_config) + M().eval() + + example_inputs = ( + torch.randn(1, 3, 3, 3), + torch.randn(1, 3, 3, 3), + ) + ref_node_occurrence = { + ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor): 0, + ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor): 0, + } + non_ref_node_occurrence = { + ns.call_function( + torch.ops.quantized_decomposed.quantize_per_tensor.default + ): 3, + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_tensor.default + ): 3, + } + self._test_representation( + M().eval(), + example_inputs, + quantizer, + ref_node_occurrence, + non_ref_node_occurrence, + ) diff --git a/backends/openvino/tests/quantizer/test_xnnpack_quantizer.py b/backends/openvino/tests/quantizer/test_xnnpack_quantizer.py new file mode 100644 index 00000000000..57aacf55263 --- /dev/null +++ b/backends/openvino/tests/quantizer/test_xnnpack_quantizer.py @@ -0,0 +1,1090 @@ +# Owner(s): ["oncall: mobile"] +import copy +import operator + +import torch +import torch._dynamo as torchdynamo +from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import ( + get_symmetric_quantization_config, + XNNPACKQuantizer, +) +from torch.ao.ns.fx.utils import compute_sqnr +from torch.ao.quantization import ( + default_dynamic_fake_quant, + default_dynamic_qconfig, + observer, + QConfig, + QConfigMapping, +) +from torch.ao.quantization.backend_config import get_qnnpack_backend_config +from torch.ao.quantization.qconfig import ( + default_per_channel_symmetric_qnnpack_qconfig, + default_symmetric_qnnpack_qconfig, + per_channel_weight_observer_range_neg_127_to_127, + weight_observer_range_neg_127_to_127, +) +from torch.ao.quantization.quantize_fx import ( + _convert_to_reference_decomposed_fx, + convert_to_reference_fx, + prepare_fx, +) +from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e +from torch.export import export_for_training +from torch.testing._internal.common_quantization import ( + NodeSpec as ns, + PT2EQuantizationTestCase, + skip_if_no_torchvision, + skipIfNoQNNPACK, + TestHelperModules, +) +from torch.testing._internal.common_quantized import override_quantized_engine + + +@skipIfNoQNNPACK +class TestXNNPACKQuantizer(PT2EQuantizationTestCase): + def test_conv1d(self): + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_global(quantization_config) + example_inputs = (torch.randn(1, 3, 5),) + node_occurrence = { + # input and output are using quantize_per_tensor and weight is using quantize_per_channel + torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, + } + node_list = [ + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.conv1d.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + ] + self._test_quantizer( + TestHelperModules.ConvWithBNRelu(dim=1, relu=False, bn=False), + example_inputs, + quantizer, + node_occurrence, + node_list, + ) + + def test_conv2d(self): + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_global(quantization_config) + example_inputs = (torch.randn(1, 3, 5, 5),) + node_occurrence = { + # input and output are using quantize_per_tensor and weight is using quantize_per_channel + torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, + # quantize_per_channel for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, + } + node_list = [ + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.conv2d.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + ] + self._test_quantizer( + TestHelperModules.ConvWithBNRelu(relu=False, bn=False), + example_inputs, + quantizer, + node_occurrence, + node_list, + ) + + def test_conv1d_with_conv2d(self): + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_global(quantization_config) + node_occurrence = { + # input and output are using quantize_per_tensor and weight is using quantize_per_channel + torch.ops.quantized_decomposed.quantize_per_tensor.default: 4, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 4, + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, + } + node_list = [ + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.conv2d.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.conv1d.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + ] + m = TestHelperModules.Conv2dThenConv1d() + self._test_quantizer( + m, + m.example_inputs(), + quantizer, + node_occurrence, + node_list, + ) + + def test_linear(self): + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_global(quantization_config) + m_eager = TestHelperModules.TwoLinearModule().eval() + + # Test with 2d inputs + example_inputs_2d = (torch.randn(9, 8),) + example_inputs_3d = (torch.randn(9, 10, 8),) + example_inputs_4d = (torch.randn(9, 10, 11, 8),) + node_occurrence = { + # input and output are using quantize_per_tensor and weight is using quantize_per_channel + torch.ops.quantized_decomposed.quantize_per_tensor.default: 3, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3, + # quantize_per_channel for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, + } + qconfig = default_per_channel_symmetric_qnnpack_qconfig + qconfig_mapping = QConfigMapping().set_global(qconfig) + for example_inputs in [example_inputs_2d, example_inputs_3d, example_inputs_4d]: + self._test_quantizer( + m_eager, + example_inputs, + quantizer, + node_occurrence, + [], + True, + qconfig_mapping, + ) + + def test_linear_relu(self): + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_global(quantization_config) + m_eager = TestHelperModules.LinearReluModel().eval() + + # Test with 2d inputs + example_inputs_2d = (torch.randn(1, 5),) + example_inputs_3d = (torch.randn(1, 2, 5),) + example_inputs_4d = (torch.randn(1, 2, 3, 5),) + + node_occurrence = { + # input and output are using quantize_per_tensor and weight is using quantize_per_channel + # There should not be extra quantize_per_tensor or dequantize_per_tensors for relu + torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, + # quantize_per_channel for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, + } + qconfig = default_per_channel_symmetric_qnnpack_qconfig + qconfig_mapping = QConfigMapping().set_global(qconfig) + for example_inputs in [example_inputs_2d, example_inputs_3d, example_inputs_4d]: + self._test_quantizer( + m_eager, + example_inputs, + quantizer, + node_occurrence, + [], # node_list + False, # executorch_backend_config() does not fuse linear-relu + qconfig_mapping, + ) + + def test_conv_linear_no_permute(self): + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_global(quantization_config) + node_occurrence = { + # input and output are using quantize_per_tensor and weight is using quantize_per_channel + torch.ops.quantized_decomposed.quantize_per_tensor.default: 5, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 5, + # quantize_per_channel for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 3, + } + qconfig = default_per_channel_symmetric_qnnpack_qconfig + qconfig_mapping = QConfigMapping().set_global(qconfig) + # Test with 2d inputs + example_inputs = (torch.randn(2, 3, 4, 4),) + self._test_quantizer( + TestHelperModules.Conv2dWithTwoLinear(), + example_inputs, + quantizer, + node_occurrence, + [], + True, + qconfig_mapping, + ) + + def test_conv_linear(self): + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_global(quantization_config) + + # Test with 2d inputs + example_inputs = (torch.randn(2, 3, 4, 4),) + node_occurrence = { + torch.ops.quantized_decomposed.quantize_per_tensor.default: 5, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 5, + # quantize_per_channel for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 3, + } + qconfig = default_per_channel_symmetric_qnnpack_qconfig + qconfig_mapping = QConfigMapping().set_global(qconfig) + self._test_quantizer( + TestHelperModules.Conv2dWithTwoLinearPermute(), + example_inputs, + quantizer, + node_occurrence, + [], + True, + qconfig_mapping, + ) + + def test_linear_with_dynamic_shape(self): + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_global(quantization_config) + m_eager = TestHelperModules.TwoLinearModule().eval() + + # Test with 2d inputs + example_inputs_3d = (torch.randn(9, 10, 8),) + node_occurrence = { + # input and output are using quantize_per_tensor and weight is using quantize_per_channel + torch.ops.quantized_decomposed.quantize_per_tensor.default: 3, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3, + # quantize_per_channel for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, + } + qconfig = default_per_channel_symmetric_qnnpack_qconfig + qconfig_mapping = QConfigMapping().set_global(qconfig) + self._test_quantizer( + m_eager, + example_inputs_3d, + quantizer, + node_occurrence, + [], + True, + qconfig_mapping, + export_with_dynamic_shape=True, + ) + + def test_obs_sharing_ops(self): + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_global(quantization_config) + m = TestHelperModules.Conv2dWithObsSharingOps().eval() + example_inputs = (torch.randn(1, 3, 5, 5),) + node_occurrence = { + # input and output are using quantize_per_tensor and weight is using quantize_per_channel + torch.ops.quantized_decomposed.quantize_per_tensor.default: 5, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 5, + # quantize_per_channel for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, + } + node_list = [ + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.conv2d.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.adaptive_avg_pool2d.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.hardtanh.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.mean.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + ] + self._test_quantizer(m, example_inputs, quantizer, node_occurrence, node_list) + + def test_set_module_name(self): + class Sub(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(5, 5) + + def forward(self, x): + return self.linear(x) + + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(5, 5) + self.sub = Sub() + + def forward(self, x): + x = self.linear(x) + x = self.sub(x) + return x + + m = M().eval() + example_inputs = (torch.randn(3, 5),) + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_module_name("sub", quantization_config) + node_occurrence = { + torch.ops.aten.linear.default: 2, + # input and output for the second linear + torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, + } + node_list = [ + # first linear is not quantized + torch.ops.aten.linear.default, + # second linear is quantized + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.linear.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + ] + self._test_quantizer(m, example_inputs, quantizer, node_occurrence, node_list) + + def test_set_module_name_with_underscores(self) -> None: + """Test that if a module name has an underscore, we can still quantize it""" + + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + # This module name has underscores, which can be part of a mangled + # name. + self.foo_bar = torch.nn.Linear(2, 2) + self.baz = torch.nn.Linear(2, 2) + + def forward(self, x): + return self.baz(self.foo_bar(x)) + + quantizer = XNNPACKQuantizer() + # Set global to no quantization and then per-channel for a specific submodule. + quantizer.set_module_name( + "foo_bar", get_symmetric_quantization_config(is_per_channel=True) + ) + example_inputs = (torch.randn(2, 2),) + m = M().eval() + m = export_for_training(m, example_inputs).module() + m = prepare_pt2e(m, quantizer) # pyre-ignore[6] + # Use a linear count instead of names because the names might change, but + # the order should be the same. + count = 0 + for n in m.graph.nodes: + if n.op == "call_function" and n.target == torch.ops.aten.linear.default: + # Get the weight observer to see the per-channel vs per-tensor. + weight_observer_node = n.args[1] + if count == 0: + # The weight tensor should be per-tensor and not per-channel + # for foo_bar. + self.assertEqual(weight_observer_node.op, "call_module") + observer_instance = getattr(m, weight_observer_node.target) + self.assertEqual( + observer_instance.qscheme, torch.per_channel_symmetric + ) + else: + # For baz it should have no observer at all. + self.assertNotEqual(weight_observer_node.op, "call_module") + count += 1 + + def test_set_module_type(self): + class Sub(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(5, 5) + + def forward(self, x): + return self.linear(x) + + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(5, 5) + self.sub = Sub() + + def forward(self, x): + x = self.linear(x) + x = self.sub(x) + return x + + m = M().eval() + example_inputs = (torch.randn(3, 5),) + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_module_type(Sub, quantization_config) + node_occurrence = { + torch.ops.aten.linear.default: 2, + # input and output for the second linear + torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, + } + node_list = [ + # first linear is not quantized + torch.ops.aten.linear.default, + # second linear is quantized + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.linear.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + ] + self._test_quantizer(m, example_inputs, quantizer, node_occurrence, node_list) + + def test_set_module_type_case_2(self): + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv = torch.nn.Conv2d( + in_channels=3, + out_channels=3, + kernel_size=3, + stride=1, + padding=1, + bias=True, + ) + self.conv2 = torch.nn.Conv2d( + in_channels=3, + out_channels=3, + kernel_size=3, + stride=1, + padding=1, + bias=True, + ) + self.conv3 = torch.nn.Conv2d( + in_channels=3, + out_channels=3, + kernel_size=3, + stride=1, + padding=1, + bias=True, + ) + self.relu = torch.nn.ReLU() + self.avgpool = torch.nn.AdaptiveAvgPool2d((1, 1)) + self.fc = torch.nn.Linear(3, 16) + + def forward(self, x): + x1 = self.conv(x) + x2 = self.relu(self.conv2(x1) + self.conv3(x1)) + x3 = self.avgpool(x2) + x4 = torch.flatten(x3, 1) + x5 = self.fc(x4) + return x5 + + m = M().eval() + example_inputs = (torch.randn(1, 3, 16, 16),) + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config(is_per_channel=True) + # We only want to annotate Linear type + quantizer.set_module_type(torch.nn.Linear, quantization_config) + node_occurrence = { + torch.ops.aten.conv2d.default: 3, + torch.ops.aten.linear.default: 1, + # input and output for the linear + torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, + } + node_list = [ + # only the linear is quantized + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.linear.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + ] + self._test_quantizer(m, example_inputs, quantizer, node_occurrence, node_list) + + def test_propagate_annotation(self): + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_global(quantization_config) + m = TestHelperModules.Conv2dPropAnnotaton().eval() + example_inputs = (torch.randn(1, 3, 5, 5),) + + # program capture + m = export_for_training( + m, + example_inputs, + ).module() + + m = prepare_pt2e(m, quantizer) + m(*example_inputs) + for n in m.graph.nodes: + if n.target in [ + torch.ops.aten.view.default, + torch.ops.aten.hardtanh.default, + ]: + input_act = getattr(m, n.args[0].target) + output_act = getattr(m, next(iter(n.users)).target) + self.assertIs(input_act, output_act) + + m = convert_pt2e(m) + node_occurrence = { + # input and output are using quantize_per_tensor and weight is using quantize_per_channel + ns.call_function( + torch.ops.quantized_decomposed.quantize_per_tensor.default + ): 5, + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_tensor.default + ): 5, + # note: quantize op for weights are const propagated + ns.call_function( + torch.ops.quantized_decomposed.quantize_per_channel.default + ): 0, + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_channel.default + ): 2, + } + self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) + + def test_dynamic_linear(self): + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config( + is_per_channel=True, is_dynamic=True + ) + quantizer.set_global(quantization_config) + m_eager = TestHelperModules.TwoLinearModule().eval() + + node_occurrence = { + # input and output are using quantize_per_tensor and weight is using quantize_per_channel + torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 2, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 2, + # note: quantize op for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, + } + act_affine_quant_obs = observer.PlaceholderObserver.with_args( + dtype=torch.qint8, + qscheme=torch.per_tensor_affine, + quant_min=-128, + quant_max=127, + eps=2**-12, + is_dynamic=True, + ) + qconfig = QConfig( + activation=act_affine_quant_obs, + weight=per_channel_weight_observer_range_neg_127_to_127, + ) + qconfig_mapping = QConfigMapping().set_global(qconfig) + # Test with 2d inputs + example_inputs_2d = (torch.randn(9, 8),) + example_inputs_4d = (torch.randn(9, 10, 11, 8),) + for example_inputs in [example_inputs_2d, example_inputs_4d]: + self._test_quantizer( + m_eager, + example_inputs, + quantizer, + node_occurrence, + [], + True, + qconfig_mapping, + ) + + def test_dynamic_linear_int4_weight(self): + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config( + is_per_channel=True, + is_dynamic=True, + weight_qmin=0, + weight_qmax=15, + ) + quantizer.set_global(quantization_config) + m_eager = TestHelperModules.TwoLinearModule().eval() + + node_occurrence = { + # input and output are using quantize_per_tensor and weight is using quantize_per_channel + torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 2, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 2, + # note: quantize op for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, + } + act_affine_quant_obs = observer.PlaceholderObserver.with_args( + dtype=torch.qint8, + qscheme=torch.per_tensor_affine, + quant_min=-128, + quant_max=127, + eps=2**-12, + is_dynamic=True, + ) + qconfig = QConfig( + activation=act_affine_quant_obs, + weight=per_channel_weight_observer_range_neg_127_to_127.with_args( + quant_min=0, quant_max=15 + ), + ) + qconfig_mapping = QConfigMapping().set_global(qconfig) + # Test with 2d inputs + example_inputs_2d = (torch.randn(9, 8),) + example_inputs_4d = (torch.randn(9, 10, 11, 8),) + for example_inputs in [example_inputs_2d, example_inputs_4d]: + self._test_quantizer( + m_eager, + example_inputs, + quantizer, + node_occurrence, + [], + True, + qconfig_mapping, + ) + + def test_qat_dynamic_linear(self): + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config( + is_per_channel=True, + is_dynamic=True, + is_qat=True, + ) + quantizer.set_global(quantization_config) + m_eager = TestHelperModules.TwoLinearModule().eval() + + node_occurrence = { + torch.ops.quantized_decomposed.choose_qparams.tensor: 2, + # input and output are using quantize_per_tensor and weight is using quantize_per_channel + torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 2, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 2, + # note: quantize op for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, + } + act_affine_quant_obs = default_dynamic_fake_quant + qconfig = QConfig( + activation=act_affine_quant_obs, + weight=per_channel_weight_observer_range_neg_127_to_127, + ) + qconfig_mapping = QConfigMapping().set_global(qconfig) + # Test with 2d inputs + example_inputs_2d = (torch.randn(9, 8),) + example_inputs_4d = (torch.randn(9, 10, 11, 8),) + for example_inputs in [example_inputs_2d, example_inputs_4d]: + self._test_quantizer( + m_eager, + example_inputs, + quantizer, + node_occurrence, + [], + True, + qconfig_mapping, + is_qat=True, + ) + + def test_dynamic_linear_with_conv(self): + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config( + is_per_channel=False, is_dynamic=True + ) + quantizer.set_global(quantization_config) + m_eager = TestHelperModules.ConvLinearWPermute().eval() + + node_occurrence = { + # input and output are using quantize_per_tensor and weight is using quantize_per_channel + torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 1, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 1, + # note: quantize op for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_tensor.default: 0, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1, + } + + training_ir_node_occurrence = { + # input and output are using quantize_per_tensor and weight is using quantize_per_channel + # In training IR, the decomposition is different. + # `torch.ops.quantized_decomposed.quantize_per_tensor.default` nodes becomes + # `torch.ops.quantized_decomposed.quantize_per_tensor.tensor` nodes. + torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 2, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 2, + # note: quantize op for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_tensor.default: 0, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 0, + } + act_affine_quant_obs = observer.PlaceholderObserver.with_args( + dtype=torch.qint8, + qscheme=torch.per_tensor_affine, + quant_min=-128, + quant_max=127, + eps=2**-12, + is_dynamic=True, + ) + qconfig = QConfig( + activation=act_affine_quant_obs, + weight=weight_observer_range_neg_127_to_127, + ) + # Test with 2d inputs + example_inputs = (torch.randn(2, 3, 4, 4),) + qconfig_mapping = QConfigMapping().set_global(qconfig) + self._test_quantizer( + m_eager, + example_inputs, + quantizer, + node_occurrence, + [], + True, + qconfig_mapping, + training_ir_node_occurrence=training_ir_node_occurrence, + ) + + def test_gru(self): + """this is a test for annotating fp32 GRU so that it produces + q -> dq -> fp32_gru -> q -> dq, this is currently enough for our use cases, + but we may change the annotation to be more precise in the future + """ + + class RNNDynamicModel(torch.nn.Module): + def __init__(self, mod_type): + super().__init__() + self.qconfig = default_dynamic_qconfig + if mod_type == "GRU": + self.mod = torch.nn.GRU(2, 2).to(dtype=torch.float) + if mod_type == "LSTM": + self.mod = torch.nn.LSTM(2, 2).to(dtype=torch.float) + + def forward(self, input_tensor, hidden_tensor): + input_tensor = 1 * input_tensor + hidden_tensor = 1 * hidden_tensor + output_tensor, hidden_out = self.mod(input_tensor, hidden_tensor) + return 1 * output_tensor, 1 * hidden_out + + with override_quantized_engine("qnnpack"): + model_fx = RNNDynamicModel("GRU") + niter = 10 + example_inputs = ( + # input_tensor + torch.tensor([[100, -155], [-155, 100], [100, -155]], dtype=torch.float) + .unsqueeze(0) + .repeat(niter, 1, 1), + # hidden_tensor + # (D * num_layers, N, H_out) + torch.tensor([[[100, -155]]], dtype=torch.float).repeat(1, 3, 1), + ) + model_graph = copy.deepcopy(model_fx) + + qconfig_mapping = QConfigMapping().set_object_type( + operator.mul, default_symmetric_qnnpack_qconfig + ) + model_fx = prepare_fx( + model_fx, + qconfig_mapping, + example_inputs, + backend_config=get_qnnpack_backend_config(), + ) + model_fx(*example_inputs) + model_fx = _convert_to_reference_decomposed_fx(model_fx) + + with torchdynamo.config.patch(allow_rnn=True): + model_graph = export_for_training( + model_graph, + example_inputs, + ).module() + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config( + is_per_channel=False, is_dynamic=False + ) + quantizer.set_global(quantization_config) + model_graph = prepare_pt2e(model_graph, quantizer) + model_graph(*example_inputs) + model_graph = convert_pt2e(model_graph) + self.assertEqual(model_fx(*example_inputs), model_graph(*example_inputs)) + + def test_linear_gru(self): + """this test is to make sure GRU annotation does not interfere with linear annotation""" + + class RNNDynamicModel(torch.nn.Module): + def __init__(self, mod_type): + super().__init__() + self.qconfig = default_dynamic_qconfig + self.linear = torch.nn.Linear(2, 2) + if mod_type == "GRU": + self.mod = torch.nn.GRU(2, 2).to(dtype=torch.float) + if mod_type == "LSTM": + self.mod = torch.nn.LSTM(2, 2).to(dtype=torch.float) + + def forward(self, input_tensor, hidden_tensor): + input_tensor = self.linear(input_tensor) + input_tensor = 1 * input_tensor + hidden_tensor = 1 * hidden_tensor + output_tensor, hidden_out = self.mod(input_tensor, hidden_tensor) + return 1 * output_tensor, 1 * hidden_out + + with override_quantized_engine("qnnpack"): + model_fx = RNNDynamicModel("GRU") + niter = 10 + example_inputs = ( + # input_tensor + torch.tensor([[100, -155], [-155, 100], [100, -155]], dtype=torch.float) + .unsqueeze(0) + .repeat(niter, 1, 1), + # hidden_tensor + # (D * num_layers, N, H_out) + torch.tensor([[[100, -155]]], dtype=torch.float).repeat(1, 3, 1), + ) + model_graph = copy.deepcopy(model_fx) + + qconfig_mapping = ( + QConfigMapping() + .set_object_type(operator.mul, default_symmetric_qnnpack_qconfig) + .set_object_type(torch.nn.Linear, default_symmetric_qnnpack_qconfig) + ) + model_fx = prepare_fx( + model_fx, + qconfig_mapping, + example_inputs, + backend_config=get_qnnpack_backend_config(), + ) + model_fx(*example_inputs) + model_fx = _convert_to_reference_decomposed_fx(model_fx) + + with torchdynamo.config.patch(allow_rnn=True): + model_graph = export_for_training( + model_graph, + example_inputs, + ).module() + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config( + is_per_channel=False, is_dynamic=False + ) + quantizer.set_global(quantization_config) + model_graph = prepare_pt2e(model_graph, quantizer) + model_graph(*example_inputs) + model_graph = convert_pt2e(model_graph) + self.assertEqual(model_fx(*example_inputs), model_graph(*example_inputs)) + + def test_add_and_inplace_add(self): + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_global(quantization_config) + example_inputs = ( + torch.randn(1, 3, 5, 5), + torch.randn(1, 3, 5, 5), + ) + node_occurrence = { + # two input and one output for first add, and output for second add + torch.ops.quantized_decomposed.quantize_per_tensor.default: 4, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 5, + } + node_list = [ + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.add.Tensor, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + # TODO torch.ops.aten.add.Tensor, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + ] + self._test_quantizer( + TestHelperModules.AddInplaceAdd(), + example_inputs, + quantizer, + node_occurrence, + node_list, + ) + + def test_mul_and_inplace_mul(self): + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_global(quantization_config) + example_inputs = ( + torch.randn(1, 3, 5, 5), + torch.randn(1, 3, 5, 5), + ) + node_occurrence = { + # two input and one output for first add, and output for second add + torch.ops.quantized_decomposed.quantize_per_tensor.default: 4, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 5, + } + node_list = [ + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.mul.Tensor, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + # TODO torch.ops.aten.mul.Tensor, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + ] + self._test_quantizer( + TestHelperModules.MulInplaceMul(), + example_inputs, + quantizer, + node_occurrence, + node_list, + ) + + def test_add_mul_scalar(self): + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_global(quantization_config) + example_inputs = (torch.randn(1, 3, 5, 5),) + node_occurrence = { + # two input and one output for first add, and output for second add + torch.ops.quantized_decomposed.quantize_per_tensor.default: 5, + # TODO torch.ops.quantized_decomposed.dequantize_per_tensor.default: 9, + } + node_list = [ + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.add.Tensor, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.mul.Tensor, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + # TODO torch.ops.aten.add.Tensor, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + # TODO torch.ops.aten.mul.Tensor, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + ] + self._test_quantizer( + TestHelperModules.AddMulScalar(), + example_inputs, + quantizer, + node_occurrence, + node_list, + ) + + def test_mul_float32_max(self): + class M(torch.nn.Module): + def forward(self, x): + return x * 3.4028235e38 + + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_global(quantization_config) + example_inputs = (torch.randn(1, 3, 5, 5),) + # not quantized + node_occurrence = { + torch.ops.quantized_decomposed.quantize_per_tensor.default: 0, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 0, + } + node_list = [ + torch.ops.aten.mul.Tensor, + ] + self._test_quantizer( + M(), + example_inputs, + quantizer, + node_occurrence, + node_list, + ) + + def test_add_mul_long(self): + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.t = torch.tensor([100]) + + def forward(self, x): + x = x + self.t + x = x * self.t + return x + + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_global(quantization_config) + example_inputs = (torch.randn(1, 3, 5, 5),) + # not quantized + node_occurrence = { + torch.ops.quantized_decomposed.quantize_per_tensor.default: 0, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 0, + } + node_list = [ + torch.ops.aten.add.Tensor, + torch.ops.aten.mul.Tensor, + ] + self._test_quantizer( + M(), + example_inputs, + quantizer, + node_occurrence, + node_list, + ) + + def test_cat_same_node(self): + """Ensure that concatenating the same node does not cause any unexpected behavior""" + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + x = torch.cat([x, x]) + return x + + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_global(quantization_config) + example_inputs = (torch.randn(1, 3, 5, 5),) + node_occurrence = { + torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, + } + node_list = [ + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.cat.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + ] + self._test_quantizer( + M(), + example_inputs, + quantizer, + node_occurrence, + node_list, + ) + + +# TODO: express this using self._test_quantizer, add test for inception_v4 +class TestXNNPACKQuantizerModels(PT2EQuantizationTestCase): + @skip_if_no_torchvision + @skipIfNoQNNPACK + def test_resnet18(self): + import torchvision + + with override_quantized_engine("qnnpack"): + example_inputs = (torch.randn(1, 3, 224, 224),) + m = torchvision.models.resnet18().eval() + m_copy = copy.deepcopy(m) + # program capture + m = export_for_training( + m, + example_inputs, + ).module() + + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_global(quantization_config) + m = prepare_pt2e(m, quantizer) + # checking that we inserted observers correctly for maxpool operator (input and + # output share observer instance) + self.assertEqual( + id(m.activation_post_process_3), id(m.activation_post_process_2) + ) + after_prepare_result = m(*example_inputs) + m = convert_pt2e(m) + + after_quant_result = m(*example_inputs) + + # comparing with existing fx graph mode quantization reference flow + qconfig = default_per_channel_symmetric_qnnpack_qconfig + qconfig_mapping = QConfigMapping().set_global(qconfig) + backend_config = get_qnnpack_backend_config() + m_fx = prepare_fx( + m_copy, qconfig_mapping, example_inputs, backend_config=backend_config + ) + after_prepare_result_fx = m_fx(*example_inputs) + m_fx = convert_to_reference_fx(m_fx, backend_config=backend_config) + + after_quant_result_fx = m_fx(*example_inputs) + + # the result matches exactly after prepare + # Note: this currently will always be true since we are inserting observers + # the check becomes useful when we add qat examples + # but we can still manully inspect the printed observers to make sure + # it matches + self.assertEqual(after_prepare_result, after_prepare_result_fx) + self.assertEqual( + compute_sqnr(after_prepare_result, after_prepare_result_fx), + torch.tensor(float("inf")), + ) + # there are slight differences after convert due to different implementations + # of quant/dequant + self.assertTrue( + torch.max(after_quant_result - after_quant_result_fx) < 1e-1 + ) + self.assertTrue( + compute_sqnr(after_quant_result, after_quant_result_fx) > 35 + ) From e9f36b63a37ea991a996eff9557731a15c71c330 Mon Sep 17 00:00:00 2001 From: daniil-lyakhov Date: Tue, 25 Feb 2025 17:13:43 +0100 Subject: [PATCH 2/6] WIP --- .../tests/quantizer/test_pt2e_quantization.py | 263 ++++++++++++++++-- .../test/quantizer/test_pt2e_quantization.py | 8 +- 2 files changed, 238 insertions(+), 33 deletions(-) diff --git a/backends/openvino/tests/quantizer/test_pt2e_quantization.py b/backends/openvino/tests/quantizer/test_pt2e_quantization.py index 8f15ada00bd..0c4b6121668 100644 --- a/backends/openvino/tests/quantizer/test_pt2e_quantization.py +++ b/backends/openvino/tests/quantizer/test_pt2e_quantization.py @@ -6,9 +6,19 @@ # pyre-unsafe +import copy from collections import Counter +import torch.ao.nn.quantized.reference as nnqr from typing import Dict, Tuple, Optional +from torch.ao.quantization.backend_config import ( + BackendConfig, + BackendPatternConfig, + DTypeConfig, + DTypeWithConstraints, + ObservationType, +) +from nncf.experimental.torch.fx.node_utils import get_tensor_constant_from_node import torch from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import ( get_symmetric_quantization_config, @@ -17,6 +27,12 @@ from executorch.backends.openvino.quantizer.quantizer import OpenVINOQuantizer from executorch.backends.openvino.quantizer.quantizer import QuantizationMode +from torch.ao.quantization.quantize_pt2e import ( + _convert_to_reference_decomposed_fx, + convert_pt2e, + prepare_pt2e, + prepare_qat_pt2e, +) from torch.ao.quantization import ( compare_results, CUSTOM_KEY, @@ -35,6 +51,7 @@ weight_observer_range_neg_127_to_127, ) from torch.ao.quantization.qconfig_mapping import QConfigMapping +from torch.ao.quantization.quantize_fx import prepare_fx from torch.ao.quantization.quantize_pt2e import ( convert_pt2e, prepare_pt2e, @@ -65,7 +82,6 @@ def run(self, result=None): """ with disable_patching(): super().run(result) - def _get_pt2e_quantized_linear( self, mode: Optional[QuantizationMode] = None @@ -143,7 +159,7 @@ def validate(self, model: torch.fx.GraphModule) -> None: ) def test_composable_quantizer_linear_conv(self) -> None: - #TODO: impelment whe dynamic quantization will be supported by OpenVINOQuantizer + # TODO: impelment whe dynamic quantization will be supported by OpenVINOQuantizer pass def test_embedding_conv_linear_quantization(self) -> None: @@ -574,43 +590,232 @@ def forward(self, x): def test_fold_quantize_sym(self) -> None: """Test to make sure the quantized model gets quantized weight (quantize_per_tensor op is folded)""" m = self._get_pt2e_quantized_linear() - - node_occurrence = { - # quantize op for weight node is folded - ns.call_function( - torch.ops.quantized_decomposed.quantize_per_tensor.default - ): 1, - ns.call_function( - torch.ops.quantized_decomposed.dequantize_per_channel.default - ): 1, - ns.call_function( - torch.ops.quantized_decomposed.dequantize_per_tensor.default - ): 1 - } - self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) + ref_q = { + "quantize_per_tensor_default": ( + None, + 0.010390480048954487, + 127, + 0, + 255, + torch.uint8, + ), + "dequantize_per_tensor_default": ( + None, + 0.010390480048954487, + 127, + 0, + 255, + torch.uint8, + ), + "dequantize_per_channel_default": ( + torch.tensor([[-78, -128], [-127, 76]], dtype=torch.int8), + torch.tensor([0.0029, 0.0036]), + torch.tensor([0, 0]), + 0, + -128, + 127, + torch.int8, + ), + } + self._check_quantization_with_ref(m, ref_q) def test_fold_quantize_mixed(self) -> None: """Test to make sure the quantized model gets quantized weight (quantize_per_channel op is folded)""" m = self._get_pt2e_quantized_linear(mode=QuantizationMode.INT8_MIXED) + ref_q = { + "quantize_per_tensor_default": ( + None, + 0.006073841359466314, + 37, + 0, + 255, + torch.uint8, + ), + "dequantize_per_tensor_default": ( + None, + 0.006073841359466314, + 37, + 0, + 255, + torch.uint8, + ), + "dequantize_per_channel_default": ( + torch.tensor([[-78, -128], [-127, 76]], dtype=torch.int8), + torch.tensor([0.0029, 0.0036]), + torch.tensor([0, 0]), + 0, + -128, + 127, + torch.int8, + ), + } + self._check_quantization_with_ref(m, ref_q) + + def _check_quantization_with_ref(self, model: torch.fx.GraphModule, ref: Dict): + matches = 0 + for node in model.graph.nodes: + if node.name not in ref: + continue + matches += 1 + ref_values = ref[node.name] + for idx, ref_value in enumerate(ref_values): + if ref_value is None: + continue + if isinstance(ref_value, torch.Tensor): + self.assertEqual( + get_tensor_constant_from_node(node.args[idx], model), + ref_value, + atol=1e-3, + rtol=1e-3, + ) + continue + if isinstance(ref_value, float): + self.assertEqual(node.args[idx], ref_value, atol=1e-3, rtol=1e-3) + continue + assert node.args[idx] == ref_value + + assert len(ref) == matches + + def _get_backend_config(self): + def _get_linear_configs(): + observation_type = ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT + dtype_configs = [ + DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.float, + weight_dtype=torch.qint8, + bias_dtype=torch.float, + ) + ] + linear_configs: list[BackendPatternConfig] = [] + # linear module + linear_configs.append( + BackendPatternConfig(torch.nn.Linear) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(torch.nn.Linear) + .set_reference_quantized_module(nnqr.Linear) + ) + # functional linear + linear_configs.append( + BackendPatternConfig(torch.nn.functional.linear) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + ._set_input_type_to_index({"weight": 1, "bias": 2}) + ) + return linear_configs + + def _get_conv_configs(): + pass + + return BackendConfig("OpenVINO").set_backend_pattern_configs( + _get_linear_configs() + ) + # .set_backend_pattern_configs(_get_conv_configs()) + + def _test_quantizer( + self, + model, + example_inputs, + quantizer, + expected_node_occurrence, + expected_node_list=None, + check_against_fx_quant=False, + fx_qconfig_mapping=None, + export_with_dynamic_shape=False, + is_qat=False, + is_debug_mode=False, + training_ir_node_occurrence=None, + ): + # resetting dynamo cache + torch._dynamo.reset() + m_eager = model.eval() + + # program capture + m = copy.deepcopy(m_eager) + dynamic_shapes = tuple( + {0: torch.export.Dim("dim")} if i == 0 else None + for i in range(len(example_inputs)) + ) + m = export_for_training( + m, + example_inputs, + dynamic_shapes=dynamic_shapes if export_with_dynamic_shape else None, + ).module() + + if is_qat: + m = prepare_qat_pt2e(m, quantizer) + else: + m = prepare_pt2e(m, quantizer) + if is_debug_mode: + print("prepared model:", m) + # Calibrate + m(*example_inputs) + m = convert_pt2e(m) + if is_debug_mode: + print("quantized model", m) + + pt2_quant_output = m(*example_inputs) node_occurrence = { - # quantize op for weight node is folded - ns.call_function( - torch.ops.quantized_decomposed.quantize_per_tensor.default - ): 1, - ns.call_function( - torch.ops.quantized_decomposed.dequantize_per_channel.default - ): 1, - ns.call_function( - torch.ops.quantized_decomposed.dequantize_per_tensor.default - ): 1 + ns.call_function(k): v for k, v in expected_node_occurrence.items() } - self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) - + if expected_node_list is None: + expected_node_list = [] + node_list = [ns.call_function(n) for n in expected_node_list] + self.checkGraphModuleNodes( + m, expected_node_occurrence=node_occurrence, expected_node_list=node_list + ) + if check_against_fx_quant: + qconfig_mapping = fx_qconfig_mapping + backend_config = self._get_backend_config() + m_copy = copy.deepcopy(m_eager) + m_fx = prepare_fx( + m_copy, qconfig_mapping, example_inputs, backend_config=backend_config + ) + m_fx(*example_inputs) + m_fx = _convert_to_reference_decomposed_fx( + m_fx, backend_config=backend_config + ) + m_fx = export_for_training( + m_fx, + example_inputs, + dynamic_shapes=dynamic_shapes if export_with_dynamic_shape else None, + ).module() + node_occurrence = {} + for k, v in PT2EQuantizationTestCase._MAP_TO_FX_TRACED_OPS.items(): + if k in expected_node_occurrence: + node_occurrence[ns.call_function(v)] = expected_node_occurrence[k] + if training_ir_node_occurrence is not None: + node_occurrence = { + ns.call_function(k): v + for k, v in training_ir_node_occurrence.items() + } + self.checkGraphModuleNodes(m_fx, expected_node_occurrence=node_occurrence) + fx_quant_output = m_fx(*example_inputs) + self.assertEqual(fx_quant_output, pt2_quant_output) + return m + # activation_observer = observer.HistogramObserver + default_qconfig = QConfig( + activation=activation_observer, weight=weight_observer + ) + qconfig_mapping = QConfigMapping() + qconfig_mapping.set_global(QConfig(activation=None, weight=None)) + qconfig_mapping.set_object_type(torch.nn.Linear, default_qconfig) + self._quantize() + self._test_quantizer( + m, + example_inputs, + quantizer, + node_occurrence, + check_against_fx_quant=True, + fx_qconfig_mapping=qconfig_mapping, + ) + # self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence, ) def test_save_load(self) -> None: """Test save/load a quantized model""" - m = self._get_pt2e_quantized_linear() + m = self._get_linear() example_inputs = (torch.randn(2, 2),) ref_res = m(*example_inputs) diff --git a/backends/xnnpack/test/quantizer/test_pt2e_quantization.py b/backends/xnnpack/test/quantizer/test_pt2e_quantization.py index ea6116a6f0a..477ac83ffd9 100644 --- a/backends/xnnpack/test/quantizer/test_pt2e_quantization.py +++ b/backends/xnnpack/test/quantizer/test_pt2e_quantization.py @@ -54,7 +54,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase): - def _get_pt2e_quantized_linear( + def _get_linear( self, is_per_channel: bool = False ) -> torch.fx.GraphModule: class M(torch.nn.Module): @@ -651,7 +651,7 @@ def forward(self, x): def test_fold_quantize(self) -> None: """Test to make sure the quantized model gets quantized weight (quantize_per_tensor op is folded)""" - m = self._get_pt2e_quantized_linear() + m = self._get_linear() node_occurrence = { # quantize op for weight node is folded ns.call_function( @@ -665,7 +665,7 @@ def test_fold_quantize(self) -> None: def test_fold_quantize_per_channel(self) -> None: """Test to make sure the quantized model gets quantized weight (quantize_per_channel op is folded)""" - m = self._get_pt2e_quantized_linear(is_per_channel=True) + m = self._get_linear(is_per_channel=True) node_occurrence = { # quantize op for weight node is folded ns.call_function( @@ -682,7 +682,7 @@ def test_fold_quantize_per_channel(self) -> None: def test_save_load(self) -> None: """Test save/load a quantized model""" - m = self._get_pt2e_quantized_linear() + m = self._get_linear() example_inputs = (torch.randn(2, 2),) ref_res = m(*example_inputs) From 467ae4a55d564324b9252ee1bbba1c8eb2b5b2eb Mon Sep 17 00:00:00 2001 From: daniil-lyakhov Date: Tue, 25 Feb 2025 19:13:56 +0100 Subject: [PATCH 3/6] WIP --- .../tests/quantizer/test_pt2e_quantization.py | 277 +++++------------- .../tests/quantizer/test_xnnpack_quantizer.py | 142 --------- 2 files changed, 81 insertions(+), 338 deletions(-) diff --git a/backends/openvino/tests/quantizer/test_pt2e_quantization.py b/backends/openvino/tests/quantizer/test_pt2e_quantization.py index 0c4b6121668..5e56bbbe4db 100644 --- a/backends/openvino/tests/quantizer/test_pt2e_quantization.py +++ b/backends/openvino/tests/quantizer/test_pt2e_quantization.py @@ -158,12 +158,7 @@ def validate(self, model: torch.fx.GraphModule) -> None: ), ) - def test_composable_quantizer_linear_conv(self) -> None: - # TODO: impelment whe dynamic quantization will be supported by OpenVINOQuantizer - pass - def test_embedding_conv_linear_quantization(self) -> None: - # Mark m_eager = TestHelperModules.EmbeddingConvLinearModule().eval() indices = torch.tensor( [ @@ -203,57 +198,87 @@ def test_embedding_conv_linear_quantization(self) -> None: ) indices = torch.unsqueeze(indices, 0) example_inputs = (indices,) + quantizer = OpenVINOQuantizer() - embedding_quantizer = EmbeddingQuantizer() - dynamic_quantizer = XNNPACKQuantizer() - quantization_config_dynamic = get_symmetric_quantization_config( - is_per_channel=True, is_dynamic=True - ) - dynamic_quantizer.set_global(quantization_config_dynamic) - static_quantizer = XNNPACKQuantizer() - quantization_config = get_symmetric_quantization_config(is_per_channel=True) - static_quantizer.set_global(quantization_config) - composed_quantizer = ComposableQuantizer( - [embedding_quantizer, dynamic_quantizer, static_quantizer] - ) + m = self._quantize(m_eager, quantizer, example_inputs, is_qat=False) - act_affine_quant_obs = observer.PlaceholderObserver.with_args( - dtype=torch.qint8, - qscheme=torch.per_tensor_affine, - quant_min=-128, - quant_max=127, - eps=2**-12, - is_dynamic=True, - ) - dynamic_qconfig = QConfig( - activation=act_affine_quant_obs, - weight=per_channel_weight_observer_range_neg_127_to_127, - ) - qconfig = default_per_channel_symmetric_qnnpack_qconfig - qconfig_mapping = QConfigMapping().set_global(qconfig) - qconfig_mapping.set_object_type(torch.nn.Linear, dynamic_qconfig) - qconfig_mapping = qconfig_mapping.set_object_type( - torch.nn.Embedding, float_qparams_weight_only_qconfig - ) - - node_occurrence = { - torch.ops.quantized_decomposed.quantize_per_tensor.default: 4, - torch.ops.quantized_decomposed.dequantize_per_tensor.default: 4, - torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 1, - torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 1, - # note: quantize op for weights are const propagated - torch.ops.quantized_decomposed.quantize_per_channel.default: 0, - torch.ops.quantized_decomposed.dequantize_per_channel.default: 3, + ref_q = { + # First conv + "quantize_per_tensor_default": ( + None, + 0.01585131697356701, + 127, + 0, + 255, + torch.uint8, + ), + "dequantize_per_tensor_default": ( + None, + 0.01585131697356701, + 127, + 0, + 255, + torch.uint8, + ), + "dequantize_per_channel_default": ( + None, + torch.tensor( + [ + 0.0015, + 0.0015, + 0.0015, + 0.0016, + 0.0015, + 0.0016, + 0.0014, + 0.0014, + 0.0015, + 0.0015, + 0.0016, + 0.0015, + 0.0015, + 0.0016, + 0.0016, + 0.0015, + ] + ), + torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), + 0, + -128, + 127, + torch.int8, + ), + # First linear + "quantize_per_tensor_default_1": ( + None, + 0.016017982736229897, + 127, + 0, + 255, + torch.uint8, + ), + "dequantize_per_tensor_default_1": ( + None, + 0.016017982736229897, + 127, + 0, + 255, + torch.uint8, + ), + "dequantize_per_channel_default_1": ( + None, + torch.tensor( + [0.0019, 0.0019, 0.0020, 0.0018, 0.0019, 0.0019, 0.0018, 0.0018] + ), + torch.tensor([0, 0, 0, 0, 0, 0, 0, 0]), + 0, + -128, + 127, + torch.int8, + ), + # TODO: embedding } - self._test_quantizer( - m_eager, - example_inputs, - composed_quantizer, - node_occurrence, - [], - True, - qconfig_mapping, - ) + self._check_quantization_with_ref(m, ref_q) def test_disallow_eval_train(self) -> None: m = TestHelperModules.ConvWithBNRelu(relu=True) @@ -272,7 +297,7 @@ def test_disallow_eval_train(self) -> None: # After prepare: still not OK quantizer = OpenVINOQuantizer() - m = prepare_qat_pt2e(m, quantizer) # pyre-ignore[6] + m = prepare_pt2e(m, quantizer) # pyre-ignore[6] with self.assertRaises(NotImplementedError): m.eval() with self.assertRaises(NotImplementedError): @@ -308,11 +333,9 @@ class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.bn = torch.nn.BatchNorm2d(3) - self.dropout = torch.nn.Dropout(0.5) def forward(self, x): x = self.bn(x) - x = self.dropout(x) return x m = M().train() @@ -324,8 +347,6 @@ def _assert_ops_are_correct(m: torch.fx.GraphModule, train: bool) -> None: bn_op = bn_train_op if train else bn_eval_op bn_node = self._get_node(m, bn_op) self.assertTrue(bn_node is not None) - dropout_node = self._get_node(m, torch.ops.aten.dropout.default) - self.assertEqual(dropout_node.args[2], train) # Before wrapping: this is not OK with self.assertRaises(NotImplementedError): @@ -341,8 +362,8 @@ def _assert_ops_are_correct(m: torch.fx.GraphModule, train: bool) -> None: _assert_ops_are_correct(m, train=True) # pyre-ignore[6] # After prepare but before wrapping: this is not OK - quantizer = XNNPACKQuantizer() - m = prepare_qat_pt2e(m, quantizer) # pyre-ignore[6] + quantizer = OpenVINOQuantizer() + m = prepare_pt2e(m, quantizer) # pyre-ignore[6] with self.assertRaises(NotImplementedError): m.eval() with self.assertRaises(NotImplementedError): @@ -677,142 +698,6 @@ def _check_quantization_with_ref(self, model: torch.fx.GraphModule, ref: Dict): assert len(ref) == matches - def _get_backend_config(self): - def _get_linear_configs(): - observation_type = ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT - dtype_configs = [ - DTypeConfig( - input_dtype=torch.quint8, - output_dtype=torch.float, - weight_dtype=torch.qint8, - bias_dtype=torch.float, - ) - ] - linear_configs: list[BackendPatternConfig] = [] - # linear module - linear_configs.append( - BackendPatternConfig(torch.nn.Linear) - .set_observation_type(observation_type) # noqa: E131 - .set_dtype_configs(dtype_configs) - .set_root_module(torch.nn.Linear) - .set_reference_quantized_module(nnqr.Linear) - ) - # functional linear - linear_configs.append( - BackendPatternConfig(torch.nn.functional.linear) - .set_observation_type(observation_type) # noqa: E131 - .set_dtype_configs(dtype_configs) - ._set_input_type_to_index({"weight": 1, "bias": 2}) - ) - return linear_configs - - def _get_conv_configs(): - pass - - return BackendConfig("OpenVINO").set_backend_pattern_configs( - _get_linear_configs() - ) - # .set_backend_pattern_configs(_get_conv_configs()) - - def _test_quantizer( - self, - model, - example_inputs, - quantizer, - expected_node_occurrence, - expected_node_list=None, - check_against_fx_quant=False, - fx_qconfig_mapping=None, - export_with_dynamic_shape=False, - is_qat=False, - is_debug_mode=False, - training_ir_node_occurrence=None, - ): - # resetting dynamo cache - torch._dynamo.reset() - m_eager = model.eval() - - # program capture - m = copy.deepcopy(m_eager) - dynamic_shapes = tuple( - {0: torch.export.Dim("dim")} if i == 0 else None - for i in range(len(example_inputs)) - ) - m = export_for_training( - m, - example_inputs, - dynamic_shapes=dynamic_shapes if export_with_dynamic_shape else None, - ).module() - - if is_qat: - m = prepare_qat_pt2e(m, quantizer) - else: - m = prepare_pt2e(m, quantizer) - if is_debug_mode: - print("prepared model:", m) - # Calibrate - m(*example_inputs) - m = convert_pt2e(m) - if is_debug_mode: - print("quantized model", m) - - pt2_quant_output = m(*example_inputs) - node_occurrence = { - ns.call_function(k): v for k, v in expected_node_occurrence.items() - } - if expected_node_list is None: - expected_node_list = [] - node_list = [ns.call_function(n) for n in expected_node_list] - self.checkGraphModuleNodes( - m, expected_node_occurrence=node_occurrence, expected_node_list=node_list - ) - if check_against_fx_quant: - qconfig_mapping = fx_qconfig_mapping - backend_config = self._get_backend_config() - m_copy = copy.deepcopy(m_eager) - m_fx = prepare_fx( - m_copy, qconfig_mapping, example_inputs, backend_config=backend_config - ) - m_fx(*example_inputs) - m_fx = _convert_to_reference_decomposed_fx( - m_fx, backend_config=backend_config - ) - m_fx = export_for_training( - m_fx, - example_inputs, - dynamic_shapes=dynamic_shapes if export_with_dynamic_shape else None, - ).module() - node_occurrence = {} - for k, v in PT2EQuantizationTestCase._MAP_TO_FX_TRACED_OPS.items(): - if k in expected_node_occurrence: - node_occurrence[ns.call_function(v)] = expected_node_occurrence[k] - if training_ir_node_occurrence is not None: - node_occurrence = { - ns.call_function(k): v - for k, v in training_ir_node_occurrence.items() - } - self.checkGraphModuleNodes(m_fx, expected_node_occurrence=node_occurrence) - fx_quant_output = m_fx(*example_inputs) - self.assertEqual(fx_quant_output, pt2_quant_output) - return m - # activation_observer = observer.HistogramObserver - default_qconfig = QConfig( - activation=activation_observer, weight=weight_observer - ) - qconfig_mapping = QConfigMapping() - qconfig_mapping.set_global(QConfig(activation=None, weight=None)) - qconfig_mapping.set_object_type(torch.nn.Linear, default_qconfig) - self._quantize() - self._test_quantizer( - m, - example_inputs, - quantizer, - node_occurrence, - check_against_fx_quant=True, - fx_qconfig_mapping=qconfig_mapping, - ) - # self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence, ) - def test_save_load(self) -> None: """Test save/load a quantized model""" m = self._get_linear() diff --git a/backends/openvino/tests/quantizer/test_xnnpack_quantizer.py b/backends/openvino/tests/quantizer/test_xnnpack_quantizer.py index 57aacf55263..34095a3c65a 100644 --- a/backends/openvino/tests/quantizer/test_xnnpack_quantizer.py +++ b/backends/openvino/tests/quantizer/test_xnnpack_quantizer.py @@ -575,148 +575,6 @@ def test_dynamic_linear(self): qconfig_mapping, ) - def test_dynamic_linear_int4_weight(self): - quantizer = XNNPACKQuantizer() - quantization_config = get_symmetric_quantization_config( - is_per_channel=True, - is_dynamic=True, - weight_qmin=0, - weight_qmax=15, - ) - quantizer.set_global(quantization_config) - m_eager = TestHelperModules.TwoLinearModule().eval() - - node_occurrence = { - # input and output are using quantize_per_tensor and weight is using quantize_per_channel - torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 2, - torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 2, - # note: quantize op for weights are const propagated - torch.ops.quantized_decomposed.quantize_per_channel.default: 0, - torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, - } - act_affine_quant_obs = observer.PlaceholderObserver.with_args( - dtype=torch.qint8, - qscheme=torch.per_tensor_affine, - quant_min=-128, - quant_max=127, - eps=2**-12, - is_dynamic=True, - ) - qconfig = QConfig( - activation=act_affine_quant_obs, - weight=per_channel_weight_observer_range_neg_127_to_127.with_args( - quant_min=0, quant_max=15 - ), - ) - qconfig_mapping = QConfigMapping().set_global(qconfig) - # Test with 2d inputs - example_inputs_2d = (torch.randn(9, 8),) - example_inputs_4d = (torch.randn(9, 10, 11, 8),) - for example_inputs in [example_inputs_2d, example_inputs_4d]: - self._test_quantizer( - m_eager, - example_inputs, - quantizer, - node_occurrence, - [], - True, - qconfig_mapping, - ) - - def test_qat_dynamic_linear(self): - quantizer = XNNPACKQuantizer() - quantization_config = get_symmetric_quantization_config( - is_per_channel=True, - is_dynamic=True, - is_qat=True, - ) - quantizer.set_global(quantization_config) - m_eager = TestHelperModules.TwoLinearModule().eval() - - node_occurrence = { - torch.ops.quantized_decomposed.choose_qparams.tensor: 2, - # input and output are using quantize_per_tensor and weight is using quantize_per_channel - torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 2, - torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 2, - # note: quantize op for weights are const propagated - torch.ops.quantized_decomposed.quantize_per_channel.default: 0, - torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, - } - act_affine_quant_obs = default_dynamic_fake_quant - qconfig = QConfig( - activation=act_affine_quant_obs, - weight=per_channel_weight_observer_range_neg_127_to_127, - ) - qconfig_mapping = QConfigMapping().set_global(qconfig) - # Test with 2d inputs - example_inputs_2d = (torch.randn(9, 8),) - example_inputs_4d = (torch.randn(9, 10, 11, 8),) - for example_inputs in [example_inputs_2d, example_inputs_4d]: - self._test_quantizer( - m_eager, - example_inputs, - quantizer, - node_occurrence, - [], - True, - qconfig_mapping, - is_qat=True, - ) - - def test_dynamic_linear_with_conv(self): - quantizer = XNNPACKQuantizer() - quantization_config = get_symmetric_quantization_config( - is_per_channel=False, is_dynamic=True - ) - quantizer.set_global(quantization_config) - m_eager = TestHelperModules.ConvLinearWPermute().eval() - - node_occurrence = { - # input and output are using quantize_per_tensor and weight is using quantize_per_channel - torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 1, - torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 1, - # note: quantize op for weights are const propagated - torch.ops.quantized_decomposed.quantize_per_tensor.default: 0, - torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1, - } - - training_ir_node_occurrence = { - # input and output are using quantize_per_tensor and weight is using quantize_per_channel - # In training IR, the decomposition is different. - # `torch.ops.quantized_decomposed.quantize_per_tensor.default` nodes becomes - # `torch.ops.quantized_decomposed.quantize_per_tensor.tensor` nodes. - torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 2, - torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 2, - # note: quantize op for weights are const propagated - torch.ops.quantized_decomposed.quantize_per_tensor.default: 0, - torch.ops.quantized_decomposed.dequantize_per_tensor.default: 0, - } - act_affine_quant_obs = observer.PlaceholderObserver.with_args( - dtype=torch.qint8, - qscheme=torch.per_tensor_affine, - quant_min=-128, - quant_max=127, - eps=2**-12, - is_dynamic=True, - ) - qconfig = QConfig( - activation=act_affine_quant_obs, - weight=weight_observer_range_neg_127_to_127, - ) - # Test with 2d inputs - example_inputs = (torch.randn(2, 3, 4, 4),) - qconfig_mapping = QConfigMapping().set_global(qconfig) - self._test_quantizer( - m_eager, - example_inputs, - quantizer, - node_occurrence, - [], - True, - qconfig_mapping, - training_ir_node_occurrence=training_ir_node_occurrence, - ) - def test_gru(self): """this is a test for annotating fp32 GRU so that it produces q -> dq -> fp32_gru -> q -> dq, this is currently enough for our use cases, From 79b112f9fa83b2fd1621cfc273172156a109c1b1 Mon Sep 17 00:00:00 2001 From: daniil-lyakhov Date: Wed, 26 Feb 2025 15:32:51 +0100 Subject: [PATCH 4/6] OpenVINOQuantizer complete tests --- .../quantizer/test_openvino_quantizer.py | 558 +++++++++++ .../tests/quantizer/test_pt2e_quantization.py | 390 +------ .../tests/quantizer/test_representation.py | 188 +--- .../tests/quantizer/test_xnnpack_quantizer.py | 948 ------------------ .../test/quantizer/test_pt2e_quantization.py | 8 +- 5 files changed, 619 insertions(+), 1473 deletions(-) create mode 100644 backends/openvino/tests/quantizer/test_openvino_quantizer.py delete mode 100644 backends/openvino/tests/quantizer/test_xnnpack_quantizer.py diff --git a/backends/openvino/tests/quantizer/test_openvino_quantizer.py b/backends/openvino/tests/quantizer/test_openvino_quantizer.py new file mode 100644 index 00000000000..ebc10a29a73 --- /dev/null +++ b/backends/openvino/tests/quantizer/test_openvino_quantizer.py @@ -0,0 +1,558 @@ +# Owner(s): ["oncall: mobile"] +import copy +import unittest + +import torch +import torch._dynamo as torchdynamo +from executorch.backends.openvino.quantizer.quantizer import OpenVINOQuantizer + +from nncf.torch import disable_patching +from torch.ao.quantization import default_dynamic_qconfig + +from torch.export import export_for_training +from torch.testing._internal.common_quantization import ( + PT2EQuantizationTestCase, + TestHelperModules, +) + + +class TestOpenVINOQuantizer(PT2EQuantizationTestCase): + def run(self, result=None): + """ + Disable NNCF pathing for each test call + """ + with disable_patching(): + super().run(result) + + def test_conv1d(self): + quantizer = OpenVINOQuantizer() + example_inputs = (torch.randn(1, 3, 5),) + node_occurrence = { + # input and output are using quantize_per_tensor and weight is using quantize_per_channel + torch.ops.quantized_decomposed.quantize_per_tensor.default: 1, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1, + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, + } + node_list = [ + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.conv1d.default, + ] + self._test_quantizer( + TestHelperModules.ConvWithBNRelu(dim=1, relu=False, bn=False), + example_inputs, + quantizer, + node_occurrence, + node_list, + ) + + def test_conv2d(self): + quantizer = OpenVINOQuantizer() + example_inputs = (torch.randn(1, 3, 5, 5),) + node_occurrence = { + # input and output are using quantize_per_tensor and weight is using quantize_per_channel + torch.ops.quantized_decomposed.quantize_per_tensor.default: 1, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1, + # quantize_per_channel for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, + } + node_list = [ + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.conv2d.default, + ] + self._test_quantizer( + TestHelperModules.ConvWithBNRelu(relu=False, bn=False), + example_inputs, + quantizer, + node_occurrence, + node_list, + ) + + def test_conv1d_with_conv2d(self): + quantizer = OpenVINOQuantizer() + node_occurrence = { + # input and output are using quantize_per_tensor and weight is using quantize_per_channel + torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, + } + node_list = [ + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.conv2d.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.conv1d.default, + ] + m = TestHelperModules.Conv2dThenConv1d() + self._test_quantizer( + m, + m.example_inputs(), + quantizer, + node_occurrence, + node_list, + ) + + def test_linear(self): + quantizer = OpenVINOQuantizer() + m_eager = TestHelperModules.TwoLinearModule().eval() + + # Test with 2d inputs + example_inputs_2d = (torch.randn(9, 8),) + example_inputs_3d = (torch.randn(9, 10, 8),) + example_inputs_4d = (torch.randn(9, 10, 11, 8),) + node_occurrence = { + # input and output are using quantize_per_tensor and weight is using quantize_per_channel + torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, + # quantize_per_channel for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, + } + for example_inputs in [example_inputs_2d, example_inputs_3d, example_inputs_4d]: + self._test_quantizer( + m_eager, + example_inputs, + quantizer, + node_occurrence, + [], + ) + + def test_linear_relu(self): + quantizer = OpenVINOQuantizer() + m_eager = TestHelperModules.LinearReluModel().eval() + + # Test with 2d inputs + example_inputs_2d = (torch.randn(1, 5),) + example_inputs_3d = (torch.randn(1, 2, 5),) + example_inputs_4d = (torch.randn(1, 2, 3, 5),) + + node_occurrence = { + torch.ops.quantized_decomposed.quantize_per_tensor.default: 1, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1, + # quantize_per_channel for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, + } + for example_inputs in [example_inputs_2d, example_inputs_3d, example_inputs_4d]: + self._test_quantizer( + m_eager, + example_inputs, + quantizer, + node_occurrence, + [], # node_list + False, # executorch_backend_config() does not fuse linear-relu + ) + + def test_conv_linear_no_permute(self): + quantizer = OpenVINOQuantizer() + node_occurrence = { + torch.ops.quantized_decomposed.quantize_per_tensor.default: 3, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3, + # quantize_per_channel for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 3, + } + # Test with 2d inputs + example_inputs = (torch.randn(2, 3, 4, 4),) + self._test_quantizer( + TestHelperModules.Conv2dWithTwoLinear(), + example_inputs, + quantizer, + node_occurrence, + [], + ) + + def test_conv_linear(self): + quantizer = OpenVINOQuantizer() + + # Test with 2d inputs + example_inputs = (torch.randn(2, 3, 4, 4),) + node_occurrence = { + torch.ops.quantized_decomposed.quantize_per_tensor.default: 3, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3, + # quantize_per_channel for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 3, + } + node_list = [ + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.conv2d.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.permute.default, + ] + self._test_quantizer( + TestHelperModules.Conv2dWithTwoLinearPermute(), + example_inputs, + quantizer, + node_occurrence, + expected_node_list=node_list, + ) + + @unittest.skip( + "Enable after the fix https://github.com/openvinotoolkit/nncf/pull/3225" + ) + def test_linear_with_dynamic_shape(self): + quantizer = OpenVINOQuantizer() + m_eager = TestHelperModules.TwoLinearModule().eval() + + # Test with 2d inputs + example_inputs_3d = (torch.randn(9, 10, 8),) + node_occurrence = { + torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, + # quantize_per_channel for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, + } + self._test_quantizer( + m_eager, + example_inputs_3d, + quantizer, + node_occurrence, + [], + export_with_dynamic_shape=True, + ) + + def test_obs_sharing_ops(self): + quantizer = OpenVINOQuantizer() + m = TestHelperModules.Conv2dWithObsSharingOps().eval() + example_inputs = (torch.randn(1, 3, 5, 5),) + node_occurrence = { + # input and output are using quantize_per_tensor and weight is using quantize_per_channel + torch.ops.quantized_decomposed.quantize_per_tensor.default: 3, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3, + # quantize_per_channel for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, + } + node_list = [ + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.conv2d.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.adaptive_avg_pool2d.default, + torch.ops.aten.hardtanh.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.mean.default, + ] + self._test_quantizer(m, example_inputs, quantizer, node_occurrence, node_list) + + def test_propagate_annotation(self): + quantizer = OpenVINOQuantizer() + m = TestHelperModules.Conv2dPropAnnotaton().eval() + example_inputs = (torch.randn(1, 3, 5, 5),) + + node_occurrence = { + torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, + # note: quantize op for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, + } + node_list = [ + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.conv2d.default, + torch.ops.aten.view.default, + torch.ops.aten.hardtanh.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.linear.default, + ] + self._test_quantizer( + m, + example_inputs, + quantizer, + expected_node_occurrence=node_occurrence, + expected_node_list=node_list, + ) + + @unittest.skip( + "Enable after the fix https://github.com/openvinotoolkit/nncf/pull/3225" + ) + def test_dynamic_linear(self): + quantizer = OpenVINOQuantizer() + m_eager = TestHelperModules.TwoLinearModule().eval() + + node_occurrence = { + # input and output are using quantize_per_tensor and weight is using quantize_per_channel + torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 2, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 2, + # note: quantize op for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, + } + # Test with 2d inputs + example_inputs_2d = (torch.randn(9, 8),) + example_inputs_4d = (torch.randn(9, 10, 11, 8),) + for example_inputs in [example_inputs_2d, example_inputs_4d]: + self._test_quantizer( + m_eager, + example_inputs, + quantizer, + node_occurrence, + [], + ) + + @unittest.skip("gru quantization is not supported yet by OpenVINOQuantizer") + def test_gru(self): + """this is a test for annotating fp32 GRU so that it produces + q -> dq -> fp32_gru -> q -> dq, this is currently enough for our use cases, + but we may change the annotation to be more precise in the future + """ + + class RNNDynamicModel(torch.nn.Module): + def __init__(self, mod_type): + super().__init__() + self.qconfig = default_dynamic_qconfig + if mod_type == "GRU": + self.mod = torch.nn.GRU(2, 2).to(dtype=torch.float) + if mod_type == "LSTM": + self.mod = torch.nn.LSTM(2, 2).to(dtype=torch.float) + + def forward(self, input_tensor, hidden_tensor): + input_tensor = 1 * input_tensor + hidden_tensor = 1 * hidden_tensor + output_tensor, hidden_out = self.mod(input_tensor, hidden_tensor) + return 1 * output_tensor, 1 * hidden_out + + model_fx = RNNDynamicModel("GRU") + niter = 10 + example_inputs = ( + # input_tensor + torch.tensor([[100, -155], [-155, 100], [100, -155]], dtype=torch.float) + .unsqueeze(0) + .repeat(niter, 1, 1), + # hidden_tensor + # (D * num_layers, N, H_out) + torch.tensor([[[100, -155]]], dtype=torch.float).repeat(1, 3, 1), + ) + model_graph = copy.deepcopy(model_fx) + + with torchdynamo.config.patch(allow_rnn=True): + model_graph = export_for_training( + model_graph, + example_inputs, + ).module() + quantizer = OpenVINOQuantizer() + + node_occurrence = { + torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 0, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 0, + # note: quantize op for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 0, + } + + with torchdynamo.config.patch(allow_rnn=True): + self._test_quantizer( + model_graph, + example_inputs, + quantizer, + expected_node_occurrence=node_occurrence, + ) + + @unittest.skip("gru quantization is not supported yet by OpenVINOQuantizer") + def test_linear_gru(self): + """this test is to make sure GRU annotation does not interfere with linear annotation""" + + class RNNDynamicModel(torch.nn.Module): + def __init__(self, mod_type): + super().__init__() + self.qconfig = default_dynamic_qconfig + self.linear = torch.nn.Linear(2, 2) + if mod_type == "GRU": + self.mod = torch.nn.GRU(2, 2).to(dtype=torch.float) + if mod_type == "LSTM": + self.mod = torch.nn.LSTM(2, 2).to(dtype=torch.float) + + def forward(self, input_tensor, hidden_tensor): + input_tensor = self.linear(input_tensor) + input_tensor = 1 * input_tensor + hidden_tensor = 1 * hidden_tensor + output_tensor, hidden_out = self.mod(input_tensor, hidden_tensor) + return 1 * output_tensor, 1 * hidden_out + + model_fx = RNNDynamicModel("GRU") + niter = 10 + example_inputs = ( + # input_tensor + torch.tensor([[100, -155], [-155, 100], [100, -155]], dtype=torch.float) + .unsqueeze(0) + .repeat(niter, 1, 1), + # hidden_tensor + # (D * num_layers, N, H_out) + torch.tensor([[[100, -155]]], dtype=torch.float).repeat(1, 3, 1), + ) + model_graph = copy.deepcopy(model_fx) + quantizer = OpenVINOQuantizer() + node_occurrence = { + torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 0, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 0, + # note: quantize op for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 0, + } + with torchdynamo.config.patch(allow_rnn=True): + self._test_quantizer( + model_graph, + example_inputs, + quantizer, + expected_node_occurrence=node_occurrence, + ) + + def test_add_and_inplace_add(self): + quantizer = OpenVINOQuantizer() + example_inputs = ( + torch.randn(1, 3, 5, 5), + torch.randn(1, 3, 5, 5), + ) + node_occurrence = { + # two input and one output for first add, and output for second add + torch.ops.quantized_decomposed.quantize_per_tensor.default: 3, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 4, + } + node_list = [ + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.add.Tensor, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + ] + self._test_quantizer( + TestHelperModules.AddInplaceAdd(), + example_inputs, + quantizer, + node_occurrence, + node_list, + ) + + def test_mul_and_inplace_mul(self): + quantizer = OpenVINOQuantizer() + example_inputs = ( + torch.randn(1, 3, 5, 5), + torch.randn(1, 3, 5, 5), + ) + node_occurrence = { + # two input and one output for first add, and output for second add + torch.ops.quantized_decomposed.quantize_per_tensor.default: 3, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 4, + } + node_list = [ + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.mul.Tensor, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + # TODO torch.ops.aten.mul.Tensor, + ] + self._test_quantizer( + TestHelperModules.MulInplaceMul(), + example_inputs, + quantizer, + node_occurrence, + node_list, + ) + + def test_add_mul_scalar(self): + quantizer = OpenVINOQuantizer() + example_inputs = (torch.randn(1, 3, 5, 5),) + node_occurrence = { + torch.ops.quantized_decomposed.quantize_per_tensor.default: 0, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 0, + } + + self._test_quantizer( + TestHelperModules.AddMulScalar(), + example_inputs, + quantizer, + node_occurrence, + ) + + def test_mul_float32_max(self): + class M(torch.nn.Module): + def forward(self, x): + return x * 3.4028235e38 + + quantizer = OpenVINOQuantizer() + example_inputs = (torch.randn(1, 3, 5, 5),) + # not quantized + node_occurrence = { + torch.ops.quantized_decomposed.quantize_per_tensor.default: 0, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 0, + } + node_list = [ + torch.ops.aten.mul.Tensor, + ] + + self._test_quantizer( + M(), + example_inputs, + quantizer, + node_occurrence, + node_list, + ) + + def test_add_mul_long(self): + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.t = torch.tensor([100]) + + def forward(self, x): + x = x + self.t + x = x * self.t + return x + + quantizer = OpenVINOQuantizer() + example_inputs = (torch.randn(1, 3, 5, 5),) + # not quantized + node_occurrence = { + torch.ops.quantized_decomposed.quantize_per_tensor.default: 0, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 0, + } + node_list = [ + torch.ops.aten.add.Tensor, + torch.ops.aten.mul.Tensor, + ] + self._test_quantizer( + M(), + example_inputs, + quantizer, + node_occurrence, + node_list, + ) + + def test_ignored_scope(self): + for kwargs in [ + {"types": ["linear"]}, + {"names": ["linear", "linear_1"]}, + {"patterns": ["linear*"]}, + ]: + quantizer = OpenVINOQuantizer() + quantizer.set_ignored_scope(**kwargs) + + # Test with 2d inputs + example_inputs = (torch.randn(2, 3, 4, 4),) + node_occurrence = { + torch.ops.quantized_decomposed.quantize_per_tensor.default: 1, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1, + # quantize_per_channel for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, + } + node_list = [ + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.conv2d.default, + torch.ops.aten.permute.default, + ] + self._test_quantizer( + TestHelperModules.Conv2dWithTwoLinearPermute(), + example_inputs, + quantizer, + node_occurrence, + expected_node_list=node_list, + ) diff --git a/backends/openvino/tests/quantizer/test_pt2e_quantization.py b/backends/openvino/tests/quantizer/test_pt2e_quantization.py index 5e56bbbe4db..f8f747f74fe 100644 --- a/backends/openvino/tests/quantizer/test_pt2e_quantization.py +++ b/backends/openvino/tests/quantizer/test_pt2e_quantization.py @@ -6,60 +6,20 @@ # pyre-unsafe -import copy -from collections import Counter -import torch.ao.nn.quantized.reference as nnqr -from typing import Dict, Tuple, Optional - -from torch.ao.quantization.backend_config import ( - BackendConfig, - BackendPatternConfig, - DTypeConfig, - DTypeWithConstraints, - ObservationType, -) -from nncf.experimental.torch.fx.node_utils import get_tensor_constant_from_node +import unittest +from typing import Dict, Optional, Tuple + import torch -from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import ( - get_symmetric_quantization_config, - XNNPACKQuantizer, -) -from executorch.backends.openvino.quantizer.quantizer import OpenVINOQuantizer -from executorch.backends.openvino.quantizer.quantizer import QuantizationMode - -from torch.ao.quantization.quantize_pt2e import ( - _convert_to_reference_decomposed_fx, - convert_pt2e, - prepare_pt2e, - prepare_qat_pt2e, -) -from torch.ao.quantization import ( - compare_results, - CUSTOM_KEY, - default_per_channel_symmetric_qnnpack_qconfig, - extract_results_from_loggers, - generate_numeric_debug_handle, - NUMERIC_DEBUG_HANDLE_KEY, - observer, - prepare_for_propagation_comparison, -) -from torch.ao.quantization.pt2e.graph_utils import bfs_trace_with_node_process -from torch.ao.quantization.qconfig import ( - float_qparams_weight_only_qconfig, - per_channel_weight_observer_range_neg_127_to_127, - QConfig, - weight_observer_range_neg_127_to_127, -) -from torch.ao.quantization.qconfig_mapping import QConfigMapping -from torch.ao.quantization.quantize_fx import prepare_fx -from torch.ao.quantization.quantize_pt2e import ( - convert_pt2e, - prepare_pt2e, - prepare_qat_pt2e, +from executorch.backends.openvino.quantizer.quantizer import ( + OpenVINOQuantizer, + QuantizationMode, ) + +from nncf.experimental.torch.fx.node_utils import get_tensor_constant_from_node +from nncf.torch import disable_patching +from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e from torch.ao.quantization.quantizer import Quantizer from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer -from torch.ao.quantization.quantizer.embedding_quantizer import EmbeddingQuantizer from torch.export import export_for_training from torch.testing._internal.common_quantization import ( NodeSpec as ns, @@ -69,9 +29,7 @@ from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, TemporaryFileName, - TestCase, ) -from nncf.torch import disable_patching class TestQuantizePT2E(PT2EQuantizationTestCase): @@ -158,43 +116,16 @@ def validate(self, model: torch.fx.GraphModule) -> None: ), ) + @unittest.skip( + "Enable after the embedding quantization fix: https://github.com/openvinotoolkit/nncf/pull/3237" + ) def test_embedding_conv_linear_quantization(self) -> None: m_eager = TestHelperModules.EmbeddingConvLinearModule().eval() indices = torch.tensor( - [ - 9, - 6, - 5, - 7, - 8, - 8, - 9, - 2, - 8, - 6, - 6, - 9, - 1, - 6, - 8, - 8, - 3, - 2, - 3, - 6, - 3, - 6, - 5, - 7, - 0, - 8, - 4, - 6, - 5, - 8, - 2, - 3, - ] + [9, 6, 5, 7, 8, 8, 9, 2, 8, 6] + + [6, 9, 1, 6, 8, 8, 3, 2, 3, 6] + + [3, 6, 5, 7, 0, 8, 4, 6, 5, 8] + + [2, 3] ) indices = torch.unsqueeze(indices, 0) example_inputs = (indices,) @@ -223,24 +154,10 @@ def test_embedding_conv_linear_quantization(self) -> None: "dequantize_per_channel_default": ( None, torch.tensor( - [ - 0.0015, - 0.0015, - 0.0015, - 0.0016, - 0.0015, - 0.0016, - 0.0014, - 0.0014, - 0.0015, - 0.0015, - 0.0016, - 0.0015, - 0.0015, - 0.0016, - 0.0016, - 0.0015, - ] + [0.0015, 0.0015, 0.0015, 0.0016, 0.0015] + + [0.0016, 0.0014, 0.0014, 0.0015, 0.0015] + + [0.0016, 0.0015, 0.0015, 0.0016, 0.0016] + + [0.0015] ), torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), 0, @@ -276,7 +193,7 @@ def test_embedding_conv_linear_quantization(self) -> None: 127, torch.int8, ), - # TODO: embedding + # TODO: check embedding after the fix in NNCF } self._check_quantization_with_ref(m, ref_q) @@ -403,9 +320,7 @@ def __init__(self) -> None: def forward(self, x): return self.linear(x) - quantizer = XNNPACKQuantizer() - operator_config = get_symmetric_quantization_config() - quantizer.set_global(operator_config) + quantizer = OpenVINOQuantizer() example_inputs = (torch.randn(2, 2),) m = M().eval() m = export_for_training( @@ -436,19 +351,15 @@ def test_reentrant(self) -> None: m = TestHelperModules.ConvBnReLU2dAndLinearReLU() example_inputs = (torch.randn(3, 3, 10, 10),) - quantizer = XNNPACKQuantizer().set_global( - get_symmetric_quantization_config(is_per_channel=True, is_qat=True) - ) + quantizer = OpenVINOQuantizer(mode=QuantizationMode.INT8_SYM) m.conv_bn_relu = export_for_training( # pyre-ignore[8] m.conv_bn_relu, example_inputs ).module() - m.conv_bn_relu = prepare_qat_pt2e(m.conv_bn_relu, quantizer) # pyre-ignore[6,8] + m.conv_bn_relu = prepare_pt2e(m.conv_bn_relu, quantizer) # pyre-ignore[6,8] m(*example_inputs) m.conv_bn_relu = convert_pt2e(m.conv_bn_relu) # pyre-ignore[6, 8] - quantizer = XNNPACKQuantizer().set_module_type( - torch.nn.Linear, get_symmetric_quantization_config(is_per_channel=False) - ) + quantizer = OpenVINOQuantizer(mode=QuantizationMode.INT8_MIXED) m = export_for_training(m, example_inputs).module() m = prepare_pt2e(m, quantizer) # pyre-ignore[6] m = convert_pt2e(m) @@ -456,41 +367,22 @@ def test_reentrant(self) -> None: node_occurrence = { ns.call_function( torch.ops.quantized_decomposed.quantize_per_tensor.default - ): 4, - # one for weight + ): 3, ns.call_function( torch.ops.quantized_decomposed.dequantize_per_tensor.default - ): 5, + ): 3, ns.call_function( - torch.ops.quantized_decomposed.dequantize_per_channel.default + torch.ops.quantized_decomposed.quantize_per_channel.default ): 1, - } - node_list = [ - ns.call_function( - torch.ops.quantized_decomposed.dequantize_per_tensor.default - ), - ns.call_function(torch.ops.aten.conv2d.default), - ns.call_function(torch.ops.aten.relu.default), - ns.call_function( - torch.ops.quantized_decomposed.quantize_per_tensor.default - ), - ns.call_function( - torch.ops.quantized_decomposed.dequantize_per_tensor.default - ), - ns.call_function(torch.ops.aten.linear.default), ns.call_function( - torch.ops.quantized_decomposed.quantize_per_tensor.default - ), - ] - self.checkGraphModuleNodes( - m, expected_node_occurrence=node_occurrence, expected_node_list=node_list - ) + torch.ops.quantized_decomposed.dequantize_per_channel.default + ): 3, + } + self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) def test_groupwise_per_channel_quant(self) -> None: m = TestHelperModules.GroupwiseConv2d() - quantizer = XNNPACKQuantizer() - operator_config = get_symmetric_quantization_config(is_per_channel=True) - quantizer.set_global(operator_config) + quantizer = OpenVINOQuantizer() example_inputs = m.example_inputs() m = self._quantize(m, quantizer, example_inputs) # make sure it runs @@ -501,9 +393,7 @@ def test_preserve_nn_module_stack(self) -> None: m = TestHelperModules.ConvBnReLU2dAndLinearReLU() example_inputs = (torch.randn(3, 3, 10, 10),) - quantizer = XNNPACKQuantizer().set_global( - get_symmetric_quantization_config(is_per_channel=True, is_qat=True) - ) + quantizer = OpenVINOQuantizer() def check_nn_module(node: torch.fx.Node) -> None: self.assertTrue("nn_module_stack" in node.meta) @@ -517,97 +407,11 @@ def check_nn_module(node: torch.fx.Node) -> None: for node in m.conv_bn_relu.graph.nodes: # pyre-ignore[16] if node.op not in ["placeholder", "output", "get_attr"]: check_nn_module(node) - m.conv_bn_relu = prepare_qat_pt2e(m.conv_bn_relu, quantizer) # pyre-ignore[6,8] + m.conv_bn_relu = prepare_pt2e(m.conv_bn_relu, quantizer) # pyre-ignore[6,8] for node in m.conv_bn_relu.graph.nodes: # pyre-ignore[16] if node.name == "mul": check_nn_module(node) - def test_speed(self) -> None: - import time # noqa: F401 - - def dynamic_quantize_pt2e(model, example_inputs) -> torch.fx.GraphModule: - torch._dynamo.reset() - model = export_for_training(model, example_inputs).module() - # Per channel quantization for weight - # Dynamic quantization for activation - # Please read a detail: https://fburl.com/code/30zds51q - embedding_quantizer = EmbeddingQuantizer() - dynamic_quantizer = XNNPACKQuantizer() - operator_config_dynamic = get_symmetric_quantization_config( - is_per_channel=True, is_dynamic=True - ) - dynamic_quantizer.set_global(operator_config_dynamic) - composed_quantizer = ComposableQuantizer( - [embedding_quantizer, dynamic_quantizer] - ) - # prev = time.time() - model = prepare_qat_pt2e(model, composed_quantizer) # pyre-ignore[6] - # cur = time.time() - # print("prepare time:", cur - prev) - # Without Calibraiton, scale/zero value will have an initialized value of 1.0 - # Per channel quantization needs a proper scale/zero shape/value to work properly. - # So we need to run calibration before converting to quantized model. - model(*example_inputs) - # prev = time.time() - model = convert_pt2e(model) - # cur = time.time() - # uncomment to see the time - # print("convert time:", cur - prev) - return model - - class M(torch.nn.Module): - def __init__(self) -> None: - super().__init__() - self.linear = torch.nn.Linear(5, 5) - - def forward(self, x): - return self.linear(x) - - m = M().eval() - example_inputs = (torch.randn(5, 5),) - _ = dynamic_quantize_pt2e(m, example_inputs) - - def test_multi_users_without_output_observer(self) -> None: - """ - Test the case in which a node is used by multiple users, - and had its output observer removed. - """ - - class M(torch.nn.Module): - def __init__(self) -> None: - super().__init__() - self.conv = torch.nn.Conv2d(3, 3, 3) - - def forward(self, x): - x = self.conv(x) - return x, x + 1 - - example_inputs = (torch.randn(1, 3, 5, 5),) - m = M() - m = export_for_training(m, example_inputs).module() - quantizer = XNNPACKQuantizer().set_global( - get_symmetric_quantization_config(), - ) - m = prepare_pt2e(m, quantizer) # pyre-ignore[6] - m(*example_inputs) - - # Remove output observer - observer_to_remove = None - for n in m.graph.nodes: - if n.op == "output": - observer_to_remove = n.args[0][0] - assert observer_to_remove.op == "call_module" - assert observer_to_remove.target.startswith("activation_post_process_") - break - assert observer_to_remove is not None - observer_to_remove.replace_all_uses_with(observer_to_remove.args[0]) - m.graph.erase_node(observer_to_remove) - m.recompile() - - # Convert should succeed - m = convert_pt2e(m) - m(*example_inputs) - def test_fold_quantize_sym(self) -> None: """Test to make sure the quantized model gets quantized weight (quantize_per_tensor op is folded)""" m = self._get_pt2e_quantized_linear() @@ -700,7 +504,7 @@ def _check_quantization_with_ref(self, model: torch.fx.GraphModule, ref: Dict): def test_save_load(self) -> None: """Test save/load a quantized model""" - m = self._get_linear() + m = self._get_pt2e_quantized_linear() example_inputs = (torch.randn(2, 2),) ref_res = m(*example_inputs) @@ -716,123 +520,3 @@ def test_save_load(self) -> None: instantiate_parametrized_tests(TestQuantizePT2E) - - -class TestNumericDebugger(TestCase): - - def _extract_debug_handles(self, model) -> Dict[str, int]: - debug_handle_map: Dict[str, int] = {} - - def _extract_debug_handles_from_node(node: torch.fx.Node) -> None: - nonlocal debug_handle_map - if ( - CUSTOM_KEY in node.meta - and NUMERIC_DEBUG_HANDLE_KEY in node.meta[CUSTOM_KEY] - ): - debug_handle_map[str(node)] = node.meta[CUSTOM_KEY][ - NUMERIC_DEBUG_HANDLE_KEY - ] - - bfs_trace_with_node_process(model, _extract_debug_handles_from_node) - return debug_handle_map - - def _assert_each_node_has_debug_handle(self, model) -> None: - def _assert_node_has_debug_handle(node: torch.fx.Node) -> None: - self.assertTrue( - CUSTOM_KEY in node.meta - and NUMERIC_DEBUG_HANDLE_KEY in node.meta[CUSTOM_KEY], - f"Node {node} doesn't have debug handle", - ) - - bfs_trace_with_node_process(model, _assert_node_has_debug_handle) - - def test_quantize_pt2e_preserve_handle(self) -> None: - m = TestHelperModules.Conv2dThenConv1d() - example_inputs = m.example_inputs() - ep = export_for_training(m, example_inputs) - generate_numeric_debug_handle(ep) - m = ep.module() - - quantizer = XNNPACKQuantizer().set_global( - get_symmetric_quantization_config(is_per_channel=False) - ) - m = prepare_pt2e(m, quantizer) # pyre-ignore[6] - debug_handle_map = self._extract_debug_handles(m) - res_counter = Counter(debug_handle_map.values()) - repeated_debug_handle_ids = [1, 2, 3] - # 3 ids were repeated because we copy over the id from node to its output observer - # torch.ops.aten.conv2d.default, torch.ops.aten.squeeze.dim and torch.ops.aten.conv1d.default - for dh_id in repeated_debug_handle_ids: - self.assertEqual(res_counter[dh_id], 2) - - m(*example_inputs) - m = convert_pt2e(m) - self._assert_each_node_has_debug_handle(ep) - debug_handle_map = self._extract_debug_handles(m) - res_counter = Counter(debug_handle_map.values()) - # same set of ids where repeated, because we copy over the id from observer/fake_quant to - # dequantize node - repeated_debug_handle_ids = [1, 2, 3] - for dh_id in repeated_debug_handle_ids: - self.assertEqual(res_counter[dh_id], 2) - - def test_extract_results_from_loggers(self) -> None: - m = TestHelperModules.Conv2dThenConv1d() - example_inputs = m.example_inputs() - ep = export_for_training(m, example_inputs) - generate_numeric_debug_handle(ep) - m = ep.module() - m_ref_logger = prepare_for_propagation_comparison(m) # pyre-ignore[6] - - quantizer = XNNPACKQuantizer().set_global( - get_symmetric_quantization_config(is_per_channel=False) - ) - m = prepare_pt2e(m, quantizer) # pyre-ignore[6] - m(*example_inputs) - m = convert_pt2e(m) - m_quant_logger = prepare_for_propagation_comparison(m) - - m_ref_logger(*example_inputs) - m_quant_logger(*example_inputs) - ref_results = extract_results_from_loggers(m_ref_logger) - quant_results = extract_results_from_loggers(m_quant_logger) - comparison_results = compare_results( - ref_results, quant_results # pyre-ignore[6] - ) - for node_summary in comparison_results.values(): - if len(node_summary.results) > 0: - self.assertGreaterEqual( - node_summary.results[0].sqnr, 35 # pyre-ignore[6] - ) - - def test_extract_results_from_loggers_list_output(self) -> None: - m = TestHelperModules.Conv2dWithSplit() - example_inputs = m.example_inputs() - ep = export_for_training(m, example_inputs) - generate_numeric_debug_handle(ep) - m = ep.module() - m_ref_logger = prepare_for_propagation_comparison(m) # pyre-ignore[6] - - quantizer = XNNPACKQuantizer().set_global( - get_symmetric_quantization_config(is_per_channel=False) - ) - m = prepare_pt2e(m, quantizer) # pyre-ignore[6] - m(*example_inputs) - m = convert_pt2e(m) - m_quant_logger = prepare_for_propagation_comparison(m) - - m_ref_logger(*example_inputs) - m_quant_logger(*example_inputs) - ref_results = extract_results_from_loggers(m_ref_logger) - quant_results = extract_results_from_loggers(m_quant_logger) - comparison_results = compare_results( - ref_results, quant_results # pyre-ignore[6] - ) - for node_summary in comparison_results.values(): - if len(node_summary.results) > 0: - sqnr = node_summary.results[0].sqnr - if isinstance(sqnr, list): - for sqnr_i in sqnr: - self.assertGreaterEqual(sqnr_i, 35) - else: - self.assertGreaterEqual(sqnr, 35) # pyre-ignore[6] diff --git a/backends/openvino/tests/quantizer/test_representation.py b/backends/openvino/tests/quantizer/test_representation.py index 83cecaec5ad..5ca7941dc8d 100644 --- a/backends/openvino/tests/quantizer/test_representation.py +++ b/backends/openvino/tests/quantizer/test_representation.py @@ -1,26 +1,30 @@ # Owner(s): ["oncall: quantization"] import copy +import unittest from typing import Any, Optional import torch -from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import ( - get_symmetric_quantization_config, - XNNPACKQuantizer, -) -from torch._higher_order_ops.out_dtype import out_dtype # noqa: F401 +from executorch.backends.openvino.quantizer.quantizer import OpenVINOQuantizer + +from nncf.torch import disable_patching from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e from torch.ao.quantization.quantizer import Quantizer from torch.export import export_for_training from torch.testing._internal.common_quantization import ( NodeSpec as ns, QuantizationTestCase, - skipIfNoQNNPACK, TestHelperModules, ) -@skipIfNoQNNPACK class TestPT2ERepresentation(QuantizationTestCase): + def run(self, result=None): + """ + Disable NNCF pathing for each test call + """ + with disable_patching(): + super().run(result) + def _test_representation( self, model: torch.nn.Module, @@ -29,7 +33,7 @@ def _test_representation( ref_node_occurrence: dict[ns, int], non_ref_node_occurrence: dict[ns, int], fixed_output_tol: Optional[float] = None, - output_scale_idx: int = 2, + output_scale_idx: int = 1, ) -> None: # resetting dynamo cache torch._dynamo.reset() @@ -71,6 +75,7 @@ def _test_representation( idx += 1 if idx == output_scale_idx: output_tol = n.args[1] + break assert output_tol is not None # make sure the result is off by one at most in the quantized integer representation @@ -88,9 +93,7 @@ def __init__(self) -> None: def forward(self, x): return self.linear(x) - quantizer = XNNPACKQuantizer() - operator_config = get_symmetric_quantization_config(is_per_channel=False) - quantizer.set_global(operator_config) + quantizer = OpenVINOQuantizer() example_inputs = (torch.randn(2, 5),) self._test_representation( @@ -101,6 +104,9 @@ def forward(self, x): non_ref_node_occurrence={}, ) + @unittest.skip( + "Enable after the fix https://github.com/openvinotoolkit/nncf/pull/3225" + ) def test_dynamic_linear(self): class M(torch.nn.Module): def __init__(self) -> None: @@ -110,11 +116,7 @@ def __init__(self) -> None: def forward(self, x): return self.linear(x) - quantizer = XNNPACKQuantizer() - operator_config = get_symmetric_quantization_config( - is_per_channel=False, is_dynamic=True - ) - quantizer.set_global(operator_config) + quantizer = OpenVINOQuantizer() example_inputs = (torch.randn(2, 5),) self._test_representation( @@ -135,9 +137,7 @@ def __init__(self) -> None: def forward(self, x): return self.conv2d(x) - quantizer = XNNPACKQuantizer() - operator_config = get_symmetric_quantization_config(is_per_channel=False) - quantizer.set_global(operator_config) + quantizer = OpenVINOQuantizer() example_inputs = (torch.randn(1, 3, 3, 3),) self._test_representation( @@ -148,66 +148,8 @@ def forward(self, x): non_ref_node_occurrence={}, ) - def test_add(self): - class M(torch.nn.Module): - def __init__(self) -> None: - super().__init__() - - def forward(self, x, y): - return x + y - - quantizer = XNNPACKQuantizer() - quantization_config = get_symmetric_quantization_config(is_per_channel=True) - quantizer.set_global(quantization_config) - M().eval() - - example_inputs = ( - torch.randn(1, 3, 3, 3), - torch.randn(1, 3, 3, 3), - ) - - self._test_representation( - M().eval(), - example_inputs, - quantizer, - ref_node_occurrence={}, - non_ref_node_occurrence={}, - ) - - def test_add_relu(self): - class M(torch.nn.Module): - def __init__(self) -> None: - super().__init__() - - def forward(self, x, y): - out = x + y - out = torch.nn.functional.relu(out) - return out - - quantizer = XNNPACKQuantizer() - operator_config = get_symmetric_quantization_config(is_per_channel=True) - quantizer.set_global(operator_config) - - example_inputs = ( - torch.randn(1, 3, 3, 3), - torch.randn(1, 3, 3, 3), - ) - ref_node_occurrence = { - ns.call_function(out_dtype): 2, - } - - self._test_representation( - M().eval(), - example_inputs, - quantizer, - ref_node_occurrence=ref_node_occurrence, - non_ref_node_occurrence={}, - ) - def test_maxpool2d(self): - quantizer = XNNPACKQuantizer() - operator_config = get_symmetric_quantization_config(is_per_channel=True) - quantizer.set_global(operator_config) + quantizer = OpenVINOQuantizer() m_eager = TestHelperModules.ConvMaxPool2d().eval() example_inputs = (torch.randn(1, 2, 2, 2),) @@ -219,93 +161,3 @@ def test_maxpool2d(self): ref_node_occurrence={}, non_ref_node_occurrence={}, ) - - def test_qdq_per_channel(self): - """Test representation for quantize_per_channel and dequantize_per_channel op""" - - class M(torch.nn.Module): - def __init__(self) -> None: - super().__init__() - self.linear = torch.nn.Linear(5, 5) - - def forward(self, x): - return self.linear(x) - - quantizer = XNNPACKQuantizer() - # use per channel quantization for weight - operator_config = get_symmetric_quantization_config(is_per_channel=True) - quantizer.set_global(operator_config) - M().eval() - - inputs = [ - (torch.randn(1, 5),), - (torch.randn(1, 3, 5),), - (torch.randn(1, 3, 3, 5),), - (torch.randn(1, 3, 3, 3, 5),), - ] - for example_inputs in inputs: - ref_node_occurrence = { - ns.call_function( - torch.ops.quantized_decomposed.quantize_per_channel.default - ): 0, - ns.call_function( - torch.ops.quantized_decomposed.dequantize_per_channel.default - ): 0, - } - non_ref_node_occurrence = { - # quantize_per_channel is folded - ns.call_function( - torch.ops.quantized_decomposed.quantize_per_channel.default - ): 0, - ns.call_function( - torch.ops.quantized_decomposed.dequantize_per_channel.default - ): 1, - } - - self._test_representation( - M().eval(), - example_inputs, - quantizer, - ref_node_occurrence, - non_ref_node_occurrence, - output_scale_idx=2, - ) - - def test_qdq(self): - """Test representation for quantize and dequantize op""" - - class M(torch.nn.Module): - def __init__(self) -> None: - super().__init__() - - def forward(self, x, y): - return x + y - - quantizer = XNNPACKQuantizer() - quantization_config = get_symmetric_quantization_config(is_per_channel=True) - quantizer.set_global(quantization_config) - M().eval() - - example_inputs = ( - torch.randn(1, 3, 3, 3), - torch.randn(1, 3, 3, 3), - ) - ref_node_occurrence = { - ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor): 0, - ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor): 0, - } - non_ref_node_occurrence = { - ns.call_function( - torch.ops.quantized_decomposed.quantize_per_tensor.default - ): 3, - ns.call_function( - torch.ops.quantized_decomposed.dequantize_per_tensor.default - ): 3, - } - self._test_representation( - M().eval(), - example_inputs, - quantizer, - ref_node_occurrence, - non_ref_node_occurrence, - ) diff --git a/backends/openvino/tests/quantizer/test_xnnpack_quantizer.py b/backends/openvino/tests/quantizer/test_xnnpack_quantizer.py deleted file mode 100644 index 34095a3c65a..00000000000 --- a/backends/openvino/tests/quantizer/test_xnnpack_quantizer.py +++ /dev/null @@ -1,948 +0,0 @@ -# Owner(s): ["oncall: mobile"] -import copy -import operator - -import torch -import torch._dynamo as torchdynamo -from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import ( - get_symmetric_quantization_config, - XNNPACKQuantizer, -) -from torch.ao.ns.fx.utils import compute_sqnr -from torch.ao.quantization import ( - default_dynamic_fake_quant, - default_dynamic_qconfig, - observer, - QConfig, - QConfigMapping, -) -from torch.ao.quantization.backend_config import get_qnnpack_backend_config -from torch.ao.quantization.qconfig import ( - default_per_channel_symmetric_qnnpack_qconfig, - default_symmetric_qnnpack_qconfig, - per_channel_weight_observer_range_neg_127_to_127, - weight_observer_range_neg_127_to_127, -) -from torch.ao.quantization.quantize_fx import ( - _convert_to_reference_decomposed_fx, - convert_to_reference_fx, - prepare_fx, -) -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e -from torch.export import export_for_training -from torch.testing._internal.common_quantization import ( - NodeSpec as ns, - PT2EQuantizationTestCase, - skip_if_no_torchvision, - skipIfNoQNNPACK, - TestHelperModules, -) -from torch.testing._internal.common_quantized import override_quantized_engine - - -@skipIfNoQNNPACK -class TestXNNPACKQuantizer(PT2EQuantizationTestCase): - def test_conv1d(self): - quantizer = XNNPACKQuantizer() - quantization_config = get_symmetric_quantization_config(is_per_channel=True) - quantizer.set_global(quantization_config) - example_inputs = (torch.randn(1, 3, 5),) - node_occurrence = { - # input and output are using quantize_per_tensor and weight is using quantize_per_channel - torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, - torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, - torch.ops.quantized_decomposed.quantize_per_channel.default: 0, - torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, - } - node_list = [ - torch.ops.quantized_decomposed.dequantize_per_tensor.default, - torch.ops.aten.conv1d.default, - torch.ops.quantized_decomposed.quantize_per_tensor.default, - ] - self._test_quantizer( - TestHelperModules.ConvWithBNRelu(dim=1, relu=False, bn=False), - example_inputs, - quantizer, - node_occurrence, - node_list, - ) - - def test_conv2d(self): - quantizer = XNNPACKQuantizer() - quantization_config = get_symmetric_quantization_config(is_per_channel=True) - quantizer.set_global(quantization_config) - example_inputs = (torch.randn(1, 3, 5, 5),) - node_occurrence = { - # input and output are using quantize_per_tensor and weight is using quantize_per_channel - torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, - torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, - # quantize_per_channel for weights are const propagated - torch.ops.quantized_decomposed.quantize_per_channel.default: 0, - torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, - } - node_list = [ - torch.ops.quantized_decomposed.dequantize_per_tensor.default, - torch.ops.aten.conv2d.default, - torch.ops.quantized_decomposed.quantize_per_tensor.default, - ] - self._test_quantizer( - TestHelperModules.ConvWithBNRelu(relu=False, bn=False), - example_inputs, - quantizer, - node_occurrence, - node_list, - ) - - def test_conv1d_with_conv2d(self): - quantizer = XNNPACKQuantizer() - quantization_config = get_symmetric_quantization_config(is_per_channel=True) - quantizer.set_global(quantization_config) - node_occurrence = { - # input and output are using quantize_per_tensor and weight is using quantize_per_channel - torch.ops.quantized_decomposed.quantize_per_tensor.default: 4, - torch.ops.quantized_decomposed.dequantize_per_tensor.default: 4, - torch.ops.quantized_decomposed.quantize_per_channel.default: 0, - torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, - } - node_list = [ - torch.ops.quantized_decomposed.dequantize_per_tensor.default, - torch.ops.aten.conv2d.default, - torch.ops.quantized_decomposed.quantize_per_tensor.default, - torch.ops.quantized_decomposed.dequantize_per_tensor.default, - torch.ops.aten.conv1d.default, - torch.ops.quantized_decomposed.quantize_per_tensor.default, - ] - m = TestHelperModules.Conv2dThenConv1d() - self._test_quantizer( - m, - m.example_inputs(), - quantizer, - node_occurrence, - node_list, - ) - - def test_linear(self): - quantizer = XNNPACKQuantizer() - quantization_config = get_symmetric_quantization_config(is_per_channel=True) - quantizer.set_global(quantization_config) - m_eager = TestHelperModules.TwoLinearModule().eval() - - # Test with 2d inputs - example_inputs_2d = (torch.randn(9, 8),) - example_inputs_3d = (torch.randn(9, 10, 8),) - example_inputs_4d = (torch.randn(9, 10, 11, 8),) - node_occurrence = { - # input and output are using quantize_per_tensor and weight is using quantize_per_channel - torch.ops.quantized_decomposed.quantize_per_tensor.default: 3, - torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3, - # quantize_per_channel for weights are const propagated - torch.ops.quantized_decomposed.quantize_per_channel.default: 0, - torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, - } - qconfig = default_per_channel_symmetric_qnnpack_qconfig - qconfig_mapping = QConfigMapping().set_global(qconfig) - for example_inputs in [example_inputs_2d, example_inputs_3d, example_inputs_4d]: - self._test_quantizer( - m_eager, - example_inputs, - quantizer, - node_occurrence, - [], - True, - qconfig_mapping, - ) - - def test_linear_relu(self): - quantizer = XNNPACKQuantizer() - quantization_config = get_symmetric_quantization_config(is_per_channel=True) - quantizer.set_global(quantization_config) - m_eager = TestHelperModules.LinearReluModel().eval() - - # Test with 2d inputs - example_inputs_2d = (torch.randn(1, 5),) - example_inputs_3d = (torch.randn(1, 2, 5),) - example_inputs_4d = (torch.randn(1, 2, 3, 5),) - - node_occurrence = { - # input and output are using quantize_per_tensor and weight is using quantize_per_channel - # There should not be extra quantize_per_tensor or dequantize_per_tensors for relu - torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, - torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, - # quantize_per_channel for weights are const propagated - torch.ops.quantized_decomposed.quantize_per_channel.default: 0, - torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, - } - qconfig = default_per_channel_symmetric_qnnpack_qconfig - qconfig_mapping = QConfigMapping().set_global(qconfig) - for example_inputs in [example_inputs_2d, example_inputs_3d, example_inputs_4d]: - self._test_quantizer( - m_eager, - example_inputs, - quantizer, - node_occurrence, - [], # node_list - False, # executorch_backend_config() does not fuse linear-relu - qconfig_mapping, - ) - - def test_conv_linear_no_permute(self): - quantizer = XNNPACKQuantizer() - quantization_config = get_symmetric_quantization_config(is_per_channel=True) - quantizer.set_global(quantization_config) - node_occurrence = { - # input and output are using quantize_per_tensor and weight is using quantize_per_channel - torch.ops.quantized_decomposed.quantize_per_tensor.default: 5, - torch.ops.quantized_decomposed.dequantize_per_tensor.default: 5, - # quantize_per_channel for weights are const propagated - torch.ops.quantized_decomposed.quantize_per_channel.default: 0, - torch.ops.quantized_decomposed.dequantize_per_channel.default: 3, - } - qconfig = default_per_channel_symmetric_qnnpack_qconfig - qconfig_mapping = QConfigMapping().set_global(qconfig) - # Test with 2d inputs - example_inputs = (torch.randn(2, 3, 4, 4),) - self._test_quantizer( - TestHelperModules.Conv2dWithTwoLinear(), - example_inputs, - quantizer, - node_occurrence, - [], - True, - qconfig_mapping, - ) - - def test_conv_linear(self): - quantizer = XNNPACKQuantizer() - quantization_config = get_symmetric_quantization_config(is_per_channel=True) - quantizer.set_global(quantization_config) - - # Test with 2d inputs - example_inputs = (torch.randn(2, 3, 4, 4),) - node_occurrence = { - torch.ops.quantized_decomposed.quantize_per_tensor.default: 5, - torch.ops.quantized_decomposed.dequantize_per_tensor.default: 5, - # quantize_per_channel for weights are const propagated - torch.ops.quantized_decomposed.quantize_per_channel.default: 0, - torch.ops.quantized_decomposed.dequantize_per_channel.default: 3, - } - qconfig = default_per_channel_symmetric_qnnpack_qconfig - qconfig_mapping = QConfigMapping().set_global(qconfig) - self._test_quantizer( - TestHelperModules.Conv2dWithTwoLinearPermute(), - example_inputs, - quantizer, - node_occurrence, - [], - True, - qconfig_mapping, - ) - - def test_linear_with_dynamic_shape(self): - quantizer = XNNPACKQuantizer() - quantization_config = get_symmetric_quantization_config(is_per_channel=True) - quantizer.set_global(quantization_config) - m_eager = TestHelperModules.TwoLinearModule().eval() - - # Test with 2d inputs - example_inputs_3d = (torch.randn(9, 10, 8),) - node_occurrence = { - # input and output are using quantize_per_tensor and weight is using quantize_per_channel - torch.ops.quantized_decomposed.quantize_per_tensor.default: 3, - torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3, - # quantize_per_channel for weights are const propagated - torch.ops.quantized_decomposed.quantize_per_channel.default: 0, - torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, - } - qconfig = default_per_channel_symmetric_qnnpack_qconfig - qconfig_mapping = QConfigMapping().set_global(qconfig) - self._test_quantizer( - m_eager, - example_inputs_3d, - quantizer, - node_occurrence, - [], - True, - qconfig_mapping, - export_with_dynamic_shape=True, - ) - - def test_obs_sharing_ops(self): - quantizer = XNNPACKQuantizer() - quantization_config = get_symmetric_quantization_config(is_per_channel=True) - quantizer.set_global(quantization_config) - m = TestHelperModules.Conv2dWithObsSharingOps().eval() - example_inputs = (torch.randn(1, 3, 5, 5),) - node_occurrence = { - # input and output are using quantize_per_tensor and weight is using quantize_per_channel - torch.ops.quantized_decomposed.quantize_per_tensor.default: 5, - torch.ops.quantized_decomposed.dequantize_per_tensor.default: 5, - # quantize_per_channel for weights are const propagated - torch.ops.quantized_decomposed.quantize_per_channel.default: 0, - torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, - } - node_list = [ - torch.ops.quantized_decomposed.dequantize_per_tensor.default, - torch.ops.aten.conv2d.default, - torch.ops.quantized_decomposed.quantize_per_tensor.default, - torch.ops.quantized_decomposed.dequantize_per_tensor.default, - torch.ops.aten.adaptive_avg_pool2d.default, - torch.ops.quantized_decomposed.quantize_per_tensor.default, - torch.ops.quantized_decomposed.dequantize_per_tensor.default, - torch.ops.aten.hardtanh.default, - torch.ops.quantized_decomposed.quantize_per_tensor.default, - torch.ops.quantized_decomposed.dequantize_per_tensor.default, - torch.ops.aten.mean.default, - torch.ops.quantized_decomposed.quantize_per_tensor.default, - torch.ops.quantized_decomposed.dequantize_per_tensor.default, - ] - self._test_quantizer(m, example_inputs, quantizer, node_occurrence, node_list) - - def test_set_module_name(self): - class Sub(torch.nn.Module): - def __init__(self) -> None: - super().__init__() - self.linear = torch.nn.Linear(5, 5) - - def forward(self, x): - return self.linear(x) - - class M(torch.nn.Module): - def __init__(self) -> None: - super().__init__() - self.linear = torch.nn.Linear(5, 5) - self.sub = Sub() - - def forward(self, x): - x = self.linear(x) - x = self.sub(x) - return x - - m = M().eval() - example_inputs = (torch.randn(3, 5),) - quantizer = XNNPACKQuantizer() - quantization_config = get_symmetric_quantization_config(is_per_channel=True) - quantizer.set_module_name("sub", quantization_config) - node_occurrence = { - torch.ops.aten.linear.default: 2, - # input and output for the second linear - torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, - torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, - } - node_list = [ - # first linear is not quantized - torch.ops.aten.linear.default, - # second linear is quantized - torch.ops.quantized_decomposed.quantize_per_tensor.default, - torch.ops.quantized_decomposed.dequantize_per_tensor.default, - torch.ops.aten.linear.default, - torch.ops.quantized_decomposed.quantize_per_tensor.default, - torch.ops.quantized_decomposed.dequantize_per_tensor.default, - ] - self._test_quantizer(m, example_inputs, quantizer, node_occurrence, node_list) - - def test_set_module_name_with_underscores(self) -> None: - """Test that if a module name has an underscore, we can still quantize it""" - - class M(torch.nn.Module): - def __init__(self) -> None: - super().__init__() - # This module name has underscores, which can be part of a mangled - # name. - self.foo_bar = torch.nn.Linear(2, 2) - self.baz = torch.nn.Linear(2, 2) - - def forward(self, x): - return self.baz(self.foo_bar(x)) - - quantizer = XNNPACKQuantizer() - # Set global to no quantization and then per-channel for a specific submodule. - quantizer.set_module_name( - "foo_bar", get_symmetric_quantization_config(is_per_channel=True) - ) - example_inputs = (torch.randn(2, 2),) - m = M().eval() - m = export_for_training(m, example_inputs).module() - m = prepare_pt2e(m, quantizer) # pyre-ignore[6] - # Use a linear count instead of names because the names might change, but - # the order should be the same. - count = 0 - for n in m.graph.nodes: - if n.op == "call_function" and n.target == torch.ops.aten.linear.default: - # Get the weight observer to see the per-channel vs per-tensor. - weight_observer_node = n.args[1] - if count == 0: - # The weight tensor should be per-tensor and not per-channel - # for foo_bar. - self.assertEqual(weight_observer_node.op, "call_module") - observer_instance = getattr(m, weight_observer_node.target) - self.assertEqual( - observer_instance.qscheme, torch.per_channel_symmetric - ) - else: - # For baz it should have no observer at all. - self.assertNotEqual(weight_observer_node.op, "call_module") - count += 1 - - def test_set_module_type(self): - class Sub(torch.nn.Module): - def __init__(self) -> None: - super().__init__() - self.linear = torch.nn.Linear(5, 5) - - def forward(self, x): - return self.linear(x) - - class M(torch.nn.Module): - def __init__(self) -> None: - super().__init__() - self.linear = torch.nn.Linear(5, 5) - self.sub = Sub() - - def forward(self, x): - x = self.linear(x) - x = self.sub(x) - return x - - m = M().eval() - example_inputs = (torch.randn(3, 5),) - quantizer = XNNPACKQuantizer() - quantization_config = get_symmetric_quantization_config(is_per_channel=True) - quantizer.set_module_type(Sub, quantization_config) - node_occurrence = { - torch.ops.aten.linear.default: 2, - # input and output for the second linear - torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, - torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, - } - node_list = [ - # first linear is not quantized - torch.ops.aten.linear.default, - # second linear is quantized - torch.ops.quantized_decomposed.quantize_per_tensor.default, - torch.ops.quantized_decomposed.dequantize_per_tensor.default, - torch.ops.aten.linear.default, - torch.ops.quantized_decomposed.quantize_per_tensor.default, - torch.ops.quantized_decomposed.dequantize_per_tensor.default, - ] - self._test_quantizer(m, example_inputs, quantizer, node_occurrence, node_list) - - def test_set_module_type_case_2(self): - class M(torch.nn.Module): - def __init__(self) -> None: - super().__init__() - self.conv = torch.nn.Conv2d( - in_channels=3, - out_channels=3, - kernel_size=3, - stride=1, - padding=1, - bias=True, - ) - self.conv2 = torch.nn.Conv2d( - in_channels=3, - out_channels=3, - kernel_size=3, - stride=1, - padding=1, - bias=True, - ) - self.conv3 = torch.nn.Conv2d( - in_channels=3, - out_channels=3, - kernel_size=3, - stride=1, - padding=1, - bias=True, - ) - self.relu = torch.nn.ReLU() - self.avgpool = torch.nn.AdaptiveAvgPool2d((1, 1)) - self.fc = torch.nn.Linear(3, 16) - - def forward(self, x): - x1 = self.conv(x) - x2 = self.relu(self.conv2(x1) + self.conv3(x1)) - x3 = self.avgpool(x2) - x4 = torch.flatten(x3, 1) - x5 = self.fc(x4) - return x5 - - m = M().eval() - example_inputs = (torch.randn(1, 3, 16, 16),) - quantizer = XNNPACKQuantizer() - quantization_config = get_symmetric_quantization_config(is_per_channel=True) - # We only want to annotate Linear type - quantizer.set_module_type(torch.nn.Linear, quantization_config) - node_occurrence = { - torch.ops.aten.conv2d.default: 3, - torch.ops.aten.linear.default: 1, - # input and output for the linear - torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, - torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, - } - node_list = [ - # only the linear is quantized - torch.ops.quantized_decomposed.quantize_per_tensor.default, - torch.ops.quantized_decomposed.dequantize_per_tensor.default, - torch.ops.aten.linear.default, - torch.ops.quantized_decomposed.quantize_per_tensor.default, - torch.ops.quantized_decomposed.dequantize_per_tensor.default, - ] - self._test_quantizer(m, example_inputs, quantizer, node_occurrence, node_list) - - def test_propagate_annotation(self): - quantizer = XNNPACKQuantizer() - quantization_config = get_symmetric_quantization_config(is_per_channel=True) - quantizer.set_global(quantization_config) - m = TestHelperModules.Conv2dPropAnnotaton().eval() - example_inputs = (torch.randn(1, 3, 5, 5),) - - # program capture - m = export_for_training( - m, - example_inputs, - ).module() - - m = prepare_pt2e(m, quantizer) - m(*example_inputs) - for n in m.graph.nodes: - if n.target in [ - torch.ops.aten.view.default, - torch.ops.aten.hardtanh.default, - ]: - input_act = getattr(m, n.args[0].target) - output_act = getattr(m, next(iter(n.users)).target) - self.assertIs(input_act, output_act) - - m = convert_pt2e(m) - node_occurrence = { - # input and output are using quantize_per_tensor and weight is using quantize_per_channel - ns.call_function( - torch.ops.quantized_decomposed.quantize_per_tensor.default - ): 5, - ns.call_function( - torch.ops.quantized_decomposed.dequantize_per_tensor.default - ): 5, - # note: quantize op for weights are const propagated - ns.call_function( - torch.ops.quantized_decomposed.quantize_per_channel.default - ): 0, - ns.call_function( - torch.ops.quantized_decomposed.dequantize_per_channel.default - ): 2, - } - self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) - - def test_dynamic_linear(self): - quantizer = XNNPACKQuantizer() - quantization_config = get_symmetric_quantization_config( - is_per_channel=True, is_dynamic=True - ) - quantizer.set_global(quantization_config) - m_eager = TestHelperModules.TwoLinearModule().eval() - - node_occurrence = { - # input and output are using quantize_per_tensor and weight is using quantize_per_channel - torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 2, - torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 2, - # note: quantize op for weights are const propagated - torch.ops.quantized_decomposed.quantize_per_channel.default: 0, - torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, - } - act_affine_quant_obs = observer.PlaceholderObserver.with_args( - dtype=torch.qint8, - qscheme=torch.per_tensor_affine, - quant_min=-128, - quant_max=127, - eps=2**-12, - is_dynamic=True, - ) - qconfig = QConfig( - activation=act_affine_quant_obs, - weight=per_channel_weight_observer_range_neg_127_to_127, - ) - qconfig_mapping = QConfigMapping().set_global(qconfig) - # Test with 2d inputs - example_inputs_2d = (torch.randn(9, 8),) - example_inputs_4d = (torch.randn(9, 10, 11, 8),) - for example_inputs in [example_inputs_2d, example_inputs_4d]: - self._test_quantizer( - m_eager, - example_inputs, - quantizer, - node_occurrence, - [], - True, - qconfig_mapping, - ) - - def test_gru(self): - """this is a test for annotating fp32 GRU so that it produces - q -> dq -> fp32_gru -> q -> dq, this is currently enough for our use cases, - but we may change the annotation to be more precise in the future - """ - - class RNNDynamicModel(torch.nn.Module): - def __init__(self, mod_type): - super().__init__() - self.qconfig = default_dynamic_qconfig - if mod_type == "GRU": - self.mod = torch.nn.GRU(2, 2).to(dtype=torch.float) - if mod_type == "LSTM": - self.mod = torch.nn.LSTM(2, 2).to(dtype=torch.float) - - def forward(self, input_tensor, hidden_tensor): - input_tensor = 1 * input_tensor - hidden_tensor = 1 * hidden_tensor - output_tensor, hidden_out = self.mod(input_tensor, hidden_tensor) - return 1 * output_tensor, 1 * hidden_out - - with override_quantized_engine("qnnpack"): - model_fx = RNNDynamicModel("GRU") - niter = 10 - example_inputs = ( - # input_tensor - torch.tensor([[100, -155], [-155, 100], [100, -155]], dtype=torch.float) - .unsqueeze(0) - .repeat(niter, 1, 1), - # hidden_tensor - # (D * num_layers, N, H_out) - torch.tensor([[[100, -155]]], dtype=torch.float).repeat(1, 3, 1), - ) - model_graph = copy.deepcopy(model_fx) - - qconfig_mapping = QConfigMapping().set_object_type( - operator.mul, default_symmetric_qnnpack_qconfig - ) - model_fx = prepare_fx( - model_fx, - qconfig_mapping, - example_inputs, - backend_config=get_qnnpack_backend_config(), - ) - model_fx(*example_inputs) - model_fx = _convert_to_reference_decomposed_fx(model_fx) - - with torchdynamo.config.patch(allow_rnn=True): - model_graph = export_for_training( - model_graph, - example_inputs, - ).module() - quantizer = XNNPACKQuantizer() - quantization_config = get_symmetric_quantization_config( - is_per_channel=False, is_dynamic=False - ) - quantizer.set_global(quantization_config) - model_graph = prepare_pt2e(model_graph, quantizer) - model_graph(*example_inputs) - model_graph = convert_pt2e(model_graph) - self.assertEqual(model_fx(*example_inputs), model_graph(*example_inputs)) - - def test_linear_gru(self): - """this test is to make sure GRU annotation does not interfere with linear annotation""" - - class RNNDynamicModel(torch.nn.Module): - def __init__(self, mod_type): - super().__init__() - self.qconfig = default_dynamic_qconfig - self.linear = torch.nn.Linear(2, 2) - if mod_type == "GRU": - self.mod = torch.nn.GRU(2, 2).to(dtype=torch.float) - if mod_type == "LSTM": - self.mod = torch.nn.LSTM(2, 2).to(dtype=torch.float) - - def forward(self, input_tensor, hidden_tensor): - input_tensor = self.linear(input_tensor) - input_tensor = 1 * input_tensor - hidden_tensor = 1 * hidden_tensor - output_tensor, hidden_out = self.mod(input_tensor, hidden_tensor) - return 1 * output_tensor, 1 * hidden_out - - with override_quantized_engine("qnnpack"): - model_fx = RNNDynamicModel("GRU") - niter = 10 - example_inputs = ( - # input_tensor - torch.tensor([[100, -155], [-155, 100], [100, -155]], dtype=torch.float) - .unsqueeze(0) - .repeat(niter, 1, 1), - # hidden_tensor - # (D * num_layers, N, H_out) - torch.tensor([[[100, -155]]], dtype=torch.float).repeat(1, 3, 1), - ) - model_graph = copy.deepcopy(model_fx) - - qconfig_mapping = ( - QConfigMapping() - .set_object_type(operator.mul, default_symmetric_qnnpack_qconfig) - .set_object_type(torch.nn.Linear, default_symmetric_qnnpack_qconfig) - ) - model_fx = prepare_fx( - model_fx, - qconfig_mapping, - example_inputs, - backend_config=get_qnnpack_backend_config(), - ) - model_fx(*example_inputs) - model_fx = _convert_to_reference_decomposed_fx(model_fx) - - with torchdynamo.config.patch(allow_rnn=True): - model_graph = export_for_training( - model_graph, - example_inputs, - ).module() - quantizer = XNNPACKQuantizer() - quantization_config = get_symmetric_quantization_config( - is_per_channel=False, is_dynamic=False - ) - quantizer.set_global(quantization_config) - model_graph = prepare_pt2e(model_graph, quantizer) - model_graph(*example_inputs) - model_graph = convert_pt2e(model_graph) - self.assertEqual(model_fx(*example_inputs), model_graph(*example_inputs)) - - def test_add_and_inplace_add(self): - quantizer = XNNPACKQuantizer() - quantization_config = get_symmetric_quantization_config(is_per_channel=True) - quantizer.set_global(quantization_config) - example_inputs = ( - torch.randn(1, 3, 5, 5), - torch.randn(1, 3, 5, 5), - ) - node_occurrence = { - # two input and one output for first add, and output for second add - torch.ops.quantized_decomposed.quantize_per_tensor.default: 4, - torch.ops.quantized_decomposed.dequantize_per_tensor.default: 5, - } - node_list = [ - torch.ops.quantized_decomposed.dequantize_per_tensor.default, - torch.ops.quantized_decomposed.dequantize_per_tensor.default, - torch.ops.aten.add.Tensor, - torch.ops.quantized_decomposed.quantize_per_tensor.default, - torch.ops.quantized_decomposed.dequantize_per_tensor.default, - # TODO torch.ops.aten.add.Tensor, - torch.ops.quantized_decomposed.quantize_per_tensor.default, - ] - self._test_quantizer( - TestHelperModules.AddInplaceAdd(), - example_inputs, - quantizer, - node_occurrence, - node_list, - ) - - def test_mul_and_inplace_mul(self): - quantizer = XNNPACKQuantizer() - quantization_config = get_symmetric_quantization_config(is_per_channel=True) - quantizer.set_global(quantization_config) - example_inputs = ( - torch.randn(1, 3, 5, 5), - torch.randn(1, 3, 5, 5), - ) - node_occurrence = { - # two input and one output for first add, and output for second add - torch.ops.quantized_decomposed.quantize_per_tensor.default: 4, - torch.ops.quantized_decomposed.dequantize_per_tensor.default: 5, - } - node_list = [ - torch.ops.quantized_decomposed.dequantize_per_tensor.default, - torch.ops.quantized_decomposed.dequantize_per_tensor.default, - torch.ops.aten.mul.Tensor, - torch.ops.quantized_decomposed.quantize_per_tensor.default, - torch.ops.quantized_decomposed.dequantize_per_tensor.default, - # TODO torch.ops.aten.mul.Tensor, - torch.ops.quantized_decomposed.quantize_per_tensor.default, - ] - self._test_quantizer( - TestHelperModules.MulInplaceMul(), - example_inputs, - quantizer, - node_occurrence, - node_list, - ) - - def test_add_mul_scalar(self): - quantizer = XNNPACKQuantizer() - quantization_config = get_symmetric_quantization_config(is_per_channel=True) - quantizer.set_global(quantization_config) - example_inputs = (torch.randn(1, 3, 5, 5),) - node_occurrence = { - # two input and one output for first add, and output for second add - torch.ops.quantized_decomposed.quantize_per_tensor.default: 5, - # TODO torch.ops.quantized_decomposed.dequantize_per_tensor.default: 9, - } - node_list = [ - torch.ops.quantized_decomposed.dequantize_per_tensor.default, - torch.ops.quantized_decomposed.dequantize_per_tensor.default, - torch.ops.aten.add.Tensor, - torch.ops.quantized_decomposed.quantize_per_tensor.default, - torch.ops.quantized_decomposed.dequantize_per_tensor.default, - torch.ops.aten.mul.Tensor, - torch.ops.quantized_decomposed.quantize_per_tensor.default, - torch.ops.quantized_decomposed.dequantize_per_tensor.default, - # TODO torch.ops.aten.add.Tensor, - torch.ops.quantized_decomposed.quantize_per_tensor.default, - torch.ops.quantized_decomposed.dequantize_per_tensor.default, - # TODO torch.ops.aten.mul.Tensor, - torch.ops.quantized_decomposed.quantize_per_tensor.default, - ] - self._test_quantizer( - TestHelperModules.AddMulScalar(), - example_inputs, - quantizer, - node_occurrence, - node_list, - ) - - def test_mul_float32_max(self): - class M(torch.nn.Module): - def forward(self, x): - return x * 3.4028235e38 - - quantizer = XNNPACKQuantizer() - quantization_config = get_symmetric_quantization_config(is_per_channel=True) - quantizer.set_global(quantization_config) - example_inputs = (torch.randn(1, 3, 5, 5),) - # not quantized - node_occurrence = { - torch.ops.quantized_decomposed.quantize_per_tensor.default: 0, - torch.ops.quantized_decomposed.dequantize_per_tensor.default: 0, - } - node_list = [ - torch.ops.aten.mul.Tensor, - ] - self._test_quantizer( - M(), - example_inputs, - quantizer, - node_occurrence, - node_list, - ) - - def test_add_mul_long(self): - class M(torch.nn.Module): - def __init__(self) -> None: - super().__init__() - self.t = torch.tensor([100]) - - def forward(self, x): - x = x + self.t - x = x * self.t - return x - - quantizer = XNNPACKQuantizer() - quantization_config = get_symmetric_quantization_config(is_per_channel=True) - quantizer.set_global(quantization_config) - example_inputs = (torch.randn(1, 3, 5, 5),) - # not quantized - node_occurrence = { - torch.ops.quantized_decomposed.quantize_per_tensor.default: 0, - torch.ops.quantized_decomposed.dequantize_per_tensor.default: 0, - } - node_list = [ - torch.ops.aten.add.Tensor, - torch.ops.aten.mul.Tensor, - ] - self._test_quantizer( - M(), - example_inputs, - quantizer, - node_occurrence, - node_list, - ) - - def test_cat_same_node(self): - """Ensure that concatenating the same node does not cause any unexpected behavior""" - - class M(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - x = torch.cat([x, x]) - return x - - quantizer = XNNPACKQuantizer() - quantization_config = get_symmetric_quantization_config(is_per_channel=True) - quantizer.set_global(quantization_config) - example_inputs = (torch.randn(1, 3, 5, 5),) - node_occurrence = { - torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, - torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, - } - node_list = [ - torch.ops.quantized_decomposed.quantize_per_tensor.default, - torch.ops.quantized_decomposed.dequantize_per_tensor.default, - torch.ops.aten.cat.default, - torch.ops.quantized_decomposed.quantize_per_tensor.default, - torch.ops.quantized_decomposed.dequantize_per_tensor.default, - ] - self._test_quantizer( - M(), - example_inputs, - quantizer, - node_occurrence, - node_list, - ) - - -# TODO: express this using self._test_quantizer, add test for inception_v4 -class TestXNNPACKQuantizerModels(PT2EQuantizationTestCase): - @skip_if_no_torchvision - @skipIfNoQNNPACK - def test_resnet18(self): - import torchvision - - with override_quantized_engine("qnnpack"): - example_inputs = (torch.randn(1, 3, 224, 224),) - m = torchvision.models.resnet18().eval() - m_copy = copy.deepcopy(m) - # program capture - m = export_for_training( - m, - example_inputs, - ).module() - - quantizer = XNNPACKQuantizer() - quantization_config = get_symmetric_quantization_config(is_per_channel=True) - quantizer.set_global(quantization_config) - m = prepare_pt2e(m, quantizer) - # checking that we inserted observers correctly for maxpool operator (input and - # output share observer instance) - self.assertEqual( - id(m.activation_post_process_3), id(m.activation_post_process_2) - ) - after_prepare_result = m(*example_inputs) - m = convert_pt2e(m) - - after_quant_result = m(*example_inputs) - - # comparing with existing fx graph mode quantization reference flow - qconfig = default_per_channel_symmetric_qnnpack_qconfig - qconfig_mapping = QConfigMapping().set_global(qconfig) - backend_config = get_qnnpack_backend_config() - m_fx = prepare_fx( - m_copy, qconfig_mapping, example_inputs, backend_config=backend_config - ) - after_prepare_result_fx = m_fx(*example_inputs) - m_fx = convert_to_reference_fx(m_fx, backend_config=backend_config) - - after_quant_result_fx = m_fx(*example_inputs) - - # the result matches exactly after prepare - # Note: this currently will always be true since we are inserting observers - # the check becomes useful when we add qat examples - # but we can still manully inspect the printed observers to make sure - # it matches - self.assertEqual(after_prepare_result, after_prepare_result_fx) - self.assertEqual( - compute_sqnr(after_prepare_result, after_prepare_result_fx), - torch.tensor(float("inf")), - ) - # there are slight differences after convert due to different implementations - # of quant/dequant - self.assertTrue( - torch.max(after_quant_result - after_quant_result_fx) < 1e-1 - ) - self.assertTrue( - compute_sqnr(after_quant_result, after_quant_result_fx) > 35 - ) diff --git a/backends/xnnpack/test/quantizer/test_pt2e_quantization.py b/backends/xnnpack/test/quantizer/test_pt2e_quantization.py index 477ac83ffd9..ea6116a6f0a 100644 --- a/backends/xnnpack/test/quantizer/test_pt2e_quantization.py +++ b/backends/xnnpack/test/quantizer/test_pt2e_quantization.py @@ -54,7 +54,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase): - def _get_linear( + def _get_pt2e_quantized_linear( self, is_per_channel: bool = False ) -> torch.fx.GraphModule: class M(torch.nn.Module): @@ -651,7 +651,7 @@ def forward(self, x): def test_fold_quantize(self) -> None: """Test to make sure the quantized model gets quantized weight (quantize_per_tensor op is folded)""" - m = self._get_linear() + m = self._get_pt2e_quantized_linear() node_occurrence = { # quantize op for weight node is folded ns.call_function( @@ -665,7 +665,7 @@ def test_fold_quantize(self) -> None: def test_fold_quantize_per_channel(self) -> None: """Test to make sure the quantized model gets quantized weight (quantize_per_channel op is folded)""" - m = self._get_linear(is_per_channel=True) + m = self._get_pt2e_quantized_linear(is_per_channel=True) node_occurrence = { # quantize op for weight node is folded ns.call_function( @@ -682,7 +682,7 @@ def test_fold_quantize_per_channel(self) -> None: def test_save_load(self) -> None: """Test save/load a quantized model""" - m = self._get_linear() + m = self._get_pt2e_quantized_linear() example_inputs = (torch.randn(2, 2),) ref_res = m(*example_inputs) From 089227a5d6147b54e185778af10f04f0b9c2f964 Mon Sep 17 00:00:00 2001 From: dlyakhov Date: Fri, 21 Mar 2025 16:49:37 +0100 Subject: [PATCH 5/6] Unskip OpenVINOQuantizer tests with dynamic shapes --- .../tests/quantizer/test_openvino_quantizer.py | 10 ++-------- .../openvino/tests/quantizer/test_pt2e_quantization.py | 3 --- .../openvino/tests/quantizer/test_representation.py | 3 --- backends/openvino/tests/test_runner.py | 2 +- 4 files changed, 3 insertions(+), 15 deletions(-) diff --git a/backends/openvino/tests/quantizer/test_openvino_quantizer.py b/backends/openvino/tests/quantizer/test_openvino_quantizer.py index ebc10a29a73..b022b16f804 100644 --- a/backends/openvino/tests/quantizer/test_openvino_quantizer.py +++ b/backends/openvino/tests/quantizer/test_openvino_quantizer.py @@ -191,9 +191,6 @@ def test_conv_linear(self): expected_node_list=node_list, ) - @unittest.skip( - "Enable after the fix https://github.com/openvinotoolkit/nncf/pull/3225" - ) def test_linear_with_dynamic_shape(self): quantizer = OpenVINOQuantizer() m_eager = TestHelperModules.TwoLinearModule().eval() @@ -270,17 +267,14 @@ def test_propagate_annotation(self): expected_node_list=node_list, ) - @unittest.skip( - "Enable after the fix https://github.com/openvinotoolkit/nncf/pull/3225" - ) def test_dynamic_linear(self): quantizer = OpenVINOQuantizer() m_eager = TestHelperModules.TwoLinearModule().eval() node_occurrence = { # input and output are using quantize_per_tensor and weight is using quantize_per_channel - torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 2, - torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 2, + torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, # note: quantize op for weights are const propagated torch.ops.quantized_decomposed.quantize_per_channel.default: 0, torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, diff --git a/backends/openvino/tests/quantizer/test_pt2e_quantization.py b/backends/openvino/tests/quantizer/test_pt2e_quantization.py index f8f747f74fe..f40ef8d1e05 100644 --- a/backends/openvino/tests/quantizer/test_pt2e_quantization.py +++ b/backends/openvino/tests/quantizer/test_pt2e_quantization.py @@ -89,9 +89,6 @@ def forward(self, x): ns.call_function( torch.ops.quantized_decomposed.dequantize_per_tensor.default ): 1, - ns.call_function( - torch.ops.quantized_decomposed.dequantize_per_channel.default - ): 1, } self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) diff --git a/backends/openvino/tests/quantizer/test_representation.py b/backends/openvino/tests/quantizer/test_representation.py index 5ca7941dc8d..48a72f940d4 100644 --- a/backends/openvino/tests/quantizer/test_representation.py +++ b/backends/openvino/tests/quantizer/test_representation.py @@ -104,9 +104,6 @@ def forward(self, x): non_ref_node_occurrence={}, ) - @unittest.skip( - "Enable after the fix https://github.com/openvinotoolkit/nncf/pull/3225" - ) def test_dynamic_linear(self): class M(torch.nn.Module): def __init__(self) -> None: diff --git a/backends/openvino/tests/test_runner.py b/backends/openvino/tests/test_runner.py index 0bda8189b0d..f6607deae6e 100644 --- a/backends/openvino/tests/test_runner.py +++ b/backends/openvino/tests/test_runner.py @@ -47,7 +47,7 @@ def parse_arguments(): help="Specify the type of tests ('ops' or 'models')", type=str, default="ops", - choices={"ops", "models"}, + choices={"ops", "models", "quantizer"}, ) args, ns_args = parser.parse_known_args(namespace=unittest) From 9ca366e4647eb297d2833e032abedede54705df6 Mon Sep 17 00:00:00 2001 From: dlyakhov Date: Fri, 21 Mar 2025 17:02:00 +0100 Subject: [PATCH 6/6] NNCF version is updated --- backends/openvino/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backends/openvino/requirements.txt b/backends/openvino/requirements.txt index 316633e9004..95798494e31 100644 --- a/backends/openvino/requirements.txt +++ b/backends/openvino/requirements.txt @@ -1,2 +1,2 @@ transformers -git+https://github.com/openvinotoolkit/nncf@6b0fc1c#egg=nncf +git+https://github.com/openvinotoolkit/nncf@72936ab10b52b50fa5eef6acf0933685fa07cabe#egg=nncf