Skip to content

Aligned bf16 tuning vs f32 inference for 4bit compression #3493

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from

Conversation

ljaljushkin
Copy link
Contributor

@ljaljushkin ljaljushkin commented May 14, 2025

Changes

Always cast input to float32 inside FQ + LoRA.

Benchmark results with new schema on the https://github.com/ljaljushkin/nncf_pytorch/tree/nl/ref_benchmark with small modifications from @nikita-malininn's branch https://github.com/nikita-malininn/nncf/tree/nm/ref_benchmark:

device dtype exec_type tensor_type granularity symmetric narrow_range timing_mode num_runs input_size
cuda bfloat16 ExecutionType.REGULAR TensorType.WEIGHTS GranularityType.PER_CHANNEL TRUE FALSE TimingMode.KERNEL 1000 [2048, 128256]
name Mode forward_avg, ms backward_avg, ms memory, Gb
compile (PR) sym 6.5 10.7 3.9
compile (PR) asym 6.8 10.7 3.9
compile (before) sym 1.6 9.5 4.2
compile (before) asym 1.9 9.5 3.9
not compiled (PR) sym 19.0 46.6 5.9
not compiled (PR) asym 19.6 47.0 5.9
not compiled (before) sym 9.2 37.0 5.4
not compiled (before) asym 9.6 37.0 5.4

There's an overhead on forward, but it's leveled up by using torch.compile.

There's a 1-6% overhead on RTX per epoch, and on A100, depending on the setup, there can even be a boost of 6% or a slowdown of 3%.
image
image

Reason for changes

Minimize the disparity in precision between the Torch model and its exported OV equivalent.
The full alignment would be very inefficient, so here's a compromise: align accuracy with minimal overhead on the forward pass.

e2e test on facebook/opt-125m proves that output is the same now within default absolute tolerance (1e-8) instead of 1e-2 one:
https://github.com/openvinotoolkit/nncf/pull/3493/files#diff-7a4f90fe4f07d515df355d6fb618112d7d3fe88eb8ba777e502c695a7c715010R170

Previously, there were 3 problematic models with significant difference in accuracy, now it's much more aligned:
image

Related tickets

166195

Tests

test examples - https://github.com/openvinotoolkit/nncf/actions/runs/15024278726/job/42221028011

@github-actions github-actions bot added documentation Improvements or additions to documentation NNCF PT Pull requests that updates NNCF PyTorch labels May 14, 2025
@@ -155,7 +170,7 @@ def test_sym_fq_to_decompressor(param: ParamSymFQ):

scale_shape = (1, 1)
scale = torch.tensor(SCALE_SAMPLE)
scale = scale.expand(scale_shape).to(torch.float16)
scale = scale.expand(scale_shape)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't influence the result, just more aligned with default precision in FQ (float32).

Comment on lines +91 to +92
ParamStripLora(CompressWeightsMode.INT8_SYM, INT8SymDQ, torch.bfloat16, 1e-2, 1e-3, torch.int8),
ParamStripLora(CompressWeightsMode.INT8_ASYM, INT8AsymDQ, torch.bfloat16, 1e-8, 1e-3, torch.uint8),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: It's expected that the ov_tol value is higher than the torch_tol value, since the OV model executes in f32 but the torch model has activations in bf16 or f16 on tuning. Even though the ov_tol isn't very small, in a few cases it was larger before the PR.

@ljaljushkin ljaljushkin marked this pull request as ready for review May 14, 2025 17:11
@ljaljushkin ljaljushkin requested a review from a team as a code owner May 14, 2025 17:11
@nikita-malininn
Copy link
Collaborator

I don't get the results from the performance/memory table. Or this PR slows down the reference and increases memory consumption or the table filled somehow wrong.

@ljaljushkin
Copy link
Contributor Author

I don't get the results from the performance/memory table. Or this PR slows down the reference and increases memory consumption or the table filled somehow wrong.

Yes, your understanding is correct. This is the price for the accuracy.
As I mentioned in the description, increase is not that bad in comparison with not compiled version and backward is not strongly affected.

@nikita-malininn
Copy link
Collaborator

I don't get the results from the performance/memory table. Or this PR slows down the reference and increases memory consumption or the table filled somehow wrong.

Yes, your understanding is correct. This is the price for the accuracy. As I mentioned in the description, increase is not that bad in comparison with not compiled version and backward is not strongly affected.

About ~4x slowness for the compiled forward version in the symmetric case (or ~2x slowness for non-compiled) and a 10% memory consumption increase - is it not that bad? I don't think so.

@ljaljushkin
Copy link
Contributor Author

ljaljushkin commented May 15, 2025

I don't get the results from the performance/memory table. Or this PR slows down the reference and increases memory consumption or the table filled somehow wrong.

Yes, your understanding is correct. This is the price for the accuracy. As I mentioned in the description, increase is not that bad in comparison with not compiled version and backward is not strongly affected.

About ~4x slowness for the compiled forward version in the symmetric case (or ~2x slowness for non-compiled) and a 10% memory consumption increase - is it not that bad? I don't think so.

It's still 1.4x faster than not compiled before PR and 2.9x than not compiled with PR.
I don't see instruments to improve it rather than select between compiled and not compiled version.

Do you have some suggestion of improvement?

@ljaljushkin
Copy link
Contributor Author

Updated description with total time required for 1 epoch before and with PR.
There's a 1-6% overhead on RTX per epoch, and on A100, depending on the setup, there can even be a boost of 6% or a slowdown of 3%.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Code Freeze documentation Improvements or additions to documentation NNCF PT Pull requests that updates NNCF PyTorch
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants