Skip to content

Commit 0d42595

Browse files
authored
Aligned bf16 tuning vs f32 inference for 4bit compression (#3493)
### Changes Always cast input to float32 inside FQ + LoRA. Benchmark results with new schema on the https://github.com/ljaljushkin/nncf_pytorch/tree/nl/ref_benchmark with small modifications from @nikita-malininn's branch https://github.com/nikita-malininn/nncf/tree/nm/ref_benchmark: device | dtype | exec_type | tensor_type | granularity | symmetric | narrow_range | timing_mode | num_runs | input_size -- | -- | -- | -- | -- | -- | -- | -- | -- | -- cuda | bfloat16 | ExecutionType.REGULAR | TensorType.WEIGHTS | GranularityType.PER_CHANNEL | TRUE | FALSE | TimingMode.KERNEL | 1000 | [2048, 128256] name | Mode | forward_avg, ms | backward_avg, ms | memory, Gb -- | -- | -- | -- | -- compile (PR) | sym | 6.5 | 10.7 | 3.9 compile (PR) | asym | 6.8 | 10.7 | 3.9 compile (before) | sym | 1.6 | 9.5 | 4.2 compile (before) | asym | 1.9 | 9.5 | 3.9 not compiled (PR) | sym | 19.0 | 46.6 | 5.9 not compiled (PR) | asym | 19.6 | 47.0 | 5.9 not compiled (before) | sym | 9.2 | 37.0 | 5.4 not compiled (before) | asym | 9.6 | 37.0 | 5.4 There's an overhead on forward, but it's leveled up by using torch.compile. There's a 1-6% overhead on RTX per epoch, and on A100, depending on the setup, there can even be a boost of 6% or a slowdown of 3%. ![image](https://github.com/user-attachments/assets/a7eb98e9-a906-4462-98a3-c4e2e061eb5a) ![image](https://github.com/user-attachments/assets/1e803a51-1a2b-4a44-9603-b7987e940bb2) ### Reason for changes Minimize the disparity in precision between the Torch model and its exported OV equivalent. The full alignment would be very inefficient, so here's a compromise: align accuracy with minimal overhead on the forward pass. e2e test on `facebook/opt-125m` proves that output is the same now within default absolute tolerance (1e-8) instead of 1e-2 one: https://github.com/openvinotoolkit/nncf/pull/3493/files#diff-7a4f90fe4f07d515df355d6fb618112d7d3fe88eb8ba777e502c695a7c715010R170 Previously, there were 3 problematic models with significant difference in accuracy, now it's much more aligned: ![image](https://github.com/user-attachments/assets/a22c855a-7c00-4f77-895f-91ce5713387c) ### Related tickets 166195 ### Tests test examples - https://github.com/openvinotoolkit/nncf/actions/runs/15024278726/job/42221028011
1 parent b3c9119 commit 0d42595

File tree

5 files changed

+75
-57
lines changed

5 files changed

+75
-57
lines changed

examples/llm_compression/torch/qat_with_lora/README.md

Lines changed: 31 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -64,28 +64,37 @@ Where:
6464
- `PPL_PTWC` is the perplexity after applying the best Post-Training Weight Compression method identified
6565
for each specific model: this was "AWQ + Scale Estimation + GPTQ" for "HuggingFaceTB/SmolLM-1.7B-Instruct",
6666
and "AWQ + Scale Estimation" for all other models evaluated.
67-
- `PPL_QAT+LoRA` is the perplexity after applying Quantization-Aware Training with LoRA.
67+
- `PPL_QAT+LoRA` is the perplexity after applying Quantization-Aware Training with LoRA for 10 epochs.
6868

6969
All quantization methods compressed the models to `INT4_ASYM` precision with a group size of `64`.
7070

71-
| Model | Precision | Wikitext,<br>word_ppl | Improvement |
72-
|------------------------------------|-------------------|-----------------------|-------------|
73-
| google/gemma-2-2b-it | BF16 | 15.02 | |
74-
| google/gemma-2-2b-it | INT4 (QAT + LoRA) | 15.13 | 86% |
75-
| google/gemma-2-2b-it | INT4 (best PTWC) | 15.80 | |
76-
| microsoft/phi3-mini-4k-instruct | BF16 | 9.49 | |
77-
| microsoft/phi3-mini-4k-instruct | INT4 (QAT + LoRA) | 10.12 | 27% |
78-
| microsoft/phi3-mini-4k-instruct | INT4 (best PTWC) | 10.36 | |
79-
| Qwen/Qwen2.5-3B-Instruct | BF16 | 11.01 | |
80-
| Qwen/Qwen2.5-3B-Instruct | INT4 (QAT + LoRA) | 11.49 | 25% |
81-
| Qwen/Qwen2.5-3B-Instruct | INT4 (best PTWC) | 11.65 | |
82-
| HuggingFaceTB/SmolLM-1.7B-Instruct | BF16 | 19.11 | |
83-
| HuggingFaceTB/SmolLM-1.7B-Instruct | INT4 (QAT + LoRA) | 19.25 | 79% |
84-
| HuggingFaceTB/SmolLM-1.7B-Instruct | INT4 (best PTWC) | 19.79 | |
85-
| mistralai/Mistral-7B-v0.3 | BF16 | 8.21 | |
86-
| mistralai/Mistral-7B-v0.3 | INT4 (QAT + LoRA) | 8.38 | 12% |
87-
| mistralai/Mistral-7B-v0.3 | INT4 (best PTWC) | 8.40 | |
88-
| meta-llama/Llama-3.2-3B-Instruct | BF16 | 12.67 | |
89-
| meta-llama/Llama-3.2-3B-Instruct | INT4 (QAT + LoRA) | 12.82 | 73% |
90-
| meta-llama/Llama-3.2-3B-Instruct | INT4 (best PTWC) | 13.22 | |
91-
| | | Average | 50.4% |
71+
| Model | Precision | Wikitext,<br>word_ppl | Improvement |
72+
|-------------------------------------|-------------------|-----------------------|-------------|
73+
| google/gemma-2-2b-it | BF16 | 15.02 | |
74+
| google/gemma-2-2b-it | INT4 (QAT + LoRA) | 15.09 | 91% |
75+
| google/gemma-2-2b-it | INT4 (best PTWC) | 15.80 | |
76+
| microsoft/phi3-mini-4k-instruct | BF16 | 9.49 | |
77+
| microsoft/phi3-mini-4k-instruct | INT4 (QAT + LoRA) | 10.04 | 37% |
78+
| microsoft/phi3-mini-4k-instruct | INT4 (best PTWC) | 10.36 | |
79+
| Qwen/Qwen2.5-3B-Instruct | BF16 | 11.01 | |
80+
| Qwen/Qwen2.5-3B-Instruct | INT4 (QAT + LoRA) | 11.44 | 33% |
81+
| Qwen/Qwen2.5-3B-Instruct | INT4 (best PTWC) | 11.65 | |
82+
| HuggingFaceTB/SmolLM-1.7B-Instruct | BF16 | 19.11 | |
83+
| HuggingFaceTB/SmolLM-1.7B-Instruct | INT4 (QAT + LoRA) | 19.34 | 66% |
84+
| HuggingFaceTB/SmolLM-1.7B-Instruct | INT4 (best PTWC) | 19.79 | |
85+
| mistralai/Mistral-7B-v0.3 | BF16 | 8.21 | |
86+
| mistralai/Mistral-7B-v0.3 | INT4 (QAT + LoRA) | 8.36 | 20% |
87+
| mistralai/Mistral-7B-v0.3 | INT4 (best PTWC) | 8.40 | |
88+
| meta-llama/Llama-3.2-1B-Instruct | BF16 | 16.30 | |
89+
| meta-llama/Llama-3.2-1B-Instruct | INT4 (QAT + LoRA) | 17.12 | 40% |
90+
| meta-llama/Llama-3.2-1B-Instruct | INT4 (best PTWC) | 17.67 | |
91+
| meta-llama/Llama-3.2-3B-Instruct | BF16 | 12.67 | |
92+
| meta-llama/Llama-3.2-3B-Instruct | INT4 (QAT + LoRA) | 13.00 | 39% |
93+
| meta-llama/Llama-3.2-3B-Instruct | INT4 (best PTWC) | 13.22 | |
94+
| meta-llama/Meta-Llama-3-8B-Instruct | BF16 | 10.22 | |
95+
| meta-llama/Meta-Llama-3-8B-Instruct | INT4 (QAT + LoRA) | 10.30 | 62% |
96+
| meta-llama/Meta-Llama-3-8B-Instruct | INT4 (best PTWC) | 10.45 | |
97+
| microsoft/phi3.5-mini-instruct | BF16 | 10.00 | |
98+
| microsoft/phi3.5-mini-instruct | INT4 (QAT + LoRA) | 10.53 | 37% |
99+
| microsoft/phi3.5-mini-instruct | INT4 (best PTWC) | 10.71 | |
100+
| | | Average | 46% |

nncf/torch/quantization/quantize_functions.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -126,23 +126,19 @@ def forward(ctx, input_, input_shape, scale, level_low, level_high, levels):
126126
input_low = torch.where(scale > 0, -scale, -scale / level_low * level_high)
127127
# 15/8 * scale or (2-1/8) * scale
128128
input_range = torch.abs((2 + 1 / level_low) * scale)
129-
130-
if input_.dtype in [torch.bfloat16, torch.float16]:
131-
input_low = input_low.type(input_.dtype)
132-
input_range = input_range.type(input_.dtype)
133-
129+
dtype = input_.dtype
134130
original_shape = input_.shape
135131
input_ = input_.reshape(input_shape)
136132

137-
output = RQ.Quantize_forward(input_, input_low, input_range, levels)
133+
output = RQ.Quantize_forward(input_.type(torch.float32), input_low, input_range, levels)
138134

139135
ctx.save_for_backward(input_, input_low, input_range)
140136
ctx.level_low = level_low
141137
ctx.level_high = level_high
142138
ctx.levels = levels
143139

144140
output = output.reshape(original_shape)
145-
return output
141+
return output.type(dtype)
146142

147143
@staticmethod
148144
def backward(ctx, grad_output):
@@ -168,14 +164,11 @@ def backward(ctx, grad_output):
168164
class QuantizeAsymmetricTorch(torch.autograd.Function):
169165
@staticmethod
170166
def forward(ctx, input_, input_shape, input_low, input_range, level_low, level_high, levels):
171-
if input_.dtype in [torch.bfloat16, torch.float16]:
172-
input_low = input_low.type(input_.dtype)
173-
input_range = input_range.type(input_.dtype)
174-
167+
dtype = input_.dtype
175168
original_shape = input_.shape
176169
input_ = input_.reshape(input_shape)
177170

178-
output = RQ.Quantize_forward(input_, input_low, input_range, levels)
171+
output = RQ.Quantize_forward(input_.type(torch.float32), input_low, input_range, levels)
179172

180173
# Save tensors for backward pass
181174
ctx.save_for_backward(input_, input_low, input_range)
@@ -184,7 +177,7 @@ def forward(ctx, input_, input_shape, input_low, input_range, level_low, level_h
184177
ctx.levels = levels
185178

186179
output = output.reshape(original_shape)
187-
return output
180+
return output.type(dtype)
188181

189182
@staticmethod
190183
def backward(ctx, grad_output):

tests/cross_fw/examples/example_scope.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -282,8 +282,8 @@
282282
"cpu": "Intel(R) Core(TM) i9-10980XE CPU @ 3.00GHz",
283283
"accuracy_tolerance": 0.1,
284284
"accuracy_metrics": {
285-
"perplexity_diff_torch": 0.6,
286-
"best_ov_perplexity": 35.1
285+
"perplexity_diff_torch": 0.75,
286+
"best_ov_perplexity": 34.94
287287
}
288288
},
289289
"quantization_aware_training_tensorflow_mobilenet_v2": {

tests/torch2/function_hook/quantization/strip/test_strip_dequantize.py

Lines changed: 32 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,24 @@
1212
from dataclasses import dataclass
1313
from typing import Any
1414

15+
import openvino as ov
1516
import pytest
1617
import torch
18+
from openvino._pyopenvino.properties.hint import inference_precision
19+
from openvino.tools.ovc import convert_model
1720
from pytest_mock import MockerFixture
1821
from torch import nn
1922

2023
import nncf
2124
import nncf.torch
2225
from nncf.common.quantization.structs import QuantizationScheme
26+
from nncf.openvino.optimized_functions.models import _compile_ov_model
2327
from nncf.parameters import CompressWeightsMode
2428
from nncf.parameters import StripFormat
2529
from nncf.torch.function_hook.wrapper import get_hook_storage
2630
from nncf.torch.quantization.layers import AsymmetricLoraQuantizer
2731
from nncf.torch.quantization.layers import BaseQuantizer
32+
from nncf.torch.quantization.layers import BaseWeightsDecompressor
2833
from nncf.torch.quantization.layers import INT4AsymmetricWeightsDecompressor as INT4AsymDQ
2934
from nncf.torch.quantization.layers import INT4SymmetricWeightsDecompressor as INT4SymDQ
3035
from nncf.torch.quantization.layers import INT8AsymmetricWeightsDecompressor as INT8AsymDQ
@@ -53,7 +58,8 @@ class ParamStripLora:
5358
mode: CompressWeightsMode
5459
decompressor_class: type
5560
torch_dtype: torch.dtype
56-
atol: float
61+
torch_atol: float
62+
ov_atol: float
5763
weight_dtype: torch.dtype
5864

5965
def __str__(self) -> str:
@@ -76,17 +82,14 @@ def num_call_pack_weight(self) -> int:
7682
@pytest.mark.parametrize(
7783
("param"),
7884
(
79-
ParamStripLora(CompressWeightsMode.INT4_ASYM, INT4AsymDQ, torch.float32, 1e-3, torch.uint8),
80-
ParamStripLora(CompressWeightsMode.INT4_ASYM, INT4AsymDQ, torch.float16, 1e-8, torch.uint8),
81-
ParamStripLora(CompressWeightsMode.INT4_ASYM, INT4AsymDQ, torch.bfloat16, 1e-2, torch.uint8),
82-
ParamStripLora(CompressWeightsMode.INT4_SYM, INT4SymDQ, torch.float32, 1e-3, torch.uint8),
83-
# torch.compile introduces bigger diff for sym
84-
ParamStripLora(CompressWeightsMode.INT4_SYM, INT4SymDQ, torch.float16, 1e-3, torch.uint8),
85-
ParamStripLora(CompressWeightsMode.INT4_SYM, INT4SymDQ, torch.bfloat16, 1e-2, torch.uint8),
86-
# int8 uses per-channel vs int4 group-wise
87-
ParamStripLora(CompressWeightsMode.INT8_SYM, INT8SymDQ, torch.bfloat16, 1e-2, torch.int8),
88-
# int8 uses per-channel vs int4 group-wise
89-
ParamStripLora(CompressWeightsMode.INT8_ASYM, INT8AsymDQ, torch.bfloat16, 1e-8, torch.uint8),
85+
ParamStripLora(CompressWeightsMode.INT4_ASYM, INT4AsymDQ, torch.float32, 1e-3, 1e-3, torch.uint8),
86+
ParamStripLora(CompressWeightsMode.INT4_ASYM, INT4AsymDQ, torch.float16, 1e-3, 1e-3, torch.uint8),
87+
ParamStripLora(CompressWeightsMode.INT4_ASYM, INT4AsymDQ, torch.bfloat16, 1e-8, 1e-1, torch.uint8),
88+
ParamStripLora(CompressWeightsMode.INT4_SYM, INT4SymDQ, torch.float32, 1e-3, 1e-3, torch.uint8),
89+
ParamStripLora(CompressWeightsMode.INT4_SYM, INT4SymDQ, torch.float16, 1e-8, 1e-3, torch.uint8),
90+
ParamStripLora(CompressWeightsMode.INT4_SYM, INT4SymDQ, torch.bfloat16, 1e-8, 1e-2, torch.uint8),
91+
ParamStripLora(CompressWeightsMode.INT8_SYM, INT8SymDQ, torch.bfloat16, 1e-2, 1e-3, torch.int8),
92+
ParamStripLora(CompressWeightsMode.INT8_ASYM, INT8AsymDQ, torch.bfloat16, 1e-8, 1e-3, torch.uint8),
9093
),
9194
ids=str,
9295
)
@@ -114,12 +117,24 @@ def test_nncf_strip_lora_model(param: ParamStripLora, mocker: MockerFixture):
114117
compressed_model, do_copy=True, strip_format=StripFormat.DQ, example_input=example_input
115118
)
116119
stripped_output = strip_compressed_model(example_input)
117-
118120
assert pack_weight_spy.call_count == param.num_call_pack_weight
119121
assert strip_compressed_model.linear.weight.dtype == param.weight_dtype
120122

121123
check_compression_modules(strip_compressed_model, param.decompressor_class)
122-
assert torch.allclose(compressed_output, stripped_output, atol=param.atol)
124+
assert torch.allclose(compressed_output, stripped_output, atol=param.torch_atol)
125+
126+
example_input = example_input.type(torch.float32)
127+
hook_storage = get_hook_storage(strip_compressed_model)
128+
for _, module in hook_storage.named_hooks():
129+
if isinstance(module, BaseWeightsDecompressor):
130+
module.result_dtype = torch.float32
131+
ov_model = convert_model(strip_compressed_model, example_input=example_input)
132+
compiled_model = _compile_ov_model(ov_model, device_name="CPU", config={inference_precision(): ov.Type.f32})
133+
infer_request = compiled_model.create_infer_request()
134+
res = infer_request.infer(example_input)
135+
out_name = compiled_model.outputs[0]
136+
ov_output = torch.from_numpy(res[out_name])
137+
assert torch.allclose(compressed_output.type(torch.float32), ov_output, atol=param.ov_atol)
123138

124139

125140
SIGNED_WEIGHT_SAMPLE = [-1.0, -0.75, -0.5, -0.25, 0.0, 0.25, 0.5, 0.75]
@@ -155,7 +170,7 @@ def test_sym_fq_to_decompressor(param: ParamSymFQ):
155170

156171
scale_shape = (1, 1)
157172
scale = torch.tensor(SCALE_SAMPLE)
158-
scale = scale.expand(scale_shape).to(torch.float16)
173+
scale = scale.expand(scale_shape)
159174

160175
# reference scale calculates with this formula:
161176
# levels = (2 ** num_bits)
@@ -246,10 +261,10 @@ def test_asym_fq_to_decompressor(param: ParamAsymFQ):
246261
ref_zero_point = ref_zero_point.expand(scale_shape).to(torch.uint8)
247262

248263
input_low = torch.tensor(INPUT_LOW_SAMPLE)
249-
input_low = input_low.expand(scale_shape).to(param.torch_dtype)
264+
input_low = input_low.expand(scale_shape)
250265

251266
input_range = torch.tensor(INPUT_RANGE_SAMPLE)
252-
input_range = input_range.expand(scale_shape).to(param.torch_dtype)
267+
input_range = input_range.expand(scale_shape)
253268

254269
qspec = PTQuantizerSpec(
255270
num_bits=param.num_bits,

tests/torch2/function_hook/quantization/test_fq_lora.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -166,9 +166,10 @@ def test_fq_lora_tuning(tmp_path, mode, backup_mode, compression_kwargs, ref_num
166166
tuned_vs_stripped = vm.calculate_similarity(tuned_output, stripped_output)
167167
tuned_vs_stripped_ov = vm.calculate_similarity(tuned_output, stripped_ov_output)
168168

169-
atol = 0.03 if mode == nncf.CompressWeightsMode.INT4_SYM else 0.01 # torch.compile introduces bigger diff
170-
assert torch.allclose(tuned_vs_stripped, vm.validation_ref, atol=atol)
171-
assert torch.allclose(tuned_vs_stripped_ov, vm.validation_ref, atol=atol)
169+
# torch.compiled version of FQ+LoRA leads to a small error
170+
atol = 1e-2 if mode == nncf.CompressWeightsMode.INT4_SYM else 1e-8
171+
assert torch.allclose(tuned_vs_stripped, vm.validation_ref, atol)
172+
assert torch.allclose(tuned_vs_stripped_ov, vm.validation_ref, atol)
172173

173174

174175
def test_checkpoint_loading(tmp_path: Path, use_cuda: bool):

0 commit comments

Comments
 (0)