Skip to content

Commit ac15b94

Browse files
TillBeemelmannsTill Beemelmanns
and
Till Beemelmanns
authored
Bug Fix for CenterNetBoxLoss (#2432)
* fix CenterNetBoxLoss, add Testcase * cleanup testcase * test case code clean up * fix Pytorch pipeline * make test_heading_regression_loss framework agnostic --------- Co-authored-by: Till Beemelmanns <[email protected]>
1 parent dee0634 commit ac15b94

File tree

2 files changed

+30
-4
lines changed

2 files changed

+30
-4
lines changed

keras_cv/src/losses/centernet_box_loss.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -65,14 +65,14 @@ def __init__(self, num_heading_bins, anchor_size, **kwargs):
6565
self.anchor_size = anchor_size
6666

6767
def heading_regression_loss(self, heading_true, heading_pred):
68+
heading_pred = ops.convert_to_tensor(heading_pred)
69+
6870
# Set the heading to within 0 -> 2pi
69-
heading_true = ops.floor(ops.mod(heading_true, 2 * math.pi))
71+
heading_true = ops.mod(heading_true, 2 * math.pi)
7072

7173
# Divide 2pi into bins. shifted by 0.5 * angle_per_class.
7274
angle_per_class = (2 * math.pi) / self.num_heading_bins
73-
shift_angle = ops.floor(
74-
ops.mod(heading_true + angle_per_class / 2, 2 * math.pi)
75-
)
75+
shift_angle = ops.mod(heading_true + angle_per_class / 2, 2 * math.pi)
7676

7777
heading_bin_label_float = ops.floor(
7878
ops.divide(shift_angle, angle_per_class)

keras_cv/src/losses/centernet_box_loss_test.py

+26
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from absl.testing import parameterized
1717

1818
import keras_cv
19+
from keras_cv.src.backend import ops
1920
from keras_cv.src.tests.test_case import TestCase
2021

2122

@@ -42,3 +43,28 @@ def test_proper_output_shapes(self, reduction, target_size):
4243
y_pred=np.random.uniform(size=(2, 10, 6 + 2 * 4)),
4344
)
4445
self.assertEqual(result.shape, target_size)
46+
47+
def test_heading_regression_loss(self):
48+
num_heading_bins = 4
49+
loss = keras_cv.losses.CenterNetBoxLoss(
50+
num_heading_bins=num_heading_bins, anchor_size=[1.0, 1.0, 1.0]
51+
)
52+
heading_true = np.array(
53+
[[0, (1 / 2.0) * np.pi, np.pi, (3.0 / 2.0) * np.pi]]
54+
)
55+
heading_pred = np.array(
56+
[
57+
[
58+
[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
59+
[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
60+
[0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0],
61+
[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0],
62+
]
63+
]
64+
)
65+
heading_loss = loss.heading_regression_loss(
66+
heading_true=heading_true, heading_pred=heading_pred
67+
)
68+
ce_loss = -np.log(np.exp(1) / np.exp([1, 0, 0, 0]).sum())
69+
expected_loss = ce_loss * num_heading_bins
70+
self.assertAllClose(ops.sum(heading_loss), expected_loss)

0 commit comments

Comments
 (0)