3232import time
3333
3434
35- from typing import Any , Callable , List , Optional , Union
35+ from typing import Any , Callable , Dict , List , Sequence , Union
3636from numbers import Number
3737
3838Handle = Callable [[List [Any ], List [Any ]], Union [typing .Counter [str ], Number ]]
3939
4040
41- def get_shape (val : object ) -> typing .List [int ]:
41+ def get_shape (val : Any ) -> typing .List [int ]:
4242 """
4343 Get the shapes from a jit value object.
4444 Args:
45- val (torch._C.Value) : jit value object.
45+ val: jit value object.
4646 Returns:
47- list(int): return a list of ints.
47+ return a list of ints.
4848 """
4949 if val .isCompleteTensor (): # pyre-ignore
5050 r = val .type ().sizes () # pyre-ignore
@@ -64,17 +64,17 @@ def get_shape(val: object) -> typing.List[int]:
6464
6565
6666def addmm_flop_jit (
67- inputs : typing .List [object ], outputs : typing .List [object ]
67+ inputs : typing .List [Any ], outputs : typing .List [Any ]
6868) -> typing .Counter [str ]:
6969 """
7070 This method counts the flops for fully connected layers with torch script.
7171 Args:
72- inputs (list(torch._C.Value)) : The input shape in the form of a list of
72+ inputs: The input shape in the form of a list of
7373 jit object.
74- outputs (list(torch._C.Value)) : The output shape in the form of a list
74+ outputs: The output shape in the form of a list
7575 of jit object.
7676 Returns:
77- Counter: A Counter dictionary that records the number of flops for each
77+ A Counter dictionary that records the number of flops for each
7878 operation.
7979 """
8080 # Count flop for nn.Linear
@@ -91,7 +91,7 @@ def addmm_flop_jit(
9191 return flop_counter
9292
9393
94- def bmm_flop_jit (inputs , outputs ) :
94+ def bmm_flop_jit (inputs : typing . List [ Any ] , outputs : typing . List [ Any ]) -> Counter [ str ] :
9595 # Count flop for nn.Linear
9696 # inputs is a list of length 3.
9797 input_shapes = [get_shape (v ) for v in inputs ]
@@ -106,7 +106,7 @@ def bmm_flop_jit(inputs, outputs):
106106 return flop_counter
107107
108108
109- def basic_binary_op_flop_jit (inputs , outputs , name ) :
109+ def basic_binary_op_flop_jit (inputs : typing . List [ Any ] , outputs : typing . List [ Any ] , name : str ) -> Counter [ str ] :
110110 input_shapes = [get_shape (v ) for v in inputs ]
111111 # for broadcasting
112112 input_shapes = [s [::- 1 ] for s in input_shapes ]
@@ -116,29 +116,34 @@ def basic_binary_op_flop_jit(inputs, outputs, name):
116116 return flop_counter
117117
118118
119- def rsqrt_flop_jit (inputs , outputs ) :
119+ def rsqrt_flop_jit (inputs : typing . List [ Any ] , outputs : typing . List [ Any ]) -> Counter [ str ] :
120120 input_shapes = [get_shape (v ) for v in inputs ]
121121 flop = prod (input_shapes [0 ]) * 2
122122 flop_counter = Counter ({"rsqrt" : flop })
123123 return flop_counter
124124
125125
126- def dropout_flop_jit (inputs , outputs ) :
126+ def dropout_flop_jit (inputs : typing . List [ Any ] , outputs : typing . List [ Any ]) -> Counter [ str ] :
127127 input_shapes = [get_shape (v ) for v in inputs [:1 ]]
128128 flop = prod (input_shapes [0 ])
129129 flop_counter = Counter ({"dropout" : flop })
130130 return flop_counter
131131
132132
133- def softmax_flop_jit (inputs , outputs ) :
133+ def softmax_flop_jit (inputs : typing . List [ Any ] , outputs : typing . List [ Any ]) -> Counter [ str ] :
134134 # from https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/profiler/internal/flops_registry.py
135135 input_shapes = [get_shape (v ) for v in inputs [:1 ]]
136136 flop = prod (input_shapes [0 ]) * 5
137137 flop_counter = Counter ({"softmax" : flop })
138138 return flop_counter
139139
140140
141- def _reduction_op_flop_jit (inputs , outputs , reduce_flops = 1 , finalize_flops = 0 ):
141+ def _reduction_op_flop_jit (
142+ inputs : typing .List [Any ],
143+ outputs : typing .List [Any ],
144+ reduce_flops : int = 1 ,
145+ finalize_flops : int = 0 ,
146+ ) -> int :
142147 input_shapes = [get_shape (v ) for v in inputs ]
143148 output_shapes = [get_shape (v ) for v in outputs ]
144149
@@ -161,11 +166,11 @@ def conv_flop_count(
161166 This method counts the flops for convolution. Note only multiplication is
162167 counted. Computation for addition and bias is ignored.
163168 Args:
164- x_shape (list(int)) : The input shape before convolution.
165- w_shape (list(int)) : The filter shape.
166- out_shape (list(int)) : The output shape after convolution.
169+ x_shape: The input shape before convolution.
170+ w_shape: The filter shape.
171+ out_shape: The output shape after convolution.
167172 Returns:
168- Counter: A Counter dictionary that records the number of flops for each
173+ A Counter dictionary that records the number of flops for each
169174 operation.
170175 """
171176 batch_size , Cin_dim , Cout_dim = x_shape [0 ], w_shape [1 ], out_shape [1 ]
@@ -177,17 +182,17 @@ def conv_flop_count(
177182
178183
179184def conv_flop_jit (
180- inputs : typing .List [object ], outputs : typing .List [object ]
185+ inputs : typing .List [Any ], outputs : typing .List [Any ]
181186) -> typing .Counter [str ]:
182187 """
183188 This method counts the flops for convolution using torch script.
184189 Args:
185- inputs (list(torch._C.Value)) : The input shape in the form of a list of
190+ inputs: The input shape in the form of a list of
186191 jit object before convolution.
187- outputs (list(torch._C.Value)) : The output shape in the form of a list
192+ outputs: The output shape in the form of a list
188193 of jit object after convolution.
189194 Returns:
190- Counter: A Counter dictionary that records the number of flops for each
195+ A Counter dictionary that records the number of flops for each
191196 operation.
192197 """
193198 # Inputs of Convolution should be a list of length 12. They represent:
@@ -206,18 +211,18 @@ def conv_flop_jit(
206211
207212
208213def einsum_flop_jit (
209- inputs : typing .List [object ], outputs : typing .List [object ]
214+ inputs : typing .List [Any ], outputs : typing .List [Any ]
210215) -> typing .Counter [str ]:
211216 """
212217 This method counts the flops for the einsum operation. We currently support
213218 two einsum operations: "nct,ncp->ntp" and "ntg,ncg->nct".
214219 Args:
215- inputs (list(torch._C.Value)) : The input shape in the form of a list of
220+ inputs: The input shape in the form of a list of
216221 jit object before einsum.
217- outputs (list(torch._C.Value)) : The output shape in the form of a list
222+ outputs: The output shape in the form of a list
218223 of jit object after einsum.
219224 Returns:
220- Counter: A Counter dictionary that records the number of flops for each
225+ A Counter dictionary that records the number of flops for each
221226 operation.
222227 """
223228 # Inputs of einsum should be a list of length 2.
@@ -254,17 +259,17 @@ def einsum_flop_jit(
254259
255260
256261def matmul_flop_jit (
257- inputs : typing .List [object ], outputs : typing .List [object ]
262+ inputs : typing .List [Any ], outputs : typing .List [Any ]
258263) -> typing .Counter [str ]:
259264 """
260265 This method counts the flops for matmul.
261266 Args:
262- inputs (list(torch._C.Value)) : The input shape in the form of a list of
267+ inputs: The input shape in the form of a list of
263268 jit object before matmul.
264- outputs (list(torch._C.Value)) : The output shape in the form of a list
269+ outputs: The output shape in the form of a list
265270 of jit object after matmul.
266271 Returns:
267- Counter: A Counter dictionary that records the number of flops for each
272+ A Counter dictionary that records the number of flops for each
268273 operation.
269274 """
270275
@@ -287,17 +292,17 @@ def matmul_flop_jit(
287292
288293
289294def batchnorm_flop_jit (
290- inputs : typing .List [object ], outputs : typing .List [object ]
295+ inputs : typing .List [Any ], outputs : typing .List [Any ]
291296) -> typing .Counter [str ]:
292297 """
293298 This method counts the flops for batch norm.
294299 Args:
295- inputs (list(torch._C.Value)) : The input shape in the form of a list of
300+ inputs: The input shape in the form of a list of
296301 jit object before batch norm.
297- outputs (list(torch._C.Value)) : The output shape in the form of a list
302+ outputs: The output shape in the form of a list
298303 of jit object after batch norm.
299304 Returns:
300- Counter: A Counter dictionary that records the number of flops for each
305+ A Counter dictionary that records the number of flops for each
301306 operation.
302307 """
303308 # Inputs[0] contains the shape of the input.
@@ -470,26 +475,26 @@ def elementwise_flop(inputs: List[Any], outputs: List[Any]) -> Number:
470475
471476def flop_count (
472477 model : nn .Module ,
473- inputs : typing .Tuple [object , ...],
474- whitelist : typing .Union [typing .List [str ], None ] = None ,
475- customized_ops : typing .Union [typing .Dict [str , typing .Callable ], None ] = None ,
478+ inputs : typing .Tuple [Any , ...],
479+ whitelist : typing .Optional [typing .List [str ]] = None ,
480+ customized_ops : typing .Optional [typing .Dict [str , typing .Callable ]] = None ,
476481) -> typing .DefaultDict [str , float ]:
477482 """
478483 Given a model and an input to the model, compute the Gflops of the given
479484 model. Note the input should have a batch size of 1.
480485 Args:
481- model (nn.Module) : The model to compute flop counts.
482- inputs (tuple) : Inputs that are passed to `model` to count flops.
486+ model: The model to compute flop counts.
487+ inputs: Inputs that are passed to `model` to count flops.
483488 Inputs need to be in a tuple.
484- whitelist (list(str)) : Whitelist of operations that will be counted. It
489+ whitelist: Whitelist of operations that will be counted. It
485490 needs to be a subset of _SUPPORTED_OPS. By default, the function
486491 computes flops for all supported operations.
487- customized_ops (dict(str,Callable)) : A dictionary contains customized
492+ customized_ops: A dictionary contains customized
488493 operations and their flop handles. If customized_ops contains an
489494 operation in _SUPPORTED_OPS, then the default handle in
490495 _SUPPORTED_OPS will be overwritten.
491496 Returns:
492- defaultdict: A dictionary that records the number of gflops for each
497+ A dictionary that records the number of gflops for each
493498 operation.
494499 """
495500 # Copy _SUPPORTED_OPS to flop_count_ops.
@@ -557,13 +562,13 @@ def flop_count(
557562 return final_count
558563
559564
560- def warmup (model , inputs , N = 10 ):
565+ def warmup (model : torch . nn . Module , inputs : Any , N : int = 10 ) -> None :
561566 for i in range (N ):
562567 out = model (inputs )
563568 torch .cuda .synchronize ()
564569
565570
566- def measure_time (model , inputs , N = 10 ):
571+ def measure_time (model : torch . nn . Module , inputs : Any , N : int = 10 ) -> float :
567572 warmup (model , inputs )
568573 s = time .time ()
569574 for i in range (N ):
@@ -573,7 +578,7 @@ def measure_time(model, inputs, N=10):
573578 return t
574579
575580
576- def fmt_res (data ) :
581+ def fmt_res (data : np . ndarray ) -> Dict [ str , float ] :
577582 # return data.mean(), data.std(), data.min(), data.max()
578583 return {
579584 "mean" : data .mean (),
@@ -583,7 +588,7 @@ def fmt_res(data):
583588 }
584589
585590
586- def benchmark (model , dataset , output_dir ) :
591+ def benchmark (model : torch . nn . Module , dataset : Sequence [ Any ] , output_dir : Any ) -> Dict [ str , Any ] :
587592 print ("Get model size, FLOPs, and FPS" )
588593 # import pdb; pdb.set_trace()
589594 _outputs = {}
0 commit comments