Skip to content

Commit b767645

Browse files
committed
Add an option to ignore funtionals in torch backend
1 parent b9340a3 commit b767645

File tree

6 files changed

+56
-23
lines changed

6 files changed

+56
-23
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
# ptflops versions log
22

3+
## v 0.7.4
4+
- Switch to aten by default.
5+
- Add ignore and custom modules for aten.
6+
- Add an option to disable counting of functional-style operations in pytorch backend.
7+
38
## v 0.7.3
49
- Add aten backend to collect the amount of flops on aten level.
510

README.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ print per-layer computational cost of a given network.
99
`ptflops` has two backends, `pytorch` and `aten`. `pytorch` backend is a legacy one, it considers `nn.Modules` only. However,
1010
it's still useful, since it provides a better par-layer analytics for CNNs. In all other cases it's recommended to use
1111
`aten` backend, which considers aten operations, and therefore it covers more model architectures (including transformers).
12+
The default backend is `aten`. Please, don't use `pytorch` backend for transformer architectures.
1213

1314
## `aten` backend
1415
### Operations considered:
@@ -19,6 +20,9 @@ it's still useful, since it provides a better par-layer analytics for CNNs. In a
1920
- Use `verbose=True` to see the operations which were not considered during complexity computation.
2021
- This backend prints per-module statistics only for modules directly nested into the root `nn.Module`.
2122
Deeper modules at the second level of nesting are not shown in the per-layer statistics.
23+
- `ignore_modules` option forces `ptflops` to ignore the listed modules. This can be useful
24+
for research purposes. For instance, one can drop all convolutions from the counting process
25+
specifying `ignore_modules=[torch.ops.aten.convolution, torch.ops.aten._convolution]`.
2226

2327
## `pytorch` backend
2428
### Supported layers:
@@ -41,7 +45,9 @@ Experimental support:
4145

4246
- This backend doesn't take into account some of the `torch.nn.functional.*` and `tensor.*` operations. Therefore unsupported operations are
4347
not contributing to the final complexity estimation. See `ptflops/pytorch_ops.py:FUNCTIONAL_MAPPING,TENSOR_OPS_MAPPING` to check supported ops.
44-
- `ptflops` launches a given model on a random tensor and estimates amount of computations during inference. Complicated models can have several inputs, some of them could be optional. To construct non-trivial input one can use the `input_constructor` argument of the `get_model_complexity_info`. `input_constructor` is a function that takes the input spatial resolution as a tuple and returns a dict with named input arguments of the model. Next this dict would be passed to the model as a keyword arguments.
48+
Sometimes considering functional style conflicts with hooks for `nn.Module` (for instance, custom ones). In that case, counting with these ops can be disabled by
49+
passing `backend_specific_config={"count_functional" : False}`.
50+
- `ptflops` launches a given model on a random tensor and estimates amount of computations during inference. Complicated models can have several inputs, some of them could be optional. To construct non-trivial input one can use the `input_constructor` argument of the `get_model_complexity_info`. `input_constructor` is a function that takes the input spatial resolution as a tuple and returns a dict with named input arguments of the model. Next, this dict would be passed to the model as a keyword arguments.
4551
- `verbose` parameter allows to get information about modules that don't contribute to the final numbers.
4652
- `ignore_modules` option forces `ptflops` to ignore the listed modules. This can be useful
4753
for research purposes. For instance, one can drop all convolutions from the counting process

ptflops/aten_engine.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from collections import defaultdict
1313
from copy import deepcopy
1414
from functools import partial
15-
from typing import Optional, Tuple, Union
15+
from typing import Dict, Optional, Tuple, Union
1616

1717
import torch
1818
from torch.utils._python_dispatch import TorchDispatchMode
@@ -106,8 +106,9 @@ def get_flops_aten(model, input_res,
106106
custom_modules_hooks={},
107107
output_precision=2,
108108
flops_units: Optional[str] = 'GMac',
109-
param_units: Optional[str] = 'M') -> Tuple[Union[int, None],
110-
Union[int, None]]:
109+
param_units: Optional[str] = 'M',
110+
extra_config: Dict = {}) -> Tuple[Union[int, None],
111+
Union[int, None]]:
111112

112113
params_sum = get_model_parameters_number(model)
113114
model.eval()

ptflops/flops_counter.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,10 @@ def get_model_complexity_info(model: nn.Module,
3434
backend: Union[str, FLOPS_BACKEND] = FLOPS_BACKEND.ATEN,
3535
flops_units: Optional[str] = None,
3636
param_units: Optional[str] = None,
37-
output_precision: int = 2) -> Tuple[Union[str, int, None],
38-
Union[str, int, None]]:
37+
output_precision: int = 2,
38+
backend_specific_config: Dict = {}) -> Tuple[
39+
Union[str, int, None],
40+
Union[str, int, None]]:
3941
"""
4042
Analyzes the input model and collects the amounts of parameters and MACs
4143
required to make a forward pass of the model.
@@ -75,6 +77,8 @@ def get_model_complexity_info(model: nn.Module,
7577
:param output_precision: Floating point precision for representing MACs/params in
7678
given units.
7779
:type output_precision: int
80+
:param backend_specific_config: Extra configuration for a specific backend.
81+
:type backend_specific_config: dict
7882
7983
Returns:
8084
Tuple[Union[str, int, None], Union[str, int, None]]: Return value is a tuple
@@ -86,14 +90,16 @@ def get_model_complexity_info(model: nn.Module,
8690
assert isinstance(model, nn.Module)
8791

8892
if FLOPS_BACKEND(backend) == FLOPS_BACKEND.PYTORCH:
89-
flops_count, params_count = get_flops_pytorch(model, input_res,
90-
print_per_layer_stat,
91-
input_constructor, ost,
92-
verbose, ignore_modules,
93-
custom_modules_hooks,
94-
output_precision=output_precision,
95-
flops_units=flops_units,
96-
param_units=param_units)
93+
flops_count, params_count = \
94+
get_flops_pytorch(model, input_res,
95+
print_per_layer_stat,
96+
input_constructor, ost,
97+
verbose, ignore_modules,
98+
custom_modules_hooks,
99+
output_precision=output_precision,
100+
flops_units=flops_units,
101+
param_units=param_units,
102+
extra_config=backend_specific_config)
97103
elif FLOPS_BACKEND(backend) == FLOPS_BACKEND.ATEN:
98104
flops_count, params_count = get_flops_aten(model, input_res,
99105
print_per_layer_stat,
@@ -102,7 +108,8 @@ def get_model_complexity_info(model: nn.Module,
102108
custom_modules_hooks,
103109
output_precision=output_precision,
104110
flops_units=flops_units,
105-
param_units=param_units)
111+
param_units=param_units,
112+
extra_config=backend_specific_config)
106113
else:
107114
raise ValueError('Wrong backend name')
108115

ptflops/pytorch_engine.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import sys
1010
import traceback
1111
from functools import partial
12-
from typing import Optional, Tuple, Union
12+
from typing import Dict, Optional, Tuple, Union
1313

1414
import torch
1515
import torch.nn as nn
@@ -27,8 +27,9 @@ def get_flops_pytorch(model, input_res,
2727
custom_modules_hooks={},
2828
output_precision=2,
2929
flops_units: Optional[str] = 'GMac',
30-
param_units: Optional[str] = 'M') -> Tuple[Union[int, None],
31-
Union[int, None]]:
30+
param_units: Optional[str] = 'M',
31+
extra_config: Dict = {}) -> Tuple[Union[int, None],
32+
Union[int, None]]:
3233
global CUSTOM_MODULES_MAPPING
3334
CUSTOM_MODULES_MAPPING = custom_modules_hooks
3435
flops_model = add_flops_counting_methods(model)
@@ -45,15 +46,18 @@ def get_flops_pytorch(model, input_res,
4546
except StopIteration:
4647
batch = torch.ones(()).new_empty((1, *input_res))
4748

49+
enable_func_ops_patching = extra_config.get('count_functional', True)
4850
torch_functional_flops = []
4951
torch_tensor_ops_flops = []
50-
patch_functional(torch_functional_flops)
51-
patch_tensor_ops(torch_tensor_ops_flops)
52+
if enable_func_ops_patching:
53+
patch_functional(torch_functional_flops)
54+
patch_tensor_ops(torch_tensor_ops_flops)
5255

5356
def reset_environment():
5457
flops_model.stop_flops_count()
55-
unpatch_functional()
56-
unpatch_tensor_ops()
58+
if enable_func_ops_patching:
59+
unpatch_functional()
60+
unpatch_tensor_ops()
5761
global CUSTOM_MODULES_MAPPING
5862
CUSTOM_MODULES_MAPPING = {}
5963

tests/common_test.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ def forward(self, x):
2222

2323
return CustomModel()
2424

25-
2625
@pytest.mark.parametrize("backend", [FLOPS_BACKEND.PYTORCH, FLOPS_BACKEND.ATEN])
2726
def test_conv(self, default_input_image_size, backend: FLOPS_BACKEND):
2827
net = nn.Sequential(nn.Conv2d(3, 2, 3, bias=True))
@@ -152,3 +151,14 @@ def test_aten_custom(self, simple_model_mm):
152151

153152
assert params == 0
154153
assert macs == reference
154+
155+
def test_torch_ignore_func(self, simple_model_mm):
156+
macs, params = \
157+
get_model_complexity_info(simple_model_mm, (10, ),
158+
backend=FLOPS_BACKEND.PYTORCH,
159+
as_strings=False,
160+
print_per_layer_stat=False,
161+
backend_specific_config={'count_functional': False})
162+
163+
assert params == 0
164+
assert macs == 0

0 commit comments

Comments
 (0)