Skip to content

Commit 1100c04

Browse files
No public description
PiperOrigin-RevId: 720640505
1 parent 5f3aa11 commit 1100c04

File tree

1 file changed

+121
-77
lines changed

1 file changed

+121
-77
lines changed

Diff for: official/projects/detr/tasks/detection.py

+121-77
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
"""DETR detection task definition."""
16+
1617
from typing import Optional
1718

1819
from absl import logging
@@ -47,21 +48,25 @@ class DetectionTask(base_task.Task):
4748
def build_model(self):
4849
"""Build DETR model."""
4950

50-
input_specs = tf_keras.layers.InputSpec(shape=[None] +
51-
self._task_config.model.input_size)
51+
input_specs = tf_keras.layers.InputSpec(
52+
shape=[None] + self._task_config.model.input_size
53+
)
5254

5355
backbone = backbones.factory.build_backbone(
5456
input_specs=input_specs,
5557
backbone_config=self._task_config.model.backbone,
56-
norm_activation_config=self._task_config.model.norm_activation)
57-
58-
model = detr.DETR(backbone,
59-
self._task_config.model.backbone_endpoint_name,
60-
self._task_config.model.num_queries,
61-
self._task_config.model.hidden_size,
62-
self._task_config.model.num_classes,
63-
self._task_config.model.num_encoder_layers,
64-
self._task_config.model.num_decoder_layers)
58+
norm_activation_config=self._task_config.model.norm_activation,
59+
)
60+
61+
model = detr.DETR(
62+
backbone,
63+
self._task_config.model.backbone_endpoint_name,
64+
self._task_config.model.num_queries,
65+
self._task_config.model.hidden_size,
66+
self._task_config.model.num_classes,
67+
self._task_config.model.num_encoder_layers,
68+
self._task_config.model.num_decoder_layers,
69+
)
6570
return model
6671

6772
def initialize(self, model: tf_keras.Model):
@@ -84,12 +89,13 @@ def initialize(self, model: tf_keras.Model):
8489
status = ckpt.restore(ckpt_dir_or_file)
8590
status.expect_partial().assert_existing_objects_matched()
8691

87-
logging.info('Finished loading pretrained checkpoint from %s',
88-
ckpt_dir_or_file)
92+
logging.info(
93+
'Finished loading pretrained checkpoint from %s', ckpt_dir_or_file
94+
)
8995

90-
def build_inputs(self,
91-
params,
92-
input_context: Optional[tf.distribute.InputContext] = None):
96+
def build_inputs(
97+
self, params, input_context: Optional[tf.distribute.InputContext] = None
98+
):
9399
"""Build input dataset."""
94100
if isinstance(params, coco.COCODataConfig):
95101
dataset = coco.COCODataLoader(params).load(input_context)
@@ -100,14 +106,17 @@ def build_inputs(self,
100106
decoder_cfg = params.decoder.get()
101107
if params.decoder.type == 'simple_decoder':
102108
decoder = tf_example_decoder.TfExampleDecoder(
103-
regenerate_source_id=decoder_cfg.regenerate_source_id)
109+
regenerate_source_id=decoder_cfg.regenerate_source_id
110+
)
104111
elif params.decoder.type == 'label_map_decoder':
105112
decoder = tf_example_label_map_decoder.TfExampleDecoderLabelMap(
106113
label_map=decoder_cfg.label_map,
107-
regenerate_source_id=decoder_cfg.regenerate_source_id)
114+
regenerate_source_id=decoder_cfg.regenerate_source_id,
115+
)
108116
else:
109-
raise ValueError('Unknown decoder type: {}!'.format(
110-
params.decoder.type))
117+
raise ValueError(
118+
'Unknown decoder type: {}!'.format(params.decoder.type)
119+
)
111120

112121
parser = detr_input.Parser(
113122
class_offset=self._task_config.losses.class_offset,
@@ -118,7 +127,8 @@ def build_inputs(self,
118127
params,
119128
dataset_fn=dataset_fn.pick_dataset_fn(params.file_type),
120129
decoder_fn=decoder.decode,
121-
parser_fn=parser.parse_fn(params.is_training))
130+
parser_fn=parser.parse_fn(params.is_training),
131+
)
122132
dataset = reader.read(input_context=input_context)
123133

124134
return dataset
@@ -128,35 +138,44 @@ def _compute_cost(self, cls_outputs, box_outputs, cls_targets, box_targets):
128138
# The 1 is a constant that doesn't change the matching, it can be ommitted.
129139
# background: 0
130140
cls_cost = self._task_config.losses.lambda_cls * tf.gather(
131-
-tf.nn.softmax(cls_outputs), cls_targets, batch_dims=1, axis=-1)
141+
-tf.nn.softmax(cls_outputs), cls_targets, batch_dims=1, axis=-1
142+
)
132143

133144
# Compute the L1 cost between boxes,
134145
paired_differences = self._task_config.losses.lambda_box * tf.abs(
135-
tf.expand_dims(box_outputs, 2) - tf.expand_dims(box_targets, 1))
146+
tf.expand_dims(box_outputs, 2) - tf.expand_dims(box_targets, 1)
147+
)
136148
box_cost = tf.reduce_sum(paired_differences, axis=-1)
137149

138150
# Compute the giou cost betwen boxes
139-
giou_cost = self._task_config.losses.lambda_giou * -box_ops.bbox_generalized_overlap(
140-
box_ops.cycxhw_to_yxyx(box_outputs),
141-
box_ops.cycxhw_to_yxyx(box_targets))
151+
giou_cost = (
152+
self._task_config.losses.lambda_giou
153+
* -box_ops.bbox_generalized_overlap(
154+
box_ops.cycxhw_to_yxyx(box_outputs),
155+
box_ops.cycxhw_to_yxyx(box_targets),
156+
)
157+
)
142158

143159
total_cost = cls_cost + box_cost + giou_cost
144160

145161
max_cost = (
146-
self._task_config.losses.lambda_cls * 0.0 +
147-
self._task_config.losses.lambda_box * 4. +
148-
self._task_config.losses.lambda_giou * 0.0)
162+
self._task_config.losses.lambda_cls * 0.0
163+
+ self._task_config.losses.lambda_box * 4.0
164+
+ self._task_config.losses.lambda_giou * 0.0
165+
)
149166

150167
# Set pads to large constant
151168
valid = tf.expand_dims(
152-
tf.cast(tf.not_equal(cls_targets, 0), dtype=total_cost.dtype), axis=1)
169+
tf.cast(tf.not_equal(cls_targets, 0), dtype=total_cost.dtype), axis=1
170+
)
153171
total_cost = (1 - valid) * max_cost + valid * total_cost
154172

155173
# Set inf of nan to large constant
156174
total_cost = tf.where(
157175
tf.logical_or(tf.math.is_nan(total_cost), tf.math.is_inf(total_cost)),
158176
max_cost * tf.ones_like(total_cost, dtype=total_cost.dtype),
159-
total_cost)
177+
total_cost,
178+
)
160179

161180
return total_cost
162181

@@ -168,7 +187,8 @@ def build_losses(self, outputs, labels, aux_losses=None):
168187
box_targets = labels['boxes']
169188

170189
cost = self._compute_cost(
171-
cls_outputs, box_outputs, cls_targets, box_targets)
190+
cls_outputs, box_outputs, cls_targets, box_targets
191+
)
172192

173193
_, indices = matchers.hungarian_matching(cost)
174194
indices = tf.stop_gradient(indices)
@@ -179,45 +199,53 @@ def build_losses(self, outputs, labels, aux_losses=None):
179199

180200
background = tf.equal(cls_targets, 0)
181201
num_boxes = tf.reduce_sum(
182-
tf.cast(tf.logical_not(background), tf.float32), axis=-1)
202+
tf.cast(tf.logical_not(background), tf.float32), axis=-1
203+
)
183204

184205
# Down-weight background to account for class imbalance.
185206
xentropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
186-
labels=cls_targets, logits=cls_assigned)
207+
labels=cls_targets, logits=cls_assigned
208+
)
187209
cls_loss = self._task_config.losses.lambda_cls * tf.where(
188-
background, self._task_config.losses.background_cls_weight * xentropy,
189-
xentropy)
210+
background,
211+
self._task_config.losses.background_cls_weight * xentropy,
212+
xentropy,
213+
)
190214
cls_weights = tf.where(
191215
background,
192216
self._task_config.losses.background_cls_weight * tf.ones_like(cls_loss),
193-
tf.ones_like(cls_loss))
217+
tf.ones_like(cls_loss),
218+
)
194219

195220
# Box loss is only calculated on non-background class.
196221
l_1 = tf.reduce_sum(tf.abs(box_assigned - box_targets), axis=-1)
197222
box_loss = self._task_config.losses.lambda_box * tf.where(
198-
background, tf.zeros_like(l_1), l_1)
223+
background, tf.zeros_like(l_1), l_1
224+
)
199225

200226
# Giou loss is only calculated on non-background class.
201-
giou = tf.linalg.diag_part(1.0 - box_ops.bbox_generalized_overlap(
202-
box_ops.cycxhw_to_yxyx(box_assigned),
203-
box_ops.cycxhw_to_yxyx(box_targets)
204-
))
227+
giou = tf.linalg.diag_part(
228+
1.0
229+
- box_ops.bbox_generalized_overlap(
230+
box_ops.cycxhw_to_yxyx(box_assigned),
231+
box_ops.cycxhw_to_yxyx(box_targets),
232+
)
233+
)
205234
giou_loss = self._task_config.losses.lambda_giou * tf.where(
206-
background, tf.zeros_like(giou), giou)
235+
background, tf.zeros_like(giou), giou
236+
)
207237

208238
# Consider doing all reduce once in train_step to speed up.
209239
num_boxes_per_replica = tf.reduce_sum(num_boxes)
210240
cls_weights_per_replica = tf.reduce_sum(cls_weights)
211241
replica_context = tf.distribute.get_replica_context()
212242
num_boxes_sum, cls_weights_sum = replica_context.all_reduce(
213243
tf.distribute.ReduceOp.SUM,
214-
[num_boxes_per_replica, cls_weights_per_replica])
215-
cls_loss = tf.math.divide_no_nan(
216-
tf.reduce_sum(cls_loss), cls_weights_sum)
217-
box_loss = tf.math.divide_no_nan(
218-
tf.reduce_sum(box_loss), num_boxes_sum)
219-
giou_loss = tf.math.divide_no_nan(
220-
tf.reduce_sum(giou_loss), num_boxes_sum)
244+
[num_boxes_per_replica, cls_weights_per_replica],
245+
)
246+
cls_loss = tf.math.divide_no_nan(tf.reduce_sum(cls_loss), cls_weights_sum)
247+
box_loss = tf.math.divide_no_nan(tf.reduce_sum(box_loss), num_boxes_sum)
248+
giou_loss = tf.math.divide_no_nan(tf.reduce_sum(giou_loss), num_boxes_sum)
221249

222250
aux_losses = tf.add_n(aux_losses) if aux_losses else 0.0
223251

@@ -236,7 +264,8 @@ def build_metrics(self, training=True):
236264
annotation_file=self._task_config.annotation_file,
237265
include_mask=False,
238266
need_rescale_bboxes=True,
239-
per_category_metrics=self._task_config.per_category_metrics)
267+
per_category_metrics=self._task_config.per_category_metrics,
268+
)
240269
return metrics
241270

242271
def train_step(self, inputs, model, optimizer, metrics=None):
@@ -262,8 +291,11 @@ def train_step(self, inputs, model, optimizer, metrics=None):
262291

263292
for output in outputs:
264293
# Computes per-replica loss.
265-
layer_loss, layer_cls_loss, layer_box_loss, layer_giou_loss = self.build_losses(
266-
outputs=output, labels=labels, aux_losses=model.losses)
294+
layer_loss, layer_cls_loss, layer_box_loss, layer_giou_loss = (
295+
self.build_losses(
296+
outputs=output, labels=labels, aux_losses=model.losses
297+
)
298+
)
267299
loss += layer_loss
268300
cls_loss += layer_cls_loss
269301
box_loss += layer_box_loss
@@ -323,7 +355,8 @@ def validation_step(self, inputs, model, metrics=None):
323355

324356
outputs = model(features, training=False)[-1]
325357
loss, cls_loss, box_loss, giou_loss = self.build_losses(
326-
outputs=outputs, labels=labels, aux_losses=model.losses)
358+
outputs=outputs, labels=labels, aux_losses=model.losses
359+
)
327360

328361
# Multiply for logging.
329362
# Since we expect the gradient replica sum to happen in the optimizer,
@@ -341,35 +374,46 @@ def validation_step(self, inputs, model, metrics=None):
341374
# This is for backward compatibility.
342375
if 'detection_boxes' not in outputs:
343376
detection_boxes = box_ops.cycxhw_to_yxyx(
344-
outputs['box_outputs']) * tf.expand_dims(
345-
tf.concat([
346-
labels['image_info'][:, 1:2, 0], labels['image_info'][:, 1:2,
347-
1],
348-
labels['image_info'][:, 1:2, 0], labels['image_info'][:, 1:2,
349-
1]
377+
outputs['box_outputs']
378+
) * tf.expand_dims(
379+
tf.concat(
380+
[
381+
labels['image_info'][:, 1:2, 0],
382+
labels['image_info'][:, 1:2, 1],
383+
labels['image_info'][:, 1:2, 0],
384+
labels['image_info'][:, 1:2, 1],
350385
],
351-
axis=1),
352-
axis=1)
386+
axis=1,
387+
),
388+
axis=1,
389+
)
353390
else:
354391
detection_boxes = outputs['detection_boxes']
355392

356-
detection_scores = tf.math.reduce_max(
357-
tf.nn.softmax(outputs['cls_outputs'])[:, :, 1:], axis=-1
358-
) if 'detection_scores' not in outputs else outputs['detection_scores']
393+
if 'detection_scores' not in outputs:
394+
detection_scores = tf.math.reduce_max(
395+
tf.nn.softmax(outputs['cls_outputs'])[:, :, 1:], axis=-1
396+
)
397+
else:
398+
detection_scores = outputs['detection_scores']
359399

360400
if 'detection_classes' not in outputs:
361-
detection_classes = tf.math.argmax(
362-
outputs['cls_outputs'][:, :, 1:], axis=-1) + 1
401+
detection_classes = (
402+
tf.math.argmax(outputs['cls_outputs'][:, :, 1:], axis=-1) + 1
403+
)
363404
else:
364405
detection_classes = outputs['detection_classes']
365406

366407
if 'num_detections' not in outputs:
367408
num_detections = tf.reduce_sum(
368409
tf.cast(
369410
tf.math.greater(
370-
tf.math.reduce_max(outputs['cls_outputs'], axis=-1), 0),
371-
tf.int32),
372-
axis=-1)
411+
tf.math.reduce_max(outputs['cls_outputs'], axis=-1), 0
412+
),
413+
tf.int32,
414+
),
415+
axis=-1,
416+
)
373417
else:
374418
num_detections = outputs['num_detections']
375419

@@ -379,21 +423,21 @@ def validation_step(self, inputs, model, metrics=None):
379423
'detection_classes': detection_classes,
380424
'num_detections': num_detections,
381425
'source_id': labels['id'],
382-
'image_info': labels['image_info']
426+
'image_info': labels['image_info'],
383427
}
384428

385429
ground_truths = {
386430
'source_id': labels['id'],
387431
'height': labels['image_info'][:, 0:1, 0],
388432
'width': labels['image_info'][:, 0:1, 1],
389433
'num_detections': tf.reduce_sum(
390-
tf.cast(tf.math.greater(labels['classes'], 0), tf.int32), axis=-1),
434+
tf.cast(tf.math.greater(labels['classes'], 0), tf.int32), axis=-1
435+
),
391436
'boxes': labels['gt_boxes'],
392437
'classes': labels['classes'],
393-
'is_crowds': labels['is_crowd']
438+
'is_crowds': labels['is_crowd'],
394439
}
395-
logs.update({'predictions': predictions,
396-
'ground_truths': ground_truths})
440+
logs.update({'predictions': predictions, 'ground_truths': ground_truths})
397441

398442
all_losses = {
399443
'cls_loss': cls_loss,
@@ -413,8 +457,8 @@ def aggregate_logs(self, state=None, step_outputs=None):
413457
state = self.coco_metric
414458

415459
state.update_state(
416-
step_outputs['ground_truths'],
417-
step_outputs['predictions'])
460+
step_outputs['ground_truths'], step_outputs['predictions']
461+
)
418462
return state
419463

420464
def reduce_aggregated_logs(self, aggregated_logs, global_step=None):

0 commit comments

Comments
 (0)