Skip to content

Commit d8d7b4b

Browse files
committed
add stochastic depth to resnet
1 parent 0d923f0 commit d8d7b4b

File tree

6 files changed

+60
-8
lines changed

6 files changed

+60
-8
lines changed

pytorch_tools/models/resnet.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,9 @@ class ResNet(nn.Module):
7070
Flag to overwrite forward pass to return 5 tensors with different resolutions. Defaults to False.
7171
drop_rate (float):
7272
Dropout probability before classifier, for training. Defaults to 0.0.
73+
drop_connect_rate (float):
74+
Drop rate for StochasticDepth. Randomly removes samples each block. Used as regularization during training.
75+
keep prob will be linearly decreased from 1 to 1 - drop_connect_rate each block. Ref: https://arxiv.org/abs/1603.09382
7376
global_pool (str):
7477
Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax'. Defaults to 'avg'.
7578
init_bn0 (bool):
@@ -95,6 +98,7 @@ def __init__(
9598
antialias=False,
9699
encoder=False,
97100
drop_rate=0.0,
101+
drop_connect_rate=0.0,
98102
global_pool="avg",
99103
init_bn0=True,
100104
):
@@ -108,6 +112,9 @@ def __init__(
108112
self.block = block
109113
self.expansion = block.expansion
110114
self.norm_act = norm_act
115+
self.block_idx = 0
116+
self.num_blocks = sum(layers)
117+
self.drop_connect_rate = drop_connect_rate
111118
super(ResNet, self).__init__()
112119

113120
if deep_stem:
@@ -185,6 +192,7 @@ def _make_layer(
185192
norm_layer=norm_layer,
186193
norm_act=norm_act,
187194
antialias=antialias,
195+
keep_prob=self.keep_prob,
188196
)
189197
]
190198

@@ -201,6 +209,7 @@ def _make_layer(
201209
norm_layer=norm_layer,
202210
norm_act=norm_act,
203211
antialias=antialias,
212+
keep_prob=self.keep_prob,
204213
)
205214
)
206215
return nn.Sequential(*layers)
@@ -266,6 +275,11 @@ def load_state_dict(self, state_dict, **kwargs):
266275
state_dict[k.replace("layer0.", "")] = state_dict.pop(k)
267276
super().load_state_dict(state_dict, **kwargs)
268277

278+
@property
279+
def keep_prob(self):
280+
keep_prob = 1 - self.drop_connect_rate * self.block_idx / self.num_blocks
281+
self.block_idx += 1
282+
return keep_prob
269283

270284
# fmt: off
271285
CFGS = {

pytorch_tools/models/tresnet.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ class TResNet(ResNet):
5151
Flag to overwrite forward pass to return 5 tensors with different resolutions. Defaults to False.
5252
drop_rate (float):
5353
Dropout probability before classifier, for training. Defaults to 0.0. to 'avg'.
54+
drop_connect_rate (float):
55+
Drop rate for StochasticDepth. Randomly removes samples each block. Used as regularization during training. Ref: https://arxiv.org/abs/1603.09382
5456
"""
5557

5658
def __init__(
@@ -65,6 +67,7 @@ def __init__(
6567
norm_act="leaky_relu",
6668
encoder=False,
6769
drop_rate=0.0,
70+
drop_connect_rate=0.0,
6871
):
6972
nn.Module.__init__(self)
7073
stem_width = int(64 * width_factor)
@@ -74,6 +77,9 @@ def __init__(
7477
self.groups = 1 # not really used but needed inside _make_layer
7578
self.base_width = 64 # used inside _make_layer
7679
self.norm_act = norm_act
80+
self.block_idx = 0
81+
self.num_blocks = sum(layers)
82+
self.drop_connect_rate = drop_connect_rate
7783

7884
# in the paper they use conv1x1 but in code conv3x3 (which seems better)
7985
self.conv1 = nn.Sequential(SpaceToDepth(), conv3x3(in_channels * 16, stem_width))

pytorch_tools/modules/residual.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ def __init__(
151151
norm_layer=ABN,
152152
norm_act="relu",
153153
antialias=False,
154+
keep_prob=1,
154155
):
155156
super(BasicBlock, self).__init__()
156157
antialias = antialias and stride == 2
@@ -167,6 +168,7 @@ def __init__(
167168
self.downsample = downsample
168169
self.blurpool = BlurPool(channels=planes) if antialias else nn.Identity()
169170
self.antialias = antialias
171+
self.drop_connect = DropConnect(keep_prob) if keep_prob < 1 else nn.Identity()
170172

171173
def forward(self, x):
172174
residual = x
@@ -180,11 +182,11 @@ def forward(self, x):
180182
if self.antialias:
181183
out = self.blurpool(out)
182184
out = self.conv2(out)
183-
# avoid 2 inplace ops by chaining into one long op. Neede for inplaceabn
185+
# avoid 2 inplace ops by chaining into one long op. Needed for inplaceabn
184186
if self.se_module is not None:
185-
out = self.se_module(self.bn2(out)) + residual
187+
out = self.drop_connect(self.se_module(self.bn2(out))) + residual
186188
else:
187-
out = self.bn2(out) + residual
189+
out = self.drop_connect(self.bn2(out)) + residual
188190
return self.final_act(out)
189191

190192

@@ -204,6 +206,7 @@ def __init__(
204206
norm_layer=ABN,
205207
norm_act="relu",
206208
antialias=False,
209+
keep_prob=1, # for drop connect
207210
):
208211
super(Bottleneck, self).__init__()
209212
antialias = antialias and stride == 2
@@ -222,6 +225,7 @@ def __init__(
222225
self.downsample = downsample
223226
self.blurpool = BlurPool(channels=width) if antialias else nn.Identity()
224227
self.antialias = antialias
228+
self.drop_connect = DropConnect(keep_prob) if keep_prob < 1 else nn.Identity()
225229

226230
def forward(self, x):
227231
residual = x
@@ -241,9 +245,9 @@ def forward(self, x):
241245
out = self.conv3(out)
242246
# avoid 2 inplace ops by chaining into one long op
243247
if self.se_module is not None:
244-
out = self.se_module(self.bn3(out)) + residual
248+
out = self.drop_connect(self.se_module(self.bn3(out))) + residual
245249
else:
246-
out = self.bn3(out) + residual
250+
out = self.drop_connect(self.bn3(out)) + residual
247251
return self.final_act(out)
248252

249253
# TResnet models use slightly modified versions of BasicBlock and Bottleneck
@@ -292,5 +296,5 @@ def forward(self, x):
292296

293297
out = self.conv3(out)
294298
# avoid 2 inplace ops by chaining into one long op
295-
out = self.bn3(out) + residual
299+
out = self.drop_connect(self.bn3(out)) + residual
296300
return self.final_act(out)

pytorch_tools/segmentation_models/unet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def __init__(
3838
self.layer3 = UnetDecoderBlock(in_channels[2], out_channels[2], **bn_params)
3939
self.layer4 = UnetDecoderBlock(in_channels[3], out_channels[3], **bn_params)
4040
self.layer5 = UnetDecoderBlock(in_channels[4], out_channels[4], **bn_params)
41-
self.dropout = nn.Dropout2d(drop_rate, inplace=True)
41+
self.dropout = nn.Dropout2d(drop_rate, inplace=False) # inplace=True raises a backprop error
4242
self.final_conv = conv1x1(out_channels[4], final_channels)
4343

4444
initialize(self)

tests/models/test_models.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,11 @@ def test_dilation(arch, output_stride):
7373
W, H = INP.shape[-2:]
7474
assert res.shape[-2:] == (W // output_stride, H // output_stride)
7575

76+
@pytest.mark.parametrize("arch", TEST_MODEL_NAMES)
77+
def test_drop_connect(arch):
78+
m = models.__dict__[arch](drop_connect_rate=0.2)
79+
_test_forward(m)
80+
7681
NUM_PARAMS = {
7782
"tresnetm": 31389032,
7883
"tresnetl": 55989256,

tests/models/test_weights.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
## test that imagenet pretrained weights are valid and able to classify correctly the cat and dog
22

3+
import torch
4+
import pytest
35
import numpy as np
46
from PIL import Image
5-
import pytest
67

78
from pytorch_tools.utils.preprocessing import get_preprocessing_fn
89
from pytorch_tools.utils.visualization import tensor_from_rgb_image
@@ -53,3 +54,25 @@ def test_imagenet_pretrain(arch):
5354
im = im.view(1, *im.shape).float()
5455
pred_cls = m(im).argmax()
5556
assert pred_cls == im_cls
57+
58+
# test that output mean for fixed input is the same
59+
MODEL_NAMES2 = [
60+
"resnet34",
61+
"se_resnet50",
62+
"efficientnet_b0",
63+
]
64+
65+
MODEL_MEAN = {
66+
"resnet34": 7.6799e-06,
67+
"se_resnet50": -2.6095e-06,
68+
"efficientnet_b0": 0.0070,
69+
}
70+
71+
@pytest.mark.parametrize("arch", MODEL_NAMES2)
72+
def test_output_mean(arch):
73+
m = models.__dict__[arch](pretrained="imagenet")
74+
m.eval()
75+
inp = torch.ones(1, 3, 256, 256)
76+
with torch.no_grad():
77+
out = m(inp).mean().numpy()
78+
assert np.allclose(out, MODEL_MEAN[arch], rtol=1e-4, atol=1e-4)

0 commit comments

Comments
 (0)