Skip to content

Commit d675990

Browse files
[TorchFX] Do not use torch quantizer in MinMax
1 parent 8c28aa2 commit d675990

26 files changed

+11084
-10212
lines changed

nncf/experimental/quantization/quantizer.py

Lines changed: 79 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,91 @@
1111

1212
from abc import ABC
1313
from abc import abstractmethod
14-
from typing import TypeVar
14+
from enum import Enum
15+
from typing import Any, Optional, TypeVar
1516

1617
from nncf.common.graph.graph import NNCFGraph
18+
from nncf.common.quantization.quantizer_setup import QuantizationPointBase
19+
from nncf.common.quantization.quantizer_setup import QuantizationPointId
1720
from nncf.common.quantization.quantizer_setup import SingleConfigQuantizerSetup
1821

1922
TModel = TypeVar("TModel")
2023

2124

25+
class IntDtype(Enum):
26+
INT8 = "INT8"
27+
UINT8 = "UINT8"
28+
29+
30+
class ExtendedQuantizerSetup(ABC, SingleConfigQuantizerSetup):
31+
"""
32+
Quantizer setup with additional info required to insert
33+
quantizers to torch.fx models.
34+
"""
35+
36+
@abstractmethod
37+
def get_extra_params(self) -> dict[QuantizationPointId, dict[str, Any]]:
38+
"""
39+
Returns extra params
40+
"""
41+
42+
43+
class ExtendedFXQuantizerSetup(ExtendedQuantizerSetup):
44+
"""
45+
Quantizer setup with additional info required to insert
46+
quantizers to torch.fx models.
47+
"""
48+
49+
QUANTIZER_DTYPE_NAME = "quantizer_dtype"
50+
51+
def __init__(self) -> None:
52+
super().__init__()
53+
self._quantization_dtypes: dict[QuantizationPointId, Optional[IntDtype]] = {}
54+
55+
def add_independent_quantization_point(
56+
self, qp: QuantizationPointBase, intermediate_dtype: Optional[IntDtype]
57+
) -> QuantizationPointId:
58+
id = super().add_independent_quantization_point(qp)
59+
self._quantization_dtypes[id] = intermediate_dtype
60+
return id
61+
62+
def get_extra_params(self) -> dict[int, dict[str, Any]]:
63+
return {k: {self.QUANTIZER_DTYPE_NAME: v} for k, v in self._quantization_dtypes.items()}
64+
65+
def get_state(self) -> dict[str, Any]:
66+
"""
67+
Returns a dictionary with Python data structures (dict, list, tuple, str, int, float, True, False, None) that
68+
represents state of the object.
69+
70+
:return: state of the object
71+
"""
72+
base_state = super().get_state()
73+
base_state[self.QUANTIZER_DTYPE_NAME] = {
74+
qp_id: dtype.value for qp_id, dtype in self.quantization_points.items()
75+
}
76+
77+
@classmethod
78+
def from_state(cls, state: dict[str, Any]) -> "ExtendedFXQuantizerSetup":
79+
"""
80+
Creates the object from its state.
81+
82+
:param state: Output of `get_state()` method.
83+
"""
84+
state_ = state.copy()
85+
dtype_names = state_.pop(cls.QUANTIZER_DTYPE_NAME)
86+
super_setup = super().from_state(state_)
87+
setup = ExtendedFXQuantizerSetup()
88+
89+
setup.quantization_points = super_setup.quantization_points
90+
setup.unified_scale_groups = super_setup.unified_scale_groups
91+
setup.shared_input_operation_set_groups = super_setup.shared_input_operation_set_groups
92+
setup._quantization_dtypes = {
93+
qp_id: None if name is None else IntDtype[name] for qp_id, name in dtype_names.items()
94+
}
95+
96+
return setup
97+
98+
2299
class Quantizer(ABC):
23100
"""
24101
Quantizer is an interface for the RangeEstimator algorithm
@@ -35,7 +112,7 @@ def transform_prior_quantization(self, model: TModel) -> TModel:
35112
"""
36113

37114
@abstractmethod
38-
def get_quantization_setup(self, model: TModel, nncf_graph: NNCFGraph) -> SingleConfigQuantizerSetup:
115+
def get_quantization_setup(self, model: TModel, nncf_graph: NNCFGraph) -> ExtendedFXQuantizerSetup:
39116
"""
40117
Builds SingleConfigQuantizerSetup for the given model.
41118

nncf/experimental/torch/fx/quantization/quantize_pt2e.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
import torch
1616
import torch.fx
17-
from torch.ao.quantization.pt2e.duplicate_dq_pass import DuplicateDQPass
1817
from torch.ao.quantization.pt2e.port_metadata_pass import PortNodeMetaForQDQ
1918
from torch.ao.quantization.pt2e.utils import _disallow_eval_train
2019
from torch.ao.quantization.pt2e.utils import _fuse_conv_bn_
@@ -33,6 +32,7 @@
3332
from nncf.experimental.torch.fx.quantization.quantizer.openvino_quantizer import OpenVINOQuantizer
3433
from nncf.experimental.torch.fx.quantization.quantizer.torch_ao_adapter import TorchAOQuantizerAdapter
3534
from nncf.experimental.torch.fx.transformations import QUANTIZE_NODE_TARGETS
35+
from nncf.experimental.torch.fx.transformations import DuplicateDQPassNoAnnotations
3636
from nncf.experimental.torch.fx.transformations import compress_post_quantize_transformation
3737
from nncf.quantization.advanced_parameters import AdvancedBiasCorrectionParameters
3838
from nncf.quantization.advanced_parameters import AdvancedSmoothQuantParameters
@@ -132,7 +132,7 @@ def quantize_pt2e(
132132
else:
133133
constant_fold(quantized_model, _quant_node_constraint)
134134

135-
pm = PassManager([DuplicateDQPass()])
135+
pm = PassManager([DuplicateDQPassNoAnnotations()])
136136

137137
quantized_model = pm(quantized_model).graph_module
138138
pm = PassManager([PortNodeMetaForQDQ()])

nncf/experimental/torch/fx/quantization/quantizer/openvino_adapter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import torch.fx
1313

1414
from nncf.common.graph.graph import NNCFGraph
15-
from nncf.common.quantization.quantizer_setup import SingleConfigQuantizerSetup
15+
from nncf.experimental.quantization.quantizer import ExtendedFXQuantizerSetup
1616
from nncf.experimental.quantization.quantizer import Quantizer
1717
from nncf.experimental.torch.fx.quantization.quantizer.openvino_quantizer import OpenVINOQuantizer
1818

@@ -28,5 +28,5 @@ def __init__(self, quantizer: OpenVINOQuantizer):
2828
def transform_prior_quantization(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
2929
return self._quantizer.transform_for_annotation(model)
3030

31-
def get_quantization_setup(self, model: torch.fx.GraphModule, nncf_graph: NNCFGraph) -> SingleConfigQuantizerSetup:
31+
def get_quantization_setup(self, model: torch.fx.GraphModule, nncf_graph: NNCFGraph) -> ExtendedFXQuantizerSetup:
3232
return self._quantizer.get_nncf_quantization_setup(model, nncf_graph)

nncf/experimental/torch/fx/quantization/quantizer/openvino_quantizer.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@
3636
from nncf.common.quantization.quantizer_setup import SingleConfigQuantizerSetup
3737
from nncf.common.quantization.structs import QuantizationScheme
3838
from nncf.common.utils.api_marker import api
39+
from nncf.experimental.quantization.quantizer import ExtendedFXQuantizerSetup
40+
from nncf.experimental.quantization.quantizer import IntDtype
3941
from nncf.experimental.torch.fx.nncf_graph_builder import GraphConverter
4042
from nncf.experimental.torch.fx.node_utils import get_graph_node_by_name
4143
from nncf.quantization.advanced_parameters import FP8QuantizationParameters
@@ -135,9 +137,16 @@ def set_ignored_scope(
135137

136138
def get_nncf_quantization_setup(
137139
self, model: torch.fx.GraphModule, nncf_graph: NNCFGraph
138-
) -> SingleConfigQuantizerSetup:
140+
) -> ExtendedFXQuantizerSetup:
139141
self._min_max_algo._set_backend_entity(model)
140-
return self._min_max_algo.find_quantization_setup(model, nncf_graph)
142+
base_setup = self._min_max_algo.find_quantization_setup(model, nncf_graph)
143+
dtype_map = {}
144+
for id_, qp in base_setup.quantization_points.items():
145+
dtype_map[id_] = None if qp.qconfig.mode == QuantizationScheme.SYMMETRIC else IntDtype.UINT8.value
146+
147+
state = base_setup.get_state()
148+
state[ExtendedFXQuantizerSetup.QUANTIZER_DTYPE_NAME] = dtype_map
149+
return ExtendedFXQuantizerSetup.from_state(state)
141150

142151
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
143152
"""

nncf/experimental/torch/fx/quantization/quantizer/torch_ao_adapter.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,11 @@
2626
from nncf.common.quantization.quantizer_setup import ActivationQuantizationInsertionPoint
2727
from nncf.common.quantization.quantizer_setup import QuantizationPointBase
2828
from nncf.common.quantization.quantizer_setup import SingleConfigQuantizationPoint
29-
from nncf.common.quantization.quantizer_setup import SingleConfigQuantizerSetup
3029
from nncf.common.quantization.quantizer_setup import WeightQuantizationInsertionPoint
3130
from nncf.common.quantization.structs import QuantizationScheme as QuantizationMode
3231
from nncf.common.quantization.structs import QuantizerConfig
32+
from nncf.experimental.quantization.quantizer import ExtendedFXQuantizerSetup
33+
from nncf.experimental.quantization.quantizer import IntDtype
3334
from nncf.experimental.quantization.quantizer import Quantizer
3435
from nncf.experimental.torch.fx.nncf_graph_builder import GraphConverter
3536

@@ -47,7 +48,7 @@ def __init__(self, quantizer: TorchAOQuantizer):
4748
def transform_prior_quantization(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
4849
return self._quantizer.transform_for_annotation(model)
4950

50-
def get_quantization_setup(self, model: torch.fx.GraphModule, nncf_graph: NNCFGraph) -> SingleConfigQuantizerSetup:
51+
def get_quantization_setup(self, model: torch.fx.GraphModule, nncf_graph: NNCFGraph) -> ExtendedFXQuantizerSetup:
5152
# Save model and nodes meta before the annotation
5253
original_meta = model.meta.copy()
5354
node_name_vs_meta = {}
@@ -116,14 +117,14 @@ def _get_node_args(node: torch.fx.Node) -> tuple[Any, ...]:
116117
return node.args
117118

118119
@staticmethod
119-
def get_quantizer_config_from_annotated_model(annotated: torch.fx.GraphModule) -> SingleConfigQuantizerSetup:
120+
def get_quantizer_config_from_annotated_model(annotated: torch.fx.GraphModule) -> ExtendedFXQuantizerSetup:
120121
"""
121122
Process a torch.fx.GraphModule annotated with quantization specifications
122123
(e.g., via torch.ao observers) and generates a corresponding NNCF quantization setup object,
123124
which maps quantization configurations to graph edges.
124125
125126
:param annotated: A torch.fx.GraphModule that has been annotated with Torch quantization observers.
126-
:return: A SingleConfigQuantizerSetup containing quantization points derived from the annotated model.
127+
:return: A ExtendedFXQuantizerSetup containing quantization points derived from the annotated model.
127128
"""
128129
edge_or_node_to_qspec = _get_edge_or_node_to_qspec(annotated)
129130
# Node means all output edges should be quantized.
@@ -142,7 +143,7 @@ def get_quantizer_config_from_annotated_model(annotated: torch.fx.GraphModule) -
142143
edge_or_node_to_qspec[edge_or_node], edge_or_node_to_qspec
143144
)
144145

145-
q_setup = SingleConfigQuantizerSetup()
146+
q_setup = ExtendedFXQuantizerSetup()
146147
for group_id, edges in group_id_vs_edges.items():
147148
qspec = group_id_vs_qspec[group_id]
148149
if qspec is None:
@@ -159,15 +160,15 @@ def get_quantizer_config_from_annotated_model(annotated: torch.fx.GraphModule) -
159160
msg = f"Unknown qscheme: {qspec.qscheme}"
160161
raise nncf.InternalError(msg)
161162

162-
signed = qspec.dtype is torch.int8
163+
dtype = IntDtype.INT8 if qspec.dtype is torch.int8 else IntDtype.UINT8
163164
mode = (
164165
QuantizationMode.SYMMETRIC
165166
if qspec.qscheme in [torch.per_channel_symmetric, torch.per_tensor_symmetric]
166167
else QuantizationMode.ASYMMETRIC
167168
)
168169
narrow_range = qspec.quant_min % 2 != 0
169170
qconfig = QuantizerConfig(
170-
mode=mode, signedness_to_force=signed, per_channel=per_channel, narrow_range=narrow_range
171+
mode=mode, signedness_to_force=False, per_channel=per_channel, narrow_range=narrow_range
171172
)
172173

173174
joined_edges = defaultdict(list)
@@ -179,7 +180,7 @@ def get_quantizer_config_from_annotated_model(annotated: torch.fx.GraphModule) -
179180
qps.extend(TorchAOQuantizerAdapter._get_quantization_points(from_node, to_nodes, annotated, qconfig))
180181
qp_ids = []
181182
for qp in qps:
182-
qp_ids.append(q_setup.add_independent_quantization_point(qp))
183+
qp_ids.append(q_setup.add_independent_quantization_point(qp, dtype))
183184
if len(qp_ids) > 1:
184185
q_setup.register_unified_scale_group(qp_ids)
185186

nncf/experimental/torch/fx/transformations.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,17 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12+
import operator
1213
from copy import copy
1314
from typing import Any, Callable, Optional, Union
1415

1516
import torch
1617
import torch.fx
1718
from torch.ao.quantization.fx.utils import create_getattr_from_value
1819
from torch.ao.quantization.pt2e.utils import _fuse_conv_bn_
20+
from torch.fx.node import map_arg
21+
from torch.fx.passes.infra.pass_base import PassBase
22+
from torch.fx.passes.infra.pass_base import PassResult
1923
from torch.quantization.fake_quantize import FakeQuantize
2024

2125
import nncf
@@ -741,3 +745,65 @@ def constraint_fn(node: torch.fx.Node):
741745
return node.op != "call_function" or node.target not in QUANTIZE_NODE_TARGETS + DEQUANTIZE_NODE_TARGETS
742746

743747
constant_fold(model, constraint_fn=constraint_fn)
748+
749+
750+
def _duplicate_dq(gm: torch.fx.GraphModule, dq_node: torch.fx.Node, user: torch.fx.Node):
751+
with gm.graph.inserting_after(dq_node):
752+
new_node = gm.graph.node_copy(dq_node)
753+
754+
def maybe_replace_node(n: torch.fx.Node) -> torch.fx.Node:
755+
if n == dq_node:
756+
return new_node
757+
else:
758+
return n
759+
760+
new_args = map_arg(user.args, maybe_replace_node)
761+
new_kwargs = map_arg(user.kwargs, maybe_replace_node)
762+
user.args = new_args
763+
user.kwargs = new_kwargs
764+
765+
766+
def _is_sym_size_node(node: torch.fx.Node):
767+
return (
768+
node.op == "call_function"
769+
and node.target == torch.ops.aten.sym_size.default
770+
or node.target == torch.ops.aten.sym_numel.default
771+
or node.target == torch.ops.aten.sym_numel
772+
or node.target == torch.ops.aten.sym_size
773+
)
774+
775+
776+
def _filter_sym_size_users(node: torch.fx.Node) -> list[torch.fx.Node]:
777+
node_users = list(filter((lambda x: (_is_sym_size_node(x) is False)), node.users))
778+
return node_users
779+
780+
781+
class DuplicateDQPassNoAnnotations(PassBase):
782+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
783+
for node in graph_module.graph.nodes:
784+
if node.op == "call_function" and node.target in DEQUANTIZE_NODE_TARGETS:
785+
dq_users = _filter_sym_size_users(node)
786+
if len(dq_users) <= 1:
787+
continue
788+
# Do not duplicate dq for dynamic quantization
789+
# Pattern: choose_qparam - getitem - q - dq
790+
q_node = node.args[0]
791+
if q_node.op == "call_function" and q_node.target in QUANTIZE_NODE_TARGETS:
792+
getitem_node = q_node.args[1]
793+
if (
794+
isinstance(getitem_node, torch.fx.node.Node)
795+
and getitem_node.op == "call_function"
796+
and getitem_node.target == operator.getitem
797+
):
798+
choose_qparam_node = getitem_node.args[0]
799+
if (
800+
isinstance(choose_qparam_node, torch.fx.node.Node)
801+
and choose_qparam_node.op == "call_function"
802+
and choose_qparam_node.target == torch.ops.quantized_decomposed.choose_qparams.tensor
803+
):
804+
continue
805+
for user in dq_users:
806+
_duplicate_dq(graph_module, node, user)
807+
graph_module.graph.eliminate_dead_code()
808+
graph_module.recompile()
809+
return PassResult(graph_module, True)

0 commit comments

Comments
 (0)