diff --git a/configs/dal/dal-base.py b/configs/dal/dal-base.py index 9d298826..15f07ddb 100644 --- a/configs/dal/dal-base.py +++ b/configs/dal/dal-base.py @@ -59,7 +59,7 @@ use_grid_mask=True, # camera img_backbone=dict( - pretrained='torchvision://resnet18', + pretrained='./ckpts/resnet18', type='ResNet', depth=18, num_stages=4, @@ -379,4 +379,4 @@ optimizer = dict(type='AdamW', lr=2e-4, weight_decay=0.01) # for 64 total batch size two_stage = True runner = dict(type='TwoStageRunner', max_epochs=20) -num_proposals_test = 300 \ No newline at end of file +num_proposals_test = 300 diff --git a/configs/dal/dal-occ-base.py b/configs/dal/dal-occ-base.py new file mode 100644 index 00000000..448fcf71 --- /dev/null +++ b/configs/dal/dal-occ-base.py @@ -0,0 +1,27 @@ +_base_ = ['./dal-base.py'] + +model = dict( + type='DALOcc', + pts_bbox_head=dict( + type='DALOccHead', + occ_enabled=True, + occ_num_classes=2, + occ_z_bins=4, + occ_topk_ratio=0.1, + occ_prop_threshold=0.3, + occ_use_gt_mask=True, + occ_feedback='cls', + occ_prop_weight=1.0, + occ_detach_feedback=False, + loss_occ_proposal=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + reduction='mean', + loss_weight=1.0), + loss_occ=dict( + type='CrossEntropyLoss', + use_sigmoid=False, + reduction='mean', + loss_weight=1.0))) diff --git a/configs/dal/dal-occ-joint-base.py b/configs/dal/dal-occ-joint-base.py new file mode 100644 index 00000000..c846d29d --- /dev/null +++ b/configs/dal/dal-occ-joint-base.py @@ -0,0 +1,228 @@ +_base_ = ['./dal-occ-base.py'] + +# Joint DAL + sparse occupancy training config. +# This file keeps DAL detection pipeline and adds occupancy supervision keys. + +class_names = [ + 'car', 'truck', 'construction_vehicle', 'bus', 'trailer', 'barrier', + 'motorcycle', 'bicycle', 'pedestrian', 'traffic_cone'] +point_cloud_range = [-54.0, -54.0, -3.0, 54.0, 54.0, 5.0] +grid_config = { + 'x': [-54.0, 54.0, 0.6], + 'y': [-54.0, 54.0, 0.6], + 'z': [-3, 5, 8], + 'depth': [1.0, 60.0, 0.5], +} +data_config = { + 'cams': ['CAM_FRONT_LEFT', 'CAM_FRONT', 'CAM_FRONT_RIGHT', + 'CAM_BACK_LEFT', 'CAM_BACK', 'CAM_BACK_RIGHT'], + 'Ncams': 5, + 'input_size': (256, 704), + 'src_size': (900, 1600), + 'resize': (-0.06, 0.44), + 'rot': (-5.4, 5.4), + 'flip': True, + 'crop_h': (0.0, 0.0), + 'random_crop_height': True, + 'vflip': True, + 'resize_test': 0.04, + 'pmd': dict( + brightness_delta=32, + contrast_lower=0.5, + contrast_upper=1.5, + saturation_lower=0.5, + saturation_upper=1.5, + hue_delta=18, + rate=0.5) +} + +dataset_type = 'NuScenesDatasetOccpancy' +data_root = 'data/nuscenes/' +file_client_args = dict(backend='disk') +input_modality = dict( + use_lidar=True, + use_camera=True, + use_radar=False, + use_map=False, + use_external=False) +bda_aug_conf = dict( + rot_lim=(-22.5 * 2, 22.5 * 2), + scale_lim=(0.9, 1.1), + flip_dx_ratio=0.5, + flip_dy_ratio=0.5, + tran_lim=[0.5, 0.5, 0.5] +) +db_sampler = dict( + data_root=data_root, + info_path=data_root + 'bevdetv3-nuscenes_dbinfos_train.pkl', + rate=1.0, + prepare=dict( + filter_by_difficulty=[-1], + filter_by_min_points=dict( + car=5, + truck=5, + bus=5, + trailer=5, + construction_vehicle=5, + traffic_cone=5, + barrier=5, + motorcycle=5, + bicycle=5, + pedestrian=5)), + classes=class_names, + sample_groups=dict( + car=2, + truck=3, + construction_vehicle=7, + bus=4, + trailer=6, + barrier=2, + motorcycle=6, + bicycle=6, + pedestrian=2, + traffic_cone=2), + points_loader=dict( + type='LoadPointsFromFile', + coord_type='LIDAR', + load_dim=5, + use_dim=[0, 1, 2, 3, 4], + file_client_args=file_client_args)) + +train_pipeline = [ + dict( + type='PrepareImageInputs', + is_train=True, + opencv_pp=True, + data_config=data_config), + dict(type='LoadOccGTFromFile'), + dict( + type='LoadPointsFromFile', + coord_type='LIDAR', + load_dim=5, + use_dim=5, + file_client_args=file_client_args), + dict( + type='LoadPointsFromMultiSweeps', + sweeps_num=10, + use_dim=[0, 1, 2, 3, 4], + file_client_args=file_client_args, + pad_empty_sweeps=True, + remove_close=True), + dict(type='ToEgo'), + dict(type='LoadAnnotations'), + dict(type='ObjectSample', db_sampler=db_sampler), + dict(type='VelocityAug'), + dict( + type='BEVAug', + bda_aug_conf=bda_aug_conf, + classes=class_names), + dict(type='PointToMultiViewDepthFusion', downsample=1, + grid_config=grid_config), + dict(type='PointsRangeFilter', point_cloud_range=point_cloud_range), + dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range), + dict(type='ObjectNameFilter', classes=class_names), + dict(type='PointShuffle'), + dict(type='DefaultFormatBundle3D', class_names=class_names), + dict( + type='Collect3D', + keys=[ + 'points', 'gt_bboxes_3d', 'gt_labels_3d', + 'img_inputs', 'gt_depth', 'gt_bboxes_ignore', + 'voxel_semantics', 'mask_camera', 'mask_lidar' + ]) +] + +# Keep test pipeline aligned with DAL detection evaluation first. +test_pipeline = [ + dict( + type='PrepareImageInputs', + is_train=False, + opencv_pp=True, + data_config=data_config), + dict( + type='LoadPointsFromFile', + coord_type='LIDAR', + load_dim=5, + use_dim=5, + file_client_args=file_client_args), + dict( + type='LoadPointsFromMultiSweeps', + sweeps_num=10, + use_dim=[0, 1, 2, 3, 4], + file_client_args=file_client_args, + pad_empty_sweeps=True, + remove_close=True), + dict(type='ToEgo'), + dict(type='LoadAnnotations'), + dict( + type='BEVAug', + bda_aug_conf=bda_aug_conf, + classes=class_names, + is_train=False), + dict( + type='PointToMultiViewDepthFusion', + downsample=1, + grid_config=grid_config), + dict( + type='MultiScaleFlipAug3D', + img_scale=(1333, 800), + pts_scale_ratio=1, + flip=False, + transforms=[ + dict( + type='GlobalRotScaleTrans', + rot_range=[0, 0], + scale_ratio_range=[1., 1.], + translation_std=[0, 0, 0]), + dict(type='RandomFlip3D'), + dict( + type='PointsRangeFilter', + point_cloud_range=point_cloud_range), + dict( + type='DefaultFormatBundle3D', + class_names=class_names, + with_label=False), + dict(type='Collect3D', keys=['points', 'img_inputs', 'gt_depth']) + ]) +] + +data = dict( + samples_per_gpu=4, + workers_per_gpu=6, + train=dict( + type='CBGSDataset', + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file=data_root + 'bevdetv3-nuscenes_infos_train.pkl', + pipeline=train_pipeline, + classes=class_names, + test_mode=False, + use_valid_flag=True, + modality=input_modality, + img_info_prototype='bevdet', + box_type_3d='LiDAR')), + val=dict( + type=dataset_type, + data_root=data_root, + pipeline=test_pipeline, + classes=class_names, + modality=input_modality, + ann_file=data_root + 'bevdetv3-nuscenes_infos_val.pkl', + img_info_prototype='bevdet', + box_type_3d='LiDAR'), + test=dict( + type=dataset_type, + data_root=data_root, + pipeline=test_pipeline, + classes=class_names, + modality=input_modality, + ann_file=data_root + 'bevdetv3-nuscenes_infos_val.pkl', + img_info_prototype='bevdet', + box_type_3d='LiDAR')) + +# Disable DAL two-stage pipeline mutation for the joint config because the +# occupancy loader introduces extra pipeline steps and breaks the hard-coded +# index assertion in tools/train.py. +two_stage = False +runner = dict(type='EpochBasedRunner', max_epochs=20) diff --git a/configs/dal/dal-occ-joint-f1-bev.py b/configs/dal/dal-occ-joint-f1-bev.py new file mode 100644 index 00000000..ea45cd65 --- /dev/null +++ b/configs/dal/dal-occ-joint-f1-bev.py @@ -0,0 +1,7 @@ +_base_ = ['./dal-occ-joint-base.py'] + +model = dict( + pts_bbox_head=dict( + occ_feedback='bev', + )) + diff --git a/configs/dal/dal-occ-joint-f2-heatmap.py b/configs/dal/dal-occ-joint-f2-heatmap.py new file mode 100644 index 00000000..2b7d4459 --- /dev/null +++ b/configs/dal/dal-occ-joint-f2-heatmap.py @@ -0,0 +1,7 @@ +_base_ = ['./dal-occ-joint-base.py'] + +model = dict( + pts_bbox_head=dict( + occ_feedback='heatmap', + )) + diff --git a/configs/dal/dal-occ-joint-f3-cls.py b/configs/dal/dal-occ-joint-f3-cls.py new file mode 100644 index 00000000..6e943d32 --- /dev/null +++ b/configs/dal/dal-occ-joint-f3-cls.py @@ -0,0 +1,7 @@ +_base_ = ['./dal-occ-joint-base.py'] + +model = dict( + pts_bbox_head=dict( + occ_feedback='cls', + )) + diff --git a/configs/dal/dal-occ-joint-nofb.py b/configs/dal/dal-occ-joint-nofb.py new file mode 100644 index 00000000..3fb2d24e --- /dev/null +++ b/configs/dal/dal-occ-joint-nofb.py @@ -0,0 +1,7 @@ +_base_ = ['./dal-occ-joint-base.py'] + +model = dict( + pts_bbox_head=dict( + occ_feedback='none', + )) + diff --git a/mmdet3d/models/dense_heads/__init__.py b/mmdet3d/models/dense_heads/__init__.py index dc1a34c0..a2fe47c0 100644 --- a/mmdet3d/models/dense_heads/__init__.py +++ b/mmdet3d/models/dense_heads/__init__.py @@ -5,6 +5,7 @@ from .base_mono3d_dense_head import BaseMono3DDenseHead from .centerpoint_head import CenterHead from .dal_head import DALHead +from .dal_occ_head import DALOccHead from .fcaf3d_head import FCAF3DHead from .fcos_mono3d_head import FCOSMono3DHead from .free_anchor3d_head import FreeAnchor3DHead @@ -24,5 +25,6 @@ 'SSD3DHead', 'BaseConvBboxHead', 'CenterHead', 'ShapeAwareHead', 'BaseMono3DDenseHead', 'AnchorFreeMono3DHead', 'FCOSMono3DHead', 'GroupFree3DHead', 'PointRPNHead', 'SMOKEMono3DHead', 'PGDHead', - 'MonoFlexHead', 'FCAF3DHead' + 'MonoFlexHead', 'FCAF3DHead', 'DALHead', 'DALOccHead', + 'TransFusionHead' ] diff --git a/mmdet3d/models/dense_heads/dal_occ_head.py b/mmdet3d/models/dense_heads/dal_occ_head.py new file mode 100644 index 00000000..e0857afc --- /dev/null +++ b/mmdet3d/models/dense_heads/dal_occ_head.py @@ -0,0 +1,297 @@ +import torch +import torch.nn.functional as F +from mmcv.cnn import ConvModule, kaiming_init +from mmcv.runner import force_fp32 +from torch import nn + +from mmdet3d.models.builder import HEADS, build_loss +from .dal_head import DALHead + + +@HEADS.register_module() +class DALOccHead(DALHead): + """First-pass DAL head with sparse pillar occupancy branch. + + Design constraints: + 1. box regression remains LiDAR-only; + 2. occupancy only enhances shared BEV / dense heatmap / classification; + 3. no occupancy feature is injected into regression. + """ + + def __init__(self, + occ_enabled=True, + occ_num_classes=2, + occ_z_bins=4, + occ_topk_ratio=0.1, + occ_prop_threshold=0.3, + occ_use_gt_mask=True, + occ_feedback='none', + occ_prop_weight=1.0, + occ_detach_feedback=False, + loss_occ_proposal=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + reduction='mean', + loss_weight=1.0), + loss_occ=dict( + type='CrossEntropyLoss', + use_sigmoid=False, + reduction='mean', + loss_weight=1.0), + **kwargs): + super().__init__(**kwargs) + self.occ_enabled = occ_enabled + self.occ_num_classes = occ_num_classes + self.occ_z_bins = occ_z_bins + self.occ_topk_ratio = occ_topk_ratio + self.occ_prop_threshold = occ_prop_threshold + self.occ_use_gt_mask = occ_use_gt_mask + self.occ_feedback = occ_feedback + self.occ_prop_weight = occ_prop_weight + self.occ_detach_feedback = occ_detach_feedback + self.loss_occ_proposal = build_loss(loss_occ_proposal) + self.loss_occ = build_loss(loss_occ) + + if not self.occ_enabled: + return + + hidden_channel = kwargs['hidden_channel'] + self.occ_bev_proposal_head = nn.Sequential( + ConvModule( + hidden_channel, + hidden_channel, + kernel_size=3, + padding=1, + conv_cfg=dict(type='Conv2d'), + norm_cfg=dict(type='BN2d')), + nn.Conv2d(hidden_channel, 1, kernel_size=1)) + self.occ_bev_feedback = ConvModule( + hidden_channel + self.occ_z_bins, + hidden_channel, + kernel_size=1, + conv_cfg=dict(type='Conv2d'), + norm_cfg=dict(type='BN2d')) + self.occ_heatmap_feedback = ConvModule( + self.num_classes + self.occ_z_bins, + self.num_classes, + kernel_size=1, + conv_cfg=dict(type='Conv2d'), + norm_cfg=dict(type='BN2d')) + self.occ_cls_feedback = ConvModule( + hidden_channel + self.occ_z_bins, + hidden_channel, + kernel_size=1, + conv_cfg=dict(type='Conv1d'), + norm_cfg=dict(type='BN1d')) + self.occ_z_embedding = nn.Embedding(self.occ_z_bins, hidden_channel) + self.occ_decoder = nn.Sequential( + ConvModule( + hidden_channel, + hidden_channel, + kernel_size=1, + conv_cfg=dict(type='Conv1d'), + norm_cfg=dict(type='BN1d')), + nn.Conv1d(hidden_channel, occ_num_classes, kernel_size=1)) + + for module in [ + self.occ_bev_proposal_head, + self.occ_bev_feedback, + self.occ_heatmap_feedback, + self.occ_cls_feedback]: + for m in module.modules(): + if isinstance(m, nn.Conv2d): + kaiming_init(m) + + def _get_num_occ_cells(self, total_cells): + num_cells = max(1, int(total_cells * self.occ_topk_ratio)) + return min(total_cells, num_cells) + + def _build_occ_targets(self, voxel_semantics, out_h, out_w, device): + semantics = voxel_semantics.long().to(device) + occupied = (semantics > 0).float() + occupied_bev = occupied.max(dim=-1).values + occ_prop_target = F.adaptive_max_pool2d(occupied_bev.unsqueeze(1), + (out_h, out_w)) + + occ_volume = occupied.permute(0, 3, 1, 2).contiguous().float() + occ_bin_target = F.adaptive_max_pool3d( + occ_volume.unsqueeze(1), (self.occ_z_bins, out_h, out_w)).squeeze(1) + occ_bin_target = occ_bin_target.permute(0, 2, 3, 1).contiguous() + return occ_prop_target, occ_bin_target.long() + + def _pool_mask_to_occ(self, mask_camera, out_h, out_w, device): + if mask_camera is None: + return None + mask = mask_camera.to(device=device, dtype=torch.float32) + mask = mask.permute(0, 3, 1, 2).contiguous() + mask = F.adaptive_max_pool3d( + mask.unsqueeze(1), (self.occ_z_bins, out_h, out_w)).squeeze(1) + mask = mask.permute(0, 2, 3, 1).contiguous() + return mask + + def _sparse_pillar_forward(self, bev_feat_lidar, occ_prop_logits): + b, _, h, w = occ_prop_logits.shape + proposal_score = occ_prop_logits.sigmoid().view(b, -1) + num_select = self._get_num_occ_cells(proposal_score.shape[-1]) + topk_score, topk_index = proposal_score.topk(num_select, dim=-1) + valid_mask = topk_score > self.occ_prop_threshold + if valid_mask.sum() == 0: + valid_mask = torch.ones_like(valid_mask, dtype=torch.bool) + + bev_feat_flat = bev_feat_lidar.view(b, bev_feat_lidar.shape[1], -1) + pillar_feat = bev_feat_flat.gather( + dim=-1, + index=topk_index[:, None, :].expand(-1, bev_feat_flat.shape[1], -1)) + + z_embed = self.occ_z_embedding.weight.t()[None, :, :, None] + z_embed = z_embed.expand(b, -1, -1, pillar_feat.shape[-1]) + pillar_feat = pillar_feat.unsqueeze(2) + z_embed.permute(0, 2, 1, 3) + pillar_feat = pillar_feat.reshape(b * self.occ_z_bins, -1, + pillar_feat.shape[-1]) + + sparse_occ_logits = self.occ_decoder(pillar_feat) + sparse_occ_logits = sparse_occ_logits.view(b, self.occ_z_bins, + self.occ_num_classes, -1) + sparse_occ_logits = sparse_occ_logits.permute(0, 2, 1, 3).contiguous() + + dense_occ_logits = sparse_occ_logits.new_zeros( + (b, self.occ_num_classes, self.occ_z_bins, h * w)) + scatter_index = topk_index[:, None, None, :].expand( + -1, self.occ_num_classes, self.occ_z_bins, -1) + dense_occ_logits.scatter_(dim=-1, index=scatter_index, + src=sparse_occ_logits) + dense_occ_logits = dense_occ_logits.view(b, self.occ_num_classes, + self.occ_z_bins, h, w) + sparse_occ_prob = sparse_occ_logits.softmax(dim=1)[:, 1] + occ_feedback = sparse_occ_prob.new_zeros((b, self.occ_z_bins, h * w)) + occ_feedback.scatter_(dim=-1, index=topk_index[:, None, :].expand( + -1, self.occ_z_bins, -1), src=sparse_occ_prob) + occ_feedback = occ_feedback.view(b, self.occ_z_bins, h, w) + return dense_occ_logits, occ_feedback, topk_index, valid_mask + + def forward_single(self, inputs, img_inputs, bev_feat_img=None): + batch_size = inputs.shape[0] + bev_feat_lidar = self.shared_conv(inputs) + bev_feat_lidar_flatten = bev_feat_lidar.view(batch_size, + bev_feat_lidar.shape[1], -1) + bev_pos = self.bev_pos.repeat(batch_size, 1, 1).to(bev_feat_lidar.device) + + dense_fuse_feat = torch.cat([bev_feat_lidar, bev_feat_img], dim=1) + dense_fuse_feat = self.dense_heatmap_fuse_convs(dense_fuse_feat)[0] + + occ_feedback_for_det = None + occ_outputs = dict() + if self.occ_enabled: + occ_prop_logits = self.occ_bev_proposal_head(bev_feat_lidar) + occ_logits, occ_feedback, occ_topk_index, occ_valid_mask = \ + self._sparse_pillar_forward(bev_feat_lidar, occ_prop_logits) + occ_feedback_for_det = occ_feedback.detach() \ + if self.occ_detach_feedback else occ_feedback + occ_outputs.update( + occ_prop_logits=occ_prop_logits, + occ_logits=occ_logits, + occ_topk_index=occ_topk_index, + occ_valid_mask=occ_valid_mask) + if self.occ_feedback == 'bev': + dense_fuse_feat = self.occ_bev_feedback( + torch.cat([dense_fuse_feat, occ_feedback_for_det], dim=1)) + + dense_heatmap = self.heatmap_head(dense_fuse_feat) + if self.occ_enabled and self.occ_feedback == 'heatmap': + dense_heatmap = self.occ_heatmap_feedback( + torch.cat([dense_heatmap, occ_feedback_for_det], dim=1)) + heatmap = dense_heatmap.detach().sigmoid() + + top_proposals_class, top_proposals_index = self.extract_proposal(heatmap) + self.query_labels = top_proposals_class + + index = top_proposals_index.expand(-1, bev_feat_lidar_flatten.shape[1], -1) + query_feat_lidar = bev_feat_lidar_flatten.gather(index=index, dim=-1) + + one_hot = F.one_hot(top_proposals_class, + num_classes=self.num_classes).permute(0, 2, 1) + query_cat_encoding = self.class_encoding(one_hot.float()) + query_feat_lidar += query_cat_encoding + + query_pos_index = top_proposals_index.permute(0, 2, 1) + query_pos_index = query_pos_index.expand(-1, -1, bev_pos.shape[-1]) + query_pos = bev_pos.gather(index=query_pos_index, dim=1) + + res = dict() + for task in ['height', 'center', 'dim', 'rot', 'vel']: + res[task] = self.prediction_heads[0].__getattr__(task)(query_feat_lidar) + res['center'] += query_pos.permute(0, 2, 1) + + query_feat_img = self.extract_instance_img_feat(res, img_inputs) + + bev_feat_img = bev_feat_img.view(batch_size, bev_feat_img.shape[1], -1) + index = top_proposals_index.expand(-1, bev_feat_img.shape[1], -1) + query_feat_img_bev = bev_feat_img.gather(index=index, dim=-1) + + query_feat_fuse = torch.cat( + [query_feat_lidar, query_feat_img, query_feat_img_bev], dim=1) + query_feat_fuse = self.fuse_convs(query_feat_fuse) + if self.occ_enabled and self.occ_feedback == 'cls': + occ_query_feat = occ_feedback_for_det.view(batch_size, + self.occ_z_bins, -1) + occ_query_feat = occ_query_feat.gather( + dim=-1, + index=top_proposals_index.expand(-1, self.occ_z_bins, -1)) + query_feat_fuse = self.occ_cls_feedback( + torch.cat([query_feat_fuse, occ_query_feat], dim=1)) + + res['heatmap'] = self.prediction_heads[0].__getattr__('heatmap')( + query_feat_fuse) + heatmap = heatmap.view(batch_size, heatmap.shape[1], -1) + res['query_heatmap_score'] = heatmap.gather( + index=top_proposals_index.expand(-1, self.num_classes, -1), dim=-1) + res['dense_heatmap'] = dense_heatmap + res.update(occ_outputs) + return [res] + + @force_fp32(apply_to=('preds_dicts',)) + def loss(self, + gt_bboxes_3d, + gt_labels_3d, + preds_dicts, + img_metas=None, + voxel_semantics=None, + mask_camera=None, + **kwargs): + loss_dict = super().loss( + gt_bboxes_3d, gt_labels_3d, preds_dicts, img_metas=img_metas, **kwargs) + if not self.occ_enabled or voxel_semantics is None: + return loss_dict + + preds_dict = preds_dicts[0][0] + occ_prop_logits = preds_dict['occ_prop_logits'] + occ_logits = preds_dict['occ_logits'] + _, _, _, out_h, out_w = occ_logits.shape + + occ_prop_target, occ_bin_target = self._build_occ_targets( + voxel_semantics, out_h, out_w, occ_prop_logits.device) + occ_mask = self._pool_mask_to_occ(mask_camera, out_h, out_w, + occ_prop_logits.device) + + loss_occ_prop = self.loss_occ_proposal( + occ_prop_logits.reshape(-1, 1), + occ_prop_target.reshape(-1, 1), + avg_factor=max(float(occ_prop_target.sum().item()), 1.0)) + + occ_target = occ_bin_target.reshape(-1) + occ_logits = occ_logits.permute(0, 3, 4, 2, 1).reshape( + -1, self.occ_num_classes) + occ_weight = occ_prop_target.expand(-1, self.occ_z_bins, -1, -1).permute( + 0, 2, 3, 1).reshape(-1) + if self.occ_use_gt_mask and occ_mask is not None: + occ_weight = occ_weight * occ_mask.reshape(-1) + + loss_dict['loss_occ_prop'] = self.occ_prop_weight * loss_occ_prop + loss_dict['loss_occ'] = self.loss_occ( + occ_logits, + occ_target, + occ_weight, + avg_factor=max(float(occ_weight.sum().item()), 1.0)) + return loss_dict diff --git a/mmdet3d/models/detectors/__init__.py b/mmdet3d/models/detectors/__init__.py index afc800cb..58023f4d 100644 --- a/mmdet3d/models/detectors/__init__.py +++ b/mmdet3d/models/detectors/__init__.py @@ -4,6 +4,7 @@ from .bevdet_occ import BEVStereo4DOCC from .centerpoint import CenterPoint from .dal import DAL +from .dal_occ import DALOcc from .dynamic_voxelnet import DynamicVoxelNet from .fcos_mono3d import FCOSMono3D from .groupfree3dnet import GroupFree3DNet @@ -28,5 +29,5 @@ 'CenterPoint', 'SSD3DNet', 'ImVoteNet', 'SingleStageMono3DDetector', 'FCOSMono3D', 'ImVoxelNet', 'GroupFree3DNet', 'PointRCNN', 'SMOKEMono3D', 'MinkSingleStage3DDetector', 'SASSD', 'BEVDet', 'BEVDet4D', 'BEVDepth4D', - 'BEVDetTRT', 'BEVStereo4D', 'BEVStereo4DOCC' + 'BEVDetTRT', 'BEVStereo4D', 'BEVStereo4DOCC', 'DAL', 'DALOcc' ] diff --git a/mmdet3d/models/detectors/dal_occ.py b/mmdet3d/models/detectors/dal_occ.py new file mode 100644 index 00000000..9cddd531 --- /dev/null +++ b/mmdet3d/models/detectors/dal_occ.py @@ -0,0 +1,67 @@ +from mmdet.models import DETECTORS + +from .dal import DAL + + +@DETECTORS.register_module() +class DALOcc(DAL): + """DAL detector with occupancy branch. + + This wrapper only changes loss wiring so occupancy supervision can be + passed into the DALOccHead while keeping DAL training flow unchanged. + """ + + def forward_pts_train(self, + pts_feats, + gt_bboxes_3d, + gt_labels_3d, + img_metas, + gt_bboxes_ignore=None, + voxel_semantics=None, + mask_camera=None): + outs = self.pts_bbox_head(pts_feats) + loss_inputs = outs + (gt_bboxes_3d, gt_labels_3d, img_metas) + losses = self.pts_bbox_head.loss( + *loss_inputs, + gt_bboxes_ignore=gt_bboxes_ignore, + voxel_semantics=voxel_semantics, + mask_camera=mask_camera) + return losses + + def forward_train(self, + points=None, + img_metas=None, + gt_bboxes_3d=None, + gt_labels_3d=None, + gt_labels=None, + gt_bboxes=None, + img_inputs=None, + proposals=None, + gt_bboxes_ignore=None, + voxel_semantics=None, + mask_camera=None, + **kwargs): + img_feats, pts_feats = self.extract_feat( + points, img=img_inputs, img_metas=img_metas) + img_feats_bev = self.img_view_transformer( + img_feats + img_inputs[1:7], depth_from_lidar=kwargs['gt_depth']) + + losses = dict() + losses_pts = self.forward_pts_train( + [img_feats, pts_feats, img_feats_bev], + gt_bboxes_3d, + gt_labels_3d, + img_metas, + gt_bboxes_ignore, + voxel_semantics=voxel_semantics, + mask_camera=mask_camera) + losses.update(losses_pts) + losses_img_auxiliary = self.forward_img_auxiliary_train( + img_feats, + img_metas, + gt_bboxes_3d, + gt_labels_3d, + gt_bboxes_ignore, + **kwargs) + losses.update(losses_img_auxiliary) + return losses