Skip to content

Commit 3e424e0

Browse files
authored
Merge pull request #140 from sovrasov/update_backends
Update backends
2 parents e2772ae + b767645 commit 3e424e0

File tree

6 files changed

+122
-46
lines changed

6 files changed

+122
-46
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
# ptflops versions log
22

3+
## v 0.7.4
4+
- Switch to aten by default.
5+
- Add ignore and custom modules for aten.
6+
- Add an option to disable counting of functional-style operations in pytorch backend.
7+
38
## v 0.7.3
49
- Add aten backend to collect the amount of flops on aten level.
510

README.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ print per-layer computational cost of a given network.
99
`ptflops` has two backends, `pytorch` and `aten`. `pytorch` backend is a legacy one, it considers `nn.Modules` only. However,
1010
it's still useful, since it provides a better par-layer analytics for CNNs. In all other cases it's recommended to use
1111
`aten` backend, which considers aten operations, and therefore it covers more model architectures (including transformers).
12+
The default backend is `aten`. Please, don't use `pytorch` backend for transformer architectures.
1213

1314
## `aten` backend
1415
### Operations considered:
@@ -19,6 +20,9 @@ it's still useful, since it provides a better par-layer analytics for CNNs. In a
1920
- Use `verbose=True` to see the operations which were not considered during complexity computation.
2021
- This backend prints per-module statistics only for modules directly nested into the root `nn.Module`.
2122
Deeper modules at the second level of nesting are not shown in the per-layer statistics.
23+
- `ignore_modules` option forces `ptflops` to ignore the listed modules. This can be useful
24+
for research purposes. For instance, one can drop all convolutions from the counting process
25+
specifying `ignore_modules=[torch.ops.aten.convolution, torch.ops.aten._convolution]`.
2226

2327
## `pytorch` backend
2428
### Supported layers:
@@ -41,7 +45,9 @@ Experimental support:
4145

4246
- This backend doesn't take into account some of the `torch.nn.functional.*` and `tensor.*` operations. Therefore unsupported operations are
4347
not contributing to the final complexity estimation. See `ptflops/pytorch_ops.py:FUNCTIONAL_MAPPING,TENSOR_OPS_MAPPING` to check supported ops.
44-
- `ptflops` launches a given model on a random tensor and estimates amount of computations during inference. Complicated models can have several inputs, some of them could be optional. To construct non-trivial input one can use the `input_constructor` argument of the `get_model_complexity_info`. `input_constructor` is a function that takes the input spatial resolution as a tuple and returns a dict with named input arguments of the model. Next this dict would be passed to the model as a keyword arguments.
48+
Sometimes considering functional style conflicts with hooks for `nn.Module` (for instance, custom ones). In that case, counting with these ops can be disabled by
49+
passing `backend_specific_config={"count_functional" : False}`.
50+
- `ptflops` launches a given model on a random tensor and estimates amount of computations during inference. Complicated models can have several inputs, some of them could be optional. To construct non-trivial input one can use the `input_constructor` argument of the `get_model_complexity_info`. `input_constructor` is a function that takes the input spatial resolution as a tuple and returns a dict with named input arguments of the model. Next, this dict would be passed to the model as a keyword arguments.
4551
- `verbose` parameter allows to get information about modules that don't contribute to the final numbers.
4652
- `ignore_modules` option forces `ptflops` to ignore the listed modules. This can be useful
4753
for research purposes. For instance, one can drop all convolutions from the counting process

ptflops/aten_engine.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,9 @@
1010
import sys
1111
import traceback
1212
from collections import defaultdict
13+
from copy import deepcopy
1314
from functools import partial
14-
from typing import Optional, Tuple, Union
15+
from typing import Dict, Optional, Tuple, Union
1516

1617
import torch
1718
from torch.utils._python_dispatch import TorchDispatchMode
@@ -23,12 +24,15 @@
2324

2425
class FlopCounterMode(TorchDispatchMode):
2526
def __init__(self, module=None, verbose=False, print_per_layer_stat=False,
26-
output_params=None):
27+
output_params=None, custom_hooks={}, ignored_ops=[]):
2728
self.verbose = verbose
2829
if output_params is None:
2930
output_params = defaultdict(dict)
3031
self.output_params = output_params
3132
self.print_fn = partial(print, **self.output_params['print_params'])
33+
self.all_ops = deepcopy(ATEN_OPS_MAPPING)
34+
self.all_ops.update(custom_hooks)
35+
self.ignored_ops = ignored_ops
3236

3337
self.print_per_layer_stat = print_per_layer_stat
3438
self.flop_counts = defaultdict(lambda: defaultdict(int))
@@ -82,8 +86,11 @@ def normalize_tuple(x):
8286

8387
out = func(*args, **kwargs)
8488
func_packet = func._overloadpacket
85-
if func_packet in ATEN_OPS_MAPPING:
86-
flop_count = ATEN_OPS_MAPPING[func_packet](args, normalize_tuple(out))
89+
90+
if func_packet in self.ignored_ops:
91+
self.print_fn(f'Warning: {func_packet} operation is ignored')
92+
elif func_packet in self.all_ops:
93+
flop_count = self.all_ops[func_packet](args, normalize_tuple(out))
8794
for par in self.parents:
8895
self.flop_counts[par][func_packet] += flop_count
8996
elif self.verbose:
@@ -99,8 +106,9 @@ def get_flops_aten(model, input_res,
99106
custom_modules_hooks={},
100107
output_precision=2,
101108
flops_units: Optional[str] = 'GMac',
102-
param_units: Optional[str] = 'M') -> Tuple[Union[int, None],
103-
Union[int, None]]:
109+
param_units: Optional[str] = 'M',
110+
extra_config: Dict = {}) -> Tuple[Union[int, None],
111+
Union[int, None]]:
104112

105113
params_sum = get_model_parameters_number(model)
106114
model.eval()
@@ -119,7 +127,8 @@ def get_flops_aten(model, input_res,
119127
batch = torch.ones(()).new_empty((1, *input_res))
120128

121129
try:
122-
counter = FlopCounterMode(model, verbose, print_per_layer_stat, output_params)
130+
counter = FlopCounterMode(model, verbose, print_per_layer_stat, output_params,
131+
custom_modules_hooks, ignore_modules)
123132
with counter:
124133
if isinstance(batch, dict):
125134
_ = model(**batch)

ptflops/flops_counter.py

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,15 @@ def get_model_complexity_info(model: nn.Module,
2929
input_constructor: Optional[Callable[[Tuple], Dict]] = None,
3030
ost: TextIO = sys.stdout,
3131
verbose: bool = False,
32-
ignore_modules: List[nn.Module] = [],
33-
custom_modules_hooks: Dict[nn.Module, Any] = {},
34-
backend: Union[str, FLOPS_BACKEND] = FLOPS_BACKEND.PYTORCH,
32+
ignore_modules: List[Union[nn.Module, Any]] = [],
33+
custom_modules_hooks: Dict[Union[nn.Module, Any], Any] = {},
34+
backend: Union[str, FLOPS_BACKEND] = FLOPS_BACKEND.ATEN,
3535
flops_units: Optional[str] = None,
3636
param_units: Optional[str] = None,
37-
output_precision: int = 2) -> Tuple[Union[str, int, None],
38-
Union[str, int, None]]:
37+
output_precision: int = 2,
38+
backend_specific_config: Dict = {}) -> Tuple[
39+
Union[str, int, None],
40+
Union[str, int, None]]:
3941
"""
4042
Analyzes the input model and collects the amounts of parameters and MACs
4143
required to make a forward pass of the model.
@@ -61,10 +63,11 @@ def get_model_complexity_info(model: nn.Module,
6163
:type ost: TextIO
6264
:param verbose: Parameter to control printing of extra information and warnings.
6365
:type verbose: bool
64-
:param ignore_modules: A list of torch.nn.Module modules to ignore.
65-
:type ignore_modules: nn.Module
66-
:param custom_modules_hooks: A dict that contains custom hooks on torch modules.
67-
:type custom_modules_hooks: Dict[nn.Module, Any]
66+
:param ignore_modules: A list of torch.nn.Module or torch.ops.aten modules to ignore.
67+
:type ignore_modules: List[Union[nn.Module, Any]]
68+
:param custom_modules_hooks: A dict that contains custom hooks for torch.nn.Module or
69+
torch.ops.aten modules.
70+
:type custom_modules_hooks: Dict[Union[nn.Module, Any], Any]
6871
:param backend: Backend that used for evaluating model complexity.
6972
:type backend: FLOPS_BACKEND
7073
:param flops_units: Units for string representation of MACs (GMac, MMac or KMac).
@@ -74,6 +77,8 @@ def get_model_complexity_info(model: nn.Module,
7477
:param output_precision: Floating point precision for representing MACs/params in
7578
given units.
7679
:type output_precision: int
80+
:param backend_specific_config: Extra configuration for a specific backend.
81+
:type backend_specific_config: dict
7782
7883
Returns:
7984
Tuple[Union[str, int, None], Union[str, int, None]]: Return value is a tuple
@@ -85,14 +90,16 @@ def get_model_complexity_info(model: nn.Module,
8590
assert isinstance(model, nn.Module)
8691

8792
if FLOPS_BACKEND(backend) == FLOPS_BACKEND.PYTORCH:
88-
flops_count, params_count = get_flops_pytorch(model, input_res,
89-
print_per_layer_stat,
90-
input_constructor, ost,
91-
verbose, ignore_modules,
92-
custom_modules_hooks,
93-
output_precision=output_precision,
94-
flops_units=flops_units,
95-
param_units=param_units)
93+
flops_count, params_count = \
94+
get_flops_pytorch(model, input_res,
95+
print_per_layer_stat,
96+
input_constructor, ost,
97+
verbose, ignore_modules,
98+
custom_modules_hooks,
99+
output_precision=output_precision,
100+
flops_units=flops_units,
101+
param_units=param_units,
102+
extra_config=backend_specific_config)
96103
elif FLOPS_BACKEND(backend) == FLOPS_BACKEND.ATEN:
97104
flops_count, params_count = get_flops_aten(model, input_res,
98105
print_per_layer_stat,
@@ -101,7 +108,8 @@ def get_model_complexity_info(model: nn.Module,
101108
custom_modules_hooks,
102109
output_precision=output_precision,
103110
flops_units=flops_units,
104-
param_units=param_units)
111+
param_units=param_units,
112+
extra_config=backend_specific_config)
105113
else:
106114
raise ValueError('Wrong backend name')
107115

ptflops/pytorch_engine.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import sys
1010
import traceback
1111
from functools import partial
12-
from typing import Optional, Tuple, Union
12+
from typing import Dict, Optional, Tuple, Union
1313

1414
import torch
1515
import torch.nn as nn
@@ -27,8 +27,9 @@ def get_flops_pytorch(model, input_res,
2727
custom_modules_hooks={},
2828
output_precision=2,
2929
flops_units: Optional[str] = 'GMac',
30-
param_units: Optional[str] = 'M') -> Tuple[Union[int, None],
31-
Union[int, None]]:
30+
param_units: Optional[str] = 'M',
31+
extra_config: Dict = {}) -> Tuple[Union[int, None],
32+
Union[int, None]]:
3233
global CUSTOM_MODULES_MAPPING
3334
CUSTOM_MODULES_MAPPING = custom_modules_hooks
3435
flops_model = add_flops_counting_methods(model)
@@ -45,15 +46,18 @@ def get_flops_pytorch(model, input_res,
4546
except StopIteration:
4647
batch = torch.ones(()).new_empty((1, *input_res))
4748

49+
enable_func_ops_patching = extra_config.get('count_functional', True)
4850
torch_functional_flops = []
4951
torch_tensor_ops_flops = []
50-
patch_functional(torch_functional_flops)
51-
patch_tensor_ops(torch_tensor_ops_flops)
52+
if enable_func_ops_patching:
53+
patch_functional(torch_functional_flops)
54+
patch_tensor_ops(torch_tensor_ops_flops)
5255

5356
def reset_environment():
5457
flops_model.stop_flops_count()
55-
unpatch_functional()
56-
unpatch_tensor_ops()
58+
if enable_func_ops_patching:
59+
unpatch_functional()
60+
unpatch_tensor_ops()
5761
global CUSTOM_MODULES_MAPPING
5862
CUSTOM_MODULES_MAPPING = {}
5963

tests/common_test.py

Lines changed: 57 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,17 @@ class TestOperations:
1111
def default_input_image_size(self):
1212
return (3, 224, 224)
1313

14+
@pytest.fixture
15+
def simple_model_mm(self):
16+
class CustomModel(nn.Module):
17+
def __init__(self):
18+
super().__init__()
19+
20+
def forward(self, x):
21+
return x.matmul(x.t())
22+
23+
return CustomModel()
24+
1425
@pytest.mark.parametrize("backend", [FLOPS_BACKEND.PYTORCH, FLOPS_BACKEND.ATEN])
1526
def test_conv(self, default_input_image_size, backend: FLOPS_BACKEND):
1627
net = nn.Sequential(nn.Conv2d(3, 2, 3, bias=True))
@@ -53,7 +64,8 @@ def input_constructor(input_res):
5364
macs, params = get_model_complexity_info(net, (3,),
5465
input_constructor=input_constructor,
5566
as_strings=False,
56-
print_per_layer_stat=False)
67+
print_per_layer_stat=False,
68+
backend=FLOPS_BACKEND.PYTORCH)
5769

5870
assert (macs, params) == (8, 8)
5971

@@ -73,7 +85,8 @@ def input_constructor(input_res):
7385
get_model_complexity_info(CustomLinear(), (3,),
7486
input_constructor=input_constructor,
7587
as_strings=False,
76-
print_per_layer_stat=False)
88+
print_per_layer_stat=False,
89+
backend=FLOPS_BACKEND.PYTORCH)
7790

7891
assert (macs, params) == (8, 8)
7992

@@ -89,7 +102,8 @@ def forward(self, x):
89102
macs, params = \
90103
get_model_complexity_info(CustomModel(), (3, 10, 10),
91104
as_strings=False,
92-
print_per_layer_stat=False)
105+
print_per_layer_stat=False,
106+
backend=FLOPS_BACKEND.PYTORCH)
93107
assert params == 0
94108
assert macs > 0
95109

@@ -99,22 +113,52 @@ def forward(self, x):
99113
macs, params = \
100114
get_model_complexity_info(CustomModel(), (3, 10, 10),
101115
as_strings=False,
102-
print_per_layer_stat=False)
116+
print_per_layer_stat=False,
117+
backend=FLOPS_BACKEND.PYTORCH)
103118
assert params == 0
104119
assert macs > 0
105120

106-
def test_ten_matmul(self):
107-
class CustomModel(nn.Module):
108-
def __init__(self):
109-
super().__init__()
121+
def test_ten_matmul(self, simple_model_mm):
122+
macs, params = \
123+
get_model_complexity_info(simple_model_mm, (10, ),
124+
as_strings=False,
125+
print_per_layer_stat=False,
126+
backend=FLOPS_BACKEND.PYTORCH)
110127

111-
def forward(self, x):
112-
return x.matmul(x.t())
128+
assert params == 0
129+
assert macs > 0
113130

131+
def test_aten_ignore(self, simple_model_mm):
132+
ignored_list = [torch.ops.aten.matmul, torch.ops.aten.mm]
114133
macs, params = \
115-
get_model_complexity_info(CustomModel(), (10, ),
134+
get_model_complexity_info(simple_model_mm, (10, ), backend=FLOPS_BACKEND.ATEN,
116135
as_strings=False,
117-
print_per_layer_stat=False)
136+
print_per_layer_stat=False,
137+
ignore_modules=ignored_list)
118138

119139
assert params == 0
120-
assert macs > 0
140+
assert macs == 0
141+
142+
def test_aten_custom(self, simple_model_mm):
143+
reference = 42
144+
custom_hooks = {torch.ops.aten.mm: lambda inputs, outputs: reference}
145+
146+
macs, params = \
147+
get_model_complexity_info(simple_model_mm, (10, ), backend=FLOPS_BACKEND.ATEN,
148+
as_strings=False,
149+
print_per_layer_stat=False,
150+
custom_modules_hooks=custom_hooks)
151+
152+
assert params == 0
153+
assert macs == reference
154+
155+
def test_torch_ignore_func(self, simple_model_mm):
156+
macs, params = \
157+
get_model_complexity_info(simple_model_mm, (10, ),
158+
backend=FLOPS_BACKEND.PYTORCH,
159+
as_strings=False,
160+
print_per_layer_stat=False,
161+
backend_specific_config={'count_functional': False})
162+
163+
assert params == 0
164+
assert macs == 0

0 commit comments

Comments
 (0)