Skip to content

Commit f125752

Browse files
Alex Sablayrollesfacebook-github-bot
Alex Sablayrolles
authored andcommitted
Fix unfold2d and add test (#443)
Summary: Fixes #442. TL;DR: F.pad has an unintuitive syntax (we need to indicated padding from last dimensions to first dimensions). This PR fixes that in unfold2d and adds tests with non-symmetric pad/dilation/kernel_size/stride. Pull Request resolved: #443 Reviewed By: karthikprasad Differential Revision: D37451293 Pulled By: alexandresablayrolles fbshipit-source-id: 0587325f4dcaa4d1dcd3068d9046825964b9d49f
1 parent 181c6b7 commit f125752

File tree

2 files changed

+53
-1
lines changed

2 files changed

+53
-1
lines changed

opacus/tests/grad_samples/conv2d_test.py

+51
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
import torch
2020
import torch.nn as nn
2121
from hypothesis import given, settings
22+
from opacus.utils.tensor_utils import unfold2d
23+
from torch.testing import assert_allclose
2224

2325
from .common import GradSampleHooks_test, expander, shrinker
2426

@@ -68,3 +70,52 @@ def test_conv2d(
6870
groups=groups,
6971
)
7072
self.run_test(x, conv, batch_first=True, atol=10e-5, rtol=10e-4)
73+
74+
@given(
75+
B=st.integers(1, 4),
76+
C=st.sampled_from([1, 3, 32]),
77+
H=st.integers(11, 17),
78+
W=st.integers(11, 17),
79+
k_w=st.integers(2, 3),
80+
k_h=st.integers(2, 3),
81+
stride_w=st.integers(1, 2),
82+
stride_h=st.integers(1, 2),
83+
pad_h=st.sampled_from([0, 2]),
84+
pad_w=st.sampled_from([0, 2]),
85+
dilation_w=st.integers(1, 3),
86+
dilation_h=st.integers(1, 3),
87+
)
88+
@settings(deadline=10000)
89+
def test_unfold2d(
90+
self,
91+
B: int,
92+
C: int,
93+
H: int,
94+
W: int,
95+
k_w: int,
96+
k_h: int,
97+
pad_w: int,
98+
pad_h: int,
99+
stride_w: int,
100+
stride_h: int,
101+
dilation_w: int,
102+
dilation_h: int,
103+
):
104+
X = torch.randn(B, C, H, W)
105+
X_unfold_torch = torch.nn.functional.unfold(
106+
X,
107+
kernel_size=(k_h, k_w),
108+
padding=(pad_h, pad_w),
109+
stride=(stride_w, stride_h),
110+
dilation=(dilation_w, dilation_h),
111+
)
112+
113+
X_unfold_opacus = unfold2d(
114+
X,
115+
kernel_size=(k_h, k_w),
116+
padding=(pad_h, pad_w),
117+
stride=(stride_w, stride_h),
118+
dilation=(dilation_w, dilation_h),
119+
)
120+
121+
assert_allclose(X_unfold_torch, X_unfold_opacus, atol=0, rtol=0)

opacus/utils/tensor_utils.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,8 @@ def unfold2d(
130130
W_effective = (
131131
W + 2 * padding[1] - (kernel_size[1] + (kernel_size[1] - 1) * (dilation[1] - 1))
132132
) // stride[1] + 1
133-
input = F.pad(input, (padding[0], padding[0], padding[1], padding[1]))
133+
# F.pad's first argument is the padding of the *last* dimension
134+
input = F.pad(input, (padding[1], padding[1], padding[0], padding[0]))
134135
*shape_pad, H_pad, W_pad = input.shape
135136
strides = list(input.stride())
136137
strides = strides[:-2] + [

0 commit comments

Comments
 (0)