Skip to content

Commit

Permalink
[prototype] Rewrite the meta dimension methods (#6722)
Browse files Browse the repository at this point in the history
* Rewrite `get_dimensions`, `get_num_channels` and `get_spatial_size`

* Remove `get_chw`

* Remove comments

* Make `get_spatial_size` support non-image input

* Reduce the unnecessary use of `get_dimensions*`

* Fix linters

* Fix merge bug

* Linter

* Fix linter
  • Loading branch information
datumbox authored Oct 7, 2022
1 parent 4c049ca commit 6e203b4
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 37 deletions.
6 changes: 5 additions & 1 deletion torchvision/prototype/features/_mask.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Any, List, Optional, Union
from typing import Any, cast, List, Optional, Tuple, Union

import torch
from torchvision.transforms import InterpolationMode
Expand Down Expand Up @@ -32,6 +32,10 @@ def wrap_like(
) -> Mask:
return cls._wrap(tensor)

@property
def image_size(self) -> Tuple[int, int]:
return cast(Tuple[int, int], tuple(self.shape[-2:]))

def horizontal_flip(self) -> Mask:
output = self._F.horizontal_flip_mask(self)
return Mask.wrap_like(self, output)
Expand Down
10 changes: 5 additions & 5 deletions torchvision/prototype/transforms/_auto_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision.prototype import features
from torchvision.prototype.transforms import AutoAugmentPolicy, functional as F, InterpolationMode, Transform
from torchvision.prototype.transforms.functional._meta import get_chw
from torchvision.prototype.transforms.functional._meta import get_spatial_size

from ._utils import _isinstance, _setup_fill_arg

Expand Down Expand Up @@ -278,7 +278,7 @@ def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]

id, image_or_video = self._extract_image_or_video(sample)
_, height, width = get_chw(image_or_video)
height, width = get_spatial_size(image_or_video)

policy = self._policies[int(torch.randint(len(self._policies), ()))]

Expand Down Expand Up @@ -349,7 +349,7 @@ def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]

id, image_or_video = self._extract_image_or_video(sample)
_, height, width = get_chw(image_or_video)
height, width = get_spatial_size(image_or_video)

for _ in range(self.num_ops):
transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)
Expand Down Expand Up @@ -403,7 +403,7 @@ def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]

id, image_or_video = self._extract_image_or_video(sample)
_, height, width = get_chw(image_or_video)
height, width = get_spatial_size(image_or_video)

transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)

Expand Down Expand Up @@ -473,7 +473,7 @@ def _sample_dirichlet(self, params: torch.Tensor) -> torch.Tensor:
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
id, orig_image_or_video = self._extract_image_or_video(sample)
_, height, width = get_chw(orig_image_or_video)
height, width = get_spatial_size(orig_image_or_video)

if isinstance(orig_image_or_video, torch.Tensor):
image_or_video = orig_image_or_video
Expand Down
7 changes: 4 additions & 3 deletions torchvision/prototype/transforms/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from torchvision.prototype import features
from torchvision.prototype.features._feature import FillType

from torchvision.prototype.transforms.functional._meta import get_chw
from torchvision.prototype.transforms.functional._meta import get_dimensions
from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401

from typing_extensions import Literal
Expand Down Expand Up @@ -80,15 +80,16 @@ def query_bounding_box(sample: Any) -> features.BoundingBox:
def query_chw(sample: Any) -> Tuple[int, int, int]:
flat_sample, _ = tree_flatten(sample)
chws = {
get_chw(item)
tuple(get_dimensions(item))
for item in flat_sample
if isinstance(item, (features.Image, PIL.Image.Image, features.Video)) or features.is_simple_tensor(item)
}
if not chws:
raise TypeError("No image or video was found in the sample")
elif len(chws) > 1:
raise ValueError(f"Found multiple CxHxW dimensions in the sample: {sequence_to_str(sorted(chws))}")
return chws.pop()
c, h, w = chws.pop()
return c, h, w


def _isinstance(obj: Any, types_or_checks: Tuple[Union[Type, Callable[[Any], bool]], ...]) -> bool:
Expand Down
6 changes: 6 additions & 0 deletions torchvision/prototype/transforms/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,15 @@
convert_color_space_image_pil,
convert_color_space_video,
convert_color_space,
get_dimensions_image_tensor,
get_dimensions_image_pil,
get_dimensions,
get_image_num_channels,
get_num_channels_image_tensor,
get_num_channels_image_pil,
get_num_channels,
get_spatial_size_image_tensor,
get_spatial_size_image_pil,
get_spatial_size,
) # usort: skip

Expand Down
21 changes: 13 additions & 8 deletions torchvision/prototype/transforms/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,12 @@
interpolate,
)

from ._meta import convert_format_bounding_box, get_dimensions_image_pil, get_dimensions_image_tensor
from ._meta import (
convert_format_bounding_box,
get_dimensions_image_tensor,
get_spatial_size_image_pil,
get_spatial_size_image_tensor,
)

horizontal_flip_image_tensor = _FT.hflip
horizontal_flip_image_pil = _FP.hflip
Expand Down Expand Up @@ -323,7 +328,7 @@ def affine_image_pil(
# it is visually better to estimate the center without 0.5 offset
# otherwise image rotated by 90 degrees is shifted vs output image of torch.rot90 or F_t.affine
if center is None:
_, height, width = get_dimensions_image_pil(image)
height, width = get_spatial_size_image_pil(image)
center = [width * 0.5, height * 0.5]
matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear)

Expand Down Expand Up @@ -1189,13 +1194,13 @@ def _center_crop_compute_crop_anchor(

def center_crop_image_tensor(image: torch.Tensor, output_size: List[int]) -> torch.Tensor:
crop_height, crop_width = _center_crop_parse_output_size(output_size)
_, image_height, image_width = get_dimensions_image_tensor(image)
image_height, image_width = get_spatial_size_image_tensor(image)

if crop_height > image_height or crop_width > image_width:
padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width)
image = pad_image_tensor(image, padding_ltrb, fill=0)

_, image_height, image_width = get_dimensions_image_tensor(image)
image_height, image_width = get_spatial_size_image_tensor(image)
if crop_width == image_width and crop_height == image_height:
return image

Expand All @@ -1206,13 +1211,13 @@ def center_crop_image_tensor(image: torch.Tensor, output_size: List[int]) -> tor
@torch.jit.unused
def center_crop_image_pil(image: PIL.Image.Image, output_size: List[int]) -> PIL.Image.Image:
crop_height, crop_width = _center_crop_parse_output_size(output_size)
_, image_height, image_width = get_dimensions_image_pil(image)
image_height, image_width = get_spatial_size_image_pil(image)

if crop_height > image_height or crop_width > image_width:
padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width)
image = pad_image_pil(image, padding_ltrb, fill=0)

_, image_height, image_width = get_dimensions_image_pil(image)
image_height, image_width = get_spatial_size_image_pil(image)
if crop_width == image_width and crop_height == image_height:
return image

Expand Down Expand Up @@ -1365,7 +1370,7 @@ def five_crop_image_tensor(
image: torch.Tensor, size: List[int]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
crop_height, crop_width = _parse_five_crop_size(size)
_, image_height, image_width = get_dimensions_image_tensor(image)
image_height, image_width = get_spatial_size_image_tensor(image)

if crop_width > image_width or crop_height > image_height:
msg = "Requested crop size {} is bigger than input size {}"
Expand All @@ -1385,7 +1390,7 @@ def five_crop_image_pil(
image: PIL.Image.Image, size: List[int]
) -> Tuple[PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image]:
crop_height, crop_width = _parse_five_crop_size(size)
_, image_height, image_width = get_dimensions_image_pil(image)
image_height, image_width = get_spatial_size_image_pil(image)

if crop_width > image_width or crop_height > image_height:
msg = "Requested crop size {} is bigger than input size {}"
Expand Down
58 changes: 38 additions & 20 deletions torchvision/prototype/transforms/functional/_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,48 +6,66 @@
from torchvision.prototype.features import BoundingBoxFormat, ColorSpace
from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT


get_dimensions_image_tensor = _FT.get_dimensions
get_dimensions_image_pil = _FP.get_dimensions


# TODO: Should this be prefixed with `_` similar to other methods that don't get exposed by init?
def get_chw(image: features.ImageOrVideoTypeJIT) -> Tuple[int, int, int]:
def get_dimensions(image: features.ImageOrVideoTypeJIT) -> List[int]:
if isinstance(image, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(image, (features.Image, features.Video))
):
channels, height, width = get_dimensions_image_tensor(image)
return get_dimensions_image_tensor(image)
elif isinstance(image, (features.Image, features.Video)):
channels = image.num_channels
height, width = image.image_size
else: # isinstance(image, PIL.Image.Image)
channels, height, width = get_dimensions_image_pil(image)
return channels, height, width


# The three functions below are here for BC. Whether we want to have two different kernels and how they and the
# compound version should be named is still under discussion: https://github.com/pytorch/vision/issues/6491
# Given that these kernels should also support boxes, masks, and videos, it is unlikely that there name will stay.
# They will either be deprecated or simply aliased to the new kernels if we have reached consensus about the issue
# detailed above.
return [channels, height, width]
else:
return get_dimensions_image_pil(image)


def get_dimensions(image: features.ImageOrVideoTypeJIT) -> List[int]:
return list(get_chw(image))
get_num_channels_image_tensor = _FT.get_image_num_channels
get_num_channels_image_pil = _FP.get_image_num_channels


def get_num_channels(image: features.ImageOrVideoTypeJIT) -> int:
num_channels, *_ = get_chw(image)
return num_channels
if isinstance(image, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(image, (features.Image, features.Video))
):
return _FT.get_image_num_channels(image)
elif isinstance(image, (features.Image, features.Video)):
return image.num_channels
else:
return _FP.get_image_num_channels(image)


# We changed the names to ensure it can be used not only for images but also videos. Thus, we just alias it without
# deprecating the old names.
get_image_num_channels = get_num_channels


def get_spatial_size(image: features.ImageOrVideoTypeJIT) -> List[int]:
_, *size = get_chw(image)
return size
def get_spatial_size_image_tensor(image: torch.Tensor) -> List[int]:
width, height = _FT.get_image_size(image)
return [height, width]


@torch.jit.unused
def get_spatial_size_image_pil(image: PIL.Image.Image) -> List[int]:
width, height = _FP.get_image_size(image)
return [height, width]


def get_spatial_size(inpt: features.InputTypeJIT) -> List[int]:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return get_spatial_size_image_tensor(inpt)
elif isinstance(inpt, features._Feature):
image_size = getattr(inpt, "image_size", None)
if image_size is not None:
return list(image_size)
else:
raise ValueError(f"Type {inpt.__class__} doesn't have spatial size.")
else:
return get_spatial_size_image_pil(inpt)


def _xywh_to_xyxy(xywh: torch.Tensor) -> torch.Tensor:
Expand Down

0 comments on commit 6e203b4

Please sign in to comment.