Skip to content

Commit db50116

Browse files
Guide shape inference in sorted_non_max_suppression_padded.
PiperOrigin-RevId: 513848412
1 parent b645ada commit db50116

File tree

2 files changed

+8
-6
lines changed

2 files changed

+8
-6
lines changed

Diff for: official/legacy/detection/ops/nms.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,9 @@ def _suppression_loop_body(boxes, iou_threshold, output_size, idx):
6767
output_size: the updated output_size.
6868
idx: the updated induction variable.
6969
"""
70-
num_tiles = tf.shape(boxes)[1] // NMS_TILE_SIZE
71-
batch_size = tf.shape(boxes)[0]
70+
boxes_shape = tf.shape(boxes)
71+
num_tiles = boxes_shape[1] // NMS_TILE_SIZE
72+
batch_size = boxes_shape[0]
7273

7374
# Iterates over tiles that can possibly suppress the current tile.
7475
box_slice = tf.slice(boxes, [0, idx * NMS_TILE_SIZE, 0],
@@ -97,7 +98,7 @@ def _suppression_loop_body(boxes, iou_threshold, output_size, idx):
9798
boxes = tf.tile(tf.expand_dims(
9899
box_slice, [1]), [1, num_tiles, 1, 1]) * mask + tf.reshape(
99100
boxes, [batch_size, num_tiles, NMS_TILE_SIZE, 4]) * (1 - mask)
100-
boxes = tf.reshape(boxes, [batch_size, -1, 4])
101+
boxes = tf.reshape(boxes, boxes_shape)
101102

102103
# Updates output_size.
103104
output_size += tf.reduce_sum(

Diff for: official/vision/ops/nms.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,9 @@ def _suppression_loop_body(boxes, iou_threshold, output_size, idx):
6565
output_size: the updated output_size.
6666
idx: the updated induction variable.
6767
"""
68-
num_tiles = tf.shape(boxes)[1] // NMS_TILE_SIZE
69-
batch_size = tf.shape(boxes)[0]
68+
boxes_shape = tf.shape(boxes)
69+
num_tiles = boxes_shape[1] // NMS_TILE_SIZE
70+
batch_size = boxes_shape[0]
7071

7172
# Iterates over tiles that can possibly suppress the current tile.
7273
box_slice = tf.slice(boxes, [0, idx * NMS_TILE_SIZE, 0],
@@ -95,7 +96,7 @@ def _suppression_loop_body(boxes, iou_threshold, output_size, idx):
9596
boxes = tf.tile(tf.expand_dims(
9697
box_slice, [1]), [1, num_tiles, 1, 1]) * mask + tf.reshape(
9798
boxes, [batch_size, num_tiles, NMS_TILE_SIZE, 4]) * (1 - mask)
98-
boxes = tf.reshape(boxes, [batch_size, -1, 4])
99+
boxes = tf.reshape(boxes, boxes_shape)
99100

100101
# Updates output_size.
101102
output_size += tf.reduce_sum(

0 commit comments

Comments
 (0)