Skip to content

Dl/fx/dont use nncf q #3487

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion nncf/common/quantization/quantizer_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,12 +247,13 @@ def __init__(self) -> None:
self._next_unified_scale_gid = 0
self._next_shared_inputs_gid = 0

def add_independent_quantization_point(self, qp: QuantizationPointBase) -> None:
def add_independent_quantization_point(self, qp: QuantizationPointBase) -> int:
if self.quantization_points.keys():
new_id = max(self.quantization_points.keys()) + 1
else:
new_id = 0
self.quantization_points[new_id] = qp
return new_id

def register_unified_scale_group(self, qp_group: list[QuantizationPointId]) -> int:
for qp_id in qp_group:
Expand Down
81 changes: 79 additions & 2 deletions nncf/experimental/quantization/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,91 @@

from abc import ABC
from abc import abstractmethod
from typing import TypeVar
from enum import Enum
from typing import Any, Optional, TypeVar

from nncf.common.graph.graph import NNCFGraph
from nncf.common.quantization.quantizer_setup import QuantizationPointBase
from nncf.common.quantization.quantizer_setup import QuantizationPointId
from nncf.common.quantization.quantizer_setup import SingleConfigQuantizerSetup

TModel = TypeVar("TModel")


class IntDtype(Enum):
INT8 = "INT8"
UINT8 = "UINT8"


class ExtendedQuantizerSetup(ABC, SingleConfigQuantizerSetup):
"""
Quantizer setup with additional info required to insert
quantizers to torch.fx models.
"""

@abstractmethod
def get_extra_params(self) -> dict[QuantizationPointId, dict[str, Any]]:
"""
Returns extra params
"""


class ExtendedFXQuantizerSetup(ExtendedQuantizerSetup):
"""
Quantizer setup with additional info required to insert
quantizers to torch.fx models.
"""

QUANTIZER_DTYPE_NAME = "quantizer_dtype"

def __init__(self) -> None:
super().__init__()
self._quantization_dtypes: dict[QuantizationPointId, Optional[IntDtype]] = {}

def add_independent_quantization_point(
self, qp: QuantizationPointBase, intermediate_dtype: Optional[IntDtype]
) -> QuantizationPointId:
id = super().add_independent_quantization_point(qp)
self._quantization_dtypes[id] = intermediate_dtype
return id

def get_extra_params(self) -> dict[int, dict[str, Any]]:
return {k: {self.QUANTIZER_DTYPE_NAME: v} for k, v in self._quantization_dtypes.items()}

def get_state(self) -> dict[str, Any]:
"""
Returns a dictionary with Python data structures (dict, list, tuple, str, int, float, True, False, None) that
represents state of the object.

:return: state of the object
"""
base_state = super().get_state()
base_state[self.QUANTIZER_DTYPE_NAME] = {
qp_id: dtype.value for qp_id, dtype in self.quantization_points.items()
}

@classmethod
def from_state(cls, state: dict[str, Any]) -> "ExtendedFXQuantizerSetup":
"""
Creates the object from its state.

:param state: Output of `get_state()` method.
"""
state_ = state.copy()
dtype_names = state_.pop(cls.QUANTIZER_DTYPE_NAME)
super_setup = super().from_state(state_)
setup = ExtendedFXQuantizerSetup()

setup.quantization_points = super_setup.quantization_points
setup.unified_scale_groups = super_setup.unified_scale_groups
setup.shared_input_operation_set_groups = super_setup.shared_input_operation_set_groups
setup._quantization_dtypes = {
qp_id: None if name is None else IntDtype[name] for qp_id, name in dtype_names.items()
}

return setup


class Quantizer(ABC):
"""
Quantizer is an interface for the RangeEstimator algorithm
Expand All @@ -35,7 +112,7 @@ def transform_prior_quantization(self, model: TModel) -> TModel:
"""

@abstractmethod
def get_quantization_setup(self, model: TModel, nncf_graph: NNCFGraph) -> SingleConfigQuantizerSetup:
def get_quantization_setup(self, model: TModel, nncf_graph: NNCFGraph) -> ExtendedFXQuantizerSetup:
"""
Builds SingleConfigQuantizerSetup for the given model.

Expand Down
4 changes: 2 additions & 2 deletions nncf/experimental/torch/fx/quantization/quantize_pt2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import torch
import torch.fx
from torch.ao.quantization.pt2e.duplicate_dq_pass import DuplicateDQPass
from torch.ao.quantization.pt2e.port_metadata_pass import PortNodeMetaForQDQ
from torch.ao.quantization.pt2e.utils import _disallow_eval_train
from torch.ao.quantization.pt2e.utils import _fuse_conv_bn_
Expand All @@ -33,6 +32,7 @@
from nncf.experimental.torch.fx.quantization.quantizer.openvino_quantizer import OpenVINOQuantizer
from nncf.experimental.torch.fx.quantization.quantizer.torch_ao_adapter import TorchAOQuantizerAdapter
from nncf.experimental.torch.fx.transformations import QUANTIZE_NODE_TARGETS
from nncf.experimental.torch.fx.transformations import DuplicateDQPassNoAnnotations
from nncf.experimental.torch.fx.transformations import compress_post_quantize_transformation
from nncf.quantization.advanced_parameters import AdvancedBiasCorrectionParameters
from nncf.quantization.advanced_parameters import AdvancedSmoothQuantParameters
Expand Down Expand Up @@ -132,7 +132,7 @@ def quantize_pt2e(
else:
constant_fold(quantized_model, _quant_node_constraint)

pm = PassManager([DuplicateDQPass()])
pm = PassManager([DuplicateDQPassNoAnnotations()])

quantized_model = pm(quantized_model).graph_module
pm = PassManager([PortNodeMetaForQDQ()])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import torch.fx

from nncf.common.graph.graph import NNCFGraph
from nncf.common.quantization.quantizer_setup import SingleConfigQuantizerSetup
from nncf.experimental.quantization.quantizer import ExtendedFXQuantizerSetup
from nncf.experimental.quantization.quantizer import Quantizer
from nncf.experimental.torch.fx.quantization.quantizer.openvino_quantizer import OpenVINOQuantizer

Expand All @@ -28,5 +28,5 @@ def __init__(self, quantizer: OpenVINOQuantizer):
def transform_prior_quantization(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
return self._quantizer.transform_for_annotation(model)

def get_quantization_setup(self, model: torch.fx.GraphModule, nncf_graph: NNCFGraph) -> SingleConfigQuantizerSetup:
def get_quantization_setup(self, model: torch.fx.GraphModule, nncf_graph: NNCFGraph) -> ExtendedFXQuantizerSetup:
return self._quantizer.get_nncf_quantization_setup(model, nncf_graph)
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
from nncf.common.quantization.quantizer_setup import SingleConfigQuantizerSetup
from nncf.common.quantization.structs import QuantizationScheme
from nncf.common.utils.api_marker import api
from nncf.experimental.quantization.quantizer import ExtendedFXQuantizerSetup
from nncf.experimental.quantization.quantizer import IntDtype
from nncf.experimental.torch.fx.nncf_graph_builder import GraphConverter
from nncf.experimental.torch.fx.node_utils import get_graph_node_by_name
from nncf.quantization.advanced_parameters import FP8QuantizationParameters
Expand Down Expand Up @@ -135,9 +137,16 @@ def set_ignored_scope(

def get_nncf_quantization_setup(
self, model: torch.fx.GraphModule, nncf_graph: NNCFGraph
) -> SingleConfigQuantizerSetup:
) -> ExtendedFXQuantizerSetup:
self._min_max_algo._set_backend_entity(model)
return self._min_max_algo.find_quantization_setup(model, nncf_graph)
base_setup = self._min_max_algo.find_quantization_setup(model, nncf_graph)
dtype_map = {}
for id_, qp in base_setup.quantization_points.items():
dtype_map[id_] = None if qp.qconfig.mode == QuantizationScheme.SYMMETRIC else IntDtype.UINT8.value

state = base_setup.get_state()
state[ExtendedFXQuantizerSetup.QUANTIZER_DTYPE_NAME] = dtype_map
return ExtendedFXQuantizerSetup.from_state(state)

def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
"""
Expand Down
Loading
Loading