-
Notifications
You must be signed in to change notification settings - Fork 9.6k
/
Copy pathtest_mixins.py
311 lines (280 loc) · 13.2 KB
/
test_mixins.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
# Copyright (c) OpenMMLab. All rights reserved.
import sys
import warnings
import numpy as np
import torch
from mmdet.core import (bbox2roi, bbox_mapping, merge_aug_bboxes,
merge_aug_masks, multiclass_nms)
if sys.version_info >= (3, 7):
from mmdet.utils.contextmanagers import completed
class BBoxTestMixin:
if sys.version_info >= (3, 7):
async def async_test_bboxes(self,
x,
img_metas,
proposals,
rcnn_test_cfg,
rescale=False,
**kwargs):
"""Asynchronized test for box head without augmentation."""
rois = bbox2roi(proposals)
roi_feats = self.bbox_roi_extractor(
x[:len(self.bbox_roi_extractor.featmap_strides)], rois)
if self.with_shared_head:
roi_feats = self.shared_head(roi_feats)
sleep_interval = rcnn_test_cfg.get('async_sleep_interval', 0.017)
async with completed(
__name__, 'bbox_head_forward',
sleep_interval=sleep_interval):
cls_score, bbox_pred = self.bbox_head(roi_feats)
img_shape = img_metas[0]['img_shape']
scale_factor = img_metas[0]['scale_factor']
det_bboxes, det_labels = self.bbox_head.get_bboxes(
rois,
cls_score,
bbox_pred,
img_shape,
scale_factor,
rescale=rescale,
cfg=rcnn_test_cfg)
return det_bboxes, det_labels
def simple_test_bboxes(self,
x,
img_metas,
proposals,
rcnn_test_cfg,
rescale=False):
"""Test only det bboxes without augmentation.
Args:
x (tuple[Tensor]): Feature maps of all scale level.
img_metas (list[dict]): Image meta info.
proposals (List[Tensor]): Region proposals.
rcnn_test_cfg (obj:`ConfigDict`): `test_cfg` of R-CNN.
rescale (bool): If True, return boxes in original image space.
Default: False.
Returns:
tuple[list[Tensor], list[Tensor]]: The first list contains
the boxes of the corresponding image in a batch, each
tensor has the shape (num_boxes, 5) and last dimension
5 represent (tl_x, tl_y, br_x, br_y, score). Each Tensor
in the second list is the labels with shape (num_boxes, ).
The length of both lists should be equal to batch_size.
"""
rois = bbox2roi(proposals)
if rois.shape[0] == 0:
batch_size = len(proposals)
det_bbox = rois.new_zeros(0, 5)
det_label = rois.new_zeros((0, ), dtype=torch.long)
if rcnn_test_cfg is None:
det_bbox = det_bbox[:, :4]
det_label = rois.new_zeros(
(0, self.bbox_head.fc_cls.out_features))
# There is no proposal in the whole batch
return [det_bbox] * batch_size, [det_label] * batch_size
bbox_results = self._bbox_forward(x, rois)
img_shapes = tuple(meta['img_shape'] for meta in img_metas)
scale_factors = tuple(meta['scale_factor'] for meta in img_metas)
# split batch bbox prediction back to each image
cls_score = bbox_results['cls_score']
bbox_pred = bbox_results['bbox_pred']
num_proposals_per_img = tuple(len(p) for p in proposals)
rois = rois.split(num_proposals_per_img, 0)
cls_score = cls_score.split(num_proposals_per_img, 0)
# some detector with_reg is False, bbox_pred will be None
if bbox_pred is not None:
# TODO move this to a sabl_roi_head
# the bbox prediction of some detectors like SABL is not Tensor
if isinstance(bbox_pred, torch.Tensor):
bbox_pred = bbox_pred.split(num_proposals_per_img, 0)
else:
bbox_pred = self.bbox_head.bbox_pred_split(
bbox_pred, num_proposals_per_img)
else:
bbox_pred = (None, ) * len(proposals)
# apply bbox post-processing to each image individually
det_bboxes = []
det_labels = []
for i in range(len(proposals)):
if rois[i].shape[0] == 0:
# There is no proposal in the single image
det_bbox = rois[i].new_zeros(0, 5)
det_label = rois[i].new_zeros((0, ), dtype=torch.long)
if rcnn_test_cfg is None:
det_bbox = det_bbox[:, :4]
det_label = rois[i].new_zeros(
(0, self.bbox_head.fc_cls.out_features))
else:
det_bbox, det_label = self.bbox_head.get_bboxes(
rois[i],
cls_score[i],
bbox_pred[i],
img_shapes[i],
scale_factors[i],
rescale=rescale,
cfg=rcnn_test_cfg)
det_bboxes.append(det_bbox)
det_labels.append(det_label)
return det_bboxes, det_labels
def aug_test_bboxes(self, feats, img_metas, proposal_list, rcnn_test_cfg):
"""Test det bboxes with test time augmentation."""
aug_bboxes = []
aug_scores = []
for x, img_meta in zip(feats, img_metas):
# only one image in the batch
img_shape = img_meta[0]['img_shape']
scale_factor = img_meta[0]['scale_factor']
flip = img_meta[0]['flip']
flip_direction = img_meta[0]['flip_direction']
# TODO more flexible
proposals = bbox_mapping(proposal_list[0][:, :4], img_shape,
scale_factor, flip, flip_direction)
rois = bbox2roi([proposals])
bbox_results = self._bbox_forward(x, rois)
bboxes, scores = self.bbox_head.get_bboxes(
rois,
bbox_results['cls_score'],
bbox_results['bbox_pred'],
img_shape,
scale_factor,
rescale=False,
cfg=None)
aug_bboxes.append(bboxes)
aug_scores.append(scores)
# after merging, bboxes will be rescaled to the original image size
merged_bboxes, merged_scores = merge_aug_bboxes(
aug_bboxes, aug_scores, img_metas, rcnn_test_cfg)
if merged_bboxes.shape[0] == 0:
# There is no proposal in the single image
det_bboxes = merged_bboxes.new_zeros(0, 5)
det_labels = merged_bboxes.new_zeros((0, ), dtype=torch.long)
else:
det_bboxes, det_labels = multiclass_nms(merged_bboxes,
merged_scores,
rcnn_test_cfg.score_thr,
rcnn_test_cfg.nms,
rcnn_test_cfg.max_per_img)
return det_bboxes, det_labels
class MaskTestMixin:
if sys.version_info >= (3, 7):
async def async_test_mask(self,
x,
img_metas,
det_bboxes,
det_labels,
rescale=False,
mask_test_cfg=None):
"""Asynchronized test for mask head without augmentation."""
# image shape of the first image in the batch (only one)
ori_shape = img_metas[0]['ori_shape']
scale_factor = img_metas[0]['scale_factor']
if det_bboxes.shape[0] == 0:
segm_result = [[] for _ in range(self.mask_head.num_classes)]
else:
if rescale and not isinstance(scale_factor,
(float, torch.Tensor)):
scale_factor = det_bboxes.new_tensor(scale_factor)
_bboxes = (
det_bboxes[:, :4] *
scale_factor if rescale else det_bboxes)
mask_rois = bbox2roi([_bboxes])
mask_feats = self.mask_roi_extractor(
x[:len(self.mask_roi_extractor.featmap_strides)],
mask_rois)
if self.with_shared_head:
mask_feats = self.shared_head(mask_feats)
if mask_test_cfg and mask_test_cfg.get('async_sleep_interval'):
sleep_interval = mask_test_cfg['async_sleep_interval']
else:
sleep_interval = 0.035
async with completed(
__name__,
'mask_head_forward',
sleep_interval=sleep_interval):
mask_pred = self.mask_head(mask_feats)
segm_result = self.mask_head.get_seg_masks(
mask_pred, _bboxes, det_labels, self.test_cfg, ori_shape,
scale_factor, rescale)
return segm_result
def simple_test_mask(self,
x,
img_metas,
det_bboxes,
det_labels,
rescale=False):
"""Simple test for mask head without augmentation."""
# image shapes of images in the batch
ori_shapes = tuple(meta['ori_shape'] for meta in img_metas)
scale_factors = tuple(meta['scale_factor'] for meta in img_metas)
if isinstance(scale_factors[0], float):
warnings.warn(
'Scale factor in img_metas should be a '
'ndarray with shape (4,) '
'arrange as (factor_w, factor_h, factor_w, factor_h), '
'The scale_factor with float type has been deprecated. ')
scale_factors = np.array([scale_factors] * 4, dtype=np.float32).T
num_imgs = len(det_bboxes)
if all(det_bbox.shape[0] == 0 for det_bbox in det_bboxes):
segm_results = [[[] for _ in range(self.mask_head.num_classes)]
for _ in range(num_imgs)]
else:
# if det_bboxes is rescaled to the original image size, we need to
# rescale it back to the testing scale to obtain RoIs.
if rescale:
scale_factors = [
torch.from_numpy(scale_factor).to(det_bboxes[0].device)
for scale_factor in scale_factors
]
_bboxes = [
det_bboxes[i][:, :4] *
scale_factors[i] if rescale else det_bboxes[i][:, :4]
for i in range(len(det_bboxes))
]
mask_rois = bbox2roi(_bboxes)
mask_results = self._mask_forward(x, mask_rois)
mask_pred = mask_results['mask_pred']
# split batch mask prediction back to each image
num_mask_roi_per_img = [len(det_bbox) for det_bbox in det_bboxes]
mask_preds = mask_pred.split(num_mask_roi_per_img, 0)
# apply mask post-processing to each image individually
segm_results = []
for i in range(num_imgs):
if det_bboxes[i].shape[0] == 0:
segm_results.append(
[[] for _ in range(self.mask_head.num_classes)])
else:
segm_result = self.mask_head.get_seg_masks(
mask_preds[i], _bboxes[i], det_labels[i],
self.test_cfg, ori_shapes[i], scale_factors[i],
rescale)
segm_results.append(segm_result)
return segm_results
def aug_test_mask(self, feats, img_metas, det_bboxes, det_labels):
"""Test for mask head with test time augmentation."""
if det_bboxes.shape[0] == 0:
segm_result = [[] for _ in range(self.mask_head.num_classes)]
else:
aug_masks = []
for x, img_meta in zip(feats, img_metas):
img_shape = img_meta[0]['img_shape']
scale_factor = img_meta[0]['scale_factor']
flip = img_meta[0]['flip']
flip_direction = img_meta[0]['flip_direction']
_bboxes = bbox_mapping(det_bboxes[:, :4], img_shape,
scale_factor, flip, flip_direction)
mask_rois = bbox2roi([_bboxes])
mask_results = self._mask_forward(x, mask_rois)
# convert to numpy array to save memory
aug_masks.append(
mask_results['mask_pred'].sigmoid().cpu().numpy())
merged_masks = merge_aug_masks(aug_masks, img_metas, self.test_cfg)
ori_shape = img_metas[0][0]['ori_shape']
scale_factor = det_bboxes.new_ones(4)
segm_result = self.mask_head.get_seg_masks(
merged_masks,
det_bboxes,
det_labels,
self.test_cfg,
ori_shape,
scale_factor=scale_factor,
rescale=False)
return segm_result