Skip to content

Commit fe8246d

Browse files
authored
Merge pull request #178 from Genentech/early-stopping
DG: Add early stopping support to LightningModel
2 parents ad0d622 + 3b42a4a commit fe8246d

1 file changed

Lines changed: 15 additions & 1 deletion

File tree

src/grelu/lightning/__init__.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import pytorch_lightning as pl
1717
import torch
1818
from einops import rearrange
19-
from pytorch_lightning.callbacks import ModelCheckpoint
19+
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
2020
from pytorch_lightning.loggers import CSVLogger, WandbLogger
2121
from torch import Tensor, nn, optim
2222
from torch.utils.data import DataLoader
@@ -57,6 +57,10 @@
5757
"class_weights": None,
5858
"total_weight": None,
5959
"accumulate_grad_batches": 1,
60+
"early_stopping": False,
61+
"patience": 5,
62+
"monitor": "val_loss",
63+
"mode": "min",
6064
}
6165

6266

@@ -545,6 +549,16 @@ def train_on_dataset(
545549
else:
546550
raise Exception("Checkpoint type must be a bool or dict")
547551

552+
# Early stopping
553+
if self.train_params["early_stopping"]:
554+
checkpoint_callbacks.append(
555+
EarlyStopping(
556+
monitor=self.train_params["monitor"],
557+
patience=self.train_params["patience"],
558+
mode=self.train_params["mode"],
559+
)
560+
)
561+
548562
# Get device
549563
accelerator, devices = self.parse_devices(self.train_params["devices"])
550564

0 commit comments

Comments
 (0)