diff --git a/configs/conditional_detr/README.md b/configs/conditional_detr/README.md
new file mode 100644
index 00000000000..14c3e9bec93
--- /dev/null
+++ b/configs/conditional_detr/README.md
@@ -0,0 +1,39 @@
+# Conditional DETR
+
+> [Conditional DETR for Fast Training Convergence](https://arxiv.org/abs/2108.06152)
+
+
+
+## Abstract
+
+The DETR approach applies the transformer encoder and decoder architecture to object detection and achieves promising performance. In this paper, we handle the critical issue, slow training convergence, and present a conditional cross-attention mechanism for fast DETR training. Our approach is motivated by that the cross-attention in DETR relies highly on the content embeddings and that the spatial embeddings make minor contributions, increasing the need for high-quality content embeddings and thus increasing the training difficulty.
+
+
+

+
+
+Our conditional DETR learns a conditional spatial query from the decoder embedding for decoder multi-head cross-attention. The benefit is that through the conditional spatial query, each cross-attention head is able to attend to a band containing a distinct region, e.g., one object extremity or a region inside the object box (Figure 1). This narrows down the spatial range for localizing the distinct regions for object classification and box regression, thus relaxing the dependence on the content embeddings and easing the training. Empirical results show that conditional DETR converges 6.7x faster for the backbones R50 and R101 and 10x faster for stronger backbones DC5-R50 and DC5-R101.
+
+
+

+

+
+
+## Results and Models
+
+We provide the config files and models for Conditional DETR: [Conditional DETR for Fast Training Convergence](https://arxiv.org/abs/2108.06152).
+
+| Backbone | Model | Lr schd | Mem (GB) | Inf time (fps) | box AP | Config | Download |
+| :------: | :--------------: | :-----: | :------: | :------------: | :----: | :------------------------------------: | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
+| R-50 | Conditional DETR | 50e | 7.9 | | 41.0 | [config](./detr_r50_8xb2-150e_coco.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/detr/detr_r50_8x2_150e_coco/detr_r50_8x2_150e_coco_20201130_194835-2c4b8974.pth) \| [log](https://download.openmmlab.com/mmdetection/v2.0/detr/detr_r50_8x2_150e_coco/detr_r50_8x2_150e_coco_20201130_194835.log.json) |
+
+## Citation
+
+```latex
+@inproceedings{meng2021-CondDETR,
+ title = {Conditional DETR for Fast Training Convergence},
+ author = {Meng, Depu and Chen, Xiaokang and Fan, Zejia and Zeng, Gang and Li, Houqiang and Yuan, Yuhui and Sun, Lei and Wang, Jingdong},
+ booktitle = {Proceedings of the IEEE International Conference on Computer Vision (ICCV)},
+ year = {2021}
+}
+```
diff --git a/configs/conditional_detr/conditional_detr_r50_8xb2-50e_coco.py b/configs/conditional_detr/conditional_detr_r50_8xb2-50e_coco.py
new file mode 100644
index 00000000000..a21476448d0
--- /dev/null
+++ b/configs/conditional_detr/conditional_detr_r50_8xb2-50e_coco.py
@@ -0,0 +1,42 @@
+_base_ = ['../detr/detr_r50_8xb2-150e_coco.py']
+model = dict(
+ type='ConditionalDETR',
+ num_queries=300,
+ decoder=dict(
+ num_layers=6,
+ layer_cfg=dict(
+ self_attn_cfg=dict(
+ _delete_=True,
+ embed_dims=256,
+ num_heads=8,
+ attn_drop=0.1,
+ cross_attn=False),
+ cross_attn_cfg=dict(
+ _delete_=True,
+ embed_dims=256,
+ num_heads=8,
+ attn_drop=0.1,
+ cross_attn=True))),
+ bbox_head=dict(
+ type='ConditionalDETRHead',
+ loss_cls=dict(
+ _delete_=True,
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=2.0)),
+ # training and testing settings
+ train_cfg=dict(
+ assigner=dict(
+ type='HungarianAssigner',
+ match_costs=[
+ dict(type='FocalLossCost', weight=2.0),
+ dict(type='BBoxL1Cost', weight=5.0, box_format='xywh'),
+ dict(type='IoUCost', iou_mode='giou', weight=2.0)
+ ])))
+
+# learning policy
+train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=50, val_interval=1)
+
+param_scheduler = [dict(type='MultiStepLR', end=50, milestones=[40])]
diff --git a/configs/conditional_detr/conditional_detr_r50_8xb2-50e_coco_class91.py b/configs/conditional_detr/conditional_detr_r50_8xb2-50e_coco_class91.py
new file mode 100644
index 00000000000..eacf881085b
--- /dev/null
+++ b/configs/conditional_detr/conditional_detr_r50_8xb2-50e_coco_class91.py
@@ -0,0 +1,48 @@
+_base_ = ['./conditional_detr_r50_8xb2-50e_coco.py']
+
+model = dict(bbox_head=dict(num_classes=91))
+
+metainfo = dict(
+ CLASSES=(None, 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
+ 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', None,
+ 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
+ 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe',
+ None, 'backpack', 'umbrella', None, None, 'handbag', 'tie',
+ 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite',
+ 'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
+ 'tennis racket', 'bottle', None, 'wine glass', 'cup', 'fork',
+ 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange',
+ 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake',
+ 'chair', 'couch', 'potted plant', 'bed', None, 'dining table',
+ None, None, 'toilet', None, 'tv', 'laptop', 'mouse', 'remote',
+ 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink',
+ 'refrigerator', None, 'book', 'clock', 'vase', 'scissors',
+ 'teddy bear', 'hair drier', 'toothbrush'),
+ PALETTE=[
+ None, (220, 20, 60), (119, 11, 32), (0, 0, 142), (0, 0, 230),
+ (106, 0, 228), (0, 60, 100), (0, 80, 100), (0, 0, 70), (0, 0, 192),
+ (250, 170, 30), (100, 170, 30), None, (220, 220, 0), (175, 116, 175),
+ (250, 0, 30), (165, 42, 42), (255, 77, 255), (0, 226, 252),
+ (182, 182, 255), (0, 82, 0), (120, 166, 157), (110, 76, 0),
+ (174, 57, 255), (199, 100, 0), (72, 0, 118), None, (255, 179, 240),
+ (0, 125, 92), None, None, (209, 0, 151),
+ (188, 208, 182), (0, 220, 176), (255, 99, 164), (92, 0, 73),
+ (133, 129, 255), (78, 180, 255), (0, 228, 0), (174, 255, 243),
+ (45, 89, 255), (134, 134, 103), (145, 148, 174), (255, 208, 186),
+ (197, 226, 255), None, (171, 134, 1), (109, 63, 54), (207, 138, 255),
+ (151, 0, 95), (9, 80, 61), (84, 105, 51),
+ (74, 65, 105), (166, 196, 102), (208, 195, 210), (255, 109, 65),
+ (0, 143, 149), (179, 0, 194), (209, 99, 106), (5, 121, 0),
+ (227, 255, 205), (147, 186, 208), (153, 69, 1), (3, 95, 161),
+ (163, 255, 0), (119, 0, 170), None, (0, 182, 199), None, None,
+ (0, 165, 120), None, (183, 130, 88), (95, 32, 0), (130, 114, 135),
+ (110, 129, 133), (166, 74, 118), (219, 142, 185), (79, 210, 114),
+ (178, 90, 62), (65, 70, 15), (127, 167, 115), (59, 105, 106), None,
+ (142, 108, 45), (196, 172, 0), (95, 54, 80), (128, 76, 255),
+ (201, 57, 1), (246, 0, 122), (191, 162, 208)
+ ] # Used for visualization.
+)
+
+train_dataloader = dict(dataset=dict(metainfo=metainfo))
+val_dataloader = dict(dataset=dict(metainfo=metainfo))
+test_dataloader = dict(dataset=dict(metainfo=metainfo))
diff --git a/configs/deformable_detr/deformable-detr_r50_16xb2-50e_coco.py b/configs/deformable_detr/deformable-detr_r50_16xb2-50e_coco.py
index 8f8dfb8ef03..51789b1f353 100644
--- a/configs/deformable_detr/deformable-detr_r50_16xb2-50e_coco.py
+++ b/configs/deformable_detr/deformable-detr_r50_16xb2-50e_coco.py
@@ -3,6 +3,10 @@
]
model = dict(
type='DeformableDETR',
+ num_queries=300,
+ num_feature_levels=4,
+ with_box_refine=False,
+ as_two_stage=False,
data_preprocessor=dict(
type='DetDataPreprocessor',
mean=[123.675, 116.28, 103.53],
@@ -27,50 +31,31 @@
act_cfg=None,
norm_cfg=dict(type='GN', num_groups=32),
num_outs=4),
+ encoder=dict( # DeformableDetrTransformerEncoder
+ num_layers=6,
+ layer_cfg=dict( # DeformableDetrTransformerEncoderLayer
+ self_attn_cfg=dict( # MultiScaleDeformableAttention
+ embed_dims=256),
+ ffn_cfg=dict(
+ embed_dims=256, feedforward_channels=1024, ffn_drop=0.1))),
+ decoder=dict( # DeformableDetrTransformerDecoder
+ num_layers=6,
+ return_intermediate=True,
+ layer_cfg=dict( # DeformableDetrTransformerDecoderLayer
+ self_attn_cfg=dict( # MultiheadAttention
+ embed_dims=256,
+ num_heads=8,
+ dropout=0.1),
+ cross_attn_cfg=dict( # MultiScaleDeformableAttention
+ embed_dims=256),
+ ffn_cfg=dict(
+ embed_dims=256, feedforward_channels=1024, ffn_drop=0.1)),
+ post_norm_cfg=None),
+ positional_encoding_cfg=dict(num_feats=128, normalize=True, offset=-0.5),
bbox_head=dict(
type='DeformableDETRHead',
- num_query=300,
num_classes=80,
- in_channels=2048,
sync_cls_avg_factor=True,
- as_two_stage=False,
- transformer=dict(
- type='DeformableDetrTransformer',
- encoder=dict(
- type='DetrTransformerEncoder',
- num_layers=6,
- transformerlayers=dict(
- type='BaseTransformerLayer',
- attn_cfgs=dict(
- type='MultiScaleDeformableAttention', embed_dims=256),
- feedforward_channels=1024,
- ffn_dropout=0.1,
- operation_order=('self_attn', 'norm', 'ffn', 'norm'))),
- decoder=dict(
- type='DeformableDetrTransformerDecoder',
- num_layers=6,
- return_intermediate=True,
- transformerlayers=dict(
- type='DetrTransformerDecoderLayer',
- attn_cfgs=[
- dict(
- type='MultiheadAttention',
- embed_dims=256,
- num_heads=8,
- dropout=0.1),
- dict(
- type='MultiScaleDeformableAttention',
- embed_dims=256)
- ],
- feedforward_channels=1024,
- ffn_dropout=0.1,
- operation_order=('self_attn', 'norm', 'cross_attn', 'norm',
- 'ffn', 'norm')))),
- positional_encoding=dict(
- type='SinePositionalEncoding',
- num_feats=128,
- normalize=True,
- offset=-0.5),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
diff --git a/configs/deformable_detr/deformable-detr_refine_r50_16xb2-50e_coco.py b/configs/deformable_detr/deformable-detr_refine_r50_16xb2-50e_coco.py
index 8c31edb65cd..b968674f4a9 100644
--- a/configs/deformable_detr/deformable-detr_refine_r50_16xb2-50e_coco.py
+++ b/configs/deformable_detr/deformable-detr_refine_r50_16xb2-50e_coco.py
@@ -1,2 +1,2 @@
_base_ = 'deformable-detr_r50_16xb2-50e_coco.py'
-model = dict(bbox_head=dict(with_box_refine=True))
+model = dict(with_box_refine=True)
diff --git a/configs/deformable_detr/deformable-detr_refine_twostage_r50_16xb2-50e_coco.py b/configs/deformable_detr/deformable-detr_refine_twostage_r50_16xb2-50e_coco.py
index 466e8d5c0f5..8286189d4b9 100644
--- a/configs/deformable_detr/deformable-detr_refine_twostage_r50_16xb2-50e_coco.py
+++ b/configs/deformable_detr/deformable-detr_refine_twostage_r50_16xb2-50e_coco.py
@@ -1,2 +1,2 @@
_base_ = 'deformable-detr_refine_r50_16xb2-50e_coco.py'
-model = dict(bbox_head=dict(as_two_stage=True))
+model = dict(as_two_stage=True)
diff --git a/configs/detr/detr_r18_8xb2-500e_coco.py b/configs/detr/detr_r18_8xb2-500e_coco.py
index 44caf93545d..305b9d6fee8 100644
--- a/configs/detr/detr_r18_8xb2-500e_coco.py
+++ b/configs/detr/detr_r18_8xb2-500e_coco.py
@@ -4,4 +4,4 @@
backbone=dict(
depth=18,
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet18')),
- neck=dict(in_channels=[64, 128, 256, 512]))
+ neck=dict(in_channels=[512]))
diff --git a/configs/detr/detr_r50_8xb2-150e_coco.py b/configs/detr/detr_r50_8xb2-150e_coco.py
index 8c2ad57568a..9346eafa1ec 100644
--- a/configs/detr/detr_r50_8xb2-150e_coco.py
+++ b/configs/detr/detr_r50_8xb2-150e_coco.py
@@ -3,6 +3,7 @@
]
model = dict(
type='DETR',
+ num_queries=100,
data_preprocessor=dict(
type='DetDataPreprocessor',
mean=[123.675, 116.28, 103.53],
@@ -19,45 +20,50 @@
norm_eval=True,
style='pytorch',
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),
+ neck=dict(
+ type='ChannelMapper',
+ in_channels=[2048],
+ kernel_size=1,
+ out_channels=256,
+ act_cfg=None,
+ norm_cfg=None,
+ num_outs=1),
+ encoder=dict( # DetrTransformerEncoder
+ num_layers=6,
+ layer_cfg=dict( # DetrTransformerEncoderLayer
+ self_attn_cfg=dict( # MultiheadAttention
+ embed_dims=256,
+ num_heads=8,
+ dropout=0.1),
+ ffn_cfg=dict(
+ embed_dims=256,
+ feedforward_channels=2048,
+ num_fcs=2,
+ ffn_drop=0.1,
+ act_cfg=dict(type='ReLU', inplace=True)))),
+ decoder=dict( # DetrTransformerDecoder
+ num_layers=6,
+ layer_cfg=dict( # DetrTransformerDecoderLayer
+ self_attn_cfg=dict( # MultiheadAttention
+ embed_dims=256,
+ num_heads=8,
+ dropout=0.1),
+ cross_attn_cfg=dict( # MultiheadAttention
+ embed_dims=256,
+ num_heads=8,
+ dropout=0.1),
+ ffn_cfg=dict(
+ embed_dims=256,
+ feedforward_channels=2048,
+ num_fcs=2,
+ ffn_drop=0.1,
+ act_cfg=dict(type='ReLU', inplace=True))),
+ return_intermediate=True),
+ positional_encoding_cfg=dict(num_feats=128, normalize=True),
bbox_head=dict(
type='DETRHead',
num_classes=80,
- in_channels=2048,
- transformer=dict(
- type='Transformer',
- encoder=dict(
- type='DetrTransformerEncoder',
- num_layers=6,
- transformerlayers=dict(
- type='BaseTransformerLayer',
- attn_cfgs=[
- dict(
- type='MultiheadAttention',
- embed_dims=256,
- num_heads=8,
- dropout=0.1)
- ],
- feedforward_channels=2048,
- ffn_dropout=0.1,
- operation_order=('self_attn', 'norm', 'ffn', 'norm'))),
- decoder=dict(
- type='DetrTransformerDecoder',
- return_intermediate=True,
- num_layers=6,
- transformerlayers=dict(
- type='DetrTransformerDecoderLayer',
- attn_cfgs=dict(
- type='MultiheadAttention',
- embed_dims=256,
- num_heads=8,
- dropout=0.1),
- feedforward_channels=2048,
- ffn_dropout=0.1,
- operation_order=('self_attn', 'norm', 'cross_attn', 'norm',
- 'ffn', 'norm')),
- )),
- positional_encoding=dict(
- type='SinePositionalEncoding', num_feats=128, normalize=True),
+ embed_dims=256,
loss_cls=dict(
type='CrossEntropyLoss',
bg_cls_weight=0.1,
diff --git a/configs/group_detr/group_detr_r50_8xb2-50e_coco.py b/configs/group_detr/group_detr_r50_8xb2-50e_coco.py
new file mode 100644
index 00000000000..6e689270cbe
--- /dev/null
+++ b/configs/group_detr/group_detr_r50_8xb2-50e_coco.py
@@ -0,0 +1,41 @@
+_base_ = ['../detr/detr_r50_8xb2-150e_coco.py']
+group_detr = 11
+model = dict(
+ type='ConditionalDETR',
+ num_queries=300,
+ group_detr=group_detr,
+ decoder=dict(
+ num_layers=6,
+ layer_cfg=dict(
+ self_attn_cfg=dict(
+ embed_dims=256,
+ num_heads=8,
+ dropout=0.1,
+ cross_attn=False,
+ group_detr=group_detr),
+ cross_attn_cfg=dict(
+ embed_dims=256, num_heads=8, dropout=0.1, cross_attn=True))),
+ bbox_head=dict(
+ type='ConditionalDETRHead',
+ loss_cls=dict(
+ _delete_=True,
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=2.0)),
+ # training and testing settings
+ train_cfg=dict(
+ assigner=dict(
+ type='GHungarianAssigner',
+ match_costs=[
+ dict(type='FocalLossCost', weight=2.0),
+ dict(type='BBoxL1Cost', weight=5.0, box_format='xywh'),
+ dict(type='IoUCost', iou_mode='giou', weight=2.0)
+ ],
+ group_detr=group_detr)))
+
+# learning policy
+train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=50, val_interval=1)
+
+param_scheduler = [dict(type='MultiStepLR', end=50, milestones=[40])]
diff --git a/mmdet/datasets/api_wrappers/coco_api.py b/mmdet/datasets/api_wrappers/coco_api.py
index 40f7f2c9b93..fe6430d1198 100644
--- a/mmdet/datasets/api_wrappers/coco_api.py
+++ b/mmdet/datasets/api_wrappers/coco_api.py
@@ -30,7 +30,12 @@ def get_ann_ids(self, img_ids=[], cat_ids=[], area_rng=[], iscrowd=None):
return self.getAnnIds(img_ids, cat_ids, area_rng, iscrowd)
def get_cat_ids(self, cat_names=[], sup_names=[], cat_ids=[]):
- return self.getCatIds(cat_names, sup_names, cat_ids)
+ cat_ids_coco = self.getCatIds(cat_names, sup_names, cat_ids)
+ index = [i for i, v in enumerate(cat_names) if v is not None]
+ cat_ids = list(range(len(cat_names)))
+ for i in range(len(index)):
+ cat_ids[index[i]] = cat_ids_coco[i]
+ return cat_ids
def get_img_ids(self, img_ids=[], cat_ids=[]):
return self.getImgIds(img_ids, cat_ids)
diff --git a/mmdet/models/dense_heads/__init__.py b/mmdet/models/dense_heads/__init__.py
index 4aa95fc3ac1..03873fe6efb 100644
--- a/mmdet/models/dense_heads/__init__.py
+++ b/mmdet/models/dense_heads/__init__.py
@@ -7,6 +7,7 @@
from .centernet_head import CenterNetHead
from .centernet_update_head import CenterNetUpdateHead
from .centripetal_head import CentripetalHead
+from .conditional_detr_head import ConditionalDETRHead
from .corner_head import CornerHead
from .ddod_head import DDODHead
from .deformable_detr_head import DeformableDETRHead
@@ -56,5 +57,6 @@
'DeformableDETRHead', 'CenterNetHead', 'YOLOXHead', 'SOLOHead',
'DecoupledSOLOHead', 'DecoupledSOLOLightHead', 'SOLOV2Head', 'LADHead',
'TOODHead', 'MaskFormerHead', 'Mask2FormerHead', 'DDODHead',
- 'CenterNetUpdateHead', 'RTMDetHead', 'RTMDetSepBNHead'
+ 'CenterNetUpdateHead', 'RTMDetHead', 'RTMDetSepBNHead',
+ 'ConditionalDETRHead'
]
diff --git a/mmdet/models/dense_heads/conditional_detr_head.py b/mmdet/models/dense_heads/conditional_detr_head.py
new file mode 100644
index 00000000000..1914ee22366
--- /dev/null
+++ b/mmdet/models/dense_heads/conditional_detr_head.py
@@ -0,0 +1,168 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Tuple
+
+import torch
+import torch.nn as nn
+from mmengine.model import bias_init_with_prob
+from torch import Tensor
+
+from mmdet.models.layers.transformer import inverse_sigmoid
+from mmdet.registry import MODELS
+from mmdet.structures import SampleList
+from mmdet.utils import InstanceList
+from .detr_head import DETRHead
+
+
+@MODELS.register_module()
+class ConditionalDETRHead(DETRHead):
+ """Head of Conditional DETR. Conditional DETR: Conditional DETR for Fast
+ Training Convergence. More details can be found in the `paper.
+
+ `_ .
+ """
+
+ def init_weights(self):
+ """Initialize weights of the transformer head."""
+ super().init_weights()
+ # The initialization below for transformer head is very
+ # important as we use Focal_loss for loss_cls
+ if self.loss_cls.use_sigmoid:
+ bias_init = bias_init_with_prob(0.01)
+ nn.init.constant_(self.fc_cls.bias, bias_init)
+
+ def forward(self, hidden_states: Tensor,
+ references: Tensor) -> Tuple[Tensor, Tensor]:
+ """"Forward function.
+
+ Args:
+ hidden_states (Tensor): Features from transformer decoder. If
+ `return_intermediate_dec` in detr.py is True output has shape
+ (num_hidden_states, bs, num_queries, dim), else has shape (1,
+ bs, num_queries, dim) which only contains the last layer
+ outputs.
+ references (Tensor): References from transformer decoder,has
+ shape (1, bs, num_query, 2).
+ Returns:
+ tuple[Tensor]: results of head containing the following tensor.
+
+ - layers_cls_scores (Tensor): Outputs from the classification head,
+ shape (num_hidden_states, bs, num_queries, cls_out_channels).
+ Note cls_out_channels should include background.
+ - layers_bbox_preds (Tensor): Sigmoid outputs from the regression
+ head with normalized coordinate format (cx, cy, w, h), has shape
+ (num_hidden_states, bs, num_queries, 4).
+ """
+
+ references_unsigmoid = inverse_sigmoid(references)
+ layers_outputs_coords = []
+ for layer_id in range(hidden_states.shape[0]):
+ tmp_reg_preds = self.fc_reg(
+ self.activate(self.reg_ffn(hidden_states[layer_id])))
+ tmp_reg_preds[..., :2] += references_unsigmoid
+ outputs_coord = tmp_reg_preds.sigmoid()
+ layers_outputs_coords.append(outputs_coord)
+ layers_outputs_coords = torch.stack(layers_outputs_coords)
+
+ layers_cls_scores = self.fc_cls(hidden_states)
+ return layers_cls_scores, layers_outputs_coords
+
+ def loss(self, hidden_states: Tensor, references: Tensor,
+ batch_data_samples: SampleList) -> dict:
+ """Perform forward propagation and loss calculation of the detection
+ head on the features of the upstream network.
+
+ Args:
+ hidden_states (Tensor): Feature from the transformer decoder, has
+ shape (num_decoder_layers, bs, num_queries, dim)
+ references (Tensor): references from the transformer decoder, has
+ shape (num_decoder_layers, bs, num_queries, 2).
+ batch_data_samples (List[:obj:`DetDataSample`]): The Data
+ Samples. It usually includes information such as
+ `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
+
+ Returns:
+ dict: A dictionary of loss components.
+ """
+ batch_gt_instances = []
+ batch_img_metas = []
+ for data_sample in batch_data_samples:
+ batch_img_metas.append(data_sample.metainfo)
+ batch_gt_instances.append(data_sample.gt_instances)
+
+ outs = self(hidden_states, references)
+ loss_inputs = outs + (batch_gt_instances, batch_img_metas)
+ losses = self.loss_by_feat(*loss_inputs)
+ return losses
+
+ def loss_and_predict(
+ self, hidden_states: Tensor, references: Tensor,
+ batch_data_samples: SampleList) -> Tuple[dict, InstanceList]:
+ """Perform forward propagation of the head, then calculate loss and
+ predictions from the features and data samples. Over-write because
+ img_metas are needed as inputs for bbox_head.
+
+ Args:
+ hidden_states (Tensor): Feature from the transformer decoder, has
+ shape (num_decoder_layers, bs, num_queries, dim).
+ references (Tensor): references from the transformer decoder, has
+ shape (num_decoder_layers, bs, num_queries, 2).
+ batch_data_samples (list[:obj:`DetDataSample`]): Each item contains
+ the meta information of each image and corresponding
+ annotations.
+
+ Returns:
+ tuple: the return value is a tuple contains:
+
+ - losses: (dict[str, Tensor]): A dictionary of loss components.
+ - predictions (list[:obj:`InstanceData`]): Detection
+ results of each image after the post process.
+ """
+ batch_gt_instances = []
+ batch_img_metas = []
+ for data_sample in batch_data_samples:
+ batch_img_metas.append(data_sample.metainfo)
+ batch_gt_instances.append(data_sample.gt_instances)
+
+ outs = self(hidden_states, references)
+ loss_inputs = outs + (batch_gt_instances, batch_img_metas)
+ losses = self.loss_by_feat(*loss_inputs)
+
+ predictions = self.predict_by_feat(
+ *outs, batch_img_metas=batch_img_metas)
+ return losses, predictions
+
+ def predict(self,
+ hidden_states: Tensor,
+ references: Tensor,
+ batch_data_samples: SampleList,
+ rescale: bool = True) -> InstanceList:
+ """Perform forward propagation of the detection head and predict
+ detection results on the features of the upstream network. Over-write
+ because img_metas are needed as inputs for bbox_head.
+
+ Args:
+ hidden_states (Tensor): Feature from the transformer decoder, has
+ shape (num_decoder_layers, bs, num_queries, dim).
+ references (Tensor): references from the transformer decoder, has
+ shape (num_decoder_layers, bs, num_queries, 2).
+ batch_data_samples (List[:obj:`DetDataSample`]): The Data
+ Samples. It usually includes information such as
+ `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
+ rescale (bool, optional): Whether to rescale the results.
+ Defaults to True.
+
+ Returns:
+ list[obj:`InstanceData`]: Detection results of each image
+ after the post process.
+ """
+ batch_img_metas = [
+ data_samples.metainfo for data_samples in batch_data_samples
+ ]
+
+ last_layer_hidden_state = hidden_states[-1].unsqueeze(0)
+ outs = self(last_layer_hidden_state, references)
+
+ predictions = self.predict_by_feat(
+ *outs, batch_img_metas=batch_img_metas, rescale=rescale)
+
+ return predictions
diff --git a/mmdet/models/dense_heads/deformable_detr_head.py b/mmdet/models/dense_heads/deformable_detr_head.py
index 641a2839170..86821211118 100644
--- a/mmdet/models/dense_heads/deformable_detr_head.py
+++ b/mmdet/models/dense_heads/deformable_detr_head.py
@@ -4,22 +4,21 @@
import torch
import torch.nn as nn
-import torch.nn.functional as F
from mmcv.cnn import Linear
from mmengine.model import bias_init_with_prob, constant_init
from torch import Tensor
from mmdet.registry import MODELS
-from mmdet.utils import InstanceList, OptConfigType, OptInstanceList
+from mmdet.structures import SampleList
+from mmdet.utils import InstanceList, OptInstanceList
from ..layers import inverse_sigmoid
-from ..utils import multi_apply
from .detr_head import DETRHead
@MODELS.register_module()
class DeformableDETRHead(DETRHead):
- """Head of DeformDETR: Deformable DETR: Deformable Transformers for End-to-
- End Object Detection.
+ r"""Head of DeformDETR: Deformable DETR: Deformable Transformers for
+ End-to-End Object Detection.
Code is modified from the `official github repo
`_.
@@ -28,30 +27,28 @@ class DeformableDETRHead(DETRHead):
`_ .
Args:
- with_box_refine (bool): Whether to refine the reference points
- in the decoder. Defaults to False.
- as_two_stage (bool) : Whether to generate the proposal from
- the outputs of encoder.
- transformer (obj:`ConfigDict`): ConfigDict is used for building
- the Encoder and Decoder.
+ share_pred_layer (bool): Whether to share parameters for all the
+ prediction layers. Defaults to `False`.
+ num_pred_layer (int): The number of the prediction layers.
+ Defaults to 6.
+ as_two_stage (bool, optional): Whether to generate the proposal
+ from the outputs of encoder. Defaults to `False`.
"""
def __init__(self,
*args,
- with_box_refine: bool = False,
+ share_pred_layer: bool = False,
+ num_pred_layer: int = 6,
as_two_stage: bool = False,
- transformer: OptConfigType = None,
**kwargs) -> None:
- self.with_box_refine = with_box_refine
+ self.share_pred_layer = share_pred_layer
+ self.num_pred_layer = num_pred_layer
self.as_two_stage = as_two_stage
- if self.as_two_stage:
- transformer['as_two_stage'] = self.as_two_stage
- super().__init__(*args, transformer=transformer, **kwargs)
+ super().__init__(*args, **kwargs)
def _init_layers(self) -> None:
"""Initialize classification branch and regression branch of head."""
-
fc_cls = Linear(self.embed_dims, self.cls_out_channels)
reg_branch = []
for _ in range(self.num_reg_fcs):
@@ -60,31 +57,20 @@ def _init_layers(self) -> None:
reg_branch.append(Linear(self.embed_dims, 4))
reg_branch = nn.Sequential(*reg_branch)
- def _get_clones(module, N):
- return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
-
- # last reg_branch is used to generate proposal from
- # encode feature map when as_two_stage is True.
- num_pred = (self.transformer.decoder.num_layers + 1) if \
- self.as_two_stage else self.transformer.decoder.num_layers
-
- if self.with_box_refine:
- self.cls_branches = _get_clones(fc_cls, num_pred)
- self.reg_branches = _get_clones(reg_branch, num_pred)
- else:
-
+ if self.share_pred_layer:
self.cls_branches = nn.ModuleList(
- [fc_cls for _ in range(num_pred)])
+ [fc_cls for _ in range(self.num_pred_layer)])
self.reg_branches = nn.ModuleList(
- [reg_branch for _ in range(num_pred)])
-
- if not self.as_two_stage:
- self.query_embedding = nn.Embedding(self.num_query,
- self.embed_dims * 2)
+ [reg_branch for _ in range(self.num_pred_layer)])
+ else:
+ self.cls_branches = nn.ModuleList(
+ [copy.deepcopy(fc_cls) for _ in range(self.num_pred_layer)])
+ self.reg_branches = nn.ModuleList([
+ copy.deepcopy(reg_branch) for _ in range(self.num_pred_layer)
+ ])
def init_weights(self) -> None:
- """Initialize weights of the DeformDETR head."""
- self.transformer.init_weights()
+ """Initialize weights of the Deformable DETR head."""
if self.loss_cls.use_sigmoid:
bias_init = bias_init_with_prob(0.01)
for m in self.cls_branches:
@@ -96,121 +82,135 @@ def init_weights(self) -> None:
for m in self.reg_branches:
nn.init.constant_(m[-1].bias.data[2:], 0.0)
- def forward(self, x: Tuple[Tensor],
- batch_img_metas: List[dict]) -> Tuple[Tensor, ...]:
+ def forward(self, hidden_states: Tensor,
+ references: List[Tensor]) -> Tuple[Tensor]:
"""Forward function.
Args:
- x (tuple[Tensor]): Features from the upstream network, each is
- a 4D-tensor.
- batch_img_metas (list[dict]): Meta information of each image, e.g.,
- image size, scaling factor, etc.
+ hidden_states (Tensor): Hidden states output from each decoder
+ layer, has shape (num_decoder_layers, num_queries, bs, dim).
+ references (list[Tensor]): List of the reference from the decoder.
+ The first reference is the `init_reference` (initial) and the
+ other num_decoder_layers(6) references are `inter_references`
+ (intermediate). The `init_reference` has shape
+ (bs, num_queries, 4) when `as_two_stage` of the detector is
+ `True`, otherwise (bs, num_queries, 2). Each `inter_reference`
+ has shape (bs, num_queries, 4) when `with_box_refine` of the
+ detector is `True`, otherwise (bs, num_queries, 2).
Returns:
- tuple[Tensor]:
-
- - all_cls_scores (Tensor): Outputs from the classification head,
- shape [nb_dec, bs, num_query, cls_out_channels].
- - cls_out_channels should includes background.
- - all_bbox_preds (Tensor): Sigmoid outputs from the regression
- head with normalized coordinate format (cx, cy, w, h).
- Shape [nb_dec, bs, num_query, 4].
- - enc_outputs_class (Tensor): The score of each point on encode
- feature map, has shape (N, h*w, num_class). Only when
- as_two_stage is True it would be returned, otherwise `None`
- would be returned.
- - enc_outputs_coord (Tensor): The proposal generate from the
- encode feature map, has shape (N, h*w, 4). Only when
- as_two_stage is True it would be returned, otherwise `None`
- would be returned.
+ tuple[Tensor]: results of head containing the following tensor.
+
+ - all_layers_outputs_classes (Tensor): Outputs from the
+ classification head, has shape (num_decoder_layers, bs,
+ num_queries, cls_out_channels).
+ - all_layers_outputs_coords (Tensor): Sigmoid outputs from the
+ regression head with normalized coordinate format (cx, cy, w,
+ h), has shape (num_decoder_layers, bs, num_queries, 4).
"""
-
- batch_size = x[0].size(0)
- input_img_h, input_img_w = batch_img_metas[0]['batch_input_shape']
- img_masks = x[0].new_ones((batch_size, input_img_h, input_img_w))
- for img_id in range(batch_size):
- img_h, img_w = batch_img_metas[img_id]['img_shape']
- img_masks[img_id, :img_h, :img_w] = 0
-
- mlvl_masks = []
- mlvl_positional_encodings = []
- for feat in x:
- mlvl_masks.append(
- F.interpolate(img_masks[None],
- size=feat.shape[-2:]).to(torch.bool).squeeze(0))
- mlvl_positional_encodings.append(
- self.positional_encoding(mlvl_masks[-1]))
-
- query_embeds = None
- if not self.as_two_stage:
- query_embeds = self.query_embedding.weight
- hs, init_reference, inter_references, \
- enc_outputs_class, enc_outputs_coord = self.transformer(
- x,
- mlvl_masks,
- query_embeds,
- mlvl_positional_encodings,
- reg_branches=self.reg_branches if self.with_box_refine else None, # noqa:E501
- cls_branches=self.cls_branches if self.as_two_stage else None # noqa:E501
- )
- hs = hs.permute(0, 2, 1, 3)
- outputs_classes = []
- outputs_coords = []
-
- for lvl in range(hs.shape[0]):
- if lvl == 0:
- reference = init_reference
- else:
- reference = inter_references[lvl - 1]
- reference = inverse_sigmoid(reference)
- outputs_class = self.cls_branches[lvl](hs[lvl])
- tmp = self.reg_branches[lvl](hs[lvl])
+ # (num_decoder_layers, bs, num_queries, dim)
+ hidden_states = hidden_states.permute(0, 2, 1, 3)
+ all_layers_outputs_classes = []
+ all_layers_outputs_coords = []
+
+ for layer_id in range(hidden_states.shape[0]):
+ reference = inverse_sigmoid(references[layer_id])
+ # NOTE The last reference will not be used.
+ hidden_state = hidden_states[layer_id]
+ outputs_class = self.cls_branches[layer_id](hidden_state)
+ tmp_reg_preds = self.reg_branches[layer_id](hidden_state)
if reference.shape[-1] == 4:
- tmp += reference
+ # When `layer` is 0 and `as_two_stage` of the detector
+ # is `True`, or when `layer` is greater than 0 and
+ # `with_box_refine` of the detector is `True`.
+ tmp_reg_preds += reference
else:
+ # When `layer` is 0 and `as_two_stage` of the detector
+ # is `False`, or when `layer` is greater than 0 and
+ # `with_box_refine` of the detector is `False`.
assert reference.shape[-1] == 2
- tmp[..., :2] += reference
- outputs_coord = tmp.sigmoid()
- outputs_classes.append(outputs_class)
- outputs_coords.append(outputs_coord)
+ tmp_reg_preds[..., :2] += reference
+ outputs_coord = tmp_reg_preds.sigmoid()
+ all_layers_outputs_classes.append(outputs_class)
+ all_layers_outputs_coords.append(outputs_coord)
- outputs_classes = torch.stack(outputs_classes)
- outputs_coords = torch.stack(outputs_coords)
- if self.as_two_stage:
- return outputs_classes, outputs_coords, \
- enc_outputs_class, \
- enc_outputs_coord.sigmoid()
- else:
- return outputs_classes, outputs_coords, \
- None, None
+ all_layers_outputs_classes = torch.stack(all_layers_outputs_classes)
+ all_layers_outputs_coords = torch.stack(all_layers_outputs_coords)
+
+ return all_layers_outputs_classes, all_layers_outputs_coords
+
+ def loss(self, hidden_states: Tensor, references: List[Tensor],
+ enc_outputs_class: Tensor, enc_outputs_coord: Tensor,
+ batch_data_samples: SampleList) -> dict:
+ """Perform forward propagation and loss calculation of the detection
+ head on the queries of the upstream network.
+
+ Args:
+ hidden_states (Tensor): Hidden states output from each decoder
+ layer, has shape (num_decoder_layers, num_queries, bs, dim).
+ references (list[Tensor]): List of the reference from the decoder.
+ The first reference is the `init_reference` (initial) and the
+ other num_decoder_layers(6) references are `inter_references`
+ (intermediate). The `init_reference` has shape
+ (bs, num_queries, 4) when `as_two_stage` of the detector is
+ `True`, otherwise (bs, num_queries, 2). Each `inter_reference`
+ has shape (bs, num_queries, 4) when `with_box_refine` of the
+ detector is `True`, otherwise (bs, num_queries, 2).
+ enc_outputs_class (Tensor): The score of each point on encode
+ feature map, has shape (bs, num_feat, cls_out_channels).
+ Only when `as_two_stage` is `True` it would be returned,
+ otherwise `None` would be returned.
+ enc_outputs_coord (Tensor): The proposal generate from the
+ encode feature map, has shape (bs, num_feat, 4). Only when
+ `as_two_stage` is `True` it would be returned, otherwise
+ `None` would be returned.
+ batch_data_samples (list[:obj:`DetDataSample`]): The Data
+ Samples. It usually includes information such as
+ `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
+
+ Returns:
+ dict: A dictionary of loss components.
+ """
+ batch_gt_instances = []
+ batch_img_metas = []
+ for data_sample in batch_data_samples:
+ batch_img_metas.append(data_sample.metainfo)
+ batch_gt_instances.append(data_sample.gt_instances)
+
+ outs = self(hidden_states, references)
+ loss_inputs = outs + (enc_outputs_class, enc_outputs_coord,
+ batch_gt_instances, batch_img_metas)
+ losses = self.loss_by_feat(*loss_inputs)
+ return losses
def loss_by_feat(
self,
- all_cls_scores: Tensor,
- all_bbox_preds: Tensor,
+ all_layers_cls_scores: Tensor,
+ all_layers_bbox_preds: Tensor,
enc_cls_scores: Tensor,
enc_bbox_preds: Tensor,
batch_gt_instances: InstanceList,
batch_img_metas: List[dict],
batch_gt_instances_ignore: OptInstanceList = None
) -> Dict[str, Tensor]:
- """"Loss function.
+ """Loss function.
Args:
- all_cls_scores (Tensor): Classification score of all
- decoder layers, has shape
- [nb_dec, bs, num_query, cls_out_channels].
- all_bbox_preds (Tensor): Sigmoid regression
- outputs of all decode layers. Each is a 4D-tensor with
- normalized coordinate format (cx, cy, w, h) and shape
- [nb_dec, bs, num_query, 4].
- enc_cls_scores (Tensor): Classification scores of
- points on encode feature map , has shape
- (N, h*w, num_classes). Only be passed when as_two_stage is
- True, otherwise is None.
- enc_bbox_preds (Tensor): Regression results of each points
- on the encode feature map, has shape (N, h*w, 4). Only be
- passed when as_two_stage is True, otherwise is None.
+ all_layers_cls_scores (Tensor): Classification scores of all
+ decoder layers, has shape (num_decoder_layers, bs, num_queries,
+ cls_out_channels).
+ all_layers_bbox_preds (Tensor): Regression outputs of all decode
+ layers. Each is a 4D-tensor with normalized coordinate format
+ (cx, cy, w, h) and has shape (num_decoder_layers, bs,
+ num_queries, 4).
+ enc_cls_scores (Tensor): The score of each point on encode
+ feature map, has shape (bs, num_feat, cls_out_channels).
+ Only when `as_two_stage` is `True` it would be returned,
+ otherwise `None` would be returned.
+ enc_bbox_preds (Tensor): The proposal generate from the
+ encode feature map, has shape (bs, num_feat, 4). Only when
+ `as_two_stage` is `True` it would be returned, otherwise
+ `None` would be returned.
batch_gt_instances (list[:obj:`InstanceData`]): Batch of
gt_instance. It usually includes ``bboxes`` and ``labels``
attributes.
@@ -224,87 +224,94 @@ def loss_by_feat(
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
- assert batch_gt_instances_ignore is None, \
- f'{self.__class__.__name__} only supports ' \
- f'for gt_bboxes_ignore setting to None.'
+ loss_dict = super().loss_by_feat(all_layers_cls_scores,
+ all_layers_bbox_preds,
+ batch_gt_instances, batch_img_metas,
+ batch_gt_instances_ignore)
- num_dec_layers = len(all_cls_scores)
- batch_gt_instances_list = [
- batch_gt_instances for _ in range(num_dec_layers)
- ]
- batch_img_metas_list = [batch_img_metas for _ in range(num_dec_layers)]
-
- losses_cls, losses_bbox, losses_iou = multi_apply(
- self.loss_by_feat_single, all_cls_scores, all_bbox_preds,
- batch_gt_instances_list, batch_img_metas_list)
-
- loss_dict = dict()
# loss of proposal generated from encode feature map.
if enc_cls_scores is not None:
- for i in range(len(batch_img_metas)):
- batch_gt_instances[i].labels = torch.zeros_like(
- batch_gt_instances[i].labels)
+ proposal_gt_instances = copy.deepcopy(batch_gt_instances)
+ for i in range(len(proposal_gt_instances)):
+ proposal_gt_instances[i].labels = torch.zeros_like(
+ proposal_gt_instances[i].labels)
enc_loss_cls, enc_losses_bbox, enc_losses_iou = \
- self.loss_single(enc_cls_scores, enc_bbox_preds,
- batch_gt_instances, batch_img_metas)
+ self.loss_by_feat_single(
+ enc_cls_scores, enc_bbox_preds,
+ batch_gt_instances=proposal_gt_instances,
+ batch_img_metas=batch_img_metas)
loss_dict['enc_loss_cls'] = enc_loss_cls
loss_dict['enc_loss_bbox'] = enc_losses_bbox
loss_dict['enc_loss_iou'] = enc_losses_iou
-
- # loss from the last decoder layer
- loss_dict['loss_cls'] = losses_cls[-1]
- loss_dict['loss_bbox'] = losses_bbox[-1]
- loss_dict['loss_iou'] = losses_iou[-1]
- # loss from other decoder layers
- num_dec_layer = 0
- for loss_cls_i, loss_bbox_i, loss_iou_i in zip(losses_cls[:-1],
- losses_bbox[:-1],
- losses_iou[:-1]):
- loss_dict[f'd{num_dec_layer}.loss_cls'] = loss_cls_i
- loss_dict[f'd{num_dec_layer}.loss_bbox'] = loss_bbox_i
- loss_dict[f'd{num_dec_layer}.loss_iou'] = loss_iou_i
- num_dec_layer += 1
return loss_dict
+ def predict(self,
+ hidden_states: Tensor,
+ references: List[Tensor],
+ batch_data_samples: SampleList,
+ rescale: bool = True,
+ **kwargs) -> InstanceList:
+ """Perform forward propagation and loss calculation of the detection
+ head on the queries of the upstream network.
+
+ Args:
+ hidden_states (Tensor): Hidden states output from each decoder
+ layer, has shape (num_decoder_layers, num_queries, bs, dim).
+ references (list[Tensor]): List of the reference from the decoder.
+ The first reference is the `init_reference` (initial) and the
+ other num_decoder_layers(6) references are `inter_references`
+ (intermediate). The `init_reference` has shape
+ (bs, num_queries, 4) when `as_two_stage` of the detector
+ is `True`, otherwise (bs, num_queries, 2). Each
+ `inter_reference` has shape (bs, num_queries, 4) when
+ `with_box_refine` of the detector is `True`, otherwise
+ (bs, num_queries, 2).
+ batch_data_samples (list[:obj:`DetDataSample`]): The Data
+ Samples. It usually includes information such as
+ `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
+ rescale (bool, optional): If `True`, return boxes in original
+ image space. Defaults to `True`.
+
+ Returns:
+ list[obj:`InstanceData`]: Detection results of each image
+ after the post process.
+ """
+ batch_img_metas = [
+ data_samples.metainfo for data_samples in batch_data_samples
+ ]
+
+ outs = self(hidden_states, references)
+
+ predictions = self.predict_by_feat(
+ *outs, batch_img_metas=batch_img_metas, rescale=rescale)
+ return predictions
+
def predict_by_feat(self,
- all_cls_scores: Tensor,
- all_bbox_preds: Tensor,
- enc_cls_scores: Tensor,
- enc_bbox_preds: Tensor,
+ all_layers_cls_scores: Tensor,
+ all_layers_bbox_preds: Tensor,
batch_img_metas: List[Dict],
rescale: bool = False) -> InstanceList:
"""Transform a batch of output features extracted from the head into
bbox results.
Args:
- all_cls_scores (Tensor): Classification score of all
- decoder layers, has shape
- [nb_dec, bs, num_query, cls_out_channels].
- all_bbox_preds (Tensor): Sigmoid regression
- outputs of all decode layers. Each is a 4D-tensor with
- normalized coordinate format (cx, cy, w, h) and shape
- [nb_dec, bs, num_query, 4].
- enc_cls_scores (Tensor): Classification scores of
- points on encode feature map , has shape
- (N, h*w, num_classes). Only be passed when as_two_stage is
- True, otherwise is None.
- enc_bbox_preds (Tensor): Regression results of each points
- on the encode feature map, has shape (N, h*w, 4). Only be
- passed when as_two_stage is True, otherwise is None.
+ all_layers_cls_scores (Tensor): Classification scores of all
+ decoder layers, has shape (num_decoder_layers, bs, num_queries,
+ cls_out_channels).
+ all_layers_bbox_preds (Tensor): Regression outputs of all decode
+ layers. Each is a 4D-tensor with normalized coordinate format
+ (cx, cy, w, h) and shape (num_decoder_layers, bs,
+ num_queries, 4).
batch_img_metas (list[dict]): Meta information of each image.
- rescale (bool, optional): If True, return boxes in original
- image space. Default False.
+ rescale (bool, optional): If `True`, return boxes in original
+ image space. Default `False`.
Returns:
- list[list[Tensor, Tensor]]: Each item in result_list is 2-tuple. \
- The first item is an (n, 5) tensor, where the first 4 columns \
- are bounding box positions (tl_x, tl_y, br_x, br_y) and the \
- 5-th column is a score between 0 and 1. The second item is a \
- (n,) tensor where each item is the predicted class label of \
- the corresponding box.
+ list[obj:`InstanceData`]: Detection results of each image
+ after the post process.
"""
- cls_scores = all_cls_scores[-1]
- bbox_preds = all_bbox_preds[-1]
+ cls_scores = all_layers_cls_scores[-1]
+ bbox_preds = all_layers_bbox_preds[-1]
result_list = []
for img_id in range(len(batch_img_metas)):
diff --git a/mmdet/models/dense_heads/detr_head.py b/mmdet/models/dense_heads/detr_head.py
index acf82a4c7bd..f42e45d2ed0 100644
--- a/mmdet/models/dense_heads/detr_head.py
+++ b/mmdet/models/dense_heads/detr_head.py
@@ -1,55 +1,49 @@
# Copyright (c) OpenMMLab. All rights reserved.
-from typing import Dict, List, Optional, Tuple, Union
+from typing import Dict, List, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
-from mmcv.cnn import Conv2d, Linear, build_activation_layer
-from mmcv.cnn.bricks.transformer import FFN, build_positional_encoding
+from mmcv.cnn import Linear
+from mmcv.cnn.bricks.transformer import FFN
+from mmengine.model import BaseModule
from mmengine.structures import InstanceData
from torch import Tensor
from mmdet.registry import MODELS, TASK_UTILS
from mmdet.structures import SampleList
from mmdet.structures.bbox import bbox_cxcywh_to_xyxy, bbox_xyxy_to_cxcywh
-from mmdet.utils import (ConfigType, InstanceList, OptConfigType,
- OptInstanceList, OptMultiConfig, reduce_mean)
+from mmdet.utils import (ConfigType, InstanceList, OptInstanceList,
+ OptMultiConfig, reduce_mean)
from ..utils import multi_apply
-from .anchor_free_head import AnchorFreeHead
@MODELS.register_module()
-class DETRHead(AnchorFreeHead):
- """Implements the DETR transformer head.
+class DETRHead(BaseModule):
+ r"""Head of DETR. DETR:End-to-End Object Detection with Transformers.
- See `paper: End-to-End Object Detection with Transformers
- `_ for details.
+ More details can be found in the `paper
+ `_ .
Args:
num_classes (int): Number of categories excluding the background.
- in_channels (int): Number of channels in the input feature map.
- num_query (int): Number of query in Transformer. Defaults to 100.
- num_reg_fcs (int): Number of fully-connected layers used in
- `FFN`, which is then used for the regression head.
- Defaults to 2.
- transformer (:obj:`ConfigDict` or dict, optional): Config for
- transformer. Defaults to None.
- sync_cls_avg_factor (bool): Whether to sync the avg_factor of all
- ranks. Defaults to False.
- positional_encoding (:obj:`ConfigDict` or dict): Config for position
- encoding.
+ embed_dims (int): The dims of Transformer embedding.
+ num_reg_fcs (int): Number of fully-connected layers used in `FFN`,
+ which is then used for the regression head. Defaults to 2.
+ sync_cls_avg_factor (bool): Whether to sync the `avg_factor` of
+ all ranks. Default to `False`.
loss_cls (:obj:`ConfigDict` or dict): Config of the classification
loss. Defaults to `CrossEntropyLoss`.
- loss_bbox (:obj:`ConfigDict` or dict): Config of the regression loss.
- Defaults to `L1Loss`.
+ loss_bbox (:obj:`ConfigDict` or dict): Config of the regression bbox
+ loss. Defaults to `L1Loss`.
loss_iou (:obj:`ConfigDict` or dict): Config of the regression iou
loss. Defaults to `GIoULoss`.
- tran_cfg (:obj:`ConfigDict` or dict): Training config of transformer
+ train_cfg (:obj:`ConfigDict` or dict): Training config of transformer
head.
test_cfg (:obj:`ConfigDict` or dict): Testing config of transformer
head.
- init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \
- dict], optional): Initialization config dict. Defaults to None.
+ init_cfg (:obj:`ConfigDict` or dict, optional): the config to control
+ the initialization. Defaults to None.
"""
_version = 2
@@ -57,13 +51,9 @@ class DETRHead(AnchorFreeHead):
def __init__(
self,
num_classes: int,
- in_channels: int,
- num_query: int = 100,
+ embed_dims: int = 256,
num_reg_fcs: int = 2,
- transformer: OptConfigType = None,
sync_cls_avg_factor: bool = False,
- positional_encoding: ConfigType = dict(
- type='SinePositionalEncoding', num_feats=128, normalize=True),
loss_cls: ConfigType = dict(
type='CrossEntropyLoss',
bg_cls_weight=0.1,
@@ -82,10 +72,7 @@ def __init__(
])),
test_cfg: ConfigType = dict(max_per_img=100),
init_cfg: OptMultiConfig = None) -> None:
- # NOTE here use `AnchorFreeHead` instead of `TransformerHead`,
- # since it brings inconvenience when the initialization of
- # `AnchorFreeHead` is called.
- super(AnchorFreeHead, self).__init__(init_cfg=init_cfg)
+ super().__init__(init_cfg=init_cfg)
self.bg_cls_weight = 0
self.sync_cls_avg_factor = sync_cls_avg_factor
class_weight = loss_cls.get('class_weight', None)
@@ -108,15 +95,14 @@ def __init__(
self.bg_cls_weight = bg_cls_weight
if train_cfg:
- assert 'assigner' in train_cfg, 'assigner should be provided '\
- 'when train_cfg is set.'
+ assert 'assigner' in train_cfg, 'assigner should be provided ' \
+ 'when train_cfg is set.'
assigner = train_cfg['assigner']
self.assigner = TASK_UTILS.build(assigner)
if train_cfg.get('sampler', None) is not None:
raise RuntimeError('DETR do not build sampler.')
- self.num_query = num_query
self.num_classes = num_classes
- self.in_channels = in_channels
+ self.embed_dims = embed_dims
self.num_reg_fcs = num_reg_fcs
self.train_cfg = train_cfg
self.test_cfg = test_cfg
@@ -128,153 +114,85 @@ def __init__(
self.cls_out_channels = num_classes
else:
self.cls_out_channels = num_classes + 1
- self.act_cfg = transformer.get('act_cfg',
- dict(type='ReLU', inplace=True))
- self.activate = build_activation_layer(self.act_cfg)
- self.positional_encoding = build_positional_encoding(
- positional_encoding)
- self.transformer = MODELS.build(transformer)
- self.embed_dims = self.transformer.embed_dims
- assert 'num_feats' in positional_encoding
- num_feats = positional_encoding['num_feats']
- assert num_feats * 2 == self.embed_dims, 'embed_dims should' \
- f' be exactly 2 times of num_feats. Found {self.embed_dims}' \
- f' and {num_feats}.'
+
self._init_layers()
def _init_layers(self) -> None:
"""Initialize layers of the transformer head."""
- self.input_proj = Conv2d(
- self.in_channels, self.embed_dims, kernel_size=1)
+ # cls branch
self.fc_cls = Linear(self.embed_dims, self.cls_out_channels)
+ # reg branch
+ self.activate = nn.ReLU()
self.reg_ffn = FFN(
self.embed_dims,
self.embed_dims,
self.num_reg_fcs,
- self.act_cfg,
+ dict(type='ReLU', inplace=True),
dropout=0.0,
add_residual=False)
+ # NOTE the activations of reg_branch here is the same as
+ # those in transformer, but they are actually different
+ # in DAB DETR (prelu in transformer and relu in reg_branch)
self.fc_reg = Linear(self.embed_dims, 4)
- self.query_embedding = nn.Embedding(self.num_query, self.embed_dims)
-
- def init_weights(self) -> None:
- """Initialize weights of the transformer head."""
- # The initialization for transformer is important
- self.transformer.init_weights()
-
- def _load_from_state_dict(self, state_dict: dict, prefix: str,
- local_metadata: dict, strict: bool,
- missing_keys: Union[List[str], str],
- unexpected_keys: Union[List[str], str],
- error_msgs: Union[List[str], str]) -> None:
- """load checkpoints."""
- # NOTE here use `AnchorFreeHead` instead of `TransformerHead`,
- # since `AnchorFreeHead._load_from_state_dict` should not be
- # called here. Invoking the default `Module._load_from_state_dict`
- # is enough.
-
- # Names of some parameters in has been changed.
- version = local_metadata.get('version', None)
- if (version is None or version < 2) and self.__class__ is DETRHead:
- convert_dict = {
- '.self_attn.': '.attentions.0.',
- '.ffn.': '.ffns.0.',
- '.multihead_attn.': '.attentions.1.',
- '.decoder.norm.': '.decoder.post_norm.'
- }
- state_dict_keys = list(state_dict.keys())
- for k in state_dict_keys:
- for ori_key, convert_key in convert_dict.items():
- if ori_key in k:
- convert_key = k.replace(ori_key, convert_key)
- state_dict[convert_key] = state_dict[k]
- del state_dict[k]
-
- super(AnchorFreeHead, self)._load_from_state_dict(
- state_dict=state_dict,
- prefix=prefix,
- local_metadata=local_metadata,
- strict=strict,
- missing_keys=missing_keys,
- unexpected_keys=unexpected_keys,
- error_msgs=error_msgs)
-
- def forward(
- self, x: Tuple[Tensor],
- batch_img_metas: List[dict]) -> Tuple[List[Tensor], List[Tensor]]:
- """Forward function.
- Args:
- x (tuple[Tensor]): Features from the upstream network, each is
- a 4D-tensor.
- batch_img_metas (list[dict]): Meta information of each image, e.g.,
- image size, scaling factor, etc.
+ # Note function _load_from_state_dict is deleted without
+ # supporting refactor-DETR in mmdetection2.0
+ def forward(self, hidden_states: Tensor) -> Tuple[Tensor]:
+ """"Forward function.
+
+ Args:
+ hidden_states (Tensor): Features from transformer decoder. If
+ `return_intermediate_dec` in detr.py is True output has shape
+ (num_hidden_states, bs, num_queries, dim), else has shape (1,
+ bs, num_queries, dim) which only contains the last layer
+ outputs.
Returns:
- tuple[list[Tensor], list[Tensor]]: Outputs for all scale levels.
-
- - all_cls_scores_list (list[Tensor]): Classification scores \
- for each scale level. Each is a 4D-tensor with shape \
- [nb_dec, bs, num_query, cls_out_channels]. Note \
- `cls_out_channels` should includes background.
- - all_bbox_preds_list (list[Tensor]): Sigmoid regression \
- outputs for each scale level. Each is a 4D-tensor with \
- normalized coordinate format (cx, cy, w, h) and shape \
- [nb_dec, bs, num_query, 4].
+ tuple[Tensor]: results of head containing the following tensor.
+
+ - layers_cls_scores (Tensor): Outputs from the classification head,
+ shape (num_hidden_states, bs, num_queries, cls_out_channels).
+ Note cls_out_channels should include background.
+ - layers_bbox_preds (Tensor): Sigmoid outputs from the regression
+ head with normalized coordinate format (cx, cy, w, h), has shape
+ (num_hidden_states, bs, num_queries, 4).
"""
- num_levels = len(x)
- batch_img_metas_list = [batch_img_metas for _ in range(num_levels)]
- return multi_apply(self.forward_single, x, batch_img_metas_list)
+ layers_cls_scores = self.fc_cls(hidden_states)
+ layers_bbox_preds = self.fc_reg(
+ self.activate(self.reg_ffn(hidden_states))).sigmoid()
+ return layers_cls_scores, layers_bbox_preds
- def forward_single(self, x: Tensor,
- batch_img_metas: List[dict]) -> Tuple[Tensor, Tensor]:
- """"Forward function for a single feature level.
+ def loss(self, hidden_states: Tensor,
+ batch_data_samples: SampleList) -> dict:
+ """Perform forward propagation and loss calculation of the detection
+ head on the features of the upstream network.
Args:
- x (Tensor): Input feature from backbone's single stage, shape
- [bs, c, h, w].
- batch_img_metas (list[dict]): Meta information of each image, e.g.,
- image size, scaling factor, etc.
+ hidden_states (Tensor): Feature from the transformer decoder, has
+ shape (num_decoder_layers, bs, num_queries, cls_out_channels)
+ or (num_decoder_layers, num_queries, bs, cls_out_channels).
+ batch_data_samples (List[:obj:`DetDataSample`]): The Data
+ Samples. It usually includes information such as
+ `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
Returns:
- tuple[Tensor]:
-
- - all_cls_scores (Tensor): Outputs from the classification head, \
- shape [nb_dec, bs, num_query, cls_out_channels]. Note \
- cls_out_channels should includes background.
- - all_bbox_preds (Tensor): Sigmoid outputs from the regression \
- head with normalized coordinate format (cx, cy, w, h). \
- Shape [nb_dec, bs, num_query, 4].
+ dict: A dictionary of loss components.
"""
- # construct binary masks which used for the transformer.
- # NOTE following the official DETR repo, non-zero values representing
- # ignored positions, while zero values means valid positions.
- batch_size = x.size(0)
- input_img_h, input_img_w = batch_img_metas[0]['batch_input_shape']
- masks = x.new_ones((batch_size, input_img_h, input_img_w))
- for img_id in range(batch_size):
- img_h, img_w, = batch_img_metas[img_id]['img_shape']
- masks[img_id, :img_h, :img_w] = 0
-
- x = self.input_proj(x)
- # interpolate masks to have the same spatial shape with x
- masks = F.interpolate(
- masks.unsqueeze(1), size=x.shape[-2:]).to(torch.bool).squeeze(1)
- # position encoding
- pos_embed = self.positional_encoding(masks) # [bs, embed_dim, h, w]
- # outs_dec: [nb_dec, bs, num_query, embed_dim]
- outs_dec, _ = self.transformer(x, masks, self.query_embedding.weight,
- pos_embed)
-
- all_cls_scores = self.fc_cls(outs_dec)
- all_bbox_preds = self.fc_reg(self.activate(
- self.reg_ffn(outs_dec))).sigmoid()
- return all_cls_scores, all_bbox_preds
+ batch_gt_instances = []
+ batch_img_metas = []
+ for data_sample in batch_data_samples:
+ batch_img_metas.append(data_sample.metainfo)
+ batch_gt_instances.append(data_sample.gt_instances)
+
+ outs = self(hidden_states)
+ loss_inputs = outs + (batch_gt_instances, batch_img_metas)
+ losses = self.loss_by_feat(*loss_inputs)
+ return losses
def loss_by_feat(
self,
- all_cls_scores_list: List[Tensor],
- all_bbox_preds_list: List[Tensor],
+ all_layers_cls_scores: Tensor,
+ all_layers_bbox_preds: Tensor,
batch_gt_instances: InstanceList,
batch_img_metas: List[dict],
batch_gt_instances_ignore: OptInstanceList = None
@@ -285,13 +203,13 @@ def loss_by_feat(
losses by default.
Args:
- all_cls_scores_list (list[Tensor]): Classification outputs
- for each feature level. Each is a 4D-tensor with shape
- [nb_dec, bs, num_query, cls_out_channels].
- all_bbox_preds_list (list[Tensor]): Sigmoid regression
- outputs for each feature level. Each is a 4D-tensor with
+ all_layers_cls_scores (Tensor): Classification outputs
+ of each decoder layers. Each is a 4D-tensor, has shape
+ (num_decoder_layers, bs, num_queries, cls_out_channels).
+ all_layers_bbox_preds (Tensor): Sigmoid regression
+ outputs of each decoder layers. Each is a 4D-tensor with
normalized coordinate format (cx, cy, w, h) and shape
- [nb_dec, bs, num_query, 4].
+ (num_decoder_layers, bs, num_queries, 4).
batch_gt_instances (list[:obj:`InstanceData`]): Batch of
gt_instance. It usually includes ``bboxes`` and ``labels``
attributes.
@@ -305,21 +223,16 @@ def loss_by_feat(
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
- # NOTE defaultly only the outputs from the last feature scale is used.
- all_cls_scores = all_cls_scores_list[-1]
- all_bbox_preds = all_bbox_preds_list[-1]
assert batch_gt_instances_ignore is None, \
- 'Only supports for batch_gt_instances_ignore setting to None.'
-
- num_dec_layers = len(all_cls_scores)
- batch_gt_instances_list = [
- batch_gt_instances for _ in range(num_dec_layers)
- ]
- batch_img_metas_list = [batch_img_metas for _ in range(num_dec_layers)]
+ f'{self.__class__.__name__} only supports ' \
+ 'for batch_gt_instances_ignore setting to None.'
losses_cls, losses_bbox, losses_iou = multi_apply(
- self.loss_by_feat_single, all_cls_scores, all_bbox_preds,
- batch_gt_instances_list, batch_img_metas_list)
+ self.loss_by_feat_single,
+ all_layers_cls_scores,
+ all_layers_bbox_preds,
+ batch_gt_instances=batch_gt_instances,
+ batch_img_metas=batch_img_metas)
loss_dict = dict()
# loss from the last decoder layer
@@ -328,9 +241,8 @@ def loss_by_feat(
loss_dict['loss_iou'] = losses_iou[-1]
# loss from other decoder layers
num_dec_layer = 0
- for loss_cls_i, loss_bbox_i, loss_iou_i in zip(losses_cls[:-1],
- losses_bbox[:-1],
- losses_iou[:-1]):
+ for loss_cls_i, loss_bbox_i, loss_iou_i in \
+ zip(losses_cls[:-1], losses_bbox[:-1], losses_iou[:-1]):
loss_dict[f'd{num_dec_layer}.loss_cls'] = loss_cls_i
loss_dict[f'd{num_dec_layer}.loss_bbox'] = loss_bbox_i
loss_dict[f'd{num_dec_layer}.loss_iou'] = loss_iou_i
@@ -345,10 +257,10 @@ def loss_by_feat_single(self, cls_scores: Tensor, bbox_preds: Tensor,
Args:
cls_scores (Tensor): Box score logits from a single decoder layer
- for all images. Shape [bs, num_query, cls_out_channels].
+ for all images, has shape (bs, num_queries, cls_out_channels).
bbox_preds (Tensor): Sigmoid outputs from a single decoder layer
for all images, with normalized coordinate (cx, cy, w, h) and
- shape [bs, num_query, 4].
+ shape (bs, num_queries, 4).
batch_gt_instances (list[:obj:`InstanceData`]): Batch of
gt_instance. It usually includes ``bboxes`` and ``labels``
attributes.
@@ -356,8 +268,8 @@ def loss_by_feat_single(self, cls_scores: Tensor, bbox_preds: Tensor,
image size, scaling factor, etc.
Returns:
- Tupe[Tensor]: A tuple includes loss_cls, loss_box and
- loss_iou.
+ Tuple[Tensor]: A tuple including `loss_cls`, `loss_box` and
+ `loss_iou`.
"""
num_imgs = cls_scores.size(0)
cls_scores_list = [cls_scores[i] for i in range(num_imgs)]
@@ -425,11 +337,11 @@ def get_targets(self, cls_scores_list: List[Tensor],
Args:
cls_scores_list (list[Tensor]): Box score logits from a single
- decoder layer for each image with shape [num_query,
+ decoder layer for each image, has shape [num_queries,
cls_out_channels].
bbox_preds_list (list[Tensor]): Sigmoid outputs from a single
decoder layer for each image, with normalized coordinate
- (cx, cy, w, h) and shape [num_query, 4].
+ (cx, cy, w, h) and shape [num_queries, 4].
batch_gt_instances (list[:obj:`InstanceData`]): Batch of
gt_instance. It usually includes ``bboxes`` and ``labels``
attributes.
@@ -465,10 +377,10 @@ def _get_targets_single(self, cls_score: Tensor, bbox_pred: Tensor,
Args:
cls_score (Tensor): Box score logits from a single decoder layer
- for one image. Shape [num_query, cls_out_channels].
+ for one image. Shape [num_queries, cls_out_channels].
bbox_pred (Tensor): Sigmoid outputs from a single decoder layer
for one image, with normalized coordinate (cx, cy, w, h) and
- shape [num_query, 4].
+ shape [num_queries, 4].
gt_instances (:obj:`InstanceData`): Ground truth of instance
annotations. It should includes ``bboxes`` and ``labels``
attributes.
@@ -529,49 +441,18 @@ def _get_targets_single(self, cls_score: Tensor, bbox_pred: Tensor,
return (labels, label_weights, bbox_targets, bbox_weights, pos_inds,
neg_inds)
- # over-write because img_metas are needed as inputs for bbox_head.
- def loss(self, x: Tuple[Tensor], batch_data_samples: SampleList) -> dict:
- """Perform forward propagation and loss calculation of the detection
- head on the features of the upstream network.
-
- Args:
- x (tuple[Tensor]): Features from the upstream network, each is
- a 4D-tensor.
- batch_data_samples (List[:obj:`DetDataSample`]): The Data
- Samples. It usually includes information such as
- `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
-
- Returns:
- dict: A dictionary of loss components.
- """
- batch_gt_instances = []
- batch_img_metas = []
- for data_sample in batch_data_samples:
- batch_img_metas.append(data_sample.metainfo)
- batch_gt_instances.append(data_sample.gt_instances)
-
- outs = self(x, batch_img_metas)
- loss_inputs = outs + (batch_gt_instances, batch_img_metas)
- losses = self.loss_by_feat(*loss_inputs)
- return losses
-
- def loss_and_predict(self,
- x: Tuple[Tensor],
- batch_data_samples: SampleList,
- proposal_cfg: Optional[ConfigType] = None) \
- -> Tuple[dict, InstanceList]:
+ def loss_and_predict(
+ self, hidden_states: Tuple[Tensor],
+ batch_data_samples: SampleList) -> Tuple[dict, InstanceList]:
"""Perform forward propagation of the head, then calculate loss and
predictions from the features and data samples. Over-write because
img_metas are needed as inputs for bbox_head.
Args:
- x (tuple[Tensor]): Features from FPN.
+ hidden_states (tuple[Tensor]): Features from FPN.
batch_data_samples (list[:obj:`DetDataSample`]): Each item contains
the meta information of each image and corresponding
annotations.
- proposal_cfg (ConfigDict, optional): Test / postprocessing
- configuration, if None, test_cfg would be used.
- Defaults to None.
Returns:
tuple: the return value is a tuple contains:
@@ -586,7 +467,7 @@ def loss_and_predict(self,
batch_img_metas.append(data_sample.metainfo)
batch_gt_instances.append(data_sample.gt_instances)
- outs = self(x, batch_img_metas)
+ outs = self(hidden_states)
loss_inputs = outs + (batch_gt_instances, batch_img_metas)
losses = self.loss_by_feat(*loss_inputs)
@@ -595,7 +476,7 @@ def loss_and_predict(self,
return losses, predictions
def predict(self,
- x: Tuple[Tensor],
+ hidden_states: Tuple[Tensor],
batch_data_samples: SampleList,
rescale: bool = True) -> InstanceList:
"""Perform forward propagation of the detection head and predict
@@ -603,7 +484,7 @@ def predict(self,
because img_metas are needed as inputs for bbox_head.
Args:
- x (tuple[Tensor]): Multi-level features from the
+ hidden_states (tuple[Tensor]): Multi-level features from the
upstream network, each is a 4D-tensor.
batch_data_samples (List[:obj:`DetDataSample`]): The Data
Samples. It usually includes information such as
@@ -619,30 +500,32 @@ def predict(self,
data_samples.metainfo for data_samples in batch_data_samples
]
- outs = self(x, batch_img_metas)
+ last_layer_hidden_state = hidden_states[-1].unsqueeze(0)
+ outs = self(last_layer_hidden_state)
predictions = self.predict_by_feat(
*outs, batch_img_metas=batch_img_metas, rescale=rescale)
+
return predictions
def predict_by_feat(self,
- all_cls_scores_list: List[Tensor],
- all_bbox_preds_list: List[Tensor],
+ layer_cls_scores: Tensor,
+ layer_bbox_preds: Tensor,
batch_img_metas: List[dict],
rescale: bool = True) -> InstanceList:
"""Transform network outputs for a batch into bbox predictions.
Args:
- all_cls_scores_list (list[Tensor]): Classification outputs
- for each feature level. Each is a 4D-tensor with shape
- [nb_dec, bs, num_query, cls_out_channels].
- all_bbox_preds_list (list[Tensor]): Sigmoid regression
- outputs for each feature level. Each is a 4D-tensor with
- normalized coordinate format (cx, cy, w, h) and shape
- [nb_dec, bs, num_query, 4].
+ layer_cls_scores (Tensor): Classification outputs of the last or
+ all decoder layer. Each is a 4D-tensor, has shape
+ (num_decoder_layers, bs, num_queries, cls_out_channels).
+ layer_bbox_preds (Tensor): Sigmoid regression outputs of the last
+ or all decoder layer. Each is a 4D-tensor with normalized
+ coordinate format (cx, cy, w, h) and shape
+ (num_decoder_layers, bs, num_queries, 4).
batch_img_metas (list[dict]): Meta information of each image.
- rescale (bool, optional): If True, return boxes in original
- image space. Defaults to True.
+ rescale (bool, optional): If `True`, return boxes in original
+ image space. Defaults to `True`.
Returns:
list[:obj:`InstanceData`]: Object detection results of each image
@@ -655,10 +538,10 @@ def predict_by_feat(self,
- bboxes (Tensor): Has a shape (num_instances, 4),
the last dimension 4 arrange as (x1, y1, x2, y2).
"""
- # NOTE defaultly only using outputs from the last feature level,
+ # NOTE only using outputs from the last feature level,
# and only the outputs from the last decoder layer is used.
- cls_scores = all_cls_scores_list[-1][-1]
- bbox_preds = all_bbox_preds_list[-1][-1]
+ cls_scores = layer_cls_scores[-1]
+ bbox_preds = layer_bbox_preds[-1]
result_list = []
for img_id in range(len(batch_img_metas)):
@@ -680,10 +563,10 @@ def _predict_by_feat_single(self,
Args:
cls_score (Tensor): Box score logits from the last decoder layer
- for each image. Shape [num_query, cls_out_channels].
+ for each image. Shape [num_queries, cls_out_channels].
bbox_pred (Tensor): Sigmoid outputs from the last decoder layer
for each image, with coordinate format (cx, cy, w, h) and
- shape [num_query, 4].
+ shape [num_queries, 4].
img_meta (dict): Image meta info.
rescale (bool): If True, return boxes in original image
space. Default True.
@@ -700,8 +583,8 @@ def _predict_by_feat_single(self,
- bboxes (Tensor): Has a shape (num_instances, 4),
the last dimension 4 arrange as (x1, y1, x2, y2).
"""
- assert len(cls_score) == len(bbox_pred)
- max_per_img = self.test_cfg.get('max_per_img', self.num_query)
+ assert len(cls_score) == len(bbox_pred) # num_queries
+ max_per_img = self.test_cfg.get('max_per_img', len(cls_score))
img_shape = img_meta['img_shape']
# exclude background
if self.loss_cls.use_sigmoid:
diff --git a/mmdet/models/detectors/__init__.py b/mmdet/models/detectors/__init__.py
index bf97944d9cf..f5a977ad6cb 100644
--- a/mmdet/models/detectors/__init__.py
+++ b/mmdet/models/detectors/__init__.py
@@ -2,8 +2,10 @@
from .atss import ATSS
from .autoassign import AutoAssign
from .base import BaseDetector
+from .base_detr import DetectionTransformer
from .cascade_rcnn import CascadeRCNN
from .centernet import CenterNet
+from .conditional_detr import ConditionalDETR
from .cornernet import CornerNet
from .ddod import DDOD
from .deformable_detr import DeformableDETR
@@ -58,5 +60,5 @@
'SOLOv2', 'DeformableDETR', 'AutoAssign', 'YOLOF', 'CenterNet', 'YOLOX',
'TwoStagePanopticSegmentor', 'PanopticFPN', 'QueryInst', 'LAD', 'TOOD',
'MaskFormer', 'DDOD', 'Mask2Former', 'SemiBaseDetector', 'SoftTeacher',
- 'RTMDet'
+ 'DetectionTransformer', 'RTMDet', 'ConditionalDETR'
]
diff --git a/mmdet/models/detectors/base_detr.py b/mmdet/models/detectors/base_detr.py
new file mode 100644
index 00000000000..47311aafefa
--- /dev/null
+++ b/mmdet/models/detectors/base_detr.py
@@ -0,0 +1,329 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from abc import ABCMeta, abstractmethod
+from typing import Dict, List, Tuple, Union
+
+from torch import Tensor
+
+from mmdet.registry import MODELS
+from mmdet.structures import OptSampleList, SampleList
+from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig
+from .base import BaseDetector
+
+
+@MODELS.register_module()
+class DetectionTransformer(BaseDetector, metaclass=ABCMeta):
+ r"""Base class for Detection Transformer.
+
+ Detection Transformer uses an encoder to process output features of neck,
+ then several queries interactive with the output features of encoder and
+ do the regression and classification with bounding box head.
+
+ Args:
+ backbone (:obj:`ConfigDict` or dict): Config of the backbone.
+ neck (:obj:`ConfigDict` or dict, optional): Config of the neck.
+ Defaults to None.
+ encoder (:obj:`ConfigDict` or dict, optional): Config of the
+ Transformer encoder. Defaults to None.
+ decoder (:obj:`ConfigDict` or dict, optional): Config of the
+ Transformer decoder. Defaults to None.
+ positional_encoding_cfg (:obj:`ConfigDict` or dict, optional): Config
+ of the positional encoding module. Defaults to None.
+ bbox_head (:obj:`ConfigDict` or dict, optional): Config for the
+ bounding box head module. Defaults to None.
+ num_queries (int, optional): Number of decoder query in Transformer.
+ Defaults to 100.
+ train_cfg (:obj:`ConfigDict` or dict, optional): Training config of
+ the bounding box head module. Defaults to None.
+ test_cfg (:obj:`ConfigDict` or dict, optional): Testing config of
+ the bounding box head module. Defaults to None.
+ data_preprocessor (dict or ConfigDict, optional): The pre-process
+ config of :class:`BaseDataPreprocessor`. it usually includes,
+ ``pad_size_divisor``, ``pad_value``, ``mean`` and ``std``.
+ Defaults to None.
+ init_cfg (:obj:`ConfigDict` or dict, optional): the config to control
+ the initialization. Defaults to None.
+ """
+
+ def __init__(self,
+ backbone: ConfigType,
+ neck: OptConfigType = None,
+ encoder: OptConfigType = None,
+ decoder: OptConfigType = None,
+ positional_encoding_cfg: OptConfigType = None,
+ bbox_head: OptConfigType = None,
+ num_queries: int = 100,
+ train_cfg: OptConfigType = None,
+ test_cfg: OptConfigType = None,
+ data_preprocessor: OptConfigType = None,
+ init_cfg: OptMultiConfig = None) -> None:
+ super().__init__(
+ data_preprocessor=data_preprocessor, init_cfg=init_cfg)
+ # process args
+ bbox_head.update(train_cfg=train_cfg)
+ bbox_head.update(test_cfg=test_cfg)
+ self.train_cfg = train_cfg
+ self.test_cfg = test_cfg
+ self.encoder = encoder
+ self.decoder = decoder
+ self.positional_encoding_cfg = positional_encoding_cfg
+ self.num_queries = num_queries
+
+ # init model layers
+ self.backbone = MODELS.build(backbone)
+ if neck is not None:
+ self.neck = MODELS.build(neck)
+ self.bbox_head = MODELS.build(bbox_head)
+ self._init_layers()
+
+ @abstractmethod
+ def _init_layers(self) -> None:
+ """Initialize layers except for backbone, neck and bbox_head."""
+ pass
+
+ def loss(self, batch_inputs: Tensor,
+ batch_data_samples: SampleList) -> Union[dict, list]:
+ """Calculate losses from a batch of inputs and data samples.
+
+ Args:
+ batch_inputs (Tensor): Input images of shape (bs, dim, H, W).
+ These should usually be mean centered and std scaled.
+ batch_data_samples (List[:obj:`DetDataSample`]): The batch
+ data samples. It usually includes information such
+ as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
+
+ Returns:
+ dict: A dictionary of loss components
+ """
+ img_feats = self.extract_feat(batch_inputs)
+ head_inputs_dict = self.forward_transformer(img_feats,
+ batch_data_samples)
+ losses = self.bbox_head.loss(
+ **head_inputs_dict, batch_data_samples=batch_data_samples)
+ return losses
+
+ def predict(self,
+ batch_inputs: Tensor,
+ batch_data_samples: SampleList,
+ rescale: bool = True) -> SampleList:
+ """Predict results from a batch of inputs and data samples with post-
+ processing.
+
+ Args:
+ batch_inputs (Tensor): Inputs, has shape (bs, dim, H, W).
+ batch_data_samples (List[:obj:`DetDataSample`]): The batch
+ data samples. It usually includes information such
+ as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
+ rescale (bool): Whether to rescale the results.
+ Defaults to True.
+
+ Returns:
+ list[:obj:`DetDataSample`]: Detection results of the input images.
+ Each DetDataSample usually contain 'pred_instances'. And the
+ `pred_instances` usually contains following keys.
+
+ - scores (Tensor): Classification scores, has a shape
+ (num_instance, )
+ - labels (Tensor): Labels of bboxes, has a shape
+ (num_instances, ).
+ - bboxes (Tensor): Has a shape (num_instances, 4),
+ the last dimension 4 arrange as (x1, y1, x2, y2).
+ """
+ img_feats = self.extract_feat(batch_inputs)
+ head_inputs_dict = self.forward_transformer(img_feats,
+ batch_data_samples)
+ results_list = self.bbox_head.predict(
+ **head_inputs_dict,
+ rescale=rescale,
+ batch_data_samples=batch_data_samples)
+ batch_data_samples = self.add_pred_to_datasample(
+ batch_data_samples, results_list)
+ return batch_data_samples
+
+ def _forward(
+ self,
+ batch_inputs: Tensor,
+ batch_data_samples: OptSampleList = None) -> Tuple[List[Tensor]]:
+ """Network forward process. Usually includes backbone, neck and head
+ forward without any post-processing.
+
+ Args:
+ batch_inputs (Tensor): Inputs, has shape (bs, dim, H, W).
+ batch_data_samples (List[:obj:`DetDataSample`], optional): The
+ batch data samples. It usually includes information such
+ as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
+ Defaults to None.
+
+ Returns:
+ tuple[Tensor]: A tuple of features from ``bbox_head`` forward.
+ """
+ img_feats = self.extract_feat(batch_inputs)
+ head_inputs_dict = self.forward_transformer(img_feats,
+ batch_data_samples)
+ results = self.bbox_head.forward(**head_inputs_dict)
+ return results
+
+ def forward_transformer(self,
+ img_feats: Tuple[Tensor],
+ batch_data_samples: OptSampleList = None) -> Dict:
+ """Forward process of Transformer, which includes four steps:
+ 'pre_transformer' -> 'encoder' -> 'pre_decoder' -> 'decoder'. We
+ summarized the parameters flow of the existing DETR-like detector,
+ which can be illustrated as follow:
+
+ .. code:: text
+
+ img_feats & batch_data_samples
+ |
+ V
+ +-----------------+
+ | pre_transformer |
+ +-----------------+
+ | |
+ | V
+ | +-----------------+
+ | | forward_encoder |
+ | +-----------------+
+ | |
+ | V
+ | +---------------+
+ | | pre_decoder |
+ | +---------------+
+ | | |
+ V V |
+ +-----------------+ |
+ | forward_decoder | |
+ +-----------------+ |
+ | |
+ V V
+ head_inputs_dict
+
+ Args:
+ img_feats (tuple[Tensor]): Tuple of feature maps from neck. Each
+ feature map has shape (bs, dim, H, W).
+ batch_data_samples (list[:obj:`DetDataSample`], optional): The
+ batch data samples. It usually includes information such
+ as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
+ Defaults to None.
+
+ Returns:
+ dict: The dictionary of bbox_head function inputs, which always
+ includes the `hidden_states` of the decoder output and may contain
+ `references` including the initial and intermediate references.
+ """
+ encoder_inputs_dict, decoder_inputs_dict = self.pre_transformer(
+ img_feats, batch_data_samples)
+
+ encoder_outputs_dict = self.forward_encoder(**encoder_inputs_dict)
+
+ tmp_dec_in, head_inputs_dict = self.pre_decoder(**encoder_outputs_dict)
+ decoder_inputs_dict.update(tmp_dec_in)
+
+ decoder_outputs_dict = self.forward_decoder(**decoder_inputs_dict)
+ head_inputs_dict.update(decoder_outputs_dict)
+ return head_inputs_dict
+
+ def extract_feat(self, batch_inputs: Tensor) -> Tuple[Tensor]:
+ """Extract features.
+
+ Args:
+ batch_inputs (Tensor): Image tensor, has shape (bs, dim, H, W).
+
+ Returns:
+ tuple[Tensor]: Tuple of feature maps from neck. Each feature map
+ has shape (bs, dim, H, W).
+ """
+ x = self.backbone(batch_inputs)
+ if self.with_neck:
+ x = self.neck(x)
+ return x
+
+ @abstractmethod
+ def pre_transformer(
+ self,
+ img_feats: Tuple[Tensor],
+ batch_data_samples: OptSampleList = None) -> Tuple[Dict, Dict]:
+ """Process image features before feeding them to the transformer.
+
+ Args:
+ img_feats (tuple[Tensor]): Tuple of feature maps from neck. Each
+ feature map has shape (bs, dim, H, W).
+ batch_data_samples (list[:obj:`DetDataSample`], optional): The
+ batch data samples. It usually includes information such
+ as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
+ Defaults to None.
+
+ Returns:
+ tuple[dict, dict]: The first dict contains the inputs of encoder
+ and the second dict contains the inputs of decoder.
+
+ - encoder_inputs_dict (dict): The keyword args dictionary of
+ `self.forward_encoder()`, which includes 'feat', 'feat_mask',
+ 'feat_pos', and other algorithm-specific arguments.
+ - decoder_inputs_dict (dict): The keyword args dictionary of
+ `self.forward_decoder()`, which includes 'memory_mask', and
+ other algorithm-specific arguments.
+ """
+ pass
+
+ @abstractmethod
+ def forward_encoder(self, feat: Tensor, feat_mask: Tensor,
+ feat_pos: Tensor, **kwargs) -> Dict:
+ """Forward with Transformer encoder.
+
+ Args:
+ feat (Tensor): Sequential features, has shape (num_feat, bs, dim).
+ feat_mask (Tensor): ByteTensor, the padding mask of the features,
+ has shape (num_feat, bs).
+ feat_pos (Tensor): The positional embeddings of the features, has
+ shape (num_feat, bs, dim).
+
+ Returns:
+ dict: The dictionary of encoder outputs, which includes the
+ `memory` of the encoder output and other algorithm-specific
+ arguments.
+ """
+ pass
+
+ @abstractmethod
+ def pre_decoder(self, memory: Tensor, **kwargs) -> Tuple[Dict, Dict]:
+ """Prepare intermediate variables before entering Transformer decoder,
+ such as `query`, `query_pos`, and `reference_points`.
+
+ Args:
+ memory (Tensor): The output embeddings of the Transformer encoder,
+ has shape (num_feat, bs, dim).
+
+ Returns:
+ tuple[dict, dict]: The first dict contains the inputs of decoder
+ and the second dict contains the inputs of the bbox_head function.
+
+ - decoder_inputs_dict (dict): The keyword dictionary args of
+ `self.forward_decoder()`, which includes 'query', 'query_pos',
+ 'memory', and other algorithm-specific arguments.
+ - head_inputs_dict (dict): The keyword dictionary args of the
+ bbox_head functions, which is usually empty, or includes
+ `enc_outputs_class` and `enc_outputs_class` when the detector
+ support 'two stage' or 'query selection' strategies.
+ """
+ pass
+
+ @abstractmethod
+ def forward_decoder(self, query: Tensor, query_pos: Tensor, memory: Tensor,
+ **kwargs) -> Dict:
+ """Forward with Transformer decoder.
+
+ Args:
+ query (Tensor): The queries of decoder inputs, has shape
+ (num_queries, bs, dim).
+ query_pos (Tensor): The positional queries of decoder inputs,
+ has shape (num_queries, bs, dim).
+ memory (Tensor): The output embeddings of the Transformer encoder,
+ has shape (num_feat, bs, dim).
+
+ Returns:
+ dict: The dictionary of decoder outputs, which includes the
+ `hidden_states` of the decoder output, `references` including
+ the initial and intermediate reference_points, and other
+ algorithm-specific arguments.
+ """
+ pass
diff --git a/mmdet/models/detectors/conditional_detr.py b/mmdet/models/detectors/conditional_detr.py
new file mode 100644
index 00000000000..d6283b46e72
--- /dev/null
+++ b/mmdet/models/detectors/conditional_detr.py
@@ -0,0 +1,117 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Dict, Tuple
+
+import torch
+import torch.nn as nn
+from torch import Tensor
+
+from mmdet.registry import MODELS
+from ..layers import (ConditionalDetrTransformerDecoder,
+ DetrTransformerEncoder, SinePositionalEncoding)
+from .detr import DETR
+
+
+@MODELS.register_module()
+class ConditionalDETR(DETR):
+ r"""Implementation of `Conditional DETR for Fast Training Convergence.
+
+ `_.
+
+ Code is modified from the `official github repo
+ `_.
+ """
+
+ def __init__(self, *arg, group_detr=1, **kwargs) -> None:
+ self.group_detr = group_detr
+ super().__init__(*arg, **kwargs)
+
+ def _init_layers(self) -> None:
+ """Initialize layers except for backbone, neck and bbox_head."""
+ self.positional_encoding = SinePositionalEncoding(
+ **self.positional_encoding_cfg)
+ self.encoder = DetrTransformerEncoder(**self.encoder)
+ self.decoder = ConditionalDetrTransformerDecoder(**self.decoder)
+ self.embed_dims = self.encoder.embed_dims
+ # NOTE The embed_dims is typically passed from the inside out.
+ # For example in DETR, The embed_dims is passed as
+ # self_attn -> the first encoder layer -> encoder -> detector.
+ self.query_embedding = nn.Embedding(self.num_queries * self.group_detr,
+ self.embed_dims)
+
+ num_feats = self.positional_encoding.num_feats
+ assert num_feats * 2 == self.embed_dims, \
+ f'embed_dims should be exactly 2 times of num_feats. ' \
+ f'Found {self.embed_dims} and {num_feats}.'
+
+ def pre_decoder(self, memory: Tensor) -> Tuple[Dict, Dict]:
+ """Prepare intermediate variables before entering Transformer decoder,
+ such as `query`, `query_pos`.
+
+ The forward procedure of the transformer is defined as:
+ 'pre_transformer' -> 'encoder' -> 'pre_decoder' -> 'decoder'
+ More details can be found at `TransformerDetector.forward_transformer`
+ in `mmdet/detector/base_detr.py`.
+
+ Args:
+ memory (Tensor): The output embeddings of the Transformer encoder,
+ has shape (num_feat, bs, dim).
+
+ Returns:
+ tuple[dict, dict]: The first dict contains the inputs of decoder
+ and the second dict contains the inputs of the bbox_head function.
+
+ - decoder_inputs_dict (dict): The keyword args dictionary of
+ `self.forward_decoder()`, which includes 'query', 'query_pos',
+ 'memory'.
+ - head_inputs_dict (dict): The keyword args dictionary of the
+ bbox_head functions, which is usually empty, or includes
+ `enc_outputs_class` and `enc_outputs_class` when the detector
+ support 'two stage' or 'query selection' strategies.
+ """
+
+ batch_size = memory.size(1)
+ if self.training:
+ query_pos = self.query_embedding.weight
+ else:
+ query_pos = self.query_embedding.weight[:self.num_queries]
+ # (num_queries, dim) -> (num_queries, bs, dim)
+ query_pos = query_pos.unsqueeze(1).repeat(1, batch_size, 1)
+ query = torch.zeros_like(query_pos)
+
+ decoder_inputs_dict = dict(
+ query_pos=query_pos, query=query, memory=memory)
+ head_inputs_dict = dict()
+ return decoder_inputs_dict, head_inputs_dict
+
+ def forward_decoder(self, query: Tensor, query_pos: Tensor, memory: Tensor,
+ memory_mask: Tensor, memory_pos: Tensor) -> Dict:
+ """Forward with Transformer decoder.
+
+ Args:
+ query (Tensor): The queries of decoder inputs, has shape
+ (num_queries, bs, dim).
+ query_pos (Tensor): The positional queries of decoder inputs,
+ has shape (num_queries, bs, dim).
+ memory (Tensor): The output embeddings of the Transformer encoder,
+ has shape (num_feat, bs, dim).
+ memory_mask (Tensor): ByteTensor, the padding mask of the memory,
+ has shape (bs, num_feat).
+ memory_pos (Tensor): The positional embeddings of memory, has
+ shape (num_feat, bs, dim).
+
+ Returns:
+ dict: The dictionary of decoder outputs, which includes the
+ `hidden_states` and `references` of the decoder output.
+ """
+ # (num_decoder_layers, num_queries, bs, dim)
+ hidden_states, references = self.decoder(
+ query=query,
+ key=memory,
+ value=memory,
+ query_pos=query_pos,
+ key_pos=memory_pos,
+ key_padding_mask=memory_mask)
+ hidden_states = hidden_states.transpose(1, 2)
+ head_inputs_dict = dict(
+ hidden_states=hidden_states, references=references)
+ return head_inputs_dict
diff --git a/mmdet/models/detectors/deformable_detr.py b/mmdet/models/detectors/deformable_detr.py
index 7fbbbb86ad7..88773906962 100644
--- a/mmdet/models/detectors/deformable_detr.py
+++ b/mmdet/models/detectors/deformable_detr.py
@@ -1,12 +1,567 @@
# Copyright (c) OpenMMLab. All rights reserved.
+import math
+from typing import Dict, List, Tuple
+
+import torch
+import torch.nn.functional as F
+from mmcv.cnn.bricks.transformer import MultiScaleDeformableAttention
+from mmengine.model import xavier_init
+from torch import Tensor, nn
+from torch.nn.init import normal_
+
from mmdet.registry import MODELS
-from .detr import DETR
+from mmdet.structures import OptSampleList
+from mmdet.utils import OptConfigType
+from ..layers import (DeformableDetrTransformerDecoder,
+ DeformableDetrTransformerEncoder, SinePositionalEncoding)
+from .base_detr import DetectionTransformer
@MODELS.register_module()
-class DeformableDETR(DETR):
+class DeformableDETR(DetectionTransformer):
r"""Implementation of `Deformable DETR: Deformable Transformers for
- End-to-End Object Detection `_"""
+ End-to-End Object Detection `_
+
+ Code is modified from the `official github repo
+ `_.
+
+ Args:
+ decoder (:obj:`ConfigDict` or dict, optional): Config of the
+ Transformer decoder. Defaults to None.
+ bbox_head (:obj:`ConfigDict` or dict, optional): Config for the
+ bounding box head module. Defaults to None.
+ with_box_refine (bool, optional): Whether to refine the references
+ in the decoder. Defaults to `False`.
+ as_two_stage (bool, optional): Whether to generate the proposal
+ from the outputs of encoder. Defaults to `False`.
+ num_feature_levels (int, optional): Number of feature levels.
+ Defaults to 4.
+ """
+
+ def __init__(self,
+ *args,
+ decoder: OptConfigType = None,
+ bbox_head: OptConfigType = None,
+ with_box_refine: bool = False,
+ as_two_stage: bool = False,
+ num_feature_levels: int = 4,
+ **kwargs) -> None:
+ self.with_box_refine = with_box_refine
+ self.as_two_stage = as_two_stage
+ self.num_feature_levels = num_feature_levels
+
+ if bbox_head is not None:
+ assert 'share_pred_layer' not in bbox_head and \
+ 'num_pred_layer' not in bbox_head and \
+ 'as_two_stage' not in bbox_head, \
+ 'The two keyword args `share_pred_layer`, `num_pred_layer`, ' \
+ 'and `as_two_stage are set in `detector.__init__()`, users ' \
+ 'should not set them in `bbox_head` config.'
+ # The last prediction layer is used to generate proposal
+ # from encode feature map when `as_two_stage` is `True`.
+ # And all the prediction layers should share parameters
+ # when `with_box_refine` is `True`.
+ bbox_head['share_pred_layer'] = not with_box_refine
+ bbox_head['num_pred_layer'] = (decoder['num_layers'] + 1) \
+ if self.as_two_stage else decoder['num_layers']
+ bbox_head['as_two_stage'] = as_two_stage
+
+ super().__init__(*args, decoder=decoder, bbox_head=bbox_head, **kwargs)
+
+ def _init_layers(self) -> None:
+ """Initialize layers except for backbone, neck and bbox_head."""
+ self.positional_encoding = SinePositionalEncoding(
+ **self.positional_encoding_cfg)
+ self.encoder = DeformableDetrTransformerEncoder(**self.encoder)
+ self.decoder = DeformableDetrTransformerDecoder(**self.decoder)
+ self.embed_dims = self.encoder.embed_dims
+ if not self.as_two_stage:
+ self.query_embedding = nn.Embedding(self.num_queries,
+ self.embed_dims * 2)
+ # NOTE The query_embedding will be split into query and query_pos
+ # in self.pre_decoder, hence, the embed_dims are doubled.
+
+ num_feats = self.positional_encoding.num_feats
+ assert num_feats * 2 == self.embed_dims, \
+ 'embed_dims should be exactly 2 times of num_feats. ' \
+ f'Found {self.embed_dims} and {num_feats}.'
+
+ self.level_embed = nn.Parameter(
+ torch.Tensor(self.num_feature_levels, self.embed_dims))
+
+ if self.as_two_stage:
+ self.memory_trans_fc = nn.Linear(self.embed_dims, self.embed_dims)
+ self.memory_trans_norm = nn.LayerNorm(self.embed_dims)
+ self.pos_trans_fc = nn.Linear(self.embed_dims * 2,
+ self.embed_dims * 2)
+ self.pos_trans_norm = nn.LayerNorm(self.embed_dims * 2)
+ else:
+ self.reference_points_fc = nn.Linear(self.embed_dims, 2)
+
+ def init_weights(self) -> None:
+ """Initialize weights for Transformer and other components."""
+ super().init_weights()
+ for coder in self.encoder, self.decoder:
+ for p in coder.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+ for m in self.modules():
+ if isinstance(m, MultiScaleDeformableAttention):
+ m.init_weights()
+ if self.as_two_stage:
+ nn.init.xavier_uniform_(self.memory_trans_fc.weight)
+ nn.init.xavier_uniform_(self.pos_trans_fc.weight)
+ else:
+ xavier_init(
+ self.reference_points_fc, distribution='uniform', bias=0.)
+ normal_(self.level_embed)
+
+ def pre_transformer(
+ self,
+ mlvl_feats: Tuple[Tensor],
+ batch_data_samples: OptSampleList = None) -> Tuple[Dict]:
+ """Process image features before feeding them to the transformer.
+
+ The forward procedure of the transformer is defined as:
+ 'pre_transformer' -> 'encoder' -> 'pre_decoder' -> 'decoder'
+ More details can be found at `TransformerDetector.forward_transformer`
+ in `mmdet/detector/base_detr.py`.
+
+ Args:
+ mlvl_feats (tuple[Tensor]): Multi-level features that may have
+ different resolutions, output from neck. Each feature has
+ shape (bs, dim, h_lvl, w_lvl), where 'lvl' means 'layer'.
+ batch_data_samples (list[:obj:`DetDataSample`], optional): The
+ batch data samples. It usually includes information such
+ as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
+ Defaults to None.
+
+ Returns:
+ tuple[dict]: The first dict contains the inputs of encoder and the
+ second dict contains the inputs of decoder.
+
+ - encoder_inputs_dict (dict): The keyword args dictionary of
+ `self.forward_encoder()`, which includes 'feat', 'feat_mask',
+ and 'feat_pos'.
+ - decoder_inputs_dict (dict): The keyword args dictionary of
+ `self.forward_decoder()`, which includes 'memory_mask'.
+ """
+ batch_size = mlvl_feats[0].size(0)
+
+ # construct binary masks for the transformer.
+ assert batch_data_samples is not None
+ batch_input_shape = batch_data_samples[0].batch_input_shape
+ img_shape_list = [sample.img_shape for sample in batch_data_samples]
+ input_img_h, input_img_w = batch_input_shape
+ masks = mlvl_feats[0].new_ones((batch_size, input_img_h, input_img_w))
+ for img_id in range(batch_size):
+ img_h, img_w = img_shape_list[img_id]
+ masks[img_id, :img_h, :img_w] = 0
+ # NOTE following the official DETR repo, non-zero values representing
+ # ignored positions, while zero values means valid positions.
+
+ mlvl_masks = []
+ mlvl_pos_embeds = []
+ for feat in mlvl_feats:
+ mlvl_masks.append(
+ F.interpolate(masks[None],
+ size=feat.shape[-2:]).to(torch.bool).squeeze(0))
+ mlvl_pos_embeds.append(self.positional_encoding(mlvl_masks[-1]))
+
+ feat_flatten = []
+ mask_flatten = []
+ lvl_pos_embed_flatten = []
+ spatial_shapes = []
+ for lvl, (feat, mask, pos_embed) in enumerate(
+ zip(mlvl_feats, mlvl_masks, mlvl_pos_embeds)):
+ batch_size, c, h, w = feat.shape
+ spatial_shape = (h, w)
+ spatial_shapes.append(spatial_shape)
+ feat = feat.flatten(2).transpose(1, 2) # (bs, h_lvl*w_lvl, dim)
+ pos_embed = pos_embed.flatten(2).transpose(1, 2) # as above
+ mask = mask.flatten(1) # (bs, h_lvl*w_lvl)
+ lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1)
+ lvl_pos_embed_flatten.append(lvl_pos_embed)
+ feat_flatten.append(feat)
+ mask_flatten.append(mask)
+
+ # (bs, num_feat), where num_feat = sum_lvl(h_lvl*w_lvl)
+ mask_flatten = torch.cat(mask_flatten, 1)
+ # (bs, num_feat, dim)
+ feat_flatten = torch.cat(feat_flatten, 1)
+ lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
+ # (num_feat, bs, dim)
+ feat_flatten = feat_flatten.permute(1, 0, 2)
+ lvl_pos_embed_flatten = lvl_pos_embed_flatten.permute(1, 0, 2)
+
+ spatial_shapes = torch.as_tensor( # (num_level, 2)
+ spatial_shapes,
+ dtype=torch.long,
+ device=feat_flatten.device)
+ level_start_index = torch.cat((
+ spatial_shapes.new_zeros( # (num_level)
+ (1, )),
+ spatial_shapes.prod(1).cumsum(0)[:-1]))
+ valid_ratios = torch.stack( # (bs, num_level, 2)
+ [self.get_valid_ratio(m) for m in mlvl_masks], 1)
+
+ encoder_inputs_dict = dict(
+ feat=feat_flatten,
+ feat_mask=mask_flatten,
+ feat_pos=lvl_pos_embed_flatten,
+ spatial_shapes=spatial_shapes,
+ level_start_index=level_start_index,
+ valid_ratios=valid_ratios)
+ decoder_inputs_dict = dict(
+ memory_mask=mask_flatten,
+ spatial_shapes=spatial_shapes,
+ level_start_index=level_start_index,
+ valid_ratios=valid_ratios)
+ return encoder_inputs_dict, decoder_inputs_dict
+
+ def forward_encoder(self, feat: Tensor, feat_mask: Tensor,
+ feat_pos: Tensor, spatial_shapes: Tensor,
+ level_start_index: Tensor,
+ valid_ratios: Tensor) -> Dict:
+ """Forward with Transformer encoder.
+
+ The forward procedure of the transformer is defined as:
+ 'pre_transformer' -> 'encoder' -> 'pre_decoder' -> 'decoder'
+ More details can be found at `TransformerDetector.forward_transformer`
+ in `mmdet/detector/base_detr.py`.
+
+ Args:
+ feat (Tensor): Sequential features, has shape (num_feat, bs, dim).
+ feat_mask (Tensor): ByteTensor, the padding mask of the features,
+ has shape (num_feat, bs).
+ feat_pos (Tensor): The positional embeddings of the features, has
+ shape (num_feat, bs, dim).
+ spatial_shapes (Tensor): Spatial shapes of features in all levels,
+ has shape (num_levels, 2), last dimension represents (h, w).
+ level_start_index (Tensor): The start index of each level.
+ A tensor has shape (num_levels, ) and can be represented
+ as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
+ valid_ratios (Tensor): The ratios of the valid width and the valid
+ height relative to the width and the height of features in all
+ levels, has shape (bs, num_levels, 2).
+
+ Returns:
+ dict: The dictionary of encoder outputs, which includes the
+ `memory` of the encoder output.
+ """
+ memory = self.encoder(
+ query=feat,
+ query_pos=feat_pos,
+ key_padding_mask=feat_mask, # for self_attn
+ spatial_shapes=spatial_shapes,
+ level_start_index=level_start_index,
+ valid_ratios=valid_ratios) # (num_feat, bs, dim)
+ memory = memory.permute(1, 0, 2) # (bs, num_feat, dim)
+ encoder_outputs_dict = dict(
+ memory=memory,
+ memory_mask=feat_mask,
+ spatial_shapes=spatial_shapes)
+ return encoder_outputs_dict
+
+ def pre_decoder(self, memory: Tensor, memory_mask: Tensor,
+ spatial_shapes: Tensor) -> Tuple[Dict, Dict]:
+ """Prepare intermediate variables before entering Transformer decoder,
+ such as `query`, `query_pos`, and `reference_points`.
+
+ The forward procedure of the transformer is defined as:
+ 'pre_transformer' -> 'encoder' -> 'pre_decoder' -> 'decoder'
+ More details can be found at `TransformerDetector.forward_transformer`
+ in `mmdet/detector/base_detr.py`.
+
+ Args:
+ memory (Tensor): The output embeddings of the Transformer encoder,
+ has shape (bs, num_feat, dim).
+ memory_mask (Tensor): ByteTensor, the padding mask of the memory,
+ has shape (bs, num_feat). It will only be used when
+ `as_two_stage` is `True`.
+ spatial_shapes (Tensor): Spatial shapes of features in all levels,
+ has shape (num_levels, 2), last dimension represents (h, w).
+ It will only be used when `as_two_stage` is `True`.
+
+ Returns:
+ tuple[dict, dict]: The decoder_inputs_dict and head_inputs_dict.
+
+ - decoder_inputs_dict (dict): The keyword dictionary args of
+ `self.forward_decoder()`, which includes 'query', 'query_pos',
+ 'memory', and `reference_points`. The reference_points of
+ decoder input here are 4D boxes when `as_two_stage` is `True`,
+ otherwise 2D points, although it has `points` in its name.
+ The reference_points in encoder is always 2D points.
+ - head_inputs_dict (dict): The keyword dictionary args of the
+ bbox_head functions, which includes `enc_outputs_class` and
+ `enc_outputs_class`. They are both `None` when 'as_two_stage'
+ is `False`.
+ """
+ batch_size, _, c = memory.shape
+ if self.as_two_stage:
+ output_memory, output_proposals = \
+ self.gen_encoder_output_proposals(
+ memory, memory_mask, spatial_shapes)
+ enc_outputs_class = self.bbox_head.cls_branches[
+ self.decoder.num_layers](
+ output_memory)
+ enc_outputs_coord_unact = self.bbox_head.reg_branches[
+ self.decoder.num_layers](output_memory) + output_proposals
+ # We only use the first channel in enc_outputs_class as foreground,
+ # the other (num_classes - 1) channels are actually not used.
+ # Its targets are set to be 0s, which indicates the first
+ # class (foreground) because we use [0, num_classes - 1] to
+ # indicate class labels, background class is indicated by
+ # num_classes (similar convention in RPN).
+ # See https://github.com/open-mmlab/mmdetection/blob/master/mmdet/models/dense_heads/deformable_detr_head.py#L241 # noqa
+ # This follows the official implementation of Deformable DETR.
+ topk_proposals = torch.topk(
+ enc_outputs_class[..., 0], self.num_queries, dim=1)[1]
+ topk_coords_unact = torch.gather(
+ enc_outputs_coord_unact, 1,
+ topk_proposals.unsqueeze(-1).repeat(1, 1, 4))
+ topk_coords_unact = topk_coords_unact.detach()
+ reference_points = topk_coords_unact.sigmoid()
+ pos_trans_out = self.pos_trans_fc(
+ self.get_proposal_pos_embed(topk_coords_unact))
+ pos_trans_out = self.pos_trans_norm(pos_trans_out)
+ query_pos, query = torch.split(pos_trans_out, c, dim=2)
+ else:
+ query_embed = self.query_embedding.weight
+ query_pos, query = torch.split(query_embed, c, dim=1)
+ query_pos = query_pos.unsqueeze(0).expand(batch_size, -1, -1)
+ query = query.unsqueeze(0).expand(batch_size, -1, -1)
+ reference_points = self.reference_points_fc(query_pos).sigmoid()
+
+ query = query.permute(1, 0, 2) # (num_queries, bs, dim)
+ memory = memory.permute(1, 0, 2) # (num_feat, bs, dim)
+ query_pos = query_pos.permute(1, 0, 2) # (num_queries, bs, dim)
+
+ decoder_inputs_dict = dict(
+ query=query,
+ query_pos=query_pos,
+ memory=memory,
+ reference_points=reference_points)
+ head_inputs_dict = dict(
+ enc_outputs_class=enc_outputs_class if self.as_two_stage else None,
+ enc_outputs_coord=enc_outputs_coord_unact.sigmoid()
+ if self.as_two_stage else None)
+ return decoder_inputs_dict, head_inputs_dict
+
+ def forward_decoder(self, query: Tensor, query_pos: Tensor, memory: Tensor,
+ memory_mask: Tensor, reference_points: Tensor,
+ spatial_shapes: Tensor, level_start_index: Tensor,
+ valid_ratios: Tensor) -> Dict:
+ """Forward with Transformer decoder.
+
+ The forward procedure of the transformer is defined as:
+ 'pre_transformer' -> 'encoder' -> 'pre_decoder' -> 'decoder'
+ More details can be found at `TransformerDetector.forward_transformer`
+ in `mmdet/detector/base_detr.py`.
+
+ Args:
+ query (Tensor): The queries of decoder inputs, has shape
+ (num_queries, bs, dim).
+ query_pos (Tensor): The positional queries of decoder inputs,
+ has shape (num_queries, bs, dim).
+ memory (Tensor): The output embeddings of the Transformer encoder,
+ has shape (num_feat, bs, dim).
+ memory_mask (Tensor): ByteTensor, the padding mask of the memory,
+ has shape (bs, num_feat).
+ reference_points (Tensor): The initial reference, has shape
+ (bs, num_queries, 4) when `as_two_stage` is `True`,
+ otherwise has shape (bs, num_queries, 2).
+ spatial_shapes (Tensor): Spatial shapes of features in all levels,
+ has shape (num_levels, 2), last dimension represents (h, w).
+ level_start_index (Tensor): The start index of each level.
+ A tensor has shape (num_levels, ) and can be represented
+ as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
+ valid_ratios (Tensor): The ratios of the valid width and the valid
+ height relative to the width and the height of features in all
+ levels, has shape (bs, num_levels, 2).
+
+ Returns:
+ dict: The dictionary of decoder outputs, which includes the
+ `hidden_states` of the decoder output and `references` including
+ the initial and intermediate reference_points.
+ """
+ inter_states, inter_references = self.decoder(
+ query=query,
+ value=memory,
+ query_pos=query_pos,
+ key_padding_mask=memory_mask, # for cross_attn
+ reference_points=reference_points,
+ spatial_shapes=spatial_shapes,
+ level_start_index=level_start_index,
+ valid_ratios=valid_ratios,
+ reg_branches=self.bbox_head.reg_branches
+ if self.with_box_refine else None)
+ references = [reference_points, *inter_references]
+ decoder_outputs_dict = dict(
+ hidden_states=inter_states, references=references)
+ return decoder_outputs_dict
+
+ @staticmethod
+ def get_valid_ratio(mask: Tensor) -> Tensor:
+ """Get the valid radios of feature map in a level.
+
+ .. code:: text
+
+ |---> valid_H <---|
+ ---+-----------------+-----+---
+ A | | | A
+ | | | | |
+ | | | | |
+ valid_W | | | |
+ | | | | W
+ | | | | |
+ V | | | |
+ ---+-----------------+ | |
+ | | V
+ +-----------------------+---
+ |---------> H <---------|
+
+ The valid_ratios are defined as:
+ r_h = valid_H / H, r_w = valid_W / W
+ They are the factors to re-normalize the relative coordinates of the
+ image to the relative coordinates of the current level feature map.
+
+ Args:
+ mask (Tensor): Binary mask of a feature map, has shape (bs, H, W).
+
+ Returns:
+ Tensor: valid ratios [r_w, r_h] of a feature map, has shape (1, 2).
+ """
+ _, H, W = mask.shape
+ valid_H = torch.sum(~mask[:, :, 0], 1)
+ valid_W = torch.sum(~mask[:, 0, :], 1)
+ valid_ratio_h = valid_H.float() / H
+ valid_ratio_w = valid_W.float() / W
+ valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
+ return valid_ratio
+
+ def gen_encoder_output_proposals(
+ self, memory: Tensor, memory_mask: Tensor,
+ spatial_shapes: Tensor) -> Tuple[Tensor, Tensor]:
+ """Generate proposals from encoded memory. The function will only be
+ used when `as_two_stage` is `True`.
+
+ Args:
+ memory (Tensor): The output embeddings of the Transformer encoder,
+ has shape (num_feat, bs, dim).
+ memory_mask (Tensor): ByteTensor, the padding mask of the memory,
+ has shape (bs, num_feat).
+ spatial_shapes (Tensor): Spatial shapes of features in all levels,
+ has shape (num_levels, 2), last dimension represents (h, w).
+
+ Returns:
+ tuple: A tuple of transformed memory and proposals.
+
+ - output_memory (Tensor): The transformed memory for obtaining
+ top-k proposals, has shape (bs, num_feat, dim).
+ - output_proposals (Tensor): The inverse-normalized proposal, has
+ shape (batch_size, num_keys, 4).
+ """
+
+ num_feat = memory.size(0)
+ proposals = []
+ _cur = 0 # start index in the sequence of the current level
+ for lvl, (H, W) in enumerate(spatial_shapes):
+ mask_flatten_ = memory_mask[:, _cur:(_cur + H * W)].view(
+ num_feat, H, W, 1)
+ valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1).unsqueeze(-1)
+ valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1).unsqueeze(-1)
+
+ grid_y, grid_x = torch.meshgrid(
+ torch.linspace(
+ 0, H - 1, H, dtype=torch.float32, device=memory.device),
+ torch.linspace(
+ 0, W - 1, W, dtype=torch.float32, device=memory.device))
+ grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1)
+
+ scale = torch.cat([valid_W, valid_H], 1).view(num_feat, 1, 1, 2)
+ grid = (grid.unsqueeze(0).expand(num_feat, -1, -1, -1) +
+ 0.5) / scale
+ wh = torch.ones_like(grid) * 0.05 * (2.0**lvl)
+ proposal = torch.cat((grid, wh), -1).view(num_feat, -1, 4)
+ proposals.append(proposal)
+ _cur += (H * W)
+ output_proposals = torch.cat(proposals, 1)
+ output_proposals_valid = ((output_proposals > 0.01) &
+ (output_proposals < 0.99)).all(
+ -1, keepdim=True)
+ # inverse_sigmoid
+ output_proposals = torch.log(output_proposals / (1 - output_proposals))
+ output_proposals = output_proposals.masked_fill(
+ memory_mask.unsqueeze(-1), float('inf'))
+ output_proposals = output_proposals.masked_fill(
+ ~output_proposals_valid, float('inf'))
+
+ output_memory = memory
+ output_memory = output_memory.masked_fill(
+ memory_mask.unsqueeze(-1), float(0))
+ output_memory = output_memory.masked_fill(~output_proposals_valid,
+ float(0))
+ output_memory = self.memory_trans_fc(output_memory)
+ output_memory = self.memory_trans_norm(output_memory)
+ # [bs, sum(hw), 2]
+ return output_memory, output_proposals
+
+ @staticmethod
+ def get_proposal_pos_embed(proposals: Tensor,
+ num_pos_feats: int = 128,
+ temperature: int = 10000) -> Tensor:
+ """Get the position embedding of the proposal.
+
+ Args:
+ proposals (Tensor): Not normalized proposals, has shape
+ (bs, num_queries, 4).
+ num_pos_feats (int, optional): The feature dimension for each
+ position along x, y, w, and h-axis. Note the final returned
+ dimension for each position is 4 times of num_pos_feats.
+ Default to 128.
+ temperature (int, optional): The temperature used for scaling the
+ position embedding. Defaults to 10000.
+
+ Returns:
+ Tensor: The position embedding of proposal, has shape
+ (bs, num_queries, num_pos_feats * 4)
+ """
+ scale = 2 * math.pi
+ dim_t = torch.arange(
+ num_pos_feats, dtype=torch.float32, device=proposals.device)
+ dim_t = temperature**(2 * (dim_t // 2) / num_pos_feats)
+ # N, L, 4
+ proposals = proposals.sigmoid() * scale
+ # N, L, 4, 128
+ pos = proposals[:, :, :, None] / dim_t
+ # N, L, 4, 64, 2
+ pos = torch.stack((pos[:, :, :, 0::2].sin(), pos[:, :, :, 1::2].cos()),
+ dim=4).flatten(2)
+ return pos
+
+ def _forward(
+ self,
+ batch_inputs: Tensor,
+ batch_data_samples: OptSampleList = None) -> Tuple[List[Tensor]]:
+ """Network forward process. Usually includes backbone, neck and head
+ forward without any post-processing.Overwrite to pop
+ 'enc_outputs_class' and 'enc_outputs_coord'.
+
+ Args:
+ batch_inputs (Tensor): Inputs, has shape (bs, dim, H, W).
+ batch_data_samples (List[:obj:`DetDataSample`], optional): The
+ batch data samples. It usually includes information such
+ as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
+ Defaults to None.
- def __init__(self, *args, **kwargs) -> None:
- super(DETR, self).__init__(*args, **kwargs)
+ Returns:
+ tuple[Tensor]: A tuple of features from ``bbox_head`` forward.
+ """
+ img_feats = self.extract_feat(batch_inputs)
+ head_inputs_dict = self.forward_transformer(img_feats,
+ batch_data_samples)
+ head_inputs_dict.pop('enc_outputs_class')
+ head_inputs_dict.pop('enc_outputs_coord')
+ results = self.bbox_head.forward(**head_inputs_dict)
+ return results
diff --git a/mmdet/models/detectors/detr.py b/mmdet/models/detectors/detr.py
index 2ba68aa46c3..e7ff317df51 100644
--- a/mmdet/models/detectors/detr.py
+++ b/mmdet/models/detectors/detr.py
@@ -1,26 +1,212 @@
# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Dict, Tuple
+
+import torch
+import torch.nn.functional as F
+from torch import Tensor, nn
+
from mmdet.registry import MODELS
-from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig
-from .single_stage import SingleStageDetector
+from mmdet.structures import OptSampleList
+from ..layers import (DetrTransformerDecoder, DetrTransformerEncoder,
+ SinePositionalEncoding)
+from .base_detr import DetectionTransformer
@MODELS.register_module()
-class DETR(SingleStageDetector):
- r"""Implementation of `DETR: End-to-End Object Detection with
- Transformers `_"""
-
- def __init__(self,
- backbone: ConfigType,
- bbox_head: ConfigType,
- train_cfg: OptConfigType = None,
- test_cfg: OptConfigType = None,
- data_preprocessor: OptConfigType = None,
- init_cfg: OptMultiConfig = None) -> None:
- super().__init__(
- backbone=backbone,
- neck=None,
- bbox_head=bbox_head,
- train_cfg=train_cfg,
- test_cfg=test_cfg,
- data_preprocessor=data_preprocessor,
- init_cfg=init_cfg)
+class DETR(DetectionTransformer):
+ r"""Implementation of `DETR: End-to-End Object Detection with Transformers.
+
+ `_.
+
+ Code is modified from the `official github repo
+ `_.
+ """
+
+ def _init_layers(self) -> None:
+ """Initialize layers except for backbone, neck and bbox_head."""
+ self.positional_encoding = SinePositionalEncoding(
+ **self.positional_encoding_cfg)
+ self.encoder = DetrTransformerEncoder(**self.encoder)
+ self.decoder = DetrTransformerDecoder(**self.decoder)
+ self.embed_dims = self.encoder.embed_dims
+ # NOTE The embed_dims is typically passed from the inside out.
+ # For example in DETR, The embed_dims is passed as
+ # self_attn -> the first encoder layer -> encoder -> detector.
+ self.query_embedding = nn.Embedding(self.num_queries, self.embed_dims)
+
+ num_feats = self.positional_encoding.num_feats
+ assert num_feats * 2 == self.embed_dims, \
+ 'embed_dims should be exactly 2 times of num_feats. ' \
+ f'Found {self.embed_dims} and {num_feats}.'
+
+ def init_weights(self) -> None:
+ """Initialize weights for Transformer and other components."""
+ super().init_weights()
+ for coder in self.encoder, self.decoder:
+ for p in coder.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+
+ def pre_transformer(
+ self,
+ img_feats: Tuple[Tensor],
+ batch_data_samples: OptSampleList = None) -> Tuple[Dict, Dict]:
+ """Prepare the inputs of the Transformer.
+
+ The forward procedure of the transformer is defined as:
+ 'pre_transformer' -> 'encoder' -> 'pre_decoder' -> 'decoder'
+ More details can be found at `TransformerDetector.forward_transformer`
+ in `mmdet/detector/base_detr.py`.
+
+ Args:
+ img_feats (Tuple[Tensor]): Tuple of features output from the neck,
+ has shape (bs, c, h, w).
+ batch_data_samples (List[:obj:`DetDataSample`]): The batch
+ data samples. It usually includes information such as
+ `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
+ Defaults to None.
+
+ Returns:
+ tuple[dict, dict]: The first dict contains the inputs of encoder
+ and the second dict contains the inputs of decoder.
+
+ - encoder_inputs_dict (dict): The keyword args dictionary of
+ `self.forward_encoder()`, which includes 'feat', 'feat_mask',
+ and 'feat_pos'.
+ - decoder_inputs_dict (dict): The keyword args dictionary of
+ `self.forward_decoder()`, which includes 'memory_mask',
+ and 'memory_pos'.
+ """
+
+ feat = img_feats[-1] # NOTE img_feats contains only one feature.
+ batch_size, feat_dim, _, _ = feat.shape
+ # construct binary masks which for the transformer.
+ assert batch_data_samples is not None
+ batch_input_shape = batch_data_samples[0].batch_input_shape
+ img_shape_list = [sample.img_shape for sample in batch_data_samples]
+
+ input_img_h, input_img_w = batch_input_shape
+ masks = feat.new_ones((batch_size, input_img_h, input_img_w))
+ for img_id in range(batch_size):
+ img_h, img_w = img_shape_list[img_id]
+ masks[img_id, :img_h, :img_w] = 0
+ # NOTE following the official DETR repo, non-zero values represent
+ # ignored positions, while zero values mean valid positions.
+
+ masks = F.interpolate(
+ masks.unsqueeze(1), size=feat.shape[-2:]).to(torch.bool).squeeze(1)
+ # [batch_size, embed_dim, h, w]
+ pos_embed = self.positional_encoding(masks)
+
+ # use `view` instead of `flatten` for dynamically exporting to ONNX
+ # [bs, c, h, w] -> [h*w, bs, c]
+ feat = feat.view(batch_size, feat_dim, -1).permute(2, 0, 1)
+ pos_embed = pos_embed.view(batch_size, feat_dim, -1).permute(2, 0, 1)
+ # [bs, h, w] -> [bs, h*w]
+ masks = masks.view(batch_size, -1)
+
+ # prepare transformer_inputs_dict
+ encoder_inputs_dict = dict(
+ feat=feat, feat_mask=masks, feat_pos=pos_embed)
+ decoder_inputs_dict = dict(memory_mask=masks, memory_pos=pos_embed)
+ return encoder_inputs_dict, decoder_inputs_dict
+
+ def forward_encoder(self, feat: Tensor, feat_mask: Tensor,
+ feat_pos: Tensor) -> Dict:
+ """Forward with Transformer encoder.
+
+ The forward procedure of the transformer is defined as:
+ 'pre_transformer' -> 'encoder' -> 'pre_decoder' -> 'decoder'
+ More details can be found at `TransformerDetector.forward_transformer`
+ in `mmdet/detector/base_detr.py`.
+
+ Args:
+ feat (Tensor): Sequential features, has shape (num_feat, bs, dim).
+ feat_mask (Tensor): ByteTensor, the padding mask of the features,
+ has shape (num_feat, bs).
+ feat_pos (Tensor): The positional embeddings of the features, has
+ shape (num_feat, bs, dim).
+
+ Returns:
+ dict: The dictionary of encoder outputs, which includes the
+ `memory` of the encoder output.
+ """
+ memory = self.encoder(
+ query=feat, query_pos=feat_pos,
+ key_padding_mask=feat_mask) # for self_attn
+ encoder_outputs_dict = dict(memory=memory)
+ return encoder_outputs_dict
+
+ def pre_decoder(self, memory: Tensor) -> Tuple[Dict, Dict]:
+ """Prepare intermediate variables before entering Transformer decoder,
+ such as `query`, `query_pos`.
+
+ The forward procedure of the transformer is defined as:
+ 'pre_transformer' -> 'encoder' -> 'pre_decoder' -> 'decoder'
+ More details can be found at `TransformerDetector.forward_transformer`
+ in `mmdet/detector/base_detr.py`.
+
+ Args:
+ memory (Tensor): The output embeddings of the Transformer encoder,
+ has shape (num_feat, bs, dim).
+
+ Returns:
+ tuple[dict, dict]: The first dict contains the inputs of decoder
+ and the second dict contains the inputs of the bbox_head function.
+
+ - decoder_inputs_dict (dict): The keyword args dictionary of
+ `self.forward_decoder()`, which includes 'query', 'query_pos',
+ 'memory'.
+ - head_inputs_dict (dict): The keyword args dictionary of the
+ bbox_head functions, which is usually empty, or includes
+ `enc_outputs_class` and `enc_outputs_class` when the detector
+ support 'two stage' or 'query selection' strategies.
+ """
+
+ batch_size = memory.size(1)
+ query_pos = self.query_embedding.weight
+ # (num_queries, dim) -> (num_queries, bs, dim)
+ query_pos = query_pos.unsqueeze(1).repeat(1, batch_size, 1)
+ query = torch.zeros_like(query_pos)
+
+ decoder_inputs_dict = dict(
+ query_pos=query_pos, query=query, memory=memory)
+ head_inputs_dict = dict()
+ return decoder_inputs_dict, head_inputs_dict
+
+ def forward_decoder(self, query: Tensor, query_pos: Tensor, memory: Tensor,
+ memory_mask: Tensor, memory_pos: Tensor) -> Dict:
+ """Forward with Transformer decoder.
+
+ The forward procedure of the transformer is defined as:
+ 'pre_transformer' -> 'encoder' -> 'pre_decoder' -> 'decoder'
+ More details can be found at `TransformerDetector.forward_transformer`
+ in `mmdet/detector/base_detr.py`.
+
+ Args:
+ query (Tensor): The queries of decoder inputs, has shape
+ (num_queries, bs, dim).
+ query_pos (Tensor): The positional queries of decoder inputs,
+ has shape (num_queries, bs, dim).
+ memory (Tensor): The output embeddings of the Transformer encoder,
+ has shape (num_feat, bs, dim).
+ memory_mask (Tensor): ByteTensor, the padding mask of the memory,
+ has shape (bs, num_feat).
+ memory_pos (Tensor): The positional embeddings of memory, has
+ shape (num_feat, bs, dim).
+
+ Returns:
+ dict: The dictionary of decoder outputs, which includes the
+ `hidden_states` of the decoder output.
+ """
+ # (num_decoder_layers, num_queries, bs, dim)
+ hidden_states = self.decoder(
+ query=query,
+ key=memory,
+ value=memory,
+ query_pos=query_pos,
+ key_pos=memory_pos,
+ key_padding_mask=memory_mask) # for cross_attn
+ hidden_states = hidden_states.transpose(1, 2)
+ head_inputs_dict = dict(hidden_states=hidden_states)
+ return head_inputs_dict
diff --git a/mmdet/models/layers/__init__.py b/mmdet/models/layers/__init__.py
index 98f8e843075..384518c3195 100644
--- a/mmdet/models/layers/__init__.py
+++ b/mmdet/models/layers/__init__.py
@@ -15,18 +15,31 @@
SinePositionalEncoding)
from .res_layer import ResLayer, SimplifiedBasicBlock
from .se_layer import ChannelAttention, DyReLU, SELayer
-from .transformer import (DetrTransformerDecoder, DetrTransformerDecoderLayer,
- DynamicConv, PatchEmbed, PatchMerging, Transformer,
- inverse_sigmoid, nchw_to_nlc, nlc_to_nchw)
+from .transformer import (MLP, DeformableDetrTransformerDecoder,
+ DeformableDetrTransformerDecoderLayer,
+ DeformableDetrTransformerEncoder,
+ DeformableDetrTransformerEncoderLayer,
+ DetrTransformerDecoder, DetrTransformerDecoderLayer,
+ DetrTransformerEncoder, DetrTransformerEncoderLayer,
+ DynamicConv, PatchEmbed, PatchMerging,
+ inverse_sigmoid, nchw_to_nlc, nlc_to_nchw,
+ ConditionalDetrTransformerDecoder,
+ ConditionalDetrTransformerDecoderLayer)
__all__ = [
'fast_nms', 'multiclass_nms', 'mask_matrix_nms', 'DropBlock',
'PixelDecoder', 'TransformerEncoderPixelDecoder',
- 'MSDeformAttnPixelDecoder', 'ResLayer', 'DetrTransformerDecoderLayer',
- 'DetrTransformerDecoder', 'Transformer', 'PatchMerging',
+ 'MSDeformAttnPixelDecoder', 'ResLayer', 'PatchMerging',
'SinePositionalEncoding', 'LearnedPositionalEncoding', 'DynamicConv',
'SimplifiedBasicBlock', 'NormedLinear', 'NormedConv2d', 'InvertedResidual',
'SELayer', 'ConvUpsample', 'CSPLayer', 'adaptive_avg_pool2d',
'AdaptiveAvgPool2d', 'PatchEmbed', 'nchw_to_nlc', 'nlc_to_nchw', 'DyReLU',
- 'ExpMomentumEMA', 'inverse_sigmoid', 'ChannelAttention', 'SiLU'
+ 'ExpMomentumEMA', 'inverse_sigmoid', 'ChannelAttention', 'SiLU', 'MLP',
+ 'DetrTransformerEncoderLayer', 'DetrTransformerDecoderLayer',
+ 'DetrTransformerEncoder', 'DetrTransformerDecoder',
+ 'DeformableDetrTransformerEncoder', 'DeformableDetrTransformerDecoder',
+ 'DeformableDetrTransformerEncoderLayer',
+ 'DeformableDetrTransformerDecoderLayer',
+ 'ConditionalDetrTransformerDecoder',
+ 'ConditionalDetrTransformerDecoderLayer'
]
diff --git a/mmdet/models/layers/msdeformattn_pixel_decoder.py b/mmdet/models/layers/msdeformattn_pixel_decoder.py
index 953f873f400..12ea14d7efc 100644
--- a/mmdet/models/layers/msdeformattn_pixel_decoder.py
+++ b/mmdet/models/layers/msdeformattn_pixel_decoder.py
@@ -5,7 +5,8 @@
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import Conv2d, ConvModule
-from mmcv.cnn.bricks.transformer import (build_positional_encoding,
+from mmcv.cnn.bricks.transformer import (MultiScaleDeformableAttention,
+ build_positional_encoding,
build_transformer_layer_sequence)
from mmengine.model import (BaseModule, ModuleList, caffe2_xavier_init,
normal_init, xavier_init)
@@ -14,7 +15,6 @@
from mmdet.registry import MODELS
from mmdet.utils import ConfigType, OptMultiConfig
from ..task_modules.prior_generators import MlvlPointGenerator
-from .transformer import MultiScaleDeformableAttention
@MODELS.register_module()
diff --git a/mmdet/models/layers/transformer.py b/mmdet/models/layers/transformer.py
deleted file mode 100644
index 19c3e62f289..00000000000
--- a/mmdet/models/layers/transformer.py
+++ /dev/null
@@ -1,1164 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import math
-import warnings
-from typing import Sequence
-
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from mmcv.cnn import build_activation_layer, build_conv_layer, build_norm_layer
-from mmcv.cnn.bricks.transformer import (BaseTransformerLayer,
- TransformerLayerSequence,
- build_transformer_layer_sequence)
-from mmengine.model import BaseModule, xavier_init
-from mmengine.utils import to_2tuple
-from torch.nn.init import normal_
-
-from mmdet.registry import MODELS
-
-try:
- from mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttention
-
-except ImportError:
- warnings.warn(
- '`MultiScaleDeformableAttention` in MMCV has been moved to '
- '`mmcv.ops.multi_scale_deform_attn`, please update your MMCV')
- from mmcv.cnn.bricks.transformer import MultiScaleDeformableAttention
-
-
-def nlc_to_nchw(x, hw_shape):
- """Convert [N, L, C] shape tensor to [N, C, H, W] shape tensor.
-
- Args:
- x (Tensor): The input tensor of shape [N, L, C] before conversion.
- hw_shape (Sequence[int]): The height and width of output feature map.
-
- Returns:
- Tensor: The output tensor of shape [N, C, H, W] after conversion.
- """
- H, W = hw_shape
- assert len(x.shape) == 3
- B, L, C = x.shape
- assert L == H * W, 'The seq_len does not match H, W'
- return x.transpose(1, 2).reshape(B, C, H, W).contiguous()
-
-
-def nchw_to_nlc(x):
- """Flatten [N, C, H, W] shape tensor to [N, L, C] shape tensor.
-
- Args:
- x (Tensor): The input tensor of shape [N, C, H, W] before conversion.
-
- Returns:
- Tensor: The output tensor of shape [N, L, C] after conversion.
- """
- assert len(x.shape) == 4
- return x.flatten(2).transpose(1, 2).contiguous()
-
-
-class AdaptivePadding(nn.Module):
- """Applies padding to input (if needed) so that input can get fully covered
- by filter you specified. It support two modes "same" and "corner". The
- "same" mode is same with "SAME" padding mode in TensorFlow, pad zero around
- input. The "corner" mode would pad zero to bottom right.
-
- Args:
- kernel_size (int | tuple): Size of the kernel:
- stride (int | tuple): Stride of the filter. Default: 1:
- dilation (int | tuple): Spacing between kernel elements.
- Default: 1
- padding (str): Support "same" and "corner", "corner" mode
- would pad zero to bottom right, and "same" mode would
- pad zero around input. Default: "corner".
- Example:
- >>> kernel_size = 16
- >>> stride = 16
- >>> dilation = 1
- >>> input = torch.rand(1, 1, 15, 17)
- >>> adap_pad = AdaptivePadding(
- >>> kernel_size=kernel_size,
- >>> stride=stride,
- >>> dilation=dilation,
- >>> padding="corner")
- >>> out = adap_pad(input)
- >>> assert (out.shape[2], out.shape[3]) == (16, 32)
- >>> input = torch.rand(1, 1, 16, 17)
- >>> out = adap_pad(input)
- >>> assert (out.shape[2], out.shape[3]) == (16, 32)
- """
-
- def __init__(self, kernel_size=1, stride=1, dilation=1, padding='corner'):
-
- super(AdaptivePadding, self).__init__()
-
- assert padding in ('same', 'corner')
-
- kernel_size = to_2tuple(kernel_size)
- stride = to_2tuple(stride)
- padding = to_2tuple(padding)
- dilation = to_2tuple(dilation)
-
- self.padding = padding
- self.kernel_size = kernel_size
- self.stride = stride
- self.dilation = dilation
-
- def get_pad_shape(self, input_shape):
- input_h, input_w = input_shape
- kernel_h, kernel_w = self.kernel_size
- stride_h, stride_w = self.stride
- output_h = math.ceil(input_h / stride_h)
- output_w = math.ceil(input_w / stride_w)
- pad_h = max((output_h - 1) * stride_h +
- (kernel_h - 1) * self.dilation[0] + 1 - input_h, 0)
- pad_w = max((output_w - 1) * stride_w +
- (kernel_w - 1) * self.dilation[1] + 1 - input_w, 0)
- return pad_h, pad_w
-
- def forward(self, x):
- pad_h, pad_w = self.get_pad_shape(x.size()[-2:])
- if pad_h > 0 or pad_w > 0:
- if self.padding == 'corner':
- x = F.pad(x, [0, pad_w, 0, pad_h])
- elif self.padding == 'same':
- x = F.pad(x, [
- pad_w // 2, pad_w - pad_w // 2, pad_h // 2,
- pad_h - pad_h // 2
- ])
- return x
-
-
-class PatchEmbed(BaseModule):
- """Image to Patch Embedding.
-
- We use a conv layer to implement PatchEmbed.
-
- Args:
- in_channels (int): The num of input channels. Default: 3
- embed_dims (int): The dimensions of embedding. Default: 768
- conv_type (str): The config dict for embedding
- conv layer type selection. Default: "Conv2d.
- kernel_size (int): The kernel_size of embedding conv. Default: 16.
- stride (int): The slide stride of embedding conv.
- Default: None (Would be set as `kernel_size`).
- padding (int | tuple | string ): The padding length of
- embedding conv. When it is a string, it means the mode
- of adaptive padding, support "same" and "corner" now.
- Default: "corner".
- dilation (int): The dilation rate of embedding conv. Default: 1.
- bias (bool): Bias of embed conv. Default: True.
- norm_cfg (dict, optional): Config dict for normalization layer.
- Default: None.
- input_size (int | tuple | None): The size of input, which will be
- used to calculate the out size. Only work when `dynamic_size`
- is False. Default: None.
- init_cfg (`mmengine.ConfigDict`, optional): The Config for
- initialization. Default: None.
- """
-
- def __init__(
- self,
- in_channels=3,
- embed_dims=768,
- conv_type='Conv2d',
- kernel_size=16,
- stride=16,
- padding='corner',
- dilation=1,
- bias=True,
- norm_cfg=None,
- input_size=None,
- init_cfg=None,
- ):
- super(PatchEmbed, self).__init__(init_cfg=init_cfg)
-
- self.embed_dims = embed_dims
- if stride is None:
- stride = kernel_size
-
- kernel_size = to_2tuple(kernel_size)
- stride = to_2tuple(stride)
- dilation = to_2tuple(dilation)
-
- if isinstance(padding, str):
- self.adap_padding = AdaptivePadding(
- kernel_size=kernel_size,
- stride=stride,
- dilation=dilation,
- padding=padding)
- # disable the padding of conv
- padding = 0
- else:
- self.adap_padding = None
- padding = to_2tuple(padding)
-
- self.projection = build_conv_layer(
- dict(type=conv_type),
- in_channels=in_channels,
- out_channels=embed_dims,
- kernel_size=kernel_size,
- stride=stride,
- padding=padding,
- dilation=dilation,
- bias=bias)
-
- if norm_cfg is not None:
- self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
- else:
- self.norm = None
-
- if input_size:
- input_size = to_2tuple(input_size)
- # `init_out_size` would be used outside to
- # calculate the num_patches
- # when `use_abs_pos_embed` outside
- self.init_input_size = input_size
- if self.adap_padding:
- pad_h, pad_w = self.adap_padding.get_pad_shape(input_size)
- input_h, input_w = input_size
- input_h = input_h + pad_h
- input_w = input_w + pad_w
- input_size = (input_h, input_w)
-
- # https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
- h_out = (input_size[0] + 2 * padding[0] - dilation[0] *
- (kernel_size[0] - 1) - 1) // stride[0] + 1
- w_out = (input_size[1] + 2 * padding[1] - dilation[1] *
- (kernel_size[1] - 1) - 1) // stride[1] + 1
- self.init_out_size = (h_out, w_out)
- else:
- self.init_input_size = None
- self.init_out_size = None
-
- def forward(self, x):
- """
- Args:
- x (Tensor): Has shape (B, C, H, W). In most case, C is 3.
-
- Returns:
- tuple: Contains merged results and its spatial shape.
-
- - x (Tensor): Has shape (B, out_h * out_w, embed_dims)
- - out_size (tuple[int]): Spatial shape of x, arrange as
- (out_h, out_w).
- """
-
- if self.adap_padding:
- x = self.adap_padding(x)
-
- x = self.projection(x)
- out_size = (x.shape[2], x.shape[3])
- x = x.flatten(2).transpose(1, 2)
- if self.norm is not None:
- x = self.norm(x)
- return x, out_size
-
-
-class PatchMerging(BaseModule):
- """Merge patch feature map.
-
- This layer groups feature map by kernel_size, and applies norm and linear
- layers to the grouped feature map. Our implementation uses `nn.Unfold` to
- merge patch, which is about 25% faster than original implementation.
- Instead, we need to modify pretrained models for compatibility.
-
- Args:
- in_channels (int): The num of input channels.
- to gets fully covered by filter and stride you specified..
- Default: True.
- out_channels (int): The num of output channels.
- kernel_size (int | tuple, optional): the kernel size in the unfold
- layer. Defaults to 2.
- stride (int | tuple, optional): the stride of the sliding blocks in the
- unfold layer. Default: None. (Would be set as `kernel_size`)
- padding (int | tuple | string ): The padding length of
- embedding conv. When it is a string, it means the mode
- of adaptive padding, support "same" and "corner" now.
- Default: "corner".
- dilation (int | tuple, optional): dilation parameter in the unfold
- layer. Default: 1.
- bias (bool, optional): Whether to add bias in linear layer or not.
- Defaults: False.
- norm_cfg (dict, optional): Config dict for normalization layer.
- Default: dict(type='LN').
- init_cfg (dict, optional): The extra config for initialization.
- Default: None.
- """
-
- def __init__(self,
- in_channels,
- out_channels,
- kernel_size=2,
- stride=None,
- padding='corner',
- dilation=1,
- bias=False,
- norm_cfg=dict(type='LN'),
- init_cfg=None):
- super().__init__(init_cfg=init_cfg)
- self.in_channels = in_channels
- self.out_channels = out_channels
- if stride:
- stride = stride
- else:
- stride = kernel_size
-
- kernel_size = to_2tuple(kernel_size)
- stride = to_2tuple(stride)
- dilation = to_2tuple(dilation)
-
- if isinstance(padding, str):
- self.adap_padding = AdaptivePadding(
- kernel_size=kernel_size,
- stride=stride,
- dilation=dilation,
- padding=padding)
- # disable the padding of unfold
- padding = 0
- else:
- self.adap_padding = None
-
- padding = to_2tuple(padding)
- self.sampler = nn.Unfold(
- kernel_size=kernel_size,
- dilation=dilation,
- padding=padding,
- stride=stride)
-
- sample_dim = kernel_size[0] * kernel_size[1] * in_channels
-
- if norm_cfg is not None:
- self.norm = build_norm_layer(norm_cfg, sample_dim)[1]
- else:
- self.norm = None
-
- self.reduction = nn.Linear(sample_dim, out_channels, bias=bias)
-
- def forward(self, x, input_size):
- """
- Args:
- x (Tensor): Has shape (B, H*W, C_in).
- input_size (tuple[int]): The spatial shape of x, arrange as (H, W).
- Default: None.
-
- Returns:
- tuple: Contains merged results and its spatial shape.
-
- - x (Tensor): Has shape (B, Merged_H * Merged_W, C_out)
- - out_size (tuple[int]): Spatial shape of x, arrange as
- (Merged_H, Merged_W).
- """
- B, L, C = x.shape
- assert isinstance(input_size, Sequence), f'Expect ' \
- f'input_size is ' \
- f'`Sequence` ' \
- f'but get {input_size}'
-
- H, W = input_size
- assert L == H * W, 'input feature has wrong size'
-
- x = x.view(B, H, W, C).permute([0, 3, 1, 2]) # B, C, H, W
- # Use nn.Unfold to merge patch. About 25% faster than original method,
- # but need to modify pretrained model for compatibility
-
- if self.adap_padding:
- x = self.adap_padding(x)
- H, W = x.shape[-2:]
-
- x = self.sampler(x)
- # if kernel_size=2 and stride=2, x should has shape (B, 4*C, H/2*W/2)
-
- out_h = (H + 2 * self.sampler.padding[0] - self.sampler.dilation[0] *
- (self.sampler.kernel_size[0] - 1) -
- 1) // self.sampler.stride[0] + 1
- out_w = (W + 2 * self.sampler.padding[1] - self.sampler.dilation[1] *
- (self.sampler.kernel_size[1] - 1) -
- 1) // self.sampler.stride[1] + 1
-
- output_size = (out_h, out_w)
- x = x.transpose(1, 2) # B, H/2*W/2, 4*C
- x = self.norm(x) if self.norm else x
- x = self.reduction(x)
- return x, output_size
-
-
-def inverse_sigmoid(x, eps=1e-5):
- """Inverse function of sigmoid.
-
- Args:
- x (Tensor): The tensor to do the
- inverse.
- eps (float): EPS avoid numerical
- overflow. Defaults 1e-5.
- Returns:
- Tensor: The x has passed the inverse
- function of sigmoid, has same
- shape with input.
- """
- x = x.clamp(min=0, max=1)
- x1 = x.clamp(min=eps)
- x2 = (1 - x).clamp(min=eps)
- return torch.log(x1 / x2)
-
-
-@MODELS.register_module()
-class DetrTransformerDecoderLayer(BaseTransformerLayer):
- """Implements decoder layer in DETR transformer.
-
- Args:
- attn_cfgs (list[`mmengine.ConfigDict`] | list[dict] | dict )):
- Configs for self_attention or cross_attention, the order
- should be consistent with it in `operation_order`. If it is
- a dict, it would be expand to the number of attention in
- `operation_order`.
- feedforward_channels (int): The hidden dimension for FFNs.
- ffn_dropout (float): Probability of an element to be zeroed
- in ffn. Default 0.0.
- operation_order (tuple[str]): The execution order of operation
- in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm').
- Default:None
- act_cfg (dict): The activation config for FFNs. Default: `LN`
- norm_cfg (dict): Config dict for normalization layer.
- Default: `LN`.
- ffn_num_fcs (int): The number of fully-connected layers in FFNs.
- Default:2.
- """
-
- def __init__(self,
- attn_cfgs,
- feedforward_channels,
- ffn_dropout=0.0,
- operation_order=None,
- act_cfg=dict(type='ReLU', inplace=True),
- norm_cfg=dict(type='LN'),
- ffn_num_fcs=2,
- **kwargs):
- super(DetrTransformerDecoderLayer, self).__init__(
- attn_cfgs=attn_cfgs,
- feedforward_channels=feedforward_channels,
- ffn_dropout=ffn_dropout,
- operation_order=operation_order,
- act_cfg=act_cfg,
- norm_cfg=norm_cfg,
- ffn_num_fcs=ffn_num_fcs,
- **kwargs)
- assert len(operation_order) == 6
- assert set(operation_order) == set(
- ['self_attn', 'norm', 'cross_attn', 'ffn'])
-
-
-@MODELS.register_module()
-class DetrTransformerEncoder(TransformerLayerSequence):
- """TransformerEncoder of DETR.
-
- Args:
- post_norm_cfg (dict): Config of last normalization layer. Default:
- `LN`. Only used when `self.pre_norm` is `True`
- """
-
- def __init__(self, *args, post_norm_cfg=dict(type='LN'), **kwargs):
- super(DetrTransformerEncoder, self).__init__(*args, **kwargs)
- if post_norm_cfg is not None:
- self.post_norm = build_norm_layer(
- post_norm_cfg, self.embed_dims)[1] if self.pre_norm else None
- else:
- assert not self.pre_norm, f'Use prenorm in ' \
- f'{self.__class__.__name__},' \
- f'Please specify post_norm_cfg'
- self.post_norm = None
-
- def forward(self, *args, **kwargs):
- """Forward function for `TransformerCoder`.
-
- Returns:
- Tensor: forwarded results with shape [num_query, bs, embed_dims].
- """
- x = super(DetrTransformerEncoder, self).forward(*args, **kwargs)
- if self.post_norm is not None:
- x = self.post_norm(x)
- return x
-
-
-@MODELS.register_module()
-class DetrTransformerDecoder(TransformerLayerSequence):
- """Implements the decoder in DETR transformer.
-
- Args:
- return_intermediate (bool): Whether to return intermediate outputs.
- post_norm_cfg (dict): Config of last normalization layer. Default:
- `LN`.
- """
-
- def __init__(self,
- *args,
- post_norm_cfg=dict(type='LN'),
- return_intermediate=False,
- **kwargs):
-
- super(DetrTransformerDecoder, self).__init__(*args, **kwargs)
- self.return_intermediate = return_intermediate
- if post_norm_cfg is not None:
- self.post_norm = build_norm_layer(post_norm_cfg,
- self.embed_dims)[1]
- else:
- self.post_norm = None
-
- def forward(self, query, *args, **kwargs):
- """Forward function for `TransformerDecoder`.
-
- Args:
- query (Tensor): Input query with shape
- `(num_query, bs, embed_dims)`.
-
- Returns:
- Tensor: Results with shape [1, num_query, bs, embed_dims] when
- return_intermediate is `False`, otherwise it has shape
- [num_layers, num_query, bs, embed_dims].
- """
- if not self.return_intermediate:
- x = super().forward(query, *args, **kwargs)
- if self.post_norm:
- x = self.post_norm(x)[None]
- return x
-
- intermediate = []
- for layer in self.layers:
- query = layer(query, *args, **kwargs)
- if self.return_intermediate:
- if self.post_norm is not None:
- intermediate.append(self.post_norm(query))
- else:
- intermediate.append(query)
- return torch.stack(intermediate)
-
-
-@MODELS.register_module()
-class Transformer(BaseModule):
- """Implements the DETR transformer.
-
- Following the official DETR implementation, this module copy-paste
- from torch.nn.Transformer with modifications:
-
- * positional encodings are passed in MultiheadAttention
- * extra LN at the end of encoder is removed
- * decoder returns a stack of activations from all decoding layers
-
- See `paper: End-to-End Object Detection with Transformers
- `_ for details.
-
- Args:
- encoder (`mmengine.ConfigDict` | Dict): Config of
- TransformerEncoder. Defaults to None.
- decoder ((`mmengine.ConfigDict` | Dict)): Config of
- TransformerDecoder. Defaults to None
- init_cfg (obj:`mmegine.ConfigDict`): The Config for initialization.
- Defaults to None.
- """
-
- def __init__(self, encoder=None, decoder=None, init_cfg=None):
- super(Transformer, self).__init__(init_cfg=init_cfg)
- self.encoder = build_transformer_layer_sequence(encoder)
- self.decoder = build_transformer_layer_sequence(decoder)
- self.embed_dims = self.encoder.embed_dims
-
- def init_weights(self):
- # follow the official DETR to init parameters
- for m in self.modules():
- if hasattr(m, 'weight') and m.weight.dim() > 1:
- xavier_init(m, distribution='uniform')
- self._is_init = True
-
- def forward(self, x, mask, query_embed, pos_embed):
- """Forward function for `Transformer`.
-
- Args:
- x (Tensor): Input query with shape [bs, c, h, w] where
- c = embed_dims.
- mask (Tensor): The key_padding_mask used for encoder and decoder,
- with shape [bs, h, w].
- query_embed (Tensor): The query embedding for decoder, with shape
- [num_query, c].
- pos_embed (Tensor): The positional encoding for encoder and
- decoder, with the same shape as `x`.
-
- Returns:
- tuple[Tensor]: results of decoder containing the following tensor.
-
- - out_dec: Output from decoder. If return_intermediate_dec \
- is True output has shape [num_dec_layers, bs,
- num_query, embed_dims], else has shape [1, bs, \
- num_query, embed_dims].
- - memory: Output results from encoder, with shape \
- [bs, embed_dims, h, w].
- """
- bs, c, h, w = x.shape
- # use `view` instead of `flatten` for dynamically exporting to ONNX
- x = x.view(bs, c, -1).permute(2, 0, 1) # [bs, c, h, w] -> [h*w, bs, c]
- pos_embed = pos_embed.view(bs, c, -1).permute(2, 0, 1)
- query_embed = query_embed.unsqueeze(1).repeat(
- 1, bs, 1) # [num_query, dim] -> [num_query, bs, dim]
- mask = mask.view(bs, -1) # [bs, h, w] -> [bs, h*w]
- memory = self.encoder(
- query=x,
- key=None,
- value=None,
- query_pos=pos_embed,
- query_key_padding_mask=mask)
- target = torch.zeros_like(query_embed)
- # out_dec: [num_layers, num_query, bs, dim]
- out_dec = self.decoder(
- query=target,
- key=memory,
- value=memory,
- key_pos=pos_embed,
- query_pos=query_embed,
- key_padding_mask=mask)
- out_dec = out_dec.transpose(1, 2)
- memory = memory.permute(1, 2, 0).reshape(bs, c, h, w)
- return out_dec, memory
-
-
-@MODELS.register_module()
-class DeformableDetrTransformerDecoder(TransformerLayerSequence):
- """Implements the decoder in DETR transformer.
-
- Args:
- return_intermediate (bool): Whether to return intermediate outputs.
- coder_norm_cfg (dict): Config of last normalization layer. Default:
- `LN`.
- """
-
- def __init__(self, *args, return_intermediate=False, **kwargs):
-
- super(DeformableDetrTransformerDecoder, self).__init__(*args, **kwargs)
- self.return_intermediate = return_intermediate
-
- def forward(self,
- query,
- *args,
- reference_points=None,
- valid_ratios=None,
- reg_branches=None,
- **kwargs):
- """Forward function for `TransformerDecoder`.
-
- Args:
- query (Tensor): Input query with shape
- `(num_query, bs, embed_dims)`.
- reference_points (Tensor): The reference
- points of offset. has shape
- (bs, num_query, 4) when as_two_stage,
- otherwise has shape ((bs, num_query, 2).
- valid_ratios (Tensor): The radios of valid
- points on the feature map, has shape
- (bs, num_levels, 2)
- reg_branch: (obj:`nn.ModuleList`): Used for
- refining the regression results. Only would
- be passed when with_box_refine is True,
- otherwise would be passed a `None`.
-
- Returns:
- Tensor: Results with shape [1, num_query, bs, embed_dims] when
- return_intermediate is `False`, otherwise it has shape
- [num_layers, num_query, bs, embed_dims].
- """
- output = query
- intermediate = []
- intermediate_reference_points = []
- for lid, layer in enumerate(self.layers):
- if reference_points.shape[-1] == 4:
- reference_points_input = reference_points[:, :, None] * \
- torch.cat([valid_ratios, valid_ratios], -1)[:, None]
- else:
- assert reference_points.shape[-1] == 2
- reference_points_input = reference_points[:, :, None] * \
- valid_ratios[:, None]
- output = layer(
- output,
- *args,
- reference_points=reference_points_input,
- **kwargs)
- output = output.permute(1, 0, 2)
-
- if reg_branches is not None:
- tmp = reg_branches[lid](output)
- if reference_points.shape[-1] == 4:
- new_reference_points = tmp + inverse_sigmoid(
- reference_points)
- new_reference_points = new_reference_points.sigmoid()
- else:
- assert reference_points.shape[-1] == 2
- new_reference_points = tmp
- new_reference_points[..., :2] = tmp[
- ..., :2] + inverse_sigmoid(reference_points)
- new_reference_points = new_reference_points.sigmoid()
- reference_points = new_reference_points.detach()
-
- output = output.permute(1, 0, 2)
- if self.return_intermediate:
- intermediate.append(output)
- intermediate_reference_points.append(reference_points)
-
- if self.return_intermediate:
- return torch.stack(intermediate), torch.stack(
- intermediate_reference_points)
-
- return output, reference_points
-
-
-@MODELS.register_module()
-class DeformableDetrTransformer(Transformer):
- """Implements the DeformableDETR transformer.
-
- Args:
- as_two_stage (bool): Generate query from encoder features.
- Default: False.
- num_feature_levels (int): Number of feature maps from FPN:
- Default: 4.
- two_stage_num_proposals (int): Number of proposals when set
- `as_two_stage` as True. Default: 300.
- """
-
- def __init__(self,
- as_two_stage=False,
- num_feature_levels=4,
- two_stage_num_proposals=300,
- **kwargs):
- super(DeformableDetrTransformer, self).__init__(**kwargs)
- self.as_two_stage = as_two_stage
- self.num_feature_levels = num_feature_levels
- self.two_stage_num_proposals = two_stage_num_proposals
- self.embed_dims = self.encoder.embed_dims
- self.init_layers()
-
- def init_layers(self):
- """Initialize layers of the DeformableDetrTransformer."""
- self.level_embeds = nn.Parameter(
- torch.Tensor(self.num_feature_levels, self.embed_dims))
-
- if self.as_two_stage:
- self.enc_output = nn.Linear(self.embed_dims, self.embed_dims)
- self.enc_output_norm = nn.LayerNorm(self.embed_dims)
- self.pos_trans = nn.Linear(self.embed_dims * 2,
- self.embed_dims * 2)
- self.pos_trans_norm = nn.LayerNorm(self.embed_dims * 2)
- else:
- self.reference_points = nn.Linear(self.embed_dims, 2)
-
- def init_weights(self):
- """Initialize the transformer weights."""
- for p in self.parameters():
- if p.dim() > 1:
- nn.init.xavier_uniform_(p)
- for m in self.modules():
- if isinstance(m, MultiScaleDeformableAttention):
- m.init_weights()
- if not self.as_two_stage:
- xavier_init(self.reference_points, distribution='uniform', bias=0.)
- normal_(self.level_embeds)
-
- def gen_encoder_output_proposals(self, memory, memory_padding_mask,
- spatial_shapes):
- """Generate proposals from encoded memory.
-
- Args:
- memory (Tensor) : The output of encoder,
- has shape (bs, num_key, embed_dim). num_key is
- equal the number of points on feature map from
- all level.
- memory_padding_mask (Tensor): Padding mask for memory.
- has shape (bs, num_key).
- spatial_shapes (Tensor): The shape of all feature maps.
- has shape (num_level, 2).
-
- Returns:
- tuple: A tuple of feature map and bbox prediction.
-
- - output_memory (Tensor): The input of decoder, \
- has shape (bs, num_key, embed_dim). num_key is \
- equal the number of points on feature map from \
- all levels.
- - output_proposals (Tensor): The normalized proposal \
- after a inverse sigmoid, has shape \
- (bs, num_keys, 4).
- """
-
- N, S, C = memory.shape
- proposals = []
- _cur = 0
- for lvl, (H, W) in enumerate(spatial_shapes):
- mask_flatten_ = memory_padding_mask[:, _cur:(_cur + H * W)].view(
- N, H, W, 1)
- valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1)
- valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1)
-
- grid_y, grid_x = torch.meshgrid(
- torch.linspace(
- 0, H - 1, H, dtype=torch.float32, device=memory.device),
- torch.linspace(
- 0, W - 1, W, dtype=torch.float32, device=memory.device))
- grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1)
-
- scale = torch.cat([valid_W.unsqueeze(-1),
- valid_H.unsqueeze(-1)], 1).view(N, 1, 1, 2)
- grid = (grid.unsqueeze(0).expand(N, -1, -1, -1) + 0.5) / scale
- wh = torch.ones_like(grid) * 0.05 * (2.0**lvl)
- proposal = torch.cat((grid, wh), -1).view(N, -1, 4)
- proposals.append(proposal)
- _cur += (H * W)
- output_proposals = torch.cat(proposals, 1)
- output_proposals_valid = ((output_proposals > 0.01) &
- (output_proposals < 0.99)).all(
- -1, keepdim=True)
- output_proposals = torch.log(output_proposals / (1 - output_proposals))
- output_proposals = output_proposals.masked_fill(
- memory_padding_mask.unsqueeze(-1), float('inf'))
- output_proposals = output_proposals.masked_fill(
- ~output_proposals_valid, float('inf'))
-
- output_memory = memory
- output_memory = output_memory.masked_fill(
- memory_padding_mask.unsqueeze(-1), float(0))
- output_memory = output_memory.masked_fill(~output_proposals_valid,
- float(0))
- output_memory = self.enc_output_norm(self.enc_output(output_memory))
- return output_memory, output_proposals
-
- @staticmethod
- def get_reference_points(spatial_shapes, valid_ratios, device):
- """Get the reference points used in decoder.
-
- Args:
- spatial_shapes (Tensor): The shape of all
- feature maps, has shape (num_level, 2).
- valid_ratios (Tensor): The radios of valid
- points on the feature map, has shape
- (bs, num_levels, 2)
- device (obj:`device`): The device where
- reference_points should be.
-
- Returns:
- Tensor: reference points used in decoder, has \
- shape (bs, num_keys, num_levels, 2).
- """
- reference_points_list = []
- for lvl, (H, W) in enumerate(spatial_shapes):
- # TODO check this 0.5
- ref_y, ref_x = torch.meshgrid(
- torch.linspace(
- 0.5, H - 0.5, H, dtype=torch.float32, device=device),
- torch.linspace(
- 0.5, W - 0.5, W, dtype=torch.float32, device=device))
- ref_y = ref_y.reshape(-1)[None] / (
- valid_ratios[:, None, lvl, 1] * H)
- ref_x = ref_x.reshape(-1)[None] / (
- valid_ratios[:, None, lvl, 0] * W)
- ref = torch.stack((ref_x, ref_y), -1)
- reference_points_list.append(ref)
- reference_points = torch.cat(reference_points_list, 1)
- reference_points = reference_points[:, :, None] * valid_ratios[:, None]
- return reference_points
-
- def get_valid_ratio(self, mask):
- """Get the valid radios of feature maps of all level."""
- _, H, W = mask.shape
- valid_H = torch.sum(~mask[:, :, 0], 1)
- valid_W = torch.sum(~mask[:, 0, :], 1)
- valid_ratio_h = valid_H.float() / H
- valid_ratio_w = valid_W.float() / W
- valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
- return valid_ratio
-
- def get_proposal_pos_embed(self,
- proposals,
- num_pos_feats=128,
- temperature=10000):
- """Get the position embedding of proposal."""
- scale = 2 * math.pi
- dim_t = torch.arange(
- num_pos_feats, dtype=torch.float32, device=proposals.device)
- dim_t = temperature**(2 * (dim_t // 2) / num_pos_feats)
- # N, L, 4
- proposals = proposals.sigmoid() * scale
- # N, L, 4, 128
- pos = proposals[:, :, :, None] / dim_t
- # N, L, 4, 64, 2
- pos = torch.stack((pos[:, :, :, 0::2].sin(), pos[:, :, :, 1::2].cos()),
- dim=4).flatten(2)
- return pos
-
- def forward(self,
- mlvl_feats,
- mlvl_masks,
- query_embed,
- mlvl_pos_embeds,
- reg_branches=None,
- cls_branches=None,
- **kwargs):
- """Forward function for `Transformer`.
-
- Args:
- mlvl_feats (list(Tensor)): Input queries from
- different level. Each element has shape
- [bs, embed_dims, h, w].
- mlvl_masks (list(Tensor)): The key_padding_mask from
- different level used for encoder and decoder,
- each element has shape [bs, h, w].
- query_embed (Tensor): The query embedding for decoder,
- with shape [num_query, c].
- mlvl_pos_embeds (list(Tensor)): The positional encoding
- of feats from different level, has the shape
- [bs, embed_dims, h, w].
- reg_branches (obj:`nn.ModuleList`): Regression heads for
- feature maps from each decoder layer. Only would
- be passed when
- `with_box_refine` is True. Default to None.
- cls_branches (obj:`nn.ModuleList`): Classification heads
- for feature maps from each decoder layer. Only would
- be passed when `as_two_stage`
- is True. Default to None.
-
-
- Returns:
- tuple[Tensor]: results of decoder containing the following tensor.
-
- - inter_states: Outputs from decoder. If
- return_intermediate_dec is True output has shape \
- (num_dec_layers, bs, num_query, embed_dims), else has \
- shape (1, bs, num_query, embed_dims).
- - init_reference_out: The initial value of reference \
- points, has shape (bs, num_queries, 4).
- - inter_references_out: The internal value of reference \
- points in decoder, has shape \
- (num_dec_layers, bs,num_query, embed_dims)
- - enc_outputs_class: The classification score of \
- proposals generated from \
- encoder's feature maps, has shape \
- (batch, h*w, num_classes). \
- Only would be returned when `as_two_stage` is True, \
- otherwise None.
- - enc_outputs_coord_unact: The regression results \
- generated from encoder's feature maps., has shape \
- (batch, h*w, 4). Only would \
- be returned when `as_two_stage` is True, \
- otherwise None.
- """
- assert self.as_two_stage or query_embed is not None
-
- feat_flatten = []
- mask_flatten = []
- lvl_pos_embed_flatten = []
- spatial_shapes = []
- for lvl, (feat, mask, pos_embed) in enumerate(
- zip(mlvl_feats, mlvl_masks, mlvl_pos_embeds)):
- bs, c, h, w = feat.shape
- spatial_shape = (h, w)
- spatial_shapes.append(spatial_shape)
- feat = feat.flatten(2).transpose(1, 2)
- mask = mask.flatten(1)
- pos_embed = pos_embed.flatten(2).transpose(1, 2)
- lvl_pos_embed = pos_embed + self.level_embeds[lvl].view(1, 1, -1)
- lvl_pos_embed_flatten.append(lvl_pos_embed)
- feat_flatten.append(feat)
- mask_flatten.append(mask)
- feat_flatten = torch.cat(feat_flatten, 1)
- mask_flatten = torch.cat(mask_flatten, 1)
- lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
- spatial_shapes = torch.as_tensor(
- spatial_shapes, dtype=torch.long, device=feat_flatten.device)
- level_start_index = torch.cat((spatial_shapes.new_zeros(
- (1, )), spatial_shapes.prod(1).cumsum(0)[:-1]))
- valid_ratios = torch.stack(
- [self.get_valid_ratio(m) for m in mlvl_masks], 1)
-
- reference_points = \
- self.get_reference_points(spatial_shapes,
- valid_ratios,
- device=feat.device)
-
- feat_flatten = feat_flatten.permute(1, 0, 2) # (H*W, bs, embed_dims)
- lvl_pos_embed_flatten = lvl_pos_embed_flatten.permute(
- 1, 0, 2) # (H*W, bs, embed_dims)
- memory = self.encoder(
- query=feat_flatten,
- key=None,
- value=None,
- query_pos=lvl_pos_embed_flatten,
- query_key_padding_mask=mask_flatten,
- spatial_shapes=spatial_shapes,
- reference_points=reference_points,
- level_start_index=level_start_index,
- valid_ratios=valid_ratios,
- **kwargs)
-
- memory = memory.permute(1, 0, 2)
- bs, _, c = memory.shape
- if self.as_two_stage:
- output_memory, output_proposals = \
- self.gen_encoder_output_proposals(
- memory, mask_flatten, spatial_shapes)
- enc_outputs_class = cls_branches[self.decoder.num_layers](
- output_memory)
- enc_outputs_coord_unact = \
- reg_branches[
- self.decoder.num_layers](output_memory) + output_proposals
-
- topk = self.two_stage_num_proposals
- # We only use the first channel in enc_outputs_class as foreground,
- # the other (num_classes - 1) channels are actually not used.
- # Its targets are set to be 0s, which indicates the first
- # class (foreground) because we use [0, num_classes - 1] to
- # indicate class labels, background class is indicated by
- # num_classes (similar convention in RPN).
- # See https://github.com/open-mmlab/mmdetection/blob/master/mmdet/models/dense_heads/deformable_detr_head.py#L241 # noqa
- # This follows the official implementation of Deformable DETR.
- topk_proposals = torch.topk(
- enc_outputs_class[..., 0], topk, dim=1)[1]
- topk_coords_unact = torch.gather(
- enc_outputs_coord_unact, 1,
- topk_proposals.unsqueeze(-1).repeat(1, 1, 4))
- topk_coords_unact = topk_coords_unact.detach()
- reference_points = topk_coords_unact.sigmoid()
- init_reference_out = reference_points
- pos_trans_out = self.pos_trans_norm(
- self.pos_trans(self.get_proposal_pos_embed(topk_coords_unact)))
- query_pos, query = torch.split(pos_trans_out, c, dim=2)
- else:
- query_pos, query = torch.split(query_embed, c, dim=1)
- query_pos = query_pos.unsqueeze(0).expand(bs, -1, -1)
- query = query.unsqueeze(0).expand(bs, -1, -1)
- reference_points = self.reference_points(query_pos).sigmoid()
- init_reference_out = reference_points
-
- # decoder
- query = query.permute(1, 0, 2)
- memory = memory.permute(1, 0, 2)
- query_pos = query_pos.permute(1, 0, 2)
- inter_states, inter_references = self.decoder(
- query=query,
- key=None,
- value=memory,
- query_pos=query_pos,
- key_padding_mask=mask_flatten,
- reference_points=reference_points,
- spatial_shapes=spatial_shapes,
- level_start_index=level_start_index,
- valid_ratios=valid_ratios,
- reg_branches=reg_branches,
- **kwargs)
-
- inter_references_out = inter_references
- if self.as_two_stage:
- return inter_states, init_reference_out,\
- inter_references_out, enc_outputs_class,\
- enc_outputs_coord_unact
- return inter_states, init_reference_out, \
- inter_references_out, None, None
-
-
-@MODELS.register_module()
-class DynamicConv(BaseModule):
- """Implements Dynamic Convolution.
-
- This module generate parameters for each sample and
- use bmm to implement 1*1 convolution. Code is modified
- from the `official github repo `_ .
-
- Args:
- in_channels (int): The input feature channel.
- Defaults to 256.
- feat_channels (int): The inner feature channel.
- Defaults to 64.
- out_channels (int, optional): The output feature channel.
- When not specified, it will be set to `in_channels`
- by default
- input_feat_shape (int): The shape of input feature.
- Defaults to 7.
- with_proj (bool): Project two-dimentional feature to
- one-dimentional feature. Default to True.
- act_cfg (dict): The activation config for DynamicConv.
- norm_cfg (dict): Config dict for normalization layer. Default
- layer normalization.
- init_cfg (obj:`mmengine.ConfigDict`): The Config for initialization.
- Default: None.
- """
-
- def __init__(self,
- in_channels=256,
- feat_channels=64,
- out_channels=None,
- input_feat_shape=7,
- with_proj=True,
- act_cfg=dict(type='ReLU', inplace=True),
- norm_cfg=dict(type='LN'),
- init_cfg=None):
- super(DynamicConv, self).__init__(init_cfg)
- self.in_channels = in_channels
- self.feat_channels = feat_channels
- self.out_channels_raw = out_channels
- self.input_feat_shape = input_feat_shape
- self.with_proj = with_proj
- self.act_cfg = act_cfg
- self.norm_cfg = norm_cfg
- self.out_channels = out_channels if out_channels else in_channels
-
- self.num_params_in = self.in_channels * self.feat_channels
- self.num_params_out = self.out_channels * self.feat_channels
- self.dynamic_layer = nn.Linear(
- self.in_channels, self.num_params_in + self.num_params_out)
-
- self.norm_in = build_norm_layer(norm_cfg, self.feat_channels)[1]
- self.norm_out = build_norm_layer(norm_cfg, self.out_channels)[1]
-
- self.activation = build_activation_layer(act_cfg)
-
- num_output = self.out_channels * input_feat_shape**2
- if self.with_proj:
- self.fc_layer = nn.Linear(num_output, self.out_channels)
- self.fc_norm = build_norm_layer(norm_cfg, self.out_channels)[1]
-
- def forward(self, param_feature, input_feature):
- """Forward function for `DynamicConv`.
-
- Args:
- param_feature (Tensor): The feature can be used
- to generate the parameter, has shape
- (num_all_proposals, in_channels).
- input_feature (Tensor): Feature that
- interact with parameters, has shape
- (num_all_proposals, in_channels, H, W).
-
- Returns:
- Tensor: The output feature has shape
- (num_all_proposals, out_channels).
- """
- input_feature = input_feature.flatten(2).permute(2, 0, 1)
-
- input_feature = input_feature.permute(1, 0, 2)
- parameters = self.dynamic_layer(param_feature)
-
- param_in = parameters[:, :self.num_params_in].view(
- -1, self.in_channels, self.feat_channels)
- param_out = parameters[:, -self.num_params_out:].view(
- -1, self.feat_channels, self.out_channels)
-
- # input_feature has shape (num_all_proposals, H*W, in_channels)
- # param_in has shape (num_all_proposals, in_channels, feat_channels)
- # feature has shape (num_all_proposals, H*W, feat_channels)
- features = torch.bmm(input_feature, param_in)
- features = self.norm_in(features)
- features = self.activation(features)
-
- # param_out has shape (batch_size, feat_channels, out_channels)
- features = torch.bmm(features, param_out)
- features = self.norm_out(features)
- features = self.activation(features)
-
- if self.with_proj:
- features = features.flatten(1)
- features = self.fc_layer(features)
- features = self.fc_norm(features)
- features = self.activation(features)
-
- return features
diff --git a/mmdet/models/layers/transformer/__init__.py b/mmdet/models/layers/transformer/__init__.py
new file mode 100644
index 00000000000..09e1c904936
--- /dev/null
+++ b/mmdet/models/layers/transformer/__init__.py
@@ -0,0 +1,24 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .conditional_detr_transformer import (
+ ConditionalDetrTransformerDecoder, ConditionalDetrTransformerDecoderLayer)
+from .deformable_detr_transformer import (
+ DeformableDetrTransformerDecoder, DeformableDetrTransformerDecoderLayer,
+ DeformableDetrTransformerEncoder, DeformableDetrTransformerEncoderLayer)
+from .detr_transformer import (DetrTransformerDecoder,
+ DetrTransformerDecoderLayer,
+ DetrTransformerEncoder,
+ DetrTransformerEncoderLayer)
+from .utils import (MLP, AdaptivePadding, DynamicConv, PatchEmbed,
+ PatchMerging, inverse_sigmoid, nchw_to_nlc, nlc_to_nchw)
+
+__all__ = [
+ 'nlc_to_nchw', 'nchw_to_nlc', 'AdaptivePadding', 'PatchEmbed',
+ 'PatchMerging', 'inverse_sigmoid', 'DynamicConv', 'MLP',
+ 'DetrTransformerEncoder', 'DetrTransformerDecoder',
+ 'DetrTransformerEncoderLayer', 'DetrTransformerDecoderLayer',
+ 'DeformableDetrTransformerEncoder', 'DeformableDetrTransformerDecoder',
+ 'DeformableDetrTransformerEncoderLayer',
+ 'DeformableDetrTransformerDecoderLayer',
+ 'ConditionalDetrTransformerDecoder',
+ 'ConditionalDetrTransformerDecoderLayer'
+]
diff --git a/mmdet/models/layers/transformer/conditional_detr_transformer.py b/mmdet/models/layers/transformer/conditional_detr_transformer.py
new file mode 100644
index 00000000000..d7508d36199
--- /dev/null
+++ b/mmdet/models/layers/transformer/conditional_detr_transformer.py
@@ -0,0 +1,418 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+from typing import Tuple
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import Linear, build_norm_layer
+from mmcv.cnn.bricks.drop import Dropout
+from mmcv.cnn.bricks.transformer import FFN
+from mmengine.model import BaseModule
+from torch import Tensor
+from torch.nn import ModuleList
+
+from mmdet.utils import OptMultiConfig
+from .detr_transformer import (DetrTransformerDecoder,
+ DetrTransformerDecoderLayer)
+from .utils import MLP, gen_sine_embed_for_ref
+
+
+class ConditionalDetrTransformerDecoder(DetrTransformerDecoder):
+ """Decoder of Conditional DETR."""
+
+ def _init_layers(self) -> None:
+ """Initialize decoder layers."""
+ self.layers = ModuleList([
+ ConditionalDetrTransformerDecoderLayer(**self.layer_cfg)
+ for _ in range(self.num_layers)
+ ])
+ self.embed_dims = self.layers[0].embed_dims
+ self.post_norm = build_norm_layer(self.post_norm_cfg,
+ self.embed_dims)[1]
+ # conditional detr affline
+ self.query_scale = MLP(self.embed_dims, self.embed_dims,
+ self.embed_dims, 2)
+ self.ref_point_head = MLP(self.embed_dims, self.embed_dims, 2, 2)
+ for layer_id in range(self.num_layers - 1):
+ self.layers[layer_id + 1].cross_attn.qpos_proj = None
+
+ def forward(self, query: Tensor, key: Tensor, value: Tensor,
+ query_pos: Tensor, key_pos: Tensor, key_padding_mask: Tensor):
+ """Forward function of decoder.
+
+ Args:
+ query (Tensor): The input query with shape
+ (num_queries, bs, dim).
+ key (Tensor): The input key with shape (num_key, bs, dim) If
+ `None`, the `query` will be used. Defaults to `None`.
+ value (Tensor): The input value with the same shape as
+ `key`. If `None`, the `key` will be used. Defaults to `None`.
+ query_pos (Tensor): The positional encoding for `query`, with the
+ same shape as `query`. If not `None`, it will be added to
+ `query` before forward function. Defaults to `None`.
+ reg_branches (nn.Module): The regression branch for dynamically
+ updating references in each layer.
+ key_pos (Tensor): The positional encoding for `key`, with the
+ same shape as `key`.
+ key_padding_mask (Tensor): ByteTensor with shape (bs, num_key).
+ Returns:
+ List[Tensor]: forwarded results with shape (num_decoder_layers,
+ bs, num_queries, dim) if `return_intermediate` is True, otherwise
+ with shape (1, bs, num_queries, dim). references with shape
+ (num_decoder_layers, bs, num_queries, 2).
+ """
+ reference_unsigmoid = self.ref_point_head(
+ query_pos) # [num_queries, batch_size, 2]
+ reference = reference_unsigmoid.sigmoid().transpose(0, 1)
+ reference_xy = reference[..., :2].transpose(0, 1)
+ intermediate = []
+ for layer_id, layer in enumerate(self.layers):
+ if layer_id == 0:
+ pos_transformation = 1
+ else:
+ pos_transformation = self.query_scale(query)
+ # get sine embedding for the query vector
+ ref_sine_embed = gen_sine_embed_for_ref(reference_xy)
+ # apply transformation
+ ref_sine_embed = ref_sine_embed * pos_transformation
+ query = layer(
+ query,
+ key=key,
+ value=value,
+ query_pos=query_pos,
+ key_pos=key_pos,
+ key_padding_mask=key_padding_mask,
+ ref_sine_embed=ref_sine_embed,
+ is_first=(layer_id == 0))
+ if self.return_intermediate:
+ intermediate.append(self.post_norm(query))
+
+ if self.return_intermediate:
+ return torch.stack(intermediate), reference
+
+ return query, reference
+
+
+class ConditionalDetrTransformerDecoderLayer(DetrTransformerDecoderLayer):
+ """Implements decoder layer in Conditional DETR transformer."""
+
+ def _init_layers(self):
+ """Initialize self-attention, cross-attention, FFN, and
+ normalization."""
+ self.self_attn = ConditionalAttention(**self.self_attn_cfg)
+ self.cross_attn = ConditionalAttention(**self.cross_attn_cfg)
+ self.embed_dims = self.self_attn.embed_dims # TODO
+ self.ffn = FFN(**self.ffn_cfg)
+ norms_list = [
+ build_norm_layer(self.norm_cfg, self.embed_dims)[1]
+ for _ in range(3)
+ ]
+ self.norms = ModuleList(norms_list)
+
+ def forward(self,
+ query: Tensor,
+ key: Tensor = None,
+ value: Tensor = None,
+ query_pos: Tensor = None,
+ key_pos: Tensor = None,
+ self_attn_masks: Tensor = None,
+ cross_attn_masks: Tensor = None,
+ key_padding_mask: Tensor = None,
+ ref_sine_embed: Tensor = None,
+ is_first=None):
+ """
+ Args:
+ query (Tensor): The input query, has shape (num_queries, bs, dim)
+ key (Tensor, optional): The input key, has shape (num_key, bs, dim)
+ If `None`, the `query` will be used. Defaults to `None`.
+ value (Tensor, optional): The input value, has the same shape as
+ `key`, as in `nn.MultiheadAttention.forward`. If `None`, the
+ `key` will be used. Defaults to `None`.
+ query_pos (Tensor, optional): The positional encoding for `query`,
+ has the same shape as `query`. If not `None`, it will be
+ added to `query` before forward function. Defaults to `None`.
+ key_pos (Tensor, optional): The positional encoding for `key`, has
+ the same shape as `key`.
+ self_attn_masks (Tensor, optional): ByteTensor mask, has shape
+ (num_queries, num_key), Same in `nn.MultiheadAttention.
+ forward`. Defaults to None.
+ cross_attn_masks (Tensor, optional): ByteTensor mask, has shape
+ (num_queries, num_key), Same in `nn.MultiheadAttention.
+ forward`. Defaults to None.
+ key_padding_mask (Tensor, optional): The `key_padding_mask` of
+ `cross_attn` input. ByteTensor, has shape (bs, num_key).
+ is_first (bool): A indicator to tell whether the current layer
+ is the first layer of the decoder.
+ Defaults to False.
+
+ Returns:
+ Tensor: forwarded results, has shape (num_queries, bs, dim).
+ """
+ query = self.self_attn(
+ query=query,
+ key=query,
+ query_pos=query_pos,
+ key_pos=query_pos,
+ attn_mask=self_attn_masks)
+ query = self.norms[0](query)
+ query = self.cross_attn(
+ query=query,
+ key=key,
+ query_pos=query_pos,
+ key_pos=key_pos,
+ attn_mask=cross_attn_masks,
+ key_padding_mask=key_padding_mask,
+ ref_sine_embed=ref_sine_embed,
+ is_first=is_first)
+ query = self.norms[1](query)
+ query = self.ffn(query)
+ query = self.norms[2](query)
+
+ return query
+
+
+class ConditionalAttention(BaseModule):
+ """A wrapper of conditional attention, dropout and residual connection."""
+
+ def __init__(self,
+ embed_dims: int,
+ num_heads: int,
+ attn_drop: float = 0.,
+ proj_drop: float = 0.,
+ cross_attn: bool = False,
+ keep_query_pos: bool = False,
+ init_cfg: OptMultiConfig = None,
+ group_detr=1):
+ super().__init__(init_cfg)
+ self.cross_attn = cross_attn
+ self.keep_query_pos = keep_query_pos
+ self.embed_dims = embed_dims
+ self.num_heads = num_heads
+ self.attn_drop = Dropout(attn_drop)
+ self.proj_drop = Dropout(proj_drop)
+
+ self._init_proj()
+ self.group_detr = group_detr
+
+ def _init_proj(self):
+ embed_dims = self.embed_dims
+ self.qcontent_proj = Linear(embed_dims, embed_dims)
+ self.qpos_proj = Linear(embed_dims, embed_dims)
+ self.kcontent_proj = Linear(embed_dims, embed_dims)
+ self.kpos_proj = Linear(embed_dims, embed_dims)
+ self.v_proj = Linear(embed_dims, embed_dims)
+ if self.cross_attn:
+ self.qpos_sine_proj = Linear(embed_dims, embed_dims)
+ self.out_proj = Linear(embed_dims, embed_dims)
+
+ nn.init.constant_(self.out_proj.bias, 0.) # init out_proj
+
+ def forward_attn(self,
+ query: Tensor,
+ key: Tensor,
+ value: Tensor,
+ attn_mask: Tensor,
+ key_padding_mask: Tensor = None) -> Tuple[Tensor]:
+ assert key.size(0) == value.size(0), \
+ f'{"key, value must have the same sequence length"}'
+ assert query.size(1) == key.size(1) == value.size(1), \
+ f'{"batch size must be equal for query, key, value"}'
+ assert query.size(2) == key.size(2), \
+ f'{"q_dims, k_dims must be equal"}'
+ assert value.size(2) == self.embed_dims, \
+ f'{"v_dims must be equal to embed_dims"}'
+
+ tgt_len, bs, hidden_dims = query.size()
+ head_dims = hidden_dims // self.num_heads
+ v_head_dims = self.embed_dims // self.num_heads
+ assert head_dims * self.num_heads == hidden_dims, \
+ f'{"hidden_dims must be divisible by num_heads"}'
+ scaling = float(head_dims)**-0.5
+
+ q = query * scaling
+ k = key
+ v = value
+
+ if attn_mask is not None:
+ assert attn_mask.dtype == torch.float32 or \
+ attn_mask.dtype == torch.float64 or \
+ attn_mask.dtype == torch.float16 or \
+ attn_mask.dtype == torch.uint8 or \
+ attn_mask.dtype == torch.bool, \
+ 'Only float, byte, and bool types are supported for \
+ attn_mask'
+
+ if attn_mask.dtype == torch.uint8:
+ warnings.warn('Byte tensor for attn_mask is deprecated.\
+ Use bool tensor instead.')
+ attn_mask = attn_mask.to(torch.bool)
+ if attn_mask.dim() == 2:
+ attn_mask = attn_mask.unsqueeze(0)
+ if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
+ raise RuntimeError(
+ 'The size of the 2D attn_mask is not correct.')
+ elif attn_mask.dim() == 3:
+ if list(attn_mask.size()) != [
+ bs * self.num_heads,
+ query.size(0),
+ key.size(0)
+ ]:
+ raise RuntimeError(
+ 'The size of the 3D attn_mask is not correct.')
+ else:
+ raise RuntimeError(
+ "attn_mask's dimension {} is not supported".format(
+ attn_mask.dim()))
+ # attn_mask's dim is 3 now.
+
+ if key_padding_mask is not None and key_padding_mask.dtype == int:
+ key_padding_mask = key_padding_mask.to(torch.bool)
+
+ q = q.contiguous().view(tgt_len, bs * self.num_heads,
+ head_dims).transpose(0, 1)
+ if k is not None:
+ k = k.contiguous().view(-1, bs * self.num_heads,
+ head_dims).transpose(0, 1)
+ if v is not None:
+ v = v.contiguous().view(-1, bs * self.num_heads,
+ v_head_dims).transpose(0, 1)
+
+ src_len = k.size(1)
+
+ if key_padding_mask is not None:
+ assert key_padding_mask.size(0) == bs
+ assert key_padding_mask.size(1) == src_len
+
+ attn_output_weights = torch.bmm(q, k.transpose(1, 2))
+ assert list(attn_output_weights.size()) == [
+ bs * self.num_heads, tgt_len, src_len
+ ]
+
+ if attn_mask is not None:
+ if attn_mask.dtype == torch.bool:
+ attn_output_weights.masked_fill_(attn_mask, float('-inf'))
+ else:
+ attn_output_weights += attn_mask
+
+ if key_padding_mask is not None:
+ attn_output_weights = attn_output_weights.view(
+ bs, self.num_heads, tgt_len, src_len)
+ attn_output_weights = attn_output_weights.masked_fill(
+ key_padding_mask.unsqueeze(1).unsqueeze(2),
+ float('-inf'),
+ )
+ attn_output_weights = attn_output_weights.view(
+ bs * self.num_heads, tgt_len, src_len)
+
+ attn_output_weights = F.softmax(attn_output_weights, dim=-1)
+ attn_output_weights = self.attn_drop(attn_output_weights)
+
+ attn_output = torch.bmm(attn_output_weights, v)
+ assert list(
+ attn_output.size()) == [bs * self.num_heads, tgt_len, v_head_dims]
+ attn_output = attn_output.transpose(0, 1).contiguous().view(
+ tgt_len, bs, self.embed_dims)
+ attn_output = self.out_proj(attn_output)
+
+ # average attention weights over heads
+ attn_output_weights = attn_output_weights.view(bs, self.num_heads,
+ tgt_len, src_len)
+ return attn_output, attn_output_weights.sum(dim=1) / self.num_heads
+
+ def forward(
+ self,
+ query: Tensor,
+ key: Tensor,
+ query_pos: Tensor = None,
+ ref_sine_embed: Tensor = None,
+ key_pos: Tensor = None, # pos
+ attn_mask: Tensor = None,
+ key_padding_mask: Tensor = None,
+ is_first: bool = False) -> Tensor:
+ """Forward function for `ConditionalAttention`.
+ Args:
+ query (Tensor): The input query with shape [num_queries, bs,
+ embed_dims] if self.batch_first is False, else
+ [bs, num_queries embed_dims].
+ key (Tensor): The key tensor with shape [num_keys, bs,
+ embed_dims] if self.batch_first is False, else
+ [bs, num_keys, embed_dims] .
+ If None, the ``query`` will be used. Defaults to None.
+ query_pos (Tensor): The positional encoding for query in self
+ attention, with the same shape as `x`. If not None, it will
+ be added to `x` before forward function.
+ Defaults to None.
+ query_sine_embed (Tensor): The positional encoding for query in
+ cross attention, with the same shape as `x`. If not None, it
+ will be added to `x` before forward function.
+ Defaults to None.
+ key_pos (Tensor): The positional encoding for `key`, with the
+ same shape as `key`. Defaults to None. If not None, it will
+ be added to `key` before forward function. If None, and
+ `query_pos` has the same shape as `key`, then `query_pos`
+ will be used for `key_pos`. Defaults to None.
+ attn_mask (Tensor): ByteTensor mask with shape [num_queries,
+ num_keys]. Same in `nn.MultiheadAttention.forward`.
+ Defaults to None.
+ key_padding_mask (Tensor): ByteTensor with shape [bs, num_keys].
+ Defaults to None.
+ is_first (bool): A indicator to tell whether the current layer
+ is the first layer of the decoder.
+ Defaults to False.
+ Returns:
+ Tensor: forwarded results with shape [num_queries, bs, embed_dims].
+ """
+ if self.cross_attn:
+ q_content = self.qcontent_proj(query)
+ k_content = self.kcontent_proj(key)
+ v = self.v_proj(key)
+
+ nq, bs, c = q_content.size()
+ hw, _, _ = k_content.size()
+
+ k_pos = self.kpos_proj(key_pos)
+ if is_first or self.keep_query_pos:
+ q_pos = self.qpos_proj(query_pos)
+ q = q_content + q_pos
+ k = k_content + k_pos
+ else:
+ q = q_content
+ k = k_content
+ q = q.view(nq, bs, self.num_heads, c // self.num_heads)
+ query_sine_embed = self.qpos_sine_proj(ref_sine_embed)
+ query_sine_embed = query_sine_embed.view(nq, bs, self.num_heads,
+ c // self.num_heads)
+ q = torch.cat([q, query_sine_embed], dim=3).view(nq, bs, 2 * c)
+ k = k.view(hw, bs, self.num_heads, c // self.num_heads)
+ k_pos = k_pos.view(hw, bs, self.num_heads, c // self.num_heads)
+ k = torch.cat([k, k_pos], dim=3).view(hw, bs, 2 * c)
+ ca_output = self.forward_attn(
+ query=q,
+ key=k,
+ value=v,
+ attn_mask=attn_mask,
+ key_padding_mask=key_padding_mask)[0]
+ query = query + self.proj_drop(ca_output)
+ else:
+ q_content = self.qcontent_proj(query)
+ q_pos = self.qpos_proj(query_pos)
+ k_content = self.kcontent_proj(query)
+ k_pos = self.kpos_proj(query_pos)
+ v = self.v_proj(query)
+ num_queries, bs, _ = q_content.shape
+ q = q_content if q_pos is None else q_content + q_pos
+ k = k_content if k_pos is None else k_content + k_pos
+ if self.training:
+ q = torch.cat(
+ q.split(num_queries // self.group_detr, dim=0), dim=1)
+ k = torch.cat(
+ k.split(num_queries // self.group_detr, dim=0), dim=1)
+ v = torch.cat(
+ v.split(num_queries // self.group_detr, dim=0), dim=1)
+ sa_output = self.forward_attn(
+ query=q, key=k, value=v, attn_mask=attn_mask)[0]
+ if self.training:
+ sa_output = torch.cat(sa_output.split(bs, dim=1), dim=0)
+ query = query + self.proj_drop(sa_output)
+ return query
diff --git a/mmdet/models/layers/transformer/deformable_detr_transformer.py b/mmdet/models/layers/transformer/deformable_detr_transformer.py
new file mode 100644
index 00000000000..311d519f7d9
--- /dev/null
+++ b/mmdet/models/layers/transformer/deformable_detr_transformer.py
@@ -0,0 +1,251 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Optional, Tuple, Union
+
+import torch
+from mmcv.cnn import build_norm_layer
+from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention
+from mmcv.ops import MultiScaleDeformableAttention
+from mmengine.model import ModuleList
+from torch import Tensor, nn
+
+from .detr_transformer import (DetrTransformerDecoder,
+ DetrTransformerDecoderLayer,
+ DetrTransformerEncoder,
+ DetrTransformerEncoderLayer)
+from .utils import inverse_sigmoid
+
+
+class DeformableDetrTransformerEncoder(DetrTransformerEncoder):
+ """Transformer encoder of Deformable DETR."""
+
+ def _init_layers(self) -> None:
+ """Initialize encoder layers."""
+ self.layers = ModuleList([
+ DeformableDetrTransformerEncoderLayer(**self.layer_cfg)
+ for _ in range(self.num_layers)
+ ])
+ self.embed_dims = self.layers[0].embed_dims
+
+ def forward(self, query: Tensor, query_pos: Tensor,
+ key_padding_mask: Tensor, spatial_shapes: Tensor,
+ level_start_index: Tensor, valid_ratios: Tensor,
+ **kwargs) -> Tensor:
+ """Forward function of Transformer encoder.
+
+ Args:
+ query (Tensor): The input query, has shape (num_queries, bs, dim).
+ query_pos (Tensor): The positional encoding for query, has shape
+ (num_queries, bs, dim). If not None, it will be added to the
+ `query` before forward function. Defaults to None.
+ key_padding_mask (Tensor): The `key_padding_mask` of `self_attn`
+ input. ByteTensor, has shape (num_queries, bs).
+ spatial_shapes (Tensor): Spatial shapes of features in all levels,
+ has shape (num_levels, 2), last dimension represents (h, w).
+ level_start_index (Tensor): The start index of each level.
+ A tensor has shape (num_levels, ) and can be represented
+ as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
+ valid_ratios (Tensor): The ratios of the valid width and the valid
+ height relative to the width and the height of features in all
+ levels, has shape (bs, num_levels, 2).
+
+ Returns:
+ Tensor: Output queries of Transformer encoder, which is also
+ called 'encoder output embeddings' or 'memory', has shape
+ (num_queries, bs, dim)
+ """
+ reference_points = self.get_encoder_reference_points(
+ spatial_shapes, valid_ratios, device=query.device)
+ for layer in self.layers:
+ query = layer(
+ query=query,
+ query_pos=query_pos,
+ key_padding_mask=key_padding_mask,
+ spatial_shapes=spatial_shapes,
+ level_start_index=level_start_index,
+ valid_ratios=valid_ratios,
+ reference_points=reference_points,
+ **kwargs)
+ return query
+
+ @staticmethod
+ def get_encoder_reference_points(
+ spatial_shapes: Tensor, valid_ratios: Tensor,
+ device: Union[torch.device, str]) -> Tensor:
+ """Get the reference points used in encoder.
+
+ Args:
+ spatial_shapes (Tensor): Spatial shapes of features in all levels,
+ has shape (num_levels, 2), last dimension represents (h, w).
+ valid_ratios (Tensor): The ratios of the valid width and the valid
+ height relative to the width and the height of features in all
+ levels, has shape (bs, num_levels, 2).
+ device (obj:`device` or str): The device acquired by the
+ `reference_points`.
+
+ Returns:
+ Tensor: Reference points used in decoder, has shape (bs, length,
+ num_levels, 2).
+ """
+
+ reference_points_list = []
+ for lvl, (H, W) in enumerate(spatial_shapes):
+ ref_y, ref_x = torch.meshgrid(
+ torch.linspace(
+ 0.5, H - 0.5, H, dtype=torch.float32, device=device),
+ torch.linspace(
+ 0.5, W - 0.5, W, dtype=torch.float32, device=device))
+ ref_y = ref_y.reshape(-1)[None] / (
+ valid_ratios[:, None, lvl, 1] * H)
+ ref_x = ref_x.reshape(-1)[None] / (
+ valid_ratios[:, None, lvl, 0] * W)
+ ref = torch.stack((ref_x, ref_y), -1)
+ reference_points_list.append(ref)
+ reference_points = torch.cat(reference_points_list, 1)
+ # [bs, sum(hw), num_level, 2]
+ reference_points = reference_points[:, :, None] * valid_ratios[:, None]
+ return reference_points
+
+
+class DeformableDetrTransformerDecoder(DetrTransformerDecoder):
+ """Transformer Decoder of Deformable DETR."""
+
+ def _init_layers(self) -> None:
+ """Initialize decoder layers."""
+ self.layers = ModuleList([
+ DeformableDetrTransformerDecoderLayer(**self.layer_cfg)
+ for _ in range(self.num_layers)
+ ])
+ if self.post_norm_cfg is not None:
+ raise ValueError('There is not post_norm in '
+ 'DeformableDetrTransformerDecoder')
+
+ def forward(self,
+ query: Tensor,
+ query_pos: Tensor,
+ value: Tensor,
+ key_padding_mask: Tensor,
+ reference_points: Tensor,
+ spatial_shapes: Tensor,
+ level_start_index: Tensor,
+ valid_ratios: Tensor,
+ reg_branches: Optional[nn.Module] = None,
+ **kwargs) -> Tuple[Tensor]:
+ """Forward function of Transformer decoder.
+
+ Args:
+ query (Tensor): The input queries, has shape
+ (num_queries, bs, dim).
+ query_pos (Tensor): The input positional query, has shape
+ (num_queries, bs, dim). It will be added to `query` before
+ forward function.
+ value (Tensor): The input values, has shape (num_value, bs, dim).
+ key_padding_mask (Tensor): The `key_padding_mask` of `cross_attn`
+ input. ByteTensor, has shape (num_value, bs).
+ reference_points (Tensor): The initial reference, has shape
+ (bs, num_queries, 4) when `as_two_stage` is `True`,
+ otherwise has shape (bs, num_queries, 2).
+ spatial_shapes (Tensor): Spatial shapes of features in all levels,
+ has shape (num_levels, 2), last dimension represents (h, w).
+ level_start_index (Tensor): The start index of each level.
+ A tensor has shape (num_levels, ) and can be represented
+ as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
+ valid_ratios (Tensor): The ratios of the valid width and the valid
+ height relative to the width and the height of features in all
+ levels, has shape (bs, num_levels, 2).
+ reg_branches: (obj:`nn.ModuleList`, optional): Used for refining
+ the regression results. Only would be passed when
+ `with_box_refine` is `True`, otherwise would be `None`.
+
+ Returns:
+ tuple[Tensor]: Outputs of Deformable Transformer Decoder.
+
+ - output (Tensor): Output embeddings of the last decoder, has
+ shape (num_queries, bs, embed_dims) when `return_intermediate`
+ is `False`. Otherwise, Intermediate output embeddings of all
+ decoder layers, has shape (num_decoder_layers, num_queries, bs,
+ embed_dims).
+ - reference_points (Tensor): The reference of the last decoder
+ layer, has shape (bs, num_queries, 4) when `return_intermediate`
+ is `False`. Otherwise, Intermediate references of all decoder
+ layers, has shape (num_decoder_layers, bs, num_queries, 4).
+ """
+ output = query
+ intermediate = []
+ intermediate_reference_points = []
+ for layer_id, layer in enumerate(self.layers):
+ if reference_points.shape[-1] == 4:
+ reference_points_input = \
+ reference_points[:, :, None] * \
+ torch.cat([valid_ratios, valid_ratios], -1)[:, None]
+ else:
+ assert reference_points.shape[-1] == 2
+ reference_points_input = \
+ reference_points[:, :, None] * \
+ valid_ratios[:, None]
+ output = layer(
+ output,
+ query_pos=query_pos,
+ value=value,
+ key_padding_mask=key_padding_mask,
+ spatial_shapes=spatial_shapes,
+ level_start_index=level_start_index,
+ valid_ratios=valid_ratios,
+ reference_points=reference_points_input,
+ **kwargs)
+ output = output.permute(1, 0, 2)
+
+ if reg_branches is not None:
+ tmp = reg_branches[layer_id](output)
+ if reference_points.shape[-1] == 4:
+ new_reference_points = tmp + inverse_sigmoid(
+ reference_points)
+ new_reference_points = new_reference_points.sigmoid()
+ else:
+ assert reference_points.shape[-1] == 2
+ new_reference_points = tmp
+ new_reference_points[..., :2] = tmp[
+ ..., :2] + inverse_sigmoid(reference_points)
+ new_reference_points = new_reference_points.sigmoid()
+ reference_points = new_reference_points.detach()
+
+ output = output.permute(1, 0, 2)
+ if self.return_intermediate:
+ intermediate.append(output)
+ intermediate_reference_points.append(reference_points)
+
+ if self.return_intermediate:
+ return torch.stack(intermediate), torch.stack(
+ intermediate_reference_points)
+
+ return output, reference_points
+
+
+class DeformableDetrTransformerEncoderLayer(DetrTransformerEncoderLayer):
+ """Encoder layer of Deformable DETR."""
+
+ def _init_layers(self) -> None:
+ """Initialize self_attn, ffn, and norms."""
+ self.self_attn = MultiScaleDeformableAttention(**self.self_attn_cfg)
+ self.embed_dims = self.self_attn.embed_dims
+ self.ffn = FFN(**self.ffn_cfg)
+ norms_list = [
+ build_norm_layer(self.norm_cfg, self.embed_dims)[1]
+ for _ in range(2)
+ ]
+ self.norms = ModuleList(norms_list)
+
+
+class DeformableDetrTransformerDecoderLayer(DetrTransformerDecoderLayer):
+ """Decoder layer of Deformable DETR."""
+
+ def _init_layers(self) -> None:
+ """Initialize self_attn, cross-attn, ffn, and norms."""
+ self.self_attn = MultiheadAttention(**self.self_attn_cfg)
+ self.cross_attn = MultiScaleDeformableAttention(**self.cross_attn_cfg)
+ self.embed_dims = self.self_attn.embed_dims
+ self.ffn = FFN(**self.ffn_cfg)
+ norms_list = [
+ build_norm_layer(self.norm_cfg, self.embed_dims)[1]
+ for _ in range(3)
+ ]
+ self.norms = ModuleList(norms_list)
diff --git a/mmdet/models/layers/transformer/detr_transformer.py b/mmdet/models/layers/transformer/detr_transformer.py
new file mode 100644
index 00000000000..01dc885e306
--- /dev/null
+++ b/mmdet/models/layers/transformer/detr_transformer.py
@@ -0,0 +1,330 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Union
+
+import torch
+from mmcv.cnn import build_norm_layer
+from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention
+from mmengine import ConfigDict
+from mmengine.model import BaseModule, ModuleList
+from torch import Tensor
+
+from mmdet.utils import ConfigType, OptConfigType
+
+
+class DetrTransformerEncoder(BaseModule):
+ """Encoder of DETR.
+
+ Args:
+ num_layers (int): Number of encoder layers.
+ layer_cfg (:obj:`ConfigDict` or dict): the config of each encoder
+ layer. All the layers will share the same config.
+ init_cfg (:obj:`ConfigDict` or dict, optional): the config to control
+ the initialization. Defaults to None.
+ """
+
+ def __init__(self,
+ num_layers: int,
+ layer_cfg: ConfigType,
+ init_cfg: OptConfigType = None) -> None:
+
+ super().__init__(init_cfg=init_cfg)
+ self.num_layers = num_layers
+ self.layer_cfg = layer_cfg
+ self._init_layers()
+
+ def _init_layers(self) -> None:
+ """Initialize encoder layers."""
+ self.layers = ModuleList([
+ DetrTransformerEncoderLayer(**self.layer_cfg)
+ for _ in range(self.num_layers)
+ ])
+ self.embed_dims = self.layers[0].embed_dims
+
+ def forward(self, query: Tensor, query_pos: Tensor,
+ key_padding_mask: Tensor, **kwargs) -> Tensor:
+ """Forward function of encoder.
+
+ Args:
+ query (Tensor): Input queries of encoder, has shape
+ (num_queries, bs, dim).
+ query_pos (Tensor): The positional embeddings of the queries, has
+ shape (num_queries, bs, dim).
+ key_padding_mask (Tensor): The `key_padding_mask` of `self_attn`
+ input. ByteTensor, has shape (num_queries, bs).
+
+ Returns:
+ Tensor: Has shape (num_queries, bs, dim).
+ """
+ for layer in self.layers:
+ query = layer(query, query_pos, key_padding_mask, **kwargs)
+ return query
+
+
+class DetrTransformerDecoder(BaseModule):
+ """Decoder of DETR.
+
+ Args:
+ num_layers (int): Number of decoder layers.
+ layer_cfg (:obj:`ConfigDict` or dict): the config of each encoder
+ layer. All the layers will share the same config.
+ post_norm_cfg (:obj:`ConfigDict` or dict, optional): Config of the
+ post normalization layer. Defaults to `LN`.
+ return_intermediate (bool, optional): Whether to return outputs of
+ intermediate layers. Defaults to `True`,
+ init_cfg (:obj:`ConfigDict` or dict, optional): the config to control
+ the initialization. Defaults to None.
+ """
+
+ def __init__(self,
+ num_layers: int,
+ layer_cfg: ConfigType,
+ post_norm_cfg: OptConfigType = dict(type='LN'),
+ return_intermediate: bool = True,
+ init_cfg: Union[dict, ConfigDict] = None) -> None:
+ super().__init__(init_cfg=init_cfg)
+ self.layer_cfg = layer_cfg
+ self.num_layers = num_layers
+ self.post_norm_cfg = post_norm_cfg
+ self.return_intermediate = return_intermediate
+ self._init_layers()
+
+ def _init_layers(self) -> None:
+ """Initialize decoder layers."""
+ self.layers = ModuleList([
+ DetrTransformerDecoderLayer(**self.layer_cfg)
+ for _ in range(self.num_layers)
+ ])
+ self.embed_dims = self.layers[0].embed_dims
+ self.post_norm = build_norm_layer(self.post_norm_cfg,
+ self.embed_dims)[1]
+
+ def forward(self, query: Tensor, key: Tensor, value: Tensor,
+ query_pos: Tensor, key_pos: Tensor, key_padding_mask: Tensor,
+ **kwargs) -> Tensor:
+ """Forward function of decoder
+ Args:
+ query (Tensor): The input query, has shape (num_queries, bs, dim).
+ key (Tensor): The input key, has shape (num_key, bs, dim) if. If
+ `None`, the `query` will be used. Defaults to `None`.
+ value (Tensor): The input value with the same shape as `key`.
+ If `None`, the `key` will be used. Defaults to `None`.
+ query_pos (Tensor): The positional encoding for `query`, with the
+ same shape as `query`. If not `None`, it will be added to
+ `query` before forward function. Defaults to `None`.
+ key_pos (Tensor): The positional encoding for `key`, with the
+ same shape as `key`. If not `None`, it will be added to
+ `key` before forward function. If `None`, and `query_pos`
+ has the same shape as `key`, then `query_pos` will be used
+ as `key_pos`. Defaults to `None`.
+ key_padding_mask (Tensor): The `key_padding_mask` of `cross_attn`
+ input. ByteTensor, has shape (num_value, bs).
+
+ Returns:
+ Tensor: The forwarded results will have shape (num_decoder_layers,
+ num_queries, bs, dim) if `return_intermediate` is `True` else
+ (num_queries, bs, dim).
+ """
+ intermediate = []
+ for layer in self.layers:
+ query = layer(
+ query,
+ key=key,
+ value=value,
+ query_pos=query_pos,
+ key_pos=key_pos,
+ key_padding_mask=key_padding_mask,
+ **kwargs)
+ if self.return_intermediate:
+ intermediate.append(self.post_norm(query))
+
+ if self.return_intermediate:
+ return torch.stack(intermediate)
+
+ return query
+
+
+class DetrTransformerEncoderLayer(BaseModule):
+ """Implements encoder layer in DETR transformer.
+
+ Args:
+ self_attn_cfg (:obj:`ConfigDict` or dict, optional): Config for self
+ attention.
+ ffn_cfg (:obj:`ConfigDict` or dict, optional): Config for FFN.
+ norm_cfg (:obj:`ConfigDict` or dict, optional): Config for
+ normalization layers. All the layers will share the same
+ config. Defaults to `LN`.
+ init_cfg (:obj:`ConfigDict` or dict, optional): Config to control
+ the initialization. Defaults to None.
+ """
+
+ def __init__(self,
+ self_attn_cfg: OptConfigType = dict(
+ embed_dims=256, num_heads=8, dropout=0.0),
+ ffn_cfg: OptConfigType = dict(
+ embed_dims=256,
+ feedforward_channels=1024,
+ num_fcs=2,
+ ffn_drop=0.,
+ act_cfg=dict(type='ReLU', inplace=True)),
+ norm_cfg: OptConfigType = dict(type='LN'),
+ init_cfg: OptConfigType = None) -> None:
+
+ super().__init__(init_cfg=init_cfg)
+ self.self_attn_cfg = self_attn_cfg
+ self.ffn_cfg = ffn_cfg
+ self.norm_cfg = norm_cfg
+ self._init_layers()
+
+ def _init_layers(self) -> None:
+ """Initialize self-attention, FFN, and normalization."""
+ self.self_attn = MultiheadAttention(**self.self_attn_cfg)
+ self.embed_dims = self.self_attn.embed_dims
+ self.ffn = FFN(**self.ffn_cfg)
+ norms_list = [
+ build_norm_layer(self.norm_cfg, self.embed_dims)[1]
+ for _ in range(2)
+ ]
+ self.norms = ModuleList(norms_list)
+
+ def forward(self, query: Tensor, query_pos: Tensor,
+ key_padding_mask: Tensor, **kwargs) -> Tensor:
+ """Forward function of an encoder layer.
+
+ Args:
+ query (Tensor): The input query, has shape (num_queries, bs, dim).
+ query_pos (Tensor): The positional encoding for query, with
+ the same shape as `query`. If not None, it will
+ be added to `query` before forward function. Defaults to None.
+ key_padding_mask (Tensor): The `key_padding_mask` of `self_attn`
+ input. ByteTensor. has shape (num_queries, bs).
+ Returns:
+ Tensor: forwarded results, has shape (num_queries, bs, dim).
+ """
+ query = self.self_attn(
+ query=query,
+ key=query,
+ value=query,
+ query_pos=query_pos,
+ key_pos=query_pos,
+ key_padding_mask=key_padding_mask,
+ **kwargs)
+ query = self.norms[0](query)
+ query = self.ffn(query)
+ query = self.norms[1](query)
+
+ return query
+
+
+class DetrTransformerDecoderLayer(BaseModule):
+ """Implements decoder layer in DETR transformer.
+
+ Args:
+ self_attn_cfg (:obj:`ConfigDict` or dict, optional): Config for self
+ attention.
+ cross_attn_cfg (:obj:`ConfigDict` or dict, optional): Config for cross
+ attention.
+ ffn_cfg (:obj:`ConfigDict` or dict, optional): Config for FFN.
+ norm_cfg (:obj:`ConfigDict` or dict, optional): Config for
+ normalization layers. All the layers will share the same
+ config. Defaults to `LN`.
+ init_cfg (:obj:`ConfigDict` or dict, optional): Config to control
+ the initialization. Defaults to None.
+ """
+
+ def __init__(self,
+ self_attn_cfg: OptConfigType = dict(
+ embed_dims=256, num_heads=8, dropout=0.0),
+ cross_attn_cfg: OptConfigType = dict(
+ embed_dims=256, num_heads=8, dropout=0.0),
+ ffn_cfg: OptConfigType = dict(
+ embed_dims=256,
+ feedforward_channels=1024,
+ num_fcs=2,
+ ffn_drop=0.,
+ act_cfg=dict(type='ReLU', inplace=True),
+ ),
+ norm_cfg: OptConfigType = dict(type='LN'),
+ init_cfg: OptConfigType = None) -> None:
+
+ super().__init__(init_cfg=init_cfg)
+ self.self_attn_cfg = self_attn_cfg
+ self.cross_attn_cfg = cross_attn_cfg
+ self.ffn_cfg = ffn_cfg
+ self.norm_cfg = norm_cfg
+ self._init_layers()
+
+ def _init_layers(self) -> None:
+ """Initialize self-attention, FFN, and normalization."""
+ self.self_attn = MultiheadAttention(**self.self_attn_cfg)
+ self.cross_attn = MultiheadAttention(**self.cross_attn_cfg)
+ self.embed_dims = self.self_attn.embed_dims
+ self.ffn = FFN(**self.ffn_cfg)
+ norms_list = [
+ build_norm_layer(self.norm_cfg, self.embed_dims)[1]
+ for _ in range(3)
+ ]
+ self.norms = ModuleList(norms_list)
+
+ def forward(self,
+ query: Tensor,
+ key: Tensor = None,
+ value: Tensor = None,
+ query_pos: Tensor = None,
+ key_pos: Tensor = None,
+ self_attn_masks: Tensor = None,
+ cross_attn_masks: Tensor = None,
+ key_padding_mask: Tensor = None,
+ **kwargs) -> Tensor:
+ """
+ Args:
+ query (Tensor): The input query, has shape (num_queries, bs, dim).
+ key (Tensor, optional): The input key, has shape (num_key, bs,
+ dim). If `None`, the `query` will be used. Defaults to `None`.
+ value (Tensor, optional): The input value, has the same shape as
+ `key`, as in `nn.MultiheadAttention.forward`. If `None`, the
+ `key` will be used. Defaults to `None`.
+ query_pos (Tensor, optional): The positional encoding for `query`,
+ has the same shape as `query`. If not `None`, it will be added
+ to `query` before forward function. Defaults to `None`.
+ key_pos (Tensor, optional): The positional encoding for `key`, has
+ the same shape as `key`. If not `None`, it will be added to
+ `key` before forward function. If None, and `query_pos` has the
+ same shape as `key`, then `query_pos` will be used for
+ `key_pos`. Defaults to None.
+ self_attn_masks (Tensor, optional): ByteTensor mask, has shape
+ (num_queries, num_key), as in `nn.MultiheadAttention.forward`.
+ Defaults to None.
+ cross_attn_masks (Tensor, optional): ByteTensor mask, has shape
+ (num_queries, num_key), as in `nn.MultiheadAttention.forward`.
+ Defaults to None.
+ key_padding_mask (Tensor, optional): The `key_padding_mask` of
+ `self_attn` input. ByteTensor, has shape (num_value, bs).
+ Defaults to None.
+
+ Returns:
+ Tensor: forwarded results, has shape (num_queries, bs, dim).
+ """
+
+ query = self.self_attn(
+ query=query,
+ key=query,
+ value=query,
+ query_pos=query_pos,
+ key_pos=query_pos,
+ attn_mask=self_attn_masks,
+ **kwargs)
+ query = self.norms[0](query)
+ query = self.cross_attn(
+ query=query,
+ key=key,
+ value=value,
+ query_pos=query_pos,
+ key_pos=key_pos,
+ attn_mask=cross_attn_masks,
+ key_padding_mask=key_padding_mask,
+ **kwargs)
+ query = self.norms[1](query)
+ query = self.ffn(query)
+ query = self.norms[2](query)
+
+ return query
diff --git a/mmdet/models/layers/transformer/utils.py b/mmdet/models/layers/transformer/utils.py
new file mode 100644
index 00000000000..287220cebea
--- /dev/null
+++ b/mmdet/models/layers/transformer/utils.py
@@ -0,0 +1,561 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+from typing import Optional, Sequence, Tuple, Union
+
+import torch
+import torch.nn.functional as F
+from mmcv.cnn import (Linear, build_activation_layer, build_conv_layer,
+ build_norm_layer)
+from mmengine.model import BaseModule, ModuleList
+from mmengine.utils import to_2tuple
+from torch import Tensor, nn
+
+from mmdet.registry import MODELS
+from mmdet.utils import OptConfigType
+
+
+def nlc_to_nchw(x: Tensor, hw_shape: Sequence[int]) -> Tensor:
+ """Convert [N, L, C] shape tensor to [N, C, H, W] shape tensor.
+
+ Args:
+ x (Tensor): The input tensor of shape [N, L, C] before conversion.
+ hw_shape (Sequence[int]): The height and width of output feature map.
+
+ Returns:
+ Tensor: The output tensor of shape [N, C, H, W] after conversion.
+ """
+ H, W = hw_shape
+ assert len(x.shape) == 3
+ B, L, C = x.shape
+ assert L == H * W, 'The seq_len does not match H, W'
+ return x.transpose(1, 2).reshape(B, C, H, W).contiguous()
+
+
+def nchw_to_nlc(x):
+ """Flatten [N, C, H, W] shape tensor to [N, L, C] shape tensor.
+
+ Args:
+ x (Tensor): The input tensor of shape [N, C, H, W] before conversion.
+
+ Returns:
+ Tensor: The output tensor of shape [N, L, C] after conversion.
+ """
+ assert len(x.shape) == 4
+ return x.flatten(2).transpose(1, 2).contiguous()
+
+
+class AdaptivePadding(nn.Module):
+ """Applies padding to input (if needed) so that input can get fully covered
+ by filter you specified. It support two modes "same" and "corner". The
+ "same" mode is same with "SAME" padding mode in TensorFlow, pad zero around
+ input. The "corner" mode would pad zero to bottom right.
+
+ Args:
+ kernel_size (int | tuple): Size of the kernel:
+ stride (int | tuple): Stride of the filter. Default: 1:
+ dilation (int | tuple): Spacing between kernel elements.
+ Default: 1
+ padding (str): Support "same" and "corner", "corner" mode
+ would pad zero to bottom right, and "same" mode would
+ pad zero around input. Default: "corner".
+ Example:
+ >>> kernel_size = 16
+ >>> stride = 16
+ >>> dilation = 1
+ >>> input = torch.rand(1, 1, 15, 17)
+ >>> adap_pad = AdaptivePadding(
+ >>> kernel_size=kernel_size,
+ >>> stride=stride,
+ >>> dilation=dilation,
+ >>> padding="corner")
+ >>> out = adap_pad(input)
+ >>> assert (out.shape[2], out.shape[3]) == (16, 32)
+ >>> input = torch.rand(1, 1, 16, 17)
+ >>> out = adap_pad(input)
+ >>> assert (out.shape[2], out.shape[3]) == (16, 32)
+ """
+
+ def __init__(self, kernel_size=1, stride=1, dilation=1, padding='corner'):
+
+ super(AdaptivePadding, self).__init__()
+
+ assert padding in ('same', 'corner')
+
+ kernel_size = to_2tuple(kernel_size)
+ stride = to_2tuple(stride)
+ padding = to_2tuple(padding)
+ dilation = to_2tuple(dilation)
+
+ self.padding = padding
+ self.kernel_size = kernel_size
+ self.stride = stride
+ self.dilation = dilation
+
+ def get_pad_shape(self, input_shape):
+ input_h, input_w = input_shape
+ kernel_h, kernel_w = self.kernel_size
+ stride_h, stride_w = self.stride
+ output_h = math.ceil(input_h / stride_h)
+ output_w = math.ceil(input_w / stride_w)
+ pad_h = max((output_h - 1) * stride_h +
+ (kernel_h - 1) * self.dilation[0] + 1 - input_h, 0)
+ pad_w = max((output_w - 1) * stride_w +
+ (kernel_w - 1) * self.dilation[1] + 1 - input_w, 0)
+ return pad_h, pad_w
+
+ def forward(self, x):
+ pad_h, pad_w = self.get_pad_shape(x.size()[-2:])
+ if pad_h > 0 or pad_w > 0:
+ if self.padding == 'corner':
+ x = F.pad(x, [0, pad_w, 0, pad_h])
+ elif self.padding == 'same':
+ x = F.pad(x, [
+ pad_w // 2, pad_w - pad_w // 2, pad_h // 2,
+ pad_h - pad_h // 2
+ ])
+ return x
+
+
+class PatchEmbed(BaseModule):
+ """Image to Patch Embedding.
+
+ We use a conv layer to implement PatchEmbed.
+
+ Args:
+ in_channels (int): The num of input channels. Default: 3
+ embed_dims (int): The dimensions of embedding. Default: 768
+ conv_type (str): The config dict for embedding
+ conv layer type selection. Default: "Conv2d.
+ kernel_size (int): The kernel_size of embedding conv. Default: 16.
+ stride (int): The slide stride of embedding conv.
+ Default: None (Would be set as `kernel_size`).
+ padding (int | tuple | string ): The padding length of
+ embedding conv. When it is a string, it means the mode
+ of adaptive padding, support "same" and "corner" now.
+ Default: "corner".
+ dilation (int): The dilation rate of embedding conv. Default: 1.
+ bias (bool): Bias of embed conv. Default: True.
+ norm_cfg (dict, optional): Config dict for normalization layer.
+ Default: None.
+ input_size (int | tuple | None): The size of input, which will be
+ used to calculate the out size. Only work when `dynamic_size`
+ is False. Default: None.
+ init_cfg (`mmengine.ConfigDict`, optional): The Config for
+ initialization. Default: None.
+ """
+
+ def __init__(self,
+ in_channels: int = 3,
+ embed_dims: int = 768,
+ conv_type: str = 'Conv2d',
+ kernel_size: int = 16,
+ stride: int = 16,
+ padding: Union[int, tuple, str] = 'corner',
+ dilation: int = 1,
+ bias: bool = True,
+ norm_cfg: OptConfigType = None,
+ input_size: Union[int, tuple] = None,
+ init_cfg: OptConfigType = None) -> None:
+ super(PatchEmbed, self).__init__(init_cfg=init_cfg)
+
+ self.embed_dims = embed_dims
+ if stride is None:
+ stride = kernel_size
+
+ kernel_size = to_2tuple(kernel_size)
+ stride = to_2tuple(stride)
+ dilation = to_2tuple(dilation)
+
+ if isinstance(padding, str):
+ self.adap_padding = AdaptivePadding(
+ kernel_size=kernel_size,
+ stride=stride,
+ dilation=dilation,
+ padding=padding)
+ # disable the padding of conv
+ padding = 0
+ else:
+ self.adap_padding = None
+ padding = to_2tuple(padding)
+
+ self.projection = build_conv_layer(
+ dict(type=conv_type),
+ in_channels=in_channels,
+ out_channels=embed_dims,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ bias=bias)
+
+ if norm_cfg is not None:
+ self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
+ else:
+ self.norm = None
+
+ if input_size:
+ input_size = to_2tuple(input_size)
+ # `init_out_size` would be used outside to
+ # calculate the num_patches
+ # when `use_abs_pos_embed` outside
+ self.init_input_size = input_size
+ if self.adap_padding:
+ pad_h, pad_w = self.adap_padding.get_pad_shape(input_size)
+ input_h, input_w = input_size
+ input_h = input_h + pad_h
+ input_w = input_w + pad_w
+ input_size = (input_h, input_w)
+
+ # https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
+ h_out = (input_size[0] + 2 * padding[0] - dilation[0] *
+ (kernel_size[0] - 1) - 1) // stride[0] + 1
+ w_out = (input_size[1] + 2 * padding[1] - dilation[1] *
+ (kernel_size[1] - 1) - 1) // stride[1] + 1
+ self.init_out_size = (h_out, w_out)
+ else:
+ self.init_input_size = None
+ self.init_out_size = None
+
+ def forward(self, x: Tensor) -> Tuple[Tensor, Tuple[int]]:
+ """
+ Args:
+ x (Tensor): Has shape (B, C, H, W). In most case, C is 3.
+
+ Returns:
+ tuple: Contains merged results and its spatial shape.
+
+ - x (Tensor): Has shape (B, out_h * out_w, embed_dims)
+ - out_size (tuple[int]): Spatial shape of x, arrange as
+ (out_h, out_w).
+ """
+
+ if self.adap_padding:
+ x = self.adap_padding(x)
+
+ x = self.projection(x)
+ out_size = (x.shape[2], x.shape[3])
+ x = x.flatten(2).transpose(1, 2)
+ if self.norm is not None:
+ x = self.norm(x)
+ return x, out_size
+
+
+class PatchMerging(BaseModule):
+ """Merge patch feature map.
+
+ This layer groups feature map by kernel_size, and applies norm and linear
+ layers to the grouped feature map. Our implementation uses `nn.Unfold` to
+ merge patch, which is about 25% faster than original implementation.
+ Instead, we need to modify pretrained models for compatibility.
+
+ Args:
+ in_channels (int): The num of input channels.
+ to gets fully covered by filter and stride you specified..
+ Default: True.
+ out_channels (int): The num of output channels.
+ kernel_size (int | tuple, optional): the kernel size in the unfold
+ layer. Defaults to 2.
+ stride (int | tuple, optional): the stride of the sliding blocks in the
+ unfold layer. Default: None. (Would be set as `kernel_size`)
+ padding (int | tuple | string ): The padding length of
+ embedding conv. When it is a string, it means the mode
+ of adaptive padding, support "same" and "corner" now.
+ Default: "corner".
+ dilation (int | tuple, optional): dilation parameter in the unfold
+ layer. Default: 1.
+ bias (bool, optional): Whether to add bias in linear layer or not.
+ Defaults: False.
+ norm_cfg (dict, optional): Config dict for normalization layer.
+ Default: dict(type='LN').
+ init_cfg (dict, optional): The extra config for initialization.
+ Default: None.
+ """
+
+ def __init__(self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: Optional[Union[int, tuple]] = 2,
+ stride: Optional[Union[int, tuple]] = None,
+ padding: Union[int, tuple, str] = 'corner',
+ dilation: Optional[Union[int, tuple]] = 1,
+ bias: Optional[bool] = False,
+ norm_cfg: OptConfigType = dict(type='LN'),
+ init_cfg: OptConfigType = None) -> None:
+ super().__init__(init_cfg=init_cfg)
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ if stride:
+ stride = stride
+ else:
+ stride = kernel_size
+
+ kernel_size = to_2tuple(kernel_size)
+ stride = to_2tuple(stride)
+ dilation = to_2tuple(dilation)
+
+ if isinstance(padding, str):
+ self.adap_padding = AdaptivePadding(
+ kernel_size=kernel_size,
+ stride=stride,
+ dilation=dilation,
+ padding=padding)
+ # disable the padding of unfold
+ padding = 0
+ else:
+ self.adap_padding = None
+
+ padding = to_2tuple(padding)
+ self.sampler = nn.Unfold(
+ kernel_size=kernel_size,
+ dilation=dilation,
+ padding=padding,
+ stride=stride)
+
+ sample_dim = kernel_size[0] * kernel_size[1] * in_channels
+
+ if norm_cfg is not None:
+ self.norm = build_norm_layer(norm_cfg, sample_dim)[1]
+ else:
+ self.norm = None
+
+ self.reduction = nn.Linear(sample_dim, out_channels, bias=bias)
+
+ def forward(self, x: Tensor,
+ input_size: Tuple[int]) -> Tuple[Tensor, Tuple[int]]:
+ """
+ Args:
+ x (Tensor): Has shape (B, H*W, C_in).
+ input_size (tuple[int]): The spatial shape of x, arrange as (H, W).
+ Default: None.
+
+ Returns:
+ tuple: Contains merged results and its spatial shape.
+
+ - x (Tensor): Has shape (B, Merged_H * Merged_W, C_out)
+ - out_size (tuple[int]): Spatial shape of x, arrange as
+ (Merged_H, Merged_W).
+ """
+ B, L, C = x.shape
+ assert isinstance(input_size, Sequence), f'Expect ' \
+ f'input_size is ' \
+ f'`Sequence` ' \
+ f'but get {input_size}'
+
+ H, W = input_size
+ assert L == H * W, 'input feature has wrong size'
+
+ x = x.view(B, H, W, C).permute([0, 3, 1, 2]) # B, C, H, W
+ # Use nn.Unfold to merge patch. About 25% faster than original method,
+ # but need to modify pretrained model for compatibility
+
+ if self.adap_padding:
+ x = self.adap_padding(x)
+ H, W = x.shape[-2:]
+
+ x = self.sampler(x)
+ # if kernel_size=2 and stride=2, x should has shape (B, 4*C, H/2*W/2)
+
+ out_h = (H + 2 * self.sampler.padding[0] - self.sampler.dilation[0] *
+ (self.sampler.kernel_size[0] - 1) -
+ 1) // self.sampler.stride[0] + 1
+ out_w = (W + 2 * self.sampler.padding[1] - self.sampler.dilation[1] *
+ (self.sampler.kernel_size[1] - 1) -
+ 1) // self.sampler.stride[1] + 1
+
+ output_size = (out_h, out_w)
+ x = x.transpose(1, 2) # B, H/2*W/2, 4*C
+ x = self.norm(x) if self.norm else x
+ x = self.reduction(x)
+ return x, output_size
+
+
+def inverse_sigmoid(x: Tensor, eps: float = 1e-5) -> Tensor:
+ """Inverse function of sigmoid.
+
+ Args:
+ x (Tensor): The tensor to do the inverse.
+ eps (float): EPS avoid numerical overflow. Defaults 1e-5.
+ Returns:
+ Tensor: The x has passed the inverse function of sigmoid, has the same
+ shape with input.
+ """
+ x = x.clamp(min=0, max=1)
+ x1 = x.clamp(min=eps)
+ x2 = (1 - x).clamp(min=eps)
+ return torch.log(x1 / x2)
+
+
+class MLP(BaseModule):
+ """Very simple multi-layer perceptron (also called FFN) with relu. Mostly
+ used in DETR series detectors.
+
+ Args:
+ input_dim (int): Feature dim of the input tensor.
+ hidden_dim (int): Feature dim of the hidden layer.
+ output_dim (int): Feature dim of the output tensor.
+ num_layers (int): Number of FFN layers. As the last
+ layer of MLP only contains FFN (Linear).
+ """
+
+ def __init__(self, input_dim: int, hidden_dim: int, output_dim: int,
+ num_layers: int) -> None:
+ super().__init__()
+ self.num_layers = num_layers
+ h = [hidden_dim] * (num_layers - 1)
+ self.layers = ModuleList(
+ Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Forward function of MLP.
+
+ Args:
+ x (Tensor): The input feature, has shape
+ (num_queries, bs, input_dim).
+ Returns:
+ Tensor: The output feature, has shape
+ (num_queries, bs, output_dim).
+ """
+ for i, layer in enumerate(self.layers):
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
+ return x
+
+
+def gen_sine_embed_for_ref(reference: Tensor):
+ # n_query, bs, _ = pos_tensor.size()
+ scale = 2 * math.pi
+ dim_t = torch.arange(128, dtype=torch.float32, device=reference.device)
+ dim_t = 10000**(2 * (dim_t // 2) / 128)
+ x_embed = reference[:, :, 0] * scale
+ y_embed = reference[:, :, 1] * scale
+ pos_x = x_embed[:, :, None] / dim_t
+ pos_y = y_embed[:, :, None] / dim_t
+ pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()),
+ dim=3).flatten(2)
+ pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()),
+ dim=3).flatten(2)
+ if reference.size(-1) == 2:
+ pos = torch.cat((pos_y, pos_x), dim=2)
+ elif reference.size(-1) == 4:
+ w_embed = reference[:, :, 2] * scale
+ pos_w = w_embed[:, :, None] / dim_t
+ pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()),
+ dim=3).flatten(2)
+
+ h_embed = reference[:, :, 3] * scale
+ pos_h = h_embed[:, :, None] / dim_t
+ pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()),
+ dim=3).flatten(2)
+
+ pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2)
+ else:
+ raise ValueError('Unknown pos_tensor shape(-1):{}'.format(
+ reference.size(-1)))
+ return pos
+
+
+@MODELS.register_module()
+class DynamicConv(BaseModule):
+ """Implements Dynamic Convolution.
+
+ This module generate parameters for each sample and
+ use bmm to implement 1*1 convolution. Code is modified
+ from the `official github repo `_ .
+
+ Args:
+ in_channels (int): The input feature channel.
+ Defaults to 256.
+ feat_channels (int): The inner feature channel.
+ Defaults to 64.
+ out_channels (int, optional): The output feature channel.
+ When not specified, it will be set to `in_channels`
+ by default
+ input_feat_shape (int): The shape of input feature.
+ Defaults to 7.
+ with_proj (bool): Project two-dimentional feature to
+ one-dimentional feature. Default to True.
+ act_cfg (dict): The activation config for DynamicConv.
+ norm_cfg (dict): Config dict for normalization layer. Default
+ layer normalization.
+ init_cfg (obj:`mmengine.ConfigDict`): The Config for initialization.
+ Default: None.
+ """
+
+ def __init__(self,
+ in_channels: int = 256,
+ feat_channels: int = 64,
+ out_channels: Optional[int] = None,
+ input_feat_shape: int = 7,
+ with_proj: bool = True,
+ act_cfg: OptConfigType = dict(type='ReLU', inplace=True),
+ norm_cfg: OptConfigType = dict(type='LN'),
+ init_cfg: OptConfigType = None) -> None:
+ super(DynamicConv, self).__init__(init_cfg)
+ self.in_channels = in_channels
+ self.feat_channels = feat_channels
+ self.out_channels_raw = out_channels
+ self.input_feat_shape = input_feat_shape
+ self.with_proj = with_proj
+ self.act_cfg = act_cfg
+ self.norm_cfg = norm_cfg
+ self.out_channels = out_channels if out_channels else in_channels
+
+ self.num_params_in = self.in_channels * self.feat_channels
+ self.num_params_out = self.out_channels * self.feat_channels
+ self.dynamic_layer = nn.Linear(
+ self.in_channels, self.num_params_in + self.num_params_out)
+
+ self.norm_in = build_norm_layer(norm_cfg, self.feat_channels)[1]
+ self.norm_out = build_norm_layer(norm_cfg, self.out_channels)[1]
+
+ self.activation = build_activation_layer(act_cfg)
+
+ num_output = self.out_channels * input_feat_shape**2
+ if self.with_proj:
+ self.fc_layer = nn.Linear(num_output, self.out_channels)
+ self.fc_norm = build_norm_layer(norm_cfg, self.out_channels)[1]
+
+ def forward(self, param_feature: Tensor, input_feature: Tensor) -> Tensor:
+ """Forward function for `DynamicConv`.
+
+ Args:
+ param_feature (Tensor): The feature can be used
+ to generate the parameter, has shape
+ (num_all_proposals, in_channels).
+ input_feature (Tensor): Feature that
+ interact with parameters, has shape
+ (num_all_proposals, in_channels, H, W).
+
+ Returns:
+ Tensor: The output feature has shape
+ (num_all_proposals, out_channels).
+ """
+ input_feature = input_feature.flatten(2).permute(2, 0, 1)
+
+ input_feature = input_feature.permute(1, 0, 2)
+ parameters = self.dynamic_layer(param_feature)
+
+ param_in = parameters[:, :self.num_params_in].view(
+ -1, self.in_channels, self.feat_channels)
+ param_out = parameters[:, -self.num_params_out:].view(
+ -1, self.feat_channels, self.out_channels)
+
+ # input_feature has shape (num_all_proposals, H*W, in_channels)
+ # param_in has shape (num_all_proposals, in_channels, feat_channels)
+ # feature has shape (num_all_proposals, H*W, feat_channels)
+ features = torch.bmm(input_feature, param_in)
+ features = self.norm_in(features)
+ features = self.activation(features)
+
+ # param_out has shape (batch_size, feat_channels, out_channels)
+ features = torch.bmm(features, param_out)
+ features = self.norm_out(features)
+ features = self.activation(features)
+
+ if self.with_proj:
+ features = features.flatten(1)
+ features = self.fc_layer(features)
+ features = self.fc_norm(features)
+ features = self.activation(features)
+
+ return features
diff --git a/mmdet/models/task_modules/assigners/hungarian_assigner.py b/mmdet/models/task_modules/assigners/hungarian_assigner.py
index c28aff07da4..aad6732aeb6 100644
--- a/mmdet/models/task_modules/assigners/hungarian_assigner.py
+++ b/mmdet/models/task_modules/assigners/hungarian_assigner.py
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
+import numpy as np
import torch
from mmdet.registry import TASK_UTILS
@@ -18,6 +19,133 @@
@TASK_UTILS.register_module()
class HungarianAssigner(BaseAssigner):
+ """Computes one-to-one matching between predictions and ground truth. This
+ class computes an assignment between the targets and the predictions based
+ on the costs. The costs are weighted sum of some components. For DETR the
+ costs are weighted sum of classification cost, regression L1 cost and
+ regression iou cost. The targets don't include the no_object, so generally
+ there are more predictions than targets. After the one-to-one matching, the
+ un-matched are treated as backgrounds. Thus each query prediction will be
+ assigned with `0` or a positive integer indicating the ground truth index:
+
+ - 0: negative sample, no assigned gt
+ - positive integer: positive sample, index (1-based) of assigned gt
+ Args:
+ match_costs (:obj:`ConfigDict` or dict or \
+ List[Union[:obj:`ConfigDict`, dict]]): Match cost configs.
+ """
+
+ def __init__(
+ self, match_costs: Union[List[Union[dict, ConfigDict]], dict,
+ ConfigDict]
+ ) -> None:
+
+ if isinstance(match_costs, dict):
+ match_costs = [match_costs]
+ elif isinstance(match_costs, list):
+ assert len(match_costs) > 0, \
+ 'match_costs must not be a empty list.'
+
+ self.match_costs = [
+ TASK_UTILS.build(match_cost) for match_cost in match_costs
+ ]
+
+ def assign(self,
+ pred_instances: InstanceData,
+ gt_instances: InstanceData,
+ img_meta: Optional[dict] = None,
+ **kwargs) -> AssignResult:
+ """Computes one-to-one matching based on the weighted costs.
+
+ This method assign each query prediction to a ground truth or
+ background. The `assigned_gt_inds` with -1 means don't care,
+ 0 means negative sample, and positive number is the index (1-based)
+ of assigned gt.
+ The assignment is done in the following steps, the order matters.
+ 1. assign every prediction to -1
+ 2. compute the weighted costs
+ 3. do Hungarian matching on CPU based on the costs
+ 4. assign all to 0 (background) first, then for each matched pair
+ between predictions and gts, treat this prediction as foreground
+ and assign the corresponding gt index (plus 1) to it.
+ Args:
+ pred_instances (:obj:`InstanceData`): Instances of model
+ predictions. It includes ``priors``, and the priors can
+ be anchors or points, or the bboxes predicted by the
+ previous stage, has shape (n, 4). The bboxes predicted by
+ the current model or stage will be named ``bboxes``,
+ ``labels``, and ``scores``, the same as the ``InstanceData``
+ in other places. It may includes ``masks``, with shape
+ (n, h, w) or (n, l).
+ gt_instances (:obj:`InstanceData`): Ground truth of instance
+ annotations. It usually includes ``bboxes``, with shape (k, 4),
+ ``labels``, with shape (k, ) and ``masks``, with shape
+ (k, h, w) or (k, l).
+ img_meta (dict): Image information.
+ Returns:
+ :obj:`AssignResult`: The assigned result.
+ """
+ assert isinstance(gt_instances.labels, Tensor)
+ num_gts, num_preds = len(gt_instances), len(pred_instances)
+ gt_labels = gt_instances.labels
+ device = gt_labels.device
+
+ # 1. assign -1 by default
+ assigned_gt_inds = torch.full((num_preds, ),
+ -1,
+ dtype=torch.long,
+ device=device)
+ assigned_labels = torch.full((num_preds, ),
+ -1,
+ dtype=torch.long,
+ device=device)
+
+ if num_gts == 0 or num_preds == 0:
+ # No ground truth or boxes, return empty assignment
+ if num_gts == 0:
+ # No ground truth, assign all to background
+ assigned_gt_inds[:] = 0
+ return AssignResult(
+ num_gts=num_gts,
+ gt_inds=assigned_gt_inds,
+ max_overlaps=None,
+ labels=assigned_labels)
+
+ # 2. compute weighted cost
+ cost_list = []
+ for match_cost in self.match_costs:
+ cost = match_cost(
+ pred_instances=pred_instances,
+ gt_instances=gt_instances,
+ img_meta=img_meta)
+ cost_list.append(cost)
+ cost = torch.stack(cost_list).sum(dim=0)
+
+ # 3. do Hungarian matching on CPU using linear_sum_assignment
+ cost = cost.detach().cpu()
+ if linear_sum_assignment is None:
+ raise ImportError('Please run "pip install scipy" '
+ 'to install scipy first.')
+
+ matched_row_inds, matched_col_inds = linear_sum_assignment(cost)
+ matched_row_inds = torch.from_numpy(matched_row_inds).to(device)
+ matched_col_inds = torch.from_numpy(matched_col_inds).to(device)
+
+ # 4. assign backgrounds and foregrounds
+ # assign all indices to backgrounds first
+ assigned_gt_inds[:] = 0
+ # assign foregrounds based on matching results
+ assigned_gt_inds[matched_row_inds] = matched_col_inds + 1
+ assigned_labels[matched_row_inds] = gt_labels[matched_col_inds]
+ return AssignResult(
+ num_gts=num_gts,
+ gt_inds=assigned_gt_inds,
+ max_overlaps=None,
+ labels=assigned_labels)
+
+
+@TASK_UTILS.register_module()
+class GHungarianAssigner(BaseAssigner):
"""Computes one-to-one matching between predictions and ground truth.
This class computes an assignment between the targets and the predictions
@@ -37,10 +165,10 @@ class HungarianAssigner(BaseAssigner):
List[Union[:obj:`ConfigDict`, dict]]): Match cost configs.
"""
- def __init__(
- self, match_costs: Union[List[Union[dict, ConfigDict]], dict,
- ConfigDict]
- ) -> None:
+ def __init__(self,
+ match_costs: Union[List[Union[dict, ConfigDict]], dict,
+ ConfigDict],
+ group_detr=1) -> None:
if isinstance(match_costs, dict):
match_costs = [match_costs]
@@ -51,6 +179,7 @@ def __init__(
self.match_costs = [
TASK_UTILS.build(match_cost) for match_cost in match_costs
]
+ self.group_detr = group_detr
def assign(self,
pred_instances: InstanceData,
@@ -132,7 +261,23 @@ def assign(self,
raise ImportError('Please run "pip install scipy" '
'to install scipy first.')
- matched_row_inds, matched_col_inds = linear_sum_assignment(cost)
+ # indices = []
+ g_num_queries = num_preds // self.group_detr
+ cost_list = cost.split(g_num_queries, dim=0)
+ for g_i in range(self.group_detr):
+ cost_g = cost_list[g_i]
+ matched_row_inds_g, matched_col_inds_g = linear_sum_assignment(
+ cost_g)
+ if g_i == 0:
+ matched_row_inds, matched_col_inds = \
+ matched_row_inds_g, matched_col_inds_g
+ else:
+ matched_row_inds = np.concatenate([
+ matched_row_inds, matched_row_inds_g + g_num_queries * g_i
+ ])
+ matched_col_inds = np.concatenate(
+ [matched_col_inds, matched_col_inds_g])
+
matched_row_inds = torch.from_numpy(matched_row_inds).to(device)
matched_col_inds = torch.from_numpy(matched_col_inds).to(device)
diff --git a/mmdet/models/task_modules/assigners/match_cost.py b/mmdet/models/task_modules/assigners/match_cost.py
index 64efc809873..95fe89fa932 100644
--- a/mmdet/models/task_modules/assigners/match_cost.py
+++ b/mmdet/models/task_modules/assigners/match_cost.py
@@ -206,7 +206,7 @@ def __call__(self,
Args:
pred_instances (:obj:`InstanceData`): ``scores`` inside is
predicted classification logits, of shape
- (num_query, num_class).
+ (num_queries, num_class).
gt_instances (:obj:`InstanceData`): ``labels`` inside should have
shape (num_gt, ).
img_meta (Optional[dict]): _description_. Defaults to None.
@@ -253,7 +253,7 @@ def _focal_loss_cost(self, cls_pred: Tensor, gt_labels: Tensor) -> Tensor:
"""
Args:
cls_pred (Tensor): Predicted classification logits, shape
- (num_query, num_class).
+ (num_queries, num_class).
gt_labels (Tensor): Label of `gt_bboxes`, shape (num_gt,).
Returns:
@@ -272,13 +272,13 @@ def _mask_focal_loss_cost(self, cls_pred, gt_labels) -> Tensor:
"""
Args:
cls_pred (Tensor): Predicted classification logits.
- in shape (num_query, d1, ..., dn), dtype=torch.float32.
+ in shape (num_queries, d1, ..., dn), dtype=torch.float32.
gt_labels (Tensor): Ground truth in shape (num_gt, d1, ..., dn),
dtype=torch.long. Labels should be binary.
Returns:
Tensor: Focal cost matrix with weight in shape\
- (num_query, num_gt).
+ (num_queries, num_gt).
"""
cls_pred = cls_pred.flatten(1)
gt_labels = gt_labels.flatten(1).float()
@@ -349,13 +349,13 @@ def _binary_mask_dice_loss(self, mask_preds: Tensor,
gt_masks: Tensor) -> Tensor:
"""
Args:
- mask_preds (Tensor): Mask prediction in shape (num_query, *).
+ mask_preds (Tensor): Mask prediction in shape (num_queries, *).
gt_masks (Tensor): Ground truth in shape (num_gt, *)
store 0 or 1, 0 for negative class and 1 for
positive class.
Returns:
- Tensor: Dice cost matrix in shape (num_query, num_gt).
+ Tensor: Dice cost matrix in shape (num_queries, num_gt).
"""
mask_preds = mask_preds.flatten(1)
gt_masks = gt_masks.flatten(1).float()
@@ -415,13 +415,13 @@ def _binary_cross_entropy(self, cls_pred: Tensor,
gt_labels: Tensor) -> Tensor:
"""
Args:
- cls_pred (Tensor): The prediction with shape (num_query, 1, *) or
- (num_query, *).
+ cls_pred (Tensor): The prediction with shape (num_queries, 1, *) or
+ (num_queries, *).
gt_labels (Tensor): The learning label of prediction with
shape (num_gt, *).
Returns:
- Tensor: Cross entropy cost matrix in shape (num_query, num_gt).
+ Tensor: Cross entropy cost matrix in shape (num_queries, num_gt).
"""
cls_pred = cls_pred.flatten(1).float()
gt_labels = gt_labels.flatten(1).float()
diff --git a/tests/test_models/test_dense_heads/test_deformable_detr_head.py b/tests/test_models/test_dense_heads/test_deformable_detr_head.py
deleted file mode 100644
index ef149afe7c2..00000000000
--- a/tests/test_models/test_dense_heads/test_deformable_detr_head.py
+++ /dev/null
@@ -1,132 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from unittest import TestCase
-
-import torch
-from mmengine.config import ConfigDict
-from mmengine.structures import InstanceData
-
-from mmdet.models.dense_heads import DeformableDETRHead
-from mmdet.utils import register_all_modules
-
-
-class TestDeformableDETRHead(TestCase):
-
- def setUp(self):
- register_all_modules()
-
- def test_detr_head_loss(self):
- """Tests transformer head loss when truth is empty and non-empty."""
- s = 256
- img_metas = [{
- 'img_shape': (s, s),
- 'scale_factor': (1, 1),
- 'pad_shape': (s, s),
- 'batch_input_shape': (s, s)
- }]
- config = ConfigDict(
- dict(
- num_classes=4,
- in_channels=2048,
- sync_cls_avg_factor=True,
- as_two_stage=False,
- transformer=dict(
- type='DeformableDetrTransformer',
- encoder=dict(
- type='DetrTransformerEncoder',
- num_layers=6,
- transformerlayers=dict(
- type='BaseTransformerLayer',
- attn_cfgs=dict(
- type='MultiScaleDeformableAttention',
- embed_dims=256),
- feedforward_channels=1024,
- ffn_dropout=0.1,
- operation_order=('self_attn', 'norm', 'ffn',
- 'norm'))),
- decoder=dict(
- type='DeformableDetrTransformerDecoder',
- num_layers=6,
- return_intermediate=True,
- transformerlayers=dict(
- type='DetrTransformerDecoderLayer',
- attn_cfgs=[
- dict(
- type='MultiheadAttention',
- embed_dims=256,
- num_heads=8,
- dropout=0.1),
- dict(
- type='MultiScaleDeformableAttention',
- embed_dims=256)
- ],
- feedforward_channels=1024,
- ffn_dropout=0.1,
- operation_order=('self_attn', 'norm', 'cross_attn',
- 'norm', 'ffn', 'norm')))),
- positional_encoding=dict(
- type='SinePositionalEncoding',
- num_feats=128,
- normalize=True,
- offset=-0.5),
- loss_cls=dict(
- type='FocalLoss',
- use_sigmoid=True,
- gamma=2.0,
- alpha=0.25,
- loss_weight=2.0),
- loss_bbox=dict(type='L1Loss', loss_weight=5.0),
- loss_iou=dict(type='GIoULoss', loss_weight=2.0)),
- train_cfg=dict(
- assigner=dict(
- type='HungarianAssigner',
- match_costs=[
- dict(type='FocalLossCost', weight=2.0),
- dict(type='BBoxL1Cost', weight=5.0, box_format='xywh'),
- dict(type='IoUCost', iou_mode='giou', weight=2.0)
- ])),
- test_cfg=dict(max_per_img=100))
-
- deformable_detr_head = DeformableDETRHead(**config)
- deformable_detr_head.init_weights()
- feat = [
- torch.rand(1, 256, s // stride, s // stride)
- for stride in [8, 16, 32, 64]
- ]
- outs = deformable_detr_head.forward(feat, img_metas)
- # Test that empty ground truth encourages the network to
- # predict background
- gt_instances = InstanceData()
- gt_instances.bboxes = torch.empty((0, 4))
- gt_instances.labels = torch.LongTensor([])
- empty_gt_losses = deformable_detr_head.loss_by_feat(
- *outs, [gt_instances], img_metas)
- # When there is no truth, the cls loss should be nonzero but there
- # should be no box loss.
- for key, loss in empty_gt_losses.items():
- if 'cls' in key:
- self.assertGreater(loss.item(), 0,
- 'cls loss should be non-zero')
- elif 'bbox' in key:
- self.assertEqual(
- loss.item(), 0,
- 'there should be no box loss when no ground true boxes')
- elif 'iou' in key:
- self.assertEqual(
- loss.item(), 0,
- 'there should be no iou loss when no ground true boxes')
-
- # When truth is non-empty then both cls and box loss should be nonzero
- # for random inputs
- gt_instances = InstanceData()
- gt_instances.bboxes = torch.Tensor(
- [[23.6667, 23.8757, 238.6326, 151.8874]])
- gt_instances.labels = torch.LongTensor([2])
- one_gt_losses = deformable_detr_head.loss_by_feat(
- *outs, [gt_instances], img_metas)
- for loss in one_gt_losses.values():
- self.assertGreater(
- loss.item(), 0,
- 'cls loss, or box loss, or iou loss should be non-zero')
-
- # test predict_by_feat
- deformable_detr_head.predict_by_feat(*outs, img_metas, rescale=False)
diff --git a/tests/test_models/test_dense_heads/test_detr_head.py b/tests/test_models/test_dense_heads/test_detr_head.py
deleted file mode 100644
index 537bbaf1a85..00000000000
--- a/tests/test_models/test_dense_heads/test_detr_head.py
+++ /dev/null
@@ -1,134 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from unittest import TestCase
-
-import torch
-from mmengine.config import ConfigDict
-from mmengine.structures import InstanceData
-
-from mmdet.models.dense_heads import DETRHead
-from mmdet.structures import DetDataSample
-from mmdet.utils import register_all_modules
-
-
-class TestDETRHead(TestCase):
-
- def setUp(self) -> None:
- register_all_modules()
-
- def test_detr_head_loss(self):
- """Tests transformer head loss when truth is empty and non-empty."""
- s = 256
- img_metas = [{
- 'img_shape': (s, s),
- 'scale_factor': (1, 1),
- 'pad_shape': (s, s),
- 'batch_input_shape': (s, s)
- }]
- config = ConfigDict(
- dict(
- num_classes=4,
- in_channels=200,
- transformer=dict(
- type='Transformer',
- encoder=dict(
- type='DetrTransformerEncoder',
- num_layers=6,
- transformerlayers=dict(
- type='BaseTransformerLayer',
- attn_cfgs=[
- dict(
- type='MultiheadAttention',
- embed_dims=256,
- num_heads=8,
- dropout=0.1)
- ],
- feedforward_channels=2048,
- ffn_dropout=0.1,
- operation_order=('self_attn', 'norm', 'ffn',
- 'norm'))),
- decoder=dict(
- type='DetrTransformerDecoder',
- return_intermediate=True,
- num_layers=6,
- transformerlayers=dict(
- type='DetrTransformerDecoderLayer',
- attn_cfgs=dict(
- type='MultiheadAttention',
- embed_dims=256,
- num_heads=8,
- dropout=0.1),
- feedforward_channels=2048,
- ffn_dropout=0.1,
- operation_order=('self_attn', 'norm', 'cross_attn',
- 'norm', 'ffn', 'norm')),
- )),
- positional_encoding=dict(
- type='SinePositionalEncoding',
- num_feats=128,
- normalize=True),
- loss_cls=dict(
- type='CrossEntropyLoss',
- bg_cls_weight=0.1,
- use_sigmoid=False,
- loss_weight=1.0,
- class_weight=1.0),
- loss_bbox=dict(type='L1Loss', loss_weight=5.0),
- loss_iou=dict(type='GIoULoss', loss_weight=2.0)),
- train_cfg=dict(
- assigner=dict(
- type='HungarianAssigner',
- match_costs=[
- dict(type='ClassificationCost', weight=1.),
- dict(type='BBoxL1Cost', weight=5.0, box_format='xywh'),
- dict(type='IoUCost', iou_mode='giou', weight=2.0)
- ])),
- test_cfg=dict(max_per_img=100))
-
- detr_head = DETRHead(**config)
- detr_head.init_weights()
- feat = [torch.rand(1, 200, 10, 10)]
- cls_scores, bbox_preds = detr_head.forward(feat, img_metas)
- # Test that empty ground truth encourages the network to
- # predict background
- gt_instances = InstanceData()
- gt_instances.bboxes = torch.empty((0, 4))
- gt_instances.labels = torch.LongTensor([])
- empty_gt_losses = detr_head.loss_by_feat(cls_scores, bbox_preds,
- [gt_instances], img_metas)
- # When there is no truth, the cls loss should be nonzero but there
- # should be no box loss.
- for key, loss in empty_gt_losses.items():
- if 'cls' in key:
- self.assertGreater(loss.item(), 0,
- 'cls loss should be non-zero')
- elif 'bbox' in key:
- self.assertEqual(
- loss.item(), 0,
- 'there should be no box loss when no ground true boxes')
- elif 'iou' in key:
- self.assertEqual(
- loss.item(), 0,
- 'there should be no iou loss when there are no true boxes')
-
- # When truth is non-empty then both cls and box loss should be nonzero
- # for random inputs
- gt_instances = InstanceData()
- gt_instances.bboxes = torch.Tensor(
- [[23.6667, 23.8757, 238.6326, 151.8874]])
- gt_instances.labels = torch.LongTensor([2])
- one_gt_losses = detr_head.loss_by_feat(cls_scores, bbox_preds,
- [gt_instances], img_metas)
- for loss in one_gt_losses.values():
- self.assertGreater(
- loss.item(), 0,
- 'cls loss, or box loss, or iou loss should be non-zero')
-
- # test loss
- samples = DetDataSample()
- samples.set_metainfo(img_metas[0])
- samples.gt_instances = gt_instances
- detr_head.loss(feat, [samples])
- # test loss and predict
- detr_head.loss_and_predict(feat, [samples])
- # test only predict
- detr_head.predict(feat, [samples], rescale=True)
diff --git a/tests/test_models/test_detectors/test_deformable_detr.py b/tests/test_models/test_detectors/test_deformable_detr.py
new file mode 100644
index 00000000000..4c626eb1ba5
--- /dev/null
+++ b/tests/test_models/test_detectors/test_deformable_detr.py
@@ -0,0 +1,98 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from unittest import TestCase
+
+import torch
+from mmengine.structures import InstanceData
+
+from mmdet.models import build_detector
+from mmdet.structures import DetDataSample
+from mmdet.testing import get_detector_cfg
+from mmdet.utils import register_all_modules
+
+
+class TestDeformableDETR(TestCase):
+
+ def setUp(self):
+ register_all_modules()
+
+ def test_detr_head_loss(self):
+ """Tests transformer head loss when truth is empty and non-empty."""
+ s = 256
+ metainfo = {
+ 'img_shape': (s, s),
+ 'scale_factor': (1, 1),
+ 'pad_shape': (s, s),
+ 'batch_input_shape': (s, s)
+ }
+ img_metas = DetDataSample()
+ img_metas.set_metainfo(metainfo)
+ batch_data_samples = []
+ batch_data_samples.append(img_metas)
+
+ configs = [
+ get_detector_cfg(
+ 'deformable_detr/deformable-detr_r50_16xb2-50e_coco.py'),
+ get_detector_cfg(
+ 'deformable_detr/deformable-detr_refine_r50_16xb2-50e_coco.py' # noqa
+ ),
+ get_detector_cfg(
+ 'deformable_detr/deformable-detr_refine_twostage_r50_16xb2-50e_coco.py' # noqa
+ )
+ ]
+
+ for config in configs:
+ model = build_detector(config)
+ model.init_weights()
+ random_image = torch.rand(1, 3, s, s)
+
+ # Test that empty ground truth encourages the network to
+ # predict background
+ gt_instances = InstanceData()
+ gt_instances.bboxes = torch.empty((0, 4))
+ gt_instances.labels = torch.LongTensor([])
+ img_metas.gt_instances = gt_instances
+ batch_data_samples1 = []
+ batch_data_samples1.append(img_metas)
+ empty_gt_losses = model.loss(
+ random_image, batch_data_samples=batch_data_samples1)
+ # When there is no truth, the cls loss should be nonzero but there
+ # should be no box loss.
+ for key, loss in empty_gt_losses.items():
+ if 'cls' in key:
+ self.assertGreater(loss.item(), 0,
+ 'cls loss should be non-zero')
+ elif 'bbox' in key:
+ self.assertEqual(
+ loss.item(), 0,
+ 'there should be no box loss when no ground true boxes'
+ )
+ elif 'iou' in key:
+ self.assertEqual(
+ loss.item(), 0,
+ 'there should be no iou loss when no ground true boxes'
+ )
+
+ # When truth is non-empty then both cls and box loss should
+ # be nonzero for random inputs
+ gt_instances = InstanceData()
+ gt_instances.bboxes = torch.Tensor(
+ [[23.6667, 23.8757, 238.6326, 151.8874]])
+ gt_instances.labels = torch.LongTensor([2])
+ img_metas.gt_instances = gt_instances
+ batch_data_samples2 = []
+ batch_data_samples2.append(img_metas)
+ one_gt_losses = model.loss(
+ random_image, batch_data_samples=batch_data_samples2)
+ for loss in one_gt_losses.values():
+ self.assertGreater(
+ loss.item(), 0,
+ 'cls loss, or box loss, or iou loss should be non-zero')
+
+ # test _forward
+ model._forward(
+ random_image, batch_data_samples=batch_data_samples2)
+ # test only predict
+ model.predict(
+ random_image,
+ batch_data_samples=batch_data_samples2,
+ rescale=True)
diff --git a/tests/test_models/test_detectors/test_detr.py b/tests/test_models/test_detectors/test_detr.py
new file mode 100644
index 00000000000..19d04be914a
--- /dev/null
+++ b/tests/test_models/test_detectors/test_detr.py
@@ -0,0 +1,83 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from unittest import TestCase
+
+import torch
+from mmengine.structures import InstanceData
+
+from mmdet.models import build_detector
+from mmdet.structures import DetDataSample
+from mmdet.testing import get_detector_cfg
+from mmdet.utils import register_all_modules
+
+
+class TestDETR(TestCase):
+
+ def setUp(self) -> None:
+ register_all_modules()
+
+ def test_detr_head_loss(self):
+ """Tests transformer head loss when truth is empty and non-empty."""
+ s = 256
+ metainfo = {
+ 'img_shape': (s, s),
+ 'scale_factor': (1, 1),
+ 'pad_shape': (s, s),
+ 'batch_input_shape': (s, s)
+ }
+ img_metas = DetDataSample()
+ img_metas.set_metainfo(metainfo)
+ batch_data_samples = []
+ batch_data_samples.append(img_metas)
+
+ config = get_detector_cfg('detr/detr_r50_8xb2-150e_coco.py')
+
+ model = build_detector(config)
+ model.init_weights()
+ random_image = torch.rand(1, 3, s, s)
+
+ # Test that empty ground truth encourages the network to
+ # predict background
+ gt_instances = InstanceData()
+ gt_instances.bboxes = torch.empty((0, 4))
+ gt_instances.labels = torch.LongTensor([])
+ img_metas.gt_instances = gt_instances
+ batch_data_samples1 = []
+ batch_data_samples1.append(img_metas)
+ empty_gt_losses = model.loss(
+ random_image, batch_data_samples=batch_data_samples1)
+ # When there is no truth, the cls loss should be nonzero but there
+ # should be no box loss.
+ for key, loss in empty_gt_losses.items():
+ if 'cls' in key:
+ self.assertGreater(loss.item(), 0,
+ 'cls loss should be non-zero')
+ elif 'bbox' in key:
+ self.assertEqual(
+ loss.item(), 0,
+ 'there should be no box loss when no ground true boxes')
+ elif 'iou' in key:
+ self.assertEqual(
+ loss.item(), 0,
+ 'there should be no iou loss when there are no true boxes')
+
+ # When truth is non-empty then both cls and box loss should be nonzero
+ # for random inputs
+ gt_instances = InstanceData()
+ gt_instances.bboxes = torch.Tensor(
+ [[23.6667, 23.8757, 238.6326, 151.8874]])
+ gt_instances.labels = torch.LongTensor([2])
+ img_metas.gt_instances = gt_instances
+ batch_data_samples2 = []
+ batch_data_samples2.append(img_metas)
+ one_gt_losses = model.loss(
+ random_image, batch_data_samples=batch_data_samples2)
+ for loss in one_gt_losses.values():
+ self.assertGreater(
+ loss.item(), 0,
+ 'cls loss, or box loss, or iou loss should be non-zero')
+
+ # test _forward
+ model._forward(random_image, batch_data_samples=batch_data_samples2)
+ # test only predict
+ model.predict(
+ random_image, batch_data_samples=batch_data_samples2, rescale=True)
diff --git a/tests/test_models/test_detectors/test_maskformer.py b/tests/test_models/test_detectors/test_maskformer.py
deleted file mode 100644
index 4c4a7c77089..00000000000
--- a/tests/test_models/test_detectors/test_maskformer.py
+++ /dev/null
@@ -1,236 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import unittest
-
-import torch
-from parameterized import parameterized
-
-from mmdet.models import build_detector
-from mmdet.structures import DetDataSample
-from mmdet.testing._utils import demo_mm_inputs, get_detector_cfg
-from mmdet.utils import register_all_modules
-
-
-class TestMaskFormer(unittest.TestCase):
-
- def setUp(self):
- register_all_modules()
-
- def _create_model_cfg(self):
- cfg_path = 'maskformer/maskformer_r50_ms-16xb1-75e_coco.py'
- model_cfg = get_detector_cfg(cfg_path)
- base_channels = 32
- model_cfg.backbone.depth = 18
- model_cfg.backbone.init_cfg = None
- model_cfg.backbone.base_channels = base_channels
- model_cfg.panoptic_head.in_channels = [
- base_channels * 2**i for i in range(4)
- ]
- model_cfg.panoptic_head.feat_channels = base_channels
- model_cfg.panoptic_head.out_channels = base_channels
- model_cfg.panoptic_head.pixel_decoder.encoder.\
- transformerlayers.attn_cfgs.embed_dims = base_channels
- model_cfg.panoptic_head.pixel_decoder.encoder.\
- transformerlayers.ffn_cfgs.embed_dims = base_channels
- model_cfg.panoptic_head.pixel_decoder.encoder.\
- transformerlayers.ffn_cfgs.feedforward_channels = base_channels * 8
- model_cfg.panoptic_head.pixel_decoder.\
- positional_encoding.num_feats = base_channels // 2
- model_cfg.panoptic_head.positional_encoding.\
- num_feats = base_channels // 2
- model_cfg.panoptic_head.transformer_decoder.\
- transformerlayers.attn_cfgs.embed_dims = base_channels
- model_cfg.panoptic_head.transformer_decoder.\
- transformerlayers.ffn_cfgs.embed_dims = base_channels
- model_cfg.panoptic_head.transformer_decoder.\
- transformerlayers.ffn_cfgs.feedforward_channels = base_channels * 8
- model_cfg.panoptic_head.transformer_decoder.\
- transformerlayers.feedforward_channels = base_channels * 8
- return model_cfg
-
- def test_init(self):
- model_cfg = self._create_model_cfg()
- detector = build_detector(model_cfg)
- detector.init_weights()
- assert detector.backbone
- assert detector.panoptic_head
-
- @parameterized.expand([('cpu', ), ('cuda', )])
- def test_forward_loss_mode(self, device):
- model_cfg = self._create_model_cfg()
- detector = build_detector(model_cfg)
-
- if device == 'cuda' and not torch.cuda.is_available():
- return unittest.skip('test requires GPU and torch+cuda')
- detector = detector.to(device)
-
- packed_inputs = demo_mm_inputs(
- 2,
- image_shapes=[(3, 128, 127), (3, 91, 92)],
- sem_seg_output_strides=1,
- with_mask=True,
- with_semantic=True)
- data = detector.data_preprocessor(packed_inputs, True)
- # Test loss mode
- losses = detector.forward(**data, mode='loss')
- self.assertIsInstance(losses, dict)
-
- @parameterized.expand([('cpu', ), ('cuda', )])
- def test_forward_predict_mode(self, device):
- model_cfg = self._create_model_cfg()
- detector = build_detector(model_cfg)
- if device == 'cuda' and not torch.cuda.is_available():
- return unittest.skip('test requires GPU and torch+cuda')
- detector = detector.to(device)
- packed_inputs = demo_mm_inputs(
- 2,
- image_shapes=[(3, 128, 127), (3, 91, 92)],
- sem_seg_output_strides=1,
- with_mask=True,
- with_semantic=True)
- data = detector.data_preprocessor(packed_inputs, False)
- # Test forward test
- detector.eval()
- with torch.no_grad():
- batch_results = detector.forward(**data, mode='predict')
- self.assertEqual(len(batch_results), 2)
- self.assertIsInstance(batch_results[0], DetDataSample)
-
- @parameterized.expand([('cpu', ), ('cuda', )])
- def test_forward_tensor_mode(self, device):
- model_cfg = self._create_model_cfg()
- detector = build_detector(model_cfg)
- if device == 'cuda' and not torch.cuda.is_available():
- return unittest.skip('test requires GPU and torch+cuda')
- detector = detector.to(device)
-
- packed_inputs = demo_mm_inputs(
- 2, [[3, 128, 128], [3, 125, 130]],
- sem_seg_output_strides=1,
- with_mask=True,
- with_semantic=True)
- data = detector.data_preprocessor(packed_inputs, False)
- out = detector.forward(**data, mode='tensor')
- self.assertIsInstance(out, tuple)
-
-
-class TestMask2Former(unittest.TestCase):
-
- def setUp(self):
- register_all_modules()
-
- def _create_model_cfg(self, cfg_path):
- model_cfg = get_detector_cfg(cfg_path)
- base_channels = 32
- model_cfg.backbone.depth = 18
- model_cfg.backbone.init_cfg = None
- model_cfg.backbone.base_channels = base_channels
- model_cfg.panoptic_head.in_channels = [
- base_channels * 2**i for i in range(4)
- ]
- model_cfg.panoptic_head.feat_channels = base_channels
- model_cfg.panoptic_head.out_channels = base_channels
- model_cfg.panoptic_head.pixel_decoder.encoder.\
- transformerlayers.attn_cfgs.embed_dims = base_channels
- model_cfg.panoptic_head.pixel_decoder.encoder.\
- transformerlayers.ffn_cfgs.embed_dims = base_channels
- model_cfg.panoptic_head.pixel_decoder.encoder.\
- transformerlayers.ffn_cfgs.feedforward_channels = base_channels * 4
- model_cfg.panoptic_head.pixel_decoder.\
- positional_encoding.num_feats = base_channels // 2
- model_cfg.panoptic_head.positional_encoding.\
- num_feats = base_channels // 2
- model_cfg.panoptic_head.transformer_decoder.\
- transformerlayers.attn_cfgs.embed_dims = base_channels
- model_cfg.panoptic_head.transformer_decoder.\
- transformerlayers.ffn_cfgs.embed_dims = base_channels
- model_cfg.panoptic_head.transformer_decoder.\
- transformerlayers.ffn_cfgs.feedforward_channels = base_channels * 8
- model_cfg.panoptic_head.transformer_decoder.\
- transformerlayers.feedforward_channels = base_channels * 8
-
- return model_cfg
-
- def test_init(self):
- model_cfg = self._create_model_cfg(
- 'mask2former/mask2former_r50_8xb2-lsj-50e_coco-panoptic.py')
- detector = build_detector(model_cfg)
- detector.init_weights()
- assert detector.backbone
- assert detector.panoptic_head
-
- @parameterized.expand([
- ('cpu', 'mask2former/mask2former_r50_8xb2-lsj-50e_coco-panoptic.py'),
- ('cpu', 'mask2former/mask2former_r50_8xb2-lsj-50e_coco.py'),
- ('cuda', 'mask2former/mask2former_r50_8xb2-lsj-50e_coco-panoptic.py'),
- ('cuda', 'mask2former/mask2former_r50_8xb2-lsj-50e_coco.py')
- ])
- def test_forward_loss_mode(self, device, cfg_path):
- print(device, cfg_path)
- with_semantic = 'panoptic' in cfg_path
- model_cfg = self._create_model_cfg(cfg_path)
- detector = build_detector(model_cfg)
-
- if device == 'cuda' and not torch.cuda.is_available():
- return unittest.skip('test requires GPU and torch+cuda')
- detector = detector.to(device)
-
- packed_inputs = demo_mm_inputs(
- 2,
- image_shapes=[(3, 128, 127), (3, 91, 92)],
- sem_seg_output_strides=1,
- with_mask=True,
- with_semantic=with_semantic)
- data = detector.data_preprocessor(packed_inputs, True)
- # Test loss mode
- losses = detector.forward(**data, mode='loss')
- self.assertIsInstance(losses, dict)
-
- @parameterized.expand([
- ('cpu', 'mask2former/mask2former_r50_8xb2-lsj-50e_coco-panoptic.py'),
- ('cpu', 'mask2former/mask2former_r50_8xb2-lsj-50e_coco.py'),
- ('cuda', 'mask2former/mask2former_r50_8xb2-lsj-50e_coco-panoptic.py'),
- ('cuda', 'mask2former/mask2former_r50_8xb2-lsj-50e_coco.py')
- ])
- def test_forward_predict_mode(self, device, cfg_path):
- with_semantic = 'panoptic' in cfg_path
- model_cfg = self._create_model_cfg(cfg_path)
- detector = build_detector(model_cfg)
- if device == 'cuda' and not torch.cuda.is_available():
- return unittest.skip('test requires GPU and torch+cuda')
- detector = detector.to(device)
- packed_inputs = demo_mm_inputs(
- 2,
- image_shapes=[(3, 128, 127), (3, 91, 92)],
- sem_seg_output_strides=1,
- with_mask=True,
- with_semantic=with_semantic)
- data = detector.data_preprocessor(packed_inputs, False)
- # Test forward test
- detector.eval()
- with torch.no_grad():
- batch_results = detector.forward(**data, mode='predict')
- self.assertEqual(len(batch_results), 2)
- self.assertIsInstance(batch_results[0], DetDataSample)
-
- @parameterized.expand([
- ('cpu', 'mask2former/mask2former_r50_8xb2-lsj-50e_coco-panoptic.py'),
- ('cpu', 'mask2former/mask2former_r50_8xb2-lsj-50e_coco.py'),
- ('cuda', 'mask2former/mask2former_r50_8xb2-lsj-50e_coco-panoptic.py'),
- ('cuda', 'mask2former/mask2former_r50_8xb2-lsj-50e_coco.py')
- ])
- def test_forward_tensor_mode(self, device, cfg_path):
- with_semantic = 'panoptic' in cfg_path
- model_cfg = self._create_model_cfg(cfg_path)
- detector = build_detector(model_cfg)
- if device == 'cuda' and not torch.cuda.is_available():
- return unittest.skip('test requires GPU and torch+cuda')
- detector = detector.to(device)
-
- packed_inputs = demo_mm_inputs(
- 2, [[3, 128, 128], [3, 125, 130]],
- sem_seg_output_strides=1,
- with_mask=True,
- with_semantic=with_semantic)
- data = detector.data_preprocessor(packed_inputs, False)
- out = detector.forward(**data, mode='tensor')
- self.assertIsInstance(out, tuple)
diff --git a/tests/test_models/test_layers/test_plugins.py b/tests/test_models/test_layers/test_plugins.py
deleted file mode 100644
index b1e57bb32eb..00000000000
--- a/tests/test_models/test_layers/test_plugins.py
+++ /dev/null
@@ -1,185 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import unittest
-
-import pytest
-import torch
-from mmengine.config import ConfigDict
-
-from mmdet.models.layers import DropBlock
-from mmdet.registry import MODELS
-from mmdet.utils import register_all_modules
-
-register_all_modules()
-
-
-def test_dropblock():
- feat = torch.rand(1, 1, 11, 11)
- drop_prob = 1.0
- dropblock = DropBlock(drop_prob, block_size=11, warmup_iters=0)
- out_feat = dropblock(feat)
- assert (out_feat == 0).all() and out_feat.shape == feat.shape
- drop_prob = 0.5
- dropblock = DropBlock(drop_prob, block_size=5, warmup_iters=0)
- out_feat = dropblock(feat)
- assert out_feat.shape == feat.shape
-
- # drop_prob must be (0,1]
- with pytest.raises(AssertionError):
- DropBlock(1.5, 3)
-
- # block_size cannot be an even number
- with pytest.raises(AssertionError):
- DropBlock(0.5, 2)
-
- # warmup_iters cannot be less than 0
- with pytest.raises(AssertionError):
- DropBlock(0.5, 3, -1)
-
-
-class TestPixelDecoder(unittest.TestCase):
-
- def test_forward(self):
- base_channels = 64
- pixel_decoder_cfg = ConfigDict(
- dict(
- type='PixelDecoder',
- in_channels=[base_channels * 2**i for i in range(4)],
- feat_channels=base_channels,
- out_channels=base_channels,
- norm_cfg=dict(type='GN', num_groups=32),
- act_cfg=dict(type='ReLU')))
- self = MODELS.build(pixel_decoder_cfg)
- self.init_weights()
- img_metas = [{}, {}]
- feats = [
- torch.rand(
- (2, base_channels * 2**i, 4 * 2**(3 - i), 5 * 2**(3 - i)))
- for i in range(4)
- ]
- mask_feature, memory = self(feats, img_metas)
-
- assert (memory == feats[-1]).all()
- assert mask_feature.shape == feats[0].shape
-
-
-class TestTransformerEncoderPixelDecoder(unittest.TestCase):
-
- def test_forward(self):
- base_channels = 64
- pixel_decoder_cfg = ConfigDict(
- dict(
- type='TransformerEncoderPixelDecoder',
- in_channels=[base_channels * 2**i for i in range(4)],
- feat_channels=base_channels,
- out_channels=base_channels,
- norm_cfg=dict(type='GN', num_groups=32),
- act_cfg=dict(type='ReLU'),
- encoder=dict(
- type='DetrTransformerEncoder',
- num_layers=6,
- transformerlayers=dict(
- type='BaseTransformerLayer',
- attn_cfgs=dict(
- type='MultiheadAttention',
- embed_dims=base_channels,
- num_heads=8,
- attn_drop=0.1,
- proj_drop=0.1,
- dropout_layer=None,
- batch_first=False),
- ffn_cfgs=dict(
- embed_dims=base_channels,
- feedforward_channels=base_channels * 8,
- num_fcs=2,
- act_cfg=dict(type='ReLU', inplace=True),
- ffn_drop=0.1,
- dropout_layer=None,
- add_identity=True),
- operation_order=('self_attn', 'norm', 'ffn', 'norm'),
- norm_cfg=dict(type='LN'),
- init_cfg=None,
- batch_first=False),
- init_cfg=None),
- positional_encoding=dict(
- type='SinePositionalEncoding',
- num_feats=base_channels // 2,
- normalize=True)))
- self = MODELS.build(pixel_decoder_cfg)
- self.init_weights()
- img_metas = [{
- 'batch_input_shape': (128, 160),
- 'img_shape': (120, 160),
- }, {
- 'batch_input_shape': (128, 160),
- 'img_shape': (125, 160),
- }]
- feats = [
- torch.rand(
- (2, base_channels * 2**i, 4 * 2**(3 - i), 5 * 2**(3 - i)))
- for i in range(4)
- ]
- mask_feature, memory = self(feats, img_metas)
-
- assert memory.shape[-2:] == feats[-1].shape[-2:]
- assert mask_feature.shape == feats[0].shape
-
-
-class TestMSDeformAttnPixelDecoder(unittest.TestCase):
-
- def test_forward(self):
- base_channels = 64
- pixel_decoder_cfg = ConfigDict(
- dict(
- type='MSDeformAttnPixelDecoder',
- in_channels=[base_channels * 2**i for i in range(4)],
- strides=[4, 8, 16, 32],
- feat_channels=base_channels,
- out_channels=base_channels,
- num_outs=3,
- norm_cfg=dict(type='GN', num_groups=32),
- act_cfg=dict(type='ReLU'),
- encoder=dict(
- type='DetrTransformerEncoder',
- num_layers=6,
- transformerlayers=dict(
- type='BaseTransformerLayer',
- attn_cfgs=dict(
- type='MultiScaleDeformableAttention',
- embed_dims=base_channels,
- num_heads=8,
- num_levels=3,
- num_points=4,
- im2col_step=64,
- dropout=0.0,
- batch_first=False,
- norm_cfg=None,
- init_cfg=None),
- ffn_cfgs=dict(
- type='FFN',
- embed_dims=base_channels,
- feedforward_channels=base_channels * 4,
- num_fcs=2,
- ffn_drop=0.0,
- act_cfg=dict(type='ReLU', inplace=True)),
- operation_order=('self_attn', 'norm', 'ffn', 'norm')),
- init_cfg=None),
- positional_encoding=dict(
- type='SinePositionalEncoding',
- num_feats=base_channels // 2,
- normalize=True),
- init_cfg=None), )
- self = MODELS.build(pixel_decoder_cfg)
- self.init_weights()
- feats = [
- torch.rand(
- (2, base_channels * 2**i, 4 * 2**(3 - i), 5 * 2**(3 - i)))
- for i in range(4)
- ]
- mask_feature, multi_scale_features = self(feats)
-
- assert mask_feature.shape == feats[0].shape
- assert len(multi_scale_features) == 3
- multi_scale_features = multi_scale_features[::-1]
- for i in range(3):
- assert multi_scale_features[i].shape[-2:] == feats[i +
- 1].shape[-2:]
diff --git a/tests/test_models/test_layers/test_transformer.py b/tests/test_models/test_layers/test_transformer.py
index 9151e308424..c261de622e7 100644
--- a/tests/test_models/test_layers/test_transformer.py
+++ b/tests/test_models/test_layers/test_transformer.py
@@ -6,8 +6,7 @@
from mmdet.models.layers.transformer import (AdaptivePadding,
DetrTransformerDecoder,
DetrTransformerEncoder,
- PatchEmbed, PatchMerging,
- Transformer)
+ PatchEmbed, PatchMerging)
def test_adaptive_padding():
@@ -466,105 +465,40 @@ def test_patch_merging():
assert x_out.size(1) == out_size[0] * out_size[1]
-def test_detr_transformer_dencoder_encoder_layer():
+def test_detr_transformer_encoder_decoder():
config = ConfigDict(
- dict(
- return_intermediate=True,
- num_layers=6,
- transformerlayers=dict(
- type='DetrTransformerDecoderLayer',
- attn_cfgs=dict(
- type='MultiheadAttention',
- embed_dims=256,
- num_heads=8,
- dropout=0.1),
+ num_layers=6,
+ layer_cfg=dict( # DetrTransformerDecoderLayer
+ self_attn_cfg=dict( # MultiheadAttention
+ embed_dims=256,
+ num_heads=8,
+ dropout=0.1),
+ cross_attn_cfg=dict( # MultiheadAttention
+ embed_dims=256,
+ num_heads=8,
+ dropout=0.1),
+ ffn_cfg=dict(
+ embed_dims=256,
feedforward_channels=2048,
- ffn_dropout=0.1,
- operation_order=(
- 'norm',
- 'self_attn',
- 'norm',
- 'cross_attn',
- 'norm',
- 'ffn',
- ))))
- assert DetrTransformerDecoder(**config).layers[0].pre_norm
+ num_fcs=2,
+ ffn_drop=0.1,
+ act_cfg=dict(type='ReLU', inplace=True))))
assert len(DetrTransformerDecoder(**config).layers) == 6
-
- DetrTransformerDecoder(**config)
- with pytest.raises(AssertionError):
- config = ConfigDict(
- dict(
- return_intermediate=True,
- num_layers=6,
- transformerlayers=[
- dict(
- type='DetrTransformerDecoderLayer',
- attn_cfgs=dict(
- type='MultiheadAttention',
- embed_dims=256,
- num_heads=8,
- dropout=0.1),
- feedforward_channels=2048,
- ffn_dropout=0.1,
- operation_order=('self_attn', 'norm', 'cross_attn',
- 'norm', 'ffn', 'norm'))
- ] * 5))
- DetrTransformerDecoder(**config)
+ assert DetrTransformerDecoder(**config)
config = ConfigDict(
dict(
num_layers=6,
- transformerlayers=dict(
- type='DetrTransformerDecoderLayer',
- attn_cfgs=dict(
- type='MultiheadAttention',
+ layer_cfg=dict( # DetrTransformerEncoderLayer
+ self_attn_cfg=dict( # MultiheadAttention
embed_dims=256,
num_heads=8,
dropout=0.1),
- feedforward_channels=2048,
- ffn_dropout=0.1,
- operation_order=('norm', 'self_attn', 'norm', 'cross_attn',
- 'norm', 'ffn', 'norm'))))
-
- with pytest.raises(AssertionError):
- # len(operation_order) == 6
- DetrTransformerEncoder(**config)
-
-
-def test_transformer():
- config = ConfigDict(
- dict(
- encoder=dict(
- type='DetrTransformerEncoder',
- num_layers=6,
- transformerlayers=dict(
- type='BaseTransformerLayer',
- attn_cfgs=[
- dict(
- type='MultiheadAttention',
- embed_dims=256,
- num_heads=8,
- dropout=0.1)
- ],
- feedforward_channels=2048,
- ffn_dropout=0.1,
- operation_order=('self_attn', 'norm', 'ffn', 'norm'))),
- decoder=dict(
- type='DetrTransformerDecoder',
- return_intermediate=True,
- num_layers=6,
- transformerlayers=dict(
- type='DetrTransformerDecoderLayer',
- attn_cfgs=dict(
- type='MultiheadAttention',
- embed_dims=256,
- num_heads=8,
- dropout=0.1),
+ ffn_cfg=dict(
+ embed_dims=256,
feedforward_channels=2048,
- ffn_dropout=0.1,
- operation_order=('self_attn', 'norm', 'cross_attn', 'norm',
- 'ffn', 'norm')),
- )))
- transformer = Transformer(**config)
- transformer.init_weights()
+ num_fcs=2,
+ ffn_drop=0.1,
+ act_cfg=dict(type='ReLU', inplace=True)))))
+ assert len(DetrTransformerEncoder(**config).layers) == 6
+ assert DetrTransformerEncoder(**config)