Skip to content

Added the KeyPoints TVTensor #8817

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 32 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
8253305
Added Keypoints to the library
Alexandre-SCHOEPP Dec 12, 2024
484561d
Improved KeyPoints to be exported
Alexandre-SCHOEPP Dec 13, 2024
3255890
Added kernels to support the keypoints
Alexandre-SCHOEPP Dec 13, 2024
7436636
Added tests for keypoints
Alexandre-SCHOEPP Dec 13, 2024
b35cba6
Applied ufmt formatting
Alexandre-SCHOEPP Dec 13, 2024
a19ec0b
Fixed the bugs found while testing
Alexandre-SCHOEPP Dec 16, 2024
5f4b188
Improved documentation to take KeyPoints into account
Alexandre-SCHOEPP Dec 17, 2024
cabce1c
Applied ufmt check
Alexandre-SCHOEPP Dec 17, 2024
d1b27ad
Fixed the hflip not being along the right coordinate
Alexandre-SCHOEPP Dec 17, 2024
6fa38f4
Merge branch 'main' into main
Alexandre-SCHOEPP Dec 18, 2024
05e4ad6
Merge branch 'main' into main
Alexandre-SCHOEPP Feb 10, 2025
03dc6c8
Merge branch 'main' into main
Alexandre-SCHOEPP Feb 20, 2025
d4d087c
Merge branch 'main' into main
Alexandre-SCHOEPP Mar 4, 2025
5a8c5b4
Fixed order of arguments
Alex-S-H-P Apr 30, 2025
dea31e2
Reworked logic of the conditions to better handle mutable/non mutable…
Alex-S-H-P Apr 30, 2025
71e20a5
Renamed out variable to be more similar with _resized_crop_bounding_b…
Alex-S-H-P Apr 30, 2025
2f77527
renamed _xyxy_to_points to _xyxy_to_keypoints for consistency
Alex-S-H-P Apr 30, 2025
517a6de
clarified _xyxy_to_points and changed the name of its caller for the …
Alex-S-H-P Apr 30, 2025
63ed4a5
Renamed half_point to more explicit single_coord_shape
Alex-S-H-P Apr 30, 2025
166c1ec
Integrated KeyPoints better in the transforms. It now warns alongside…
Alex-S-H-P Apr 30, 2025
fcfd597
Merge branch 'main' into main
Alexandre-SCHOEPP Apr 30, 2025
1cc3b6f
Fixed _geometry.py post botched merge request
Alex-S-H-P Apr 30, 2025
841de77
Review python 3.9 type hint and lint
AntoineSimoulin May 3, 2025
ff6ab48
Add specific keypoint tests
AntoineSimoulin May 3, 2025
0de59e7
Adjust variable names
AntoineSimoulin May 4, 2025
4b62ef4
Improved documentation inside of the KeyPoints class definition
Alex-S-H-P May 5, 2025
e99b82a
Improved convert_bounding_boxes_to_points to handle rotated bounding …
Alex-S-H-P May 5, 2025
a869f39
Applied ufmt
Alex-S-H-P May 5, 2025
6007b2c
Adding a type:ignore[override] on KeyPoints__repr__ as it also exist …
Alex-S-H-P May 5, 2025
b68b57b
Fixed flake8 compliance on "..." present in the line of the __init__ …
Alex-S-H-P May 5, 2025
801e24d
get_all_keypoints is now get_keypoints and returns the only keypoints…
Alex-S-H-P May 5, 2025
73a40a8
Fixed docstring on sanitize_keypoints
Alex-S-H-P May 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/tv_tensors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ info.

Image
Video
KeyPoints
BoundingBoxFormat
BoundingBoxes
Mask
Expand Down
11 changes: 10 additions & 1 deletion gallery/transforms/plot_tv_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,12 @@
# Under the hood, they are needed in :mod:`torchvision.transforms.v2` to correctly dispatch to the appropriate function
# for the input data.
#
# :mod:`torchvision.tv_tensors` supports four types of TVTensors:
# :mod:`torchvision.tv_tensors` supports five types of TVTensors:
#
# * :class:`~torchvision.tv_tensors.Image`
# * :class:`~torchvision.tv_tensors.Video`
# * :class:`~torchvision.tv_tensors.BoundingBoxes`
# * :class:`~torchvision.tv_tensors.KeyPoints`
# * :class:`~torchvision.tv_tensors.Mask`
#
# What can I do with a TVTensor?
Expand Down Expand Up @@ -96,6 +97,7 @@
# :class:`~torchvision.tv_tensors.BoundingBoxes` requires the coordinate format as well as the size of the
# corresponding image (``canvas_size``) alongside the actual values. These
# metadata are required to properly transform the bounding boxes.
# In a similar fashion, :class:`~torchvision.tv_tensors.KeyPoints` also require the ``canvas_size`` metadata to be added.

bboxes = tv_tensors.BoundingBoxes(
[[17, 16, 344, 495], [0, 10, 0, 10]],
Expand All @@ -104,6 +106,13 @@
)
print(bboxes)


keypoints = tv_tensors.KeyPoints(
[[17, 16], [344, 495], [0, 10], [0, 10]],
canvas_size=image.shape[-2:]
)
print(keypoints)

# %%
# Using ``tv_tensors.wrap()``
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^
Expand Down
15 changes: 15 additions & 0 deletions test/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import shutil
import sys
import tempfile
from typing import Sequence, Tuple
import warnings
from subprocess import CalledProcessError, check_output, STDOUT

Expand Down Expand Up @@ -402,6 +403,20 @@ def make_image_pil(*args, **kwargs):
return to_pil_image(make_image(*args, **kwargs))


def make_keypoints(
canvas_size: Tuple[int, int] = DEFAULT_SIZE, *, num_points: int | Sequence[int] = 4, dtype=None, device='cpu'
) -> tv_tensors.KeyPoints:
"""Make the KeyPoints for testing purposes"""
if isinstance(num_points, int):
num_points = [num_points]
half_point: Tuple[int, ...] = tuple(num_points) + (1,)
y = torch.randint(0, canvas_size[0] - 1, half_point, dtype=dtype, device=device)
x = torch.randint(0, canvas_size[1] - 1, half_point, dtype=dtype, device=device)
points = torch.cat((x, y), dim=-1)
keypoints = tv_tensors.KeyPoints(points, canvas_size=canvas_size)
return keypoints


def make_bounding_boxes(
canvas_size=DEFAULT_SIZE,
*,
Expand Down
32 changes: 28 additions & 4 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
make_image,
make_image_pil,
make_image_tensor,
make_keypoints,
make_segmentation_mask,
make_video,
make_video_tensor,
Expand Down Expand Up @@ -232,6 +233,7 @@ def check_functional_kernel_signature_match(functional, *, kernel, input_type):
# explicitly passed to the kernel.
explicit_metadata = {
tv_tensors.BoundingBoxes: {"format", "canvas_size"},
tv_tensors.KeyPoints: {"canvas_size"}
}
kernel_params = [param for param in kernel_params if param.name not in explicit_metadata.get(input_type, set())]

Expand Down Expand Up @@ -336,6 +338,18 @@ def _make_transform_sample(transform, *, image_or_video, adapter):
canvas_size=size,
device=device,
),
keypoints=make_keypoints(canvas_size=size), keypoints_degenerate=tv_tensors.KeyPoints(
[
[0, 1], # left edge
[1, 0], # top edge
[0, 0], # top left corner
[size[1], 1], # right edge
[size[1], 0], # top right corner
[1, size[0]], # bottom edge
[0, size[0]], # bottom left corner
[size[1], size[0]] # bottom right corner
], canvas_size=size, device=device
),
detection_mask=make_detection_masks(size, device=device),
segmentation_mask=make_segmentation_mask(size, device=device),
int=0,
Expand Down Expand Up @@ -689,6 +703,7 @@ def test_functional(self, size, make_input):
(F.resize_image, torch.Tensor),
(F._geometry._resize_image_pil, PIL.Image.Image),
(F.resize_image, tv_tensors.Image),
(F.resize_keypoints, tv_tensors.KeyPoints),
(F.resize_bounding_boxes, tv_tensors.BoundingBoxes),
(F.resize_mask, tv_tensors.Mask),
(F.resize_video, tv_tensors.Video),
Expand Down Expand Up @@ -1044,6 +1059,7 @@ def test_functional(self, make_input):
(F.horizontal_flip_image, torch.Tensor),
(F._geometry._horizontal_flip_image_pil, PIL.Image.Image),
(F.horizontal_flip_image, tv_tensors.Image),
(F.horizontal_flip_keypoints, tv_tensors.KeyPoints),
(F.horizontal_flip_bounding_boxes, tv_tensors.BoundingBoxes),
(F.horizontal_flip_mask, tv_tensors.Mask),
(F.horizontal_flip_video, tv_tensors.Video),
Expand Down Expand Up @@ -1212,6 +1228,7 @@ def test_functional(self, make_input):
(F.affine_image, torch.Tensor),
(F._geometry._affine_image_pil, PIL.Image.Image),
(F.affine_image, tv_tensors.Image),
(F.affine_keypoints, tv_tensors.KeyPoints),
(F.affine_bounding_boxes, tv_tensors.BoundingBoxes),
(F.affine_mask, tv_tensors.Mask),
(F.affine_video, tv_tensors.Video),
Expand Down Expand Up @@ -1494,6 +1511,7 @@ def test_functional(self, make_input):
(F.vertical_flip_image, torch.Tensor),
(F._geometry._vertical_flip_image_pil, PIL.Image.Image),
(F.vertical_flip_image, tv_tensors.Image),
(F.vertical_flip_keypoints, tv_tensors.KeyPoints),
(F.vertical_flip_bounding_boxes, tv_tensors.BoundingBoxes),
(F.vertical_flip_mask, tv_tensors.Mask),
(F.vertical_flip_video, tv_tensors.Video),
Expand Down Expand Up @@ -1636,6 +1654,7 @@ def test_functional(self, make_input):
(F.rotate_image, torch.Tensor),
(F._geometry._rotate_image_pil, PIL.Image.Image),
(F.rotate_image, tv_tensors.Image),
(F.rotate_keypoints, tv_tensors.KeyPoints),
(F.rotate_bounding_boxes, tv_tensors.BoundingBoxes),
(F.rotate_mask, tv_tensors.Mask),
(F.rotate_video, tv_tensors.Video),
Expand Down Expand Up @@ -2341,7 +2360,9 @@ def test_error(self, T):
F.to_pil_image(imgs[0]),
tv_tensors.Mask(torch.rand(12, 12)),
tv_tensors.BoundingBoxes(torch.rand(2, 4), format="XYXY", canvas_size=12),
tv_tensors.KeyPoints(torch.rand(2, 2), canvas_size=(12, 12))
):
print(type(input_with_bad_type), cutmix_mixup)
with pytest.raises(ValueError, match="does not support PIL images, "):
cutmix_mixup(input_with_bad_type)

Expand Down Expand Up @@ -2749,8 +2770,9 @@ def test_functional_signature(self, kernel, input_type):
check_functional_kernel_signature_match(F.elastic, kernel=kernel, input_type=input_type)

@pytest.mark.parametrize(
"make_input",
[make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video],
"make_input", [
make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video, make_keypoints
],
)
def test_displacement_error(self, make_input):
input = make_input()
Expand All @@ -2762,8 +2784,10 @@ def test_displacement_error(self, make_input):
F.elastic(input, displacement=torch.rand(F.get_size(input)))

@pytest.mark.parametrize(
"make_input",
[make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video],
"make_input", [
make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video,
make_keypoints
],
)
# ElasticTransform needs larger images to avoid the needed internal padding being larger than the actual image
@pytest.mark.parametrize("size", [(163, 163), (72, 333), (313, 95)])
Expand Down
54 changes: 30 additions & 24 deletions test/test_transforms_v2_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch

import torchvision.transforms.v2._utils
from common_utils import DEFAULT_SIZE, make_bounding_boxes, make_detection_masks, make_image
from common_utils import DEFAULT_SIZE, make_bounding_boxes, make_detection_masks, make_image, make_keypoints

from torchvision import tv_tensors
from torchvision.transforms.v2._utils import has_all, has_any
Expand All @@ -14,29 +14,32 @@
IMAGE = make_image(DEFAULT_SIZE, color_space="RGB")
BOUNDING_BOX = make_bounding_boxes(DEFAULT_SIZE, format=tv_tensors.BoundingBoxFormat.XYXY)
MASK = make_detection_masks(DEFAULT_SIZE)
KEYPOINTS = make_keypoints(DEFAULT_SIZE)


@pytest.mark.parametrize(
("sample", "types", "expected"),
[
((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Image,), True),
((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.BoundingBoxes,), True),
((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Mask,), True),
((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Image, tv_tensors.BoundingBoxes), True),
((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Image, tv_tensors.Mask), True),
((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.BoundingBoxes, tv_tensors.Mask), True),
((MASK,), (tv_tensors.Image, tv_tensors.BoundingBoxes), False),
((BOUNDING_BOX,), (tv_tensors.Image, tv_tensors.Mask), False),
((IMAGE,), (tv_tensors.BoundingBoxes, tv_tensors.Mask), False),
((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (tv_tensors.Image,), True),
((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (tv_tensors.BoundingBoxes,), True),
((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (tv_tensors.Mask,), True),
((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (tv_tensors.Image, tv_tensors.BoundingBoxes), True),
((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (tv_tensors.Image, tv_tensors.Mask), True),
((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (tv_tensors.BoundingBoxes, tv_tensors.Mask), True),
((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (tv_tensors.KeyPoints,), True),
((MASK,), (tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.KeyPoints), False),
((BOUNDING_BOX,), (tv_tensors.Image, tv_tensors.Mask, tv_tensors.KeyPoints), False),
((IMAGE,), (tv_tensors.BoundingBoxes, tv_tensors.Mask, tv_tensors.KeyPoints), False),
((KEYPOINTS,), (tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask), False),
(
(IMAGE, BOUNDING_BOX, MASK),
(tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask),
(IMAGE, BOUNDING_BOX, MASK, KEYPOINTS),
(tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask, tv_tensors.KeyPoints),
True,
),
((), (tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask), False),
((IMAGE, BOUNDING_BOX, MASK), (lambda obj: isinstance(obj, tv_tensors.Image),), True),
((IMAGE, BOUNDING_BOX, MASK), (lambda _: False,), False),
((IMAGE, BOUNDING_BOX, MASK), (lambda _: True,), True),
((), (tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask, tv_tensors.KeyPoints), False),
((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (lambda obj: isinstance(obj, tv_tensors.Image),), True),
((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (lambda _: False,), False),
((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (lambda _: True,), True),
((IMAGE,), (tv_tensors.Image, PIL.Image.Image, torchvision.transforms.v2._utils.is_pure_tensor), True),
(
(torch.Tensor(IMAGE),),
Expand All @@ -57,15 +60,18 @@ def test_has_any(sample, types, expected):
@pytest.mark.parametrize(
("sample", "types", "expected"),
[
((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Image,), True),
((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.BoundingBoxes,), True),
((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Mask,), True),
((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Image, tv_tensors.BoundingBoxes), True),
((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Image, tv_tensors.Mask), True),
((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.BoundingBoxes, tv_tensors.Mask), True),
((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (tv_tensors.Image,), True),
((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (tv_tensors.BoundingBoxes,), True),
((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (tv_tensors.Mask,), True),
((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (tv_tensors.Image, tv_tensors.BoundingBoxes), True),
((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (tv_tensors.Image, tv_tensors.Mask), True),
((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (tv_tensors.BoundingBoxes, tv_tensors.Mask), True),
((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (tv_tensors.Mask, tv_tensors.KeyPoints), True),
((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (tv_tensors.BoundingBoxes, tv_tensors.KeyPoints), True),
((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (tv_tensors.BoundingBoxes, tv_tensors.Mask, tv_tensors.KeyPoints), True),
(
(IMAGE, BOUNDING_BOX, MASK),
(tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask),
(IMAGE, BOUNDING_BOX, MASK, KEYPOINTS),
(tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask, tv_tensors.KeyPoints),
True,
),
((BOUNDING_BOX, MASK), (tv_tensors.Image, tv_tensors.BoundingBoxes), False),
Expand Down
Loading