Skip to content

Commit e5ad47b

Browse files
[pt] Trace parameters in list (#3483)
### Changes Trace parameters in list, like `foo([self.w1, self.w2])`. ### Reason for changes Ignoring parameters in list or tuple. Fail on build_graph.
1 parent 8b49993 commit e5ad47b

File tree

4 files changed

+54
-15
lines changed

4 files changed

+54
-15
lines changed

nncf/torch/function_hook/hook_executor_mode.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -361,10 +361,20 @@ def process_parameters(self, args: list[Any], kwargs: dict[str, Any]) -> tuple[l
361361
:param kwargs: The keyword arguments to the function.
362362
:return: The modified arguments and keyword arguments after pre-hooks.
363363
"""
364+
365+
def _execute_hooks_for_parameter(value: Any) -> Any:
366+
if not isinstance(value, list):
367+
return self.execute_hooks_for_parameter(value)
368+
for list_idx, tensor in enumerate(value):
369+
value[list_idx] = self.execute_hooks_for_parameter(tensor)
370+
return value
371+
364372
for idx, value in enumerate(args):
365-
args[idx] = self.execute_hooks_for_parameter(value)
373+
args[idx] = _execute_hooks_for_parameter(value)
374+
366375
for kw_name, value in kwargs.items():
367-
kwargs[kw_name] = self.execute_hooks_for_parameter(value)
376+
kwargs[kw_name] = _execute_hooks_for_parameter(value)
377+
368378
return args, kwargs
369379

370380
def execute_pre_hooks(
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
strict digraph {
2+
x [id=0, type="nncf_model_input", metatype=PTInputNoopMetatype];
3+
"gru/zeros/0" [id=1, type=zeros, metatype=UnknownMetatype];
4+
"gru.weight_ih_l0" [id=2, type="nncf_model_const", metatype=PTConstNoopMetatype];
5+
"gru.weight_hh_l0" [id=3, type="nncf_model_const", metatype=PTConstNoopMetatype];
6+
"gru.bias_ih_l0" [id=4, type="nncf_model_const", metatype=PTConstNoopMetatype];
7+
"gru.bias_hh_l0" [id=5, type="nncf_model_const", metatype=PTConstNoopMetatype];
8+
"gru/gru/0" [id=6, type=gru, metatype=UnknownMetatype];
9+
output_0 [id=7, type="nncf_model_output", metatype=PTOutputNoopMetatype];
10+
output_1 [id=8, type="nncf_model_output", metatype=PTOutputNoopMetatype];
11+
x -> "gru/gru/0" [dtype=float, shape="(1, 3, 3)", out_port_id=0, in_port_id=0];
12+
"gru/zeros/0" -> "gru/gru/0" [dtype=float, shape="(1, 1, 4)", out_port_id=0, in_port_id=1];
13+
"gru.weight_ih_l0" -> "gru/gru/0" [dtype=float, shape="(12, 3)", out_port_id=0, in_port_id=2];
14+
"gru.weight_hh_l0" -> "gru/gru/0" [dtype=float, shape="(12, 4)", out_port_id=0, in_port_id=3];
15+
"gru.bias_ih_l0" -> "gru/gru/0" [dtype=float, shape="(12,)", out_port_id=0, in_port_id=4];
16+
"gru.bias_hh_l0" -> "gru/gru/0" [dtype=float, shape="(12,)", out_port_id=0, in_port_id=5];
17+
"gru/gru/0" -> output_0 [dtype=float, shape="(1, 3, 4)", out_port_id=0, in_port_id=0];
18+
"gru/gru/0" -> output_1 [dtype=float, shape="(1, 1, 4)", out_port_id=1, in_port_id=0];
19+
}

tests/torch2/function_hook/helpers.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,3 +180,12 @@ def get_config(self):
180180
@classmethod
181181
def from_config(cls, state: str):
182182
return cls(state)
183+
184+
185+
class ModelGRU(nn.Module):
186+
def __init__(self):
187+
super().__init__()
188+
self.gru = torch.nn.GRU(3, 4, batch_first=True)
189+
190+
def forward(self, x):
191+
return self.gru(x)

tests/torch2/function_hook/nncf_graph/test_nncf_graph.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -125,24 +125,25 @@ def __str__(self):
125125

126126

127127
TEST_MODELS_DESC = [
128-
ModelDesc("convnext_small", models.convnext_small, [1, 3, 64, 64]),
129-
ModelDesc("densenet121", models.densenet121, [1, 3, 64, 64]),
130-
ModelDesc("efficientnet_b0", models.efficientnet_b0, [1, 3, 64, 64]),
131-
ModelDesc("inception_v3", partial(models.inception_v3, init_weights=False), [1, 3, 300, 300]),
132-
ModelDesc("mobilenet_v2", models.mobilenet_v2, [1, 3, 64, 64]),
133-
ModelDesc("mobilenet_v3_small", models.mobilenet_v3_small, [1, 3, 64, 64]),
134-
ModelDesc("resnet18", models.resnet18, [1, 3, 64, 64]),
135-
ModelDesc("resnext50_32x4d", models.resnext50_32x4d, [1, 3, 64, 64]),
136-
ModelDesc("shufflenet_v2_x0_5", models.shufflenet_v2_x0_5, [1, 3, 224, 224]),
137-
ModelDesc("squeezenet1_0", models.squeezenet1_0, [1, 3, 64, 64]),
138-
ModelDesc("swin_v2_b", models.swin_v2_b, [1, 3, 64, 64]),
139-
ModelDesc("vgg16", models.vgg16, [1, 3, 32, 32]),
128+
ModelDesc("convnext_small", partial(models.convnext_small, weights=None), [1, 3, 64, 64]),
129+
ModelDesc("densenet121", partial(models.densenet121, weights=None), [1, 3, 64, 64]),
130+
ModelDesc("efficientnet_b0", partial(models.efficientnet_b0, weights=None), [1, 3, 64, 64]),
131+
ModelDesc("inception_v3", partial(models.inception_v3, init_weights=False, weights=None), [1, 3, 300, 300]),
132+
ModelDesc("mobilenet_v2", partial(models.mobilenet_v2, weights=None), [1, 3, 64, 64]),
133+
ModelDesc("mobilenet_v3_small", partial(models.mobilenet_v3_small, weights=None), [1, 3, 64, 64]),
134+
ModelDesc("resnet18", partial(models.resnet18, weights=None), [1, 3, 64, 64]),
135+
ModelDesc("resnext50_32x4d", partial(models.resnext50_32x4d, weights=None), [1, 3, 64, 64]),
136+
ModelDesc("shufflenet_v2_x0_5", partial(models.shufflenet_v2_x0_5, weights=None), [1, 3, 224, 224]),
137+
ModelDesc("squeezenet1_0", partial(models.squeezenet1_0, weights=None), [1, 3, 64, 64]),
138+
ModelDesc("swin_v2_b", partial(models.swin_v2_b, weights=None), [1, 3, 64, 64]),
139+
ModelDesc("vgg16", partial(models.vgg16, weights=None), [1, 3, 32, 32]),
140+
ModelDesc("gru", helpers.ModelGRU, [1, 3, 3]),
140141
]
141142

142143

143144
@pytest.mark.parametrize("desc", TEST_MODELS_DESC, ids=str)
144145
def test_model_graph(desc: ModelDesc, regen_ref_data: bool):
145-
model: torch.nn.Module = desc.model_builder(weights=None)
146+
model: torch.nn.Module = desc.model_builder()
146147
model = model.eval()
147148
model = wrap_model(model)
148149
nncf_graph = build_nncf_graph(model, torch.randn(desc.inputs_info))

0 commit comments

Comments
 (0)