File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change 1616import pytorch_lightning as pl
1717import torch
1818from einops import rearrange
19- from pytorch_lightning .callbacks import ModelCheckpoint
19+ from pytorch_lightning .callbacks import EarlyStopping , ModelCheckpoint
2020from pytorch_lightning .loggers import CSVLogger , WandbLogger
2121from torch import Tensor , nn , optim
2222from torch .utils .data import DataLoader
5757 "class_weights" : None ,
5858 "total_weight" : None ,
5959 "accumulate_grad_batches" : 1 ,
60+ "early_stopping" : False ,
61+ "patience" : 5 ,
62+ "monitor" : "val_loss" ,
63+ "mode" : "min" ,
6064}
6165
6266
@@ -545,6 +549,16 @@ def train_on_dataset(
545549 else :
546550 raise Exception ("Checkpoint type must be a bool or dict" )
547551
552+ # Early stopping
553+ if self .train_params ["early_stopping" ]:
554+ checkpoint_callbacks .append (
555+ EarlyStopping (
556+ monitor = self .train_params ["monitor" ],
557+ patience = self .train_params ["patience" ],
558+ mode = self .train_params ["mode" ],
559+ )
560+ )
561+
548562 # Get device
549563 accelerator , devices = self .parse_devices (self .train_params ["devices" ])
550564
You can’t perform that action at this time.
0 commit comments