|
| 1 | +#!/usr/bin/env python3 |
| 2 | +# Portions Copyright (c) Facebook, Inc. and its affiliates. |
| 3 | +# |
| 4 | +# This source code is licensed under the MIT license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +# MIT License |
| 8 | +# |
| 9 | +# Copyright (c) 2018 Philip Popien |
| 10 | +# |
| 11 | +# Permission is hereby granted, free of charge, to any person obtaining a copy |
| 12 | +# of this software and associated documentation files (the "Software"), to deal |
| 13 | +# in the Software without restriction, including without limitation the rights |
| 14 | +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell |
| 15 | +# copies of the Software, and to permit persons to whom the Software is |
| 16 | +# furnished to do so, subject to the following conditions: |
| 17 | +# |
| 18 | +# The above copyright notice and this permission notice shall be included in all |
| 19 | +# copies or substantial portions of the Software. |
| 20 | +# |
| 21 | +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
| 22 | +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
| 23 | +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
| 24 | +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER |
| 25 | +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, |
| 26 | +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE |
| 27 | +# SOFTWARE. |
| 28 | + |
| 29 | +# Code modified from |
| 30 | +# https://github.com/DeepVoltaire/AutoAugment/blob/master/autoaugment.py |
| 31 | + |
| 32 | +import random |
| 33 | +import random |
| 34 | +from enum import Enum, auto |
| 35 | +from functools import partial |
| 36 | +from typing import Any |
| 37 | +from typing import Tuple, Any, NamedTuple, Sequence, Callable |
| 38 | + |
| 39 | +import numpy as np |
| 40 | +from classy_vision.dataset.transforms import ClassyTransform, register_transform |
| 41 | +from PIL import Image, ImageEnhance, ImageOps |
| 42 | + |
| 43 | + |
| 44 | +MIDDLE_GRAY = (128, 128, 128) |
| 45 | + |
| 46 | + |
| 47 | +class ImageOp(Enum): |
| 48 | + SHEAR_X = auto() |
| 49 | + SHEAR_Y = auto() |
| 50 | + TRANSLATE_X = auto() |
| 51 | + TRANSLATE_Y = auto() |
| 52 | + ROTATE = auto() |
| 53 | + AUTO_CONTRAST = auto() |
| 54 | + INVERT = auto() |
| 55 | + EQUALIZE = auto() |
| 56 | + SOLARIZE = auto() |
| 57 | + POSTERIZE = auto() |
| 58 | + CONTRAST = auto() |
| 59 | + COLOR = auto() |
| 60 | + BRIGHTNESS = auto() |
| 61 | + SHARPNESS = auto() |
| 62 | + |
| 63 | + |
| 64 | +class ImageOpSetting(NamedTuple): |
| 65 | + ranges: Sequence |
| 66 | + function: Callable |
| 67 | + |
| 68 | + |
| 69 | +def shear_x(img: Any, magnitude: int, fillcolor: Any = None) -> Any: |
| 70 | + return img.transform( |
| 71 | + img.size, |
| 72 | + Image.AFFINE, |
| 73 | + (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0), |
| 74 | + Image.BICUBIC, |
| 75 | + fillcolor=fillcolor, |
| 76 | + ) |
| 77 | + |
| 78 | + |
| 79 | +def shear_y(img: Any, magnitude: int, fillcolor: Any = None) -> Any: |
| 80 | + return img.transform( |
| 81 | + img.size, |
| 82 | + Image.AFFINE, |
| 83 | + (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0), |
| 84 | + Image.BICUBIC, |
| 85 | + fillcolor=fillcolor, |
| 86 | + ) |
| 87 | + |
| 88 | + |
| 89 | +def translate_x(img: Any, magnitude: int, fillcolor: Any = None) -> Any: |
| 90 | + return img.transform( |
| 91 | + img.size, |
| 92 | + Image.AFFINE, |
| 93 | + (1, 0, magnitude * img.size[0] * random.choice([-1, 1]), 0, 1, 0), |
| 94 | + fillcolor=fillcolor, |
| 95 | + ) |
| 96 | + |
| 97 | + |
| 98 | +def translate_y(img: Any, magnitude: int, fillcolor: Any = None) -> Any: |
| 99 | + return img.transform( |
| 100 | + img.size, |
| 101 | + Image.AFFINE, |
| 102 | + (1, 0, 0, 0, 1, magnitude * img.size[1] * random.choice([-1, 1])), |
| 103 | + fillcolor=fillcolor, |
| 104 | + ) |
| 105 | + |
| 106 | + |
| 107 | +# from https://stackoverflow.com/questions/5252170/specify-image-filling-color-when-rotating-in-python-with-pil-and-setting-expand # noqa |
| 108 | +def rotate_with_fill(img: Any, magnitude: int) -> Any: |
| 109 | + rot = img.convert("RGBA").rotate(magnitude) |
| 110 | + return Image.composite(rot, Image.new("RGBA", rot.size, (128,) * 4), rot).convert( |
| 111 | + img.mode |
| 112 | + ) |
| 113 | + |
| 114 | + |
| 115 | +def color(img: Any, magnitude: int) -> Any: |
| 116 | + return ImageEnhance.Color(img).enhance(1 + magnitude * random.choice([-1, 1])) |
| 117 | + |
| 118 | + |
| 119 | +def posterize(img: Any, magnitude: int) -> Any: |
| 120 | + return ImageOps.posterize(img, magnitude) |
| 121 | + |
| 122 | + |
| 123 | +def solarize(img: Any, magnitude: int) -> Any: |
| 124 | + return ImageOps.solarize(img, magnitude) |
| 125 | + |
| 126 | + |
| 127 | +def contrast(img: Any, magnitude: int) -> Any: |
| 128 | + return ImageEnhance.Contrast(img).enhance(1 + magnitude * random.choice([-1, 1])) |
| 129 | + |
| 130 | + |
| 131 | +def sharpness(img: Any, magnitude: int) -> Any: |
| 132 | + return ImageEnhance.Sharpness(img).enhance(1 + magnitude * random.choice([-1, 1])) |
| 133 | + |
| 134 | + |
| 135 | +def brightness(img: Any, magnitude: int) -> Any: |
| 136 | + return ImageEnhance.Brightness(img).enhance(1 + magnitude * random.choice([-1, 1])) |
| 137 | + |
| 138 | + |
| 139 | +def auto_contrast(img: Any, magnitude: int) -> Any: |
| 140 | + return ImageOps.autocontrast(img) |
| 141 | + |
| 142 | + |
| 143 | +def equalize(img: Any, magnitude: int) -> Any: |
| 144 | + return ImageOps.equalize(img) |
| 145 | + |
| 146 | + |
| 147 | +def invert(img: Any, magnitude: int) -> Any: |
| 148 | + return ImageOps.invert(img) |
| 149 | + |
| 150 | + |
| 151 | +def get_image_op_settings( |
| 152 | + image_op: ImageOp, fillcolor: Tuple[int, int, int] = MIDDLE_GRAY |
| 153 | +): |
| 154 | + return { |
| 155 | + ImageOp.SHEAR_X: ImageOpSetting( |
| 156 | + np.linspace(0, 0.3, 10), partial(shear_x, fillcolor=fillcolor) |
| 157 | + ), |
| 158 | + ImageOp.SHEAR_Y: ImageOpSetting( |
| 159 | + np.linspace(0, 0.3, 10), partial(shear_y, fillcolor=fillcolor) |
| 160 | + ), |
| 161 | + ImageOp.TRANSLATE_X: ImageOpSetting( |
| 162 | + np.linspace(0, 150 / 331, 10), partial(translate_x, fillcolor=fillcolor) |
| 163 | + ), |
| 164 | + ImageOp.TRANSLATE_Y: ImageOpSetting( |
| 165 | + np.linspace(0, 150 / 331, 10), partial(translate_y, fillcolor=fillcolor) |
| 166 | + ), |
| 167 | + ImageOp.ROTATE: ImageOpSetting(np.linspace(0, 30, 10), rotate_with_fill), |
| 168 | + ImageOp.COLOR: ImageOpSetting(np.linspace(0.0, 0.9, 10), color), |
| 169 | + ImageOp.POSTERIZE: ImageOpSetting( |
| 170 | + np.round(np.linspace(8, 4, 10), 0).astype(np.int), posterize |
| 171 | + ), |
| 172 | + ImageOp.SOLARIZE: ImageOpSetting(np.linspace(256, 0, 10), solarize), |
| 173 | + ImageOp.CONTRAST: ImageOpSetting(np.linspace(0.0, 0.9, 10), contrast), |
| 174 | + ImageOp.SHARPNESS: ImageOpSetting(np.linspace(0.0, 0.9, 10), sharpness), |
| 175 | + ImageOp.BRIGHTNESS: ImageOpSetting(np.linspace(0.0, 0.9, 10), brightness), |
| 176 | + ImageOp.AUTO_CONTRAST: ImageOpSetting([0] * 10, auto_contrast), |
| 177 | + ImageOp.EQUALIZE: ImageOpSetting([0] * 10, equalize), |
| 178 | + ImageOp.INVERT: ImageOpSetting([0] * 10, invert), |
| 179 | + }[image_op] |
| 180 | + |
| 181 | + |
| 182 | +class SubPolicy: |
| 183 | + def __init__( |
| 184 | + self, |
| 185 | + operation1: ImageOp, |
| 186 | + magnitude_idx1: int, |
| 187 | + p1: float, |
| 188 | + operation2: ImageOp, |
| 189 | + magnitude_idx2: int, |
| 190 | + p2: float, |
| 191 | + fillcolor: Tuple[int, int, int] = MIDDLE_GRAY, |
| 192 | + ) -> None: |
| 193 | + operation1_settings = get_image_op_settings(operation1, fillcolor) |
| 194 | + self.operation1 = operation1_settings.function |
| 195 | + self.magnitude1 = operation1_settings.ranges[magnitude_idx1] |
| 196 | + self.p1 = p1 |
| 197 | + |
| 198 | + operation2_settings = get_image_op_settings(operation2, fillcolor) |
| 199 | + self.operation2 = operation2_settings.function |
| 200 | + self.magnitude2 = operation2_settings.ranges[magnitude_idx2] |
| 201 | + self.p2 = p2 |
| 202 | + |
| 203 | + def __call__(self, img: Any) -> Any: |
| 204 | + if random.random() < self.p1: |
| 205 | + img = self.operation1(img, self.magnitude1) |
| 206 | + if random.random() < self.p2: |
| 207 | + img = self.operation2(img, self.magnitude2) |
| 208 | + return img |
| 209 | + |
| 210 | + |
| 211 | +@register_transform("imagenet_autoaugment") |
| 212 | +class ImagenetAutoAugment(ClassyTransform): |
| 213 | + """Randomly choose one of the best 24 Sub-policies on ImageNet. |
| 214 | +
|
| 215 | + Example: |
| 216 | + >>> policy = ImageNetPolicy() |
| 217 | + >>> transformed = policy(image) |
| 218 | +
|
| 219 | + Example as a PyTorch Transform: |
| 220 | + >>> transform=transforms.Compose([ |
| 221 | + >>> transforms.Resize(256), |
| 222 | + >>> ImageNetPolicy(), |
| 223 | + >>> transforms.ToTensor()]) |
| 224 | + """ |
| 225 | + |
| 226 | + def __init__(self, fillcolor: Tuple[int, int, int] = MIDDLE_GRAY) -> None: |
| 227 | + self.policies = [ |
| 228 | + SubPolicy(ImageOp.POSTERIZE, 8, 0.4, ImageOp.ROTATE, 9, 0.6, fillcolor), |
| 229 | + SubPolicy( |
| 230 | + ImageOp.SOLARIZE, 5, 0.6, ImageOp.AUTO_CONTRAST, 5, 0.6, fillcolor |
| 231 | + ), |
| 232 | + SubPolicy(ImageOp.EQUALIZE, 8, 0.8, ImageOp.EQUALIZE, 3, 0.6, fillcolor), |
| 233 | + SubPolicy(ImageOp.POSTERIZE, 7, 0.6, ImageOp.POSTERIZE, 6, 0.6, fillcolor), |
| 234 | + SubPolicy(ImageOp.EQUALIZE, 7, 0.4, ImageOp.SOLARIZE, 4, 0.2, fillcolor), |
| 235 | + SubPolicy(ImageOp.EQUALIZE, 4, 0.4, ImageOp.ROTATE, 8, 0.8, fillcolor), |
| 236 | + SubPolicy(ImageOp.SOLARIZE, 3, 0.6, ImageOp.EQUALIZE, 7, 0.6, fillcolor), |
| 237 | + SubPolicy(ImageOp.POSTERIZE, 5, 0.8, ImageOp.EQUALIZE, 2, 1.0, fillcolor), |
| 238 | + SubPolicy(ImageOp.ROTATE, 3, 0.2, ImageOp.SOLARIZE, 8, 0.6, fillcolor), |
| 239 | + SubPolicy(ImageOp.EQUALIZE, 8, 0.6, ImageOp.POSTERIZE, 6, 0.4, fillcolor), |
| 240 | + SubPolicy(ImageOp.ROTATE, 8, 0.8, ImageOp.COLOR, 0, 0.4, fillcolor), |
| 241 | + SubPolicy(ImageOp.ROTATE, 9, 0.4, ImageOp.EQUALIZE, 2, 0.6, fillcolor), |
| 242 | + SubPolicy(ImageOp.EQUALIZE, 7, 0.0, ImageOp.EQUALIZE, 8, 0.8, fillcolor), |
| 243 | + SubPolicy(ImageOp.INVERT, 4, 0.6, ImageOp.EQUALIZE, 8, 1.0, fillcolor), |
| 244 | + SubPolicy(ImageOp.COLOR, 4, 0.6, ImageOp.CONTRAST, 8, 1.0, fillcolor), |
| 245 | + SubPolicy(ImageOp.ROTATE, 8, 0.8, ImageOp.COLOR, 2, 1.0, fillcolor), |
| 246 | + SubPolicy(ImageOp.COLOR, 8, 0.8, ImageOp.SOLARIZE, 7, 0.8, fillcolor), |
| 247 | + SubPolicy(ImageOp.SHARPNESS, 7, 0.4, ImageOp.INVERT, 8, 0.6, fillcolor), |
| 248 | + SubPolicy(ImageOp.SHEAR_X, 5, 0.6, ImageOp.EQUALIZE, 9, 1.0, fillcolor), |
| 249 | + SubPolicy(ImageOp.COLOR, 0, 0.4, ImageOp.EQUALIZE, 3, 0.6, fillcolor), |
| 250 | + SubPolicy(ImageOp.EQUALIZE, 7, 0.4, ImageOp.SOLARIZE, 4, 0.2, fillcolor), |
| 251 | + SubPolicy( |
| 252 | + ImageOp.SOLARIZE, 5, 0.6, ImageOp.AUTO_CONTRAST, 5, 0.6, fillcolor |
| 253 | + ), |
| 254 | + SubPolicy(ImageOp.INVERT, 4, 0.6, ImageOp.EQUALIZE, 8, 1.0, fillcolor), |
| 255 | + SubPolicy(ImageOp.COLOR, 4, 0.6, ImageOp.CONTRAST, 8, 1.0, fillcolor), |
| 256 | + ] |
| 257 | + |
| 258 | + def __call__(self, img: Any) -> Any: |
| 259 | + policy_idx = random.randint(0, len(self.policies) - 1) |
| 260 | + return self.policies[policy_idx](img) |
0 commit comments