diff --git a/docs/source/ops.rst b/docs/source/ops.rst
index 7124c85bb79..541b5c30c15 100644
--- a/docs/source/ops.rst
+++ b/docs/source/ops.rst
@@ -50,8 +50,10 @@ These utility functions perform various operations on bounding boxes.
:template: function.rst
box_area
+ box_area_center
box_convert
box_iou
+ box_iou_center
clip_boxes_to_image
complete_box_iou
distance_box_iou
diff --git a/test/test_ops.py b/test/test_ops.py
index 88124f7ba17..4b94f5018dc 100644
--- a/test/test_ops.py
+++ b/test/test_ops.py
@@ -1451,6 +1451,41 @@ def test_box_area_jit(self):
torch.testing.assert_close(scripted_area, expected)
+class TestBoxAreaCenter:
+ def area_check(self, box, expected, atol=1e-4):
+ out = ops.box_area_center(box)
+ torch.testing.assert_close(out, expected, rtol=0.0, check_dtype=False, atol=atol)
+
+ @pytest.mark.parametrize("dtype", [torch.int8, torch.int16, torch.int32, torch.int64])
+ def test_int_boxes(self, dtype):
+ box_tensor = ops.box_convert(torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0]], dtype=dtype),
+ in_fmt="xyxy", out_fmt="cxcywh")
+ expected = torch.tensor([10000, 0], dtype=torch.int32)
+ self.area_check(box_tensor, expected)
+
+ @pytest.mark.parametrize("dtype", [torch.float32, torch.float64])
+ def test_float_boxes(self, dtype):
+ box_tensor = ops.box_convert(torch.tensor(FLOAT_BOXES, dtype=dtype), in_fmt="xyxy", out_fmt="cxcywh")
+ expected = torch.tensor([604723.0806, 600965.4666, 592761.0085], dtype=dtype)
+ self.area_check(box_tensor, expected)
+
+ def test_float16_box(self):
+ box_tensor = ops.box_convert(torch.tensor(
+ [[2.825, 1.8625, 3.90, 4.85], [2.825, 4.875, 19.20, 5.10], [2.925, 1.80, 8.90, 4.90]], dtype=torch.float16
+ ), in_fmt="xyxy", out_fmt="cxcywh")
+
+ expected = torch.tensor([3.2170, 3.7108, 18.5071], dtype=torch.float16)
+ self.area_check(box_tensor, expected, atol=0.01)
+
+ def test_box_area_jit(self):
+ box_tensor = ops.box_convert(torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0]], dtype=torch.float),
+ in_fmt="xyxy", out_fmt="cxcywh")
+ expected = ops.box_area_center(box_tensor)
+ scripted_fn = torch.jit.script(ops.box_area_center)
+ scripted_area = scripted_fn(box_tensor)
+ torch.testing.assert_close(scripted_area, expected)
+
+
INT_BOXES = [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300], [0, 0, 25, 25]]
INT_BOXES2 = [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]
FLOAT_BOXES = [
@@ -1459,6 +1494,14 @@ def test_box_area_jit(self):
[279.2440, 197.9812, 1189.4746, 849.2019],
]
+INT_BOXES_CXCYWH = [[50, 50, 100, 100], [25, 25, 50, 50], [250, 250, 100, 100], [10, 10, 20, 20]]
+INT_BOXES2_CXCYWH = [[50, 50, 100, 100], [25, 25, 50, 50], [250, 250, 100, 100]]
+FLOAT_BOXES_CXCYWH = [
+ [739.4324, 518.5154, 908.1572, 665.8793],
+ [738.8228, 519.9021, 907.3512, 662.3295],
+ [734.3593, 523.5916, 910.2306, 651.2207]
+]
+
def gen_box(size, dtype=torch.float):
xy1 = torch.rand((size, 2), dtype=dtype)
@@ -1525,6 +1568,65 @@ def test_iou_cartesian(self):
self._run_cartesian_test(ops.box_iou)
+class TestIouCenterBase:
+ @staticmethod
+ def _run_test(target_fn: Callable, actual_box1, actual_box2, dtypes, atol, expected):
+ for dtype in dtypes:
+ actual_box1 = torch.tensor(actual_box1, dtype=dtype)
+ actual_box2 = torch.tensor(actual_box2, dtype=dtype)
+ expected_box = torch.tensor(expected)
+ out = target_fn(actual_box1, actual_box2)
+ torch.testing.assert_close(out, expected_box, rtol=0.0, check_dtype=False, atol=atol)
+
+ @staticmethod
+ def _run_jit_test(target_fn: Callable, actual_box: List):
+ box_tensor = torch.tensor(actual_box, dtype=torch.float)
+ expected = target_fn(box_tensor, box_tensor)
+ scripted_fn = torch.jit.script(target_fn)
+ scripted_out = scripted_fn(box_tensor, box_tensor)
+ torch.testing.assert_close(scripted_out, expected)
+
+ @staticmethod
+ def _cartesian_product(boxes1, boxes2, target_fn: Callable):
+ N = boxes1.size(0)
+ M = boxes2.size(0)
+ result = torch.zeros((N, M))
+ for i in range(N):
+ for j in range(M):
+ result[i, j] = target_fn(boxes1[i].unsqueeze(0), boxes2[j].unsqueeze(0))
+ return result
+
+ @staticmethod
+ def _run_cartesian_test(target_fn: Callable):
+ boxes1 = ops.box_convert(gen_box(5), in_fmt="xyxy", out_fmt="cxcywh")
+ boxes2 = ops.box_convert(gen_box(7), in_fmt="xyxy", out_fmt="cxcywh")
+ a = TestIouCenterBase._cartesian_product(boxes1, boxes2, target_fn)
+ b = target_fn(boxes1, boxes2)
+ torch.testing.assert_close(a, b)
+
+
+class TestBoxIouCenter(TestIouBase):
+ int_expected = [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0], [0.04, 0.16, 0.0]]
+ float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]]
+
+ @pytest.mark.parametrize(
+ "actual_box1, actual_box2, dtypes, atol, expected",
+ [
+ pytest.param(INT_BOXES_CXCYWH, INT_BOXES2_CXCYWH, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected),
+ pytest.param(FLOAT_BOXES_CXCYWH, FLOAT_BOXES_CXCYWH, [torch.float16], 0.002, float_expected),
+ pytest.param(FLOAT_BOXES_CXCYWH, FLOAT_BOXES_CXCYWH, [torch.float32, torch.float64], 1e-3, float_expected),
+ ],
+ )
+ def test_iou(self, actual_box1, actual_box2, dtypes, atol, expected):
+ self._run_test(ops.box_iou_center, actual_box1, actual_box2, dtypes, atol, expected)
+
+ def test_iou_jit(self):
+ self._run_jit_test(ops.box_iou_center, INT_BOXES_CXCYWH)
+
+ def test_iou_cartesian(self):
+ self._run_cartesian_test(ops.box_iou_center)
+
+
class TestGeneralizedBoxIou(TestIouBase):
int_expected = [[1.0, 0.25, -0.7778], [0.25, 1.0, -0.8611], [-0.7778, -0.8611, 1.0], [0.0625, 0.25, -0.8819]]
float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]]
diff --git a/torchvision/ops/__init__.py b/torchvision/ops/__init__.py
index 827505b842d..456bde2d036 100644
--- a/torchvision/ops/__init__.py
+++ b/torchvision/ops/__init__.py
@@ -2,8 +2,10 @@
from .boxes import (
batched_nms,
box_area,
+ box_area_center,
box_convert,
box_iou,
+ box_iou_center,
clip_boxes_to_image,
complete_box_iou,
distance_box_iou,
@@ -40,7 +42,9 @@
"clip_boxes_to_image",
"box_convert",
"box_area",
+ "box_area_center",
"box_iou",
+ "box_iou_center",
"generalized_box_iou",
"distance_box_iou",
"complete_box_iou",
diff --git a/torchvision/ops/boxes.py b/torchvision/ops/boxes.py
index 48df4d85cc7..a16085acaae 100644
--- a/torchvision/ops/boxes.py
+++ b/torchvision/ops/boxes.py
@@ -291,6 +291,25 @@ def box_area(boxes: Tensor) -> Tensor:
return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
+def box_area_center(boxes: Tensor) -> Tensor:
+ """
+ Computes the area of a set of bounding boxes, which are specified by their
+ (cx, cy, w, h) coordinates.
+
+ Args:
+ boxes (Tensor[N, 4]): boxes for which the area will be computed. They
+ are expected to be in (cx, cy, w, h) format with
+ ``0 <= cx``, ``0 <= cy``, ``0 <= w`` and ``0 <= h``.
+
+ Returns:
+ Tensor[N]: the area for each box
+ """
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
+ _log_api_usage_once(box_area_center)
+ boxes = _upcast(boxes)
+ return boxes[:, 2] * boxes[:, 3]
+
+
# implementation from https://github.com/kuangliu/torchcv/blob/master/torchcv/utils/box.py
# with slight modifications
def _box_inter_union(boxes1: Tensor, boxes2: Tensor) -> Tuple[Tensor, Tensor]:
@@ -329,6 +348,42 @@ def box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
return iou
+def _box_inter_union_center(boxes1: Tensor, boxes2: Tensor) -> Tuple[Tensor, Tensor]:
+ area1 = box_area_center(boxes1)
+ area2 = box_area_center(boxes2)
+
+ lt = torch.max(boxes1[:, None, :2] - boxes1[:, None, 2:] / 2, boxes2[:, :2] - boxes2[:, 2:] / 2) # [N,M,2]
+ rb = torch.min(boxes1[:, None, :2] + boxes1[:, None, 2:] / 2, boxes2[:, :2] + boxes2[:, 2:] / 2) # [N,M,2]
+
+ wh = _upcast(rb - lt).clamp(min=0) # [N,M,2]
+ inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
+
+ union = area1[:, None] + area2 - inter
+
+ return inter, union
+
+
+def box_iou_center(boxes1: Tensor, boxes2: Tensor) -> Tensor:
+ """
+ Return intersection-over-union (Jaccard index) between two sets of boxes.
+
+ Both sets of boxes are expected to be in ``(cx, cy, w, h)`` format with
+ ``0 <= cx``, ``0 <= cy``, ``0 <= w`` and ``0 <= h``.
+
+ Args:
+ boxes1 (Tensor[N, 4]): first set of boxes
+ boxes2 (Tensor[M, 4]): second set of boxes
+
+ Returns:
+ Tensor[N, M]: the NxM matrix containing the pairwise IoU values for every element in boxes1 and boxes2
+ """
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
+ _log_api_usage_once(box_iou_center)
+ inter, union = _box_inter_union_center(boxes1, boxes2)
+ iou = inter / union
+ return iou
+
+
# Implementation adapted from https://github.com/facebookresearch/detr/blob/master/util/box_ops.py
def generalized_box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
"""