Skip to content

Commit 7a2a141

Browse files
authored
Merge pull request #133 from sovrasov/aten_backend
Aten backend
2 parents c2e1af7 + 81ace22 commit 7a2a141

9 files changed

+367
-51
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
# ptflops versions log
22

3+
## v 0.7.3
4+
- Add aten backend to collect the amount of flops on aten level.
5+
36
## v 0.7.2.2
47
- Switch from setup.py to pyproject
58

README.md

Lines changed: 57 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,26 @@
22
[![Pypi version](https://img.shields.io/pypi/v/ptflops.svg)](https://pypi.org/project/ptflops/)
33
[![Build Status](https://travis-ci.com/sovrasov/flops-counter.pytorch.svg?branch=master)](https://travis-ci.com/sovrasov/flops-counter.pytorch)
44

5-
This script is designed to compute the theoretical amount of multiply-add operations
6-
in convolutional neural networks. It can also compute the number of parameters and
5+
This tool is designed to compute the theoretical amount of multiply-add operations
6+
in neural networks. It can also compute the number of parameters and
77
print per-layer computational cost of a given network.
88

9-
Supported layers:
9+
`ptflops` has two backends, `pytorch` and `aten`. `pytorch` backend is a legacy one, it considers `nn.Modules` only. However,
10+
it's still useful, since it provides a better par-layer analytics for CNNs. In all other cases it's recommended to use
11+
`aten` backend, which considers aten operations, and therefore it covers more model architectures (including transformers).
12+
13+
## `aten` backend
14+
### Operations considered:
15+
- aten.mm, aten.matmul, aten.addmm, aten.bmm
16+
- aten.convolution
17+
18+
### Usage tips
19+
- Use `verbose=True` to see the operations which were not considered during complexity computation.
20+
- This backend prints per-module statistics only for modules directly nested into the root `nn.Module`.
21+
Deeper modules at the second level of nesting are not shown in the per-layer statistics.
22+
23+
## `pytorch` backend
24+
### Supported layers:
1025
- Conv1d/2d/3d (including grouping)
1126
- ConvTranspose1d/2d/3d (including grouping)
1227
- BatchNorm1d/2d/3d, GroupNorm, InstanceNorm1d/2d/3d, LayerNorm
@@ -22,20 +37,20 @@ Experimental support:
2237
- torchvision.ops.DeformConv2d
2338
- visual transformers from [timm](https://github.com/huggingface/pytorch-image-models)
2439

25-
Requirements: Pytorch >= 1.1, torchvision >= 0.3
26-
27-
Thanks to @warmspringwinds for the initial version of script.
28-
29-
## Usage tips
40+
### Usage tips
3041

31-
- This tool doesn't take into account some of the `torch.nn.functional.*` and `tensor.*` operations. Therefore unsupported operations are
42+
- This backend doesn't take into account some of the `torch.nn.functional.*` and `tensor.*` operations. Therefore unsupported operations are
3243
not contributing to the final complexity estimation. See `ptflops/pytorch_ops.py:FUNCTIONAL_MAPPING,TENSOR_OPS_MAPPING` to check supported ops.
3344
- `ptflops` launches a given model on a random tensor and estimates amount of computations during inference. Complicated models can have several inputs, some of them could be optional. To construct non-trivial input one can use the `input_constructor` argument of the `get_model_complexity_info`. `input_constructor` is a function that takes the input spatial resolution as a tuple and returns a dict with named input arguments of the model. Next this dict would be passed to the model as a keyword arguments.
3445
- `verbose` parameter allows to get information about modules that don't contribute to the final numbers.
3546
- `ignore_modules` option forces `ptflops` to ignore the listed modules. This can be useful
3647
for research purposes. For instance, one can drop all convolutions from the counting process
3748
specifying `ignore_modules=[torch.nn.Conv2d]`.
3849

50+
Requirements: Pytorch >= 1.1, torchvision >= 0.3
51+
52+
Thanks to @warmspringwinds and Horace He for the initial version of the script.
53+
3954
## Install the latest version
4055
From PyPI:
4156
```bash
@@ -55,7 +70,12 @@ from ptflops import get_model_complexity_info
5570

5671
with torch.cuda.device(0):
5772
net = models.densenet161()
58-
macs, params = get_model_complexity_info(net, (3, 224, 224), as_strings=True,
73+
macs, params = get_model_complexity_info(net, (3, 224, 224), as_strings=True, backend='pytorch'
74+
print_per_layer_stat=True, verbose=True)
75+
print('{:<30} {:<8}'.format('Computational complexity: ', macs))
76+
print('{:<30} {:<8}'.format('Number of parameters: ', params))
77+
78+
macs, params = get_model_complexity_info(net, (3, 224, 224), as_strings=True, backend='aten'
5979
print_per_layer_stat=True, verbose=True)
6080
print('{:<30} {:<8}'.format('Computational complexity: ', macs))
6181
print('{:<30} {:<8}'.format('Number of parameters: ', params))
@@ -67,7 +87,7 @@ If ptflops was useful for your paper or tech report, please cite me:
6787
@online{ptflops,
6888
author = {Vladislav Sovrasov},
6989
title = {ptflops: a flops counting tool for neural networks in pytorch framework},
70-
year = 2018-2023,
90+
year = 2018-2024,
7191
url = {https://github.com/sovrasov/flops-counter.pytorch},
7292
}
7393
```
@@ -76,25 +96,30 @@ If ptflops was useful for your paper or tech report, please cite me:
7696

7797
### [torchvision](https://pytorch.org/vision/0.16/models.html)
7898

99+
Model | Input Resolution | Params(M) | MACs(G) (`pytorch`) | MACs(G) (`aten`)
100+
--- |--- |--- |--- |---
101+
alexnet | 224x224 | 61.10 | 0.72 | 0.71
102+
convnext_base | 224x224 | 88.59 | 15.43 | 15.38
103+
densenet121 | 224x224 | 7.98 | 2.90 |
104+
efficientnet_b0 | 224x224 | 5.29 | 0.41 |
105+
efficientnet_v2_m | 224x224 | 54.14 | 5.43 |
106+
googlenet | 224x224 | 13.00 | 1.51 |
107+
inception_v3 | 224x224 | 27.16 | 5.75 | 5.71
108+
maxvit_t | 224x224 | 30.92 | 5.48 |
109+
mnasnet1_0 | 224x224 | 4.38 | 0.33 |
110+
mobilenet_v2 | 224x224 | 3.50 | 0.32 |
111+
mobilenet_v3_large | 224x224 | 5.48 | 0.23 |
112+
regnet_y_1_6gf | 224x224 | 11.20 | 1.65 |
113+
resnet18 | 224x224 | 11.69 | 1.83 | 1.81
114+
resnet50 | 224x224 | 25.56 | 4.13 | 4.09
115+
resnext50_32x4d | 224x224 | 25.03 | 4.29 |
116+
shufflenet_v2_x1_0 | 224x224 | 2.28 | 0.15 |
117+
squeezenet1_0 | 224x224 | 1.25 | 0.84 | 0.82
118+
vgg16 | 224x224 | 138.36 | 15.52 | 15.48
119+
vit_b_16 | 224x224 | 86.57 | 17.61 (wrong) | 16.86
120+
wide_resnet50_2 | 224x224 | 68.88 | 11.45 |
121+
122+
123+
### [timm](https://github.com/huggingface/pytorch-image-models)
124+
79125
Model | Input Resolution | Params(M) | MACs(G)
80-
--- |--- |--- |---
81-
alexnet | 224x224 | 61.10 | 0.72
82-
convnext_base | 224x224 | 88.59 | 15.43
83-
densenet121 | 224x224 | 7.98 | 2.90
84-
efficientnet_b0 | 224x224 | 5.29 | 0.41
85-
efficientnet_v2_m | 224x224 | 54.14 | 5.43
86-
googlenet | 224x224 | 13.00 | 1.51
87-
inception_v3 | 224x224 | 27.16 | 2.86
88-
maxvit_t | 224x224 | 30.92 | 5.48
89-
mnasnet1_0 | 224x224 | 4.38 | 0.33
90-
mobilenet_v2 | 224x224 | 3.50 | 0.32
91-
mobilenet_v3_large | 224x224 | 5.48 | 0.23
92-
regnet_y_1_6gf | 224x224 | 11.20 | 1.65
93-
resnet18 | 224x224 | 11.69 | 1.83
94-
resnet50 | 224x224 | 25.56 | 4.13
95-
resnext50_32x4d | 224x224 | 25.03 | 4.29
96-
shufflenet_v2_x1_0 | 224x224 | 2.28 | 0.15
97-
squeezenet1_0 | 224x224 | 1.25 | 0.84
98-
vgg16 | 224x224 | 138.36 | 15.52
99-
vit_b_16 | 224x224 | 86.57 | 17.60
100-
wide_resnet50_2 | 224x224 | 68.88 | 11.45

ptflops/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,18 @@
11
'''
2-
Copyright (C) 2019-2023 Sovrasov V. - All Rights Reserved
2+
Copyright (C) 2019-2024 Sovrasov V. - All Rights Reserved
33
* You may use, distribute and modify this code under the
44
* terms of the MIT license.
55
* You should have received a copy of the MIT license with
66
* this file. If not visit https://opensource.org/licenses/MIT
77
'''
88

99

10-
from .flops_counter import get_model_complexity_info
10+
from .flops_counter import FLOPS_BACKEND, get_model_complexity_info
1111
from .utils import flops_to_string, params_to_string
1212

1313
__all__ = [
1414
"get_model_complexity_info",
1515
"flops_to_string",
1616
"params_to_string",
17+
"FLOPS_BACKEND",
1718
]

ptflops/aten_engine.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
'''
2+
Copyright (C) 2024 Sovrasov V. - All Rights Reserved
3+
* You may use, distribute and modify this code under the
4+
* terms of the MIT license.
5+
* You should have received a copy of the MIT license with
6+
* this file. If not visit https://opensource.org/licenses/MIT
7+
'''
8+
9+
10+
import sys
11+
import traceback
12+
from collections import defaultdict
13+
from functools import partial
14+
from typing import Optional, Tuple, Union
15+
16+
import torch
17+
from torch.utils._python_dispatch import TorchDispatchMode
18+
19+
from ptflops.pytorch_engine import get_model_parameters_number
20+
from ptflops.utils import flops_to_string
21+
from .aten_ops import ATEN_OPS_MAPPING
22+
23+
24+
class FlopCounterMode(TorchDispatchMode):
25+
def __init__(self, module=None, verbose=False, print_per_layer_stat=False,
26+
output_params=None):
27+
self.verbose = verbose
28+
if output_params is None:
29+
output_params = defaultdict(dict)
30+
self.output_params = output_params
31+
self.print_fn = partial(print, **self.output_params['print_params'])
32+
33+
self.print_per_layer_stat = print_per_layer_stat
34+
self.flop_counts = defaultdict(lambda: defaultdict(int))
35+
self.parents = ['Global']
36+
self._total_complexity = None
37+
if module is not None:
38+
for name, mod in dict(module.named_children()).items():
39+
mod.register_forward_pre_hook(self.enter_module(name))
40+
mod.register_forward_hook(self.exit_module(name))
41+
42+
@property
43+
def complexity(self):
44+
return self._total_complexity
45+
46+
def enter_module(self, name):
47+
def f(*args):
48+
self.parents.append(name)
49+
return f
50+
51+
def exit_module(self, name):
52+
def f(*args):
53+
assert(self.parents[-1] == name)
54+
self.parents.pop()
55+
return f
56+
57+
def __enter__(self):
58+
self.flop_counts.clear()
59+
super().__enter__()
60+
61+
def __exit__(self, *args):
62+
self._total_complexity = sum(self.flop_counts['Global'].values())
63+
if self.print_per_layer_stat:
64+
self.print_fn('Total:' +
65+
flops_to_string(self._total_complexity,
66+
**self.output_params['serialize_params']))
67+
for mod in self.flop_counts.keys():
68+
self.print_fn("Module: ", mod)
69+
for k, v in self.flop_counts[mod].items():
70+
self.print_fn(
71+
f'{k}: ' +
72+
flops_to_string(v, **self.output_params['serialize_params']))
73+
self.print_fn()
74+
super().__exit__(*args)
75+
76+
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
77+
def normalize_tuple(x):
78+
if not isinstance(x, tuple):
79+
return (x,)
80+
return x
81+
kwargs = kwargs if kwargs else {}
82+
83+
out = func(*args, **kwargs)
84+
func_packet = func._overloadpacket
85+
if func_packet in ATEN_OPS_MAPPING:
86+
flop_count = ATEN_OPS_MAPPING[func_packet](args, normalize_tuple(out))
87+
for par in self.parents:
88+
self.flop_counts[par][func_packet] += flop_count
89+
elif self.verbose:
90+
self.print_fn(f'Warning: {func_packet} operation is treated as a zero-op')
91+
92+
return out
93+
94+
95+
def get_flops_aten(model, input_res,
96+
print_per_layer_stat=True,
97+
input_constructor=None, ost=sys.stdout,
98+
verbose=False, ignore_modules=[],
99+
custom_modules_hooks={},
100+
output_precision=2,
101+
flops_units: Optional[str] = 'GMac',
102+
param_units: Optional[str] = 'M') -> Tuple[Union[int, None],
103+
Union[int, None]]:
104+
105+
params_sum = get_model_parameters_number(model)
106+
model.eval()
107+
output_params = {'serialize_params':
108+
{'units': flops_units, 'precision': output_precision},
109+
'print_params': {'file': ost}}
110+
111+
if input_constructor:
112+
batch = input_constructor(input_res)
113+
else:
114+
try:
115+
batch = torch.ones(()).new_empty((1, *input_res),
116+
dtype=next(model.parameters()).dtype,
117+
device=next(model.parameters()).device)
118+
except StopIteration:
119+
batch = torch.ones(()).new_empty((1, *input_res))
120+
121+
try:
122+
counter = FlopCounterMode(model, verbose, print_per_layer_stat, output_params)
123+
with counter:
124+
if isinstance(batch, dict):
125+
_ = model(**batch)
126+
else:
127+
_ = model(batch)
128+
macs_count = counter.complexity
129+
130+
except Exception as e:
131+
print("Flops estimation was not finished successfully because of"
132+
f" the following exception:\n{type(e)} : {e}")
133+
traceback.print_exc()
134+
135+
return None, None
136+
137+
return macs_count, params_sum

0 commit comments

Comments
 (0)