Skip to content

Commit aa3cc95

Browse files
AlexanderDokuchaevAlexander Dokuchaev
and
Alexander Dokuchaev
authored
[PT2] FBC (#3258)
### Changes - Impanated FBC for experimental tracing - Save graph to GraphModelWrapper, to using graph like for NNCFNetwork - Add ConstantLayerAttributes to constant node - Check is_experimental_torch_tracing_enabled inside patch_torch_operators, to support [optimum-intel ](https://github.com/huggingface/optimum-intel/blob/f601b8b1fda4477ed9f9e4293cf3c9d5cec4ad1b/optimum/intel/openvino/__init__.py#L49) ### Related tickets 152996 --------- Co-authored-by: Alexander Dokuchaev <[email protected]>
1 parent 77ef40c commit aa3cc95

File tree

15 files changed

+528
-76
lines changed

15 files changed

+528
-76
lines changed

nncf/common/factory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def create(model: TModel) -> NNCFGraph:
5858
from nncf.torch.nncf_network import NNCFNetwork
5959

6060
if isinstance(model, GraphModelWrapper):
61-
return model.build_graph()
61+
return model.get_graph()
6262
if isinstance(model, NNCFNetwork):
6363
return model.nncf.get_graph()
6464
msg = f"Unexpected type of model {type(model)} for TORCH backend"
Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,220 @@
1+
# Copyright (c) 2025 Intel Corporation
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple
12+
13+
import torch
14+
from torch import nn
15+
16+
import nncf
17+
from nncf import nncf_logger
18+
from nncf.common.graph.graph import NNCFNode
19+
from nncf.experimental.torch2.function_hook.nncf_graph.layer_attributes import PT2OpLayerAttributes
20+
from nncf.experimental.torch2.function_hook.wrapper import get_hook_storage
21+
from nncf.torch.graph import operator_metatypes as om
22+
from nncf.torch.graph.graph import PTNNCFGraph
23+
from nncf.torch.model_graph_manager import get_const_data
24+
from nncf.torch.model_graph_manager import get_const_data_on_port
25+
from nncf.torch.model_graph_manager import get_const_node
26+
27+
CONV_METATYPES = (
28+
om.PTConv1dMetatype,
29+
om.PTConv2dMetatype,
30+
om.PTConv3dMetatype,
31+
om.PTDepthwiseConv1dSubtype,
32+
om.PTDepthwiseConv2dSubtype,
33+
om.PTDepthwiseConv3dSubtype,
34+
)
35+
36+
37+
class ExtractedFunc(nn.Module):
38+
"""
39+
Module to execute function with kwargs.
40+
Support function only with one input.
41+
42+
:param fn: Function to execute.
43+
:param kwargs: Function arguments.
44+
"""
45+
46+
def __init__(self, fn: Callable[..., torch.Tensor], kwargs: Dict[str, Any]) -> None:
47+
super().__init__()
48+
self.fn = fn
49+
self.kwargs = kwargs
50+
51+
def forward(self, x: torch.Tensor) -> torch.Tensor:
52+
return self.fn(x, **self.kwargs)
53+
54+
55+
def apply_args_to_kwargs(
56+
args: Sequence[Any], kwargs: Dict[str, Any], indexed_args: List[Tuple[int, str]]
57+
) -> Dict[str, Any]:
58+
"""
59+
Applies the given arguments and keyword arguments to a dictionary of keyword arguments.
60+
61+
:param args: The positional arguments.
62+
:param kwargs: The keyword arguments.
63+
:param indexed_args: The list of pairs of indexes and names.
64+
:return: A dictionary of keyword arguments with the applied arguments and keyword arguments.
65+
"""
66+
args_dict: Dict[str, Any] = dict()
67+
for idx, arg_name in indexed_args:
68+
if idx < len(args):
69+
args_dict[arg_name] = args[idx]
70+
elif arg_name in kwargs:
71+
args_dict[arg_name] = kwargs[arg_name]
72+
73+
return args_dict
74+
75+
76+
def extract_bn(model: nn.Module, graph: PTNNCFGraph, node: NNCFNode) -> ExtractedFunc:
77+
"""
78+
Extract batch_norm operation.
79+
80+
:param model: Source model.
81+
:param graph: Graph of source model.
82+
:param node: Target batch_norm node.
83+
:return: BatchNorm module with same attributes and parameters from source module or None.
84+
"""
85+
layer_attr = node.layer_attributes
86+
if not isinstance(layer_attr, PT2OpLayerAttributes):
87+
msg = f"Expected PT2OpLayerAttributes for input_node.layer_attributes, actual: {type(layer_attr)}"
88+
raise nncf.InternalError(msg)
89+
90+
# torch.batch_norm(
91+
# 0 - input: Tensor,
92+
# 1 - weight: Optional[Tensor]
93+
# 2 - bias: Optional[Tensor]
94+
# 3 - running_mean: Optional[Tensor]
95+
# 4 - running_var: Optional[Tensor]
96+
# 5 - training: _bool
97+
# 6 - momentum: _float
98+
# 7 - eps: _float
99+
# 8 - cudnn_enabled: _bool
100+
# ) -> Tensor: ...
101+
102+
weight = get_const_data_on_port(model, graph, node, 1)
103+
bias = get_const_data_on_port(model, graph, node, 2)
104+
running_mean = get_const_data_on_port(model, graph, node, 3)
105+
running_var = get_const_data_on_port(model, graph, node, 4)
106+
107+
bn_kwargs = apply_args_to_kwargs(
108+
layer_attr.op_args,
109+
layer_attr.op_kwargs,
110+
[(6, "momentum"), (7, "eps"), (8, "cudnn_enabled")],
111+
)
112+
bn_kwargs["weight"] = weight
113+
bn_kwargs["bias"] = bias
114+
bn_kwargs["running_mean"] = running_mean
115+
bn_kwargs["running_var"] = running_var
116+
bn_kwargs["training"] = False
117+
118+
return ExtractedFunc(layer_attr.func, bn_kwargs)
119+
120+
121+
def extract_conv(
122+
model: nn.Module,
123+
graph: PTNNCFGraph,
124+
input_node: NNCFNode,
125+
output_node: NNCFNode,
126+
) -> nn.Module:
127+
"""
128+
Extracts a convolutional layer from an NNCF graph and constructs an ExtractedFunc module.
129+
130+
:param model: The NNCF network containing the layer.
131+
:param graph: The NNCF graph.
132+
:param input_nodes: The name of input node.
133+
:param output_nodes: The name of output node.
134+
:return: The extracted convolutional layer as an ExtractedFunc module.
135+
"""
136+
137+
# torch.conv*d(
138+
# 0 - input: Tensor
139+
# 1 - weight: Tensor
140+
# 2 - bias: Optional[Tensor]
141+
# 3 - stride: Union[Union[_int, SymInt], Sequence[Union[_int, SymInt]]]
142+
# 4 - padding: Union[Union[_int, SymInt] | str
143+
# 5 - dilation: Union[Union[_int, SymInt], Sequence[Union[_int, SymInt]]]
144+
# 6 - groups: Union[_int, SymInt]
145+
# ) -> Tensor: ...
146+
147+
weight_node = get_const_node(input_node, 1, graph)
148+
if weight_node is None:
149+
msg = "Weight node not found for {input_node}"
150+
raise nncf.InternalError(msg)
151+
weight = get_const_data(weight_node, model)
152+
153+
hook_storage = get_hook_storage(model)
154+
with torch.no_grad():
155+
# Calculate weight after execution all hook fro weight data
156+
weight = hook_storage.execute_post_function_hooks(weight_node.node_name, 0, weight)
157+
weight = hook_storage.execute_pre_function_hooks(input_node.node_name, 1, weight)
158+
159+
bias_node = get_const_node(input_node, 2, graph)
160+
bias = get_const_data(bias_node, model) if bias_node is not None else None
161+
162+
layer_attrs = input_node.layer_attributes
163+
164+
if not isinstance(layer_attrs, PT2OpLayerAttributes):
165+
msg = f"Expected PT2OpLayerAttributes for input_node.layer_attributes, actual: {type(layer_attrs)}"
166+
raise nncf.InternalError(msg)
167+
168+
conv_kwargs = apply_args_to_kwargs(
169+
layer_attrs.op_args,
170+
layer_attrs.op_kwargs,
171+
[(3, "stride"), (4, "padding"), (5, "dilation"), (6, "groups")],
172+
)
173+
conv_kwargs["weight"] = weight
174+
conv_kwargs["bias"] = bias
175+
conv_module = ExtractedFunc(layer_attrs.func, conv_kwargs)
176+
177+
if input_node == output_node:
178+
return conv_module
179+
180+
if output_node.metatype is not om.PTBatchNormMetatype:
181+
msg = f"Support only PTBatchNormMetatype as output node, actual: {output_node.metatype}"
182+
raise nncf.InternalError(msg)
183+
184+
next_nodes = graph.get_next_nodes(input_node)
185+
if output_node not in next_nodes:
186+
msg = f"Output node {output_node} not found after {input_node}"
187+
raise nncf.InternalError(msg)
188+
189+
bn_module = extract_bn(model, graph, output_node)
190+
return nn.Sequential(conv_module, bn_module)
191+
192+
193+
def extract_model(
194+
model: nn.Module, graph: PTNNCFGraph, input_nodes: List[str], output_nodes: List[str]
195+
) -> Optional[nn.Module]:
196+
"""
197+
Extracts a submodule from a given NNCF network containing only the nodes from the input to the output node.
198+
199+
Supported subgraph:
200+
- Conv
201+
- Conv + BatchNorm
202+
203+
:param model: The NNCF network to extract the submodule from.
204+
:param input_nodes: List containing names of the input nodes for the submodule.
205+
:param output_nodes: List containing names of the output nodes for the submodule.
206+
:return: An nn.Module containing the extracted submodel, or None if extraction is not supported.
207+
"""
208+
209+
if len(input_nodes) != 1 or len(output_nodes) != 1:
210+
msg = "input_nodes and output_nodes should contain only one node."
211+
raise nncf.InternalError(msg)
212+
213+
input_node = graph.get_node_by_name(input_nodes[0])
214+
output_node = graph.get_node_by_name(output_nodes[0])
215+
216+
if input_node.metatype in CONV_METATYPES:
217+
return extract_conv(model, graph, input_node, output_node)
218+
219+
nncf_logger.debug(f"Can`t extract module for {input_node.node_name}")
220+
return None

nncf/experimental/torch2/function_hook/graph/build_graph_mode.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from nncf.experimental.torch2.function_hook.weak_map import WeakUnhashableKeyMap
3333
from nncf.experimental.torch2.function_hook.wrapper import ForwardWithHooks
3434
from nncf.experimental.torch2.function_hook.wrapper import get_hook_storage
35+
from nncf.torch.utils import training_mode_switcher
3536

3637

3738
class GraphBuilderMode(FunctionHookMode):
@@ -358,12 +359,12 @@ def build_graph(model: nn.Module, *args: Any, **kwargs: Any) -> nx.MultiDiGraph:
358359
:param model: The PyTorch model for which the computational graph will be built.
359360
:return: A nx.MultiDiGraph where nodes represent operations of model.
360361
"""
361-
362-
with torch.enable_grad(): # type: ignore
363-
# Gradient use to get information about __get__ functions to detect tensor.(T, mT) attributes
364-
with GraphBuilderMode(model=model, hook_storage=get_hook_storage(model)) as ctx:
365-
args, kwargs = ctx.process_model_inputs(args, kwargs)
366-
wrapped_forward = cast(ForwardWithHooks, model.forward)
367-
outputs = wrapped_forward._func(*args, **kwargs)
368-
outputs = ctx.process_model_outputs(outputs)
362+
with training_mode_switcher(model, is_training=False):
363+
with torch.enable_grad(): # type: ignore
364+
# Gradient use to get information about __get__ functions to detect tensor.(T, mT) attributes
365+
with GraphBuilderMode(model=model, hook_storage=get_hook_storage(model)) as ctx:
366+
args, kwargs = ctx.process_model_inputs(args, kwargs)
367+
wrapped_forward = cast(ForwardWithHooks, model.forward)
368+
outputs = wrapped_forward._func(*args, **kwargs)
369+
outputs = ctx.process_model_outputs(outputs)
369370
return ctx.graph

nncf/experimental/torch2/function_hook/nncf_graph/nncf_graph_builder.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import nncf.torch.graph.operator_metatypes as om
2222
from nncf.common.graph.graph import NNCFNode
2323
from nncf.common.graph.layer_attributes import BaseLayerAttributes
24+
from nncf.common.graph.layer_attributes import ConstantLayerAttributes
2425
from nncf.common.graph.layer_attributes import Dtype
2526
from nncf.experimental.torch2.function_hook.graph.build_graph_mode import build_graph
2627
from nncf.experimental.torch2.function_hook.graph.graph_utils import ConstMeta
@@ -157,7 +158,8 @@ def get_layer_attributes(
157158
if isinstance(meta, FunctionMeta):
158159
constant_port_ids = get_constant_port_ids(nx_graph, node)
159160
return PT2OpLayerAttributes(meta.func, meta.args, meta.kwargs, constant_port_ids)
160-
161+
if isinstance(meta, ConstMeta):
162+
return ConstantLayerAttributes(meta.name_in_model, list(meta.shape))
161163
return None
162164

163165

@@ -228,17 +230,16 @@ class GraphModelWrapper:
228230
"""
229231
A class that wraps a PyTorch model with examples inputs and provides an interface
230232
to build a computational graph of the model.
231-
232-
:param model: The PyTorch model to be wrapped.
233-
:param example_input: A tuple of example input for the model.
234233
"""
235234

236235
def __init__(self, model: nn.Module, example_input: Any) -> None:
237236
"""
238-
Initialize the GraphModelWrapper.
237+
:param model: The PyTorch model to be wrapped.
238+
:param example_input: A tuple of example input for the model.
239239
"""
240240
self.model = model
241241
self.example_input = example_input
242+
self.graph: Optional[PTNNCFGraph] = None
242243

243244
def build_graph(self) -> PTNNCFGraph:
244245
"""
@@ -254,3 +255,19 @@ def build_graph(self) -> PTNNCFGraph:
254255
if isinstance(self.example_input, tuple):
255256
return build_nncf_graph(self.model, *self.example_input)
256257
return build_nncf_graph(self.model, self.example_input)
258+
259+
def get_graph(self) -> PTNNCFGraph:
260+
"""
261+
Returns the computational graph of the model.
262+
263+
:return: The PTNNCFGraph representing the model.
264+
"""
265+
if self.graph is None:
266+
self.graph = self.build_graph()
267+
return self.graph
268+
269+
def reset_graph(self) -> None:
270+
"""
271+
Resets the computational graph of the model.
272+
"""
273+
self.graph = None

0 commit comments

Comments
 (0)