Skip to content

Refactor detr 3.x conditional detr and group detr #9248

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 46 commits into
base: refactor-detr
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 34 commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
97e1dd9
[Refactor]: Refactor DETR and Deformable DETR (#8763)
Li-Qingyun Oct 20, 2022
9ac71ab
delete Transformer in ci test
LYM-fire Oct 20, 2022
436de72
refine Transformer in ci test
LYM-fire Oct 20, 2022
db09116
add test_detr.py
LYM-fire Oct 20, 2022
8487eea
add test_detr.py
LYM-fire Oct 20, 2022
2f66ad3
support deformable detr
LYM-fire Oct 20, 2022
c64213b
fix _forward bug
LYM-fire Oct 20, 2022
7b08916
delete test_deformable_detr_head.py
LYM-fire Oct 20, 2022
be1202e
delete test_maskformer.py
LYM-fire Oct 20, 2022
d024714
todo
LYM-fire Oct 20, 2022
5edb02b
use relativte config unitest
LYM-fire Oct 21, 2022
43a3af1
use relativte config unitest
LYM-fire Oct 21, 2022
b97ac76
support conditional detr
LYM-fire Oct 23, 2022
72f679b
add doc
LYM-fire Oct 23, 2022
05a6348
add gropdetr decoder self attention
LYM-fire Oct 23, 2022
b257a29
Add unitests for detr 3.x (#9089)
LYMDLUT Oct 24, 2022
2b2ac06
add Tuple
LYM-fire Oct 26, 2022
dae222c
add test_detr.py
LYM-fire Oct 20, 2022
29b1d42
add test_detr.py
LYM-fire Oct 20, 2022
4f667f1
support conditional detr
LYM-fire Oct 23, 2022
621da7c
add doc
LYM-fire Oct 23, 2022
af115a6
add gropdetr decoder self attention
LYM-fire Oct 23, 2022
cf19bb3
add Tuple
LYM-fire Oct 26, 2022
4477dba
Merge remote-tracking branch 'origin/refactor-detr-3.x-rebase-groupde…
LYM-fire Oct 26, 2022
c520d85
move con decoder to layers
LYM-fire Nov 3, 2022
01aadf5
add conditional config
LYM-fire Nov 3, 2022
2916268
support group detr
LYM-fire Nov 5, 2022
2bfb32f
rename some variable in head
LYM-fire Nov 6, 2022
cc26f32
move gen_sine_embed_for_ref to utils
LYM-fire Nov 6, 2022
8d01c79
rename the detector
LYM-fire Nov 6, 2022
57b8767
refine util
LYM-fire Nov 6, 2022
b9a6318
refine precommit
LYM-fire Nov 7, 2022
22ee0d2
refine precommit
LYM-fire Nov 7, 2022
b3a002c
fix bug of 50e stop
LYM-fire Nov 23, 2022
36ff00b
support 91class
LYM-fire Nov 28, 2022
2534694
condetr 91class config
LYM-fire Nov 28, 2022
ceb34df
change num_query to num_queries
LYM-fire Nov 28, 2022
8be12b5
refine dropout and attn_mask
LYM-fire Nov 28, 2022
987fcbb
add typehint
LYM-fire Nov 28, 2022
41f1895
delete batch_first
LYM-fire Nov 28, 2022
d54cfe7
delete batch_first
LYM-fire Nov 28, 2022
552239a
refine head doc
LYM-fire Nov 28, 2022
3d26369
refine condetr doc
LYM-fire Nov 28, 2022
4707ac2
refine condetr_transformer doc
LYM-fire Nov 28, 2022
782fca0
fix the bug of class91
LYM-fire Nov 29, 2022
72544f8
fix the bug of class91
LYM-fire Nov 29, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions configs/conditional_detr/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Conditional DETR

> [Conditional DETR for Fast Training Convergence](https://arxiv.org/abs/2108.06152)

<!-- [ALGORITHM] -->

## Abstract

The DETR approach applies the transformer encoder and decoder architecture to object detection and achieves promising performance. In this paper, we handle the critical issue, slow training convergence, and present a conditional cross-attention mechanism for fast DETR training. Our approach is motivated by that the cross-attention in DETR relies highly on the content embeddings and that the spatial embeddings make minor contributions, increasing the need for high-quality content embeddings and thus increasing the training difficulty.

<div align=center>
<img src="https://github.com/Atten4Vis/ConditionalDETR/blob/main/.github/attention-maps.png?raw=true"/>
</div>

Our conditional DETR learns a conditional spatial query from the decoder embedding for decoder multi-head cross-attention. The benefit is that through the conditional spatial query, each cross-attention head is able to attend to a band containing a distinct region, e.g., one object extremity or a region inside the object box (Figure 1). This narrows down the spatial range for localizing the distinct regions for object classification and box regression, thus relaxing the dependence on the content embeddings and easing the training. Empirical results show that conditional DETR converges 6.7x faster for the backbones R50 and R101 and 10x faster for stronger backbones DC5-R50 and DC5-R101.

<div align=center>
<img src="https://github.com/Atten4Vis/ConditionalDETR/raw/main/.github/conditional-detr.png" width="48%"/>
<img src="https://github.com/Atten4Vis/ConditionalDETR/raw/main/.github/convergence-curve.png" width="48%"/>
</div>

## Results and Models

We provide the config files and models for Conditional DETR: [Conditional DETR for Fast Training Convergence](https://arxiv.org/abs/2108.06152).

| Backbone | Model | Lr schd | Mem (GB) | Inf time (fps) | box AP | Config | Download |
| :------: | :--------------: | :-----: | :------: | :------------: | :----: | :------------------------------------: | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
| R-50 | Conditional DETR | 50e | 7.9 | | 41.0 | [config](./detr_r50_8xb2-150e_coco.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/detr/detr_r50_8x2_150e_coco/detr_r50_8x2_150e_coco_20201130_194835-2c4b8974.pth) \| [log](https://download.openmmlab.com/mmdetection/v2.0/detr/detr_r50_8x2_150e_coco/detr_r50_8x2_150e_coco_20201130_194835.log.json) |

## Citation

```latex
@inproceedings{meng2021-CondDETR,
title = {Conditional DETR for Fast Training Convergence},
author = {Meng, Depu and Chen, Xiaokang and Fan, Zejia and Zeng, Gang and Li, Houqiang and Yuan, Yuhui and Sun, Lei and Wang, Jingdong},
booktitle = {Proceedings of the IEEE International Conference on Computer Vision (ICCV)},
year = {2021}
}
```
34 changes: 34 additions & 0 deletions configs/conditional_detr/conditional_detr_r50_8xb2-50e_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
_base_ = ['../detr/detr_r50_8xb2-150e_coco.py']
model = dict(
type='ConditionalDETR',
num_query=300,
decoder=dict(
num_layers=6,
layer_cfg=dict(
self_attn_cfg=dict(
embed_dims=256, num_heads=8, dropout=0.1, cross_attn=False),
cross_attn_cfg=dict(
embed_dims=256, num_heads=8, dropout=0.1, cross_attn=True))),
bbox_head=dict(
type='ConditionalDETRHead',
loss_cls=dict(
_delete_=True,
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=2.0)),
# training and testing settings
train_cfg=dict(
assigner=dict(
type='HungarianAssigner',
match_costs=[
dict(type='FocalLossCost', weight=2.0),
dict(type='BBoxL1Cost', weight=5.0, box_format='xywh'),
dict(type='IoUCost', iou_mode='giou', weight=2.0)
])))

# learning policy
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=50, val_interval=1)

param_scheduler = [dict(type='MultiStepLR', end=50, milestones=[40])]
65 changes: 25 additions & 40 deletions configs/deformable_detr/deformable-detr_r50_16xb2-50e_coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
]
model = dict(
type='DeformableDETR',
num_query=300,
num_feature_levels=4,
with_box_refine=False,
as_two_stage=False,
data_preprocessor=dict(
type='DetDataPreprocessor',
mean=[123.675, 116.28, 103.53],
Expand All @@ -27,50 +31,31 @@
act_cfg=None,
norm_cfg=dict(type='GN', num_groups=32),
num_outs=4),
encoder=dict( # DeformableDetrTransformerEncoder
num_layers=6,
layer_cfg=dict( # DeformableDetrTransformerEncoderLayer
self_attn_cfg=dict( # MultiScaleDeformableAttention
embed_dims=256),
ffn_cfg=dict(
embed_dims=256, feedforward_channels=1024, ffn_drop=0.1))),
decoder=dict( # DeformableDetrTransformerDecoder
num_layers=6,
return_intermediate=True,
layer_cfg=dict( # DeformableDetrTransformerDecoderLayer
self_attn_cfg=dict( # MultiheadAttention
embed_dims=256,
num_heads=8,
dropout=0.1),
cross_attn_cfg=dict( # MultiScaleDeformableAttention
embed_dims=256),
ffn_cfg=dict(
embed_dims=256, feedforward_channels=1024, ffn_drop=0.1)),
post_norm_cfg=None),
positional_encoding_cfg=dict(num_feats=128, normalize=True, offset=-0.5),
bbox_head=dict(
type='DeformableDETRHead',
num_query=300,
num_classes=80,
in_channels=2048,
sync_cls_avg_factor=True,
as_two_stage=False,
transformer=dict(
type='DeformableDetrTransformer',
encoder=dict(
type='DetrTransformerEncoder',
num_layers=6,
transformerlayers=dict(
type='BaseTransformerLayer',
attn_cfgs=dict(
type='MultiScaleDeformableAttention', embed_dims=256),
feedforward_channels=1024,
ffn_dropout=0.1,
operation_order=('self_attn', 'norm', 'ffn', 'norm'))),
decoder=dict(
type='DeformableDetrTransformerDecoder',
num_layers=6,
return_intermediate=True,
transformerlayers=dict(
type='DetrTransformerDecoderLayer',
attn_cfgs=[
dict(
type='MultiheadAttention',
embed_dims=256,
num_heads=8,
dropout=0.1),
dict(
type='MultiScaleDeformableAttention',
embed_dims=256)
],
feedforward_channels=1024,
ffn_dropout=0.1,
operation_order=('self_attn', 'norm', 'cross_attn', 'norm',
'ffn', 'norm')))),
positional_encoding=dict(
type='SinePositionalEncoding',
num_feats=128,
normalize=True,
offset=-0.5),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
_base_ = 'deformable-detr_r50_16xb2-50e_coco.py'
model = dict(bbox_head=dict(with_box_refine=True))
model = dict(with_box_refine=True)
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
_base_ = 'deformable-detr_refine_r50_16xb2-50e_coco.py'
model = dict(bbox_head=dict(as_two_stage=True))
model = dict(as_two_stage=True)
2 changes: 1 addition & 1 deletion configs/detr/detr_r18_8xb2-500e_coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
backbone=dict(
depth=18,
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet18')),
neck=dict(in_channels=[64, 128, 256, 512]))
neck=dict(in_channels=[512]))
78 changes: 42 additions & 36 deletions configs/detr/detr_r50_8xb2-150e_coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
]
model = dict(
type='DETR',
num_query=100,
data_preprocessor=dict(
type='DetDataPreprocessor',
mean=[123.675, 116.28, 103.53],
Expand All @@ -19,45 +20,50 @@
norm_eval=True,
style='pytorch',
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),
neck=dict(
type='ChannelMapper',
in_channels=[2048],
kernel_size=1,
out_channels=256,
act_cfg=None,
norm_cfg=None,
num_outs=1),
encoder=dict( # DetrTransformerEncoder
num_layers=6,
layer_cfg=dict( # DetrTransformerEncoderLayer
self_attn_cfg=dict( # MultiheadAttention
embed_dims=256,
num_heads=8,
dropout=0.1),
ffn_cfg=dict(
embed_dims=256,
feedforward_channels=2048,
num_fcs=2,
ffn_drop=0.1,
act_cfg=dict(type='ReLU', inplace=True)))),
decoder=dict( # DetrTransformerDecoder
num_layers=6,
layer_cfg=dict( # DetrTransformerDecoderLayer
self_attn_cfg=dict( # MultiheadAttention
embed_dims=256,
num_heads=8,
dropout=0.1),
cross_attn_cfg=dict( # MultiheadAttention
embed_dims=256,
num_heads=8,
dropout=0.1),
ffn_cfg=dict(
embed_dims=256,
feedforward_channels=2048,
num_fcs=2,
ffn_drop=0.1,
act_cfg=dict(type='ReLU', inplace=True))),
return_intermediate=True),
positional_encoding_cfg=dict(num_feats=128, normalize=True),
bbox_head=dict(
type='DETRHead',
num_classes=80,
in_channels=2048,
transformer=dict(
type='Transformer',
encoder=dict(
type='DetrTransformerEncoder',
num_layers=6,
transformerlayers=dict(
type='BaseTransformerLayer',
attn_cfgs=[
dict(
type='MultiheadAttention',
embed_dims=256,
num_heads=8,
dropout=0.1)
],
feedforward_channels=2048,
ffn_dropout=0.1,
operation_order=('self_attn', 'norm', 'ffn', 'norm'))),
decoder=dict(
type='DetrTransformerDecoder',
return_intermediate=True,
num_layers=6,
transformerlayers=dict(
type='DetrTransformerDecoderLayer',
attn_cfgs=dict(
type='MultiheadAttention',
embed_dims=256,
num_heads=8,
dropout=0.1),
feedforward_channels=2048,
ffn_dropout=0.1,
operation_order=('self_attn', 'norm', 'cross_attn', 'norm',
'ffn', 'norm')),
)),
positional_encoding=dict(
type='SinePositionalEncoding', num_feats=128, normalize=True),
embed_dims=256,
loss_cls=dict(
type='CrossEntropyLoss',
bg_cls_weight=0.1,
Expand Down
41 changes: 41 additions & 0 deletions configs/group_detr/group_detr_r50_8xb2-50e_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
_base_ = ['../detr/detr_r50_8xb2-150e_coco.py']
group_detr = 11
model = dict(
type='ConditionalDETR',
num_query=300,
group_detr=group_detr,
decoder=dict(
num_layers=6,
layer_cfg=dict(
self_attn_cfg=dict(
embed_dims=256,
num_heads=8,
dropout=0.1,
cross_attn=False,
group_detr=group_detr),
cross_attn_cfg=dict(
embed_dims=256, num_heads=8, dropout=0.1, cross_attn=True))),
bbox_head=dict(
type='ConditionalDETRHead',
loss_cls=dict(
_delete_=True,
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=2.0)),
# training and testing settings
train_cfg=dict(
assigner=dict(
type='GHungarianAssigner',
match_costs=[
dict(type='FocalLossCost', weight=2.0),
dict(type='BBoxL1Cost', weight=5.0, box_format='xywh'),
dict(type='IoUCost', iou_mode='giou', weight=2.0)
],
group_detr=group_detr)))

# learning policy
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=50, val_interval=1)

param_scheduler = [dict(type='MultiStepLR', end=50, milestones=[40])]
4 changes: 3 additions & 1 deletion mmdet/models/dense_heads/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .centernet_head import CenterNetHead
from .centernet_update_head import CenterNetUpdateHead
from .centripetal_head import CentripetalHead
from .conditional_detr_head import ConditionalDETRHead
from .corner_head import CornerHead
from .ddod_head import DDODHead
from .deformable_detr_head import DeformableDETRHead
Expand Down Expand Up @@ -56,5 +57,6 @@
'DeformableDETRHead', 'CenterNetHead', 'YOLOXHead', 'SOLOHead',
'DecoupledSOLOHead', 'DecoupledSOLOLightHead', 'SOLOV2Head', 'LADHead',
'TOODHead', 'MaskFormerHead', 'Mask2FormerHead', 'DDODHead',
'CenterNetUpdateHead', 'RTMDetHead', 'RTMDetSepBNHead'
'CenterNetUpdateHead', 'RTMDetHead', 'RTMDetSepBNHead',
'ConditionalDETRHead'
]
Loading