Skip to content

Commit e47b0cb

Browse files
committed
torch blur transform separable 1d convolution, slice and volume dimensions
1 parent 792d169 commit e47b0cb

1 file changed

Lines changed: 66 additions & 0 deletions

File tree

  • yucca/functional/transforms/torch
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import torch
2+
import torchvision.transforms.functional as TF
3+
4+
def torch_blur(
5+
image: torch.tensor,
6+
input_dims: torch.tensor,
7+
sigma: float,
8+
clip_to_input_range: bool = True
9+
):
10+
img_min = image.min()
11+
img_max = image.max()
12+
13+
if input_dims == 2:
14+
image = blur_2D_case_from_2D(image, sigma)
15+
elif input_dims == 3 and image.shape[0] <= 3:
16+
image = blur_2D_case_from_3D(image, sigma)
17+
elif input_dims == 3:
18+
image = blur_3D_case_from_3D(image, sigma)
19+
else:
20+
raise ValueError(f"Unsupported image shape for blur: {image.shape}")
21+
22+
if clip_to_input_range:
23+
image = torch.clamp(image, min=img_min, max=img_max)
24+
return image.cpu().numpy() if not isinstance(image, torch.Tensor) else image
25+
26+
def blur_2D_case_from_2D(image: torch.tensor, sigma: float) -> torch.tensor:
27+
image = image.unsqueeze(0).unsqueeze(0) # (1, 1, H, W)
28+
image = TF.gaussian_blur(image, kernel_size=int(2 * round(3 * sigma) + 1), sigma=sigma)
29+
return image.squeeze(0).squeeze(0)
30+
31+
def blur_2D_case_from_3D(image: torch.tensor, sigma: float) -> torch.tensor:
32+
image = image.unsqueeze(0) # (1, C, H, W)
33+
image = TF.gaussian_blur(image, kernel_size=int(2 * round(3 * sigma) + 1), sigma=sigma)
34+
return image.squeeze(0)
35+
36+
def blur_3D_case_from_3D(image: torch.Tensor, sigma: float) -> torch.Tensor:
37+
kernel_size = int(2 * round(3 * sigma) + 1)
38+
if kernel_size % 2 == 0:
39+
kernel_size += 1
40+
41+
coords = torch.arange(kernel_size, device=image.device) - kernel_size // 2
42+
kernel = torch.exp(-(coords.float() ** 2) / (2 * sigma ** 2))
43+
kernel = kernel / kernel.sum()
44+
kernel = kernel.view(1, 1, -1) # single-channel filter
45+
46+
# We are dealing with a voxel volume here but we'd like to support batches of
47+
# volumes and multi-modal MRI so adding batch and channel dimensions should support that.
48+
image = image.unsqueeze(0).unsqueeze(0)
49+
50+
# the permutation ordering is moving the target spatial dimension slice to the last dimension since
51+
# this is where conv1d applies the filter.
52+
for axis in range(2, 5):
53+
permute_order = list(range(image.dim()))
54+
permute_order[axis], permute_order[-1] = permute_order[-1], permute_order[axis]
55+
image = image.permute(permute_order).contiguous()
56+
orig_shape = image.shape
57+
58+
image = image.reshape(-1, 1, orig_shape[-1])
59+
image = torch.nn.functional.conv1d(image, kernel, padding=kernel.shape[-1] // 2)
60+
61+
image = image.reshape(orig_shape)
62+
image = image.permute(permute_order).contiguous()
63+
image = image.squeeze(0).squeeze(0)
64+
if image.shape[0] != image.shape[1]:
65+
image = image.permute(1, 2, 0)
66+
return image

0 commit comments

Comments
 (0)