Skip to content
This repository was archived by the owner on Jul 1, 2024. It is now read-only.

Commit e5b9873

Browse files
mannatsinghfacebook-github-bot
authored andcommitted
Profiler bug-fixes and improvements (#482)
Summary: Pull Request resolved: #482 Made the following changes to the profiler code - - The parameter calculation skipped parameters defined in modules, and had unnecessary complexity - updated the logic - `AdaptiveAvgPool2d` handles a single number `output_size` as well - Added FLOPs for the `Identity` module - Added support to specify an `activations` function to fetch activations from (similar to the `flops` function) - Replaced the hacky list append logic with a class, `_ComplexityComputer` - Implemented test cases which verify that the fixes work Reviewed By: vreis Differential Revision: D21009734 fbshipit-source-id: 926d93164c13c6c98eb88f9131d295b61d6acda7
1 parent 5643849 commit e5b9873

File tree

2 files changed

+153
-35
lines changed

2 files changed

+153
-35
lines changed

classy_vision/generic/profiler.py

+79-34
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import collections.abc as abc
88
import logging
99
import operator
10+
from typing import Callable
1011

1112
import torch
1213
import torch.nn as nn
@@ -183,8 +184,12 @@ def flops(self, x):
183184
elif layer_type in ["AdaptiveAvgPool2d"]:
184185
in_h = x.size()[2]
185186
in_w = x.size()[3]
186-
out_h = layer.output_size[0]
187-
out_w = layer.output_size[1]
187+
if isinstance(layer.output_size, int):
188+
out_h, out_w = layer.output_size, layer.output_size
189+
elif len(layer.output_size) == 1:
190+
out_h, out_w = layer.output_size[0], layer.output_size[0]
191+
else:
192+
out_h, out_w = layer.output_size
188193
if out_h > in_h or out_w > in_w:
189194
raise NotImplementedError()
190195
batchsize_per_replica = x.size()[0]
@@ -295,6 +300,10 @@ def flops(self, x):
295300
for dim_size in x.size():
296301
flops *= dim_size
297302
return flops
303+
304+
elif layer_type == "Identity":
305+
return 0
306+
298307
elif hasattr(layer, "flops"):
299308
# If the module already defines a method to compute flops with the signature
300309
# below, we use it to compute flops
@@ -312,8 +321,16 @@ def _layer_activations(layer, x, out):
312321
"""
313322
Computes the number of activations produced by a single layer.
314323
315-
Activations are counted only for convolutional layers.
324+
Activations are counted only for convolutional layers. To override this behavior, a
325+
layer can define a method to compute activations with the signature below, which
326+
will be used to compute the activations instead.
327+
328+
Class MyModule(nn.Module):
329+
def activations(self, x, out):
330+
...
316331
"""
332+
if hasattr(layer, "activations"):
333+
return layer.activations(x, out)
317334
return out.numel() if isinstance(layer, (nn.Conv1d, nn.Conv2d, nn.Conv3d)) else 0
318335

319336

@@ -338,11 +355,25 @@ def summarize_profiler_info(prof):
338355
return str
339356

340357

341-
def _patched_computation_module(module, compute_list, compute_fn):
358+
class _ComplexityComputer:
359+
def __init__(self, compute_fn: Callable, count_unique: bool):
360+
self.compute_fn = compute_fn
361+
self.count_unique = count_unique
362+
self.count = 0
363+
self.seen_modules = set()
364+
365+
def compute(self, layer, x, out, module_name):
366+
if self.count_unique and module_name in self.seen_modules:
367+
return
368+
self.count += self.compute_fn(layer, x, out)
369+
self.seen_modules.add(module_name)
370+
371+
372+
def _patched_computation_module(module, complexity_computer, module_name):
342373
"""
343374
Patch the module to compute a module's parameters, like FLOPs.
344375
345-
Calls compute_fn and appends the results to compute_list.
376+
Calls compute_fn and passes the results to the complexity computer.
346377
"""
347378
ty = type(module)
348379
typestring = module.__repr__()
@@ -355,7 +386,7 @@ def _original_forward(self, *args, **kwargs):
355386

356387
def forward(self, *args, **kwargs):
357388
out = self._original_forward(*args, **kwargs)
358-
compute_list.append(compute_fn(self, args[0], out))
389+
complexity_computer.compute(self, args[0], out, module_name)
359390
return out
360391

361392
def __repr__(self):
@@ -364,37 +395,58 @@ def __repr__(self):
364395
return ComputeModule
365396

366397

367-
def modify_forward(model, compute_list, compute_fn):
398+
def modify_forward(model, complexity_computer, prefix="", patch_attr=None):
368399
"""
369400
Modify forward pass to measure a module's parameters, like FLOPs.
370401
"""
371-
if is_leaf(model) or hasattr(model, "flops"):
372-
model.__class__ = _patched_computation_module(model, compute_list, compute_fn)
402+
if is_leaf(model) or (patch_attr is not None and hasattr(model, patch_attr)):
403+
model.__class__ = _patched_computation_module(
404+
model, complexity_computer, prefix
405+
)
373406

374407
else:
375-
for child in model.children():
376-
modify_forward(child, compute_list, compute_fn)
408+
for name, child in model.named_children():
409+
modify_forward(
410+
child,
411+
complexity_computer,
412+
prefix=f"{prefix}.{name}",
413+
patch_attr=patch_attr,
414+
)
377415

378416
return model
379417

380418

381-
def restore_forward(model):
419+
def restore_forward(model, patch_attr=None):
382420
"""
383-
Restore original forward in model:
421+
Restore original forward in model.
384422
"""
385-
if is_leaf(model) or hasattr(model, "flops"):
423+
if is_leaf(model) or (patch_attr is not None and hasattr(model, patch_attr)):
386424
model.__class__ = model.orig_type
387425

388426
else:
389427
for child in model.children():
390-
restore_forward(child)
428+
restore_forward(child, patch_attr=patch_attr)
391429

392430
return model
393431

394432

395-
def compute_complexity(model, compute_fn, input_shape, input_key=None):
433+
def compute_complexity(
434+
model,
435+
compute_fn,
436+
input_shape,
437+
input_key=None,
438+
patch_attr=None,
439+
compute_unique=False,
440+
):
396441
"""
397442
Compute the complexity of a forward pass.
443+
444+
Args:
445+
compute_unique: If True, the compexity for a given module is only calculated
446+
once. Otherwise, it is counted every time the module is called.
447+
448+
TODO(@mannatsingh): We have some assumptions about only modules which are leaves
449+
or have patch_attr defined. This should be fixed and generalized if possible.
398450
"""
399451
# assertions, input, and upvalue in which we will perform the count:
400452
assert isinstance(model, nn.Module)
@@ -404,50 +456,43 @@ def compute_complexity(model, compute_fn, input_shape, input_key=None):
404456
else:
405457
input = get_model_dummy_input(model, input_shape, input_key)
406458

407-
compute_list = []
459+
complexity_computer = _ComplexityComputer(compute_fn, compute_unique)
408460

409461
# measure FLOPs:
410-
modify_forward(model, compute_list, compute_fn)
462+
modify_forward(model, complexity_computer, patch_attr=patch_attr)
411463
try:
412464
# compute complexity in eval mode
413465
with eval_model(model), torch.no_grad():
414466
model.forward(input)
415467
except NotImplementedError as err:
416468
raise err
417469
finally:
418-
restore_forward(model)
470+
restore_forward(model, patch_attr=patch_attr)
419471

420-
return sum(compute_list)
472+
return complexity_computer.count
421473

422474

423475
def compute_flops(model, input_shape=(3, 224, 224), input_key=None):
424476
"""
425477
Compute the number of FLOPs needed for a forward pass.
426478
"""
427-
return compute_complexity(model, _layer_flops, input_shape, input_key)
479+
return compute_complexity(
480+
model, _layer_flops, input_shape, input_key, patch_attr="flops"
481+
)
428482

429483

430484
def compute_activations(model, input_shape=(3, 224, 224), input_key=None):
431485
"""
432486
Compute the number of activations created in a forward pass.
433487
"""
434-
return compute_complexity(model, _layer_activations, input_shape, input_key)
488+
return compute_complexity(
489+
model, _layer_activations, input_shape, input_key, patch_attr="activations"
490+
)
435491

436492

437493
def count_params(model):
438494
"""
439495
Count the number of parameters in a model.
440496
"""
441497
assert isinstance(model, nn.Module)
442-
count = 0
443-
for child in model.children():
444-
if is_leaf(child):
445-
if hasattr(child, "_mask"): # for masked modules (like LGC)
446-
count += child._mask.long().sum().item()
447-
# FIXME: BatchNorm parameters in LGC are not counted.
448-
else: # for regular modules
449-
for p in child.parameters():
450-
count += p.nelement()
451-
else:
452-
count += count_params(child)
453-
return count
498+
return sum((parameter.nelement() for parameter in model.parameters()))

test/generic_profiler_test.py

+74-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
import unittest
88
from test.generic.config_utils import get_test_model_configs
99

10+
import torch
11+
import torch.nn as nn
1012
from classy_vision.generic.profiler import (
1113
compute_activations,
1214
compute_flops,
@@ -15,8 +17,61 @@
1517
from classy_vision.models import build_model
1618

1719

20+
class TestModule(nn.Module):
21+
def __init__(self):
22+
super().__init__()
23+
# add parameters to the module to affect the parameter count
24+
self.linear = nn.Linear(2, 3, bias=False)
25+
26+
def forward(self, x):
27+
return x + 1
28+
29+
def flops(self, x):
30+
# TODO: this should raise an exception if this function is not defined
31+
# since the FLOPs are indeterminable
32+
33+
# need to define flops since this is an unknown class
34+
return x.numel()
35+
36+
37+
class TestConvModule(nn.Conv2d):
38+
def __init__(self):
39+
super().__init__(2, 3, (4, 4), bias=False)
40+
# add another (unused) layer for added complexity and to test parameters
41+
self.linear = nn.Linear(4, 5, bias=False)
42+
43+
def forward(self, x):
44+
return x
45+
46+
def activations(self, x, out):
47+
# TODO: this should ideally work without this function being defined
48+
return out.numel()
49+
50+
def flops(self, x):
51+
# need to define flops since this is an unknown class
52+
return 0
53+
54+
55+
class TestModel(nn.Module):
56+
def __init__(self):
57+
super().__init__()
58+
self.linear = nn.Linear(300, 300, bias=False)
59+
self.mod = TestModule()
60+
self.conv = TestConvModule()
61+
# we should be able to pick up user defined parameters as well
62+
self.extra_params = nn.Parameter(torch.randn(10, 10))
63+
# we shouldn't count flops for an unused layer
64+
self.unused_linear = nn.Linear(2, 2, bias=False)
65+
66+
def forward(self, x):
67+
out = self.conv(x)
68+
out = out.view(out.shape[0], -1)
69+
out = self.mod(out)
70+
return self.linear(out)
71+
72+
1873
class TestProfilerFunctions(unittest.TestCase):
19-
def test_complexity_calculation(self) -> None:
74+
def test_complexity_calculation_resnext(self) -> None:
2075
model_configs = get_test_model_configs()
2176
# make sure there are three configs returned
2277
self.assertEqual(len(model_configs), 3)
@@ -34,3 +89,21 @@ def test_complexity_calculation(self) -> None:
3489
self.assertEqual(compute_activations(model) // 10 ** 6, m_activations)
3590
self.assertEqual(compute_flops(model) // 10 ** 6, m_flops)
3691
self.assertEqual(count_params(model) // 10 ** 6, m_params)
92+
93+
def test_complexity_calculation(self) -> None:
94+
model = TestModel()
95+
input_shape = (3, 10, 10)
96+
num_elems = 3 * 10 * 10
97+
self.assertEqual(compute_activations(model, input_shape=input_shape), num_elems)
98+
self.assertEqual(
99+
compute_flops(model, input_shape=input_shape),
100+
num_elems
101+
+ 0
102+
+ (300 * 300), # TestModule + TestConvModule + TestModel.linear;
103+
# TestModel.unused_linear is unused and shouldn't be counted
104+
)
105+
self.assertEqual(
106+
count_params(model),
107+
(2 * 3) + (2 * 3 * 4 * 4) + (4 * 5) + (300 * 300) + (10 * 10) + (2 * 2),
108+
) # TestModule.linear + TestConvModule + TestConvModule.linear +
109+
# TestModel.linear + TestModel.extra_params + TestModel.unused_linear

0 commit comments

Comments
 (0)