Skip to content

[Bug] Endless loop on benchmark.py #3088

Open
@AinaraC

Description

@AinaraC

Prerequisite

Task

I'm using the official example scripts/configs for the officially supported tasks/models/datasets.

Branch

main branch https://github.com/open-mmlab/mmdetection3d

Environment

No necessary

Reproduces the problem - code sample

# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import time

import torch
from mmengine import Config
from mmengine.device import get_device
from mmengine.registry import init_default_scope
from mmengine.runner import Runner, autocast, load_checkpoint

from mmdet3d.registry import MODELS
from tools.misc.fuse_conv_bn import fuse_module


def parse_args():
    parser = argparse.ArgumentParser(description='MMDet benchmark a model')
    parser.add_argument('config', help='test config file path')
    parser.add_argument('checkpoint', help='checkpoint file')
    parser.add_argument('--samples', default=2000, help='samples to benchmark')
    parser.add_argument(
        '--log-interval', default=50, help='interval of logging')
    parser.add_argument(
        '--amp',
        action='store_true',
        help='Whether to use automatic mixed precision inference')
    parser.add_argument(
        '--fuse-conv-bn',
        action='store_true',
        help='Whether to fuse conv and bn, this will slightly increase'
        'the inference speed')
    args = parser.parse_args()
    return args


def main():
    args = parse_args()
    init_default_scope('mmdet3d')

    # build config and set cudnn_benchmark
    cfg = Config.fromfile(args.config)

    if cfg.env_cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True

    # build dataloader
    dataloader = Runner.build_dataloader(cfg.test_dataloader)

    # build model and load checkpoint
    model = MODELS.build(cfg.model)
    load_checkpoint(model, args.checkpoint, map_location='cpu')
    if args.fuse_conv_bn:
        model = fuse_module(model)
    model.to(get_device())
    model.eval()

    # the first several iterations may be very slow so skip them
    num_warmup = 5
    pure_inf_time = 0

    # benchmark with several samples and take the average
    for i, data in enumerate(dataloader):
        torch.cuda.synchronize()
        start_time = time.perf_counter()

        with autocast(enabled=args.amp):
            model.test_step(data)

        torch.cuda.synchronize()
        elapsed = time.perf_counter() - start_time
        if i >= num_warmup:
            pure_inf_time += elapsed
            if (i + 1) % args.log_interval == 0:
                fps = (i + 1 - num_warmup) / pure_inf_time
                print(f'Done sample [{i + 1:<3}/ {args.samples}], '
                      f'fps: {fps:.1f} sample / s')
  
        if (i + 1) == args.samples:
            pure_inf_time += elapsed
            fps = (i + 1 - num_warmup) / pure_inf_time
            print(f'Overall fps: {fps:.1f} sample / s')
            break


if __name__ == '__main__':
    main()

Reproduces the problem - command or script

python3 tools/analysis_tools/benchmark.py configs/mvxnet/mvxnet_fpn_dv_second_secfpn_8xb2-80e_kitti-3d-3class.py work_dirs/mvxnet_fpn_dv_second_secfpn_8xb2-80e_kitti-3d-3class/epoch_40.pth --samples 200

Reproduces the problem - error message

Done sample [50 / 200], fps: 2.7 sample / s
Done sample [100/ 200], fps: 2.7 sample / s
Done sample [150/ 200], fps: 2.7 sample / s
Done sample [200/ 200], fps: 2.7 sample / s
Done sample [250/ 200], fps: 2.7 sample / s
Done sample [300/ 200], fps: 2.7 sample / s
Done sample [350/ 200], fps: 2.7 sample / s
Done sample [400/ 200], fps: 2.7 sample / s
Done sample [450/ 200], fps: 2.7 sample / s

Additional information

When I define the parameter samples, the parser interprets it as a string, so the condition (i + 1) == args.samples is never true. When I define log-interval, I get an error because the operation (i + 1) % args.log_interval == 0 can´t process a string.

Solution:
Define type as int.

parser.add_argument('--samples',  type=int, default=2000, help='samples to benchmark')
parser.add_argument(
        '--log-interval', type=int, default=50, help='interval of logging')

Fix the error, please.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions