From 5dd9645d4e1b16eb4b44aad51dd6344eb3d411ea Mon Sep 17 00:00:00 2001 From: dlyakhov Date: Thu, 6 Nov 2025 18:14:44 +0100 Subject: [PATCH 1/4] [Rebase] Support transposed input for data-aware Weights Compression --- src/nncf/openvino/graph/nncf_graph_builder.py | 2 +- .../weight_compression/algorithm.py | 30 ++++++++++------ .../algorithms/weight_compression/backend.py | 11 ++++++ .../algorithms/weight_compression/gptq.py | 35 +++++++++++-------- .../weight_compression/mixed_precision.py | 2 +- .../weight_compression/onnx_backend.py | 5 +++ .../weight_compression/openvino_backend.py | 18 +++++++--- .../weight_compression/scale_estimation.py | 5 ++- .../weight_compression/torch_backend.py | 4 +++ .../weight_compression/torch_fx_backend.py | 4 +++ .../openvino/native/quantization/test_gptq.py | 7 ++-- .../quantization/test_weights_compression.py | 28 ++++++++++++--- 12 files changed, 111 insertions(+), 40 deletions(-) diff --git a/src/nncf/openvino/graph/nncf_graph_builder.py b/src/nncf/openvino/graph/nncf_graph_builder.py index edc6e4018eb..1eb9a3540fa 100644 --- a/src/nncf/openvino/graph/nncf_graph_builder.py +++ b/src/nncf/openvino/graph/nncf_graph_builder.py @@ -101,7 +101,7 @@ def _add_edges_to_nncf_graph(model: ov.Model, graph: NNCFGraph) -> None: in_node_id = graph.get_node_by_name(op.get_friendly_name()).node_id for output_port_id, out in enumerate(op.outputs()): node_vs_target_inputs = defaultdict(list) - for inp in out.get_target_inputs(): + for inp in sorted(out.get_target_inputs(), key=lambda inp: inp.get_node().get_friendly_name()): node_vs_target_inputs[inp.get_node()].append(inp) for out_node, inputs in node_vs_target_inputs.items(): diff --git a/src/nncf/quantization/algorithms/weight_compression/algorithm.py b/src/nncf/quantization/algorithms/weight_compression/algorithm.py index b126374741a..421b0207697 100644 --- a/src/nncf/quantization/algorithms/weight_compression/algorithm.py +++ b/src/nncf/quantization/algorithms/weight_compression/algorithm.py @@ -1028,19 +1028,22 @@ def apply( ) return transformed_model - def _get_activation_node_and_port(self, node: NNCFNode, nncf_graph: NNCFGraph) -> tuple[NNCFNode, int]: + def _get_activation_node_port_and_channel(self, node: NNCFNode, nncf_graph: NNCFGraph) -> tuple[NNCFNode, int]: """ - This method returns the activation layer and corresponding port id for the node. + This method returns the activation layer, corresponding port id and channel axis for the given node. :param node: NNCFGraph node for which the activation is sought. :param nncf_graph: NNCFGraph instance with the node. - :return: Tuple with the activation node and port id. + :return: Tuple with the activation node, port id and channel axis. """ activation_port = self._backend_entity.get_activation_port_id(node, nncf_graph) activation_edge = nncf_graph.get_input_edge_by_port_id(node, activation_port) activation_node = activation_edge.from_node port_id = activation_edge.output_port_id - return activation_node, port_id + activation_channel_axis = self._backend_entity.get_activation_channel_axis( + node, port_id, activation_edge.tensor_shape + ) + return activation_node, port_id, activation_channel_axis def get_matmul_input_to_output_nodes_map( self, matmul_nodes: list[NNCFNode], graph: NNCFGraph @@ -1061,8 +1064,8 @@ def get_matmul_input_to_output_nodes_map( """ matmul_input_to_output_nodes_map = defaultdict(list) for node in matmul_nodes: - act_node, output_port_id = self._get_activation_node_and_port(node, graph) - matmul_input_to_output_nodes_map[(act_node, output_port_id)].append(node) + act_node, output_port_id, act_channel_axis = self._get_activation_node_port_and_channel(node, graph) + matmul_input_to_output_nodes_map[(act_node, output_port_id, act_channel_axis)].append(node) return matmul_input_to_output_nodes_map def get_compression_nodes_info( @@ -1130,7 +1133,11 @@ def get_statistic_points( # Statistics for data aware algorithms if self._data_aware_compression: - for (node, output_port_id), node_with_weights in matmul_input_to_output_nodes_map.items(): + for ( + node, + output_port_id, + input_channel_axis, + ), node_with_weights in matmul_input_to_output_nodes_map.items(): statistic_point = self._backend_entity.target_point( TargetType.POST_LAYER_OPERATION, node.node_name, port_id=output_port_id ) @@ -1145,10 +1152,11 @@ def get_statistic_points( ] all_weight_dims.extend(weight_dims) - # by default, reduce activations across all but the last dimension. The last dimension is - # assumed to be the hidden size dimension. + # Reduce activations across all but the hidden dimension. n_dims = len(graph.get_output_edges_by_port_id(node, output_port_id)[0].tensor_shape) - reduction_axes = tuple(range(n_dims - 1)) + # negative axis (e.g. -1 for the last axis) is converted into corresponding positive value + input_channel_axis = input_channel_axis % n_dims + reduction_axes = tuple(i for i in range(n_dims) if i != input_channel_axis) # For 3D weights, hidden dimension is the second dimension. Reduce by all other dimensions reduction_axes = (1,) if any(weight_dim == 3 for weight_dim in all_weight_dims) else reduction_axes @@ -1191,7 +1199,7 @@ def _get_statistics_for_weights_compression( # Where mean_value is a 1D tensor representing an activation reduced over batch and sequence length dimensions, # shape is an original shape of an activation before reduction, n is the size of the dataset (or subset_size). statistics = {} - for (act_node, output_port_id), matmul_nodes in matmul_input_to_output_nodes_map.items(): + for (act_node, output_port_id, _), matmul_nodes in matmul_input_to_output_nodes_map.items(): tensor_collectors = list( statistic_points.get_algo_statistics_for_node( act_node.node_name, diff --git a/src/nncf/quantization/algorithms/weight_compression/backend.py b/src/nncf/quantization/algorithms/weight_compression/backend.py index d429b127152..aa9cd5131b5 100644 --- a/src/nncf/quantization/algorithms/weight_compression/backend.py +++ b/src/nncf/quantization/algorithms/weight_compression/backend.py @@ -274,6 +274,17 @@ def get_ignored_patterns() -> GraphPattern: :return: backend-specific ignored patterns. """ + @staticmethod + @abstractmethod + def get_activation_channel_axis(node: NNCFNode, port_id: int, input_shape: tuple[int]) -> int: + """ + Returns axis number of the activation tensor which correspond to it channel. + :param node: NNCFNode instance. + :param port_id: Port ID for input. + :param input_shape: Shape of the input. + :return: Channel axis number. + """ + class AWQAlgoBackend(WeightCompressionAlgoBackend): @staticmethod diff --git a/src/nncf/quantization/algorithms/weight_compression/gptq.py b/src/nncf/quantization/algorithms/weight_compression/gptq.py index b90f2e0574b..d7e2919c04f 100644 --- a/src/nncf/quantization/algorithms/weight_compression/gptq.py +++ b/src/nncf/quantization/algorithms/weight_compression/gptq.py @@ -125,10 +125,14 @@ def apply( ]: continue _, input_tensors = next(iter(inputs.items())) - hessian = self._calculate_hessian(node, input_tensors) - scale, zero_point = self._quantize_weights(model, graph, wc_params, hessian, input_tensors) + input_channel_axis = self._backend_entity.get_activation_channel_axis( + node, self._backend_entity.get_activation_port_id(node, graph), input_tensors[0].shape + ) + hessian = self._calculate_hessian(node, input_tensors, input_channel_axis) + scale, zero_point = self._quantize_weights( + model, graph, wc_params, hessian, input_tensors, input_channel_axis + ) res[wc_params.weight_name] = CompressedWeight(None, scale, zero_point, None) - return model, res def get_statistic_points( @@ -158,7 +162,7 @@ def get_statistic_points( return self._layerwise_engine.get_statistic_points(model, graph, filtered_nodes) - def _calculate_hessian(self, node: NNCFNode, inputs: list[Tensor]) -> Tensor: + def _calculate_hessian(self, node: NNCFNode, inputs: list[Tensor], input_channel_axis: int) -> Tensor: """ Calculates the Hessian matrix for the given node and inputs. @@ -171,19 +175,18 @@ def _calculate_hessian(self, node: NNCFNode, inputs: list[Tensor]) -> Tensor: if node.metatype in self._backend_entity.convolution_metatypes: msg = "Convolution metatypes are not supported" raise nncf.UnsupportedModelError(msg) - if node.layer_attributes.input_attributes["transpose"]: - msg = "Transposed input is not supported" - raise nncf.UnsupportedModelError(msg) hessian = fns.zeros( - (inputs[0].shape[-1], inputs[0].shape[-1]), backend=inputs[0].backend, dtype=TensorDataType.float32 + (inputs[0].shape[input_channel_axis], inputs[0].shape[input_channel_axis]), + backend=inputs[0].backend, + dtype=TensorDataType.float32, ) for inp in inputs: batch_size = 1 if len(inp.shape) == 2 else inp.shape[0] if node.metatype in self._backend_entity.matmul_metatypes: if len(inp.shape) == 3: - inp = inp.reshape((-1, inp.shape[-1])) + inp = inp.reshape((-1, inp.shape[input_channel_axis])) inp = fns.transpose(inp) hessian *= nsamples / (nsamples + batch_size) nsamples += batch_size @@ -199,6 +202,7 @@ def _quantize_weights( wc_params: WeightCompressionParameters, hessian: Tensor, inputs: list[Tensor], + input_channel_axis: int, ): """ Quantizes the weights of the model based on the calculated Hessian matrix. @@ -211,10 +215,7 @@ def _quantize_weights( """ if wc_params.node_with_weight.metatype in self._backend_entity.convolution_metatypes: msg = "Convolution metatypes are not supported" - raise RuntimeError(msg) - if not wc_params.node_with_weight.layer_attributes.constant_attributes[wc_params.weight_port_id]["transpose"]: - msg = "Transpose is not supported" - raise RuntimeError(msg) + raise nncf.UnsupportedModelError(msg) weight_tensor = self._backend_entity.get_weight( wc_params.node_with_weight, wc_params.weight_port_id, model, graph @@ -272,8 +273,12 @@ def _quantize_weights( scales.append(scale) else: if self._scale_estimation and block_compression_config.num_bits == 4: - activations = [inp[..., (i1 + i) : (i1 + i + group_size)] for inp in inputs] - wc_statistics = ScaleEstimation.activations_to_wc_statistics(activations) + slicing_along_axis = [slice(None)] * len(inputs[0].shape) + slicing_along_axis[input_channel_axis] = slice(i1 + i, i1 + i + group_size) + activations = [inp[tuple(slicing_along_axis)] for inp in inputs] + wc_statistics = ScaleEstimation.activations_to_wc_statistics( + activations, input_channel_axis + ) scale, zero_point = ScaleEstimation.calculate_quantization_params( wc_statistics, weight_tensor[:, (i1 + i) : (i1 + i + group_size)], diff --git a/src/nncf/quantization/algorithms/weight_compression/mixed_precision.py b/src/nncf/quantization/algorithms/weight_compression/mixed_precision.py index fdb51b8d69c..c4b3223fc3e 100644 --- a/src/nncf/quantization/algorithms/weight_compression/mixed_precision.py +++ b/src/nncf/quantization/algorithms/weight_compression/mixed_precision.py @@ -281,7 +281,7 @@ def get_statistic_points( self._set_backend_entity(model) statistic_container = StatisticPointsContainer() - for act_node, output_port_id in nodes_and_port_ids: + for act_node, output_port_id, _ in nodes_and_port_ids: n_dims = len(graph.get_output_edges_by_port_id(act_node, output_port_id)[0].tensor_shape) if n_dims < 2: msg = ( diff --git a/src/nncf/quantization/algorithms/weight_compression/onnx_backend.py b/src/nncf/quantization/algorithms/weight_compression/onnx_backend.py index 735ba9a2a3e..49f9c9afdfb 100644 --- a/src/nncf/quantization/algorithms/weight_compression/onnx_backend.py +++ b/src/nncf/quantization/algorithms/weight_compression/onnx_backend.py @@ -38,6 +38,7 @@ from nncf.onnx.graph.model_transformer import remove_initializer from nncf.onnx.graph.model_transformer import remove_node from nncf.onnx.graph.model_transformer import set_initializer +from nncf.onnx.graph.node_utils import get_act_quantization_axis from nncf.onnx.graph.node_utils import get_weight_quantization_axis from nncf.onnx.graph.onnx_helper import ONNX_DTYPE_TO_NNCF_DTYPE from nncf.onnx.graph.onnx_helper import get_name_to_node_map @@ -301,6 +302,10 @@ def filter_func(point: StatisticPoint) -> bool: return filter_func + @staticmethod + def get_activation_channel_axis(node: NNCFNode, port_id: int, input_shape: tuple[int]) -> int: + return get_act_quantization_axis(node, port_id) + def insert_adapters( self, wc_params: WeightCompressionParameters, lora_A: Tensor, lora_B: Tensor, int8_lora: bool ) -> None: diff --git a/src/nncf/quantization/algorithms/weight_compression/openvino_backend.py b/src/nncf/quantization/algorithms/weight_compression/openvino_backend.py index 3ec241b36c6..ea3bb55724c 100644 --- a/src/nncf/quantization/algorithms/weight_compression/openvino_backend.py +++ b/src/nncf/quantization/algorithms/weight_compression/openvino_backend.py @@ -35,6 +35,7 @@ from nncf.openvino.graph.node_utils import convert_op from nncf.openvino.graph.node_utils import create_ov_codebook_subgraph from nncf.openvino.graph.node_utils import create_ov_const_from_tensor +from nncf.openvino.graph.node_utils import get_activation_channel_axis from nncf.openvino.graph.node_utils import get_const_value_as_numpy_tensor from nncf.openvino.graph.node_utils import get_const_value_as_ov_tensor from nncf.openvino.graph.node_utils import get_weight_channel_axes @@ -118,9 +119,6 @@ def mean_statistic_collector( @staticmethod def get_activation_port_id(node: NNCFNode, nncf_graph: NNCFGraph) -> int: - if node.layer_attributes.input_attributes["transpose"]: - msg = "Transposed input is not supported" - raise nncf.UnsupportedModelError(msg) constant_ports = node.layer_attributes.get_const_port_ids() activation_ports = [ e.input_port_id for e in nncf_graph.get_input_edges(node) if e.input_port_id not in constant_ports @@ -137,6 +135,9 @@ def get_weight_names_and_port_ids(node: NNCFNode, graph: NNCFGraph) -> list[tupl return result def get_weight(self, node_with_weight: NNCFNode, weight_port_id: int, model: ov.Model, graph: NNCFGraph) -> Tensor: + if not node_with_weight.layer_attributes.constant_attributes[weight_port_id]["transpose"]: + msg = "Only transposed weights are supported" + raise nncf.UnsupportedModelError(msg) weight_name = node_with_weight.layer_attributes.constant_attributes[weight_port_id]["name"] weight_node = self.name_to_node_mapping[weight_name] weight_tensor = get_const_value_as_numpy_tensor(weight_node) @@ -203,7 +204,12 @@ def insert_adapters( A_W = opset.constant(lora_A.data) B_W = opset.constant(lora_B.data) - A_MM = opset.matmul(input_node, A_W, transpose_a=False, transpose_b=True) + A_MM = opset.matmul( + input_node, + A_W, + transpose_a=wc_params.node_with_weight.layer_attributes.input_attributes["transpose"], + transpose_b=True, + ) B_MM = opset.matmul(A_MM, B_W, transpose_a=False, transpose_b=True) node_output_port = mm_node.output(0) @@ -399,6 +405,10 @@ def get_ignored_patterns() -> GraphPattern: pattern.add_pattern_alternative(create_sam_pe()) return pattern + @staticmethod + def get_activation_channel_axis(node: NNCFNode, port_id: int, input_shape: tuple[int]) -> int: + return get_activation_channel_axis(node, port_id, input_shape) + class OVTensorWeightCompressionAlgoBackend(OVWeightCompressionAlgoBackend): """ diff --git a/src/nncf/quantization/algorithms/weight_compression/scale_estimation.py b/src/nncf/quantization/algorithms/weight_compression/scale_estimation.py index 4ad557b9868..b5fd8b4ef70 100644 --- a/src/nncf/quantization/algorithms/weight_compression/scale_estimation.py +++ b/src/nncf/quantization/algorithms/weight_compression/scale_estimation.py @@ -382,7 +382,7 @@ def calculate_quantization_params( return result_scale, zp @staticmethod - def activations_to_wc_statistics(activations: list[Tensor]) -> WCTensorStatistic: + def activations_to_wc_statistics(activations: list[Tensor], input_channel_axis: int) -> WCTensorStatistic: """ Mimic the activation reducing logic from WeightCompression.get_statistic_points. @@ -393,6 +393,9 @@ def activations_to_wc_statistics(activations: list[Tensor]) -> WCTensorStatistic shapes = [] for act in activations: shapes.append(act.shape) + # negative axis (e.g. -1 for the last axis) is converted into corresponding positive value + input_channel_axis = input_channel_axis % len(act.shape) + reduction_shape = tuple(i for i in range(len(act.shape)) if i != input_channel_axis) reduction_shape = tuple(range(act.ndim - 1)) mean_values.append(fns.mean(act, axis=reduction_shape)) wc_statistics = WCTensorStatistic(mean_values, shapes) diff --git a/src/nncf/quantization/algorithms/weight_compression/torch_backend.py b/src/nncf/quantization/algorithms/weight_compression/torch_backend.py index a153e3e85f1..d3da1c02639 100644 --- a/src/nncf/quantization/algorithms/weight_compression/torch_backend.py +++ b/src/nncf/quantization/algorithms/weight_compression/torch_backend.py @@ -486,6 +486,10 @@ def get_ignored_patterns() -> GraphPattern: pattern.add_pattern_alternative(create_sam_pe()) return pattern + @staticmethod + def get_activation_channel_axis(node: NNCFNode, port_id: int, input_shape: tuple[int]) -> int: + return node.metatype.output_channel_axis + class PTAWQAlgoAlgoBackend(AWQAlgoBackend, PTWeightCompressionAlgoBackend): @staticmethod diff --git a/src/nncf/quantization/algorithms/weight_compression/torch_fx_backend.py b/src/nncf/quantization/algorithms/weight_compression/torch_fx_backend.py index f1c1a49a269..2be7099d670 100644 --- a/src/nncf/quantization/algorithms/weight_compression/torch_fx_backend.py +++ b/src/nncf/quantization/algorithms/weight_compression/torch_fx_backend.py @@ -262,6 +262,10 @@ def get_ignored_patterns() -> GraphPattern: pattern.add_pattern_alternative(create_sam_pe()) return pattern + @staticmethod + def get_activation_channel_axis(node: NNCFNode, port_id: int, input_shape: tuple[int]) -> int: + return node.metatype.output_channel_axis + class FXMixedPrecisionAlgoBackend(MixedPrecisionAlgoBackend, FXWeightCompressionAlgoBackend): pass diff --git a/tests/openvino/native/quantization/test_gptq.py b/tests/openvino/native/quantization/test_gptq.py index 28439cc276e..42b43dc5bea 100644 --- a/tests/openvino/native/quantization/test_gptq.py +++ b/tests/openvino/native/quantization/test_gptq.py @@ -350,7 +350,10 @@ def test_calculate_scale_linear(): nodes = graph.get_all_nodes() wrapped_inputs = [Tensor(inp) for inp in inputs] - H = gptq._calculate_hessian(nodes[1], wrapped_inputs) + input_channel_axis = gptq._backend_entity.get_activation_channel_axis( + nodes[1], gptq._backend_entity.get_activation_port_id(nodes[1], graph), wrapped_inputs[0].shape + ) + H = gptq._calculate_hessian(nodes[1], wrapped_inputs, input_channel_axis) ref_H = ref_gptq.H.numpy() assert np.all(np.isclose(ref_H, H.data)) @@ -365,7 +368,7 @@ def test_calculate_scale_linear(): ) wc_params.compression_config = WeightCompressionConfig(mode=CompressWeightsMode.INT4_SYM, group_size=16) - scale, _ = gptq._quantize_weights(ov_model, graph, wc_params, H, wrapped_inputs) + scale, _ = gptq._quantize_weights(ov_model, graph, wc_params, H, wrapped_inputs, input_channel_axis) ref_scale = ref_scale.numpy() scale = scale.reshape(ref_scale.shape) assert np.all(np.isclose(ref_scale, scale.data)) diff --git a/tests/openvino/native/quantization/test_weights_compression.py b/tests/openvino/native/quantization/test_weights_compression.py index 02c113f722f..4c7dd1c9ef5 100644 --- a/tests/openvino/native/quantization/test_weights_compression.py +++ b/tests/openvino/native/quantization/test_weights_compression.py @@ -12,7 +12,8 @@ import inspect import os from collections import defaultdict -from typing import Callable +from contextlib import nullcontext +from typing import Callable, Optional from unittest.mock import patch import numpy as np @@ -104,7 +105,9 @@ class LMLinearModel(OVReferenceModel): HIDDEN_DIM = 16 INPUT_SHAPE = [1, 24, HIDDEN_DIM] # [B, SeqLen, HiddenDim] - def _create_ov_model(self, transpose_b: bool = True, transpose_a=False, input_shape=None): + def _create_ov_model( + self, transpose_b: bool = True, transpose_a: bool = False, input_shape: Optional[list[int]] = None + ): self._input_shape = self.INPUT_SHAPE if input_shape is None else input_shape hdim_axis = -2 if transpose_a else -1 self._hidden_dim = self._input_shape[hdim_axis] @@ -1940,6 +1943,16 @@ def test_compression_with_different_algo_combinations(input_shape, kwargs): ) +@pytest.mark.parametrize( + ("transpose_a", "transpose_b", "raises_error"), + [ + (False, True, False), + (True, True, False), + (False, False, True), + (True, False, True), + ], + ids=["tb_nota", "ta_tb", "nota_notb", "ta_notb"], +) @pytest.mark.parametrize( "kwargs", [ @@ -1952,14 +1965,19 @@ def test_compression_with_different_algo_combinations(input_shape, kwargs): advanced_parameters=CompressionParams(gptq_params=GPTQParams(subset_size=2)), ), ], + ids=["se", "lora", "gptq_se_awq"], ) -def test_compression_with_transposed_activations(kwargs): +def test_compression_with_transpose(transpose_a, transpose_b, raises_error, kwargs): dataset_size = 4 - model = LMLinearModel(transpose_a=True, transpose_b=False).ov_model + model = LMLinearModel(transpose_a=transpose_a, transpose_b=transpose_b).ov_model input_data = [np.ones(inp.shape) for inp in model.inputs] * dataset_size dataset = Dataset(input_data) - with pytest.raises(nncf.UnsupportedModelError): + with ( + pytest.raises(nncf.UnsupportedModelError) + if raises_error and not kwargs.get("lora_correction", False) + else nullcontext() + ): compress_weights( model, mode=CompressWeightsMode.INT4_SYM, From 075b9f7a51d1c48ba7bb6eae882238eed546c7d3 Mon Sep 17 00:00:00 2001 From: dlyakhov Date: Fri, 7 Nov 2025 13:28:16 +0100 Subject: [PATCH 2/4] Fix rebase --- .../algorithms/weight_compression/scale_estimation.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/nncf/quantization/algorithms/weight_compression/scale_estimation.py b/src/nncf/quantization/algorithms/weight_compression/scale_estimation.py index b5fd8b4ef70..e107597f181 100644 --- a/src/nncf/quantization/algorithms/weight_compression/scale_estimation.py +++ b/src/nncf/quantization/algorithms/weight_compression/scale_estimation.py @@ -396,7 +396,6 @@ def activations_to_wc_statistics(activations: list[Tensor], input_channel_axis: # negative axis (e.g. -1 for the last axis) is converted into corresponding positive value input_channel_axis = input_channel_axis % len(act.shape) reduction_shape = tuple(i for i in range(len(act.shape)) if i != input_channel_axis) - reduction_shape = tuple(range(act.ndim - 1)) mean_values.append(fns.mean(act, axis=reduction_shape)) wc_statistics = WCTensorStatistic(mean_values, shapes) return wc_statistics From 353f50f0bbffc989a35a9216f9da653e4ee73ef5 Mon Sep 17 00:00:00 2001 From: dlyakhov Date: Fri, 7 Nov 2025 14:23:31 +0100 Subject: [PATCH 3/4] Comments --- tests/openvino/native/models.py | 36 +++++++++++++++++++ .../quantization/test_weights_compression.py | 13 +++++-- 2 files changed, 47 insertions(+), 2 deletions(-) diff --git a/tests/openvino/native/models.py b/tests/openvino/native/models.py index 0e0ef99b40a..51d2ae039c0 100644 --- a/tests/openvino/native/models.py +++ b/tests/openvino/native/models.py @@ -1084,6 +1084,42 @@ def _create_ov_model(self, is_int8=False, with_multiply=False, n_layers=8): return model +class AWQModel(OVReferenceModel): + OUTPUT_DIM = 32 + HIDDEN_DIM = 16 + INPUT_SHAPE = [1, 24, HIDDEN_DIM] # [B, SeqLen, HiddenDim] + + def _create_ov_model( + self, + transpose_b: bool = True, + transpose_a: bool = False, + input_shape: Optional[list[int]] = None, + is_int8=False, + ): + self._input_shape = self.INPUT_SHAPE if input_shape is None else input_shape + hdim_axis = -2 if transpose_a else -1 + self._hidden_dim = self._input_shape[hdim_axis] + input_1 = opset.parameter(self._input_shape, name="Input") + weight_shape = self.get_weight_shape(transpose_b) + data = self._rng.random(weight_shape).astype(np.float32) + + weights = AWQMatmulModel.get_weights(data, is_int8=is_int8, name="weights_1") + + matmul = opset.matmul(input_1, weights, transpose_a=transpose_a, transpose_b=transpose_b, name="MatMul") + + result = opset.result(matmul, name="Result") + result.get_output_tensor(0).set_names(set(["Result"])) + model = ov.Model([result], [input_1]) + return model + + @property + def hidden_dim(self): + return self._hidden_dim + + def get_weight_shape(self, transpose_b: bool = True): + return [self.OUTPUT_DIM, self.hidden_dim] if transpose_b else [self.hidden_dim, self.OUTPUT_DIM] + + class AWQModel_fp16_overlow(OVReferenceModel): """ Model for testing AWQ algorithm with fp16 overflow fix. diff --git a/tests/openvino/native/quantization/test_weights_compression.py b/tests/openvino/native/quantization/test_weights_compression.py index 4c7dd1c9ef5..bd1063fa5aa 100644 --- a/tests/openvino/native/quantization/test_weights_compression.py +++ b/tests/openvino/native/quantization/test_weights_compression.py @@ -65,6 +65,7 @@ from tests.openvino.native.common import get_actual_reference_for_current_openvino from tests.openvino.native.models import AWQActMatmulModel from tests.openvino.native.models import AWQMatmulModel +from tests.openvino.native.models import AWQModel from tests.openvino.native.models import AWQModel_fp16_overlow from tests.openvino.native.models import DifferentChannelSizeMatmulModel from tests.openvino.native.models import GatherAndMatmulShareData @@ -1943,6 +1944,14 @@ def test_compression_with_different_algo_combinations(input_shape, kwargs): ) +@pytest.mark.parametrize( + "model_cls", + [ + (LMLinearModel), + (AWQModel), + ], + ids=["lm_linear", "awq_model"], +) @pytest.mark.parametrize( ("transpose_a", "transpose_b", "raises_error"), [ @@ -1967,9 +1976,9 @@ def test_compression_with_different_algo_combinations(input_shape, kwargs): ], ids=["se", "lora", "gptq_se_awq"], ) -def test_compression_with_transpose(transpose_a, transpose_b, raises_error, kwargs): +def test_compression_with_transpose(model_cls, transpose_a, transpose_b, raises_error, kwargs): dataset_size = 4 - model = LMLinearModel(transpose_a=transpose_a, transpose_b=transpose_b).ov_model + model = model_cls(transpose_a=transpose_a, transpose_b=transpose_b).ov_model input_data = [np.ones(inp.shape) for inp in model.inputs] * dataset_size dataset = Dataset(input_data) From c1f35a2d5a6722a3a2ef8f204982887e4943d227 Mon Sep 17 00:00:00 2001 From: dlyakhov Date: Mon, 24 Nov 2025 16:37:35 +0100 Subject: [PATCH 4/4] get_activation_channel_axis for torch backend --- .../weight_compression/torch_backend.py | 3 +- .../weight_compression/torch_fx_backend.py | 3 +- src/nncf/torch/node_utils.py | 42 +++++++++++++++++++ tests/torch/test_node_utils.py | 42 +++++++++++++++++++ 4 files changed, 88 insertions(+), 2 deletions(-) create mode 100644 src/nncf/torch/node_utils.py create mode 100644 tests/torch/test_node_utils.py diff --git a/src/nncf/quantization/algorithms/weight_compression/torch_backend.py b/src/nncf/quantization/algorithms/weight_compression/torch_backend.py index d3da1c02639..73ed2d0b475 100644 --- a/src/nncf/quantization/algorithms/weight_compression/torch_backend.py +++ b/src/nncf/quantization/algorithms/weight_compression/torch_backend.py @@ -61,6 +61,7 @@ from nncf.torch.model_graph_manager import split_const_name from nncf.torch.model_transformer import PTModelTransformer from nncf.torch.nncf_network import NNCFNetwork +from nncf.torch.node_utils import get_activation_channel_axis as get_activation_channel_axis_util from nncf.torch.quantization.ignored_patterns import create_rope from nncf.torch.quantization.ignored_patterns import create_sam_pe from nncf.torch.quantization.layers import QUANTIZATION_MODULES @@ -488,7 +489,7 @@ def get_ignored_patterns() -> GraphPattern: @staticmethod def get_activation_channel_axis(node: NNCFNode, port_id: int, input_shape: tuple[int]) -> int: - return node.metatype.output_channel_axis + return get_activation_channel_axis_util(node, port_id) class PTAWQAlgoAlgoBackend(AWQAlgoBackend, PTWeightCompressionAlgoBackend): diff --git a/src/nncf/quantization/algorithms/weight_compression/torch_fx_backend.py b/src/nncf/quantization/algorithms/weight_compression/torch_fx_backend.py index 2be7099d670..7b9ba9706fd 100644 --- a/src/nncf/quantization/algorithms/weight_compression/torch_fx_backend.py +++ b/src/nncf/quantization/algorithms/weight_compression/torch_fx_backend.py @@ -56,6 +56,7 @@ from nncf.torch.model_graph_manager import get_const_node from nncf.torch.model_graph_manager import get_weight_compression_reduction_axes from nncf.torch.model_graph_manager import get_weight_tensor_port_ids +from nncf.torch.node_utils import get_activation_channel_axis as get_activation_channel_axis_util from nncf.torch.quantization.ignored_patterns import create_rope from nncf.torch.quantization.ignored_patterns import create_sam_pe from nncf.torch.quantization.layers import INT4AsymmetricWeightsDecompressor @@ -264,7 +265,7 @@ def get_ignored_patterns() -> GraphPattern: @staticmethod def get_activation_channel_axis(node: NNCFNode, port_id: int, input_shape: tuple[int]) -> int: - return node.metatype.output_channel_axis + return get_activation_channel_axis_util(node, port_id) class FXMixedPrecisionAlgoBackend(MixedPrecisionAlgoBackend, FXWeightCompressionAlgoBackend): diff --git a/src/nncf/torch/node_utils.py b/src/nncf/torch/node_utils.py new file mode 100644 index 00000000000..3b139248d7c --- /dev/null +++ b/src/nncf/torch/node_utils.py @@ -0,0 +1,42 @@ +# Copyright (c) 2025 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import nncf +import nncf.torch.graph.operator_metatypes as op +from nncf.common.graph import NNCFNode +from nncf.torch.graph.operator_metatypes import PTAddmmMetatype +from nncf.torch.graph.operator_metatypes import PTMatMulMetatype + + +def get_activation_channel_axis(node: NNCFNode, port_id: int) -> int: + """ + Returns axis number of the activation tensor which correspond to it channel. + + :param node: NNCFNode instance. + :param port_id: Port ID for input. + :return: Channel axis number. + """ + if node.metatype not in op.CONVOLUTION_METATYPES + op.MATMUL_METATYPES + op.UNIFICATION_PRODUCING_METATYPES: + msg = f"Activation channel axis retrieval from node with metatype {node.metatype} is not supported" + raise nncf.InternalError(msg) + + if node.metatype not in [PTMatMulMetatype, PTAddmmMetatype]: + return node.metatype.output_channel_axis + + if port_id == 0: + # X(port:0) * W(port:1): [..., C_IN] * [... , C_IN, C_OUT] + return -1 + if port_id == 1: + # W(port:0) * X(port:1): [... , C_OUT, C_IN] * [... , C_IN, ...] + return -2 + + msg = f"Port id for a {node.metatype} operation is expected to be in [0, 1], {port_id} recieved" + raise nncf.InternalError(msg) diff --git a/tests/torch/test_node_utils.py b/tests/torch/test_node_utils.py new file mode 100644 index 00000000000..f878c3e22a2 --- /dev/null +++ b/tests/torch/test_node_utils.py @@ -0,0 +1,42 @@ +# Copyright (c) 2025 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +import nncf +import nncf.torch.graph.operator_metatypes as op +from nncf.common.graph import NNCFNode +from nncf.torch.node_utils import get_activation_channel_axis + + +@pytest.mark.parametrize( + "metatype,port_id,ref_out", + ( + (op.PTLinearMetatype, 0, -1), + (op.PTModuleLinearMetatype, 0, -1), + (op.PTConv2dMetatype, 0, 1), + (op.PTDepthwiseConv2dSubtype, 0, 1), + (op.PTConvTranspose2dMetatype, 0, 1), + (op.PTMatMulMetatype, 0, -1), + (op.PTMatMulMetatype, 1, -2), + (op.PTAddmmMetatype, 0, -1), + (op.PTAddmmMetatype, 1, -2), + (op.PTMatMulMetatype, 2, "error"), + (op.PTAddMetatype, 0, "error"), + ), +) +def test_get_activation_channel_axis(metatype, port_id, ref_out): + node = NNCFNode({"metatype": metatype}) + if ref_out == "error": + with pytest.raises(nncf.InternalError): + get_activation_channel_axis(node, port_id) + else: + assert get_activation_channel_axis(node, port_id) == ref_out