12
12
from dataclasses import dataclass
13
13
from typing import Any
14
14
15
+ import openvino as ov
15
16
import pytest
16
17
import torch
18
+ from openvino ._pyopenvino .properties .hint import inference_precision
19
+ from openvino .tools .ovc import convert_model
17
20
from pytest_mock import MockerFixture
18
21
from torch import nn
19
22
20
23
import nncf
21
24
import nncf .torch
22
25
from nncf .common .quantization .structs import QuantizationScheme
26
+ from nncf .openvino .optimized_functions .models import _compile_ov_model
23
27
from nncf .parameters import CompressWeightsMode
24
28
from nncf .parameters import StripFormat
25
29
from nncf .torch .function_hook .wrapper import get_hook_storage
26
30
from nncf .torch .quantization .layers import AsymmetricLoraQuantizer
27
31
from nncf .torch .quantization .layers import BaseQuantizer
32
+ from nncf .torch .quantization .layers import BaseWeightsDecompressor
28
33
from nncf .torch .quantization .layers import INT4AsymmetricWeightsDecompressor as INT4AsymDQ
29
34
from nncf .torch .quantization .layers import INT4SymmetricWeightsDecompressor as INT4SymDQ
30
35
from nncf .torch .quantization .layers import INT8AsymmetricWeightsDecompressor as INT8AsymDQ
@@ -53,7 +58,8 @@ class ParamStripLora:
53
58
mode : CompressWeightsMode
54
59
decompressor_class : type
55
60
torch_dtype : torch .dtype
56
- atol : float
61
+ torch_atol : float
62
+ ov_atol : float
57
63
weight_dtype : torch .dtype
58
64
59
65
def __str__ (self ) -> str :
@@ -76,17 +82,14 @@ def num_call_pack_weight(self) -> int:
76
82
@pytest .mark .parametrize (
77
83
("param" ),
78
84
(
79
- ParamStripLora (CompressWeightsMode .INT4_ASYM , INT4AsymDQ , torch .float32 , 1e-3 , torch .uint8 ),
80
- ParamStripLora (CompressWeightsMode .INT4_ASYM , INT4AsymDQ , torch .float16 , 1e-8 , torch .uint8 ),
81
- ParamStripLora (CompressWeightsMode .INT4_ASYM , INT4AsymDQ , torch .bfloat16 , 1e-2 , torch .uint8 ),
82
- ParamStripLora (CompressWeightsMode .INT4_SYM , INT4SymDQ , torch .float32 , 1e-3 , torch .uint8 ),
83
- # torch.compile introduces bigger diff for sym
84
- ParamStripLora (CompressWeightsMode .INT4_SYM , INT4SymDQ , torch .float16 , 1e-3 , torch .uint8 ),
85
- ParamStripLora (CompressWeightsMode .INT4_SYM , INT4SymDQ , torch .bfloat16 , 1e-2 , torch .uint8 ),
86
- # int8 uses per-channel vs int4 group-wise
87
- ParamStripLora (CompressWeightsMode .INT8_SYM , INT8SymDQ , torch .bfloat16 , 1e-2 , torch .int8 ),
88
- # int8 uses per-channel vs int4 group-wise
89
- ParamStripLora (CompressWeightsMode .INT8_ASYM , INT8AsymDQ , torch .bfloat16 , 1e-8 , torch .uint8 ),
85
+ ParamStripLora (CompressWeightsMode .INT4_ASYM , INT4AsymDQ , torch .float32 , 1e-3 , 1e-3 , torch .uint8 ),
86
+ ParamStripLora (CompressWeightsMode .INT4_ASYM , INT4AsymDQ , torch .float16 , 1e-3 , 1e-3 , torch .uint8 ),
87
+ ParamStripLora (CompressWeightsMode .INT4_ASYM , INT4AsymDQ , torch .bfloat16 , 1e-8 , 1e-1 , torch .uint8 ),
88
+ ParamStripLora (CompressWeightsMode .INT4_SYM , INT4SymDQ , torch .float32 , 1e-3 , 1e-3 , torch .uint8 ),
89
+ ParamStripLora (CompressWeightsMode .INT4_SYM , INT4SymDQ , torch .float16 , 1e-8 , 1e-3 , torch .uint8 ),
90
+ ParamStripLora (CompressWeightsMode .INT4_SYM , INT4SymDQ , torch .bfloat16 , 1e-8 , 1e-2 , torch .uint8 ),
91
+ ParamStripLora (CompressWeightsMode .INT8_SYM , INT8SymDQ , torch .bfloat16 , 1e-2 , 1e-3 , torch .int8 ),
92
+ ParamStripLora (CompressWeightsMode .INT8_ASYM , INT8AsymDQ , torch .bfloat16 , 1e-8 , 1e-3 , torch .uint8 ),
90
93
),
91
94
ids = str ,
92
95
)
@@ -114,12 +117,24 @@ def test_nncf_strip_lora_model(param: ParamStripLora, mocker: MockerFixture):
114
117
compressed_model , do_copy = True , strip_format = StripFormat .DQ , example_input = example_input
115
118
)
116
119
stripped_output = strip_compressed_model (example_input )
117
-
118
120
assert pack_weight_spy .call_count == param .num_call_pack_weight
119
121
assert strip_compressed_model .linear .weight .dtype == param .weight_dtype
120
122
121
123
check_compression_modules (strip_compressed_model , param .decompressor_class )
122
- assert torch .allclose (compressed_output , stripped_output , atol = param .atol )
124
+ assert torch .allclose (compressed_output , stripped_output , atol = param .torch_atol )
125
+
126
+ example_input = example_input .type (torch .float32 )
127
+ hook_storage = get_hook_storage (strip_compressed_model )
128
+ for _ , module in hook_storage .named_hooks ():
129
+ if isinstance (module , BaseWeightsDecompressor ):
130
+ module .result_dtype = torch .float32
131
+ ov_model = convert_model (strip_compressed_model , example_input = example_input )
132
+ compiled_model = _compile_ov_model (ov_model , device_name = "CPU" , config = {inference_precision (): ov .Type .f32 })
133
+ infer_request = compiled_model .create_infer_request ()
134
+ res = infer_request .infer (example_input )
135
+ out_name = compiled_model .outputs [0 ]
136
+ ov_output = torch .from_numpy (res [out_name ])
137
+ assert torch .allclose (compressed_output .type (torch .float32 ), ov_output , atol = param .ov_atol )
123
138
124
139
125
140
SIGNED_WEIGHT_SAMPLE = [- 1.0 , - 0.75 , - 0.5 , - 0.25 , 0.0 , 0.25 , 0.5 , 0.75 ]
@@ -155,7 +170,7 @@ def test_sym_fq_to_decompressor(param: ParamSymFQ):
155
170
156
171
scale_shape = (1 , 1 )
157
172
scale = torch .tensor (SCALE_SAMPLE )
158
- scale = scale .expand (scale_shape ). to ( torch . float16 )
173
+ scale = scale .expand (scale_shape )
159
174
160
175
# reference scale calculates with this formula:
161
176
# levels = (2 ** num_bits)
@@ -246,10 +261,10 @@ def test_asym_fq_to_decompressor(param: ParamAsymFQ):
246
261
ref_zero_point = ref_zero_point .expand (scale_shape ).to (torch .uint8 )
247
262
248
263
input_low = torch .tensor (INPUT_LOW_SAMPLE )
249
- input_low = input_low .expand (scale_shape ). to ( param . torch_dtype )
264
+ input_low = input_low .expand (scale_shape )
250
265
251
266
input_range = torch .tensor (INPUT_RANGE_SAMPLE )
252
- input_range = input_range .expand (scale_shape ). to ( param . torch_dtype )
267
+ input_range = input_range .expand (scale_shape )
253
268
254
269
qspec = PTQuantizerSpec (
255
270
num_bits = param .num_bits ,
0 commit comments