Skip to content

Commit b480ef0

Browse files
WIP
1 parent 555dfc9 commit b480ef0

File tree

24 files changed

+8883
-8753
lines changed

24 files changed

+8883
-8753
lines changed

Diff for: nncf/common/quantization/quantizer_setup.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -247,12 +247,13 @@ def __init__(self) -> None:
247247
self._next_unified_scale_gid = 0
248248
self._next_shared_inputs_gid = 0
249249

250-
def add_independent_quantization_point(self, qp: QuantizationPointBase) -> None:
250+
def add_independent_quantization_point(self, qp: QuantizationPointBase) -> int:
251251
if self.quantization_points.keys():
252252
new_id = max(self.quantization_points.keys()) + 1
253253
else:
254254
new_id = 0
255255
self.quantization_points[new_id] = qp
256+
return new_id
256257

257258
def register_unified_scale_group(self, qp_group: List[QuantizationPointId]) -> int:
258259
for qp_id in qp_group:

Diff for: nncf/experimental/torch/fx/quantization/quantizer/torch_ao_adapter.py

+93-60
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,18 @@
1111

1212

1313
from collections import defaultdict
14-
from typing import Dict, List, Tuple, Union
14+
from typing import Any, Dict, List, Tuple, Union
1515

1616
import torch
1717
import torch.fx
18+
from torch.ao.quantization.pt2e.prepare import _get_edge_or_node_to_group_id
19+
from torch.ao.quantization.pt2e.prepare import _get_edge_or_node_to_qspec
1820
from torch.ao.quantization.quantizer import Quantizer as TorchAOQuantizer
1921
from torch.ao.quantization.quantizer.quantizer import QuantizationSpec
20-
from torch.ao.quantization.quantizer.quantizer import QuantizationSpecBase
2122
from torch.ao.quantization.quantizer.quantizer import SharedQuantizationSpec
2223

2324
import nncf
2425
from nncf.common.graph.graph import NNCFGraph
25-
from nncf.common.logging import nncf_logger
2626
from nncf.common.quantization.quantizer_setup import ActivationQuantizationInsertionPoint
2727
from nncf.common.quantization.quantizer_setup import QuantizationPointBase
2828
from nncf.common.quantization.quantizer_setup import SingleConfigQuantizationPoint
@@ -73,6 +73,15 @@ def _get_quantization_points(
7373
annotated_model: torch.fx.GraphModule,
7474
qconfig: QuantizerConfig,
7575
) -> List[QuantizationPointBase]:
76+
"""
77+
Creates quantization points based on the nodes and edges.
78+
79+
:param from_node: The originating node in the computation graph.
80+
:param to_nodes: The list of destination nodes of the from_node.
81+
:param annotated_model: The torch.fx.GraphModule instance.
82+
:param qconfig: The torch.ao quantization configuration.
83+
:return: A list of NNCF quantization points.
84+
"""
7685
to_n = to_nodes[0]
7786
if from_node.op == "get_attr":
7887
_, metatype = GraphConverter.get_node_type_and_metatype(to_n, annotated_model)
@@ -95,78 +104,102 @@ def _get_quantization_points(
95104
return qps
96105

97106
@staticmethod
98-
def _get_node_args(node: torch.fx.Node):
107+
def _get_node_args(node: torch.fx.Node) -> Tuple[Any, ...]:
108+
"""
109+
Correctly retrieves arguments of the given node.
110+
111+
:param node: The given node.
112+
:return: The arguments of the given node.
113+
"""
99114
if node.target == torch.ops.aten.cat.default:
100115
return node.args[0]
101116
return node.args
102117

103118
@staticmethod
104-
def get_quantizer_config_from_annotated_model(annotated_model: torch.fx.GraphModule) -> SingleConfigQuantizerSetup:
105-
edge_or_node_to_qspec = _get_edge_or_node_to_qspec(annotated_model)
106-
107-
q_map = defaultdict(list)
108-
for edge, qspec in edge_or_node_to_qspec.items():
109-
if not isinstance(edge, tuple):
110-
continue
111-
from_n, to_n = edge
112-
q_map[from_n].append(to_n)
119+
def get_quantizer_config_from_annotated_model(annotated: torch.fx.GraphModule) -> SingleConfigQuantizerSetup:
120+
edge_or_node_to_qspec = _get_edge_or_node_to_qspec(annotated)
121+
# Node means all output edges should be quantized.
122+
# Edge means only one edge should be quantized.
123+
edge_or_node_to_group_id = _get_edge_or_node_to_group_id(edge_or_node_to_qspec)
124+
125+
group_id_vs_edges = defaultdict(set)
126+
group_id_vs_qspec = {}
127+
for edge_or_node, group_id in edge_or_node_to_group_id.items():
128+
target_edges = [edge_or_node]
129+
if isinstance(edge_or_node, torch.fx.Node):
130+
target_edges = []
131+
for user in edge_or_node.users:
132+
target_edges.append((edge_or_node, user))
133+
group_id_vs_edges[group_id].update(target_edges)
134+
# All qspecs should be aligned after the _get_edge_or_node_to_group_id call
135+
group_id_vs_qspec[group_id] = _unwrap_shared_qspec_safe(
136+
edge_or_node_to_qspec[edge_or_node], edge_or_node_to_qspec
137+
)
113138

114139
q_setup = SingleConfigQuantizerSetup()
115-
for from_n, to_nodes in q_map.items():
116-
to_n = to_nodes[0]
117-
qspec = edge_or_node_to_qspec[(from_n, to_n)]
140+
for group_id, edges in group_id_vs_edges.items():
141+
qspec = group_id_vs_qspec[group_id]
118142
if qspec is None:
119143
continue
120-
if isinstance(qspec, QuantizationSpec):
121-
if qspec.qscheme in [torch.per_channel_affine, torch.per_channel_symmetric]:
122-
per_channel = True
123-
elif qspec.qscheme in [torch.per_tensor_affine, torch.per_tensor_symmetric]:
124-
per_channel = False
125-
else:
126-
msg = f"Unknown qscheme: {qspec.qscheme}"
127-
raise nncf.InternalError(msg)
128-
signed = qspec.dtype is torch.int8
129-
mode = (
130-
QuantizationMode.SYMMETRIC
131-
if qspec.qscheme in [torch.per_channel_symmetric, torch.per_tensor_symmetric]
132-
else QuantizationMode.ASYMMETRIC
133-
)
134-
qconfig = QuantizerConfig(mode=mode, signedness_to_force=signed, per_channel=per_channel)
135-
136-
qps = TorchAOQuantizerAdapter._get_quantization_points(from_n, to_nodes, annotated_model, qconfig)
137-
for qp in qps:
138-
q_setup.add_independent_quantization_point(qp)
139-
140-
elif isinstance(qspec, SharedQuantizationSpec):
141-
# TODO(dlyakhov): Support SharedQuantizationSpec
142-
nncf_logger.warning(
143-
f"SharedQuantizationSpec is not supported yet; edges {from_n} -> {to_nodes} won't be quantized."
144-
)
145-
else:
144+
if not isinstance(qspec, QuantizationSpec):
146145
msg = f"Unknown torch.ao quantization spec: {qspec}"
147146
raise nncf.InternalError(msg)
148147

148+
if qspec.qscheme in [torch.per_channel_affine, torch.per_channel_symmetric]:
149+
per_channel = True
150+
elif qspec.qscheme in [torch.per_tensor_affine, torch.per_tensor_symmetric]:
151+
per_channel = False
152+
else:
153+
msg = f"Unknown qscheme: {qspec.qscheme}"
154+
raise nncf.InternalError(msg)
155+
156+
signed = qspec.dtype is torch.int8
157+
mode = (
158+
QuantizationMode.SYMMETRIC
159+
if qspec.qscheme in [torch.per_channel_symmetric, torch.per_tensor_symmetric]
160+
else QuantizationMode.ASYMMETRIC
161+
)
162+
narrow_range = qspec.quant_min % 2 != 0
163+
qconfig = QuantizerConfig(
164+
mode=mode, signedness_to_force=signed, per_channel=per_channel, narrow_range=narrow_range
165+
)
166+
167+
joined_edges = defaultdict(list)
168+
for edge in edges:
169+
joined_edges[edge[0]].append(edge[1])
170+
171+
qps = []
172+
for from_node, to_nodes in joined_edges.items():
173+
qps.extend(TorchAOQuantizerAdapter._get_quantization_points(from_node, to_nodes, annotated, qconfig))
174+
qp_ids = []
175+
for qp in qps:
176+
qp_ids.append(q_setup.add_independent_quantization_point(qp))
177+
if len(qp_ids) > 1:
178+
q_setup.register_unified_scale_group(qp_ids)
179+
149180
return q_setup
150181

151182

152-
def _get_edge_or_node_to_qspec(
153-
model: torch.fx.GraphModule,
154-
) -> Dict[EdgeOrNode, QuantizationSpecBase]:
183+
def _unwrap_shared_qspec_safe(qspec: QuantizationSpec, edge_or_node_to_qspec: Dict[EdgeOrNode, QuantizationSpec]):
155184
"""
156-
Get a map from EdgeOrNode to quantization spec based on annotations on the nodes.
185+
Iteratively unwraps a given SharedQuantizationSpec to retrieve its actual QuantizationSpec.
186+
It detects cyclic dependencies and enforces a maximum depth limit to prevent infinite recursion.
157187
158-
:param model: torch.fx.GraphModule instance.
159-
:return: A map from EdgeOrNode to quantization spec based on annotations on the nodes.
188+
:param qspec: The quantization specification to unwrap.
189+
:param edge_or_node_to_qspec: A dictionary mapping EdgeOrNode instances to their respective QuantizationSpec.
190+
:return: The resolved QuantizationSpec.
160191
"""
161-
edge_or_node_to_qspec: Dict[EdgeOrNode, QuantizationSpecBase] = {}
162-
for n in model.graph.nodes:
163-
if hasattr(n, "meta") and "quantization_annotation" in n.meta:
164-
qa = n.meta["quantization_annotation"]
165-
for input_to_n, qspec in qa.input_qspec_map.items():
166-
input_edge = (input_to_n, n)
167-
edge_or_node_to_qspec[input_edge] = qspec
168-
if qa.output_qspec is not None:
169-
output_node = n
170-
qspec = qa.output_qspec
171-
edge_or_node_to_qspec[output_node] = qspec
172-
return edge_or_node_to_qspec
192+
MAX_DEPTH = 1000
193+
i = 0
194+
visited = []
195+
while i < MAX_DEPTH and isinstance(qspec, SharedQuantizationSpec):
196+
if qspec.edge_or_node in visited:
197+
msg = f"A cycled dependency of the quantization spec is detected {visited + [qspec.edge_or_node]}"
198+
raise RuntimeError(msg)
199+
visited.append(qspec.edge_or_node)
200+
qspec = edge_or_node_to_qspec[qspec.edge_or_node]
201+
i += 1
202+
if i == MAX_DEPTH:
203+
msg = f"Shared qspecs referenced to each other more than the limit: {MAX_DEPTH}"
204+
raise RuntimeError(msg)
205+
return qspec

Diff for: tests/post_training/pipelines/image_classification_base.py

+46-8
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,37 @@
1212
import copy
1313
import os
1414

15+
from torch.ao.quantization.observer import MovingAverageMinMaxObserver
16+
17+
os.environ["TORCHINDUCTOR_FREEZING"] = "1"
18+
19+
from itertools import islice
20+
from typing import Optional
21+
1522
import numpy as np
1623
import openvino as ov
1724
import torch
25+
26+
# from executorch.backends.arm.quantizer import arm_quantizer
27+
from executorch.backends.qualcomm.quantizer import quantizer as qualcom_q
1828
from sklearn.metrics import accuracy_score
29+
from torch.ao.quantization.quantize_pt2e import convert_pt2e
30+
from torch.ao.quantization.quantize_pt2e import prepare_pt2e
31+
from torch.ao.quantization.quantizer import xnnpack_quantizer
32+
from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer
33+
from torch.ao.quantization.quantizer.x86_inductor_quantizer import get_default_x86_inductor_quantization_config
1934
from torchvision import datasets
2035

2136
import nncf
37+
from nncf import AdvancedQuantizationParameters
2238
from nncf.common.logging.track_progress import track
39+
from nncf.experimental.torch.fx import OpenVINOQuantizer
40+
from nncf.experimental.torch.fx import quantize_pt2e
41+
from nncf.torch import disable_patching
2342
from tests.post_training.pipelines.base import DEFAULT_VAL_THREADS
2443
from tests.post_training.pipelines.base import FX_BACKENDS
44+
from tests.post_training.pipelines.base import FX_EAGER_BACKENDS
45+
from tests.post_training.pipelines.base import BackendType
2546
from tests.post_training.pipelines.base import PTQTestPipeline
2647

2748

@@ -75,14 +96,31 @@ def process_result(request, userdata):
7596
def _validate_torch_compile(
7697
self, val_loader: torch.utils.data.DataLoader, predictions: np.ndarray, references: np.ndarray
7798
):
78-
compiled_model = torch.compile(self.compressed_model.cpu(), backend="openvino")
79-
for i, (images, target) in enumerate(val_loader):
80-
# W/A for memory leaks when using torch DataLoader and OpenVINO
81-
pred = compiled_model(images)
82-
pred = torch.argmax(pred, dim=1)
83-
predictions[i] = pred.numpy()
84-
references[i] = target.numpy()
85-
return predictions, references
99+
# compiled_model = torch.compile(self.compressed_model, backend="openvino")
100+
q_num = 0
101+
for node in self.compressed_model.graph.nodes:
102+
if ".quantize_per" in str(node.target):
103+
q_num += 1
104+
105+
print(f"Qunatize ops num: {q_num}")
106+
107+
with disable_patching():
108+
with torch.no_grad():
109+
if self.backend in FX_EAGER_BACKENDS:
110+
# Run such models in eager model
111+
compiled_model = self.compressed_model
112+
if self.backend in [BackendType.X86_QUANTIZER_AO, BackendType.X86_QUANTIZER_NNCF]:
113+
compiled_model = torch.compile(self.compressed_model)
114+
else:
115+
compiled_model = torch.compile(self.compressed_model, backend="openvino")
116+
117+
for i, (images, target) in enumerate(val_loader):
118+
# W/A for memory leaks when using torch DataLoader and OpenVINO
119+
pred = compiled_model(images)
120+
pred = torch.argmax(pred, dim=1)
121+
predictions[i] = pred.numpy()
122+
references[i] = target.numpy()
123+
return predictions, references
86124

87125
def _validate(self) -> None:
88126
val_dataset = datasets.ImageFolder(root=self.data_dir / "imagenet" / "val", transform=self.transform)

0 commit comments

Comments
 (0)