Skip to content

Commit ee1e156

Browse files
Sllambiasstarostka
authored andcommitted
add Torch_Blur
1 parent 9021ff2 commit ee1e156

1 file changed

Lines changed: 41 additions & 0 deletions

File tree

  • yucca/modules/data/augmentation/transforms

yucca/modules/data/augmentation/transforms/Blur.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import numpy as np
33
from typing import Tuple
44
from yucca.functional.transforms import blur
5+
from yucca.functional.transforms.torch.blur import torch_blur
56

67

78
class Blur(YuccaTransform):
@@ -42,3 +43,43 @@ def __call__(self, packed_data_dict=None, **unpacked_data_dict):
4243
sigma = self.get_params(self.sigma)
4344
data_dict[self.data_key][b] = self.__blur__(data_dict[self.data_key][b], sigma)
4445
return data_dict
46+
47+
48+
class Torch_Blur(YuccaTransform):
49+
def __init__(
50+
self,
51+
data_key="image",
52+
p_per_sample: float = 1.0,
53+
p_per_channel: float = 0.5,
54+
sigma=(0.5, 1.0),
55+
clip_to_input_range=False,
56+
):
57+
self.data_key = data_key
58+
self.p_per_sample = p_per_sample
59+
self.p_per_channel = p_per_channel
60+
self.sigma = sigma
61+
self.clip_to_input_range = clip_to_input_range
62+
63+
@staticmethod
64+
def get_params(sigma: Tuple[float]):
65+
sigma = np.random.uniform(*sigma)
66+
return sigma
67+
68+
def __blur__(self, image, sigma):
69+
for c in range(image.shape[0]):
70+
if np.random.uniform() < self.p_per_channel:
71+
image[c] = torch_blur(image[c], sigma, clip_to_input_range=self.clip_to_input_range)
72+
return image
73+
74+
def __call__(self, packed_data_dict=None, **unpacked_data_dict):
75+
data_dict = packed_data_dict if packed_data_dict else unpacked_data_dict
76+
assert (
77+
len(data_dict[self.data_key].shape) == 5 or len(data_dict[self.data_key].shape) == 4
78+
), f"Incorrect data size or shape.\
79+
\nShould be (b, c, x, y, z) or (b, c, x, y) and is: {data_dict[self.data_key].shape}"
80+
81+
for b in range(data_dict[self.data_key].shape[0]):
82+
if np.random.uniform() < self.p_per_sample:
83+
sigma = self.get_params(self.sigma)
84+
data_dict[self.data_key][b] = self.__blur__(data_dict[self.data_key][b], sigma)
85+
return data_dict

0 commit comments

Comments
 (0)