Skip to content

Commit 37c07eb

Browse files
authored
Merge pull request #145 from sovrasov/fix_interpolate
Fix funtional.interpolate hook
2 parents 307e6c3 + 0482b31 commit 37c07eb

File tree

4 files changed

+28
-11
lines changed

4 files changed

+28
-11
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# ptflops versions log
22

33
## v 0.7.4
4+
- Fix hook for nn.functional.interpolate.
45
- Switch to aten by default.
56
- Add ignore and custom modules for aten.
67
- Add an option to disable counting of functional-style operations in pytorch backend.

README.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
# Flops counting tool for neural networks in pytorch framework
22
[![Pypi version](https://img.shields.io/pypi/v/ptflops.svg)](https://pypi.org/project/ptflops/)
3-
[![Build Status](https://travis-ci.com/sovrasov/flops-counter.pytorch.svg?branch=master)](https://travis-ci.com/sovrasov/flops-counter.pytorch)
43

54
This tool is designed to compute the theoretical amount of multiply-add operations
65
in neural networks. It can also compute the number of parameters and

ptflops/pytorch_ops.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -347,15 +347,19 @@ def _interpolate_functional_flops_hook(*args, **kwargs):
347347
if input is None and len(args) > 0:
348348
input = args[0]
349349

350+
assert input.dim() - 2 > 0, "Input of interpolate should have NC... layout"
351+
350352
size = kwargs.get('size', None)
351353
if size is None and len(args) > 1:
352354
size = args[1]
353355

354356
if size is not None:
355357
if isinstance(size, tuple) or isinstance(size, list):
356-
return int(np.prod(size, dtype=np.int64))
358+
return int(np.prod(size, dtype=np.int64)) * \
359+
np.prod(input.shape[:2], dtype=np.int64)
357360
else:
358-
return int(size)
361+
return int(size) ** (input.dim() - 2) * \
362+
np.prod(input.shape[:2], dtype=np.int64)
359363

360364
scale_factor = kwargs.get('scale_factor', None)
361365
if scale_factor is None and len(args) > 2:
@@ -364,10 +368,10 @@ def _interpolate_functional_flops_hook(*args, **kwargs):
364368
"should be passes to interpolate"
365369

366370
flops = input.numel()
367-
if isinstance(scale_factor, tuple) and len(scale_factor) == len(input):
371+
if isinstance(scale_factor, tuple) and len(scale_factor) == len(input.shape) - 2:
368372
flops *= int(np.prod(scale_factor, dtype=np.int64))
369-
else:
370-
flops *= scale_factor**len(input)
373+
else: # NC... layout is assumed, see interpolate docs
374+
flops *= scale_factor ** (input.dim() - 2)
371375

372376
return flops
373377

tests/common_test.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -90,24 +90,26 @@ def input_constructor(input_res):
9090

9191
assert (macs, params) == (8, 8)
9292

93-
def test_func_interpolate_args(self):
93+
@pytest.mark.parametrize("out_size", [(20, 20), 20])
94+
def test_func_interpolate_args(self, out_size):
9495
class CustomModel(nn.Module):
9596
def __init__(self):
9697
super().__init__()
9798

9899
def forward(self, x):
99-
return nn.functional.interpolate(input=x, size=(20, 20),
100+
return nn.functional.interpolate(input=x, size=out_size,
100101
mode='bilinear', align_corners=False)
101102

102103
macs, params = \
103104
get_model_complexity_info(CustomModel(), (3, 10, 10),
104105
as_strings=False,
105106
print_per_layer_stat=False,
106107
backend=FLOPS_BACKEND.PYTORCH)
108+
107109
assert params == 0
108-
assert macs > 0
110+
assert macs == 1200
109111

110-
CustomModel.forward = lambda self, x: nn.functional.interpolate(x, size=(20, 20),
112+
CustomModel.forward = lambda self, x: nn.functional.interpolate(x, out_size,
111113
mode='bilinear')
112114

113115
macs, params = \
@@ -116,7 +118,18 @@ def forward(self, x):
116118
print_per_layer_stat=False,
117119
backend=FLOPS_BACKEND.PYTORCH)
118120
assert params == 0
119-
assert macs > 0
121+
assert macs == 1200
122+
123+
CustomModel.forward = lambda self, x: nn.functional.interpolate(x, scale_factor=2,
124+
mode='bilinear')
125+
126+
macs, params = \
127+
get_model_complexity_info(CustomModel(), (3, 10, 10),
128+
as_strings=False,
129+
print_per_layer_stat=False,
130+
backend=FLOPS_BACKEND.PYTORCH)
131+
assert params == 0
132+
assert macs == 1200
120133

121134
def test_ten_matmul(self, simple_model_mm):
122135
macs, params = \

0 commit comments

Comments
 (0)