Skip to content
This repository was archived by the owner on Jul 1, 2024. It is now read-only.

Commit 553c18c

Browse files
Aaron Adcockfacebook-github-bot
authored andcommitted
Make bn weight decay configurable (#65)
Summary: Pull Request resolved: fairinternal/ClassyVision#65 Make the bn weight decay configurable, for some datasets it might be desirable to turn it off. Reviewed By: vreis Differential Revision: D20140487 fbshipit-source-id: 77debf2c4600a080081668565d70b7a3ddc788f4
1 parent e47a18d commit 553c18c

1 file changed

Lines changed: 17 additions & 3 deletions

File tree

classy_vision/models/resnext.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,7 @@ def __init__(
235235
base_width_and_cardinality: Optional[Union[Tuple, List]] = None,
236236
basic_layer: bool = False,
237237
final_bn_relu: bool = True,
238+
bn_weight_decay: Optional[bool] = False,
238239
):
239240
"""
240241
Implementation of `ResNeXt <https://arxiv.org/pdf/1611.05431.pdf>`_.
@@ -251,6 +252,7 @@ def __init__(
251252
assert all(is_pos_int(n) for n in num_blocks)
252253
assert is_pos_int(init_planes) and is_pos_int(reduction)
253254
assert type(small_input) == bool
255+
assert type(bn_weight_decay) == bool
254256
assert (
255257
type(zero_init_bn_residuals) == bool
256258
), "zero_init_bn_residuals must be a boolean, set to true if gamma of last\
@@ -262,9 +264,11 @@ def __init__(
262264
and is_pos_int(base_width_and_cardinality[1])
263265
)
264266

265-
# we apply weight decay to batch norm if the model is a ResNeXt and we don't if
266-
# it is a ResNet
267-
self.bn_weight_decay = base_width_and_cardinality is not None
267+
# Chooses whether to apply weight decay to batch norm
268+
# parameters. This improves results in some situations,
269+
# e.g. ResNeXt models trained / evaluated using the Imagenet
270+
# dataset, but can cause worse performance in other scenarios
271+
self.bn_weight_decay = bn_weight_decay
268272

269273
# initial convolutional block:
270274
self.num_blocks = num_blocks
@@ -374,6 +378,7 @@ def from_config(cls, config: Dict[str, Any]) -> "ResNeXt":
374378
"basic_layer": config.get("basic_layer", False),
375379
"final_bn_relu": config.get("final_bn_relu", True),
376380
"zero_init_bn_residuals": config.get("zero_init_bn_residuals", False),
381+
"bn_weight_decay": config.get("bn_weight_decay", False),
377382
}
378383
return cls(**config)
379384

@@ -476,6 +481,12 @@ def from_config(cls, config: Dict[str, Any]) -> "ResNeXt":
476481
return cls()
477482

478483

484+
# Note, the ResNeXt models all have weight decay enabled for the batch
485+
# norm parameters. We have found empirically that this gives better
486+
# results when training on ImageNet (~0.5pp of top-1 acc) and brings
487+
# our results on track with reported ImageNet results...but for
488+
# training on other datasets, we have observed losses in accuracy (for
489+
# example, the dataset used in https://arxiv.org/abs/1805.00932).
479490
@register_model("resnext50_32x4d")
480491
class ResNeXt50(ResNeXt):
481492
def __init__(self):
@@ -484,6 +495,7 @@ def __init__(self):
484495
basic_layer=False,
485496
zero_init_bn_residuals=True,
486497
base_width_and_cardinality=(4, 32),
498+
bn_weight_decay=True,
487499
)
488500

489501
@classmethod
@@ -499,6 +511,7 @@ def __init__(self):
499511
basic_layer=False,
500512
zero_init_bn_residuals=True,
501513
base_width_and_cardinality=(4, 32),
514+
bn_weight_decay=True,
502515
)
503516

504517
@classmethod
@@ -514,6 +527,7 @@ def __init__(self):
514527
basic_layer=False,
515528
zero_init_bn_residuals=True,
516529
base_width_and_cardinality=(4, 32),
530+
bn_weight_decay=True,
517531
)
518532

519533
@classmethod

0 commit comments

Comments
 (0)