Description
Prerequisite
- I have searched Issues and Discussions but cannot get the expected help.
- I have read the FAQ documentation but cannot get the expected help.
- The bug has not been fixed in the latest version (dev-1.x) or latest version (dev-1.0).
Task
I have modified the scripts/configs, or I'm working on my own tasks/models/datasets.
Branch
main branch https://github.com/open-mmlab/mmdetection3d
Environment
sys.platform: linux
Python: 3.8.20 (default, Oct 3 2024, 15:24:27) [GCC 11.2.0]
CUDA available: True
MUSA available: False
numpy_random_seed: 2147483648
GPU 0: NVIDIA GeForce RTX 4070
CUDA_HOME: /usr
NVCC: Cuda compilation tools, release 12.1, V12.1.66
GCC: gcc (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0
PyTorch: 2.1.0+cu121
PyTorch compiling details: PyTorch built with:
- GCC 9.3
- C++ Version: 201703
- Intel(R) oneAPI Math Kernel Library Version 2022.2-Product Build 20220804 for Intel(R) 64 architecture applications
- Intel(R) MKL-DNN v3.1.1 (Git Hash 64f6bcbcbab628e96f33a62c3e975f8535a7bde4)
- OpenMP 201511 (a.k.a. OpenMP 4.5)
- LAPACK is enabled (usually provided by MKL)
- NNPACK is enabled
- CPU capability usage: AVX2
- CUDA Runtime 12.1
- NVCC architecture flags: -gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86;-gencode;arch=compute_90,code=sm_90
- CuDNN 8.9.2
- Magma 2.6.1
- Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=12.1, CUDNN_VERSION=8.9.2, CXX_COMPILER=/opt/rh/devtoolset-9/root/usr/bin/c++, CXX_FLAGS= -D_GLIBCXX_USE_CXX11_ABI=0 -fabi-version=11 -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOROCTRACER -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -Wall -Wextra -Werror=return-type -Werror=non-virtual-dtor -Werror=bool-operation -Wnarrowing -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-strict-overflow -Wno-strict-aliasing -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=old-style-cast -Wno-invalid-partial-specialization -Wno-unused-private-field -Wno-aligned-allocation-unavailable -Wno-missing-braces -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Werror=cast-function-type -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_DISABLE_GPU_ASSERTS=ON, TORCH_VERSION=2.1.0, USE_CUDA=ON, USE_CUDNN=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=1, USE_NNPACK=ON, USE_OPENMP=ON, USE_ROCM=OFF,
TorchVision: 0.16.0+cu121
OpenCV: 4.11.0
MMEngine: 0.10.6
MMDetection: 3.3.0
MMDetection3D: 1.4.0+
spconv2.0: False
Configs
`base = ['../../../configs/base/schedules/cosine.py', '../../../configs/base/default_runtime.py']
custom_imports = dict(
imports=['projects.fusion'], allow_failed_imports=False)
voxel_size = [0.05, 0.05, 0.1]
point_cloud_range = [0, -40, -3, 70.4, 40, 1]
model = dict(
type='BAFusion',
data_preprocessor=dict(
type='Det3DDataPreprocessor',
voxel=True,
voxel_type='hard',
voxel_layer=dict(
max_num_points=10,
point_cloud_range=point_cloud_range,
voxel_size=voxel_size,
max_voxels=[12000, 16000],
),
mean=[102.9801, 115.9465, 122.7717],
std=[1.0, 1.0, 1.0],
bgr_to_rgb=False,
pad_size_divisor=32),
img_backbone=dict(
type='mmdet.ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=False),
norm_eval=True,
style='caffe'),
img_neck=dict(
type='mmdet.FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
norm_cfg=dict(type='BN', requires_grad=False),
num_outs=5),
pts_voxel_encoder=dict(
type='HardSimpleVFE',
num_features=5
),
pts_middle_encoder=dict(
type='SparseEncoder',
in_channels=128,
sparse_shape=[41, 1600, 1408],
order=('conv', 'norm', 'act')),
pts_backbone=dict(
type='SECOND',
in_channels=256,
layer_nums=[5, 5],
layer_strides=[1, 2],
out_channels=[128, 256]),
pts_neck=dict(
type='SECONDFPN',
in_channels=[128, 256],
upsample_strides=[1, 2],
out_channels=[256, 256]),
view_transform=dict(
type='DepthLSSTransform',
in_channels=256,
out_channels=80,
image_size=[256, 704],
feature_size=[32, 88],
xbound=[00.0, 70.4, 0.3],
ybound=[-40.0, 40.0, 0.3],
zbound=[-10.0, 10.0, 20.0],
dbound=[1.0, 60.0, 0.5],
downsample=2),
fusion_layer=dict(
type='ConvFuser', in_channels=512, out_channels=256),
pts_bbox_head=dict(
type='Anchor3DHead',
num_classes=3,
in_channels=512,
feat_channels=512,
use_direction_classifier=True,
anchor_generator=dict(
type='Anchor3DRangeGenerator',
ranges=[
[0, -40.0, -0.6, 70.4, 40.0, -0.6],
[0, -40.0, -0.6, 70.4, 40.0, -0.6],
[0, -40.0, -1.78, 70.4, 40.0, -1.78],
],
sizes=[[0.8, 0.6, 1.73],
[1.76, 0.6, 1.73],
[3.9, 1.6, 1.56]],
rotations=[0, 1.57],
reshape_out=False,
scales = [1,1]),
assigner_per_size=True,
diff_rad_by_sin=True,
assign_per_class=True,
bbox_coder=dict(type='DeltaXYZWLHRBBoxCoder'),
loss_cls=dict(
type='mmdet.FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_bbox=dict(
type='mmdet.SmoothL1Loss', beta=1.0 / 9.0, loss_weight=2.0),
loss_dir=dict(
type='mmdet.CrossEntropyLoss', use_sigmoid=False,
loss_weight=0.2)),
train_cfg=dict(
pts=dict(
assigner=[
dict(
type='Max3DIoUAssigner',
iou_calculator=dict(type='BboxOverlapsNearest3D'),
pos_iou_thr=0.35,
neg_iou_thr=0.2,
min_pos_iou=0.2,
ignore_iof_thr=-1),
dict(
type='Max3DIoUAssigner',
iou_calculator=dict(type='BboxOverlapsNearest3D'),
pos_iou_thr=0.35,
neg_iou_thr=0.2,
min_pos_iou=0.2,
ignore_iof_thr=-1),
dict(
type='Max3DIoUAssigner',
iou_calculator=dict(type='BboxOverlapsNearest3D'),
pos_iou_thr=0.6,
neg_iou_thr=0.45,
min_pos_iou=0.45,
ignore_iof_thr=-1),
],
allowed_border=0,
pos_weight=-1,
debug=False)),
test_cfg=dict(
pts=dict(
use_rotate_nms=True,
nms_across_levels=False,
nms_thr=0.01,
score_thr=0.1,
min_bbox_size=0,
nms_pre=100,
max_num=50)))
dataset_type = 'KittiDataset'
data_root = '/home/knight/learn/data'
class_names = ['Pedestrian', 'Cyclist', 'Car']
metainfo = dict(classes=class_names)
input_modality = dict(use_lidar=True, use_camera=True)
backend_args = None
train_pipeline = [
dict(
type='LoadPointsFromFile',
coord_type='LIDAR',
load_dim=4,
use_dim=4,
backend_args=backend_args),
dict(type='LoadImageFromFile', backend_args=backend_args),
dict(type='LoadAnnotations3D', with_bbox_3d=True, with_label_3d=True),
dict(
type='RandomResize', scale=[(640, 192), (2560, 768)], keep_ratio=True),
dict(
type='GlobalRotScaleTrans',
rot_range=[-0.78539816, 0.78539816],
scale_ratio_range=[0.95, 1.05],
translation_std=[0.2, 0.2, 0.2]),
dict(type='RandomFlip3D', flip_ratio_bev_horizontal=0.5),
dict(type='PointsRangeFilter', point_cloud_range=point_cloud_range),
dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range),
dict(type='PointShuffle'),
dict(
type='Pack3DDetInputs',
keys=[
'points', 'img', 'gt_bboxes_3d', 'gt_labels_3d', 'gt_bboxes',
'gt_labels'
])
]
test_pipeline = [
dict(
type='LoadPointsFromFile',
coord_type='LIDAR',
load_dim=4,
use_dim=4,
backend_args=backend_args),
dict(type='LoadImageFromFile', backend_args=backend_args),
dict(
type='MultiScaleFlipAug3D',
img_scale=(1280, 384),
pts_scale_ratio=1,
flip=False,
transforms=[
dict(type='Resize', scale=0, keep_ratio=True),
dict(
type='GlobalRotScaleTrans',
rot_range=[0, 0],
scale_ratio_range=[1., 1.],
translation_std=[0, 0, 0]),
dict(type='RandomFlip3D'),
dict(
type='PointsRangeFilter', point_cloud_range=point_cloud_range),
]),
dict(type='Pack3DDetInputs', keys=['points', 'img'])
]
modality = dict(use_lidar=True, use_camera=True)
train_dataloader = dict(
batch_size=2,
num_workers=2,
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dict(
type='RepeatDataset',
times=2,
dataset=dict(
type=dataset_type,
data_root=data_root,
modality=modality,
ann_file='kitti_infos_train.pkl',
data_prefix=dict(
pts='training/velodyne', img='training/image_2'),
pipeline=train_pipeline,
filter_empty_gt=False,
metainfo=metainfo,
box_type_3d='LiDAR',
backend_args=backend_args)))
val_dataloader = dict(
batch_size=1,
num_workers=1,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
modality=modality,
ann_file='kitti_infos_val.pkl',
data_prefix=dict(
pts='training/velodyne', img='training/image_2'),
pipeline=test_pipeline,
metainfo=metainfo,
test_mode=True,
box_type_3d='LiDAR',
backend_args=backend_args))
test_dataloader = dict(
batch_size=1,
num_workers=1,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='kitti_infos_val.pkl',
modality=modality,
data_prefix=dict(
pts='training/velodyne', img='training/image_2'),
pipeline=test_pipeline,
metainfo=metainfo,
test_mode=True,
box_type_3d='LiDAR',
backend_args=backend_args))
optim_wrapper = dict(
optimizer=dict(weight_decay=0.01),
clip_grad=dict(max_norm=35, norm_type=2),
)
val_evaluator = dict(
type='KittiMetric', ann_file='data/kitti/kitti_infos_val.pkl')
test_evaluator = val_evaluator
vis_backends = [dict(type='LocalVisBackend')]
visualizer = dict(
type='Det3DLocalVisualizer', vis_backends=vis_backends, name='visualizer')
load_from = 'https://download.openmmlab.com/mmdetection3d/pretrain_models/mvx_faster_rcnn_detectron2-caffe_20e_coco-pretrain_gt-sample_kitti-3-class_moderate-79.3_20200207-a4a6a3c7.pth' # noqa
`
Reproduces the problem - code sample
def _loss_by_feat_single(self, cls_score: Tensor, bbox_pred: Tensor,
dir_cls_pred: Tensor, labels: Tensor,
label_weights: Tensor, bbox_targets: Tensor,
bbox_weights: Tensor, dir_targets: Tensor,
dir_weights: Tensor, num_total_samples: int):
"""Calculate loss of Single-level results.
Args:
cls_score (Tensor): Class score in single-level.
bbox_pred (Tensor): Bbox prediction in single-level.
dir_cls_pred (Tensor): Predictions of direction class
in single-level.
labels (Tensor): Labels of class.
label_weights (Tensor): Weights of class loss.
bbox_targets (Tensor): Targets of bbox predictions.
bbox_weights (Tensor): Weights of bbox loss.
dir_targets (Tensor): Targets of direction predictions.
dir_weights (Tensor): Weights of direction loss.
num_total_samples (int): The number of valid samples.
Returns:
tuple[torch.Tensor]: Losses of class, bbox
and direction, respectively.
"""
# classification loss
if num_total_samples is None:
num_total_samples = int(cls_score.shape[0])
labels = labels.reshape(-1)
label_weights = label_weights.reshape(-1)
cls_score = cls_score.permute(0, 2, 3, 1).reshape(-1, self.num_classes)
assert labels.max().item() <= self.num_classes
loss_cls = self.loss_cls(
cls_score, labels, label_weights, avg_factor=num_total_samples)
# regression loss
bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(-1, self.box_code_size)
bbox_targets = bbox_targets.reshape(-1, self.box_code_size)
bbox_weights = bbox_weights.reshape(-1, self.box_code_size)
bg_class_ind = self.num_classes
pos_inds = ((labels >= 0)
& (labels < bg_class_ind)).nonzero(
as_tuple=False).reshape(-1)
num_pos = len(pos_inds)
pos_bbox_pred = bbox_pred[pos_inds]
pos_bbox_targets = bbox_targets[pos_inds]
pos_bbox_weights = bbox_weights[pos_inds]
# dir loss
if self.use_direction_classifier:
dir_cls_pred = dir_cls_pred.permute(0, 2, 3, 1).reshape(-1, 2)
dir_targets = dir_targets.reshape(-1)
dir_weights = dir_weights.reshape(-1)
pos_dir_cls_pred = dir_cls_pred[pos_inds]
pos_dir_targets = dir_targets[pos_inds]
pos_dir_weights = dir_weights[pos_inds]
if num_pos > 0:
code_weight = self.train_cfg.get('code_weight', None)
if code_weight:
pos_bbox_weights = pos_bbox_weights * bbox_weights.new_tensor(
code_weight)
if self.diff_rad_by_sin:
pos_bbox_pred, pos_bbox_targets = self.add_sin_difference(
pos_bbox_pred, pos_bbox_targets)
loss_bbox = self.loss_bbox(
pos_bbox_pred,
pos_bbox_targets,
pos_bbox_weights,
avg_factor=num_total_samples)
# direction classification loss
loss_dir = None
if self.use_direction_classifier:
loss_dir = self.loss_dir(
pos_dir_cls_pred,
pos_dir_targets,
pos_dir_weights,
avg_factor=num_total_samples)
else:
loss_bbox = pos_bbox_pred.sum()
if self.use_direction_classifier:
loss_dir = pos_dir_cls_pred.sum()
return loss_cls, loss_bbox, loss_dir
Reproduces the problem - command or script
python tools/train.py /home/knight/mmdetection3d-main/projects/BAFusion/config/BAFusion_lidar-cam_kitti_3d.py
Reproduces the problem - error message
File "/home/knight/mmdetection3d-main/mmdet3d/models/dense_heads/anchor3d_head.py", line 277, in _loss_by_feat_single
cls_score = cls_score.permute(0, 2, 3, 0).reshape(-1, self.num_classes)
RuntimeError: permute(sparse_coo): number of dimensions in the tensor input does not match the length of the desired ordering of dimensions i.e. input.dim() = 3 is not equal to len(dims) = 4
Additional information
the size of cls_score is [18,100,88], and the number of dimensions is 3, but in the code, number of dimensions in the tensor must be 4, and bbox_pred have the same problem. I use the KITTI dataset, do I need to modify this code?
cls_score = cls_score.permute(0, 2, 3, 1).reshape(-1, self.num_classes)