diff --git a/.github/workflows/executorch.yml b/.github/workflows/executorch.yml new file mode 100644 index 00000000000..8149334d6c6 --- /dev/null +++ b/.github/workflows/executorch.yml @@ -0,0 +1,54 @@ +name: ExecuTorch +permissions: read-all + +on: + workflow_dispatch: + schedule: + - cron: '0 0 * * *' + pull_request: + paths: + - 'src/nncf/experimental/quantization/algorithms/range_estimator/*' + - 'src/nncf/experimental/quantization/algorithms/post_training/*' + - 'src/nncf/experimental/quantization/algorithms/weight_compression/*' + - 'tests/executorch*' + - 'src/nncf/experimental/torch/fx/*' + - 'src/nncf/quantization/algorithms/algorithm.py' + +jobs: + executorch: + timeout-minutes: 40 + runs-on: ubuntu-latest-8-cores + defaults: + run: + shell: bash + env: + DEBIAN_FRONTEND: noninteractive + steps: + - name: Install dependencies + run : | + sudo apt-get update + sudo apt-get --assume-yes install gcc g++ build-essential ninja-build libgl1-mesa-dev libglib2.0-0 + - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + with: + lfs: true + - uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0 + with: + python-version: "3.10.14" + - name: Runner info + continue-on-error: true + run: | + cat /etc/*release + cat /proc/cpuinfo + - name: Install NNCF and test requirements + run: | + # Torchao installation requires pytorch to be installed first. + pip install . -r tests/executorch/requirements.txt + pip install --pre torch torchvision torchao --index-url https://download.pytorch.org/whl/nightly/cpu + # Executorch + # Editable install due to https://github.com/pytorch/executorch/issues/6475 + pip install --no-build-isolation -e git+https://github.com/anzr299/executorch.git@an/openvino/nncf_compress_pt2e#egg=executorch + - name: Print installed modules + run: pip list + - name: Run PyTorch precommit test scope + run: | + pytest -ra tests/executorch \ No newline at end of file diff --git a/src/nncf/common/tensor_statistics/collectors.py b/src/nncf/common/tensor_statistics/collectors.py index 4a44d3bf592..df65825a112 100644 --- a/src/nncf/common/tensor_statistics/collectors.py +++ b/src/nncf/common/tensor_statistics/collectors.py @@ -938,7 +938,7 @@ def _aggregate_impl(self) -> Tensor: class HistogramAggregator(AggregatorBase): """ - NNCF implementation of the torch.ao.quantization.observer.HistogramObserver. + NNCF implementation of the torchao.quantization.pt2e.observer.HistogramObserver. Intended to be combined with a single RawReducer. The aggregator records the running histogram of the input tensor values along with min/max values. Only the reduction_axis==None is supported. diff --git a/src/nncf/experimental/torch/fx/__init__.py b/src/nncf/experimental/torch/fx/__init__.py index 79350c12855..0c6cfb97597 100644 --- a/src/nncf/experimental/torch/fx/__init__.py +++ b/src/nncf/experimental/torch/fx/__init__.py @@ -11,4 +11,3 @@ from nncf.experimental.torch.fx.quantization.quantize_pt2e import compress_pt2e as compress_pt2e from nncf.experimental.torch.fx.quantization.quantize_pt2e import quantize_pt2e as quantize_pt2e -from nncf.experimental.torch.fx.quantization.quantizer.openvino_quantizer import OpenVINOQuantizer as OpenVINOQuantizer diff --git a/src/nncf/experimental/torch/fx/quantization/quantize_model.py b/src/nncf/experimental/torch/fx/quantization/quantize_model.py index 17f895f54ff..ca543c9ede6 100644 --- a/src/nncf/experimental/torch/fx/quantization/quantize_model.py +++ b/src/nncf/experimental/torch/fx/quantization/quantize_model.py @@ -14,11 +14,11 @@ import torch import torch.fx -from torch.ao.quantization.pt2e.port_metadata_pass import PortNodeMetaForQDQ -from torch.ao.quantization.pt2e.qat_utils import _fold_conv_bn_qat -from torch.ao.quantization.pt2e.utils import _disallow_eval_train from torch.fx import GraphModule from torch.fx.passes.infra.pass_manager import PassManager +from torchao.quantization.pt2e.qat_utils import _fold_conv_bn_qat +from torchao.quantization.pt2e.quantizer import PortNodeMetaForQDQ +from torchao.quantization.pt2e.utils import _disallow_eval_train import nncf from nncf.common.factory import build_graph diff --git a/src/nncf/experimental/torch/fx/quantization/quantize_pt2e.py b/src/nncf/experimental/torch/fx/quantization/quantize_pt2e.py index 396dd2e87cb..5e1850fd5e0 100644 --- a/src/nncf/experimental/torch/fx/quantization/quantize_pt2e.py +++ b/src/nncf/experimental/torch/fx/quantization/quantize_pt2e.py @@ -8,18 +8,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from copy import deepcopy from typing import Optional import torch import torch.fx -from torch.ao.quantization.pt2e.port_metadata_pass import PortNodeMetaForQDQ -from torch.ao.quantization.pt2e.utils import _disallow_eval_train -from torch.ao.quantization.pt2e.utils import _fuse_conv_bn_ -from torch.ao.quantization.quantizer import Quantizer from torch.fx import GraphModule from torch.fx.passes.infra.pass_manager import PassManager +from torchao.quantization.pt2e.quantizer import PortNodeMetaForQDQ +from torchao.quantization.pt2e.quantizer.quantizer import Quantizer +from torchao.quantization.pt2e.utils import _disallow_eval_train +from torchao.quantization.pt2e.utils import _fuse_conv_bn_ import nncf from nncf import AdvancedCompressionParameters @@ -32,7 +31,6 @@ from nncf.experimental.quantization.algorithms.weight_compression.algorithm import WeightsCompression from nncf.experimental.torch.fx.constant_folding import constant_fold from nncf.experimental.torch.fx.quantization.quantizer.openvino_adapter import OpenVINOQuantizerAdapter -from nncf.experimental.torch.fx.quantization.quantizer.openvino_quantizer import OpenVINOQuantizer from nncf.experimental.torch.fx.quantization.quantizer.torch_ao_adapter import TorchAOQuantizerAdapter from nncf.experimental.torch.fx.transformations import QUANTIZE_NODE_TARGETS from nncf.experimental.torch.fx.transformations import DuplicateDQPassNoAnnotations @@ -42,6 +40,19 @@ from nncf.quantization.range_estimator import RangeEstimatorParameters +def _is_openvino_quantizer_instance(obj) -> bool: + """ + Safely check if an object is instance of OpenVINOQuantizer. + This is to avoid a circular import + """ + try: + from executorch.backends.openvino.quantizer.quantizer import OpenVINOQuantizer + except ModuleNotFoundError as err: + msg = "OpenVINO Quantizer could not be imported from Executorch. Please install Executorch." + raise nncf.ModuleNotFoundError(msg) from err + return isinstance(obj, OpenVINOQuantizer) + + @api(canonical_alias="nncf.experimental.torch.fx.quantize_pt2e") def quantize_pt2e( model: torch.fx.GraphModule, @@ -60,7 +71,7 @@ def quantize_pt2e( ) -> torch.fx.GraphModule: """ Applies post-training quantization to the torch.fx.GraphModule provided model - using provided torch.ao quantizer. + using provided torchao quantizer. :param model: A torch.fx.GraphModule instance to be quantized. :param quantizer: Torch ao quantizer to annotate nodes in the graph with quantization setups @@ -103,7 +114,7 @@ def quantize_pt2e( model = deepcopy(model) _fuse_conv_bn_(model) - if isinstance(quantizer, OpenVINOQuantizer) or hasattr(quantizer, "get_nncf_quantization_setup"): + if _is_openvino_quantizer_instance(quantizer) or hasattr(quantizer, "get_nncf_quantization_setup"): quantizer = OpenVINOQuantizerAdapter(quantizer) else: quantizer = TorchAOQuantizerAdapter(quantizer) @@ -130,7 +141,7 @@ def quantize_pt2e( quantized_model = GraphModule(quantized_model, quantized_model.graph) if fold_quantize: - if isinstance(quantizer, OpenVINOQuantizerAdapter): + if _is_openvino_quantizer_instance(quantizer): compress_post_quantize_transformation(quantized_model) else: constant_fold(quantized_model, _quant_node_constraint) @@ -178,7 +189,7 @@ def compress_pt2e( advanced_parameters: Optional[AdvancedCompressionParameters] = None, ) -> torch.fx.GraphModule: """ - Applies Weight Compression to the torch.fx.GraphModule model using provided torch.ao quantizer. + Applies Weight Compression to the torch.fx.GraphModule model using provided torchao quantizer. :param model: A torch.fx.GraphModule instance to be quantized. :param quantizer: Torch ao quantizer to annotate nodes in the graph with quantization setups @@ -196,7 +207,7 @@ def compress_pt2e( preserve the accuracy of the model, the more sensitive layers receive a higher precision. :param advanced_parameters: Advanced parameters for algorithms in the compression pipeline. """ - if isinstance(quantizer, OpenVINOQuantizer) or hasattr(quantizer, "get_nncf_weight_compression_parameters"): + if _is_openvino_quantizer_instance(quantizer) or hasattr(quantizer, "get_nncf_weight_compression_parameters"): quantizer = OpenVINOQuantizerAdapter(quantizer) compression_format = nncf.CompressionFormat.DQ else: diff --git a/src/nncf/experimental/torch/fx/quantization/quantizer/openvino_adapter.py b/src/nncf/experimental/torch/fx/quantization/quantizer/openvino_adapter.py index 63c4c4c6ff1..90f3ff47e93 100644 --- a/src/nncf/experimental/torch/fx/quantization/quantizer/openvino_adapter.py +++ b/src/nncf/experimental/torch/fx/quantization/quantizer/openvino_adapter.py @@ -9,16 +9,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any +from __future__ import annotations + +from typing import TYPE_CHECKING, Any import torch.fx from nncf.common.graph.graph import NNCFGraph from nncf.common.quantization.quantizer_setup import SingleConfigQuantizerSetup from nncf.experimental.quantization.quantizer import Quantizer -from nncf.experimental.torch.fx.quantization.quantizer.openvino_quantizer import OpenVINOQuantizer from nncf.quantization.algorithms.weight_compression.config import WeightCompressionParameters +if TYPE_CHECKING: + from executorch.backends.openvino.quantizer.quantizer import OpenVINOQuantizer + class OpenVINOQuantizerAdapter(Quantizer): """ diff --git a/src/nncf/experimental/torch/fx/quantization/quantizer/openvino_quantizer.py b/src/nncf/experimental/torch/fx/quantization/quantizer/openvino_quantizer.py deleted file mode 100644 index 170abca504b..00000000000 --- a/src/nncf/experimental/torch/fx/quantization/quantizer/openvino_quantizer.py +++ /dev/null @@ -1,363 +0,0 @@ -# Copyright (c) 2026 Intel Corporation -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from collections import defaultdict -from typing import Optional, Union - -import torch.fx -from torch.ao.quantization.observer import HistogramObserver -from torch.ao.quantization.observer import PerChannelMinMaxObserver -from torch.ao.quantization.quantizer.quantizer import EdgeOrNode -from torch.ao.quantization.quantizer.quantizer import QuantizationAnnotation as TorchAOQuantizationAnnotation -from torch.ao.quantization.quantizer.quantizer import QuantizationSpec as TorchAOQuantizationSpec -from torch.ao.quantization.quantizer.quantizer import QuantizationSpecBase as TorchAOQuantizationSpecBase -from torch.ao.quantization.quantizer.quantizer import Quantizer as TorchAOQuantizer -from torch.ao.quantization.quantizer.quantizer import SharedQuantizationSpec as TorchAOSharedQuantizationSpec - -import nncf -from nncf import IgnoredScope -from nncf import ModelType -from nncf import OverflowFix -from nncf import QuantizationMode -from nncf import QuantizationPreset -from nncf import TargetDevice -from nncf.common.graph.graph import NNCFGraph -from nncf.common.logging import nncf_logger -from nncf.common.quantization.quantizer_propagation.structs import QuantizerPropagationRule -from nncf.common.quantization.quantizer_setup import QuantizationPointBase -from nncf.common.quantization.quantizer_setup import SingleConfigQuantizerSetup -from nncf.common.quantization.structs import QuantizationScheme -from nncf.common.utils.api_marker import api -from nncf.experimental.torch.fx.nncf_graph_builder import GraphConverter -from nncf.experimental.torch.fx.node_utils import get_graph_node_by_name -from nncf.quantization.advanced_parameters import FP8QuantizationParameters -from nncf.quantization.advanced_parameters import QuantizationParameters -from nncf.quantization.algorithms.min_max.algorithm import MinMaxQuantization -from nncf.torch.model_graph_manager import get_weight_tensor_port_ids - -QUANT_ANNOTATION_KEY = "quantization_annotation" - - -@api(canonical_alias="nncf.experimental.torch.fx.OpenVINOQuantizer") -class OpenVINOQuantizer(TorchAOQuantizer): - """ - Implementation of the Torch AO quantizer which annotates models with quantization annotations - optimally for the inference via OpenVINO. - - :param mode: Defines optimization mode for the algorithm. None by default. - :param preset: A preset controls the quantization mode (symmetric and asymmetric). - It can take the following values: - - `performance`: Symmetric quantization of weights and activations. - - `mixed`: Symmetric quantization of weights and asymmetric quantization of activations. - Default value is None. In this case, `mixed` preset is used for `transformer` - model type otherwise `performance`. - :param target_device: A target device the specificity of which will be taken - into account while compressing in order to obtain the best performance - for this type of device, defaults to TargetDevice.ANY. - :param model_type: Model type is needed to specify additional patterns - in the model. Supported only `transformer` now. - :param ignored_scope: An ignored scope that defined the list of model control - flow graph nodes to be ignored during quantization. - :param overflow_fix: This option controls whether to apply the overflow issue - fix for the 8-bit quantization. - :param quantize_outputs: Whether to insert additional quantizers right before - each of the model outputs. - :param activations_quantization_params: Quantization parameters for model - activations. - :param weights_quantization_params: Quantization parameters for model weights. - :param quantizer_propagation_rule: The strategy to be used while propagating and merging quantizers. - MERGE_ALL_IN_ONE by default. - """ - - def __init__( - self, - *, - mode: Optional[QuantizationMode] = None, - preset: Optional[QuantizationPreset] = None, - target_device: TargetDevice = TargetDevice.ANY, - model_type: Optional[ModelType] = None, - ignored_scope: Optional[IgnoredScope] = None, - overflow_fix: Optional[OverflowFix] = None, - quantize_outputs: bool = False, - activations_quantization_params: Optional[Union[QuantizationParameters, FP8QuantizationParameters]] = None, - weights_quantization_params: Optional[Union[QuantizationParameters, FP8QuantizationParameters]] = None, - quantizer_propagation_rule: QuantizerPropagationRule = QuantizerPropagationRule.MERGE_ALL_IN_ONE, - ): - self._min_max_algo = MinMaxQuantization( - mode=mode, - preset=preset, - target_device=target_device, - model_type=model_type, - ignored_scope=ignored_scope, - overflow_fix=overflow_fix, - quantize_outputs=quantize_outputs, - activations_quantization_params=activations_quantization_params, - weights_quantization_params=weights_quantization_params, - quantizer_propagation_rule=quantizer_propagation_rule, - ) - - def set_ignored_scope( - self, - names: Optional[list[str]] = None, - patterns: Optional[list[str]] = None, - types: Optional[list[str]] = None, - subgraphs: Optional[list[tuple[list[str], list[str]]]] = None, - validate: bool = True, - ) -> None: - """ - Provides an option to specify portions of model to be excluded from compression. - The ignored scope defines model sub-graphs that should be excluded from the quantization process. - - :param names: List of ignored node names. - :param patterns: List of regular expressions that define patterns for names of ignored nodes. - :param types: List of ignored operation types. - :param subgraphs: List of ignored subgraphs. - :param validate: If set to True, then a RuntimeError will be raised if any ignored scope does not match - in the model graph. - """ - self._min_max_algo.set_ignored_scope( - nncf.IgnoredScope( - names=names or [], - patterns=patterns or [], - types=types or [], - subgraphs=subgraphs or [], - validate=validate, - ) - ) - - def get_nncf_quantization_setup( - self, model: torch.fx.GraphModule, nncf_graph: NNCFGraph - ) -> SingleConfigQuantizerSetup: - self._min_max_algo._set_backend_entity(model) - return self._min_max_algo.find_quantization_setup(model, nncf_graph) - - def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: - """ - Adds quantization annotations to the nodes in the model graph in-place. - - :param model: A torch.fx.GraphModule to annotate. - :return: The torch.fx.GraphModule with updated annotations. - """ - nncf_graph = GraphConverter.create_nncf_graph(model) - quantization_setup = self.get_nncf_quantization_setup(model, nncf_graph) - - graph = model.graph - node_vs_torch_annotation = defaultdict(TorchAOQuantizationAnnotation) - - for qp in quantization_setup.quantization_points.values(): - edge_or_node, annotation = self._get_edge_or_node_and_annotation( - graph, nncf_graph, qp, node_vs_torch_annotation - ) - qspec = self._get_torch_ao_qspec_from_qp(qp) - self._fill_torch_ao_annotation(edge_or_node, qspec, annotation) - - for quantizer_ids in quantization_setup.unified_scale_groups.values(): - root_quantizer_id = self._get_unified_scales_root_quantizer_id( - nncf_graph, quantizer_ids, quantization_setup - ) - root_qp = quantization_setup.quantization_points[root_quantizer_id] - - if any(root_qp.qconfig != quantization_setup.quantization_points[q_id].qconfig for q_id in quantizer_ids): - qps = [quantization_setup.quantization_points[q_id] for q_id in quantizer_ids] - msg = ( - "Different quantization configs are set to one unified scale group:" - f"{[(qp.insertion_point.__dict__, str(qp.qconfig)) for qp in qps]}" - ) - raise nncf.InternalError(msg) - - root_target_node = get_graph_node_by_name(graph, root_qp.insertion_point.target_node_name) - root_edge_or_node = self._get_edge_or_node(root_target_node, root_qp, nncf_graph) - - for quantizer_id in quantizer_ids: - if quantizer_id == root_quantizer_id: - continue - - qspec = TorchAOSharedQuantizationSpec(root_edge_or_node) - qp = quantization_setup.quantization_points[quantizer_id] - edge_or_node, annotation = self._get_edge_or_node_and_annotation( - graph, nncf_graph, qp, node_vs_torch_annotation - ) - self._fill_torch_ao_annotation(edge_or_node, qspec, annotation) - - for node, annotation in node_vs_torch_annotation.items(): - assert QUANT_ANNOTATION_KEY not in node.meta - node.meta[QUANT_ANNOTATION_KEY] = annotation - return model - - @staticmethod - def _get_unified_scales_root_quantizer_id( - nncf_graph: NNCFGraph, quantizer_ids: list[int], quantizer_setup: SingleConfigQuantizerSetup - ) -> int: - """ - Identifies the earliest quantizer node ID based on the corresponding `nncf_node.node_id` - in the given NNCFGraph. This is required by the `_get_obs_or_fq_map` function. - Refer to: https://github.com/pytorch/pytorch/blob/main/torch/ao/quantization/pt2e/prepare.py#L291 - - :param nncf_graph: The NNCFGraph instance. - :param quantizer_ids: The list of quantizer IDs to evaluate. - :param quantizer_setup: The instance of SingleConfigQuantizerSetup. - :return: The ID of the earliest quantizer node in terms of `nncf_node.node_id`. - """ - nncf_node_quantizer_id = None - root_quantizer_id = None - for quantizer_id in quantizer_ids: - target_node_name = quantizer_setup.quantization_points[quantizer_id].insertion_point.target_node_name - nncf_node = nncf_graph.get_node_by_name(target_node_name) - if nncf_node_quantizer_id is None or nncf_node.node_id < nncf_node_quantizer_id: - root_quantizer_id = quantizer_id - nncf_node_quantizer_id = nncf_node.node_id - return root_quantizer_id - - @staticmethod - def _get_edge_or_node_and_annotation( - graph: torch.fx.Graph, - nncf_graph: NNCFGraph, - qp: QuantizationPointBase, - node_vs_torch_annotation: dict[torch.fx.Node, TorchAOQuantizationAnnotation], - ) -> tuple[EdgeOrNode, TorchAOQuantizationAnnotation]: - """ - Retrieves the edge or node and its corresponding TorchAOQuantizationAnnotation based on the given graph, - quantization point, and node-to-annotation mapping. - - :param graph: torch.fx.Graph instance. - :param nncf_graph: NNCFGraph instance. - :param qp: QuantizationPointBase instance. - :param node_vs_torch_annotation: A dictionary mapping torch.fx.GraphNode objects to their respective - TorchAOQuantizationAnnotations. - :return: A tuple containing the EdgeOrNode and its associated TorchAOQuantizationAnnotation. - """ - target_node = get_graph_node_by_name(graph, qp.insertion_point.target_node_name) - annotation = node_vs_torch_annotation[target_node] - edge_or_node = OpenVINOQuantizer._get_edge_or_node(target_node, qp, nncf_graph) - return edge_or_node, annotation - - @staticmethod - def _get_edge_or_node(target_node: torch.fx.Node, qp: QuantizationPointBase, nncf_graph: NNCFGraph) -> EdgeOrNode: - """ - Returns the edge or node based on the given target node and quantization point. - - :param target_node: Target node instance. - :param qp: QuantizationPointBase instance. - :param graph: NNCFGraph instance. - :return: The corresponding EdgeOrNode derived from the target node and quantization point. - """ - ip = qp.insertion_point - if qp.is_weight_quantization_point(): - nncf_node = nncf_graph.get_node_by_name(target_node.name) - weights_ports_ids = get_weight_tensor_port_ids(nncf_node, nncf_graph) - if len(weights_ports_ids) > 1: - # TODO(dlyakhov): support quantization for nodes with several weights - nncf_logger.warning( - f"Quantization of the weighted node {target_node.name}" - " is not yet supported by the OpenVINOQuantizer." - f" Only the weight on port ID {weights_ports_ids[0]} will be quantized." - f" Quantizable weights are located on ports: {weights_ports_ids}." - ) - weight_node = target_node.all_input_nodes[weights_ports_ids[0]] - return (weight_node, target_node) - - if ip.input_port_id is None: - return target_node - - node = target_node.all_input_nodes[ip.input_port_id] - return (node, target_node) - - @staticmethod - def _fill_torch_ao_annotation( - edge_or_node: EdgeOrNode, - qspec: TorchAOQuantizationSpecBase, - annotation_to_update: TorchAOQuantizationAnnotation, - ) -> None: - """ - Helper method to update the annotation_to_update based on the specified edge_or_node and qspec. - - :param edge_or_node: The target EdgeOrNode to be used for the update. - :param qspec: An instance of TorchAOQuantizationSpecBase representing the quantization specification to apply. - :param annotation_to_update: The annotation to update based on the edge_or_node and qspec. - """ - if isinstance(edge_or_node, torch.fx.Node): - annotation_to_update.output_qspec = qspec - else: - annotation_to_update.input_qspec_map[edge_or_node[0]] = qspec - - @staticmethod - def _get_torch_ao_qspec_from_qp(qp: QuantizationPointBase) -> TorchAOQuantizationSpec: - """ - Retrieves the quantization configuration from the given quantization point and - converts it into a TorchAOQuantizationSpec. - - :param qp: An instance of QuantizationPointBase. - :return: A TorchAOQuantizationSpec retrieved and converted from the quantization point. - """ - # Eps value is copied from nncf/torch/quantization/layers.py - extra_args = {"eps": 1e-16} - qconfig = qp.qconfig - is_weight = qp.is_weight_quantization_point() - - if qconfig.per_channel: - torch_qscheme = ( - torch.per_channel_symmetric - if qconfig.mode is QuantizationScheme.SYMMETRIC - else torch.per_channel_affine - ) - else: - torch_qscheme = ( - torch.per_tensor_symmetric if qconfig.mode is QuantizationScheme.SYMMETRIC else torch.per_tensor_affine - ) - if is_weight: - observer = PerChannelMinMaxObserver - quant_min = -128 - quant_max = 127 - dtype = torch.int8 - channel_axis = 0 - else: - observer = ( - HistogramObserver - if torch_qscheme in [torch.per_tensor_symmetric, torch.per_tensor_affine] - else PerChannelMinMaxObserver - ) - quant_min = 0 - quant_max = 255 - dtype = torch.int8 if qconfig.signedness_to_force else torch.uint8 - channel_axis = 1 # channel dim for activations - return TorchAOQuantizationSpec( - dtype=dtype, - observer_or_fake_quant_ctr=observer.with_args(**extra_args), - quant_min=quant_min, - quant_max=quant_max, - qscheme=torch_qscheme, - ch_axis=channel_axis, - is_dynamic=False, - ) - - def validate(self, model: torch.fx.GraphModule) -> None: - """ - Validates the annotated model before the insertion of FakeQuantizers / observers. - - :param model: Annotated torch.fx.GraphModule to validate after the annotation. - """ - pass - - def transform_for_annotation(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: - """ - Allows for user defined transforms to run before annotating the graph. - This allows quantizer to allow quantizing part of the model that are otherwise not quantizable. - For example quantizer can - a) decompose a compound operator like scaled dot product attention, - into bmm and softmax if quantizer knows how to quantize bmm/softmax but not sdpa - or b) transform scalars to tensor to allow quantizing scalars. - - Note: this is an optional method - - :param model: Given torch.fx.GraphModule to transform before the annotation. - :return: The transformed torch.fx.GraphModule ready for the annotation. - """ - return model diff --git a/src/nncf/experimental/torch/fx/quantization/quantizer/torch_ao_adapter.py b/src/nncf/experimental/torch/fx/quantization/quantizer/torch_ao_adapter.py index cf46603257c..96b8ad50902 100644 --- a/src/nncf/experimental/torch/fx/quantization/quantizer/torch_ao_adapter.py +++ b/src/nncf/experimental/torch/fx/quantization/quantizer/torch_ao_adapter.py @@ -15,11 +15,11 @@ import torch import torch.fx -from torch.ao.quantization.pt2e.prepare import _get_edge_or_node_to_group_id -from torch.ao.quantization.pt2e.prepare import _get_edge_or_node_to_qspec -from torch.ao.quantization.quantizer import Quantizer as TorchAOQuantizer -from torch.ao.quantization.quantizer.quantizer import QuantizationSpec -from torch.ao.quantization.quantizer.quantizer import SharedQuantizationSpec +from torchao.quantization.pt2e.prepare import _get_edge_or_node_to_group_id +from torchao.quantization.pt2e.prepare import _get_edge_or_node_to_qspec +from torchao.quantization.pt2e.quantizer import Quantizer as TorchAOQuantizer +from torchao.quantization.pt2e.quantizer.quantizer import QuantizationSpec +from torchao.quantization.pt2e.quantizer.quantizer import SharedQuantizationSpec import nncf from nncf.common.graph.graph import NNCFGraph @@ -41,7 +41,7 @@ class TorchAOQuantizerAdapter(Quantizer): """ - Implementation of the NNCF Quantizer interface for any given torch.ao quantizer. + Implementation of the NNCF Quantizer interface for any given torchao quantizer. """ def __init__(self, quantizer: TorchAOQuantizer): @@ -110,7 +110,7 @@ def _get_quantization_points( def get_quantizer_config_from_annotated_model(annotated: torch.fx.GraphModule) -> SingleConfigQuantizerSetup: """ Process a torch.fx.GraphModule annotated with quantization specifications - (e.g., via torch.ao observers) and generates a corresponding NNCF quantization setup object, + (e.g., via torchao observers) and generates a corresponding NNCF quantization setup object, which maps quantization configurations to graph edges. :param annotated: A torch.fx.GraphModule that has been annotated with Torch quantization observers. @@ -139,7 +139,7 @@ def get_quantizer_config_from_annotated_model(annotated: torch.fx.GraphModule) - if qspec is None: continue if not isinstance(qspec, QuantizationSpec): - msg = f"Unknown torch.ao quantization spec: {qspec}" + msg = f"Unknown torchao quantization spec: {qspec}" raise nncf.InternalError(msg) if qspec.qscheme in [torch.per_channel_affine, torch.per_channel_symmetric]: @@ -156,9 +156,8 @@ def get_quantizer_config_from_annotated_model(annotated: torch.fx.GraphModule) - if qspec.qscheme in [torch.per_channel_symmetric, torch.per_tensor_symmetric] else QuantizationMode.ASYMMETRIC ) - # QuantizationSpec may have quant_min and quant_max attributes set to None. - # torch.ao.prepare_pt2e treats such occurrences as a signal + # torchao.prepare_pt2e treats such occurrences as a signal # that the full range of values should be used for quant_min and quant_max. # Therefore, the narrow_range parameter is set to False in this case. if qspec.quant_min is None or qspec.quant_max is None: diff --git a/src/nncf/experimental/torch/fx/transformations.py b/src/nncf/experimental/torch/fx/transformations.py index 49579afb906..3f50c3c69ad 100644 --- a/src/nncf/experimental/torch/fx/transformations.py +++ b/src/nncf/experimental/torch/fx/transformations.py @@ -15,12 +15,12 @@ import torch import torch.fx -from torch.ao.quantization.fx.utils import create_getattr_from_value -from torch.ao.quantization.pt2e.utils import _fuse_conv_bn_ from torch.fx.node import map_arg from torch.fx.passes.infra.pass_base import PassBase from torch.fx.passes.infra.pass_base import PassResult from torch.quantization.fake_quantize import FakeQuantize +from torchao.quantization.pt2e.utils import _fuse_conv_bn_ +from torchao.quantization.pt2e.utils import create_getattr_from_value import nncf import nncf.torch @@ -382,7 +382,7 @@ def insert_one_qdq(model: torch.fx.GraphModule, target_point: PTTargetPoint, qua target node. :param quantizer: Quantizer module to inherit quantization parameters from. """ - # Copied from torch.ao.quantization.quantize_pt2e.convert_pt2e + # Copied from torchao.quantization.pt2e.quantize_pt2e.convert_pt2e # 1. extract information for inserting q/dq node from activation_post_process node_type = "call_function" quantize_op: Optional[Callable] = None diff --git a/src/nncf/quantization/algorithms/min_max/torch_fx_backend.py b/src/nncf/quantization/algorithms/min_max/torch_fx_backend.py index ad411e0e4fa..a0ba0dc19e2 100644 --- a/src/nncf/quantization/algorithms/min_max/torch_fx_backend.py +++ b/src/nncf/quantization/algorithms/min_max/torch_fx_backend.py @@ -12,6 +12,7 @@ from typing import Optional import torch +import torchao from torch.quantization.fake_quantize import FakeQuantize import nncf @@ -203,9 +204,9 @@ def _create_quantizer( ) if per_channel: - observer = torch.ao.quantization.observer.PerChannelMinMaxObserver + observer = torchao.quantization.pt2e.observer.PerChannelMinMaxObserver else: - observer = torch.ao.quantization.observer.MinMaxObserver + observer = torchao.quantization.pt2e.observer.MinMaxObserver if dtype is TensorDataType.int8: level_high = 127 diff --git a/src/nncf/torch/quantization/strip.py b/src/nncf/torch/quantization/strip.py index 1e071ad2729..9dde31caa43 100644 --- a/src/nncf/torch/quantization/strip.py +++ b/src/nncf/torch/quantization/strip.py @@ -49,11 +49,12 @@ def convert_to_torch_fakequantizer(nncf_quantizer: BaseQuantizer) -> FakeQuantiz scale_shape = nncf_quantizer.scale_shape ch_axis = int(np.argmax(scale_shape)) dtype = torch.qint8 if nncf_quantizer.level_low < 0 else torch.quint8 + import torchao if per_channel: - observer = torch.ao.quantization.observer.PerChannelMinMaxObserver + observer = torchao.quantization.pt2e.observer.PerChannelMinMaxObserver else: - observer = torch.ao.quantization.observer.MinMaxObserver + observer = torchao.quantization.pt2e.observer.MinMaxObserver if isinstance(nncf_quantizer, SymmetricQuantizer): qscheme = torch.per_channel_symmetric if per_channel else torch.per_tensor_symmetric diff --git a/tests/torch/fx/test_quantizer.py b/tests/executorch/test_ptq.py similarity index 92% rename from tests/torch/fx/test_quantizer.py rename to tests/executorch/test_ptq.py index 90734f60b0f..0a50f3081c4 100644 --- a/tests/torch/fx/test_quantizer.py +++ b/tests/executorch/test_ptq.py @@ -17,22 +17,19 @@ import pytest import torch import torch.fx -import torch.nn.parallel -import torch.optim -import torch.utils.data -import torch.utils.data.distributed import torchvision.models as models -from torch.ao.quantization.pt2e.utils import _fuse_conv_bn_ -from torch.ao.quantization.quantize_pt2e import convert_pt2e -from torch.ao.quantization.quantize_pt2e import prepare_pt2e -from torch.ao.quantization.quantizer import xnnpack_quantizer -from torch.ao.quantization.quantizer.quantizer import QuantizationAnnotation -from torch.ao.quantization.quantizer.quantizer import QuantizationSpec as TorchAOQuantizationSpec -from torch.ao.quantization.quantizer.quantizer import Quantizer -from torch.ao.quantization.quantizer.quantizer import Quantizer as TorchAOQuantizer -from torch.ao.quantization.quantizer.quantizer import SharedQuantizationSpec as TorchAOSharedQuantizationSpec -from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer -from torch.ao.quantization.quantizer.x86_inductor_quantizer import get_default_x86_inductor_quantization_config +from executorch.backends.openvino.quantizer.quantizer import OpenVINOQuantizer +from executorch.backends.xnnpack.quantizer import xnnpack_quantizer +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e +from torchao.quantization.pt2e.quantize_pt2e import prepare_pt2e +from torchao.quantization.pt2e.quantizer import QuantizationAnnotation +from torchao.quantization.pt2e.quantizer import QuantizationSpec as TorchAOQuantizationSpec +from torchao.quantization.pt2e.quantizer import Quantizer +from torchao.quantization.pt2e.quantizer import Quantizer as TorchAOQuantizer +from torchao.quantization.pt2e.quantizer import SharedQuantizationSpec as TorchAOSharedQuantizationSpec +from torchao.quantization.pt2e.quantizer.x86_inductor_quantizer import X86InductorQuantizer +from torchao.quantization.pt2e.quantizer.x86_inductor_quantizer import get_default_x86_inductor_quantization_config +from torchao.quantization.pt2e.utils import _fuse_conv_bn_ import nncf from nncf.common.graph import NNCFGraph @@ -41,7 +38,6 @@ from nncf.experimental.torch.fx.nncf_graph_builder import GraphConverter from nncf.experimental.torch.fx.node_utils import get_graph_node_by_name from nncf.experimental.torch.fx.quantization.quantizer.openvino_adapter import OpenVINOQuantizerAdapter -from nncf.experimental.torch.fx.quantization.quantizer.openvino_quantizer import OpenVINOQuantizer from nncf.experimental.torch.fx.quantization.quantizer.torch_ao_adapter import TorchAOQuantizerAdapter from nncf.experimental.torch.fx.quantization.quantizer.torch_ao_adapter import _get_edge_or_node_to_qspec from nncf.tensor.definitions import TensorDataType @@ -55,7 +51,7 @@ from tests.torch.test_models.synthetic import SimpleConcatModel from tests.torch.test_models.synthetic import YOLO11N_SDPABlock -FX_QUANTIZED_DIR_NAME = TEST_ROOT / "torch" / "data" / "fx" +FX_QUANTIZED_DIR_NAME = TEST_ROOT / "executorch" / "data" / "fx" @dataclass @@ -171,7 +167,7 @@ def test_quantized_model( ) # Uncomment to visualize torch fx graph - # from tests.torch.fx.helpers import visualize_fx_model + # from tests.torch2.fx.helpers import visualize_fx_model # visualize_fx_model(quantized_model, f"{quantizer.__class__.__name__}_{model_case.model_id}_int8.svg") nncf_graph = GraphConverter.create_nncf_graph(quantized_model) @@ -181,9 +177,9 @@ def test_quantized_model( compare_nx_graph_with_reference(nx_graph, path_to_dot.as_posix()) # Uncomment to visualize reference graphs - # from torch.ao.quantization.quantize_pt2e import convert_pt2e - # from torch.ao.quantization.quantize_pt2e import prepare_pt2e - # from tests.torch.fx.helpers import visualize_fx_model + # from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e + # from torchao.quantization.pt2e.quantize_pt2e import prepare_pt2e + # from tests.torch2.fx.helpers import visualize_fx_model # prepared_model = prepare_pt2e(fx_model, quantizer) # prepared_model(example_input) # ao_quantized_model = convert_pt2e(prepared_model) diff --git a/tests/post_training/pipelines/image_classification_base.py b/tests/post_training/pipelines/image_classification_base.py index 129cb875f71..1171088e7e6 100644 --- a/tests/post_training/pipelines/image_classification_base.py +++ b/tests/post_training/pipelines/image_classification_base.py @@ -24,9 +24,9 @@ import openvino as ov import torch from sklearn.metrics import accuracy_score -from torch.ao.quantization.quantize_pt2e import convert_pt2e -from torch.ao.quantization.quantize_pt2e import prepare_pt2e -from torch.ao.quantization.quantizer.quantizer import Quantizer as TorchAOQuantizer +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e +from torchao.quantization.pt2e.quantize_pt2e import prepare_pt2e +from torchao.quantization.pt2e.quantizer import Quantizer as TorchAOQuantizer from torchvision import datasets import nncf diff --git a/tests/torch/fx/test_model_transformer.py b/tests/torch/fx/test_model_transformer.py index 05e195e4299..046c0ff3fff 100644 --- a/tests/torch/fx/test_model_transformer.py +++ b/tests/torch/fx/test_model_transformer.py @@ -15,12 +15,11 @@ import pytest import torch -import torch.ao.quantization import torch.fx -from torch.ao.quantization.fx.utils import create_getattr_from_value -from torch.ao.quantization.observer import MinMaxObserver -from torch.ao.quantization.observer import PerChannelMinMaxObserver from torch.quantization.fake_quantize import FakeQuantize +from torchao.quantization.pt2e.observer import MinMaxObserver +from torchao.quantization.pt2e.observer import PerChannelMinMaxObserver +from torchao.quantization.pt2e.utils import create_getattr_from_value import nncf import nncf.common