Skip to content

Commit 32b8076

Browse files
author
Jordan Stomps
committed
adding an EarlyStopper class for managing that functionality
1 parent 5e050cf commit 32b8076

File tree

4 files changed

+87
-18
lines changed

4 files changed

+87
-18
lines changed

models/SSML/ShadowCNN.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import shadow.utils
1414
from shadow.utils import set_seed
1515
# diagnostics
16-
from scripts.utils import run_hyperopt
16+
from scripts.utils import EarlyStopper, run_hyperopt
1717
import joblib
1818

1919

@@ -322,6 +322,9 @@ def train(self, trainx, trainy, Ux, testx=None, testy=None):
322322
# labels for unlabeled data are always "-1"
323323
xEnt = torch.nn.CrossEntropyLoss(ignore_index=-1)
324324

325+
# generate early-stopping watchdog
326+
# TODO: allow a user of ShadowCNN to specify EarlyStopper's params
327+
stopper = EarlyStopper(patience=3, min_delta=0)
325328
n_epochs = 100
326329
self.eaat.to(self.device)
327330
losscurve = []
@@ -345,6 +348,20 @@ def train(self, trainx, trainy, Ux, testx=None, testy=None):
345348
pred, acc = self.predict(testx, testy)
346349
evalcurve.append(acc)
347350

351+
self.eaat.train()
352+
# test for early stopping
353+
x_val = torch.FloatTensor(
354+
testx.copy()[:, ::self.params['binning']])
355+
x_val = x_val.reshape((x_val.shape[0],
356+
1,
357+
x_val.shape[1])).to(self.device)
358+
y_val = torch.LongTensor(testy).to(self.device)
359+
out = self.eaat(x_val)
360+
val_loss = xEnt(out, y_val) + \
361+
self.eaat.get_technique_cost(x_val)
362+
if stopper.early_stop(val_loss):
363+
break
364+
348365
# optionally return the training accuracy if test data was provided
349366
return losscurve, evalcurve
350367

models/SSML/ShadowNN.py

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import shadow.utils
1010
from shadow.utils import set_seed
1111
# diagnostics
12-
from scripts.utils import run_hyperopt
12+
from scripts.utils import EarlyStopper, run_hyperopt
1313
import joblib
1414

1515

@@ -199,12 +199,15 @@ def train(self, trainx, trainy, Ux, testx=None, testy=None):
199199
n_epochs = 100
200200
xt = torch.Tensor(xtens).to(self.device)
201201
yt = torch.LongTensor(ytens).to(self.device)
202+
# generate early-stopping watchdog
203+
# TODO: allow a user of ShadowCNN to specify EarlyStopper's params
204+
stopper = EarlyStopper(patience=3, min_delta=0)
202205
# saves history for max accuracy
203206
acc_history = []
204-
# set the model into training mode
205-
# NOTE: change this to .eval() mode for testing and back again
206-
self.eaat.train()
207207
for epoch in range(n_epochs):
208+
# set the model into training mode
209+
# NOTE: change this to .eval() mode for testing and back again
210+
self.eaat.train()
208211
# Forward/backward pass for training semi-supervised model
209212
out = self.eaat(xt)
210213
# supervised + unsupervised loss
@@ -214,20 +217,26 @@ def train(self, trainx, trainy, Ux, testx=None, testy=None):
214217
self.eaat_opt.step()
215218

216219
if testx is not None and testy is not None:
220+
x_val = torch.FloatTensor(
221+
testx.copy()
222+
)[:, ::self.params['binning']].to(self.device)
223+
y_val = torch.LongTensor(testy.copy()).to(self.device)
224+
217225
self.eaat.eval()
218-
eaat_pred = torch.max(self.eaat(
219-
torch.FloatTensor(
220-
testx.copy()[:,
221-
::self.params[
222-
'binning']
223-
]
224-
)
225-
), 1)[-1]
226+
eaat_pred = torch.max(self.eaat(x_val), 1)[-1]
226227
acc = shadow.losses.accuracy(eaat_pred,
227-
torch.LongTensor(testy.copy())
228+
y_val
228229
).data.item()
229230
acc_history.append(acc)
230231

232+
self.eaat.train()
233+
# test for early stopping
234+
out = self.eaat(x_val)
235+
val_loss = self.xEnt(out, y_val) + \
236+
self.eaat.get_technique_cost(x_val)
237+
if stopper.early_stop(val_loss):
238+
break
239+
231240
# optionally return the training accuracy if test data was provided
232241
return acc_history
233242

@@ -245,15 +254,18 @@ def predict(self, testx, testy=None):
245254
eaat_pred = torch.max(self.eaat(
246255
torch.FloatTensor(
247256
testx.copy()[:, ::self.params['binning']]
248-
)
257+
).to(self.device)
249258
), 1)[-1]
250259

251260
acc = None
252261
if testy is not None:
253262
acc = shadow.losses.accuracy(eaat_pred,
254-
torch.LongTensor(testy.copy())
263+
torch.LongTensor(
264+
testy.copy()).to(self.device)
255265
).data.item()
256266

267+
# return tensor to cpu if on gpu and convert to numpy for return
268+
eaat_pred = eaat_pred.cpu().numpy()
257269
return eaat_pred, acc
258270

259271
def save(self, filename):

scripts/utils.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,47 @@
1111
from sklearn.decomposition import PCA
1212

1313

14+
class EarlyStopper:
15+
'''
16+
Early stopping mechanism for neural networks.
17+
Code adapted from user "isle_of_gods" from StackOverflow:
18+
https://stackoverflow.com/questions/71998978/early-stopping-in-pytorch
19+
Use this class to break a training loop if the validation loss is low.
20+
Inputs:
21+
patience: integer; forces stop if validation loss has not improved
22+
for some time
23+
min_delta: "fudge value" for how much loss to tolerate before stopping
24+
'''
25+
26+
def __init__(self, patience=1, min_delta=0):
27+
self.patience = patience
28+
self.min_delta = min_delta
29+
self.counter = 0
30+
self.min_validation_loss = np.inf
31+
32+
def early_stop(self, validation_loss):
33+
'''
34+
Tests for the early stopping condition if the validation loss
35+
has not improved for a certain period of time (patience).
36+
Inputs:
37+
validation_loss: typically a float value for the loss function of
38+
a neural network training loop
39+
'''
40+
41+
if validation_loss < self.min_validation_loss:
42+
# keep track of the smallest validation loss
43+
# if it has been beaten, restart patience
44+
self.min_validation_loss = validation_loss
45+
self.counter = 0
46+
elif validation_loss > (self.min_validation_loss + self.min_delta):
47+
# keep track of whether validation loss has been decreasing
48+
# by a tolerable amount
49+
self.counter += 1
50+
if self.counter >= self.patience:
51+
return True
52+
return False
53+
54+
1455
def run_hyperopt(space, model, data_dict, max_evals=50, verbose=True):
1556
'''
1657
Runs hyperparameter optimization on a model given a parameter space.

tests/test_models.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -337,8 +337,7 @@ def test_ShadowNN():
337337
# rather than decimals
338338
# uninteresting test if Shadow predicts all one class
339339
# TODO: make the default params test meaningful
340-
# NOTE: .numpy() needed because model.predict() returns a tensor
341-
assert np.count_nonzero(pred.numpy() == y_test) > 0
340+
assert np.count_nonzero(pred == y_test) > 0
342341

343342
# testing hyperopt optimize methods
344343
space = {'hidden_layer': 10,

0 commit comments

Comments
 (0)