Skip to content

Commit fec84af

Browse files
committed
Add experimental aten backend
1 parent d8d871c commit fec84af

File tree

4 files changed

+254
-5
lines changed

4 files changed

+254
-5
lines changed

ptflops/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,18 @@
11
'''
2-
Copyright (C) 2019-2023 Sovrasov V. - All Rights Reserved
2+
Copyright (C) 2019-2024 Sovrasov V. - All Rights Reserved
33
* You may use, distribute and modify this code under the
44
* terms of the MIT license.
55
* You should have received a copy of the MIT license with
66
* this file. If not visit https://opensource.org/licenses/MIT
77
'''
88

99

10-
from .flops_counter import get_model_complexity_info
10+
from .flops_counter import FLOPS_BACKEND, get_model_complexity_info
1111
from .utils import flops_to_string, params_to_string
1212

1313
__all__ = [
1414
"get_model_complexity_info",
1515
"flops_to_string",
1616
"params_to_string",
17+
"FLOPS_BACKEND",
1718
]

ptflops/aten_engine.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
'''
2+
Copyright (C) 2024 Sovrasov V. - All Rights Reserved
3+
* You may use, distribute and modify this code under the
4+
* terms of the MIT license.
5+
* You should have received a copy of the MIT license with
6+
* this file. If not visit https://opensource.org/licenses/MIT
7+
'''
8+
9+
10+
import sys
11+
import traceback
12+
from collections import defaultdict
13+
from typing import Optional, Tuple, Union
14+
15+
import torch
16+
from torch.utils._python_dispatch import TorchDispatchMode
17+
18+
from ptflops.pytorch_engine import get_model_parameters_number
19+
from .aten_ops import ATEN_OPS_MAPPING
20+
21+
22+
def normalize_tuple(x):
23+
if not isinstance(x, tuple):
24+
return (x,)
25+
return x
26+
27+
28+
class FlopCounterMode(TorchDispatchMode):
29+
def __init__(self, module=None):
30+
self.flop_counts = defaultdict(lambda: defaultdict(int))
31+
self.parents = ['Global']
32+
self._total_complexity = None
33+
if module is not None:
34+
for name, mod in dict(module.named_children()).items():
35+
mod.register_forward_pre_hook(self.enter_module(name))
36+
mod.register_forward_hook(self.exit_module(name))
37+
38+
@property
39+
def complexity(self):
40+
return self._total_complexity
41+
42+
def enter_module(self, name):
43+
def f(*args):
44+
self.parents.append(name)
45+
return f
46+
47+
def exit_module(self, name):
48+
def f(*args):
49+
assert(self.parents[-1] == name)
50+
self.parents.pop()
51+
return f
52+
53+
def __enter__(self):
54+
self.flop_counts.clear()
55+
super().__enter__()
56+
57+
def __exit__(self, *args):
58+
self._total_complexity = sum(self.flop_counts['Global'].values())
59+
super().__exit__(*args)
60+
61+
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
62+
kwargs = kwargs if kwargs else {}
63+
64+
out = func(*args, **kwargs)
65+
func_packet = func._overloadpacket
66+
if func_packet in ATEN_OPS_MAPPING:
67+
flop_count = ATEN_OPS_MAPPING[func_packet](args, normalize_tuple(out))
68+
for par in self.parents:
69+
self.flop_counts[par][func_packet] += flop_count
70+
71+
return out
72+
73+
74+
def get_flops_aten(model, input_res,
75+
print_per_layer_stat=True,
76+
input_constructor=None, ost=sys.stdout,
77+
verbose=False, ignore_modules=[],
78+
custom_modules_hooks={},
79+
output_precision=2,
80+
flops_units: Optional[str] = 'GMac',
81+
param_units: Optional[str] = 'M') -> Tuple[Union[int, None],
82+
Union[int, None]]:
83+
84+
params_sum = get_model_parameters_number(model)
85+
86+
if input_constructor:
87+
batch = input_constructor(input_res)
88+
else:
89+
try:
90+
batch = torch.ones(()).new_empty((1, *input_res),
91+
dtype=next(model.parameters()).dtype,
92+
device=next(model.parameters()).device)
93+
except StopIteration:
94+
batch = torch.ones(()).new_empty((1, *input_res))
95+
96+
try:
97+
counter = FlopCounterMode(model)
98+
with counter:
99+
if isinstance(batch, dict):
100+
_ = model(**batch)
101+
else:
102+
_ = model(batch)
103+
macs_count = counter.complexity
104+
105+
except Exception as e:
106+
print("Flops estimation was not finished successfully because of"
107+
f"the following exception:\n{type(e)} : {e}")
108+
traceback.print_exc()
109+
110+
return None, None
111+
112+
return macs_count, params_sum

ptflops/aten_ops.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
'''
2+
Copyright (C) 2023 Sovrasov V. - All Rights Reserved
3+
* You may use, distribute and modify this code under the
4+
* terms of the MIT license.
5+
* You should have received a copy of the MIT license with
6+
* this file. If not visit https://opensource.org/licenses/MIT
7+
'''
8+
9+
from typing import Any, List
10+
11+
import torch
12+
13+
aten = torch.ops.aten
14+
15+
16+
def get_shape(i):
17+
return i.shape
18+
19+
20+
def prod(x):
21+
res = 1
22+
for i in x:
23+
res *= i
24+
return res
25+
26+
27+
def matmul_flop(inputs: List[Any], outputs: List[Any]) -> int:
28+
"""
29+
Count flops for matmul.
30+
"""
31+
# Inputs should be a list of length 2.
32+
# Inputs contains the shapes of two matrices.
33+
input_shapes = [get_shape(v) for v in inputs]
34+
assert len(input_shapes) == 2, input_shapes
35+
assert input_shapes[0][-1] == input_shapes[1][-2], input_shapes
36+
flop = prod(input_shapes[0]) * input_shapes[-1][-1]
37+
return flop
38+
39+
40+
def addmm_flop(inputs: List[Any], outputs: List[Any]) -> int:
41+
"""
42+
Count flops for fully connected layers.
43+
"""
44+
# Count flop for nn.Linear
45+
# inputs is a list of length 3.
46+
input_shapes = [get_shape(v) for v in inputs[1:3]]
47+
# input_shapes[0]: [batch size, input feature dimension]
48+
# input_shapes[1]: [batch size, output feature dimension]
49+
assert len(input_shapes[0]) == 2, input_shapes[0]
50+
assert len(input_shapes[1]) == 2, input_shapes[1]
51+
batch_size, input_dim = input_shapes[0]
52+
output_dim = input_shapes[1][1]
53+
flops = batch_size * input_dim * output_dim
54+
return flops
55+
56+
57+
def bmm_flop(inputs: List[Any], outputs: List[Any]) -> int:
58+
"""
59+
Count flops for the bmm operation.
60+
"""
61+
# Inputs should be a list of length 2.
62+
# Inputs contains the shapes of two tensor.
63+
assert len(inputs) == 2, len(inputs)
64+
input_shapes = [get_shape(v) for v in inputs]
65+
n, c, t = input_shapes[0]
66+
d = input_shapes[-1][-1]
67+
flop = n * c * t * d
68+
return flop
69+
70+
71+
def conv_flop_count(
72+
x_shape: List[int],
73+
w_shape: List[int],
74+
out_shape: List[int],
75+
transposed: bool = False,
76+
) -> int:
77+
"""
78+
Count flops for convolution. Note only multiplication is
79+
counted. Computation for addition and bias is ignored.
80+
Flops for a transposed convolution are calculated as
81+
flops = (x_shape[2:] * prod(w_shape) * batch_size).
82+
Args:
83+
x_shape (list(int)): The input shape before convolution.
84+
w_shape (list(int)): The filter shape.
85+
out_shape (list(int)): The output shape after convolution.
86+
transposed (bool): is the convolution transposed
87+
Returns:
88+
int: the number of flops
89+
"""
90+
batch_size = x_shape[0]
91+
conv_shape = (x_shape if transposed else out_shape)[2:]
92+
flop = batch_size * prod(w_shape) * prod(conv_shape)
93+
return flop
94+
95+
96+
def conv_flop(inputs: List[Any], outputs: List[Any]):
97+
"""
98+
Count flops for convolution.
99+
"""
100+
x, w = inputs[:2]
101+
x_shape, w_shape, out_shape = (get_shape(x), get_shape(w), get_shape(outputs[0]))
102+
transposed = inputs[6]
103+
104+
return conv_flop_count(x_shape, w_shape, out_shape, transposed=transposed)
105+
106+
107+
def transpose_shape(shape):
108+
return [shape[1], shape[0]] + list(shape[2:])
109+
110+
111+
ATEN_OPS_MAPPING = {
112+
aten.mm: matmul_flop,
113+
aten.matmul: matmul_flop,
114+
aten.addmm: addmm_flop,
115+
aten.bmm: bmm_flop,
116+
aten.convolution: conv_flop,
117+
aten._convolution: conv_flop,
118+
}

ptflops/flops_counter.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,27 @@
11
'''
2-
Copyright (C) 2019-2023 Sovrasov V. - All Rights Reserved
2+
Copyright (C) 2019-2024 Sovrasov V. - All Rights Reserved
33
* You may use, distribute and modify this code under the
44
* terms of the MIT license.
55
* You should have received a copy of the MIT license with
66
* this file. If not visit https://opensource.org/licenses/MIT
77
'''
88

99
import sys
10+
from enum import Enum
1011
from typing import Any, Callable, Dict, List, Optional, TextIO, Tuple, Union
1112

1213
import torch.nn as nn
1314

15+
from .aten_engine import get_flops_aten
1416
from .pytorch_engine import get_flops_pytorch
1517
from .utils import flops_to_string, params_to_string
1618

1719

20+
class FLOPS_BACKEND(Enum):
21+
PYTORCH = 'pytorch'
22+
ATEN = 'aten'
23+
24+
1825
def get_model_complexity_info(model: nn.Module,
1926
input_res: Tuple[int, ...],
2027
print_per_layer_stat: bool = True,
@@ -24,7 +31,7 @@ def get_model_complexity_info(model: nn.Module,
2431
verbose: bool = False,
2532
ignore_modules: List[nn.Module] = [],
2633
custom_modules_hooks: Dict[nn.Module, Any] = {},
27-
backend: str = 'pytorch',
34+
backend: Union[str, FLOPS_BACKEND] = FLOPS_BACKEND.PYTORCH,
2835
flops_units: Optional[str] = None,
2936
param_units: Optional[str] = None,
3037
output_precision: int = 2) -> Tuple[Union[str, int, None],
@@ -58,6 +65,8 @@ def get_model_complexity_info(model: nn.Module,
5865
:type ignore_modules: nn.Module
5966
:param custom_modules_hooks: A dict that contains custom hooks on torch modules.
6067
:type custom_modules_hooks: Dict[nn.Module, Any]
68+
:param backend: Backend that used for evaluating model complexity.
69+
:type backend: FLOPS_BACKEND
6170
:param flops_units: Units for string representation of MACs (GMac, MMac or KMac).
6271
:type flops_units: Optional[str]
6372
:param param_units: Units for string representation of params (M, K or B).
@@ -75,7 +84,7 @@ def get_model_complexity_info(model: nn.Module,
7584
assert len(input_res) >= 1
7685
assert isinstance(model, nn.Module)
7786

78-
if backend == 'pytorch':
87+
if FLOPS_BACKEND(backend) == FLOPS_BACKEND.PYTORCH:
7988
flops_count, params_count = get_flops_pytorch(model, input_res,
8089
print_per_layer_stat,
8190
input_constructor, ost,
@@ -84,6 +93,15 @@ def get_model_complexity_info(model: nn.Module,
8493
output_precision=output_precision,
8594
flops_units=flops_units,
8695
param_units=param_units)
96+
elif FLOPS_BACKEND(backend) == FLOPS_BACKEND.ATEN:
97+
flops_count, params_count = get_flops_aten(model, input_res,
98+
print_per_layer_stat,
99+
input_constructor, ost,
100+
verbose, ignore_modules,
101+
custom_modules_hooks,
102+
output_precision=output_precision,
103+
flops_units=flops_units,
104+
param_units=param_units)
87105
else:
88106
raise ValueError('Wrong backend name')
89107

0 commit comments

Comments
 (0)