@@ -151,6 +151,7 @@ def __init__(
151151 norm_layer = ABN ,
152152 norm_act = "relu" ,
153153 antialias = False ,
154+ keep_prob = 1 ,
154155 ):
155156 super (BasicBlock , self ).__init__ ()
156157 antialias = antialias and stride == 2
@@ -167,6 +168,7 @@ def __init__(
167168 self .downsample = downsample
168169 self .blurpool = BlurPool (channels = planes ) if antialias else nn .Identity ()
169170 self .antialias = antialias
171+ self .drop_connect = DropConnect (keep_prob ) if keep_prob < 1 else nn .Identity ()
170172
171173 def forward (self , x ):
172174 residual = x
@@ -180,11 +182,11 @@ def forward(self, x):
180182 if self .antialias :
181183 out = self .blurpool (out )
182184 out = self .conv2 (out )
183- # avoid 2 inplace ops by chaining into one long op. Neede for inplaceabn
185+ # avoid 2 inplace ops by chaining into one long op. Needed for inplaceabn
184186 if self .se_module is not None :
185- out = self .se_module (self .bn2 (out )) + residual
187+ out = self .drop_connect ( self . se_module (self .bn2 (out ) )) + residual
186188 else :
187- out = self .bn2 (out ) + residual
189+ out = self .drop_connect ( self . bn2 (out ) ) + residual
188190 return self .final_act (out )
189191
190192
@@ -204,6 +206,7 @@ def __init__(
204206 norm_layer = ABN ,
205207 norm_act = "relu" ,
206208 antialias = False ,
209+ keep_prob = 1 , # for drop connect
207210 ):
208211 super (Bottleneck , self ).__init__ ()
209212 antialias = antialias and stride == 2
@@ -222,6 +225,7 @@ def __init__(
222225 self .downsample = downsample
223226 self .blurpool = BlurPool (channels = width ) if antialias else nn .Identity ()
224227 self .antialias = antialias
228+ self .drop_connect = DropConnect (keep_prob ) if keep_prob < 1 else nn .Identity ()
225229
226230 def forward (self , x ):
227231 residual = x
@@ -241,9 +245,9 @@ def forward(self, x):
241245 out = self .conv3 (out )
242246 # avoid 2 inplace ops by chaining into one long op
243247 if self .se_module is not None :
244- out = self .se_module (self .bn3 (out )) + residual
248+ out = self .drop_connect ( self . se_module (self .bn3 (out ) )) + residual
245249 else :
246- out = self .bn3 (out ) + residual
250+ out = self .drop_connect ( self . bn3 (out ) ) + residual
247251 return self .final_act (out )
248252
249253# TResnet models use slightly modified versions of BasicBlock and Bottleneck
@@ -292,5 +296,5 @@ def forward(self, x):
292296
293297 out = self .conv3 (out )
294298 # avoid 2 inplace ops by chaining into one long op
295- out = self .bn3 (out ) + residual
299+ out = self .drop_connect ( self . bn3 (out ) ) + residual
296300 return self .final_act (out )
0 commit comments