Skip to content

MVX-Net low perfomance pedestrian AOS #3080

Open
@AinaraC

Description

@AinaraC

Prerequisite

Task

I'm using the official example scripts/configs for the officially supported tasks/models/datasets.

Branch

main branch https://github.com/open-mmlab/mmdetection3d

Environment

sys.platform: linux
Python: 3.10.12 (main, Nov 6 2024, 20:22:13) [GCC 11.4.0]
CUDA available: True
MUSA available: False
numpy_random_seed: 2147483648
GPU 0: Orin
CUDA_HOME: /usr/local/cuda
NVCC: Cuda compilation tools, release 12.6, V12.6.77
GCC: aarch64-linux-gnu-gcc (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
PyTorch: 2.5.0
PyTorch compiling details: PyTorch built with:

  • GCC 11.4
  • C++ Version: 201703
  • OpenMP 201511 (a.k.a. OpenMP 4.5)
  • LAPACK is enabled (usually provided by MKL)
  • NNPACK is enabled
  • CPU capability usage: NO AVX
  • CUDA Runtime 12.6
  • NVCC architecture flags: -gencode;arch=compute_87,code=sm_87
  • CuDNN 90.4
  • Build settings: BLAS_INFO=open, BUILD_TYPE=Release, CUDA_VERSION=12.6, CUDNN_VERSION=9.4.0, CXX_COMPILER=/usr/bin/c++, CXX_FLAGS= -D_GLIBCXX_USE_CXX11_ABI=1 -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOROCTRACER -DLIBKINETO_NOXPUPTI=ON -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -Wall -Wextra -Werror=return-type -Werror=non-virtual-dtor -Werror=range-loop-construct -Werror=bool-operation -Wnarrowing -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-unused-parameter -Wno-strict-overflow -Wno-strict-aliasing -Wno-stringop-overflow -Wsuggest-override -Wno-psabi -Wno-error=old-style-cast -Wno-missing-braces -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Wno-stringop-overflow, LAPACK_INFO=open, TORCH_VERSION=2.5.0, USE_CUDA=ON, USE_CUDNN=ON, USE_CUSPARSELT=OFF, USE_EIGEN_FOR_BLAS=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_GLOO=1, USE_MKL=OFF, USE_MKLDNN=OFF, USE_MPI=0, USE_NCCL=1, USE_NNPACK=1, USE_OPENMP=ON, USE_ROCM=OFF, USE_ROCM_KERNEL_ASSERT=OFF,

TorchVision: 0.20.0
OpenCV: 4.10.0-dev
MMEngine: 0.10.5
MMDetection: 3.2.0
MMDetection3D: 1.4.0+962f093
spconv2.0: False

Reproduces the problem - code sample

This schedule is mainly used by models with dynamic voxelization

optimizer

lr = 0.003 # max learning rate
optim_wrapper = dict(
type='OptimWrapper',
optimizer=dict(
type='AdamW', lr=lr, weight_decay=0.001, betas=(0.95, 0.99)),
clip_grad=dict(max_norm=10, norm_type=2),
)

param_scheduler = [
dict(type='LinearLR', start_factor=0.1, by_epoch=False, begin=0, end=1000),
dict(
type='CosineAnnealingLR',
begin=0,
T_max=60,
end=60,
by_epoch=True,
eta_min=1e-5)
]

training schedule for 1x

train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=60, val_interval=1)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')

Default setting for scaling LR automatically

- enable means enable scaling LR automatically

or not by default.

- base_batch_size = (8 GPUs) x (2 samples per GPU).

auto_scale_lr = dict(enable=False, base_batch_size=2)

base = ['../base/schedules/cosine.py', '../base/default_runtime.py']

model settings

voxel_size = [0.05, 0.05, 0.1]
point_cloud_range = [0, -40, -3, 70.4, 40, 1]

model = dict(
type='DynamicMVXFasterRCNN',
data_preprocessor=dict(
type='Det3DDataPreprocessor',
voxel=True,
voxel_type='dynamic',
voxel_layer=dict(
max_num_points=-1,
point_cloud_range=point_cloud_range,
voxel_size=voxel_size,
max_voxels=(-1, -1)),
mean=[102.9801, 115.9465, 122.7717],
std=[1.0, 1.0, 1.0],
bgr_to_rgb=False,
pad_size_divisor=32),
img_backbone=dict(
type='mmdet.ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=False),
norm_eval=True,
style='caffe'),
img_neck=dict(
type='mmdet.FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
# make the image features more stable numerically to avoid loss nan
norm_cfg=dict(type='BN', requires_grad=False),
num_outs=5),
pts_voxel_encoder=dict(
type='DynamicVFE',
in_channels=4,
feat_channels=[64, 64],
with_distance=False,
voxel_size=voxel_size,
with_cluster_center=True,
with_voxel_center=True,
point_cloud_range=point_cloud_range,
fusion_layer=dict(
type='PointFusion',
img_channels=256,
pts_channels=64,
mid_channels=128,
out_channels=128,
img_levels=[0, 1, 2, 3, 4],
align_corners=False,
activate_out=True,
fuse_out=False)),
pts_middle_encoder=dict(
type='SparseEncoder',
in_channels=128,
sparse_shape=[41, 1600, 1408],
order=('conv', 'norm', 'act')),
pts_backbone=dict(
type='SECOND',
in_channels=256,
layer_nums=[5, 5],
layer_strides=[1, 2],
out_channels=[128, 256]),
pts_neck=dict(
type='SECONDFPN',
in_channels=[128, 256],
upsample_strides=[1, 2],
out_channels=[256, 256]),
pts_bbox_head=dict(
type='Anchor3DHead',
num_classes=3,
in_channels=512,
feat_channels=512,
use_direction_classifier=True,
anchor_generator=dict(
type='Anchor3DRangeGenerator',
ranges=[
[0, -40.0, -0.6, 70.4, 40.0, -0.6],
[0, -40.0, -0.6, 70.4, 40.0, -0.6],
[0, -40.0, -1.78, 70.4, 40.0, -1.78],
],
sizes=[[0.8, 0.6, 1.73], [1.76, 0.6, 1.73], [3.9, 1.6, 1.56]],
rotations=[0, 1.57],
reshape_out=False),
assigner_per_size=True,
diff_rad_by_sin=True,
assign_per_class=True,
bbox_coder=dict(type='DeltaXYZWLHRBBoxCoder'),
loss_cls=dict(
type='mmdet.FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_bbox=dict(
type='mmdet.SmoothL1Loss', beta=1.0 / 9.0, loss_weight=2.0),
loss_dir=dict(
type='mmdet.CrossEntropyLoss', use_sigmoid=False,
loss_weight=0.2)),
# model training and testing settings
train_cfg=dict(
type='EpochBasedTrainLoop',
pts=dict(
assigner=[
dict( # for Pedestrian
type='Max3DIoUAssigner',
iou_calculator=dict(type='BboxOverlapsNearest3D'),
pos_iou_thr=0.35,
neg_iou_thr=0.2,
min_pos_iou=0.2,
ignore_iof_thr=-1),
dict( # for Cyclist
type='Max3DIoUAssigner',
iou_calculator=dict(type='BboxOverlapsNearest3D'),
pos_iou_thr=0.35,
neg_iou_thr=0.2,
min_pos_iou=0.2,
ignore_iof_thr=-1),
dict( # for Car
type='Max3DIoUAssigner',
iou_calculator=dict(type='BboxOverlapsNearest3D'),
pos_iou_thr=0.6,
neg_iou_thr=0.45,
min_pos_iou=0.45,
ignore_iof_thr=-1),
],
allowed_border=0,
pos_weight=-1,
debug=False)),
test_cfg=dict(
pts=dict(
use_rotate_nms=True,
nms_across_levels=False,
nms_thr=0.01,
score_thr=0.1,
min_bbox_size=0,
nms_pre=100,
max_num=50)))

dataset settings

dataset_type = 'KittiDataset'
data_root = 'data/kitti/'
class_names = ['Pedestrian', 'Cyclist', 'Car']
metainfo = dict(classes=class_names)
input_modality = dict(use_lidar=True, use_camera=True)
backend_args = None
train_pipeline = [
dict(
type='LoadPointsFromFile',
coord_type='LIDAR',
load_dim=4,
use_dim=4,
backend_args=backend_args),
dict(type='LoadImageFromFile', backend_args=backend_args),
dict(type='LoadAnnotations3D', with_bbox_3d=True, with_label_3d=True),
dict(
type='RandomResize', scale=[(640, 192), (2560, 768)], keep_ratio=True),
dict(
type='GlobalRotScaleTrans',
rot_range=[-0.78539816, 0.78539816],
scale_ratio_range=[0.95, 1.05],
translation_std=[0.2, 0.2, 0.2]),
dict(type='RandomFlip3D', flip_ratio_bev_horizontal=0.5),
dict(type='PointsRangeFilter', point_cloud_range=point_cloud_range),
dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range),
dict(type='PointShuffle'),
dict(
type='Pack3DDetInputs',
keys=[
'points', 'img', 'gt_bboxes_3d', 'gt_labels_3d', 'gt_bboxes',
'gt_labels'
])
]
test_pipeline = [
dict(
type='LoadPointsFromFile',
coord_type='LIDAR',
load_dim=4,
use_dim=4,
backend_args=backend_args),
dict(type='LoadImageFromFile', backend_args=backend_args),
dict(
type='MultiScaleFlipAug3D',
img_scale=(1280, 384),
pts_scale_ratio=1,
flip=False,
transforms=[
# Temporary solution, fix this after refactor the augtest
dict(type='Resize', scale=0, keep_ratio=True),
dict(
type='GlobalRotScaleTrans',
rot_range=[0, 0],
scale_ratio_range=[1., 1.],
translation_std=[0, 0, 0]),
dict(type='RandomFlip3D'),
dict(
type='PointsRangeFilter', point_cloud_range=point_cloud_range),
]),
dict(type='Pack3DDetInputs', keys=['points', 'img'])
]
modality = dict(use_lidar=True, use_camera=True)
train_dataloader = dict(
batch_size=2,
num_workers=2,
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dict(
type='RepeatDataset',
times=2,
dataset=dict(
type=dataset_type,
data_root=data_root,
modality=modality,
ann_file='kitti_infos_train.pkl',
data_prefix=dict(
pts='training/velodyne_reduced', img='training/image_2'),
pipeline=train_pipeline,
filter_empty_gt=False,
metainfo=metainfo,
# we use box_type_3d='LiDAR' in kitti and nuscenes dataset
# and box_type_3d='Depth' in sunrgbd and scannet dataset.
box_type_3d='LiDAR',
backend_args=backend_args)))

val_dataloader = dict(
batch_size=1,
num_workers=1,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
modality=modality,
ann_file='kitti_infos_val.pkl',
data_prefix=dict(
pts='training/velodyne_reduced', img='training/image_2'),
pipeline=test_pipeline,
metainfo=metainfo,
test_mode=True,
box_type_3d='LiDAR',
backend_args=backend_args))
test_dataloader = dict(
batch_size=1,
num_workers=1,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='kitti_infos_val.pkl',
modality=modality,
data_prefix=dict(
pts='training/velodyne_reduced', img='training/image_2'),
pipeline=test_pipeline,
metainfo=metainfo,
test_mode=True,
box_type_3d='LiDAR',
backend_args=backend_args))

optim_wrapper se sobreescribe en cosine.py

optim_wrapper = dict(
optimizer=dict(weight_decay=0.01),
clip_grad=dict(max_norm=35, norm_type=2),
)
val_evaluator = dict(
type='KittiMetric', ann_file='data/kitti/kitti_infos_val.pkl')
test_evaluator = val_evaluator

You may need to download the model first is the network is unstable

load_from = 'https://download.openmmlab.com/mmdetection3d/pretrain_models/mvx_faster_rcnn_detectron2-caffe_20e_coco-pretrain_gt-sample_kitti-3-class_moderate-79.3_20200207-a4a6a3c7.pth' # noqa

Reproduces the problem - command or script

python3 tools/train.py work_dirs/mvxnet_fpn_dv_second_secfpn_8xb2-80e_kitti-3d-3class/mvxnet_fpn_dv_second_secfpn_8xb2-80e_kitti-3d-3class.py

Reproduces the problem - error message

EPOCH 40

Image

EPOCH 49

Image

Additional information

First, I trained MVX-Net for 40 epochs using the original config file (I only changed the batch size to 2 for 1 GPU). The Pedestrian AOS AP results for the 40th epoch were quite low, less than 50%, which I think is very poor. I tried resuming the training for 10 more epochs to improve the results, but it didn't seem to get better (on the contrary, they've worsen) .

I'm using the KITTI dataset with 3769 training and 3712 validation samples. I got the ImageSets from this model: VirConv ImageSets

I would like to know if anyone has achieved better AOS results for pedestrians with MVX-Net and, if so, what changes I should make to the config file. I don't have much experience with model training, so I welcome any advice and support. Thank you!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions