diff --git a/.ci/cspell_dict.txt b/.ci/cspell_dict.txt index 804d7608db0..b95dd5b5b00 100644 --- a/.ci/cspell_dict.txt +++ b/.ci/cspell_dict.txt @@ -318,6 +318,7 @@ ovhw ovlstm ovmvn ovroi +ovselu pbar perchannel pertensor @@ -339,6 +340,7 @@ pthw ptnncf ptprelu ptrelu +ptselu ptsilu ptwc pymodules diff --git a/examples/llm_compression/torch/qat_with_lora/main.py b/examples/llm_compression/torch/qat_with_lora/main.py index 71300c50896..29ec03c7ff2 100644 --- a/examples/llm_compression/torch/qat_with_lora/main.py +++ b/examples/llm_compression/torch/qat_with_lora/main.py @@ -34,12 +34,12 @@ import nncf.torch from nncf.common.logging.track_progress import track from nncf.data.dataset import Dataset -from nncf.experimental.torch2.function_hook.wrapper import get_hook_storage from nncf.parameters import CompressionFormat from nncf.parameters import CompressWeightsMode from nncf.parameters import StripFormat from nncf.quantization.advanced_parameters import AdvancedCompressionParameters from nncf.quantization.quantize_model import compress_weights +from nncf.torch.function_hook.wrapper import get_hook_storage from nncf.torch.model_creation import load_from_config from nncf.torch.quantization.layers import AsymmetricLoraQuantizer from nncf.torch.quantization.layers import SymmetricLoraQuantizer diff --git a/nncf/experimental/common/check_feature.py b/nncf/common/check_features.py similarity index 86% rename from nncf/experimental/common/check_feature.py rename to nncf/common/check_features.py index e28b58d0115..b91aee03656 100644 --- a/nncf/experimental/common/check_feature.py +++ b/nncf/common/check_features.py @@ -12,7 +12,7 @@ import os -def is_torch_tracing_by_torch_function_mode() -> bool: +def is_torch_tracing_by_patching() -> bool: """ Checks if legacy torch tracing is enabled by environment variable NNCF_TORCH_LEGACY_TRACING. @@ -21,4 +21,4 @@ def is_torch_tracing_by_torch_function_mode() -> bool: :return: True if legacy torch tracing is enabled, False otherwise. """ - return os.getenv("NNCF_TORCH_LEGACY_TRACING", "").lower() not in ["1", "on", "true"] + return os.getenv("NNCF_TORCH_LEGACY_TRACING", "").lower() in ["1", "on", "true"] diff --git a/nncf/common/factory.py b/nncf/common/factory.py index efa1e593613..9cecbb1dca5 100644 --- a/nncf/common/factory.py +++ b/nncf/common/factory.py @@ -12,6 +12,7 @@ from typing import Any, TypeVar, cast import nncf +from nncf.common.check_features import is_torch_tracing_by_patching from nncf.common.engine import Engine from nncf.common.graph.graph import NNCFGraph from nncf.common.graph.model_transformer import ModelTransformer @@ -20,7 +21,6 @@ from nncf.common.utils.backend import BackendType from nncf.common.utils.backend import get_backend from nncf.data.dataset import Dataset -from nncf.experimental.common.check_feature import is_torch_tracing_by_torch_function_mode TModel = TypeVar("TModel") @@ -54,7 +54,7 @@ def create(model: TModel) -> NNCFGraph: return FXGraphConverter.create_nncf_graph(cast(GraphModule, model)) if model_backend == BackendType.TORCH: - from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper + from nncf.torch.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper from nncf.torch.nncf_network import NNCFNetwork if isinstance(model, GraphModelWrapper): @@ -90,13 +90,13 @@ def create(model: TModel, inplace: bool = False) -> ModelTransformer[Any]: from nncf.openvino.graph.model_transformer import OVModelTransformer return OVModelTransformer(cast(Model, model), inplace=inplace) - if model_backend == BackendType.TORCH and is_torch_tracing_by_torch_function_mode(): - from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper - from nncf.experimental.torch2.model_transformer import PT2ModelTransformer + if model_backend == BackendType.TORCH and not is_torch_tracing_by_patching(): + from nncf.torch.function_hook.model_transformer import PT2ModelTransformer + from nncf.torch.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper return PT2ModelTransformer(cast(GraphModelWrapper, model)) - if model_backend == BackendType.TORCH and not is_torch_tracing_by_torch_function_mode(): + if model_backend == BackendType.TORCH and is_torch_tracing_by_patching(): from nncf.torch.model_transformer import PTModelTransformer from nncf.torch.nncf_network import NNCFNetwork @@ -137,8 +137,8 @@ def create(model: TModel) -> Engine: if model_backend in (BackendType.TORCH, BackendType.TORCH_FX): from torch.nn import Module - from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper from nncf.torch.engine import PTEngine + from nncf.torch.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper if isinstance(model, GraphModelWrapper): pt_model = model.model @@ -191,12 +191,12 @@ def create(model: TModel, dataset: Dataset) -> aggregator.StatisticsAggregator: from nncf.openvino.statistics.aggregator import OVStatisticsAggregator return OVStatisticsAggregator(dataset) - if model_backend == BackendType.TORCH and not is_torch_tracing_by_torch_function_mode(): + if model_backend == BackendType.TORCH and is_torch_tracing_by_patching(): from nncf.torch.statistics.aggregator import PTStatisticsAggregator return PTStatisticsAggregator(dataset) - if model_backend == BackendType.TORCH and is_torch_tracing_by_torch_function_mode(): - from nncf.experimental.torch2.statistics.aggregator import PT2StatisticsAggregator + if model_backend == BackendType.TORCH and not is_torch_tracing_by_patching(): + from nncf.torch.function_hook.statistics.aggregator import PT2StatisticsAggregator return PT2StatisticsAggregator(dataset) if model_backend == BackendType.TORCH_FX: diff --git a/nncf/common/utils/backend.py b/nncf/common/utils/backend.py index d088d856eee..450bfdae9e0 100644 --- a/nncf/common/utils/backend.py +++ b/nncf/common/utils/backend.py @@ -15,7 +15,7 @@ from packaging import version import nncf -from nncf.experimental.common.check_feature import is_torch_tracing_by_torch_function_mode +from nncf.common.check_features import is_torch_tracing_by_patching try: import openvino # type: ignore # noqa: F401 @@ -56,12 +56,11 @@ def is_torch_model(model: Any) -> bool: import torch import torch.fx - from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper + from nncf.torch.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper - if is_torch_tracing_by_torch_function_mode(): - return isinstance(model, (GraphModelWrapper, torch.nn.Module)) and not isinstance(model, torch.fx.GraphModule) - - return not isinstance(model, torch.fx.GraphModule) and isinstance(model, torch.nn.Module) + if is_torch_tracing_by_patching(): + return not isinstance(model, torch.fx.GraphModule) and isinstance(model, torch.nn.Module) + return isinstance(model, (GraphModelWrapper, torch.nn.Module)) and not isinstance(model, torch.fx.GraphModule) @result_verifier diff --git a/nncf/experimental/torch/sparsify_activations/torch_backend.py b/nncf/experimental/torch/sparsify_activations/torch_backend.py index e760bc18444..59861aa34c5 100644 --- a/nncf/experimental/torch/sparsify_activations/torch_backend.py +++ b/nncf/experimental/torch/sparsify_activations/torch_backend.py @@ -22,11 +22,11 @@ from nncf.common.graph.transformations.commands import TargetType from nncf.data import Dataset from nncf.experimental.torch.sparsify_activations.sparsify_activations_impl import SparsifyActivationsAlgoBackend -from nncf.experimental.torch2.commands import PT2InsertionCommand -from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper -from nncf.experimental.torch2.function_hook.wrapper import get_hook_storage -from nncf.experimental.torch2.model_transformer import PT2ModelTransformer from nncf.tensor.functions.torch_numeric import quantile +from nncf.torch.function_hook.commands import PT2InsertionCommand +from nncf.torch.function_hook.model_transformer import PT2ModelTransformer +from nncf.torch.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper +from nncf.torch.function_hook.wrapper import get_hook_storage from nncf.torch.graph import operator_metatypes as om from nncf.torch.graph.transformations.commands import PTTargetPoint from nncf.torch.graph.transformations.layout import PTTransformationLayout diff --git a/nncf/experimental/torch2/function_hook/__init__.py b/nncf/experimental/torch2/function_hook/__init__.py deleted file mode 100644 index 7c3c801d807..00000000000 --- a/nncf/experimental/torch2/function_hook/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright (c) 2025 Intel Corporation -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from nncf.experimental.torch2.function_hook.graph.build_graph_mode import build_graph as build_graph -from nncf.experimental.torch2.function_hook.wrapper import get_hook_storage as get_hook_storage -from nncf.experimental.torch2.function_hook.wrapper import is_wrapped as is_wrapped -from nncf.experimental.torch2.function_hook.wrapper import register_post_function_hook as register_post_function_hook -from nncf.experimental.torch2.function_hook.wrapper import register_pre_function_hook as register_pre_function_hook -from nncf.experimental.torch2.function_hook.wrapper import wrap_model as wrap_model diff --git a/nncf/quantization/algorithms/fast_bias_correction/torch_backend.py b/nncf/quantization/algorithms/fast_bias_correction/torch_backend.py index 4ee15603465..f880c9dd36d 100644 --- a/nncf/quantization/algorithms/fast_bias_correction/torch_backend.py +++ b/nncf/quantization/algorithms/fast_bias_correction/torch_backend.py @@ -20,10 +20,10 @@ from nncf.common.graph.model_transformer import ModelTransformer from nncf.common.graph.transformations.commands import TargetType from nncf.experimental.common.tensor_statistics.collectors import TensorCollector -from nncf.experimental.torch2.function_hook.extractor import extract_model -from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper from nncf.quantization.algorithms.fast_bias_correction.backend import FastBiasCorrectionAlgoBackend from nncf.tensor import Tensor +from nncf.torch.function_hook.extractor import extract_model +from nncf.torch.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper from nncf.torch.graph.transformations.command_creation import create_bias_correction_command from nncf.torch.graph.transformations.commands import PTBiasCorrectionCommand from nncf.torch.graph.transformations.commands import PTModelExtractionCommand diff --git a/nncf/quantization/algorithms/min_max/torch_backend.py b/nncf/quantization/algorithms/min_max/torch_backend.py index 8b2974a2ed5..d5090ba5c93 100644 --- a/nncf/quantization/algorithms/min_max/torch_backend.py +++ b/nncf/quantization/algorithms/min_max/torch_backend.py @@ -15,6 +15,7 @@ import nncf import nncf.torch.graph.operator_metatypes as om +from nncf.common.check_features import is_torch_tracing_by_patching from nncf.common.graph.definitions import NNCFGraphNodeType from nncf.common.graph.graph import NNCFGraph from nncf.common.graph.graph import NNCFNode @@ -24,16 +25,15 @@ from nncf.common.hardware.config import HWConfig from nncf.common.quantization.quantizer_propagation.structs import QuantizationTrait from nncf.common.quantization.structs import QuantizerConfig -from nncf.experimental.common.check_feature import is_torch_tracing_by_torch_function_mode from nncf.experimental.common.tensor_statistics.collectors import REDUCERS_MAP from nncf.experimental.common.tensor_statistics.collectors import TensorReducerBase -from nncf.experimental.torch2.commands import PT2InsertionCommand from nncf.parameters import ModelType from nncf.parameters import TargetDevice from nncf.quantization.algorithms.min_max.backend import MinMaxAlgoBackend from nncf.quantization.fake_quantize import FakeConvertParameters from nncf.quantization.fake_quantize import FakeQuantizeParameters from nncf.quantization.range_estimator import StatisticsType +from nncf.torch.function_hook.commands import PT2InsertionCommand from nncf.torch.graph.graph import PTNNCFGraph from nncf.torch.graph.graph import PTTargetPoint from nncf.torch.graph.operator_metatypes import ELEMENTWISE_OPERATIONS @@ -153,10 +153,7 @@ def target_point(target_type: TargetType, target_node_name: str, port_id: int) - input_port_id: Optional[int] = port_id if NNCFGraphNodeType.INPUT_NODE in target_node_name or target_type == TargetType.POST_LAYER_OPERATION: input_port_id = None - if ( - not is_torch_tracing_by_torch_function_mode() - and target_type in PTMinMaxAlgoBackend.TARGET_TYPE_TO_PT_INS_TYPE_MAP - ): + if is_torch_tracing_by_patching() and target_type in PTMinMaxAlgoBackend.TARGET_TYPE_TO_PT_INS_TYPE_MAP: target_type = PTMinMaxAlgoBackend.TARGET_TYPE_TO_PT_INS_TYPE_MAP[target_type] return PTTargetPoint(target_type, target_node_name, input_port_id=input_port_id) @@ -281,10 +278,9 @@ def create_quantizer_insertion_command( quantizer = PTMinMaxAlgoBackend._create_quantizer( quantizer_config, scale_shape, parameters, target_point.target_type ) - if is_torch_tracing_by_torch_function_mode(): - return PT2InsertionCommand(target_points=[target_point], hook_module=quantizer) - - return create_quantizer_insertion_command(target_point, quantizer) + if is_torch_tracing_by_patching(): + return create_quantizer_insertion_command(target_point, quantizer) + return PT2InsertionCommand(target_points=[target_point], hook_module=quantizer) @staticmethod def create_unified_scales_quantizers_insertion_commands( @@ -300,9 +296,9 @@ def create_unified_scales_quantizers_insertion_commands( quantizer = PTMinMaxAlgoBackend._create_quantizer( quantizer_config, scale_shape, parameters, target_points[0].target_type ) - if is_torch_tracing_by_torch_function_mode(): - return [PT2InsertionCommand(target_points=target_points, hook_module=quantizer)] - return [create_shared_quantizer_insertion_command(target_points, quantizer)] + if is_torch_tracing_by_patching(): + return [create_shared_quantizer_insertion_command(target_points, quantizer)] + return [PT2InsertionCommand(target_points=target_points, hook_module=quantizer)] @staticmethod def get_ignored_metatypes(model_type: ModelType, device: TargetDevice) -> List[OperatorMetatype]: diff --git a/nncf/quantization/algorithms/smooth_quant/torch_backend.py b/nncf/quantization/algorithms/smooth_quant/torch_backend.py index d8b5c949c45..e7972e3386d 100644 --- a/nncf/quantization/algorithms/smooth_quant/torch_backend.py +++ b/nncf/quantization/algorithms/smooth_quant/torch_backend.py @@ -14,21 +14,21 @@ import torch import nncf.torch.graph.operator_metatypes as om +from nncf.common.check_features import is_torch_tracing_by_patching from nncf.common.graph import NNCFGraph from nncf.common.graph import NNCFNode from nncf.common.graph.operator_metatypes import OperatorMetatype from nncf.common.graph.transformations.commands import TargetType from nncf.common.quantization.quantizer_propagation.structs import QuantizationTrait from nncf.common.tensor_statistics.statistic_point import StatisticPoint -from nncf.experimental.common.check_feature import is_torch_tracing_by_torch_function_mode from nncf.experimental.common.tensor_statistics.collectors import AbsMaxReducer from nncf.experimental.common.tensor_statistics.collectors import MaxAggregator from nncf.experimental.common.tensor_statistics.collectors import TensorCollector -from nncf.experimental.torch2.commands import PT2ConstUpdateCommand -from nncf.experimental.torch2.commands import PT2InsertionCommand -from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper from nncf.quantization.algorithms.smooth_quant.backend import SmoothQuantAlgoBackend from nncf.tensor import Tensor +from nncf.torch.function_hook.commands import PT2ConstUpdateCommand +from nncf.torch.function_hook.commands import PT2InsertionCommand +from nncf.torch.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper from nncf.torch.graph.transformations.command_creation import create_command_to_update_weight from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand from nncf.torch.graph.transformations.commands import PTTargetPoint @@ -136,10 +136,10 @@ def get_weight_value(node_with_weight: NNCFNode, model: NNCFNetwork, nncf_graph: def weight_update_command( node_with_weight: NNCFNode, nncf_graph: NNCFGraph, weight_value: torch.Tensor ) -> PTWeightUpdateCommand: - if is_torch_tracing_by_torch_function_mode(): - weight_node = get_const_node(node_with_weight, node_with_weight.metatype.weight_port_ids[0], nncf_graph) - return PT2ConstUpdateCommand(weight_node, weight_value) - return create_command_to_update_weight(node_with_weight, weight_value) + if is_torch_tracing_by_patching(): + return create_command_to_update_weight(node_with_weight, weight_value) + weight_node = get_const_node(node_with_weight, node_with_weight.metatype.weight_port_ids[0], nncf_graph) + return PT2ConstUpdateCommand(weight_node, weight_value) @staticmethod def scale_insertion_command( @@ -157,9 +157,9 @@ def scale_insertion_command( sq_multiply = SQMultiply(scale_value.shape) sq_multiply.scale = scale_value - if is_torch_tracing_by_torch_function_mode(): - return PT2InsertionCommand(target_points=target_points, hook_module=sq_multiply) - return PTSharedFnInsertionCommand(target_points, sq_multiply, scale_node_name) + if is_torch_tracing_by_patching(): + return PTSharedFnInsertionCommand(target_points, sq_multiply, scale_node_name) + return PT2InsertionCommand(target_points=target_points, hook_module=sq_multiply) @staticmethod def get_activation_channel_axis(node: NNCFNode, port_id: int) -> int: @@ -175,11 +175,12 @@ def get_weight_channel_axis(node: NNCFNode) -> int: @staticmethod def is_node_with_shared_weight(node: NNCFNode, nncf_graph: NNCFGraph) -> bool: - if is_torch_tracing_by_torch_function_mode(): - weight_node = get_const_node(node, node.metatype.weight_port_ids[0], nncf_graph) - output_edges = nncf_graph.get_next_nodes(weight_node) - return len(output_edges) > 1 - return node.is_shared() + if is_torch_tracing_by_patching(): + return node.is_shared() + + weight_node = get_const_node(node, node.metatype.weight_port_ids[0], nncf_graph) + output_edges = nncf_graph.get_next_nodes(weight_node) + return len(output_edges) > 1 @staticmethod def get_filter_fn_for_statistics(activation_port_id: int, algorithm_key: str) -> Callable[[StatisticPoint], bool]: diff --git a/nncf/quantization/algorithms/weight_compression/torch_backend.py b/nncf/quantization/algorithms/weight_compression/torch_backend.py index cf40e10b99e..61f397e7c1e 100644 --- a/nncf/quantization/algorithms/weight_compression/torch_backend.py +++ b/nncf/quantization/algorithms/weight_compression/torch_backend.py @@ -15,6 +15,7 @@ import nncf import nncf.torch.graph.operator_metatypes as om +from nncf.common.check_features import is_torch_tracing_by_patching from nncf.common.graph.definitions import NNCFGraphNodeType from nncf.common.graph.graph import NNCFGraph from nncf.common.graph.graph import NNCFNode @@ -26,7 +27,6 @@ from nncf.common.graph.transformations.layout import TransformationLayout from nncf.common.quantization.structs import QuantizationScheme from nncf.common.tensor_statistics.statistic_point import StatisticPoint -from nncf.experimental.common.check_feature import is_torch_tracing_by_torch_function_mode from nncf.experimental.common.tensor_statistics.collectors import MaxVarianceReducer from nncf.experimental.common.tensor_statistics.collectors import MeanAbsMaxReducer from nncf.experimental.common.tensor_statistics.collectors import MeanAggregator @@ -39,9 +39,6 @@ from nncf.experimental.common.tensor_statistics.statistics import MeanMagnitudeTensorStatistic from nncf.experimental.common.tensor_statistics.statistics import MeanVarianceTensorStatistic from nncf.experimental.common.tensor_statistics.statistics import WCTensorStatistic -from nncf.experimental.torch2.commands import PT2InsertionCommand -from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper -from nncf.experimental.torch2.model_transformer import PT2ModelTransformer from nncf.parameters import CompressionFormat from nncf.parameters import CompressWeightsMode from nncf.quantization.advanced_parameters import AdvancedCompressionParameters @@ -57,6 +54,9 @@ from nncf.quantization.algorithms.weight_compression.weight_lowering import compress_weight from nncf.tensor import Tensor from nncf.tensor.definitions import TensorDataType +from nncf.torch.function_hook.commands import PT2InsertionCommand +from nncf.torch.function_hook.model_transformer import PT2ModelTransformer +from nncf.torch.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper from nncf.torch.graph.graph import PTTargetPoint from nncf.torch.graph.operator_metatypes import PTMulMetatype from nncf.torch.graph.pattern_operations import ATOMIC_ACTIVATIONS_OPERATIONS @@ -233,14 +233,14 @@ def get_weight_shape(node_with_weight: NNCFNode, weight_port_id: int, graph: NNC def set_weight( self, node_with_weight: NNCFNode, weight_port_id: int, model: torch.nn.Module, graph: NNCFGraph, weight: Tensor ): - if is_torch_tracing_by_torch_function_mode(): + if is_torch_tracing_by_patching(): + update_parameter(node_with_weight.node_name, "weight", weight.data, model) + else: weight_node = get_const_node(node_with_weight, weight_port_id, graph) module_name, weight_attr_name = split_const_name(weight_node.layer_attributes.name) module = get_module_by_name(module_name, model.model) weight_param = getattr(module, weight_attr_name) weight_param.data = weight.data - else: - update_parameter(node_with_weight.node_name, "weight", weight.data, model) def insert_adapters( self, wc_params: WeightCompressionParameters, lora_A: Tensor, lora_B: Tensor, int8_lora: bool @@ -359,9 +359,7 @@ def get_fq_insertion_command( target_node_name = wc_params.weight_name target_point = PTTargetPoint(TargetType.OPERATOR_POST_HOOK, target_node_name=target_node_name) - if is_torch_tracing_by_torch_function_mode(): - return PT2InsertionCommand([target_point], quantizer) - else: + if is_torch_tracing_by_patching(): storage_key = "FQ_LORA_{}".format(target_node_name.replace(".", "_")) return PTSharedFnInsertionCommand( target_points=[target_point], @@ -370,6 +368,8 @@ def get_fq_insertion_command( compression_module_type=ExtraCompressionModuleType.EXTERNAL_QUANTIZER, priority=TransformationPriority.QUANTIZATION_PRIORITY, ) + else: + return PT2InsertionCommand([target_point], quantizer) @staticmethod def get_dq_insertion_command( @@ -439,16 +439,7 @@ def get_dq_insertion_command( weight.requires_grad = False weight.data = packed_tensor - if is_torch_tracing_by_torch_function_mode(): - return PT2InsertionCommand( - [ - PTTargetPoint( - TargetType.OPERATOR_POST_HOOK, target_node_name=weight_node.node_name.replace(".", ":") - ) - ], - decompressor, - ) - else: + if is_torch_tracing_by_patching(): # registry weight decompression module in the model decompressor_name = f"weights_decompressor_{weight_node.node_name.replace('.', '_')}" @@ -458,6 +449,15 @@ def get_dq_insertion_command( decompressor, decompressor_name, ) + else: + return PT2InsertionCommand( + [ + PTTargetPoint( + TargetType.OPERATOR_POST_HOOK, target_node_name=weight_node.node_name.replace(".", ":") + ) + ], + decompressor, + ) def transform_model( self, @@ -560,10 +560,10 @@ def scale_insertion_command( sq_multiply = SQMultiply(scale.shape) sq_multiply.scale = scale - if is_torch_tracing_by_torch_function_mode(): - return PT2InsertionCommand(target_points, sq_multiply) - scale_node_name = f"{source_node.node_name}/awq_mul" - return PTSharedFnInsertionCommand(target_points, sq_multiply, scale_node_name) + if is_torch_tracing_by_patching(): + scale_node_name = f"{source_node.node_name}/awq_mul" + return PTSharedFnInsertionCommand(target_points, sq_multiply, scale_node_name) + return PT2InsertionCommand(target_points, sq_multiply) class PTMixedPrecisionAlgoBackend(MixedPrecisionAlgoBackend, PTWeightCompressionAlgoBackend): diff --git a/nncf/quantization/quantize_model.py b/nncf/quantization/quantize_model.py index 3eb38d79fb9..b78b797afd0 100644 --- a/nncf/quantization/quantize_model.py +++ b/nncf/quantization/quantize_model.py @@ -12,6 +12,7 @@ import nncf from nncf.api.compression import TModel +from nncf.common.check_features import is_torch_tracing_by_patching from nncf.common.deprecation import warning_deprecated from nncf.common.graph import NNCFGraph from nncf.common.graph.operator_metatypes import OperatorMetatype @@ -21,7 +22,6 @@ from nncf.common.utils.backend import BackendType from nncf.common.utils.backend import get_backend from nncf.data import Dataset -from nncf.experimental.common.check_feature import is_torch_tracing_by_torch_function_mode from nncf.parameters import BackupMode from nncf.parameters import CompressionFormat from nncf.parameters import CompressWeightsMode @@ -232,10 +232,10 @@ def quantize( ) if backend == BackendType.TORCH: - if is_torch_tracing_by_torch_function_mode(): - from nncf.experimental.torch2.quantization.quantize_model import quantize_impl - else: + if is_torch_tracing_by_patching(): from nncf.torch.quantization.quantize_model import quantize_impl + else: + from nncf.torch.function_hook.quantization.quantize_model import quantize_impl return quantize_impl( # type: ignore[no-any-return] model=model, diff --git a/nncf/torch/dynamic_graph/context.py b/nncf/torch/dynamic_graph/context.py index 0a06b625f22..fe6ef8fda47 100644 --- a/nncf/torch/dynamic_graph/context.py +++ b/nncf/torch/dynamic_graph/context.py @@ -19,14 +19,13 @@ import torch +from nncf.common.check_features import is_torch_tracing_by_patching from nncf.common.graph.layer_attributes import BaseLayerAttributes from nncf.common.hook_handle import HookHandle from nncf.common.hook_handle import add_op_to_registry from nncf.common.utils.api_marker import api from nncf.common.utils.debug import is_debug from nncf.common.utils.patcher import PATCHER -from nncf.experimental.common.check_feature import is_torch_tracing_by_torch_function_mode -from nncf.experimental.torch2.function_hook.hook_executor_mode import disable_function_hook_mode from nncf.torch.dynamic_graph.graph import DynamicGraph from nncf.torch.dynamic_graph.graph import DynamicGraphNode from nncf.torch.dynamic_graph.graph import DynamicGraphNodeParameters @@ -37,6 +36,7 @@ from nncf.torch.dynamic_graph.scope import ScopeElement from nncf.torch.dynamic_graph.trace_tensor import TensorMeta from nncf.torch.dynamic_graph.trace_tensor import TracedTensorMixin +from nncf.torch.function_hook.hook_executor_mode import disable_function_hook_mode class ThreadLocalGlobalContext(threading.local): @@ -506,15 +506,15 @@ def disable_tracing(method): Patch a method so that it will be executed within no_nncf_trace context :param method: A method to patch. """ - if is_torch_tracing_by_torch_function_mode(): + if is_torch_tracing_by_patching(): def no_nncf_trace_wrapper(self, fn, *args, **kwargs): - with disable_function_hook_mode(): + with no_nncf_trace(): return fn(*args, **kwargs) else: def no_nncf_trace_wrapper(self, fn, *args, **kwargs): - with no_nncf_trace(): + with disable_function_hook_mode(): return fn(*args, **kwargs) PATCHER.patch(method, no_nncf_trace_wrapper) diff --git a/nncf/torch/dynamic_graph/patch_pytorch.py b/nncf/torch/dynamic_graph/patch_pytorch.py index ea64183b525..9af4cb6e7ae 100644 --- a/nncf/torch/dynamic_graph/patch_pytorch.py +++ b/nncf/torch/dynamic_graph/patch_pytorch.py @@ -23,8 +23,8 @@ import nncf from nncf import nncf_logger +from nncf.common.check_features import is_torch_tracing_by_patching from nncf.common.utils.api_marker import api -from nncf.experimental.common.check_feature import is_torch_tracing_by_torch_function_mode from nncf.torch.dynamic_graph.patch_pytorch_state import PATCHING_STATE from nncf.torch.dynamic_graph.structs import NamespaceTarget from nncf.torch.dynamic_graph.structs import PatchedOperatorInfo @@ -181,18 +181,18 @@ class MagicFunctionsToPatch: @api(canonical_alias="nncf.torch.register_operator") def register_operator(name=None): - if is_torch_tracing_by_torch_function_mode(): - - def wrap(operator): - # Skip wrapping operator for tracing by TorchFunctionMode - return operator - else: + if is_torch_tracing_by_patching(): def wrap(operator): op_name = name if op_name is None: op_name = operator.__name__ return wrap_operator(operator, PatchedOperatorInfo(op_name, NamespaceTarget.EXTERNAL)) + else: + + def wrap(operator): + # Skip wrapping operator for tracing by TorchFunctionMode + return operator return wrap @@ -360,7 +360,7 @@ def remove_private_functions(names: List[str]) -> List[str]: def patch_torch_operators(): - if is_torch_tracing_by_torch_function_mode(): + if not is_torch_tracing_by_patching(): return # Only patch torch.jit.script during first patch_torch_operators call diff --git a/nncf/experimental/torch2/statistics/__init__.py b/nncf/torch/function_hook/__init__.py similarity index 52% rename from nncf/experimental/torch2/statistics/__init__.py rename to nncf/torch/function_hook/__init__.py index e5a42efc0ef..1ebbb36b10c 100644 --- a/nncf/experimental/torch2/statistics/__init__.py +++ b/nncf/torch/function_hook/__init__.py @@ -8,3 +8,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from nncf.torch.function_hook.graph.build_graph_mode import build_graph as build_graph +from nncf.torch.function_hook.wrapper import get_hook_storage as get_hook_storage +from nncf.torch.function_hook.wrapper import is_wrapped as is_wrapped +from nncf.torch.function_hook.wrapper import register_post_function_hook as register_post_function_hook +from nncf.torch.function_hook.wrapper import register_pre_function_hook as register_pre_function_hook +from nncf.torch.function_hook.wrapper import wrap_model as wrap_model diff --git a/nncf/experimental/torch2/commands.py b/nncf/torch/function_hook/commands.py similarity index 96% rename from nncf/experimental/torch2/commands.py rename to nncf/torch/function_hook/commands.py index 937c1f65a90..fd9bc79de10 100644 --- a/nncf/experimental/torch2/commands.py +++ b/nncf/torch/function_hook/commands.py @@ -17,7 +17,7 @@ from nncf.common.graph.graph import NNCFNode from nncf.common.graph.transformations.commands import Command from nncf.common.graph.transformations.commands import TransformationType -from nncf.experimental.torch2.function_hook.hook_storage import RemovableHookHandle +from nncf.torch.function_hook.hook_storage import RemovableHookHandle from nncf.torch.graph.transformations.commands import PTTargetPoint diff --git a/nncf/experimental/torch2/function_hook/extractor.py b/nncf/torch/function_hook/extractor.py similarity index 97% rename from nncf/experimental/torch2/function_hook/extractor.py rename to nncf/torch/function_hook/extractor.py index 8ea0c454aa0..b30d15e1ef8 100644 --- a/nncf/experimental/torch2/function_hook/extractor.py +++ b/nncf/torch/function_hook/extractor.py @@ -16,8 +16,8 @@ import nncf from nncf import nncf_logger from nncf.common.graph.graph import NNCFNode -from nncf.experimental.torch2.function_hook.nncf_graph.layer_attributes import PT2OpLayerAttributes -from nncf.experimental.torch2.function_hook.wrapper import get_hook_storage +from nncf.torch.function_hook.nncf_graph.layer_attributes import PT2OpLayerAttributes +from nncf.torch.function_hook.wrapper import get_hook_storage from nncf.torch.graph import operator_metatypes as om from nncf.torch.graph.graph import PTNNCFGraph from nncf.torch.model_graph_manager import get_const_data diff --git a/nncf/experimental/torch2/__init__.py b/nncf/torch/function_hook/graph/__init__.py similarity index 100% rename from nncf/experimental/torch2/__init__.py rename to nncf/torch/function_hook/graph/__init__.py diff --git a/nncf/experimental/torch2/function_hook/graph/build_graph_mode.py b/nncf/torch/function_hook/graph/build_graph_mode.py similarity index 93% rename from nncf/experimental/torch2/function_hook/graph/build_graph_mode.py rename to nncf/torch/function_hook/graph/build_graph_mode.py index 8f37a5b2038..fdd3d2ab4a7 100644 --- a/nncf/experimental/torch2/function_hook/graph/build_graph_mode.py +++ b/nncf/torch/function_hook/graph/build_graph_mode.py @@ -18,20 +18,20 @@ from torch import nn from nncf.common.logging import nncf_logger as logger -from nncf.experimental.torch2.function_hook.graph.graph_utils import ConstMeta -from nncf.experimental.torch2.function_hook.graph.graph_utils import EdgeMeta -from nncf.experimental.torch2.function_hook.graph.graph_utils import FunctionMeta -from nncf.experimental.torch2.function_hook.graph.graph_utils import InOutMeta -from nncf.experimental.torch2.function_hook.graph.graph_utils import NodeType -from nncf.experimental.torch2.function_hook.graph.graph_utils import TensorInfo -from nncf.experimental.torch2.function_hook.graph.graph_utils import TensorMeta -from nncf.experimental.torch2.function_hook.graph.graph_utils import TensorSource -from nncf.experimental.torch2.function_hook.hook_executor_mode import FunctionHookMode -from nncf.experimental.torch2.function_hook.hook_executor_mode import OpMeta -from nncf.experimental.torch2.function_hook.hook_storage import HookStorage -from nncf.experimental.torch2.function_hook.weak_map import WeakUnhashableKeyMap -from nncf.experimental.torch2.function_hook.wrapper import ForwardWithHooks -from nncf.experimental.torch2.function_hook.wrapper import get_hook_storage +from nncf.torch.function_hook.graph.graph_utils import ConstMeta +from nncf.torch.function_hook.graph.graph_utils import EdgeMeta +from nncf.torch.function_hook.graph.graph_utils import FunctionMeta +from nncf.torch.function_hook.graph.graph_utils import InOutMeta +from nncf.torch.function_hook.graph.graph_utils import NodeType +from nncf.torch.function_hook.graph.graph_utils import TensorInfo +from nncf.torch.function_hook.graph.graph_utils import TensorMeta +from nncf.torch.function_hook.graph.graph_utils import TensorSource +from nncf.torch.function_hook.hook_executor_mode import FunctionHookMode +from nncf.torch.function_hook.hook_executor_mode import OpMeta +from nncf.torch.function_hook.hook_storage import HookStorage +from nncf.torch.function_hook.weak_map import WeakUnhashableKeyMap +from nncf.torch.function_hook.wrapper import ForwardWithHooks +from nncf.torch.function_hook.wrapper import get_hook_storage from nncf.torch.utils import training_mode_switcher diff --git a/nncf/experimental/torch2/function_hook/graph/graph_utils.py b/nncf/torch/function_hook/graph/graph_utils.py similarity index 100% rename from nncf/experimental/torch2/function_hook/graph/graph_utils.py rename to nncf/torch/function_hook/graph/graph_utils.py diff --git a/nncf/experimental/torch2/function_hook/graph/graph_visualization.py b/nncf/torch/function_hook/graph/graph_visualization.py similarity index 95% rename from nncf/experimental/torch2/function_hook/graph/graph_visualization.py rename to nncf/torch/function_hook/graph/graph_visualization.py index 96ada0022d4..49d7372e839 100644 --- a/nncf/experimental/torch2/function_hook/graph/graph_visualization.py +++ b/nncf/torch/function_hook/graph/graph_visualization.py @@ -16,10 +16,10 @@ import networkx as nx # type: ignore[import-untyped] import pydot # type: ignore[import-untyped] -from nncf.experimental.torch2.function_hook.graph.graph_utils import ConstMeta -from nncf.experimental.torch2.function_hook.graph.graph_utils import EdgeMeta -from nncf.experimental.torch2.function_hook.graph.graph_utils import FunctionMeta -from nncf.experimental.torch2.function_hook.graph.graph_utils import InOutMeta +from nncf.torch.function_hook.graph.graph_utils import ConstMeta +from nncf.torch.function_hook.graph.graph_utils import EdgeMeta +from nncf.torch.function_hook.graph.graph_utils import FunctionMeta +from nncf.torch.function_hook.graph.graph_utils import InOutMeta class PydotStyleTemplate(Enum): diff --git a/nncf/experimental/torch2/function_hook/handle_inner_functions.py b/nncf/torch/function_hook/handle_inner_functions.py similarity index 100% rename from nncf/experimental/torch2/function_hook/handle_inner_functions.py rename to nncf/torch/function_hook/handle_inner_functions.py diff --git a/nncf/experimental/torch2/function_hook/hook_executor_mode.py b/nncf/torch/function_hook/hook_executor_mode.py similarity index 98% rename from nncf/experimental/torch2/function_hook/hook_executor_mode.py rename to nncf/torch/function_hook/hook_executor_mode.py index f1f18144913..5d59f967f39 100644 --- a/nncf/experimental/torch2/function_hook/hook_executor_mode.py +++ b/nncf/torch/function_hook/hook_executor_mode.py @@ -30,9 +30,9 @@ from torch.overrides import TorchFunctionMode from nncf.common.logging import nncf_logger as logger -from nncf.experimental.torch2.function_hook.handle_inner_functions import get_handle_inner_function -from nncf.experimental.torch2.function_hook.hook_storage import HookStorage -from nncf.experimental.torch2.function_hook.weak_map import WeakUnhashableKeyMap +from nncf.torch.function_hook.handle_inner_functions import get_handle_inner_function +from nncf.torch.function_hook.hook_storage import HookStorage +from nncf.torch.function_hook.weak_map import WeakUnhashableKeyMap IGNORED_FN_NAMES = [ "__repr__", diff --git a/nncf/experimental/torch2/function_hook/hook_storage.py b/nncf/torch/function_hook/hook_storage.py similarity index 100% rename from nncf/experimental/torch2/function_hook/hook_storage.py rename to nncf/torch/function_hook/hook_storage.py diff --git a/nncf/experimental/torch2/model_transformer.py b/nncf/torch/function_hook/model_transformer.py similarity index 93% rename from nncf/experimental/torch2/model_transformer.py rename to nncf/torch/function_hook/model_transformer.py index 631e950a759..5d46c10ebca 100644 --- a/nncf/experimental/torch2/model_transformer.py +++ b/nncf/torch/function_hook/model_transformer.py @@ -18,12 +18,12 @@ from nncf.common.graph.transformations.commands import Command from nncf.common.graph.transformations.commands import TargetType from nncf.common.graph.transformations.layout import TransformationLayout -from nncf.experimental.torch2.commands import PT2ConstUpdateCommand -from nncf.experimental.torch2.commands import PT2InsertionCommand -from nncf.experimental.torch2.function_hook.hook_storage import RemovableHookHandle -from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper -from nncf.experimental.torch2.function_hook.wrapper import register_post_function_hook -from nncf.experimental.torch2.function_hook.wrapper import register_pre_function_hook +from nncf.torch.function_hook.commands import PT2ConstUpdateCommand +from nncf.torch.function_hook.commands import PT2InsertionCommand +from nncf.torch.function_hook.hook_storage import RemovableHookHandle +from nncf.torch.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper +from nncf.torch.function_hook.wrapper import register_post_function_hook +from nncf.torch.function_hook.wrapper import register_pre_function_hook from nncf.torch.graph.transformations.commands import PTBiasCorrectionCommand from nncf.torch.graph.transformations.commands import PTTargetPoint from nncf.torch.model_graph_manager import set_const_data diff --git a/nncf/experimental/torch2/function_hook/graph/__init__.py b/nncf/torch/function_hook/nncf_graph/__init__.py similarity index 100% rename from nncf/experimental/torch2/function_hook/graph/__init__.py rename to nncf/torch/function_hook/nncf_graph/__init__.py diff --git a/nncf/experimental/torch2/function_hook/nncf_graph/layer_attributes.py b/nncf/torch/function_hook/nncf_graph/layer_attributes.py similarity index 100% rename from nncf/experimental/torch2/function_hook/nncf_graph/layer_attributes.py rename to nncf/torch/function_hook/nncf_graph/layer_attributes.py diff --git a/nncf/experimental/torch2/function_hook/nncf_graph/nncf_graph_builder.py b/nncf/torch/function_hook/nncf_graph/nncf_graph_builder.py similarity index 94% rename from nncf/experimental/torch2/function_hook/nncf_graph/nncf_graph_builder.py rename to nncf/torch/function_hook/nncf_graph/nncf_graph_builder.py index c815405741f..1f10debf401 100644 --- a/nncf/experimental/torch2/function_hook/nncf_graph/nncf_graph_builder.py +++ b/nncf/torch/function_hook/nncf_graph/nncf_graph_builder.py @@ -23,13 +23,13 @@ from nncf.common.graph.layer_attributes import BaseLayerAttributes from nncf.common.graph.layer_attributes import ConstantLayerAttributes from nncf.common.graph.layer_attributes import Dtype -from nncf.experimental.torch2.function_hook.graph.build_graph_mode import build_graph -from nncf.experimental.torch2.function_hook.graph.graph_utils import ConstMeta -from nncf.experimental.torch2.function_hook.graph.graph_utils import EdgeMeta -from nncf.experimental.torch2.function_hook.graph.graph_utils import FunctionMeta -from nncf.experimental.torch2.function_hook.graph.graph_utils import InOutMeta -from nncf.experimental.torch2.function_hook.graph.graph_utils import NodeType -from nncf.experimental.torch2.function_hook.nncf_graph.layer_attributes import PT2OpLayerAttributes +from nncf.torch.function_hook.graph.build_graph_mode import build_graph +from nncf.torch.function_hook.graph.graph_utils import ConstMeta +from nncf.torch.function_hook.graph.graph_utils import EdgeMeta +from nncf.torch.function_hook.graph.graph_utils import FunctionMeta +from nncf.torch.function_hook.graph.graph_utils import InOutMeta +from nncf.torch.function_hook.graph.graph_utils import NodeType +from nncf.torch.function_hook.nncf_graph.layer_attributes import PT2OpLayerAttributes from nncf.torch.graph.graph import PTNNCFGraph diff --git a/nncf/experimental/torch2/function_hook/nncf_graph/__init__.py b/nncf/torch/function_hook/quantization/__init__.py similarity index 100% rename from nncf/experimental/torch2/function_hook/nncf_graph/__init__.py rename to nncf/torch/function_hook/quantization/__init__.py diff --git a/nncf/experimental/torch2/quantization/quantize_model.py b/nncf/torch/function_hook/quantization/quantize_model.py similarity index 94% rename from nncf/experimental/torch2/quantization/quantize_model.py rename to nncf/torch/function_hook/quantization/quantize_model.py index efa17fc11c9..df65082e680 100644 --- a/nncf/experimental/torch2/quantization/quantize_model.py +++ b/nncf/torch/function_hook/quantization/quantize_model.py @@ -17,8 +17,6 @@ import nncf from nncf.common.quantization.structs import QuantizationPreset from nncf.data import Dataset -from nncf.experimental.torch2.function_hook import wrap_model -from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper from nncf.parameters import ModelType from nncf.parameters import QuantizationMode from nncf.parameters import TargetDevice @@ -26,6 +24,8 @@ from nncf.quantization.algorithms.post_training.algorithm import PostTrainingQuantization from nncf.quantization.quantize_model import warning_model_no_batchwise_support from nncf.scopes import IgnoredScope +from nncf.torch.function_hook import wrap_model +from nncf.torch.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper from nncf.torch.graph.operator_metatypes import OPERATIONS_OUTPUT_HAS_NO_BATCH_AXIS diff --git a/nncf/experimental/torch2/function_hook/serialization.py b/nncf/torch/function_hook/serialization.py similarity index 97% rename from nncf/experimental/torch2/function_hook/serialization.py rename to nncf/torch/function_hook/serialization.py index 7cd95ca753f..d60ba3ead6b 100644 --- a/nncf/experimental/torch2/function_hook/serialization.py +++ b/nncf/torch/function_hook/serialization.py @@ -16,8 +16,8 @@ import nncf from nncf.common.logging import nncf_logger -from nncf.experimental.torch2.function_hook.wrapper import get_hook_storage -from nncf.experimental.torch2.function_hook.wrapper import wrap_model +from nncf.torch.function_hook.wrapper import get_hook_storage +from nncf.torch.function_hook.wrapper import wrap_model from nncf.torch.layer_utils import COMPRESSION_MODULES from nncf.torch.layer_utils import StatefulModuleInterface from nncf.torch.utils import get_model_device diff --git a/nncf/experimental/torch2/quantization/__init__.py b/nncf/torch/function_hook/statistics/__init__.py similarity index 100% rename from nncf/experimental/torch2/quantization/__init__.py rename to nncf/torch/function_hook/statistics/__init__.py diff --git a/nncf/experimental/torch2/statistics/aggregator.py b/nncf/torch/function_hook/statistics/aggregator.py similarity index 94% rename from nncf/experimental/torch2/statistics/aggregator.py rename to nncf/torch/function_hook/statistics/aggregator.py index eadaf940be2..abb9b9f8cf0 100644 --- a/nncf/experimental/torch2/statistics/aggregator.py +++ b/nncf/torch/function_hook/statistics/aggregator.py @@ -24,10 +24,10 @@ from nncf.data.dataset import Dataset from nncf.experimental.common.tensor_statistics.collectors import TensorCollector from nncf.experimental.common.tensor_statistics.statistics import TensorStatistic -from nncf.experimental.torch2.commands import PT2InsertionCommand -from nncf.experimental.torch2.function_hook.hook_storage import RemovableHookHandle -from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper from nncf.tensor import Tensor +from nncf.torch.function_hook.commands import PT2InsertionCommand +from nncf.torch.function_hook.hook_storage import RemovableHookHandle +from nncf.torch.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper from nncf.torch.graph.transformations.commands import PTTargetPoint diff --git a/nncf/experimental/torch2/function_hook/strip.py b/nncf/torch/function_hook/strip.py similarity index 93% rename from nncf/experimental/torch2/function_hook/strip.py rename to nncf/torch/function_hook/strip.py index 2f0c5f87cbc..43ca1e296a7 100644 --- a/nncf/experimental/torch2/function_hook/strip.py +++ b/nncf/torch/function_hook/strip.py @@ -17,15 +17,14 @@ import nncf from nncf.common.graph.graph import NNCFGraph from nncf.common.graph.layer_attributes import ConstantLayerAttributes -from nncf.experimental.torch2.function_hook.hook_storage import decode_hook_name -from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import build_nncf_graph -from nncf.experimental.torch2.function_hook.wrapper import get_hook_storage from nncf.parameters import StripFormat +from nncf.torch.function_hook.hook_storage import decode_hook_name +from nncf.torch.function_hook.nncf_graph.nncf_graph_builder import build_nncf_graph +from nncf.torch.function_hook.wrapper import get_hook_storage from nncf.torch.model_graph_manager import get_const_data from nncf.torch.model_graph_manager import get_const_node from nncf.torch.model_graph_manager import get_module_by_name from nncf.torch.model_graph_manager import split_const_name -from nncf.torch.nncf_network import NNCFNetwork from nncf.torch.quantization.layers import AsymmetricQuantizer from nncf.torch.quantization.layers import BaseQuantizer from nncf.torch.quantization.layers import SymmetricQuantizer @@ -36,9 +35,7 @@ TModel = TypeVar("TModel", bound=nn.Module) -def strip_quantized_model( - model: NNCFNetwork, example_input: Any, strip_format: StripFormat = StripFormat.NATIVE -) -> NNCFNetwork: +def strip_quantized_model(model: TModel, example_input: Any, strip_format: StripFormat = StripFormat.NATIVE) -> TModel: """ Removes auxiliary layers and operations added during the quantization process, resulting in a clean quantized model ready for deployment. The functionality of the model object is still preserved diff --git a/nncf/experimental/torch2/function_hook/weak_map.py b/nncf/torch/function_hook/weak_map.py similarity index 100% rename from nncf/experimental/torch2/function_hook/weak_map.py rename to nncf/torch/function_hook/weak_map.py diff --git a/nncf/experimental/torch2/function_hook/wrapper.py b/nncf/torch/function_hook/wrapper.py similarity index 97% rename from nncf/experimental/torch2/function_hook/wrapper.py rename to nncf/torch/function_hook/wrapper.py index 1a334bbb17f..97cb8468841 100644 --- a/nncf/experimental/torch2/function_hook/wrapper.py +++ b/nncf/torch/function_hook/wrapper.py @@ -19,9 +19,9 @@ from torch import nn import nncf -from nncf.experimental.torch2.function_hook.hook_executor_mode import FunctionHookMode -from nncf.experimental.torch2.function_hook.hook_storage import HookStorage -from nncf.experimental.torch2.function_hook.hook_storage import RemovableHookHandle +from nncf.torch.function_hook.hook_executor_mode import FunctionHookMode +from nncf.torch.function_hook.hook_storage import HookStorage +from nncf.torch.function_hook.hook_storage import RemovableHookHandle ATR_HOOK_STORAGE = "__nncf_hooks" diff --git a/nncf/torch/graph/graph.py b/nncf/torch/graph/graph.py index 7a02e8b42d9..9e2f15f87df 100644 --- a/nncf/torch/graph/graph.py +++ b/nncf/torch/graph/graph.py @@ -13,14 +13,14 @@ from typing import Dict, List, Tuple import nncf +from nncf.common.check_features import is_torch_tracing_by_patching from nncf.common.graph import NNCFGraph from nncf.common.graph import NNCFNode from nncf.common.graph import NNCFNodeName from nncf.common.graph.layer_attributes import MultipleInputLayerAttributes -from nncf.experimental.common.check_feature import is_torch_tracing_by_torch_function_mode -from nncf.experimental.torch2.function_hook.graph.graph_utils import TensorMeta -from nncf.experimental.torch2.function_hook.nncf_graph.layer_attributes import PT2OpLayerAttributes from nncf.torch.dynamic_graph.scope import Scope +from nncf.torch.function_hook.graph.graph_utils import TensorMeta +from nncf.torch.function_hook.nncf_graph.layer_attributes import PT2OpLayerAttributes from nncf.torch.graph.transformations.commands import PTTargetPoint @@ -106,21 +106,7 @@ def get_nodes_with_missed_input_edges(self) -> List[NNCFNode]: :return: List of NNCFNodes that are identified as disconnected. """ input_nodes = set() - if is_torch_tracing_by_torch_function_mode(): - # Check expected number of input edges by counting TensorMeta in op_args and op_kwargs. - for node in self.get_all_nodes(): - input_edges = len(self.get_input_edges(node)) - if not isinstance(node.layer_attributes, PT2OpLayerAttributes): - continue - num_expected_input_edges = 0 - for val in chain(node.layer_attributes.op_args, node.layer_attributes.op_kwargs.values()): - if isinstance(val, TensorMeta): - num_expected_input_edges += 1 - if isinstance(val, (list, tuple)): - num_expected_input_edges += sum(isinstance(v, TensorMeta) for v in val) - if input_edges < num_expected_input_edges: - input_nodes.add(node) - else: + if is_torch_tracing_by_patching(): for node in self.get_all_nodes(): num_expected_input_edges = None if hasattr(node.metatype, "num_expected_input_edges"): @@ -135,4 +121,19 @@ def get_nodes_with_missed_input_edges(self) -> List[NNCFNode]: # If node has missed input edges we assume this node is an input node # that was disconnected from an activation input. input_nodes.add(node) + else: + # Check expected number of input edges by counting TensorMeta in op_args and op_kwargs. + for node in self.get_all_nodes(): + input_edges = len(self.get_input_edges(node)) + if not isinstance(node.layer_attributes, PT2OpLayerAttributes): + continue + num_expected_input_edges = 0 + for val in chain(node.layer_attributes.op_args, node.layer_attributes.op_kwargs.values()): + if isinstance(val, TensorMeta): + num_expected_input_edges += 1 + if isinstance(val, (list, tuple)): + num_expected_input_edges += sum(isinstance(v, TensorMeta) for v in val) + if input_edges < num_expected_input_edges: + input_nodes.add(node) + return list(input_nodes) diff --git a/nncf/torch/graph/operator_metatypes.py b/nncf/torch/graph/operator_metatypes.py index 318689a1cee..ce74b62fdb9 100644 --- a/nncf/torch/graph/operator_metatypes.py +++ b/nncf/torch/graph/operator_metatypes.py @@ -12,6 +12,7 @@ from typing import Dict, List, Optional, Type, TypeVar +from nncf.common.check_features import is_torch_tracing_by_patching from nncf.common.graph.definitions import NNCFGraphNodeType from nncf.common.graph.layer_attributes import BaseLayerAttributes from nncf.common.graph.layer_attributes import ConvolutionLayerAttributes @@ -22,7 +23,6 @@ from nncf.common.graph.operator_metatypes import OperatorMetatype from nncf.common.graph.operator_metatypes import OperatorMetatypeRegistry from nncf.common.hardware.opset import HWConfigOpName -from nncf.experimental.common.check_feature import is_torch_tracing_by_torch_function_mode from nncf.torch.dynamic_graph.graph import DynamicGraph from nncf.torch.dynamic_graph.structs import NamespaceTarget @@ -729,14 +729,14 @@ class PTBatchNormMetatype(PTOperatorMetatype): } subtypes = [PTModuleBatchNormMetatype] - if is_torch_tracing_by_torch_function_mode(): - # torch.batch_norm - weight_port_ids = [1] - bias_port_id = 2 - else: + if is_torch_tracing_by_patching(): # torch.nn.functional.batch_norm weight_port_ids = [3] bias_port_id = 4 + else: + # torch.batch_norm + weight_port_ids = [1] + bias_port_id = 2 @PT_OPERATOR_METATYPES.register() diff --git a/nncf/torch/model_creation.py b/nncf/torch/model_creation.py index 245a14e31c6..6bb429b97f4 100644 --- a/nncf/torch/model_creation.py +++ b/nncf/torch/model_creation.py @@ -18,6 +18,7 @@ import nncf from nncf.api.compression import CompressionAlgorithmController +from nncf.common.check_features import is_torch_tracing_by_patching from nncf.common.compression import BaseCompressionAlgorithmController as BaseController from nncf.common.deprecation import warning_deprecated from nncf.common.logging import nncf_logger @@ -27,9 +28,6 @@ from nncf.config.extractors import extract_algorithm_names from nncf.config.extractors import has_input_info_field from nncf.config.telemetry_extractors import CompressionStartedFromConfig -from nncf.experimental.common.check_feature import is_torch_tracing_by_torch_function_mode -from nncf.experimental.torch2.function_hook.serialization import get_config as pt2_get_config -from nncf.experimental.torch2.function_hook.serialization import load_from_config as pt2_load_from_config from nncf.telemetry import tracked_function from nncf.telemetry.events import NNCF_PT_CATEGORY from nncf.telemetry.extractors import FunctionCallTelemetryExtractor @@ -45,6 +43,8 @@ from nncf.torch.dynamic_graph.io_handling import LoaderInputInfo from nncf.torch.dynamic_graph.io_handling import ModelInputInfo from nncf.torch.dynamic_graph.patch_pytorch_state import PATCHING_STATE +from nncf.torch.function_hook.serialization import get_config as pt2_get_config +from nncf.torch.function_hook.serialization import load_from_config as pt2_load_from_config from nncf.torch.graph.transformations.serialization import deserialize_transformations from nncf.torch.model_transformer import PTModelTransformer from nncf.torch.nncf_network import NNCFNetwork @@ -361,35 +361,35 @@ def wrap_model( :param trace_parameters: Whether to trace model parameters. Default is False. :return: A model wrapped by NNCFNetwork or GraphModelWrapper if experimental PyTorch model tracing is enabled. """ - if is_torch_tracing_by_torch_function_mode(): - if not trace_parameters: - msg = "The 'trace_parameters=False' option is not supported in the experimental tracing mode." - raise nncf.InternalError(msg) - from nncf.experimental.torch2.function_hook import is_wrapped as pt2_is_wrapped - from nncf.experimental.torch2.function_hook import wrap_model as pt2_wrap_model - from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper + if is_torch_tracing_by_patching(): + if not isinstance(model, torch.nn.Module): + msg = ( + f"The provided model type {type(model)} is incompatible. " + "Only models inheriting from torch.nn.Module are supported." + ) + raise TypeError(msg) - if not pt2_is_wrapped(model): - model = pt2_wrap_model(model) - wrapped_model = GraphModelWrapper(model, example_input=example_input) - return wrapped_model + input_info = ExampleInputInfo.from_example_input(example_input) - if not isinstance(model, torch.nn.Module): - msg = ( - f"The provided model type {type(model)} is incompatible. " - "Only models inheriting from torch.nn.Module are supported." - ) - raise TypeError(msg) + with training_mode_switcher(model, is_training=False): + nncf_network = NNCFNetwork( + model, input_info=input_info, replace_modules=not trace_parameters, trace_parameters=trace_parameters + ) + nncf_network.nncf.get_tracing_context().disable_trace_dynamic_graph() - input_info = ExampleInputInfo.from_example_input(example_input) + return nncf_network - with training_mode_switcher(model, is_training=False): - nncf_network = NNCFNetwork( - model, input_info=input_info, replace_modules=not trace_parameters, trace_parameters=trace_parameters - ) - nncf_network.nncf.get_tracing_context().disable_trace_dynamic_graph() + if not trace_parameters: + msg = "The 'trace_parameters=False' option is not supported in the experimental tracing mode." + raise nncf.InternalError(msg) + from nncf.torch.function_hook import is_wrapped as pt2_is_wrapped + from nncf.torch.function_hook import wrap_model as pt2_wrap_model + from nncf.torch.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper - return nncf_network + if not pt2_is_wrapped(model): + model = pt2_wrap_model(model) + wrapped_model = GraphModelWrapper(model, example_input=example_input) + return wrapped_model def is_wrapped_model(model: Any) -> bool: @@ -399,7 +399,7 @@ def is_wrapped_model(model: Any) -> bool: :param model: A model. :return: True if the model is wrapped, False otherwise. """ - from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper + from nncf.torch.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper return isinstance(model, (NNCFNetwork, GraphModelWrapper)) @@ -422,16 +422,16 @@ def load_from_config(model: Module, config: Dict[str, Any], example_input: Optio of keywords arguments. Required with enabled legacy tracing mode. :return: Wrapped model with additional modules recovered from given config. """ - if is_torch_tracing_by_torch_function_mode(): - return pt2_load_from_config(model, config) + if is_torch_tracing_by_patching(): + if example_input is None: + msg = "The 'example_input' parameter must be specified." + raise nncf.InternalError(msg) - if example_input is None: - msg = "The 'example_input' parameter must be specified." - raise nncf.InternalError(msg) + nncf_network = wrap_model(model, example_input, trace_parameters=config[NNCFNetwork.TRACE_PARAMETERS_KEY]) + transformation_layout = deserialize_transformations(config) + return PTModelTransformer(nncf_network).transform(transformation_layout) - nncf_network = wrap_model(model, example_input, trace_parameters=config[NNCFNetwork.TRACE_PARAMETERS_KEY]) - transformation_layout = deserialize_transformations(config) - return PTModelTransformer(nncf_network).transform(transformation_layout) + return pt2_load_from_config(model, config) @tracked_function( @@ -447,6 +447,6 @@ def get_config(model: Module) -> Dict[str, Any]: :param model: The compressed model. :return: The configuration object of the compressed model. """ - if is_torch_tracing_by_torch_function_mode(): - return pt2_get_config(model) - return model.nncf.get_config() + if is_torch_tracing_by_patching(): + return model.nncf.get_config() + return pt2_get_config(model) diff --git a/nncf/torch/quantization/quantize_model.py b/nncf/torch/quantization/quantize_model.py index 1a5424d2ec2..46f1b70adc4 100644 --- a/nncf/torch/quantization/quantize_model.py +++ b/nncf/torch/quantization/quantize_model.py @@ -18,7 +18,6 @@ from nncf.common.factory import NNCFGraphFactory from nncf.common.quantization.structs import QuantizationPreset from nncf.data import Dataset -from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper from nncf.parameters import BackupMode from nncf.parameters import CompressionFormat from nncf.parameters import CompressWeightsMode @@ -32,6 +31,7 @@ from nncf.quantization.algorithms.weight_compression.algorithm import WeightCompression from nncf.quantization.quantize_model import warning_model_no_batchwise_support from nncf.scopes import IgnoredScope +from nncf.torch.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper from nncf.torch.graph.operator_metatypes import OPERATIONS_OUTPUT_HAS_NO_BATCH_AXIS from nncf.torch.model_creation import wrap_model diff --git a/nncf/torch/strip.py b/nncf/torch/strip.py index eadf12cfa22..cd14e492e30 100644 --- a/nncf/torch/strip.py +++ b/nncf/torch/strip.py @@ -11,20 +11,23 @@ from copy import deepcopy -from typing import Any, Optional +from typing import Any, Optional, TypeVar + +from torch import nn import nncf -from nncf.experimental.common.check_feature import is_torch_tracing_by_torch_function_mode +from nncf.common.check_features import is_torch_tracing_by_patching from nncf.parameters import StripFormat -from nncf.torch.nncf_network import NNCFNetwork + +TModel = TypeVar("TModel", bound=nn.Module) def strip( - model: NNCFNetwork, + model: TModel, do_copy: bool = True, strip_format: StripFormat = StripFormat.NATIVE, example_input: Optional[Any] = None, -) -> NNCFNetwork: +) -> TModel: """ Removes auxiliary layers and operations added during the compression process, resulting in a clean model ready for deployment. The functionality of the model object is still preserved as a compressed model. @@ -35,13 +38,13 @@ def strip( :param example_input: An example input tensor to be used for tracing the model. :return: The stripped model. """ - if is_torch_tracing_by_torch_function_mode(): - from nncf.experimental.torch2.function_hook.strip import strip_quantized_model + if is_torch_tracing_by_patching(): + return model.nncf.strip(do_copy, strip_format) - if example_input is None: - msg = "Required example_input for strip model." - raise nncf.InternalError(msg) - model = deepcopy(model) if do_copy else model - return strip_quantized_model(model, example_input, strip_format) + from nncf.torch.function_hook.strip import strip_quantized_model - return model.nncf.strip(do_copy, strip_format) + if example_input is None: + msg = "Required example_input for strip model." + raise nncf.InternalError(msg) + model = deepcopy(model) if do_copy else model + return strip_quantized_model(model, example_input, strip_format) diff --git a/pyproject.toml b/pyproject.toml index 64db185b55f..b34f6e2dca5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -85,7 +85,7 @@ files = [ "nncf/data", "nncf/common", "nncf/config", - "nncf/experimental/torch2", + "nncf/torch/function_hook", "nncf/quantization/*py", "nncf/telemetry/", "nncf/tensor/", @@ -102,7 +102,7 @@ disable_error_code = ["empty-body", "no-any-return"] line-length = 120 exclude = [ "nncf/tensorflow/__init__.py", - "nncf/experimental/torch2/function_hook/handle_inner_functions.py" + "nncf/torch/function_hook/handle_inner_functions.py" ] [tool.ruff.lint] diff --git a/tests/post_training/experimental/sparsify_activations/pipelines.py b/tests/post_training/experimental/sparsify_activations/pipelines.py index c4a7265c6f5..41f0e09e89e 100644 --- a/tests/post_training/experimental/sparsify_activations/pipelines.py +++ b/tests/post_training/experimental/sparsify_activations/pipelines.py @@ -29,7 +29,7 @@ from nncf.experimental.torch.sparsify_activations import sparsify_activations from nncf.experimental.torch.sparsify_activations.sparsify_activations_impl import SparsifyActivationsAlgoBackend from nncf.experimental.torch.sparsify_activations.torch_backend import PTSparsifyActivationsAlgoBackend -from nncf.experimental.torch2.function_hook.wrapper import get_hook_storage +from nncf.torch.function_hook.wrapper import get_hook_storage from nncf.torch.quantization.layers import INT8AsymmetricWeightsDecompressor from nncf.torch.quantization.layers import INT8SymmetricWeightsDecompressor from tests.post_training.pipelines.base import PT_BACKENDS diff --git a/tests/torch2/function_hook/graph/test_build_graph_mode.py b/tests/torch2/function_hook/graph/test_build_graph_mode.py index 8ec4f7655bf..8be29d135e1 100644 --- a/tests/torch2/function_hook/graph/test_build_graph_mode.py +++ b/tests/torch2/function_hook/graph/test_build_graph_mode.py @@ -16,20 +16,20 @@ from pytest import FixtureRequest from torch import nn -from nncf.experimental.torch2.function_hook.graph.build_graph_mode import GraphBuilderMode -from nncf.experimental.torch2.function_hook.graph.build_graph_mode import build_graph -from nncf.experimental.torch2.function_hook.graph.graph_utils import ConstMeta -from nncf.experimental.torch2.function_hook.graph.graph_utils import EdgeMeta -from nncf.experimental.torch2.function_hook.graph.graph_utils import FunctionMeta -from nncf.experimental.torch2.function_hook.graph.graph_utils import InOutMeta -from nncf.experimental.torch2.function_hook.graph.graph_utils import NodeType -from nncf.experimental.torch2.function_hook.graph.graph_utils import TensorInfo -from nncf.experimental.torch2.function_hook.graph.graph_utils import TensorMeta -from nncf.experimental.torch2.function_hook.graph.graph_utils import TensorSource -from nncf.experimental.torch2.function_hook.hook_executor_mode import OpMeta -from nncf.experimental.torch2.function_hook.hook_storage import HookStorage -from nncf.experimental.torch2.function_hook.wrapper import get_hook_storage -from nncf.experimental.torch2.function_hook.wrapper import wrap_model +from nncf.torch.function_hook.graph.build_graph_mode import GraphBuilderMode +from nncf.torch.function_hook.graph.build_graph_mode import build_graph +from nncf.torch.function_hook.graph.graph_utils import ConstMeta +from nncf.torch.function_hook.graph.graph_utils import EdgeMeta +from nncf.torch.function_hook.graph.graph_utils import FunctionMeta +from nncf.torch.function_hook.graph.graph_utils import InOutMeta +from nncf.torch.function_hook.graph.graph_utils import NodeType +from nncf.torch.function_hook.graph.graph_utils import TensorInfo +from nncf.torch.function_hook.graph.graph_utils import TensorMeta +from nncf.torch.function_hook.graph.graph_utils import TensorSource +from nncf.torch.function_hook.hook_executor_mode import OpMeta +from nncf.torch.function_hook.hook_storage import HookStorage +from nncf.torch.function_hook.wrapper import get_hook_storage +from nncf.torch.function_hook.wrapper import wrap_model from tests.torch2.function_hook import helpers diff --git a/tests/torch2/function_hook/graph/test_graph_visualisation.py b/tests/torch2/function_hook/graph/test_graph_visualisation.py index ce819298778..25f451c2765 100644 --- a/tests/torch2/function_hook/graph/test_graph_visualisation.py +++ b/tests/torch2/function_hook/graph/test_graph_visualisation.py @@ -12,9 +12,9 @@ import pytest -from nncf.experimental.torch2.function_hook.graph.build_graph_mode import build_graph -from nncf.experimental.torch2.function_hook.graph.graph_visualization import PydotStyleTemplate -from nncf.experimental.torch2.function_hook.graph.graph_visualization import to_pydot +from nncf.torch.function_hook.graph.build_graph_mode import build_graph +from nncf.torch.function_hook.graph.graph_visualization import PydotStyleTemplate +from nncf.torch.function_hook.graph.graph_visualization import to_pydot from tests.cross_fw.shared.paths import TEST_ROOT from tests.torch2.function_hook import helpers from tests.torch2.utils import compare_with_reference_file diff --git a/tests/torch2/function_hook/helpers.py b/tests/torch2/function_hook/helpers.py index f21124d70de..a324cebf7c7 100644 --- a/tests/torch2/function_hook/helpers.py +++ b/tests/torch2/function_hook/helpers.py @@ -12,8 +12,8 @@ import torch from torch import nn -from nncf.experimental.torch2.function_hook.wrapper import register_post_function_hook -from nncf.experimental.torch2.function_hook.wrapper import wrap_model +from nncf.torch.function_hook.wrapper import register_post_function_hook +from nncf.torch.function_hook.wrapper import wrap_model from nncf.torch.layer_utils import COMPRESSION_MODULES from nncf.torch.layer_utils import StatefulModuleInterface diff --git a/tests/torch2/function_hook/nncf_graph/test_layer_attributes.py b/tests/torch2/function_hook/nncf_graph/test_layer_attributes.py index 43007a4264d..8d12c0e4422 100644 --- a/tests/torch2/function_hook/nncf_graph/test_layer_attributes.py +++ b/tests/torch2/function_hook/nncf_graph/test_layer_attributes.py @@ -15,10 +15,10 @@ import torch from torch import nn -from nncf.experimental.torch2.function_hook.graph.graph_utils import TensorMeta -from nncf.experimental.torch2.function_hook.nncf_graph.layer_attributes import PT2OpLayerAttributes -from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import build_nncf_graph -from nncf.experimental.torch2.function_hook.wrapper import wrap_model +from nncf.torch.function_hook.graph.graph_utils import TensorMeta +from nncf.torch.function_hook.nncf_graph.layer_attributes import PT2OpLayerAttributes +from nncf.torch.function_hook.nncf_graph.nncf_graph_builder import build_nncf_graph +from nncf.torch.function_hook.wrapper import wrap_model from tests.torch2.function_hook.helpers import ConvModel from tests.torch2.function_hook.helpers import MatMulLeft from tests.torch2.function_hook.helpers import MatMulRight diff --git a/tests/torch2/function_hook/nncf_graph/test_nncf_graph.py b/tests/torch2/function_hook/nncf_graph/test_nncf_graph.py index 26dfdda9ce6..30418001f24 100644 --- a/tests/torch2/function_hook/nncf_graph/test_nncf_graph.py +++ b/tests/torch2/function_hook/nncf_graph/test_nncf_graph.py @@ -20,20 +20,20 @@ import torchvision.models as models from nncf.common.graph.layer_attributes import Dtype -from nncf.experimental.torch2.function_hook.graph.build_graph_mode import build_graph -from nncf.experimental.torch2.function_hook.graph.graph_utils import ConstMeta -from nncf.experimental.torch2.function_hook.graph.graph_utils import FunctionMeta -from nncf.experimental.torch2.function_hook.graph.graph_utils import InOutMeta -from nncf.experimental.torch2.function_hook.graph.graph_utils import NodeType -from nncf.experimental.torch2.function_hook.graph.graph_utils import TensorMeta -from nncf.experimental.torch2.function_hook.nncf_graph.layer_attributes import PT2OpLayerAttributes -from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import build_nncf_graph -from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import convert_to_nncf_graph -from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import get_dtype -from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import get_name_of_node -from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import get_node_type -from nncf.experimental.torch2.function_hook.wrapper import register_post_function_hook -from nncf.experimental.torch2.function_hook.wrapper import wrap_model +from nncf.torch.function_hook.graph.build_graph_mode import build_graph +from nncf.torch.function_hook.graph.graph_utils import ConstMeta +from nncf.torch.function_hook.graph.graph_utils import FunctionMeta +from nncf.torch.function_hook.graph.graph_utils import InOutMeta +from nncf.torch.function_hook.graph.graph_utils import NodeType +from nncf.torch.function_hook.graph.graph_utils import TensorMeta +from nncf.torch.function_hook.nncf_graph.layer_attributes import PT2OpLayerAttributes +from nncf.torch.function_hook.nncf_graph.nncf_graph_builder import build_nncf_graph +from nncf.torch.function_hook.nncf_graph.nncf_graph_builder import convert_to_nncf_graph +from nncf.torch.function_hook.nncf_graph.nncf_graph_builder import get_dtype +from nncf.torch.function_hook.nncf_graph.nncf_graph_builder import get_name_of_node +from nncf.torch.function_hook.nncf_graph.nncf_graph_builder import get_node_type +from nncf.torch.function_hook.wrapper import register_post_function_hook +from nncf.torch.function_hook.wrapper import wrap_model from nncf.torch.graph.graph import PTNNCFGraph from nncf.torch.graph.operator_metatypes import PTCatMetatype from nncf.torch.graph.operator_metatypes import PTConv2dMetatype diff --git a/tests/torch2/function_hook/quantization/strip/test_strip_dequantize.py b/tests/torch2/function_hook/quantization/strip/test_strip_dequantize.py index df021c27317..2d07fe35e9e 100644 --- a/tests/torch2/function_hook/quantization/strip/test_strip_dequantize.py +++ b/tests/torch2/function_hook/quantization/strip/test_strip_dequantize.py @@ -20,9 +20,9 @@ import nncf import nncf.torch from nncf.common.quantization.structs import QuantizationScheme -from nncf.experimental.torch2.function_hook.wrapper import get_hook_storage from nncf.parameters import CompressWeightsMode from nncf.parameters import StripFormat +from nncf.torch.function_hook.wrapper import get_hook_storage from nncf.torch.quantization.layers import AsymmetricLoraQuantizer from nncf.torch.quantization.layers import BaseQuantizer from nncf.torch.quantization.layers import INT4AsymmetricWeightsDecompressor as INT4AsymDQ diff --git a/tests/torch2/function_hook/quantization/strip/test_strip_native.py b/tests/torch2/function_hook/quantization/strip/test_strip_native.py index 39c96fc7dbe..6f3c21eb882 100644 --- a/tests/torch2/function_hook/quantization/strip/test_strip_native.py +++ b/tests/torch2/function_hook/quantization/strip/test_strip_native.py @@ -18,8 +18,8 @@ import nncf import nncf.torch -from nncf.experimental.torch2.function_hook.wrapper import get_hook_storage from nncf.quantization.advanced_parameters import OverflowFix +from nncf.torch.function_hook.wrapper import get_hook_storage from tests.common.quantization.data_generators import generate_lazy_sweep_data from tests.torch.helpers import BasicConvTestModel diff --git a/tests/torch2/function_hook/quantization/test_calculation_quantizer_params.py b/tests/torch2/function_hook/quantization/test_calculation_quantizer_params.py index 66d6d8b228d..56abadd23e4 100644 --- a/tests/torch2/function_hook/quantization/test_calculation_quantizer_params.py +++ b/tests/torch2/function_hook/quantization/test_calculation_quantizer_params.py @@ -25,14 +25,14 @@ from nncf.common.quantization.structs import QuantizerConfig from nncf.common.quantization.structs import QuantizerGroup from nncf.experimental.common.tensor_statistics.statistics import MinMaxTensorStatistic -from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper -from nncf.experimental.torch2.statistics.aggregator import PT2StatisticsAggregator from nncf.quantization.algorithms.min_max.algorithm import MinMaxQuantization from nncf.quantization.algorithms.min_max.torch_backend import PTMinMaxAlgoBackend from nncf.quantization.fake_quantize import FakeQuantizeParameters from nncf.quantization.fake_quantize import calculate_quantizer_parameters from nncf.tensor import Tensor from nncf.tensor import functions as fns +from nncf.torch.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper +from nncf.torch.function_hook.statistics.aggregator import PT2StatisticsAggregator from nncf.torch.model_creation import wrap_model from tests.cross_fw.test_templates.test_calculate_quantizer_parameters import TemplateTestFQParams from tests.torch.helpers import get_all_inputs_for_graph_node diff --git a/tests/torch2/function_hook/quantization/test_fast_bias_correction.py b/tests/torch2/function_hook/quantization/test_fast_bias_correction.py index 09c67004104..630f1a9a018 100644 --- a/tests/torch2/function_hook/quantization/test_fast_bias_correction.py +++ b/tests/torch2/function_hook/quantization/test_fast_bias_correction.py @@ -14,9 +14,9 @@ import pytest import torch -from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper -from nncf.experimental.torch2.function_hook.wrapper import wrap_model from nncf.quantization.algorithms.fast_bias_correction.torch_backend import PTFastBiasCorrectionAlgoBackend +from nncf.torch.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper +from nncf.torch.function_hook.wrapper import wrap_model from nncf.torch.model_graph_manager import get_fused_bias_value from nncf.torch.model_graph_manager import is_node_with_fused_bias from tests.cross_fw.test_templates.test_fast_bias_correction import TemplateTestFBCAlgorithm diff --git a/tests/torch2/function_hook/quantization/test_fq_lora.py b/tests/torch2/function_hook/quantization/test_fq_lora.py index 15df332e0c6..5c42e8e6a31 100644 --- a/tests/torch2/function_hook/quantization/test_fq_lora.py +++ b/tests/torch2/function_hook/quantization/test_fq_lora.py @@ -24,13 +24,13 @@ import nncf from nncf.data.dataset import Dataset from nncf.errors import ValidationError -from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import build_nncf_graph from nncf.parameters import CompressionFormat from nncf.parameters import CompressWeightsMode from nncf.parameters import StripFormat from nncf.quantization.advanced_parameters import AdvancedCompressionParameters from nncf.quantization.quantize_model import compress_weights from nncf.torch import load_from_config +from nncf.torch.function_hook.nncf_graph.nncf_graph_builder import build_nncf_graph from nncf.torch.model_creation import get_config from nncf.torch.model_creation import wrap_model from nncf.torch.quantization.layers import AsymmetricQuantizer as AQ diff --git a/tests/torch2/function_hook/quantization/test_ptq_params.py b/tests/torch2/function_hook/quantization/test_ptq_params.py index c9540e769cc..fe6c7c7d51c 100644 --- a/tests/torch2/function_hook/quantization/test_ptq_params.py +++ b/tests/torch2/function_hook/quantization/test_ptq_params.py @@ -18,10 +18,10 @@ from nncf.common.graph.transformations.commands import TargetType from nncf.common.graph.transformations.commands import TransformationType from nncf.common.utils.backend import BackendType -from nncf.experimental.torch2.commands import PT2InsertionCommand from nncf.parameters import TargetDevice from nncf.quantization.algorithms.min_max.torch_backend import PTMinMaxAlgoBackend from nncf.scopes import IgnoredScope +from nncf.torch.function_hook.commands import PT2InsertionCommand from nncf.torch.graph.graph import PTNNCFGraph from nncf.torch.graph.graph import PTTargetPoint from nncf.torch.graph.operator_metatypes import PTCatMetatype diff --git a/tests/torch2/function_hook/quantization/test_quantized_graphs.py b/tests/torch2/function_hook/quantization/test_quantized_graphs.py index 81dafaf7648..99415fc2a96 100644 --- a/tests/torch2/function_hook/quantization/test_quantized_graphs.py +++ b/tests/torch2/function_hook/quantization/test_quantized_graphs.py @@ -17,8 +17,8 @@ from networkx.drawing.nx_pydot import to_pydot import nncf -from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import build_nncf_graph from nncf.parameters import ModelType +from nncf.torch.function_hook.nncf_graph.nncf_graph_builder import build_nncf_graph from tests.cross_fw.shared.paths import TEST_ROOT from tests.cross_fw.test_templates.helpers import EmbeddingModel from tests.cross_fw.test_templates.helpers import RoPEModel diff --git a/tests/torch2/function_hook/quantization/test_smooth_quant.py b/tests/torch2/function_hook/quantization/test_smooth_quant.py index 350d70e5b12..c6d41efe130 100644 --- a/tests/torch2/function_hook/quantization/test_smooth_quant.py +++ b/tests/torch2/function_hook/quantization/test_smooth_quant.py @@ -16,9 +16,9 @@ import pytest import torch -from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper -from nncf.experimental.torch2.function_hook.wrapper import wrap_model from nncf.quantization.algorithms.smooth_quant.torch_backend import PTSmoothQuantAlgoBackend +from nncf.torch.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper +from nncf.torch.function_hook.wrapper import wrap_model from nncf.torch.graph.operator_metatypes import PTConv2dMetatype from nncf.torch.graph.operator_metatypes import PTLinearMetatype from tests.cross_fw.test_templates.helpers import ConvTestModel diff --git a/tests/torch2/function_hook/quantization/test_weights_compression.py b/tests/torch2/function_hook/quantization/test_weights_compression.py index 5c983f44932..5daa434f4a2 100644 --- a/tests/torch2/function_hook/quantization/test_weights_compression.py +++ b/tests/torch2/function_hook/quantization/test_weights_compression.py @@ -20,15 +20,15 @@ from nncf import BackupMode from nncf import CompressWeightsMode from nncf import SensitivityMetric -from nncf.experimental.torch2.function_hook import get_hook_storage -from nncf.experimental.torch2.function_hook import wrap_model -from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper from nncf.parameters import CompressionFormat from nncf.quantization import compress_weights from nncf.quantization.advanced_parameters import AdvancedCompressionParameters from nncf.quantization.algorithms.smooth_quant.torch_backend import SQMultiply from nncf.tensor import Tensor from nncf.tensor import TensorDataType +from nncf.torch.function_hook import get_hook_storage +from nncf.torch.function_hook import wrap_model +from nncf.torch.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper from nncf.torch.quantization.layers import INT4AsymmetricWeightsDecompressor from nncf.torch.quantization.layers import INT4SymmetricWeightsDecompressor from nncf.torch.quantization.layers import INT8AsymmetricWeightsDecompressor diff --git a/tests/torch2/function_hook/sparsify_activations/test_algo.py b/tests/torch2/function_hook/sparsify_activations/test_algo.py index 9f371f5e45f..2946316c416 100644 --- a/tests/torch2/function_hook/sparsify_activations/test_algo.py +++ b/tests/torch2/function_hook/sparsify_activations/test_algo.py @@ -25,10 +25,10 @@ from nncf.experimental.torch.sparsify_activations.sparsify_activations_impl import SparsifyActivationsAlgorithm from nncf.experimental.torch.sparsify_activations.sparsify_activations_impl import TargetScope from nncf.experimental.torch.sparsify_activations.torch_backend import ActivationsSparsifier -from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper -from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import build_nncf_graph -from nncf.experimental.torch2.function_hook.wrapper import get_hook_storage from nncf.scopes import IgnoredScope +from nncf.torch.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper +from nncf.torch.function_hook.nncf_graph.nncf_graph_builder import build_nncf_graph +from nncf.torch.function_hook.wrapper import get_hook_storage from nncf.torch.model_creation import wrap_model from tests.cross_fw.shared.paths import TEST_ROOT from tests.torch.helpers import set_torch_seed diff --git a/tests/torch2/function_hook/test_disable_tracing.py b/tests/torch2/function_hook/test_disable_tracing.py index f4db97812c9..23d22f99cb0 100644 --- a/tests/torch2/function_hook/test_disable_tracing.py +++ b/tests/torch2/function_hook/test_disable_tracing.py @@ -13,13 +13,13 @@ from torch import nn from torch.overrides import _get_current_function_mode_stack -from nncf.experimental.torch2.function_hook.graph.build_graph_mode import build_graph -from nncf.experimental.torch2.function_hook.graph.graph_visualization import to_pydot -from nncf.experimental.torch2.function_hook.hook_executor_mode import FunctionHookMode -from nncf.experimental.torch2.function_hook.hook_executor_mode import disable_function_hook_mode -from nncf.experimental.torch2.function_hook.wrapper import get_hook_storage -from nncf.experimental.torch2.function_hook.wrapper import wrap_model from nncf.torch import disable_tracing +from nncf.torch.function_hook.graph.build_graph_mode import build_graph +from nncf.torch.function_hook.graph.graph_visualization import to_pydot +from nncf.torch.function_hook.hook_executor_mode import FunctionHookMode +from nncf.torch.function_hook.hook_executor_mode import disable_function_hook_mode +from nncf.torch.function_hook.wrapper import get_hook_storage +from nncf.torch.function_hook.wrapper import wrap_model from tests.cross_fw.shared.paths import TEST_ROOT from tests.torch2.utils import compare_with_reference_file diff --git a/tests/torch2/function_hook/test_extractor.py b/tests/torch2/function_hook/test_extractor.py index e786896ec69..ac0a970cbde 100644 --- a/tests/torch2/function_hook/test_extractor.py +++ b/tests/torch2/function_hook/test_extractor.py @@ -14,10 +14,10 @@ from torch import nn import tests.cross_fw.test_templates.helpers as helpers -from nncf.experimental.torch2.function_hook.extractor import extract_model -from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import build_nncf_graph -from nncf.experimental.torch2.function_hook.wrapper import register_pre_function_hook -from nncf.experimental.torch2.function_hook.wrapper import wrap_model +from nncf.torch.function_hook.extractor import extract_model +from nncf.torch.function_hook.nncf_graph.nncf_graph_builder import build_nncf_graph +from nncf.torch.function_hook.wrapper import register_pre_function_hook +from nncf.torch.function_hook.wrapper import wrap_model from nncf.torch.quantization.layers import PTQuantizerSpec from nncf.torch.quantization.layers import QuantizationMode from nncf.torch.quantization.layers import SymmetricQuantizer diff --git a/tests/torch2/function_hook/test_function_hook_mode.py b/tests/torch2/function_hook/test_function_hook_mode.py index 4a0e0eea447..310d5c48995 100644 --- a/tests/torch2/function_hook/test_function_hook_mode.py +++ b/tests/torch2/function_hook/test_function_hook_mode.py @@ -18,13 +18,13 @@ from pytest import FixtureRequest from torch import nn -from nncf.experimental.torch2.function_hook.hook_executor_mode import FunctionHookMode -from nncf.experimental.torch2.function_hook.hook_executor_mode import OpMeta -from nncf.experimental.torch2.function_hook.hook_executor_mode import generate_normalized_op_name -from nncf.experimental.torch2.function_hook.hook_storage import HookStorage -from nncf.experimental.torch2.function_hook.wrapper import get_hook_storage -from nncf.experimental.torch2.function_hook.wrapper import register_pre_function_hook -from nncf.experimental.torch2.function_hook.wrapper import wrap_model +from nncf.torch.function_hook.hook_executor_mode import FunctionHookMode +from nncf.torch.function_hook.hook_executor_mode import OpMeta +from nncf.torch.function_hook.hook_executor_mode import generate_normalized_op_name +from nncf.torch.function_hook.hook_storage import HookStorage +from nncf.torch.function_hook.wrapper import get_hook_storage +from nncf.torch.function_hook.wrapper import register_pre_function_hook +from nncf.torch.function_hook.wrapper import wrap_model from tests.torch2.function_hook import helpers from tests.torch2.function_hook.helpers import CallCount from tests.torch2.function_hook.helpers import CounterHook diff --git a/tests/torch2/function_hook/test_handle_inner_functions.py b/tests/torch2/function_hook/test_handle_inner_functions.py index a24d9f62f11..ec480346350 100644 --- a/tests/torch2/function_hook/test_handle_inner_functions.py +++ b/tests/torch2/function_hook/test_handle_inner_functions.py @@ -18,10 +18,10 @@ import torch.nn.functional as F from torch import nn -from nncf.experimental.torch2.function_hook.graph.build_graph_mode import build_graph -from nncf.experimental.torch2.function_hook.graph.graph_visualization import to_pydot -from nncf.experimental.torch2.function_hook.handle_inner_functions import MAP_HANDLER_TO_INNER_FUNCTION -from nncf.experimental.torch2.function_hook.wrapper import wrap_model +from nncf.torch.function_hook.graph.build_graph_mode import build_graph +from nncf.torch.function_hook.graph.graph_visualization import to_pydot +from nncf.torch.function_hook.handle_inner_functions import MAP_HANDLER_TO_INNER_FUNCTION +from nncf.torch.function_hook.wrapper import wrap_model from tests.cross_fw.shared.paths import TEST_ROOT from tests.torch2.utils import compare_with_reference_file diff --git a/tests/torch2/function_hook/test_hook_storage.py b/tests/torch2/function_hook/test_hook_storage.py index 3e3807a8be7..1cf144ddf1a 100644 --- a/tests/torch2/function_hook/test_hook_storage.py +++ b/tests/torch2/function_hook/test_hook_storage.py @@ -15,8 +15,8 @@ import torch from torch import nn -from nncf.experimental.torch2.function_hook.hook_storage import HookStorage -from nncf.experimental.torch2.function_hook.hook_storage import decode_hook_name +from nncf.torch.function_hook.hook_storage import HookStorage +from nncf.torch.function_hook.hook_storage import decode_hook_name from tests.torch2.function_hook.helpers import CallCount diff --git a/tests/torch2/function_hook/test_serialization.py b/tests/torch2/function_hook/test_serialization.py index 66df90b27a6..27f23242b13 100644 --- a/tests/torch2/function_hook/test_serialization.py +++ b/tests/torch2/function_hook/test_serialization.py @@ -15,12 +15,12 @@ from torch import nn import nncf -from nncf.experimental.torch2.function_hook import get_hook_storage -from nncf.experimental.torch2.function_hook import register_post_function_hook -from nncf.experimental.torch2.function_hook import register_pre_function_hook -from nncf.experimental.torch2.function_hook import wrap_model from nncf.torch import get_config from nncf.torch import load_from_config +from nncf.torch.function_hook import get_hook_storage +from nncf.torch.function_hook import register_post_function_hook +from nncf.torch.function_hook import register_pre_function_hook +from nncf.torch.function_hook import wrap_model from tests.torch2.function_hook.helpers import HookWithState from tests.torch2.function_hook.helpers import SimpleModel diff --git a/tests/torch2/function_hook/test_train.py b/tests/torch2/function_hook/test_train.py index a79766fc3f3..340501f16de 100644 --- a/tests/torch2/function_hook/test_train.py +++ b/tests/torch2/function_hook/test_train.py @@ -19,7 +19,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP import nncf -from nncf.experimental.torch2.function_hook.wrapper import wrap_model +from nncf.torch.function_hook.wrapper import wrap_model from tests.torch2.function_hook.helpers import ConvModel from tests.torch2.function_hook.helpers import get_wrapped_simple_model_with_hook diff --git a/tests/torch2/function_hook/test_weak_map.py b/tests/torch2/function_hook/test_weak_map.py index 0f8000c867d..185ac96dd92 100644 --- a/tests/torch2/function_hook/test_weak_map.py +++ b/tests/torch2/function_hook/test_weak_map.py @@ -13,7 +13,7 @@ import torch -from nncf.experimental.torch2.function_hook.weak_map import WeakUnhashableKeyMap +from nncf.torch.function_hook.weak_map import WeakUnhashableKeyMap def test_set_get(): diff --git a/tests/torch2/function_hook/test_wrapper.py b/tests/torch2/function_hook/test_wrapper.py index 8b70bbc3b50..dbf7bd18555 100644 --- a/tests/torch2/function_hook/test_wrapper.py +++ b/tests/torch2/function_hook/test_wrapper.py @@ -18,10 +18,10 @@ import pytest import torch -from nncf.experimental.torch2.function_hook.wrapper import is_wrapped -from nncf.experimental.torch2.function_hook.wrapper import register_post_function_hook -from nncf.experimental.torch2.function_hook.wrapper import register_pre_function_hook -from nncf.experimental.torch2.function_hook.wrapper import wrap_model +from nncf.torch.function_hook.wrapper import is_wrapped +from nncf.torch.function_hook.wrapper import register_post_function_hook +from nncf.torch.function_hook.wrapper import register_pre_function_hook +from nncf.torch.function_hook.wrapper import wrap_model from tests.torch2.function_hook import helpers ADD_VALUE = 2.0