Skip to content

Commit 0c628c4

Browse files
Conformance: nncf.quantize_pt2e and OpenVINOQuantize support
No grad during the TorchFX model validation quantization params are being forwarded to quantize_pt2e/OpenVINOQuantizer
1 parent d675990 commit 0c628c4

File tree

5 files changed

+167
-5
lines changed

5 files changed

+167
-5
lines changed

nncf/quantization/algorithms/min_max/backend.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ def create_quantizer_insertion_command(
174174
:param target_point: Target location for the quantizer insertion.
175175
:param quantizer_config: QuantizerConfig instance for the current layer.
176176
:param parameters: FakeQuantizeParameters to calculate activation quantization parameters.
177+
:param extra_params: Additional backend-specific parameters to initiate a quantizer insertion command.
177178
:return: Backend-specific Command for the quantizer insertion operation.
178179
"""
179180

@@ -193,6 +194,7 @@ def create_unified_scales_quantizers_insertion_commands(
193194
:param target_points: List of target locations for the quantizers insertion.
194195
:param quantizer_config: QuantizerConfig instance for the current layer.
195196
:param parameters: FakeQuantizeParameters to calculate activation quantization parameters.
197+
:param extra_params: Additional backend-specific parameters to initiate a quantizer insertion command.
196198
:return: List of backend-specific Commands
197199
for the quantizers with unified scales insertion operations.
198200
"""

tests/post_training/data/ptq_reference_data.yaml

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,13 @@ torchvision/resnet18_backend_FX_TORCH:
4343
error_message: "Openvino Model Files Not Found!"
4444
message: "Issue-166847"
4545
torchvision/resnet18_backend_CUDA_FX_TORCH:
46+
torchvision/resnet18_backend_OV_QUANTIZER_NNCF:
47+
metric_value: 0.6946
48+
torchvision/resnet18_backend_OV_QUANTIZER_AO:
49+
metric_value: 0.6946
50+
torchvision/resnet18_backend_X86_QUANTIZER_NNCF:
51+
metric_value: 0.6946
52+
torchvision/resnet18_backend_X86_QUANTIZER_AO:
4653
metric_value: 0.6946
4754
exception_xfail_reason:
4855
type: "FileNotFoundError"
@@ -66,6 +73,14 @@ torchvision/mobilenet_v3_small_BC_backend_CUDA_FX_TORCH:
6673
type: "FileNotFoundError"
6774
error_message: "Openvino Model Files Not Found!"
6875
message: "Issue-166847"
76+
torchvision/mobilenet_v3_small_BC_backend_OV_QUANTIZER_NNCF:
77+
metric_value: 0.6679
78+
torchvision/mobilenet_v3_small_BC_backend_OV_QUANTIZER_AO:
79+
metric_value: 0.6679
80+
torchvision/mobilenet_v3_small_BC_backend_X86_QUANTIZER_NNCF:
81+
metric_value: 0.6679
82+
torchvision/mobilenet_v3_small_BC_backend_X86_QUANTIZER_AO:
83+
metric_value: 0.6679
6984
torchvision/vit_b_16_backend_FP32:
7085
metric_value: 0.8107
7186
torchvision/vit_b_16_backend_OV:
@@ -77,6 +92,13 @@ torchvision/vit_b_16_backend_FX_TORCH:
7792
error_message: "Openvino Model Files Not Found!"
7893
message: "Issue-166847"
7994
torchvision/vit_b_16_backend_CUDA_FX_TORCH:
95+
torchvision/vit_b_16_backend_OV_QUANTIZER_NNCF:
96+
metric_value: 0.80922
97+
torchvision/vit_b_16_backend_OV_QUANTIZER_AO:
98+
metric_value: 0.80922
99+
torchvision/vit_b_16_backend_X86_QUANTIZER_NNCF:
100+
metric_value: 0.80922
101+
torchvision/vit_b_16_backend_X86_QUANTIZER_AO:
80102
metric_value: 0.80922
81103
exception_xfail_reason:
82104
type: "FileNotFoundError"
@@ -93,6 +115,13 @@ torchvision/swin_v2_s_backend_FX_TORCH:
93115
error_message: "Openvino Model Files Not Found!"
94116
message: "Issue-166847"
95117
torchvision/swin_v2_s_backend_CUDA_FX_TORCH:
118+
torchvision/swin_v2_s_backend_OV_QUANTIZER_NNCF:
119+
metric_value: 0.8360
120+
torchvision/swin_v2_s_backend_OV_QUANTIZER_AO:
121+
metric_value: 0.8360
122+
torchvision/swin_v2_s_backend_X86_QUANTIZER_NNCF:
123+
metric_value: 0.8360
124+
torchvision/swin_v2_s_backend_X86_QUANTIZER_AO:
96125
metric_value: 0.8360
97126
exception_xfail_reason:
98127
type: "FileNotFoundError"

tests/post_training/model_scope.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from nncf.quantization.advanced_parameters import AdvancedScaleEstimationParameters
2424
from nncf.quantization.advanced_parameters import AdvancedSmoothQuantParameters
2525
from tests.post_training.pipelines.base import ALL_PTQ_BACKENDS
26+
from tests.post_training.pipelines.base import FX_BACKENDS
2627
from tests.post_training.pipelines.base import NNCF_PTQ_BACKENDS
2728
from tests.post_training.pipelines.base import BackendType
2829
from tests.post_training.pipelines.causal_language_model import CausalLMHF
@@ -107,7 +108,7 @@
107108
"fast_bias_correction": False,
108109
"preset": QuantizationPreset.MIXED,
109110
},
110-
"backends": [BackendType.FX_TORCH, BackendType.CUDA_FX_TORCH, BackendType.OV, BackendType.ONNX],
111+
"backends": FX_BACKENDS + [BackendType.OV, BackendType.ONNX],
111112
"batch_size": 128,
112113
},
113114
{
@@ -118,7 +119,7 @@
118119
"model_type": ModelType.TRANSFORMER,
119120
"advanced_parameters": AdvancedQuantizationParameters(smooth_quant_alpha=0.15),
120121
},
121-
"backends": [BackendType.FX_TORCH, BackendType.CUDA_FX_TORCH, BackendType.OV],
122+
"backends": FX_BACKENDS + [BackendType.OV],
122123
"batch_size": 1,
123124
},
124125
{
@@ -129,7 +130,7 @@
129130
"model_type": ModelType.TRANSFORMER,
130131
"advanced_parameters": AdvancedQuantizationParameters(smooth_quant_alpha=0.5),
131132
},
132-
"backends": [BackendType.FX_TORCH, BackendType.CUDA_FX_TORCH, BackendType.OV],
133+
"backends": FX_BACKENDS + [BackendType.OV],
133134
"batch_size": 1,
134135
},
135136
# Timm models

tests/post_training/pipelines/base.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,10 @@ class BackendType(Enum):
5757
CUDA_TORCH = "CUDA_TORCH"
5858
FX_TORCH = "FX_TORCH"
5959
CUDA_FX_TORCH = "CUDA_FX_TORCH"
60+
OV_QUANTIZER_NNCF = "OV_QUANTIZER_NNCF"
61+
OV_QUANTIZER_AO = "OV_QUANTIZER_AO"
62+
X86_QUANTIZER_NNCF = "X86_QUANTIZER_NNCF"
63+
X86_QUANTIZER_AO = "X86_QUANTIZER_AO"
6064
ONNX = "ONNX"
6165
OV = "OV"
6266
OPTIMUM = "OPTIMUM"
@@ -65,7 +69,14 @@ class BackendType(Enum):
6569
NNCF_PTQ_BACKENDS = [BackendType.TORCH, BackendType.CUDA_TORCH, BackendType.ONNX, BackendType.OV]
6670
ALL_PTQ_BACKENDS = NNCF_PTQ_BACKENDS
6771
PT_BACKENDS = [BackendType.TORCH, BackendType.CUDA_TORCH]
68-
FX_BACKENDS = [BackendType.FX_TORCH, BackendType.CUDA_FX_TORCH]
72+
FX_BACKENDS = [
73+
BackendType.FX_TORCH,
74+
BackendType.CUDA_FX_TORCH,
75+
BackendType.OV_QUANTIZER_NNCF,
76+
BackendType.OV_QUANTIZER_AO,
77+
BackendType.X86_QUANTIZER_NNCF,
78+
BackendType.X86_QUANTIZER_AO,
79+
]
6980
OV_BACKENDS = [BackendType.OV, BackendType.OPTIMUM]
7081

7182
LIMIT_LENGTH_OF_STATUS = 120

tests/post_training/pipelines/image_classification_base.py

Lines changed: 120 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,30 @@
1212
import copy
1313
import os
1414

15+
os.environ["TORCHINDUCTOR_FREEZING"] = "1"
16+
17+
from itertools import islice
18+
1519
import numpy as np
1620
import openvino as ov
1721
import torch
1822
from sklearn.metrics import accuracy_score
23+
from torch.ao.quantization.quantize_pt2e import convert_pt2e
24+
from torch.ao.quantization.quantize_pt2e import prepare_pt2e
25+
from torch.ao.quantization.quantizer.quantizer import Quantizer as TorchAOQuantizer
26+
from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer
27+
from torch.ao.quantization.quantizer.x86_inductor_quantizer import get_default_x86_inductor_quantization_config
1928
from torchvision import datasets
2029

2130
import nncf
31+
from nncf import AdvancedQuantizationParameters
2232
from nncf.common.logging.track_progress import track
33+
from nncf.experimental.torch.fx import OpenVINOQuantizer
34+
from nncf.experimental.torch.fx import quantize_pt2e
35+
from nncf.torch import disable_patching
2336
from tests.post_training.pipelines.base import DEFAULT_VAL_THREADS
2437
from tests.post_training.pipelines.base import FX_BACKENDS
38+
from tests.post_training.pipelines.base import BackendType
2539
from tests.post_training.pipelines.base import PTQTestPipeline
2640

2741

@@ -75,7 +89,17 @@ def process_result(request, userdata):
7589
def _validate_torch_compile(
7690
self, val_loader: torch.utils.data.DataLoader, predictions: np.ndarray, references: np.ndarray
7791
):
78-
compiled_model = torch.compile(self.compressed_model.cpu(), backend="openvino", options={"aot_autograd": True})
92+
if self.backend in [
93+
BackendType.FX_TORCH,
94+
BackendType.CUDA_FX_TORCH,
95+
BackendType.OV_QUANTIZER_AO,
96+
BackendType.OV_QUANTIZER_NNCF,
97+
]:
98+
compiled_model = torch.compile(
99+
self.compressed_model.cpu(), backend="openvino", options={"aot_autograd": True}
100+
)
101+
else:
102+
compiled_model = torch.compile(self.compressed_model)
79103
for i, (images, target) in enumerate(val_loader):
80104
# W/A for memory leaks when using torch DataLoader and OpenVINO
81105
pred = compiled_model(images)
@@ -103,3 +127,98 @@ def _validate(self) -> None:
103127

104128
self.run_info.metric_name = "Acc@1"
105129
self.run_info.metric_value = acc_top1
130+
return []
131+
132+
def _compress_torch_ao(self, quantizer):
133+
with torch.no_grad(), disable_patching():
134+
prepared_model = prepare_pt2e(self.model, quantizer)
135+
subset_size = self.compression_params.get("subset_size", 300)
136+
for data in islice(self.calibration_dataset.get_inference_data(), subset_size):
137+
prepared_model(data)
138+
self.compressed_model = convert_pt2e(prepared_model)
139+
140+
def _compress_nncf_pt2e(self, quantizer):
141+
pt2e_kwargs = {}
142+
for key in (
143+
"subset_size",
144+
"fast_bias_correction",
145+
):
146+
if key in self.compression_params:
147+
pt2e_kwargs[key] = self.compression_params[key]
148+
149+
advanced_parameters: AdvancedQuantizationParameters = self.compression_params.get(
150+
"advanced_parameters", AdvancedQuantizationParameters()
151+
)
152+
153+
sq_params = advanced_parameters.smooth_quant_alphas
154+
sq_alpha = advanced_parameters.smooth_quant_alpha
155+
if sq_alpha is not None:
156+
if sq_alpha < 0:
157+
sq_params.convolution = -1
158+
sq_params.matmul = -1
159+
else:
160+
sq_params.matmul = sq_alpha
161+
pt2e_kwargs["smooth_quant_params"] = sq_params
162+
pt2e_kwargs["bias_correction_params"] = advanced_parameters.bias_correction_params
163+
pt2e_kwargs["activations_range_estimator_params"] = advanced_parameters.activations_range_estimator_params
164+
pt2e_kwargs["weights_range_estimator_params"] = advanced_parameters.weights_range_estimator_params
165+
166+
smooth_quant = False
167+
if self.compression_params.get("model_type", False):
168+
smooth_quant = self.compression_params["model_type"] == nncf.ModelType.TRANSFORMER
169+
170+
with disable_patching(), torch.no_grad():
171+
self.compressed_model = quantize_pt2e(
172+
self.model,
173+
quantizer,
174+
self.calibration_dataset,
175+
smooth_quant=smooth_quant,
176+
fold_quantize=False,
177+
**pt2e_kwargs,
178+
)
179+
180+
def _compress(self):
181+
"""
182+
Quantize self.model
183+
"""
184+
if self.backend not in FX_BACKENDS:
185+
super()._compress()
186+
187+
return
188+
if self.backend in [BackendType.FX_TORCH, BackendType.CUDA_FX_TORCH]:
189+
with disable_patching(), torch.no_grad():
190+
super()._compress()
191+
return
192+
193+
quantizer = self._build_quantizer()
194+
195+
if self.backend in [BackendType.OV_QUANTIZER_NNCF, BackendType.X86_QUANTIZER_NNCF]:
196+
self._compress_nncf_pt2e(quantizer)
197+
else:
198+
self._compress_torch_ao(quantizer)
199+
200+
def _build_quantizer(self) -> TorchAOQuantizer:
201+
if self.backend in [BackendType.X86_QUANTIZER_AO, BackendType.X86_QUANTIZER_NNCF]:
202+
quantizer = X86InductorQuantizer()
203+
quantizer.set_global(get_default_x86_inductor_quantization_config())
204+
return quantizer
205+
quantizer_kwargs = {}
206+
for key in (
207+
"mode",
208+
"preset",
209+
"target_device",
210+
"model_type",
211+
"ignored_scope",
212+
):
213+
if key in self.compression_params:
214+
quantizer_kwargs[key] = self.compression_params[key]
215+
advanced_parameters: AdvancedQuantizationParameters = self.compression_params.get(
216+
"advanced_parameters", AdvancedQuantizationParameters()
217+
)
218+
quantizer_kwargs["overflow_fix"] = advanced_parameters.overflow_fix
219+
quantizer_kwargs["quantize_outputs"] = advanced_parameters.quantize_outputs
220+
quantizer_kwargs["activations_quantization_params"] = advanced_parameters.activations_quantization_params
221+
quantizer_kwargs["weights_quantization_params"] = advanced_parameters.weights_quantization_params
222+
quantizer_kwargs["quantizer_propagation_rule"] = advanced_parameters.quantizer_propagation_rule
223+
224+
return OpenVINOQuantizer(**quantizer_kwargs)

0 commit comments

Comments
 (0)