diff --git a/nncf/quantization/advanced_parameters.py b/nncf/quantization/advanced_parameters.py index 37658878b81..10f18b34eae 100644 --- a/nncf/quantization/advanced_parameters.py +++ b/nncf/quantization/advanced_parameters.py @@ -276,6 +276,9 @@ class AdvancedAWQParameters: :type alpha_max: float :param steps: The number of the steps in grid search. :type steps: int + :param prefer_data_aware_scaling: Determines whether to use activations to calculate scales if + activations are presented. + :type prefer_data_aware_scaling: bool """ subset_size: int = 32 @@ -283,6 +286,7 @@ class AdvancedAWQParameters: alpha_min: float = 0.0 alpha_max: float = 1.0 steps: int = 100 + prefer_data_aware_scaling: bool = True @api() diff --git a/nncf/quantization/algorithms/weight_compression/algorithm.py b/nncf/quantization/algorithms/weight_compression/algorithm.py index 79363722934..9ac54e144b9 100644 --- a/nncf/quantization/algorithms/weight_compression/algorithm.py +++ b/nncf/quantization/algorithms/weight_compression/algorithm.py @@ -302,6 +302,7 @@ def __init__( awq_params.alpha_min, awq_params.alpha_max, awq_params.steps, + awq_params.prefer_data_aware_scaling, ) if self._gptq: gptq_params = self._advanced_parameters.gptq_params @@ -323,7 +324,12 @@ def __init__( self._data_aware_mixed_precision = ( self._sensitivity_metric != SensitivityMetric.WEIGHT_QUANTIZATION_ERROR and self._ratio != 1.0 ) - self._data_aware_compression = self._awq or self._scale_estimation or self._lora_correction or self._gptq + self._data_aware_compression = ( + (self._awq and self._advanced_parameters.awq_params.prefer_data_aware_scaling) + or self._scale_estimation + or self._lora_correction + or self._gptq + ) @property def available_backends(self) -> list[BackendType]: @@ -546,7 +552,7 @@ def apply( nodes_to_compress = self.get_nodes_to_compress(graph) statistics = None - if self._data_aware_mixed_precision or self._data_aware_compression: + if (self._data_aware_mixed_precision or self._data_aware_compression) and dataset: matmul_nodes_to_compress = [ node for node in nodes_to_compress if node.metatype in self._backend_entity.matmul_metatypes ] diff --git a/nncf/quantization/algorithms/weight_compression/awq.py b/nncf/quantization/algorithms/weight_compression/awq.py index 3c614d228d9..1e3d09706f9 100644 --- a/nncf/quantization/algorithms/weight_compression/awq.py +++ b/nncf/quantization/algorithms/weight_compression/awq.py @@ -66,6 +66,7 @@ def __init__( alpha_min: float = 0.0, alpha_max: float = 1.0, steps: int = 100, + prefer_data_aware_scaling: bool = True, ): """ :param subset_size: The number of samples for AWQ. @@ -73,6 +74,7 @@ def __init__( :param alpha_min: Minimum value of smoothness parameter for grid search. :param alpha_max: Maximal value of smoothness parameter for grid search. :param steps: The number of the steps in grid search. + :param prefer_data_aware_scaling: Determines whether to use activations to calculate scales. """ super().__init__() self._subset_size = subset_size @@ -80,6 +82,7 @@ def __init__( self._alpha_min = alpha_min self._alpha_max = alpha_max self._steps = steps + self._prefer_data_aware_scaling = prefer_data_aware_scaling self._backend_entity = None self._patterns = None self._scale_per_target_node = {} @@ -121,7 +124,7 @@ def apply( graph: NNCFGraph, all_weight_params: list[WeightCompressionParameters], nodes_to_compress: list[NNCFNode], - statistics: dict[str, WCTensorStatistic], + statistics: Optional[dict[str, WCTensorStatistic]] = None, wc_backend_entity: Optional[WeightCompressionAlgoBackend] = None, ) -> TModel: """ @@ -135,156 +138,41 @@ def apply( :return: A resulting model. """ self._set_backend_entity(model, wc_backend_entity) - matches = [] - inference_nncf_graph = transform_to_inference_graph(deepcopy(graph), [], [], [], []) - nx_graph = inference_nncf_graph.get_nx_graph_copy() - for pattern_graph in self._patterns.values(): - matches.extend(find_subgraphs_matching_pattern(nx_graph, pattern_graph(), strict=False)) - - if len(matches) == 0: - nncf_logger.info("No matching patterns were found for applying AWQ algorithm, it will be skipped.") + awq_data = self._get_awq_data(graph, all_weight_params, nodes_to_compress) + if len(awq_data) == 0: return model transformation_layout = TransformationLayout() model_transformer = ModelTransformerFactory.create(model, inplace=True) - awq_data = {} - name_mapping = {wp.weight_name: idx for idx, wp in enumerate(all_weight_params)} - - for match in matches: - nncf_node = graph.get_node_by_key(match[-1]) - if not self._backend_entity.is_node_with_weights(nncf_node, graph): - continue - - target_node_names = [] - for weight_op_friendly_name, _ in self._backend_entity.get_weight_names_and_port_ids(nncf_node, graph): - target_node_names.append(weight_op_friendly_name) - - # skip node if it is in IgnoredScope or should not be compressed - if target_node_names[-1] not in name_mapping: - continue - - weight_params = all_weight_params[name_mapping[target_node_names[-1]]] - - if weight_params.compression_config.num_bits != 4: - continue - target_node = nodes_to_compress[name_mapping[target_node_names[-1]]] - - # avoid matching different patterns for the same node - if target_node.node_name in awq_data: - continue + is_data_free = statistics is None or not self._prefer_data_aware_scaling - nncf_node = graph.get_node_by_key(match[0]) + description = "Applying data-free AWQ" if is_data_free else "Applying data-aware AWQ" - if self._backend_entity.is_node_with_weights(nncf_node, graph): # pattern MatMul->Multiply->MatMul - merge_node_names = [] - for weight_op_friendly_name, _ in self._backend_entity.get_weight_names_and_port_ids(nncf_node, graph): - merge_node_names.append(weight_op_friendly_name) - merge_node = nodes_to_compress[name_mapping[merge_node_names[-1]]] - else: # pattern Act->MatMul or Act->Multiply->MatMul - merge_node = nncf_node - - awq_data[target_node.node_name] = AWQCompressionInfo(weight_params, target_node, merge_node) - - alpha_step = (self._alpha_max - self._alpha_min) / self._steps - - for k, awq_data_item in track(awq_data.items(), description="Applying AWQ"): + for k, awq_data_item in track(awq_data.items(), description=description): wp = awq_data_item.weight_params - target_node = awq_data_item.target_node merge_node = awq_data_item.merge_node weight_data = self._backend_entity.get_weight_names_and_port_ids(wp.node_with_weight, graph) if len(weight_data) != 1: # not supported by the algorithm continue - nncf_logger.debug(f"Apply AWQ for: {wp.node_with_weight.node_name}") + nncf_logger.debug(f"{description} for: {wp.node_with_weight.node_name}") _, weight_port_id = weight_data[0] - - config = wp.compression_config - - s, X = process_stats(statistics[k], self._subset_size) - s = s.astype(TensorDataType.float32) - X = X.astype(TensorDataType.float32) - - top_k = max(int(s.shape[0] * self._percent_to_apply), 1) - topk_idxs = fns.argsort(-s)[:top_k] - - group_size = config.group_size - if group_size == -1: - group_size = s.shape[0] - - groups_to_correct = set() - for idx in topk_idxs: - groups_to_correct.add(idx.data // group_size) - - groups_to_correct = list(groups_to_correct) - weight = self._backend_entity.get_weight( wp.node_with_weight, weight_port_id, model, graph ) # get_const_value(wp.weight_node) weight_dtype = weight.dtype weight = weight.astype(TensorDataType.float32) - assert isinstance(wp.reduction_axes, tuple) and len(wp.reduction_axes) == 1 - reduction_axis = wp.reduction_axes[0] - - if reduction_axis == 0: - weight = fns.transpose(weight) - reduction_axis = 1 - - shape_vector = fns.mean(X, axis=1) - scale = fns.ones_like(shape_vector) - - awq_config = deepcopy(config) - awq_config.group_size = -1 - - for gi in groups_to_correct: - offset = gi * group_size - gscale = s[offset : offset + group_size] - - a_min = fns.astype(fns.quantile(gscale, 0.1), TensorDataType.float32) - a_max = 1e2 - gscale = fns.clip(gscale, a_min=a_min, a_max=a_max) - - gweight = weight[:, offset : offset + group_size] - gacts = X[offset : offset + group_size, :] - - fp32_out = fns.matmul(gweight, gacts) - min_diff = fns.max(fns.abs(fp32_out)) - best_scale = None - - alpha = self._alpha_min - for _ in range(self._steps): - cur_scale = gscale**alpha - weights_to_fake_quantize = gweight * cur_scale - if config.mode == CompressWeightsMode.NF4: - g_c_scale = calculate_nf4_scale(weights_to_fake_quantize, reduction_axis) - g_compressed_weighs = do_nf4_quantization(weights_to_fake_quantize, g_c_scale) - g_decompressed_weighs = do_nf4_dequantization(g_compressed_weighs, g_c_scale) - else: - g_decompressed_weighs = quantize_dequantize_weight( - weights_to_fake_quantize, awq_config, reduction_axis - ) - sacts = gacts / fns.unsqueeze(cur_scale, 1) - - cur_out = fns.matmul(g_decompressed_weighs, sacts) - cur_diff = fns.mean(fns.abs(cur_out - fp32_out)) - if cur_diff < min_diff: - min_diff = cur_diff - best_scale = cur_scale - alpha += alpha_step - - if best_scale is not None: - scale.data[offset : offset + group_size] = best_scale.data - - a_scale = scale - w_scale = scale - if wp.reduction_axes[0] == 0: - w_scale = fns.unsqueeze(w_scale, 1) - a_scale = fns.unsqueeze(1.0 / a_scale, 0) + + if is_data_free: + scale = self._data_free_step(weight) else: - w_scale = fns.unsqueeze(w_scale, 0) - a_scale = fns.unsqueeze(1.0 / a_scale, 1) + scale = self._data_aware_step(wp, weight, statistics[k]) + + w_scale = fns.unsqueeze(scale, 1 - wp.reduction_axes[0]) + a_scale = fns.unsqueeze(1.0 / scale, wp.reduction_axes[0]) scaled_weight = (weight * w_scale).astype(weight_dtype) self._backend_entity.set_weight(wp.node_with_weight, weight_port_id, model, graph, scaled_weight) @@ -312,7 +200,148 @@ def apply( return transformed_model + def _data_aware_step(self, wp, weight, statistics): + alpha_step = (self._alpha_max - self._alpha_min) / self._steps + config = wp.compression_config + s, X = process_stats(statistics, self._subset_size) + s = s.astype(TensorDataType.float32) + X = X.astype(TensorDataType.float32) + + top_k = max(int(s.shape[0] * self._percent_to_apply), 1) + topk_idxs = fns.argsort(-s)[:top_k] + + group_size = config.group_size + if group_size == -1: + group_size = s.shape[0] + + groups_to_correct = set() + for idx in topk_idxs: + groups_to_correct.add(idx.data // group_size) + + groups_to_correct = list(groups_to_correct) + + assert isinstance(wp.reduction_axes, tuple) and len(wp.reduction_axes) == 1 + reduction_axis = wp.reduction_axes[0] + + if reduction_axis == 0: + weight = fns.transpose(weight) + reduction_axis = 1 + + shape_vector = fns.mean(X, axis=1) + scale = fns.ones_like(shape_vector) + + awq_config = deepcopy(config) + awq_config.group_size = -1 + + for gi in groups_to_correct: + offset = gi * group_size + gscale = s[offset : offset + group_size] + + a_min = fns.astype(fns.quantile(gscale, 0.1), TensorDataType.float32) + a_max = 1e2 + gscale = fns.clip(gscale, a_min=a_min, a_max=a_max) + + gweight = weight[:, offset : offset + group_size] + gacts = X[offset : offset + group_size, :] + + fp32_out = fns.matmul(gweight, gacts) + min_diff = fns.max(fns.abs(fp32_out)) + best_scale = None + + alpha = self._alpha_min + for _ in range(self._steps): + cur_scale = gscale**alpha + weights_to_fake_quantize = gweight * cur_scale + if config.mode == CompressWeightsMode.NF4: + g_c_scale = calculate_nf4_scale(weights_to_fake_quantize, reduction_axis) + g_compressed_weighs = do_nf4_quantization(weights_to_fake_quantize, g_c_scale) + g_decompressed_weighs = do_nf4_dequantization(g_compressed_weighs, g_c_scale) + else: + g_decompressed_weighs = quantize_dequantize_weight( + weights_to_fake_quantize, awq_config, reduction_axis + ) + sacts = gacts / fns.unsqueeze(cur_scale, 1) + + cur_out = fns.matmul(g_decompressed_weighs, sacts) + cur_diff = fns.mean(fns.abs(cur_out - fp32_out)) + if cur_diff < min_diff: + min_diff = cur_diff + best_scale = cur_scale + alpha += alpha_step + + if best_scale is not None: + scale.data[offset : offset + group_size] = best_scale.data + + return scale + + def _data_free_step(self, weight): + eps = fns.finfo(weight).eps + scale = fns.maximum(fns.mean(fns.abs(weight), axis=0), eps) + return 1 / scale + + def _get_awq_data( + self, graph: NNCFGraph, all_weight_params: list[WeightCompressionParameters], nodes_to_compress: list[NNCFNode] + ) -> dict[str, AWQCompressionInfo]: + """ + Finds awq patterns in graph and returns it. + :param graph: Model graph. + :param all_weight_params: list of all weight parameters. + :param nodes_to_compress: list of nodes for processing. + :return: A dict with node names and matched AWQ patterns. + """ + matches = [] + inference_nncf_graph = transform_to_inference_graph(deepcopy(graph), [], [], [], []) + nx_graph = inference_nncf_graph.get_nx_graph_copy() + for pattern_graph in self._patterns.values(): + matches.extend(find_subgraphs_matching_pattern(nx_graph, pattern_graph(), strict=False)) + + if len(matches) == 0: + nncf_logger.info("No matching patterns were found for applying AWQ algorithm, it will be skipped.") + return {} + + awq_data = {} + name_mapping = {wp.weight_name: idx for idx, wp in enumerate(all_weight_params)} + + for match in matches: + nncf_node = graph.get_node_by_key(match[-1]) + if not self._backend_entity.is_node_with_weights(nncf_node, graph): + continue + + target_node_names = [] + for weight_op_friendly_name, _ in self._backend_entity.get_weight_names_and_port_ids(nncf_node, graph): + target_node_names.append(weight_op_friendly_name) + + # skip node if it is in IgnoredScope or should not be compressed + if target_node_names[-1] not in name_mapping: + continue + + weight_params = all_weight_params[name_mapping[target_node_names[-1]]] + + if weight_params.compression_config.num_bits != 4: + continue + target_node = nodes_to_compress[name_mapping[target_node_names[-1]]] + + # avoid matching different patterns for the same node + if target_node.node_name in awq_data: + continue + + nncf_node = graph.get_node_by_key(match[0]) + + if self._backend_entity.is_node_with_weights(nncf_node, graph): # pattern MatMul->Multiply->MatMul + merge_node_names = [] + for weight_op_friendly_name, _ in self._backend_entity.get_weight_names_and_port_ids(nncf_node, graph): + merge_node_names.append(weight_op_friendly_name) + merge_node = nodes_to_compress[name_mapping[merge_node_names[-1]]] + else: # pattern Act->MatMul or Act->Multiply->MatMul + merge_node = nncf_node + + awq_data[target_node.node_name] = AWQCompressionInfo(weight_params, target_node, merge_node) + return awq_data + def update_statistics(self, statistics): + if not statistics: + return statistics + # Multiply activations by the computed scales for node_name, scale in self._scale_per_target_node.items(): for mean_stat in statistics[node_name].mean_values: diff --git a/nncf/quantization/quantize_model.py b/nncf/quantization/quantize_model.py index 41a90b212cd..c8921a07063 100644 --- a/nncf/quantization/quantize_model.py +++ b/nncf/quantization/quantize_model.py @@ -593,13 +593,12 @@ def compress_weights( elif backend == BackendType.OPENVINO: from nncf.openvino.quantization.quantize_model import compress_weights_impl as ov_compress_weights_impl - if any((awq, scale_estimation, gptq, lora_correction)) and ( - dataset is None or mode == CompressWeightsMode.E2M1 - ): - msg = ( - "Scale estimation, AWQ, GPTQ or Lora Correction algorithm is defined, " - "but dataset is None or mode is E2M1." - ) + if any((scale_estimation, gptq, lora_correction)) and dataset is None: + msg = "Scale estimation, GPTQ or Lora Correction algorithm is defined, but dataset is None." + raise nncf.ParameterNotSupportedError(msg) + + if any((awq, scale_estimation, gptq, lora_correction)) and mode == CompressWeightsMode.E2M1: + msg = "AWQ, Scale estimation, GPTQ or Lora Correction algorithm is defined, but mode is E2M1." raise nncf.ParameterNotSupportedError(msg) if gptq and lora_correction: diff --git a/tests/cross_fw/test_templates/template_test_weights_compression.py b/tests/cross_fw/test_templates/template_test_weights_compression.py index 75830005654..75f880f0fba 100644 --- a/tests/cross_fw/test_templates/template_test_weights_compression.py +++ b/tests/cross_fw/test_templates/template_test_weights_compression.py @@ -23,6 +23,8 @@ from nncf.data.dataset import Dataset from nncf.errors import InvalidGroupSizeError from nncf.quantization import compress_weights +from nncf.quantization.advanced_parameters import AdvancedAWQParameters as AWQParams +from nncf.quantization.advanced_parameters import AdvancedCompressionParameters as CompressionParams from nncf.quantization.algorithms.weight_compression.awq import AWQ from nncf.quantization.algorithms.weight_compression.config import WeightCompressionConfig from nncf.quantization.algorithms.weight_compression.mixed_precision import MIXED_PRECISION_CRITERIA @@ -124,6 +126,13 @@ def get_not_supported_algorithms() -> list[str]: Returns a list of not supported weight compression algorithms. """ + @staticmethod + @abstractmethod + def wrap_model(model, data) -> CompressionParams: + """ + Returns model wrapped with backend specific graph. + """ + @pytest.mark.parametrize( ("mode", "all_layers", "ratio", "ref_ids"), ( @@ -359,3 +368,38 @@ def test_error_message_for_invalid_group_size(self, algorithm): name_list = [name.strip('"') for name in names[0].split(",")] compress_weights(**kwargs, ignored_scope=IgnoredScope(names=name_list)) + + @pytest.mark.parametrize("dataset", [None, np.ones([1, 8, 8], dtype=np.float32)]) + @pytest.mark.parametrize("prefer_data_aware_scaling", [True, False]) + def test_data_free_awq(self, dataset, prefer_data_aware_scaling, mocker): + input_data = np.ones([1, 8, 8], dtype=np.float32) + + n_layers = 8 + n_awq_target = n_layers - 1 # first MatMul is always int8 + model = self.get_awq_act_model(True, n_layers) + model = self.wrap_model(model, input_data) + + if dataset is not None: + dataset = Dataset([self.to_tensor(dataset)]) + + fn_name = "_data_free_step" if dataset is None or not prefer_data_aware_scaling else "_data_aware_step" + + collect_spy = mocker.spy(AWQ, fn_name) + + compressed_model = compress_weights( + model, + mode=CompressWeightsMode.INT4_ASYM, + ratio=1.0, + group_size=-1, + dataset=dataset, + awq=True, + advanced_parameters=CompressionParams( + awq_params=AWQParams( + prefer_data_aware_scaling=prefer_data_aware_scaling, + ) + ), + ) + + n_awq = self.get_num_multiply_from_awq(compressed_model) + assert n_awq == n_awq_target + assert collect_spy.call_count == n_awq, f"Statistics should be collected {n_awq_target} times." diff --git a/tests/openvino/native/quantization/test_weights_compression.py b/tests/openvino/native/quantization/test_weights_compression.py index 6122f273114..5935b2265b7 100644 --- a/tests/openvino/native/quantization/test_weights_compression.py +++ b/tests/openvino/native/quantization/test_weights_compression.py @@ -701,7 +701,6 @@ def test_raise_error_with_unsupported_params_for_e2m1(algo): "algo", ( "lora_correction", - "awq", "scale_estimation", "gptq", ), @@ -1568,6 +1567,10 @@ def check_weights(model: ov.Model, ref_ids: list[int]) -> None: def get_not_supported_algorithms() -> list[str]: return [] + @staticmethod + def wrap_model(model, data): + return model + @staticmethod def get_scale_estimation_ref(): return np.array( diff --git a/tests/post_training/data/wc_reference_data.yaml b/tests/post_training/data/wc_reference_data.yaml index 055eedb2327..8824f966fe1 100644 --- a/tests/post_training/data/wc_reference_data.yaml +++ b/tests/post_training/data/wc_reference_data.yaml @@ -102,3 +102,11 @@ tinyllama_data_aware_awq_scale_estimation_backend_FX_TORCH: type: "FileNotFoundError" error_message: "Openvino Model Files Not Found!" message: "Issue-165013" +tinyllama_data_free_awq_backend_OV: + metric_value: 0.85466 + num_int4: 94 + num_int8: 124 +tinyllama_data_free_awq_backend_TORCH: + metric_value: 0.85466 + num_int4: 94 + num_int8: 124 diff --git a/tests/post_training/data/wc_test_durations.json b/tests/post_training/data/wc_test_durations.json index f411ecaed25..ecf1abd6645 100644 --- a/tests/post_training/data/wc_test_durations.json +++ b/tests/post_training/data/wc_test_durations.json @@ -11,5 +11,7 @@ "tests/post_training/test_quantize_conformance.py::test_weight_compression[tinyllama_int4_data_free_backend_TORCH]": 133, "tests/post_training/test_quantize_conformance.py::test_weight_compression[tinyllama_int8_data_free_backend_TORCH]": 154, "tests/post_training/test_quantize_conformance.py::test_weight_compression[tinyllama_NF4_scale_estimation_stateful_per_channel_backend_OV]": 256, - "tests/post_training/test_quantize_conformance.py::test_weight_compression[tinyllama_scale_estimation_per_channel_backend_OV]": 258 + "tests/post_training/test_quantize_conformance.py::test_weight_compression[tinyllama_scale_estimation_per_channel_backend_OV]": 258, + "tests/post_training/test_quantize_conformance.py::test_weight_compression[tinyllama_data_free_awq_backend_OV]": 200, + "tests/post_training/test_quantize_conformance.py::test_weight_compression[tinyllama_data_free_awq_backend_TORCH]": 200 } diff --git a/tests/post_training/model_scope.py b/tests/post_training/model_scope.py index e25912c18e0..9ed937bcefd 100644 --- a/tests/post_training/model_scope.py +++ b/tests/post_training/model_scope.py @@ -17,6 +17,7 @@ from nncf.parameters import BackupMode from nncf.parameters import CompressWeightsMode from nncf.parameters import SensitivityMetric +from nncf.quantization.advanced_parameters import AdvancedAWQParameters from nncf.quantization.advanced_parameters import AdvancedCompressionParameters from nncf.quantization.advanced_parameters import AdvancedLoraCorrectionParameters from nncf.quantization.advanced_parameters import AdvancedQuantizationParameters @@ -551,6 +552,22 @@ }, "backends": [BackendType.OV], }, + { + "reported_name": "tinyllama_data_free_awq", + "model_id": "tinyllama/tinyllama-1.1b-step-50k-105b", + "pipeline_cls": LMWeightCompression, + "compression_params": { + "group_size": 64, + "ratio": 0.8, + "mode": CompressWeightsMode.INT4_SYM, + "awq": True, + "advanced_parameters": AdvancedCompressionParameters( + awq_params=AdvancedAWQParameters(prefer_data_aware_scaling=False) + ), + }, + # TODO: (andreyanufr) add torch.fx backend + "backends": [BackendType.OV, BackendType.TORCH], + }, ] diff --git a/tests/torch2/function_hook/quantization/test_weights_compression.py b/tests/torch2/function_hook/quantization/test_weights_compression.py index 014275d4c71..9fdd8c9fa87 100644 --- a/tests/torch2/function_hook/quantization/test_weights_compression.py +++ b/tests/torch2/function_hook/quantization/test_weights_compression.py @@ -489,6 +489,13 @@ def check_weights(model: torch.nn.Module, ref_ids: list[int]) -> None: def get_not_supported_algorithms() -> list[str]: return ["lora_correction", "gptq"] + @staticmethod + def wrap_model(model, data): + model = wrap_model(model) + data = torch.tensor(data) + wrapped_model = GraphModelWrapper(model, example_input=data) + return wrapped_model + @staticmethod def get_scale_estimation_ref(): return torch.tensor( diff --git a/tests/torch2/fx/test_compress_weights.py b/tests/torch2/fx/test_compress_weights.py index 28a799eb262..80dc2437d14 100644 --- a/tests/torch2/fx/test_compress_weights.py +++ b/tests/torch2/fx/test_compress_weights.py @@ -369,6 +369,13 @@ def check_weights(model: torch.fx.GraphModule, ref_ids: list[int]) -> None: def get_not_supported_algorithms() -> list[str]: return ["lora_correction", "gptq"] + @staticmethod + def wrap_model(model, data): + if isinstance(model, torch.fx.GraphModule): + return model + data = torch.tensor(data) + return get_torch_fx_model(model, data) + @staticmethod def get_scale_estimation_ref(): return torch.tensor(