@@ -235,6 +235,7 @@ def __init__(
235235 base_width_and_cardinality : Optional [Union [Tuple , List ]] = None ,
236236 basic_layer : bool = False ,
237237 final_bn_relu : bool = True ,
238+ bn_weight_decay : Optional [bool ] = False ,
238239 ):
239240 """
240241 Implementation of `ResNeXt <https://arxiv.org/pdf/1611.05431.pdf>`_.
@@ -251,6 +252,7 @@ def __init__(
251252 assert all (is_pos_int (n ) for n in num_blocks )
252253 assert is_pos_int (init_planes ) and is_pos_int (reduction )
253254 assert type (small_input ) == bool
255+ assert type (bn_weight_decay ) == bool
254256 assert (
255257 type (zero_init_bn_residuals ) == bool
256258 ), "zero_init_bn_residuals must be a boolean, set to true if gamma of last\
@@ -262,9 +264,11 @@ def __init__(
262264 and is_pos_int (base_width_and_cardinality [1 ])
263265 )
264266
265- # we apply weight decay to batch norm if the model is a ResNeXt and we don't if
266- # it is a ResNet
267- self .bn_weight_decay = base_width_and_cardinality is not None
267+ # Chooses whether to apply weight decay to batch norm
268+ # parameters. This improves results in some situations,
269+ # e.g. ResNeXt models trained / evaluated using the Imagenet
270+ # dataset, but can cause worse performance in other scenarios
271+ self .bn_weight_decay = bn_weight_decay
268272
269273 # initial convolutional block:
270274 self .num_blocks = num_blocks
@@ -374,6 +378,7 @@ def from_config(cls, config: Dict[str, Any]) -> "ResNeXt":
374378 "basic_layer" : config .get ("basic_layer" , False ),
375379 "final_bn_relu" : config .get ("final_bn_relu" , True ),
376380 "zero_init_bn_residuals" : config .get ("zero_init_bn_residuals" , False ),
381+ "bn_weight_decay" : config .get ("bn_weight_decay" , False ),
377382 }
378383 return cls (** config )
379384
@@ -476,6 +481,12 @@ def from_config(cls, config: Dict[str, Any]) -> "ResNeXt":
476481 return cls ()
477482
478483
484+ # Note, the ResNeXt models all have weight decay enabled for the batch
485+ # norm parameters. We have found empirically that this gives better
486+ # results when training on ImageNet (~0.5pp of top-1 acc) and brings
487+ # our results on track with reported ImageNet results...but for
488+ # training on other datasets, we have observed losses in accuracy (for
489+ # example, the dataset used in https://arxiv.org/abs/1805.00932).
479490@register_model ("resnext50_32x4d" )
480491class ResNeXt50 (ResNeXt ):
481492 def __init__ (self ):
@@ -484,6 +495,7 @@ def __init__(self):
484495 basic_layer = False ,
485496 zero_init_bn_residuals = True ,
486497 base_width_and_cardinality = (4 , 32 ),
498+ bn_weight_decay = True ,
487499 )
488500
489501 @classmethod
@@ -499,6 +511,7 @@ def __init__(self):
499511 basic_layer = False ,
500512 zero_init_bn_residuals = True ,
501513 base_width_and_cardinality = (4 , 32 ),
514+ bn_weight_decay = True ,
502515 )
503516
504517 @classmethod
@@ -514,6 +527,7 @@ def __init__(self):
514527 basic_layer = False ,
515528 zero_init_bn_residuals = True ,
516529 base_width_and_cardinality = (4 , 32 ),
530+ bn_weight_decay = True ,
517531 )
518532
519533 @classmethod
0 commit comments