Skip to content

Updated RQ/CPU extension signatures #3441

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions nncf/torch/quantization/quantize_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,8 @@ def backward(ctx, grad_output):
orig_shape = grad_output.shape
grad_output = grad_output.reshape(input_shape)

output = RQ.Quantize_forward(input_, input_low, input_range, levels)
grad_input, _, grad_scale = RQ.Quantize_backward(
grad_output, input_, input_low, input_range, output, level_low, level_high
grad_output, input_, input_low, input_range, levels, level_low, level_high
)

grad_input = grad_input.reshape(orig_shape)
Expand Down Expand Up @@ -197,9 +196,8 @@ def backward(ctx, grad_output):
orig_shape = grad_output.shape
grad_output = grad_output.reshape(input_shape)

output = RQ.Quantize_forward(input_, input_low, input_range, levels)
grad_input, grad_low, grad_range = RQ.Quantize_backward(
grad_output, input_, input_low, input_range, output, level_low, level_high
grad_output, input_, input_low, input_range, levels, level_low, level_high
)

grad_input = grad_input.reshape(orig_shape)
Expand Down
3 changes: 2 additions & 1 deletion nncf/torch/quantization/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def backward(
input_: GeneralizedTensor,
input_low: GeneralizedTensor,
input_range: GeneralizedTensor,
output: GeneralizedTensor,
levels: int,
level_low: int,
level_high: int,
is_asymmetric: bool = False,
Expand All @@ -83,6 +83,7 @@ def backward(

mask_in = 1 - mask_hi - mask_lo
range_sign = self._sign(input_range)
output = self.forward(input_, input_low, input_range, levels)
err = (output - input_) * self._reciprocal(input_range * range_sign)
grad_range = grad_output * (err * mask_in + range_sign * (level_low / level_high) * mask_lo + mask_hi)
grad_range = sum_like(grad_range, input_range)
Expand Down
132 changes: 130 additions & 2 deletions tests/torch/quantization/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,14 @@
from torch.distributions.uniform import Uniform

from nncf.common.quantization.structs import QuantizationScheme as QuantizationMode
from nncf.torch.quantization.extensions import QuantizedFunctionsCPU
from nncf.torch.quantization.extensions import QuantizedFunctionsCUDA
from nncf.torch.quantization.quantize_functions import asymmetric_quantize
from nncf.torch.quantization.quantize_functions import get_scale_zp_from_input_low_input_high
from nncf.torch.quantization.quantize_functions import symmetric_quantize
from nncf.torch.quantization.reference import ReferenceBackendType
from nncf.torch.quantization.reference import ReferenceQuantize
from nncf.torch.quantization.reference import ReferenceQuantizedFunctions
from tests.torch.helpers import PTTensorListComparator
from tests.torch.helpers import get_grads

Expand Down Expand Up @@ -381,7 +384,7 @@ def test_quantize_symmetric_backward(

mock_prev_output_grads = np.ones(input_size, dtype=np.float16 if is_fp16 else np.float32)
ref_grads = RQ.backward(
mock_prev_output_grads, ref_input, ref_input_low, ref_input_range, ref_output, level_low, level_high
mock_prev_output_grads, ref_input, ref_input_low, ref_input_range, levels, level_low, level_high
)
del ref_grads[1]
test_value = symmetric_quantize(test_input, levels, level_low, level_high, test_scale, EPS)
Expand Down Expand Up @@ -585,7 +588,7 @@ def test_quantize_asymmetric_backward(self, _seed, input_size, bits, use_cuda, i

mock_prev_output_grads = np.ones(input_size, dtype=np.float16 if is_fp16 else np.float32)
ref_grads = RQ.backward(
mock_prev_output_grads, ref_input, ref_input_low, ref_input_range, ref_output, level_low, level_high
mock_prev_output_grads, ref_input, ref_input_low, ref_input_range, levels, level_low, level_high
)

test_value = asymmetric_quantize(
Expand Down Expand Up @@ -669,3 +672,128 @@ def test_get_scale_zp_from_input_low_input_high(
)
assert zero_point == ref_zero_point, f"{zero_point} != {ref_zero_point}"
assert np.isclose(scale, ref_scale), f"{scale:.10f} != {ref_scale}"


class CompatibilityTestDesc:
def __init__(self, levels, level_low, level_high, is_asymmetric=False):
self.input_ = torch.tensor([[-0.5, 0.5]])
self.input_low = torch.tensor([-0.5])
self.input_range = torch.tensor([1.0])
self.grad_output = torch.tensor([[-0.5, 0.5]])
self.levels = levels
self.level_low = level_low
self.level_high = level_high
self.is_asymmetric = is_asymmetric


@pytest.mark.parametrize(
"desc",
[
CompatibilityTestDesc(
levels=256,
level_low=-128,
level_high=127,
is_asymmetric=True,
),
CompatibilityTestDesc(
levels=256,
level_low=-128,
level_high=127,
is_asymmetric=False,
),
CompatibilityTestDesc(
levels=16,
level_low=0,
level_high=15,
is_asymmetric=True,
),
CompatibilityTestDesc(
levels=16,
level_low=0,
level_high=15,
is_asymmetric=False,
),
],
)
def test_cpu_extension_reference_compatibility(desc):
fwd_args = [desc.input_, desc.input_low, desc.input_range, desc.levels]
bwd_args = [
desc.grad_output,
desc.input_,
desc.input_low,
desc.input_range,
desc.levels,
desc.level_low,
desc.level_high,
desc.is_asymmetric,
]

ext_fwd_output = QuantizedFunctionsCPU.get("Quantize_forward")(*fwd_args)
ref_fwd_output = ReferenceQuantizedFunctions.Quantize_forward(*fwd_args)

assert torch.allclose(ext_fwd_output, ref_fwd_output)

bwd_grad_input, bwd_grad_low, bwd_grad_range = QuantizedFunctionsCPU.get("Quantize_backward")(*bwd_args)
ref_grad_input, ref_grad_low, ref_grad_range = ReferenceQuantizedFunctions.Quantize_backward(*bwd_args)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can ignore gradient for input low in the test

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case, the differences between extensions and references wouldn't be noticed.

You've previously worried about the absence of tests for years for this case, and now suggest testing it partially.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, gradient for input low is not used anywhere as intended. IMHO, nothing wrong here and it's not an error.
It's coming from a requirement of torch.autograd to have the same number of outputs on backward as the number of inputs on forward. But symmetric case has only one learnable parameter - scale, which is represented by input_range, that's why only one gradient is used.

My ask was about testing the functionality of extensions with reference implementation.
The gradient for input low doesn't influence this functionality. Hope, it makes things clear.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As you wish.

Copy link
Collaborator Author

@nikita-malininn nikita-malininn Apr 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

However, currently, the method signatures are different between CUDA and CPU extensions.

Furthermore, the grad_range is not the same between CUDA and the reference execution (see failed tests on CUDA). We can continue to sweep the issues under the rug and only check grad_input, if you want?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please stop misinterpreting my words.
I didn't say that we should check only grad_input and say nothing about different signatures.

My point was only about gradient for input_low.
Feel free to align signatures, fix and check gradient for input_range - no objection for that.


assert torch.allclose(bwd_grad_input, ref_grad_input)
if desc.is_asymmetric:
assert torch.allclose(bwd_grad_low, ref_grad_low)
assert torch.allclose(bwd_grad_range, ref_grad_range)


@pytest.mark.cuda
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Skipping for CPU-only setups")
@pytest.mark.parametrize(
"desc",
[
CompatibilityTestDesc(
levels=256,
level_low=-128,
level_high=127,
),
CompatibilityTestDesc(
levels=256,
level_low=-128,
level_high=127,
),
CompatibilityTestDesc(
levels=16,
level_low=0,
level_high=15,
),
CompatibilityTestDesc(
levels=16,
level_low=0,
level_high=15,
),
],
)
def test_cuda_extension_reference_compatibility(desc):
device = torch.device("cuda")
input_low = desc.input_low.to(device)
input_range = desc.input_range.to(device)
input_ = desc.input_.to(device)

fwd_args = [input_, input_low, input_range, desc.levels]
bwd_args = [
desc.grad_output.to(device),
input_,
input_low,
input_range,
desc.levels,
desc.level_low,
desc.level_high,
]

ext_fwd_output = QuantizedFunctionsCUDA.get("Quantize_forward")(*fwd_args)
ref_fwd_output = ReferenceQuantizedFunctions.Quantize_forward(*fwd_args)

assert torch.allclose(ext_fwd_output, ref_fwd_output)

bwd_grad_input, bwd_grad_low, bwd_grad_range = QuantizedFunctionsCUDA.get("Quantize_backward")(*bwd_args)
ref_grad_input, ref_grad_low, ref_grad_range = ReferenceQuantizedFunctions.Quantize_backward(*bwd_args)

assert torch.allclose(bwd_grad_input, ref_grad_input)
assert torch.allclose(bwd_grad_low, ref_grad_low)
assert torch.allclose(bwd_grad_range, ref_grad_range)
Loading