This repository was archived by the owner on Mar 10, 2026. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 327
Expand file tree
/
Copy pathfaster_rcnn.py
More file actions
471 lines (436 loc) · 17.6 KB
/
faster_rcnn.py
File metadata and controls
471 lines (436 loc) · 17.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
# Copyright 2023 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import tree
from absl import logging
import keras_cv
from keras_cv import bounding_box
from keras_cv import layers as cv_layers
from keras_cv.api_export import keras_cv_export
from keras_cv.backend import keras
from keras_cv.backend import ops
from keras_cv.bounding_box.converters import _decode_deltas_to_boxes
# All the imports from legacy
from keras_cv.bounding_box.utils import _clip_boxes
from keras_cv.layers.object_detection.anchor_generator import AnchorGenerator
from keras_cv.layers.object_detection.box_matcher import BoxMatcher
from keras_cv.layers.object_detection.roi_align import _ROIAligner
from keras_cv.layers.object_detection.roi_generator import ROIGenerator
from keras_cv.layers.object_detection.roi_sampler import _ROISampler
from keras_cv.layers.object_detection.rpn_label_encoder import _RpnLabelEncoder
from keras_cv.models.backbones.backbone_presets import backbone_presets
from keras_cv.models.backbones.backbone_presets import (
backbone_presets_with_weights,
)
from keras_cv.models.object_detection.__internal__ import unpack_input
from keras_cv.models.object_detection.faster_rcnn.feature_pyramid import (
FeaturePyramid,
)
from keras_cv.models.object_detection.faster_rcnn.rcnn_head import RCNNHead
from keras_cv.models.object_detection.faster_rcnn.rpn_head import RPNHead
from keras_cv.models.task import Task
from keras_cv.utils.python_utils import classproperty
from keras_cv.utils.train import get_feature_extractor
BOX_VARIANCE = [0.1, 0.1, 0.2, 0.2]
# TODO(tanzheny): add more configurations
@keras_cv_export("keras_cv.models.FasterRCNN")
class FasterRCNN(Task):
"""A Keras model implementing the FasterRCNN architecture.
Implements the FasterRCNN architecture for object detection. The constructor
requires `num_classes`, `bounding_box_format` and a `backbone`.
References:
- [FasterRCNN](https://arxiv.org/pdf/1506.01497.pdf)
Args:
backbone: `keras.Model`. Must implement the
`pyramid_level_inputs` property with keys "P2", "P3", "P4", and "P5"
and layer names as values.
num_classes: the number of classes in your dataset excluding the
background class. classes should be represented by integers in the
range [0, num_classes).
bounding_box_format: The format of bounding boxes of model output. Refer
[to the keras.io docs](https://keras.io/api/keras_cv/bounding_box/formats/)
for more details on supported bounding box formats.
anchor_generator: (Optional) a `keras_cv.layers.AnchorGenerator`. It is
used in the model to match ground truth boxes and labels with
anchors, or with region proposals. By default it uses the sizes and
ratios from the paper, that is optimized for image size between
[640, 800]. The users should pass their own anchor generator if the
input image size differs from paper. For now, only anchor generator
with per level dict output is supported,
label_encoder: (Optional) a keras.Layer that accepts an anchors Tensor,
a bounding box Tensor and a bounding box class Tensor to its
`call()` method, and returns RetinaNet training targets. It returns
box and class targets as well as sample weights.
rcnn_head: (Optional) a `keras.layers.Layer` that takes input feature
map and returns a box delta prediction (in reference to rois) and
multi-class prediction (all foreground classes + one background
class). By default it uses the rcnn head from paper, which is 2 FC
layer with 1024 dimension, 1 box regressor and 1 softmax classifier.
prediction_decoder: (Optional) a `keras.layers.Layer` that takes input
box prediction and softmaxed score prediction, and returns NMSed box
prediction, NMSed softmaxed score prediction, NMSed class
prediction, and NMSed valid detection.
Examples:
```python
images = np.ones((1, 512, 512, 3))
labels = {
"boxes": [
[
[0, 0, 100, 100],
[100, 100, 200, 200],
[300, 300, 100, 100],
]
],
"classes": [[1, 1, 1]],
}
model = keras_cv.models.FasterRCNN(
num_classes=20,
bounding_box_format="xywh",
backbone=keras_cv.models.ResNet50Backbone.from_preset(
"resnet50_imagenet"
)
)
# Evaluate model without box decoding and NMS
model(images)
# Prediction with box decoding and NMS
model.predict(images)
# Train model
model.compile(
classification_loss='focal',
box_loss='smoothl1',
optimizer=keras.optimizers.SGD(global_clipnorm=10.0),
jit_compile=False,
)
model.fit(images, labels)
```
""" # noqa: E501
def __init__(
self,
backbone,
num_classes,
bounding_box_format,
anchor_generator=None,
label_encoder=None,
rcnn_head=None,
prediction_decoder=None,
**kwargs,
):
self.bounding_box_format = bounding_box_format
super().__init__(**kwargs)
scales = [2**x for x in [0]]
aspect_ratios = [0.5, 1.0, 2.0]
self.anchor_generator = anchor_generator or AnchorGenerator(
bounding_box_format="yxyx",
sizes={
"P2": 32.0,
"P3": 64.0,
"P4": 128.0,
"P5": 256.0,
"P6": 512.0,
},
scales=scales,
aspect_ratios=aspect_ratios,
strides={f"P{i}": 2**i for i in range(2, 7)},
clip_boxes=True,
)
self.rpn_head = RPNHead(
num_anchors_per_location=len(scales) * len(aspect_ratios)
)
self.roi_generator = ROIGenerator(
bounding_box_format="yxyx",
nms_score_threshold_train=float("-inf"),
nms_score_threshold_test=float("-inf"),
)
self.box_matcher = BoxMatcher(
thresholds=[0.0, 0.5], match_values=[-2, -1, 1]
)
self.roi_sampler = _ROISampler(
bounding_box_format="yxyx",
roi_matcher=self.box_matcher,
background_class=num_classes,
num_sampled_rois=512,
)
self.roi_pooler = _ROIAligner(bounding_box_format="yxyx")
self.rcnn_head = rcnn_head or RCNNHead(num_classes)
self.backbone = backbone or keras_cv.models.ResNet50Backbone()
extractor_levels = ["P2", "P3", "P4", "P5"]
extractor_layer_names = [
self.backbone.pyramid_level_inputs[i] for i in extractor_levels
]
self.feature_extractor = get_feature_extractor(
self.backbone, extractor_layer_names, extractor_levels
)
self.feature_pyramid = FeaturePyramid()
self.rpn_labeler = label_encoder or _RpnLabelEncoder(
anchor_format="yxyx",
ground_truth_box_format="yxyx",
positive_threshold=0.7,
negative_threshold=0.3,
samples_per_image=256,
positive_fraction=0.5,
box_variance=BOX_VARIANCE,
)
self._prediction_decoder = (
prediction_decoder
or cv_layers.NonMaxSuppression(
bounding_box_format=bounding_box_format,
from_logits=False,
iou_threshold=0.5,
confidence_threshold=0.5,
max_detections=100,
)
)
def _call_rpn(self, images, anchors, training=None):
image_shape = ops.shape(images[0])
backbone_outputs = self.feature_extractor(images, training=training)
feature_map = self.feature_pyramid(backbone_outputs, training=training)
# [BS, num_anchors, 4], [BS, num_anchors, 1]
rpn_boxes, rpn_scores = self.rpn_head(feature_map, training=training)
# the decoded format is center_xywh, convert to yxyx
decoded_rpn_boxes = _decode_deltas_to_boxes(
anchors=anchors,
boxes_delta=rpn_boxes,
anchor_format="yxyx",
box_format="yxyx",
variance=BOX_VARIANCE,
)
rois, _ = self.roi_generator(
decoded_rpn_boxes, rpn_scores, training=training
)
rois = _clip_boxes(rois, "yxyx", image_shape)
rpn_boxes = ops.concatenate(tree.flatten(rpn_boxes), axis=1)
rpn_scores = ops.concatenate(tree.flatten(rpn_scores), axis=1)
return rois, feature_map, rpn_boxes, rpn_scores
def _call_rcnn(self, rois, feature_map, training=None):
feature_map = self.roi_pooler(feature_map, rois)
# [BS, H*W*K, pool_shape*C]
feature_map = ops.reshape(
feature_map, ops.concatenate([ops.shape(rois)[:2], [-1]], axis=0)
)
# [BS, H*W*K, 4], [BS, H*W*K, num_classes + 1]
rcnn_box_pred, rcnn_cls_pred = self.rcnn_head(
feature_map, training=training
)
return rcnn_box_pred, rcnn_cls_pred
def call(self, images, training=None):
image_shape = ops.shape(images[0])
anchors = self.anchor_generator(image_shape=image_shape)
rois, feature_map, _, _ = self._call_rpn(
images, anchors, training=training
)
box_pred, cls_pred = self._call_rcnn(
rois, feature_map, training=training
)
if not training:
# box_pred is on "center_yxhw" format, convert to target format.
box_pred = _decode_deltas_to_boxes(
anchors=rois,
boxes_delta=box_pred,
anchor_format="yxyx",
box_format=self.bounding_box_format,
variance=[0.1, 0.1, 0.2, 0.2],
)
return box_pred, cls_pred
# TODO(tanzhenyu): Support compile with metrics.
def compile(
self,
box_loss=None,
classification_loss=None,
rpn_box_loss=None,
rpn_classification_loss=None,
weight_decay=0.0001,
loss=None,
**kwargs,
):
# TODO(tanzhenyu): Add metrics support once COCOMap issue is addressed.
# https://github.com/keras-team/keras-cv/issues/915
if "metrics" in kwargs.keys():
raise ValueError(
"`FasterRCNN` does not currently support the use of "
"`metrics` due to performance and distribution concerns. "
"Please use the `PyCOCOCallback` to evaluate COCO metrics."
)
if loss is not None:
raise ValueError(
"`FasterRCNN` does not accept a `loss` to `compile()`. "
"Instead, please pass `box_loss` and `classification_loss`. "
"`loss` will be ignored during training."
)
box_loss = _validate_and_get_loss(box_loss, "box_loss")
classification_loss = _validate_and_get_loss(
classification_loss, "classification_loss"
)
rpn_box_loss = _validate_and_get_loss(rpn_box_loss, "rpn_box_loss")
if rpn_classification_loss == "BinaryCrossentropy":
rpn_classification_loss = keras.losses.BinaryCrossentropy(
from_logits=True, reduction=keras.losses.Reduction.SUM
)
rpn_classification_loss = _validate_and_get_loss(
rpn_classification_loss, "rpn_cls_loss"
)
if not rpn_classification_loss.from_logits:
raise ValueError(
"`rpn_classification_loss` must come with `from_logits`=True"
)
self.rpn_box_loss = rpn_box_loss
self.rpn_cls_loss = rpn_classification_loss
self.box_loss = box_loss
self.cls_loss = classification_loss
self.weight_decay = weight_decay
losses = {
"box": self.box_loss,
"classification": self.cls_loss,
"rpn_box": self.rpn_box_loss,
"rpn_classification": self.rpn_cls_loss,
}
super().compile(loss=losses, **kwargs)
def compute_loss(self, images, boxes, classes, training):
local_batch = images.get_shape().as_list()[0]
anchors = self.anchor_generator(image_shape=tuple(images[0].shape))
(
rpn_box_targets,
rpn_box_weights,
rpn_cls_targets,
rpn_cls_weights,
) = self.rpn_labeler(
ops.concatenate(tree.flatten(anchors), axis=0), boxes, classes
)
rpn_box_weights /= (
self.rpn_labeler.samples_per_image * local_batch * 0.25
)
rpn_cls_weights /= self.rpn_labeler.samples_per_image * local_batch
rois, feature_map, rpn_box_pred, rpn_cls_pred = self._call_rpn(
images, anchors, training=training
)
rois = ops.stop_gradient(rois)
(
rois,
box_targets,
box_weights,
cls_targets,
cls_weights,
) = self.roi_sampler(rois, boxes, classes)
box_weights /= self.roi_sampler.num_sampled_rois * local_batch * 0.25
cls_weights /= self.roi_sampler.num_sampled_rois * local_batch
box_pred, cls_pred = self._call_rcnn(
rois, feature_map, training=training
)
y_true = {
"rpn_box": rpn_box_targets,
"rpn_classification": rpn_cls_targets,
"box": box_targets,
"classification": cls_targets,
}
y_pred = {
"rpn_box": rpn_box_pred,
"rpn_classification": rpn_cls_pred,
"box": box_pred,
"classification": cls_pred,
}
weights = {
"rpn_box": rpn_box_weights,
"rpn_classification": rpn_cls_weights,
"box": box_weights,
"classification": cls_weights,
}
return super().compute_loss(
x=images, y=y_true, y_pred=y_pred, sample_weight=weights
)
def train_step(self, *args):
data = args[-1]
args = args[:-1]
x, y = unpack_input(data)
return super().train_step(*args, (x, y))
def test_step(self, *args):
data = args[-1]
args = args[:-1]
x, y = unpack_input(data)
return super().test_step(*args, (x, y))
def predict_step(self, *args):
outputs = super().predict_step(*args)
if type(outputs) is tuple:
return self.decode_predictions(outputs[0], args[-1]), outputs[1]
else:
return self.decode_predictions(outputs, args[-1])
@property
def prediction_decoder(self):
return self._prediction_decoder
@prediction_decoder.setter
def prediction_decoder(self, prediction_decoder):
self._prediction_decoder = prediction_decoder
self.make_predict_function(force=True)
def decode_predictions(self, predictions, images):
# no-op if default decoder is used.
box_pred, scores_pred = predictions
box_pred = bounding_box.convert_format(
box_pred,
source=self.bounding_box_format,
target=self.prediction_decoder.bounding_box_format,
images=images,
)
y_pred = self.prediction_decoder(box_pred, scores_pred[..., :-1])
box_pred = bounding_box.convert_format(
y_pred["boxes"],
source=self.prediction_decoder.bounding_box_format,
target=self.bounding_box_format,
images=images,
)
y_pred["boxes"] = box_pred
return y_pred
def get_config(self):
return {
"num_classes": self.num_classes,
"bounding_box_format": self.bounding_box_format,
"backbone": self.backbone,
"anchor_generator": self.anchor_generator,
"label_encoder": self.rpn_labeler,
"prediction_decoder": self._prediction_decoder,
"feature_pyramid": self.feature_pyramid,
"rcnn_head": self.rcnn_head,
}
@classproperty
def presets(cls):
"""Dictionary of preset names and configurations."""
# return copy.deepcopy({**backbone_presets, **fasterrcnn_presets})
return copy.deepcopy({**backbone_presets})
@classproperty
def presets_with_weights(cls):
"""Dictionary of preset names and configurations that include
weights."""
return copy.deepcopy(
# {**backbone_presets_with_weights, **fasterrcnn_presets}
{
**backbone_presets_with_weights,
}
)
@classproperty
def backbone_presets(cls):
"""Dictionary of preset names and configurations of compatible
backbones."""
return copy.deepcopy(backbone_presets)
def _validate_and_get_loss(loss, loss_name):
if isinstance(loss, str):
loss = keras.losses.get(loss)
if loss is None or not isinstance(loss, keras.losses.Loss):
raise ValueError(
f"FasterRCNN only accepts `keras.losses.Loss` for {loss_name}, "
f"got {loss}"
)
if loss.reduction != keras.losses.Reduction.SUM:
logging.info(
f"FasterRCNN only accepts `SUM` reduction, got {loss.reduction}, "
"automatically converted."
)
loss.reduction = keras.losses.Reduction.SUM
return loss