Skip to content

[WIP]: Add DINO on 3.x #8820

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions .circleci/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,13 @@ jobs:
- run: pip install "protobuf <= 3.20.1" && sudo apt-get update && sudo apt-get -y install libprotobuf-dev protobuf-compiler cmake
- run:
name: Install mmdet dependencies
# numpy may be downgraded after building pycocotools, which causes `ImportError: numpy.core.multiarray failed to import`
# force reinstall pycocotools to ensure pycocotools being built under the currenct numpy
command: |
python -m pip install git+ssh://[email protected]/open-mmlab/mmengine.git@main
python -m pip install << parameters.mmcv >>
pip install -r requirements/tests.txt -r requirements/optional.txt
pip install --force-reinstall pycocotools
pip install albumentations>=0.3.2 --no-binary imgaug,albumentations
pip install git+https://github.com/cocodataset/panopticapi.git
- run:
Expand Down Expand Up @@ -111,17 +114,18 @@ jobs:
command: |
docker build .circleci/docker -t mmdetection:gpu --build-arg PYTORCH=<< parameters.torch >> --build-arg CUDA=<< parameters.cuda >> --build-arg CUDNN=<< parameters.cudnn >>
docker run --gpus all -t -d -v /home/circleci/project:/mmdetection -v /home/circleci/mmengine:/mmengine -w /mmdetection --name mmdetection mmdetection:gpu
docker exec mmdetection apt-get install -y git
- run:
name: Install mmdet dependencies
# pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu101/torch${{matrix.torch_version}}/index.html
command: |
docker exec mmdetection pip install -e /mmengine
docker exec mmdetection pip install << parameters.mmcv >>
pip install -r requirements/tests.txt -r requirements/optional.txt
pip install pycocotools
pip install albumentations>=0.3.2 --no-binary imgaug,albumentations
pip install git+https://github.com/cocodataset/panopticapi.git
python -c 'import mmcv; print(mmcv.__version__)'
docker exec mmdetection pip install -r requirements/tests.txt -r requirements/optional.txt
docker exec mmdetection pip install pycocotools
docker exec mmdetection pip install albumentations>=0.3.2 --no-binary imgaug,albumentations
docker exec mmdetection pip install git+https://github.com/cocodataset/panopticapi.git
docker exec mmdetection python -c 'import mmcv; print(mmcv.__version__)'
- run:
name: Build and install
command: |
Expand Down
67 changes: 26 additions & 41 deletions configs/deformable_detr/deformable-detr_r50_16xb2-50e_coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
]
model = dict(
type='DeformableDETR',
num_query=300,
with_box_refine=False,
as_two_stage=False,
data_preprocessor=dict(
type='DetDataPreprocessor',
mean=[123.675, 116.28, 103.53],
Expand All @@ -27,50 +30,29 @@
act_cfg=None,
norm_cfg=dict(type='GN', num_groups=32),
num_outs=4),
encoder_cfg=dict(
num_layers=6,
layer_cfg=dict(
self_attn_cfg=dict(embed_dims=256),
ffn_cfg=dict(
embed_dims=256, feedforward_channels=1024, ffn_drop=0.1))),
decoder_cfg=dict(
num_layers=6,
return_intermediate=True,
layer_cfg=dict(
self_attn_cfg=dict(embed_dims=256, num_heads=8, dropout=0.1),
cross_attn_cfg=dict(embed_dims=256),
ffn_cfg=dict(
embed_dims=256, feedforward_channels=1024, ffn_drop=0.1)),
post_norm_cfg=None),
positional_encoding_cfg=dict(num_feats=128, normalize=True, offset=-0.5),
bbox_head=dict(
type='DeformableDETRHead',
num_query=300,
num_pred=6, # TODO: modify this
with_box_refine=False,
as_two_stage=False,
num_classes=80,
in_channels=2048,
sync_cls_avg_factor=True,
as_two_stage=False,
transformer=dict(
type='DeformableDetrTransformer',
encoder=dict(
type='DetrTransformerEncoder',
num_layers=6,
transformerlayers=dict(
type='BaseTransformerLayer',
attn_cfgs=dict(
type='MultiScaleDeformableAttention', embed_dims=256),
feedforward_channels=1024,
ffn_dropout=0.1,
operation_order=('self_attn', 'norm', 'ffn', 'norm'))),
decoder=dict(
type='DeformableDetrTransformerDecoder',
num_layers=6,
return_intermediate=True,
transformerlayers=dict(
type='DetrTransformerDecoderLayer',
attn_cfgs=[
dict(
type='MultiheadAttention',
embed_dims=256,
num_heads=8,
dropout=0.1),
dict(
type='MultiScaleDeformableAttention',
embed_dims=256)
],
feedforward_channels=1024,
ffn_dropout=0.1,
operation_order=('self_attn', 'norm', 'cross_attn', 'norm',
'ffn', 'norm')))),
positional_encoding=dict(
type='SinePositionalEncoding',
num_feats=128,
normalize=True,
offset=-0.5),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
Expand Down Expand Up @@ -150,7 +132,10 @@
# learning policy
max_epochs = 50
train_cfg = dict(
type='EpochBasedTrainLoop', max_epochs=max_epochs, val_interval=1)
# type='EpochBasedTrainLoop', max_epochs=max_epochs, val_interval=1)
type='IterBasedTrainLoop',
max_iters=max_epochs,
val_interval=1)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')

Expand Down
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
_base_ = 'deformable-detr_r50_16xb2-50e_coco.py'
model = dict(bbox_head=dict(with_box_refine=True))
model = dict(with_box_refine=True, bbox_head=dict(with_box_refine=True))
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
_base_ = 'deformable-detr_refine_r50_16xb2-50e_coco.py'
model = dict(bbox_head=dict(as_two_stage=True))
model = dict(as_two_stage=True, bbox_head=dict(num_pred=7, as_two_stage=True))
2 changes: 1 addition & 1 deletion configs/detr/detr_r18_8xb2-500e_coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
backbone=dict(
depth=18,
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet18')),
neck=dict(in_channels=[64, 128, 256, 512]))
bbox_head=dict(in_channels=512))
61 changes: 25 additions & 36 deletions configs/detr/detr_r50_8xb2-150e_coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
]
model = dict(
type='DETR',
num_query=100,
data_preprocessor=dict(
type='DetDataPreprocessor',
mean=[123.675, 116.28, 103.53],
Expand All @@ -19,45 +20,33 @@
norm_eval=True,
style='pytorch',
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),
encoder_cfg=dict(
num_layers=6,
layer_cfg=dict(
self_attn_cfg=dict(embed_dims=256, num_heads=8, dropout=0.1),
ffn_cfg=dict(
embed_dims=256,
feedforward_channels=2048,
num_fcs=2,
ffn_drop=0.1,
act_cfg=dict(type='ReLU', inplace=True)))),
decoder_cfg=dict(
num_layers=6,
layer_cfg=dict(
self_attn_cfg=dict(embed_dims=256, num_heads=8, dropout=0.1),
cross_attn_cfg=dict(embed_dims=256, num_heads=8, dropout=0.1),
ffn_cfg=dict(
embed_dims=256,
feedforward_channels=2048,
num_fcs=2,
ffn_drop=0.1,
act_cfg=dict(type='ReLU', inplace=True))),
return_intermediate=True),
positional_encoding_cfg=dict(num_feats=128, normalize=True),
bbox_head=dict(
type='DETRHead',
num_classes=80,
in_channels=2048,
transformer=dict(
type='Transformer',
encoder=dict(
type='DetrTransformerEncoder',
num_layers=6,
transformerlayers=dict(
type='BaseTransformerLayer',
attn_cfgs=[
dict(
type='MultiheadAttention',
embed_dims=256,
num_heads=8,
dropout=0.1)
],
feedforward_channels=2048,
ffn_dropout=0.1,
operation_order=('self_attn', 'norm', 'ffn', 'norm'))),
decoder=dict(
type='DetrTransformerDecoder',
return_intermediate=True,
num_layers=6,
transformerlayers=dict(
type='DetrTransformerDecoderLayer',
attn_cfgs=dict(
type='MultiheadAttention',
embed_dims=256,
num_heads=8,
dropout=0.1),
feedforward_channels=2048,
ffn_dropout=0.1,
operation_order=('self_attn', 'norm', 'cross_attn', 'norm',
'ffn', 'norm')),
)),
positional_encoding=dict(
type='SinePositionalEncoding', num_feats=128, normalize=True),
embed_dims=256,
loss_cls=dict(
type='CrossEntropyLoss',
bg_cls_weight=0.1,
Expand Down
164 changes: 164 additions & 0 deletions configs/dino/dino_4scale_r50_8x2_12e_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
_base_ = [
'../_base_/datasets/coco_detection.py', '../_base_/default_runtime.py'
]
model = dict(
type='DINO',
num_query=900,
with_box_refine=True,
as_two_stage=True,
data_preprocessor=dict(
type='DetDataPreprocessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True,
pad_size_divisor=1),
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=False),
norm_eval=True,
style='pytorch',
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),
neck=dict(
type='ChannelMapper',
in_channels=[512, 1024, 2048],
kernel_size=1,
out_channels=256,
act_cfg=None,
norm_cfg=dict(type='GN', num_groups=32),
num_outs=4),
encoder_cfg=dict(
num_layers=6,
layer_cfg=dict(
self_attn_cfg=dict(embed_dims=256, num_levels=4,
dropout=0.0), # 0.1 for DeformDETR
ffn_cfg=dict(
embed_dims=256,
feedforward_channels=2048, # 1024 for DeformDETR
ffn_drop=0.0))), # 0.1 for DeformDETR
decoder_cfg=dict(
num_layers=6,
return_intermediate=True,
layer_cfg=dict(
self_attn_cfg=dict(embed_dims=256, num_heads=8,
dropout=0.0), # 0.1 for DeformDETR
cross_attn_cfg=dict(embed_dims=256, num_levels=4,
dropout=0.0), # 0.1 for DeformDETR
ffn_cfg=dict(
embed_dims=256,
feedforward_channels=2048, # 1024 for DeformDETR
ffn_drop=0.0)), # 0.1 for DeformDETR
post_norm_cfg=None),
positional_encoding_cfg=dict(
num_feats=128,
normalize=True,
offset=0.0, # -0.5 for DeformDETR
temperature=20), # 10000 for DeformDETR
bbox_head=dict(
type='DINOHead',
num_pred=7, # TODO: modify this
# num_feature_levels=4,
with_box_refine=True,
as_two_stage=True,
num_classes=80,
sync_cls_avg_factor=True,
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0), # 2.0 in DeformDETR
loss_bbox=dict(type='L1Loss', loss_weight=5.0),
loss_iou=dict(type='GIoULoss', loss_weight=2.0)),
# training and testing settings
train_cfg=dict(
assigner=dict(
type='HungarianAssigner',
match_costs=[
dict(type='FocalLossCost', weight=2.0),
dict(type='BBoxL1Cost', weight=5.0, box_format='xywh'),
dict(type='IoUCost', iou_mode='giou', weight=2.0)
])),
test_cfg=dict(max_per_img=300)) # 100 for DeformDETR

# train_pipeline, NOTE the img_scale and the Pad's size_divisor is different
# from the default setting in mmdet.
train_pipeline = [
dict(
type='LoadImageFromFile',
file_client_args={{_base_.file_client_args}}),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='RandomFlip', prob=0.5),
dict(
type='RandomChoice',
transforms=[
[
dict(
type='RandomChoiceResize',
scales=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),
(608, 1333), (640, 1333), (672, 1333), (704, 1333),
(736, 1333), (768, 1333), (800, 1333)],
keep_ratio=True)
],
[
dict(
type='RandomChoiceResize',
# The radio of all image in train dataset < 7
# follow the original implement
scales=[(400, 4200), (500, 4200), (600, 4200)],
keep_ratio=True),
dict(
type='RandomCrop',
crop_type='absolute_range',
crop_size=(384, 600),
allow_negative_crop=True),
dict(
type='RandomChoiceResize',
scales=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),
(608, 1333), (640, 1333), (672, 1333), (704, 1333),
(736, 1333), (768, 1333), (800, 1333)],
keep_ratio=True)
]
]),
dict(type='PackDetInputs')
]
train_dataloader = dict(
dataset=dict(
filter_cfg=dict(filter_empty_gt=False), pipeline=train_pipeline))

# optimizer
optim_wrapper = dict(
type='OptimWrapper',
optimizer=dict(
type='AdamW',
lr=0.0001, # 0.0002 for DeformDETR
weight_decay=0.0001),
clip_grad=dict(max_norm=0.1, norm_type=2),
paramwise_cfg=dict(custom_keys={'backbone': dict(lr_mult=0.1)})
) # custom_keys contains sampling_offsets and reference_points in DeformDETR # noqa

# learning policy
max_epochs = 12
train_cfg = dict(
type='EpochBasedTrainLoop', max_epochs=max_epochs, val_interval=1)

val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')

param_scheduler = [
dict(
type='MultiStepLR',
begin=0,
end=max_epochs,
by_epoch=True,
milestones=[11],
gamma=0.1)
]

# NOTE: `auto_scale_lr` is for automatically scaling LR,
# USER SHOULD NOT CHANGE ITS VALUES.
# base_batch_size = (8 GPUs) x (2 samples per GPU)
auto_scale_lr = dict(base_batch_size=16)
2 changes: 1 addition & 1 deletion mmdet/datasets/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1136,7 +1136,7 @@ def transform(self, results: dict) -> dict:
if patch[2] == patch[0] or patch[3] == patch[1]:
continue
overlaps = boxes.overlaps(
HorizontalBoxes(patch.reshape(-1, 4)),
HorizontalBoxes(patch.reshape(-1, 4).astype(np.float32)),
boxes).numpy().reshape(-1)
if len(overlaps) > 0 and overlaps.min() < min_iou:
continue
Expand Down
Loading