Skip to content

Commit 013eb9a

Browse files
author
Jordan Stomps
committed
adding cross validation implementation
1 parent 32b8076 commit 013eb9a

File tree

2 files changed

+127
-0
lines changed

2 files changed

+127
-0
lines changed

scripts/utils.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
# pca
1010
from sklearn.preprocessing import StandardScaler
1111
from sklearn.decomposition import PCA
12+
# Cross Validation
13+
from sklearn.model_selection import KFold, StratifiedKFold
1214

1315

1416
class EarlyStopper:
@@ -96,6 +98,99 @@ def run_hyperopt(space, model, data_dict, max_evals=50, verbose=True):
9698
return best, worst
9799

98100

101+
def cross_validation(model, X, y, params, n_splits=3,
102+
stratified=False, random_state=None):
103+
'''
104+
Perform K-Fold cross validation using sklearn and a given model.
105+
The model *must* have a fresh_start method (see models in RadClass/models).
106+
fresh_start() is used instead of train() to be agnostic to the data needed
107+
for training (fresh_start requires a data_dict whereas each model's
108+
train could take different combinations of labeled & unlabeled data).
109+
This also avoids the need to do hyperparameter optimization (and
110+
therefore many training epochs) for every K-Fold.
111+
NOTE: fresh_start returns the model and results in a dictionary but
112+
does not overwrite/save the model to the respective class.
113+
You can manually overwrite using model.model = return.model
114+
Hyperparameter optimization (model.optimize) can be done before or after
115+
cross validation to specify the (optimal) parameters used by the model
116+
since they are required here.
117+
NOTE: Fixed default to shuffle data during cross validation splits.
118+
(See sklearn cross validation docs for more info.)
119+
NOTE: Unlabeled data, if provided, will always be included in the training
120+
dataset. This means that this cross validation implementation is
121+
susceptible to bias in the unlabeled data distribution. To test for
122+
this bias, a user can manually run cross validation as a parent to
123+
calling this function, splitting the unlabeled data and adding
124+
different folds into X.
125+
Inputs:
126+
model: ML model class object (e.g. RadClass/models).
127+
Must have a fresh_start() method.
128+
NOTE: If the model expects unlabeled data but unlabed data is not
129+
provided in X/y, an error will likely be thrown when training the model
130+
through fresh_start.
131+
X: array of feature vectors (rows of individual instances, cols of vectors)
132+
This should include all data for training and testing (since the
133+
testing subset will be split by cross validation), including unlabeled
134+
data if needed/used.
135+
y: array/vector of labels for X. If including unlabeled data, use -1.
136+
This should have the same order as X. That is, each row index in X
137+
has an associated label with the same index in y.
138+
params: dictionary of hyperparameters. Will depend on model used.
139+
Alternatively, use model.params for models in RadClass/models
140+
n_splits: int number of splits for K-Fold cross validation
141+
stratified: bool; if True, balance the K-Folds to have roughly the same
142+
proportion of samples from each class.
143+
random_state: seed for reproducility.
144+
'''
145+
146+
# return lists
147+
accs = []
148+
reports = []
149+
150+
if stratified:
151+
cv = StratifiedKFold(n_splits=n_splits, random_state=random_state,
152+
shuffle=True)
153+
else:
154+
cv = KFold(n_splits=n_splits, random_state=random_state,
155+
shuffle=True)
156+
157+
# separate unlabeled data if included
158+
Ux = None
159+
Uy = None
160+
if -1 in y:
161+
U_idx = np.where(y == -1)[0]
162+
L_idx = np.where(y != -1)[0]
163+
Ux = X[U_idx]
164+
Uy = y[U_idx]
165+
Lx = X[L_idx]
166+
Ly = y[L_idx]
167+
else:
168+
Lx = X
169+
Ly = y
170+
# conduct K-Fold cross validation
171+
cv.get_n_splits(Lx, Ly)
172+
for train_idx, test_idx in cv.split(Lx, Ly):
173+
trainx, testx = Lx[train_idx], Lx[test_idx]
174+
trainy, testy = Ly[train_idx], Ly[test_idx]
175+
176+
# construct data dictionary for training in fresh_start
177+
data_dict = {'trainx': trainx, 'trainy': trainy,
178+
'testx': testx, 'testy': testy}
179+
if Ux is not None:
180+
data_dict['Ux'] = Ux
181+
data_dict['Uy'] = Uy
182+
results = model.fresh_start(params, data_dict)
183+
accs = np.append(accs, results['accuracy'])
184+
reports = np.append(reports, results)
185+
186+
# report cross validation results
187+
print('Average accuracy:', np.mean(accs))
188+
print('Max accuracy:', np.max(accs))
189+
print('All accuracy:', accs)
190+
# return the results of fresh_start for the max accuracy model
191+
return reports[np.argmax(accs)]
192+
193+
99194
def pca(Lx, Ly, Ux, Uy, filename):
100195
'''
101196
A function for computing and plotting 2D PCA.

tests/test_models.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,38 @@ def test_utils():
4646
labels,
4747
test_size=0.5,
4848
random_state=0)
49+
Uy = np.full_like(Uy, -1)
50+
51+
# test cross validation for supervised data using LogReg
52+
params = {'max_iter': 2022, 'tol': 0.5, 'C': 5.0}
53+
model = LogReg(params=params)
54+
max_acc_model = utils.cross_validation(model=model,
55+
X=X,
56+
y=y,
57+
params=params)
58+
assert max_acc_model['accuracy'] >= 0.5
59+
60+
# test cross validation for supervised data and StratifiedKFold with LogReg
61+
params = {'max_iter': 2022, 'tol': 0.5, 'C': 5.0}
62+
model = LogReg(params=params)
63+
max_acc_model = utils.cross_validation(model=model,
64+
X=X,
65+
y=y,
66+
params=params,
67+
stratified=True)
68+
assert max_acc_model['accuracy'] >= 0.5
69+
70+
# test cross validation for SSML with LabelProp
71+
params = {'gamma': 10, 'n_neighbors': 15, 'max_iter': 2022, 'tol': 0.5}
72+
model = LabelProp(params=params)
73+
max_acc_model = utils.cross_validation(model=model,
74+
X=np.append(X, Ux, axis=0),
75+
y=np.append(y, Uy, axis=0),
76+
params=params,
77+
stratified=True)
78+
assert max_acc_model['accuracy'] >= 0.5
79+
80+
# data split for data visualization
4981
X_train, X_test, y_train, y_test = train_test_split(X,
5082
y,
5183
test_size=0.2,

0 commit comments

Comments
 (0)