-
Notifications
You must be signed in to change notification settings - Fork 640
Expand file tree
/
Copy pathoriented_standard_roi_head.py
More file actions
188 lines (163 loc) · 7.78 KB
/
oriented_standard_roi_head.py
File metadata and controls
188 lines (163 loc) · 7.78 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmrotate.core import rbbox2roi
from ..builder import ROTATED_HEADS
from .rotate_standard_roi_head import RotatedStandardRoIHead
@ROTATED_HEADS.register_module()
class OrientedStandardRoIHead(RotatedStandardRoIHead):
"""Oriented RCNN roi head including one bbox head."""
def forward_dummy(self, x, proposals):
"""Dummy forward function.
Args:
x (list[Tensors]): list of multi-level img features.
proposals (list[Tensors]): list of region proposals.
Returns:
list[Tensors]: list of region of interest.
"""
outs = ()
rois = rbbox2roi([proposals])
if self.with_bbox:
bbox_results = self._bbox_forward(x, rois)
outs = outs + (bbox_results['cls_score'],
bbox_results['bbox_pred'])
return outs
def forward_train(self,
x,
img_metas,
proposal_list,
gt_bboxes,
gt_labels,
gt_bboxes_ignore=None,
gt_masks=None):
"""
Args:
x (list[Tensor]): list of multi-level img features.
img_metas (list[dict]): list of image info dict where each dict
has: 'img_shape', 'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details on the values of these keys see
`mmdet/datasets/pipelines/formatting.py:Collect`.
proposals (list[Tensors]): list of region proposals.
gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
shape (num_gts, 5) in [cx, cy, w, h, a] format.
gt_labels (list[Tensor]): class indices corresponding to each box
gt_bboxes_ignore (None | list[Tensor]): specify which bounding
boxes can be ignored when computing the loss.
gt_masks (None | Tensor) : true segmentation masks for each box
used if the architecture supports a segmentation task. Always
set to None.
Returns:
dict[str, Tensor]: a dictionary of loss components
"""
# assign gts and sample proposals
if self.with_bbox:
num_imgs = len(img_metas)
if gt_bboxes_ignore is None:
gt_bboxes_ignore = [None for _ in range(num_imgs)]
sampling_results = []
for i in range(num_imgs):
assign_result = self.bbox_assigner.assign(
proposal_list[i], gt_bboxes[i], gt_bboxes_ignore[i],
gt_labels[i])
sampling_result = self.bbox_sampler.sample(
assign_result,
proposal_list[i],
gt_bboxes[i],
gt_labels[i],
feats=[lvl_feat[i][None] for lvl_feat in x])
if gt_bboxes[i].numel() == 0:
sampling_result.pos_gt_bboxes = gt_bboxes[i].new(
(0, gt_bboxes[0].size(-1))).zero_()
else:
sampling_result.pos_gt_bboxes = \
gt_bboxes[i][sampling_result.pos_assigned_gt_inds, :]
sampling_results.append(sampling_result)
losses = dict()
# bbox head forward and loss
if self.with_bbox:
bbox_results = self._bbox_forward_train(x, sampling_results,
gt_bboxes, gt_labels,
img_metas)
losses.update(bbox_results['loss_bbox'])
return losses
def _bbox_forward_train(self, x, sampling_results, gt_bboxes, gt_labels,
img_metas):
"""Run forward function and calculate loss for box head in training.
Args:
x (list[Tensor]): list of multi-level img features.
sampling_results (list[Tensor]): list of sampling results.
gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
shape (num_gts, 5) in [cx, cy, w, h, a] format.
gt_labels (list[Tensor]): class indices corresponding to each box
img_metas (list[dict]): list of image info dict where each dict
has: 'img_shape', 'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
Returns:
dict[str, Tensor]: a dictionary of bbox_results.
"""
rois = rbbox2roi([res.bboxes for res in sampling_results])
bbox_results = self._bbox_forward(x, rois)
bbox_targets = self.bbox_head.get_targets(sampling_results, gt_bboxes,
gt_labels, self.train_cfg)
loss_bbox = self.bbox_head.loss(bbox_results['cls_score'],
bbox_results['bbox_pred'], rois,
*bbox_targets)
bbox_results.update(loss_bbox=loss_bbox)
return bbox_results
def simple_test_bboxes(self,
x,
img_metas,
proposals,
rcnn_test_cfg,
rescale=False):
"""Test only det bboxes without augmentation.
Args:
x (tuple[Tensor]): Feature maps of all scale level.
img_metas (list[dict]): Image meta info.
proposals (List[Tensor]): Region proposals.
rcnn_test_cfg (obj:`ConfigDict`): `test_cfg` of R-CNN.
rescale (bool): If True, return boxes in original image space.
Default: False.
Returns:
tuple[list[Tensor], list[Tensor]]: The first list contains \
the boxes of the corresponding image in a batch, each \
tensor has the shape (num_boxes, 6) and last dimension \
6 represent (cx, cy, w, h, a, score). Each Tensor \
in the second list is the labels with shape (num_boxes, ). \
The length of both lists should be equal to batch_size.
"""
rois = rbbox2roi(proposals)
bbox_results = self._bbox_forward(x, rois)
img_shapes = tuple(meta['img_shape'] for meta in img_metas)
scale_factors = tuple(meta['scale_factor'] for meta in img_metas)
# split batch bbox prediction back to each image
cls_score = bbox_results['cls_score']
bbox_pred = bbox_results['bbox_pred']
num_proposals_per_img = tuple(len(p) for p in proposals)
rois = rois.split(num_proposals_per_img, 0)
cls_score = cls_score.split(num_proposals_per_img, 0)
# some detector with_reg is False, bbox_pred will be None
if bbox_pred is not None:
# the bbox prediction of some detectors like SABL is not Tensor
if isinstance(bbox_pred, torch.Tensor):
bbox_pred = bbox_pred.split(num_proposals_per_img, 0)
else:
bbox_pred = self.bbox_head.bbox_pred_split(
bbox_pred, num_proposals_per_img)
else:
bbox_pred = (None, ) * len(proposals)
# apply bbox post-processing to each image individually
det_bboxes = []
det_labels = []
for i in range(len(proposals)):
det_bbox, det_label = self.bbox_head.get_bboxes(
rois[i],
cls_score[i],
bbox_pred[i],
img_shapes[i],
scale_factors[i],
rescale=rescale,
cfg=rcnn_test_cfg)
det_bboxes.append(det_bbox)
det_labels.append(det_label)
return det_bboxes, det_labels