@@ -156,7 +156,7 @@ def on_epoch_end(self):
156156
157157
158158class EarlyStopping (Callback ):
159- """Stop training when a monitored quantity (training loss) has stopped improving.
159+ """Stop training when a monitored quantity (training or testing loss) has stopped improving.
160160 Only checked at validation step according to ``display_every`` in ``Model.train``.
161161
162162 Args:
@@ -169,12 +169,14 @@ class EarlyStopping(Callback):
169169 baseline: Baseline value for the monitored quantity to reach.
170170 Training will stop if the model doesn't show improvement
171171 over the baseline.
172+ monitor: The loss function that is monitored. Either 'loss_train' or 'loss_test'
172173 """
173174
174- def __init__ (self , min_delta = 0 , patience = 0 , baseline = None ):
175+ def __init__ (self , min_delta = 0 , patience = 0 , baseline = None , monitor = "loss_train" ):
175176 super (EarlyStopping , self ).__init__ ()
176177
177178 self .baseline = baseline
179+ self .monitor = monitor
178180 self .patience = patience
179181 self .min_delta = min_delta
180182 self .wait = 0
@@ -208,7 +210,14 @@ def on_train_end(self):
208210 print ("Epoch {}: early stopping" .format (self .stopped_epoch ))
209211
210212 def get_monitor_value (self ):
211- return sum (self .model .train_state .loss_train )
213+ if self .monitor == "loss_train" :
214+ result = sum (self .model .train_state .loss_train )
215+ elif self .monitor == "loss_test" :
216+ result = sum (self .model .train_state .loss_test )
217+ else :
218+ raise ValueError ("The specified monitor function is incorrect." )
219+
220+ return result
212221
213222
214223class Timer (Callback ):
0 commit comments