Skip to content

Sat-SlideMix Augmentation #2742

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions tests/transforms/test_augmentations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import kornia.augmentation as K
import pytest
import torch

from torchgeo.transforms import SatSlideMix


def test_sat_slidemix() -> None:
b, c, h, w = 2, 3, 64, 64
gamma = 2
batch = {
'image': torch.randn(size=(b, c, h, w)),
'mask': torch.randint(low=0, high=2, size=(b, 1, h, w)),
}
aug = K.AugmentationSequential(
SatSlideMix(gamma=gamma, beta=(0.0, 1.0), p=1.0), data_keys=None
)
out = aug(batch)
assert out['image'].shape == (b * gamma, c, h, w)
assert out['mask'].shape == (b * gamma, 1, h, w)

# Catch that assertion is thrown if gamma is not a positive integer
with pytest.raises(AssertionError, match='gamma must be a positive integer'):
SatSlideMix(gamma=0)
2 changes: 2 additions & 0 deletions torchgeo/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

"""TorchGeo transforms."""

from .augmentations import SatSlideMix
from .color import RandomGrayscale
from .indices import (
AppendBNDVI,
Expand Down Expand Up @@ -39,4 +40,5 @@
'AppendTriBandNormalizedDifferenceIndex',
'AugmentationSequential',
'RandomGrayscale',
'SatSlideMix',
)
115 changes: 115 additions & 0 deletions torchgeo/transforms/augmentations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""TorchGeo augmentations."""

from typing import Any

import kornia.augmentation as K
import torch
from torch import Tensor


class SatSlideMix(K.GeometricAugmentationBase2D):
"""Applies the Sat-SlideMix augmentation to a batch of images and masks.

Sat-SlideMix rolls (circularly shifts) images along either the height
or width axis by a random amount.

If you use this method in your research, please cite the following paper:

* https://doi.org/10.1609/aaai.v39i27.35028

.. versionadded:: 0.8
"""

def __init__(
self,
gamma: int = 1,
beta: Tensor | float | tuple[float, float] | list[float] = (0.0, 1.0),
p: float = 0.5,
) -> None:
"""Initialize a new _ExtractPatches instance.

Args:
gamma: The number of augmented samples to create for each
input image. The output batch size will be gamma * B.
beta: The range of percentage (0.0 to 1.0) of the image
dimension (height or width) to shift.
p: Probability to apply the augmentation on each sample

Raises:
AssertionError: If `gamma` is not a positive integer.
"""
super().__init__(p=p)
assert isinstance(gamma, int) and gamma > 0, 'gamma must be a positive integer'
self._param_generator = K.random_generator.PlainUniformGenerator(
(beta, 'beta', 0.5, (0.0, 1.0)),
((0.0, 1.0), 'dim', 0.5, (0.0, 1.0)),
((0.0, 1.0), 'direction', 0.5, (0.0, 1.0)),
)
self.flags = {'gamma': gamma}

def generate_parameters(self, batch_shape: tuple[int, ...]) -> dict[str, Tensor]:
"""Generate parameters for the batch."""
B, C, H, W = batch_shape
batch_shape = torch.Size((B * self.flags['gamma'], C, H, W))
params: dict[str, Tensor] = self._param_generator(
batch_shape, self.same_on_batch
)
return params

def compute_transformation(
self, input: Tensor, params: dict[str, Tensor], flags: dict[str, Any]
) -> Tensor:
"""Compute the transformation.

Args:
input: the input tensor
params: generated parameters
flags: static parameters

Returns:
the transformation
"""
out: Tensor = self.identity_matrix(input)
return out

def apply_transform(
self,
input: Tensor,
params: dict[str, Tensor],
flags: dict[str, Any],
transform: Tensor | None = None,
) -> Tensor:
"""Apply the transform to the input image or mask.

Args:
input: the input tensor image or mask
params: generated parameters
flags: static parameters
transform: the geometric transformation tensor

Returns:
the augmented input
"""
directions = (params['direction'].round() * 2.0 - 1.0).to(
torch.int
) # convert to -1 or 1
dims = params['dim'].round().to(torch.int) + 2 # convert to 2 or 3
sizes = torch.index_select(torch.tensor(input.shape), dim=0, index=dims)
betas = params['beta']

# Repeat each image gamma times (B*gamma, C, H, W)
out = input.repeat_interleave(flags['gamma'], dim=0)

# It's necessary to roll each image individually if shifts/dims vary
# Apply roll to the i-th image along the chosen dimension
# Note: We roll out[i] which has shape (C, H, W).
# Because out[i] is a 3D tensor, we index using dim - 1 for torch.roll.
for i, (beta, dim, direction, size) in enumerate(
zip(betas, dims, directions, sizes, strict=True)
):
shift = torch.round(beta * size * direction)
out[i] = torch.roll(out[i], shifts=int(shift), dims=int(dim) - 1)
return out
Loading