Skip to content

Commit 5c64196

Browse files
authored
EarlyStopping callback can monitor testing loss (#501)
1 parent f7ba50e commit 5c64196

File tree

1 file changed

+12
-3
lines changed

1 file changed

+12
-3
lines changed

deepxde/callbacks.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def on_epoch_end(self):
156156

157157

158158
class EarlyStopping(Callback):
159-
"""Stop training when a monitored quantity (training loss) has stopped improving.
159+
"""Stop training when a monitored quantity (training or testing loss) has stopped improving.
160160
Only checked at validation step according to ``display_every`` in ``Model.train``.
161161
162162
Args:
@@ -169,12 +169,14 @@ class EarlyStopping(Callback):
169169
baseline: Baseline value for the monitored quantity to reach.
170170
Training will stop if the model doesn't show improvement
171171
over the baseline.
172+
monitor: The loss function that is monitored. Either 'loss_train' or 'loss_test'
172173
"""
173174

174-
def __init__(self, min_delta=0, patience=0, baseline=None):
175+
def __init__(self, min_delta=0, patience=0, baseline=None, monitor="loss_train"):
175176
super(EarlyStopping, self).__init__()
176177

177178
self.baseline = baseline
179+
self.monitor = monitor
178180
self.patience = patience
179181
self.min_delta = min_delta
180182
self.wait = 0
@@ -208,7 +210,14 @@ def on_train_end(self):
208210
print("Epoch {}: early stopping".format(self.stopped_epoch))
209211

210212
def get_monitor_value(self):
211-
return sum(self.model.train_state.loss_train)
213+
if self.monitor == "loss_train":
214+
result = sum(self.model.train_state.loss_train)
215+
elif self.monitor == "loss_test":
216+
result = sum(self.model.train_state.loss_test)
217+
else:
218+
raise ValueError("The specified monitor function is incorrect.")
219+
220+
return result
212221

213222

214223
class Timer(Callback):

0 commit comments

Comments
 (0)