diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 94d90b9e2f6..c87feef11c7 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -19,6 +19,7 @@ import torch import torchvision.ops import torchvision.transforms.v2 as transforms +from torchvision.transforms.v2 import RandomPosterize from common_utils import ( assert_equal, @@ -6270,4 +6271,20 @@ def test_different_sizes(self, make_input1, make_input2, query): @pytest.mark.parametrize("query", [transforms.query_size, transforms.query_chw]) def test_no_valid_input(self, query): with pytest.raises(TypeError, match="No image"): - query(["blah"]) + query(["blah"] + +@pytest.mark.parametrize("dtype", [torch.float32, torch.bool, torch.complex64]) +def test_random_posterize_dtype_error(dtype): + rp = RandomPosterize(bits=3, p=1.0) + tensor = torch.zeros((1, 3, 5, 5), dtype=dtype) + with pytest.raises(TypeError) as excinfo: + rp(tensor) + assert "Number of value bits is only defined for integer dtypes" in str(excinfo.value) + + +def test_random_posterize_uint8_pass(): + rp = RandomPosterize(bits=4, p=1.0) + tensor = torch.randint(0, 255, (1, 3, 5, 5), dtype=torch.uint8) + out = rp(tensor) + assert isinstance(out, torch.Tensor) + assert out.dtype == torch.uint8 diff --git a/torchvision/transforms/v2/_color.py b/torchvision/transforms/v2/_color.py index bf4ae55d232..198e03398fe 100644 --- a/torchvision/transforms/v2/_color.py +++ b/torchvision/transforms/v2/_color.py @@ -3,12 +3,26 @@ from typing import Any, Optional, Union import torch +from torch import Tensor from torchvision import transforms as _transforms from torchvision.transforms.v2 import functional as F, Transform + from ._transform import _RandomApplyTransform from ._utils import query_chw +def _ensure_integer_dtype(tensor: Tensor) -> None: + """ + Checks that the tensor's dtype is integer. + Throws TypeError for float, complex, bool, etc. + """ + try: + torch.iinfo(tensor.dtype) + except (ValueError, TypeError): + raise TypeError( + f"Number of value bits is only defined for integer dtypes, but got {tensor.dtype}" + ) + class Grayscale(Transform): """Convert images or videos to grayscale. @@ -306,9 +320,11 @@ def __init__(self, bits: int, p: float = 0.5) -> None: self.bits = bits def transform(self, inpt: Any, params: dict[str, Any]) -> Any: + # Check that the tensor is integer + if isinstance(inpt, Tensor): + _ensure_integer_dtype(inpt) return self._call_kernel(F.posterize, inpt, bits=self.bits) - class RandomSolarize(_RandomApplyTransform): """Solarize the image or video with a given probability by inverting all pixel values above a threshold.