Skip to content
This repository was archived by the owner on Jul 1, 2024. It is now read-only.

Commit 7afbd9e

Browse files
mannatsinghfacebook-github-bot
authored andcommitted
Open source AutoAugment (#655)
Summary: Pull Request resolved: #655 Open source the autoaugment implementation Reviewed By: vreis Differential Revision: D24940146 fbshipit-source-id: 93bf8d61afcdc1c623a697776efcab52971984e3
1 parent 8bc1903 commit 7afbd9e

File tree

3 files changed

+319
-0
lines changed

3 files changed

+319
-0
lines changed

NOTICE

+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
=======================================================================
2+
AutoAugment's MIT license
3+
=======================================================================
4+
We modified and utilize the AutoAugment implementation from
5+
https://github.com/DeepVoltaire/AutoAugment. The license is as follows:
6+
7+
MIT License
8+
9+
Copyright (c) 2018 Philip Popien
10+
11+
Permission is hereby granted, free of charge, to any person obtaining a copy
12+
of this software and associated documentation files (the "Software"), to deal
13+
in the Software without restriction, including without limitation the rights
14+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
15+
copies of the Software, and to permit persons to whom the Software is
16+
furnished to do so, subject to the following conditions:
17+
18+
The above copyright notice and this permission notice shall be included in all
19+
copies or substantial portions of the Software.
20+
21+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
22+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
23+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
24+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
25+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
26+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
27+
SOFTWARE.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,260 @@
1+
#!/usr/bin/env python3
2+
# Portions Copyright (c) Facebook, Inc. and its affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# MIT License
8+
#
9+
# Copyright (c) 2018 Philip Popien
10+
#
11+
# Permission is hereby granted, free of charge, to any person obtaining a copy
12+
# of this software and associated documentation files (the "Software"), to deal
13+
# in the Software without restriction, including without limitation the rights
14+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
15+
# copies of the Software, and to permit persons to whom the Software is
16+
# furnished to do so, subject to the following conditions:
17+
#
18+
# The above copyright notice and this permission notice shall be included in all
19+
# copies or substantial portions of the Software.
20+
#
21+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
22+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
23+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
24+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
25+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
26+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
27+
# SOFTWARE.
28+
29+
# Code modified from
30+
# https://github.com/DeepVoltaire/AutoAugment/blob/master/autoaugment.py
31+
32+
import random
33+
import random
34+
from enum import Enum, auto
35+
from functools import partial
36+
from typing import Any
37+
from typing import Tuple, Any, NamedTuple, Sequence, Callable
38+
39+
import numpy as np
40+
from classy_vision.dataset.transforms import ClassyTransform, register_transform
41+
from PIL import Image, ImageEnhance, ImageOps
42+
43+
44+
MIDDLE_GRAY = (128, 128, 128)
45+
46+
47+
class ImageOp(Enum):
48+
SHEAR_X = auto()
49+
SHEAR_Y = auto()
50+
TRANSLATE_X = auto()
51+
TRANSLATE_Y = auto()
52+
ROTATE = auto()
53+
AUTO_CONTRAST = auto()
54+
INVERT = auto()
55+
EQUALIZE = auto()
56+
SOLARIZE = auto()
57+
POSTERIZE = auto()
58+
CONTRAST = auto()
59+
COLOR = auto()
60+
BRIGHTNESS = auto()
61+
SHARPNESS = auto()
62+
63+
64+
class ImageOpSetting(NamedTuple):
65+
ranges: Sequence
66+
function: Callable
67+
68+
69+
def shear_x(img: Any, magnitude: int, fillcolor: Any = None) -> Any:
70+
return img.transform(
71+
img.size,
72+
Image.AFFINE,
73+
(1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0),
74+
Image.BICUBIC,
75+
fillcolor=fillcolor,
76+
)
77+
78+
79+
def shear_y(img: Any, magnitude: int, fillcolor: Any = None) -> Any:
80+
return img.transform(
81+
img.size,
82+
Image.AFFINE,
83+
(1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0),
84+
Image.BICUBIC,
85+
fillcolor=fillcolor,
86+
)
87+
88+
89+
def translate_x(img: Any, magnitude: int, fillcolor: Any = None) -> Any:
90+
return img.transform(
91+
img.size,
92+
Image.AFFINE,
93+
(1, 0, magnitude * img.size[0] * random.choice([-1, 1]), 0, 1, 0),
94+
fillcolor=fillcolor,
95+
)
96+
97+
98+
def translate_y(img: Any, magnitude: int, fillcolor: Any = None) -> Any:
99+
return img.transform(
100+
img.size,
101+
Image.AFFINE,
102+
(1, 0, 0, 0, 1, magnitude * img.size[1] * random.choice([-1, 1])),
103+
fillcolor=fillcolor,
104+
)
105+
106+
107+
# from https://stackoverflow.com/questions/5252170/specify-image-filling-color-when-rotating-in-python-with-pil-and-setting-expand # noqa
108+
def rotate_with_fill(img: Any, magnitude: int) -> Any:
109+
rot = img.convert("RGBA").rotate(magnitude)
110+
return Image.composite(rot, Image.new("RGBA", rot.size, (128,) * 4), rot).convert(
111+
img.mode
112+
)
113+
114+
115+
def color(img: Any, magnitude: int) -> Any:
116+
return ImageEnhance.Color(img).enhance(1 + magnitude * random.choice([-1, 1]))
117+
118+
119+
def posterize(img: Any, magnitude: int) -> Any:
120+
return ImageOps.posterize(img, magnitude)
121+
122+
123+
def solarize(img: Any, magnitude: int) -> Any:
124+
return ImageOps.solarize(img, magnitude)
125+
126+
127+
def contrast(img: Any, magnitude: int) -> Any:
128+
return ImageEnhance.Contrast(img).enhance(1 + magnitude * random.choice([-1, 1]))
129+
130+
131+
def sharpness(img: Any, magnitude: int) -> Any:
132+
return ImageEnhance.Sharpness(img).enhance(1 + magnitude * random.choice([-1, 1]))
133+
134+
135+
def brightness(img: Any, magnitude: int) -> Any:
136+
return ImageEnhance.Brightness(img).enhance(1 + magnitude * random.choice([-1, 1]))
137+
138+
139+
def auto_contrast(img: Any, magnitude: int) -> Any:
140+
return ImageOps.autocontrast(img)
141+
142+
143+
def equalize(img: Any, magnitude: int) -> Any:
144+
return ImageOps.equalize(img)
145+
146+
147+
def invert(img: Any, magnitude: int) -> Any:
148+
return ImageOps.invert(img)
149+
150+
151+
def get_image_op_settings(
152+
image_op: ImageOp, fillcolor: Tuple[int, int, int] = MIDDLE_GRAY
153+
):
154+
return {
155+
ImageOp.SHEAR_X: ImageOpSetting(
156+
np.linspace(0, 0.3, 10), partial(shear_x, fillcolor=fillcolor)
157+
),
158+
ImageOp.SHEAR_Y: ImageOpSetting(
159+
np.linspace(0, 0.3, 10), partial(shear_y, fillcolor=fillcolor)
160+
),
161+
ImageOp.TRANSLATE_X: ImageOpSetting(
162+
np.linspace(0, 150 / 331, 10), partial(translate_x, fillcolor=fillcolor)
163+
),
164+
ImageOp.TRANSLATE_Y: ImageOpSetting(
165+
np.linspace(0, 150 / 331, 10), partial(translate_y, fillcolor=fillcolor)
166+
),
167+
ImageOp.ROTATE: ImageOpSetting(np.linspace(0, 30, 10), rotate_with_fill),
168+
ImageOp.COLOR: ImageOpSetting(np.linspace(0.0, 0.9, 10), color),
169+
ImageOp.POSTERIZE: ImageOpSetting(
170+
np.round(np.linspace(8, 4, 10), 0).astype(np.int), posterize
171+
),
172+
ImageOp.SOLARIZE: ImageOpSetting(np.linspace(256, 0, 10), solarize),
173+
ImageOp.CONTRAST: ImageOpSetting(np.linspace(0.0, 0.9, 10), contrast),
174+
ImageOp.SHARPNESS: ImageOpSetting(np.linspace(0.0, 0.9, 10), sharpness),
175+
ImageOp.BRIGHTNESS: ImageOpSetting(np.linspace(0.0, 0.9, 10), brightness),
176+
ImageOp.AUTO_CONTRAST: ImageOpSetting([0] * 10, auto_contrast),
177+
ImageOp.EQUALIZE: ImageOpSetting([0] * 10, equalize),
178+
ImageOp.INVERT: ImageOpSetting([0] * 10, invert),
179+
}[image_op]
180+
181+
182+
class SubPolicy:
183+
def __init__(
184+
self,
185+
operation1: ImageOp,
186+
magnitude_idx1: int,
187+
p1: float,
188+
operation2: ImageOp,
189+
magnitude_idx2: int,
190+
p2: float,
191+
fillcolor: Tuple[int, int, int] = MIDDLE_GRAY,
192+
) -> None:
193+
operation1_settings = get_image_op_settings(operation1, fillcolor)
194+
self.operation1 = operation1_settings.function
195+
self.magnitude1 = operation1_settings.ranges[magnitude_idx1]
196+
self.p1 = p1
197+
198+
operation2_settings = get_image_op_settings(operation2, fillcolor)
199+
self.operation2 = operation2_settings.function
200+
self.magnitude2 = operation2_settings.ranges[magnitude_idx2]
201+
self.p2 = p2
202+
203+
def __call__(self, img: Any) -> Any:
204+
if random.random() < self.p1:
205+
img = self.operation1(img, self.magnitude1)
206+
if random.random() < self.p2:
207+
img = self.operation2(img, self.magnitude2)
208+
return img
209+
210+
211+
@register_transform("imagenet_autoaugment")
212+
class ImagenetAutoAugment(ClassyTransform):
213+
"""Randomly choose one of the best 24 Sub-policies on ImageNet.
214+
215+
Example:
216+
>>> policy = ImageNetPolicy()
217+
>>> transformed = policy(image)
218+
219+
Example as a PyTorch Transform:
220+
>>> transform=transforms.Compose([
221+
>>> transforms.Resize(256),
222+
>>> ImageNetPolicy(),
223+
>>> transforms.ToTensor()])
224+
"""
225+
226+
def __init__(self, fillcolor: Tuple[int, int, int] = MIDDLE_GRAY) -> None:
227+
self.policies = [
228+
SubPolicy(ImageOp.POSTERIZE, 8, 0.4, ImageOp.ROTATE, 9, 0.6, fillcolor),
229+
SubPolicy(
230+
ImageOp.SOLARIZE, 5, 0.6, ImageOp.AUTO_CONTRAST, 5, 0.6, fillcolor
231+
),
232+
SubPolicy(ImageOp.EQUALIZE, 8, 0.8, ImageOp.EQUALIZE, 3, 0.6, fillcolor),
233+
SubPolicy(ImageOp.POSTERIZE, 7, 0.6, ImageOp.POSTERIZE, 6, 0.6, fillcolor),
234+
SubPolicy(ImageOp.EQUALIZE, 7, 0.4, ImageOp.SOLARIZE, 4, 0.2, fillcolor),
235+
SubPolicy(ImageOp.EQUALIZE, 4, 0.4, ImageOp.ROTATE, 8, 0.8, fillcolor),
236+
SubPolicy(ImageOp.SOLARIZE, 3, 0.6, ImageOp.EQUALIZE, 7, 0.6, fillcolor),
237+
SubPolicy(ImageOp.POSTERIZE, 5, 0.8, ImageOp.EQUALIZE, 2, 1.0, fillcolor),
238+
SubPolicy(ImageOp.ROTATE, 3, 0.2, ImageOp.SOLARIZE, 8, 0.6, fillcolor),
239+
SubPolicy(ImageOp.EQUALIZE, 8, 0.6, ImageOp.POSTERIZE, 6, 0.4, fillcolor),
240+
SubPolicy(ImageOp.ROTATE, 8, 0.8, ImageOp.COLOR, 0, 0.4, fillcolor),
241+
SubPolicy(ImageOp.ROTATE, 9, 0.4, ImageOp.EQUALIZE, 2, 0.6, fillcolor),
242+
SubPolicy(ImageOp.EQUALIZE, 7, 0.0, ImageOp.EQUALIZE, 8, 0.8, fillcolor),
243+
SubPolicy(ImageOp.INVERT, 4, 0.6, ImageOp.EQUALIZE, 8, 1.0, fillcolor),
244+
SubPolicy(ImageOp.COLOR, 4, 0.6, ImageOp.CONTRAST, 8, 1.0, fillcolor),
245+
SubPolicy(ImageOp.ROTATE, 8, 0.8, ImageOp.COLOR, 2, 1.0, fillcolor),
246+
SubPolicy(ImageOp.COLOR, 8, 0.8, ImageOp.SOLARIZE, 7, 0.8, fillcolor),
247+
SubPolicy(ImageOp.SHARPNESS, 7, 0.4, ImageOp.INVERT, 8, 0.6, fillcolor),
248+
SubPolicy(ImageOp.SHEAR_X, 5, 0.6, ImageOp.EQUALIZE, 9, 1.0, fillcolor),
249+
SubPolicy(ImageOp.COLOR, 0, 0.4, ImageOp.EQUALIZE, 3, 0.6, fillcolor),
250+
SubPolicy(ImageOp.EQUALIZE, 7, 0.4, ImageOp.SOLARIZE, 4, 0.2, fillcolor),
251+
SubPolicy(
252+
ImageOp.SOLARIZE, 5, 0.6, ImageOp.AUTO_CONTRAST, 5, 0.6, fillcolor
253+
),
254+
SubPolicy(ImageOp.INVERT, 4, 0.6, ImageOp.EQUALIZE, 8, 1.0, fillcolor),
255+
SubPolicy(ImageOp.COLOR, 4, 0.6, ImageOp.CONTRAST, 8, 1.0, fillcolor),
256+
]
257+
258+
def __call__(self, img: Any) -> Any:
259+
policy_idx = random.randint(0, len(self.policies) - 1)
260+
return self.policies[policy_idx](img)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Facebook, Inc. and its affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
from classy_vision.dataset.core.random_image_datasets import (
10+
RandomImageBinaryClassDataset,
11+
)
12+
from classy_vision.dataset.transforms.autoaugment import ImagenetAutoAugment # noqa
13+
from classy_vision.dataset.transforms.util import build_field_transform_default_imagenet
14+
15+
16+
class AutoaugmentTransformTest(unittest.TestCase):
17+
def get_test_image_dataset(self):
18+
return RandomImageBinaryClassDataset(
19+
crop_size=224, class_ratio=0.5, num_samples=100, seed=0
20+
)
21+
22+
def test_imagenet_autoaugment_transform_no_errors(self):
23+
"""
24+
Tests that the imagenet autoaugment transform runs without any errors.
25+
"""
26+
dataset = self.get_test_image_dataset()
27+
28+
config = [{"name": "imagenet_autoaugment"}]
29+
transform = build_field_transform_default_imagenet(config)
30+
sample = dataset[0]
31+
# test that imagenet autoaugment has been registered and runs without errors
32+
transform(sample)

0 commit comments

Comments
 (0)