Skip to content

Commit 0d69e35

Browse files
authored
Add tests for prototype <-> legacy transforms consistency (#6514)
* add consistency checks for prototype and legacy transforms * fix Resize
1 parent d11556a commit 0d69e35

File tree

2 files changed

+167
-1
lines changed

2 files changed

+167
-1
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
import itertools
2+
3+
import pytest
4+
import torch.testing
5+
from test_prototype_transforms_functional import make_images
6+
from torchvision import transforms as legacy_transforms
7+
from torchvision.prototype import features, transforms as prototype_transforms
8+
from torchvision.prototype.transforms.functional import to_image_pil, to_image_tensor
9+
10+
DEFAULT_MAKE_IMAGES_KWARGS = dict(color_spaces=[features.ColorSpace.RGB], extra_dims=[(4,)])
11+
12+
13+
class ArgsKwargs:
14+
def __init__(self, *args, **kwargs):
15+
self.args = args
16+
self.kwargs = kwargs
17+
18+
def __iter__(self):
19+
yield self.args
20+
yield self.kwargs
21+
22+
def __str__(self):
23+
return ", ".join(
24+
itertools.chain(
25+
[repr(arg) for arg in self.args],
26+
[f"{param}={repr(kwarg)}" for param, kwarg in self.kwargs.items()],
27+
)
28+
)
29+
30+
31+
class ConsistencyConfig:
32+
def __init__(
33+
self, prototype_cls, legacy_cls, transform_args_kwargs=None, make_images_kwargs=None, supports_pil=True
34+
):
35+
self.prototype_cls = prototype_cls
36+
self.legacy_cls = legacy_cls
37+
self.transform_args_kwargs = transform_args_kwargs or [((), dict())]
38+
self.make_images_kwargs = make_images_kwargs or DEFAULT_MAKE_IMAGES_KWARGS
39+
self.supports_pil = supports_pil
40+
41+
def parametrization(self):
42+
return [
43+
pytest.param(
44+
self.prototype_cls,
45+
self.legacy_cls,
46+
args_kwargs,
47+
self.make_images_kwargs,
48+
self.supports_pil,
49+
id=f"{self.prototype_cls.__name__}({args_kwargs})",
50+
)
51+
for args_kwargs in self.transform_args_kwargs
52+
]
53+
54+
55+
CONSISTENCY_CONFIGS = [
56+
ConsistencyConfig(
57+
prototype_transforms.Normalize,
58+
legacy_transforms.Normalize,
59+
[
60+
ArgsKwargs(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
61+
],
62+
supports_pil=False,
63+
make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, dtypes=[torch.float]),
64+
),
65+
ConsistencyConfig(
66+
prototype_transforms.Resize,
67+
legacy_transforms.Resize,
68+
[
69+
ArgsKwargs(32),
70+
ArgsKwargs((32, 29)),
71+
ArgsKwargs((31, 28), interpolation=prototype_transforms.InterpolationMode.BICUBIC),
72+
],
73+
),
74+
ConsistencyConfig(
75+
prototype_transforms.CenterCrop,
76+
legacy_transforms.CenterCrop,
77+
[
78+
ArgsKwargs(18),
79+
ArgsKwargs((18, 13)),
80+
],
81+
),
82+
]
83+
84+
85+
@pytest.mark.parametrize(
86+
("prototype_transform_cls", "legacy_transform_cls", "args_kwargs", "make_images_kwargs", "supports_pil"),
87+
itertools.chain.from_iterable(config.parametrization() for config in CONSISTENCY_CONFIGS),
88+
)
89+
def test_consistency(prototype_transform_cls, legacy_transform_cls, args_kwargs, make_images_kwargs, supports_pil):
90+
args, kwargs = args_kwargs
91+
92+
try:
93+
legacy = legacy_transform_cls(*args, **kwargs)
94+
except Exception as exc:
95+
raise pytest.UsageError(
96+
f"Initializing the legacy transform failed with the error above. "
97+
f"Please correct the `ArgsKwargs({args_kwargs})` in the `ConsistencyConfig`."
98+
) from exc
99+
100+
try:
101+
prototype = prototype_transform_cls(*args, **kwargs)
102+
except Exception as exc:
103+
raise AssertionError(
104+
"Initializing the prototype transform failed with the error above. "
105+
"This means there is a consistency bug in the constructor."
106+
) from exc
107+
108+
for image in make_images(**make_images_kwargs):
109+
image_tensor = torch.Tensor(image)
110+
image_pil = to_image_pil(image) if image.ndim == 3 and supports_pil else None
111+
112+
try:
113+
output_legacy_tensor = legacy(image_tensor)
114+
except Exception as exc:
115+
raise pytest.UsageError(
116+
f"Transforming a tensor image with shape {tuple(image.shape)} failed with the error above. "
117+
"This means that you need to specify the parameters passed to `make_images` through the "
118+
"`make_images_kwargs` of the `ConsistencyConfig`."
119+
) from exc
120+
121+
try:
122+
output_prototype_tensor = prototype(image_tensor)
123+
except Exception as exc:
124+
raise AssertionError(
125+
f"Transforming a tensor image with shape {tuple(image.shape)} failed with the error above. "
126+
f"This means there is a consistency bug either in `_get_params` "
127+
f"or in the `is_simple_tensor` path in `_transform`."
128+
) from exc
129+
130+
torch.testing.assert_close(
131+
output_prototype_tensor,
132+
output_legacy_tensor,
133+
atol=0,
134+
rtol=0,
135+
msg=lambda msg: f"Tensor image consistency check failed with: \n\n{msg}",
136+
)
137+
138+
try:
139+
output_prototype_image = prototype(image)
140+
except Exception as exc:
141+
raise AssertionError(
142+
f"Transforming a feature image with shape {tuple(image.shape)} failed with the error above. "
143+
f"This means there is a consistency bug either in `_get_params` "
144+
f"or in the `features.Image` path in `_transform`."
145+
) from exc
146+
147+
torch.testing.assert_close(
148+
torch.Tensor(output_prototype_image),
149+
output_prototype_tensor,
150+
atol=0,
151+
rtol=0,
152+
msg=lambda msg: f"Output for feature and tensor images is not equal: \n\n{msg}",
153+
)
154+
155+
if image_pil is not None:
156+
torch.testing.assert_close(
157+
to_image_tensor(prototype(image_pil)),
158+
to_image_tensor(legacy(image_pil)),
159+
atol=0,
160+
rtol=0,
161+
msg=lambda msg: f"PIL image consistency check failed with: \n\n{msg}",
162+
)

torchvision/prototype/transforms/_geometry.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,11 @@ def __init__(
4444
) -> None:
4545
super().__init__()
4646

47-
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
47+
self.size = (
48+
[size]
49+
if isinstance(size, int)
50+
else _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
51+
)
4852
self.interpolation = interpolation
4953
self.max_size = max_size
5054
self.antialias = antialias

0 commit comments

Comments
 (0)