Skip to content
This repository was archived by the owner on Oct 9, 2023. It is now read-only.

Commit 6742323

Browse files
Bordakrshrimaliethanwharris
authored
refactoring Img Segm augmentation with albumentations (#1313)
Co-authored-by: Kushashwa Ravi Shrimali <kushashwaravishrimali@gmail.com> Co-authored-by: Ethan Harris <ethanwharris@gmail.com>
1 parent c05a3ea commit 6742323

18 files changed

Lines changed: 131 additions & 186 deletions

File tree

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
4444

4545
- Changed the `ImageEmbedder` dependency on VISSL to optional ([#1276](https://github.com/PyTorchLightning/lightning-flash/pull/1276))
4646

47+
- Changed the transforms in `SemanticSegmentationData` to use albumentations instead of Kornia ([#1313](https://github.com/PyTorchLightning/lightning-flash/pull/1313))
48+
4749
### Deprecated
4850

4951
### Removed

docs/source/api/data.rst

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -119,13 +119,6 @@ __________________________
119119
:template: classtemplate.rst
120120

121121
~flash.core.data.transforms.ApplyToKeys
122-
~flash.core.data.transforms.KorniaParallelTransforms
123-
124-
.. autosummary::
125-
:toctree: generated/
126-
:nosignatures:
127-
128-
~flash.core.data.transforms.kornia_collate
129122

130123
flash.core.data.utils
131124
_____________________

flash/core/data/transforms.py

Lines changed: 42 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,51 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Any, Dict, Mapping, Sequence, Union
14+
from typing import Any, Mapping, Sequence, Union
1515

16-
import torch
16+
import numpy as np
1717
from torch import nn
1818

19-
from flash.core.data.utilities.collate import default_collate
19+
from flash.core.data.io.input import DataKeys
2020
from flash.core.data.utils import convert_to_modules
21+
from flash.core.utilities.imports import _ALBUMENTATIONS_AVAILABLE, requires
22+
23+
if _ALBUMENTATIONS_AVAILABLE:
24+
from albumentations import BasicTransform, Compose
25+
else:
26+
BasicTransform, Compose = object, object
27+
28+
29+
class AlbumentationsAdapter(nn.Module):
30+
# mapping from albumentations to Flash
31+
TRANSFORM_INPUT_MAPPING = {"image": DataKeys.INPUT, "mask": DataKeys.TARGET}
32+
33+
@requires("albumentations")
34+
def __init__(
35+
self,
36+
transform: Union[BasicTransform, Sequence[BasicTransform]],
37+
mapping: dict = None,
38+
):
39+
super().__init__()
40+
if not isinstance(transform, (list, tuple)):
41+
transform = [transform]
42+
self.transform = Compose(list(transform))
43+
if not mapping:
44+
mapping = self.TRANSFORM_INPUT_MAPPING
45+
self._mapping_rev = mapping
46+
self._mapping = {v: k for k, v in mapping.items()}
47+
48+
def forward(self, x: Any) -> Any:
49+
if isinstance(x, dict):
50+
x_ = {self._mapping.get(key, key): np.array(value) for key, value in x.items() if key in self._mapping}
51+
else:
52+
x_ = {"image": x}
53+
x_ = self.transform(**x_)
54+
if isinstance(x, dict):
55+
x.update({self._mapping_rev.get(k, k): x_[k] for k in self._mapping_rev if k in x_})
56+
else:
57+
x = x_["image"]
58+
return x
2159

2260

2361
class ApplyToKeys(nn.Sequential):
@@ -49,9 +87,7 @@ def forward(self, x: Mapping[str, Any]) -> Mapping[str, Any]:
4987
try:
5088
outputs = super().forward(inputs)
5189
except TypeError as e:
52-
raise Exception(
53-
"Failed to apply transforms to multiple keys at the same time, try using KorniaParallelTransforms."
54-
) from e
90+
raise Exception("Failed to apply transforms to multiple keys at the same time.") from e
5591

5692
for i, key in enumerate(keys):
5793
result[key] = outputs[i]
@@ -66,54 +102,3 @@ def __repr__(self):
66102
transform = transform[0] if len(transform) == 1 else transform
67103

68104
return f"{self.__class__.__name__}(keys={repr(keys)}, transform={repr(transform)})"
69-
70-
71-
class KorniaParallelTransforms(nn.Sequential):
72-
"""The ``KorniaParallelTransforms`` class is an ``nn.Sequential`` which will apply the given transforms to each
73-
input (to ``.forward``) in parallel, whilst sharing the random state (``._params``). This should be used when
74-
multiple elements need to be augmented in the same way (e.g. an image and corresponding segmentation mask).
75-
76-
Args:
77-
args: The transforms, passed to the ``nn.Sequential`` super constructor.
78-
"""
79-
80-
def __init__(self, *args):
81-
super().__init__(*(convert_to_modules(arg) for arg in args))
82-
83-
def forward(self, inputs: Any):
84-
result = list(inputs) if isinstance(inputs, Sequence) else [inputs]
85-
for transform in self.children():
86-
inputs = result
87-
88-
# we enforce the first time to sample random params
89-
result[0] = transform(inputs[0])
90-
91-
if hasattr(transform, "_params") and bool(transform._params):
92-
params = transform._params
93-
else:
94-
params = None
95-
96-
# apply transforms from (1, n)
97-
for i, input in enumerate(inputs[1:]):
98-
if params is not None:
99-
result[i + 1] = transform(input, params)
100-
else: # case for non-random transforms
101-
result[i + 1] = transform(input)
102-
if hasattr(transform, "_params") and bool(transform._params):
103-
transform._params = None
104-
return result
105-
106-
107-
def kornia_collate(samples: Sequence[Dict[str, Any]]) -> Dict[str, Any]:
108-
"""Kornia transforms add batch dimension which need to be removed.
109-
110-
This function removes that dimension and then
111-
applies ``torch.utils.data._utils.collate.default_collate``.
112-
"""
113-
if len(samples) == 1 and isinstance(samples[0], list):
114-
samples = samples[0]
115-
for sample in samples:
116-
for key in sample.keys():
117-
if torch.is_tensor(sample[key]) and sample[key].ndim == 4:
118-
sample[key] = sample[key].squeeze(0)
119-
return default_collate(samples)

flash/core/data/utilities/collate.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@
2020

2121

2222
def _wrap_collate(collate: Callable, batch: List[Any]) -> Any:
23+
# Needed for learn2learn integration
24+
if len(batch) == 1 and isinstance(batch[0], list):
25+
batch = batch[0]
26+
2327
metadata = [sample.pop(DataKeys.METADATA, None) if isinstance(sample, Mapping) else None for sample in batch]
2428
metadata = metadata if any(m is not None for m in metadata) else None
2529

flash/core/utilities/imports.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ class Image:
156156
_TORCHVISION_AVAILABLE,
157157
_TIMM_AVAILABLE,
158158
_PIL_AVAILABLE,
159-
_KORNIA_AVAILABLE,
159+
_ALBUMENTATIONS_AVAILABLE,
160160
_PYSTICHE_AVAILABLE,
161161
_SEGMENTATION_MODELS_AVAILABLE,
162162
]

flash/image/classification/input_transform.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,14 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from dataclasses import dataclass
15-
from typing import Callable, Tuple, Union
15+
from typing import Tuple, Union
1616

1717
import torch
1818
from torch import nn
1919

2020
from flash.core.data.io.input import DataKeys
2121
from flash.core.data.io.input_transform import InputTransform
22-
from flash.core.data.transforms import ApplyToKeys, kornia_collate
22+
from flash.core.data.transforms import ApplyToKeys
2323
from flash.core.utilities.imports import _ALBUMENTATIONS_AVAILABLE, _TORCHVISION_AVAILABLE, requires
2424

2525
if _TORCHVISION_AVAILABLE:
@@ -76,7 +76,3 @@ def train_per_sample_transform(self):
7676
ApplyToKeys(DataKeys.TARGET, torch.as_tensor),
7777
]
7878
)
79-
80-
def collate(self) -> Callable:
81-
# TODO: Remove kornia collate for default_collate
82-
return kornia_collate

flash/image/classification/integrations/learn2learn.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@
2121

2222
import pytorch_lightning as pl
2323
from torch.utils.data import IterableDataset
24-
from torch.utils.data._utils.collate import default_collate
2524
from torch.utils.data._utils.worker import get_worker_info
2625

26+
from flash.core.data.utilities.collate import default_collate
2727
from flash.core.utilities.imports import requires
2828

2929

@@ -109,7 +109,6 @@ def __init__(
109109
self.epoch_length = epoch_length
110110
self.seed = seed
111111
self.iteration = 0
112-
self.iteration = 0
113112
self.requires_divisible = requires_divisible
114113
self.counter = 0
115114

flash/image/instance_segmentation/data.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,12 @@
2222
from flash.core.data.utilities.sort import sorted_alphanumeric
2323
from flash.core.integrations.icevision.data import IceVisionInput
2424
from flash.core.integrations.icevision.transforms import IceVisionInputTransform
25-
from flash.core.utilities.imports import _ICEVISION_AVAILABLE, _IMAGE_EXTRAS_TESTING, _KORNIA_AVAILABLE
25+
from flash.core.utilities.imports import (
26+
_ICEVISION_AVAILABLE,
27+
_IMAGE_EXTRAS_TESTING,
28+
_TORCHVISION_AVAILABLE,
29+
_TORCHVISION_GREATER_EQUAL_0_9,
30+
)
2631
from flash.core.utilities.stages import RunningStage
2732
from flash.core.utilities.types import INPUT_TRANSFORM_TYPE
2833

@@ -34,8 +39,15 @@
3439
VOCMaskParser = object
3540
Parser = object
3641

37-
if _KORNIA_AVAILABLE:
38-
import kornia as K
42+
if _TORCHVISION_AVAILABLE:
43+
from torchvision import transforms as T
44+
45+
if _TORCHVISION_GREATER_EQUAL_0_9:
46+
from torchvision.transforms import InterpolationMode
47+
else:
48+
49+
class InterpolationMode:
50+
NEAREST = "nearest"
3951

4052

4153
# Skip doctests if requirements aren't available
@@ -45,7 +57,7 @@
4557

4658
class InstanceSegmentationOutputTransform(OutputTransform):
4759
def per_sample_transform(self, sample: Any) -> Any:
48-
resize = K.geometry.Resize(sample[DataKeys.METADATA]["size"], interpolation="nearest")
60+
resize = T.Resize(sample[DataKeys.METADATA]["size"], interpolation=InterpolationMode.NEAREST)
4961
sample[DataKeys.PREDS]["masks"] = [resize(tensor(mask)) for mask in sample[DataKeys.PREDS]["masks"]]
5062
return sample[DataKeys.PREDS]
5163

flash/image/segmentation/input.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def load_data(
9494

9595
def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:
9696
if DataKeys.TARGET in sample:
97-
sample[DataKeys.TARGET] = np.array(load_image(sample[DataKeys.TARGET]))[:, :, 0]
97+
sample[DataKeys.TARGET] = np.array(load_image(sample[DataKeys.TARGET])).transpose((2, 0, 1))[:, :, 0]
9898
return super().load_sample(sample)
9999

100100

flash/image/segmentation/input_transform.py

Lines changed: 29 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,17 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from dataclasses import dataclass
15-
from typing import Any, Callable, Dict, Tuple, Union
16-
17-
import torch
15+
from typing import Any, Callable, Dict, Tuple
1816

1917
from flash.core.data.io.input import DataKeys
2018
from flash.core.data.io.input_transform import InputTransform
21-
from flash.core.data.transforms import ApplyToKeys, kornia_collate, KorniaParallelTransforms
22-
from flash.core.utilities.imports import _KORNIA_AVAILABLE, _TORCHVISION_AVAILABLE, requires
19+
from flash.core.data.transforms import AlbumentationsAdapter, ApplyToKeys
20+
from flash.core.utilities.imports import _ALBUMENTATIONS_AVAILABLE, _TORCHVISION_AVAILABLE, requires
2321

24-
if _KORNIA_AVAILABLE:
25-
import kornia as K
22+
if _ALBUMENTATIONS_AVAILABLE:
23+
import albumentations as alb
24+
else:
25+
alb = None
2626

2727
if _TORCHVISION_AVAILABLE:
2828
from torchvision import transforms as T
@@ -31,16 +31,16 @@
3131
def prepare_target(batch: Dict[str, Any]) -> Dict[str, Any]:
3232
"""Convert the target mask to long and remove the channel dimension."""
3333
if DataKeys.TARGET in batch:
34-
batch[DataKeys.TARGET] = batch[DataKeys.TARGET].long().squeeze(1)
34+
batch[DataKeys.TARGET] = batch[DataKeys.TARGET].squeeze().long()
3535
return batch
3636

3737

38-
def target_as_tensor(sample: Dict[str, Any]) -> Dict[str, Any]:
38+
def permute_target(sample: Dict[str, Any]) -> Dict[str, Any]:
3939
if DataKeys.TARGET in sample:
4040
target = sample[DataKeys.TARGET]
4141
if target.ndim == 2:
42-
target = target[:, :, None]
43-
sample[DataKeys.TARGET] = torch.from_numpy(target.transpose((2, 0, 1))).contiguous().squeeze().float()
42+
target = target[None, :, :]
43+
sample[DataKeys.TARGET] = target.transpose((1, 2, 0))
4444
return sample
4545

4646

@@ -53,62 +53,48 @@ def remove_extra_dimensions(batch: Dict[str, Any]):
5353

5454
@dataclass
5555
class SemanticSegmentationInputTransform(InputTransform):
56+
# https://albumentations.ai/docs/examples/pytorch_semantic_segmentation
5657

5758
image_size: Tuple[int, int] = (128, 128)
58-
mean: Union[float, Tuple[float, float, float]] = (0.485, 0.456, 0.406)
59-
std: Union[float, Tuple[float, float, float]] = (0.229, 0.224, 0.225)
59+
mean: Tuple[float, float, float] = (0.485, 0.456, 0.406)
60+
std: Tuple[float, float, float] = (0.229, 0.224, 0.225)
6061

6162
@requires("image")
6263
def train_per_sample_transform(self) -> Callable:
6364
return T.Compose(
6465
[
66+
permute_target,
67+
AlbumentationsAdapter(
68+
[
69+
alb.Resize(*self.image_size),
70+
alb.HorizontalFlip(p=0.5),
71+
alb.Normalize(mean=self.mean, std=self.std),
72+
]
73+
),
6574
ApplyToKeys(
6675
DataKeys.INPUT,
6776
T.ToTensor(),
6877
),
69-
target_as_tensor,
70-
ApplyToKeys(
71-
[DataKeys.INPUT, DataKeys.TARGET],
72-
KorniaParallelTransforms(
73-
K.geometry.Resize(self.image_size, interpolation="nearest"),
74-
K.augmentation.RandomHorizontalFlip(p=0.5),
75-
),
76-
),
77-
ApplyToKeys([DataKeys.INPUT], K.augmentation.Normalize(mean=self.mean, std=self.std)),
7878
]
7979
)
8080

8181
@requires("image")
8282
def per_sample_transform(self) -> Callable:
8383
return T.Compose(
8484
[
85+
permute_target,
86+
AlbumentationsAdapter(
87+
[
88+
alb.Resize(*self.image_size),
89+
alb.Normalize(mean=self.mean, std=self.std),
90+
]
91+
),
8592
ApplyToKeys(
8693
DataKeys.INPUT,
8794
T.ToTensor(),
8895
),
89-
target_as_tensor,
90-
ApplyToKeys(
91-
[DataKeys.INPUT, DataKeys.TARGET],
92-
KorniaParallelTransforms(K.geometry.Resize(self.image_size, interpolation="nearest")),
93-
),
94-
ApplyToKeys([DataKeys.INPUT], K.augmentation.Normalize(mean=self.mean, std=self.std)),
9596
]
9697
)
9798

98-
@requires("image")
99-
def predict_per_sample_transform(self) -> Callable:
100-
return ApplyToKeys(
101-
DataKeys.INPUT,
102-
T.ToTensor(),
103-
K.geometry.Resize(
104-
self.image_size,
105-
interpolation="nearest",
106-
),
107-
K.augmentation.Normalize(mean=self.mean, std=self.std),
108-
)
109-
110-
def collate(self) -> Callable:
111-
return kornia_collate
112-
11399
def per_batch_transform(self) -> Callable:
114100
return T.Compose([prepare_target, remove_extra_dimensions])

0 commit comments

Comments
 (0)