|
19 | 19 | import torch
|
20 | 20 | import torch.nn as nn
|
21 | 21 | from hypothesis import given, settings
|
| 22 | +from opacus.utils.tensor_utils import unfold2d |
| 23 | +from torch.testing import assert_allclose |
22 | 24 |
|
23 | 25 | from .common import GradSampleHooks_test, expander, shrinker
|
24 | 26 |
|
@@ -68,3 +70,52 @@ def test_conv2d(
|
68 | 70 | groups=groups,
|
69 | 71 | )
|
70 | 72 | self.run_test(x, conv, batch_first=True, atol=10e-5, rtol=10e-4)
|
| 73 | + |
| 74 | + @given( |
| 75 | + B=st.integers(1, 4), |
| 76 | + C=st.sampled_from([1, 3, 32]), |
| 77 | + H=st.integers(11, 17), |
| 78 | + W=st.integers(11, 17), |
| 79 | + k_w=st.integers(2, 3), |
| 80 | + k_h=st.integers(2, 3), |
| 81 | + stride_w=st.integers(1, 2), |
| 82 | + stride_h=st.integers(1, 2), |
| 83 | + pad_h=st.sampled_from([0, 2]), |
| 84 | + pad_w=st.sampled_from([0, 2]), |
| 85 | + dilation_w=st.integers(1, 3), |
| 86 | + dilation_h=st.integers(1, 3), |
| 87 | + ) |
| 88 | + @settings(deadline=10000) |
| 89 | + def test_unfold2d( |
| 90 | + self, |
| 91 | + B: int, |
| 92 | + C: int, |
| 93 | + H: int, |
| 94 | + W: int, |
| 95 | + k_w: int, |
| 96 | + k_h: int, |
| 97 | + pad_w: int, |
| 98 | + pad_h: int, |
| 99 | + stride_w: int, |
| 100 | + stride_h: int, |
| 101 | + dilation_w: int, |
| 102 | + dilation_h: int, |
| 103 | + ): |
| 104 | + X = torch.randn(B, C, H, W) |
| 105 | + X_unfold_torch = torch.nn.functional.unfold( |
| 106 | + X, |
| 107 | + kernel_size=(k_h, k_w), |
| 108 | + padding=(pad_h, pad_w), |
| 109 | + stride=(stride_w, stride_h), |
| 110 | + dilation=(dilation_w, dilation_h), |
| 111 | + ) |
| 112 | + |
| 113 | + X_unfold_opacus = unfold2d( |
| 114 | + X, |
| 115 | + kernel_size=(k_h, k_w), |
| 116 | + padding=(pad_h, pad_w), |
| 117 | + stride=(stride_w, stride_h), |
| 118 | + dilation=(dilation_w, dilation_h), |
| 119 | + ) |
| 120 | + |
| 121 | + assert_allclose(X_unfold_torch, X_unfold_opacus, atol=0, rtol=0) |
0 commit comments