77'''
88
99
10+ from functools import partial
1011import sys
1112import traceback
1213from collections import defaultdict
1617from torch .utils ._python_dispatch import TorchDispatchMode
1718
1819from ptflops .pytorch_engine import get_model_parameters_number
20+ from ptflops .utils import flops_to_string
1921from .aten_ops import ATEN_OPS_MAPPING
2022
2123
22- def normalize_tuple (x ):
23- if not isinstance (x , tuple ):
24- return (x ,)
25- return x
26-
27-
2824class FlopCounterMode (TorchDispatchMode ):
29- def __init__ (self , module = None ):
25+ def __init__ (self , module = None , verbose = False , print_per_layer_stat = False ,
26+ output_params = None ):
27+ self .verbose = verbose
28+ if output_params is None :
29+ output_params = defaultdict (dict )
30+ self .output_params = output_params
31+ self .print_fn = partial (print , ** self .output_params ['print_params' ])
32+
33+ self .print_per_layer_stat = print_per_layer_stat
3034 self .flop_counts = defaultdict (lambda : defaultdict (int ))
3135 self .parents = ['Global' ]
3236 self ._total_complexity = None
@@ -56,9 +60,24 @@ def __enter__(self):
5660
5761 def __exit__ (self , * args ):
5862 self ._total_complexity = sum (self .flop_counts ['Global' ].values ())
63+ if self .print_per_layer_stat :
64+ self .print_fn ('Total:' +
65+ flops_to_string (self ._total_complexity ,
66+ ** self .output_params ['serialize_params' ]))
67+ for mod in self .flop_counts .keys ():
68+ self .print_fn ("Module: " , mod )
69+ for k , v in self .flop_counts [mod ].items ():
70+ self .print_fn (
71+ f'{ k } : ' +
72+ flops_to_string (v , ** self .output_params ['serialize_params' ]))
73+ self .print_fn ()
5974 super ().__exit__ (* args )
6075
6176 def __torch_dispatch__ (self , func , types , args = (), kwargs = None ):
77+ def normalize_tuple (x ):
78+ if not isinstance (x , tuple ):
79+ return (x ,)
80+ return x
6281 kwargs = kwargs if kwargs else {}
6382
6483 out = func (* args , ** kwargs )
@@ -67,6 +86,8 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None):
6786 flop_count = ATEN_OPS_MAPPING [func_packet ](args , normalize_tuple (out ))
6887 for par in self .parents :
6988 self .flop_counts [par ][func_packet ] += flop_count
89+ elif self .verbose :
90+ self .print_fn (f'Warning: { func_packet } operation is treated as a zero-op' )
7091
7192 return out
7293
@@ -82,6 +103,9 @@ def get_flops_aten(model, input_res,
82103 Union [int , None ]]:
83104
84105 params_sum = get_model_parameters_number (model )
106+ output_params = {'serialize_params' :
107+ {'units' : flops_units , 'precision' : output_precision },
108+ 'print_params' : {'file' : ost }}
85109
86110 if input_constructor :
87111 batch = input_constructor (input_res )
@@ -94,7 +118,7 @@ def get_flops_aten(model, input_res,
94118 batch = torch .ones (()).new_empty ((1 , * input_res ))
95119
96120 try :
97- counter = FlopCounterMode (model )
121+ counter = FlopCounterMode (model , verbose , print_per_layer_stat , output_params )
98122 with counter :
99123 if isinstance (batch , dict ):
100124 _ = model (** batch )
0 commit comments