Skip to content

Commit 555dfc9

Browse files
[TorchFX] Use DuplicateDQPass instead of the manual duplication
1 parent b1e6231 commit 555dfc9

File tree

23 files changed

+12498
-12798
lines changed

23 files changed

+12498
-12798
lines changed

Diff for: nncf/experimental/torch/fx/transformations.py

+13-19
Original file line numberDiff line numberDiff line change
@@ -383,27 +383,21 @@ def insert_one_qdq(model: torch.fx.GraphModule, target_point: PTTargetPoint, qua
383383

384384
# use the same qparams from quantize op
385385
dq_inputs = [quantized_node] + quantize_op_inputs[1:]
386-
if target_point.target_type == TargetType.OPERATOR_POST_HOOK:
387-
user_dq_nodes = []
388-
with graph.inserting_after(quantized_node):
389-
for user in target_node.users:
386+
387+
with graph.inserting_after(quantized_node):
388+
dq_node = graph.call_function(dequantize_op, tuple(dq_inputs), {})
389+
dq_node.meta["val"] = copy(meta_val)
390+
if target_point.target_type == TargetType.OPERATOR_POST_HOOK:
391+
for user in list(target_node.users):
390392
if user is quantized_node:
391393
continue
392-
dq_node = graph.call_function(dequantize_op, tuple(dq_inputs), {})
393-
dq_node.meta["val"] = copy(meta_val)
394-
user_dq_nodes.append((user, dq_node))
395-
396-
for user, dq_node in user_dq_nodes:
397-
user.replace_input_with(target_node, dq_node)
398-
elif target_point.target_type in [TargetType.OPERATOR_PRE_HOOK, TargetType.OPERATION_WITH_WEIGHTS]:
399-
with graph.inserting_after(quantized_node):
400-
dq_node = graph.call_function(dequantize_op, tuple(dq_inputs), {})
401-
dq_node.meta["val"] = copy(meta_val)
402-
403-
target_node.replace_input_with(input_node, dq_node)
404-
else:
405-
msg = f"Unexpected target type: {target_point.target_type}"
406-
raise nncf.InternalError(msg)
394+
user.replace_input_with(target_node, dq_node)
395+
396+
elif target_point.target_type in [TargetType.OPERATOR_PRE_HOOK, TargetType.OPERATION_WITH_WEIGHTS]:
397+
target_node.replace_input_with(input_node, dq_node)
398+
else:
399+
msg = f"Unexpected target type: {target_point.target_type}"
400+
raise nncf.InternalError(msg)
407401

408402

409403
def _insert_call_module(

Diff for: tests/torch/data/reference_graphs/fx/dynamic_shapes/post_quantization_compressed/mobilenet_v3_small.dot

+1,072-1,102
Large diffs are not rendered by default.

Diff for: tests/torch/data/reference_graphs/fx/dynamic_shapes/post_quantization_compressed/resnet18.dot

+472-486
Large diffs are not rendered by default.

Diff for: tests/torch/data/reference_graphs/fx/dynamic_shapes/post_quantization_compressed/unet.dot

+497-505
Large diffs are not rendered by default.

Diff for: tests/torch/data/reference_graphs/fx/dynamic_shapes/quantized/mobilenet_v3_small.dot

+1,148-1,178
Large diffs are not rendered by default.

Diff for: tests/torch/data/reference_graphs/fx/dynamic_shapes/quantized/resnet18.dot

+520-534
Large diffs are not rendered by default.

Diff for: tests/torch/data/reference_graphs/fx/dynamic_shapes/quantized/unet.dot

+547-555
Large diffs are not rendered by default.

Diff for: tests/torch/data/reference_graphs/fx/experimental/OpenVINOQuantizer/mobilenet_v3_small.dot

+1,072-1,102
Large diffs are not rendered by default.

Diff for: tests/torch/data/reference_graphs/fx/experimental/OpenVINOQuantizer/resnet18.dot

+472-486
Large diffs are not rendered by default.

Diff for: tests/torch/data/reference_graphs/fx/experimental/OpenVINOQuantizer/unet.dot

+497-505
Large diffs are not rendered by default.

Diff for: tests/torch/data/reference_graphs/fx/experimental/X86InductorQuantizer/mobilenet_v3_small.dot

+1,053-1,065
Large diffs are not rendered by default.

Diff for: tests/torch/data/reference_graphs/fx/experimental/X86InductorQuantizer/resnet18.dot

+414-430
Large diffs are not rendered by default.

Diff for: tests/torch/data/reference_graphs/fx/experimental/X86InductorQuantizer/unet.dot

+377-377
Large diffs are not rendered by default.

Diff for: tests/torch/data/reference_graphs/fx/post_quantization_compressed/mobilenet_v3_small.dot

+1,072-1,102
Large diffs are not rendered by default.

Diff for: tests/torch/data/reference_graphs/fx/post_quantization_compressed/resnet18.dot

+472-486
Large diffs are not rendered by default.

Diff for: tests/torch/data/reference_graphs/fx/post_quantization_compressed/unet.dot

+497-505
Large diffs are not rendered by default.

Diff for: tests/torch/data/reference_graphs/fx/quantized/mobilenet_v3_small.dot

+1,148-1,178
Large diffs are not rendered by default.

Diff for: tests/torch/data/reference_graphs/fx/quantized/resnet18.dot

+520-534
Large diffs are not rendered by default.

Diff for: tests/torch/data/reference_graphs/fx/quantized/unet.dot

+547-555
Large diffs are not rendered by default.
Original file line numberDiff line numberDiff line change
@@ -1,52 +1,48 @@
1-
strict digraph {
2-
"0 const" [id=0, type=get_attr];
3-
"1 conv_a_weight" [id=1, type=get_attr];
4-
"2 conv_a_bias" [id=2, type=get_attr];
5-
"3 conv_b_weight" [id=3, type=get_attr];
6-
"4 conv_b_bias" [id=4, type=get_attr];
7-
"5 conv_c_weight" [id=5, type=get_attr];
8-
"6 conv_c_bias" [id=6, type=get_attr];
9-
"7 bias" [id=7, type=get_attr];
1+
strict digraph {
2+
"0 const" [id=0, type="get_attr"];
3+
"1 conv_a_weight" [id=1, type="get_attr"];
4+
"2 conv_a_bias" [id=2, type="get_attr"];
5+
"3 conv_b_weight" [id=3, type="get_attr"];
6+
"4 conv_b_bias" [id=4, type="get_attr"];
7+
"5 conv_c_weight" [id=5, type="get_attr"];
8+
"6 conv_c_bias" [id=6, type="get_attr"];
9+
"7 bias" [id=7, type="get_attr"];
1010
"8 x" [id=8, type=input];
11-
"9 conv2d_scale_0" [id=9, type=get_attr];
12-
"10 conv2d_zero_point_0" [id=10, type=get_attr];
11+
"9 conv2d_scale_0" [id=9, type="get_attr"];
12+
"10 conv2d_zero_point_0" [id=10, type="get_attr"];
1313
"11 conv2d" [id=11, type=conv2d];
14-
"12 quantize_per_channel_default" [id=12, type=quantize_per_channel];
15-
"13 dequantize_per_channel_default_1" [id=13, type=dequantize_per_channel];
16-
"14 dequantize_per_channel_default" [id=14, type=dequantize_per_channel];
17-
"15 conv2d_1" [id=15, type=conv2d];
18-
"16 add_" [id=16, type=add_];
19-
"17 add__1" [id=17, type=add_];
20-
"18 cat" [id=18, type=cat];
21-
"19 conv2d_2" [id=19, type=conv2d];
22-
"20 add" [id=20, type=add];
23-
"21 output" [id=21, type=output];
24-
"0 const" -> "18 cat" [label="(1, 3, 3, 3)", style=solid];
25-
"1 conv_a_weight" -> "11 conv2d" [label="(3, 3, 1, 1)", style=solid];
26-
"2 conv_a_bias" -> "11 conv2d" [label="(3,)", style=solid];
27-
"3 conv_b_weight" -> "15 conv2d_1" [label="(3, 3, 1, 1)", style=solid];
28-
"4 conv_b_bias" -> "15 conv2d_1" [label="(3,)", style=solid];
29-
"5 conv_c_weight" -> "19 conv2d_2" [label="(3, 9, 1, 1)", style=solid];
30-
"6 conv_c_bias" -> "19 conv2d_2" [label="(3,)", style=solid];
31-
"7 bias" -> "16 add_" [label="(1,)", style=solid];
32-
"7 bias" -> "17 add__1" [label="(1,)", style=solid];
33-
"7 bias" -> "20 add" [label="(1,)", style=solid];
34-
"8 x" -> "11 conv2d" [label="(1, 3, 3, 3)", style=solid];
35-
"9 conv2d_scale_0" -> "12 quantize_per_channel_default" [label="(1,)", style=solid];
36-
"9 conv2d_scale_0" -> "13 dequantize_per_channel_default_1" [label="(1,)", style=solid];
37-
"9 conv2d_scale_0" -> "14 dequantize_per_channel_default" [label="(1,)", style=solid];
38-
"10 conv2d_zero_point_0" -> "12 quantize_per_channel_default" [label="(1,)", style=solid];
39-
"10 conv2d_zero_point_0" -> "13 dequantize_per_channel_default_1" [label="(1,)", style=solid];
40-
"10 conv2d_zero_point_0" -> "14 dequantize_per_channel_default" [label="(1,)", style=solid];
41-
"11 conv2d" -> "12 quantize_per_channel_default" [label="(1, 3, 3, 3)", style=solid];
42-
"12 quantize_per_channel_default" -> "13 dequantize_per_channel_default_1" [label="(1, 3, 3, 3)", style=solid];
43-
"12 quantize_per_channel_default" -> "14 dequantize_per_channel_default" [label="(1, 3, 3, 3)", style=solid];
44-
"13 dequantize_per_channel_default_1" -> "16 add_" [label="(1, 3, 3, 3)", style=solid];
45-
"14 dequantize_per_channel_default" -> "15 conv2d_1" [label="(1, 3, 3, 3)", style=solid];
46-
"15 conv2d_1" -> "17 add__1" [label="(1, 3, 3, 3)", style=solid];
47-
"16 add_" -> "18 cat" [label="(1, 3, 3, 3)", style=solid];
48-
"17 add__1" -> "18 cat" [label="(1, 3, 3, 3)", style=solid];
49-
"18 cat" -> "19 conv2d_2" [label="(1, 9, 3, 3)", style=solid];
50-
"19 conv2d_2" -> "20 add" [label="(1, 3, 3, 3)", style=solid];
51-
"20 add" -> "21 output" [label="(1, 3, 3, 3)", style=solid];
14+
"12 quantize_per_channel_default" [id=12, type="quantize_per_channel"];
15+
"13 dequantize_per_channel_default" [id=13, type="dequantize_per_channel"];
16+
"14 conv2d_1" [id=14, type=conv2d];
17+
"15 add_" [id=15, type="add_"];
18+
"16 add__1" [id=16, type="add_"];
19+
"17 cat" [id=17, type=cat];
20+
"18 conv2d_2" [id=18, type=conv2d];
21+
"19 add" [id=19, type=add];
22+
"20 output" [id=20, type=output];
23+
"0 const" -> "17 cat" [style=solid, label="(1, 3, 3, 3)"];
24+
"1 conv_a_weight" -> "11 conv2d" [style=solid, label="(3, 3, 1, 1)"];
25+
"2 conv_a_bias" -> "11 conv2d" [style=solid, label="(3,)"];
26+
"3 conv_b_weight" -> "14 conv2d_1" [style=solid, label="(3, 3, 1, 1)"];
27+
"4 conv_b_bias" -> "14 conv2d_1" [style=solid, label="(3,)"];
28+
"5 conv_c_weight" -> "18 conv2d_2" [style=solid, label="(3, 9, 1, 1)"];
29+
"6 conv_c_bias" -> "18 conv2d_2" [style=solid, label="(3,)"];
30+
"7 bias" -> "15 add_" [style=solid, label="(1,)"];
31+
"7 bias" -> "16 add__1" [style=solid, label="(1,)"];
32+
"7 bias" -> "19 add" [style=solid, label="(1,)"];
33+
"8 x" -> "11 conv2d" [style=solid, label="(1, 3, 3, 3)"];
34+
"9 conv2d_scale_0" -> "12 quantize_per_channel_default" [style=solid, label="(1,)"];
35+
"9 conv2d_scale_0" -> "13 dequantize_per_channel_default" [style=solid, label="(1,)"];
36+
"10 conv2d_zero_point_0" -> "12 quantize_per_channel_default" [style=solid, label="(1,)"];
37+
"10 conv2d_zero_point_0" -> "13 dequantize_per_channel_default" [style=solid, label="(1,)"];
38+
"11 conv2d" -> "12 quantize_per_channel_default" [style=solid, label="(1, 3, 3, 3)"];
39+
"12 quantize_per_channel_default" -> "13 dequantize_per_channel_default" [style=solid, label="(1, 3, 3, 3)"];
40+
"13 dequantize_per_channel_default" -> "14 conv2d_1" [style=solid, label="(1, 3, 3, 3)"];
41+
"13 dequantize_per_channel_default" -> "15 add_" [style=solid, label="(1, 3, 3, 3)"];
42+
"14 conv2d_1" -> "16 add__1" [style=solid, label="(1, 3, 3, 3)"];
43+
"15 add_" -> "17 cat" [style=solid, label="(1, 3, 3, 3)"];
44+
"16 add__1" -> "17 cat" [style=solid, label="(1, 3, 3, 3)"];
45+
"17 cat" -> "18 conv2d_2" [style=solid, label="(1, 9, 3, 3)"];
46+
"18 conv2d_2" -> "19 add" [style=solid, label="(1, 3, 3, 3)"];
47+
"19 add" -> "20 output" [style=solid, label="(1, 3, 3, 3)"];
5248
}
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,42 @@
1-
strict digraph {
2-
"0 const" [id=0, type=get_attr];
3-
"1 conv_a_weight" [id=1, type=get_attr];
4-
"2 conv_a_bias" [id=2, type=get_attr];
5-
"3 conv_b_weight" [id=3, type=get_attr];
6-
"4 conv_b_bias" [id=4, type=get_attr];
7-
"5 conv_c_weight" [id=5, type=get_attr];
8-
"6 conv_c_bias" [id=6, type=get_attr];
9-
"7 bias" [id=7, type=get_attr];
1+
strict digraph {
2+
"0 const" [id=0, type="get_attr"];
3+
"1 conv_a_weight" [id=1, type="get_attr"];
4+
"2 conv_a_bias" [id=2, type="get_attr"];
5+
"3 conv_b_weight" [id=3, type="get_attr"];
6+
"4 conv_b_bias" [id=4, type="get_attr"];
7+
"5 conv_c_weight" [id=5, type="get_attr"];
8+
"6 conv_c_bias" [id=6, type="get_attr"];
9+
"7 bias" [id=7, type="get_attr"];
1010
"8 x" [id=8, type=input];
1111
"9 conv2d" [id=9, type=conv2d];
12-
"10 quantize_per_tensor_default" [id=10, type=quantize_per_tensor];
13-
"11 dequantize_per_tensor_default_1" [id=11, type=dequantize_per_tensor];
14-
"12 dequantize_per_tensor_default" [id=12, type=dequantize_per_tensor];
15-
"13 conv2d_1" [id=13, type=conv2d];
16-
"14 add_" [id=14, type=add_];
17-
"15 add__1" [id=15, type=add_];
18-
"16 cat" [id=16, type=cat];
19-
"17 conv2d_2" [id=17, type=conv2d];
20-
"18 add" [id=18, type=add];
21-
"19 output" [id=19, type=output];
22-
"0 const" -> "16 cat" [label="(1, 3, 3, 3)", style=solid];
23-
"1 conv_a_weight" -> "9 conv2d" [label="(3, 3, 1, 1)", style=solid];
24-
"2 conv_a_bias" -> "9 conv2d" [label="(3,)", style=solid];
25-
"3 conv_b_weight" -> "13 conv2d_1" [label="(3, 3, 1, 1)", style=solid];
26-
"4 conv_b_bias" -> "13 conv2d_1" [label="(3,)", style=solid];
27-
"5 conv_c_weight" -> "17 conv2d_2" [label="(3, 9, 1, 1)", style=solid];
28-
"6 conv_c_bias" -> "17 conv2d_2" [label="(3,)", style=solid];
29-
"7 bias" -> "14 add_" [label="(1,)", style=solid];
30-
"7 bias" -> "15 add__1" [label="(1,)", style=solid];
31-
"7 bias" -> "18 add" [label="(1,)", style=solid];
32-
"8 x" -> "9 conv2d" [label="(1, 3, 3, 3)", style=solid];
33-
"9 conv2d" -> "10 quantize_per_tensor_default" [label="(1, 3, 3, 3)", style=solid];
34-
"10 quantize_per_tensor_default" -> "11 dequantize_per_tensor_default_1" [label="(1, 3, 3, 3)", style=solid];
35-
"10 quantize_per_tensor_default" -> "12 dequantize_per_tensor_default" [label="(1, 3, 3, 3)", style=solid];
36-
"11 dequantize_per_tensor_default_1" -> "14 add_" [label="(1, 3, 3, 3)", style=solid];
37-
"12 dequantize_per_tensor_default" -> "13 conv2d_1" [label="(1, 3, 3, 3)", style=solid];
38-
"13 conv2d_1" -> "15 add__1" [label="(1, 3, 3, 3)", style=solid];
39-
"14 add_" -> "16 cat" [label="(1, 3, 3, 3)", style=solid];
40-
"15 add__1" -> "16 cat" [label="(1, 3, 3, 3)", style=solid];
41-
"16 cat" -> "17 conv2d_2" [label="(1, 9, 3, 3)", style=solid];
42-
"17 conv2d_2" -> "18 add" [label="(1, 3, 3, 3)", style=solid];
43-
"18 add" -> "19 output" [label="(1, 3, 3, 3)", style=solid];
12+
"10 quantize_per_tensor_default" [id=10, type="quantize_per_tensor"];
13+
"11 dequantize_per_tensor_default" [id=11, type="dequantize_per_tensor"];
14+
"12 conv2d_1" [id=12, type=conv2d];
15+
"13 add_" [id=13, type="add_"];
16+
"14 add__1" [id=14, type="add_"];
17+
"15 cat" [id=15, type=cat];
18+
"16 conv2d_2" [id=16, type=conv2d];
19+
"17 add" [id=17, type=add];
20+
"18 output" [id=18, type=output];
21+
"0 const" -> "15 cat" [style=solid, label="(1, 3, 3, 3)"];
22+
"1 conv_a_weight" -> "9 conv2d" [style=solid, label="(3, 3, 1, 1)"];
23+
"2 conv_a_bias" -> "9 conv2d" [style=solid, label="(3,)"];
24+
"3 conv_b_weight" -> "12 conv2d_1" [style=solid, label="(3, 3, 1, 1)"];
25+
"4 conv_b_bias" -> "12 conv2d_1" [style=solid, label="(3,)"];
26+
"5 conv_c_weight" -> "16 conv2d_2" [style=solid, label="(3, 9, 1, 1)"];
27+
"6 conv_c_bias" -> "16 conv2d_2" [style=solid, label="(3,)"];
28+
"7 bias" -> "13 add_" [style=solid, label="(1,)"];
29+
"7 bias" -> "14 add__1" [style=solid, label="(1,)"];
30+
"7 bias" -> "17 add" [style=solid, label="(1,)"];
31+
"8 x" -> "9 conv2d" [style=solid, label="(1, 3, 3, 3)"];
32+
"9 conv2d" -> "10 quantize_per_tensor_default" [style=solid, label="(1, 3, 3, 3)"];
33+
"10 quantize_per_tensor_default" -> "11 dequantize_per_tensor_default" [style=solid, label="(1, 3, 3, 3)"];
34+
"11 dequantize_per_tensor_default" -> "12 conv2d_1" [style=solid, label="(1, 3, 3, 3)"];
35+
"11 dequantize_per_tensor_default" -> "13 add_" [style=solid, label="(1, 3, 3, 3)"];
36+
"12 conv2d_1" -> "14 add__1" [style=solid, label="(1, 3, 3, 3)"];
37+
"13 add_" -> "15 cat" [style=solid, label="(1, 3, 3, 3)"];
38+
"14 add__1" -> "15 cat" [style=solid, label="(1, 3, 3, 3)"];
39+
"15 cat" -> "16 conv2d_2" [style=solid, label="(1, 9, 3, 3)"];
40+
"16 conv2d_2" -> "17 add" [style=solid, label="(1, 3, 3, 3)"];
41+
"17 add" -> "18 output" [style=solid, label="(1, 3, 3, 3)"];
4442
}

Diff for: tests/torch/fx/test_models.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -137,19 +137,19 @@ def test_model(test_case: ModelCase):
137137
(
138138
ModelCase(test_models.UNet, "unet", [1, 3, 224, 224]),
139139
{},
140-
[(46, 50), (23, 27)],
140+
[(46, 46), (23, 23)],
141141
[Dim.AUTO, Dim.STATIC, Dim.STATIC, Dim.STATIC], # This Unet Model is not eligible for dynamic shape capability
142142
),
143143
(
144144
torchvision_model_case("resnet18", (1, 3, 224, 224)),
145145
{},
146-
[(51, 58), (30, 37)],
146+
[(51, 51), (30, 30)],
147147
[Dim.AUTO, Dim.STATIC, Dim.AUTO, Dim.AUTO],
148148
),
149149
(
150150
torchvision_model_case("mobilenet_v3_small", (1, 3, 224, 224)),
151151
{},
152-
[(97, 112), (61, 76)],
152+
[(97, 97), (61, 61)],
153153
[Dim.AUTO, Dim.STATIC, Dim.AUTO, Dim.AUTO],
154154
),
155155
(

Diff for: tests/torch/fx/test_sanity.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ class SanitySampleCase:
5353
"https://storage.openvinotoolkit.org/repositories/nncf/openvino_notebook_ckpts/302_resnet18_fp32_v1.pth",
5454
55.30,
5555
30,
56-
37,
56+
30,
5757
),
5858
)
5959

0 commit comments

Comments
 (0)