diff --git a/docs/source/tv_tensors.rst b/docs/source/tv_tensors.rst index cb8a3c45fa9..d292012fdf8 100644 --- a/docs/source/tv_tensors.rst +++ b/docs/source/tv_tensors.rst @@ -21,6 +21,7 @@ info. Image Video + KeyPoints BoundingBoxFormat BoundingBoxes Mask diff --git a/gallery/transforms/plot_tv_tensors.py b/gallery/transforms/plot_tv_tensors.py index 5bce37aa374..2c6ebbf9031 100644 --- a/gallery/transforms/plot_tv_tensors.py +++ b/gallery/transforms/plot_tv_tensors.py @@ -46,11 +46,12 @@ # Under the hood, they are needed in :mod:`torchvision.transforms.v2` to correctly dispatch to the appropriate function # for the input data. # -# :mod:`torchvision.tv_tensors` supports four types of TVTensors: +# :mod:`torchvision.tv_tensors` supports five types of TVTensors: # # * :class:`~torchvision.tv_tensors.Image` # * :class:`~torchvision.tv_tensors.Video` # * :class:`~torchvision.tv_tensors.BoundingBoxes` +# * :class:`~torchvision.tv_tensors.KeyPoints` # * :class:`~torchvision.tv_tensors.Mask` # # What can I do with a TVTensor? @@ -96,6 +97,7 @@ # :class:`~torchvision.tv_tensors.BoundingBoxes` requires the coordinate format as well as the size of the # corresponding image (``canvas_size``) alongside the actual values. These # metadata are required to properly transform the bounding boxes. +# In a similar fashion, :class:`~torchvision.tv_tensors.KeyPoints` also require the ``canvas_size`` metadata to be added. bboxes = tv_tensors.BoundingBoxes( [[17, 16, 344, 495], [0, 10, 0, 10]], @@ -104,6 +106,13 @@ ) print(bboxes) + +keypoints = tv_tensors.KeyPoints( + [[17, 16], [344, 495], [0, 10], [0, 10]], + canvas_size=image.shape[-2:] +) +print(keypoints) + # %% # Using ``tv_tensors.wrap()`` # ^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/test/common_utils.py b/test/common_utils.py index b3a26dfd441..600cb5a13b7 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -9,6 +9,7 @@ import sys import tempfile import warnings +from collections.abc import Sequence from subprocess import CalledProcessError, check_output, STDOUT import numpy as np @@ -400,6 +401,20 @@ def make_image_pil(*args, **kwargs): return to_pil_image(make_image(*args, **kwargs)) +def make_keypoints( + canvas_size: tuple[int, int] = DEFAULT_SIZE, *, num_points: int | Sequence[int] = 4, dtype=None, device="cpu" +) -> tv_tensors.KeyPoints: + """Make the KeyPoints for testing purposes""" + if isinstance(num_points, int): + num_points = [num_points] + single_coord_shape: tuple[int, ...] = tuple(num_points) + (1,) + y = torch.randint(0, canvas_size[0] - 1, single_coord_shape, dtype=dtype, device=device) + x = torch.randint(0, canvas_size[1] - 1, single_coord_shape, dtype=dtype, device=device) + points = torch.cat((x, y), dim=-1) + keypoints = tv_tensors.KeyPoints(points, canvas_size=canvas_size) + return keypoints + + def make_bounding_boxes( canvas_size=DEFAULT_SIZE, *, diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 94d90b9e2f6..c3a9692a664 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -31,6 +31,7 @@ make_image, make_image_pil, make_image_tensor, + make_keypoints, make_segmentation_mask, make_video, make_video_tensor, @@ -230,9 +231,7 @@ def check_functional_kernel_signature_match(functional, *, kernel, input_type): if issubclass(input_type, tv_tensors.TVTensor): # We filter out metadata that is implicitly passed to the functional through the input tv_tensor, but has to be # explicitly passed to the kernel. - explicit_metadata = { - tv_tensors.BoundingBoxes: {"format", "canvas_size"}, - } + explicit_metadata = {tv_tensors.BoundingBoxes: {"format", "canvas_size"}, tv_tensors.KeyPoints: {"canvas_size"}} kernel_params = [param for param in kernel_params if param.name not in explicit_metadata.get(input_type, set())] functional_params = iter(functional_params) @@ -336,6 +335,21 @@ def _make_transform_sample(transform, *, image_or_video, adapter): canvas_size=size, device=device, ), + keypoints=make_keypoints(canvas_size=size), + keypoints_degenerate=tv_tensors.KeyPoints( + [ + [0, 1], # left edge + [1, 0], # top edge + [0, 0], # top left corner + [size[1], 1], # right edge + [size[1], 0], # top right corner + [1, size[0]], # bottom edge + [0, size[0]], # bottom left corner + [size[1], size[0]], # bottom right corner + ], + canvas_size=size, + device=device, + ), detection_mask=make_detection_masks(size, device=device), segmentation_mask=make_segmentation_mask(size, device=device), int=0, @@ -560,6 +574,45 @@ def affine_bounding_boxes(bounding_boxes): ) +def reference_affine_keypoints_helper(keypoints, *, affine_matrix, new_canvas_size=None, clamp=True): + canvas_size = new_canvas_size or keypoints.canvas_size + + def affine_keypoints(keypoints): + dtype = keypoints.dtype + device = keypoints.device + + # Go to float before converting to prevent precision loss + x, y = keypoints.to(dtype=torch.float64, device="cpu", copy=True).squeeze(0).tolist() + + points = np.array([[x, y, 1.0]]) + transformed_points = np.matmul(points, affine_matrix.astype(points.dtype).T) + + output = torch.Tensor( + [ + float(transformed_points[0, 0]), + float(transformed_points[0, 1]), + ] + ) + + if clamp: + # It is important to clamp before casting, especially for CXCYWH format, dtype=int64 + output = F.clamp_keypoints( + output, + canvas_size=canvas_size, + ) + else: + # We leave the bounding box as float64 so the caller gets the full precision to perform any additional + # operation + dtype = output.dtype + + return output.to(dtype=dtype, device=device) + + return tv_tensors.KeyPoints( + torch.cat([affine_keypoints(k) for k in keypoints.reshape(-1, 2).unbind()], dim=0).reshape(keypoints.shape), + canvas_size=canvas_size, + ) + + class TestResize: INPUT_SIZE = (17, 11) OUTPUT_SIZES = [17, [17], (17,), None, [12, 13], (12, 13)] @@ -659,6 +712,28 @@ def test_kernel_bounding_boxes(self, format, size, use_max_size, dtype, device): check_scripted_vs_eager=not isinstance(size, int), ) + @pytest.mark.parametrize("size", OUTPUT_SIZES) + @pytest.mark.parametrize("use_max_size", [True, False]) + @pytest.mark.parametrize("dtype", [torch.float32, torch.int64]) + @pytest.mark.parametrize("device", cpu_and_cuda()) + def test_kernel_keypoints(self, size, use_max_size, dtype, device): + if not (max_size_kwarg := self._make_max_size_kwarg(use_max_size=use_max_size, size=size)): + return + + keypoints = make_keypoints( + canvas_size=self.INPUT_SIZE, + dtype=dtype, + device=device, + ) + check_kernel( + F.resize_keypoints, + keypoints, + canvas_size=keypoints.canvas_size, + size=size, + **max_size_kwarg, + check_scripted_vs_eager=not isinstance(size, int), + ) + @pytest.mark.parametrize("make_mask", [make_segmentation_mask, make_detection_masks]) def test_kernel_mask(self, make_mask): check_kernel(F.resize_mask, make_mask(self.INPUT_SIZE), size=self.OUTPUT_SIZES[-1]) @@ -689,6 +764,7 @@ def test_functional(self, size, make_input): (F.resize_image, torch.Tensor), (F._geometry._resize_image_pil, PIL.Image.Image), (F.resize_image, tv_tensors.Image), + (F.resize_keypoints, tv_tensors.KeyPoints), (F.resize_bounding_boxes, tv_tensors.BoundingBoxes), (F.resize_mask, tv_tensors.Mask), (F.resize_video, tv_tensors.Video), @@ -766,6 +842,28 @@ def _reference_resize_bounding_boxes(self, bounding_boxes, *, size, max_size=Non new_canvas_size=(new_height, new_width), ) + def _reference_resize_keypoints(self, keypoints, *, size, max_size=None): + old_height, old_width = keypoints.canvas_size + new_height, new_width = self._compute_output_size( + input_size=keypoints.canvas_size, size=size, max_size=max_size + ) + + if (old_height, old_width) == (new_height, new_width): + return keypoints + + affine_matrix = np.array( + [ + [new_width / old_width, 0, 0], + [0, new_height / old_height, 0], + ], + ) + + return reference_affine_keypoints_helper( + keypoints, + affine_matrix=affine_matrix, + new_canvas_size=(new_height, new_width), + ) + @pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS) @pytest.mark.parametrize("size", OUTPUT_SIZES) @pytest.mark.parametrize("use_max_size", [True, False]) @@ -782,6 +880,21 @@ def test_bounding_boxes_correctness(self, format, size, use_max_size, fn): self._check_output_size(bounding_boxes, actual, size=size, **max_size_kwarg) torch.testing.assert_close(actual, expected) + @pytest.mark.parametrize("size", OUTPUT_SIZES) + @pytest.mark.parametrize("use_max_size", [True, False]) + @pytest.mark.parametrize("fn", [F.resize, transform_cls_to_functional(transforms.Resize)]) + def test_keypoints_correctness(self, size, use_max_size, fn): + if not (max_size_kwarg := self._make_max_size_kwarg(use_max_size=use_max_size, size=size)): + return + + keypoints = make_keypoints(canvas_size=self.INPUT_SIZE) + + actual = fn(keypoints, size=size, **max_size_kwarg) + expected = self._reference_resize_keypoints(keypoints, size=size, **max_size_kwarg) + + self._check_output_size(keypoints, actual, size=size, **max_size_kwarg) + torch.testing.assert_close(actual, expected) + @pytest.mark.parametrize("interpolation", set(transforms.InterpolationMode) - set(INTERPOLATION_MODES)) @pytest.mark.parametrize( "make_input", @@ -1024,6 +1137,16 @@ def test_kernel_bounding_boxes(self, format, dtype, device): canvas_size=bounding_boxes.canvas_size, ) + @pytest.mark.parametrize("dtype", [torch.float32, torch.int64]) + @pytest.mark.parametrize("device", cpu_and_cuda()) + def test_kernel_keypoints(self, dtype, device): + keypoints = make_keypoints(dtype=dtype, device=device) + check_kernel( + F.horizontal_flip_keypoints, + keypoints, + canvas_size=keypoints.canvas_size, + ) + @pytest.mark.parametrize("make_mask", [make_segmentation_mask, make_detection_masks]) def test_kernel_mask(self, make_mask): check_kernel(F.horizontal_flip_mask, make_mask()) @@ -1044,6 +1167,7 @@ def test_functional(self, make_input): (F.horizontal_flip_image, torch.Tensor), (F._geometry._horizontal_flip_image_pil, PIL.Image.Image), (F.horizontal_flip_image, tv_tensors.Image), + (F.horizontal_flip_keypoints, tv_tensors.KeyPoints), (F.horizontal_flip_bounding_boxes, tv_tensors.BoundingBoxes), (F.horizontal_flip_mask, tv_tensors.Mask), (F.horizontal_flip_video, tv_tensors.Video), @@ -1081,6 +1205,16 @@ def _reference_horizontal_flip_bounding_boxes(self, bounding_boxes): return reference_affine_bounding_boxes_helper(bounding_boxes, affine_matrix=affine_matrix) + def _reference_horizontal_flip_keypoints(self, keypoints): + affine_matrix = np.array( + [ + [-1, 0, keypoints.canvas_size[1]], + [0, 1, 0], + ], + ) + + return reference_affine_keypoints_helper(keypoints, affine_matrix=affine_matrix) + @pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS) @pytest.mark.parametrize( "fn", [F.horizontal_flip, transform_cls_to_functional(transforms.RandomHorizontalFlip, p=1)] @@ -1093,6 +1227,17 @@ def test_bounding_boxes_correctness(self, format, fn): torch.testing.assert_close(actual, expected) + @pytest.mark.parametrize( + "fn", [F.horizontal_flip, transform_cls_to_functional(transforms.RandomHorizontalFlip, p=1)] + ) + def test_keypoints_correctness(self, fn): + keypoints = make_keypoints() + + actual = fn(keypoints) + expected = self._reference_horizontal_flip_keypoints(keypoints) + + torch.testing.assert_close(actual, expected) + @pytest.mark.parametrize( "make_input", [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video], @@ -1194,6 +1339,24 @@ def test_kernel_bounding_boxes(self, param, value, format, dtype, device): check_scripted_vs_eager=not (param == "shear" and isinstance(value, (int, float))), ) + @param_value_parametrization( + angle=_EXHAUSTIVE_TYPE_AFFINE_KWARGS["angle"], + translate=_EXHAUSTIVE_TYPE_AFFINE_KWARGS["translate"], + shear=_EXHAUSTIVE_TYPE_AFFINE_KWARGS["shear"], + center=_EXHAUSTIVE_TYPE_AFFINE_KWARGS["center"], + ) + @pytest.mark.parametrize("dtype", [torch.float32, torch.int64]) + @pytest.mark.parametrize("device", cpu_and_cuda()) + def test_kernel_keypoints(self, param, value, dtype, device): + keypoints = make_keypoints(dtype=dtype, device=device) + self._check_kernel( + F.affine_keypoints, + keypoints, + canvas_size=keypoints.canvas_size, + **{param: value}, + check_scripted_vs_eager=not (param == "shear" and isinstance(value, (int, float))), + ) + @pytest.mark.parametrize("make_mask", [make_segmentation_mask, make_detection_masks]) def test_kernel_mask(self, make_mask): self._check_kernel(F.affine_mask, make_mask()) @@ -1214,6 +1377,7 @@ def test_functional(self, make_input): (F.affine_image, torch.Tensor), (F._geometry._affine_image_pil, PIL.Image.Image), (F.affine_image, tv_tensors.Image), + (F.affine_keypoints, tv_tensors.KeyPoints), (F.affine_bounding_boxes, tv_tensors.BoundingBoxes), (F.affine_mask, tv_tensors.Mask), (F.affine_video, tv_tensors.Video), @@ -1329,6 +1493,17 @@ def _reference_affine_bounding_boxes(self, bounding_boxes, *, angle, translate, ), ) + def _reference_affine_keypoints(self, keypoints, *, angle, translate, scale, shear, center): + if center is None: + center = [s * 0.5 for s in keypoints.canvas_size[::-1]] + + return reference_affine_keypoints_helper( + keypoints, + affine_matrix=self._compute_affine_matrix( + angle=angle, translate=translate, scale=scale, shear=shear, center=center + ), + ) + @pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS) @pytest.mark.parametrize("angle", _CORRECTNESS_AFFINE_KWARGS["angle"]) @pytest.mark.parametrize("translate", _CORRECTNESS_AFFINE_KWARGS["translate"]) @@ -1375,6 +1550,50 @@ def test_transform_bounding_boxes_correctness(self, format, center, seed): torch.testing.assert_close(actual, expected) + @pytest.mark.parametrize("angle", _CORRECTNESS_AFFINE_KWARGS["angle"]) + @pytest.mark.parametrize("translate", _CORRECTNESS_AFFINE_KWARGS["translate"]) + @pytest.mark.parametrize("scale", _CORRECTNESS_AFFINE_KWARGS["scale"]) + @pytest.mark.parametrize("shear", _CORRECTNESS_AFFINE_KWARGS["shear"]) + @pytest.mark.parametrize("center", _CORRECTNESS_AFFINE_KWARGS["center"]) + def test_functional_keypoints_correctness(self, angle, translate, scale, shear, center): + keypoints = make_keypoints() + + actual = F.affine( + keypoints, + angle=angle, + translate=translate, + scale=scale, + shear=shear, + center=center, + ) + expected = self._reference_affine_keypoints( + keypoints, + angle=angle, + translate=translate, + scale=scale, + shear=shear, + center=center, + ) + + torch.testing.assert_close(actual, expected) + + @pytest.mark.parametrize("center", _CORRECTNESS_AFFINE_KWARGS["center"]) + @pytest.mark.parametrize("seed", list(range(5))) + def test_transform_keypoints_correctness(self, center, seed): + keypoints = make_keypoints() + + transform = transforms.RandomAffine(**self._CORRECTNESS_TRANSFORM_AFFINE_RANGES, center=center) + + torch.manual_seed(seed) + params = transform.make_params([keypoints]) + + torch.manual_seed(seed) + actual = transform(keypoints) + + expected = self._reference_affine_keypoints(keypoints, **params, center=center) + + torch.testing.assert_close(actual, expected) + @pytest.mark.parametrize("degrees", _EXHAUSTIVE_TYPE_TRANSFORM_AFFINE_RANGES["degrees"]) @pytest.mark.parametrize("translate", _EXHAUSTIVE_TYPE_TRANSFORM_AFFINE_RANGES["translate"]) @pytest.mark.parametrize("scale", _EXHAUSTIVE_TYPE_TRANSFORM_AFFINE_RANGES["scale"]) @@ -1476,6 +1695,16 @@ def test_kernel_bounding_boxes(self, format, dtype, device): canvas_size=bounding_boxes.canvas_size, ) + @pytest.mark.parametrize("dtype", [torch.float32, torch.int64]) + @pytest.mark.parametrize("device", cpu_and_cuda()) + def test_kernel_keypoints(self, dtype, device): + keypoints = make_keypoints(dtype=dtype, device=device) + check_kernel( + F.vertical_flip_keypoints, + keypoints, + canvas_size=keypoints.canvas_size, + ) + @pytest.mark.parametrize("make_mask", [make_segmentation_mask, make_detection_masks]) def test_kernel_mask(self, make_mask): check_kernel(F.vertical_flip_mask, make_mask()) @@ -1496,6 +1725,7 @@ def test_functional(self, make_input): (F.vertical_flip_image, torch.Tensor), (F._geometry._vertical_flip_image_pil, PIL.Image.Image), (F.vertical_flip_image, tv_tensors.Image), + (F.vertical_flip_keypoints, tv_tensors.KeyPoints), (F.vertical_flip_bounding_boxes, tv_tensors.BoundingBoxes), (F.vertical_flip_mask, tv_tensors.Mask), (F.vertical_flip_video, tv_tensors.Video), @@ -1531,6 +1761,16 @@ def _reference_vertical_flip_bounding_boxes(self, bounding_boxes): return reference_affine_bounding_boxes_helper(bounding_boxes, affine_matrix=affine_matrix) + def _reference_vertical_flip_keypoints(self, keypoints): + affine_matrix = np.array( + [ + [1, 0, 0], + [0, -1, keypoints.canvas_size[0]], + ], + ) + + return reference_affine_keypoints_helper(keypoints, affine_matrix=affine_matrix) + @pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS) @pytest.mark.parametrize("fn", [F.vertical_flip, transform_cls_to_functional(transforms.RandomVerticalFlip, p=1)]) def test_bounding_boxes_correctness(self, format, fn): @@ -1541,6 +1781,15 @@ def test_bounding_boxes_correctness(self, format, fn): torch.testing.assert_close(actual, expected) + @pytest.mark.parametrize("fn", [F.vertical_flip, transform_cls_to_functional(transforms.RandomVerticalFlip, p=1)]) + def test_keypoints_correctness(self, fn): + keypoints = make_keypoints() + + actual = fn(keypoints) + expected = self._reference_vertical_flip_keypoints(keypoints) + + torch.testing.assert_close(actual, expected) + @pytest.mark.parametrize( "make_input", [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video], @@ -1618,6 +1867,27 @@ def test_kernel_bounding_boxes(self, param, value, format, dtype, device): **kwargs, ) + @param_value_parametrization( + angle=_EXHAUSTIVE_TYPE_AFFINE_KWARGS["angle"], + expand=[False, True], + center=_EXHAUSTIVE_TYPE_AFFINE_KWARGS["center"], + ) + @pytest.mark.parametrize("dtype", [torch.float32, torch.uint8]) + @pytest.mark.parametrize("device", cpu_and_cuda()) + def test_kernel_keypoints(self, param, value, dtype, device): + kwargs = {param: value} + if param != "angle": + kwargs["angle"] = self._MINIMAL_AFFINE_KWARGS["angle"] + + keypoints = make_keypoints(dtype=dtype, device=device) + + check_kernel( + F.rotate_keypoints, + keypoints, + canvas_size=keypoints.canvas_size, + **kwargs, + ) + @pytest.mark.parametrize("make_mask", [make_segmentation_mask, make_detection_masks]) def test_kernel_mask(self, make_mask): check_kernel(F.rotate_mask, make_mask(), **self._MINIMAL_AFFINE_KWARGS) @@ -1638,6 +1908,7 @@ def test_functional(self, make_input): (F.rotate_image, torch.Tensor), (F._geometry._rotate_image_pil, PIL.Image.Image), (F.rotate_image, tv_tensors.Image), + (F.rotate_keypoints, tv_tensors.KeyPoints), (F.rotate_bounding_boxes, tv_tensors.BoundingBoxes), (F.rotate_mask, tv_tensors.Mask), (F.rotate_video, tv_tensors.Video), @@ -1804,6 +2075,71 @@ def test_transform_bounding_boxes_correctness(self, format, expand, center, seed torch.testing.assert_close(actual, expected) torch.testing.assert_close(F.get_size(actual), F.get_size(expected), atol=2 if expand else 0, rtol=0) + def _recenter_keypoints_after_expand(self, keypoints, *, recenter_xy): + x, y = recenter_xy + translate = [x, y] + return tv_tensors.wrap( + (keypoints.to(torch.float64) - torch.tensor(translate)).to(keypoints.dtype), like=keypoints + ) + + def _reference_rotate_keypoints(self, keypoints, *, angle, expand, center): + if center is None: + center = [s * 0.5 for s in keypoints.canvas_size[::-1]] + cx, cy = center + + a = np.cos(angle * np.pi / 180.0) + b = np.sin(angle * np.pi / 180.0) + affine_matrix = np.array( + [ + [a, b, cx - cx * a - b * cy], + [-b, a, cy + cx * b - a * cy], + ], + ) + + new_canvas_size, recenter_xy = self._compute_output_canvas_size( + expand=expand, canvas_size=keypoints.canvas_size, affine_matrix=affine_matrix + ) + + output = reference_affine_keypoints_helper( + keypoints, + affine_matrix=affine_matrix, + new_canvas_size=new_canvas_size, + clamp=False, + ) + + return F.clamp_keypoints(self._recenter_keypoints_after_expand(output, recenter_xy=recenter_xy)).to(keypoints) + + @pytest.mark.parametrize("angle", _CORRECTNESS_AFFINE_KWARGS["angle"]) + @pytest.mark.parametrize("expand", [False, True]) + @pytest.mark.parametrize("center", _CORRECTNESS_AFFINE_KWARGS["center"]) + def test_functional_keypoints_correctness(self, angle, expand, center): + keypoints = make_keypoints() + + actual = F.rotate(keypoints, angle=angle, expand=expand, center=center) + expected = self._reference_rotate_keypoints(keypoints, angle=angle, expand=expand, center=center) + + torch.testing.assert_close(actual, expected) + torch.testing.assert_close(F.get_size(actual), F.get_size(expected), atol=2 if expand else 0, rtol=0) + + @pytest.mark.parametrize("expand", [False, True]) + @pytest.mark.parametrize("center", _CORRECTNESS_AFFINE_KWARGS["center"]) + @pytest.mark.parametrize("seed", list(range(5))) + def test_transform_keypoints_correctness(self, expand, center, seed): + keypoints = make_keypoints() + + transform = transforms.RandomRotation(**self._CORRECTNESS_TRANSFORM_AFFINE_RANGES, expand=expand, center=center) + + torch.manual_seed(seed) + params = transform.make_params([keypoints]) + + torch.manual_seed(seed) + actual = transform(keypoints) + + expected = self._reference_rotate_keypoints(keypoints, **params, expand=expand, center=center) + + torch.testing.assert_close(actual, expected) + torch.testing.assert_close(F.get_size(actual), F.get_size(expected), atol=2 if expand else 0, rtol=0) + @pytest.mark.parametrize("degrees", _EXHAUSTIVE_TYPE_TRANSFORM_AFFINE_RANGES["degrees"]) @pytest.mark.parametrize("seed", list(range(10))) def test_transformmake_params_bounds(self, degrees, seed): @@ -2343,7 +2679,9 @@ def test_error(self, T): F.to_pil_image(imgs[0]), tv_tensors.Mask(torch.rand(12, 12)), tv_tensors.BoundingBoxes(torch.rand(2, 4), format="XYXY", canvas_size=12), + tv_tensors.KeyPoints(torch.rand(2, 2), canvas_size=(12, 12)), ): + print(type(input_with_bad_type), cutmix_mixup) with pytest.raises(ValueError, match="does not support PIL images, "): cutmix_mixup(input_with_bad_type) @@ -2719,6 +3057,18 @@ def test_kernel_bounding_boxes(self, format, dtype, device): displacement=self._make_displacement(bounding_boxes), ) + @pytest.mark.parametrize("dtype", [torch.float32, torch.int64]) + @pytest.mark.parametrize("device", cpu_and_cuda()) + def test_kernel_keypoints(self, dtype, device): + keypoints = make_keypoints(dtype=dtype, device=device) + + check_kernel( + F.elastic_keypoints, + keypoints, + canvas_size=keypoints.canvas_size, + displacement=self._make_displacement(keypoints), + ) + @pytest.mark.parametrize("make_mask", [make_segmentation_mask, make_detection_masks]) def test_kernel_mask(self, make_mask): mask = make_mask() @@ -2752,7 +3102,15 @@ def test_functional_signature(self, kernel, input_type): @pytest.mark.parametrize( "make_input", - [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video], + [ + make_image_tensor, + make_image_pil, + make_image, + make_bounding_boxes, + make_segmentation_mask, + make_video, + make_keypoints, + ], ) def test_displacement_error(self, make_input): input = make_input() @@ -2765,7 +3123,15 @@ def test_displacement_error(self, make_input): @pytest.mark.parametrize( "make_input", - [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video], + [ + make_image_tensor, + make_image_pil, + make_image, + make_bounding_boxes, + make_segmentation_mask, + make_video, + make_keypoints, + ], ) # ElasticTransform needs larger images to avoid the needed internal padding being larger than the actual image @pytest.mark.parametrize("size", [(163, 163), (72, 333), (313, 95)]) @@ -2835,7 +3201,7 @@ def test_kernel_image(self, kwargs, dtype, device): @pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS) @pytest.mark.parametrize("dtype", [torch.float32, torch.int64]) @pytest.mark.parametrize("device", cpu_and_cuda()) - def test_kernel_bounding_box(self, kwargs, format, dtype, device): + def test_kernel_bounding_boxes(self, kwargs, format, dtype, device): bounding_boxes = make_bounding_boxes(self.INPUT_SIZE, format=format, dtype=dtype, device=device) check_kernel(F.crop_bounding_boxes, bounding_boxes, format=format, **kwargs) @@ -3020,6 +3386,54 @@ def test_transform_bounding_boxes_correctness(self, output_size, format, dtype, assert_equal(actual, expected) assert_equal(F.get_size(actual), F.get_size(expected)) + def _reference_crop_keypoints(self, keypoints, *, top, left, height, width): + affine_matrix = np.array( + [ + [1, 0, -left], + [0, 1, -top], + ], + ) + return reference_affine_keypoints_helper( + keypoints, affine_matrix=affine_matrix, new_canvas_size=(height, width) + ) + + @pytest.mark.parametrize("kwargs", CORRECTNESS_CROP_KWARGS) + @pytest.mark.parametrize("dtype", [torch.float32, torch.int64]) + @pytest.mark.parametrize("device", cpu_and_cuda()) + def test_functional_keypoints_correctness(self, kwargs, dtype, device): + keypoints = make_keypoints(self.INPUT_SIZE, dtype=dtype, device=device) + + actual = F.crop(keypoints, **kwargs) + expected = self._reference_crop_keypoints(keypoints, **kwargs) + + assert_equal(actual, expected, atol=1, rtol=0) + assert_equal(F.get_size(actual), F.get_size(expected)) + + @pytest.mark.parametrize("output_size", [(17, 11), (11, 17), (11, 11)]) + @pytest.mark.parametrize("dtype", [torch.float32, torch.int64]) + @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("seed", list(range(5))) + def test_transform_keypoints_correctness(self, output_size, dtype, device, seed): + input_size = (output_size[0] * 2, output_size[1] * 2) + keypoints = make_keypoints(input_size, dtype=dtype, device=device) + + transform = transforms.RandomCrop(output_size) + + with freeze_rng_state(): + torch.manual_seed(seed) + params = transform.make_params([keypoints]) + assert not params.pop("needs_pad") + del params["padding"] + assert params.pop("needs_crop") + + torch.manual_seed(seed) + actual = transform(keypoints) + + expected = self._reference_crop_keypoints(keypoints, **params) + + assert_equal(actual, expected) + assert_equal(F.get_size(actual), F.get_size(expected)) + def test_errors(self): with pytest.raises(ValueError, match="Please provide only two dimensions"): transforms.RandomCrop([10, 12, 14]) @@ -3471,7 +3885,7 @@ def _sample_input_adapter(self, transform, input, device): adapted_input = {} image_or_video_found = False for key, value in input.items(): - if isinstance(value, (tv_tensors.BoundingBoxes, tv_tensors.Mask)): + if isinstance(value, (tv_tensors.BoundingBoxes, tv_tensors.KeyPoints, tv_tensors.Mask)): # AA transforms don't support bounding boxes or masks continue elif check_type(value, (tv_tensors.Image, tv_tensors.Video, is_pure_tensor, PIL.Image.Image)): @@ -3758,6 +4172,31 @@ def _reference_resized_crop_bounding_boxes(self, bounding_boxes, *, top, left, h new_canvas_size=size, ) + def _reference_resized_crop_keypoints(self, keypoints, *, top, left, height, width, size): + new_height, new_width = size + + crop_affine_matrix = np.array( + [ + [1, 0, -left], + [0, 1, -top], + [0, 0, 1], + ], + ) + resize_affine_matrix = np.array( + [ + [new_width / width, 0, 0], + [0, new_height / height, 0], + [0, 0, 1], + ], + ) + affine_matrix = (resize_affine_matrix @ crop_affine_matrix)[:2, :] + + return reference_affine_keypoints_helper( + keypoints, + affine_matrix=affine_matrix, + new_canvas_size=size, + ) + @pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS) def test_functional_bounding_boxes_correctness(self, format): bounding_boxes = make_bounding_boxes(self.INPUT_SIZE, format=format) @@ -3770,6 +4209,15 @@ def test_functional_bounding_boxes_correctness(self, format): assert_equal(actual, expected) assert_equal(F.get_size(actual), F.get_size(expected)) + def test_functional_keypoints_correctness(self): + keypoints = make_keypoints(self.INPUT_SIZE) + + actual = F.resized_crop(keypoints, **self.CROP_KWARGS, size=self.OUTPUT_SIZE) + expected = self._reference_resized_crop_keypoints(keypoints, **self.CROP_KWARGS, size=self.OUTPUT_SIZE) + + assert_equal(actual, expected) + assert_equal(F.get_size(actual), F.get_size(expected)) + def test_transform_errors_warnings(self): with pytest.raises(ValueError, match="provide only two dimensions"): transforms.RandomResizedCrop(size=(1, 2, 3)) @@ -3855,6 +4303,26 @@ def test_kernel_bounding_boxes_errors(self, padding_mode): padding_mode=padding_mode, ) + def test_kernel_keypoints(self): + keypoints = make_keypoints() + check_kernel( + F.pad_keypoints, + keypoints, + canvas_size=keypoints.canvas_size, + padding=[1], + ) + + @pytest.mark.parametrize("padding_mode", ["symmetric", "edge", "reflect"]) + def test_kernel_keypoints_errors(self, padding_mode): + keypoints = make_keypoints() + with pytest.raises(ValueError, match=f"'{padding_mode}' is not supported"): + F.pad_keypoints( + keypoints, + canvas_size=keypoints.canvas_size, + padding=[1], + padding_mode=padding_mode, + ) + @pytest.mark.parametrize("make_mask", [make_segmentation_mask, make_detection_masks]) def test_kernel_mask(self, make_mask): check_kernel(F.pad_mask, make_mask(), padding=[1]) @@ -3998,6 +4466,17 @@ def test_kernel_bounding_boxes(self, output_size, format): check_scripted_vs_eager=not isinstance(output_size, int), ) + @pytest.mark.parametrize("output_size", OUTPUT_SIZES) + def test_kernel_keypoints(self, output_size): + keypoints = make_keypoints(self.INPUT_SIZE) + check_kernel( + F.center_crop_keypoints, + keypoints, + canvas_size=keypoints.canvas_size, + output_size=output_size, + check_scripted_vs_eager=not isinstance(output_size, int), + ) + @pytest.mark.parametrize("make_mask", [make_segmentation_mask, make_detection_masks]) def test_kernel_mask(self, make_mask): check_kernel(F.center_crop_mask, make_mask(), output_size=self.OUTPUT_SIZES[0]) @@ -4077,6 +4556,37 @@ def test_bounding_boxes_correctness(self, output_size, format, dtype, device, fn assert_equal(actual, expected) + def _reference_center_crop_keypoints(self, keypoints, output_size): + image_height, image_width = keypoints.canvas_size + if isinstance(output_size, int): + output_size = (output_size, output_size) + elif len(output_size) == 1: + output_size *= 2 + crop_height, crop_width = output_size + + top = int(round((image_height - crop_height) / 2)) + left = int(round((image_width - crop_width) / 2)) + + affine_matrix = np.array( + [ + [1, 0, -left], + [0, 1, -top], + ], + ) + return reference_affine_keypoints_helper(keypoints, affine_matrix=affine_matrix, new_canvas_size=output_size) + + @pytest.mark.parametrize("output_size", OUTPUT_SIZES) + @pytest.mark.parametrize("dtype", [torch.int64, torch.float32]) + @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("fn", [F.center_crop, transform_cls_to_functional(transforms.CenterCrop)]) + def test_keypoints_correctness(self, output_size, dtype, device, fn): + keypoints = make_keypoints(self.INPUT_SIZE, dtype=dtype, device=device) + + actual = fn(keypoints, output_size) + expected = self._reference_center_crop_keypoints(keypoints, output_size) + + assert_equal(actual, expected) + class TestPerspective: COEFFICIENTS = [ @@ -4164,6 +4674,39 @@ def test_kernel_bounding_boxes_error(self): coefficients=[0.0] * 8, ) + @param_value_parametrization( + coefficients=COEFFICIENTS, + start_end_points=START_END_POINTS, + ) + def test_kernel_keypoints(self, param, value): + if param == "start_end_points": + kwargs = dict(zip(["startpoints", "endpoints"], value)) + else: + kwargs = {"startpoints": None, "endpoints": None, param: value} + + keypoints = make_keypoints() + + check_kernel( + F.perspective_keypoints, + keypoints, + canvas_size=keypoints.canvas_size, + **kwargs, + ) + + def test_kernel_keypoints_error(self): + keypoints = make_keypoints() + canvas_size = keypoints.canvas_size + keypoints = keypoints.as_subclass(torch.Tensor) + + with pytest.raises(RuntimeError, match="Denominator is zero"): + F.perspective_keypoints( + keypoints, + canvas_size=canvas_size, + startpoints=None, + endpoints=None, + coefficients=[0.0] * 8, + ) + @pytest.mark.parametrize("make_mask", [make_segmentation_mask, make_detection_masks]) def test_kernel_mask(self, make_mask): check_kernel(F.perspective_mask, make_mask(), **self.MINIMAL_KWARGS) @@ -4321,6 +4864,67 @@ def test_correctness_perspective_bounding_boxes(self, startpoints, endpoints, fo assert_close(actual, expected, rtol=0, atol=1) + def _reference_perspective_keypoints(self, keypoints, *, startpoints, endpoints): + canvas_size = keypoints.canvas_size + dtype = keypoints.dtype + device = keypoints.device + + coefficients = _get_perspective_coeffs(endpoints, startpoints) + + def perspective_keypoints(keypoints): + m1 = np.array( + [ + [coefficients[0], coefficients[1], coefficients[2]], + [coefficients[3], coefficients[4], coefficients[5]], + ] + ) + m2 = np.array( + [ + [coefficients[6], coefficients[7], 1.0], + [coefficients[6], coefficients[7], 1.0], + ] + ) + + # Go to float before converting to prevent precision loss + x, y = keypoints.to(dtype=torch.float64, device="cpu", copy=True).squeeze(0).tolist() + + points = np.array([[x, y, 1.0]]) + + numerator = points @ m1.T + denominator = points @ m2.T + transformed_points = numerator / denominator + + output = torch.Tensor( + [ + float(transformed_points[0, 0]), + float(transformed_points[0, 1]), + ] + ) + + # It is important to clamp before casting, especially for CXCYWH format, dtype=int64 + return F.clamp_keypoints( + output, + canvas_size=canvas_size, + ).to(dtype=dtype, device=device) + + return tv_tensors.KeyPoints( + torch.cat([perspective_keypoints(k) for k in keypoints.reshape(-1, 2).unbind()], dim=0).reshape( + keypoints.shape + ), + canvas_size=canvas_size, + ) + + @pytest.mark.parametrize(("startpoints", "endpoints"), START_END_POINTS) + @pytest.mark.parametrize("dtype", [torch.int64, torch.float32]) + @pytest.mark.parametrize("device", cpu_and_cuda()) + def test_correctness_perspective_keypoints(self, startpoints, endpoints, dtype, device): + keypoints = make_keypoints(dtype=dtype, device=device) + + actual = F.perspective(keypoints, startpoints=startpoints, endpoints=endpoints) + expected = self._reference_perspective_keypoints(keypoints, startpoints=startpoints, endpoints=endpoints) + + assert_close(actual, expected, rtol=0, atol=1) + class TestEqualize: @pytest.mark.parametrize("dtype", [torch.uint8, torch.float32]) @@ -6271,3 +6875,41 @@ def test_different_sizes(self, make_input1, make_input2, query): def test_no_valid_input(self, query): with pytest.raises(TypeError, match="No image"): query(["blah"]) + + @pytest.mark.parametrize( + "boxes", [ + tv_tensors.BoundingBoxes(torch.tensor([[1., 1., 2., 2.]]), format="XYXY", canvas_size=(4, 4)), # [boxes0] + tv_tensors.BoundingBoxes(torch.tensor([[1., 1., 1., 1.]]), format="XYWH", canvas_size=(4, 4)), # [boxes1] + tv_tensors.BoundingBoxes(torch.tensor([[1.5, 1.5, 1., 1.]]), format="CXCYWH", canvas_size=(4, 4)), # [boxes2] + tv_tensors.BoundingBoxes(torch.tensor([[1.5, 1.5, 1., 1., 45]]), format="CXCYWHR", canvas_size=(4, 4)), # [boxes3] + tv_tensors.BoundingBoxes(torch.tensor([[1., 1., 1., 1., 45.]]), format="XYWHR", canvas_size=(4, 4)), # [boxes4] + tv_tensors.BoundingBoxes(torch.tensor([[1., 1., 1., 2., 2., 2., 2., 1.]]), format="XY" * 4, canvas_size=(4, 4)), # [boxes5] + ] + ) + def test_convert_bounding_boxes_to_points(self, boxes: tv_tensors.BoundingBoxes): + kp = F.convert_bounding_boxes_to_points(boxes) + assert kp.shape == (boxes.shape[0], 4, 2) + assert kp.dtype == boxes.dtype + # kp is a list of A, B, C, D polygons. + + if F._meta.is_rotated_bounding_box_format(boxes.format): + # In the rotated case + # If we convert to XYXYXYXY format, we should get what we want. + reconverted = kp.reshape(-1, 8) + reconverted_bbox = F.convert_bounding_box_format( + tv_tensors.BoundingBoxes(reconverted, format=tv_tensors.BoundingBoxFormat.XYXYXYXY, canvas_size=kp.canvas_size), + new_format=boxes.format + ) + assert ((reconverted_bbox - boxes).abs() < 1e-5).all(), ( # Rotational computations mean that we can't ensure exactitude. + f"Invalid reconversion :\n\tGot: {reconverted_bbox}\n\tFrom: {boxes}\n\t" + f"Diff: {reconverted_bbox - boxes}" + ) + else: + # In the unrotated case + # If we use A | C, we should get back the XYXY format of bounding box + reconverted = torch.cat([kp[..., 0, :], kp[..., 2, :]], dim=-1) + reconverted_bbox = F.convert_bounding_box_format( + tv_tensors.BoundingBoxes(reconverted, format=tv_tensors.BoundingBoxFormat.XYXY, canvas_size=kp.canvas_size), + new_format=boxes.format, + ) + assert (reconverted_bbox == boxes).all(), f"Invalid reconversion :\n\tGot: {reconverted_bbox}\n\tFrom: {boxes}" diff --git a/test/test_transforms_v2_utils.py b/test/test_transforms_v2_utils.py index 53222c6a2c8..dab6d525a38 100644 --- a/test/test_transforms_v2_utils.py +++ b/test/test_transforms_v2_utils.py @@ -4,7 +4,7 @@ import torch import torchvision.transforms.v2._utils -from common_utils import DEFAULT_SIZE, make_bounding_boxes, make_detection_masks, make_image +from common_utils import DEFAULT_SIZE, make_bounding_boxes, make_detection_masks, make_image, make_keypoints from torchvision import tv_tensors from torchvision.transforms.v2._utils import has_all, has_any @@ -14,29 +14,32 @@ IMAGE = make_image(DEFAULT_SIZE, color_space="RGB") BOUNDING_BOX = make_bounding_boxes(DEFAULT_SIZE, format=tv_tensors.BoundingBoxFormat.XYXY) MASK = make_detection_masks(DEFAULT_SIZE) +KEYPOINTS = make_keypoints(DEFAULT_SIZE) @pytest.mark.parametrize( ("sample", "types", "expected"), [ - ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Image,), True), - ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.BoundingBoxes,), True), - ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Mask,), True), - ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Image, tv_tensors.BoundingBoxes), True), - ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Image, tv_tensors.Mask), True), - ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.BoundingBoxes, tv_tensors.Mask), True), - ((MASK,), (tv_tensors.Image, tv_tensors.BoundingBoxes), False), - ((BOUNDING_BOX,), (tv_tensors.Image, tv_tensors.Mask), False), - ((IMAGE,), (tv_tensors.BoundingBoxes, tv_tensors.Mask), False), + ((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (tv_tensors.Image,), True), + ((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (tv_tensors.BoundingBoxes,), True), + ((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (tv_tensors.Mask,), True), + ((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (tv_tensors.Image, tv_tensors.BoundingBoxes), True), + ((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (tv_tensors.Image, tv_tensors.Mask), True), + ((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (tv_tensors.BoundingBoxes, tv_tensors.Mask), True), + ((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (tv_tensors.KeyPoints,), True), + ((MASK,), (tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.KeyPoints), False), + ((BOUNDING_BOX,), (tv_tensors.Image, tv_tensors.Mask, tv_tensors.KeyPoints), False), + ((IMAGE,), (tv_tensors.BoundingBoxes, tv_tensors.Mask, tv_tensors.KeyPoints), False), + ((KEYPOINTS,), (tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask), False), ( - (IMAGE, BOUNDING_BOX, MASK), - (tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask), + (IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), + (tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask, tv_tensors.KeyPoints), True, ), - ((), (tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask), False), - ((IMAGE, BOUNDING_BOX, MASK), (lambda obj: isinstance(obj, tv_tensors.Image),), True), - ((IMAGE, BOUNDING_BOX, MASK), (lambda _: False,), False), - ((IMAGE, BOUNDING_BOX, MASK), (lambda _: True,), True), + ((), (tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask, tv_tensors.KeyPoints), False), + ((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (lambda obj: isinstance(obj, tv_tensors.Image),), True), + ((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (lambda _: False,), False), + ((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (lambda _: True,), True), ((IMAGE,), (tv_tensors.Image, PIL.Image.Image, torchvision.transforms.v2._utils.is_pure_tensor), True), ( (torch.Tensor(IMAGE),), @@ -57,15 +60,22 @@ def test_has_any(sample, types, expected): @pytest.mark.parametrize( ("sample", "types", "expected"), [ - ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Image,), True), - ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.BoundingBoxes,), True), - ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Mask,), True), - ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Image, tv_tensors.BoundingBoxes), True), - ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Image, tv_tensors.Mask), True), - ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.BoundingBoxes, tv_tensors.Mask), True), + ((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (tv_tensors.Image,), True), + ((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (tv_tensors.BoundingBoxes,), True), + ((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (tv_tensors.Mask,), True), + ((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (tv_tensors.Image, tv_tensors.BoundingBoxes), True), + ((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (tv_tensors.Image, tv_tensors.Mask), True), + ((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (tv_tensors.BoundingBoxes, tv_tensors.Mask), True), + ((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (tv_tensors.Mask, tv_tensors.KeyPoints), True), + ((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (tv_tensors.BoundingBoxes, tv_tensors.KeyPoints), True), ( - (IMAGE, BOUNDING_BOX, MASK), - (tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask), + (IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), + (tv_tensors.BoundingBoxes, tv_tensors.Mask, tv_tensors.KeyPoints), + True, + ), + ( + (IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), + (tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask, tv_tensors.KeyPoints), True, ), ((BOUNDING_BOX, MASK), (tv_tensors.Image, tv_tensors.BoundingBoxes), False), diff --git a/test/test_tv_tensors.py b/test/test_tv_tensors.py index a8e59ab7531..0c06bc9c929 100644 --- a/test/test_tv_tensors.py +++ b/test/test_tv_tensors.py @@ -2,7 +2,14 @@ import pytest import torch -from common_utils import assert_equal, make_bounding_boxes, make_image, make_segmentation_mask, make_video +from common_utils import ( + assert_equal, + make_bounding_boxes, + make_image, + make_keypoints, + make_segmentation_mask, + make_video, +) from PIL import Image from torchvision import tv_tensors @@ -49,6 +56,39 @@ def test_bbox_dim_error(): tv_tensors.BoundingBoxes(data_3d, format="XYXY", canvas_size=(32, 32)) +@pytest.mark.parametrize( + "data", + [ + torch.randint(0, 32, size=(5, 2)), + [ + [ + 0, + 0, + ], + [ + 2, + 2, + ], + ], + [ + 1, + 2, + ], + ], +) +def test_keypoints_instance(data): + kpoint = tv_tensors.KeyPoints(data, canvas_size=(32, 32)) + assert isinstance(kpoint, tv_tensors.KeyPoints) + assert type(kpoint) is tv_tensors.KeyPoints + assert kpoint.shape[-1] == 2 + + +def test_keypoints_shape_error(): + data_3d = [(0, 1, 2)] + with pytest.raises(ValueError, match="shape"): + tv_tensors.KeyPoints(torch.tensor(data_3d), canvas_size=(11, 7)) + + @pytest.mark.parametrize( ("data", "input_requires_grad", "expected_requires_grad"), [ @@ -68,7 +108,9 @@ def test_new_requires_grad(data, input_requires_grad, expected_requires_grad): assert tv_tensor.requires_grad is expected_requires_grad -@pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video]) +@pytest.mark.parametrize( + "make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video, make_keypoints] +) def test_isinstance(make_input): assert isinstance(make_input(), torch.Tensor) @@ -80,7 +122,9 @@ def test_wrapping_no_copy(): assert image.data_ptr() == tensor.data_ptr() -@pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video]) +@pytest.mark.parametrize( + "make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video, make_keypoints] +) def test_to_wrapping(make_input): dp = make_input() @@ -90,7 +134,9 @@ def test_to_wrapping(make_input): assert dp_to.dtype is torch.float64 -@pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video]) +@pytest.mark.parametrize( + "make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video, make_keypoints] +) @pytest.mark.parametrize("return_type", ["Tensor", "TVTensor"]) def test_to_tv_tensor_reference(make_input, return_type): tensor = torch.rand((3, 16, 16), dtype=torch.float64) @@ -104,7 +150,9 @@ def test_to_tv_tensor_reference(make_input, return_type): assert type(tensor) is torch.Tensor -@pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video]) +@pytest.mark.parametrize( + "make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video, make_keypoints] +) @pytest.mark.parametrize("return_type", ["Tensor", "TVTensor"]) def test_clone_wrapping(make_input, return_type): dp = make_input() @@ -116,7 +164,9 @@ def test_clone_wrapping(make_input, return_type): assert dp_clone.data_ptr() != dp.data_ptr() -@pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video]) +@pytest.mark.parametrize( + "make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video, make_keypoints] +) @pytest.mark.parametrize("return_type", ["Tensor", "TVTensor"]) def test_requires_grad__wrapping(make_input, return_type): dp = make_input(dtype=torch.float) @@ -131,7 +181,9 @@ def test_requires_grad__wrapping(make_input, return_type): assert dp_requires_grad.requires_grad -@pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video]) +@pytest.mark.parametrize( + "make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video, make_keypoints] +) @pytest.mark.parametrize("return_type", ["Tensor", "TVTensor"]) def test_detach_wrapping(make_input, return_type): dp = make_input(dtype=torch.float).requires_grad_(True) @@ -148,18 +200,25 @@ def test_force_subclass_with_metadata(return_type): # Largely the same as above, we additionally check that the metadata is preserved format, canvas_size = "XYXY", (32, 32) bbox = tv_tensors.BoundingBoxes([[0, 0, 5, 5], [2, 2, 7, 7]], format=format, canvas_size=canvas_size) + kpoints = tv_tensors.KeyPoints([[0, 0], [2, 2]], canvas_size=canvas_size) tv_tensors.set_return_type(return_type) bbox = bbox.clone() + kpoints = kpoints.clone() if return_type == "TVTensor": + assert kpoints.canvas_size == canvas_size assert bbox.format, bbox.canvas_size == (format, canvas_size) bbox = bbox.to(torch.float64) + kpoints = kpoints.to(torch.float64) if return_type == "TVTensor": + assert kpoints.canvas_size == canvas_size assert bbox.format, bbox.canvas_size == (format, canvas_size) bbox = bbox.detach() + kpoints = kpoints.detach() if return_type == "TVTensor": + assert kpoints.canvas_size == canvas_size assert bbox.format, bbox.canvas_size == (format, canvas_size) if torch.cuda.is_available(): @@ -168,14 +227,20 @@ def test_force_subclass_with_metadata(return_type): assert bbox.format, bbox.canvas_size == (format, canvas_size) assert not bbox.requires_grad + assert not kpoints.requires_grad bbox.requires_grad_(True) + kpoints.requires_grad_(True) if return_type == "TVTensor": + assert kpoints.canvas_size == canvas_size assert bbox.format, bbox.canvas_size == (format, canvas_size) assert bbox.requires_grad + assert kpoints.requires_grad tv_tensors.set_return_type("tensor") -@pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video]) +@pytest.mark.parametrize( + "make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video, make_keypoints] +) @pytest.mark.parametrize("return_type", ["Tensor", "TVTensor"]) def test_other_op_no_wrapping(make_input, return_type): dp = make_input() @@ -187,7 +252,9 @@ def test_other_op_no_wrapping(make_input, return_type): assert type(output) is (type(dp) if return_type == "TVTensor" else torch.Tensor) -@pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video]) +@pytest.mark.parametrize( + "make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video, make_keypoints] +) @pytest.mark.parametrize( "op", [ @@ -204,7 +271,9 @@ def test_no_tensor_output_op_no_wrapping(make_input, op): assert type(output) is not type(dp) -@pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video]) +@pytest.mark.parametrize( + "make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video, make_keypoints] +) @pytest.mark.parametrize("return_type", ["Tensor", "TVTensor"]) def test_inplace_op_no_wrapping(make_input, return_type): dp = make_input() diff --git a/torchvision/transforms/v2/_augment.py b/torchvision/transforms/v2/_augment.py index 02a487cabd3..980e27647f7 100644 --- a/torchvision/transforms/v2/_augment.py +++ b/torchvision/transforms/v2/_augment.py @@ -90,7 +90,7 @@ def __init__( self._log_ratio = torch.log(torch.tensor(self.ratio)) def _call_kernel(self, functional: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any: - if isinstance(inpt, (tv_tensors.BoundingBoxes, tv_tensors.Mask)): + if isinstance(inpt, (tv_tensors.BoundingBoxes, tv_tensors.KeyPoints, tv_tensors.Mask)): warnings.warn( f"{type(self).__name__}() is currently passing through inputs of type " f"tv_tensors.{type(inpt).__name__}. This will likely change in the future." @@ -157,7 +157,7 @@ def forward(self, *inputs): flat_inputs, spec = tree_flatten(inputs) needs_transform_list = self._needs_transform_list(flat_inputs) - if has_any(flat_inputs, PIL.Image.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask): + if has_any(flat_inputs, PIL.Image.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask, tv_tensors.KeyPoints): raise ValueError(f"{type(self).__name__}() does not support PIL images, bounding boxes and masks.") labels = self._labels_getter(inputs) diff --git a/torchvision/transforms/v2/_auto_augment.py b/torchvision/transforms/v2/_auto_augment.py index c743eb40775..52707af1f2e 100644 --- a/torchvision/transforms/v2/_auto_augment.py +++ b/torchvision/transforms/v2/_auto_augment.py @@ -46,7 +46,7 @@ def _get_random_item(self, dct: dict[str, tuple[Callable, bool]]) -> tuple[str, def _flatten_and_extract_image_or_video( self, inputs: Any, - unsupported_types: tuple[type, ...] = (tv_tensors.BoundingBoxes, tv_tensors.Mask), + unsupported_types: tuple[type, ...] = (tv_tensors.BoundingBoxes, tv_tensors.Mask, tv_tensors.KeyPoints), ) -> tuple[tuple[list[Any], TreeSpec, int], ImageOrVideo]: flat_inputs, spec = tree_flatten(inputs if len(inputs) > 1 else inputs[0]) needs_transform_list = self._needs_transform_list(flat_inputs) diff --git a/torchvision/transforms/v2/_geometry.py b/torchvision/transforms/v2/_geometry.py index 86c00e28a66..e1ed436ba36 100644 --- a/torchvision/transforms/v2/_geometry.py +++ b/torchvision/transforms/v2/_geometry.py @@ -357,7 +357,7 @@ def __init__(self, size: Union[int, Sequence[int]]) -> None: self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") def _call_kernel(self, functional: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any: - if isinstance(inpt, (tv_tensors.BoundingBoxes, tv_tensors.Mask)): + if isinstance(inpt, (tv_tensors.BoundingBoxes, tv_tensors.KeyPoints, tv_tensors.Mask)): warnings.warn( f"{type(self).__name__}() is currently passing through inputs of type " f"tv_tensors.{type(inpt).__name__}. This will likely change in the future." @@ -402,7 +402,7 @@ def __init__(self, size: Union[int, Sequence[int]], vertical_flip: bool = False) self.vertical_flip = vertical_flip def _call_kernel(self, functional: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any: - if isinstance(inpt, (tv_tensors.BoundingBoxes, tv_tensors.Mask)): + if isinstance(inpt, (tv_tensors.BoundingBoxes, tv_tensors.KeyPoints, tv_tensors.Mask)): warnings.warn( f"{type(self).__name__}() is currently passing through inputs of type " f"tv_tensors.{type(inpt).__name__}. This will likely change in the future." diff --git a/torchvision/transforms/v2/_misc.py b/torchvision/transforms/v2/_misc.py index dfd521b13be..d6d61fa6d6c 100644 --- a/torchvision/transforms/v2/_misc.py +++ b/torchvision/transforms/v2/_misc.py @@ -341,9 +341,9 @@ def transform(self, inpt: Any, params: dict[str, Any]) -> Any: class SanitizeBoundingBoxes(Transform): - """Remove degenerate/invalid bounding boxes and their corresponding labels and masks. + """Remove degenerate/invalid bounding boxes and their corresponding labels, masks and keypoints. - This transform removes bounding boxes and their associated labels/masks that: + This transform removes bounding boxes and their associated labels, masks and keypoints that: - are below a given ``min_size`` or ``min_area``: by default this also removes degenerate boxes that have e.g. X2 <= X1. - have any coordinate outside of their corresponding image. You may want to @@ -359,6 +359,14 @@ class SanitizeBoundingBoxes(Transform): may modify bounding boxes but once at the end should be enough in most cases. + .. note:: + This transform requires that any :class:`~torchvision.tv_tensor.KeyPoints` or + :class:`~torchvision.tv_tensor.Mask` provided has to match the bounding boxes in shape. + + If the bounding boxes are of shape ``[N, K]``, then the + KeyPoints have to be of shape ``[N, ..., 2]`` or ``[N, 2]`` + and the masks have to be of shape ``[N, ..., H, W]`` or ``[N, H, W]`` + Args: min_size (float, optional): The size below which bounding boxes are removed. Default is 1. min_area (float, optional): The area below which bounding boxes are removed. Default is 1. @@ -438,10 +446,15 @@ def forward(self, *inputs: Any) -> Any: return tree_unflatten(flat_outputs, spec) def transform(self, inpt: Any, params: dict[str, Any]) -> Any: + # For every object in the flattened input of the `forward` method, we apply transform + # The params contain the list of valid indices of the (N, K) bbox set + + # We suppose here that any KeyPoints or Masks TVTensors is of shape (N, ..., 2) and (N, ..., H, W) respectively + # TODO: check this. is_label = params["labels"] is not None and any(inpt is label for label in params["labels"]) - is_bounding_boxes_or_mask = isinstance(inpt, (tv_tensors.BoundingBoxes, tv_tensors.Mask)) + is_bbox_mask_or_kpoints = isinstance(inpt, (tv_tensors.BoundingBoxes, tv_tensors.Mask, tv_tensors.KeyPoints)) - if not (is_label or is_bounding_boxes_or_mask): + if not (is_label or is_bbox_mask_or_kpoints): return inpt output = inpt[params["valid"]] diff --git a/torchvision/transforms/v2/_utils.py b/torchvision/transforms/v2/_utils.py index c4371ce0953..fd41b222b19 100644 --- a/torchvision/transforms/v2/_utils.py +++ b/torchvision/transforms/v2/_utils.py @@ -165,6 +165,18 @@ def get_bounding_boxes(flat_inputs: list[Any]) -> tv_tensors.BoundingBoxes: raise ValueError("No bounding boxes were found in the sample") +def get_keypoints(flat_inputs: list[Any]) -> tv_tensors.KeyPoints: + """Returns the KeyPoints in the input. + + Assumes only one ``KeyPoints`` object is present + """ + generator = (inpt for inpt in flat_inputs if isinstance(inpt, tv_tensors.KeyPoints)) + try: + return next(generator) + except StopIteration: + raise ValueError("No Keypoints were found in the sample.") + + def query_chw(flat_inputs: list[Any]) -> tuple[int, int, int]: """Return Channel, Height, and Width.""" chws = { @@ -194,6 +206,7 @@ def query_size(flat_inputs: list[Any]) -> tuple[int, int]: tv_tensors.Video, tv_tensors.Mask, tv_tensors.BoundingBoxes, + tv_tensors.KeyPoints, ), ) } diff --git a/torchvision/transforms/v2/functional/__init__.py b/torchvision/transforms/v2/functional/__init__.py index d5705d55c4b..e651bbd9257 100644 --- a/torchvision/transforms/v2/functional/__init__.py +++ b/torchvision/transforms/v2/functional/__init__.py @@ -4,7 +4,9 @@ from ._meta import ( clamp_bounding_boxes, + clamp_keypoints, convert_bounding_box_format, + convert_bounding_boxes_to_points, get_dimensions_image, get_dimensions_video, get_dimensions, @@ -15,6 +17,7 @@ get_num_channels_video, get_num_channels, get_size_bounding_boxes, + get_size_keypoints, get_size_image, get_size_mask, get_size_video, @@ -69,21 +72,25 @@ affine, affine_bounding_boxes, affine_image, + affine_keypoints, affine_mask, affine_video, center_crop, center_crop_bounding_boxes, center_crop_image, + center_crop_keypoints, center_crop_mask, center_crop_video, crop, crop_bounding_boxes, crop_image, + crop_keypoints, crop_mask, crop_video, elastic, elastic_bounding_boxes, elastic_image, + elastic_keypoints, elastic_mask, elastic_transform, elastic_video, @@ -94,31 +101,37 @@ horizontal_flip, horizontal_flip_bounding_boxes, horizontal_flip_image, + horizontal_flip_keypoints, horizontal_flip_mask, horizontal_flip_video, pad, pad_bounding_boxes, pad_image, + pad_keypoints, pad_mask, pad_video, perspective, perspective_bounding_boxes, perspective_image, + perspective_keypoints, perspective_mask, perspective_video, resize, resize_bounding_boxes, resize_image, + resize_keypoints, resize_mask, resize_video, resized_crop, resized_crop_bounding_boxes, resized_crop_image, + resized_crop_keypoints, resized_crop_mask, resized_crop_video, rotate, rotate_bounding_boxes, rotate_image, + rotate_keypoints, rotate_mask, rotate_video, ten_crop, @@ -127,6 +140,7 @@ vertical_flip, vertical_flip_bounding_boxes, vertical_flip_image, + vertical_flip_keypoints, vertical_flip_mask, vertical_flip_video, vflip, @@ -143,6 +157,7 @@ normalize_image, normalize_video, sanitize_bounding_boxes, + sanitize_keypoints, to_dtype, to_dtype_image, to_dtype_video, diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index 8303019e011..448199dbe0c 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -23,7 +23,7 @@ from torchvision.utils import _log_api_usage_once -from ._meta import _get_size_image_pil, clamp_bounding_boxes, convert_bounding_box_format +from ._meta import _get_size_image_pil, clamp_bounding_boxes, clamp_keypoints, convert_bounding_box_format from ._utils import _FillTypeJIT, _get_kernel, _register_five_ten_crop_kernel_internal, _register_kernel_internal @@ -66,6 +66,19 @@ def horizontal_flip_mask(mask: torch.Tensor) -> torch.Tensor: return horizontal_flip_image(mask) +def horizontal_flip_keypoints(keypoints: torch.Tensor, canvas_size: tuple[int, int]): + shape = keypoints.shape + keypoints = keypoints.clone().reshape(-1, 2) + keypoints[..., 0] = keypoints[..., 0].sub_(canvas_size[1]).neg_() + return keypoints.reshape(shape) + + +@_register_kernel_internal(horizontal_flip, tv_tensors.KeyPoints, tv_tensor_wrapper=False) +def _horizontal_flip_keypoints_dispatch(keypoints: tv_tensors.KeyPoints): + out = horizontal_flip_keypoints(keypoints.as_subclass(torch.Tensor), canvas_size=keypoints.canvas_size) + return tv_tensors.wrap(out, like=keypoints) + + def horizontal_flip_bounding_boxes( bounding_boxes: torch.Tensor, format: tv_tensors.BoundingBoxFormat, canvas_size: tuple[int, int] ) -> torch.Tensor: @@ -123,6 +136,13 @@ def vertical_flip_mask(mask: torch.Tensor) -> torch.Tensor: return vertical_flip_image(mask) +def vertical_flip_keypoints(keypoints: torch.Tensor, canvas_size: tuple[int, int]) -> torch.Tensor: + shape = keypoints.shape + keypoints = keypoints.clone().reshape(-1, 2) + keypoints[..., 1] = keypoints[..., 1].sub_(canvas_size[0]).neg_() + return keypoints.reshape(shape) + + def vertical_flip_bounding_boxes( bounding_boxes: torch.Tensor, format: tv_tensors.BoundingBoxFormat, canvas_size: tuple[int, int] ) -> torch.Tensor: @@ -140,6 +160,12 @@ def vertical_flip_bounding_boxes( return bounding_boxes.reshape(shape) +@_register_kernel_internal(vertical_flip, tv_tensors.KeyPoints, tv_tensor_wrapper=False) +def _vertical_flip_keypoints_dispatch(inpt: tv_tensors.KeyPoints) -> tv_tensors.KeyPoints: + output = vertical_flip_keypoints(inpt.as_subclass(torch.Tensor), canvas_size=inpt.canvas_size) + return tv_tensors.wrap(output, like=inpt) + + @_register_kernel_internal(vertical_flip, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False) def _vertical_flip_bounding_boxes_dispatch(inpt: tv_tensors.BoundingBoxes) -> tv_tensors.BoundingBoxes: output = vertical_flip_bounding_boxes( @@ -334,6 +360,41 @@ def _resize_mask_dispatch( return tv_tensors.wrap(output, like=inpt) +def resize_keypoints( + keypoints: torch.Tensor, + size: Optional[list[int]], + canvas_size: tuple[int, int], + max_size: Optional[int] = None, +): + old_height, old_width = canvas_size + new_height, new_width = _compute_resized_output_size(canvas_size, size=size, max_size=max_size) + + if (new_height, new_width) == (old_height, old_width): + return keypoints, canvas_size + + w_ratio = new_width / old_width + h_ratio = new_height / old_height + ratios = torch.tensor([w_ratio, h_ratio], device=keypoints.device) + keypoints = keypoints.mul(ratios).to(keypoints.dtype) + + return keypoints, (new_height, new_width) + + +@_register_kernel_internal(resize, tv_tensors.KeyPoints, tv_tensor_wrapper=False) +def _resize_keypoints_dispatch( + keypoints: tv_tensors.KeyPoints, + size: Optional[list[int]], + max_size: Optional[int] = None, +) -> tv_tensors.KeyPoints: + out, canvas_size = resize_keypoints( + keypoints.as_subclass(torch.Tensor), + size, + canvas_size=keypoints.canvas_size, + max_size=max_size, + ) + return tv_tensors.wrap(out, like=keypoints, canvas_size=canvas_size) + + def resize_bounding_boxes( bounding_boxes: torch.Tensor, canvas_size: tuple[int, int], @@ -759,6 +820,122 @@ def _affine_image_pil( return _FP.affine(image, matrix, interpolation=pil_modes_mapping[interpolation], fill=fill) +def _affine_keypoints_with_expand( + keypoints: torch.Tensor, + canvas_size: tuple[int, int], + angle: Union[int, float], + translate: list[float], + scale: float, + shear: list[float], + center: Optional[list[float]] = None, + expand: bool = False, +) -> tuple[torch.Tensor, tuple[int, int]]: + if keypoints.numel() == 0: + return keypoints, canvas_size + + original_dtype = keypoints.dtype + original_shape = keypoints.shape + keypoints = keypoints.clone() if keypoints.is_floating_point() else keypoints.float() + dtype = keypoints.dtype + device = keypoints.device + + angle, translate, shear, center = _affine_parse_args( + angle, translate, scale, shear, InterpolationMode.NEAREST, center + ) + + if center is None: + height, width = canvas_size + center = [width * 0.5, height * 0.5] + + affine_vector = _get_inverse_affine_matrix(center, angle, translate, scale, shear, inverted=False) + transposed_affine_matrix = ( + torch.tensor( + affine_vector, + dtype=dtype, + device=device, + ) + .reshape(2, 3) + .T + ) + + # 1) We transform points into a tensor of points with shape (N, 3), where N is the number of points. + points = keypoints.reshape(-1, 2) + points = torch.cat([points, torch.ones(points.shape[0], 1, device=device, dtype=dtype)], dim=-1) + # 2) Now let's transform the points using affine matrix + transformed_points = torch.matmul(points, transposed_affine_matrix) + + if expand: + # Compute minimum point for transformed image frame: + # Points are Top-Left, Top-Right, Bottom-Left, Bottom-Right points. + height, width = canvas_size + points = torch.tensor( + [ + [0.0, 0.0, 1.0], + [0.0, float(height), 1.0], + [float(width), float(height), 1.0], + [float(width), 0.0, 1.0], + ], + dtype=dtype, + device=device, + ) + new_points = torch.matmul(points, transposed_affine_matrix) + tr = torch.amin(new_points, dim=0, keepdim=True) + # Translate keypoints + transformed_points.sub_(tr) + # Estimate meta-data for image with inverted=True + affine_vector = _get_inverse_affine_matrix(center, angle, translate, scale, shear) + new_width, new_height = _compute_affine_output_size(affine_vector, width, height) + canvas_size = (new_height, new_width) + + out_kkpoints = clamp_keypoints(transformed_points, canvas_size=canvas_size).reshape(original_shape) + out_kkpoints = out_kkpoints.to(original_dtype) + + return out_kkpoints, canvas_size + + +def affine_keypoints( + keypoints: torch.Tensor, + canvas_size: tuple[int, int], + angle: Union[int, float], + translate: list[float], + scale: float, + shear: list[float], + center: Optional[list[float]] = None, +): + return _affine_keypoints_with_expand( + keypoints=keypoints, + canvas_size=canvas_size, + angle=angle, + translate=translate, + scale=scale, + shear=shear, + center=center, + expand=False, + ) + + +@_register_kernel_internal(affine, tv_tensors.KeyPoints, tv_tensor_wrapper=False) +def _affine_keypoints_dispatch( + inpt: tv_tensors.KeyPoints, + angle: Union[int, float], + translate: list[float], + scale: float, + shear: list[float], + center: Optional[list[float]] = None, + **kwargs, +) -> tv_tensors.KeyPoints: + output, canvas_size = affine_keypoints( + inpt.as_subclass(torch.Tensor), + canvas_size=inpt.canvas_size, + angle=angle, + translate=translate, + scale=scale, + shear=shear, + center=center, + ) + return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size) + + def _affine_bounding_boxes_with_expand( bounding_boxes: torch.Tensor, format: tv_tensors.BoundingBoxFormat, @@ -1056,6 +1233,35 @@ def _rotate_image_pil( ) +def rotate_keypoints( + keypoints: torch.Tensor, + canvas_size: tuple[int, int], + angle: float, + expand: bool = False, + center: Optional[list[float]] = None, +) -> tuple[torch.Tensor, tuple[int, int]]: + return _affine_keypoints_with_expand( + keypoints=keypoints, + canvas_size=canvas_size, + angle=-angle, + translate=[0.0, 0.0], + scale=1.0, + shear=[0.0, 0.0], + center=center, + expand=expand, + ) + + +@_register_kernel_internal(rotate, tv_tensors.KeyPoints, tv_tensor_wrapper=False) +def _rotate_keypoints_dispatch( + inpt: tv_tensors.KeyPoints, angle: float, expand: bool = False, center: Optional[list[float]] = None, **kwargs +) -> tv_tensors.KeyPoints: + output, canvas_size = rotate_keypoints( + inpt, canvas_size=inpt.canvas_size, angle=angle, center=center, expand=expand + ) + return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size) + + def rotate_bounding_boxes( bounding_boxes: torch.Tensor, format: tv_tensors.BoundingBoxFormat, @@ -1319,6 +1525,35 @@ def pad_mask( return output +def pad_keypoints( + keypoints: torch.Tensor, canvas_size: tuple[int, int], padding: list[int], padding_mode: str = "constant" +): + SUPPORTED_MODES = ["constant"] + if padding_mode not in SUPPORTED_MODES: + # TODO: add support of other padding modes + raise ValueError( + f"Padding mode '{padding_mode}' is not supported with KeyPoints" + f" (supported modes are {', '.join(SUPPORTED_MODES)})" + ) + left, right, top, bottom = _parse_pad_padding(padding) + pad = torch.tensor([left, top], dtype=keypoints.dtype, device=keypoints.device) + canvas_size = (canvas_size[0] + top + bottom, canvas_size[1] + left + right) + return clamp_keypoints(keypoints + pad, canvas_size), canvas_size + + +@_register_kernel_internal(pad, tv_tensors.KeyPoints, tv_tensor_wrapper=False) +def _pad_keypoints_dispatch( + keypoints: tv_tensors.KeyPoints, padding: list[int], padding_mode: str = "constant", **kwargs +) -> tv_tensors.KeyPoints: + output, canvas_size = pad_keypoints( + keypoints.as_subclass(torch.Tensor), + canvas_size=keypoints.canvas_size, + padding=padding, + padding_mode=padding_mode, + ) + return tv_tensors.wrap(output, like=keypoints, canvas_size=canvas_size) + + def pad_bounding_boxes( bounding_boxes: torch.Tensor, format: tv_tensors.BoundingBoxFormat, @@ -1405,6 +1640,28 @@ def crop_image(image: torch.Tensor, top: int, left: int, height: int, width: int _register_kernel_internal(crop, PIL.Image.Image)(_crop_image_pil) +def crop_keypoints( + keypoints: torch.Tensor, + top: int, + left: int, + height: int, + width: int, +) -> tuple[torch.Tensor, tuple[int, int]]: + + keypoints = keypoints - torch.tensor([left, top], dtype=keypoints.dtype, device=keypoints.device) + canvas_size = (height, width) + + return clamp_keypoints(keypoints, canvas_size=canvas_size), canvas_size + + +@_register_kernel_internal(crop, tv_tensors.KeyPoints, tv_tensor_wrapper=False) +def _crop_keypoints_dispatch( + inpt: tv_tensors.KeyPoints, top: int, left: int, height: int, width: int +) -> tv_tensors.KeyPoints: + output, canvas_size = crop_keypoints(inpt.as_subclass(torch.Tensor), top=top, left=left, height=height, width=width) + return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size) + + def crop_bounding_boxes( bounding_boxes: torch.Tensor, format: tv_tensors.BoundingBoxFormat, @@ -1578,6 +1835,56 @@ def _perspective_image_pil( return _FP.perspective(image, perspective_coeffs, interpolation=pil_modes_mapping[interpolation], fill=fill) +def perspective_keypoints( + keypoints: torch.Tensor, + canvas_size: tuple[int, int], + startpoints: Optional[list[list[int]]], + endpoints: Optional[list[list[int]]], + coefficients: Optional[list[float]] = None, +): + if keypoints.numel() == 0: + return keypoints + dtype = keypoints.dtype if torch.is_floating_point(keypoints) else torch.float32 + device = keypoints.device + original_shape = keypoints.shape + + keypoints = keypoints.clone().reshape(-1, 2) + perspective_coeffs = _perspective_coefficients(startpoints, endpoints, coefficients) + + denom = perspective_coeffs[0] * perspective_coeffs[4] - perspective_coeffs[1] * perspective_coeffs[3] + if denom == 0: + raise RuntimeError( + f"Provided perspective_coeffs {perspective_coeffs} can not be inverted to transform keypoints. " + f"Denominator is zero, denom={denom}" + ) + + theta1, theta2 = _compute_perspective_thetas(perspective_coeffs, dtype, device, denom) + points = torch.cat([keypoints, torch.ones(keypoints.shape[0], 1, device=keypoints.device)], dim=-1) + + numer_points = torch.matmul(points, theta1.T) + denom_points = torch.matmul(points, theta2.T) + transformed_points = numer_points.div_(denom_points) + return clamp_keypoints(transformed_points.to(keypoints.dtype), canvas_size).reshape(original_shape) + + +@_register_kernel_internal(perspective, tv_tensors.KeyPoints, tv_tensor_wrapper=False) +def _perspective_keypoints_dispatch( + inpt: tv_tensors.KeyPoints, + startpoints: Optional[list[list[int]]], + endpoints: Optional[list[list[int]]], + coefficients: Optional[list[float]] = None, + **kwargs, +) -> tv_tensors.KeyPoints: + output = perspective_keypoints( + inpt.as_subclass(torch.Tensor), + canvas_size=inpt.canvas_size, + startpoints=startpoints, + endpoints=endpoints, + coefficients=coefficients, + ) + return tv_tensors.wrap(output, like=inpt) + + def perspective_bounding_boxes( bounding_boxes: torch.Tensor, format: tv_tensors.BoundingBoxFormat, @@ -1619,26 +1926,7 @@ def perspective_bounding_boxes( f"Denominator is zero, denom={denom}" ) - inv_coeffs = [ - (perspective_coeffs[4] - perspective_coeffs[5] * perspective_coeffs[7]) / denom, - (-perspective_coeffs[1] + perspective_coeffs[2] * perspective_coeffs[7]) / denom, - (perspective_coeffs[1] * perspective_coeffs[5] - perspective_coeffs[2] * perspective_coeffs[4]) / denom, - (-perspective_coeffs[3] + perspective_coeffs[5] * perspective_coeffs[6]) / denom, - (perspective_coeffs[0] - perspective_coeffs[2] * perspective_coeffs[6]) / denom, - (-perspective_coeffs[0] * perspective_coeffs[5] + perspective_coeffs[2] * perspective_coeffs[3]) / denom, - (-perspective_coeffs[4] * perspective_coeffs[6] + perspective_coeffs[3] * perspective_coeffs[7]) / denom, - (-perspective_coeffs[0] * perspective_coeffs[7] + perspective_coeffs[1] * perspective_coeffs[6]) / denom, - ] - - theta1 = torch.tensor( - [[inv_coeffs[0], inv_coeffs[1], inv_coeffs[2]], [inv_coeffs[3], inv_coeffs[4], inv_coeffs[5]]], - dtype=dtype, - device=device, - ) - - theta2 = torch.tensor( - [[inv_coeffs[6], inv_coeffs[7], 1.0], [inv_coeffs[6], inv_coeffs[7], 1.0]], dtype=dtype, device=device - ) + theta1, theta2 = _compute_perspective_thetas(perspective_coeffs, dtype, device, denom) # 1) Let's transform bboxes into a tensor of 4 points (top-left, top-right, bottom-left, bottom-right corners). # Tensor of points has shape (N * 4, 3), where N is the number of bboxes @@ -1672,6 +1960,36 @@ def perspective_bounding_boxes( ).reshape(original_shape) +def _compute_perspective_thetas( + perspective_coeffs: list[float], + dtype: torch.dtype, + device: torch.device, + denom: float, +) -> tuple[torch.Tensor, torch.Tensor]: + inv_coeffs = [ + (perspective_coeffs[4] - perspective_coeffs[5] * perspective_coeffs[7]) / denom, + (-perspective_coeffs[1] + perspective_coeffs[2] * perspective_coeffs[7]) / denom, + (perspective_coeffs[1] * perspective_coeffs[5] - perspective_coeffs[2] * perspective_coeffs[4]) / denom, + (-perspective_coeffs[3] + perspective_coeffs[5] * perspective_coeffs[6]) / denom, + (perspective_coeffs[0] - perspective_coeffs[2] * perspective_coeffs[6]) / denom, + (-perspective_coeffs[0] * perspective_coeffs[5] + perspective_coeffs[2] * perspective_coeffs[3]) / denom, + (-perspective_coeffs[4] * perspective_coeffs[6] + perspective_coeffs[3] * perspective_coeffs[7]) / denom, + (-perspective_coeffs[0] * perspective_coeffs[7] + perspective_coeffs[1] * perspective_coeffs[6]) / denom, + ] + + theta1 = torch.tensor( + [[inv_coeffs[0], inv_coeffs[1], inv_coeffs[2]], [inv_coeffs[3], inv_coeffs[4], inv_coeffs[5]]], + dtype=dtype, + device=device, + ) + + theta2 = torch.tensor( + [[inv_coeffs[6], inv_coeffs[7], 1.0], [inv_coeffs[6], inv_coeffs[7], 1.0]], dtype=dtype, device=device + ) + + return theta1, theta2 + + @_register_kernel_internal(perspective, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False) def _perspective_bounding_boxes_dispatch( inpt: tv_tensors.BoundingBoxes, @@ -1832,6 +2150,48 @@ def _create_identity_grid(size: tuple[int, int], device: torch.device, dtype: to return base_grid +def elastic_keypoints( + keypoints: torch.Tensor, canvas_size: tuple[int, int], displacement: torch.Tensor +) -> torch.Tensor: + expected_shape = (1, canvas_size[0], canvas_size[1], 2) + if not isinstance(displacement, torch.Tensor): + raise TypeError("Argument displacement should be a Tensor") + elif displacement.shape != expected_shape: + raise ValueError(f"Argument displacement shape should be {expected_shape}, but given {displacement.shape}") + + if keypoints.numel() == 0: + return keypoints + + device = keypoints.device + dtype = keypoints.dtype if torch.is_floating_point(keypoints) else torch.float32 + + if displacement.dtype != dtype or displacement.device != device: + displacement = displacement.to(dtype=dtype, device=device) + + original_shape = keypoints.shape + keypoints = keypoints.clone().reshape(-1, 2) + + id_grid = _create_identity_grid(canvas_size, device=device, dtype=dtype) + inv_grid = id_grid.sub_(displacement) + + index_xy = keypoints.to(dtype=torch.long) + index_x, index_y = index_xy[:, 0], index_xy[:, 1] + # Unlike bounding boxes, this may not work well. + index_x.clamp_(0, inv_grid.shape[2] - 1) + index_y.clamp_(0, inv_grid.shape[1] - 1) + + t_size = torch.tensor(canvas_size[::-1], device=displacement.device, dtype=displacement.dtype) + transformed_points = inv_grid[0, index_y, index_x, :].add_(1).mul_(0.5 * t_size).sub_(0.5) + + return clamp_keypoints(transformed_points.to(keypoints.dtype), canvas_size=canvas_size).reshape(original_shape) + + +@_register_kernel_internal(elastic, tv_tensors.KeyPoints, tv_tensor_wrapper=False) +def _elastic_keypoints_dispatch(inpt: tv_tensors.KeyPoints, displacement: torch.Tensor, **kwargs): + output = elastic_keypoints(inpt.as_subclass(torch.Tensor), canvas_size=inpt.canvas_size, displacement=displacement) + return tv_tensors.wrap(output, like=inpt) + + def elastic_bounding_boxes( bounding_boxes: torch.Tensor, format: tv_tensors.BoundingBoxFormat, @@ -2012,6 +2372,20 @@ def _center_crop_image_pil(image: PIL.Image.Image, output_size: list[int]) -> PI return _crop_image_pil(image, crop_top, crop_left, crop_height, crop_width) +def center_crop_keypoints(inpt: torch.Tensor, canvas_size: tuple[int, int], output_size: list[int]): + crop_height, crop_width = _center_crop_parse_output_size(output_size) + crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, *canvas_size) + return crop_keypoints(inpt, top=crop_top, left=crop_left, height=crop_height, width=crop_width) + + +@_register_kernel_internal(center_crop, tv_tensors.KeyPoints, tv_tensor_wrapper=False) +def _center_crop_keypoints_dispatch(inpt: tv_tensors.KeyPoints, output_size: list[int]) -> tv_tensors.KeyPoints: + output, canvas_size = center_crop_keypoints( + inpt.as_subclass(torch.Tensor), canvas_size=inpt.canvas_size, output_size=output_size + ) + return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size) + + def center_crop_bounding_boxes( bounding_boxes: torch.Tensor, format: tv_tensors.BoundingBoxFormat, @@ -2147,6 +2521,28 @@ def _resized_crop_image_pil_dispatch( ) +def resized_crop_keypoints( + keypoints: torch.Tensor, + top: int, + left: int, + height: int, + width: int, + size: list[int], +) -> tuple[torch.Tensor, tuple[int, int]]: + keypoints, canvas_size = crop_keypoints(keypoints, top, left, height, width) + return resize_keypoints(keypoints, size=size, canvas_size=canvas_size) + + +@_register_kernel_internal(resized_crop, tv_tensors.KeyPoints, tv_tensor_wrapper=False) +def _resized_crop_keypoints_dispatch( + inpt: tv_tensors.BoundingBoxes, top: int, left: int, height: int, width: int, size: list[int], **kwargs +): + output, canvas_size = resized_crop_keypoints( + inpt.as_subclass(torch.Tensor), top=top, left=left, height=height, width=width, size=size + ) + return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size) + + def resized_crop_bounding_boxes( bounding_boxes: torch.Tensor, format: tv_tensors.BoundingBoxFormat, @@ -2281,9 +2677,7 @@ def five_crop_video( return five_crop_image(video, size) -def ten_crop( - inpt: torch.Tensor, size: list[int], vertical_flip: bool = False -) -> tuple[ +def ten_crop(inpt: torch.Tensor, size: list[int], vertical_flip: bool = False) -> tuple[ torch.Tensor, torch.Tensor, torch.Tensor, @@ -2307,9 +2701,7 @@ def ten_crop( @_register_five_ten_crop_kernel_internal(ten_crop, torch.Tensor) @_register_five_ten_crop_kernel_internal(ten_crop, tv_tensors.Image) -def ten_crop_image( - image: torch.Tensor, size: list[int], vertical_flip: bool = False -) -> tuple[ +def ten_crop_image(image: torch.Tensor, size: list[int], vertical_flip: bool = False) -> tuple[ torch.Tensor, torch.Tensor, torch.Tensor, @@ -2334,9 +2726,7 @@ def ten_crop_image( @_register_five_ten_crop_kernel_internal(ten_crop, PIL.Image.Image) -def _ten_crop_image_pil( - image: PIL.Image.Image, size: list[int], vertical_flip: bool = False -) -> tuple[ +def _ten_crop_image_pil(image: PIL.Image.Image, size: list[int], vertical_flip: bool = False) -> tuple[ PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, @@ -2361,9 +2751,7 @@ def _ten_crop_image_pil( @_register_five_ten_crop_kernel_internal(ten_crop, tv_tensors.Video) -def ten_crop_video( - video: torch.Tensor, size: list[int], vertical_flip: bool = False -) -> tuple[ +def ten_crop_video(video: torch.Tensor, size: list[int], vertical_flip: bool = False) -> tuple[ torch.Tensor, torch.Tensor, torch.Tensor, diff --git a/torchvision/transforms/v2/functional/_meta.py b/torchvision/transforms/v2/functional/_meta.py index 31dae9a1a81..d6699235572 100644 --- a/torchvision/transforms/v2/functional/_meta.py +++ b/torchvision/transforms/v2/functional/_meta.py @@ -121,6 +121,11 @@ def get_size_bounding_boxes(bounding_box: tv_tensors.BoundingBoxes) -> list[int] return list(bounding_box.canvas_size) +@_register_kernel_internal(get_size, tv_tensors.KeyPoints, tv_tensor_wrapper=False) +def get_size_keypoints(keypoints: tv_tensors.KeyPoints) -> list[int]: + return list(keypoints.canvas_size) + + def get_num_frames(inpt: torch.Tensor) -> int: if torch.jit.is_scripting(): return get_num_frames_video(inpt) @@ -176,6 +181,49 @@ def _xyxy_to_cxcywh(xyxy: torch.Tensor, inplace: bool) -> torch.Tensor: return xyxy +def _xyxy_to_keypoints(bounding_boxes: torch.Tensor) -> torch.Tensor: + return bounding_boxes[:, [[0, 1], [2, 1], [2, 3], [0, 3]]] + + +def _xyxyxyxy_to_keypoints(bounding_boxes: torch.Tensor) -> torch.Tensor: + return bounding_boxes[:, [[0, 1], [2, 3], [4, 5], [6, 7]]] + + +def convert_bounding_boxes_to_points(bounding_boxes: tv_tensors.BoundingBoxes) -> tv_tensors.KeyPoints: + """Converts a set of bounding boxes to its edge points. + + .. note:: + + This handles rotated :class:`tv_tensors.BoundingBoxes` formats + by first converting them to XYXYXYXY format. + + Due to floating-point approximation, this may not be an exact computation. + + Args: + bounding_boxes (tv_tensors.BoundingBoxes): A set of ``N`` bounding boxes (of shape ``[N, 4]``) + + Returns: + tv_tensors.KeyPoints: The edges, as a polygon of shape ``[N, 4, 2]`` + """ + if is_rotated_bounding_box_format(bounding_boxes.format): + # We are working on a rotated bounding box + bbox = _convert_bounding_box_format( + bounding_boxes.as_subclass(torch.Tensor), + old_format=bounding_boxes.format, + new_format=BoundingBoxFormat.XYXYXYXY, + inplace=False, + ) + return tv_tensors.KeyPoints(_xyxyxyxy_to_keypoints(bbox), canvas_size=bounding_boxes.canvas_size) + + bbox = _convert_bounding_box_format( + bounding_boxes.as_subclass(torch.Tensor), + old_format=bounding_boxes.format, + new_format=BoundingBoxFormat.XYXY, + inplace=False, + ) + return tv_tensors.KeyPoints(_xyxy_to_keypoints(bbox), canvas_size=bounding_boxes.canvas_size) + + def _cxcywhr_to_xywhr(cxcywhr: torch.Tensor, inplace: bool) -> torch.Tensor: if not inplace: cxcywhr = cxcywhr.clone() @@ -360,6 +408,14 @@ def _clamp_bounding_boxes( return out_boxes.to(in_dtype) +def _clamp_keypoints(keypoints: torch.Tensor, canvas_size: tuple[int, int]) -> torch.Tensor: + dtype = keypoints.dtype + keypoints = keypoints.clone() if keypoints.is_floating_point() else keypoints.float() + keypoints[..., 0].clamp_(min=0, max=canvas_size[1]) + keypoints[..., 1].clamp_(min=0, max=canvas_size[0]) + return keypoints.to(dtype=dtype) + + def clamp_bounding_boxes( inpt: torch.Tensor, format: Optional[BoundingBoxFormat] = None, @@ -383,3 +439,25 @@ def clamp_bounding_boxes( raise TypeError( f"Input can either be a plain tensor or a bounding box tv_tensor, but got {type(inpt)} instead." ) + + +def clamp_keypoints( + inpt: torch.Tensor, + canvas_size: Optional[tuple[int, int]] = None, +) -> torch.Tensor: + """See :func:`~torchvision.transforms.v2.ClampKeyPoints` for details.""" + if not torch.jit.is_scripting(): + _log_api_usage_once(clamp_keypoints) + + if torch.jit.is_scripting() or is_pure_tensor(inpt): + + if canvas_size is None: + raise ValueError("For pure tensor inputs, `canvas_size` have to be passed.") + return _clamp_keypoints(inpt, canvas_size=canvas_size) + elif isinstance(inpt, tv_tensors.KeyPoints): + if canvas_size is not None: + raise ValueError("For keypoints tv_tensor inputs, `canvas_size` must not be passed.") + output = _clamp_keypoints(inpt.as_subclass(torch.Tensor), canvas_size=inpt.canvas_size) + return tv_tensors.wrap(output, like=inpt) + else: + raise TypeError(f"Input can either be a plain tensor or a keypoints tv_tensor, but got {type(inpt)} instead.") diff --git a/torchvision/transforms/v2/functional/_misc.py b/torchvision/transforms/v2/functional/_misc.py index 7e167d788e6..35fc7e3110d 100644 --- a/torchvision/transforms/v2/functional/_misc.py +++ b/torchvision/transforms/v2/functional/_misc.py @@ -320,6 +320,7 @@ def to_dtype_video(video: torch.Tensor, dtype: torch.dtype = torch.float, scale: return to_dtype_image(video, dtype, scale=scale) +@_register_kernel_internal(to_dtype, tv_tensors.KeyPoints, tv_tensor_wrapper=False) @_register_kernel_internal(to_dtype, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False) @_register_kernel_internal(to_dtype, tv_tensors.Mask, tv_tensor_wrapper=False) def _to_dtype_tensor_dispatch(inpt: torch.Tensor, dtype: torch.dtype, scale: bool = False) -> torch.Tensor: @@ -327,6 +328,79 @@ def _to_dtype_tensor_dispatch(inpt: torch.Tensor, dtype: torch.dtype, scale: boo return inpt.to(dtype) +def sanitize_keypoints( + keypoints: torch.Tensor, canvas_size: Optional[tuple[int, int]] = None +) -> tuple[torch.Tensor, torch.Tensor]: + """Removes degenerate/invalid keypoints and returns the corresponding indexing mask. + + This removes the keypoints that are outside of their corresponing image. + + It is recommended to call it at the end of a pipeline, before passing the + input to the models. It is critical to call this transform if + :class:`~torchvision.transforms.v2.RandomIoUCrop` was called. + If you want to be extra careful, you may call it after all transforms that + may modify the key points but once at the end should be enough in most + cases. + + .. note:: + + Points that touch the edge of the canvas are removed, unlike for :func:`sanitize_bounding_boxes`. + + Raises: + ValueError: If the keypoints are not passed as a two dimensional tensor. + + Args: + keypoints (torch.Tensor or :class:`~torchvision.tv_tensors.KeyPoints`): The Keypoints being sanitized. + Should be of shape ``[N, 2]`` + canvas_size (Optional[tuple[int, int]], optional): The canvas_size of the bounding boxes + (size of the corresponding image/video). + Must be left to none if ``bounding_boxes`` is a :class:`~torchvision.tv_tensors.KeyPoints` object. + + Returns: + out (tuple of Tensors): The subset of valid bounding boxes, and the corresponding indexing mask. + The mask can then be used to subset other tensors (e.g. labels) that are associated with the bounding boxes. + """ + if not keypoints.ndim == 2: + if keypoints.ndim < 2: + raise ValueError("Cannot sanitize a single Keypoint") + raise ValueError( + "Cannot sanitize KeyPoints structure that are not 2D. " + f"Expected shape to be (N, 2), got {keypoints.shape} ({keypoints.ndim=}, not 2)" + ) + if torch.jit.is_scripting() or is_pure_tensor(keypoints): + if canvas_size is None: + raise ValueError( + "canvas_size cannot be None if keypoints is a pure tensor. " + f"Got canvas_size={canvas_size}." + "Set that to appropriate values or pass keypoints as a tv_tensors.KeyPoints object." + ) + valid = _get_sanitize_keypoints_mask( + keypoints, + canvas_size=canvas_size, + ) + return keypoints[valid], valid + + if not isinstance(keypoints, tv_tensors.KeyPoints): + raise ValueError("keypoints must be a tv_tensors.KeyPoints instance or a pure tensor.") + + valid = _get_sanitize_keypoints_mask( + keypoints, + canvas_size=keypoints.canvas_size, + ) + return tv_tensors.wrap(keypoints[valid], like=keypoints), valid + + +def _get_sanitize_keypoints_mask( + keypoints: torch.Tensor, + canvas_size: tuple[int, int], +) -> torch.Tensor: + image_h, image_w = canvas_size + x = keypoints[:, 0] + y = keypoints[:, 1] + + return (0 < x) & (x < image_w) & (0 < y) & (y < image_h) + + def sanitize_bounding_boxes( bounding_boxes: torch.Tensor, format: Optional[tv_tensors.BoundingBoxFormat] = None, diff --git a/torchvision/tv_tensors/__init__.py b/torchvision/tv_tensors/__init__.py index 1ba47f60a36..e1c6b2202df 100644 --- a/torchvision/tv_tensors/__init__.py +++ b/torchvision/tv_tensors/__init__.py @@ -1,18 +1,24 @@ +from typing import TypeVar + import torch from ._bounding_boxes import BoundingBoxes, BoundingBoxFormat from ._image import Image +from ._keypoints import KeyPoints from ._mask import Mask from ._torch_function_helpers import set_return_type from ._tv_tensor import TVTensor from ._video import Video +_WRAP_LIKE_T = TypeVar("_WRAP_LIKE_T", bound=TVTensor) + + # TODO: Fix this. We skip this method as it leads to # RecursionError: maximum recursion depth exceeded while calling a Python object # Until `disable` is removed, there will be graph breaks after all calls to functional transforms @torch.compiler.disable -def wrap(wrappee, *, like, **kwargs): +def wrap(wrappee: torch.Tensor, *, like: _WRAP_LIKE_T, **kwargs) -> _WRAP_LIKE_T: """Convert a :class:`torch.Tensor` (``wrappee``) into the same :class:`~torchvision.tv_tensors.TVTensor` subclass as ``like``. If ``like`` is a :class:`~torchvision.tv_tensors.BoundingBoxes`, the ``format`` and ``canvas_size`` of @@ -26,10 +32,25 @@ def wrap(wrappee, *, like, **kwargs): Ignored otherwise. """ if isinstance(like, BoundingBoxes): - return BoundingBoxes._wrap( + return BoundingBoxes._wrap( # type:ignore wrappee, format=kwargs.get("format", like.format), canvas_size=kwargs.get("canvas_size", like.canvas_size), ) + elif isinstance(like, KeyPoints): + return KeyPoints(wrappee, canvas_size=kwargs.get("canvas_size", like.canvas_size)) # type:ignore else: return wrappee.as_subclass(type(like)) + + +__all__: list[str] = [ + "wrap", + "KeyPoints", + "Video", + "TVTensor", + "set_return_type", + "Mask", + "Image", + "BoundingBoxFormat", + "BoundingBoxes", +] diff --git a/torchvision/tv_tensors/_keypoints.py b/torchvision/tv_tensors/_keypoints.py new file mode 100644 index 00000000000..8e0b1a502fc --- /dev/null +++ b/torchvision/tv_tensors/_keypoints.py @@ -0,0 +1,107 @@ +from __future__ import annotations + +from typing import Any, Mapping, MutableSequence, Optional, Sequence, Tuple, TYPE_CHECKING, Union + +import torch +from torch.utils._pytree import tree_flatten + +from ._tv_tensor import TVTensor + + +class KeyPoints(TVTensor): + """:class:`torch.Tensor` subclass for tensors with shape ``[..., 2]`` that represent points in an image. + + Each point is represented by its XY coordinates. + + KeyPoints can be converted from :class:`torchvision.tv_tensors.BoundingBoxes` + by :func:`torchvision.transforms.v2.functional.convert_bounding_boxes_to_points`. + + KeyPoints may represent any object that can be represented by sequences of 2D points: + - `Polygonal chains`, including polylines, Bézier curves, etc., + which should be of shape ``[N_chains, N_points, 2]``, which is equal to ``[N_chains, N_segments + 1, 2]`` + - Polygons, which should be of shape ``[N_polygons, N_points, 2]``, which is equal to ``[N_polygons, N_sides, 2]`` + - Skeletons, which could be of shape ``[N_skeletons, N_bones, 2, 2]`` for pose-estimation models + + .. note:: + + Like for :class:`torchvision.tv_tensors.BoundingBoxes`, there should only ever be a single + instance of the :class:`torchvision.tv_tensors.KeyPoints` class per sample + e.g. ``{"img": img, "poins_of_interest": KeyPoints(...)}``, + although one :class:`torchvision.tv_tensors.KeyPoints` object can contain multiple key points + + Args: + data: Any data that can be turned into a tensor with :func:`torch.as_tensor`. + canvas_size (two-tuple of ints): Height and width of the corresponding image or video. + dtype (torch.dtype, optional): Desired data type of the bounding box. If omitted, will be inferred from + ``data``. + device (torch.device, optional): Desired device of the bounding box. If omitted and ``data`` is a + :class:`torch.Tensor`, the device is taken from it. Otherwise, the bounding box is constructed on the CPU. + requires_grad (bool, optional): Whether autograd should record operations on the bounding box. If omitted and + ``data`` is a :class:`torch.Tensor`, the value is taken from it. Otherwise, defaults to ``False``. + """ + + canvas_size: Tuple[int, int] + + def __new__( + cls, + data: Any, + *, + canvas_size: Tuple[int, int], + dtype: Optional[torch.dtype] = None, + device: Optional[Union[torch.device, str, int]] = None, + requires_grad: Optional[bool] = None, + ): + tensor: torch.Tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad) + if tensor.ndim == 1: + tensor = tensor.unsqueeze(0) + elif tensor.shape[-1] != 2: + raise ValueError(f"Expected a tensor of shape (..., 2), not {tensor.shape}") + points = tensor.as_subclass(cls) + points.canvas_size = canvas_size + return points + + if TYPE_CHECKING: + # EVIL: Just so that MYPY+PYLANCE+others stop shouting that everything is wrong when initializeing the TVTensor + # Not read or defined at Runtime (only at linting time). + # TODO: BOUNDING BOXES needs something similar + def __init__( + self, + data: Any, + *, + canvas_size: Tuple[int, int], + dtype: Optional[torch.dtype] = None, + device: Optional[Union[torch.device, str, int]] = None, + requires_grad: Optional[bool] = None, + ): + pass + + @classmethod + def _wrap_output( + cls, + output: Any, + args: Sequence[Any] = (), + kwargs: Optional[Mapping[str, Any]] = None, + ) -> Any: + # Mostly copied over from the BoundingBoxes TVTensor, minor improvements. + # This copies over the metadata. + # For BoundingBoxes, that included format, but we only support one format here ! + flat_params, _ = tree_flatten(args + (tuple(kwargs.values()) if kwargs else ())) # type: ignore[operator] + first_bbox_from_args = next(x for x in flat_params if isinstance(x, KeyPoints)) + canvas_size: Tuple[int, int] = first_bbox_from_args.canvas_size + + if isinstance(output, torch.Tensor) and not isinstance(output, KeyPoints): + output = KeyPoints(output, canvas_size=canvas_size) + elif isinstance(output, MutableSequence): + # For lists and list-like object we don't try to create a new object, we just set the values in the list + # This allows us to conserve the type of complex list-like object that may not follow the initialization API of lists + for i, part in enumerate(output): + output[i] = KeyPoints(part, canvas_size=canvas_size) + elif isinstance(output, Sequence): + # Non-mutable sequences handled here (like tuples) + # Every sequence that is not a mutable sequence is a non-mutable sequence + # We have to use a tuple here, since we know its initialization api, unlike for `output` + output = tuple(KeyPoints(part, canvas_size=canvas_size) for part in output) + return output + + def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override] + return self._make_repr(canvas_size=self.canvas_size)