-
Notifications
You must be signed in to change notification settings - Fork 255
Aligned bf16 tuning vs f32 inference for 4bit compression #3493
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Aligned bf16 tuning vs f32 inference for 4bit compression #3493
Conversation
@@ -155,7 +170,7 @@ def test_sym_fq_to_decompressor(param: ParamSymFQ): | |||
|
|||
scale_shape = (1, 1) | |||
scale = torch.tensor(SCALE_SAMPLE) | |||
scale = scale.expand(scale_shape).to(torch.float16) | |||
scale = scale.expand(scale_shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doesn't influence the result, just more aligned with default precision in FQ (float32).
ParamStripLora(CompressWeightsMode.INT8_SYM, INT8SymDQ, torch.bfloat16, 1e-2, 1e-3, torch.int8), | ||
ParamStripLora(CompressWeightsMode.INT8_ASYM, INT8AsymDQ, torch.bfloat16, 1e-8, 1e-3, torch.uint8), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: It's expected that the ov_tol value is higher than the torch_tol value, since the OV model executes in f32 but the torch model has activations in bf16 or f16 on tuning. Even though the ov_tol isn't very small, in a few cases it was larger before the PR.
I don't get the results from the performance/memory table. Or this PR slows down the reference and increases memory consumption or the table filled somehow wrong. |
Yes, your understanding is correct. This is the price for the accuracy. |
About ~4x slowness for the compiled forward version in the symmetric case (or ~2x slowness for non-compiled) and a 10% memory consumption increase - is it |
It's still 1.4x faster than not compiled before PR and 2.9x than not compiled with PR. Do you have some suggestion of improvement? |
Updated description with total time required for 1 epoch before and with PR. |
Changes
Always cast input to float32 inside FQ + LoRA.
Benchmark results with new schema on the https://github.com/ljaljushkin/nncf_pytorch/tree/nl/ref_benchmark with small modifications from @nikita-malininn's branch https://github.com/nikita-malininn/nncf/tree/nm/ref_benchmark:
There's an overhead on forward, but it's leveled up by using torch.compile.
There's a 1-6% overhead on RTX per epoch, and on A100, depending on the setup, there can even be a boost of 6% or a slowdown of 3%.


Reason for changes
Minimize the disparity in precision between the Torch model and its exported OV equivalent.
The full alignment would be very inefficient, so here's a compromise: align accuracy with minimal overhead on the forward pass.
e2e test on
facebook/opt-125m
proves that output is the same now within default absolute tolerance (1e-8) instead of 1e-2 one:https://github.com/openvinotoolkit/nncf/pull/3493/files#diff-7a4f90fe4f07d515df355d6fb618112d7d3fe88eb8ba777e502c695a7c715010R170
Previously, there were 3 problematic models with significant difference in accuracy, now it's much more aligned:

Related tickets
166195
Tests
test examples - https://github.com/openvinotoolkit/nncf/actions/runs/15024278726/job/42221028011