Skip to content

Move function hook mode from experimental to torch #3437

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .ci/cspell_dict.txt
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,7 @@ ovhw
ovlstm
ovmvn
ovroi
ovselu
pbar
perchannel
pertensor
Expand All @@ -339,6 +340,7 @@ pthw
ptnncf
ptprelu
ptrelu
ptselu
ptsilu
ptwc
pymodules
Expand Down
2 changes: 1 addition & 1 deletion examples/llm_compression/torch/qat_with_lora/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@
import nncf.torch
from nncf.common.logging.track_progress import track
from nncf.data.dataset import Dataset
from nncf.experimental.torch2.function_hook.wrapper import get_hook_storage
from nncf.parameters import CompressionFormat
from nncf.parameters import CompressWeightsMode
from nncf.parameters import StripFormat
from nncf.quantization.advanced_parameters import AdvancedCompressionParameters
from nncf.quantization.quantize_model import compress_weights
from nncf.torch.function_hook.wrapper import get_hook_storage
from nncf.torch.model_creation import load_from_config
from nncf.torch.quantization.layers import AsymmetricLoraQuantizer
from nncf.torch.quantization.layers import SymmetricLoraQuantizer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import os


def is_torch_tracing_by_torch_function_mode() -> bool:
def is_torch_tracing_by_patching() -> bool:
"""
Checks if legacy torch tracing is enabled by environment variable NNCF_TORCH_LEGACY_TRACING.

Expand All @@ -21,4 +21,4 @@ def is_torch_tracing_by_torch_function_mode() -> bool:

:return: True if legacy torch tracing is enabled, False otherwise.
"""
return os.getenv("NNCF_TORCH_LEGACY_TRACING", "").lower() not in ["1", "on", "true"]
return os.getenv("NNCF_TORCH_LEGACY_TRACING", "").lower() in ["1", "on", "true"]
20 changes: 10 additions & 10 deletions nncf/common/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from typing import Any, TypeVar, cast

import nncf
from nncf.common.check_features import is_torch_tracing_by_patching
from nncf.common.engine import Engine
from nncf.common.graph.graph import NNCFGraph
from nncf.common.graph.model_transformer import ModelTransformer
Expand All @@ -20,7 +21,6 @@
from nncf.common.utils.backend import BackendType
from nncf.common.utils.backend import get_backend
from nncf.data.dataset import Dataset
from nncf.experimental.common.check_feature import is_torch_tracing_by_torch_function_mode

TModel = TypeVar("TModel")

Expand Down Expand Up @@ -54,7 +54,7 @@ def create(model: TModel) -> NNCFGraph:

return FXGraphConverter.create_nncf_graph(cast(GraphModule, model))
if model_backend == BackendType.TORCH:
from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper
from nncf.torch.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper
from nncf.torch.nncf_network import NNCFNetwork

if isinstance(model, GraphModelWrapper):
Expand Down Expand Up @@ -90,13 +90,13 @@ def create(model: TModel, inplace: bool = False) -> ModelTransformer[Any]:
from nncf.openvino.graph.model_transformer import OVModelTransformer

return OVModelTransformer(cast(Model, model), inplace=inplace)
if model_backend == BackendType.TORCH and is_torch_tracing_by_torch_function_mode():
from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper
from nncf.experimental.torch2.model_transformer import PT2ModelTransformer
if model_backend == BackendType.TORCH and not is_torch_tracing_by_patching():
from nncf.torch.function_hook.model_transformer import PT2ModelTransformer
from nncf.torch.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper

return PT2ModelTransformer(cast(GraphModelWrapper, model))

if model_backend == BackendType.TORCH and not is_torch_tracing_by_torch_function_mode():
if model_backend == BackendType.TORCH and is_torch_tracing_by_patching():
from nncf.torch.model_transformer import PTModelTransformer
from nncf.torch.nncf_network import NNCFNetwork

Expand Down Expand Up @@ -137,8 +137,8 @@ def create(model: TModel) -> Engine:
if model_backend in (BackendType.TORCH, BackendType.TORCH_FX):
from torch.nn import Module

from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper
from nncf.torch.engine import PTEngine
from nncf.torch.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper

if isinstance(model, GraphModelWrapper):
pt_model = model.model
Expand Down Expand Up @@ -191,12 +191,12 @@ def create(model: TModel, dataset: Dataset) -> aggregator.StatisticsAggregator:
from nncf.openvino.statistics.aggregator import OVStatisticsAggregator

return OVStatisticsAggregator(dataset)
if model_backend == BackendType.TORCH and not is_torch_tracing_by_torch_function_mode():
if model_backend == BackendType.TORCH and is_torch_tracing_by_patching():
from nncf.torch.statistics.aggregator import PTStatisticsAggregator

return PTStatisticsAggregator(dataset)
if model_backend == BackendType.TORCH and is_torch_tracing_by_torch_function_mode():
from nncf.experimental.torch2.statistics.aggregator import PT2StatisticsAggregator
if model_backend == BackendType.TORCH and not is_torch_tracing_by_patching():
from nncf.torch.function_hook.statistics.aggregator import PT2StatisticsAggregator

return PT2StatisticsAggregator(dataset)
if model_backend == BackendType.TORCH_FX:
Expand Down
11 changes: 5 additions & 6 deletions nncf/common/utils/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from packaging import version

import nncf
from nncf.experimental.common.check_feature import is_torch_tracing_by_torch_function_mode
from nncf.common.check_features import is_torch_tracing_by_patching

try:
import openvino # type: ignore # noqa: F401
Expand Down Expand Up @@ -56,12 +56,11 @@ def is_torch_model(model: Any) -> bool:
import torch
import torch.fx

from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper
from nncf.torch.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper

if is_torch_tracing_by_torch_function_mode():
return isinstance(model, (GraphModelWrapper, torch.nn.Module)) and not isinstance(model, torch.fx.GraphModule)

return not isinstance(model, torch.fx.GraphModule) and isinstance(model, torch.nn.Module)
if is_torch_tracing_by_patching():
return not isinstance(model, torch.fx.GraphModule) and isinstance(model, torch.nn.Module)
return isinstance(model, (GraphModelWrapper, torch.nn.Module)) and not isinstance(model, torch.fx.GraphModule)


@result_verifier
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@
from nncf.common.graph.transformations.commands import TargetType
from nncf.data import Dataset
from nncf.experimental.torch.sparsify_activations.sparsify_activations_impl import SparsifyActivationsAlgoBackend
from nncf.experimental.torch2.commands import PT2InsertionCommand
from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper
from nncf.experimental.torch2.function_hook.wrapper import get_hook_storage
from nncf.experimental.torch2.model_transformer import PT2ModelTransformer
from nncf.tensor.functions.torch_numeric import quantile
from nncf.torch.function_hook.commands import PT2InsertionCommand
from nncf.torch.function_hook.model_transformer import PT2ModelTransformer
from nncf.torch.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper
from nncf.torch.function_hook.wrapper import get_hook_storage
from nncf.torch.graph import operator_metatypes as om
from nncf.torch.graph.transformations.commands import PTTargetPoint
from nncf.torch.graph.transformations.layout import PTTransformationLayout
Expand Down
17 changes: 0 additions & 17 deletions nncf/experimental/torch2/function_hook/__init__.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@
from nncf.common.graph.model_transformer import ModelTransformer
from nncf.common.graph.transformations.commands import TargetType
from nncf.experimental.common.tensor_statistics.collectors import TensorCollector
from nncf.experimental.torch2.function_hook.extractor import extract_model
from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper
from nncf.quantization.algorithms.fast_bias_correction.backend import FastBiasCorrectionAlgoBackend
from nncf.tensor import Tensor
from nncf.torch.function_hook.extractor import extract_model
from nncf.torch.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper
from nncf.torch.graph.transformations.command_creation import create_bias_correction_command
from nncf.torch.graph.transformations.commands import PTBiasCorrectionCommand
from nncf.torch.graph.transformations.commands import PTModelExtractionCommand
Expand Down
22 changes: 9 additions & 13 deletions nncf/quantization/algorithms/min_max/torch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import nncf
import nncf.torch.graph.operator_metatypes as om
from nncf.common.check_features import is_torch_tracing_by_patching
from nncf.common.graph.definitions import NNCFGraphNodeType
from nncf.common.graph.graph import NNCFGraph
from nncf.common.graph.graph import NNCFNode
Expand All @@ -24,16 +25,15 @@
from nncf.common.hardware.config import HWConfig
from nncf.common.quantization.quantizer_propagation.structs import QuantizationTrait
from nncf.common.quantization.structs import QuantizerConfig
from nncf.experimental.common.check_feature import is_torch_tracing_by_torch_function_mode
from nncf.experimental.common.tensor_statistics.collectors import REDUCERS_MAP
from nncf.experimental.common.tensor_statistics.collectors import TensorReducerBase
from nncf.experimental.torch2.commands import PT2InsertionCommand
from nncf.parameters import ModelType
from nncf.parameters import TargetDevice
from nncf.quantization.algorithms.min_max.backend import MinMaxAlgoBackend
from nncf.quantization.fake_quantize import FakeConvertParameters
from nncf.quantization.fake_quantize import FakeQuantizeParameters
from nncf.quantization.range_estimator import StatisticsType
from nncf.torch.function_hook.commands import PT2InsertionCommand
from nncf.torch.graph.graph import PTNNCFGraph
from nncf.torch.graph.graph import PTTargetPoint
from nncf.torch.graph.operator_metatypes import ELEMENTWISE_OPERATIONS
Expand Down Expand Up @@ -153,10 +153,7 @@ def target_point(target_type: TargetType, target_node_name: str, port_id: int) -
input_port_id: Optional[int] = port_id
if NNCFGraphNodeType.INPUT_NODE in target_node_name or target_type == TargetType.POST_LAYER_OPERATION:
input_port_id = None
if (
not is_torch_tracing_by_torch_function_mode()
and target_type in PTMinMaxAlgoBackend.TARGET_TYPE_TO_PT_INS_TYPE_MAP
):
if is_torch_tracing_by_patching() and target_type in PTMinMaxAlgoBackend.TARGET_TYPE_TO_PT_INS_TYPE_MAP:
target_type = PTMinMaxAlgoBackend.TARGET_TYPE_TO_PT_INS_TYPE_MAP[target_type]
return PTTargetPoint(target_type, target_node_name, input_port_id=input_port_id)

Expand Down Expand Up @@ -281,10 +278,9 @@ def create_quantizer_insertion_command(
quantizer = PTMinMaxAlgoBackend._create_quantizer(
quantizer_config, scale_shape, parameters, target_point.target_type
)
if is_torch_tracing_by_torch_function_mode():
return PT2InsertionCommand(target_points=[target_point], hook_module=quantizer)

return create_quantizer_insertion_command(target_point, quantizer)
if is_torch_tracing_by_patching():
return create_quantizer_insertion_command(target_point, quantizer)
return PT2InsertionCommand(target_points=[target_point], hook_module=quantizer)

@staticmethod
def create_unified_scales_quantizers_insertion_commands(
Expand All @@ -300,9 +296,9 @@ def create_unified_scales_quantizers_insertion_commands(
quantizer = PTMinMaxAlgoBackend._create_quantizer(
quantizer_config, scale_shape, parameters, target_points[0].target_type
)
if is_torch_tracing_by_torch_function_mode():
return [PT2InsertionCommand(target_points=target_points, hook_module=quantizer)]
return [create_shared_quantizer_insertion_command(target_points, quantizer)]
if is_torch_tracing_by_patching():
return [create_shared_quantizer_insertion_command(target_points, quantizer)]
return [PT2InsertionCommand(target_points=target_points, hook_module=quantizer)]

@staticmethod
def get_ignored_metatypes(model_type: ModelType, device: TargetDevice) -> List[OperatorMetatype]:
Expand Down
33 changes: 17 additions & 16 deletions nncf/quantization/algorithms/smooth_quant/torch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,21 @@
import torch

import nncf.torch.graph.operator_metatypes as om
from nncf.common.check_features import is_torch_tracing_by_patching
from nncf.common.graph import NNCFGraph
from nncf.common.graph import NNCFNode
from nncf.common.graph.operator_metatypes import OperatorMetatype
from nncf.common.graph.transformations.commands import TargetType
from nncf.common.quantization.quantizer_propagation.structs import QuantizationTrait
from nncf.common.tensor_statistics.statistic_point import StatisticPoint
from nncf.experimental.common.check_feature import is_torch_tracing_by_torch_function_mode
from nncf.experimental.common.tensor_statistics.collectors import AbsMaxReducer
from nncf.experimental.common.tensor_statistics.collectors import MaxAggregator
from nncf.experimental.common.tensor_statistics.collectors import TensorCollector
from nncf.experimental.torch2.commands import PT2ConstUpdateCommand
from nncf.experimental.torch2.commands import PT2InsertionCommand
from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper
from nncf.quantization.algorithms.smooth_quant.backend import SmoothQuantAlgoBackend
from nncf.tensor import Tensor
from nncf.torch.function_hook.commands import PT2ConstUpdateCommand
from nncf.torch.function_hook.commands import PT2InsertionCommand
from nncf.torch.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper
from nncf.torch.graph.transformations.command_creation import create_command_to_update_weight
from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand
from nncf.torch.graph.transformations.commands import PTTargetPoint
Expand Down Expand Up @@ -136,10 +136,10 @@ def get_weight_value(node_with_weight: NNCFNode, model: NNCFNetwork, nncf_graph:
def weight_update_command(
node_with_weight: NNCFNode, nncf_graph: NNCFGraph, weight_value: torch.Tensor
) -> PTWeightUpdateCommand:
if is_torch_tracing_by_torch_function_mode():
weight_node = get_const_node(node_with_weight, node_with_weight.metatype.weight_port_ids[0], nncf_graph)
return PT2ConstUpdateCommand(weight_node, weight_value)
return create_command_to_update_weight(node_with_weight, weight_value)
if is_torch_tracing_by_patching():
return create_command_to_update_weight(node_with_weight, weight_value)
weight_node = get_const_node(node_with_weight, node_with_weight.metatype.weight_port_ids[0], nncf_graph)
return PT2ConstUpdateCommand(weight_node, weight_value)

@staticmethod
def scale_insertion_command(
Expand All @@ -157,9 +157,9 @@ def scale_insertion_command(
sq_multiply = SQMultiply(scale_value.shape)
sq_multiply.scale = scale_value

if is_torch_tracing_by_torch_function_mode():
return PT2InsertionCommand(target_points=target_points, hook_module=sq_multiply)
return PTSharedFnInsertionCommand(target_points, sq_multiply, scale_node_name)
if is_torch_tracing_by_patching():
return PTSharedFnInsertionCommand(target_points, sq_multiply, scale_node_name)
return PT2InsertionCommand(target_points=target_points, hook_module=sq_multiply)

@staticmethod
def get_activation_channel_axis(node: NNCFNode, port_id: int) -> int:
Expand All @@ -175,11 +175,12 @@ def get_weight_channel_axis(node: NNCFNode) -> int:

@staticmethod
def is_node_with_shared_weight(node: NNCFNode, nncf_graph: NNCFGraph) -> bool:
if is_torch_tracing_by_torch_function_mode():
weight_node = get_const_node(node, node.metatype.weight_port_ids[0], nncf_graph)
output_edges = nncf_graph.get_next_nodes(weight_node)
return len(output_edges) > 1
return node.is_shared()
if is_torch_tracing_by_patching():
return node.is_shared()

weight_node = get_const_node(node, node.metatype.weight_port_ids[0], nncf_graph)
output_edges = nncf_graph.get_next_nodes(weight_node)
return len(output_edges) > 1

@staticmethod
def get_filter_fn_for_statistics(activation_port_id: int, algorithm_key: str) -> Callable[[StatisticPoint], bool]:
Expand Down
Loading