@@ -52,19 +52,20 @@ def _prepare(x):
5252
5353 coords = _create_zero_centered_coordinate_matrix (patch_size ).to (device , dtype )
5454 if torch .rand (1 ) < p_deform :
55- noise = torch .randn (1 , ndim , * patch_size , device = device , dtype = dtype )
5655 if ndim == 2 :
5756 # Separable 2D blur
57+ noise = torch .randn (1 , 1 , * patch_size , device = device , dtype = dtype )
5858 ksize = 21
5959 ax = torch .arange (ksize , device = device , dtype = dtype ) - ksize // 2
6060 k = torch .exp (- 0.5 * (ax / sigma ) ** 2 )
6161 k /= k .sum ()
6262 ky = k .view (1 , 1 , - 1 , 1 )
6363 kx = k .view (1 , 1 , 1 , - 1 )
64- noise = F .conv2d (noise , ky , padding = (ksize // 2 , 0 ), groups = ndim )
65- noise = F .conv2d (noise , kx , padding = (0 , ksize // 2 ), groups = ndim )
64+ noise = F .conv2d (noise , ky , padding = (ksize // 2 , 0 ), groups = 1 )
65+ noise = F .conv2d (noise , kx , padding = (0 , ksize // 2 ), groups = 1 )
6666 else :
6767 # Separable 3D blur
68+ noise = torch .randn (1 , ndim , * patch_size , device = device , dtype = dtype )
6869 ksize = 9
6970 ax = torch .arange (ksize , device = device , dtype = dtype ) - ksize // 2
7071 k = torch .exp (- 0.5 * (ax / sigma ) ** 2 )
0 commit comments