13
13
# limitations under the License.
14
14
15
15
"""DETR detection task definition."""
16
+
16
17
from typing import Optional
17
18
18
19
from absl import logging
@@ -47,21 +48,25 @@ class DetectionTask(base_task.Task):
47
48
def build_model (self ):
48
49
"""Build DETR model."""
49
50
50
- input_specs = tf_keras .layers .InputSpec (shape = [None ] +
51
- self ._task_config .model .input_size )
51
+ input_specs = tf_keras .layers .InputSpec (
52
+ shape = [None ] + self ._task_config .model .input_size
53
+ )
52
54
53
55
backbone = backbones .factory .build_backbone (
54
56
input_specs = input_specs ,
55
57
backbone_config = self ._task_config .model .backbone ,
56
- norm_activation_config = self ._task_config .model .norm_activation )
57
-
58
- model = detr .DETR (backbone ,
59
- self ._task_config .model .backbone_endpoint_name ,
60
- self ._task_config .model .num_queries ,
61
- self ._task_config .model .hidden_size ,
62
- self ._task_config .model .num_classes ,
63
- self ._task_config .model .num_encoder_layers ,
64
- self ._task_config .model .num_decoder_layers )
58
+ norm_activation_config = self ._task_config .model .norm_activation ,
59
+ )
60
+
61
+ model = detr .DETR (
62
+ backbone ,
63
+ self ._task_config .model .backbone_endpoint_name ,
64
+ self ._task_config .model .num_queries ,
65
+ self ._task_config .model .hidden_size ,
66
+ self ._task_config .model .num_classes ,
67
+ self ._task_config .model .num_encoder_layers ,
68
+ self ._task_config .model .num_decoder_layers ,
69
+ )
65
70
return model
66
71
67
72
def initialize (self , model : tf_keras .Model ):
@@ -84,12 +89,13 @@ def initialize(self, model: tf_keras.Model):
84
89
status = ckpt .restore (ckpt_dir_or_file )
85
90
status .expect_partial ().assert_existing_objects_matched ()
86
91
87
- logging .info ('Finished loading pretrained checkpoint from %s' ,
88
- ckpt_dir_or_file )
92
+ logging .info (
93
+ 'Finished loading pretrained checkpoint from %s' , ckpt_dir_or_file
94
+ )
89
95
90
- def build_inputs (self ,
91
- params ,
92
- input_context : Optional [ tf . distribute . InputContext ] = None ):
96
+ def build_inputs (
97
+ self , params , input_context : Optional [ tf . distribute . InputContext ] = None
98
+ ):
93
99
"""Build input dataset."""
94
100
if isinstance (params , coco .COCODataConfig ):
95
101
dataset = coco .COCODataLoader (params ).load (input_context )
@@ -100,14 +106,17 @@ def build_inputs(self,
100
106
decoder_cfg = params .decoder .get ()
101
107
if params .decoder .type == 'simple_decoder' :
102
108
decoder = tf_example_decoder .TfExampleDecoder (
103
- regenerate_source_id = decoder_cfg .regenerate_source_id )
109
+ regenerate_source_id = decoder_cfg .regenerate_source_id
110
+ )
104
111
elif params .decoder .type == 'label_map_decoder' :
105
112
decoder = tf_example_label_map_decoder .TfExampleDecoderLabelMap (
106
113
label_map = decoder_cfg .label_map ,
107
- regenerate_source_id = decoder_cfg .regenerate_source_id )
114
+ regenerate_source_id = decoder_cfg .regenerate_source_id ,
115
+ )
108
116
else :
109
- raise ValueError ('Unknown decoder type: {}!' .format (
110
- params .decoder .type ))
117
+ raise ValueError (
118
+ 'Unknown decoder type: {}!' .format (params .decoder .type )
119
+ )
111
120
112
121
parser = detr_input .Parser (
113
122
class_offset = self ._task_config .losses .class_offset ,
@@ -118,7 +127,8 @@ def build_inputs(self,
118
127
params ,
119
128
dataset_fn = dataset_fn .pick_dataset_fn (params .file_type ),
120
129
decoder_fn = decoder .decode ,
121
- parser_fn = parser .parse_fn (params .is_training ))
130
+ parser_fn = parser .parse_fn (params .is_training ),
131
+ )
122
132
dataset = reader .read (input_context = input_context )
123
133
124
134
return dataset
@@ -128,35 +138,44 @@ def _compute_cost(self, cls_outputs, box_outputs, cls_targets, box_targets):
128
138
# The 1 is a constant that doesn't change the matching, it can be ommitted.
129
139
# background: 0
130
140
cls_cost = self ._task_config .losses .lambda_cls * tf .gather (
131
- - tf .nn .softmax (cls_outputs ), cls_targets , batch_dims = 1 , axis = - 1 )
141
+ - tf .nn .softmax (cls_outputs ), cls_targets , batch_dims = 1 , axis = - 1
142
+ )
132
143
133
144
# Compute the L1 cost between boxes,
134
145
paired_differences = self ._task_config .losses .lambda_box * tf .abs (
135
- tf .expand_dims (box_outputs , 2 ) - tf .expand_dims (box_targets , 1 ))
146
+ tf .expand_dims (box_outputs , 2 ) - tf .expand_dims (box_targets , 1 )
147
+ )
136
148
box_cost = tf .reduce_sum (paired_differences , axis = - 1 )
137
149
138
150
# Compute the giou cost betwen boxes
139
- giou_cost = self ._task_config .losses .lambda_giou * - box_ops .bbox_generalized_overlap (
140
- box_ops .cycxhw_to_yxyx (box_outputs ),
141
- box_ops .cycxhw_to_yxyx (box_targets ))
151
+ giou_cost = (
152
+ self ._task_config .losses .lambda_giou
153
+ * - box_ops .bbox_generalized_overlap (
154
+ box_ops .cycxhw_to_yxyx (box_outputs ),
155
+ box_ops .cycxhw_to_yxyx (box_targets ),
156
+ )
157
+ )
142
158
143
159
total_cost = cls_cost + box_cost + giou_cost
144
160
145
161
max_cost = (
146
- self ._task_config .losses .lambda_cls * 0.0 +
147
- self ._task_config .losses .lambda_box * 4. +
148
- self ._task_config .losses .lambda_giou * 0.0 )
162
+ self ._task_config .losses .lambda_cls * 0.0
163
+ + self ._task_config .losses .lambda_box * 4.0
164
+ + self ._task_config .losses .lambda_giou * 0.0
165
+ )
149
166
150
167
# Set pads to large constant
151
168
valid = tf .expand_dims (
152
- tf .cast (tf .not_equal (cls_targets , 0 ), dtype = total_cost .dtype ), axis = 1 )
169
+ tf .cast (tf .not_equal (cls_targets , 0 ), dtype = total_cost .dtype ), axis = 1
170
+ )
153
171
total_cost = (1 - valid ) * max_cost + valid * total_cost
154
172
155
173
# Set inf of nan to large constant
156
174
total_cost = tf .where (
157
175
tf .logical_or (tf .math .is_nan (total_cost ), tf .math .is_inf (total_cost )),
158
176
max_cost * tf .ones_like (total_cost , dtype = total_cost .dtype ),
159
- total_cost )
177
+ total_cost ,
178
+ )
160
179
161
180
return total_cost
162
181
@@ -168,7 +187,8 @@ def build_losses(self, outputs, labels, aux_losses=None):
168
187
box_targets = labels ['boxes' ]
169
188
170
189
cost = self ._compute_cost (
171
- cls_outputs , box_outputs , cls_targets , box_targets )
190
+ cls_outputs , box_outputs , cls_targets , box_targets
191
+ )
172
192
173
193
_ , indices = matchers .hungarian_matching (cost )
174
194
indices = tf .stop_gradient (indices )
@@ -179,45 +199,53 @@ def build_losses(self, outputs, labels, aux_losses=None):
179
199
180
200
background = tf .equal (cls_targets , 0 )
181
201
num_boxes = tf .reduce_sum (
182
- tf .cast (tf .logical_not (background ), tf .float32 ), axis = - 1 )
202
+ tf .cast (tf .logical_not (background ), tf .float32 ), axis = - 1
203
+ )
183
204
184
205
# Down-weight background to account for class imbalance.
185
206
xentropy = tf .nn .sparse_softmax_cross_entropy_with_logits (
186
- labels = cls_targets , logits = cls_assigned )
207
+ labels = cls_targets , logits = cls_assigned
208
+ )
187
209
cls_loss = self ._task_config .losses .lambda_cls * tf .where (
188
- background , self ._task_config .losses .background_cls_weight * xentropy ,
189
- xentropy )
210
+ background ,
211
+ self ._task_config .losses .background_cls_weight * xentropy ,
212
+ xentropy ,
213
+ )
190
214
cls_weights = tf .where (
191
215
background ,
192
216
self ._task_config .losses .background_cls_weight * tf .ones_like (cls_loss ),
193
- tf .ones_like (cls_loss ))
217
+ tf .ones_like (cls_loss ),
218
+ )
194
219
195
220
# Box loss is only calculated on non-background class.
196
221
l_1 = tf .reduce_sum (tf .abs (box_assigned - box_targets ), axis = - 1 )
197
222
box_loss = self ._task_config .losses .lambda_box * tf .where (
198
- background , tf .zeros_like (l_1 ), l_1 )
223
+ background , tf .zeros_like (l_1 ), l_1
224
+ )
199
225
200
226
# Giou loss is only calculated on non-background class.
201
- giou = tf .linalg .diag_part (1.0 - box_ops .bbox_generalized_overlap (
202
- box_ops .cycxhw_to_yxyx (box_assigned ),
203
- box_ops .cycxhw_to_yxyx (box_targets )
204
- ))
227
+ giou = tf .linalg .diag_part (
228
+ 1.0
229
+ - box_ops .bbox_generalized_overlap (
230
+ box_ops .cycxhw_to_yxyx (box_assigned ),
231
+ box_ops .cycxhw_to_yxyx (box_targets ),
232
+ )
233
+ )
205
234
giou_loss = self ._task_config .losses .lambda_giou * tf .where (
206
- background , tf .zeros_like (giou ), giou )
235
+ background , tf .zeros_like (giou ), giou
236
+ )
207
237
208
238
# Consider doing all reduce once in train_step to speed up.
209
239
num_boxes_per_replica = tf .reduce_sum (num_boxes )
210
240
cls_weights_per_replica = tf .reduce_sum (cls_weights )
211
241
replica_context = tf .distribute .get_replica_context ()
212
242
num_boxes_sum , cls_weights_sum = replica_context .all_reduce (
213
243
tf .distribute .ReduceOp .SUM ,
214
- [num_boxes_per_replica , cls_weights_per_replica ])
215
- cls_loss = tf .math .divide_no_nan (
216
- tf .reduce_sum (cls_loss ), cls_weights_sum )
217
- box_loss = tf .math .divide_no_nan (
218
- tf .reduce_sum (box_loss ), num_boxes_sum )
219
- giou_loss = tf .math .divide_no_nan (
220
- tf .reduce_sum (giou_loss ), num_boxes_sum )
244
+ [num_boxes_per_replica , cls_weights_per_replica ],
245
+ )
246
+ cls_loss = tf .math .divide_no_nan (tf .reduce_sum (cls_loss ), cls_weights_sum )
247
+ box_loss = tf .math .divide_no_nan (tf .reduce_sum (box_loss ), num_boxes_sum )
248
+ giou_loss = tf .math .divide_no_nan (tf .reduce_sum (giou_loss ), num_boxes_sum )
221
249
222
250
aux_losses = tf .add_n (aux_losses ) if aux_losses else 0.0
223
251
@@ -236,7 +264,8 @@ def build_metrics(self, training=True):
236
264
annotation_file = self ._task_config .annotation_file ,
237
265
include_mask = False ,
238
266
need_rescale_bboxes = True ,
239
- per_category_metrics = self ._task_config .per_category_metrics )
267
+ per_category_metrics = self ._task_config .per_category_metrics ,
268
+ )
240
269
return metrics
241
270
242
271
def train_step (self , inputs , model , optimizer , metrics = None ):
@@ -262,8 +291,11 @@ def train_step(self, inputs, model, optimizer, metrics=None):
262
291
263
292
for output in outputs :
264
293
# Computes per-replica loss.
265
- layer_loss , layer_cls_loss , layer_box_loss , layer_giou_loss = self .build_losses (
266
- outputs = output , labels = labels , aux_losses = model .losses )
294
+ layer_loss , layer_cls_loss , layer_box_loss , layer_giou_loss = (
295
+ self .build_losses (
296
+ outputs = output , labels = labels , aux_losses = model .losses
297
+ )
298
+ )
267
299
loss += layer_loss
268
300
cls_loss += layer_cls_loss
269
301
box_loss += layer_box_loss
@@ -323,7 +355,8 @@ def validation_step(self, inputs, model, metrics=None):
323
355
324
356
outputs = model (features , training = False )[- 1 ]
325
357
loss , cls_loss , box_loss , giou_loss = self .build_losses (
326
- outputs = outputs , labels = labels , aux_losses = model .losses )
358
+ outputs = outputs , labels = labels , aux_losses = model .losses
359
+ )
327
360
328
361
# Multiply for logging.
329
362
# Since we expect the gradient replica sum to happen in the optimizer,
@@ -341,35 +374,46 @@ def validation_step(self, inputs, model, metrics=None):
341
374
# This is for backward compatibility.
342
375
if 'detection_boxes' not in outputs :
343
376
detection_boxes = box_ops .cycxhw_to_yxyx (
344
- outputs ['box_outputs' ]) * tf .expand_dims (
345
- tf .concat ([
346
- labels ['image_info' ][:, 1 :2 , 0 ], labels ['image_info' ][:, 1 :2 ,
347
- 1 ],
348
- labels ['image_info' ][:, 1 :2 , 0 ], labels ['image_info' ][:, 1 :2 ,
349
- 1 ]
377
+ outputs ['box_outputs' ]
378
+ ) * tf .expand_dims (
379
+ tf .concat (
380
+ [
381
+ labels ['image_info' ][:, 1 :2 , 0 ],
382
+ labels ['image_info' ][:, 1 :2 , 1 ],
383
+ labels ['image_info' ][:, 1 :2 , 0 ],
384
+ labels ['image_info' ][:, 1 :2 , 1 ],
350
385
],
351
- axis = 1 ),
352
- axis = 1 )
386
+ axis = 1 ,
387
+ ),
388
+ axis = 1 ,
389
+ )
353
390
else :
354
391
detection_boxes = outputs ['detection_boxes' ]
355
392
356
- detection_scores = tf .math .reduce_max (
357
- tf .nn .softmax (outputs ['cls_outputs' ])[:, :, 1 :], axis = - 1
358
- ) if 'detection_scores' not in outputs else outputs ['detection_scores' ]
393
+ if 'detection_scores' not in outputs :
394
+ detection_scores = tf .math .reduce_max (
395
+ tf .nn .softmax (outputs ['cls_outputs' ])[:, :, 1 :], axis = - 1
396
+ )
397
+ else :
398
+ detection_scores = outputs ['detection_scores' ]
359
399
360
400
if 'detection_classes' not in outputs :
361
- detection_classes = tf .math .argmax (
362
- outputs ['cls_outputs' ][:, :, 1 :], axis = - 1 ) + 1
401
+ detection_classes = (
402
+ tf .math .argmax (outputs ['cls_outputs' ][:, :, 1 :], axis = - 1 ) + 1
403
+ )
363
404
else :
364
405
detection_classes = outputs ['detection_classes' ]
365
406
366
407
if 'num_detections' not in outputs :
367
408
num_detections = tf .reduce_sum (
368
409
tf .cast (
369
410
tf .math .greater (
370
- tf .math .reduce_max (outputs ['cls_outputs' ], axis = - 1 ), 0 ),
371
- tf .int32 ),
372
- axis = - 1 )
411
+ tf .math .reduce_max (outputs ['cls_outputs' ], axis = - 1 ), 0
412
+ ),
413
+ tf .int32 ,
414
+ ),
415
+ axis = - 1 ,
416
+ )
373
417
else :
374
418
num_detections = outputs ['num_detections' ]
375
419
@@ -379,21 +423,21 @@ def validation_step(self, inputs, model, metrics=None):
379
423
'detection_classes' : detection_classes ,
380
424
'num_detections' : num_detections ,
381
425
'source_id' : labels ['id' ],
382
- 'image_info' : labels ['image_info' ]
426
+ 'image_info' : labels ['image_info' ],
383
427
}
384
428
385
429
ground_truths = {
386
430
'source_id' : labels ['id' ],
387
431
'height' : labels ['image_info' ][:, 0 :1 , 0 ],
388
432
'width' : labels ['image_info' ][:, 0 :1 , 1 ],
389
433
'num_detections' : tf .reduce_sum (
390
- tf .cast (tf .math .greater (labels ['classes' ], 0 ), tf .int32 ), axis = - 1 ),
434
+ tf .cast (tf .math .greater (labels ['classes' ], 0 ), tf .int32 ), axis = - 1
435
+ ),
391
436
'boxes' : labels ['gt_boxes' ],
392
437
'classes' : labels ['classes' ],
393
- 'is_crowds' : labels ['is_crowd' ]
438
+ 'is_crowds' : labels ['is_crowd' ],
394
439
}
395
- logs .update ({'predictions' : predictions ,
396
- 'ground_truths' : ground_truths })
440
+ logs .update ({'predictions' : predictions , 'ground_truths' : ground_truths })
397
441
398
442
all_losses = {
399
443
'cls_loss' : cls_loss ,
@@ -413,8 +457,8 @@ def aggregate_logs(self, state=None, step_outputs=None):
413
457
state = self .coco_metric
414
458
415
459
state .update_state (
416
- step_outputs ['ground_truths' ],
417
- step_outputs [ 'predictions' ] )
460
+ step_outputs ['ground_truths' ], step_outputs [ 'predictions' ]
461
+ )
418
462
return state
419
463
420
464
def reduce_aggregated_logs (self , aggregated_logs , global_step = None ):
0 commit comments