@@ -6,61 +6,53 @@ def torch_blur(
66 input_dims : torch .tensor ,
77 sigma : float ,
88 clip_to_input_range : bool = True
9- ):
9+ ) -> torch . tensor :
1010 img_min = image .min ()
1111 img_max = image .max ()
1212
1313 if input_dims == 2 :
14- image = blur_2D_case_from_2D (image , sigma )
15- elif input_dims == 3 and image .shape [0 ] <= 3 :
16- image = blur_2D_case_from_3D (image , sigma )
14+ image = blur_2D (image , sigma )
1715 elif input_dims == 3 :
18- image = blur_3D_case_from_3D (image , sigma )
16+ image = blur_3D (image , sigma )
1917 else :
2018 raise ValueError (f"Unsupported image shape for blur: { image .shape } " )
2119
2220 if clip_to_input_range :
2321 image = torch .clamp (image , min = img_min , max = img_max )
24- return image .cpu ().numpy () if not isinstance (image , torch .Tensor ) else image
22+ return image
23+
24+ def blur_2D (image : torch .tensor , sigma : float ) -> torch .tensor :
25+ assert image .ndim == 2 , "Expected [H, W] tensor"
2526
26- def blur_2D_case_from_2D (image : torch .tensor , sigma : float ) -> torch .tensor :
27- image = image .unsqueeze (0 ).unsqueeze (0 ) # (1, 1, H, W)
27+ image = image .unsqueeze (0 ).unsqueeze (0 )
2828 image = TF .gaussian_blur (image , kernel_size = int (2 * round (3 * sigma ) + 1 ), sigma = sigma )
2929 return image .squeeze (0 ).squeeze (0 )
3030
31- def blur_2D_case_from_3D (image : torch .tensor , sigma : float ) -> torch .tensor :
32- image = image .unsqueeze (0 ) # (1, C, H, W)
33- image = TF .gaussian_blur (image , kernel_size = int (2 * round (3 * sigma ) + 1 ), sigma = sigma )
34- return image .squeeze (0 )
31+ def blur_3D (image : torch .Tensor , sigma : float ) -> torch .Tensor :
32+ assert image .ndim == 3 , "Expected [D, H, W] tensor"
3533
36- def blur_3D_case_from_3D (image : torch .Tensor , sigma : float ) -> torch .Tensor :
3734 kernel_size = int (2 * round (3 * sigma ) + 1 )
3835 if kernel_size % 2 == 0 :
3936 kernel_size += 1
4037
41- coords = torch .arange (kernel_size , device = image .device ) - kernel_size // 2
38+ # : filter single-channel, [out_channels, in_channels, kernel_size]
39+ coords = torch .arange (kernel_size ) - kernel_size // 2
4240 kernel = torch .exp (- (coords .float () ** 2 ) / (2 * sigma ** 2 ))
4341 kernel = kernel / kernel .sum ()
44- kernel = kernel .view (1 , 1 , - 1 ) # single-channel filter
45-
46- # We are dealing with a voxel volume here but we'd like to support batches of
47- # volumes and multi-modal MRI so adding batch and channel dimensions should support that.
48- image = image .unsqueeze (0 ).unsqueeze (0 )
42+ kernel = kernel .view (1 , 1 , - 1 )
4943
5044 # the permutation ordering is moving the target spatial dimension slice to the last dimension since
5145 # this is where conv1d applies the filter.
52- for axis in range (2 , 5 ):
53- permute_order = list (range (image .dim ()))
46+ volume = image
47+ for axis in range (3 ):
48+ permute_order = list (range (3 ))
5449 permute_order [axis ], permute_order [- 1 ] = permute_order [- 1 ], permute_order [axis ]
55- image = image .permute (permute_order ).contiguous ()
56- orig_shape = image .shape
57-
58- image = image .reshape (- 1 , 1 , orig_shape [- 1 ])
59- image = torch .nn .functional .conv1d (image , kernel , padding = kernel .shape [- 1 ] // 2 )
60-
61- image = image .reshape (orig_shape )
62- image = image .permute (permute_order ).contiguous ()
63- image = image .squeeze (0 ).squeeze (0 )
64- if image .shape [0 ] != image .shape [1 ]:
65- image = image .permute (1 , 2 , 0 )
66- return image
50+ volume = volume .permute (permute_order ).contiguous ()
51+
52+ shape = volume .shape
53+ volume = volume .view (- 1 , 1 , shape [- 1 ])
54+ volume = torch .nn .functional .conv1d (volume , kernel , padding = kernel .shape [- 1 ] // 2 )
55+ volume = volume .view (shape )
56+ volume = volume .permute (permute_order ).contiguous ()
57+
58+ return volume
0 commit comments