Skip to content

Commit 2bababf

Browse files
authored
Add a GrayscaleToRgb transform that can expand channels to 3 (#8247)
1 parent fa82fd3 commit 2bababf

File tree

6 files changed

+94
-1
lines changed

6 files changed

+94
-1
lines changed

docs/source/transforms.rst

+3-1
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,7 @@ Color
347347
v2.RandomChannelPermutation
348348
v2.RandomPhotometricDistort
349349
v2.Grayscale
350+
v2.RGB
350351
v2.RandomGrayscale
351352
v2.GaussianBlur
352353
v2.RandomInvert
@@ -364,6 +365,7 @@ Functionals
364365

365366
v2.functional.permute_channels
366367
v2.functional.rgb_to_grayscale
368+
v2.functional.grayscale_to_rgb
367369
v2.functional.to_grayscale
368370
v2.functional.gaussian_blur
369371
v2.functional.invert
@@ -584,7 +586,7 @@ Conversion
584586
while performing the conversion, while some may not do any scaling. By
585587
scaling, we mean e.g. that a ``uint8`` -> ``float32`` would map the [0,
586588
255] range into [0, 1] (and vice-versa). See :ref:`range_and_dtype`.
587-
589+
588590
.. autosummary::
589591
:toctree: generated/
590592
:template: class.rst

test/test_transforms_v2.py

+48
Original file line numberDiff line numberDiff line change
@@ -5005,6 +5005,54 @@ def test_random_transform_correctness(self, num_input_channels):
50055005
assert_equal(actual, expected, rtol=0, atol=1)
50065006

50075007

5008+
class TestGrayscaleToRgb:
5009+
@pytest.mark.parametrize("dtype", [torch.uint8, torch.float32])
5010+
@pytest.mark.parametrize("device", cpu_and_cuda())
5011+
def test_kernel_image(self, dtype, device):
5012+
check_kernel(F.grayscale_to_rgb_image, make_image(dtype=dtype, device=device))
5013+
5014+
@pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image])
5015+
def test_functional(self, make_input):
5016+
check_functional(F.grayscale_to_rgb, make_input())
5017+
5018+
@pytest.mark.parametrize(
5019+
("kernel", "input_type"),
5020+
[
5021+
(F.rgb_to_grayscale_image, torch.Tensor),
5022+
(F._rgb_to_grayscale_image_pil, PIL.Image.Image),
5023+
(F.rgb_to_grayscale_image, tv_tensors.Image),
5024+
],
5025+
)
5026+
def test_functional_signature(self, kernel, input_type):
5027+
check_functional_kernel_signature_match(F.grayscale_to_rgb, kernel=kernel, input_type=input_type)
5028+
5029+
@pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image])
5030+
def test_transform(self, make_input):
5031+
check_transform(transforms.RGB(), make_input(color_space="GRAY"))
5032+
5033+
@pytest.mark.parametrize("fn", [F.grayscale_to_rgb, transform_cls_to_functional(transforms.RGB)])
5034+
def test_image_correctness(self, fn):
5035+
image = make_image(dtype=torch.uint8, device="cpu", color_space="GRAY")
5036+
5037+
actual = fn(image)
5038+
expected = F.to_image(F.grayscale_to_rgb(F.to_pil_image(image)))
5039+
5040+
assert_equal(actual, expected, rtol=0, atol=1)
5041+
5042+
def test_expanded_channels_are_not_views_into_the_same_underlying_tensor(self):
5043+
image = make_image(dtype=torch.uint8, device="cpu", color_space="GRAY")
5044+
5045+
output_image = F.grayscale_to_rgb(image)
5046+
assert_equal(output_image[0][0][0], output_image[1][0][0])
5047+
output_image[0][0][0] = output_image[0][0][0] + 1
5048+
assert output_image[0][0][0] != output_image[1][0][0]
5049+
5050+
def test_rgb_image_is_unchanged(self):
5051+
image = make_image(dtype=torch.uint8, device="cpu", color_space="RGB")
5052+
assert_equal(image.shape[-3], 3)
5053+
assert_equal(F.grayscale_to_rgb(image), image)
5054+
5055+
50085056
class TestRandomZoomOut:
50095057
# Tests are light because this largely relies on the already tested `pad` kernels.
50105058

torchvision/transforms/v2/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
RandomPhotometricDistort,
1919
RandomPosterize,
2020
RandomSolarize,
21+
RGB,
2122
)
2223
from ._container import Compose, RandomApply, RandomChoice, RandomOrder
2324
from ._geometry import (

torchvision/transforms/v2/_color.py

+14
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,20 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
5454
return self._call_kernel(F.rgb_to_grayscale, inpt, num_output_channels=params["num_input_channels"])
5555

5656

57+
class RGB(Transform):
58+
"""Convert images or videos to RGB (if they are already not RGB).
59+
60+
If the input is a :class:`torch.Tensor`, it is expected
61+
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions
62+
"""
63+
64+
def __init__(self):
65+
super().__init__()
66+
67+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
68+
return self._call_kernel(F.grayscale_to_rgb, inpt)
69+
70+
5771
class ColorJitter(Transform):
5872
"""Randomly change the brightness, contrast, saturation and hue of an image or video.
5973

torchvision/transforms/v2/functional/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@
6363
equalize,
6464
equalize_image,
6565
equalize_video,
66+
grayscale_to_rgb,
67+
grayscale_to_rgb_image,
6668
invert,
6769
invert_image,
6870
invert_video,

torchvision/transforms/v2/functional/_color.py

+26
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,32 @@ def _rgb_to_grayscale_image_pil(image: PIL.Image.Image, num_output_channels: int
6565
return _FP.to_grayscale(image, num_output_channels=num_output_channels)
6666

6767

68+
def grayscale_to_rgb(inpt: torch.Tensor) -> torch.Tensor:
69+
"""See :class:`~torchvision.transforms.v2.GrayscaleToRgb` for details."""
70+
if torch.jit.is_scripting():
71+
return grayscale_to_rgb_image(inpt)
72+
73+
_log_api_usage_once(grayscale_to_rgb)
74+
75+
kernel = _get_kernel(grayscale_to_rgb, type(inpt))
76+
return kernel(inpt)
77+
78+
79+
@_register_kernel_internal(grayscale_to_rgb, torch.Tensor)
80+
@_register_kernel_internal(grayscale_to_rgb, tv_tensors.Image)
81+
def grayscale_to_rgb_image(image: torch.Tensor) -> torch.Tensor:
82+
if image.shape[-3] >= 3:
83+
# Image already has RGB channels. We don't need to do anything.
84+
return image
85+
# rgb_to_grayscale can be used to add channels so we reuse that function.
86+
return _rgb_to_grayscale_image(image, num_output_channels=3, preserve_dtype=True)
87+
88+
89+
@_register_kernel_internal(grayscale_to_rgb, PIL.Image.Image)
90+
def grayscale_to_rgb_image_pil(image: PIL.Image.Image) -> PIL.Image.Image:
91+
return image.convert(mode="RGB")
92+
93+
6894
def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Tensor:
6995
ratio = float(ratio)
7096
fp = image1.is_floating_point()

0 commit comments

Comments
 (0)