Skip to content

Commit 8cd3db3

Browse files
committed
add yolox pai
1 parent f35c723 commit 8cd3db3

13 files changed

+646
-4
lines changed

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ Results and models are available in the [model zoo](docs/en/model_zoo.md).
184184
<li><a href="configs/deformable_detr">Deformable DETR (ICLR'2021)</a></li>
185185
<li><a href="configs/tood">TOOD (ICCV'2021)</a></li>
186186
<li><a href="configs/ddod">DDOD (ACM MM'2021)</a></li>
187+
<li><a href="configs/yoloxpai">YOLOX-PAI (ArXiv'2022)</a></li>
187188
</ul>
188189
</td>
189190
<td>

configs/yoloxpai/README.md

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# YOLOX-PAI
2+
3+
> [YOLOX-PAI: An Improved YOLOX, Stronger and Faster than YOLOv6](https://arxiv.org/abs/2208.13040)
4+
5+
<!-- [ALGORITHM] -->
6+
7+
## Abstract
8+
9+
We develop an all-in-one computer vision toolbox named EasyCV to facilitate the use of various SOTA computer vision methods. Recently, we add YOLOX-PAI, an improved version of YOLOX, into EasyCV. We conduct ablation studies to investigate the influence of some detection methods on YOLOX. We also provide an easy use for PAI-Blade which is used to accelerate the inference process based on BladeDISC and TensorRT. Finally, we receive 42.8 mAP on COCO dateset within 1.0 ms on a single NVIDIA V100 GPU, which is a bit faster than YOLOv6. A simple but efficient predictor api is also designed in EasyCV to conduct end2end object detection. Codes and models are now available at: this https URL.
10+
11+
<div align=center>
12+
<img src="https://user-images.githubusercontent.com/24734142/189808824-094c66f7-f95c-4e31-8a1e-50515fce545d.png"/>
13+
</div>
14+
15+
## Results and Models
16+
17+
| Backbone | ASFF | TOOD | box AP | Config | Download |
18+
| :---------: | :--: | :--: | :----: | :---------------------------------------------------------------------------------------------------------------------: | :----------------------: |
19+
| YOLOX-PAI-s | N | N | 41.8 | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/yoloxpai/yolox_pai_s_8x8_300e_coco.py) | [model](<>) \| [log](<>) |
20+
| YOLOX-PAI-s | Y | N | 42.8 | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/yoloxpai/yolox_pai_asff_s_8x8_300e_coco.py) | [model](<>) \| [log](<>) |
21+
| YOLOX-PAI-s | Y | Y | 43.6 | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/yoloxpai/yolox_pai_asff_tood_s_8x8_300e_coco.py) | [model](<>) \| [log](<>) |

configs/yoloxpai/metafile.yml

+58
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
Collections:
2+
- Name: YOLOX-PAI
3+
Metadata:
4+
Training Data: COCO
5+
Training Techniques:
6+
- SGD with Nesterov
7+
- Weight Decay
8+
- Cosine Annealing Lr Updater
9+
Training Resources: 8x TITANXp GPUs
10+
Architecture:
11+
- RepVGG
12+
- PAFPN
13+
Paper:
14+
URL: https://arxiv.org/abs/2208.13040
15+
Title: 'YOLOX-PAI: An Improved YOLOX, Stronger and Faster than YOLOv6'
16+
README: configs/yoloxpai/README.md
17+
Code:
18+
URL:
19+
Version:
20+
21+
22+
Models:
23+
- Name: yolox_pai_s_8x8_300e_coco
24+
In Collection: YOLOX-PAI
25+
Config: configs/yoloxpai/yolox_pai_s_8x8_300e_coco.py
26+
Metadata:
27+
Training Memory (GB):
28+
Epochs: 300
29+
Results:
30+
- Task: Object Detection
31+
Dataset: COCO
32+
Metrics:
33+
box AP: 41.8
34+
Weights:
35+
- Name: yolox_pai_asff_s_8x8_300e_coco
36+
In Collection: YOLOX-PAI
37+
Config: configs/yoloxpai/yolox_pai_asff_s_8x8_300e_coco.py
38+
Metadata:
39+
Training Memory (GB):
40+
Epochs: 300
41+
Results:
42+
- Task: Object Detection
43+
Dataset: COCO
44+
Metrics:
45+
box AP: 42.8
46+
Weights:
47+
- Name: yolox_pai_asff_tood_s_8x8_300e_coco
48+
In Collection: YOLOX-PAI
49+
Config: configs/yoloxpai/yolox_pai_asff_tood_s_8x8_300e_coco.py
50+
Metadata:
51+
Training Memory (GB):
52+
Epochs: 300
53+
Results:
54+
- Task: Object Detection
55+
Dataset: COCO
56+
Metrics:
57+
box AP: 43.6
58+
Weights:
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
_base_ = './yolox_pai_s_8x8_300e_coco.py'
2+
3+
model = dict(neck=dict(type='YOLOXASFFPAFPN'))
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
_base_ = './yolox_pai_s_8x8_300e_coco.py'
2+
3+
model = dict(bbox_head=dict(type='YOLOXTOODHead'))
4+
find_unused_parameters = True
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
_base_ = '../yolox/yolox_s_8x8_300e_coco.py'
2+
custom_imports = dict(imports=['mmcls.models'], allow_failed_imports=False)
3+
4+
model = dict(
5+
backbone=dict(
6+
_delete_=True,
7+
type='mmcls.RepVGG',
8+
arch=dict(
9+
num_blocks=[3, 5, 7, 3],
10+
base_channels=32,
11+
width_factor=[1, 1, 1, 1],
12+
group_layer_map=None,
13+
se_cfg=None),
14+
add_ppf=True,
15+
norm_cfg=dict(type='BN', eps=0.001, momentum=0.03),
16+
out_indices=(1, 2, 3),
17+
),
18+
neck=dict(act_cfg=dict(type='SiLU')),
19+
bbox_head=dict(act_cfg=dict(type='SiLU')))

mmdet/models/dense_heads/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from .yolo_head import YOLOV3Head
4242
from .yolof_head import YOLOFHead
4343
from .yolox_head import YOLOXHead
44+
from .yolox_tood_head import YOLOXTOODHead
4445

4546
__all__ = [
4647
'AnchorFreeHead', 'AnchorHead', 'GuidedAnchorHead', 'FeatureAdaption',
@@ -54,5 +55,5 @@
5455
'DETRHead', 'YOLOFHead', 'DeformableDETRHead', 'SOLOHead',
5556
'DecoupledSOLOHead', 'CenterNetHead', 'YOLOXHead',
5657
'DecoupledSOLOLightHead', 'LADHead', 'TOODHead', 'MaskFormerHead',
57-
'Mask2FormerHead', 'SOLOV2Head', 'DDODHead'
58+
'Mask2FormerHead', 'SOLOV2Head', 'DDODHead', 'YOLOXTOODHead'
5859
]
+104
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
import torch
3+
import torch.nn as nn
4+
import torch.nn.functional as F
5+
from mmcv.cnn import ConvModule
6+
7+
from mmdet.core import multi_apply
8+
from mmdet.models.builder import HEADS
9+
from mmdet.models.dense_heads import YOLOXHead
10+
from mmdet.models.dense_heads.tood_head import TaskDecomposition
11+
12+
13+
@HEADS.register_module()
14+
class YOLOXTOODHead(YOLOXHead):
15+
"""YOLOXTOODHead head used in `YOLOX-PAI.
16+
17+
<https://arxiv.org/abs/2208.13040>`_.
18+
19+
Args:
20+
tood_stacked_convs (int): Number of conv layers in TOOD head.
21+
Default: 3.
22+
la_down_rate (int): Downsample rate of layer attention.
23+
Default: 32.
24+
tood_norm_cfg (dict): Config dict for normalization layer.
25+
"""
26+
27+
def __init__(self,
28+
*args,
29+
tood_stacked_convs=3,
30+
la_down_rate=32,
31+
tood_norm_cfg=dict(
32+
type='GN', num_groups=32, requires_grad=True),
33+
**kwargs):
34+
super().__init__(*args, **kwargs)
35+
self.tood_stacked_convs = tood_stacked_convs
36+
self.la_down_rate = la_down_rate
37+
self.tood_norm_cfg = tood_norm_cfg
38+
39+
self._init_tood_layers()
40+
41+
def _init_tood_layers(self):
42+
self.multi_level_cls_decomps = nn.ModuleList()
43+
self.multi_level_reg_decomps = nn.ModuleList()
44+
for _ in self.strides:
45+
self.multi_level_cls_decomps.append(
46+
TaskDecomposition(self.in_channels, self.tood_stacked_convs,
47+
self.tood_stacked_convs * self.la_down_rate,
48+
self.conv_cfg, self.tood_norm_cfg))
49+
self.multi_level_reg_decomps.append(
50+
TaskDecomposition(self.in_channels, self.tood_stacked_convs,
51+
self.tood_stacked_convs * self.la_down_rate,
52+
self.conv_cfg, self.tood_norm_cfg))
53+
54+
self.inter_convs = nn.ModuleList()
55+
for _ in range(self.tood_stacked_convs):
56+
self.inter_convs.append(
57+
ConvModule(
58+
self.in_channels,
59+
self.in_channels,
60+
3,
61+
stride=1,
62+
padding=1,
63+
conv_cfg=self.conv_cfg,
64+
norm_cfg=self.tood_norm_cfg))
65+
66+
def forward_single(self, x, cls_convs, reg_convs, conv_cls, conv_reg,
67+
conv_obj, cls_decomp, reg_decomp):
68+
"""Forward feature of a single scale level."""
69+
70+
inter_feats = []
71+
for inter_conv in self.inter_convs:
72+
x = inter_conv(x)
73+
inter_feats.append(x)
74+
feat = torch.cat(inter_feats, 1)
75+
76+
avg_feat = F.adaptive_avg_pool2d(feat, (1, 1))
77+
cls_x = cls_decomp(feat, avg_feat)
78+
reg_x = reg_decomp(feat, avg_feat)
79+
80+
cls_feat = cls_convs(cls_x)
81+
reg_feat = reg_convs(reg_x)
82+
83+
cls_score = conv_cls(cls_feat)
84+
bbox_pred = conv_reg(reg_feat)
85+
objectness = conv_obj(reg_feat)
86+
87+
return cls_score, bbox_pred, objectness
88+
89+
def forward(self, feats):
90+
"""Forward features from the upstream network.
91+
92+
Args:
93+
feats (tuple[Tensor]): Features from the upstream network, each is
94+
a 4D-tensor.
95+
Returns:
96+
tuple[Tensor]: A tuple of multi-level predication map, each is a
97+
4D-tensor of shape (batch_size, 5+num_classes, height, width).
98+
"""
99+
100+
return multi_apply(
101+
self.forward_single, feats, self.multi_level_cls_convs,
102+
self.multi_level_reg_convs, self.multi_level_conv_cls,
103+
self.multi_level_conv_reg, self.multi_level_conv_obj,
104+
self.multi_level_cls_decomps, self.multi_level_reg_decomps)

mmdet/models/necks/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,11 @@
1414
from .rfp import RFP
1515
from .ssd_neck import SSDNeck
1616
from .yolo_neck import YOLOV3Neck
17+
from .yolox_asff_pafpn import YOLOXASFFPAFPN
1718
from .yolox_pafpn import YOLOXPAFPN
1819

1920
__all__ = [
2021
'FPN', 'BFP', 'ChannelMapper', 'HRFPN', 'NASFPN', 'FPN_CARAFE', 'PAFPN',
2122
'NASFCOS_FPN', 'RFP', 'YOLOV3Neck', 'FPG', 'DilatedEncoder',
22-
'CTResNetNeck', 'SSDNeck', 'YOLOXPAFPN', 'DyHead'
23+
'CTResNetNeck', 'SSDNeck', 'YOLOXPAFPN', 'DyHead', 'YOLOXASFFPAFPN'
2324
]

0 commit comments

Comments
 (0)