Skip to content

Fix funtional.interpolate hook #145

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# ptflops versions log

## v 0.7.4
- Fix hook for nn.functional.interpolate.
- Switch to aten by default.
- Add ignore and custom modules for aten.
- Add an option to disable counting of functional-style operations in pytorch backend.
Expand Down
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Flops counting tool for neural networks in pytorch framework
[![Pypi version](https://img.shields.io/pypi/v/ptflops.svg)](https://pypi.org/project/ptflops/)
[![Build Status](https://travis-ci.com/sovrasov/flops-counter.pytorch.svg?branch=master)](https://travis-ci.com/sovrasov/flops-counter.pytorch)

This tool is designed to compute the theoretical amount of multiply-add operations
in neural networks. It can also compute the number of parameters and
Expand Down
14 changes: 9 additions & 5 deletions ptflops/pytorch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,15 +347,19 @@ def _interpolate_functional_flops_hook(*args, **kwargs):
if input is None and len(args) > 0:
input = args[0]

assert input.dim() - 2 > 0, "Input of interpolate should have NC... layout"

size = kwargs.get('size', None)
if size is None and len(args) > 1:
size = args[1]

if size is not None:
if isinstance(size, tuple) or isinstance(size, list):
return int(np.prod(size, dtype=np.int64))
return int(np.prod(size, dtype=np.int64)) * \
np.prod(input.shape[:2], dtype=np.int64)
else:
return int(size)
return int(size) ** (input.dim() - 2) * \
np.prod(input.shape[:2], dtype=np.int64)

scale_factor = kwargs.get('scale_factor', None)
if scale_factor is None and len(args) > 2:
Expand All @@ -364,10 +368,10 @@ def _interpolate_functional_flops_hook(*args, **kwargs):
"should be passes to interpolate"

flops = input.numel()
if isinstance(scale_factor, tuple) and len(scale_factor) == len(input):
if isinstance(scale_factor, tuple) and len(scale_factor) == len(input.shape) - 2:
flops *= int(np.prod(scale_factor, dtype=np.int64))
else:
flops *= scale_factor**len(input)
else: # NC... layout is assumed, see interpolate docs
flops *= scale_factor ** (input.dim() - 2)

return flops

Expand Down
23 changes: 18 additions & 5 deletions tests/common_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,24 +90,26 @@ def input_constructor(input_res):

assert (macs, params) == (8, 8)

def test_func_interpolate_args(self):
@pytest.mark.parametrize("out_size", [(20, 20), 20])
def test_func_interpolate_args(self, out_size):
class CustomModel(nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return nn.functional.interpolate(input=x, size=(20, 20),
return nn.functional.interpolate(input=x, size=out_size,
mode='bilinear', align_corners=False)

macs, params = \
get_model_complexity_info(CustomModel(), (3, 10, 10),
as_strings=False,
print_per_layer_stat=False,
backend=FLOPS_BACKEND.PYTORCH)

assert params == 0
assert macs > 0
assert macs == 1200

CustomModel.forward = lambda self, x: nn.functional.interpolate(x, size=(20, 20),
CustomModel.forward = lambda self, x: nn.functional.interpolate(x, out_size,
mode='bilinear')

macs, params = \
Expand All @@ -116,7 +118,18 @@ def forward(self, x):
print_per_layer_stat=False,
backend=FLOPS_BACKEND.PYTORCH)
assert params == 0
assert macs > 0
assert macs == 1200

CustomModel.forward = lambda self, x: nn.functional.interpolate(x, scale_factor=2,
mode='bilinear')

macs, params = \
get_model_complexity_info(CustomModel(), (3, 10, 10),
as_strings=False,
print_per_layer_stat=False,
backend=FLOPS_BACKEND.PYTORCH)
assert params == 0
assert macs == 1200

def test_ten_matmul(self, simple_model_mm):
macs, params = \
Expand Down
Loading