Skip to content

Aligned bf16 tuning vs f32 inference for 4bit compression #3493

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 31 additions & 22 deletions examples/llm_compression/torch/qat_with_lora/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,28 +64,37 @@ Where:
- `PPL_PTWC` is the perplexity after applying the best Post-Training Weight Compression method identified
for each specific model: this was "AWQ + Scale Estimation + GPTQ" for "HuggingFaceTB/SmolLM-1.7B-Instruct",
and "AWQ + Scale Estimation" for all other models evaluated.
- `PPL_QAT+LoRA` is the perplexity after applying Quantization-Aware Training with LoRA.
- `PPL_QAT+LoRA` is the perplexity after applying Quantization-Aware Training with LoRA for 10 epochs.

All quantization methods compressed the models to `INT4_ASYM` precision with a group size of `64`.

| Model | Precision | Wikitext,<br>word_ppl | Improvement |
|------------------------------------|-------------------|-----------------------|-------------|
| google/gemma-2-2b-it | BF16 | 15.02 | |
| google/gemma-2-2b-it | INT4 (QAT + LoRA) | 15.13 | 86% |
| google/gemma-2-2b-it | INT4 (best PTWC) | 15.80 | |
| microsoft/phi3-mini-4k-instruct | BF16 | 9.49 | |
| microsoft/phi3-mini-4k-instruct | INT4 (QAT + LoRA) | 10.12 | 27% |
| microsoft/phi3-mini-4k-instruct | INT4 (best PTWC) | 10.36 | |
| Qwen/Qwen2.5-3B-Instruct | BF16 | 11.01 | |
| Qwen/Qwen2.5-3B-Instruct | INT4 (QAT + LoRA) | 11.49 | 25% |
| Qwen/Qwen2.5-3B-Instruct | INT4 (best PTWC) | 11.65 | |
| HuggingFaceTB/SmolLM-1.7B-Instruct | BF16 | 19.11 | |
| HuggingFaceTB/SmolLM-1.7B-Instruct | INT4 (QAT + LoRA) | 19.25 | 79% |
| HuggingFaceTB/SmolLM-1.7B-Instruct | INT4 (best PTWC) | 19.79 | |
| mistralai/Mistral-7B-v0.3 | BF16 | 8.21 | |
| mistralai/Mistral-7B-v0.3 | INT4 (QAT + LoRA) | 8.38 | 12% |
| mistralai/Mistral-7B-v0.3 | INT4 (best PTWC) | 8.40 | |
| meta-llama/Llama-3.2-3B-Instruct | BF16 | 12.67 | |
| meta-llama/Llama-3.2-3B-Instruct | INT4 (QAT + LoRA) | 12.82 | 73% |
| meta-llama/Llama-3.2-3B-Instruct | INT4 (best PTWC) | 13.22 | |
| | | Average | 50.4% |
| Model | Precision | Wikitext,<br>word_ppl | Improvement |
|-------------------------------------|-------------------|-----------------------|-------------|
| google/gemma-2-2b-it | BF16 | 15.02 | |
| google/gemma-2-2b-it | INT4 (QAT + LoRA) | 15.09 | 91% |
| google/gemma-2-2b-it | INT4 (best PTWC) | 15.80 | |
| microsoft/phi3-mini-4k-instruct | BF16 | 9.49 | |
| microsoft/phi3-mini-4k-instruct | INT4 (QAT + LoRA) | 10.04 | 37% |
| microsoft/phi3-mini-4k-instruct | INT4 (best PTWC) | 10.36 | |
| Qwen/Qwen2.5-3B-Instruct | BF16 | 11.01 | |
| Qwen/Qwen2.5-3B-Instruct | INT4 (QAT + LoRA) | 11.44 | 33% |
| Qwen/Qwen2.5-3B-Instruct | INT4 (best PTWC) | 11.65 | |
| HuggingFaceTB/SmolLM-1.7B-Instruct | BF16 | 19.11 | |
| HuggingFaceTB/SmolLM-1.7B-Instruct | INT4 (QAT + LoRA) | 19.34 | 66% |
| HuggingFaceTB/SmolLM-1.7B-Instruct | INT4 (best PTWC) | 19.79 | |
| mistralai/Mistral-7B-v0.3 | BF16 | 8.21 | |
| mistralai/Mistral-7B-v0.3 | INT4 (QAT + LoRA) | 8.36 | 20% |
| mistralai/Mistral-7B-v0.3 | INT4 (best PTWC) | 8.40 | |
| meta-llama/Llama-3.2-1B-Instruct | BF16 | 16.30 | |
| meta-llama/Llama-3.2-1B-Instruct | INT4 (QAT + LoRA) | 17.12 | 40% |
| meta-llama/Llama-3.2-1B-Instruct | INT4 (best PTWC) | 17.67 | |
| meta-llama/Llama-3.2-3B-Instruct | BF16 | 12.67 | |
| meta-llama/Llama-3.2-3B-Instruct | INT4 (QAT + LoRA) | 13.00 | 39% |
| meta-llama/Llama-3.2-3B-Instruct | INT4 (best PTWC) | 13.22 | |
| meta-llama/Meta-Llama-3-8B-Instruct | BF16 | 10.22 | |
| meta-llama/Meta-Llama-3-8B-Instruct | INT4 (QAT + LoRA) | 10.30 | 62% |
| meta-llama/Meta-Llama-3-8B-Instruct | INT4 (best PTWC) | 10.45 | |
| microsoft/phi3.5-mini-instruct | BF16 | 10.00 | |
| microsoft/phi3.5-mini-instruct | INT4 (QAT + LoRA) | 10.53 | 37% |
| microsoft/phi3.5-mini-instruct | INT4 (best PTWC) | 10.71 | |
| | | Average | 46% |
19 changes: 6 additions & 13 deletions nncf/torch/quantization/quantize_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,23 +126,19 @@ def forward(ctx, input_, input_shape, scale, level_low, level_high, levels):
input_low = torch.where(scale > 0, -scale, -scale / level_low * level_high)
# 15/8 * scale or (2-1/8) * scale
input_range = torch.abs((2 + 1 / level_low) * scale)

if input_.dtype in [torch.bfloat16, torch.float16]:
input_low = input_low.type(input_.dtype)
input_range = input_range.type(input_.dtype)

dtype = input_.dtype
original_shape = input_.shape
input_ = input_.reshape(input_shape)

output = RQ.Quantize_forward(input_, input_low, input_range, levels)
output = RQ.Quantize_forward(input_.type(torch.float32), input_low, input_range, levels)

ctx.save_for_backward(input_, input_low, input_range)
ctx.level_low = level_low
ctx.level_high = level_high
ctx.levels = levels

output = output.reshape(original_shape)
return output
return output.type(dtype)

@staticmethod
def backward(ctx, grad_output):
Expand All @@ -168,14 +164,11 @@ def backward(ctx, grad_output):
class QuantizeAsymmetricTorch(torch.autograd.Function):
@staticmethod
def forward(ctx, input_, input_shape, input_low, input_range, level_low, level_high, levels):
if input_.dtype in [torch.bfloat16, torch.float16]:
input_low = input_low.type(input_.dtype)
input_range = input_range.type(input_.dtype)

dtype = input_.dtype
original_shape = input_.shape
input_ = input_.reshape(input_shape)

output = RQ.Quantize_forward(input_, input_low, input_range, levels)
output = RQ.Quantize_forward(input_.type(torch.float32), input_low, input_range, levels)

# Save tensors for backward pass
ctx.save_for_backward(input_, input_low, input_range)
Expand All @@ -184,7 +177,7 @@ def forward(ctx, input_, input_shape, input_low, input_range, level_low, level_h
ctx.levels = levels

output = output.reshape(original_shape)
return output
return output.type(dtype)

@staticmethod
def backward(ctx, grad_output):
Expand Down
4 changes: 2 additions & 2 deletions tests/cross_fw/examples/example_scope.json
Original file line number Diff line number Diff line change
Expand Up @@ -282,8 +282,8 @@
"cpu": "Intel(R) Core(TM) i9-10980XE CPU @ 3.00GHz",
"accuracy_tolerance": 0.1,
"accuracy_metrics": {
"perplexity_diff_torch": 0.6,
"best_ov_perplexity": 35.1
"perplexity_diff_torch": 0.75,
"best_ov_perplexity": 34.94
}
},
"quantization_aware_training_tensorflow_mobilenet_v2": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,24 @@
from dataclasses import dataclass
from typing import Any

import openvino as ov
import pytest
import torch
from openvino._pyopenvino.properties.hint import inference_precision
from openvino.tools.ovc import convert_model
from pytest_mock import MockerFixture
from torch import nn

import nncf
import nncf.torch
from nncf.common.quantization.structs import QuantizationScheme
from nncf.openvino.optimized_functions.models import _compile_ov_model
from nncf.parameters import CompressWeightsMode
from nncf.parameters import StripFormat
from nncf.torch.function_hook.wrapper import get_hook_storage
from nncf.torch.quantization.layers import AsymmetricLoraQuantizer
from nncf.torch.quantization.layers import BaseQuantizer
from nncf.torch.quantization.layers import BaseWeightsDecompressor
from nncf.torch.quantization.layers import INT4AsymmetricWeightsDecompressor as INT4AsymDQ
from nncf.torch.quantization.layers import INT4SymmetricWeightsDecompressor as INT4SymDQ
from nncf.torch.quantization.layers import INT8AsymmetricWeightsDecompressor as INT8AsymDQ
Expand Down Expand Up @@ -53,7 +58,8 @@ class ParamStripLora:
mode: CompressWeightsMode
decompressor_class: type
torch_dtype: torch.dtype
atol: float
torch_atol: float
ov_atol: float
weight_dtype: torch.dtype

def __str__(self) -> str:
Expand All @@ -76,17 +82,14 @@ def num_call_pack_weight(self) -> int:
@pytest.mark.parametrize(
("param"),
(
ParamStripLora(CompressWeightsMode.INT4_ASYM, INT4AsymDQ, torch.float32, 1e-3, torch.uint8),
ParamStripLora(CompressWeightsMode.INT4_ASYM, INT4AsymDQ, torch.float16, 1e-8, torch.uint8),
ParamStripLora(CompressWeightsMode.INT4_ASYM, INT4AsymDQ, torch.bfloat16, 1e-2, torch.uint8),
ParamStripLora(CompressWeightsMode.INT4_SYM, INT4SymDQ, torch.float32, 1e-3, torch.uint8),
# torch.compile introduces bigger diff for sym
ParamStripLora(CompressWeightsMode.INT4_SYM, INT4SymDQ, torch.float16, 1e-3, torch.uint8),
ParamStripLora(CompressWeightsMode.INT4_SYM, INT4SymDQ, torch.bfloat16, 1e-2, torch.uint8),
# int8 uses per-channel vs int4 group-wise
ParamStripLora(CompressWeightsMode.INT8_SYM, INT8SymDQ, torch.bfloat16, 1e-2, torch.int8),
# int8 uses per-channel vs int4 group-wise
ParamStripLora(CompressWeightsMode.INT8_ASYM, INT8AsymDQ, torch.bfloat16, 1e-8, torch.uint8),
ParamStripLora(CompressWeightsMode.INT4_ASYM, INT4AsymDQ, torch.float32, 1e-3, 1e-3, torch.uint8),
ParamStripLora(CompressWeightsMode.INT4_ASYM, INT4AsymDQ, torch.float16, 1e-3, 1e-3, torch.uint8),
ParamStripLora(CompressWeightsMode.INT4_ASYM, INT4AsymDQ, torch.bfloat16, 1e-8, 1e-1, torch.uint8),
ParamStripLora(CompressWeightsMode.INT4_SYM, INT4SymDQ, torch.float32, 1e-3, 1e-3, torch.uint8),
ParamStripLora(CompressWeightsMode.INT4_SYM, INT4SymDQ, torch.float16, 1e-8, 1e-3, torch.uint8),
ParamStripLora(CompressWeightsMode.INT4_SYM, INT4SymDQ, torch.bfloat16, 1e-8, 1e-2, torch.uint8),
ParamStripLora(CompressWeightsMode.INT8_SYM, INT8SymDQ, torch.bfloat16, 1e-2, 1e-3, torch.int8),
ParamStripLora(CompressWeightsMode.INT8_ASYM, INT8AsymDQ, torch.bfloat16, 1e-8, 1e-3, torch.uint8),
Comment on lines +91 to +92
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: It's expected that the ov_tol value is higher than the torch_tol value, since the OV model executes in f32 but the torch model has activations in bf16 or f16 on tuning. Even though the ov_tol isn't very small, in a few cases it was larger before the PR.

),
ids=str,
)
Expand Down Expand Up @@ -114,12 +117,24 @@ def test_nncf_strip_lora_model(param: ParamStripLora, mocker: MockerFixture):
compressed_model, do_copy=True, strip_format=StripFormat.DQ, example_input=example_input
)
stripped_output = strip_compressed_model(example_input)

assert pack_weight_spy.call_count == param.num_call_pack_weight
assert strip_compressed_model.linear.weight.dtype == param.weight_dtype

check_compression_modules(strip_compressed_model, param.decompressor_class)
assert torch.allclose(compressed_output, stripped_output, atol=param.atol)
assert torch.allclose(compressed_output, stripped_output, atol=param.torch_atol)

example_input = example_input.type(torch.float32)
hook_storage = get_hook_storage(strip_compressed_model)
for _, module in hook_storage.named_hooks():
if isinstance(module, BaseWeightsDecompressor):
module.result_dtype = torch.float32
ov_model = convert_model(strip_compressed_model, example_input=example_input)
compiled_model = _compile_ov_model(ov_model, device_name="CPU", config={inference_precision(): ov.Type.f32})
infer_request = compiled_model.create_infer_request()
res = infer_request.infer(example_input)
out_name = compiled_model.outputs[0]
ov_output = torch.from_numpy(res[out_name])
assert torch.allclose(compressed_output.type(torch.float32), ov_output, atol=param.ov_atol)


SIGNED_WEIGHT_SAMPLE = [-1.0, -0.75, -0.5, -0.25, 0.0, 0.25, 0.5, 0.75]
Expand Down Expand Up @@ -155,7 +170,7 @@ def test_sym_fq_to_decompressor(param: ParamSymFQ):

scale_shape = (1, 1)
scale = torch.tensor(SCALE_SAMPLE)
scale = scale.expand(scale_shape).to(torch.float16)
scale = scale.expand(scale_shape)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't influence the result, just more aligned with default precision in FQ (float32).


# reference scale calculates with this formula:
# levels = (2 ** num_bits)
Expand Down Expand Up @@ -246,10 +261,10 @@ def test_asym_fq_to_decompressor(param: ParamAsymFQ):
ref_zero_point = ref_zero_point.expand(scale_shape).to(torch.uint8)

input_low = torch.tensor(INPUT_LOW_SAMPLE)
input_low = input_low.expand(scale_shape).to(param.torch_dtype)
input_low = input_low.expand(scale_shape)

input_range = torch.tensor(INPUT_RANGE_SAMPLE)
input_range = input_range.expand(scale_shape).to(param.torch_dtype)
input_range = input_range.expand(scale_shape)

qspec = PTQuantizerSpec(
num_bits=param.num_bits,
Expand Down
7 changes: 4 additions & 3 deletions tests/torch2/function_hook/quantization/test_fq_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,10 @@ def test_fq_lora_tuning(tmp_path, mode, backup_mode, compression_kwargs, ref_num
tuned_vs_stripped = vm.calculate_similarity(tuned_output, stripped_output)
tuned_vs_stripped_ov = vm.calculate_similarity(tuned_output, stripped_ov_output)

atol = 0.03 if mode == nncf.CompressWeightsMode.INT4_SYM else 0.01 # torch.compile introduces bigger diff
assert torch.allclose(tuned_vs_stripped, vm.validation_ref, atol=atol)
assert torch.allclose(tuned_vs_stripped_ov, vm.validation_ref, atol=atol)
# torch.compiled version of FQ+LoRA leads to a small error
atol = 1e-2 if mode == nncf.CompressWeightsMode.INT4_SYM else 1e-8
assert torch.allclose(tuned_vs_stripped, vm.validation_ref, atol)
assert torch.allclose(tuned_vs_stripped_ov, vm.validation_ref, atol)


def test_checkpoint_loading(tmp_path: Path, use_cuda: bool):
Expand Down
Loading