Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 59 additions & 19 deletions src/nncf/quantization/algorithms/weight_compression/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,13 @@
import math
from typing import Optional, TypeVar

import numpy as np

import nncf
from nncf import Dataset
from nncf.common.graph import NNCFGraph
from nncf.common.graph import NNCFNode
from nncf.common.logging import nncf_logger
from nncf.common.logging.track_progress import track
from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer
from nncf.common.utils.backend import BackendType
Expand Down Expand Up @@ -130,8 +133,37 @@ def apply(
raise nncf.UnsupportedModelError(msg)

_, input_tensors = next(iter(inputs.items()))
hessian = self._calculate_hessian(node, input_tensors)
scale, zero_point = self._quantize_weights(model, graph, wc_params, hessian, input_tensors)
weight_tensor = self._backend_entity.get_weight(
wc_params.node_with_weight, wc_params.weight_port_id, model, graph
)
weight_tensor = fns.astype(weight_tensor, TensorDataType.float32)

is_3d_weight = len(weight_tensor.shape) == 3

node = wc_params.node_with_weight
hessian = self._calculate_hessian(node, input_tensors, is_3d_weight)
weight_tensor = fns.unsqueeze(weight_tensor, 0) if not is_3d_weight else weight_tensor
scales = []
zero_points = []
weights = []
for batch_idx in range(hessian.shape[0]):
batch_hessian = hessian[batch_idx]
batch_weight = weight_tensor[batch_idx]
reduction_axes = wc_params.reduction_axes
assert len(reduction_axes) == 1, "2D reduction axes is not currently supported in GPTQ"
wc_params.reduction_axes = (reduction_axes[0] - 1,) if is_3d_weight else reduction_axes
input_tensor = input_tensors[batch_idx] if is_3d_weight else input_tensors
batch_quantized_weight, batch_scale, batch_zero_point = self._quantize_weights(
wc_params, batch_hessian, batch_weight, input_tensor
)
wc_params.reduction_axes = reduction_axes
weights.append(batch_quantized_weight)
scales.append(batch_scale)
zero_points.append(batch_zero_point)
scale = fns.stack(scales, axis=0) if is_3d_weight else scales[0]
zero_point = fns.stack(zero_points, axis=0) if is_3d_weight else zero_points[0]
weight = fns.stack(weights, axis=0) if is_3d_weight else weights[0]
self._backend_entity.set_weight(wc_params.node_with_weight, wc_params.weight_port_id, model, graph, weight)
res[wc_params.weight_name] = CompressedWeight(None, scale, zero_point, None)

return model, res
Expand Down Expand Up @@ -163,7 +195,7 @@ def get_statistic_points(

return self._layerwise_engine.get_statistic_points(model, graph, filtered_nodes)

def _calculate_hessian(self, node: NNCFNode, inputs: list[Tensor]) -> Tensor:
def _calculate_hessian(self, node: NNCFNode, inputs: list[Tensor], is_3d_weight: bool = False) -> Tensor:
"""
Calculates the Hessian matrix for the given node and inputs.

Expand All @@ -179,30 +211,39 @@ def _calculate_hessian(self, node: NNCFNode, inputs: list[Tensor]) -> Tensor:
if node.layer_attributes.input_attributes["transpose"]:
msg = "Transposed input is not supported"
raise nncf.UnsupportedModelError(msg)

# Make hessian 3D. Such that for 2D weights it is only 1 batch and can be squeezed later.
# For 3D weights this dimension matches the weights dimensions
hessian_batch = 1 if not is_3d_weight else np.multiply.reduce(inputs[0].shape[:-2])
hessian = fns.zeros(
(inputs[0].shape[-1], inputs[0].shape[-1]), backend=inputs[0].backend, dtype=TensorDataType.float32
(hessian_batch, inputs[0].shape[-1], inputs[0].shape[-1]),
backend=inputs[0].backend,
dtype=TensorDataType.float32,
)

for inp in inputs:
batch_size = 1 if len(inp.shape) == 2 else inp.shape[0]
is_3d_act = len(inp.shape) == 3
# For 3D weights case, batch size will always be 1. Each "batch"/expert of the activation is treated as
# single 2D matmuls
batch_size = 1 if not is_3d_act and not is_3d_weight else inp.shape[0]
if node.metatype in self._backend_entity.matmul_metatypes:
if len(inp.shape) == 3:
# For 3D act + 2D weight case we should reshape activation to 2D to match weight
# For 3D act + 3D weight it should remain in 3D and the last 2 dimensions should be activation per
# batch/0-th dimension
if is_3d_act and not is_3d_weight:
inp = inp.reshape((-1, inp.shape[-1]))
inp = fns.transpose(inp)
inp = fns.moveaxis(inp, -1, -2)
hessian *= nsamples / (nsamples + batch_size)
nsamples += batch_size
inp = fns.astype(inp, TensorDataType.float32) * math.sqrt(2 / nsamples)
hessian += fns.matmul(inp, fns.transpose(inp))
hessian += fns.matmul(inp, fns.moveaxis(inp, -1, -2))

return hessian

def _quantize_weights(
self,
model: TModel,
graph: NNCFGraph,
wc_params: WeightCompressionParameters,
hessian: Tensor,
weight_tensor: Tensor,
inputs: list[Tensor],
):
"""
Expand All @@ -221,10 +262,11 @@ def _quantize_weights(
msg = "Transpose is not supported"
raise RuntimeError(msg)

weight_tensor = self._backend_entity.get_weight(
wc_params.node_with_weight, wc_params.weight_port_id, model, graph
)
weight_tensor = fns.astype(weight_tensor, TensorDataType.float32)
if len(hessian.shape) == 3 and hessian.shape[0] == 1:
hessian = fns.squeeze(hessian)
msg = "The hessian passed to quantize_weights is 3D. It should be 2D"
nncf_logger.warning(msg=msg)
assert len(hessian.shape) == 2, "Hessian should be 2D"

dead_indices = fns.diag(hessian) == 0
hessian[dead_indices, dead_indices] = 1
Expand Down Expand Up @@ -278,6 +320,7 @@ def _quantize_weights(
else:
if self._scale_estimation and block_compression_config.num_bits == 4:
activations = [inp[..., (i1 + i) : (i1 + i + group_size)] for inp in inputs]
# TODO(anazir): Make it work for 3D weights
wc_statistics = ScaleEstimation.activations_to_wc_statistics(activations)
scale, zero_point = ScaleEstimation.calculate_quantization_params(
wc_statistics,
Expand Down Expand Up @@ -323,9 +366,6 @@ def _quantize_weights(
weight_tensor[:, i2:] -= fns.matmul(error_block, hessian_inv[i1:i2, i2:])

quantized_tensor = quantized_tensor.reshape(weight_tensor.shape).astype(weight_tensor.dtype)
self._backend_entity.set_weight(
wc_params.node_with_weight, wc_params.weight_port_id, model, graph, quantized_tensor
)

scales = fns.stack(scales, axis=1)
if wc_params.compression_config.group_size == -1:
Expand All @@ -339,4 +379,4 @@ def _quantize_weights(
zero_points = fns.squeeze(zero_points, axis=-1)
else:
zero_points = None
return scales, zero_points
return weight_tensor, scales, zero_points
3 changes: 2 additions & 1 deletion tests/openvino/native/quantization/test_gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,8 @@ def test_calculate_scale_linear():
)
wc_params.compression_config = WeightCompressionConfig(mode=CompressWeightsMode.INT4_SYM, group_size=16)

scale, _ = gptq._quantize_weights(ov_model, graph, wc_params, H, wrapped_inputs)
nncf_weight = Tensor(weights)
_, scale, _ = gptq._quantize_weights(wc_params, H, nncf_weight, wrapped_inputs)
ref_scale = ref_scale.numpy()
scale = scale.reshape(ref_scale.shape)
assert np.all(np.isclose(ref_scale, scale.data))
Expand Down