Skip to content

Commit 55c2ea2

Browse files
fix: pass train_bn to freeze in BackboneFinetuning.freeze_before_training
freeze_before_training() called self.freeze(pl_module.backbone) without forwarding self.train_bn, so BatchNorm layers stayed trainable during the initial frozen phase regardless of the train_bn setting. The unfreezing path already passed the flag correctly.
1 parent bb7820f commit 55c2ea2

2 files changed

Lines changed: 32 additions & 1 deletion

File tree

src/lightning/pytorch/callbacks/finetuning.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -455,7 +455,7 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -
455455

456456
@override
457457
def freeze_before_training(self, pl_module: "pl.LightningModule") -> None:
458-
self.freeze(pl_module.backbone)
458+
self.freeze(pl_module.backbone, train_bn=self.train_bn)
459459

460460
@override
461461
def finetune_function(self, pl_module: "pl.LightningModule", epoch: int, optimizer: Optimizer) -> None:

tests/tests_pytorch/callbacks/test_finetuning_callback.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,37 @@ def train_dataloader(self):
7373
assert model.backbone.has_been_used
7474

7575

76+
def test_finetuning_callback_train_bn_false(tmp_path):
77+
"""Test that BackboneFinetuning respects train_bn=False during the initial freeze phase."""
78+
seed_everything(42)
79+
80+
class FinetuningBoringModel(BoringModel):
81+
def __init__(self):
82+
super().__init__()
83+
self.backbone = nn.Sequential(nn.Linear(32, 32, bias=False), nn.BatchNorm1d(32), nn.ReLU())
84+
self.layer = nn.Linear(32, 2)
85+
86+
def forward(self, x):
87+
return self.layer(self.backbone(x))
88+
89+
def configure_optimizers(self):
90+
return torch.optim.SGD(self.layer.parameters(), lr=0.1)
91+
92+
def train_dataloader(self):
93+
return DataLoader(RandomDataset(32, 64), batch_size=2)
94+
95+
model = FinetuningBoringModel()
96+
callback = BackboneFinetuning(unfreeze_backbone_at_epoch=3, train_bn=False, verbose=False)
97+
98+
trainer = Trainer(limit_train_batches=4, default_root_dir=tmp_path, callbacks=[callback], max_epochs=1)
99+
trainer.fit(model)
100+
101+
# With train_bn=False, BatchNorm should be fully frozen (not trainable, no running stats)
102+
assert not model.backbone[1].weight.requires_grad
103+
assert not model.backbone[1].bias.requires_grad
104+
assert not model.backbone[1].track_running_stats
105+
106+
76107
class TestBackboneFinetuningWarningCallback(BackboneFinetuning):
77108
def finetune_function(self, pl_module, epoch: int, optimizer):
78109
"""Called when the epoch begins."""

0 commit comments

Comments
 (0)