7
7
import collections .abc as abc
8
8
import logging
9
9
import operator
10
+ from typing import Callable
10
11
11
12
import torch
12
13
import torch .nn as nn
@@ -183,8 +184,12 @@ def flops(self, x):
183
184
elif layer_type in ["AdaptiveAvgPool2d" ]:
184
185
in_h = x .size ()[2 ]
185
186
in_w = x .size ()[3 ]
186
- out_h = layer .output_size [0 ]
187
- out_w = layer .output_size [1 ]
187
+ if isinstance (layer .output_size , int ):
188
+ out_h , out_w = layer .output_size , layer .output_size
189
+ elif len (layer .output_size ) == 1 :
190
+ out_h , out_w = layer .output_size [0 ], layer .output_size [0 ]
191
+ else :
192
+ out_h , out_w = layer .output_size
188
193
if out_h > in_h or out_w > in_w :
189
194
raise NotImplementedError ()
190
195
batchsize_per_replica = x .size ()[0 ]
@@ -295,6 +300,10 @@ def flops(self, x):
295
300
for dim_size in x .size ():
296
301
flops *= dim_size
297
302
return flops
303
+
304
+ elif layer_type == "Identity" :
305
+ return 0
306
+
298
307
elif hasattr (layer , "flops" ):
299
308
# If the module already defines a method to compute flops with the signature
300
309
# below, we use it to compute flops
@@ -312,8 +321,16 @@ def _layer_activations(layer, x, out):
312
321
"""
313
322
Computes the number of activations produced by a single layer.
314
323
315
- Activations are counted only for convolutional layers.
324
+ Activations are counted only for convolutional layers. To override this behavior, a
325
+ layer can define a method to compute activations with the signature below, which
326
+ will be used to compute the activations instead.
327
+
328
+ Class MyModule(nn.Module):
329
+ def activations(self, x, out):
330
+ ...
316
331
"""
332
+ if hasattr (layer , "activations" ):
333
+ return layer .activations (x , out )
317
334
return out .numel () if isinstance (layer , (nn .Conv1d , nn .Conv2d , nn .Conv3d )) else 0
318
335
319
336
@@ -338,11 +355,25 @@ def summarize_profiler_info(prof):
338
355
return str
339
356
340
357
341
- def _patched_computation_module (module , compute_list , compute_fn ):
358
+ class _ComplexityComputer :
359
+ def __init__ (self , compute_fn : Callable , count_unique : bool ):
360
+ self .compute_fn = compute_fn
361
+ self .count_unique = count_unique
362
+ self .count = 0
363
+ self .seen_modules = set ()
364
+
365
+ def compute (self , layer , x , out , module_name ):
366
+ if self .count_unique and module_name in self .seen_modules :
367
+ return
368
+ self .count += self .compute_fn (layer , x , out )
369
+ self .seen_modules .add (module_name )
370
+
371
+
372
+ def _patched_computation_module (module , complexity_computer , module_name ):
342
373
"""
343
374
Patch the module to compute a module's parameters, like FLOPs.
344
375
345
- Calls compute_fn and appends the results to compute_list .
376
+ Calls compute_fn and passes the results to the complexity computer .
346
377
"""
347
378
ty = type (module )
348
379
typestring = module .__repr__ ()
@@ -355,7 +386,7 @@ def _original_forward(self, *args, **kwargs):
355
386
356
387
def forward (self , * args , ** kwargs ):
357
388
out = self ._original_forward (* args , ** kwargs )
358
- compute_list . append ( compute_fn ( self , args [0 ], out ) )
389
+ complexity_computer . compute ( self , args [0 ], out , module_name )
359
390
return out
360
391
361
392
def __repr__ (self ):
@@ -364,37 +395,58 @@ def __repr__(self):
364
395
return ComputeModule
365
396
366
397
367
- def modify_forward (model , compute_list , compute_fn ):
398
+ def modify_forward (model , complexity_computer , prefix = "" , patch_attr = None ):
368
399
"""
369
400
Modify forward pass to measure a module's parameters, like FLOPs.
370
401
"""
371
- if is_leaf (model ) or hasattr (model , "flops" ):
372
- model .__class__ = _patched_computation_module (model , compute_list , compute_fn )
402
+ if is_leaf (model ) or (patch_attr is not None and hasattr (model , patch_attr )):
403
+ model .__class__ = _patched_computation_module (
404
+ model , complexity_computer , prefix
405
+ )
373
406
374
407
else :
375
- for child in model .children ():
376
- modify_forward (child , compute_list , compute_fn )
408
+ for name , child in model .named_children ():
409
+ modify_forward (
410
+ child ,
411
+ complexity_computer ,
412
+ prefix = f"{ prefix } .{ name } " ,
413
+ patch_attr = patch_attr ,
414
+ )
377
415
378
416
return model
379
417
380
418
381
- def restore_forward (model ):
419
+ def restore_forward (model , patch_attr = None ):
382
420
"""
383
- Restore original forward in model:
421
+ Restore original forward in model.
384
422
"""
385
- if is_leaf (model ) or hasattr (model , "flops" ):
423
+ if is_leaf (model ) or ( patch_attr is not None and hasattr (model , patch_attr ) ):
386
424
model .__class__ = model .orig_type
387
425
388
426
else :
389
427
for child in model .children ():
390
- restore_forward (child )
428
+ restore_forward (child , patch_attr = patch_attr )
391
429
392
430
return model
393
431
394
432
395
- def compute_complexity (model , compute_fn , input_shape , input_key = None ):
433
+ def compute_complexity (
434
+ model ,
435
+ compute_fn ,
436
+ input_shape ,
437
+ input_key = None ,
438
+ patch_attr = None ,
439
+ compute_unique = False ,
440
+ ):
396
441
"""
397
442
Compute the complexity of a forward pass.
443
+
444
+ Args:
445
+ compute_unique: If True, the compexity for a given module is only calculated
446
+ once. Otherwise, it is counted every time the module is called.
447
+
448
+ TODO(@mannatsingh): We have some assumptions about only modules which are leaves
449
+ or have patch_attr defined. This should be fixed and generalized if possible.
398
450
"""
399
451
# assertions, input, and upvalue in which we will perform the count:
400
452
assert isinstance (model , nn .Module )
@@ -404,50 +456,43 @@ def compute_complexity(model, compute_fn, input_shape, input_key=None):
404
456
else :
405
457
input = get_model_dummy_input (model , input_shape , input_key )
406
458
407
- compute_list = []
459
+ complexity_computer = _ComplexityComputer ( compute_fn , compute_unique )
408
460
409
461
# measure FLOPs:
410
- modify_forward (model , compute_list , compute_fn )
462
+ modify_forward (model , complexity_computer , patch_attr = patch_attr )
411
463
try :
412
464
# compute complexity in eval mode
413
465
with eval_model (model ), torch .no_grad ():
414
466
model .forward (input )
415
467
except NotImplementedError as err :
416
468
raise err
417
469
finally :
418
- restore_forward (model )
470
+ restore_forward (model , patch_attr = patch_attr )
419
471
420
- return sum ( compute_list )
472
+ return complexity_computer . count
421
473
422
474
423
475
def compute_flops (model , input_shape = (3 , 224 , 224 ), input_key = None ):
424
476
"""
425
477
Compute the number of FLOPs needed for a forward pass.
426
478
"""
427
- return compute_complexity (model , _layer_flops , input_shape , input_key )
479
+ return compute_complexity (
480
+ model , _layer_flops , input_shape , input_key , patch_attr = "flops"
481
+ )
428
482
429
483
430
484
def compute_activations (model , input_shape = (3 , 224 , 224 ), input_key = None ):
431
485
"""
432
486
Compute the number of activations created in a forward pass.
433
487
"""
434
- return compute_complexity (model , _layer_activations , input_shape , input_key )
488
+ return compute_complexity (
489
+ model , _layer_activations , input_shape , input_key , patch_attr = "activations"
490
+ )
435
491
436
492
437
493
def count_params (model ):
438
494
"""
439
495
Count the number of parameters in a model.
440
496
"""
441
497
assert isinstance (model , nn .Module )
442
- count = 0
443
- for child in model .children ():
444
- if is_leaf (child ):
445
- if hasattr (child , "_mask" ): # for masked modules (like LGC)
446
- count += child ._mask .long ().sum ().item ()
447
- # FIXME: BatchNorm parameters in LGC are not counted.
448
- else : # for regular modules
449
- for p in child .parameters ():
450
- count += p .nelement ()
451
- else :
452
- count += count_params (child )
453
- return count
498
+ return sum ((parameter .nelement () for parameter in model .parameters ()))
0 commit comments