From 6a3baac1a2de42dec9ebb3bf9fe82764ae32fb4a Mon Sep 17 00:00:00 2001 From: dlyakhov Date: Tue, 13 Jan 2026 19:43:31 +0100 Subject: [PATCH 1/2] [WC] Scale Estimation transpose_a support --- .../algorithms/weight_compression/gptq.py | 20 ++++++++++- .../weight_compression/scale_estimation.py | 33 +++++++------------ .../template_test_weights_compression.py | 22 ++++++++----- tests/onnx/common.py | 9 +++++ .../quantization/test_weights_compression.py | 16 +++++++-- tests/openvino/native/models.py | 16 ++++++--- .../quantization/test_weights_compression.py | 17 ++++++---- .../quantization/test_weights_compression.py | 4 +-- tests/torch/fx/test_weights_compression.py | 4 +-- 9 files changed, 92 insertions(+), 49 deletions(-) diff --git a/src/nncf/quantization/algorithms/weight_compression/gptq.py b/src/nncf/quantization/algorithms/weight_compression/gptq.py index aeb32adede1..c9a0397ead0 100644 --- a/src/nncf/quantization/algorithms/weight_compression/gptq.py +++ b/src/nncf/quantization/algorithms/weight_compression/gptq.py @@ -18,6 +18,7 @@ from nncf.common.graph import NNCFNode from nncf.common.logging.track_progress import track from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer +from nncf.common.tensor_statistics.statistics import WCTensorStatistic from nncf.common.utils.backend import BackendType from nncf.common.utils.backend import get_backend from nncf.parameters import CompressWeightsMode @@ -278,7 +279,7 @@ def _quantize_weights( else: if self._scale_estimation and block_compression_config.num_bits == 4: activations = [inp[..., (i1 + i) : (i1 + i + group_size)] for inp in inputs] - wc_statistics = ScaleEstimation.activations_to_wc_statistics(activations) + wc_statistics = self.activations_to_wc_statistics(activations) scale, zero_point = ScaleEstimation.calculate_quantization_params( wc_statistics, weight_tensor[:, (i1 + i) : (i1 + i + group_size)], @@ -340,3 +341,20 @@ def _quantize_weights( else: zero_points = None return scales, zero_points + + @staticmethod + def activations_to_wc_statistics(activations: list[Tensor]) -> WCTensorStatistic: + """ + Mimic the activation reducing logic from WeightCompression.get_statistic_points. + + :param activations: List of raw activations. + :return: Instance of WCTensorStatistic class containing reduced activations and shapes. + """ + mean_values = [] + shapes = [] + for act in activations: + shapes.append(act.shape) + reduction_shape = tuple(range(act.ndim - 1)) + mean_values.append(fns.mean(act, axis=reduction_shape)) + wc_statistics = WCTensorStatistic(mean_values, shapes) + return wc_statistics diff --git a/src/nncf/quantization/algorithms/weight_compression/scale_estimation.py b/src/nncf/quantization/algorithms/weight_compression/scale_estimation.py index d953a284c06..1c0cc457316 100644 --- a/src/nncf/quantization/algorithms/weight_compression/scale_estimation.py +++ b/src/nncf/quantization/algorithms/weight_compression/scale_estimation.py @@ -139,17 +139,21 @@ def apply( continue _, weight_port_id = weight_data[0] - if self._backend_entity.matmul_has_transposed_activations(wp.node_with_weight, graph): - msg = "Transposed activations are not supported yet for the Scale Estimation algorithm" - raise nncf.UnsupportedModelError(msg) - weight = self._backend_entity.get_weight(wp.node_with_weight, weight_port_id, model, graph) + activation_port_id = self._backend_entity.get_activation_port_id(wp.node_with_weight, graph) + act_shape = graph.get_input_edge_by_port_id(wp.node_with_weight, activation_port_id).tensor_shape + act_ch_axis = self._backend_entity.get_activation_channel_axis( + wp.node_with_weight, activation_port_id, act_shape + ) + act_ch_axis %= len(act_shape) + scale, zero_point = self.calculate_quantization_params( stats, weight, wp.reduction_axes, config, + act_ch_axis, self._subset_size, self._initial_steps, self._scale_steps, @@ -165,6 +169,7 @@ def calculate_quantization_params( weight: Tensor, reduction_axes: tuple[int, ...], config: WeightCompressionConfig, + act_ch_axis: int = -1, subset_size: int = 32, initial_steps: int = 5, scale_steps: int = 10, @@ -185,6 +190,7 @@ def calculate_quantization_params( :param weight: The weight tensor that is being quantized. :param reduction_axes: Tuple specifying the axes along which the reduction is performed for quantization. :param config: Configuration parameters for the weight compression, including quantization settings. + :param act_ch_axis: The activation channel axis. :param subset_size: The number of samples to use for scale estimation. Defaults to 32. :param initial_steps: The number of steps for initial scale rectification using activation statistics. Defaults to 5. @@ -195,7 +201,7 @@ def calculate_quantization_params( """ reduction_axis = reduction_axes[0] - s, X = process_stats(statistics, subset_size) + s, X = process_stats(statistics, subset_size, act_ch_axis=act_ch_axis) X = X.astype(TensorDataType.float32) weight = weight.astype(TensorDataType.float32) @@ -382,23 +388,6 @@ def calculate_quantization_params( return result_scale, zp - @staticmethod - def activations_to_wc_statistics(activations: list[Tensor]) -> WCTensorStatistic: - """ - Mimic the activation reducing logic from WeightCompression.get_statistic_points. - - :param activations: List of raw activations. - :return: Instance of WCTensorStatistic class containing reduced activations and shapes. - """ - mean_values = [] - shapes = [] - for act in activations: - shapes.append(act.shape) - reduction_shape = tuple(range(act.ndim - 1)) - mean_values.append(fns.mean(act, axis=reduction_shape)) - wc_statistics = WCTensorStatistic(mean_values, shapes) - return wc_statistics - def get_target_zero_mask(compressed_weights: Tensor, zp: Optional[Tensor] = None) -> tuple[Tensor, Tensor]: """ diff --git a/tests/cross_fw/test_templates/template_test_weights_compression.py b/tests/cross_fw/test_templates/template_test_weights_compression.py index 290e74dd46b..d565267a3f6 100644 --- a/tests/cross_fw/test_templates/template_test_weights_compression.py +++ b/tests/cross_fw/test_templates/template_test_weights_compression.py @@ -236,14 +236,14 @@ def test_mixed_precision(self, mode, all_layers, ratio, ref_ids, transpose_a, tr @staticmethod @abstractmethod - def get_model_for_test_scale_estimation() -> TModel: + def get_model_for_test_scale_estimation(transpose_a: bool) -> TModel: """ Returns a backend model for test_scale_estimation. """ @staticmethod @abstractmethod - def get_moe_model_for_test_scale_estimation() -> TModel: + def get_moe_model_for_test_scale_estimation(transpose_a: bool) -> TModel: """ Returns a backend MoE model for test_scale_estimation with 3D weights. """ @@ -266,17 +266,24 @@ def get_scale_estimation_ref(check_sampling_activation_stats_flow: bool) -> TTen Returns the reference output of calculate_quantization_params of ScaleEstimation. """ + @pytest.mark.parametrize("transpose_a", [False, True]) @pytest.mark.parametrize("is_moe", [False, True]) @pytest.mark.parametrize("check_sampling_activation_stats_flow", [False, True]) - def test_scale_estimation(self, mocker, is_moe, check_sampling_activation_stats_flow): + def test_scale_estimation( + self, mocker, transpose_a, is_moe, check_sampling_activation_stats_flow, transpose_a_supported + ): """Checks that scales match the reference.""" + if transpose_a and not transpose_a_supported: + msg = "Transpose a is not supported for the current backend" + pytest.skip(msg) + calc_q_params_spy = mocker.spy(ScaleEstimation, "calculate_quantization_params") if is_moe: - model = self.get_moe_model_for_test_scale_estimation() + model = self.get_moe_model_for_test_scale_estimation(transpose_a=transpose_a) input = np.arange(0, 2 * 4 * 8, dtype=np.float32).reshape(2, 4, 8) else: - model = self.get_model_for_test_scale_estimation() + model = self.get_model_for_test_scale_estimation(transpose_a=transpose_a) input = np.arange(0, 4 * 8, dtype=np.float32).reshape(1, 4, 8) # prepare dataset of size subset_size with input tensors @@ -325,7 +332,7 @@ def get_decompressed_weight(compressed_model: TModel, input: TTensor) -> Tensor: def test_scale_estimation_outlier_channel_has_lowest_error(self, mocker): """Checks that outlier channel has a lowest error after quantization.""" OUTLIER_CHANNEL = 4 - model = self.get_model_for_test_scale_estimation() + model = self.get_model_for_test_scale_estimation(transpose_a=False) original_weight = self.get_orig_weight(model) # prepare dataset with one input tensor @@ -801,7 +808,6 @@ def get_transposable_awq_model( @pytest.mark.parametrize( "kwargs", [ - dict(scale_estimation=True), dict(lora_correction=True), dict( gptq=True, @@ -812,8 +818,6 @@ def get_transposable_awq_model( def test_compression_skipped_with_transposed_activations(self, transpose_a_supported, kwargs): if not transpose_a_supported: pytest.skip("transpose_a is not supported for the current backend") - if kwargs.get("scale_estimation", False) and "scale_estimation" in self.get_not_supported_algorithms(): - pytest.skip("Scale estimation is not supported") if kwargs.get("gptq", False) and "gptq" in self.get_not_supported_algorithms(): pytest.skip("GPTQ is not supported") if kwargs.get("lora_correction", False) and "lora_correction" in self.get_not_supported_algorithms(): diff --git a/tests/onnx/common.py b/tests/onnx/common.py index 7584c6afa70..dfeeb0bbe45 100644 --- a/tests/onnx/common.py +++ b/tests/onnx/common.py @@ -223,6 +223,15 @@ def add_constant(self, data: np.ndarray, output: Optional[str] = None) -> str: return output + def add_squeeze(self, input: str, output: Optional[str] = None) -> str: + i = len(self._nodes) + + output = f"Squeeze_{i}_output" if output is None else output + self._nodes.append( + onnx.helper.make_node(op_type="Squeeze", inputs=[input], outputs=[output], name=f"Squeeze_{i}") + ) + return output + def add_unsqueeze(self, input: str, axes: tuple[int, ...], output: Optional[str] = None) -> str: i = len(self._nodes) diff --git a/tests/onnx/quantization/test_weights_compression.py b/tests/onnx/quantization/test_weights_compression.py index 5f8cf66966b..fa3e51121d9 100644 --- a/tests/onnx/quantization/test_weights_compression.py +++ b/tests/onnx/quantization/test_weights_compression.py @@ -493,23 +493,33 @@ def wrap_model(model: onnx.ModelProto, data: Any) -> onnx.ModelProto: return model @staticmethod - def get_model_for_test_scale_estimation() -> onnx.ModelProto: + def get_model_for_test_scale_estimation(transpose_a) -> onnx.ModelProto: """ Builds a model to be used in the following tests: - TemplateWeightCompression.test_scale_estimation() - TemplateWeightCompression.test_scale_estimation_outlier_channel_has_lowest_error() tests. """ + mb = ModelBuilder() x = mb.add_input("input", (1, 4, 8)) output = mb.add_output("output", (1, 4, 16)) weights = np.arange(0, 16 * 8, dtype=np.float32).reshape(16, 8).T - mb.add_matmul(x, shape=(8, 16), output=output, data=weights) + if transpose_a: + squeeze = mb.add_squeeze(x) + transpose = mb.add_transpose(squeeze, (1, 0)) + mb.add_gemm(transpose, shape=(8, 16), output=output, weight_data=weights, trans_a=1) + else: + mb.add_matmul(x, shape=(8, 16), output=output, data=weights) return mb.build(opset_version=21) @staticmethod - def get_moe_model_for_test_scale_estimation() -> onnx.ModelProto: + def get_moe_model_for_test_scale_estimation(transpose_a: bool) -> onnx.ModelProto: + if transpose_a: + msg = "ONNX does not support transpose_a + MoE" + pytest.skip(msg) + num_experts = 2 hidden_dim = 8 out_dim = 16 diff --git a/tests/openvino/native/models.py b/tests/openvino/native/models.py index 14e30180274..2c2e910e19c 100644 --- a/tests/openvino/native/models.py +++ b/tests/openvino/native/models.py @@ -78,14 +78,18 @@ def _create_ov_model(self, input_shape=None, reshape_shape=None, matmul_w_shape= class SimpleMoEModel(OVReferenceModel): - def _create_ov_model(self, num_experts=2, hidden_dim=8, out_dim=16, seq_len=4): + def _create_ov_model(self, num_experts=2, hidden_dim=8, out_dim=16, seq_len=4, tranpsose_a: bool = False): input_shape = [num_experts, seq_len, hidden_dim] input_1 = opset.parameter(input_shape, name="Input") weight_data = np.arange(0, num_experts * hidden_dim * out_dim, dtype=np.float32) weight_data = weight_data.reshape(num_experts, hidden_dim, out_dim) - matmul = opset.matmul(input_1, weight_data, transpose_a=False, transpose_b=False, name="MoE_MatMul") + if tranpsose_a: + transpose = opset.transpose(input_1, (0, 2, 1)) + else: + transpose = input_1 + matmul = opset.matmul(transpose, weight_data, transpose_a=False, transpose_b=False, name="MoE_MatMul") result = opset.result(matmul, name="Result") result.get_output_tensor(0).set_names(set(["Result"])) @@ -1366,13 +1370,17 @@ def _create_ov_model(self): class MatMul(OVReferenceModel): - def _create_ov_model(self): + def _create_ov_model(self, transpose_a: bool = False): input_node = opset.parameter([1, 4, 8], name="Input") weights_data = np.arange(0, 16 * 8, dtype=np.float32).reshape(16, 8) weights_node = opset.constant(weights_data, dtype=np.float32, name="Weights") - matmul_node = opset.matmul(input_node, weights_node, transpose_a=False, transpose_b=True, name="MatMul") + if transpose_a: + transpose = opset.transpose(input_node, (0, 2, 1)) + else: + transpose = input_node + matmul_node = opset.matmul(transpose, weights_node, transpose_a=transpose_a, transpose_b=True, name="MatMul") result_node = opset.result(matmul_node, name="Result") diff --git a/tests/openvino/native/quantization/test_weights_compression.py b/tests/openvino/native/quantization/test_weights_compression.py index 9d2ae47c892..cc3bfe4e01e 100644 --- a/tests/openvino/native/quantization/test_weights_compression.py +++ b/tests/openvino/native/quantization/test_weights_compression.py @@ -2165,12 +2165,12 @@ def get_sequential_matmul_model(transpose_a: bool) -> ov.Model: return SequentialMatmulModel(transpose_a=transpose_a).ov_model @staticmethod - def get_model_for_test_scale_estimation(): - return MatMul().ov_model + def get_model_for_test_scale_estimation(transpose_a: bool): + return MatMul(transpose_a=transpose_a).ov_model @staticmethod - def get_moe_model_for_test_scale_estimation(): - return SimpleMoEModel().ov_model + def get_moe_model_for_test_scale_estimation(transpose_a: bool): + return SimpleMoEModel(transpose_a=transpose_a).ov_model @staticmethod def get_awq_model(non_mergable_pattern: bool, is_3d_weights: bool) -> ov.Model: @@ -2365,10 +2365,15 @@ def get_moe_scale_estimation_ref(check_sampling_activation_stats_flow): ), )[check_sampling_activation_stats_flow] + @pytest.mark.parametrize("transpose_a", [False, True]) @pytest.mark.parametrize("is_moe", [False, pytest.param(True, marks=pytest.mark.xfail(reason="Ticket - 176465"))]) @pytest.mark.parametrize("check_sampling_activation_stats_flow", [False, True]) - def test_scale_estimation(self, mocker, is_moe, check_sampling_activation_stats_flow): - return super().test_scale_estimation(mocker, is_moe, check_sampling_activation_stats_flow) + def test_scale_estimation( + self, mocker, transpose_a, is_moe, check_sampling_activation_stats_flow, transpose_a_supported + ): + return super().test_scale_estimation( + mocker, transpose_a, is_moe, check_sampling_activation_stats_flow, transpose_a_supported + ) @pytest.mark.parametrize( "is_3d_weights", [False, pytest.param(True, marks=pytest.mark.xfail(reason="Ticket - 176465"))] diff --git a/tests/torch/function_hook/quantization/test_weights_compression.py b/tests/torch/function_hook/quantization/test_weights_compression.py index d66cfaf13d9..efbddfd5cc5 100644 --- a/tests/torch/function_hook/quantization/test_weights_compression.py +++ b/tests/torch/function_hook/quantization/test_weights_compression.py @@ -583,11 +583,11 @@ def get_sequential_matmul_model(transpose_a: bool) -> torch.nn.Module: return SequentialMatmulModel() @staticmethod - def get_model_for_test_scale_estimation(): + def get_model_for_test_scale_estimation(transpose_a: bool): return LinearModel(torch.arange(0, 8 * 16, dtype=torch.float32).reshape(16, 8)) @staticmethod - def get_moe_model_for_test_scale_estimation(): + def get_moe_model_for_test_scale_estimation(transpose_a: bool): num_experts = 2 hidden_dim = 8 out_dim = 16 diff --git a/tests/torch/fx/test_weights_compression.py b/tests/torch/fx/test_weights_compression.py index ac5b96756b0..0278501a189 100644 --- a/tests/torch/fx/test_weights_compression.py +++ b/tests/torch/fx/test_weights_compression.py @@ -347,14 +347,14 @@ def get_sequential_matmul_model(transpose_a: bool) -> torch.fx.GraphModule: return exported_model @staticmethod - def get_model_for_test_scale_estimation(): + def get_model_for_test_scale_estimation(transpose_a: bool): model = LinearModel(torch.arange(0, 8 * 16, dtype=torch.float32).reshape(16, 8)) ex_input = torch.ones([1, 4, 8], dtype=torch.float32) exported_model = get_torch_fx_model(model, ex_input) return exported_model @staticmethod - def get_moe_model_for_test_scale_estimation(): + def get_moe_model_for_test_scale_estimation(transpose_a: bool): num_experts = 2 hidden_dim = 8 out_dim = 16 From 1f70ea19aa3aadb34475fdd7cbbee3b86c409079 Mon Sep 17 00:00:00 2001 From: dlyakhov Date: Thu, 15 Jan 2026 14:18:48 +0100 Subject: [PATCH 2/2] Comments --- .../algorithms/weight_compression/awq.py | 14 +++----------- .../algorithms/weight_compression/backend.py | 17 +++++++++++++++++ .../weight_compression/scale_estimation.py | 8 +------- tests/openvino/native/models.py | 6 +++--- 4 files changed, 24 insertions(+), 21 deletions(-) diff --git a/src/nncf/quantization/algorithms/weight_compression/awq.py b/src/nncf/quantization/algorithms/weight_compression/awq.py index a31011cdb5f..926d3c91509 100644 --- a/src/nncf/quantization/algorithms/weight_compression/awq.py +++ b/src/nncf/quantization/algorithms/weight_compression/awq.py @@ -167,7 +167,9 @@ def apply( weight_dtype = weight.dtype weight = weight.astype(TensorDataType.float32) - act_ch_axis, act_shape = self._get_activation_channel_axis_and_shape(graph, wp) + act_ch_axis, act_shape = self._backend_entity.get_activation_channel_axis_and_shape( + graph, wp.node_with_weight + ) is_mergeable = False if self._backend_entity.is_node_with_weights(merge_node, graph): @@ -356,16 +358,6 @@ def _data_aware_step(self, wp, weight, statistics, act_ch_axis, prev_weight=None return scale - def _get_activation_channel_axis_and_shape( - self, graph: NNCFGraph, wp: WeightCompressionParameters - ) -> tuple[int, tuple[int, ...]]: - activation_port_id = self._backend_entity.get_activation_port_id(wp.node_with_weight, graph) - act_shape = graph.get_input_edge_by_port_id(wp.node_with_weight, activation_port_id).tensor_shape - act_ch_axis = self._backend_entity.get_activation_channel_axis( - wp.node_with_weight, activation_port_id, act_shape - ) - return act_ch_axis % len(act_shape), act_shape - @staticmethod def _clamp_scale(magnitudes, threshold, scale, clamped_scale): return fns.where(magnitudes < threshold, scale, clamped_scale) diff --git a/src/nncf/quantization/algorithms/weight_compression/backend.py b/src/nncf/quantization/algorithms/weight_compression/backend.py index d1ddf8f99dc..877d3641cb1 100644 --- a/src/nncf/quantization/algorithms/weight_compression/backend.py +++ b/src/nncf/quantization/algorithms/weight_compression/backend.py @@ -296,6 +296,23 @@ def get_activation_channel_axis(node: NNCFNode, port_id: int, input_shape: tuple :return: Channel axis number. """ + def get_activation_channel_axis_and_shape( + self, + graph: NNCFGraph, + node: NNCFNode, + ) -> tuple[int, tuple[int, ...]]: + """ + Returns the activation channel axis and activation tensor shape for a node with weights. + + :param graph: NNCF graph containing the model structure and tensor metadata. + :param node: NNCF node with weights instance. + :return: A tuple consisting of activation channel axis and activation tensor shape. + """ + activation_port_id = self.get_activation_port_id(node, graph) + act_shape = graph.get_input_edge_by_port_id(node, activation_port_id).tensor_shape + act_ch_axis = self.get_activation_channel_axis(node, activation_port_id, act_shape) + return act_ch_axis % len(act_shape), act_shape + class AWQAlgoBackend(WeightCompressionAlgoBackend): @staticmethod diff --git a/src/nncf/quantization/algorithms/weight_compression/scale_estimation.py b/src/nncf/quantization/algorithms/weight_compression/scale_estimation.py index 1c0cc457316..0c64d4770ac 100644 --- a/src/nncf/quantization/algorithms/weight_compression/scale_estimation.py +++ b/src/nncf/quantization/algorithms/weight_compression/scale_estimation.py @@ -140,13 +140,7 @@ def apply( _, weight_port_id = weight_data[0] weight = self._backend_entity.get_weight(wp.node_with_weight, weight_port_id, model, graph) - - activation_port_id = self._backend_entity.get_activation_port_id(wp.node_with_weight, graph) - act_shape = graph.get_input_edge_by_port_id(wp.node_with_weight, activation_port_id).tensor_shape - act_ch_axis = self._backend_entity.get_activation_channel_axis( - wp.node_with_weight, activation_port_id, act_shape - ) - act_ch_axis %= len(act_shape) + act_ch_axis, _ = self._backend_entity.get_activation_channel_axis_and_shape(graph, wp.node_with_weight) scale, zero_point = self.calculate_quantization_params( stats, diff --git a/tests/openvino/native/models.py b/tests/openvino/native/models.py index 2c2e910e19c..e607f39f281 100644 --- a/tests/openvino/native/models.py +++ b/tests/openvino/native/models.py @@ -78,18 +78,18 @@ def _create_ov_model(self, input_shape=None, reshape_shape=None, matmul_w_shape= class SimpleMoEModel(OVReferenceModel): - def _create_ov_model(self, num_experts=2, hidden_dim=8, out_dim=16, seq_len=4, tranpsose_a: bool = False): + def _create_ov_model(self, num_experts=2, hidden_dim=8, out_dim=16, seq_len=4, transpose_a: bool = False): input_shape = [num_experts, seq_len, hidden_dim] input_1 = opset.parameter(input_shape, name="Input") weight_data = np.arange(0, num_experts * hidden_dim * out_dim, dtype=np.float32) weight_data = weight_data.reshape(num_experts, hidden_dim, out_dim) - if tranpsose_a: + if transpose_a: transpose = opset.transpose(input_1, (0, 2, 1)) else: transpose = input_1 - matmul = opset.matmul(transpose, weight_data, transpose_a=False, transpose_b=False, name="MoE_MatMul") + matmul = opset.matmul(transpose, weight_data, transpose_a=transpose_a, transpose_b=False, name="MoE_MatMul") result = opset.result(matmul, name="Result") result.get_output_tensor(0).set_names(set(["Result"]))