Skip to content

Commit ba5351e

Browse files
Richard-meiVVsssssk
and
VVsssssk
authored
add gfl_trt (open-mmlab#124)
* add gfl_trt * add gfl_head.py * add batch_integral * lint code * add gfl unit test * fix unit test * add gfl benchmark * fix unit test bug * Update gfl_head.py * Update __init__.py remove '**_forward_single' * fix lint error and ut error * fix docs and benchmark Co-authored-by: VVsssssk <[email protected]>
1 parent e89becd commit ba5351e

File tree

6 files changed

+311
-1
lines changed

6 files changed

+311
-1
lines changed

docs/en/benchmark.md

+14
Original file line numberDiff line numberDiff line change
@@ -996,6 +996,20 @@ Users can directly test the performance through [how_to_evaluate_a_model.md](tut
996996
<td align="center">-</td>
997997
<td>$MMDET_DIR/configs/cascade_rcnn/cascade_rcnn_r50_caffe_fpn_1x_coco.py</td>
998998
</tr>
999+
<tr>
1000+
<td align="center">GFL</td>
1001+
<td align="center">Object Detection</td>
1002+
<td align="center">COCO2017</td>
1003+
<td align="center">box AP</td>
1004+
<td align="center">40.2</td>
1005+
<td align="center">-</td>
1006+
<td align="center">40.2</td>
1007+
<td align="center">40.2</td>
1008+
<td align="center">40.0</td>
1009+
<td align="center">-</td>
1010+
<td align="center">-</td>
1011+
<td>$MMDET_DIR/configs/gfl/gfl_r50_fpn_1x_coco.py</td>
1012+
</tr>
9991013
<tr>
10001014
<td align="center" rowspan="2">Mask R-CNN</td>
10011015
<td align="center" rowspan="2">Instance Segmentation</td>

docs/en/codebases/mmdet.md

+1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ Please refer to [get_started.md](https://github.com/open-mmlab/mmdetection/blob/
2222
| Cascade R-CNN | ObjectDetection | Y | Y | N | Y | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/cascade_rcnn) |
2323
| Faster R-CNN | ObjectDetection | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/faster_rcnn) |
2424
| Faster R-CNN + DCN | ObjectDetection | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/faster_rcnn) |
25+
| GFL | ObjectDetection | Y | Y | N | ? | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/gfl) |
2526
| Cascade Mask R-CNN | InstanceSegmentation | Y | N | N | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/cascade_rcnn) |
2627
| Mask R-CNN | InstanceSegmentation | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/mask_rcnn) |
2728

docs/en/supported_models.md

+1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ The table below lists the models that are guaranteed to be exportable to other b
1414
| SSD[*](#note) | MMDetection | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/ssd) |
1515
| FoveaBox | MMDetection | Y | N | N | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/foveabox) |
1616
| ATSS | MMDetection | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/atss) |
17+
| GFL | MMDetection | Y | Y | N | ? | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/gfl) |
1718
| Cascade R-CNN | MMDetection | Y | Y | N | Y | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/cascade_rcnn) |
1819
| Cascade Mask R-CNN | MMDetection | Y | N | N | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/cascade_rcnn) |
1920
| VFNet | MMDetection | N | N | N | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/vfnet) |

mmdeploy/codebase/mmdet/models/dense_heads/__init__.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from .base_dense_head import (base_dense_head__get_bbox,
33
base_dense_head__get_bboxes__ncnn)
44
from .fovea_head import fovea_head__get_bboxes
5+
from .gfl_head import gfl_head__get_bbox
56
from .rpn_head import rpn_head__get_bboxes, rpn_head__get_bboxes__ncnn
67
from .ssd_head import ssd_head__get_bboxes__ncnn
78
from .yolo_head import yolov3_head__get_bboxes, yolov3_head__get_bboxes__ncnn
@@ -12,5 +13,6 @@
1213
'yolov3_head__get_bboxes', 'yolov3_head__get_bboxes__ncnn',
1314
'yolox_head__get_bboxes', 'base_dense_head__get_bbox',
1415
'fovea_head__get_bboxes', 'base_dense_head__get_bboxes__ncnn',
15-
'ssd_head__get_bboxes__ncnn', 'yolox_head__get_bboxes__ncnn'
16+
'ssd_head__get_bboxes__ncnn', 'yolox_head__get_bboxes__ncnn',
17+
'gfl_head__get_bbox'
1618
]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
import torch
3+
import torch.nn.functional as F
4+
5+
from mmdeploy.codebase.mmdet import (get_post_processing_params,
6+
multiclass_nms, pad_with_value)
7+
from mmdeploy.core import FUNCTION_REWRITER
8+
from mmdeploy.utils import Backend, get_backend, is_dynamic_shape
9+
10+
11+
@FUNCTION_REWRITER.register_rewriter(
12+
func_name='mmdet.models.dense_heads.gfl_head.'
13+
'GFLHead.get_bboxes')
14+
def gfl_head__get_bbox(ctx,
15+
self,
16+
cls_scores,
17+
bbox_preds,
18+
score_factors=None,
19+
img_metas=None,
20+
cfg=None,
21+
rescale=False,
22+
with_nms=True,
23+
**kwargs):
24+
"""Rewrite `get_bboxes` of `GFLHead` for default backend.
25+
26+
Rewrite this function to deploy model, transform network output for a
27+
batch into bbox predictions.
28+
29+
Args:
30+
ctx (ContextCaller): The context with additional information.
31+
self: The instance of the original class.
32+
cls_scores (list[Tensor]): Classification scores for all
33+
scale levels, each is a 4D-tensor, has shape
34+
(batch_size, num_priors * num_classes, H, W).
35+
bbox_preds (list[Tensor]): Box energies / deltas for all
36+
scale levels, each is a 4D-tensor, has shape
37+
(batch_size, num_priors * 4, H, W).
38+
score_factors (list[Tensor], Optional): Score factor for
39+
all scale level, each is a 4D-tensor, has shape
40+
(batch_size, num_priors * 1, H, W). Default None.
41+
img_metas (list[dict], Optional): Image meta info. Default None.
42+
cfg (mmcv.Config, Optional): Test / postprocessing configuration,
43+
if None, test_cfg would be used. Default None.
44+
rescale (bool): If True, return boxes in original image space.
45+
Default False.
46+
with_nms (bool): If True, do nms before return boxes.
47+
Default True.
48+
49+
Returns:
50+
If with_nms == True:
51+
tuple[Tensor, Tensor]: tuple[Tensor, Tensor]: (dets, labels),
52+
`dets` of shape [N, num_det, 5] and `labels` of shape
53+
[N, num_det].
54+
Else:
55+
tuple[Tensor, Tensor, Tensor]: batch_mlvl_bboxes,
56+
batch_mlvl_scores, batch_mlvl_centerness
57+
"""
58+
deploy_cfg = ctx.cfg
59+
is_dynamic_flag = is_dynamic_shape(deploy_cfg)
60+
backend = get_backend(deploy_cfg)
61+
num_levels = len(cls_scores)
62+
63+
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
64+
mlvl_priors = self.prior_generator.grid_priors(
65+
featmap_sizes, dtype=bbox_preds[0].dtype, device=bbox_preds[0].device)
66+
67+
mlvl_cls_scores = [cls_scores[i].detach() for i in range(num_levels)]
68+
mlvl_bbox_preds = [bbox_preds[i].detach() for i in range(num_levels)]
69+
if score_factors is None:
70+
with_score_factors = False
71+
mlvl_score_factor = [None for _ in range(num_levels)]
72+
else:
73+
with_score_factors = True
74+
mlvl_score_factor = [
75+
score_factors[i].detach() for i in range(num_levels)
76+
]
77+
mlvl_score_factors = []
78+
assert img_metas is not None
79+
img_shape = img_metas[0]['img_shape']
80+
81+
assert len(cls_scores) == len(bbox_preds) == len(mlvl_priors)
82+
batch_size = cls_scores[0].shape[0]
83+
cfg = self.test_cfg
84+
pre_topk = cfg.get('nms_pre', -1)
85+
86+
mlvl_valid_bboxes = []
87+
mlvl_valid_scores = []
88+
mlvl_valid_priors = []
89+
90+
for cls_score, bbox_pred, score_factors, priors, stride in zip(
91+
mlvl_cls_scores, mlvl_bbox_preds, mlvl_score_factor, mlvl_priors,
92+
self.prior_generator.strides):
93+
assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
94+
assert stride[0] == stride[1]
95+
96+
scores = cls_score.permute(0, 2, 3, 1).reshape(batch_size, -1,
97+
self.cls_out_channels)
98+
if self.use_sigmoid_cls:
99+
scores = scores.sigmoid()
100+
nms_pre_score = scores
101+
else:
102+
scores = scores.softmax(-1)
103+
nms_pre_score = scores
104+
if with_score_factors:
105+
score_factors = score_factors.permute(0, 2, 3,
106+
1).reshape(batch_size,
107+
-1).sigmoid()
108+
score_factors = score_factors.unsqueeze(2)
109+
bbox_pred = batched_integral(self.integral,
110+
bbox_pred.permute(0, 2, 3, 1)) * stride[0]
111+
if not is_dynamic_flag:
112+
priors = priors.data
113+
priors = priors.expand(batch_size, -1, priors.size(-1))
114+
if pre_topk > 0:
115+
if with_score_factors:
116+
nms_pre_score = nms_pre_score * score_factors
117+
if backend == Backend.TENSORRT:
118+
priors = pad_with_value(priors, 1, pre_topk)
119+
bbox_pred = pad_with_value(bbox_pred, 1, pre_topk)
120+
scores = pad_with_value(scores, 1, pre_topk, 0.)
121+
nms_pre_score = pad_with_value(nms_pre_score, 1, pre_topk, 0.)
122+
if with_score_factors:
123+
score_factors = pad_with_value(score_factors, 1, pre_topk,
124+
0.)
125+
126+
# Get maximum scores for foreground classes.
127+
if self.use_sigmoid_cls:
128+
max_scores, _ = nms_pre_score.max(-1)
129+
else:
130+
max_scores, _ = nms_pre_score[..., :-1].max(-1)
131+
_, topk_inds = max_scores.topk(pre_topk)
132+
batch_inds = torch.arange(
133+
batch_size,
134+
device=bbox_pred.device).view(-1, 1).expand_as(topk_inds)
135+
priors = priors[batch_inds, topk_inds, :]
136+
bbox_pred = bbox_pred[batch_inds, topk_inds, :]
137+
scores = scores[batch_inds, topk_inds, :]
138+
if with_score_factors:
139+
score_factors = score_factors[batch_inds, topk_inds, :]
140+
141+
mlvl_valid_bboxes.append(bbox_pred)
142+
mlvl_valid_scores.append(scores)
143+
priors = self.anchor_center(priors)
144+
mlvl_valid_priors.append(priors)
145+
if with_score_factors:
146+
mlvl_score_factors.append(score_factors)
147+
148+
batch_mlvl_bboxes_pred = torch.cat(mlvl_valid_bboxes, dim=1)
149+
batch_scores = torch.cat(mlvl_valid_scores, dim=1)
150+
batch_priors = torch.cat(mlvl_valid_priors, dim=1)
151+
batch_bboxes = self.bbox_coder.decode(
152+
batch_priors, batch_mlvl_bboxes_pred, max_shape=img_shape)
153+
if with_score_factors:
154+
batch_score_factors = torch.cat(mlvl_score_factors, dim=1)
155+
156+
if not self.use_sigmoid_cls:
157+
batch_scores = batch_scores[..., :self.num_classes]
158+
159+
if with_score_factors:
160+
batch_scores = batch_scores * batch_score_factors
161+
if not with_nms:
162+
return batch_bboxes, batch_scores
163+
post_params = get_post_processing_params(deploy_cfg)
164+
max_output_boxes_per_class = post_params.max_output_boxes_per_class
165+
iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold)
166+
score_threshold = cfg.get('score_thr', post_params.score_threshold)
167+
pre_top_k = post_params.pre_top_k
168+
keep_top_k = cfg.get('max_per_img', post_params.keep_top_k)
169+
return multiclass_nms(
170+
batch_bboxes,
171+
batch_scores,
172+
max_output_boxes_per_class,
173+
iou_threshold=iou_threshold,
174+
score_threshold=score_threshold,
175+
pre_top_k=pre_top_k,
176+
keep_top_k=keep_top_k)
177+
178+
179+
def batched_integral(intergral, x):
180+
batch_size = x.size(0)
181+
x = F.softmax(x.reshape(batch_size, -1, intergral.reg_max + 1), dim=2)
182+
x = F.linear(x,
183+
intergral.project.type_as(x).unsqueeze(0)).reshape(
184+
batch_size, -1, 4)
185+
return x

tests/test_codebase/test_mmdet/test_mmdet_models.py

+107
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,31 @@ def get_single_roi_extractor():
157157
return model
158158

159159

160+
def get_gfl_head_model():
161+
test_cfg = mmcv.Config(
162+
dict(
163+
nms_pre=1000,
164+
min_bbox_size=0,
165+
score_thr=0.05,
166+
nms=dict(type='nms', iou_threshold=0.6),
167+
max_per_img=100))
168+
anchor_generator = dict(
169+
type='AnchorGenerator',
170+
scales_per_octave=1,
171+
octave_base_scale=8,
172+
ratios=[1.0],
173+
strides=[8, 16, 32, 64, 128])
174+
from mmdet.models.dense_heads import GFLHead
175+
model = GFLHead(
176+
num_classes=3,
177+
in_channels=256,
178+
reg_max=3,
179+
test_cfg=test_cfg,
180+
anchor_generator=anchor_generator)
181+
model.requires_grad_(False)
182+
return model
183+
184+
160185
def test_focus_forward_ncnn():
161186
backend_type = Backend.NCNN
162187
check_backend(backend_type)
@@ -349,6 +374,88 @@ def test_get_bboxes_of_rpn_head(backend_type: Backend):
349374
assert rewrite_outputs is not None
350375

351376

377+
@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME])
378+
def test_get_bboxes_of_gfl_head(backend_type):
379+
check_backend(backend_type)
380+
head = get_gfl_head_model()
381+
head.cpu().eval()
382+
s = 4
383+
img_metas = [{
384+
'scale_factor': np.ones(4),
385+
'pad_shape': (s, s, 3),
386+
'img_shape': (s, s, 3)
387+
}]
388+
output_names = ['dets']
389+
deploy_cfg = mmcv.Config(
390+
dict(
391+
backend_config=dict(type=backend_type.value),
392+
onnx_config=dict(output_names=output_names, input_shape=None),
393+
codebase_config=dict(
394+
type='mmdet',
395+
task='ObjectDetection',
396+
model_type='ncnn_end2end',
397+
post_processing=dict(
398+
score_threshold=0.05,
399+
iou_threshold=0.5,
400+
max_output_boxes_per_class=200,
401+
pre_top_k=5000,
402+
keep_top_k=100,
403+
background_label_id=-1,
404+
))))
405+
406+
seed_everything(1234)
407+
cls_score = [
408+
torch.rand(1, 3, pow(2, i), pow(2, i)) for i in range(5, 0, -1)
409+
]
410+
seed_everything(5678)
411+
bboxes = [torch.rand(1, 16, pow(2, i), pow(2, i)) for i in range(5, 0, -1)]
412+
413+
# to get outputs of onnx model after rewrite
414+
img_metas[0]['img_shape'] = torch.Tensor([s, s])
415+
wrapped_model = WrapModel(
416+
head, 'get_bboxes', img_metas=img_metas, with_nms=True)
417+
rewrite_inputs = {
418+
'cls_scores': cls_score,
419+
'bbox_preds': bboxes,
420+
}
421+
# do not run with ncnn backend
422+
run_with_backend = False if backend_type in [Backend.NCNN] else True
423+
rewrite_outputs, is_backend_output = get_rewrite_outputs(
424+
wrapped_model=wrapped_model,
425+
model_inputs=rewrite_inputs,
426+
deploy_cfg=deploy_cfg,
427+
run_with_backend=run_with_backend)
428+
assert rewrite_outputs is not None
429+
430+
431+
@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME])
432+
def test_forward_of_gfl_head(backend_type):
433+
check_backend(backend_type)
434+
head = get_gfl_head_model()
435+
head.cpu().eval()
436+
deploy_cfg = mmcv.Config(
437+
dict(
438+
backend_config=dict(type=backend_type.value),
439+
onnx_config=dict(input_shape=None)))
440+
feats = [torch.rand(1, 256, pow(2, i), pow(2, i)) for i in range(5, 0, -1)]
441+
model_outputs = [head.forward(feats)]
442+
wrapped_model = WrapModel(head, 'forward')
443+
rewrite_inputs = {
444+
'feats': feats,
445+
}
446+
rewrite_outputs, is_backend_output = get_rewrite_outputs(
447+
wrapped_model=wrapped_model,
448+
model_inputs=rewrite_inputs,
449+
deploy_cfg=deploy_cfg)
450+
model_outputs[0] = [*model_outputs[0][0], *model_outputs[0][1]]
451+
for model_output, rewrite_output in zip(model_outputs[0],
452+
rewrite_outputs[0]):
453+
model_output = model_output.squeeze().cpu().numpy()
454+
rewrite_output = rewrite_output.squeeze()
455+
assert np.allclose(
456+
model_output, rewrite_output, rtol=1e-03, atol=1e-05)
457+
458+
352459
def _replace_r50_with_r18(model):
353460
"""Replace ResNet50 with ResNet18 in config."""
354461
model = copy.deepcopy(model)

0 commit comments

Comments
 (0)