Skip to content

Commit 16722b1

Browse files
committed
Minor spatial fix
1 parent 118adfb commit 16722b1

2 files changed

Lines changed: 5 additions & 4 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "yucca"
3-
version = "2.3.4"
3+
version = "2.3.5"
44
authors = [
55
{ name="Sebastian Llambias", email="llambias@live.com" },
66
{ name="Asbjørn Munk", email="9844416+asbjrnmunk@users.noreply.github.com" },

yucca/functional/transforms/torch/spatial.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,19 +52,20 @@ def _prepare(x):
5252

5353
coords = _create_zero_centered_coordinate_matrix(patch_size).to(device, dtype)
5454
if torch.rand(1) < p_deform:
55-
noise = torch.randn(1, ndim, *patch_size, device=device, dtype=dtype)
5655
if ndim == 2:
5756
# Separable 2D blur
57+
noise = torch.randn(1, 1, *patch_size, device=device, dtype=dtype)
5858
ksize = 21
5959
ax = torch.arange(ksize, device=device, dtype=dtype) - ksize // 2
6060
k = torch.exp(-0.5 * (ax / sigma) ** 2)
6161
k /= k.sum()
6262
ky = k.view(1, 1, -1, 1)
6363
kx = k.view(1, 1, 1, -1)
64-
noise = F.conv2d(noise, ky, padding=(ksize // 2, 0), groups=ndim)
65-
noise = F.conv2d(noise, kx, padding=(0, ksize // 2), groups=ndim)
64+
noise = F.conv2d(noise, ky, padding=(ksize // 2, 0), groups=1)
65+
noise = F.conv2d(noise, kx, padding=(0, ksize // 2), groups=1)
6666
else:
6767
# Separable 3D blur
68+
noise = torch.randn(1, ndim, *patch_size, device=device, dtype=dtype)
6869
ksize = 9
6970
ax = torch.arange(ksize, device=device, dtype=dtype) - ksize // 2
7071
k = torch.exp(-0.5 * (ax / sigma) ** 2)

0 commit comments

Comments
 (0)