Skip to content

Commit 33d073b

Browse files
authored
[Feature]: Support calculating FLOPs of detectors (#9777)
1 parent ffc2bb3 commit 33d073b

File tree

4 files changed

+421
-42
lines changed

4 files changed

+421
-42
lines changed

Diff for: .dev_scripts/benchmark_valid_flops.py

+295
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,295 @@
1+
import logging
2+
import re
3+
import tempfile
4+
from argparse import ArgumentParser
5+
from collections import OrderedDict
6+
from functools import partial
7+
from pathlib import Path
8+
9+
import numpy as np
10+
import pandas as pd
11+
import torch
12+
from mmengine import Config, DictAction
13+
from mmengine.analysis import get_model_complexity_info
14+
from mmengine.analysis.print_helper import _format_size
15+
from mmengine.fileio import FileClient
16+
from mmengine.logging import MMLogger
17+
from mmengine.model import revert_sync_batchnorm
18+
from mmengine.runner import Runner
19+
from modelindex.load_model_index import load
20+
from rich.console import Console
21+
from rich.table import Table
22+
from rich.text import Text
23+
from tqdm import tqdm
24+
25+
from mmdet.registry import MODELS
26+
from mmdet.utils import register_all_modules
27+
28+
console = Console()
29+
MMDET_ROOT = Path(__file__).absolute().parents[1]
30+
31+
32+
def parse_args():
33+
parser = ArgumentParser(description='Valid all models in model-index.yml')
34+
parser.add_argument(
35+
'--shape',
36+
type=int,
37+
nargs='+',
38+
default=[1280, 800],
39+
help='input image size')
40+
parser.add_argument(
41+
'--checkpoint_root',
42+
help='Checkpoint file root path. If set, load checkpoint before test.')
43+
parser.add_argument('--img', default='demo/demo.jpg', help='Image file')
44+
parser.add_argument('--models', nargs='+', help='models name to inference')
45+
parser.add_argument(
46+
'--batch-size',
47+
type=int,
48+
default=1,
49+
help='The batch size during the inference.')
50+
parser.add_argument(
51+
'--flops', action='store_true', help='Get Flops and Params of models')
52+
parser.add_argument(
53+
'--flops-str',
54+
action='store_true',
55+
help='Output FLOPs and params counts in a string form.')
56+
parser.add_argument(
57+
'--cfg-options',
58+
nargs='+',
59+
action=DictAction,
60+
help='override some settings in the used config, the key-value pair '
61+
'in xxx=yyy format will be merged into config file. If the value to '
62+
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
63+
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
64+
'Note that the quotation marks are necessary and that no white space '
65+
'is allowed.')
66+
parser.add_argument(
67+
'--size_divisor',
68+
type=int,
69+
default=32,
70+
help='Pad the input image, the minimum size that is divisible '
71+
'by size_divisor, -1 means do not pad the image.')
72+
args = parser.parse_args()
73+
return args
74+
75+
76+
def inference(config_file, checkpoint, work_dir, args, exp_name):
77+
logger = MMLogger.get_instance(name='MMLogger')
78+
logger.warning('if you want test flops, please make sure torch>=1.12')
79+
cfg = Config.fromfile(config_file)
80+
cfg.work_dir = work_dir
81+
cfg.load_from = checkpoint
82+
cfg.log_level = 'WARN'
83+
cfg.experiment_name = exp_name
84+
if args.cfg_options is not None:
85+
cfg.merge_from_dict(args.cfg_options)
86+
87+
# forward the model
88+
result = {'model': config_file.stem}
89+
90+
if args.flops:
91+
92+
if len(args.shape) == 1:
93+
h = w = args.shape[0]
94+
elif len(args.shape) == 2:
95+
h, w = args.shape
96+
else:
97+
raise ValueError('invalid input shape')
98+
divisor = args.size_divisor
99+
if divisor > 0:
100+
h = int(np.ceil(h / divisor)) * divisor
101+
w = int(np.ceil(w / divisor)) * divisor
102+
103+
input_shape = (3, h, w)
104+
result['resolution'] = input_shape
105+
106+
try:
107+
cfg = Config.fromfile(config_file)
108+
if hasattr(cfg, 'head_norm_cfg'):
109+
cfg['head_norm_cfg'] = dict(type='SyncBN', requires_grad=True)
110+
cfg['model']['roi_head']['bbox_head']['norm_cfg'] = dict(
111+
type='SyncBN', requires_grad=True)
112+
cfg['model']['roi_head']['mask_head']['norm_cfg'] = dict(
113+
type='SyncBN', requires_grad=True)
114+
115+
if args.cfg_options is not None:
116+
cfg.merge_from_dict(args.cfg_options)
117+
118+
model = MODELS.build(cfg.model)
119+
input = torch.rand(1, *input_shape)
120+
if torch.cuda.is_available():
121+
model.cuda()
122+
input = input.cuda()
123+
model = revert_sync_batchnorm(model)
124+
inputs = (input, )
125+
model.eval()
126+
outputs = get_model_complexity_info(
127+
model, input_shape, inputs, show_table=False, show_arch=False)
128+
flops = outputs['flops']
129+
params = outputs['params']
130+
activations = outputs['activations']
131+
result['Get Types'] = 'direct'
132+
except: # noqa 772
133+
logger = MMLogger.get_instance(name='MMLogger')
134+
logger.warning(
135+
'Direct get flops failed, try to get flops with data')
136+
cfg = Config.fromfile(config_file)
137+
if hasattr(cfg, 'head_norm_cfg'):
138+
cfg['head_norm_cfg'] = dict(type='SyncBN', requires_grad=True)
139+
cfg['model']['roi_head']['bbox_head']['norm_cfg'] = dict(
140+
type='SyncBN', requires_grad=True)
141+
cfg['model']['roi_head']['mask_head']['norm_cfg'] = dict(
142+
type='SyncBN', requires_grad=True)
143+
data_loader = Runner.build_dataloader(cfg.val_dataloader)
144+
data_batch = next(iter(data_loader))
145+
model = MODELS.build(cfg.model)
146+
if torch.cuda.is_available():
147+
model = model.cuda()
148+
model = revert_sync_batchnorm(model)
149+
model.eval()
150+
_forward = model.forward
151+
data = model.data_preprocessor(data_batch)
152+
del data_loader
153+
model.forward = partial(
154+
_forward, data_samples=data['data_samples'])
155+
outputs = get_model_complexity_info(
156+
model,
157+
input_shape,
158+
data['inputs'],
159+
show_table=False,
160+
show_arch=False)
161+
flops = outputs['flops']
162+
params = outputs['params']
163+
activations = outputs['activations']
164+
result['Get Types'] = 'dataloader'
165+
166+
if args.flops_str:
167+
flops = _format_size(flops)
168+
params = _format_size(params)
169+
activations = _format_size(activations)
170+
171+
result['flops'] = flops
172+
result['params'] = params
173+
174+
return result
175+
176+
177+
def show_summary(summary_data, args):
178+
table = Table(title='Validation Benchmark Regression Summary')
179+
table.add_column('Model')
180+
table.add_column('Validation')
181+
table.add_column('Resolution (c, h, w)')
182+
if args.flops:
183+
table.add_column('Flops', justify='right', width=11)
184+
table.add_column('Params', justify='right')
185+
186+
for model_name, summary in summary_data.items():
187+
row = [model_name]
188+
valid = summary['valid']
189+
color = 'green' if valid == 'PASS' else 'red'
190+
row.append(f'[{color}]{valid}[/{color}]')
191+
if valid == 'PASS':
192+
row.append(str(summary['resolution']))
193+
if args.flops:
194+
row.append(str(summary['flops']))
195+
row.append(str(summary['params']))
196+
table.add_row(*row)
197+
198+
console.print(table)
199+
table_data = {
200+
x.header: [Text.from_markup(y).plain for y in x.cells]
201+
for x in table.columns
202+
}
203+
table_pd = pd.DataFrame(table_data)
204+
table_pd.to_csv('./mmdetection_flops.csv')
205+
206+
207+
# Sample test whether the inference code is correct
208+
def main(args):
209+
register_all_modules()
210+
model_index_file = MMDET_ROOT / 'model-index.yml'
211+
model_index = load(str(model_index_file))
212+
model_index.build_models_with_collections()
213+
models = OrderedDict({model.name: model for model in model_index.models})
214+
215+
logger = MMLogger(
216+
'validation',
217+
logger_name='validation',
218+
log_file='benchmark_test_image.log',
219+
log_level=logging.INFO)
220+
221+
if args.models:
222+
patterns = [
223+
re.compile(pattern.replace('+', '_')) for pattern in args.models
224+
]
225+
filter_models = {}
226+
for k, v in models.items():
227+
k = k.replace('+', '_')
228+
if any([re.match(pattern, k) for pattern in patterns]):
229+
filter_models[k] = v
230+
if len(filter_models) == 0:
231+
print('No model found, please specify models in:')
232+
print('\n'.join(models.keys()))
233+
return
234+
models = filter_models
235+
236+
summary_data = {}
237+
tmpdir = tempfile.TemporaryDirectory()
238+
for model_name, model_info in tqdm(models.items()):
239+
240+
if model_info.config is None:
241+
continue
242+
243+
model_info.config = model_info.config.replace('%2B', '+')
244+
config = Path(model_info.config)
245+
246+
try:
247+
config.exists()
248+
except: # noqa 722
249+
logger.error(f'{model_name}: {config} not found.')
250+
continue
251+
252+
logger.info(f'Processing: {model_name}')
253+
254+
http_prefix = 'https://download.openmmlab.com/mmdetection/'
255+
if args.checkpoint_root is not None:
256+
root = args.checkpoint_root
257+
if 's3://' in args.checkpoint_root:
258+
from petrel_client.common.exception import AccessDeniedError
259+
file_client = FileClient.infer_client(uri=root)
260+
checkpoint = file_client.join_path(
261+
root, model_info.weights[len(http_prefix):])
262+
try:
263+
exists = file_client.exists(checkpoint)
264+
except AccessDeniedError:
265+
exists = False
266+
else:
267+
checkpoint = Path(root) / model_info.weights[len(http_prefix):]
268+
exists = checkpoint.exists()
269+
if exists:
270+
checkpoint = str(checkpoint)
271+
else:
272+
print(f'WARNING: {model_name}: {checkpoint} not found.')
273+
checkpoint = None
274+
else:
275+
checkpoint = None
276+
277+
try:
278+
# build the model from a config file and a checkpoint file
279+
result = inference(MMDET_ROOT / config, checkpoint, tmpdir.name,
280+
args, model_name)
281+
result['valid'] = 'PASS'
282+
except Exception: # noqa 722
283+
import traceback
284+
logger.error(f'"{config}" :\n{traceback.format_exc()}')
285+
result = {'valid': 'FAIL'}
286+
287+
summary_data[model_name] = result
288+
289+
tmpdir.cleanup()
290+
show_summary(summary_data, args)
291+
292+
293+
if __name__ == '__main__':
294+
args = parse_args()
295+
main(args)

Diff for: configs/simple_copy_paste/metafile.yml

+4-4
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ Collections:
2525
Models:
2626
- Name: mask-rcnn_r50_fpn_syncbn-all_rpn-2conv_ssj_32x2_270k_coco
2727
In Collection: SimpleCopyPaste
28-
Config: configs/simplecopypaste/mask-rcnn_r50_fpn_rpn-2conv_4conv1fc_syncbn-all_32xb2-ssj-270k_coco.py
28+
Config: configs/simple_copy_paste/mask-rcnn_r50_fpn_rpn-2conv_4conv1fc_syncbn-all_32xb2-ssj-270k_coco.py
2929
Metadata:
3030
Training Memory (GB): 7.2
3131
Iterations: 270000
@@ -42,7 +42,7 @@ Models:
4242

4343
- Name: mask-rcnn_r50_fpn_syncbn-all_rpn-2conv_ssj_32x2_90k_coco
4444
In Collection: SimpleCopyPaste
45-
Config: configs/simplecopypaste/mask-rcnn_r50_fpn_rpn-2conv_4conv1fc_syncbn-all_32xb2-ssj-90k_coco.py
45+
Config: configs/simple_copy_paste/mask-rcnn_r50_fpn_rpn-2conv_4conv1fc_syncbn-all_32xb2-ssj-90k_coco.py
4646
Metadata:
4747
Training Memory (GB): 7.2
4848
Iterations: 90000
@@ -59,7 +59,7 @@ Models:
5959

6060
- Name: mask-rcnn_r50_fpn_syncbn-all_rpn-2conv_ssj_scp_32x2_270k_coco
6161
In Collection: SimpleCopyPaste
62-
Config: configs/simplecopypaste/mask-rcnn_r50_fpn_rpn-2conv_4conv1fc_syncbn-all_32xb2-ssj-scp-270k_coco.py
62+
Config: configs/simple_copy_paste/mask-rcnn_r50_fpn_rpn-2conv_4conv1fc_syncbn-all_32xb2-ssj-scp-270k_coco.py
6363
Metadata:
6464
Training Memory (GB): 7.2
6565
Iterations: 270000
@@ -76,7 +76,7 @@ Models:
7676

7777
- Name: mask-rcnn_r50_fpn_syncbn-all_rpn-2conv_ssj_scp_32x2_90k_coco
7878
In Collection: SimpleCopyPaste
79-
Config: configs/simplecopypaste/mask-rcnn_r50_fpn_rpn-2conv_4conv1fc_syncbn-all_32xb2-ssj-scp-90k_coco.py
79+
Config: configs/simple_copy_paste/mask-rcnn_r50_fpn_rpn-2conv_4conv1fc_syncbn-all_32xb2-ssj-scp-90k_coco.py
8080
Metadata:
8181
Training Memory (GB): 7.2
8282
Iterations: 90000

Diff for: mmdet/models/roi_heads/sparse_roi_head.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -578,7 +578,11 @@ def forward(self, x: Tuple[Tensor], rpn_results_list: InstanceList,
578578
batch_img_metas=batch_img_metas,
579579
batch_gt_instances=batch_gt_instances)
580580
bbox_results.pop('loss_bbox')
581-
all_stage_bbox_results.append((bbox_results, ))
581+
# torch.jit does not support obj:SamplingResult
582+
bbox_results.pop('results_list')
583+
bbox_res = bbox_results.copy()
584+
bbox_res.pop('sampling_results')
585+
all_stage_bbox_results.append((bbox_res, ))
582586

583587
if self.with_mask:
584588
attn_feats = bbox_results['attn_feats']

0 commit comments

Comments
 (0)