|
| 1 | +import torch |
| 2 | +import torchvision.transforms.functional as TF |
| 3 | + |
| 4 | +def torch_blur( |
| 5 | + image: torch.tensor, |
| 6 | + input_dims: torch.tensor, |
| 7 | + sigma: float, |
| 8 | + clip_to_input_range: bool = True |
| 9 | + ): |
| 10 | + img_min = image.min() |
| 11 | + img_max = image.max() |
| 12 | + |
| 13 | + if input_dims == 2: |
| 14 | + image = blur_2D_case_from_2D(image, sigma) |
| 15 | + elif input_dims == 3 and image.shape[0] <= 3: |
| 16 | + image = blur_2D_case_from_3D(image, sigma) |
| 17 | + elif input_dims == 3: |
| 18 | + image = blur_3D_case_from_3D(image, sigma) |
| 19 | + else: |
| 20 | + raise ValueError(f"Unsupported image shape for blur: {image.shape}") |
| 21 | + |
| 22 | + if clip_to_input_range: |
| 23 | + image = torch.clamp(image, min=img_min, max=img_max) |
| 24 | + return image.cpu().numpy() if not isinstance(image, torch.Tensor) else image |
| 25 | + |
| 26 | +def blur_2D_case_from_2D(image: torch.tensor, sigma: float) -> torch.tensor: |
| 27 | + image = image.unsqueeze(0).unsqueeze(0) # (1, 1, H, W) |
| 28 | + image = TF.gaussian_blur(image, kernel_size=int(2 * round(3 * sigma) + 1), sigma=sigma) |
| 29 | + return image.squeeze(0).squeeze(0) |
| 30 | + |
| 31 | +def blur_2D_case_from_3D(image: torch.tensor, sigma: float) -> torch.tensor: |
| 32 | + image = image.unsqueeze(0) # (1, C, H, W) |
| 33 | + image = TF.gaussian_blur(image, kernel_size=int(2 * round(3 * sigma) + 1), sigma=sigma) |
| 34 | + return image.squeeze(0) |
| 35 | + |
| 36 | +def blur_3D_case_from_3D(image: torch.Tensor, sigma: float) -> torch.Tensor: |
| 37 | + kernel_size = int(2 * round(3 * sigma) + 1) |
| 38 | + if kernel_size % 2 == 0: |
| 39 | + kernel_size += 1 |
| 40 | + |
| 41 | + coords = torch.arange(kernel_size, device=image.device) - kernel_size // 2 |
| 42 | + kernel = torch.exp(-(coords.float() ** 2) / (2 * sigma ** 2)) |
| 43 | + kernel = kernel / kernel.sum() |
| 44 | + kernel = kernel.view(1, 1, -1) # single-channel filter |
| 45 | + |
| 46 | + # We are dealing with a voxel volume here but we'd like to support batches of |
| 47 | + # volumes and multi-modal MRI so adding batch and channel dimensions should support that. |
| 48 | + image = image.unsqueeze(0).unsqueeze(0) |
| 49 | + |
| 50 | + # the permutation ordering is moving the target spatial dimension slice to the last dimension since |
| 51 | + # this is where conv1d applies the filter. |
| 52 | + for axis in range(2, 5): |
| 53 | + permute_order = list(range(image.dim())) |
| 54 | + permute_order[axis], permute_order[-1] = permute_order[-1], permute_order[axis] |
| 55 | + image = image.permute(permute_order).contiguous() |
| 56 | + orig_shape = image.shape |
| 57 | + |
| 58 | + image = image.reshape(-1, 1, orig_shape[-1]) |
| 59 | + image = torch.nn.functional.conv1d(image, kernel, padding=kernel.shape[-1] // 2) |
| 60 | + |
| 61 | + image = image.reshape(orig_shape) |
| 62 | + image = image.permute(permute_order).contiguous() |
| 63 | + image = image.squeeze(0).squeeze(0) |
| 64 | + if image.shape[0] != image.shape[1]: |
| 65 | + image = image.permute(1, 2, 0) |
| 66 | + return image |
0 commit comments