Skip to content

Commit 36f91ea

Browse files
committed
change: narrow to 2d slice and 3d volume, ignore 2d w. channels
1 parent 177f398 commit 36f91ea

1 file changed

Lines changed: 25 additions & 33 deletions

File tree

  • yucca/functional/transforms/torch

yucca/functional/transforms/torch/blur.py

Lines changed: 25 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -6,61 +6,53 @@ def torch_blur(
66
input_dims: torch.tensor,
77
sigma: float,
88
clip_to_input_range: bool = True
9-
):
9+
) -> torch.tensor:
1010
img_min = image.min()
1111
img_max = image.max()
1212

1313
if input_dims == 2:
14-
image = blur_2D_case_from_2D(image, sigma)
15-
elif input_dims == 3 and image.shape[0] <= 3:
16-
image = blur_2D_case_from_3D(image, sigma)
14+
image = blur_2D(image, sigma)
1715
elif input_dims == 3:
18-
image = blur_3D_case_from_3D(image, sigma)
16+
image = blur_3D(image, sigma)
1917
else:
2018
raise ValueError(f"Unsupported image shape for blur: {image.shape}")
2119

2220
if clip_to_input_range:
2321
image = torch.clamp(image, min=img_min, max=img_max)
24-
return image.cpu().numpy() if not isinstance(image, torch.Tensor) else image
22+
return image
23+
24+
def blur_2D(image: torch.tensor, sigma: float) -> torch.tensor:
25+
assert image.ndim == 2, "Expected [H, W] tensor"
2526

26-
def blur_2D_case_from_2D(image: torch.tensor, sigma: float) -> torch.tensor:
27-
image = image.unsqueeze(0).unsqueeze(0) # (1, 1, H, W)
27+
image = image.unsqueeze(0).unsqueeze(0)
2828
image = TF.gaussian_blur(image, kernel_size=int(2 * round(3 * sigma) + 1), sigma=sigma)
2929
return image.squeeze(0).squeeze(0)
3030

31-
def blur_2D_case_from_3D(image: torch.tensor, sigma: float) -> torch.tensor:
32-
image = image.unsqueeze(0) # (1, C, H, W)
33-
image = TF.gaussian_blur(image, kernel_size=int(2 * round(3 * sigma) + 1), sigma=sigma)
34-
return image.squeeze(0)
31+
def blur_3D(image: torch.Tensor, sigma: float) -> torch.Tensor:
32+
assert image.ndim == 3, "Expected [D, H, W] tensor"
3533

36-
def blur_3D_case_from_3D(image: torch.Tensor, sigma: float) -> torch.Tensor:
3734
kernel_size = int(2 * round(3 * sigma) + 1)
3835
if kernel_size % 2 == 0:
3936
kernel_size += 1
4037

41-
coords = torch.arange(kernel_size, device=image.device) - kernel_size // 2
38+
# : filter single-channel, [out_channels, in_channels, kernel_size]
39+
coords = torch.arange(kernel_size) - kernel_size // 2
4240
kernel = torch.exp(-(coords.float() ** 2) / (2 * sigma ** 2))
4341
kernel = kernel / kernel.sum()
44-
kernel = kernel.view(1, 1, -1) # single-channel filter
45-
46-
# We are dealing with a voxel volume here but we'd like to support batches of
47-
# volumes and multi-modal MRI so adding batch and channel dimensions should support that.
48-
image = image.unsqueeze(0).unsqueeze(0)
42+
kernel = kernel.view(1, 1, -1)
4943

5044
# the permutation ordering is moving the target spatial dimension slice to the last dimension since
5145
# this is where conv1d applies the filter.
52-
for axis in range(2, 5):
53-
permute_order = list(range(image.dim()))
46+
volume = image
47+
for axis in range(3):
48+
permute_order = list(range(3))
5449
permute_order[axis], permute_order[-1] = permute_order[-1], permute_order[axis]
55-
image = image.permute(permute_order).contiguous()
56-
orig_shape = image.shape
57-
58-
image = image.reshape(-1, 1, orig_shape[-1])
59-
image = torch.nn.functional.conv1d(image, kernel, padding=kernel.shape[-1] // 2)
60-
61-
image = image.reshape(orig_shape)
62-
image = image.permute(permute_order).contiguous()
63-
image = image.squeeze(0).squeeze(0)
64-
if image.shape[0] != image.shape[1]:
65-
image = image.permute(1, 2, 0)
66-
return image
50+
volume = volume.permute(permute_order).contiguous()
51+
52+
shape = volume.shape
53+
volume = volume.view(-1, 1, shape[-1])
54+
volume = torch.nn.functional.conv1d(volume, kernel, padding=kernel.shape[-1] // 2)
55+
volume = volume.view(shape)
56+
volume = volume.permute(permute_order).contiguous()
57+
58+
return volume

0 commit comments

Comments
 (0)