Skip to content

Refactor DAB-DETR in MMDetection 3.x #9252

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 49 commits into from
Dec 22, 2022
Merged
Show file tree
Hide file tree
Changes from 41 commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
0697135
resolve refactor conflict w.o. pre-commit hooks
KeiChiTse Nov 6, 2022
a493aa8
fixed error and finished alignment
KeiChiTse Nov 9, 2022
f232f3d
supprot 91 cls and remove batch_first
KeiChiTse Nov 27, 2022
6b1be57
delete iter_update, keep return intermediate
KeiChiTse Nov 27, 2022
d087892
delete posHW
KeiChiTse Nov 28, 2022
e460dfa
substitute 'num_query' to 'num_queries'
KeiChiTse Nov 28, 2022
b5d1358
change 'gen_sineembed_for_position' to 'convert_coordinate_to_encoding'
KeiChiTse Nov 28, 2022
5244c24
resolve extra comments
KeiChiTse Nov 28, 2022
ebe71cd
fix error
KeiChiTse Nov 28, 2022
88b15a8
fix error
KeiChiTse Nov 28, 2022
8d08eeb
fix data path
KeiChiTse Nov 28, 2022
9fa9d5b
support 91 cls temporarily
KeiChiTse Nov 29, 2022
67d38cf
resolve extra comments
KeiChiTse Nov 29, 2022
f534413
fix num_keys, num_feats
KeiChiTse Dec 2, 2022
4587906
delete reg_branches in decoder_inputs_dict
KeiChiTse Dec 2, 2022
b24e2ec
fix docstring
KeiChiTse Dec 2, 2022
bbeb862
fix docstring
KeiChiTse Dec 2, 2022
1b232c5
commit modification in pr of DINO
Li-Qingyun Dec 7, 2022
a36607d
fix data format from nbc to bnc in detr and deformable-detr
KeiChiTse Dec 9, 2022
4166e8f
fix 'gen_encoder_output_proposals' for two-stage deformable-detr
KeiChiTse Dec 9, 2022
8c7dcb5
fix 'gen_encoder_output_proposals' for two-stage deformable-detr
KeiChiTse Dec 9, 2022
47348d9
set 'batch_first' to True in deformable attention
KeiChiTse Dec 9, 2022
cfafbf6
fix error
KeiChiTse Dec 9, 2022
ec4b951
fix ut
Li-Qingyun Dec 9, 2022
4b8b8d5
add assert for batch_first
KeiChiTse Dec 9, 2022
c2f3da0
remove 91 cls
KeiChiTse Dec 9, 2022
88ee92c
modify pre_decoder of DeformableDETR
Li-Qingyun Dec 10, 2022
039d14e
delete useless comments
Li-Qingyun Dec 10, 2022
2db01cf
bnc data flow w.o. merge detr and def-detr
KeiChiTse Dec 10, 2022
a9564f9
merge detr, def-detr
KeiChiTse Dec 11, 2022
7d03a03
assert batch first flag in conditional attention, fix error
KeiChiTse Dec 11, 2022
044c3d3
add unit test for dab-detr
KeiChiTse Dec 11, 2022
decec74
fix doc
KeiChiTse Dec 11, 2022
e5982d5
disable yapf hook
KeiChiTse Dec 11, 2022
519895d
move conditional attention to trm/layers
KeiChiTse Dec 13, 2022
276ba74
fix name and add doc
KeiChiTse Dec 13, 2022
a7c148a
fix doc
KeiChiTse Dec 13, 2022
f6e0005
add loss_and_predict for head
KeiChiTse Dec 13, 2022
5326c47
fix doc and typehint
KeiChiTse Dec 13, 2022
bddf68e
fix doc and typehint
KeiChiTse Dec 13, 2022
d5d2c1d
Merge branch 'refactor-detr' into refactor-dab-detr-3.x
KeiChiTse Dec 13, 2022
6972509
modify batch first assert for attention
KeiChiTse Dec 14, 2022
483fe0c
merge refactor-detr branch
KeiChiTse Dec 18, 2022
d11d684
merge refactor-detr
KeiChiTse Dec 19, 2022
f1efdb4
change Dab to DAB
KeiChiTse Dec 19, 2022
a384378
rename file and function
KeiChiTse Dec 19, 2022
12e2cff
make dab-detr head inherit conditional detr head
KeiChiTse Dec 19, 2022
7c64c88
fix doc
KeiChiTse Dec 20, 2022
f8cc093
fix doc
KeiChiTse Dec 21, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 162 additions & 0 deletions configs/dab_detr/dab-detr_r50_8xb2-50e_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
_base_ = [
'../_base_/datasets/coco_detection.py', '../_base_/default_runtime.py'
]
model = dict(
type='DABDETR',
num_queries=300,
with_random_refpoints=False,
num_patterns=0,
data_preprocessor=dict(
type='DetDataPreprocessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True,
pad_size_divisor=1),
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(3, ),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=False),
norm_eval=True,
style='pytorch',
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),
neck=dict(
type='ChannelMapper',
in_channels=[2048],
kernel_size=1,
out_channels=256,
act_cfg=None,
norm_cfg=None,
num_outs=1),
encoder=dict(
num_layers=6,
layer_cfg=dict(
self_attn_cfg=dict(
embed_dims=256, num_heads=8, dropout=0., batch_first=True),
ffn_cfg=dict(
embed_dims=256,
feedforward_channels=2048,
num_fcs=2,
ffn_drop=0.,
act_cfg=dict(type='PReLU')))),
decoder=dict(
num_layers=6,
query_dim=4,
query_scale_type='cond_elewise',
with_modulated_hw_attn=True,
layer_cfg=dict(
self_attn_cfg=dict(
embed_dims=256,
num_heads=8,
attn_drop=0.,
proj_drop=0.,
cross_attn=False),
cross_attn_cfg=dict(
embed_dims=256,
num_heads=8,
attn_drop=0.,
proj_drop=0.,
cross_attn=True),
ffn_cfg=dict(
embed_dims=256,
feedforward_channels=2048,
num_fcs=2,
ffn_drop=0.,
act_cfg=dict(type='PReLU'))),
return_intermediate=True),
positional_encoding_cfg=dict(
num_feats=128, temperature=20, normalize=True),
bbox_head=dict(
type='DABDETRHead',
num_classes=80,
embed_dims=256,
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_bbox=dict(type='L1Loss', loss_weight=5.0),
loss_iou=dict(type='GIoULoss', loss_weight=2.0)),
# training and testing settings
train_cfg=dict(
assigner=dict(
type='HungarianAssigner',
match_costs=[
dict(type='FocalLossCost', weight=2., eps=1e-8),
dict(type='BBoxL1Cost', weight=5.0, box_format='xywh'),
dict(type='IoUCost', iou_mode='giou', weight=2.0)
])),
test_cfg=dict(max_per_img=300))

# train_pipeline, NOTE the img_scale and the Pad's size_divisor is different
# from the default setting in mmdet.
train_pipeline = [
dict(
type='LoadImageFromFile',
file_client_args={{_base_.file_client_args}}),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='RandomFlip', prob=0.5),
dict(
type='RandomChoice',
transforms=[[
dict(
type='RandomChoiceResize',
scales=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),
(608, 1333), (640, 1333), (672, 1333), (704, 1333),
(736, 1333), (768, 1333), (800, 1333)],
keep_ratio=True)
],
[
dict(
type='RandomChoiceResize',
scales=[(400, 1333), (500, 1333), (600, 1333)],
keep_ratio=True),
dict(
type='RandomCrop',
crop_type='absolute_range',
crop_size=(384, 600),
allow_negative_crop=True),
dict(
type='RandomChoiceResize',
scales=[(480, 1333), (512, 1333), (544, 1333),
(576, 1333), (608, 1333), (640, 1333),
(672, 1333), (704, 1333), (736, 1333),
(768, 1333), (800, 1333)],
keep_ratio=True)
]]),
dict(type='PackDetInputs')
]
train_dataloader = dict(dataset=dict(pipeline=train_pipeline))

# optimizer
optim_wrapper = dict(
type='OptimWrapper',
optimizer=dict(type='AdamW', lr=0.0001, weight_decay=0.0001),
clip_grad=dict(max_norm=0.1, norm_type=2),
paramwise_cfg=dict(
custom_keys={'backbone': dict(lr_mult=0.1, decay_mult=1.0)}))

# learning policy
max_epochs = 50
train_cfg = dict(
type='EpochBasedTrainLoop', max_epochs=max_epochs, val_interval=1)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')

param_scheduler = [
dict(
type='MultiStepLR',
begin=0,
end=max_epochs,
by_epoch=True,
milestones=[40],
gamma=0.1)
]

# NOTE: `auto_scale_lr` is for automatically scaling LR,
# USER SHOULD NOT CHANGE ITS VALUES.
# base_batch_size = (8 GPUs) x (2 samples per GPU)
auto_scale_lr = dict(base_batch_size=16, enable=False)
3 changes: 2 additions & 1 deletion mmdet/models/dense_heads/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .centernet_update_head import CenterNetUpdateHead
from .centripetal_head import CentripetalHead
from .corner_head import CornerHead
from .dab_detr_head import DABDETRHead
from .ddod_head import DDODHead
from .deformable_detr_head import DeformableDETRHead
from .detr_head import DETRHead
Expand Down Expand Up @@ -56,5 +57,5 @@
'DeformableDETRHead', 'CenterNetHead', 'YOLOXHead', 'SOLOHead',
'DecoupledSOLOHead', 'DecoupledSOLOLightHead', 'SOLOV2Head', 'LADHead',
'TOODHead', 'MaskFormerHead', 'Mask2FormerHead', 'DDODHead',
'CenterNetUpdateHead', 'RTMDetHead', 'RTMDetSepBNHead'
'CenterNetUpdateHead', 'RTMDetHead', 'RTMDetSepBNHead', 'DABDETRHead'
]
168 changes: 168 additions & 0 deletions mmdet/models/dense_heads/dab_detr_head.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple

import torch.nn as nn
from mmcv.cnn import Linear
from mmengine.model import bias_init_with_prob, constant_init
from torch import Tensor

from mmdet.registry import MODELS
from mmdet.structures import SampleList
from mmdet.utils import InstanceList
from ..layers import MLP, inverse_sigmoid
from .detr_head import DETRHead


@MODELS.register_module()
class DABDETRHead(DETRHead):
"""Head of DAB-DETR. DAB-DETR: Dynamic Anchor Boxes are Better Queries for
DETR.

More details can be found in the `paper
<https://arxiv.org/abs/2201.12329>`_ .
"""

def _init_layers(self) -> None:
"""Initialize layers of the transformer head."""
# cls branch
self.fc_cls = Linear(self.embed_dims, self.cls_out_channels)
# reg branch
self.fc_reg = MLP(self.embed_dims, self.embed_dims, 4, 3)

def init_weights(self) -> None:
"""initialize weights."""
if self.loss_cls.use_sigmoid:
bias_init = bias_init_with_prob(0.01)
nn.init.constant_(self.fc_cls.bias, bias_init)
constant_init(self.fc_reg.layers[-1], 0., bias=0.)

def forward(self, hidden_states: Tensor,
references: Tensor) -> Tuple[Tensor, Tensor]:
""""Forward function.

Args:
hidden_states (Tensor): Features from transformer decoder. If
`return_intermediate_dec` in detr.py is True output has shape
(num_decoder_layers, bs, num_queries, dim), else has shape (1,
bs, num_queries, dim) which only contains the last layer
outputs.
references (Tensor): References from transformer decoder. If
`return_intermediate_dec` in detr.py is True output has shape
(num_decoder_layers, bs, num_queries, 2/4), else has shape (1,
bs, num_queries, 2/4)
which only contains the last layer reference.
Returns:
tuple[Tensor]: results of head containing the following tensor.

- layers_cls_scores (Tensor): Outputs from the classification head,
shape (num_decoder_layers, bs, num_queries, cls_out_channels).
Note cls_out_channels should include background.
- layers_bbox_preds (Tensor): Sigmoid outputs from the regression
head with normalized coordinate format (cx, cy, w, h), has shape
(num_decoder_layers, bs, num_queries, 4).
"""
layers_cls_scores = self.fc_cls(hidden_states)
references_before_sigmoid = inverse_sigmoid(references, eps=1e-3)
tmp_reg_preds = self.fc_reg(hidden_states)
tmp_reg_preds[..., :references_before_sigmoid.
size(-1)] += references_before_sigmoid
layers_bbox_preds = tmp_reg_preds.sigmoid()
return layers_cls_scores, layers_bbox_preds

def loss(self, hidden_states: Tensor, references: Tensor,
batch_data_samples: SampleList) -> dict:
"""Perform forward propagation and loss calculation of the detection
head on the features of the upstream network.

Args:
hidden_states (Tensor): Feature from the transformer decoder, has
shape (num_decoder_layers, bs, num_queries, dim).
references (Tensor): references from the transformer decoder, has
shape (num_decoder_layers, bs, num_queries, 2/4).
batch_data_samples (List[:obj:`DetDataSample`]): The Data
Samples. It usually includes information such as
`gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.

Returns:
dict: A dictionary of loss components.
"""
batch_gt_instances = []
batch_img_metas = []
for data_sample in batch_data_samples:
batch_img_metas.append(data_sample.metainfo)
batch_gt_instances.append(data_sample.gt_instances)

outs = self(hidden_states, references)
loss_inputs = outs + (batch_gt_instances, batch_img_metas)
losses = self.loss_by_feat(*loss_inputs)
return losses

def predict(self,
hidden_states: Tensor,
references: Tensor,
batch_data_samples: SampleList,
rescale: bool = True) -> InstanceList:
"""Perform forward propagation of the detection head and predict
detection results on the features of the upstream network. Over-write
because img_metas are needed as inputs for bbox_head.

Args:
hidden_states (Tensor): Feature from the transformer decoder, has
shape (num_decoder_layers, bs, num_queries, dim).
references (Tensor): references from the transformer decoder, has
shape (num_decoder_layers, bs, num_queries, 2/4).
batch_data_samples (List[:obj:`DetDataSample`]): The Data
Samples. It usually includes information such as
`gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
rescale (bool, optional): Whether to rescale the results.
Defaults to True.

Returns:
list[obj:`InstanceData`]: Detection results of each image
after the post process.
"""
batch_img_metas = [
data_samples.metainfo for data_samples in batch_data_samples
]

outs = self(hidden_states, references)

predictions = self.predict_by_feat(
*outs, batch_img_metas=batch_img_metas, rescale=rescale)
return predictions

def loss_and_predict(self, hidden_states: Tensor, references: Tensor,
batch_data_samples: SampleList) -> dict:
"""Perform forward propagation of the head, then calculate loss and
predictions from the features and data samples. Over-write because
img_metas are needed as inputs for bbox_head.

Args:
hidden_states (Tensor): Feature from the transformer decoder, has
shape (num_decoder_layers, bs, num_queries, dim).
references (Tensor): references from the transformer decoder, has
shape (num_decoder_layers, bs, num_queries, 2/4).
batch_data_samples (List[:obj:`DetDataSample`]): The Data
Samples. It usually includes information such as
`gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.

Returns:
tuple: the return value is a tuple contains:

- losses: (dict[str, Tensor]): A dictionary of loss components.
- predictions (list[:obj:`InstanceData`]): Detection
results of each image after the post process.
"""
batch_gt_instances = []
batch_img_metas = []
for data_sample in batch_data_samples:
batch_img_metas.append(data_sample.metainfo)
batch_gt_instances.append(data_sample.gt_instances)

outs = self(hidden_states, references)
loss_inputs = outs + (batch_gt_instances, batch_img_metas)
losses = self.loss_by_feat(*loss_inputs)

predictions = self.predict_by_feat(
*outs, batch_img_metas=batch_img_metas)
return losses, predictions
3 changes: 2 additions & 1 deletion mmdet/models/detectors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .cascade_rcnn import CascadeRCNN
from .centernet import CenterNet
from .cornernet import CornerNet
from .dab_detr import DABDETR
from .ddod import DDOD
from .deformable_detr import DeformableDETR
from .detr import DETR
Expand Down Expand Up @@ -59,5 +60,5 @@
'SOLOv2', 'DeformableDETR', 'AutoAssign', 'YOLOF', 'CenterNet', 'YOLOX',
'TwoStagePanopticSegmentor', 'PanopticFPN', 'QueryInst', 'LAD', 'TOOD',
'MaskFormer', 'DDOD', 'Mask2Former', 'SemiBaseDetector', 'SoftTeacher',
'DetectionTransformer', 'RTMDet'
'DetectionTransformer', 'DABDETR', 'RTMDet'
]
Loading