From 5a287f861376d6910a84c58bca6c29b77beaf29b Mon Sep 17 00:00:00 2001 From: anzr299 Date: Fri, 9 Jan 2026 13:29:43 +0400 Subject: [PATCH 1/6] init --- .../algorithms/weight_compression/gptq.py | 62 +++++++++++++------ 1 file changed, 44 insertions(+), 18 deletions(-) diff --git a/src/nncf/quantization/algorithms/weight_compression/gptq.py b/src/nncf/quantization/algorithms/weight_compression/gptq.py index aeb32adede1..f18f7c951e9 100644 --- a/src/nncf/quantization/algorithms/weight_compression/gptq.py +++ b/src/nncf/quantization/algorithms/weight_compression/gptq.py @@ -130,8 +130,32 @@ def apply( raise nncf.UnsupportedModelError(msg) _, input_tensors = next(iter(inputs.items())) - hessian = self._calculate_hessian(node, input_tensors) - scale, zero_point = self._quantize_weights(model, graph, wc_params, hessian, input_tensors) + weight_tensor = self._backend_entity.get_weight( + wc_params.node_with_weight, wc_params.weight_port_id, model, graph + ) + weight_tensor = fns.astype(weight_tensor, TensorDataType.float32) + + is_3d_weight = len(weight_tensor.shape) == 3 + + node = wc_params.node_with_weight + hessian = self._calculate_hessian(node, input_tensors, is_3d_weight) + weight_tensor = fns.unsqueeze(weight_tensor, 0) if not is_3d_weight else weight_tensor + scales = [] + zero_points = [] + weights = [] + for batch_idx in range(hessian.shape[0]): + batch_hessian = hessian[batch_idx] + batch_weight = weight_tensor[batch_idx] + batch_quantized_weight, batch_scale, batch_zero_point = self._quantize_weights( + wc_params, batch_hessian, batch_weight, input_tensors + ) + weights.append(batch_quantized_weight) + scales.append(batch_scale) + zero_points.append(batch_zero_point) + scale = fns.stack(scales, axis=0) if is_3d_weight else scales[0] + zero_point = fns.stack(zero_points, axis=0) if is_3d_weight else zero_points[0] + weight = fns.stack(weights, axis=0) if is_3d_weight else weights[0] + self._backend_entity.set_weight(wc_params.node_with_weight, wc_params.weight_port_id, model, graph, weight) res[wc_params.weight_name] = CompressedWeight(None, scale, zero_point, None) return model, res @@ -163,7 +187,7 @@ def get_statistic_points( return self._layerwise_engine.get_statistic_points(model, graph, filtered_nodes) - def _calculate_hessian(self, node: NNCFNode, inputs: list[Tensor]) -> Tensor: + def _calculate_hessian(self, node: NNCFNode, inputs: list[Tensor], is_3d_weight: bool) -> Tensor: """ Calculates the Hessian matrix for the given node and inputs. @@ -184,25 +208,34 @@ def _calculate_hessian(self, node: NNCFNode, inputs: list[Tensor]) -> Tensor: (inputs[0].shape[-1], inputs[0].shape[-1]), backend=inputs[0].backend, dtype=TensorDataType.float32 ) + # Make hessian 3D. Such that for 2D weights it is only 1 batch and can be squeezed later. + # For 3D weights this dimension matches the weights dimensions + hessian = fns.unsqueeze(hessian, 0) + for inp in inputs: - batch_size = 1 if len(inp.shape) == 2 else inp.shape[0] + is_3d_act = len(inp.shape) == 3 + # For 3D weights case, batch size will always be 1. Each "batch"/expert of the activation is treated as + # single 2D matmuls + batch_size = 1 if not is_3d_act and not is_3d_weight else inp.shape[0] if node.metatype in self._backend_entity.matmul_metatypes: - if len(inp.shape) == 3: + # For 3D act + 2D weight case we should reshape activation to 2D to match weight + # For 3D act + 3D weight it should remain in 3D and the last 2 dimensions should be activation + # per batch/0-th dimension + if is_3d_act and not is_3d_weight: inp = inp.reshape((-1, inp.shape[-1])) - inp = fns.transpose(inp) + inp = fns.moveaxis(inp, -1, -2) hessian *= nsamples / (nsamples + batch_size) nsamples += batch_size inp = fns.astype(inp, TensorDataType.float32) * math.sqrt(2 / nsamples) - hessian += fns.matmul(inp, fns.transpose(inp)) + hessian += fns.matmul(inp, fns.moveaxis(inp, -1, -2)) return hessian def _quantize_weights( self, - model: TModel, - graph: NNCFGraph, wc_params: WeightCompressionParameters, hessian: Tensor, + weight_tensor: Tensor, inputs: list[Tensor], ): """ @@ -221,11 +254,6 @@ def _quantize_weights( msg = "Transpose is not supported" raise RuntimeError(msg) - weight_tensor = self._backend_entity.get_weight( - wc_params.node_with_weight, wc_params.weight_port_id, model, graph - ) - weight_tensor = fns.astype(weight_tensor, TensorDataType.float32) - dead_indices = fns.diag(hessian) == 0 hessian[dead_indices, dead_indices] = 1 weight_tensor[:, dead_indices] = 0 @@ -278,6 +306,7 @@ def _quantize_weights( else: if self._scale_estimation and block_compression_config.num_bits == 4: activations = [inp[..., (i1 + i) : (i1 + i + group_size)] for inp in inputs] + # TODO(anazir): Make it work for 3D weights wc_statistics = ScaleEstimation.activations_to_wc_statistics(activations) scale, zero_point = ScaleEstimation.calculate_quantization_params( wc_statistics, @@ -323,9 +352,6 @@ def _quantize_weights( weight_tensor[:, i2:] -= fns.matmul(error_block, hessian_inv[i1:i2, i2:]) quantized_tensor = quantized_tensor.reshape(weight_tensor.shape).astype(weight_tensor.dtype) - self._backend_entity.set_weight( - wc_params.node_with_weight, wc_params.weight_port_id, model, graph, quantized_tensor - ) scales = fns.stack(scales, axis=1) if wc_params.compression_config.group_size == -1: @@ -339,4 +365,4 @@ def _quantize_weights( zero_points = fns.squeeze(zero_points, axis=-1) else: zero_points = None - return scales, zero_points + return weight_tensor, scales, zero_points From 4076d08d58515cbbf8a875ba8b493ddb3e05d844 Mon Sep 17 00:00:00 2001 From: anzr299 Date: Fri, 9 Jan 2026 13:32:41 +0400 Subject: [PATCH 2/6] fix --- src/nncf/quantization/algorithms/weight_compression/gptq.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/nncf/quantization/algorithms/weight_compression/gptq.py b/src/nncf/quantization/algorithms/weight_compression/gptq.py index f18f7c951e9..48b7c28cc93 100644 --- a/src/nncf/quantization/algorithms/weight_compression/gptq.py +++ b/src/nncf/quantization/algorithms/weight_compression/gptq.py @@ -219,8 +219,8 @@ def _calculate_hessian(self, node: NNCFNode, inputs: list[Tensor], is_3d_weight: batch_size = 1 if not is_3d_act and not is_3d_weight else inp.shape[0] if node.metatype in self._backend_entity.matmul_metatypes: # For 3D act + 2D weight case we should reshape activation to 2D to match weight - # For 3D act + 3D weight it should remain in 3D and the last 2 dimensions should be activation - # per batch/0-th dimension + # For 3D act + 3D weight it should remain in 3D and the last 2 dimensions should be activation per + # batch/0-th dimension if is_3d_act and not is_3d_weight: inp = inp.reshape((-1, inp.shape[-1])) inp = fns.moveaxis(inp, -1, -2) From 26e8403fdc6eef96547f844f6f3e428ec9aee9b0 Mon Sep 17 00:00:00 2001 From: anzr299 Date: Mon, 12 Jan 2026 12:23:31 +0400 Subject: [PATCH 3/6] make gptq work --- .../algorithms/weight_compression/gptq.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/src/nncf/quantization/algorithms/weight_compression/gptq.py b/src/nncf/quantization/algorithms/weight_compression/gptq.py index 48b7c28cc93..e4e511b0faf 100644 --- a/src/nncf/quantization/algorithms/weight_compression/gptq.py +++ b/src/nncf/quantization/algorithms/weight_compression/gptq.py @@ -12,6 +12,8 @@ import math from typing import Optional, TypeVar +import numpy as np + import nncf from nncf import Dataset from nncf.common.graph import NNCFGraph @@ -146,9 +148,13 @@ def apply( for batch_idx in range(hessian.shape[0]): batch_hessian = hessian[batch_idx] batch_weight = weight_tensor[batch_idx] + reduction_axes = wc_params.reduction_axes + assert len(reduction_axes) == 1, "2D reduction axes is not currently supported in GPTQ" + wc_params.reduction_axes = (reduction_axes[0] - 1,) if is_3d_weight else reduction_axes batch_quantized_weight, batch_scale, batch_zero_point = self._quantize_weights( wc_params, batch_hessian, batch_weight, input_tensors ) + wc_params.reduction_axes = reduction_axes weights.append(batch_quantized_weight) scales.append(batch_scale) zero_points.append(batch_zero_point) @@ -203,14 +209,14 @@ def _calculate_hessian(self, node: NNCFNode, inputs: list[Tensor], is_3d_weight: if node.layer_attributes.input_attributes["transpose"]: msg = "Transposed input is not supported" raise nncf.UnsupportedModelError(msg) - - hessian = fns.zeros( - (inputs[0].shape[-1], inputs[0].shape[-1]), backend=inputs[0].backend, dtype=TensorDataType.float32 - ) - # Make hessian 3D. Such that for 2D weights it is only 1 batch and can be squeezed later. # For 3D weights this dimension matches the weights dimensions - hessian = fns.unsqueeze(hessian, 0) + hessian_batch = 1 if not is_3d_weight else np.multiply.reduce(inputs[0].shape[:-2]) + hessian = fns.zeros( + (hessian_batch, inputs[0].shape[-1], inputs[0].shape[-1]), + backend=inputs[0].backend, + dtype=TensorDataType.float32, + ) for inp in inputs: is_3d_act = len(inp.shape) == 3 From 2c8dd510d75bb8d1b27d1d8be8ab8666e046c1cf Mon Sep 17 00:00:00 2001 From: anzr299 Date: Mon, 12 Jan 2026 13:01:21 +0400 Subject: [PATCH 4/6] set default value for is_3d_weights arg in calculate_hessian --- src/nncf/quantization/algorithms/weight_compression/gptq.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nncf/quantization/algorithms/weight_compression/gptq.py b/src/nncf/quantization/algorithms/weight_compression/gptq.py index e4e511b0faf..cf2cf06aa63 100644 --- a/src/nncf/quantization/algorithms/weight_compression/gptq.py +++ b/src/nncf/quantization/algorithms/weight_compression/gptq.py @@ -193,7 +193,7 @@ def get_statistic_points( return self._layerwise_engine.get_statistic_points(model, graph, filtered_nodes) - def _calculate_hessian(self, node: NNCFNode, inputs: list[Tensor], is_3d_weight: bool) -> Tensor: + def _calculate_hessian(self, node: NNCFNode, inputs: list[Tensor], is_3d_weight: bool = False) -> Tensor: """ Calculates the Hessian matrix for the given node and inputs. From 3d481854481083439e60bb253d0228b6c31c3889 Mon Sep 17 00:00:00 2001 From: anzr299 Date: Mon, 12 Jan 2026 13:39:15 +0400 Subject: [PATCH 5/6] fix gptq test --- .../quantization/algorithms/weight_compression/gptq.py | 7 +++++++ tests/openvino/native/quantization/test_gptq.py | 3 ++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/src/nncf/quantization/algorithms/weight_compression/gptq.py b/src/nncf/quantization/algorithms/weight_compression/gptq.py index cf2cf06aa63..4aaee65c718 100644 --- a/src/nncf/quantization/algorithms/weight_compression/gptq.py +++ b/src/nncf/quantization/algorithms/weight_compression/gptq.py @@ -18,6 +18,7 @@ from nncf import Dataset from nncf.common.graph import NNCFGraph from nncf.common.graph import NNCFNode +from nncf.common.logging import nncf_logger from nncf.common.logging.track_progress import track from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer from nncf.common.utils.backend import BackendType @@ -260,6 +261,12 @@ def _quantize_weights( msg = "Transpose is not supported" raise RuntimeError(msg) + if len(hessian.shape) == 3 and hessian.shape[0] == 1: + hessian = fns.squeeze(hessian) + msg = "The hessian passed to quantize_weights is 3D. It should be 2D" + nncf_logger.warning(msg=msg) + assert len(hessian.shape) == 2, "Hessian should be 2D" + dead_indices = fns.diag(hessian) == 0 hessian[dead_indices, dead_indices] = 1 weight_tensor[:, dead_indices] = 0 diff --git a/tests/openvino/native/quantization/test_gptq.py b/tests/openvino/native/quantization/test_gptq.py index daf36a29aec..e01ac3b53ca 100644 --- a/tests/openvino/native/quantization/test_gptq.py +++ b/tests/openvino/native/quantization/test_gptq.py @@ -365,7 +365,8 @@ def test_calculate_scale_linear(): ) wc_params.compression_config = WeightCompressionConfig(mode=CompressWeightsMode.INT4_SYM, group_size=16) - scale, _ = gptq._quantize_weights(ov_model, graph, wc_params, H, wrapped_inputs) + nncf_weight = Tensor(weights) + _, scale, _ = gptq._quantize_weights(wc_params, H, nncf_weight, wrapped_inputs) ref_scale = ref_scale.numpy() scale = scale.reshape(ref_scale.shape) assert np.all(np.isclose(ref_scale, scale.data)) From 8b46a1a05dab976783e5cc77633af44db01da056 Mon Sep 17 00:00:00 2001 From: anzr299 Date: Mon, 12 Jan 2026 17:35:24 +0400 Subject: [PATCH 6/6] pass only specific batch of inputs for MoE case --- src/nncf/quantization/algorithms/weight_compression/gptq.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/nncf/quantization/algorithms/weight_compression/gptq.py b/src/nncf/quantization/algorithms/weight_compression/gptq.py index 4aaee65c718..f1bd63e184a 100644 --- a/src/nncf/quantization/algorithms/weight_compression/gptq.py +++ b/src/nncf/quantization/algorithms/weight_compression/gptq.py @@ -152,8 +152,9 @@ def apply( reduction_axes = wc_params.reduction_axes assert len(reduction_axes) == 1, "2D reduction axes is not currently supported in GPTQ" wc_params.reduction_axes = (reduction_axes[0] - 1,) if is_3d_weight else reduction_axes + input_tensor = input_tensors[batch_idx] if is_3d_weight else input_tensors batch_quantized_weight, batch_scale, batch_zero_point = self._quantize_weights( - wc_params, batch_hessian, batch_weight, input_tensors + wc_params, batch_hessian, batch_weight, input_tensor ) wc_params.reduction_axes = reduction_axes weights.append(batch_quantized_weight)