Skip to content

Updated RQ signature #3441

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: develop
Choose a base branch
from
22 changes: 8 additions & 14 deletions nncf/torch/extensions/src/quantization/cpu/functions_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ std::vector<at::Tensor> q_cpu_backward(
at::Tensor input_range,
scalar_t levels,
scalar_t levels_low,
scalar_t levels_high,
bool is_asymmetric) {
scalar_t levels_high) {
auto output = q_cpu_forward<scalar_t>(input, input_low, input_range, levels);
auto reverted_range = 1 / input_range;
scalar_t alpha = levels_low / levels_high;
Expand All @@ -50,16 +49,12 @@ std::vector<at::Tensor> q_cpu_backward(
auto outside_mask = mask_hi.add_(mask_lo);
grad_input = grad_input.masked_fill_(outside_mask, 0);

if (is_asymmetric) {
auto grad_input_low = grad_output.clone();
auto all_ones = torch::ones_like(outside_mask);
grad_input_low = grad_input_low.masked_fill_(at::__xor__(all_ones, outside_mask), 0);
auto grad_input_low = grad_output.clone();
auto all_ones = torch::ones_like(outside_mask);
grad_input_low = grad_input_low.masked_fill_(at::__xor__(all_ones, outside_mask), 0);

sum_like(grad_input_low, input_low);
return {grad_input, grad_input_low, grad_input_range};
}
auto dummy_variable = torch::autograd::make_variable(at::empty(input_low.sizes()), true);
return {grad_input, dummy_variable, grad_input_range};
sum_like(grad_input_low, input_low);
return {grad_input, grad_input_low, grad_input_range};
}

#define CHECK_INPUT(x) CHECK_CPU(x)
Expand Down Expand Up @@ -95,16 +90,15 @@ std::vector<at::Tensor> q_backward(
at::Tensor input_range,
int levels,
int level_low,
int level_high,
bool is_asymmetric) {
int level_high) {
CHECK_INPUT(grad_output);
CHECK_INPUT(input);
CHECK_INPUT(input_low);
CHECK_INPUT(input_range);

std::vector<at::Tensor> results;
DISPATCH_TENSOR_DATA_TYPES(input.scalar_type(), "q_cpu_backward", ([&] {
results = q_cpu_backward<scalar_t>(grad_output, input, input_low, input_range, levels, level_low, level_high, is_asymmetric);
results = q_cpu_backward<scalar_t>(grad_output, input, input_low, input_range, levels, level_low, level_high);
}));

return results;
Expand Down
10 changes: 4 additions & 6 deletions nncf/torch/quantization/quantize_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def backward(ctx: Any, *grad_outputs: Any) -> Any:
)
else:
grad_input, _, grad_scale = QuantizedFunctionsCPU.get("Quantize_backward")(
grad_output, input_, input_low, input_range, levels, level_low, level_high, False
grad_output, input_, input_low, input_range, levels, level_low, level_high
)

return grad_input, grad_scale, None, None, None
Expand Down Expand Up @@ -114,7 +114,7 @@ def backward(ctx: Any, *grad_outputs: Any) -> Any:
)
else:
grad_input, grad_input_low, grad_input_range = QuantizedFunctionsCPU.get("Quantize_backward")(
grad_output, input_, input_low, input_range, levels, level_low, level_high, True
grad_output, input_, input_low, input_range, levels, level_low, level_high
)
return grad_input, grad_input_low, grad_input_range, None, None, None

Expand Down Expand Up @@ -155,9 +155,8 @@ def backward(ctx, grad_output):
orig_shape = grad_output.shape
grad_output = grad_output.reshape(input_shape)

output = RQ.Quantize_forward(input_, input_low, input_range, levels)
grad_input, _, grad_scale = RQ.Quantize_backward(
grad_output, input_, input_low, input_range, output, level_low, level_high
grad_output, input_, input_low, input_range, levels, level_low, level_high
)

grad_input = grad_input.reshape(orig_shape)
Expand Down Expand Up @@ -197,9 +196,8 @@ def backward(ctx, grad_output):
orig_shape = grad_output.shape
grad_output = grad_output.reshape(input_shape)

output = RQ.Quantize_forward(input_, input_low, input_range, levels)
grad_input, grad_low, grad_range = RQ.Quantize_backward(
grad_output, input_, input_low, input_range, output, level_low, level_high
grad_output, input_, input_low, input_range, levels, level_low, level_high
)

grad_input = grad_input.reshape(orig_shape)
Expand Down
5 changes: 2 additions & 3 deletions nncf/torch/quantization/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,19 +70,18 @@ def backward(
input_: GeneralizedTensor,
input_low: GeneralizedTensor,
input_range: GeneralizedTensor,
output: GeneralizedTensor,
levels: int,
level_low: int,
level_high: int,
is_asymmetric: bool = False,
) -> List[GeneralizedTensor]:
# is_asymmetric is unused, present only to correspond to the CPU signature of calling "backward"
mask_hi = input_ > (input_low + input_range)
mask_hi = self._astype(mask_hi, input_.dtype)
mask_lo = input_ < input_low
mask_lo = self._astype(mask_lo, input_.dtype)

mask_in = 1 - mask_hi - mask_lo
range_sign = self._sign(input_range)
output = self.forward(input_, input_low, input_range, levels)
err = (output - input_) * self._reciprocal(input_range * range_sign)
grad_range = grad_output * (err * mask_in + range_sign * (level_low / level_high) * mask_lo + mask_hi)
grad_range = sum_like(grad_range, input_range)
Expand Down
103 changes: 101 additions & 2 deletions tests/torch/quantization/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,14 @@
from torch.distributions.uniform import Uniform

from nncf.common.quantization.structs import QuantizationScheme as QuantizationMode
from nncf.torch.quantization.extensions import QuantizedFunctionsCPU
from nncf.torch.quantization.extensions import QuantizedFunctionsCUDA
from nncf.torch.quantization.quantize_functions import asymmetric_quantize
from nncf.torch.quantization.quantize_functions import get_scale_zp_from_input_low_input_high
from nncf.torch.quantization.quantize_functions import symmetric_quantize
from nncf.torch.quantization.reference import ReferenceBackendType
from nncf.torch.quantization.reference import ReferenceQuantize
from nncf.torch.quantization.reference import ReferenceQuantizedFunctions
from tests.torch.helpers import PTTensorListComparator
from tests.torch.helpers import get_grads

Expand Down Expand Up @@ -381,7 +384,7 @@ def test_quantize_symmetric_backward(

mock_prev_output_grads = np.ones(input_size, dtype=np.float16 if is_fp16 else np.float32)
ref_grads = RQ.backward(
mock_prev_output_grads, ref_input, ref_input_low, ref_input_range, ref_output, level_low, level_high
mock_prev_output_grads, ref_input, ref_input_low, ref_input_range, levels, level_low, level_high
)
del ref_grads[1]
test_value = symmetric_quantize(test_input, levels, level_low, level_high, test_scale, EPS)
Expand Down Expand Up @@ -585,7 +588,7 @@ def test_quantize_asymmetric_backward(self, _seed, input_size, bits, use_cuda, i

mock_prev_output_grads = np.ones(input_size, dtype=np.float16 if is_fp16 else np.float32)
ref_grads = RQ.backward(
mock_prev_output_grads, ref_input, ref_input_low, ref_input_range, ref_output, level_low, level_high
mock_prev_output_grads, ref_input, ref_input_low, ref_input_range, levels, level_low, level_high
)

test_value = asymmetric_quantize(
Expand Down Expand Up @@ -669,3 +672,99 @@ def test_get_scale_zp_from_input_low_input_high(
)
assert zero_point == ref_zero_point, f"{zero_point} != {ref_zero_point}"
assert np.isclose(scale, ref_scale), f"{scale:.10f} != {ref_scale}"


class CompatibilityTestDesc:
def __init__(self, levels, level_low, level_high, is_asymmetric):
self.input_ = torch.tensor([[-0.5, 0.5]])
self.input_low = torch.tensor([-0.5])
self.input_high = torch.tensor([0.5])
self.grad_output = torch.tensor([[-0.5, 0.5]])
self.levels = levels
self.level_low = level_low
self.level_high = level_high
self.is_asymmetric = is_asymmetric


@pytest.mark.parametrize(
"desc",
[
CompatibilityTestDesc(
levels=256,
level_low=-128,
level_high=127,
),
CompatibilityTestDesc(
levels=16,
level_low=0,
level_high=15,
),
],
)
def test_cpu_extension_reference_compatibility(desc):
input_range = desc.input_high - desc.input_low
fwd_args = [desc.input_, desc.input_low, input_range, desc.levels]
bwd_args = [
desc.grad_output,
desc.input_,
desc.input_low,
input_range,
desc.levels,
desc.level_low,
desc.level_high,
]

ext_fwd_output = QuantizedFunctionsCPU.get("Quantize_forward")(*fwd_args)
ref_fwd_output = ReferenceQuantizedFunctions.Quantize_forward(*fwd_args)

assert torch.allclose(ext_fwd_output, ref_fwd_output)

bwd_grad_input, bwd_grad_low, bwd_grad_range = QuantizedFunctionsCPU.get("Quantize_backward")(*bwd_args)
ref_grad_input, ref_grad_low, ref_grad_range = ReferenceQuantizedFunctions.Quantize_backward(*bwd_args)

assert torch.allclose(bwd_grad_input, ref_grad_input)
assert torch.allclose(bwd_grad_low, ref_grad_low)
assert torch.allclose(bwd_grad_range, ref_grad_range)


@pytest.mark.cuda
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Skipping for CPU-only setups")
@pytest.mark.parametrize(
"desc",
[
CompatibilityTestDesc(
levels=256,
level_low=-128,
level_high=127,
),
CompatibilityTestDesc(
levels=16,
level_low=0,
level_high=15,
),
],
)
def test_cuda_extension_reference_compatibility(desc):
input_range = desc.input_high - desc.input_low
fwd_args = [desc.input_, desc.input_low, input_range, desc.levels]
bwd_args = [
desc.grad_output,
desc.input_,
desc.input_low,
input_range,
desc.levels,
desc.level_low,
desc.level_high,
]

ext_fwd_output = QuantizedFunctionsCUDA.get("Quantize_forward")(*fwd_args)
ref_fwd_output = ReferenceQuantizedFunctions.Quantize_forward(*fwd_args)

assert torch.allclose(ext_fwd_output, ref_fwd_output)

bwd_grad_input, bwd_grad_low, bwd_grad_range = QuantizedFunctionsCUDA.get("Quantize_backward")(*bwd_args)
ref_grad_input, ref_grad_low, ref_grad_range = ReferenceQuantizedFunctions.Quantize_backward(*bwd_args)

assert torch.allclose(bwd_grad_input, ref_grad_input)
assert torch.allclose(bwd_grad_low, ref_grad_low)
assert torch.allclose(bwd_grad_range, ref_grad_range)
Loading