|
2 | 2 | import numpy as np |
3 | 3 | from typing import Tuple |
4 | 4 | from yucca.functional.transforms import blur |
| 5 | +from yucca.functional.transforms.torch.blur import torch_blur |
5 | 6 |
|
6 | 7 |
|
7 | 8 | class Blur(YuccaTransform): |
@@ -42,3 +43,43 @@ def __call__(self, packed_data_dict=None, **unpacked_data_dict): |
42 | 43 | sigma = self.get_params(self.sigma) |
43 | 44 | data_dict[self.data_key][b] = self.__blur__(data_dict[self.data_key][b], sigma) |
44 | 45 | return data_dict |
| 46 | + |
| 47 | + |
| 48 | +class Torch_Blur(YuccaTransform): |
| 49 | + def __init__( |
| 50 | + self, |
| 51 | + data_key="image", |
| 52 | + p_per_sample: float = 1.0, |
| 53 | + p_per_channel: float = 0.5, |
| 54 | + sigma=(0.5, 1.0), |
| 55 | + clip_to_input_range=False, |
| 56 | + ): |
| 57 | + self.data_key = data_key |
| 58 | + self.p_per_sample = p_per_sample |
| 59 | + self.p_per_channel = p_per_channel |
| 60 | + self.sigma = sigma |
| 61 | + self.clip_to_input_range = clip_to_input_range |
| 62 | + |
| 63 | + @staticmethod |
| 64 | + def get_params(sigma: Tuple[float]): |
| 65 | + sigma = np.random.uniform(*sigma) |
| 66 | + return sigma |
| 67 | + |
| 68 | + def __blur__(self, image, sigma): |
| 69 | + for c in range(image.shape[0]): |
| 70 | + if np.random.uniform() < self.p_per_channel: |
| 71 | + image[c] = torch_blur(image[c], sigma, clip_to_input_range=self.clip_to_input_range) |
| 72 | + return image |
| 73 | + |
| 74 | + def __call__(self, packed_data_dict=None, **unpacked_data_dict): |
| 75 | + data_dict = packed_data_dict if packed_data_dict else unpacked_data_dict |
| 76 | + assert ( |
| 77 | + len(data_dict[self.data_key].shape) == 5 or len(data_dict[self.data_key].shape) == 4 |
| 78 | + ), f"Incorrect data size or shape.\ |
| 79 | + \nShould be (b, c, x, y, z) or (b, c, x, y) and is: {data_dict[self.data_key].shape}" |
| 80 | + |
| 81 | + for b in range(data_dict[self.data_key].shape[0]): |
| 82 | + if np.random.uniform() < self.p_per_sample: |
| 83 | + sigma = self.get_params(self.sigma) |
| 84 | + data_dict[self.data_key][b] = self.__blur__(data_dict[self.data_key][b], sigma) |
| 85 | + return data_dict |
0 commit comments