Skip to content

Commit 744e746

Browse files
committed
Forward printing parametere to aten backend
1 parent fec84af commit 744e746

File tree

2 files changed

+32
-12
lines changed

2 files changed

+32
-12
lines changed

ptflops/aten_engine.py

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
'''
88

99

10+
from functools import partial
1011
import sys
1112
import traceback
1213
from collections import defaultdict
@@ -16,17 +17,20 @@
1617
from torch.utils._python_dispatch import TorchDispatchMode
1718

1819
from ptflops.pytorch_engine import get_model_parameters_number
20+
from ptflops.utils import flops_to_string
1921
from .aten_ops import ATEN_OPS_MAPPING
2022

2123

22-
def normalize_tuple(x):
23-
if not isinstance(x, tuple):
24-
return (x,)
25-
return x
26-
27-
2824
class FlopCounterMode(TorchDispatchMode):
29-
def __init__(self, module=None):
25+
def __init__(self, module=None, verbose=False, print_per_layer_stat=False,
26+
output_params=None):
27+
self.verbose = verbose
28+
if output_params is None:
29+
output_params = defaultdict(dict)
30+
self.output_params = output_params
31+
self.print_fn = partial(print, **self.output_params['print_params'])
32+
33+
self.print_per_layer_stat = print_per_layer_stat
3034
self.flop_counts = defaultdict(lambda: defaultdict(int))
3135
self.parents = ['Global']
3236
self._total_complexity = None
@@ -56,9 +60,24 @@ def __enter__(self):
5660

5761
def __exit__(self, *args):
5862
self._total_complexity = sum(self.flop_counts['Global'].values())
63+
if self.print_per_layer_stat:
64+
self.print_fn('Total:' +
65+
flops_to_string(self._total_complexity,
66+
**self.output_params['serialize_params']))
67+
for mod in self.flop_counts.keys():
68+
self.print_fn("Module: ", mod)
69+
for k, v in self.flop_counts[mod].items():
70+
self.print_fn(
71+
f'{k}: ' +
72+
flops_to_string(v, **self.output_params['serialize_params']))
73+
self.print_fn()
5974
super().__exit__(*args)
6075

6176
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
77+
def normalize_tuple(x):
78+
if not isinstance(x, tuple):
79+
return (x,)
80+
return x
6281
kwargs = kwargs if kwargs else {}
6382

6483
out = func(*args, **kwargs)
@@ -67,6 +86,8 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None):
6786
flop_count = ATEN_OPS_MAPPING[func_packet](args, normalize_tuple(out))
6887
for par in self.parents:
6988
self.flop_counts[par][func_packet] += flop_count
89+
elif self.verbose:
90+
self.print_fn(f'Warning: {func_packet} operation is treated as a zero-op')
7091

7192
return out
7293

@@ -82,6 +103,9 @@ def get_flops_aten(model, input_res,
82103
Union[int, None]]:
83104

84105
params_sum = get_model_parameters_number(model)
106+
output_params = {'serialize_params':
107+
{'units': flops_units, 'precision': output_precision},
108+
'print_params': {'file': ost}}
85109

86110
if input_constructor:
87111
batch = input_constructor(input_res)
@@ -94,7 +118,7 @@ def get_flops_aten(model, input_res,
94118
batch = torch.ones(()).new_empty((1, *input_res))
95119

96120
try:
97-
counter = FlopCounterMode(model)
121+
counter = FlopCounterMode(model, verbose, print_per_layer_stat, output_params)
98122
with counter:
99123
if isinstance(batch, dict):
100124
_ = model(**batch)

ptflops/aten_ops.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,10 +104,6 @@ def conv_flop(inputs: List[Any], outputs: List[Any]):
104104
return conv_flop_count(x_shape, w_shape, out_shape, transposed=transposed)
105105

106106

107-
def transpose_shape(shape):
108-
return [shape[1], shape[0]] + list(shape[2:])
109-
110-
111107
ATEN_OPS_MAPPING = {
112108
aten.mm: matmul_flop,
113109
aten.matmul: matmul_flop,

0 commit comments

Comments
 (0)